diff mbox series

[v6,10/12] crypto: x86/aes - Prepare for a new AES implementation

Message ID 20230410225936.8940-11-chang.seok.bae@intel.com
State Superseded
Headers show
Series x86: Support Key Locker | expand

Commit Message

Chang S. Bae April 10, 2023, 10:59 p.m. UTC
Key Locker's AES instruction set ('AES-KL') has a similar programming
interface to AES-NI. The internal ABI in the assembly code will have the
same prototype as AES-NI. Then, the glue code will be the same as AES-NI's.

Refactor the common C code to avoid code duplication. The AES-NI code uses
it with a function pointer argument to call back the AES-NI assembly code.
So will the AES-KL code.

Introduce wrappers for data transformation functions to return an error
code. AES-KL may populate an error during data transformation.

Export those refactored functions and new wrappers that the AES-KL code
will reference.

Also move some constant values to a common assembly code.

No functional change intended.

Signed-off-by: Chang S. Bae <chang.seok.bae@intel.com>
Acked-by: Dan Williams <dan.j.williams@intel.com>
Cc: Herbert Xu <herbert@gondor.apana.org.au>
Cc: "David S. Miller" <davem@davemloft.net>
Cc: Thomas Gleixner <tglx@linutronix.de>
Cc: Ingo Molnar <mingo@redhat.com>
Cc: Borislav Petkov <bp@alien8.de>
Cc: Dave Hansen <dave.hansen@linux.intel.com>
Cc: "H. Peter Anvin" <hpa@zytor.com>
Cc: Ard Biesheuvel <ardb@kernel.org>
Cc: x86@kernel.org
Cc: linux-crypto@vger.kernel.org
Cc: linux-kernel@vger.kernel.org
---
Changes from v5:
* Clean up the staled function definition -- cbc_crypt_common().
* Ensure kernel_fpu_end() for the possible error return from
  xts_crypt_common()->crypt1_fn().

Changes from v4:
* Drop CBC mode changes. (Eric Biggers)

Changes from v3:
* Drop ECB and CTR mode changes. (Eric Biggers)
* Export symbols. (Eric Biggers)

Changes from RFC v2:
* Massage the changelog. (Dan Williams)

Changes from RFC v1:
* Added as a new patch. (Ard Biesheuvel)
  Link:
https://lore.kernel.org/lkml/CAMj1kXGa4f21eH0mdxd1pQsZMUjUr1Btq+Dgw-gC=O-yYft7xw@mail.gmail.com/
---
 arch/x86/crypto/Makefile           |   2 +-
 arch/x86/crypto/aes-intel_asm.S    |  26 ++++
 arch/x86/crypto/aes-intel_glue.c   | 127 ++++++++++++++++
 arch/x86/crypto/aes-intel_glue.h   |  44 ++++++
 arch/x86/crypto/aesni-intel_asm.S  |  58 +++----
 arch/x86/crypto/aesni-intel_glue.c | 235 +++++++++--------------------
 arch/x86/crypto/aesni-intel_glue.h |  17 +++
 7 files changed, 305 insertions(+), 204 deletions(-)
 create mode 100644 arch/x86/crypto/aes-intel_asm.S
 create mode 100644 arch/x86/crypto/aes-intel_glue.c
 create mode 100644 arch/x86/crypto/aes-intel_glue.h
 create mode 100644 arch/x86/crypto/aesni-intel_glue.h

Comments

Eric Biggers May 5, 2023, 11:27 p.m. UTC | #1
On Mon, Apr 10, 2023 at 03:59:34PM -0700, Chang S. Bae wrote:
> Refactor the common C code to avoid code duplication. The AES-NI code uses
> it with a function pointer argument to call back the AES-NI assembly code.
> So will the AES-KL code.

Actually, the AES-NI XTS glue code currently makes direct calls to the assembly
code.  This patch changes it to make indirect calls.  Indirect calls are very
expensive these days, partly due to all the speculative execution mitigations.
So this patch likely causes a performance regression.  How about making
xts_crypt_common() and xts_setkey_common() be inline functions?

Another issue with having the above be exported symbols is that their names are
too generic, so they could easily collide with another symbols in the kernel.
To be exported symbols, they would need something x86-specific in their names.

>  arch/x86/crypto/Makefile           |   2 +-
>  arch/x86/crypto/aes-intel_asm.S    |  26 ++++
>  arch/x86/crypto/aes-intel_glue.c   | 127 ++++++++++++++++
>  arch/x86/crypto/aes-intel_glue.h   |  44 ++++++
>  arch/x86/crypto/aesni-intel_asm.S  |  58 +++----
>  arch/x86/crypto/aesni-intel_glue.c | 235 +++++++++--------------------
>  arch/x86/crypto/aesni-intel_glue.h |  17 +++

It's confusing having aes-intel, aesni-intel, *and* aeskl-intel.  Maybe call the
first one "aes-helpers" or "aes-common" instead?

> +struct aes_xts_ctx {
> +	u8 raw_tweak_ctx[sizeof(struct crypto_aes_ctx)] AES_ALIGN_ATTR;
> +	u8 raw_crypt_ctx[sizeof(struct crypto_aes_ctx)] AES_ALIGN_ATTR;
> +};

This struct does not make sense.  It should look like:

    struct aes_xts_ctx {
            struct crypto_aes_ctx tweak_ctx AES_ALIGN_ATTR;
            struct crypto_aes_ctx crypt_ctx AES_ALIGN_ATTR;
    };

The runtime alignment to a 16-byte boundary should happen when translating the
raw crypto_skcipher_ctx() into the pointer to the aes_xts_ctx.  It should not
happen when accessing each individual field in the aes_xts_ctx.

>  /*
> - * int aesni_set_key(struct crypto_aes_ctx *ctx, const u8 *in_key,
> - *                   unsigned int key_len)
> + * int _aesni_set_key(struct crypto_aes_ctx *ctx, const u8 *in_key,
> + *                    unsigned int key_len)
>   */

It's conventional to use two leading underscores, not one.

- Eric
Chang S. Bae May 9, 2023, 12:55 a.m. UTC | #2
On 5/5/2023 4:27 PM, Eric Biggers wrote:
> On Mon, Apr 10, 2023 at 03:59:34PM -0700, Chang S. Bae wrote:
>> Refactor the common C code to avoid code duplication. The AES-NI code uses
>> it with a function pointer argument to call back the AES-NI assembly code.
>> So will the AES-KL code.
> 
> Actually, the AES-NI XTS glue code currently makes direct calls to the assembly
> code.  This patch changes it to make indirect calls.  Indirect calls are very
> expensive these days, partly due to all the speculative execution mitigations.
> So this patch likely causes a performance regression.  How about making
> xts_crypt_common() and xts_setkey_common() be inline functions?

