@@ -2405,6 +2405,9 @@ int iommu_fwspec_init(struct device *dev, struct fwnode_handle *iommu_fwnode,
if (fwspec)
return ops == fwspec->ops ? 0 : -EINVAL;
+ if (!dev_iommu_get(dev))
+ return -ENOMEM;
+
fwspec = kzalloc(sizeof(*fwspec), GFP_KERNEL);
if (!fwspec)
return -ENOMEM;
@@ -42,7 +42,6 @@ struct device_node;
struct fwnode_handle;
struct iommu_ops;
struct iommu_group;
-struct iommu_fwspec;
struct dev_pin_info;
struct dev_iommu;
@@ -513,7 +512,6 @@ struct dev_links_info {
* gone away. This should be set by the allocator of the
* device (i.e. the bus driver that discovered the device).
* @iommu_group: IOMMU group the device belongs to.
- * @iommu_fwspec: IOMMU-specific properties supplied by firmware.
* @iommu: Per device generic IOMMU runtime data
*
* @offline_disabled: If set, the device is permanently online.
@@ -613,7 +611,6 @@ struct device {
void (*release)(struct device *dev);
struct iommu_group *iommu_group;
- struct iommu_fwspec *iommu_fwspec;
struct dev_iommu *iommu;
bool offline_disabled:1;
@@ -368,14 +368,15 @@ struct iommu_fault_param {
* struct dev_iommu - Collection of per-device IOMMU data
*
* @fault_param: IOMMU detected device fault reporting data
+ * @fwspec: IOMMU fwspec data
*
* TODO: migrate other per device data pointers under iommu_dev_data, e.g.
* struct iommu_group *iommu_group;
- * struct iommu_fwspec *iommu_fwspec;
*/
struct dev_iommu {
struct mutex lock;
- struct iommu_fault_param *fault_param;
+ struct iommu_fault_param *fault_param;
+ struct iommu_fwspec *fwspec;
};
int iommu_device_register(struct iommu_device *iommu);
@@ -614,13 +615,16 @@ const struct iommu_ops *iommu_ops_from_fwnode(struct fwnode_handle *fwnode);
static inline struct iommu_fwspec *dev_iommu_fwspec_get(struct device *dev)
{
- return dev->iommu_fwspec;
+ if (dev->iommu)
+ return dev->iommu->fwspec;
+ else
+ return NULL;
}
static inline void dev_iommu_fwspec_set(struct device *dev,
struct iommu_fwspec *fwspec)
{
- dev->iommu_fwspec = fwspec;
+ dev->iommu->fwspec = fwspec;
}
int iommu_probe_device(struct device *dev);