diff mbox series

[Part2,v6,09/49] x86/fault: Add support to handle the RMP fault for user address

Message ID 0ecb0a4781be933fcadeb56a85070818ef3566e7.1655761627.git.ashish.kalra@amd.com
State New
Headers show
Series Add AMD Secure Nested Paging (SEV-SNP) | expand

Commit Message

Ashish Kalra June 20, 2022, 11:03 p.m. UTC
From: Brijesh Singh <brijesh.singh@amd.com>

When SEV-SNP is enabled globally, a write from the host goes through the
RMP check. When the host writes to pages, hardware checks the following
conditions at the end of page walk:

1. Assigned bit in the RMP table is zero (i.e page is shared).
2. If the page table entry that gives the sPA indicates that the target
   page size is a large page, then all RMP entries for the 4KB
   constituting pages of the target must have the assigned bit 0.
3. Immutable bit in the RMP table is not zero.

The hardware will raise page fault if one of the above conditions is not
met. Try resolving the fault instead of taking fault again and again. If
the host attempts to write to the guest private memory then send the
SIGBUS signal to kill the process. If the page level between the host and
RMP entry does not match, then split the address to keep the RMP and host
page levels in sync.

Signed-off-by: Brijesh Singh <brijesh.singh@amd.com>
---
 arch/x86/mm/fault.c      | 66 ++++++++++++++++++++++++++++++++++++++++
 include/linux/mm.h       |  3 +-
 include/linux/mm_types.h |  3 ++
 mm/memory.c              | 13 ++++++++
 4 files changed, 84 insertions(+), 1 deletion(-)

Comments

Jeremi Piotrowski June 22, 2022, 2:29 p.m. UTC | #1
On Mon, Jun 20, 2022 at 11:03:43PM +0000, Ashish Kalra wrote:
> From: Brijesh Singh <brijesh.singh@amd.com>
> 
> When SEV-SNP is enabled globally, a write from the host goes through the
> RMP check. When the host writes to pages, hardware checks the following
> conditions at the end of page walk:
> 
> 1. Assigned bit in the RMP table is zero (i.e page is shared).
> 2. If the page table entry that gives the sPA indicates that the target
>    page size is a large page, then all RMP entries for the 4KB
>    constituting pages of the target must have the assigned bit 0.
> 3. Immutable bit in the RMP table is not zero.
> 
> The hardware will raise page fault if one of the above conditions is not
> met. Try resolving the fault instead of taking fault again and again. If
> the host attempts to write to the guest private memory then send the
> SIGBUS signal to kill the process. If the page level between the host and
> RMP entry does not match, then split the address to keep the RMP and host
> page levels in sync.
> 
> Signed-off-by: Brijesh Singh <brijesh.singh@amd.com>
> ---
>  arch/x86/mm/fault.c      | 66 ++++++++++++++++++++++++++++++++++++++++
>  include/linux/mm.h       |  3 +-
>  include/linux/mm_types.h |  3 ++
>  mm/memory.c              | 13 ++++++++
>  4 files changed, 84 insertions(+), 1 deletion(-)
> 
> diff --git a/arch/x86/mm/fault.c b/arch/x86/mm/fault.c
> index a4c270e99f7f..f5de9673093a 100644
> --- a/arch/x86/mm/fault.c
> +++ b/arch/x86/mm/fault.c
> @@ -19,6 +19,7 @@
>  #include <linux/uaccess.h>		/* faulthandler_disabled()	*/
>  #include <linux/efi.h>			/* efi_crash_gracefully_on_page_fault()*/
>  #include <linux/mm_types.h>
> +#include <linux/sev.h>			/* snp_lookup_rmpentry()	*/
>  
>  #include <asm/cpufeature.h>		/* boot_cpu_has, ...		*/
>  #include <asm/traps.h>			/* dotraplinkage, ...		*/
> @@ -1209,6 +1210,60 @@ do_kern_addr_fault(struct pt_regs *regs, unsigned long hw_error_code,
>  }
>  NOKPROBE_SYMBOL(do_kern_addr_fault);
>  
> +static inline size_t pages_per_hpage(int level)
> +{
> +	return page_level_size(level) / PAGE_SIZE;
> +}
> +
> +/*
> + * Return 1 if the caller need to retry, 0 if it the address need to be split
> + * in order to resolve the fault.
> + */
> +static int handle_user_rmp_page_fault(struct pt_regs *regs, unsigned long error_code,
> +				      unsigned long address)
> +{
> +	int rmp_level, level;
> +	pte_t *pte;
> +	u64 pfn;
> +
> +	pte = lookup_address_in_mm(current->mm, address, &level);
> +
> +	/*
> +	 * It can happen if there was a race between an unmap event and
> +	 * the RMP fault delivery.
> +	 */
> +	if (!pte || !pte_present(*pte))
> +		return 1;
> +
> +	pfn = pte_pfn(*pte);
> +
> +	/* If its large page then calculte the fault pfn */
> +	if (level > PG_LEVEL_4K) {
> +		unsigned long mask;
> +
> +		mask = pages_per_hpage(level) - pages_per_hpage(level - 1);
> +		pfn |= (address >> PAGE_SHIFT) & mask;
> +	}
> +
> +	/*
> +	 * If its a guest private page, then the fault cannot be resolved.
> +	 * Send a SIGBUS to terminate the process.
> +	 */
> +	if (snp_lookup_rmpentry(pfn, &rmp_level)) {

snp_lookup_rmpentry returns 0, 1 or -errno, so this should likely be:

  if (snp_lookup_rmpentry(pfn, &rmp_level) != 1)) {

> +		do_sigbus(regs, error_code, address, VM_FAULT_SIGBUS);
> +		return 1;
> +	}
> +
> +	/*
> +	 * The backing page level is higher than the RMP page level, request
> +	 * to split the page.
> +	 */
> +	if (level > rmp_level)
> +		return 0;
> +
> +	return 1;
> +}
> +
>  /*
>   * Handle faults in the user portion of the address space.  Nothing in here
>   * should check X86_PF_USER without a specific justification: for almost
> @@ -1306,6 +1361,17 @@ void do_user_addr_fault(struct pt_regs *regs,
>  	if (error_code & X86_PF_INSTR)
>  		flags |= FAULT_FLAG_INSTRUCTION;
>  
> +	/*
> +	 * If its an RMP violation, try resolving it.
> +	 */
> +	if (error_code & X86_PF_RMP) {
> +		if (handle_user_rmp_page_fault(regs, error_code, address))
> +			return;
> +
> +		/* Ask to split the page */
> +		flags |= FAULT_FLAG_PAGE_SPLIT;
> +	}
> +
>  #ifdef CONFIG_X86_64
>  	/*
>  	 * Faults in the vsyscall page might need emulation.  The
> diff --git a/include/linux/mm.h b/include/linux/mm.h
> index de32c0383387..2ccc562d166f 100644
> --- a/include/linux/mm.h
> +++ b/include/linux/mm.h
> @@ -463,7 +463,8 @@ static inline bool fault_flag_allow_retry_first(enum fault_flag flags)
>  	{ FAULT_FLAG_USER,		"USER" }, \
>  	{ FAULT_FLAG_REMOTE,		"REMOTE" }, \
>  	{ FAULT_FLAG_INSTRUCTION,	"INSTRUCTION" }, \
> -	{ FAULT_FLAG_INTERRUPTIBLE,	"INTERRUPTIBLE" }
> +	{ FAULT_FLAG_INTERRUPTIBLE,	"INTERRUPTIBLE" }, \
> +	{ FAULT_FLAG_PAGE_SPLIT,	"PAGESPLIT" }
>  
>  /*
>   * vm_fault is filled by the pagefault handler and passed to the vma's
> diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h
> index 6dfaf271ebf8..aa2d8d48ce3e 100644
> --- a/include/linux/mm_types.h
> +++ b/include/linux/mm_types.h
> @@ -818,6 +818,8 @@ typedef struct {
>   *                      mapped R/O.
>   * @FAULT_FLAG_ORIG_PTE_VALID: whether the fault has vmf->orig_pte cached.
>   *                        We should only access orig_pte if this flag set.
> + * @FAULT_FLAG_PAGE_SPLIT: The fault was due page size mismatch, split the
> + *                         region to smaller page size and retry.
>   *
>   * About @FAULT_FLAG_ALLOW_RETRY and @FAULT_FLAG_TRIED: we can specify
>   * whether we would allow page faults to retry by specifying these two
> @@ -855,6 +857,7 @@ enum fault_flag {
>  	FAULT_FLAG_INTERRUPTIBLE =	1 << 9,
>  	FAULT_FLAG_UNSHARE =		1 << 10,
>  	FAULT_FLAG_ORIG_PTE_VALID =	1 << 11,
> +	FAULT_FLAG_PAGE_SPLIT =		1 << 12,
>  };
>  
>  typedef unsigned int __bitwise zap_flags_t;
> diff --git a/mm/memory.c b/mm/memory.c
> index 7274f2b52bca..c2187ffcbb8e 100644
> --- a/mm/memory.c
> +++ b/mm/memory.c
> @@ -4945,6 +4945,15 @@ static vm_fault_t handle_pte_fault(struct vm_fault *vmf)
>  	return 0;
>  }
>  
> +static int handle_split_page_fault(struct vm_fault *vmf)
> +{
> +	if (!IS_ENABLED(CONFIG_AMD_MEM_ENCRYPT))
> +		return VM_FAULT_SIGBUS;
> +
> +	__split_huge_pmd(vmf->vma, vmf->pmd, vmf->address, false, NULL);
> +	return 0;
> +}
> +
>  /*
>   * By the time we get here, we already hold the mm semaphore
>   *
> @@ -5024,6 +5033,10 @@ static vm_fault_t __handle_mm_fault(struct vm_area_struct *vma,
>  				pmd_migration_entry_wait(mm, vmf.pmd);
>  			return 0;
>  		}
> +
> +		if (flags & FAULT_FLAG_PAGE_SPLIT)
> +			return handle_split_page_fault(&vmf);
> +
>  		if (pmd_trans_huge(vmf.orig_pmd) || pmd_devmap(vmf.orig_pmd)) {
>  			if (pmd_protnone(vmf.orig_pmd) && vma_is_accessible(vma))
>  				return do_huge_pmd_numa_page(&vmf);
> -- 
> 2.25.1
>
Borislav Petkov Aug. 9, 2022, 4:55 p.m. UTC | #2
On Mon, Jun 20, 2022 at 11:03:43PM +0000, Ashish Kalra wrote:
> From: Brijesh Singh <brijesh.singh@amd.com>
> 
> When SEV-SNP is enabled globally, a write from the host goes through the

globally?

Can SNP be even enabled any other way?

I see the APM talks about it being enabled globally, I guess this means
the RMP represents *all* system memory?

> @@ -1209,6 +1210,60 @@ do_kern_addr_fault(struct pt_regs *regs, unsigned long hw_error_code,
>  }
>  NOKPROBE_SYMBOL(do_kern_addr_fault);
>  
> +static inline size_t pages_per_hpage(int level)
> +{
> +	return page_level_size(level) / PAGE_SIZE;
> +}
> +
> +/*
> + * Return 1 if the caller need to retry, 0 if it the address need to be split
> + * in order to resolve the fault.
> + */

