@@ -41,6 +41,7 @@
static struct kset *iommu_group_kset;
static DEFINE_IDA(iommu_group_ida);
static DEFINE_IDA(iommu_global_pasid_ida);
+static DEFINE_MUTEX(iommu_probe_device_lock);
static unsigned int iommu_def_domain_type __read_mostly;
static bool iommu_dma_strict __read_mostly = IS_ENABLED(CONFIG_IOMMU_DEFAULT_DMA_STRICT);
@@ -498,7 +499,6 @@ static int __iommu_probe_device(struct device *dev,
struct iommu_fwspec *fwspec = caller_fwspec;
const struct iommu_ops *ops;
struct iommu_group *group;
- static DEFINE_MUTEX(iommu_probe_device_lock);
struct group_device *gdev;
int ret;
@@ -2985,8 +2985,11 @@ int iommu_fwspec_of_xlate(struct iommu_fwspec *fwspec, struct device *dev,
if (!fwspec->ops->of_xlate)
return -ENODEV;
- if (!dev_iommu_get(dev))
+ mutex_lock(&iommu_probe_device_lock);
+ if (!dev_iommu_get(dev)) {
+ mutex_unlock(&iommu_probe_device_lock);
return -ENOMEM;
+ }
/*
* ops->of_xlate() requires the fwspec to be passed through dev->iommu,
@@ -2998,6 +3001,7 @@ int iommu_fwspec_of_xlate(struct iommu_fwspec *fwspec, struct device *dev,
ret = fwspec->ops->of_xlate(dev, iommu_spec);
if (dev->iommu->fwspec == fwspec)
dev->iommu->fwspec = NULL;
+ mutex_unlock(&iommu_probe_device_lock);
return ret;
}
@@ -3027,6 +3031,8 @@ int iommu_fwspec_init(struct device *dev, struct fwnode_handle *iommu_fwnode,
struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
int ret;
+ lockdep_assert_held(&iommu_probe_device_lock);
+
if (fwspec)
return ops == fwspec->ops ? 0 : -EINVAL;
@@ -3080,6 +3086,8 @@ int iommu_fwspec_add_ids(struct device *dev, u32 *ids, int num_ids)
{
struct iommu_fwspec *fwspec = dev_iommu_fwspec_get(dev);
+ lockdep_assert_held(&iommu_probe_device_lock);
+
if (!fwspec)
return -EINVAL;
return iommu_fwspec_append_ids(fwspec, ids, num_ids);