diff mbox series

[v3,09/10] iommu: Split iommu_group_add_device()

Message ID 9-v3-328044aa278c+45e49-iommu_probe_jgg@nvidia.com
State Accepted
Commit fa08280364882c42993dc5d8394c467d76803fdb
Headers show
Series [v3,01/10] iommu: Have __iommu_probe_device() check for already probed devices | expand

Commit Message

Jason Gunthorpe June 6, 2023, 12:59 a.m. UTC
Move the list_add_tail() for the group_device into the critical region
that immediately follows in __iommu_probe_device(). This avoids one case
of unlocking and immediately re-locking the group->mutex.

Consistently make the caller responsible for setting dev->iommu_group,
prior patches moved this into iommu_init_device(), make the no-driver path
do this in iommu_group_add_device().

This completes making __iommu_group_free_device() and
iommu_group_alloc_device() into pair'd functions.

Reviewed-by: Lu Baolu <baolu.lu@linux.intel.com>
Reviewed-by: Kevin Tian <kevin.tian@intel.com>
Signed-off-by: Jason Gunthorpe <jgg@nvidia.com>
---
 drivers/iommu/iommu.c | 66 ++++++++++++++++++++++++++++---------------
 1 file changed, 43 insertions(+), 23 deletions(-)
diff mbox series

Patch

diff --git a/drivers/iommu/iommu.c b/drivers/iommu/iommu.c
index 45d69462ddedab..4da3623d7686a0 100644
--- a/drivers/iommu/iommu.c
+++ b/drivers/iommu/iommu.c
@@ -129,6 +129,8 @@  static int iommu_create_device_direct_mappings(struct iommu_domain *domain,
 					       struct device *dev);
 static ssize_t iommu_group_store_type(struct iommu_group *group,
 				      const char *buf, size_t count);
+static struct group_device *iommu_group_alloc_device(struct iommu_group *group,
+						     struct device *dev);
 
 #define IOMMU_GROUP_ATTR(_name, _mode, _show, _store)		\
 struct iommu_group_attribute iommu_group_attr_##_name =		\
@@ -435,6 +437,7 @@  static int __iommu_probe_device(struct device *dev, struct list_head *group_list
 	const struct iommu_ops *ops = dev->bus->iommu_ops;
 	struct iommu_group *group;
 	static DEFINE_MUTEX(iommu_probe_device_lock);
+	struct group_device *gdev;
 	int ret;
 
 	if (!ops)
@@ -459,16 +462,17 @@  static int __iommu_probe_device(struct device *dev, struct list_head *group_list
 		goto out_unlock;
 
 	group = dev->iommu_group;
-	ret = iommu_group_add_device(group, dev);
+	gdev = iommu_group_alloc_device(group, dev);
 	mutex_lock(&group->mutex);
-	if (ret)
+	if (IS_ERR(gdev)) {
+		ret = PTR_ERR(gdev);
 		goto err_put_group;
+	}
 
+	list_add_tail(&gdev->list, &group->devices);
 	if (group_list && !group->default_domain && list_empty(&group->entry))
 		list_add_tail(&group->entry, group_list);
 	mutex_unlock(&group->mutex);
-	iommu_group_put(group);
-
 	mutex_unlock(&iommu_probe_device_lock);
 
 	return 0;
@@ -578,7 +582,10 @@  static void __iommu_group_remove_device(struct device *dev)
 	}
 	mutex_unlock(&group->mutex);
 
-	/* Pairs with the get in iommu_group_add_device() */
+	/*
+	 * Pairs with the get in iommu_init_device() or
+	 * iommu_group_add_device()
+	 */
 	iommu_group_put(group);
 }
 
@@ -1067,22 +1074,16 @@  static int iommu_create_device_direct_mappings(struct iommu_domain *domain,
 	return ret;
 }
 
-/**
- * iommu_group_add_device - add a device to an iommu group
- * @group: the group into which to add the device (reference should be held)
- * @dev: the device
- *
- * This function is called by an iommu driver to add a device into a
- * group.  Adding a device increments the group reference count.
- */
-int iommu_group_add_device(struct iommu_group *group, struct device *dev)
+/* This is undone by __iommu_group_free_device() */
+static struct group_device *iommu_group_alloc_device(struct iommu_group *group,
+						     struct device *dev)
 {
 	int ret, i = 0;
 	struct group_device *device;
 
 	device = kzalloc(sizeof(*device), GFP_KERNEL);
 	if (!device)
-		return -ENOMEM;
+		return ERR_PTR(-ENOMEM);
 
 	device->dev = dev;
 
@@ -1113,17 +1114,11 @@  int iommu_group_add_device(struct iommu_group *group, struct device *dev)
 		goto err_free_name;
 	}
 
-	iommu_group_ref_get(group);
-	dev->iommu_group = group;
-
-	mutex_lock(&group->mutex);
-	list_add_tail(&device->list, &group->devices);
-	mutex_unlock(&group->mutex);
 	trace_add_device_to_group(group->id, dev);
 
 	dev_info(dev, "Adding to iommu group %d\n", group->id);
 
-	return 0;
+	return device;
 
 err_free_name:
 	kfree(device->name);
@@ -1132,7 +1127,32 @@  int iommu_group_add_device(struct iommu_group *group, struct device *dev)
 err_free_device:
 	kfree(device);
 	dev_err(dev, "Failed to add to iommu group %d: %d\n", group->id, ret);
-	return ret;
+	return ERR_PTR(ret);
+}
+
+/**
+ * iommu_group_add_device - add a device to an iommu group
+ * @group: the group into which to add the device (reference should be held)
+ * @dev: the device
+ *
+ * This function is called by an iommu driver to add a device into a
+ * group.  Adding a device increments the group reference count.
+ */
+int iommu_group_add_device(struct iommu_group *group, struct device *dev)
+{
+	struct group_device *gdev;
+
+	gdev = iommu_group_alloc_device(group, dev);
+	if (IS_ERR(gdev))
+		return PTR_ERR(gdev);
+
+	iommu_group_ref_get(group);
+	dev->iommu_group = group;
+
+	mutex_lock(&group->mutex);
+	list_add_tail(&gdev->list, &group->devices);
+	mutex_unlock(&group->mutex);
+	return 0;
 }
 EXPORT_SYMBOL_GPL(iommu_group_add_device);