I guess this series is relevant:
   https://lore.kernel.org/lkml/20201222160629.22268-1-ardb@kernel.org/#t

Yeah, inlining those looks to be just a cut-and-paste work. Then I was 
curious about the performance impact.

So I picked one of my old machines. Then, I was able to quickly run 
through with these cases:

$ cryptsetup benchmark -c aes-xts -s 256

# Tests are approximate using memory only (no storage IO).
# Algorithm |       Key |      Encryption |      Decryption

Upstream (6.4-rc1)
     aes-xts        256b      3949.3 MiB/s      4014.2 MiB/s
     aes-xts        256b      4016.1 MiB/s      4011.6 MiB/s
     aes-xts        256b      4026.2 MiB/s      4018.4 MiB/s
     aes-xts        256b      4009.2 MiB/s      4006.9 MiB/s
     aes-xts        256b      4025.0 MiB/s      4016.4 MiB/s

Upstream + V6
     aes-xts        256b      3876.1 MiB/s      3963.6 MiB/s
     aes-xts        256b      3926.3 MiB/s      3984.2 MiB/s
     aes-xts        256b      3940.8 MiB/s      3961.2 MiB/s
     aes-xts        256b      3929.7 MiB/s      3984.7 MiB/s
     aes-xts        256b      3892.5 MiB/s      3942.5 MiB/s

Upstream + V6 + {inlined helpers}
     aes-xts        256b      3996.9 MiB/s      4085.4 MiB/s
     aes-xts        256b      4087.6 MiB/s      4104.9 MiB/s
     aes-xts        256b      4117.9 MiB/s      4130.2 MiB/s
     aes-xts        256b      4098.4 MiB/s      4120.6 MiB/s
     aes-xts        256b      4095.5 MiB/s      4111.5 MiB/s

Okay, I'm a bit more convinced with this inlining. Will try to comment 
this along with the change.

> Another issue with having the above be exported symbols is that their names are
> too generic, so they could easily collide with another symbols in the kernel.
> To be exported symbols, they would need something x86-specific in their names.

I think that's another good point though, they don't need to be exported 
once moved into the header so that inlined.

>>   arch/x86/crypto/Makefile           |   2 +-
>>   arch/x86/crypto/aes-intel_asm.S    |  26 ++++
>>   arch/x86/crypto/aes-intel_glue.c   | 127 ++++++++++++++++
>>   arch/x86/crypto/aes-intel_glue.h   |  44 ++++++
>>   arch/x86/crypto/aesni-intel_asm.S  |  58 +++----
>>   arch/x86/crypto/aesni-intel_glue.c | 235 +++++++++--------------------
>>   arch/x86/crypto/aesni-intel_glue.h |  17 +++
> 
> It's confusing having aes-intel, aesni-intel, *and* aeskl-intel.  Maybe call the
> first one "aes-helpers" or "aes-common" instead?

Yeah, I can see a few files named with helper. So, maybe 
s/aes-intel/aes-helpers/.

>> +struct aes_xts_ctx {
>> +	u8 raw_tweak_ctx[sizeof(struct crypto_aes_ctx)] AES_ALIGN_ATTR;
>> +	u8 raw_crypt_ctx[sizeof(struct crypto_aes_ctx)] AES_ALIGN_ATTR;
>> +}; >
> This struct does not make sense.  It should look like:
> 
>      struct aes_xts_ctx {
>              struct crypto_aes_ctx tweak_ctx AES_ALIGN_ATTR;
>              struct crypto_aes_ctx crypt_ctx AES_ALIGN_ATTR;
>      };
> 
> The runtime alignment to a 16-byte boundary should happen when translating the
> raw crypto_skcipher_ctx() into the pointer to the aes_xts_ctx.  It should not
> happen when accessing each individual field in the aes_xts_ctx.

Oh, ugly. This came from mindless copy & paste here. I guess the fix 
could be a standalone patch. Or, it can be fixed along with this mess.

>>   /*
>> - * int aesni_set_key(struct crypto_aes_ctx *ctx, const u8 *in_key,
>> - *                   unsigned int key_len)
>> + * int _aesni_set_key(struct crypto_aes_ctx *ctx, const u8 *in_key,
>> + *                    unsigned int key_len)
>>    */
> 
> It's conventional to use two leading underscores, not one.

Yes, will fix.

Thanks,
Chang
Eric Biggers May 11, 2023, 9:39 p.m. UTC | #3
On Thu, May 11, 2023 at 12:05:17PM -0700, Chang S. Bae wrote:
> +
> +struct aes_xts_ctx {
> +	struct crypto_aes_ctx tweak_ctx AES_ALIGN_ATTR;
> +	struct crypto_aes_ctx crypt_ctx AES_ALIGN_ATTR;
> +};
> +
> +static inline struct crypto_aes_ctx *aes_ctx(void *raw_ctx)
> +{
> +	unsigned long addr = (unsigned long)raw_ctx;
> +	unsigned long align = AES_ALIGN;
> +
> +	if (align <= crypto_tfm_ctx_alignment())
> +		align = 1;
> +
> +	return (struct crypto_aes_ctx *)ALIGN(addr, align);
> +}

It seems you took my suggestion to fix the definition of struct aes_xts_ctx, but
you didn't make the corresponding change to the runtime alignment code.  There
should be a helper function aes_xts_ctx() that is used like:

    struct aes_xts_ctx *ctx = aes_xts_ctx(tfm);

It would do the runtime alignment.  Then, aes_ctx() should be removed.

- Eric
Chang S. Bae May 11, 2023, 11:19 p.m. UTC | #4
On 5/11/2023 2:39 PM, Eric Biggers wrote:
> On Thu, May 11, 2023 at 12:05:17PM -0700, Chang S. Bae wrote:
>> +
>> +struct aes_xts_ctx {
>> +	struct crypto_aes_ctx tweak_ctx AES_ALIGN_ATTR;
>> +	struct crypto_aes_ctx crypt_ctx AES_ALIGN_ATTR;
>> +};
>> +
>> +static inline struct crypto_aes_ctx *aes_ctx(void *raw_ctx)
>> +{
>> +	unsigned long addr = (unsigned long)raw_ctx;
>> +	unsigned long align = AES_ALIGN;
>> +
>> +	if (align <= crypto_tfm_ctx_alignment())
>> +		align = 1;
>> +
>> +	return (struct crypto_aes_ctx *)ALIGN(addr, align);
>> +}
> 
> It seems you took my suggestion to fix the definition of struct aes_xts_ctx, but
> you didn't make the corresponding change to the runtime alignment code.  

