diff mbox series

[v10,14/50] crypto: ccp: Add support to initialize the AMD-SP for SEV-SNP

Message ID 20231016132819.1002933-15-michael.roth@amd.com
State New
Headers show
Series Add AMD Secure Nested Paging (SEV-SNP) Hypervisor Support | expand

Commit Message

Michael Roth Oct. 16, 2023, 1:27 p.m. UTC
From: Brijesh Singh <brijesh.singh@amd.com>

Before SNP VMs can be launched, the platform must be appropriately
configured and initialized. Platform initialization is accomplished via
the SNP_INIT command. Make sure to do a WBINVD and issue DF_FLUSH
command to prepare for the first SNP guest launch after INIT.

During the execution of SNP_INIT command, the firmware configures
and enables SNP security policy enforcement in many system components.
Some system components write to regions of memory reserved by early
x86 firmware (e.g. UEFI). Other system components write to regions
provided by the operation system, hypervisor, or x86 firmware.
Such system components can only write to HV-fixed pages or Default
pages. They will error when attempting to write to other page states
after SNP_INIT enables their SNP enforcement.

Starting in SNP firmware v1.52, the SNP_INIT_EX command takes a list of
system physical address ranges to convert into the HV-fixed page states
during the RMP initialization. If INIT_RMP is 1, hypervisors should
provide all system physical address ranges that the hypervisor will
never assign to a guest until the next RMP re-initialization.
For instance, the memory that UEFI reserves should be included in the
range list. This allows system components that occasionally write to
memory (e.g. logging to UEFI reserved regions) to not fail due to
RMP initialization and SNP enablement.

Note that SNP_INIT(_EX) must not be executed while non-SEV guests are
executing, otherwise it is possible that the system could reset or hang.
The psp_init_on_probe module parameter was added for SEV/SEV-ES support
and the init_ex_path module parameter to allow for time for the
necessary file system to be mounted/available. SNP_INIT(_EX) does not
use the file associated with init_ex_path. So, to avoid running into
issues where SNP_INIT(_EX) is called while there are other running
guests, issue it during module probe regardless of the psp_init_on_probe
setting, but maintain the previous deferrable handling for SEV/SEV-ES
initialization.

Co-developed-by: Ashish Kalra <ashish.kalra@amd.com>
Signed-off-by: Ashish Kalra <ashish.kalra@amd.com>
Co-developed-by: Jarkko Sakkinen <jarkko@profian.com>
Signed-off-by: Jarkko Sakkinen <jarkko@profian.com>
Signed-off-by: Brijesh Singh <brijesh.singh@amd.com>
Signed-off-by: Tom Lendacky <thomas.lendacky@amd.com>
[mdr: squash in psp_init_on_probe changes from Tom]
Signed-off-by: Michael Roth <michael.roth@amd.com>
---
 drivers/crypto/ccp/sev-dev.c | 272 +++++++++++++++++++++++++++++++++--
 drivers/crypto/ccp/sev-dev.h |   2 +
 2 files changed, 259 insertions(+), 15 deletions(-)

Comments

Borislav Petkov Nov. 27, 2023, 9:59 a.m. UTC | #1
On Mon, Oct 16, 2023 at 08:27:43AM -0500, Michael Roth wrote:
> +/*
> + * SEV_DATA_RANGE_LIST:
> + *   Array containing range of pages that firmware transitions to HV-fixed
> + *   page state.
> + */
> +struct sev_data_range_list *snp_range_list;
> +static int __sev_snp_init_locked(int *error);

Put the function above the caller instead of doing a forward
declaration.