Magic numbers.

Pls do instead:

enum rmp_pf_ret {
	RMP_PF_SPLIT	= 0,
	RMP_PF_RETRY	= 1,
};

and use those instead.

> +static int handle_user_rmp_page_fault(struct pt_regs *regs, unsigned long error_code,
> +				      unsigned long address)
> +{
> +	int rmp_level, level;
> +	pte_t *pte;
> +	u64 pfn;
> +
> +	pte = lookup_address_in_mm(current->mm, address, &level);
> +
> +	/*
> +	 * It can happen if there was a race between an unmap event and
> +	 * the RMP fault delivery.
> +	 */

You need to elaborate more here: a RMP fault can happen and then the
page can get unmapped? What is the exact scenario here?

> +	if (!pte || !pte_present(*pte))
> +		return 1;
> +
> +	pfn = pte_pfn(*pte);
> +
> +	/* If its large page then calculte the fault pfn */
> +	if (level > PG_LEVEL_4K) {
> +		unsigned long mask;
> +
> +		mask = pages_per_hpage(level) - pages_per_hpage(level - 1);
> +		pfn |= (address >> PAGE_SHIFT) & mask;

Oh boy, this is unnecessarily complicated. Isn't this

	pfn |= pud_index(address);

or
	pfn |= pmd_index(address);

depending on the level?

I think it is but it needs more explaining.

In any case, those are two static masks exactly and they don't need to
be computed for each #PF.

> diff --git a/mm/memory.c b/mm/memory.c
> index 7274f2b52bca..c2187ffcbb8e 100644
> --- a/mm/memory.c
> +++ b/mm/memory.c
> @@ -4945,6 +4945,15 @@ static vm_fault_t handle_pte_fault(struct vm_fault *vmf)
>  	return 0;
>  }
>  
> +static int handle_split_page_fault(struct vm_fault *vmf)
> +{
> +	if (!IS_ENABLED(CONFIG_AMD_MEM_ENCRYPT))
> +		return VM_FAULT_SIGBUS;

Yah, this looks weird: generic code implies that page splitting after a
#PF makes sense only when SEV is present and none otherwise.

Why?
Ashish Kalra Aug. 10, 2022, 3:59 a.m. UTC | #3
[AMD Official Use Only - General]

Hello Boris,

>> When SEV-SNP is enabled globally, a write from the host goes through 
>> the

>globally?

>Can SNP be even enabled any other way?

>I see the APM talks about it being enabled globally, I guess this means the RMP represents *all* system memory?

Actually SNP feature can be enabled globally, but SNP is activated on a per VM basis.
Borislav Petkov Aug. 10, 2022, 9:42 a.m. UTC | #4
On Wed, Aug 10, 2022 at 03:59:34AM +0000, Kalra, Ashish wrote:
> Actually SNP feature can be enabled globally, but SNP is activated on a per VM basis.
> 
> From the APM:
> The term SNP-enabled indicates that SEV-SNP is globally enabled in the SYSCFG 
> MSR. The term SNP-active indicates that SEV-SNP is enabled for a specific VM in the 
> SEV_FEATURES field of its VMSA

Aha, and I was wondering whether "globally" meant the RMP needs to cover
all physical memory but I guess that isn't the case:

"RMP-Covered: Checks that the target page is covered by the RMP. A page
is covered by the RMP if its corresponding RMP entry is below RMP_END.
Any page not covered by the RMP is considered a Hypervisor-Owned page."

> >You need to elaborate more here: a RMP fault can happen and then the
> >page can get unmapped? What is the exact scenario here?
>
> Yes, if the page gets unmapped while the RMP fault was being handled,
> will add more explanation here.

So what's the logic here to return 1, i.e., retry?

Why should a fault for a page that gets unmapped be retried? The fault
in that case should be ignored, IMO. It'll have the same effect to
return from do_user_addr_fault() there, without splitting but you need
to have a separate return value definition so that it is clear what
needs to happen. And that return value should be != 0 so that the
current check still works.

> Actually, the above computes an index into the RMP table.

What index in the RMP table?

> It is basically an index into the 4K page within the hugepage mapped
> in the RMP table or in other words an index into the RMP table entry
> for 4K page(s) corresponding to a hugepage.

So pte_index(address) and for 1G pages, pmd_index(address).

So no reinventing the wheel if we already have helpers for that.

> It is mainly a wrapper around__split_huge_pmd() for SNP use case where
> the host hugepage is split to be in sync with the RMP table.

I see what it is. And I'm saying this looks wrong. You're enforcing page
splitting to be a valid thing to do only for SEV machines. Why?

Why is

        if (!IS_ENABLED(CONFIG_AMD_MEM_ENCRYPT))
                return VM_FAULT_SIGBUS;

there at all?

This is generic code you're touching - not arch/x86/.
Ashish Kalra Aug. 10, 2022, 10 p.m. UTC | #5
[AMD Official Use Only - General]

Hello Boris,

>> >You need to elaborate more here: a RMP fault can happen and then the 
>> >page can get unmapped? What is the exact scenario here?
>>
>> Yes, if the page gets unmapped while the RMP fault was being handled, 
>> will add more explanation here.

>So what's the logic here to return 1, i.e., retry?

>Why should a fault for a page that gets unmapped be retried? The fault in that case should be ignored, IMO. It'll have the same effect to return from do_user_addr_fault() there, without splitting but you need to have a separate return value >definition so that it is clear what needs to happen. And that return value should be != 0 so that the current check still works.

if (!pte || !pte_present(*pte))
                return 1;

This is more like a sanity check and returning 1 will cause the fault handler to return and ignore the fault for current #PF case.
If the page got unmapped, the fault will not happen again and there will be no retry, so the fault in this case is
being ignored.
The other case where 1 is returned is RMP table lookup failure, in that case the faulting process is being terminated,
that resolves the fault. 

>> Actually, the above computes an index into the RMP table.

>What index in the RMP table?

>> It is basically an index into the 4K page within the hugepage mapped 
>> in the RMP table or in other words an index into the RMP table entry 
>> for 4K page(s) corresponding to a hugepage.

>So pte_index(address) and for 1G pages, pmd_index(address).

>So no reinventing the wheel if we already have helpers for that.

Yes that makes sense and pte_index(address) is exactly what is
required for 2M hugepages.

Will use pte_index() for 2M pages and pmd_index() for 1G pages. 

>> It is mainly a wrapper around__split_huge_pmd() for SNP use case where 
>> the host hugepage is split to be in sync with the RMP table.

>I see what it is. And I'm saying this looks wrong. You're enforcing page splitting to be a valid thing to do only for SEV machines. Why?

>Why is

>        if (!IS_ENABLED(CONFIG_AMD_MEM_ENCRYPT))
>                return VM_FAULT_SIGBUS;

>there at all?

>This is generic code you're touching - not arch/x86/.

Ok, so you are suggesting that we remove this check and simply keep this function wrapping around __split_huge_pmd(). 
This becomes a generic utility function. 

Thanks,
Ashish
Borislav Petkov Aug. 11, 2022, 2:27 p.m. UTC | #6
On Wed, Aug 10, 2022 at 10:00:57PM +0000, Kalra, Ashish wrote:
> This is more like a sanity check and returning 1 will cause the fault
> handler to return and ignore the fault for current #PF case. If the
> page got unmapped, the fault will not happen again and there will be
> no retry, so the fault in this case is being ignored.

I know what will happen. I'm asking you to make this explicit in the
code because this separate define documents the situation.

One more return type != 0 won't hurt.

> Ok, so you are suggesting that we remove this check and simply keep
> this function wrapping around __split_huge_pmd(). This becomes a
> generic utility function.

Yes, it is in generic code so it better be generic function. That's why
I'm questioning the vendor-specific check there.
Vlastimil Babka Aug. 11, 2022, 3:15 p.m. UTC | #7
On 6/21/22 01:03, Ashish Kalra wrote:
> From: Brijesh Singh <brijesh.singh@amd.com>
> 
> When SEV-SNP is enabled globally, a write from the host goes through the
> RMP check. When the host writes to pages, hardware checks the following
> conditions at the end of page walk:
> 
> 1. Assigned bit in the RMP table is zero (i.e page is shared).
> 2. If the page table entry that gives the sPA indicates that the target
>    page size is a large page, then all RMP entries for the 4KB
>    constituting pages of the target must have the assigned bit 0.
> 3. Immutable bit in the RMP table is not zero.
> 
> The hardware will raise page fault if one of the above conditions is not
> met. Try resolving the fault instead of taking fault again and again. If
> the host attempts to write to the guest private memory then send the
> SIGBUS signal to kill the process. If the page level between the host and
> RMP entry does not match, then split the address to keep the RMP and host
> page levels in sync.
> 
> Signed-off-by: Brijesh Singh <brijesh.singh@amd.com>
> ---
>  arch/x86/mm/fault.c      | 66 ++++++++++++++++++++++++++++++++++++++++
>  include/linux/mm.h       |  3 +-
>  include/linux/mm_types.h |  3 ++
>  mm/memory.c              | 13 ++++++++
>  4 files changed, 84 insertions(+), 1 deletion(-)
> 
> diff --git a/arch/x86/mm/fault.c b/arch/x86/mm/fault.c
> index a4c270e99f7f..f5de9673093a 100644
> --- a/arch/x86/mm/fault.c
> +++ b/arch/x86/mm/fault.c
> @@ -19,6 +19,7 @@
>  #include <linux/uaccess.h>		/* faulthandler_disabled()	*/
>  #include <linux/efi.h>			/* efi_crash_gracefully_on_page_fault()*/
>  #include <linux/mm_types.h>
> +#include <linux/sev.h>			/* snp_lookup_rmpentry()	*/
>  
>  #include <asm/cpufeature.h>		/* boot_cpu_has, ...		*/
>  #include <asm/traps.h>			/* dotraplinkage, ...		*/
> @@ -1209,6 +1210,60 @@ do_kern_addr_fault(struct pt_regs *regs, unsigned long hw_error_code,
>  }
>  NOKPROBE_SYMBOL(do_kern_addr_fault);
>  
> +static inline size_t pages_per_hpage(int level)
> +{
> +	return page_level_size(level) / PAGE_SIZE;
> +}
> +
> +/*
> + * Return 1 if the caller need to retry, 0 if it the address need to be split
> + * in order to resolve the fault.
> + */
> +static int handle_user_rmp_page_fault(struct pt_regs *regs, unsigned long error_code,
> +				      unsigned long address)
> +{
> +	int rmp_level, level;
> +	pte_t *pte;
> +	u64 pfn;
> +
> +	pte = lookup_address_in_mm(current->mm, address, &level);
> +
> +	/*
> +	 * It can happen if there was a race between an unmap event and
> +	 * the RMP fault delivery.
> +	 */
> +	if (!pte || !pte_present(*pte))
> +		return 1;
> +
> +	pfn = pte_pfn(*pte);
> +
> +	/* If its large page then calculte the fault pfn */
> +	if (level > PG_LEVEL_4K) {
> +		unsigned long mask;
> +
> +		mask = pages_per_hpage(level) - pages_per_hpage(level - 1);
> +		pfn |= (address >> PAGE_SHIFT) & mask;
> +	}
> +
> +	/*
> +	 * If its a guest private page, then the fault cannot be resolved.
> +	 * Send a SIGBUS to terminate the process.
> +	 */
> +	if (snp_lookup_rmpentry(pfn, &rmp_level)) {
> +		do_sigbus(regs, error_code, address, VM_FAULT_SIGBUS);
> +		return 1;
> +	}
> +
> +	/*
> +	 * The backing page level is higher than the RMP page level, request
> +	 * to split the page.
> +	 */
> +	if (level > rmp_level)
> +		return 0;

I don't see any checks that make sure this is in fact a THP, and not e.g.
hugetlb (which is disallowed only later in patch 25/49), or even something
else unexpected. Calling blindly __split_huge_pmd() in
handle_split_page_fault() on anything that's not a THP will just make it
return without splitting anything, and then this will result in a page fault
loop? Some kind of warning and a SIGBUS would be more safe I think.

> +
> +	return 1;
> +}
> +
>  /*
>   * Handle faults in the user portion of the address space.  Nothing in here
>   * should check X86_PF_USER without a specific justification: for almost
> @@ -1306,6 +1361,17 @@ void do_user_addr_fault(struct pt_regs *regs,
>  	if (error_code & X86_PF_INSTR)
>  		flags |= FAULT_FLAG_INSTRUCTION;
>  
> +	/*
> +	 * If its an RMP violation, try resolving it.
> +	 */
> +	if (error_code & X86_PF_RMP) {
> +		if (handle_user_rmp_page_fault(regs, error_code, address))
> +			return;
> +
> +		/* Ask to split the page */
> +		flags |= FAULT_FLAG_PAGE_SPLIT;
> +	}
> +
>  #ifdef CONFIG_X86_64
>  	/*
>  	 * Faults in the vsyscall page might need emulation.  The
> diff --git a/include/linux/mm.h b/include/linux/mm.h
> index de32c0383387..2ccc562d166f 100644
> --- a/include/linux/mm.h
> +++ b/include/linux/mm.h
> @@ -463,7 +463,8 @@ static inline bool fault_flag_allow_retry_first(enum fault_flag flags)
>  	{ FAULT_FLAG_USER,		"USER" }, \
>  	{ FAULT_FLAG_REMOTE,		"REMOTE" }, \
>  	{ FAULT_FLAG_INSTRUCTION,	"INSTRUCTION" }, \
> -	{ FAULT_FLAG_INTERRUPTIBLE,	"INTERRUPTIBLE" }
> +	{ FAULT_FLAG_INTERRUPTIBLE,	"INTERRUPTIBLE" }, \
> +	{ FAULT_FLAG_PAGE_SPLIT,	"PAGESPLIT" }
>  
>  /*
>   * vm_fault is filled by the pagefault handler and passed to the vma's
> diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h
> index 6dfaf271ebf8..aa2d8d48ce3e 100644
> --- a/include/linux/mm_types.h
> +++ b/include/linux/mm_types.h
> @@ -818,6 +818,8 @@ typedef struct {
>   *                      mapped R/O.
>   * @FAULT_FLAG_ORIG_PTE_VALID: whether the fault has vmf->orig_pte cached.
>   *                        We should only access orig_pte if this flag set.
> + * @FAULT_FLAG_PAGE_SPLIT: The fault was due page size mismatch, split the
> + *                         region to smaller page size and retry.
>   *
>   * About @FAULT_FLAG_ALLOW_RETRY and @FAULT_FLAG_TRIED: we can specify
>   * whether we would allow page faults to retry by specifying these two
> @@ -855,6 +857,7 @@ enum fault_flag {
>  	FAULT_FLAG_INTERRUPTIBLE =	1 << 9,
>  	FAULT_FLAG_UNSHARE =		1 << 10,
>  	FAULT_FLAG_ORIG_PTE_VALID =	1 << 11,
> +	FAULT_FLAG_PAGE_SPLIT =		1 << 12,
>  };
>  
>  typedef unsigned int __bitwise zap_flags_t;
> diff --git a/mm/memory.c b/mm/memory.c
> index 7274f2b52bca..c2187ffcbb8e 100644
> --- a/mm/memory.c
> +++ b/mm/memory.c
> @@ -4945,6 +4945,15 @@ static vm_fault_t handle_pte_fault(struct vm_fault *vmf)
>  	return 0;
>  }
>  
> +static int handle_split_page_fault(struct vm_fault *vmf)
> +{
> +	if (!IS_ENABLED(CONFIG_AMD_MEM_ENCRYPT))
> +		return VM_FAULT_SIGBUS;
> +
> +	__split_huge_pmd(vmf->vma, vmf->pmd, vmf->address, false, NULL);
> +	return 0;
> +}
> +
>  /*
>   * By the time we get here, we already hold the mm semaphore
>   *
> @@ -5024,6 +5033,10 @@ static vm_fault_t __handle_mm_fault(struct vm_area_struct *vma,
>  				pmd_migration_entry_wait(mm, vmf.pmd);
>  			return 0;
>  		}
> +
> +		if (flags & FAULT_FLAG_PAGE_SPLIT)
> +			return handle_split_page_fault(&vmf);
> +
>  		if (pmd_trans_huge(vmf.orig_pmd) || pmd_devmap(vmf.orig_pmd)) {
>  			if (pmd_protnone(vmf.orig_pmd) && vma_is_accessible(vma))
>  				return do_huge_pmd_numa_page(&vmf);
Ashish Kalra Sept. 1, 2022, 8:32 p.m. UTC | #8
[AMD Official Use Only - General]

Hello Boris,

>> It is basically an index into the 4K page within the hugepage mapped 
>> in the RMP table or in other words an index into the RMP table entry 
>> for 4K page(s) corresponding to a hugepage.

>So pte_index(address) and for 1G pages, pmd_index(address).

>So no reinventing the wheel if we already have helpers for that.

>Yes that makes sense and pte_index(address) is exactly what is required for 2M hugepages.

>Will use pte_index() for 2M pages and pmd_index() for 1G pages. 

Had a relook into this. 

As I mentioned earlier, this is computing an index into a 4K page within a hugepage mapping,
therefore, though pte_index() works for 2M pages, but pmd_index() will not work for 1G pages.

We basically need to do :
pfn |= (address >> PAGE_SHIFT) & mask;

where mask is the (number of 4K pages per hugepage) - 1

So this still needs the original code but with a fix for mask computation as following : 

static inline size_t pages_per_hpage(int level)
        return page_level_size(level) / PAGE_SIZE;
 }
        
static int handle_user_rmp_page_fault(struct pt_regs *regs, unsigned long error_code,
                                      unsigned long address)
 {      
       ... 
       pfn = pte_pfn(*pte);
        
        /* If its large page then calculte the fault pfn */
        if (level > PG_LEVEL_4K) {
+               /*
+                * index into the 4K page within the hugepage mapping
+                * in the RMP table
+                */
                unsigned long mask;
        
-               mask = pages_per_hpage(level) - pages_per_hpage(level - 1);
+              mask = pages_per_hpage(level) - 1;
                pfn |= (address >> PAGE_SHIFT) & mask;


Thanks,
Ashish
Borislav Petkov Sept. 2, 2022, 6:52 a.m. UTC | #9
On Thu, Sep 01, 2022 at 08:32:35PM +0000, Kalra, Ashish wrote:
> As I mentioned earlier, this is computing an index into a 4K page
> within a hugepage mapping, therefore, though pte_index() works for 2M
> pages, but pmd_index() will not work for 1G pages.

Why not? What exactly do you need to get here?

So the way I understand it is, you want to map the faulting address to a
RMP entry. And that is either the 2M PMD entry when the page is a 1G one
and the 4K PTE entry when the page is a 2M one?

Why doesn't pmd_index() work?

Also, why isn't the lookup function's signature:

int snp_lookup_rmpentry(unsigned long address, int *level)

and all that logic to do the conversion to a PFN also not in it?

Thx.
Ashish Kalra Sept. 2, 2022, 3:33 p.m. UTC | #10
[AMD Official Use Only - General]

Hello Boris,

>> As I mentioned earlier, this is computing an index into a 4K page 
>> within a hugepage mapping, therefore, though pte_index() works for 2M 
>> pages, but pmd_index() will not work for 1G pages.

>Why not? What exactly do you need to get here?

>So the way I understand it is, you want to map the faulting address to a RMP entry. And that is either the 2M PMD entry when the page is a 1G one and the 4K PTE entry when the page is a 2M one?

>Why doesn't pmd_index() work?

Yes we want to map the faulting address to a RMP entry, but hugepage entries in RMP table are basically subpage 4K entries. So it is a 4K entry when the page is a 2M one 
and also a 4K entry when the page is a 1G one.

That's why the computation to get a 4K page index within a 2M/1G hugepage mapping is required.

>Also, why isn't the lookup function's signature:

>int snp_lookup_rmpentry(unsigned long address, int *level)

>and all that logic to do the conversion to a PFN also not in it?

Thanks,
Ashish
Borislav Petkov Sept. 3, 2022, 4:25 a.m. UTC | #11
On Fri, Sep 02, 2022 at 03:33:20PM +0000, Kalra, Ashish wrote:
> Yes we want to map the faulting address to a RMP entry, but hugepage
> entries in RMP table are basically subpage 4K entries. So it is a 4K
> entry when the page is a 2M one and also a 4K entry when the page is a
> 1G one.

Wait, what?!

APM v2 section "15.36.11 Large Page Management" and PSMASH are then for
what exactly?

> That's why the computation to get a 4K page index within a 2M/1G
> hugepage mapping is required.

What if a guest RMP-faults on a 2M page and there's a corresponding 2M
RMP entry? What do you need the 4K entry then for?

Hell, __snp_lookup_rmpentry() even tries to return the proper page
level...

/me looks in disbelief in your direction...

Thx.
Ashish Kalra Sept. 3, 2022, 5:51 a.m. UTC | #12
[AMD Official Use Only - General]

Hello Boris,

>> Yes we want to map the faulting address to a RMP entry, but hugepage 
>> entries in RMP table are basically subpage 4K entries. So it is a 4K 
>> entry when the page is a 2M one and also a 4K entry when the page is a 
>> 1G one.

>Wait, what?!

>APM v2 section "15.36.11 Large Page Management" and PSMASH are then for what exactly?

This is what exactly PSMASH is for, in case the 2MB RMP entry needs to be smashed if guest PVALIDATES a 4K page,
the HV will need to PSMASH the 2MB RMP entry to corresponding 4K RMP entries during #VMEXIT(NPF).

What I meant above is that 4K RMP table entries need to be available in case the 2MB RMP entry needs to be
smashed. 

>> That's why the computation to get a 4K page index within a 2M/1G 
>> hugepage mapping is required.

>What if a guest RMP-faults on a 2M page and there's a corresponding 2M RMP entry? What do you need the 4K entry then for?

There is no fault here, if guest pvalidates a 2M page that is backed by a 2MB RMP entry.
We need the 4K entries in case the guest pvalidates a 4K page that is mapped by a 2MB RMP entry.

>Hell, __snp_lookup_rmpentry() even tries to return the proper page level...

>/me looks in disbelief in your direction...

Thanks,
Ashish
Ashish Kalra Sept. 3, 2022, 6:57 a.m. UTC | #13
[AMD Official Use Only - General]

So essentially we want to map the faulting address to a RMP entry, considering the fact that a 2M host hugepage can be mapped as 
4K RMP table entries and 1G host hugepage can be mapped as 2M RMP table entries.

Hence, this mask computation is done as:
mask = pages_per_hpage(level) - pages_per_hpage(level -1);

and the final faulting pfn is computed as:
pfn |= (address >> PAGE_SHIFT) & mask;
      
Thanks,
Ashish    

-----Original Message-----
From: Kalra, Ashish 
Sent: Saturday, September 3, 2022 12:51 AM
To: Borislav Petkov <bp@alien8.de>
Cc: x86@kernel.org; linux-kernel@vger.kernel.org; kvm@vger.kernel.org; linux-coco@lists.linux.dev; linux-mm@kvack.org; linux-crypto@vger.kernel.org; tglx@linutronix.de; mingo@redhat.com; jroedel@suse.de; Lendacky, Thomas <Thomas.Lendacky@amd.com>; hpa@zytor.com; ardb@kernel.org; pbonzini@redhat.com; seanjc@google.com; vkuznets@redhat.com; jmattson@google.com; luto@kernel.org; dave.hansen@linux.intel.com; slp@redhat.com; pgonda@google.com; peterz@infradead.org; srinivas.pandruvada@linux.intel.com; rientjes@google.com; dovmurik@linux.ibm.com; tobin@ibm.com; Roth, Michael <Michael.Roth@amd.com>; vbabka@suse.cz; kirill@shutemov.name; ak@linux.intel.com; tony.luck@intel.com; marcorr@google.com; sathyanarayanan.kuppuswamy@linux.intel.com; alpergun@google.com; dgilbert@redhat.com; jarkko@kernel.org
Subject: RE: [PATCH Part2 v6 09/49] x86/fault: Add support to handle the RMP fault for user address

[AMD Official Use Only - General]

Hello Boris,

>> Yes we want to map the faulting address to a RMP entry, but hugepage 
>> entries in RMP table are basically subpage 4K entries. So it is a 4K 
>> entry when the page is a 2M one and also a 4K entry when the page is 
>> a 1G one.

>Wait, what?!

>APM v2 section "15.36.11 Large Page Management" and PSMASH are then for what exactly?

This is what exactly PSMASH is for, in case the 2MB RMP entry needs to be smashed if guest PVALIDATES a 4K page, the HV will need to PSMASH the 2MB RMP entry to corresponding 4K RMP entries during #VMEXIT(NPF).

What I meant above is that 4K RMP table entries need to be available in case the 2MB RMP entry needs to be smashed. 

>> That's why the computation to get a 4K page index within a 2M/1G 
>> hugepage mapping is required.

>What if a guest RMP-faults on a 2M page and there's a corresponding 2M RMP entry? What do you need the 4K entry then for?

There is no fault here, if guest pvalidates a 2M page that is backed by a 2MB RMP entry.
We need the 4K entries in case the guest pvalidates a 4K page that is mapped by a 2MB RMP entry.

>Hell, __snp_lookup_rmpentry() even tries to return the proper page level...

>/me looks in disbelief in your direction...

Thanks,
Ashish
Borislav Petkov Sept. 3, 2022, 8:31 a.m. UTC | #14
On September 3, 2022 6:57:51 AM UTC, "Kalra, Ashish" <Ashish.Kalra@amd.com> wrote:
>[AMD Official Use Only - General]
>
>So essentially we want to map the faulting address to a RMP entry, considering the fact that a 2M host hugepage can be mapped as 
>4K RMP table entries and 1G host hugepage can be mapped as 2M RMP table entries.

So something's seriously confusing or missing here because if you fault on a 2M host page and the underlying RMP entries are 4K then you can use pte_index().

If the host page is 1G and the underlying RMP entries are 2M, pmd_index() should work here too.

But this piecemeal back'n'forth doesn't seem to resolve this so I'd like to ask you pls to sit down, take your time and give a detailed example of the two possible cases and what the difference is between pte_/pmd_index and your way. Feel free to add actual debug output and paste it here.

Thanks.
Ashish Kalra Sept. 3, 2022, 5:30 p.m. UTC | #15
[AMD Official Use Only - General]

Hello Boris,

>>So essentially we want to map the faulting address to a RMP entry, 
>>considering the fact that a 2M host hugepage can be mapped as 4K RMP table entries and 1G host hugepage can be mapped as 2M RMP table entries.

>So something's seriously confusing or missing here because if you fault on a 2M host page and the underlying RMP entries are 4K then you can use pte_index().

>If the host page is 1G and the underlying RMP entries are 2M, pmd_index() should work here too.

>But this piecemeal back'n'forth doesn't seem to resolve this so I'd like to ask you pls to sit down, take your time and give a detailed example of the two possible cases and what the difference is between pte_/pmd_index and your way. Feel free to >add actual debug output and paste it here.

There is 1 64-bit RMP entry for every physical 4k page of DRAM, so essentially every 4K page of DRAM is represented by a RMP entry.

So even if host page is 1G and underlying (smashed/split) RMP entries are 2M, the RMP table entry has to be indexed to a 4K entry
corresponding to that.

If it was simply a 2M entry in the RMP table, then pmd_index() will work correctly.

Considering the following example: 

address = 0x40200000; 
level = PG_LEVEL_1G;
pfn  = 0x40000;
pfn |= pmd_index(address);
This will give the RMP table index as 0x40001.
And it will work if the RMP table entry was simply a 2MB entry, but we need to map this further to its corresponding 4K entry.

With the same example as above: 
level = PG_LEVEL_1G;
mask = pages_per_hpage(level) - pages_per_hpage(level - 1);
pfn |= (address >> PAGE_SHIFT) & mask;
This will give the RMP table index as 0x40200.       
Which is the correct RMP table entry for a 2MB smashed/split 1G page mapped further to its corresponding 4K entry.

Hopefully this clarifies why pmd_index() can't be used here.

Thanks,
Ashish
Borislav Petkov Sept. 4, 2022, 6:37 a.m. UTC | #16
On Sat, Sep 03, 2022 at 05:30:28PM +0000, Kalra, Ashish wrote:
> There is 1 64-bit RMP entry for every physical 4k page of DRAM, so
> essentially every 4K page of DRAM is represented by a RMP entry.

Before we get to the rest - this sounds wrong to me. My APM has:

"PSMASH	Page Smash

Expands a 2MB-page RMP entry into a corresponding set of contiguous
4KB-page RMP entries. The 2MB page’s system physical address is
specified in the RAX register. The new entries inherit the attributes
of the original entry. Upon completion, a return code is stored in EAX.
rFLAGS bits OF, ZF, AF, PF and SF are set based on this return code..."

So there *are* 2M entries in the RMP table.

> So even if host page is 1G and underlying (smashed/split) RMP
> entries are 2M, the RMP table entry has to be indexed to a 4K entry
> corresponding to that.

So if there are 2M entries in the RMP table, how does that indexing with
4K entries is supposed to work?

Hell, even PSMASH pseudocode shows how you go and write all those 512 4K
entries using the 2M entry as a template. So *before* you have smashed
that 2M entry, it *is* an *actual* 2M entry.

So if you fault on a page which is backed by that 2M RMP entry, you will
get that 2M RMP entry.

> If it was simply a 2M entry in the RMP table, then pmd_index() will
> work correctly.

Judging by the above text, it *can* *be* a 2M RMP entry!

By reading your example you're trying to tell me that a RMP #PF will
always need to work on 4K entries. Which would then need for a 2M entry
as above to be PSMASHed in order to get the 4K thing. But that would be
silly - RMP PFs will this way gradually break all 2M pages and degrage
performance for no real reason.

So this still looks real wrong to me.

Thx.
Dave Hansen Sept. 6, 2022, 2:30 a.m. UTC | #17
On 6/20/22 16:03, Ashish Kalra wrote:
> 
> When SEV-SNP is enabled globally, a write from the host goes through the
> RMP check. When the host writes to pages, hardware checks the following
> conditions at the end of page walk:
> 
> 1. Assigned bit in the RMP table is zero (i.e page is shared).
> 2. If the page table entry that gives the sPA indicates that the target
>    page size is a large page, then all RMP entries for the 4KB
>    constituting pages of the target must have the assigned bit 0.
> 3. Immutable bit in the RMP table is not zero.
> 
> The hardware will raise page fault if one of the above conditions is not
> met. Try resolving the fault instead of taking fault again and again. If
> the host attempts to write to the guest private memory then send the
> SIGBUS signal to kill the process. If the page level between the host and
> RMP entry does not match, then split the address to keep the RMP and host
> page levels in sync.

When you're working on this changelog for Borislav, I'd like to make one
other suggestion:  Please write it more logically and _less_ about what
the hardware is doing.  We don't need about the internal details of what
hardware is doing in the changelog.  Mentioning whether an RMP bit is 0
or 1 is kinda silly unless it matters to the code.

For instance, what does the immutable bit have to do with all of this?
There's no specific handling for it.  There are really only faults that
you can handle and faults that you can't.

There's also some major missing context here about how it guarantees
that pages that can't be handled *CAN* be split.  I think it has to do
with disallowing hugetlbfs which implies that the only pages that might
need splitting are THP's.+	/*
> +	 * If its an RMP violation, try resolving it.
> +	 */
> +	if (error_code & X86_PF_RMP) {
> +		if (handle_user_rmp_page_fault(regs, error_code, address))
> +			return;
> +
> +		/* Ask to split the page */
> +		flags |= FAULT_FLAG_PAGE_SPLIT;
> +	}

This also needs some chatter about why any failure to handle the fault
automatically means splitting a page.
Jarkko Sakkinen Sept. 6, 2022, 10:25 a.m. UTC | #18
On Tue, Aug 09, 2022 at 06:55:43PM +0200, Borislav Petkov wrote:
> On Mon, Jun 20, 2022 at 11:03:43PM +0000, Ashish Kalra wrote:
> > +	pfn = pte_pfn(*pte);
> > +
> > +	/* If its large page then calculte the fault pfn */
> > +	if (level > PG_LEVEL_4K) {
> > +		unsigned long mask;
> > +
> > +		mask = pages_per_hpage(level) - pages_per_hpage(level - 1);
> > +		pfn |= (address >> PAGE_SHIFT) & mask;
> 
> Oh boy, this is unnecessarily complicated. Isn't this
> 
> 	pfn |= pud_index(address);
> 
> or
> 	pfn |= pmd_index(address);

I played with this a bit and ended up with

        pfn = pte_pfn(*pte) | PFN_DOWN(address & page_level_mask(level - 1));

Unless I got something terribly wrong, this should do the
same (see the attached patch) as the existing calculations.

BR, Jarkko
Jarkko Sakkinen Sept. 6, 2022, 10:33 a.m. UTC | #19
On Tue, Sep 06, 2022 at 01:25:10PM +0300, Jarkko Sakkinen wrote:
> On Tue, Aug 09, 2022 at 06:55:43PM +0200, Borislav Petkov wrote:
> > On Mon, Jun 20, 2022 at 11:03:43PM +0000, Ashish Kalra wrote:
> > > +	pfn = pte_pfn(*pte);
> > > +
> > > +	/* If its large page then calculte the fault pfn */
> > > +	if (level > PG_LEVEL_4K) {
> > > +		unsigned long mask;
> > > +
> > > +		mask = pages_per_hpage(level) - pages_per_hpage(level - 1);
> > > +		pfn |= (address >> PAGE_SHIFT) & mask;
> > 
> > Oh boy, this is unnecessarily complicated. Isn't this
> > 
> > 	pfn |= pud_index(address);
> > 
> > or
> > 	pfn |= pmd_index(address);
> 
> I played with this a bit and ended up with
> 
>         pfn = pte_pfn(*pte) | PFN_DOWN(address & page_level_mask(level - 1));
> 
> Unless I got something terribly wrong, this should do the
> same (see the attached patch) as the existing calculations.

IMHO a better name for this function would be do_user_rmp_addr_fault() as
it is more consistent with the existing function names.

BR, Jarkko
Marc Orr Sept. 6, 2022, 1:54 p.m. UTC | #20
On Tue, Sep 6, 2022 at 3:25 AM Jarkko Sakkinen <jarkko@kernel.org> wrote:
>
> On Tue, Aug 09, 2022 at 06:55:43PM +0200, Borislav Petkov wrote:
> > On Mon, Jun 20, 2022 at 11:03:43PM +0000, Ashish Kalra wrote:
> > > +   pfn = pte_pfn(*pte);
> > > +
> > > +   /* If its large page then calculte the fault pfn */
> > > +   if (level > PG_LEVEL_4K) {
> > > +           unsigned long mask;
> > > +
> > > +           mask = pages_per_hpage(level) - pages_per_hpage(level - 1);
> > > +           pfn |= (address >> PAGE_SHIFT) & mask;
> >
> > Oh boy, this is unnecessarily complicated. Isn't this
> >
> >       pfn |= pud_index(address);
> >
> > or
> >       pfn |= pmd_index(address);
>
> I played with this a bit and ended up with
>
>         pfn = pte_pfn(*pte) | PFN_DOWN(address & page_level_mask(level - 1));
>
> Unless I got something terribly wrong, this should do the
> same (see the attached patch) as the existing calculations.

Actually, I don't think they're the same. I think Jarkko's version is
correct. Specifically:
- For level = PG_LEVEL_2M they're the same.
- For level = PG_LEVEL_1G:
The current code calculates a garbage mask:
mask = pages_per_hpage(level) - pages_per_hpage(level - 1);
translates to:
>>> hex(262144 - 512)
'0x3fe00'

But I believe Jarkko's version calculates the correct mask (below),
incorporating all 18 offset bits into the 1G page.
>>> hex(262144 -1)
'0x3ffff'
Ashish Kalra Sept. 6, 2022, 2:06 p.m. UTC | #21
[AMD Official Use Only - General]

Hello Boris,

>> There is 1 64-bit RMP entry for every physical 4k page of DRAM, so 
>> essentially every 4K page of DRAM is represented by a RMP entry.

>Before we get to the rest - this sounds wrong to me. My APM has:

>"PSMASH	Page Smash

>Expands a 2MB-page RMP entry into a corresponding set of contiguous 4KB-page RMP entries. The 2MB page's system physical address is specified in the RAX register. The new entries inherit the attributes of the original entry. Upon completion, a >return code is stored in EAX.
>rFLAGS bits OF, ZF, AF, PF and SF are set based on this return code..."

>So there *are* 2M entries in the RMP table.

> So even if host page is 1G and underlying (smashed/split) RMP entries 
> are 2M, the RMP table entry has to be indexed to a 4K entry 
> corresponding to that.

>So if there are 2M entries in the RMP table, how does that indexing with 4K entries is supposed to work?

>Hell, even PSMASH pseudocode shows how you go and write all those 512 4K entries using the 2M entry as a template. So *before* you have smashed that 2M entry, it *is* an *actual* 2M entry.

>So if you fault on a page which is backed by that 2M RMP entry, you will get that 2M RMP entry.

> If it was simply a 2M entry in the RMP table, then pmd_index() will 
> work correctly.

>Judging by the above text, it *can* *be* a 2M RMP entry!

>By reading your example you're trying to tell me that a RMP #PF will always need to work on 4K entries. Which would then need for a 2M entry as above to be PSMASHed in order to get the 4K thing. But that would be silly - RMP PFs will this way >gradually break all 2M pages and degrage performance for no real reason.

>So this still looks real wrong to me.

Please note that RMP table entries have only 2 page size indicators 4k and 2M, so it covers a max physical address range of 2MB.
In all cases, there is one RMP entry per 4K page and the index into the RMP table is basically address /PAGE_SIZE, and that does
not change for hugepages. Therefore we need to capture the address bits (from address) so that we index into the 
4K entry in the RMP table. 

An important point to note here is that RMPUPDATE instruction sets the Assigned bit for all the sub-page entries for
a hugepage mapping in RMP table, so we will get the correct "assigned" page information when we index into the 4K entry
in the RMP table and additionally,  __snp_lookup_rmpentry() gets the 2MB aligned entry in the RMP table to get the correct Page size.
(as below)

static struct rmpentry *__snp_lookup_rmpentry(u64 pfn, int *level)
{
         ..
        /* Read a large RMP entry to get the correct page level used in RMP entry. */
        large_entry = rmptable_entry(paddr & PMD_MASK);
        *level = RMP_TO_X86_PG_LEVEL(rmpentry_pagesize(large_entry));
        ..

Therefore, the 2M entry and it's subpages in the RMP table will always exist because of the RMPUPDATE instruction even
without smashing/splitting of the hugepage, so we really don't need the 2MB entry to be PSMASHed in order to get the 4K thing. 

Thanks,
Ashish
Ashish Kalra Sept. 6, 2022, 2:17 p.m. UTC | #22
[AMD Official Use Only - General]

>> On Tue, Aug 09, 2022 at 06:55:43PM +0200, Borislav Petkov wrote:
>> > On Mon, Jun 20, 2022 at 11:03:43PM +0000, Ashish Kalra wrote:
>> > > +   pfn = pte_pfn(*pte);
>> > > +
>> > > +   /* If its large page then calculte the fault pfn */
>> > > +   if (level > PG_LEVEL_4K) {
>> > > +           unsigned long mask;
>> > > +
>> > > +           mask = pages_per_hpage(level) - pages_per_hpage(level - 1);
>> > > +           pfn |= (address >> PAGE_SHIFT) & mask;
>> >
>> > Oh boy, this is unnecessarily complicated. Isn't this
>> >
>> >       pfn |= pud_index(address);
>> >
>> > or
>> >       pfn |= pmd_index(address);
>>
>> I played with this a bit and ended up with
>>
>>         pfn = pte_pfn(*pte) | PFN_DOWN(address & page_level_mask(level 
>> - 1));
>>
>> Unless I got something terribly wrong, this should do the same (see 
>> the attached patch) as the existing calculations.

>Actually, I don't think they're the same. I think Jarkko's version is correct. Specifically:
>- For level = PG_LEVEL_2M they're the same.
>- For level = PG_LEVEL_1G:
>The current code calculates a garbage mask:
>mask = pages_per_hpage(level) - pages_per_hpage(level - 1); translates to:
>>> hex(262144 - 512)
>'0x3fe00'

No actually this is not a garbage mask, as I explained in earlier responses we need to capture the address bits 
to get to the correct 4K index into the RMP table.
Therefore, for level = PG_LEVEL_1G:
mask = pages_per_hpage(level) - pages_per_hpage(level - 1) => 0x3fe00 (which is the correct mask).

>But I believe Jarkko's version calculates the correct mask (below), incorporating all 18 offset bits into the 1G page.
>>> hex(262144 -1)
>'0x3ffff'

We can get this simply by doing (page_per_hpage(level)-1), but as I mentioned above this is not what we need.

Thanks,
Ashish
Michael Roth Sept. 6, 2022, 3:06 p.m. UTC | #23
On Tue, Sep 06, 2022 at 09:17:15AM -0500, Kalra, Ashish wrote:
> [AMD Official Use Only - General]
> 
> >> On Tue, Aug 09, 2022 at 06:55:43PM +0200, Borislav Petkov wrote:
> >> > On Mon, Jun 20, 2022 at 11:03:43PM +0000, Ashish Kalra wrote:
> >> > > +   pfn = pte_pfn(*pte);
> >> > > +
> >> > > +   /* If its large page then calculte the fault pfn */
> >> > > +   if (level > PG_LEVEL_4K) {
> >> > > +           unsigned long mask;
> >> > > +
> >> > > +           mask = pages_per_hpage(level) - pages_per_hpage(level - 1);
> >> > > +           pfn |= (address >> PAGE_SHIFT) & mask;
> >> >
> >> > Oh boy, this is unnecessarily complicated. Isn't this
> >> >
> >> >       pfn |= pud_index(address);
> >> >
> >> > or
> >> >       pfn |= pmd_index(address);
> >>
> >> I played with this a bit and ended up with
> >>
> >>         pfn = pte_pfn(*pte) | PFN_DOWN(address & page_level_mask(level 
> >> - 1));
> >>
> >> Unless I got something terribly wrong, this should do the same (see 
> >> the attached patch) as the existing calculations.
> 
> >Actually, I don't think they're the same. I think Jarkko's version is correct. Specifically:
> >- For level = PG_LEVEL_2M they're the same.
> >- For level = PG_LEVEL_1G:
> >The current code calculates a garbage mask:
> >mask = pages_per_hpage(level) - pages_per_hpage(level - 1); translates to:
> >>> hex(262144 - 512)
> >'0x3fe00'
> 
> No actually this is not a garbage mask, as I explained in earlier responses we need to capture the address bits 
> to get to the correct 4K index into the RMP table.
> Therefore, for level = PG_LEVEL_1G:
> mask = pages_per_hpage(level) - pages_per_hpage(level - 1) => 0x3fe00 (which is the correct mask).

That's the correct mask to grab the 2M-aligned address bits, e.g:

  pfn_mask = 3fe00h = 11 1111 1110 0000 0000b
  
  So the last 9 bits are ignored, e.g. anything PFNs that are multiples
  of 512 (2^9), and the upper bits comes from the 1GB PTE entry.

But there is an open question of whether we actually want to index using
2M-aligned or specific 4K-aligned PFN indicated by the faulting address.

> 
> >But I believe Jarkko's version calculates the correct mask (below), incorporating all 18 offset bits into the 1G page.
> >>> hex(262144 -1)
> >'0x3ffff'
> 
> We can get this simply by doing (page_per_hpage(level)-1), but as I mentioned above this is not what we need.

If we actually want the 4K page, I think we would want to use the 0x3ffff
mask as Marc suggested to get to the specific 4K RMP entry, which I don't
think the current code is trying to do. But maybe that *should* be what we
should be doing.

Based on your earlier explanation, if we index into the RMP table using
2M-aligned address, we might find that the entry does not have the
page-size bit set (maybe it was PSMASH'd for some reason). If that's the
cause we'd then have to calculate the index for the specific RMP entry
for the specific 4K address that caused the fault, and then check that
instead.

If however we simply index directly in the 4K RMP entry from the start,
snp_lookup_rmpentry() should still tell us whether the page is private
or not, because RMPUPDATE/PSMASH are both documented to also update the
assigned bits for each 4K RMP entry even if you're using a 2M RMP entry
and setting the page-size bit to cover the whole 2M range.

Additionally, snp_lookup_rmpentry() already has logic to also go check
the 2M-aligned RMP entry to provide an indication of what level it is
mapped at in the RMP table, so we can still use that to determine if the
host mapping needs to be split or not.

One thing that could use some confirmation is what happens if you do an
RMPUPDATE for a 2MB RMP entry, and then go back and try to RMPUPDATE a
sub-page and change the assigned bit so it's not consistent with 2MB RMP
entry. I would assume that would fail the instruction, but we should
confirm that before relying on this logic.

-Mike

> 
> Thanks,
> Ashish
Jarkko Sakkinen Sept. 6, 2022, 3:44 p.m. UTC | #24
On Tue, Sep 06, 2022 at 02:17:15PM +0000, Kalra, Ashish wrote:
> [AMD Official Use Only - General]
> 
> >> On Tue, Aug 09, 2022 at 06:55:43PM +0200, Borislav Petkov wrote:
> >> > On Mon, Jun 20, 2022 at 11:03:43PM +0000, Ashish Kalra wrote:
> >> > > +   pfn = pte_pfn(*pte);
> >> > > +
> >> > > +   /* If its large page then calculte the fault pfn */
> >> > > +   if (level > PG_LEVEL_4K) {
> >> > > +           unsigned long mask;
> >> > > +
> >> > > +           mask = pages_per_hpage(level) - pages_per_hpage(level - 1);
> >> > > +           pfn |= (address >> PAGE_SHIFT) & mask;
> >> >
> >> > Oh boy, this is unnecessarily complicated. Isn't this
> >> >
> >> >       pfn |= pud_index(address);
> >> >
> >> > or
> >> >       pfn |= pmd_index(address);
> >>
> >> I played with this a bit and ended up with
> >>
> >>         pfn = pte_pfn(*pte) | PFN_DOWN(address & page_level_mask(level 
> >> - 1));
> >>
> >> Unless I got something terribly wrong, this should do the same (see 
> >> the attached patch) as the existing calculations.
> 
> >Actually, I don't think they're the same. I think Jarkko's version is correct. Specifically:
> >- For level = PG_LEVEL_2M they're the same.
> >- For level = PG_LEVEL_1G:
> >The current code calculates a garbage mask:
> >mask = pages_per_hpage(level) - pages_per_hpage(level - 1); translates to:
> >>> hex(262144 - 512)
> >'0x3fe00'
> 
> No actually this is not a garbage mask, as I explained in earlier responses we need to capture the address bits 
> to get to the correct 4K index into the RMP table.
> Therefore, for level = PG_LEVEL_1G:
> mask = pages_per_hpage(level) - pages_per_hpage(level - 1) => 0x3fe00 (which is the correct mask).
> 
> >But I believe Jarkko's version calculates the correct mask (below), incorporating all 18 offset bits into the 1G page.
> >>> hex(262144 -1)
> >'0x3ffff'
> 
> We can get this simply by doing (page_per_hpage(level)-1), but as I mentioned above this is not what we need.

I think you're correct, so I'll retry:

(address / PAGE_SIZE) & (pages_per_hpage(level) - pages_per_hpage(level - 1)) =

(address / PAGE_SIZE) & ((page_level_size(level) / PAGE_SIZE) - (page_level_size(level - 1) / PAGE_SIZE)) =

[ factor out 1 / PAGE_SIZE ]

(address & (page_level_size(level) - page_level_size(level - 1))) / PAGE_SIZE  =

[ Substitute with PFN_DOWN() ] 

PFN_DOWN(address & (page_level_size(level) - page_level_size(level - 1)))

So you can just:

pfn = pte_pfn(*pte) | PFN_DOWN(address & (page_level_size(level) - page_level_size(level - 1)));

Which is IMHO way better still what it is now because no branching
and no ad-hoc helpers (the current is essentially just page_level_size
wrapper).

BR, Jarkko
Ashish Kalra Sept. 6, 2022, 4:39 p.m. UTC | #25
[AMD Official Use Only - General]

>> >Actually, I don't think they're the same. I think Jarkko's version is correct. Specifically:
>> >- For level = PG_LEVEL_2M they're the same.
>> >- For level = PG_LEVEL_1G:
>> >The current code calculates a garbage mask:
>> >mask = pages_per_hpage(level) - pages_per_hpage(level - 1); translates to:
>> >>> hex(262144 - 512)
>> >'0x3fe00'
>> 
>> No actually this is not a garbage mask, as I explained in earlier 
>> responses we need to capture the address bits to get to the correct 4K index into the RMP table.
>> Therefore, for level = PG_LEVEL_1G:
>> mask = pages_per_hpage(level) - pages_per_hpage(level - 1) => 0x3fe00 (which is the correct mask).

>That's the correct mask to grab the 2M-aligned address bits, e.g:

>  pfn_mask = 3fe00h = 11 1111 1110 0000 0000b
  
>  So the last 9 bits are ignored, e.g. anything PFNs that are multiples
>  of 512 (2^9), and the upper bits comes from the 1GB PTE entry.

> But there is an open question of whether we actually want to index using 2M-aligned or specific 4K-aligned PFN indicated by the faulting address.

>> 
>> >But I believe Jarkko's version calculates the correct mask (below), incorporating all 18 offset bits into the 1G page.
>> >>> hex(262144 -1)
>> >'0x3ffff'
>> 
>> We can get this simply by doing (page_per_hpage(level)-1), but as I mentioned above this is not what we need.

>If we actually want the 4K page, I think we would want to use the 0x3ffff mask as Marc suggested to get to the specific 4K RMP entry, which I don't think the current code is trying to do. But maybe that *should* be what we should be doing.

Ok, I agree to get to the specific 4K RMP entry.

>Based on your earlier explanation, if we index into the RMP table using 2M-aligned address, we might find that the entry does not have the page-size bit set (maybe it was PSMASH'd for some reason). 

I believe that PSMASH does update the 2M-aligned RMP table entry to the smashed page size.
It sets all the 4K intermediate/smashed pages size to 4K and changes the page size of the base RMP table (2M-aligned) entry  to 4K.

>If that's the cause we'd then have to calculate the index for the specific RMP entry for the specific 4K address that caused the fault, and then check that instead.

>If however we simply index directly in the 4K RMP entry from the start,
>snp_lookup_rmpentry() should still tell us whether the page is private or not, because RMPUPDATE/PSMASH are both documented to also update the assigned bits for each 4K RMP entry even if you're using a 2M RMP entry and setting the page-size >bit to cover the whole 2M range.

I think it does make sense to index directly into the 4K RMP entry, as we should be indexing into the most granular entry in the RMP table, and that will have the page "assigned" information as both RMPUPDATE/PSMASH would update
the assigned bits for each 4K RMP entry even if we using a 2MB RMP entry (this is an important point to note).

>Additionally, snp_lookup_rmpentry() already has logic to also go check the 2M-aligned RMP entry to provide an indication of what level it is mapped at in the RMP table, so we can still use that to determine if the host mapping needs to be split or >not.

Yes.

>One thing that could use some confirmation is what happens if you do an RMPUPDATE for a 2MB RMP entry, and then go back and try to RMPUPDATE a sub-page and change the assigned bit so it's not consistent with 2MB RMP entry. I would assume >that would fail the instruction, but we should confirm that before relying on this logic.

I agree.

Thanks,
Ashish
Marc Orr Sept. 7, 2022, 5:14 a.m. UTC | #26
> >> >But I believe Jarkko's version calculates the correct mask (below), incorporating all 18 offset bits into the 1G page.
> >> >>> hex(262144 -1)
> >> >'0x3ffff'
> >>
> >> We can get this simply by doing (page_per_hpage(level)-1), but as I mentioned above this is not what we need.
>
> >If we actually want the 4K page, I think we would want to use the 0x3ffff mask as Marc suggested to get to the specific 4K RMP entry, which I don't think the current code is trying to do. But maybe that *should* be what we should be doing.
>
> Ok, I agree to get to the specific 4K RMP entry.

Thanks, Michael, for a thorough and complete reply! I have to admit,
there was some nuance I missed in my earlier reply. But after reading
through what you wrote, I agree, always going to the 4k-entry to get
the "assigned" bit and also leveraging the implementation of
snp_lookup_rmpentry() to lookup the size bit in the 2M-aligned entry
seems like an elegant way to code this up. Assuming this suggestion
becomes the consensus, we might consider a comment in the source code
to capture this discussion. Otherwise, I think I'll forget all of this
the next time I'm reading this code :-). Something like:

/*
 * The guest-assigned bit is always propagated to the paddr's respective 4k RMP
 * entry -- regardless of the actual RMP page size. In contrast, the RMP page
 * size, handled in snp_lookup_rmpentry(), is indicated by the 2M-aligned RMP
 * entry.
 */
Jarkko Sakkinen Sept. 8, 2022, 7:46 a.m. UTC | #27
On Tue, Sep 06, 2022 at 06:44:23PM +0300, Jarkko Sakkinen wrote:
> On Tue, Sep 06, 2022 at 02:17:15PM +0000, Kalra, Ashish wrote:
> > [AMD Official Use Only - General]
> > 
> > >> On Tue, Aug 09, 2022 at 06:55:43PM +0200, Borislav Petkov wrote:
> > >> > On Mon, Jun 20, 2022 at 11:03:43PM +0000, Ashish Kalra wrote:
> > >> > > +   pfn = pte_pfn(*pte);
> > >> > > +
> > >> > > +   /* If its large page then calculte the fault pfn */
> > >> > > +   if (level > PG_LEVEL_4K) {
> > >> > > +           unsigned long mask;
> > >> > > +
> > >> > > +           mask = pages_per_hpage(level) - pages_per_hpage(level - 1);
> > >> > > +           pfn |= (address >> PAGE_SHIFT) & mask;
> > >> >
> > >> > Oh boy, this is unnecessarily complicated. Isn't this
> > >> >
> > >> >       pfn |= pud_index(address);
> > >> >
> > >> > or
> > >> >       pfn |= pmd_index(address);
> > >>
> > >> I played with this a bit and ended up with
> > >>
> > >>         pfn = pte_pfn(*pte) | PFN_DOWN(address & page_level_mask(level 
> > >> - 1));
> > >>
> > >> Unless I got something terribly wrong, this should do the same (see 
> > >> the attached patch) as the existing calculations.
> > 
> > >Actually, I don't think they're the same. I think Jarkko's version is correct. Specifically:
> > >- For level = PG_LEVEL_2M they're the same.
> > >- For level = PG_LEVEL_1G:
> > >The current code calculates a garbage mask:
> > >mask = pages_per_hpage(level) - pages_per_hpage(level - 1); translates to:
> > >>> hex(262144 - 512)
> > >'0x3fe00'
> > 
> > No actually this is not a garbage mask, as I explained in earlier responses we need to capture the address bits 
> > to get to the correct 4K index into the RMP table.
> > Therefore, for level = PG_LEVEL_1G:
> > mask = pages_per_hpage(level) - pages_per_hpage(level - 1) => 0x3fe00 (which is the correct mask).
> > 
> > >But I believe Jarkko's version calculates the correct mask (below), incorporating all 18 offset bits into the 1G page.
> > >>> hex(262144 -1)
> > >'0x3ffff'
> > 
> > We can get this simply by doing (page_per_hpage(level)-1), but as I mentioned above this is not what we need.
> 
> I think you're correct, so I'll retry:
> 
> (address / PAGE_SIZE) & (pages_per_hpage(level) - pages_per_hpage(level - 1)) =
> 
> (address / PAGE_SIZE) & ((page_level_size(level) / PAGE_SIZE) - (page_level_size(level - 1) / PAGE_SIZE)) =
> 
> [ factor out 1 / PAGE_SIZE ]
> 
> (address & (page_level_size(level) - page_level_size(level - 1))) / PAGE_SIZE  =
> 
> [ Substitute with PFN_DOWN() ] 
> 
> PFN_DOWN(address & (page_level_size(level) - page_level_size(level - 1)))
> 
> So you can just:
> 
> pfn = pte_pfn(*pte) | PFN_DOWN(address & (page_level_size(level) - page_level_size(level - 1)));
> 
> Which is IMHO way better still what it is now because no branching
> and no ad-hoc helpers (the current is essentially just page_level_size
> wrapper).

I created a small test program:

$ cat test.c
#include <stdio.h>
int main(void)
{
        unsigned long arr[] = {0x8, 0x1000, 0x200000, 0x40000000, 0x8000000000};
        int i;

        for (i = 1; i < sizeof(arr)/sizeof(unsigned long); i++) {
                printf("%048b\n", arr[i] - arr[i - 1]);
                printf("%048b\n", (arr[i] - 1) ^ (arr[i - 1] - 1));
        }
}

kultaheltta in linux on  host-snp-v7 [?]
$ gcc -o test test.c

kultaheltta in linux on  host-snp-v7 [?]
$ ./test
000000000000000000000000000000000000111111111000
000000000000000000000000000000000000111111111000
000000000000000000000000000111111111000000000000
000000000000000000000000000111111111000000000000
000000000000000000111111111000000000000000000000
000000000000000000111111111000000000000000000000
000000000000000011000000000000000000000000000000
000000000000000011000000000000000000000000000000

So the operation could be described as:

        pfn = PFN_DOWN(address & (~page_level_mask(level) ^ ~page_level_mask(level - 1)));

Which IMHO already documents itself quite well: index
with the granularity of PGD by removing bits used for
PGD's below it.

BR, Jarkko
Jarkko Sakkinen Sept. 8, 2022, 7:57 a.m. UTC | #28
On Thu, Sep 08, 2022 at 10:46:51AM +0300, Jarkko Sakkinen wrote:
> On Tue, Sep 06, 2022 at 06:44:23PM +0300, Jarkko Sakkinen wrote:
> > On Tue, Sep 06, 2022 at 02:17:15PM +0000, Kalra, Ashish wrote:
> > > [AMD Official Use Only - General]
> > > 
> > > >> On Tue, Aug 09, 2022 at 06:55:43PM +0200, Borislav Petkov wrote:
> > > >> > On Mon, Jun 20, 2022 at 11:03:43PM +0000, Ashish Kalra wrote:
> > > >> > > +   pfn = pte_pfn(*pte);
> > > >> > > +
> > > >> > > +   /* If its large page then calculte the fault pfn */
> > > >> > > +   if (level > PG_LEVEL_4K) {
> > > >> > > +           unsigned long mask;
> > > >> > > +
> > > >> > > +           mask = pages_per_hpage(level) - pages_per_hpage(level - 1);
> > > >> > > +           pfn |= (address >> PAGE_SHIFT) & mask;
> > > >> >
> > > >> > Oh boy, this is unnecessarily complicated. Isn't this
> > > >> >
> > > >> >       pfn |= pud_index(address);
> > > >> >
> > > >> > or
> > > >> >       pfn |= pmd_index(address);
> > > >>
> > > >> I played with this a bit and ended up with
> > > >>
> > > >>         pfn = pte_pfn(*pte) | PFN_DOWN(address & page_level_mask(level 
> > > >> - 1));
> > > >>
> > > >> Unless I got something terribly wrong, this should do the same (see 
> > > >> the attached patch) as the existing calculations.
> > > 
> > > >Actually, I don't think they're the same. I think Jarkko's version is correct. Specifically:
> > > >- For level = PG_LEVEL_2M they're the same.
> > > >- For level = PG_LEVEL_1G:
> > > >The current code calculates a garbage mask:
> > > >mask = pages_per_hpage(level) - pages_per_hpage(level - 1); translates to:
> > > >>> hex(262144 - 512)
> > > >'0x3fe00'
> > > 
> > > No actually this is not a garbage mask, as I explained in earlier responses we need to capture the address bits 
> > > to get to the correct 4K index into the RMP table.
> > > Therefore, for level = PG_LEVEL_1G:
> > > mask = pages_per_hpage(level) - pages_per_hpage(level - 1) => 0x3fe00 (which is the correct mask).
> > > 
> > > >But I believe Jarkko's version calculates the correct mask (below), incorporating all 18 offset bits into the 1G page.
> > > >>> hex(262144 -1)
> > > >'0x3ffff'
> > > 
> > > We can get this simply by doing (page_per_hpage(level)-1), but as I mentioned above this is not what we need.
> > 
> > I think you're correct, so I'll retry:
> > 
> > (address / PAGE_SIZE) & (pages_per_hpage(level) - pages_per_hpage(level - 1)) =
> > 
> > (address / PAGE_SIZE) & ((page_level_size(level) / PAGE_SIZE) - (page_level_size(level - 1) / PAGE_SIZE)) =
> > 
> > [ factor out 1 / PAGE_SIZE ]
> > 
> > (address & (page_level_size(level) - page_level_size(level - 1))) / PAGE_SIZE  =
> > 
> > [ Substitute with PFN_DOWN() ] 
> > 
> > PFN_DOWN(address & (page_level_size(level) - page_level_size(level - 1)))
> > 
> > So you can just:
> > 
> > pfn = pte_pfn(*pte) | PFN_DOWN(address & (page_level_size(level) - page_level_size(level - 1)));
> > 
> > Which is IMHO way better still what it is now because no branching
> > and no ad-hoc helpers (the current is essentially just page_level_size
> > wrapper).
> 
> I created a small test program:
> 
> $ cat test.c
> #include <stdio.h>
> int main(void)
> {
>         unsigned long arr[] = {0x8, 0x1000, 0x200000, 0x40000000, 0x8000000000};
>         int i;
> 
>         for (i = 1; i < sizeof(arr)/sizeof(unsigned long); i++) {
>                 printf("%048b\n", arr[i] - arr[i - 1]);
>                 printf("%048b\n", (arr[i] - 1) ^ (arr[i - 1] - 1));
>         }
> }
> 
> kultaheltta in linux on  host-snp-v7 [?]
> $ gcc -o test test.c
> 
> kultaheltta in linux on  host-snp-v7 [?]
> $ ./test
> 000000000000000000000000000000000000111111111000
> 000000000000000000000000000000000000111111111000
> 000000000000000000000000000111111111000000000000
> 000000000000000000000000000111111111000000000000
> 000000000000000000111111111000000000000000000000
> 000000000000000000111111111000000000000000000000
> 000000000000000011000000000000000000000000000000
> 000000000000000011000000000000000000000000000000
> 
> So the operation could be described as:
> 
>         pfn = PFN_DOWN(address & (~page_level_mask(level) ^ ~page_level_mask(level - 1)));
> 
> Which IMHO already documents itself quite well: index
> with the granularity of PGD by removing bits used for
> PGD's below it.

I mean:

       pfn =  pte_pfn(*pte) | PFN_DOWN(address & (~page_level_mask(level) ^ ~page_level_mask(level - 1)));

Note that PG_LEVEL_4K check is unnecessary as the result
will be zero after PFN_DOWN().

BR, Jarkko
diff mbox series

Patch

diff --git a/arch/x86/mm/fault.c b/arch/x86/mm/fault.c
index a4c270e99f7f..f5de9673093a 100644
--- a/arch/x86/mm/fault.c
+++ b/arch/x86/mm/fault.c
@@ -19,6 +19,7 @@ 
 #include <linux/uaccess.h>		/* faulthandler_disabled()	*/
 #include <linux/efi.h>			/* efi_crash_gracefully_on_page_fault()*/
 #include <linux/mm_types.h>
+#include <linux/sev.h>			/* snp_lookup_rmpentry()	*/
 
 #include <asm/cpufeature.h>		/* boot_cpu_has, ...		*/
 #include <asm/traps.h>			/* dotraplinkage, ...		*/
@@ -1209,6 +1210,60 @@  do_kern_addr_fault(struct pt_regs *regs, unsigned long hw_error_code,
 }
 NOKPROBE_SYMBOL(do_kern_addr_fault);
 
+static inline size_t pages_per_hpage(int level)
+{
+	return page_level_size(level) / PAGE_SIZE;
+}
+
+/*
+ * Return 1 if the caller need to retry, 0 if it the address need to be split
+ * in order to resolve the fault.
+ */
+static int handle_user_rmp_page_fault(struct pt_regs *regs, unsigned long error_code,
+				      unsigned long address)
+{
+	int rmp_level, level;
+	pte_t *pte;
+	u64 pfn;
+
+	pte = lookup_address_in_mm(current->mm, address, &level);
+
+	/*
+	 * It can happen if there was a race between an unmap event and
+	 * the RMP fault delivery.
+	 */
+	if (!pte || !pte_present(*pte))
+		return 1;
+
+	pfn = pte_pfn(*pte);
+
+	/* If its large page then calculte the fault pfn */
+	if (level > PG_LEVEL_4K) {
+		unsigned long mask;
+
+		mask = pages_per_hpage(level) - pages_per_hpage(level - 1);
+		pfn |= (address >> PAGE_SHIFT) & mask;
+	}
+
+	/*
+	 * If its a guest private page, then the fault cannot be resolved.
+	 * Send a SIGBUS to terminate the process.
+	 */
+	if (snp_lookup_rmpentry(pfn, &rmp_level)) {
+		do_sigbus(regs, error_code, address, VM_FAULT_SIGBUS);
+		return 1;
+	}
+
+	/*
+	 * The backing page level is higher than the RMP page level, request
+	 * to split the page.
+	 */
+	if (level > rmp_level)
+		return 0;
+
+	return 1;
+}
+
 /*
  * Handle faults in the user portion of the address space.  Nothing in here
  * should check X86_PF_USER without a specific justification: for almost
@@ -1306,6 +1361,17 @@  void do_user_addr_fault(struct pt_regs *regs,
 	if (error_code & X86_PF_INSTR)
 		flags |= FAULT_FLAG_INSTRUCTION;
 
+	/*
+	 * If its an RMP violation, try resolving it.
+	 */
+	if (error_code & X86_PF_RMP) {
+		if (handle_user_rmp_page_fault(regs, error_code, address))
+			return;
+
+		/* Ask to split the page */
+		flags |= FAULT_FLAG_PAGE_SPLIT;
+	}
+
 #ifdef CONFIG_X86_64
 	/*
 	 * Faults in the vsyscall page might need emulation.  The
diff --git a/include/linux/mm.h b/include/linux/mm.h
index de32c0383387..2ccc562d166f 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -463,7 +463,8 @@  static inline bool fault_flag_allow_retry_first(enum fault_flag flags)
 	{ FAULT_FLAG_USER,		"USER" }, \
 	{ FAULT_FLAG_REMOTE,		"REMOTE" }, \
 	{ FAULT_FLAG_INSTRUCTION,	"INSTRUCTION" }, \
-	{ FAULT_FLAG_INTERRUPTIBLE,	"INTERRUPTIBLE" }
+	{ FAULT_FLAG_INTERRUPTIBLE,	"INTERRUPTIBLE" }, \
+	{ FAULT_FLAG_PAGE_SPLIT,	"PAGESPLIT" }
 
 /*
  * vm_fault is filled by the pagefault handler and passed to the vma's
diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h
index 6dfaf271ebf8..aa2d8d48ce3e 100644
--- a/include/linux/mm_types.h
+++ b/include/linux/mm_types.h
@@ -818,6 +818,8 @@  typedef struct {
  *                      mapped R/O.
  * @FAULT_FLAG_ORIG_PTE_VALID: whether the fault has vmf->orig_pte cached.
  *                        We should only access orig_pte if this flag set.
+ * @FAULT_FLAG_PAGE_SPLIT: The fault was due page size mismatch, split the
+ *                         region to smaller page size and retry.
  *
  * About @FAULT_FLAG_ALLOW_RETRY and @FAULT_FLAG_TRIED: we can specify
  * whether we would allow page faults to retry by specifying these two
@@ -855,6 +857,7 @@  enum fault_flag {
 	FAULT_FLAG_INTERRUPTIBLE =	1 << 9,
 	FAULT_FLAG_UNSHARE =		1 << 10,
 	FAULT_FLAG_ORIG_PTE_VALID =	1 << 11,
+	FAULT_FLAG_PAGE_SPLIT =		1 << 12,
 };
 
 typedef unsigned int __bitwise zap_flags_t;
diff --git a/mm/memory.c b/mm/memory.c
index 7274f2b52bca..c2187ffcbb8e 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -4945,6 +4945,15 @@  static vm_fault_t handle_pte_fault(struct vm_fault *vmf)
 	return 0;
 }
 
+static int handle_split_page_fault(struct vm_fault *vmf)
+{
+	if (!IS_ENABLED(CONFIG_AMD_MEM_ENCRYPT))
+		return VM_FAULT_SIGBUS;
+
+	__split_huge_pmd(vmf->vma, vmf->pmd, vmf->address, false, NULL);
+	return 0;
+}
+
 /*
  * By the time we get here, we already hold the mm semaphore
  *
@@ -5024,6 +5033,10 @@  static vm_fault_t __handle_mm_fault(struct vm_area_struct *vma,
 				pmd_migration_entry_wait(mm, vmf.pmd);
 			return 0;
 		}
+
+		if (flags & FAULT_FLAG_PAGE_SPLIT)
+			return handle_split_page_fault(&vmf);
+
 		if (pmd_trans_huge(vmf.orig_pmd) || pmd_devmap(vmf.orig_pmd)) {
 			if (pmd_protnone(vmf.orig_pmd) && vma_is_accessible(vma))
 				return do_huge_pmd_numa_page(&vmf);