Sigh. This particular change was unintentionally leaked here from my WIP 
code. Yes, I'm aware of your comment:

 > The runtime alignment to a 16-byte boundary should happen when 
translating the
 > raw crypto_skcipher_ctx() into the pointer to the aes_xts_ctx.  It 
should not
 > happen when accessing each individual field in the aes_xts_ctx.

> There should be a helper function aes_xts_ctx() that is used like:
> 
>      struct aes_xts_ctx *ctx = aes_xts_ctx(tfm);
> 
> It would do the runtime alignment.  Then, aes_ctx() should be removed.

Yes, I could think of some changes like the one below. I guess the aeskl 
code can live with it. The aesni glue code still wants aes_cts() as it 
deals with other modes. Then, that can be left there as it is.

diff --git a/arch/x86/crypto/aes-intel_glue.h 
b/arch/x86/crypto/aes-intel_glue.h
index 5877d0988e36..b22de77594fe 100644
--- a/arch/x86/crypto/aes-intel_glue.h
+++ b/arch/x86/crypto/aes-intel_glue.h
@@ -31,15 +31,15 @@ struct aes_xts_ctx {
         struct crypto_aes_ctx crypt_ctx AES_ALIGN_ATTR;
  };

-static inline struct crypto_aes_ctx *aes_ctx(void *raw_ctx)
+static inline struct aes_xts_ctx *aes_xts_ctx(struct crypto_skcipher *tfm)
  {
-       unsigned long addr = (unsigned long)raw_ctx;
+       unsigned long addr = (unsigned long)crypto_skcipher_ctx(tfm);
         unsigned long align = AES_ALIGN;

         if (align <= crypto_tfm_ctx_alignment())
                 align = 1;

-       return (struct crypto_aes_ctx *)ALIGN(addr, align);
+       return (struct aes_xts_ctx *)ALIGN(addr, align);
  }

  static inline int