>  static inline bool sev_version_greater_or_equal(u8 maj, u8 min)
>  {
>  	struct sev_device *sev = psp_master->sev_data;
> @@ -466,9 +479,9 @@ static inline int __sev_do_init_locked(int *psp_ret)
>  		return __sev_init_locked(psp_ret);
>  }
>  
> -static int __sev_platform_init_locked(int *error)
> +static int ___sev_platform_init_locked(int *error, bool probe)
>  {
> -	int rc = 0, psp_ret = SEV_RET_NO_FW_CALL;
> +	int rc, psp_ret = SEV_RET_NO_FW_CALL;
>  	struct psp_device *psp = psp_master;
>  	struct sev_device *sev;
>  
> @@ -480,6 +493,34 @@ static int __sev_platform_init_locked(int *error)
>  	if (sev->state == SEV_STATE_INIT)
>  		return 0;
>  
> +	/*
> +	 * Legacy guests cannot be running while SNP_INIT(_EX) is executing,
> +	 * so perform SEV-SNP initialization at probe time.
> +	 */
> +	rc = __sev_snp_init_locked(error);
> +	if (rc && rc != -ENODEV) {
> +		/*
> +		 * Don't abort the probe if SNP INIT failed,
> +		 * continue to initialize the legacy SEV firmware.
> +		 */
> +		dev_err(sev->dev, "SEV-SNP: failed to INIT rc %d, error %#x\n", rc, *error);
> +	}
> +
> +	/* Delay SEV/SEV-ES support initialization */
> +	if (probe && !psp_init_on_probe)
> +		return 0;
> +
> +	if (!sev_es_tmr) {
> +		/* Obtain the TMR memory area for SEV-ES use */
> +		sev_es_tmr = sev_fw_alloc(SEV_ES_TMR_SIZE);
> +		if (sev_es_tmr)
> +			/* Must flush the cache before giving it to the firmware */
> +			clflush_cache_range(sev_es_tmr, SEV_ES_TMR_SIZE);
> +		else
> +			dev_warn(sev->dev,
> +				 "SEV: TMR allocation failed, SEV-ES support unavailable\n");
> +		}
> +
>  	if (sev_init_ex_buffer) {
>  		rc = sev_read_init_ex_file();
>  		if (rc)
> @@ -522,6 +563,11 @@ static int __sev_platform_init_locked(int *error)
>  	return 0;
>  }
>  
> +static int __sev_platform_init_locked(int *error)
> +{
> +	return ___sev_platform_init_locked(error, false);
> +}

Uff, this is silly. And it makes the code hard to follow and that meat
of the platform init functionality in the ___-prefixed function a mess.

And the problem is that that "probe" functionality is replicated from
the one place where it is actually needed - sev_pci_init() which calls
that new sev_platform_init_on_probe() function - to everything that
calls __sev_platform_init_locked() for which you've added a wrapper.

What you should do, instead, is split the code around
__sev_snp_init_locked() in a separate function which does only that and
is called something like __sev_platform_init_snp_locked() or so which
does that unconditional work. And then you define:

_sev_platform_init_locked(int *error, bool probe)

note the *one* '_' - i.e., first layer:

_sev_platform_init_locked(int *error, bool probe):
{
	__sev_platform_init_snp_locked(error);

	if (!probe)
		return 0;

	if (psp_init_on_probe)
		__sev_platform_init_locked(error);

	...
}

and you do the probing in that function only so that it doesn't get lost
in the bunch of things __sev_platform_init_locked() does.

And then you call _sev_platform_init_locked() everywhere and no need for
a second sev_platform_init_on_probe().

> +
>  int sev_platform_init(int *error)
>  {
>  	int rc;
> @@ -534,6 +580,17 @@ int sev_platform_init(int *error)
>  }
>  EXPORT_SYMBOL_GPL(sev_platform_init);
>  
> +static int sev_platform_init_on_probe(int *error)
> +{
> +	int rc;
> +
> +	mutex_lock(&sev_cmd_mutex);
> +	rc = ___sev_platform_init_locked(error, true);
> +	mutex_unlock(&sev_cmd_mutex);
> +
> +	return rc;
> +}
> +
>  static int __sev_platform_shutdown_locked(int *error)
>  {
>  	struct sev_device *sev = psp_master->sev_data;
> @@ -838,6 +895,191 @@ static int sev_update_firmware(struct device *dev)
>  	return ret;
>  }
>  
> +static void snp_set_hsave_pa(void *arg)
> +{
> +	wrmsrl(MSR_VM_HSAVE_PA, 0);
> +}
> +
> +static int snp_filter_reserved_mem_regions(struct resource *rs, void *arg)
> +{
> +	struct sev_data_range_list *range_list = arg;
> +	struct sev_data_range *range = &range_list->ranges[range_list->num_elements];
> +	size_t size;
> +
> +	if ((range_list->num_elements * sizeof(struct sev_data_range) +
> +	     sizeof(struct sev_data_range_list)) > PAGE_SIZE)
> +		return -E2BIG;

Why? A comment would be helpful like with the rest this patch adds.

> +	switch (rs->desc) {
> +	case E820_TYPE_RESERVED:
> +	case E820_TYPE_PMEM:
> +	case E820_TYPE_ACPI:
> +		range->base = rs->start & PAGE_MASK;
> +		size = (rs->end + 1) - rs->start;
> +		range->page_count = size >> PAGE_SHIFT;
> +		range_list->num_elements++;
> +		break;
> +	default:
> +		break;
> +	}
> +
> +	return 0;
> +}
> +
> +static int __sev_snp_init_locked(int *error)
> +{
> +	struct psp_device *psp = psp_master;
> +	struct sev_data_snp_init_ex data;
> +	struct sev_device *sev;
> +	int rc = 0;
> +
> +	if (!cpu_feature_enabled(X86_FEATURE_SEV_SNP))
> +		return -ENODEV;
> +
> +	if (!psp || !psp->sev_data)
> +		return -ENODEV;

Only caller checks this already.

> +	sev = psp->sev_data;
> +
> +	if (sev->snp_initialized)

Do we really need this silly boolean or is there a way to query the
platform whether SNP has been initialized?

> +		return 0;
> +
> +	if (!sev_version_greater_or_equal(SNP_MIN_API_MAJOR, SNP_MIN_API_MINOR)) {
> +		dev_dbg(sev->dev, "SEV-SNP support requires firmware version >= %d:%d\n",
> +			SNP_MIN_API_MAJOR, SNP_MIN_API_MINOR);
> +		return 0;
> +	}
> +
> +	/*
> +	 * The SNP_INIT requires the MSR_VM_HSAVE_PA must be set to 0h
> +	 * across all cores.
> +	 */
> +	on_each_cpu(snp_set_hsave_pa, NULL, 1);
> +
> +	/*
> +	 * Starting in SNP firmware v1.52, the SNP_INIT_EX command takes a list of
> +	 * system physical address ranges to convert into the HV-fixed page states
> +	 * during the RMP initialization.  For instance, the memory that UEFI
> +	 * reserves should be included in the range list. This allows system
> +	 * components that occasionally write to memory (e.g. logging to UEFI
> +	 * reserved regions) to not fail due to RMP initialization and SNP enablement.
> +	 */
> +	if (sev_version_greater_or_equal(SNP_MIN_API_MAJOR, 52)) {

Is there a generic way to probe SNP_INIT_EX presence in the firmware or
are FW version numbers the only way?

> +		/*
> +		 * Firmware checks that the pages containing the ranges enumerated
> +		 * in the RANGES structure are either in the Default page state or in the

"default"

> +		 * firmware page state.
> +		 */
> +		snp_range_list = kzalloc(PAGE_SIZE, GFP_KERNEL);
> +		if (!snp_range_list) {
> +			dev_err(sev->dev,
> +				"SEV: SNP_INIT_EX range list memory allocation failed\n");
> +			return -ENOMEM;
> +		}
> +
> +		/*
> +		 * Retrieve all reserved memory regions setup by UEFI from the e820 memory map
> +		 * to be setup as HV-fixed pages.
> +		 */
> +


^ Superfluous newline.

> +		rc = walk_iomem_res_desc(IORES_DESC_NONE, IORESOURCE_MEM, 0, ~0,
> +					 snp_range_list, snp_filter_reserved_mem_regions);
> +		if (rc) {
> +			dev_err(sev->dev,
> +				"SEV: SNP_INIT_EX walk_iomem_res_desc failed rc = %d\n", rc);
> +			return rc;
> +		}
> +
> +		memset(&data, 0, sizeof(data));
> +		data.init_rmp = 1;
> +		data.list_paddr_en = 1;
> +		data.list_paddr = __psp_pa(snp_range_list);
> +
> +		/*
> +		 * Before invoking SNP_INIT_EX with INIT_RMP=1, make sure that
> +		 * all dirty cache lines containing the RMP are flushed.
> +		 *
> +		 * NOTE: that includes writes via RMPUPDATE instructions, which
> +		 * are also cacheable writes.
> +		 */
> +		wbinvd_on_all_cpus();
> +
> +		rc = __sev_do_cmd_locked(SEV_CMD_SNP_INIT_EX, &data, error);
> +		if (rc)
> +			return rc;
> +	} else {
> +		/*
> +		 * SNP_INIT is equivalent to SNP_INIT_EX with INIT_RMP=1, so
> +		 * just as with that case, make sure all dirty cache lines
> +		 * containing the RMP are flushed.
> +		 */
> +		wbinvd_on_all_cpus();
> +
> +		rc = __sev_do_cmd_locked(SEV_CMD_SNP_INIT, NULL, error);
> +		if (rc)
> +			return rc;
> +	}

So instead of duplicating the code here at the end of the if-else
branching, you can do:

	void *arg = &data;

	if () {
		...
		cmd = SEV_CMD_SNP_INIT_EX;
	} else {
		cmd = SEV_CMD_SNP_INIT;
		arg = NULL;
	}

	wbinvd_on_all_cpus();
	rc = __sev_do_cmd_locked(cmd, arg, error);
	if (rc)
		return rc;

> +	/* Prepare for first SNP guest launch after INIT */
> +	wbinvd_on_all_cpus();

Why is that WBINVD needed?

> +	rc = __sev_do_cmd_locked(SEV_CMD_SNP_DF_FLUSH, NULL, error);
> +	if (rc)
> +		return rc;
> +
> +	sev->snp_initialized = true;
> +	dev_dbg(sev->dev, "SEV-SNP firmware initialized\n");
> +
> +	return rc;
> +}
> +
> +static int __sev_snp_shutdown_locked(int *error)
> +{
> +	struct sev_device *sev = psp_master->sev_data;
> +	struct sev_data_snp_shutdown_ex data;
> +	int ret;
> +
> +	if (!sev->snp_initialized)
> +		return 0;
> +
> +	memset(&data, 0, sizeof(data));
> +	data.length = sizeof(data);
> +	data.iommu_snp_shutdown = 1;
> +
> +	wbinvd_on_all_cpus();
> +
> +retry:
> +	ret = __sev_do_cmd_locked(SEV_CMD_SNP_SHUTDOWN_EX, &data, error);
> +	/* SHUTDOWN may require DF_FLUSH */
> +	if (*error == SEV_RET_DFFLUSH_REQUIRED) {
> +		ret = __sev_do_cmd_locked(SEV_CMD_SNP_DF_FLUSH, NULL, NULL);
> +		if (ret) {
> +			dev_err(sev->dev, "SEV-SNP DF_FLUSH failed\n");
> +			return ret;

When you return here,  sev->snp_initialized is still true but, in
reality, it probably is in some half-broken state after issuing those
commands you it is not really initialized anymore.

> +		}
> +		goto retry;

This needs an upper limit from which to break out and not potentially
endless-loop.

> +	}
> +	if (ret) {
> +		dev_err(sev->dev, "SEV-SNP firmware shutdown failed\n");
> +		return ret;
> +	}
> +
> +	sev->snp_initialized = false;
> +	dev_dbg(sev->dev, "SEV-SNP firmware shutdown\n");
> +
> +	return ret;
> +}
> +
> +static int sev_snp_shutdown(int *error)
> +{
> +	int rc;
> +
> +	mutex_lock(&sev_cmd_mutex);
> +	rc = __sev_snp_shutdown_locked(error);

Why is this "locked" version even there if it is called only here?

IOW, put all the logic in here - no need for
__sev_snp_shutdown_locked().

> +	mutex_unlock(&sev_cmd_mutex);
> +
> +	return rc;
> +}

...
Ashish Kalra Nov. 30, 2023, 2:13 a.m. UTC | #2
Hello Boris,

>> +static int ___sev_platform_init_locked(int *error, bool probe)
>>   {
>> -	int rc = 0, psp_ret = SEV_RET_NO_FW_CALL;
>> +	int rc, psp_ret = SEV_RET_NO_FW_CALL;
>>   	struct psp_device *psp = psp_master;
>>   	struct sev_device *sev;
>>   
>> @@ -480,6 +493,34 @@ static int __sev_platform_init_locked(int *error)
>>   	if (sev->state == SEV_STATE_INIT)
>>   		return 0;
>>   
>> +	/*
>> +	 * Legacy guests cannot be running while SNP_INIT(_EX) is executing,
>> +	 * so perform SEV-SNP initialization at probe time.
>> +	 */
>> +	rc = __sev_snp_init_locked(error);
>> +	if (rc && rc != -ENODEV) {
>> +		/*
>> +		 * Don't abort the probe if SNP INIT failed,
>> +		 * continue to initialize the legacy SEV firmware.
>> +		 */
>> +		dev_err(sev->dev, "SEV-SNP: failed to INIT rc %d, error %#x\n", rc, *error);
>> +	}
>> +
>> +	/* Delay SEV/SEV-ES support initialization */
>> +	if (probe && !psp_init_on_probe)
>> +		return 0;
>> +
>> +	if (!sev_es_tmr) {
>> +		/* Obtain the TMR memory area for SEV-ES use */
>> +		sev_es_tmr = sev_fw_alloc(SEV_ES_TMR_SIZE);
>> +		if (sev_es_tmr)
>> +			/* Must flush the cache before giving it to the firmware */
>> +			clflush_cache_range(sev_es_tmr, SEV_ES_TMR_SIZE);
>> +		else
>> +			dev_warn(sev->dev,
>> +				 "SEV: TMR allocation failed, SEV-ES support unavailable\n");
>> +		}
>> +
>>   	if (sev_init_ex_buffer) {
>>   		rc = sev_read_init_ex_file();
>>   		if (rc)
>> @@ -522,6 +563,11 @@ static int __sev_platform_init_locked(int *error)
>>   	return 0;
>>   }
>>   
>> +static int __sev_platform_init_locked(int *error)
>> +{
>> +	return ___sev_platform_init_locked(error, false);
>> +}
> 
> Uff, this is silly. And it makes the code hard to follow and that meat
> of the platform init functionality in the ___-prefixed function a mess.
> 
> And the problem is that that "probe" functionality is replicated from
> the one place where it is actually needed - sev_pci_init() which calls
> that new sev_platform_init_on_probe() function - to everything that
> calls __sev_platform_init_locked() for which you've added a wrapper.
> 
> What you should do, instead, is split the code around
> __sev_snp_init_locked() in a separate function which does only that and
> is called something like __sev_platform_init_snp_locked() or so which
> does that unconditional work. And then you define:
> 
> _sev_platform_init_locked(int *error, bool probe)
> 
> note the *one* '_' - i.e., first layer:
> 
> _sev_platform_init_locked(int *error, bool probe):
> {
> 	__sev_platform_init_snp_locked(error);
> 
> 	if (!probe)
> 		return 0;
> 
> 	if (psp_init_on_probe)
> 		__sev_platform_init_locked(error);
> 
> 	...
> }
> 
> and you do the probing in that function only so that it doesn't get lost
> in the bunch of things __sev_platform_init_locked() does.
> 
> And then you call _sev_platform_init_locked() everywhere and no need for
> a second sev_platform_init_on_probe().
>

It surely seems hard to follow up, so i am anyway going to clean it up by:

Adding the "probe" parameter to sev_platform_init() where the parameter 
being true indicates that we only want to do SNP initialization on 
probe, the same parameter will get passed on to
__sev_platform_init_locked().

So eventually there won't be a second sev_platform_init_on_probe() and 
also there is no need for a ___sev_platform_init_locked().

We will only have sev_platform_init() and _sev_platform_init_locked().

>> +
>> +static int snp_filter_reserved_mem_regions(struct resource *rs, void *arg)
>> +{
>> +	struct sev_data_range_list *range_list = arg;
>> +	struct sev_data_range *range = &range_list->ranges[range_list->num_elements];
>> +	size_t size;
>> +
>> +	if ((range_list->num_elements * sizeof(struct sev_data_range) +
>> +	     sizeof(struct sev_data_range_list)) > PAGE_SIZE)
>> +		return -E2BIG;
> 
> Why? A comment would be helpful like with the rest this patch adds.
>
Ok.

>> +	switch (rs->desc) {
>> +	case E820_TYPE_RESERVED:
>> +	case E820_TYPE_PMEM:
>> +	case E820_TYPE_ACPI:
>> +		range->base = rs->start & PAGE_MASK;
>> +		size = (rs->end + 1) - rs->start;
>> +		range->page_count = size >> PAGE_SHIFT;
>> +		range_list->num_elements++;
>> +		break;
>> +	default:
>> +		break;
>> +	}
>> +
>> +	return 0;
>> +}
>> +
>> +static int __sev_snp_init_locked(int *error)
>> +{
>> +	struct psp_device *psp = psp_master;
>> +	struct sev_data_snp_init_ex data;
>> +	struct sev_device *sev;
>> +	int rc = 0;
>> +
>> +	if (!cpu_feature_enabled(X86_FEATURE_SEV_SNP))
>> +		return -ENODEV;
>> +
>> +	if (!psp || !psp->sev_data)
>> +		return -ENODEV;
> 
> Only caller checks this already.
> 
Ok.

>> +	sev = psp->sev_data;
>> +
>> +	if (sev->snp_initialized)
> 
> Do we really need this silly boolean or is there a way to query the
> platform whether SNP has been initialized?
> 

Yes it makes sense to have it as any platform specific way to query 
whether the SNP has been initialized will be much more expensive then 
simply checking this boolean.

>> +		return 0;
>> +
>> +	if (!sev_version_greater_or_equal(SNP_MIN_API_MAJOR, SNP_MIN_API_MINOR)) {
>> +		dev_dbg(sev->dev, "SEV-SNP support requires firmware version >= %d:%d\n",
>> +			SNP_MIN_API_MAJOR, SNP_MIN_API_MINOR);
>> +		return 0;
>> +	}
>> +
>> +	/*
>> +	 * The SNP_INIT requires the MSR_VM_HSAVE_PA must be set to 0h
>> +	 * across all cores.
>> +	 */
>> +	on_each_cpu(snp_set_hsave_pa, NULL, 1);
>> +
>> +	/*
>> +	 * Starting in SNP firmware v1.52, the SNP_INIT_EX command takes a list of
>> +	 * system physical address ranges to convert into the HV-fixed page states
>> +	 * during the RMP initialization.  For instance, the memory that UEFI
>> +	 * reserves should be included in the range list. This allows system
>> +	 * components that occasionally write to memory (e.g. logging to UEFI
>> +	 * reserved regions) to not fail due to RMP initialization and SNP enablement.
>> +	 */
>> +	if (sev_version_greater_or_equal(SNP_MIN_API_MAJOR, 52)) {
> 
> Is there a generic way to probe SNP_INIT_EX presence in the firmware or
> are FW version numbers the only way?

It is not only the presence of SNP_INIT_EX but this check is more 
specific to passing the HV_Fixed pages list to SNP_INIT_EX and that is 
only supported with SNP FW versions 1.52 and beyond, so the FW version 
check is the only way.

> 
>> +		/*
>> +		 * Firmware checks that the pages containing the ranges enumerated
>> +		 * in the RANGES structure are either in the Default page state or in the
> 
> "default"
> 
>> +		 * firmware page state.
>> +		 */
>> +		snp_range_list = kzalloc(PAGE_SIZE, GFP_KERNEL);
>> +		if (!snp_range_list) {
>> +			dev_err(sev->dev,
>> +				"SEV: SNP_INIT_EX range list memory allocation failed\n");
>> +			return -ENOMEM;
>> +		}
>> +
>> +		/*
>> +		 * Retrieve all reserved memory regions setup by UEFI from the e820 memory map
>> +		 * to be setup as HV-fixed pages.
>> +		 */
>> +
> 
> 
> ^ Superfluous newline.
> 
>> +		rc = walk_iomem_res_desc(IORES_DESC_NONE, IORESOURCE_MEM, 0, ~0,
>> +					 snp_range_list, snp_filter_reserved_mem_regions);
>> +		if (rc) {
>> +			dev_err(sev->dev,
>> +				"SEV: SNP_INIT_EX walk_iomem_res_desc failed rc = %d\n", rc);
>> +			return rc;
>> +		}
>> +
>> +		memset(&data, 0, sizeof(data));
>> +		data.init_rmp = 1;
>> +		data.list_paddr_en = 1;
>> +		data.list_paddr = __psp_pa(snp_range_list);
>> +
>> +		/*
>> +		 * Before invoking SNP_INIT_EX with INIT_RMP=1, make sure that
>> +		 * all dirty cache lines containing the RMP are flushed.
>> +		 *
>> +		 * NOTE: that includes writes via RMPUPDATE instructions, which
>> +		 * are also cacheable writes.
>> +		 */
>> +		wbinvd_on_all_cpus();
>> +
>> +		rc = __sev_do_cmd_locked(SEV_CMD_SNP_INIT_EX, &data, error);
>> +		if (rc)
>> +			return rc;
>> +	} else {
>> +		/*
>> +		 * SNP_INIT is equivalent to SNP_INIT_EX with INIT_RMP=1, so
>> +		 * just as with that case, make sure all dirty cache lines
>> +		 * containing the RMP are flushed.
>> +		 */
>> +		wbinvd_on_all_cpus();
>> +
>> +		rc = __sev_do_cmd_locked(SEV_CMD_SNP_INIT, NULL, error);
>> +		if (rc)
>> +			return rc;
>> +	}
> 
> So instead of duplicating the code here at the end of the if-else
> branching, you can do:
> 
> 	void *arg = &data;
> 
> 	if () {
> 		...
> 		cmd = SEV_CMD_SNP_INIT_EX;
> 	} else {
> 		cmd = SEV_CMD_SNP_INIT;
> 		arg = NULL;
> 	}
> 
> 	wbinvd_on_all_cpus();
> 	rc = __sev_do_cmd_locked(cmd, arg, error);
> 	if (rc)
> 		return rc;

Yes, makes sense, will fix it.

> 
>> +	/* Prepare for first SNP guest launch after INIT */
>> +	wbinvd_on_all_cpus();
> 
> Why is that WBINVD needed?

As the comment above mentions, WBINVD + DF_FLUSH is needed before the 
first guest launch.

> 
>> +	rc = __sev_do_cmd_locked(SEV_CMD_SNP_DF_FLUSH, NULL, error);
>> +	if (rc)
>> +		return rc;
>> +
>> +	sev->snp_initialized = true;
>> +	dev_dbg(sev->dev, "SEV-SNP firmware initialized\n");
>> +
>> +	return rc;
>> +}
>> +
>> +static int __sev_snp_shutdown_locked(int *error)
>> +{
>> +	struct sev_device *sev = psp_master->sev_data;
>> +	struct sev_data_snp_shutdown_ex data;
>> +	int ret;
>> +
>> +	if (!sev->snp_initialized)
>> +		return 0;
>> +
>> +	memset(&data, 0, sizeof(data));
>> +	data.length = sizeof(data);
>> +	data.iommu_snp_shutdown = 1;
>> +
>> +	wbinvd_on_all_cpus();
>> +
>> +retry:
>> +	ret = __sev_do_cmd_locked(SEV_CMD_SNP_SHUTDOWN_EX, &data, error);
>> +	/* SHUTDOWN may require DF_FLUSH */
>> +	if (*error == SEV_RET_DFFLUSH_REQUIRED) {
>> +		ret = __sev_do_cmd_locked(SEV_CMD_SNP_DF_FLUSH, NULL, NULL);
>> +		if (ret) {
>> +			dev_err(sev->dev, "SEV-SNP DF_FLUSH failed\n");
>> +			return ret;
> 
> When you return here,  sev->snp_initialized is still true but, in
> reality, it probably is in some half-broken state after issuing those
> commands you it is not really initialized anymore.

Yes, this needs to be fixed.

> 
>> +		}
>> +		goto retry;
> 
> This needs an upper limit from which to break out and not potentially
> endless-loop.
>

Ok.

>> +	}
>> +	if (ret) {
>> +		dev_err(sev->dev, "SEV-SNP firmware shutdown failed\n");
>> +		return ret;
>> +	}
>> +
>> +	sev->snp_initialized = false;
>> +	dev_dbg(sev->dev, "SEV-SNP firmware shutdown\n");
>> +
>> +	return ret;
>> +}
>> +
>> +static int sev_snp_shutdown(int *error)
>> +{
>> +	int rc;
>> +
>> +	mutex_lock(&sev_cmd_mutex);
>> +	rc = __sev_snp_shutdown_locked(error);
> 
> Why is this "locked" version even there if it is called only here?
> 
> IOW, put all the logic in here - no need for
> __sev_snp_shutdown_locked().

In the latest code base, _sev_snp_shutdown_locked() is called from
__sev_firmware_shutdown().

Thanks,
Ashish

> 
>> +	mutex_unlock(&sev_cmd_mutex);
>> +
>> +	return rc;
>> +}
> 
> ...
>
Borislav Petkov Dec. 6, 2023, 5:08 p.m. UTC | #3
On Wed, Nov 29, 2023 at 08:13:52PM -0600, Kalra, Ashish wrote:
> It surely seems hard to follow up, so i am anyway going to clean it up by:
> 
> Adding the "probe" parameter to sev_platform_init() where the parameter
> being true indicates that we only want to do SNP initialization on probe,
> the same parameter will get passed on to
> __sev_platform_init_locked().

That's exactly what you should *not* do - the probe parameter controls
whether

	if (psp_init_on_probe)
		__sev_platform_init_locked(error);

and so on should get executed or not.

If it is unclear, lemme know and I'll do a diff to show you what I mean.
	
> > > +	/* Prepare for first SNP guest launch after INIT */
> > > +	wbinvd_on_all_cpus();
> > 
> > Why is that WBINVD needed?
> 
> As the comment above mentions, WBINVD + DF_FLUSH is needed before the first
> guest launch.

Lemme see if I get this straight. The correct order is:

	WBINVD
	SNP_INIT_*
	WBINVD
	DF_FLUSH

If so, do a comment which goes like this:

	/*
	 * The order of commands to execute before the first guest
	 * launch is the following:
	 *
	 * bla...
	 */


> In the latest code base, _sev_snp_shutdown_locked() is called from
> __sev_firmware_shutdown().

Then carve that function out only when needed - do not do changes
preemptively. This is not helping during review.

Thx.
Ashish Kalra Dec. 6, 2023, 8:35 p.m. UTC | #4
Hello Boris,

On 12/6/2023 11:08 AM, Borislav Petkov wrote:
> On Wed, Nov 29, 2023 at 08:13:52PM -0600, Kalra, Ashish wrote:
>> It surely seems hard to follow up, so i am anyway going to clean it up by:
>>
>> Adding the "probe" parameter to sev_platform_init() where the parameter
>> being true indicates that we only want to do SNP initialization on probe,
>> the same parameter will get passed on to
>> __sev_platform_init_locked().
> 
> That's exactly what you should *not* do - the probe parameter controls
> whether
> 
> 	if (psp_init_on_probe)
> 		__sev_platform_init_locked(error);
> 
> and so on should get executed or not.
>

Not actually.

The main use case for the probe parameter is to control if we want to do 
legacy SEV/SEV-ES INIT during probe. There is a usage case where we want 
to delay legacy SEV INIT till an actual SEV/SEV-ES guest is being 
launched. So essentially the probe parameter controls if we want to
execute __sev_do_init_locked() or not.

We always want to do SNP INIT at probe time.

Thanks,
Ashish
Borislav Petkov Dec. 9, 2023, 4:20 p.m. UTC | #5
On Wed, Dec 06, 2023 at 02:35:28PM -0600, Kalra, Ashish wrote:
> The main use case for the probe parameter is to control if we want to do
> legacy SEV/SEV-ES INIT during probe. There is a usage case where we want to
> delay legacy SEV INIT till an actual SEV/SEV-ES guest is being launched. So
> essentially the probe parameter controls if we want to
> execute __sev_do_init_locked() or not.
> 
> We always want to do SNP INIT at probe time.

Here's what I mean (diff ontop):

diff --git a/drivers/crypto/ccp/sev-dev.c b/drivers/crypto/ccp/sev-dev.c
index fae1fd45eccd..830d74fcf950 100644
--- a/drivers/crypto/ccp/sev-dev.c
+++ b/drivers/crypto/ccp/sev-dev.c
@@ -479,11 +479,16 @@ static inline int __sev_do_init_locked(int *psp_ret)
 		return __sev_init_locked(psp_ret);
 }
 
-static int ___sev_platform_init_locked(int *error, bool probe)
+/*
+ * Legacy guests cannot be running while SNP_INIT(_EX) is executing,
+ * so perform SEV-SNP initialization at probe time.
+ */
+static int __sev_platform_init_snp_locked(int *error)
 {
-	int rc, psp_ret = SEV_RET_NO_FW_CALL;
+
 	struct psp_device *psp = psp_master;
 	struct sev_device *sev;
+	int rc;
 
 	if (!psp || !psp->sev_data)
 		return -ENODEV;
@@ -493,10 +498,6 @@ static int ___sev_platform_init_locked(int *error, bool probe)
 	if (sev->state == SEV_STATE_INIT)
 		return 0;
 
-	/*
-	 * Legacy guests cannot be running while SNP_INIT(_EX) is executing,
-	 * so perform SEV-SNP initialization at probe time.
-	 */
 	rc = __sev_snp_init_locked(error);
 	if (rc && rc != -ENODEV) {
 		/*
@@ -506,8 +507,21 @@ static int ___sev_platform_init_locked(int *error, bool probe)
 		dev_err(sev->dev, "SEV-SNP: failed to INIT rc %d, error %#x\n", rc, *error);
 	}
 
-	/* Delay SEV/SEV-ES support initialization */
-	if (probe && !psp_init_on_probe)
+	return rc;
+}
+
+static int __sev_platform_init_locked(int *error)
+{
+	int rc, psp_ret = SEV_RET_NO_FW_CALL;
+	struct psp_device *psp = psp_master;
+	struct sev_device *sev;
+
+	if (!psp || !psp->sev_data)
+		return -ENODEV;
+
+	sev = psp->sev_data;
+
+	if (sev->state == SEV_STATE_INIT)
 		return 0;
 
 	if (!sev_es_tmr) {
@@ -563,33 +577,32 @@ static int ___sev_platform_init_locked(int *error, bool probe)
 	return 0;
 }
 
-static int __sev_platform_init_locked(int *error)
-{
-	return ___sev_platform_init_locked(error, false);
-}
-
-int sev_platform_init(int *error)
+static int _sev_platform_init_locked(int *error, bool probe)
 {
 	int rc;
 
-	mutex_lock(&sev_cmd_mutex);
-	rc = __sev_platform_init_locked(error);
-	mutex_unlock(&sev_cmd_mutex);
+	rc = __sev_platform_init_snp_locked(error);
+	if (rc)
+		return rc;
 
-	return rc;
+	/* Delay SEV/SEV-ES support initialization */
+	if (probe && !psp_init_on_probe)
+		return 0;
+
+	return __sev_platform_init_locked(error);
 }
-EXPORT_SYMBOL_GPL(sev_platform_init);
 
-static int sev_platform_init_on_probe(int *error)
+int sev_platform_init(int *error)
 {
 	int rc;
 
 	mutex_lock(&sev_cmd_mutex);
-	rc = ___sev_platform_init_locked(error, true);
+	rc = _sev_platform_init_locked(error, false);
 	mutex_unlock(&sev_cmd_mutex);
 
 	return rc;
 }
+EXPORT_SYMBOL_GPL(sev_platform_init);
 
 static int __sev_platform_shutdown_locked(int *error)
 {
@@ -691,7 +704,7 @@ static int sev_ioctl_do_pek_pdh_gen(int cmd, struct sev_issue_cmd *argp, bool wr
 		return -EPERM;
 
 	if (sev->state == SEV_STATE_UNINIT) {
-		rc = __sev_platform_init_locked(&argp->error);
+		rc = _sev_platform_init_locked(&argp->error, false);
 		if (rc)
 			return rc;
 	}
@@ -734,7 +747,7 @@ static int sev_ioctl_do_pek_csr(struct sev_issue_cmd *argp, bool writable)
 
 cmd:
 	if (sev->state == SEV_STATE_UNINIT) {
-		ret = __sev_platform_init_locked(&argp->error);
+		ret = _sev_platform_init_locked(&argp->error, false);
 		if (ret)
 			goto e_free_blob;
 	}
@@ -1115,7 +1128,7 @@ static int sev_ioctl_do_pek_import(struct sev_issue_cmd *argp, bool writable)
 
 	/* If platform is not in INIT state then transition it to INIT */
 	if (sev->state != SEV_STATE_INIT) {
-		ret = __sev_platform_init_locked(&argp->error);
+		ret = _sev_platform_init_locked(&argp->error, false);
 		if (ret)
 			goto e_free_oca;
 	}
@@ -1246,7 +1259,7 @@ static int sev_ioctl_do_pdh_export(struct sev_issue_cmd *argp, bool writable)
 		if (!writable)
 			return -EPERM;
 
-		ret = __sev_platform_init_locked(&argp->error);
+		ret = _sev_platform_init_locked(&argp->error, false);
 		if (ret)
 			return ret;
 	}
@@ -1608,7 +1621,9 @@ void sev_pci_init(void)
 	}
 
 	/* Initialize the platform */
-	rc = sev_platform_init_on_probe(&error);
+	mutex_lock(&sev_cmd_mutex);
+	rc = _sev_platform_init_locked(&error, true);
+	mutex_unlock(&sev_cmd_mutex);
 	if (rc)
 		dev_err(sev->dev, "SEV: failed to INIT error %#x, rc %d\n",
 			error, rc);
Ashish Kalra Dec. 11, 2023, 9:11 p.m. UTC | #6
Hello Boris,

On 12/9/2023 10:20 AM, Borislav Petkov wrote:
> On Wed, Dec 06, 2023 at 02:35:28PM -0600, Kalra, Ashish wrote:
>> The main use case for the probe parameter is to control if we want to doHl
>> legacy SEV/SEV-ES INIT during probe. There is a usage case where we want to
>> delay legacy SEV INIT till an actual SEV/SEV-ES guest is being launched. So
>> essentially the probe parameter controls if we want to
>> execute __sev_do_init_locked() or not.
>>
>> We always want to do SNP INIT at probe time.
> 
> Here's what I mean (diff ontop):
> 

See my comments below on this patch:

> +int sev_platform_init(int *error)
>   {
>   	int rc;
>   
>   	mutex_lock(&sev_cmd_mutex);
> -	rc = ___sev_platform_init_locked(error, true);
> +	rc = _sev_platform_init_locked(error, false);
>   	mutex_unlock(&sev_cmd_mutex);
>   
>   	return rc;
>   }
> +EXPORT_SYMBOL_GPL(sev_platform_init);
>   

What we need is a mechanism to do legacy SEV/SEV-ES INIT only if a 
SEV/SEV-ES guest is being launched, hence, we want an additional 
parameter added to sev_platform_init() exported interface so that 
kvm_amd module can call this interface during guest launch and indicate 
if SNP/legacy guest is being launched.

That's the reason we want to add the probe parameter to
sev_platform_init().

And to address your previous comments, this will remain a clean 
interface, there are going to be only two functions:
sev_platform_init() & __sev_platform_init_locked().

Thanks,
Ashish
Borislav Petkov Dec. 12, 2023, 6:52 a.m. UTC | #7
On Mon, Dec 11, 2023 at 03:11:17PM -0600, Kalra, Ashish wrote:
> What we need is a mechanism to do legacy SEV/SEV-ES INIT only if a
> SEV/SEV-ES guest is being launched, hence, we want an additional parameter
> added to sev_platform_init() exported interface so that kvm_amd module can
> call this interface during guest launch and indicate if SNP/legacy guest is
> being launched.
> 
> That's the reason we want to add the probe parameter to
> sev_platform_init().

That's not what your original patch does and nowhere in the whole
patchset do I see this new requirement for KVM to be able to control the
probing.

The probe param is added to ___sev_platform_init_locked() which is
called by this new sev_platform_init_on_probe() thing to signal that
whatever calls this, it wants the probing.

And "whatever" is sev_pci_init() which is called from the bowels of the
secure processor drivers. Suffice it to say, this is some sort of an
init path.

So, it wants to init SNP stuff which is unconditional during driver init
- not when KVM starts guests - and probe too on driver init time, *iff*
that psp_init_on_probe thing is set. Which looks suspicious to me:

  "Add psp_init_on_probe module parameter that allows for skipping the
  PSP's SEV platform initialization during module init. User may decouple
  module init from PSP init due to use of the INIT_EX support in upcoming
  patch which allows for users to save PSP's internal state to file."
diff mbox series

Patch

diff --git a/drivers/crypto/ccp/sev-dev.c b/drivers/crypto/ccp/sev-dev.c
index c2da92f19ccd..fae1fd45eccd 100644
--- a/drivers/crypto/ccp/sev-dev.c
+++ b/drivers/crypto/ccp/sev-dev.c
@@ -29,6 +29,7 @@ 
 
 #include <asm/smp.h>
 #include <asm/cacheflush.h>
+#include <asm/e820/types.h>
 
 #include "psp-dev.h"
 #include "sev-dev.h"
@@ -37,6 +38,10 @@ 
 #define SEV_FW_FILE		"amd/sev.fw"
 #define SEV_FW_NAME_SIZE	64
 
+/* Minimum firmware version required for the SEV-SNP support */
+#define SNP_MIN_API_MAJOR	1
+#define SNP_MIN_API_MINOR	51
+
 static DEFINE_MUTEX(sev_cmd_mutex);
 static struct sev_misc_dev *misc_dev;
 
@@ -80,6 +85,14 @@  static void *sev_es_tmr;
 #define NV_LENGTH (32 * 1024)
 static void *sev_init_ex_buffer;
 
+/*
+ * SEV_DATA_RANGE_LIST:
+ *   Array containing range of pages that firmware transitions to HV-fixed
+ *   page state.
+ */
+struct sev_data_range_list *snp_range_list;
+static int __sev_snp_init_locked(int *error);
+
 static inline bool sev_version_greater_or_equal(u8 maj, u8 min)
 {
 	struct sev_device *sev = psp_master->sev_data;
@@ -466,9 +479,9 @@  static inline int __sev_do_init_locked(int *psp_ret)
 		return __sev_init_locked(psp_ret);
 }
 
-static int __sev_platform_init_locked(int *error)
+static int ___sev_platform_init_locked(int *error, bool probe)
 {
-	int rc = 0, psp_ret = SEV_RET_NO_FW_CALL;
+	int rc, psp_ret = SEV_RET_NO_FW_CALL;
 	struct psp_device *psp = psp_master;
 	struct sev_device *sev;
 
@@ -480,6 +493,34 @@  static int __sev_platform_init_locked(int *error)
 	if (sev->state == SEV_STATE_INIT)
 		return 0;
 
+	/*
+	 * Legacy guests cannot be running while SNP_INIT(_EX) is executing,
+	 * so perform SEV-SNP initialization at probe time.
+	 */
+	rc = __sev_snp_init_locked(error);
+	if (rc && rc != -ENODEV) {
+		/*
+		 * Don't abort the probe if SNP INIT failed,
+		 * continue to initialize the legacy SEV firmware.
+		 */
+		dev_err(sev->dev, "SEV-SNP: failed to INIT rc %d, error %#x\n", rc, *error);
+	}
+
+	/* Delay SEV/SEV-ES support initialization */
+	if (probe && !psp_init_on_probe)
+		return 0;
+
+	if (!sev_es_tmr) {
+		/* Obtain the TMR memory area for SEV-ES use */
+		sev_es_tmr = sev_fw_alloc(SEV_ES_TMR_SIZE);
+		if (sev_es_tmr)
+			/* Must flush the cache before giving it to the firmware */
+			clflush_cache_range(sev_es_tmr, SEV_ES_TMR_SIZE);
+		else
+			dev_warn(sev->dev,
+				 "SEV: TMR allocation failed, SEV-ES support unavailable\n");
+		}
+
 	if (sev_init_ex_buffer) {
 		rc = sev_read_init_ex_file();
 		if (rc)
@@ -522,6 +563,11 @@  static int __sev_platform_init_locked(int *error)
 	return 0;
 }
 
+static int __sev_platform_init_locked(int *error)
+{
+	return ___sev_platform_init_locked(error, false);
+}
+
 int sev_platform_init(int *error)
 {
 	int rc;
@@ -534,6 +580,17 @@  int sev_platform_init(int *error)
 }
 EXPORT_SYMBOL_GPL(sev_platform_init);
 
+static int sev_platform_init_on_probe(int *error)
+{
+	int rc;
+
+	mutex_lock(&sev_cmd_mutex);
+	rc = ___sev_platform_init_locked(error, true);
+	mutex_unlock(&sev_cmd_mutex);
+
+	return rc;
+}
+
 static int __sev_platform_shutdown_locked(int *error)
 {
 	struct sev_device *sev = psp_master->sev_data;
@@ -838,6 +895,191 @@  static int sev_update_firmware(struct device *dev)
 	return ret;
 }
 
+static void snp_set_hsave_pa(void *arg)
+{
+	wrmsrl(MSR_VM_HSAVE_PA, 0);
+}
+
+static int snp_filter_reserved_mem_regions(struct resource *rs, void *arg)
+{
+	struct sev_data_range_list *range_list = arg;
+	struct sev_data_range *range = &range_list->ranges[range_list->num_elements];
+	size_t size;
+
+	if ((range_list->num_elements * sizeof(struct sev_data_range) +
+	     sizeof(struct sev_data_range_list)) > PAGE_SIZE)
+		return -E2BIG;
+
+	switch (rs->desc) {
+	case E820_TYPE_RESERVED:
+	case E820_TYPE_PMEM:
+	case E820_TYPE_ACPI:
+		range->base = rs->start & PAGE_MASK;
+		size = (rs->end + 1) - rs->start;
+		range->page_count = size >> PAGE_SHIFT;
+		range_list->num_elements++;
+		break;
+	default:
+		break;
+	}
+
+	return 0;
+}
+
+static int __sev_snp_init_locked(int *error)
+{
+	struct psp_device *psp = psp_master;
+	struct sev_data_snp_init_ex data;
+	struct sev_device *sev;
+	int rc = 0;
+
+	if (!cpu_feature_enabled(X86_FEATURE_SEV_SNP))
+		return -ENODEV;
+
+	if (!psp || !psp->sev_data)
+		return -ENODEV;
+
+	sev = psp->sev_data;
+
+	if (sev->snp_initialized)
+		return 0;
+
+	if (!sev_version_greater_or_equal(SNP_MIN_API_MAJOR, SNP_MIN_API_MINOR)) {
+		dev_dbg(sev->dev, "SEV-SNP support requires firmware version >= %d:%d\n",
+			SNP_MIN_API_MAJOR, SNP_MIN_API_MINOR);
+		return 0;
+	}
+
+	/*
+	 * The SNP_INIT requires the MSR_VM_HSAVE_PA must be set to 0h
+	 * across all cores.
+	 */
+	on_each_cpu(snp_set_hsave_pa, NULL, 1);
+
+	/*
+	 * Starting in SNP firmware v1.52, the SNP_INIT_EX command takes a list of
+	 * system physical address ranges to convert into the HV-fixed page states
+	 * during the RMP initialization.  For instance, the memory that UEFI
+	 * reserves should be included in the range list. This allows system
+	 * components that occasionally write to memory (e.g. logging to UEFI
+	 * reserved regions) to not fail due to RMP initialization and SNP enablement.
+	 */
+	if (sev_version_greater_or_equal(SNP_MIN_API_MAJOR, 52)) {
+		/*
+		 * Firmware checks that the pages containing the ranges enumerated
+		 * in the RANGES structure are either in the Default page state or in the
+		 * firmware page state.
+		 */
+		snp_range_list = kzalloc(PAGE_SIZE, GFP_KERNEL);
+		if (!snp_range_list) {
+			dev_err(sev->dev,
+				"SEV: SNP_INIT_EX range list memory allocation failed\n");
+			return -ENOMEM;
+		}
+
+		/*
+		 * Retrieve all reserved memory regions setup by UEFI from the e820 memory map
+		 * to be setup as HV-fixed pages.
+		 */
+
+		rc = walk_iomem_res_desc(IORES_DESC_NONE, IORESOURCE_MEM, 0, ~0,
+					 snp_range_list, snp_filter_reserved_mem_regions);
+		if (rc) {
+			dev_err(sev->dev,
+				"SEV: SNP_INIT_EX walk_iomem_res_desc failed rc = %d\n", rc);
+			return rc;
+		}
+
+		memset(&data, 0, sizeof(data));
+		data.init_rmp = 1;
+		data.list_paddr_en = 1;
+		data.list_paddr = __psp_pa(snp_range_list);
+
+		/*
+		 * Before invoking SNP_INIT_EX with INIT_RMP=1, make sure that
+		 * all dirty cache lines containing the RMP are flushed.
+		 *
+		 * NOTE: that includes writes via RMPUPDATE instructions, which
+		 * are also cacheable writes.
+		 */
+		wbinvd_on_all_cpus();
+
+		rc = __sev_do_cmd_locked(SEV_CMD_SNP_INIT_EX, &data, error);
+		if (rc)
+			return rc;
+	} else {
+		/*
+		 * SNP_INIT is equivalent to SNP_INIT_EX with INIT_RMP=1, so
+		 * just as with that case, make sure all dirty cache lines
+		 * containing the RMP are flushed.
+		 */
+		wbinvd_on_all_cpus();
+
+		rc = __sev_do_cmd_locked(SEV_CMD_SNP_INIT, NULL, error);
+		if (rc)
+			return rc;
+	}
+
+	/* Prepare for first SNP guest launch after INIT */
+	wbinvd_on_all_cpus();
+	rc = __sev_do_cmd_locked(SEV_CMD_SNP_DF_FLUSH, NULL, error);
+	if (rc)
+		return rc;
+
+	sev->snp_initialized = true;
+	dev_dbg(sev->dev, "SEV-SNP firmware initialized\n");
+
+	return rc;
+}
+
+static int __sev_snp_shutdown_locked(int *error)
+{
+	struct sev_device *sev = psp_master->sev_data;
+	struct sev_data_snp_shutdown_ex data;
+	int ret;
+
+	if (!sev->snp_initialized)
+		return 0;
+
+	memset(&data, 0, sizeof(data));
+	data.length = sizeof(data);
+	data.iommu_snp_shutdown = 1;
+
+	wbinvd_on_all_cpus();
+
+retry:
+	ret = __sev_do_cmd_locked(SEV_CMD_SNP_SHUTDOWN_EX, &data, error);
+	/* SHUTDOWN may require DF_FLUSH */
+	if (*error == SEV_RET_DFFLUSH_REQUIRED) {
+		ret = __sev_do_cmd_locked(SEV_CMD_SNP_DF_FLUSH, NULL, NULL);
+		if (ret) {
+			dev_err(sev->dev, "SEV-SNP DF_FLUSH failed\n");
+			return ret;
+		}
+		goto retry;
+	}
+	if (ret) {
+		dev_err(sev->dev, "SEV-SNP firmware shutdown failed\n");
+		return ret;
+	}
+
+	sev->snp_initialized = false;
+	dev_dbg(sev->dev, "SEV-SNP firmware shutdown\n");
+
+	return ret;
+}
+
+static int sev_snp_shutdown(int *error)
+{
+	int rc;
+
+	mutex_lock(&sev_cmd_mutex);
+	rc = __sev_snp_shutdown_locked(error);
+	mutex_unlock(&sev_cmd_mutex);
+
+	return rc;
+}
+
 static int sev_ioctl_do_pek_import(struct sev_issue_cmd *argp, bool writable)
 {
 	struct sev_device *sev = psp_master->sev_data;
@@ -1285,6 +1527,8 @@  int sev_dev_init(struct psp_device *psp)
 
 static void sev_firmware_shutdown(struct sev_device *sev)
 {
+	int error;
+
 	sev_platform_shutdown(NULL);
 
 	if (sev_es_tmr) {
@@ -1301,6 +1545,13 @@  static void sev_firmware_shutdown(struct sev_device *sev)
 			   get_order(NV_LENGTH));
 		sev_init_ex_buffer = NULL;
 	}
+
+	if (snp_range_list) {
+		kfree(snp_range_list);
+		snp_range_list = NULL;
+	}
+
+	sev_snp_shutdown(&error);
 }
 
 void sev_dev_destroy(struct psp_device *psp)
@@ -1356,24 +1607,15 @@  void sev_pci_init(void)
 		}
 	}
 
-	/* Obtain the TMR memory area for SEV-ES use */
-	sev_es_tmr = sev_fw_alloc(SEV_ES_TMR_SIZE);
-	if (sev_es_tmr)
-		/* Must flush the cache before giving it to the firmware */
-		clflush_cache_range(sev_es_tmr, SEV_ES_TMR_SIZE);
-	else
-		dev_warn(sev->dev,
-			 "SEV: TMR allocation failed, SEV-ES support unavailable\n");
-
-	if (!psp_init_on_probe)
-		return;
-
 	/* Initialize the platform */
-	rc = sev_platform_init(&error);
+	rc = sev_platform_init_on_probe(&error);
 	if (rc)
 		dev_err(sev->dev, "SEV: failed to INIT error %#x, rc %d\n",
 			error, rc);
 
+	dev_info(sev->dev, "SEV%s API:%d.%d build:%d\n", sev->snp_initialized ?
+		"-SNP" : "", sev->api_major, sev->api_minor, sev->build);
+
 	return;
 
 err:
diff --git a/drivers/crypto/ccp/sev-dev.h b/drivers/crypto/ccp/sev-dev.h
index 778c95155e74..85506325051a 100644
--- a/drivers/crypto/ccp/sev-dev.h
+++ b/drivers/crypto/ccp/sev-dev.h
@@ -52,6 +52,8 @@  struct sev_device {
 	u8 build;
 
 	void *cmd_buf;
+
+	bool snp_initialized;
 };
 
 int sev_dev_init(struct psp_device *psp);