diff --git a/drivers/iommu/arm-smmu-v3.c b/drivers/iommu/arm-smmu-v3.c index b7b3b0ff8ed6..d7c65dfe42dc 100644 --- a/drivers/iommu/arm-smmu-v3.c +++ b/drivers/iommu/arm-smmu-v3.c @@ -2286,44 +2286,52 @@ static void arm_smmu_install_ste_for_dev(struct arm_smmu_master *master) } } -static int arm_smmu_enable_ats(struct arm_smmu_master *master) +static bool arm_smmu_ats_supported(struct arm_smmu_master *master) { - int ret; - size_t stu; struct pci_dev *pdev; struct arm_smmu_device *smmu = master->smmu; struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(master->dev); if (!(smmu->features & ARM_SMMU_FEAT_ATS) || !dev_is_pci(master->dev) || !(fwspec->flags & IOMMU_FWSPEC_PCI_RC_ATS) || pci_ats_disabled()) - return -ENXIO; + return false; pdev = to_pci_dev(master->dev); - if (pdev->untrusted) - return -EPERM; + return !pdev->untrusted && pdev->ats_cap; +} + +static void arm_smmu_enable_ats(struct arm_smmu_master *master) +{ + size_t stu; + struct pci_dev *pdev; + struct arm_smmu_device *smmu = master->smmu; + + /* Don't enable ATS at the endpoint if it's not enabled in the STE */ + if (!master->ats_enabled) + return; /* Smallest Translation Unit: log2 of the smallest supported granule */ stu = __ffs(smmu->pgsize_bitmap); - - ret = pci_enable_ats(pdev, stu); - if (ret) - return ret; - - master->ats_enabled = true; - return 0; + pdev = to_pci_dev(master->dev); + if (pci_enable_ats(pdev, stu)) + dev_err(master->dev, "Failed to enable ATS (STU %zu)\n", stu); } static void arm_smmu_disable_ats(struct arm_smmu_master *master) { struct arm_smmu_cmdq_ent cmd; - if (!master->ats_enabled || !dev_is_pci(master->dev)) + if (!master->ats_enabled) return; + pci_disable_ats(to_pci_dev(master->dev)); + /* + * Ensure ATS is disabled at the endpoint before we issue the + * ATC invalidation via the SMMU. + */ + wmb(); arm_smmu_atc_inv_to_cmd(0, 0, 0, &cmd); arm_smmu_atc_inv_master(master, &cmd); - pci_disable_ats(to_pci_dev(master->dev)); - master->ats_enabled = false; } static void arm_smmu_detach_dev(struct arm_smmu_master *master) @@ -2338,10 +2346,10 @@ static void arm_smmu_detach_dev(struct arm_smmu_master *master) list_del(&master->domain_head); spin_unlock_irqrestore(&smmu_domain->devices_lock, flags); - master->domain = NULL; - arm_smmu_install_ste_for_dev(master); - arm_smmu_disable_ats(master); + master->domain = NULL; + master->ats_enabled = false; + arm_smmu_install_ste_for_dev(master); } static int arm_smmu_attach_dev(struct iommu_domain *domain, struct device *dev) @@ -2386,12 +2394,13 @@ static int arm_smmu_attach_dev(struct iommu_domain *domain, struct device *dev) spin_unlock_irqrestore(&smmu_domain->devices_lock, flags); if (smmu_domain->stage != ARM_SMMU_DOMAIN_BYPASS) - arm_smmu_enable_ats(master); + master->ats_enabled = arm_smmu_ats_supported(master); if (smmu_domain->stage == ARM_SMMU_DOMAIN_S1) arm_smmu_write_ctx_desc(smmu, &smmu_domain->s1_cfg); arm_smmu_install_ste_for_dev(master); + arm_smmu_enable_ats(master); out_unlock: mutex_unlock(&smmu_domain->init_mutex); return ret;