@@ -47,7 +47,7 @@ xts_setkey_common(struct crypto_skcipher *tfm, const 
u8 *key, unsigned int keyle
                   int (*fn)(struct crypto_tfm *tfm, void *ctx, const u8 
*in_key,
                             unsigned int key_len))
  {
-       struct aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
+       struct aes_xts_ctx *ctx = aes_xts_ctx(tfm);
         int err;

         err = xts_verify_key(tfm, key, keylen);
@@ -72,7 +72,7 @@ xts_crypt_common(struct skcipher_request *req,
                  int (*crypt1_fn)(const void *ctx, u8 *out, const u8 *in))
  {
         struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
-       struct aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
+       struct aes_xts_ctx *ctx = aes_xts_ctx(tfm);
         int tail = req->cryptlen % AES_BLOCK_SIZE;
         struct skcipher_request subreq;
         struct skcipher_walk walk;
@@ -108,7 +108,7 @@ xts_crypt_common(struct skcipher_request *req,
         kernel_fpu_begin();

         /* calculate first value of T */
-       err = crypt1_fn(aes_ctx(&ctx->tweak_ctx), walk.iv, walk.iv);
+       err = crypt1_fn(&ctx->tweak_ctx, walk.iv, walk.iv);
         if (err) {
                 kernel_fpu_end();
                 return err;
@@ -120,7 +120,7 @@ xts_crypt_common(struct skcipher_request *req,
                 if (nbytes < walk.total)
                         nbytes &= ~(AES_BLOCK_SIZE - 1);

-               err = crypt_fn(aes_ctx(&ctx->crypt_ctx), 
walk.dst.virt.addr, walk.src.virt.addr,
+               err = crypt_fn(&ctx->crypt_ctx, walk.dst.virt.addr, 
walk.src.virt.addr,
                                nbytes, walk.iv);
                 kernel_fpu_end();
                 if (err)
@@ -148,7 +148,7 @@ xts_crypt_common(struct skcipher_request *req,
                         return err;

                 kernel_fpu_begin();
-               err = crypt_fn(aes_ctx(&ctx->crypt_ctx), 
walk.dst.virt.addr, walk.src.virt.addr,
+               err = crypt_fn(&ctx->crypt_ctx, walk.dst.virt.addr, 
walk.src.virt.addr,
                                walk.nbytes, walk.iv);
                 kernel_fpu_end();
                 if (err)

Thanks,
Chang
diff mbox series

Patch

diff --git a/arch/x86/crypto/Makefile b/arch/x86/crypto/Makefile
index 9aa46093c91b..8469b0a09cb5 100644
--- a/arch/x86/crypto/Makefile
+++ b/arch/x86/crypto/Makefile
@@ -47,7 +47,7 @@  chacha-x86_64-y := chacha-avx2-x86_64.o chacha-ssse3-x86_64.o chacha_glue.o
 chacha-x86_64-$(CONFIG_AS_AVX512) += chacha-avx512vl-x86_64.o
 
 obj-$(CONFIG_CRYPTO_AES_NI_INTEL) += aesni-intel.o
-aesni-intel-y := aesni-intel_asm.o aesni-intel_glue.o
+aesni-intel-y := aesni-intel_asm.o aesni-intel_glue.o aes-intel_glue.o
 aesni-intel-$(CONFIG_64BIT) += aesni-intel_avx-x86_64.o aes_ctrby8_avx-x86_64.o
 
 obj-$(CONFIG_CRYPTO_SHA1_SSSE3) += sha1-ssse3.o
diff --git a/arch/x86/crypto/aes-intel_asm.S b/arch/x86/crypto/aes-intel_asm.S
new file mode 100644
index 000000000000..98abf875af79
--- /dev/null
+++ b/arch/x86/crypto/aes-intel_asm.S
@@ -0,0 +1,26 @@ 
+/* SPDX-License-Identifier: GPL-2.0-or-later */
+
+/*
+ * Constant values shared between AES implementations:
+ */
+
+.pushsection .rodata
+.align 16
+.Lcts_permute_table:
+	.byte		0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80
+	.byte		0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80
+	.byte		0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07
+	.byte		0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f
+	.byte		0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80
+	.byte		0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80
+#ifdef __x86_64__
+.Lbswap_mask:
+	.byte 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0
+#endif
+.popsection
+
+.section	.rodata.cst16.gf128mul_x_ble_mask, "aM", @progbits, 16
+.align 16
+.Lgf128mul_x_ble_mask:
+	.octa 0x00000000000000010000000000000087
+.previous
diff --git a/arch/x86/crypto/aes-intel_glue.c b/arch/x86/crypto/aes-intel_glue.c
new file mode 100644
index 000000000000..d485577e121b
--- /dev/null
+++ b/arch/x86/crypto/aes-intel_glue.c
@@ -0,0 +1,127 @@ 
+// SPDX-License-Identifier: GPL-2.0-or-later
+
+#include <linux/err.h>
+#include <crypto/algapi.h>
+#include <crypto/aes.h>
+#include <crypto/xts.h>
+#include <crypto/scatterwalk.h>
+#include <crypto/internal/aead.h>
+#include <crypto/internal/simd.h>
+#include "aes-intel_glue.h"
+
+int xts_setkey_common(struct crypto_skcipher *tfm, const u8 *key, unsigned int keylen,
+		      int (*fn)(struct crypto_tfm *tfm, void *raw_ctx,
+				const u8 *in_key, unsigned int key_len))
+{
+	struct aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
+	int err;
+
+	err = xts_verify_key(tfm, key, keylen);
+	if (err)
+		return err;
+
+	keylen /= 2;
+
+	/* first half of xts-key is for crypt */
+	err = fn(crypto_skcipher_tfm(tfm), ctx->raw_crypt_ctx, key, keylen);
+	if (err)
+		return err;
+
+	/* second half of xts-key is for tweak */
+	return fn(crypto_skcipher_tfm(tfm), ctx->raw_tweak_ctx, key + keylen, keylen);
+}
+EXPORT_SYMBOL_GPL(xts_setkey_common);
+
+int xts_crypt_common(struct skcipher_request *req,
+		     int (*crypt_fn)(const struct crypto_aes_ctx *ctx, u8 *out, const u8 *in,
+				     unsigned int len, u8 *iv),
+		     int (*crypt1_fn)(const void *ctx, u8 *out, const u8 *in))
+{
+	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
+	struct aes_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
+	int tail = req->cryptlen % AES_BLOCK_SIZE;
+	struct skcipher_request subreq;
+	struct skcipher_walk walk;
+	int err;
+
+	if (req->cryptlen < AES_BLOCK_SIZE)
+		return -EINVAL;
+
+	err = skcipher_walk_virt(&walk, req, false);
+	if (!walk.nbytes)
+		return err;
+
+	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
+		int blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
+
+		skcipher_walk_abort(&walk);
+
+		skcipher_request_set_tfm(&subreq, tfm);
+		skcipher_request_set_callback(&subreq,
+					      skcipher_request_flags(req),
+					      NULL, NULL);
+		skcipher_request_set_crypt(&subreq, req->src, req->dst,
+					   blocks * AES_BLOCK_SIZE, req->iv);
+		req = &subreq;
+
+		err = skcipher_walk_virt(&walk, req, false);
+		if (!walk.nbytes)
+			return err;
+	} else {
+		tail = 0;
+	}
+
+	kernel_fpu_begin();
+
+	/* calculate first value of T */
+	err = crypt1_fn(aes_ctx(ctx->raw_tweak_ctx), walk.iv, walk.iv);
+	if (err) {
+		kernel_fpu_end();
+		return err;
+	}
+
+	while (walk.nbytes > 0) {
+		int nbytes = walk.nbytes;
+
+		if (nbytes < walk.total)
+			nbytes &= ~(AES_BLOCK_SIZE - 1);
+
+		err = crypt_fn(aes_ctx(ctx->raw_crypt_ctx), walk.dst.virt.addr, walk.src.virt.addr,
+			       nbytes, walk.iv);
+		kernel_fpu_end();
+		if (err)
+			return err;
+
+		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
+
+		if (walk.nbytes > 0)
+			kernel_fpu_begin();
+	}
+
+	if (unlikely(tail > 0 && !err)) {
+		struct scatterlist sg_src[2], sg_dst[2];
+		struct scatterlist *src, *dst;
+
+		dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
+		if (req->dst != req->src)
+			dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
+
+		skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
+					   req->iv);
+
+		err = skcipher_walk_virt(&walk, &subreq, false);
+		if (err)
+			return err;
+
+		kernel_fpu_begin();
+		err = crypt_fn(aes_ctx(ctx->raw_crypt_ctx), walk.dst.virt.addr, walk.src.virt.addr,
+			       walk.nbytes, walk.iv);
+		kernel_fpu_end();
+		if (err)
+			return err;
+
+		err = skcipher_walk_done(&walk, 0);
+	}
+	return err;
+}
+EXPORT_SYMBOL_GPL(xts_crypt_common);
diff --git a/arch/x86/crypto/aes-intel_glue.h b/arch/x86/crypto/aes-intel_glue.h
new file mode 100644
index 000000000000..a5bd528ac072
--- /dev/null
+++ b/arch/x86/crypto/aes-intel_glue.h
@@ -0,0 +1,44 @@ 
+/* SPDX-License-Identifier: GPL-2.0 */
+
+/*
+ * Shared glue code between AES implementations, refactored from the AES-NI's.
+ */
+
+#ifndef _AES_INTEL_GLUE_H
+#define _AES_INTEL_GLUE_H
+
+#include <crypto/aes.h>
+
+#define AES_ALIGN			16
+#define AES_ALIGN_ATTR			__attribute__((__aligned__(AES_ALIGN)))
+#define AES_BLOCK_MASK			(~(AES_BLOCK_SIZE - 1))
+#define AES_ALIGN_EXTRA			((AES_ALIGN - 1) & ~(CRYPTO_MINALIGN - 1))
+#define CRYPTO_AES_CTX_SIZE		(sizeof(struct crypto_aes_ctx) + AES_ALIGN_EXTRA)
+#define XTS_AES_CTX_SIZE		(sizeof(struct aes_xts_ctx) + AES_ALIGN_EXTRA)
+
+struct aes_xts_ctx {
+	u8 raw_tweak_ctx[sizeof(struct crypto_aes_ctx)] AES_ALIGN_ATTR;
+	u8 raw_crypt_ctx[sizeof(struct crypto_aes_ctx)] AES_ALIGN_ATTR;
+};
+
+static inline struct crypto_aes_ctx *aes_ctx(void *raw_ctx)
+{
+	unsigned long addr = (unsigned long)raw_ctx;
+	unsigned long align = AES_ALIGN;
+
+	if (align <= crypto_tfm_ctx_alignment())
+		align = 1;
+
+	return (struct crypto_aes_ctx *)ALIGN(addr, align);
+}
+
+int xts_setkey_common(struct crypto_skcipher *tfm, const u8 *key, unsigned int keylen,
+		      int (*fn)(struct crypto_tfm *tfm, void *raw_ctx,
+				const u8 *in_key, unsigned int key_len));
+
+int xts_crypt_common(struct skcipher_request *req,
+		     int (*crypt_fn)(const struct crypto_aes_ctx *ctx, u8 *out, const u8 *in,
+				     unsigned int len, u8 *iv),
+		     int (*crypt1_fn)(const void *ctx, u8 *out, const u8 *in));
+
+#endif /* _AES_INTEL_GLUE_H */
diff --git a/arch/x86/crypto/aesni-intel_asm.S b/arch/x86/crypto/aesni-intel_asm.S
index 837c1e0aa021..9ea4f13464d5 100644
--- a/arch/x86/crypto/aesni-intel_asm.S
+++ b/arch/x86/crypto/aesni-intel_asm.S
@@ -28,6 +28,7 @@ 
 #include <linux/linkage.h>
 #include <asm/frame.h>
 #include <asm/nospec-branch.h>
+#include "aes-intel_asm.S"
 
 /*
  * The following macros are used to move an (un)aligned 16 byte value to/from
@@ -1820,10 +1821,10 @@  SYM_FUNC_START_LOCAL(_key_expansion_256b)
 SYM_FUNC_END(_key_expansion_256b)
 
 /*
- * int aesni_set_key(struct crypto_aes_ctx *ctx, const u8 *in_key,
- *                   unsigned int key_len)
+ * int _aesni_set_key(struct crypto_aes_ctx *ctx, const u8 *in_key,
+ *                    unsigned int key_len)
  */
-SYM_FUNC_START(aesni_set_key)
+SYM_FUNC_START(_aesni_set_key)
 	FRAME_BEGIN
 #ifndef __x86_64__
 	pushl KEYP
@@ -1932,12 +1933,12 @@  SYM_FUNC_START(aesni_set_key)
 #endif
 	FRAME_END
 	RET
-SYM_FUNC_END(aesni_set_key)
+SYM_FUNC_END(_aesni_set_key)
 
 /*
- * void aesni_enc(const void *ctx, u8 *dst, const u8 *src)
+ * void _aesni_enc(const void *ctx, u8 *dst, const u8 *src)
  */
-SYM_FUNC_START(aesni_enc)
+SYM_FUNC_START(_aesni_enc)
 	FRAME_BEGIN
 #ifndef __x86_64__
 	pushl KEYP
@@ -1956,7 +1957,7 @@  SYM_FUNC_START(aesni_enc)
 #endif
 	FRAME_END
 	RET
-SYM_FUNC_END(aesni_enc)
+SYM_FUNC_END(_aesni_enc)
 
 /*
  * _aesni_enc1:		internal ABI
@@ -2124,9 +2125,9 @@  SYM_FUNC_START_LOCAL(_aesni_enc4)
 SYM_FUNC_END(_aesni_enc4)
 
 /*
- * void aesni_dec (const void *ctx, u8 *dst, const u8 *src)
+ * void _aesni_dec (const void *ctx, u8 *dst, const u8 *src)
  */
-SYM_FUNC_START(aesni_dec)
+SYM_FUNC_START(_aesni_dec)
 	FRAME_BEGIN
 #ifndef __x86_64__
 	pushl KEYP
@@ -2146,7 +2147,7 @@  SYM_FUNC_START(aesni_dec)
 #endif
 	FRAME_END
 	RET
-SYM_FUNC_END(aesni_dec)
+SYM_FUNC_END(_aesni_dec)
 
 /*
  * _aesni_dec1:		internal ABI
@@ -2689,21 +2690,6 @@  SYM_FUNC_START(aesni_cts_cbc_dec)
 	RET
 SYM_FUNC_END(aesni_cts_cbc_dec)
 
-.pushsection .rodata
-.align 16
-.Lcts_permute_table:
-	.byte		0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80
-	.byte		0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80
-	.byte		0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07
-	.byte		0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x0f
-	.byte		0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80
-	.byte		0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80, 0x80
-#ifdef __x86_64__
-.Lbswap_mask:
-	.byte 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0
-#endif
-.popsection
-
 #ifdef __x86_64__
 /*
  * _aesni_inc_init:	internal ABI
@@ -2819,12 +2805,6 @@  SYM_FUNC_END(aesni_ctr_enc)
 
 #endif
 
-.section	.rodata.cst16.gf128mul_x_ble_mask, "aM", @progbits, 16
-.align 16
-.Lgf128mul_x_ble_mask:
-	.octa 0x00000000000000010000000000000087
-.previous
-
 /*
  * _aesni_gf128mul_x_ble:		internal ABI
  *	Multiply in GF(2^128) for XTS IVs
@@ -2844,10 +2824,10 @@  SYM_FUNC_END(aesni_ctr_enc)
 	pxor KEY, IV;
 
 /*
- * void aesni_xts_encrypt(const struct crypto_aes_ctx *ctx, u8 *dst,
- *			  const u8 *src, unsigned int len, le128 *iv)
+ * void _aesni_xts_encrypt(const struct crypto_aes_ctx *ctx, u8 *dst,
+ *			   const u8 *src, unsigned int len, le128 *iv)
  */
-SYM_FUNC_START(aesni_xts_encrypt)
+SYM_FUNC_START(_aesni_xts_encrypt)
 	FRAME_BEGIN
 #ifndef __x86_64__
 	pushl IVP
@@ -2996,13 +2976,13 @@  SYM_FUNC_START(aesni_xts_encrypt)
 
 	movups STATE, (OUTP)
 	jmp .Lxts_enc_ret
-SYM_FUNC_END(aesni_xts_encrypt)
+SYM_FUNC_END(_aesni_xts_encrypt)
 
 /*
- * void aesni_xts_decrypt(const struct crypto_aes_ctx *ctx, u8 *dst,
- *			  const u8 *src, unsigned int len, le128 *iv)
+ * void _aesni_xts_decrypt(const struct crypto_aes_ctx *ctx, u8 *dst,
+ *			   const u8 *src, unsigned int len, le128 *iv)
  */
-SYM_FUNC_START(aesni_xts_decrypt)
+SYM_FUNC_START(_aesni_xts_decrypt)
 	FRAME_BEGIN
 #ifndef __x86_64__
 	pushl IVP
@@ -3158,4 +3138,4 @@  SYM_FUNC_START(aesni_xts_decrypt)
 
 	movups STATE, (OUTP)
 	jmp .Lxts_dec_ret
-SYM_FUNC_END(aesni_xts_decrypt)
+SYM_FUNC_END(_aesni_xts_decrypt)
diff --git a/arch/x86/crypto/aesni-intel_glue.c b/arch/x86/crypto/aesni-intel_glue.c
index a5b0cb3efeba..e7a9e8e01680 100644
--- a/arch/x86/crypto/aesni-intel_glue.c
+++ b/arch/x86/crypto/aesni-intel_glue.c
@@ -36,33 +36,24 @@ 
 #include <linux/spinlock.h>
 #include <linux/static_call.h>
 
+#include "aes-intel_glue.h"
+#include "aesni-intel_glue.h"
 
-#define AESNI_ALIGN	16
-#define AESNI_ALIGN_ATTR __attribute__ ((__aligned__(AESNI_ALIGN)))
-#define AES_BLOCK_MASK	(~(AES_BLOCK_SIZE - 1))
 #define RFC4106_HASH_SUBKEY_SIZE 16
-#define AESNI_ALIGN_EXTRA ((AESNI_ALIGN - 1) & ~(CRYPTO_MINALIGN - 1))
-#define CRYPTO_AES_CTX_SIZE (sizeof(struct crypto_aes_ctx) + AESNI_ALIGN_EXTRA)
-#define XTS_AES_CTX_SIZE (sizeof(struct aesni_xts_ctx) + AESNI_ALIGN_EXTRA)
 
 /* This data is stored at the end of the crypto_tfm struct.
  * It's a type of per "session" data storage location.
  * This needs to be 16 byte aligned.
  */
 struct aesni_rfc4106_gcm_ctx {
-	u8 hash_subkey[16] AESNI_ALIGN_ATTR;
-	struct crypto_aes_ctx aes_key_expanded AESNI_ALIGN_ATTR;
+	u8 hash_subkey[16] AES_ALIGN_ATTR;
+	struct crypto_aes_ctx aes_key_expanded AES_ALIGN_ATTR;
 	u8 nonce[4];
 };
 
 struct generic_gcmaes_ctx {
-	u8 hash_subkey[16] AESNI_ALIGN_ATTR;
-	struct crypto_aes_ctx aes_key_expanded AESNI_ALIGN_ATTR;
-};
-
-struct aesni_xts_ctx {
-	u8 raw_tweak_ctx[sizeof(struct crypto_aes_ctx)] AESNI_ALIGN_ATTR;
-	u8 raw_crypt_ctx[sizeof(struct crypto_aes_ctx)] AESNI_ALIGN_ATTR;
+	u8 hash_subkey[16] AES_ALIGN_ATTR;
+	struct crypto_aes_ctx aes_key_expanded AES_ALIGN_ATTR;
 };
 
 #define GCM_BLOCK_LEN 16
@@ -80,10 +71,10 @@  struct gcm_context_data {
 	u8 hash_keys[GCM_BLOCK_LEN * 16];
 };
 
-asmlinkage int aesni_set_key(struct crypto_aes_ctx *ctx, const u8 *in_key,
-			     unsigned int key_len);
-asmlinkage void aesni_enc(const void *ctx, u8 *out, const u8 *in);
-asmlinkage void aesni_dec(const void *ctx, u8 *out, const u8 *in);
+asmlinkage int _aesni_set_key(struct crypto_aes_ctx *ctx, const u8 *in_key,
+			      unsigned int key_len);
+asmlinkage void _aesni_enc(const void *ctx, u8 *out, const u8 *in);
+asmlinkage void _aesni_dec(const void *ctx, u8 *out, const u8 *in);
 asmlinkage void aesni_ecb_enc(struct crypto_aes_ctx *ctx, u8 *out,
 			      const u8 *in, unsigned int len);
 asmlinkage void aesni_ecb_dec(struct crypto_aes_ctx *ctx, u8 *out,
@@ -97,14 +88,51 @@  asmlinkage void aesni_cts_cbc_enc(struct crypto_aes_ctx *ctx, u8 *out,
 asmlinkage void aesni_cts_cbc_dec(struct crypto_aes_ctx *ctx, u8 *out,
 				  const u8 *in, unsigned int len, u8 *iv);
 
+int aesni_set_key(struct crypto_aes_ctx *ctx, const u8 *in_key,
+		  unsigned int key_len)
+{
+	return _aesni_set_key(ctx, in_key, key_len);
+}
+EXPORT_SYMBOL_GPL(aesni_set_key);
+
+int aesni_enc(const void *ctx, u8 *out, const u8 *in)
+{
+	_aesni_enc(ctx, out, in);
+	return 0;
+}
+EXPORT_SYMBOL_GPL(aesni_enc);
+
+int aesni_dec(const void *ctx, u8 *out, const u8 *in)
+{
+	_aesni_dec(ctx, out, in);
+	return 0;
+}
+EXPORT_SYMBOL_GPL(aesni_dec);
+
 #define AVX_GEN2_OPTSIZE 640
 #define AVX_GEN4_OPTSIZE 4096
 
-asmlinkage void aesni_xts_encrypt(const struct crypto_aes_ctx *ctx, u8 *out,
-				  const u8 *in, unsigned int len, u8 *iv);
+asmlinkage void _aesni_xts_encrypt(const struct crypto_aes_ctx *ctx, u8 *out,
+				   const u8 *in, unsigned int len, u8 *iv);
 
-asmlinkage void aesni_xts_decrypt(const struct crypto_aes_ctx *ctx, u8 *out,
-				  const u8 *in, unsigned int len, u8 *iv);
+asmlinkage void _aesni_xts_decrypt(const struct crypto_aes_ctx *ctx, u8 *out,
+				   const u8 *in, unsigned int len, u8 *iv);
+
+int aesni_xts_encrypt(const struct crypto_aes_ctx *ctx, u8 *out, const u8 *in,
+		      unsigned int len, u8 *iv)
+{
+	_aesni_xts_encrypt(ctx, out, in, len, iv);
+	return 0;
+}
+EXPORT_SYMBOL_GPL(aesni_xts_encrypt);
+
+int aesni_xts_decrypt(const struct crypto_aes_ctx *ctx, u8 *out, const u8 *in,
+		      unsigned int len, u8 *iv)
+{
+	_aesni_xts_decrypt(ctx, out, in, len, iv);
+	return 0;
+}
+EXPORT_SYMBOL_GPL(aesni_xts_decrypt);
 
 #ifdef CONFIG_X86_64
 
@@ -201,7 +229,7 @@  static __ro_after_init DEFINE_STATIC_KEY_FALSE(gcm_use_avx2);
 static inline struct
 aesni_rfc4106_gcm_ctx *aesni_rfc4106_gcm_ctx_get(struct crypto_aead *tfm)
 {
-	unsigned long align = AESNI_ALIGN;
+	unsigned long align = AES_ALIGN;
 
 	if (align <= crypto_tfm_ctx_alignment())
 		align = 1;
@@ -211,7 +239,7 @@  aesni_rfc4106_gcm_ctx *aesni_rfc4106_gcm_ctx_get(struct crypto_aead *tfm)
 static inline struct
 generic_gcmaes_ctx *generic_gcmaes_ctx_get(struct crypto_aead *tfm)
 {
-	unsigned long align = AESNI_ALIGN;
+	unsigned long align = AES_ALIGN;
 
 	if (align <= crypto_tfm_ctx_alignment())
 		align = 1;
@@ -219,16 +247,6 @@  generic_gcmaes_ctx *generic_gcmaes_ctx_get(struct crypto_aead *tfm)
 }
 #endif
 
-static inline struct crypto_aes_ctx *aes_ctx(void *raw_ctx)
-{
-	unsigned long addr = (unsigned long)raw_ctx;
-	unsigned long align = AESNI_ALIGN;
-
-	if (align <= crypto_tfm_ctx_alignment())
-		align = 1;
-	return (struct crypto_aes_ctx *)ALIGN(addr, align);
-}
-
 static int aes_set_key_common(struct crypto_tfm *tfm, void *raw_ctx,
 			      const u8 *in_key, unsigned int key_len)
 {
@@ -243,7 +261,7 @@  static int aes_set_key_common(struct crypto_tfm *tfm, void *raw_ctx,
 		err = aes_expandkey(ctx, in_key, key_len);
 	else {
 		kernel_fpu_begin();
-		err = aesni_set_key(ctx, in_key, key_len);
+		err = _aesni_set_key(ctx, in_key, key_len);
 		kernel_fpu_end();
 	}
 
@@ -264,7 +282,7 @@  static void aesni_encrypt(struct crypto_tfm *tfm, u8 *dst, const u8 *src)
 		aes_encrypt(ctx, dst, src);
 	} else {
 		kernel_fpu_begin();
-		aesni_enc(ctx, dst, src);
+		_aesni_enc(ctx, dst, src);
 		kernel_fpu_end();
 	}
 }
@@ -277,7 +295,7 @@  static void aesni_decrypt(struct crypto_tfm *tfm, u8 *dst, const u8 *src)
 		aes_decrypt(ctx, dst, src);
 	} else {
 		kernel_fpu_begin();
-		aesni_dec(ctx, dst, src);
+		_aesni_dec(ctx, dst, src);
 		kernel_fpu_end();
 	}
 }
@@ -491,7 +509,7 @@  static int cts_cbc_decrypt(struct skcipher_request *req)
 
 #ifdef CONFIG_X86_64
 static void aesni_ctr_enc_avx_tfm(struct crypto_aes_ctx *ctx, u8 *out,
-			      const u8 *in, unsigned int len, u8 *iv)
+				 const u8 *in, unsigned int len, u8 *iv)
 {
 	/*
 	 * based on key length, override with the by8 version
@@ -528,7 +546,7 @@  static int ctr_crypt(struct skcipher_request *req)
 		nbytes &= ~AES_BLOCK_MASK;
 
 		if (walk.nbytes == walk.total && nbytes > 0) {
-			aesni_enc(ctx, keystream, walk.iv);
+			_aesni_enc(ctx, keystream, walk.iv);
 			crypto_xor_cpy(walk.dst.virt.addr + walk.nbytes - nbytes,
 				       walk.src.virt.addr + walk.nbytes - nbytes,
 				       keystream, nbytes);
@@ -673,8 +691,8 @@  static int gcmaes_crypt_by_sg(bool enc, struct aead_request *req,
 			      u8 *iv, void *aes_ctx, u8 *auth_tag,
 			      unsigned long auth_tag_len)
 {
-	u8 databuf[sizeof(struct gcm_context_data) + (AESNI_ALIGN - 8)] __aligned(8);
-	struct gcm_context_data *data = PTR_ALIGN((void *)databuf, AESNI_ALIGN);
+	u8 databuf[sizeof(struct gcm_context_data) + (AES_ALIGN - 8)] __aligned(8);
+	struct gcm_context_data *data = PTR_ALIGN((void *)databuf, AES_ALIGN);
 	unsigned long left = req->cryptlen;
 	struct scatter_walk assoc_sg_walk;
 	struct skcipher_walk walk;
@@ -829,8 +847,8 @@  static int helper_rfc4106_encrypt(struct aead_request *req)
 	struct crypto_aead *tfm = crypto_aead_reqtfm(req);
 	struct aesni_rfc4106_gcm_ctx *ctx = aesni_rfc4106_gcm_ctx_get(tfm);
 	void *aes_ctx = &(ctx->aes_key_expanded);
-	u8 ivbuf[16 + (AESNI_ALIGN - 8)] __aligned(8);
-	u8 *iv = PTR_ALIGN(&ivbuf[0], AESNI_ALIGN);
+	u8 ivbuf[16 + (AES_ALIGN - 8)] __aligned(8);
+	u8 *iv = PTR_ALIGN(&ivbuf[0], AES_ALIGN);
 	unsigned int i;
 	__be32 counter = cpu_to_be32(1);
 
@@ -857,8 +875,8 @@  static int helper_rfc4106_decrypt(struct aead_request *req)
 	struct crypto_aead *tfm = crypto_aead_reqtfm(req);
 	struct aesni_rfc4106_gcm_ctx *ctx = aesni_rfc4106_gcm_ctx_get(tfm);
 	void *aes_ctx = &(ctx->aes_key_expanded);
-	u8 ivbuf[16 + (AESNI_ALIGN - 8)] __aligned(8);
-	u8 *iv = PTR_ALIGN(&ivbuf[0], AESNI_ALIGN);
+	u8 ivbuf[16 + (AES_ALIGN - 8)] __aligned(8);
+	u8 *iv = PTR_ALIGN(&ivbuf[0], AES_ALIGN);
 	unsigned int i;
 
 	if (unlikely(req->assoclen != 16 && req->assoclen != 20))
@@ -883,128 +901,17 @@  static int helper_rfc4106_decrypt(struct aead_request *req)
 static int xts_aesni_setkey(struct crypto_skcipher *tfm, const u8 *key,
 			    unsigned int keylen)
 {
-	struct aesni_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
-	int err;
-
-	err = xts_verify_key(tfm, key, keylen);
-	if (err)
-		return err;
-
-	keylen /= 2;
-
-	/* first half of xts-key is for crypt */
-	err = aes_set_key_common(crypto_skcipher_tfm(tfm), ctx->raw_crypt_ctx,
-				 key, keylen);
-	if (err)
-		return err;
-
-	/* second half of xts-key is for tweak */
-	return aes_set_key_common(crypto_skcipher_tfm(tfm), ctx->raw_tweak_ctx,
-				  key + keylen, keylen);
-}
-
-static int xts_crypt(struct skcipher_request *req, bool encrypt)
-{
-	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
-	struct aesni_xts_ctx *ctx = crypto_skcipher_ctx(tfm);
-	int tail = req->cryptlen % AES_BLOCK_SIZE;
-	struct skcipher_request subreq;
-	struct skcipher_walk walk;
-	int err;
-
-	if (req->cryptlen < AES_BLOCK_SIZE)
-		return -EINVAL;
-
-	err = skcipher_walk_virt(&walk, req, false);
-	if (!walk.nbytes)
-		return err;
-
-	if (unlikely(tail > 0 && walk.nbytes < walk.total)) {
-		int blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
-
-		skcipher_walk_abort(&walk);
-
-		skcipher_request_set_tfm(&subreq, tfm);
-		skcipher_request_set_callback(&subreq,
-					      skcipher_request_flags(req),
-					      NULL, NULL);
-		skcipher_request_set_crypt(&subreq, req->src, req->dst,
-					   blocks * AES_BLOCK_SIZE, req->iv);
-		req = &subreq;
-
-		err = skcipher_walk_virt(&walk, req, false);
-		if (!walk.nbytes)
-			return err;
-	} else {
-		tail = 0;
-	}
-
-	kernel_fpu_begin();
-
-	/* calculate first value of T */
-	aesni_enc(aes_ctx(ctx->raw_tweak_ctx), walk.iv, walk.iv);
-
-	while (walk.nbytes > 0) {
-		int nbytes = walk.nbytes;
-
-		if (nbytes < walk.total)
-			nbytes &= ~(AES_BLOCK_SIZE - 1);
-
-		if (encrypt)
-			aesni_xts_encrypt(aes_ctx(ctx->raw_crypt_ctx),
-					  walk.dst.virt.addr, walk.src.virt.addr,
-					  nbytes, walk.iv);
-		else
-			aesni_xts_decrypt(aes_ctx(ctx->raw_crypt_ctx),
-					  walk.dst.virt.addr, walk.src.virt.addr,
-					  nbytes, walk.iv);
-		kernel_fpu_end();
-
-		err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
-
-		if (walk.nbytes > 0)
-			kernel_fpu_begin();
-	}
-
-	if (unlikely(tail > 0 && !err)) {
-		struct scatterlist sg_src[2], sg_dst[2];
-		struct scatterlist *src, *dst;
-
-		dst = src = scatterwalk_ffwd(sg_src, req->src, req->cryptlen);
-		if (req->dst != req->src)
-			dst = scatterwalk_ffwd(sg_dst, req->dst, req->cryptlen);
-
-		skcipher_request_set_crypt(req, src, dst, AES_BLOCK_SIZE + tail,
-					   req->iv);
-
-		err = skcipher_walk_virt(&walk, &subreq, false);
-		if (err)
-			return err;
-
-		kernel_fpu_begin();
-		if (encrypt)
-			aesni_xts_encrypt(aes_ctx(ctx->raw_crypt_ctx),
-					  walk.dst.virt.addr, walk.src.virt.addr,
-					  walk.nbytes, walk.iv);
-		else
-			aesni_xts_decrypt(aes_ctx(ctx->raw_crypt_ctx),
-					  walk.dst.virt.addr, walk.src.virt.addr,
-					  walk.nbytes, walk.iv);
-		kernel_fpu_end();
-
-		err = skcipher_walk_done(&walk, 0);
-	}
-	return err;
+	return xts_setkey_common(tfm, key, keylen, aes_set_key_common);
 }
 
 static int xts_encrypt(struct skcipher_request *req)
 {
-	return xts_crypt(req, true);
+	return xts_crypt_common(req, aesni_xts_encrypt, aesni_enc);
 }
 
 static int xts_decrypt(struct skcipher_request *req)
 {
-	return xts_crypt(req, false);
+	return xts_crypt_common(req, aesni_xts_decrypt, aesni_enc);
 }
 
 static struct crypto_alg aesni_cipher_alg = {
@@ -1160,8 +1067,8 @@  static int generic_gcmaes_encrypt(struct aead_request *req)
 	struct crypto_aead *tfm = crypto_aead_reqtfm(req);
 	struct generic_gcmaes_ctx *ctx = generic_gcmaes_ctx_get(tfm);
 	void *aes_ctx = &(ctx->aes_key_expanded);
-	u8 ivbuf[16 + (AESNI_ALIGN - 8)] __aligned(8);
-	u8 *iv = PTR_ALIGN(&ivbuf[0], AESNI_ALIGN);
+	u8 ivbuf[16 + (AES_ALIGN - 8)] __aligned(8);
+	u8 *iv = PTR_ALIGN(&ivbuf[0], AES_ALIGN);
 	__be32 counter = cpu_to_be32(1);
 
 	memcpy(iv, req->iv, 12);
@@ -1177,8 +1084,8 @@  static int generic_gcmaes_decrypt(struct aead_request *req)
 	struct crypto_aead *tfm = crypto_aead_reqtfm(req);
 	struct generic_gcmaes_ctx *ctx = generic_gcmaes_ctx_get(tfm);
 	void *aes_ctx = &(ctx->aes_key_expanded);
-	u8 ivbuf[16 + (AESNI_ALIGN - 8)] __aligned(8);
-	u8 *iv = PTR_ALIGN(&ivbuf[0], AESNI_ALIGN);
+	u8 ivbuf[16 + (AES_ALIGN - 8)] __aligned(8);
+	u8 *iv = PTR_ALIGN(&ivbuf[0], AES_ALIGN);
 
 	memcpy(iv, req->iv, 12);
 	*((__be32 *)(iv+12)) = counter;
diff --git a/arch/x86/crypto/aesni-intel_glue.h b/arch/x86/crypto/aesni-intel_glue.h
new file mode 100644
index 000000000000..81ecacb4e54c
--- /dev/null
+++ b/arch/x86/crypto/aesni-intel_glue.h
@@ -0,0 +1,17 @@ 
+/* SPDX-License-Identifier: GPL-2.0 */
+
+/*
+ * Support for Intel AES-NI instructions. This file contains function
+ * prototypes to be referenced for other AES implementations
+ */
+
+int aesni_set_key(struct crypto_aes_ctx *ctx, const u8 *in_key, unsigned int key_len);
+
+int aesni_enc(const void *ctx, u8 *out, const u8 *in);
+int aesni_dec(const void *ctx, u8 *out, const u8 *in);
+
+int aesni_xts_encrypt(const struct crypto_aes_ctx *ctx, u8 *out, const u8 *in,
+		      unsigned int len, u8 *iv);
+int aesni_xts_decrypt(const struct crypto_aes_ctx *ctx, u8 *out, const u8 *in,
+		      unsigned int len, u8 *iv);
+