diff mbox series

[v3,4/8] vfio/pci: Add support for setting driver data inside core layer

Message ID 20220425092615.10133-5-abhsahu@nvidia.com
State New
Headers show
Series vfio/pci: power management changes | expand

Commit Message

Abhishek Sahu April 25, 2022, 9:26 a.m. UTC
The vfio driver is divided into two layers: core layer (implemented in
vfio_pci_core.c) and parent driver (For example, vfio_pci, mlx5_vfio_pci,
hisi_acc_vfio_pci, etc.). All the parent driver calls dev_set_drvdata()
and assigns its own structure as driver data. Some of the callback
functions are implemented in the core layer and these callback functions
provide the reference of 'struct pci_dev' or 'struct device'. Currently,
we use vfio_device_get_from_dev() which provides reference to the
vfio_device for a device. But this function follows long path to extract
the same. There are few cases, where we don't need to go through this
long path if we get this through drvdata.

This patch moves the setting of drvdata inside the core layer. If we see
the current implementation of parent driver structure implementation,
then 'struct vfio_pci_core_device' is a first member so the pointer of
the parent structure and 'struct vfio_pci_core_device' should be the same.

struct hisi_acc_vf_core_device {
    struct vfio_pci_core_device core_device;
    ...
};

struct mlx5vf_pci_core_device {
    struct vfio_pci_core_device core_device;
    ...
};

The vfio_pci.c uses 'struct vfio_pci_core_device' itself.

To support getting the drvdata in both the layers, we can put the
restriction to make 'struct vfio_pci_core_device' as a first member.
Also, vfio_pci_core_register_device() has this validation which makes sure
that this prerequisite is always satisfied.

Signed-off-by: Abhishek Sahu <abhsahu@nvidia.com>
---
 .../vfio/pci/hisilicon/hisi_acc_vfio_pci.c    |  4 ++--
 drivers/vfio/pci/mlx5/main.c                  |  3 +--
 drivers/vfio/pci/vfio_pci.c                   |  4 ++--
 drivers/vfio/pci/vfio_pci_core.c              | 24 ++++++++++++++++---
 include/linux/vfio_pci_core.h                 |  7 +++++-
 5 files changed, 32 insertions(+), 10 deletions(-)

Comments

Alex Williamson May 3, 2022, 5:11 p.m. UTC | #1
On Mon, 25 Apr 2022 14:56:11 +0530
Abhishek Sahu <abhsahu@nvidia.com> wrote:

> The vfio driver is divided into two layers: core layer (implemented in
> vfio_pci_core.c) and parent driver (For example, vfio_pci, mlx5_vfio_pci,
> hisi_acc_vfio_pci, etc.). All the parent driver calls dev_set_drvdata()
> and assigns its own structure as driver data. Some of the callback
> functions are implemented in the core layer and these callback functions
> provide the reference of 'struct pci_dev' or 'struct device'. Currently,
> we use vfio_device_get_from_dev() which provides reference to the
> vfio_device for a device. But this function follows long path to extract
> the same. There are few cases, where we don't need to go through this
> long path if we get this through drvdata.
> 
> This patch moves the setting of drvdata inside the core layer. If we see
> the current implementation of parent driver structure implementation,
> then 'struct vfio_pci_core_device' is a first member so the pointer of
> the parent structure and 'struct vfio_pci_core_device' should be the same.
> 
> struct hisi_acc_vf_core_device {
>     struct vfio_pci_core_device core_device;
>     ...
> };
> 
> struct mlx5vf_pci_core_device {
>     struct vfio_pci_core_device core_device;
>     ...
> };
> 
> The vfio_pci.c uses 'struct vfio_pci_core_device' itself.
> 
> To support getting the drvdata in both the layers, we can put the
> restriction to make 'struct vfio_pci_core_device' as a first member.
> Also, vfio_pci_core_register_device() has this validation which makes sure
> that this prerequisite is always satisfied.
> 
> Signed-off-by: Abhishek Sahu <abhsahu@nvidia.com>
> ---
>  .../vfio/pci/hisilicon/hisi_acc_vfio_pci.c    |  4 ++--
>  drivers/vfio/pci/mlx5/main.c                  |  3 +--
>  drivers/vfio/pci/vfio_pci.c                   |  4 ++--
>  drivers/vfio/pci/vfio_pci_core.c              | 24 ++++++++++++++++---
>  include/linux/vfio_pci_core.h                 |  7 +++++-
>  5 files changed, 32 insertions(+), 10 deletions(-)
> 
> diff --git a/drivers/vfio/pci/hisilicon/hisi_acc_vfio_pci.c b/drivers/vfio/pci/hisilicon/hisi_acc_vfio_pci.c
> index 767b5d47631a..c76c09302a8f 100644
> --- a/drivers/vfio/pci/hisilicon/hisi_acc_vfio_pci.c
> +++ b/drivers/vfio/pci/hisilicon/hisi_acc_vfio_pci.c
> @@ -1274,11 +1274,11 @@ static int hisi_acc_vfio_pci_probe(struct pci_dev *pdev, const struct pci_device
>  					  &hisi_acc_vfio_pci_ops);
>  	}
>  
> -	ret = vfio_pci_core_register_device(&hisi_acc_vdev->core_device);
> +	ret = vfio_pci_core_register_device(&hisi_acc_vdev->core_device,
> +					    hisi_acc_vdev);
>  	if (ret)
>  		goto out_free;
>  
> -	dev_set_drvdata(&pdev->dev, hisi_acc_vdev);
>  	return 0;
>  
>  out_free:
> diff --git a/drivers/vfio/pci/mlx5/main.c b/drivers/vfio/pci/mlx5/main.c
> index bbec5d288fee..8689248f66f3 100644
> --- a/drivers/vfio/pci/mlx5/main.c
> +++ b/drivers/vfio/pci/mlx5/main.c
> @@ -614,11 +614,10 @@ static int mlx5vf_pci_probe(struct pci_dev *pdev,
>  		}
>  	}
>  
> -	ret = vfio_pci_core_register_device(&mvdev->core_device);
> +	ret = vfio_pci_core_register_device(&mvdev->core_device, mvdev);
>  	if (ret)
>  		goto out_free;
>  
> -	dev_set_drvdata(&pdev->dev, mvdev);
>  	return 0;
>  
>  out_free:
> diff --git a/drivers/vfio/pci/vfio_pci.c b/drivers/vfio/pci/vfio_pci.c
> index 2b047469e02f..e0f8027c5cd8 100644
> --- a/drivers/vfio/pci/vfio_pci.c
> +++ b/drivers/vfio/pci/vfio_pci.c
> @@ -151,10 +151,10 @@ static int vfio_pci_probe(struct pci_dev *pdev, const struct pci_device_id *id)
>  		return -ENOMEM;
>  	vfio_pci_core_init_device(vdev, pdev, &vfio_pci_ops);
>  
> -	ret = vfio_pci_core_register_device(vdev);
> +	ret = vfio_pci_core_register_device(vdev, vdev);
>  	if (ret)
>  		goto out_free;
> -	dev_set_drvdata(&pdev->dev, vdev);
> +
>  	return 0;
>  
>  out_free:
> diff --git a/drivers/vfio/pci/vfio_pci_core.c b/drivers/vfio/pci/vfio_pci_core.c
> index 1271728a09db..953ac33b2f5f 100644
> --- a/drivers/vfio/pci/vfio_pci_core.c
> +++ b/drivers/vfio/pci/vfio_pci_core.c
> @@ -1822,9 +1822,11 @@ void vfio_pci_core_uninit_device(struct vfio_pci_core_device *vdev)
>  }
>  EXPORT_SYMBOL_GPL(vfio_pci_core_uninit_device);
>  
> -int vfio_pci_core_register_device(struct vfio_pci_core_device *vdev)
> +int vfio_pci_core_register_device(struct vfio_pci_core_device *vdev,
> +				  void *driver_data)
>  {
>  	struct pci_dev *pdev = vdev->pdev;
> +	struct device *dev = &pdev->dev;
>  	int ret;
>  
>  	if (pdev->hdr_type != PCI_HEADER_TYPE_NORMAL)
> @@ -1843,6 +1845,17 @@ int vfio_pci_core_register_device(struct vfio_pci_core_device *vdev)
>  		return -EBUSY;
>  	}
>  
> +	/*
> +	 * The 'struct vfio_pci_core_device' should be the first member of the
> +	 * of the structure referenced by 'driver_data' so that it can be
> +	 * retrieved with dev_get_drvdata() inside vfio-pci core layer.
> +	 */
> +	if ((struct vfio_pci_core_device *)driver_data != vdev) {
> +		pci_warn(pdev, "Invalid driver data\n");
> +		return -EINVAL;
> +	}

