Message ID | 6b8bb8f2319bf26ead928321f609105e4e5eecf4.1729897278.git.nicolinc@nvidia.com |
---|---|
State | New |
Headers | show |
Series | iommufd: Add vIOMMU infrastructure (Part-2: vDEVICE) | expand |
On Fri, Oct 25, 2024 at 04:50:34PM -0700, Nicolin Chen wrote: > @@ -497,17 +497,35 @@ int iommufd_hwpt_invalidate(struct iommufd_ucmd *ucmd) > goto out; > } > > - hwpt = iommufd_get_hwpt_nested(ucmd, cmd->hwpt_id); > - if (IS_ERR(hwpt)) { > - rc = PTR_ERR(hwpt); > + pt_obj = iommufd_get_object(ucmd->ictx, cmd->hwpt_id, IOMMUFD_OBJ_ANY); > + if (IS_ERR(pt_obj)) { > + rc = PTR_ERR(pt_obj); > goto out; > } > + if (pt_obj->type == IOMMUFD_OBJ_HWPT_NESTED) { > + struct iommufd_hw_pagetable *hwpt = > + container_of(pt_obj, struct iommufd_hw_pagetable, obj); > + > + rc = hwpt->domain->ops->cache_invalidate_user(hwpt->domain, > + &data_array); > + } else if (pt_obj->type == IOMMUFD_OBJ_VIOMMU) { > + struct iommufd_viommu *viommu = > + container_of(pt_obj, struct iommufd_viommu, obj); > + > + if (!viommu->ops || !viommu->ops->cache_invalidate) { > + rc = -EOPNOTSUPP; > + goto out_put_pt; > + } > + rc = viommu->ops->cache_invalidate(viommu, &data_array); > + } else { > + rc = -EINVAL; > + goto out_put_pt; > + } Given the test in iommufd_viommu_alloc_hwpt_nested() is: if (WARN_ON_ONCE(hwpt->domain->type != IOMMU_DOMAIN_NESTED || (!viommu->ops->cache_invalidate && !hwpt->domain->ops->cache_invalidate_user))) { We will crash if the user passes a viommu allocated domain as IOMMUFD_OBJ_HWPT_NESTED since the above doesn't check it. I suggest we put the required if (ops..) -EOPNOTSUPP above and remove the ops->cache_invalidate checks from both WARN_ONs. Jason
On Tue, Oct 29, 2024 at 04:09:41PM -0300, Jason Gunthorpe wrote: > On Fri, Oct 25, 2024 at 04:50:34PM -0700, Nicolin Chen wrote: > > @@ -497,17 +497,35 @@ int iommufd_hwpt_invalidate(struct iommufd_ucmd *ucmd) > > goto out; > > } > > > > - hwpt = iommufd_get_hwpt_nested(ucmd, cmd->hwpt_id); > > - if (IS_ERR(hwpt)) { > > - rc = PTR_ERR(hwpt); > > + pt_obj = iommufd_get_object(ucmd->ictx, cmd->hwpt_id, IOMMUFD_OBJ_ANY); > > + if (IS_ERR(pt_obj)) { > > + rc = PTR_ERR(pt_obj); > > goto out; > > } > > + if (pt_obj->type == IOMMUFD_OBJ_HWPT_NESTED) { > > + struct iommufd_hw_pagetable *hwpt = > > + container_of(pt_obj, struct iommufd_hw_pagetable, obj); > > + > > + rc = hwpt->domain->ops->cache_invalidate_user(hwpt->domain, > > + &data_array); > > + } else if (pt_obj->type == IOMMUFD_OBJ_VIOMMU) { > > + struct iommufd_viommu *viommu = > > + container_of(pt_obj, struct iommufd_viommu, obj); > > + > > + if (!viommu->ops || !viommu->ops->cache_invalidate) { > > + rc = -EOPNOTSUPP; > > + goto out_put_pt; > > + } > > + rc = viommu->ops->cache_invalidate(viommu, &data_array); > > + } else { > > + rc = -EINVAL; > > + goto out_put_pt; > > + } > > Given the test in iommufd_viommu_alloc_hwpt_nested() is: > > if (WARN_ON_ONCE(hwpt->domain->type != IOMMU_DOMAIN_NESTED || > (!viommu->ops->cache_invalidate && > !hwpt->domain->ops->cache_invalidate_user))) > { > > We will crash if the user passes a viommu allocated domain as > IOMMUFD_OBJ_HWPT_NESTED since the above doesn't check it. Ah, that was missed. > I suggest we put the required if (ops..) -EOPNOTSUPP above and remove > the ops->cache_invalidate checks from both WARN_ONs. Ack. I will add hwpt->domain->ops check: --------------------------------------------------------------------- if (pt_obj->type == IOMMUFD_OBJ_HWPT_NESTED) { struct iommufd_hw_pagetable *hwpt = container_of(pt_obj, struct iommufd_hw_pagetable, obj); if (!hwpt->domain->ops || !hwpt->domain->ops->cache_invalidate_user) { rc = -EOPNOTSUPP; goto out_put_pt; } rc = hwpt->domain->ops->cache_invalidate_user(hwpt->domain, &data_array); } else if (pt_obj->type == IOMMUFD_OBJ_VIOMMU) { struct iommufd_viommu *viommu = container_of(pt_obj, struct iommufd_viommu, obj); if (!viommu->ops || !viommu->ops->cache_invalidate) { rc = -EOPNOTSUPP; goto out_put_pt; } rc = viommu->ops->cache_invalidate(viommu, &data_array); } else { --------------------------------------------------------------------- Thanks Nicolin
diff --git a/include/uapi/linux/iommufd.h b/include/uapi/linux/iommufd.h index b699ecb7aa9c..c2c5f49fdf17 100644 --- a/include/uapi/linux/iommufd.h +++ b/include/uapi/linux/iommufd.h @@ -730,7 +730,7 @@ struct iommu_hwpt_vtd_s1_invalidate { /** * struct iommu_hwpt_invalidate - ioctl(IOMMU_HWPT_INVALIDATE) * @size: sizeof(struct iommu_hwpt_invalidate) - * @hwpt_id: ID of a nested HWPT for cache invalidation + * @hwpt_id: ID of a nested HWPT or a vIOMMU, for cache invalidation * @data_uptr: User pointer to an array of driver-specific cache invalidation * data. * @data_type: One of enum iommu_hwpt_invalidate_data_type, defining the data @@ -741,8 +741,11 @@ struct iommu_hwpt_vtd_s1_invalidate { * Output the number of requests successfully handled by kernel. * @__reserved: Must be 0. * - * Invalidate the iommu cache for user-managed page table. Modifications on a - * user-managed page table should be followed by this operation to sync cache. + * Invalidate iommu cache for user-managed page table or vIOMMU. Modifications + * on a user-managed page table should be followed by this operation, if a HWPT + * is passed in via @hwpt_id. Other caches, such as device cache or descriptor + * cache can be flushed if a vIOMMU is passed in via the @hwpt_id field. + * * Each ioctl can support one or more cache invalidation requests in the array * that has a total size of @entry_len * @entry_num. * diff --git a/drivers/iommu/iommufd/hw_pagetable.c b/drivers/iommu/iommufd/hw_pagetable.c index fd260a67b82c..5301ba69fb8a 100644 --- a/drivers/iommu/iommufd/hw_pagetable.c +++ b/drivers/iommu/iommufd/hw_pagetable.c @@ -483,7 +483,7 @@ int iommufd_hwpt_invalidate(struct iommufd_ucmd *ucmd) .entry_len = cmd->entry_len, .entry_num = cmd->entry_num, }; - struct iommufd_hw_pagetable *hwpt; + struct iommufd_object *pt_obj; u32 done_num = 0; int rc; @@ -497,17 +497,35 @@ int iommufd_hwpt_invalidate(struct iommufd_ucmd *ucmd) goto out; } - hwpt = iommufd_get_hwpt_nested(ucmd, cmd->hwpt_id); - if (IS_ERR(hwpt)) { - rc = PTR_ERR(hwpt); + pt_obj = iommufd_get_object(ucmd->ictx, cmd->hwpt_id, IOMMUFD_OBJ_ANY); + if (IS_ERR(pt_obj)) { + rc = PTR_ERR(pt_obj); goto out; } + if (pt_obj->type == IOMMUFD_OBJ_HWPT_NESTED) { + struct iommufd_hw_pagetable *hwpt = + container_of(pt_obj, struct iommufd_hw_pagetable, obj); + + rc = hwpt->domain->ops->cache_invalidate_user(hwpt->domain, + &data_array); + } else if (pt_obj->type == IOMMUFD_OBJ_VIOMMU) { + struct iommufd_viommu *viommu = + container_of(pt_obj, struct iommufd_viommu, obj); + + if (!viommu->ops || !viommu->ops->cache_invalidate) { + rc = -EOPNOTSUPP; + goto out_put_pt; + } + rc = viommu->ops->cache_invalidate(viommu, &data_array); + } else { + rc = -EINVAL; + goto out_put_pt; + } - rc = hwpt->domain->ops->cache_invalidate_user(hwpt->domain, - &data_array); done_num = data_array.entry_num; - iommufd_put_object(ucmd->ictx, &hwpt->obj); +out_put_pt: + iommufd_put_object(ucmd->ictx, pt_obj); out: cmd->entry_num = done_num; if (iommufd_ucmd_respond(ucmd, sizeof(*cmd))) diff --git a/tools/testing/selftests/iommu/iommufd.c b/tools/testing/selftests/iommu/iommufd.c index 93255403dee4..44fbc7e5aa2e 100644 --- a/tools/testing/selftests/iommu/iommufd.c +++ b/tools/testing/selftests/iommu/iommufd.c @@ -362,9 +362,9 @@ TEST_F(iommufd_ioas, alloc_hwpt_nested) EXPECT_ERRNO(EBUSY, _test_ioctl_destroy(self->fd, parent_hwpt_id)); - /* hwpt_invalidate only supports a user-managed hwpt (nested) */ + /* hwpt_invalidate does not support a parent hwpt */ num_inv = 1; - test_err_hwpt_invalidate(ENOENT, parent_hwpt_id, inv_reqs, + test_err_hwpt_invalidate(EINVAL, parent_hwpt_id, inv_reqs, IOMMU_HWPT_INVALIDATE_DATA_SELFTEST, sizeof(*inv_reqs), &num_inv); assert(!num_inv);