It seems a bit odd to me to add a driver_data arg to the function,
which is actually required to point to the same thing as the existing
function arg.  Is this just to codify the requirement?  Maybe others
can suggest alternatives.

We also need to collaborate with Jason's patch:

https://lore.kernel.org/all/0-v2-0f36bcf6ec1e+64d-vfio_get_from_dev_jgg@nvidia.com/

(and maybe others)

If we implement a change like proposed here that vfio-pci-core sets
drvdata then we don't need for each variant driver to implement their
own wrapper around err_handler or err_detected as Jason proposes in the
linked patch.  Thanks,

Alex

> +	dev_set_drvdata(dev, driver_data);
> +
>  	if (pci_is_root_bus(pdev->bus)) {
>  		ret = vfio_assign_device_set(&vdev->vdev, vdev);
>  	} else if (!pci_probe_reset_slot(pdev->slot)) {
> @@ -1856,10 +1869,10 @@ int vfio_pci_core_register_device(struct vfio_pci_core_device *vdev)
>  	}
>  
>  	if (ret)
> -		return ret;
> +		goto out_drvdata;
>  	ret = vfio_pci_vf_init(vdev);
>  	if (ret)
> -		return ret;
> +		goto out_drvdata;
>  	ret = vfio_pci_vga_init(vdev);
>  	if (ret)
>  		goto out_vf;
> @@ -1890,6 +1903,8 @@ int vfio_pci_core_register_device(struct vfio_pci_core_device *vdev)
>  		vfio_pci_set_power_state(vdev, PCI_D0);
>  out_vf:
>  	vfio_pci_vf_uninit(vdev);
> +out_drvdata:
> +	dev_set_drvdata(dev, NULL);
>  	return ret;
>  }
>  EXPORT_SYMBOL_GPL(vfio_pci_core_register_device);
> @@ -1897,6 +1912,7 @@ EXPORT_SYMBOL_GPL(vfio_pci_core_register_device);
>  void vfio_pci_core_unregister_device(struct vfio_pci_core_device *vdev)
>  {
>  	struct pci_dev *pdev = vdev->pdev;
> +	struct device *dev = &pdev->dev;
>  
>  	vfio_pci_core_sriov_configure(pdev, 0);
>  
> @@ -1907,6 +1923,8 @@ void vfio_pci_core_unregister_device(struct vfio_pci_core_device *vdev)
>  
>  	if (!disable_idle_d3)
>  		vfio_pci_set_power_state(vdev, PCI_D0);
> +
> +	dev_set_drvdata(dev, NULL);
>  }
>  EXPORT_SYMBOL_GPL(vfio_pci_core_unregister_device);
>  
> diff --git a/include/linux/vfio_pci_core.h b/include/linux/vfio_pci_core.h
> index 505b2a74a479..3c7d65e68340 100644
> --- a/include/linux/vfio_pci_core.h
> +++ b/include/linux/vfio_pci_core.h
> @@ -225,7 +225,12 @@ void vfio_pci_core_close_device(struct vfio_device *core_vdev);
>  void vfio_pci_core_init_device(struct vfio_pci_core_device *vdev,
>  			       struct pci_dev *pdev,
>  			       const struct vfio_device_ops *vfio_pci_ops);
> -int vfio_pci_core_register_device(struct vfio_pci_core_device *vdev);
> +/*
> + * The 'struct vfio_pci_core_device' should be the first member
> + * of the structure referenced by 'driver_data'.
> + */
> +int vfio_pci_core_register_device(struct vfio_pci_core_device *vdev,
> +				  void *driver_data);
>  void vfio_pci_core_uninit_device(struct vfio_pci_core_device *vdev);
>  void vfio_pci_core_unregister_device(struct vfio_pci_core_device *vdev);
>  int vfio_pci_core_sriov_configure(struct pci_dev *pdev, int nr_virtfn);
Jason Gunthorpe May 4, 2022, 12:20 a.m. UTC | #2
On Tue, May 03, 2022 at 11:11:24AM -0600, Alex Williamson wrote:
> > @@ -1843,6 +1845,17 @@ int vfio_pci_core_register_device(struct vfio_pci_core_device *vdev)
> >  		return -EBUSY;
> >  	}
> >  
> > +	/*
> > +	 * The 'struct vfio_pci_core_device' should be the first member of the
> > +	 * of the structure referenced by 'driver_data' so that it can be
> > +	 * retrieved with dev_get_drvdata() inside vfio-pci core layer.
> > +	 */
> > +	if ((struct vfio_pci_core_device *)driver_data != vdev) {
> > +		pci_warn(pdev, "Invalid driver data\n");
> > +		return -EINVAL;
> > +	}
> 
> It seems a bit odd to me to add a driver_data arg to the function,
> which is actually required to point to the same thing as the existing
> function arg.  Is this just to codify the requirement?  Maybe others
> can suggest alternatives.
> 
> We also need to collaborate with Jason's patch:
> 
> https://lore.kernel.org/all/0-v2-0f36bcf6ec1e+64d-vfio_get_from_dev_jgg@nvidia.com/
> 
> (and maybe others)
> 
> If we implement a change like proposed here that vfio-pci-core sets
> drvdata then we don't need for each variant driver to implement their
> own wrapper around err_handler or err_detected as Jason proposes in the
> linked patch.  Thanks,

Oh, I forgot about this series completely.

Yes, we need to pick a method, either drvdata always points at the
core struct, or we wrapper the core functions.

I have an independent version of the above patch that uses the
drvdata, but I chucked it because it was unnecessary for just a couple
of AER functions. 

We should probably go back to it though if we are adding more
functions, as the wrapping is a bit repetitive. I'll go and respin
that series then. Abhishek can base on top of it.

My approach was more type-sane though:

commit 12ba94a72d7aa134af8752d6ff78193acdac93ae
Author: Jason Gunthorpe <jgg@ziepe.ca>
Date:   Tue Mar 29 16:32:32 2022 -0300

    vfio/pci: Have all VFIO PCI drivers store the vfio_pci_core_device in drvdata
    
    Having a consistent pointer in the drvdata will allow the next patch to
    make use of the drvdata from some of the core code helpers.
    
    Use a WARN_ON inside vfio_pci_core_unregister_device() to detect drivers
    that miss this.
    
    Signed-off-by: Jason Gunthorpe <jgg@nvidia.com>

diff --git a/drivers/vfio/pci/hisilicon/hisi_acc_vfio_pci.c b/drivers/vfio/pci/hisilicon/hisi_acc_vfio_pci.c
index 767b5d47631a49..665691967a030c 100644
--- a/drivers/vfio/pci/hisilicon/hisi_acc_vfio_pci.c
+++ b/drivers/vfio/pci/hisilicon/hisi_acc_vfio_pci.c
@@ -337,6 +337,14 @@ static int vf_qm_cache_wb(struct hisi_qm *qm)
 	return 0;
 }
 
+static struct hisi_acc_vf_core_device *hssi_acc_drvdata(struct pci_dev *pdev)
+{
+	struct vfio_pci_core_device *core_device = dev_get_drvdata(&pdev->dev);
+
+	return container_of(core_device, struct hisi_acc_vf_core_device,
+			    core_device);
+}
+
 static void vf_qm_fun_reset(struct hisi_acc_vf_core_device *hisi_acc_vdev,
 			    struct hisi_qm *qm)
 {
@@ -962,7 +970,7 @@ hisi_acc_vfio_pci_get_device_state(struct vfio_device *vdev,
 
 static void hisi_acc_vf_pci_aer_reset_done(struct pci_dev *pdev)
 {
-	struct hisi_acc_vf_core_device *hisi_acc_vdev = dev_get_drvdata(&pdev->dev);
+	struct hisi_acc_vf_core_device *hisi_acc_vdev = hssi_acc_drvdata(pdev);
 
 	if (hisi_acc_vdev->core_device.vdev.migration_flags !=
 				VFIO_MIGRATION_STOP_COPY)
@@ -1278,7 +1286,7 @@ static int hisi_acc_vfio_pci_probe(struct pci_dev *pdev, const struct pci_device
 	if (ret)
 		goto out_free;
 
-	dev_set_drvdata(&pdev->dev, hisi_acc_vdev);
+	dev_set_drvdata(&pdev->dev, &hisi_acc_vdev->core_device);
 	return 0;
 
 out_free:
@@ -1289,7 +1297,7 @@ static int hisi_acc_vfio_pci_probe(struct pci_dev *pdev, const struct pci_device
 
 static void hisi_acc_vfio_pci_remove(struct pci_dev *pdev)
 {
-	struct hisi_acc_vf_core_device *hisi_acc_vdev = dev_get_drvdata(&pdev->dev);
+	struct hisi_acc_vf_core_device *hisi_acc_vdev = hssi_acc_drvdata(pdev);
 
 	vfio_pci_core_unregister_device(&hisi_acc_vdev->core_device);
 	vfio_pci_core_uninit_device(&hisi_acc_vdev->core_device);
diff --git a/drivers/vfio/pci/mlx5/main.c b/drivers/vfio/pci/mlx5/main.c
index bbec5d288fee97..3391f965abd9f0 100644
--- a/drivers/vfio/pci/mlx5/main.c
+++ b/drivers/vfio/pci/mlx5/main.c
@@ -39,6 +39,14 @@ struct mlx5vf_pci_core_device {
 	struct mlx5_vf_migration_file *saving_migf;
 };
 
+static struct mlx5vf_pci_core_device *mlx5vf_drvdata(struct pci_dev *pdev)
+{
+	struct vfio_pci_core_device *core_device = dev_get_drvdata(&pdev->dev);
+
+	return container_of(core_device, struct mlx5vf_pci_core_device,
+			    core_device);
+}
+
 static struct page *
 mlx5vf_get_migration_page(struct mlx5_vf_migration_file *migf,
 			  unsigned long offset)
@@ -505,7 +513,7 @@ static int mlx5vf_pci_get_device_state(struct vfio_device *vdev,
 
 static void mlx5vf_pci_aer_reset_done(struct pci_dev *pdev)
 {
-	struct mlx5vf_pci_core_device *mvdev = dev_get_drvdata(&pdev->dev);
+	struct mlx5vf_pci_core_device *mvdev = mlx5vf_drvdata(pdev);
 
 	if (!mvdev->migrate_cap)
 		return;
@@ -618,7 +626,7 @@ static int mlx5vf_pci_probe(struct pci_dev *pdev,
 	if (ret)
 		goto out_free;
 
-	dev_set_drvdata(&pdev->dev, mvdev);
+	dev_set_drvdata(&pdev->dev, &mvdev->core_device);
 	return 0;
 
 out_free:
@@ -629,7 +637,7 @@ static int mlx5vf_pci_probe(struct pci_dev *pdev,
 
 static void mlx5vf_pci_remove(struct pci_dev *pdev)
 {
-	struct mlx5vf_pci_core_device *mvdev = dev_get_drvdata(&pdev->dev);
+	struct mlx5vf_pci_core_device *mvdev = mlx5vf_drvdata(pdev);
 
 	vfio_pci_core_unregister_device(&mvdev->core_device);
 	vfio_pci_core_uninit_device(&mvdev->core_device);
diff --git a/drivers/vfio/pci/vfio_pci_core.c b/drivers/vfio/pci/vfio_pci_core.c
index 06b6f3594a1316..53ad39d617653d 100644
--- a/drivers/vfio/pci/vfio_pci_core.c
+++ b/drivers/vfio/pci/vfio_pci_core.c
@@ -262,6 +262,10 @@ int vfio_pci_core_enable(struct vfio_pci_core_device *vdev)
 	u16 cmd;
 	u8 msix_pos;
 
+	/* Drivers must set the vfio_pci_core_device to their drvdata */
+	if (WARN_ON(vdev != dev_get_drvdata(&vdev->pdev->dev)))
+		return -EINVAL;
+
 	vfio_pci_set_power_state(vdev, PCI_D0);
 
 	/* Don't allow our initial saved state to include busmaster */
diff mbox series

Patch

diff --git a/drivers/vfio/pci/hisilicon/hisi_acc_vfio_pci.c b/drivers/vfio/pci/hisilicon/hisi_acc_vfio_pci.c
index 767b5d47631a..c76c09302a8f 100644
--- a/drivers/vfio/pci/hisilicon/hisi_acc_vfio_pci.c
+++ b/drivers/vfio/pci/hisilicon/hisi_acc_vfio_pci.c
@@ -1274,11 +1274,11 @@  static int hisi_acc_vfio_pci_probe(struct pci_dev *pdev, const struct pci_device
 					  &hisi_acc_vfio_pci_ops);
 	}
 
-	ret = vfio_pci_core_register_device(&hisi_acc_vdev->core_device);
+	ret = vfio_pci_core_register_device(&hisi_acc_vdev->core_device,
+					    hisi_acc_vdev);
 	if (ret)
 		goto out_free;
 
-	dev_set_drvdata(&pdev->dev, hisi_acc_vdev);
 	return 0;
 
 out_free:
diff --git a/drivers/vfio/pci/mlx5/main.c b/drivers/vfio/pci/mlx5/main.c
index bbec5d288fee..8689248f66f3 100644
--- a/drivers/vfio/pci/mlx5/main.c
+++ b/drivers/vfio/pci/mlx5/main.c
@@ -614,11 +614,10 @@  static int mlx5vf_pci_probe(struct pci_dev *pdev,
 		}
 	}
 
-	ret = vfio_pci_core_register_device(&mvdev->core_device);
+	ret = vfio_pci_core_register_device(&mvdev->core_device, mvdev);
 	if (ret)
 		goto out_free;
 
-	dev_set_drvdata(&pdev->dev, mvdev);
 	return 0;
 
 out_free:
diff --git a/drivers/vfio/pci/vfio_pci.c b/drivers/vfio/pci/vfio_pci.c
index 2b047469e02f..e0f8027c5cd8 100644
--- a/drivers/vfio/pci/vfio_pci.c
+++ b/drivers/vfio/pci/vfio_pci.c
@@ -151,10 +151,10 @@  static int vfio_pci_probe(struct pci_dev *pdev, const struct pci_device_id *id)
 		return -ENOMEM;
 	vfio_pci_core_init_device(vdev, pdev, &vfio_pci_ops);
 
-	ret = vfio_pci_core_register_device(vdev);
+	ret = vfio_pci_core_register_device(vdev, vdev);
 	if (ret)
 		goto out_free;
-	dev_set_drvdata(&pdev->dev, vdev);
+
 	return 0;
 
 out_free:
diff --git a/drivers/vfio/pci/vfio_pci_core.c b/drivers/vfio/pci/vfio_pci_core.c
index 1271728a09db..953ac33b2f5f 100644
--- a/drivers/vfio/pci/vfio_pci_core.c
+++ b/drivers/vfio/pci/vfio_pci_core.c
@@ -1822,9 +1822,11 @@  void vfio_pci_core_uninit_device(struct vfio_pci_core_device *vdev)
 }
 EXPORT_SYMBOL_GPL(vfio_pci_core_uninit_device);
 
-int vfio_pci_core_register_device(struct vfio_pci_core_device *vdev)
+int vfio_pci_core_register_device(struct vfio_pci_core_device *vdev,
+				  void *driver_data)
 {
 	struct pci_dev *pdev = vdev->pdev;
+	struct device *dev = &pdev->dev;
 	int ret;
 
 	if (pdev->hdr_type != PCI_HEADER_TYPE_NORMAL)
@@ -1843,6 +1845,17 @@  int vfio_pci_core_register_device(struct vfio_pci_core_device *vdev)
 		return -EBUSY;
 	}
 
+	/*
+	 * The 'struct vfio_pci_core_device' should be the first member of the
+	 * of the structure referenced by 'driver_data' so that it can be
+	 * retrieved with dev_get_drvdata() inside vfio-pci core layer.
+	 */
+	if ((struct vfio_pci_core_device *)driver_data != vdev) {
+		pci_warn(pdev, "Invalid driver data\n");
+		return -EINVAL;
+	}
+	dev_set_drvdata(dev, driver_data);
+
 	if (pci_is_root_bus(pdev->bus)) {
 		ret = vfio_assign_device_set(&vdev->vdev, vdev);
 	} else if (!pci_probe_reset_slot(pdev->slot)) {
@@ -1856,10 +1869,10 @@  int vfio_pci_core_register_device(struct vfio_pci_core_device *vdev)
 	}
 
 	if (ret)
-		return ret;
+		goto out_drvdata;
 	ret = vfio_pci_vf_init(vdev);
 	if (ret)
-		return ret;
+		goto out_drvdata;
 	ret = vfio_pci_vga_init(vdev);
 	if (ret)
 		goto out_vf;
@@ -1890,6 +1903,8 @@  int vfio_pci_core_register_device(struct vfio_pci_core_device *vdev)
 		vfio_pci_set_power_state(vdev, PCI_D0);
 out_vf:
 	vfio_pci_vf_uninit(vdev);
+out_drvdata:
+	dev_set_drvdata(dev, NULL);
 	return ret;
 }
 EXPORT_SYMBOL_GPL(vfio_pci_core_register_device);
@@ -1897,6 +1912,7 @@  EXPORT_SYMBOL_GPL(vfio_pci_core_register_device);
 void vfio_pci_core_unregister_device(struct vfio_pci_core_device *vdev)
 {
 	struct pci_dev *pdev = vdev->pdev;
+	struct device *dev = &pdev->dev;
 
 	vfio_pci_core_sriov_configure(pdev, 0);
 
@@ -1907,6 +1923,8 @@  void vfio_pci_core_unregister_device(struct vfio_pci_core_device *vdev)
 
 	if (!disable_idle_d3)
 		vfio_pci_set_power_state(vdev, PCI_D0);
+
+	dev_set_drvdata(dev, NULL);
 }
 EXPORT_SYMBOL_GPL(vfio_pci_core_unregister_device);
 
diff --git a/include/linux/vfio_pci_core.h b/include/linux/vfio_pci_core.h
index 505b2a74a479..3c7d65e68340 100644
--- a/include/linux/vfio_pci_core.h
+++ b/include/linux/vfio_pci_core.h
@@ -225,7 +225,12 @@  void vfio_pci_core_close_device(struct vfio_device *core_vdev);
 void vfio_pci_core_init_device(struct vfio_pci_core_device *vdev,
 			       struct pci_dev *pdev,
 			       const struct vfio_device_ops *vfio_pci_ops);
-int vfio_pci_core_register_device(struct vfio_pci_core_device *vdev);
+/*
+ * The 'struct vfio_pci_core_device' should be the first member
+ * of the structure referenced by 'driver_data'.
+ */
+int vfio_pci_core_register_device(struct vfio_pci_core_device *vdev,
+				  void *driver_data);
 void vfio_pci_core_uninit_device(struct vfio_pci_core_device *vdev);
 void vfio_pci_core_unregister_device(struct vfio_pci_core_device *vdev);
 int vfio_pci_core_sriov_configure(struct pci_dev *pdev, int nr_virtfn);