diff mbox series

[2/2] crypto: arm64/gcm-ce - implement 4 way interleave

Message ID 20190910231900.25445-3-ard.biesheuvel@linaro.org
State Accepted
Commit 11031c0d7d6e9bca0df233a8acfd6708d2b89470
Headers show
Series crypto: faster GCM-AES for arm64 | expand

Commit Message

Ard Biesheuvel Sept. 10, 2019, 11:19 p.m. UTC
To improve performance on cores with deep pipelines such as ThunderX2,
reimplement gcm(aes) using a 4-way interleave rather than the 2-way
interleave we use currently.

This comes down to a complete rewrite of the GCM part of the combined
GCM/GHASH driver, and instead of interleaving two invocations of AES
with the GHASH handling at the instruction level, the new version
uses a more coarse grained approach where each chunk of 64 bytes is
encrypted first and then ghashed (or ghashed and then decrypted in
the converse case).

The core NEON routine is now able to consume inputs of any size,
and tail blocks of less than 64 bytes are handled using overlapping
loads and stores, and processed by the same 4-way encryption and
hashing routines. This gets rid of most of the branches, and avoids
having to return to the C code to handle the tail block using a
stack buffer.

The table below compares the performance of the old driver and the new
one on various micro-architectures and running in various modes.

        |     AES-128      |     AES-192      |     AES-256      |
 #bytes | 512 | 1500 |  4k | 512 | 1500 |  4k | 512 | 1500 |  4k |
 -------+-----+------+-----+-----+------+-----+-----+------+-----+
    TX2 | 35% |  23% | 11% | 34% |  20% |  9% | 38% |  25% | 16% |
   EMAG | 11% |   6% |  3% | 12% |   4% |  2% | 11% |   4% |  2% |
    A72 |  8% |   5% | -4% |  9% |   4% | -5% |  7% |   4% | -5% |
    A53 | 11% |   6% | -1% | 10% |   8% | -1% | 10% |   8% | -2% |

Signed-off-by: Ard Biesheuvel <ard.biesheuvel@linaro.org>

---
 arch/arm64/crypto/ghash-ce-core.S | 501 ++++++++++++++------
 arch/arm64/crypto/ghash-ce-glue.c | 293 +++++-------
 2 files changed, 467 insertions(+), 327 deletions(-)

-- 
2.17.1
diff mbox series

Patch

diff --git a/arch/arm64/crypto/ghash-ce-core.S b/arch/arm64/crypto/ghash-ce-core.S
index 410e8afcf5a7..a791c4adf8e6 100644
--- a/arch/arm64/crypto/ghash-ce-core.S
+++ b/arch/arm64/crypto/ghash-ce-core.S
@@ -13,8 +13,8 @@ 
 	T1		.req	v2
 	T2		.req	v3
 	MASK		.req	v4
-	XL		.req	v5
-	XM		.req	v6
+	XM		.req	v5
+	XL		.req	v6
 	XH		.req	v7
 	IN1		.req	v7
 
@@ -358,20 +358,37 @@  ENTRY(pmull_ghash_update_p8)
 	__pmull_ghash	p8
 ENDPROC(pmull_ghash_update_p8)
 
-	KS0		.req	v12
-	KS1		.req	v13
-	INP0		.req	v14
-	INP1		.req	v15
-
-	.macro		load_round_keys, rounds, rk
-	cmp		\rounds, #12
-	blo		2222f		/* 128 bits */
-	beq		1111f		/* 192 bits */
-	ld1		{v17.4s-v18.4s}, [\rk], #32
-1111:	ld1		{v19.4s-v20.4s}, [\rk], #32
-2222:	ld1		{v21.4s-v24.4s}, [\rk], #64
-	ld1		{v25.4s-v28.4s}, [\rk], #64
-	ld1		{v29.4s-v31.4s}, [\rk]
+	KS0		.req	v8
+	KS1		.req	v9
+	KS2		.req	v10
+	KS3		.req	v11
+
+	INP0		.req	v21
+	INP1		.req	v22
+	INP2		.req	v23
+	INP3		.req	v24
+
+	K0		.req	v25
+	K1		.req	v26
+	K2		.req	v27
+	K3		.req	v28
+	K4		.req	v12
+	K5		.req	v13
+	K6		.req	v4
+	K7		.req	v5
+	K8		.req	v14
+	K9		.req	v15
+	KK		.req	v29
+	KL		.req	v30
+	KM		.req	v31
+
+	.macro		load_round_keys, rounds, rk, tmp
+	add		\tmp, \rk, #64
+	ld1		{K0.4s-K3.4s}, [\rk]
+	ld1		{K4.4s-K5.4s}, [\tmp]
+	add		\tmp, \rk, \rounds, lsl #4
+	sub		\tmp, \tmp, #32
+	ld1		{KK.4s-KM.4s}, [\tmp]
 	.endm
 
 	.macro		enc_round, state, key
@@ -379,197 +396,367 @@  ENDPROC(pmull_ghash_update_p8)
 	aesmc		\state\().16b, \state\().16b
 	.endm
 
-	.macro		enc_block, state, rounds
-	cmp		\rounds, #12
-	b.lo		2222f		/* 128 bits */
-	b.eq		1111f		/* 192 bits */
-	enc_round	\state, v17
-	enc_round	\state, v18
-1111:	enc_round	\state, v19
-	enc_round	\state, v20
-2222:	.irp		key, v21, v22, v23, v24, v25, v26, v27, v28, v29
+	.macro		enc_qround, s0, s1, s2, s3, key
+	enc_round	\s0, \key
+	enc_round	\s1, \key
+	enc_round	\s2, \key
+	enc_round	\s3, \key
+	.endm
+
+	.macro		enc_block, state, rounds, rk, tmp
+	add		\tmp, \rk, #96
+	ld1		{K6.4s-K7.4s}, [\tmp], #32
+	.irp		key, K0, K1, K2, K3, K4 K5
 	enc_round	\state, \key
 	.endr
-	aese		\state\().16b, v30.16b
-	eor		\state\().16b, \state\().16b, v31.16b
+
+	tbnz		\rounds, #2, .Lnot128_\@
+.Lout256_\@:
+	enc_round	\state, K6
+	enc_round	\state, K7
+
+.Lout192_\@:
+	enc_round	\state, KK
+	aese		\state\().16b, KL.16b
+	eor		\state\().16b, \state\().16b, KM.16b
+
+	.subsection	1
+.Lnot128_\@:
+	ld1		{K8.4s-K9.4s}, [\tmp], #32
+	enc_round	\state, K6
+	enc_round	\state, K7
+	ld1		{K6.4s-K7.4s}, [\tmp]
+	enc_round	\state, K8
+	enc_round	\state, K9
+	tbz		\rounds, #1, .Lout192_\@
+	b		.Lout256_\@
+	.previous
 	.endm
 
+	.align		6
 	.macro		pmull_gcm_do_crypt, enc
-	ld1		{SHASH.2d}, [x4], #16
-	ld1		{HH.2d}, [x4]
-	ld1		{XL.2d}, [x1]
-	ldr		x8, [x5, #8]			// load lower counter
+	stp		x29, x30, [sp, #-32]!
+	mov		x29, sp
+	str		x19, [sp, #24]
+
+	load_round_keys	x7, x6, x8
+
+	ld1		{SHASH.2d}, [x3], #16
+	ld1		{HH.2d-HH4.2d}, [x3]
 
-	movi		MASK.16b, #0xe1
 	trn1		SHASH2.2d, SHASH.2d, HH.2d
 	trn2		T1.2d, SHASH.2d, HH.2d
-CPU_LE(	rev		x8, x8		)
-	shl		MASK.2d, MASK.2d, #57
 	eor		SHASH2.16b, SHASH2.16b, T1.16b
 
-	.if		\enc == 1
-	ldr		x10, [sp]
-	ld1		{KS0.16b-KS1.16b}, [x10]
-	.endif
+	trn1		HH34.2d, HH3.2d, HH4.2d
+	trn2		T1.2d, HH3.2d, HH4.2d
+	eor		HH34.16b, HH34.16b, T1.16b
 
-	cbnz		x6, 4f
+	ld1		{XL.2d}, [x4]
 
-0:	ld1		{INP0.16b-INP1.16b}, [x3], #32
+	cbz		x0, 3f				// tag only?
 
-	rev		x9, x8
-	add		x11, x8, #1
-	add		x8, x8, #2
+	ldr		w8, [x5, #12]			// load lower counter
+CPU_LE(	rev		w8, w8		)
 
-	.if		\enc == 1
-	eor		INP0.16b, INP0.16b, KS0.16b	// encrypt input
-	eor		INP1.16b, INP1.16b, KS1.16b
+0:	mov		w9, #4				// max blocks per round
+	add		x10, x0, #0xf
+	lsr		x10, x10, #4			// remaining blocks
+
+	subs		x0, x0, #64
+	csel		w9, w10, w9, mi
+	add		w8, w8, w9
+
+	bmi		1f
+	ld1		{INP0.16b-INP3.16b}, [x2], #64
+	.subsection	1
+	/*
+	 * Populate the four input registers right to left with up to 63 bytes
+	 * of data, using overlapping loads to avoid branches.
+	 *
+	 *                INP0     INP1     INP2     INP3
+	 *  1 byte     |        |        |        |x       |
+	 * 16 bytes    |        |        |        |xxxxxxxx|
+	 * 17 bytes    |        |        |xxxxxxxx|x       |
+	 * 47 bytes    |        |xxxxxxxx|xxxxxxxx|xxxxxxx |
+	 * etc etc
+	 *
+	 * Note that this code may read up to 15 bytes before the start of
+	 * the input. It is up to the calling code to ensure this is safe if
+	 * this happens in the first iteration of the loop (i.e., when the
+	 * input size is < 16 bytes)
+	 */
+1:	mov		x15, #16
+	ands		x19, x0, #0xf
+	csel		x19, x19, x15, ne
+	adr_l		x17, .Lpermute_table + 16
+
+	sub		x11, x15, x19
+	add		x12, x17, x11
+	sub		x17, x17, x11
+	ld1		{T1.16b}, [x12]
+	sub		x10, x1, x11
+	sub		x11, x2, x11
+
+	cmp		x0, #-16
+	csel		x14, x15, xzr, gt
+	cmp		x0, #-32
+	csel		x15, x15, xzr, gt
+	cmp		x0, #-48
+	csel		x16, x19, xzr, gt
+	csel		x1, x1, x10, gt
+	csel		x2, x2, x11, gt
+
+	ld1		{INP0.16b}, [x2], x14
+	ld1		{INP1.16b}, [x2], x15
+	ld1		{INP2.16b}, [x2], x16
+	ld1		{INP3.16b}, [x2]
+	tbl		INP3.16b, {INP3.16b}, T1.16b
+	b		2f
+	.previous
+
+2:	.if		\enc == 0
+	bl		pmull_gcm_ghash_4x
 	.endif
 
-	ld1		{KS0.8b}, [x5]			// load upper counter
-	rev		x11, x11
-	sub		w0, w0, #2
-	mov		KS1.8b, KS0.8b
-	ins		KS0.d[1], x9			// set lower counter
-	ins		KS1.d[1], x11
+	bl		pmull_gcm_enc_4x
 
-	rev64		T1.16b, INP1.16b
+	tbnz		x0, #63, 6f
+	st1		{INP0.16b-INP3.16b}, [x1], #64
+	.if		\enc == 1
+	bl		pmull_gcm_ghash_4x
+	.endif
+	bne		0b
 
-	cmp		w7, #12
-	b.ge		2f				// AES-192/256?
+3:	ldp		x19, x10, [sp, #24]
+	cbz		x10, 5f				// output tag?
 
-1:	enc_round	KS0, v21
-	ext		IN1.16b, T1.16b, T1.16b, #8
+	ld1		{INP3.16b}, [x10]		// load lengths[]
+	mov		w9, #1
+	bl		pmull_gcm_ghash_4x
 
-	enc_round	KS1, v21
-	pmull2		XH2.1q, SHASH.2d, IN1.2d	// a1 * b1
+	mov		w11, #(0x1 << 24)		// BE '1U'
+	ld1		{KS0.16b}, [x5]
+	mov		KS0.s[3], w11
 
-	enc_round	KS0, v22
-	eor		T1.16b, T1.16b, IN1.16b
+	enc_block	KS0, x7, x6, x12
 
-	enc_round	KS1, v22
-	pmull		XL2.1q, SHASH.1d, IN1.1d	// a0 * b0
+	ext		XL.16b, XL.16b, XL.16b, #8
+	rev64		XL.16b, XL.16b
+	eor		XL.16b, XL.16b, KS0.16b
+	st1		{XL.16b}, [x10]			// store tag
 
-	enc_round	KS0, v23
-	pmull		XM2.1q, SHASH2.1d, T1.1d	// (a1 + a0)(b1 + b0)
+4:	ldp		x29, x30, [sp], #32
+	ret
 
-	enc_round	KS1, v23
-	rev64		T1.16b, INP0.16b
-	ext		T2.16b, XL.16b, XL.16b, #8
+5:
+CPU_LE(	rev		w8, w8		)
+	str		w8, [x5, #12]			// store lower counter
+	st1		{XL.2d}, [x4]
+	b		4b
+
+6:	ld1		{T1.16b-T2.16b}, [x17], #32	// permute vectors
+	sub		x17, x17, x19, lsl #1
+
+	cmp		w9, #1
+	beq		7f
+	.subsection	1
+7:	ld1		{INP2.16b}, [x1]
+	tbx		INP2.16b, {INP3.16b}, T1.16b
+	mov		INP3.16b, INP2.16b
+	b		8f
+	.previous
+
+	st1		{INP0.16b}, [x1], x14
+	st1		{INP1.16b}, [x1], x15
+	st1		{INP2.16b}, [x1], x16
+	tbl		INP3.16b, {INP3.16b}, T1.16b
+	tbx		INP3.16b, {INP2.16b}, T2.16b
+8:	st1		{INP3.16b}, [x1]
 
-	enc_round	KS0, v24
-	ext		IN1.16b, T1.16b, T1.16b, #8
-	eor		T1.16b, T1.16b, T2.16b
+	.if		\enc == 1
+	ld1		{T1.16b}, [x17]
+	tbl		INP3.16b, {INP3.16b}, T1.16b	// clear non-data bits
+	bl		pmull_gcm_ghash_4x
+	.endif
+	b		3b
+	.endm
 
-	enc_round	KS1, v24
-	eor		XL.16b, XL.16b, IN1.16b
+	/*
+	 * void pmull_gcm_encrypt(int blocks, u8 dst[], const u8 src[],
+	 *			  struct ghash_key const *k, u64 dg[], u8 ctr[],
+	 *			  int rounds, u8 tag)
+	 */
+ENTRY(pmull_gcm_encrypt)
+	pmull_gcm_do_crypt	1
+ENDPROC(pmull_gcm_encrypt)
 
-	enc_round	KS0, v25
-	eor		T1.16b, T1.16b, XL.16b
+	/*
+	 * void pmull_gcm_decrypt(int blocks, u8 dst[], const u8 src[],
+	 *			  struct ghash_key const *k, u64 dg[], u8 ctr[],
+	 *			  int rounds, u8 tag)
+	 */
+ENTRY(pmull_gcm_decrypt)
+	pmull_gcm_do_crypt	0
+ENDPROC(pmull_gcm_decrypt)
 
-	enc_round	KS1, v25
-	pmull2		XH.1q, HH.2d, XL.2d		// a1 * b1
+pmull_gcm_ghash_4x:
+	movi		MASK.16b, #0xe1
+	shl		MASK.2d, MASK.2d, #57
 
-	enc_round	KS0, v26
-	pmull		XL.1q, HH.1d, XL.1d		// a0 * b0
+	rev64		T1.16b, INP0.16b
+	rev64		T2.16b, INP1.16b
+	rev64		TT3.16b, INP2.16b
+	rev64		TT4.16b, INP3.16b
 
-	enc_round	KS1, v26
-	pmull2		XM.1q, SHASH2.2d, T1.2d		// (a1 + a0)(b1 + b0)
+	ext		XL.16b, XL.16b, XL.16b, #8
 
-	enc_round	KS0, v27
-	eor		XL.16b, XL.16b, XL2.16b
-	eor		XH.16b, XH.16b, XH2.16b
+	tbz		w9, #2, 0f			// <4 blocks?
+	.subsection	1
+0:	movi		XH2.16b, #0
+	movi		XM2.16b, #0
+	movi		XL2.16b, #0
 
-	enc_round	KS1, v27
-	eor		XM.16b, XM.16b, XM2.16b
-	ext		T1.16b, XL.16b, XH.16b, #8
+	tbz		w9, #0, 1f			// 2 blocks?
+	tbz		w9, #1, 2f			// 1 block?
 
-	enc_round	KS0, v28
-	eor		T2.16b, XL.16b, XH.16b
-	eor		XM.16b, XM.16b, T1.16b
+	eor		T2.16b, T2.16b, XL.16b
+	ext		T1.16b, T2.16b, T2.16b, #8
+	b		.Lgh3
 
-	enc_round	KS1, v28
-	eor		XM.16b, XM.16b, T2.16b
+1:	eor		TT3.16b, TT3.16b, XL.16b
+	ext		T2.16b, TT3.16b, TT3.16b, #8
+	b		.Lgh2
 
-	enc_round	KS0, v29
-	pmull		T2.1q, XL.1d, MASK.1d
+2:	eor		TT4.16b, TT4.16b, XL.16b
+	ext		IN1.16b, TT4.16b, TT4.16b, #8
+	b		.Lgh1
+	.previous
 
-	enc_round	KS1, v29
-	mov		XH.d[0], XM.d[1]
-	mov		XM.d[1], XL.d[0]
+	eor		T1.16b, T1.16b, XL.16b
+	ext		IN1.16b, T1.16b, T1.16b, #8
 
-	aese		KS0.16b, v30.16b
-	eor		XL.16b, XM.16b, T2.16b
+	pmull2		XH2.1q, HH4.2d, IN1.2d		// a1 * b1
+	eor		T1.16b, T1.16b, IN1.16b
+	pmull		XL2.1q, HH4.1d, IN1.1d		// a0 * b0
+	pmull2		XM2.1q, HH34.2d, T1.2d		// (a1 + a0)(b1 + b0)
 
-	aese		KS1.16b, v30.16b
-	ext		T2.16b, XL.16b, XL.16b, #8
+	ext		T1.16b, T2.16b, T2.16b, #8
+.Lgh3:	eor		T2.16b, T2.16b, T1.16b
+	pmull2		XH.1q, HH3.2d, T1.2d		// a1 * b1
+	pmull		XL.1q, HH3.1d, T1.1d		// a0 * b0
+	pmull		XM.1q, HH34.1d, T2.1d		// (a1 + a0)(b1 + b0)
 
-	eor		KS0.16b, KS0.16b, v31.16b
-	pmull		XL.1q, XL.1d, MASK.1d
-	eor		T2.16b, T2.16b, XH.16b
+	eor		XH2.16b, XH2.16b, XH.16b
+	eor		XL2.16b, XL2.16b, XL.16b
+	eor		XM2.16b, XM2.16b, XM.16b
 
-	eor		KS1.16b, KS1.16b, v31.16b
-	eor		XL.16b, XL.16b, T2.16b
+	ext		T2.16b, TT3.16b, TT3.16b, #8
+.Lgh2:	eor		TT3.16b, TT3.16b, T2.16b
+	pmull2		XH.1q, HH.2d, T2.2d		// a1 * b1
+	pmull		XL.1q, HH.1d, T2.1d		// a0 * b0
+	pmull2		XM.1q, SHASH2.2d, TT3.2d	// (a1 + a0)(b1 + b0)
 
-	.if		\enc == 0
-	eor		INP0.16b, INP0.16b, KS0.16b
-	eor		INP1.16b, INP1.16b, KS1.16b
-	.endif
+	eor		XH2.16b, XH2.16b, XH.16b
+	eor		XL2.16b, XL2.16b, XL.16b
+	eor		XM2.16b, XM2.16b, XM.16b
 
-	st1		{INP0.16b-INP1.16b}, [x2], #32
+	ext		IN1.16b, TT4.16b, TT4.16b, #8
+.Lgh1:	eor		TT4.16b, TT4.16b, IN1.16b
+	pmull		XL.1q, SHASH.1d, IN1.1d		// a0 * b0
+	pmull2		XH.1q, SHASH.2d, IN1.2d		// a1 * b1
+	pmull		XM.1q, SHASH2.1d, TT4.1d	// (a1 + a0)(b1 + b0)
 
-	cbnz		w0, 0b
+	eor		XH.16b, XH.16b, XH2.16b
+	eor		XL.16b, XL.16b, XL2.16b
+	eor		XM.16b, XM.16b, XM2.16b
 
-CPU_LE(	rev		x8, x8		)
-	st1		{XL.2d}, [x1]
-	str		x8, [x5, #8]			// store lower counter
+	eor		T2.16b, XL.16b, XH.16b
+	ext		T1.16b, XL.16b, XH.16b, #8
+	eor		XM.16b, XM.16b, T2.16b
 
-	.if		\enc == 1
-	st1		{KS0.16b-KS1.16b}, [x10]
-	.endif
+	__pmull_reduce_p64
+
+	eor		T2.16b, T2.16b, XH.16b
+	eor		XL.16b, XL.16b, T2.16b
 
 	ret
+ENDPROC(pmull_gcm_ghash_4x)
+
+pmull_gcm_enc_4x:
+	ld1		{KS0.16b}, [x5]			// load upper counter
+	sub		w10, w8, #4
+	sub		w11, w8, #3
+	sub		w12, w8, #2
+	sub		w13, w8, #1
+	rev		w10, w10
+	rev		w11, w11
+	rev		w12, w12
+	rev		w13, w13
+	mov		KS1.16b, KS0.16b
+	mov		KS2.16b, KS0.16b
+	mov		KS3.16b, KS0.16b
+	ins		KS0.s[3], w10			// set lower counter
+	ins		KS1.s[3], w11
+	ins		KS2.s[3], w12
+	ins		KS3.s[3], w13
+
+	add		x10, x6, #96			// round key pointer
+	ld1		{K6.4s-K7.4s}, [x10], #32
+	.irp		key, K0, K1, K2, K3, K4, K5
+	enc_qround	KS0, KS1, KS2, KS3, \key
+	.endr
 
-2:	b.eq		3f				// AES-192?
-	enc_round	KS0, v17
-	enc_round	KS1, v17
-	enc_round	KS0, v18
-	enc_round	KS1, v18
-3:	enc_round	KS0, v19
-	enc_round	KS1, v19
-	enc_round	KS0, v20
-	enc_round	KS1, v20
-	b		1b
+	tbnz		x7, #2, .Lnot128
+	.subsection	1
+.Lnot128:
+	ld1		{K8.4s-K9.4s}, [x10], #32
+	.irp		key, K6, K7
+	enc_qround	KS0, KS1, KS2, KS3, \key
+	.endr
+	ld1		{K6.4s-K7.4s}, [x10]
+	.irp		key, K8, K9
+	enc_qround	KS0, KS1, KS2, KS3, \key
+	.endr
+	tbz		x7, #1, .Lout192
+	b		.Lout256
+	.previous
 
-4:	load_round_keys	w7, x6
-	b		0b
-	.endm
+.Lout256:
+	.irp		key, K6, K7
+	enc_qround	KS0, KS1, KS2, KS3, \key
+	.endr
 
-	/*
-	 * void pmull_gcm_encrypt(int blocks, u64 dg[], u8 dst[], const u8 src[],
-	 *			  struct ghash_key const *k, u8 ctr[],
-	 *			  int rounds, u8 ks[])
-	 */
-ENTRY(pmull_gcm_encrypt)
-	pmull_gcm_do_crypt	1
-ENDPROC(pmull_gcm_encrypt)
+.Lout192:
+	enc_qround	KS0, KS1, KS2, KS3, KK
 
-	/*
-	 * void pmull_gcm_decrypt(int blocks, u64 dg[], u8 dst[], const u8 src[],
-	 *			  struct ghash_key const *k, u8 ctr[],
-	 *			  int rounds)
-	 */
-ENTRY(pmull_gcm_decrypt)
-	pmull_gcm_do_crypt	0
-ENDPROC(pmull_gcm_decrypt)
+	aese		KS0.16b, KL.16b
+	aese		KS1.16b, KL.16b
+	aese		KS2.16b, KL.16b
+	aese		KS3.16b, KL.16b
+
+	eor		KS0.16b, KS0.16b, KM.16b
+	eor		KS1.16b, KS1.16b, KM.16b
+	eor		KS2.16b, KS2.16b, KM.16b
+	eor		KS3.16b, KS3.16b, KM.16b
+
+	eor		INP0.16b, INP0.16b, KS0.16b
+	eor		INP1.16b, INP1.16b, KS1.16b
+	eor		INP2.16b, INP2.16b, KS2.16b
+	eor		INP3.16b, INP3.16b, KS3.16b
 
-	/*
-	 * void pmull_gcm_encrypt_block(u8 dst[], u8 src[], u8 rk[], int rounds)
-	 */
-ENTRY(pmull_gcm_encrypt_block)
-	cbz		x2, 0f
-	load_round_keys	w3, x2
-0:	ld1		{v0.16b}, [x1]
-	enc_block	v0, w3
-	st1		{v0.16b}, [x0]
 	ret
-ENDPROC(pmull_gcm_encrypt_block)
+ENDPROC(pmull_gcm_enc_4x)
+
+	.section	".rodata", "a"
+	.align		6
+.Lpermute_table:
+	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+	.byte		 0x0,  0x1,  0x2,  0x3,  0x4,  0x5,  0x6,  0x7
+	.byte		 0x8,  0x9,  0xa,  0xb,  0xc,  0xd,  0xe,  0xf
+	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
+	.byte		 0x0,  0x1,  0x2,  0x3,  0x4,  0x5,  0x6,  0x7
+	.byte		 0x8,  0x9,  0xa,  0xb,  0xc,  0xd,  0xe,  0xf
+	.previous
diff --git a/arch/arm64/crypto/ghash-ce-glue.c b/arch/arm64/crypto/ghash-ce-glue.c
index 70b1469783f9..522cf004ce65 100644
--- a/arch/arm64/crypto/ghash-ce-glue.c
+++ b/arch/arm64/crypto/ghash-ce-glue.c
@@ -58,17 +58,15 @@  asmlinkage void pmull_ghash_update_p8(int blocks, u64 dg[], const char *src,
 				      struct ghash_key const *k,
 				      const char *head);
 
-asmlinkage void pmull_gcm_encrypt(int blocks, u64 dg[], u8 dst[],
-				  const u8 src[], struct ghash_key const *k,
+asmlinkage void pmull_gcm_encrypt(int bytes, u8 dst[], const u8 src[],
+				  struct ghash_key const *k, u64 dg[],
 				  u8 ctr[], u32 const rk[], int rounds,
-				  u8 ks[]);
+				  u8 tag[]);
 
-asmlinkage void pmull_gcm_decrypt(int blocks, u64 dg[], u8 dst[],
-				  const u8 src[], struct ghash_key const *k,
-				  u8 ctr[], u32 const rk[], int rounds);
-
-asmlinkage void pmull_gcm_encrypt_block(u8 dst[], u8 const src[],
-					u32 const rk[], int rounds);
+asmlinkage void pmull_gcm_decrypt(int bytes, u8 dst[], const u8 src[],
+				  struct ghash_key const *k, u64 dg[],
+				  u8 ctr[], u32 const rk[], int rounds,
+				  u8 tag[]);
 
 static int ghash_init(struct shash_desc *desc)
 {
@@ -85,7 +83,7 @@  static void ghash_do_update(int blocks, u64 dg[], const char *src,
 						struct ghash_key const *k,
 						const char *head))
 {
-	if (likely(crypto_simd_usable())) {
+	if (likely(crypto_simd_usable() && simd_update)) {
 		kernel_neon_begin();
 		simd_update(blocks, dg, src, key, head);
 		kernel_neon_end();
@@ -398,136 +396,112 @@  static void gcm_calculate_auth_mac(struct aead_request *req, u64 dg[])
 	}
 }
 
-static void gcm_final(struct aead_request *req, struct gcm_aes_ctx *ctx,
-		      u64 dg[], u8 tag[], int cryptlen)
-{
-	u8 mac[AES_BLOCK_SIZE];
-	u128 lengths;
-
-	lengths.a = cpu_to_be64(req->assoclen * 8);
-	lengths.b = cpu_to_be64(cryptlen * 8);
-
-	ghash_do_update(1, dg, (void *)&lengths, &ctx->ghash_key, NULL,
-			pmull_ghash_update_p64);
-
-	put_unaligned_be64(dg[1], mac);
-	put_unaligned_be64(dg[0], mac + 8);
-
-	crypto_xor(tag, mac, AES_BLOCK_SIZE);
-}
-
 static int gcm_encrypt(struct aead_request *req)
 {
 	struct crypto_aead *aead = crypto_aead_reqtfm(req);
 	struct gcm_aes_ctx *ctx = crypto_aead_ctx(aead);
+	int nrounds = num_rounds(&ctx->aes_key);
 	struct skcipher_walk walk;
+	u8 buf[AES_BLOCK_SIZE];
 	u8 iv[AES_BLOCK_SIZE];
-	u8 ks[2 * AES_BLOCK_SIZE];
-	u8 tag[AES_BLOCK_SIZE];
 	u64 dg[2] = {};
-	int nrounds = num_rounds(&ctx->aes_key);
+	u128 lengths;
+	u8 *tag;
 	int err;
 
+	lengths.a = cpu_to_be64(req->assoclen * 8);
+	lengths.b = cpu_to_be64(req->cryptlen * 8);
+
 	if (req->assoclen)
 		gcm_calculate_auth_mac(req, dg);
 
 	memcpy(iv, req->iv, GCM_IV_SIZE);
-	put_unaligned_be32(1, iv + GCM_IV_SIZE);
+	put_unaligned_be32(2, iv + GCM_IV_SIZE);
 
 	err = skcipher_walk_aead_encrypt(&walk, req, false);
 
-	if (likely(crypto_simd_usable() && walk.total >= 2 * AES_BLOCK_SIZE)) {
-		u32 const *rk = NULL;
-
-		kernel_neon_begin();
-		pmull_gcm_encrypt_block(tag, iv, ctx->aes_key.key_enc, nrounds);
-		put_unaligned_be32(2, iv + GCM_IV_SIZE);
-		pmull_gcm_encrypt_block(ks, iv, NULL, nrounds);
-		put_unaligned_be32(3, iv + GCM_IV_SIZE);
-		pmull_gcm_encrypt_block(ks + AES_BLOCK_SIZE, iv, NULL, nrounds);
-		put_unaligned_be32(4, iv + GCM_IV_SIZE);
-
+	if (likely(crypto_simd_usable())) {
 		do {
-			int blocks = walk.nbytes / (2 * AES_BLOCK_SIZE) * 2;
+			const u8 *src = walk.src.virt.addr;
+			u8 *dst = walk.dst.virt.addr;
+			int nbytes = walk.nbytes;
+
+			tag = (u8 *)&lengths;
 
-			if (rk)
-				kernel_neon_begin();
+			if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE)) {
+				src = dst = memcpy(buf + sizeof(buf) - nbytes,
+						   src, nbytes);
+			} else if (nbytes < walk.total) {
+				nbytes &= ~(AES_BLOCK_SIZE - 1);
+				tag = NULL;
+			}
 
-			pmull_gcm_encrypt(blocks, dg, walk.dst.virt.addr,
-					  walk.src.virt.addr, &ctx->ghash_key,
-					  iv, rk, nrounds, ks);
+			kernel_neon_begin();
+			pmull_gcm_encrypt(nbytes, dst, src, &ctx->ghash_key, dg,
+					  iv, ctx->aes_key.key_enc, nrounds,
+					  tag);
 			kernel_neon_end();
 
-			err = skcipher_walk_done(&walk,
-					walk.nbytes % (2 * AES_BLOCK_SIZE));
+			if (unlikely(!nbytes))
+				break;
 
-			rk = ctx->aes_key.key_enc;
-		} while (walk.nbytes >= 2 * AES_BLOCK_SIZE);
-	} else {
-		aes_encrypt(&ctx->aes_key, tag, iv);
-		put_unaligned_be32(2, iv + GCM_IV_SIZE);
+			if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE))
+				memcpy(walk.dst.virt.addr,
+				       buf + sizeof(buf) - nbytes, nbytes);
 
-		while (walk.nbytes >= (2 * AES_BLOCK_SIZE)) {
-			const int blocks =
-				walk.nbytes / (2 * AES_BLOCK_SIZE) * 2;
+			err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
+		} while (walk.nbytes);
+	} else {
+		while (walk.nbytes >= AES_BLOCK_SIZE) {
+			int blocks = walk.nbytes / AES_BLOCK_SIZE;
+			const u8 *src = walk.src.virt.addr;
 			u8 *dst = walk.dst.virt.addr;
-			u8 *src = walk.src.virt.addr;
 			int remaining = blocks;
 
 			do {
-				aes_encrypt(&ctx->aes_key, ks, iv);
-				crypto_xor_cpy(dst, src, ks, AES_BLOCK_SIZE);
+				aes_encrypt(&ctx->aes_key, buf, iv);
+				crypto_xor_cpy(dst, src, buf, AES_BLOCK_SIZE);
 				crypto_inc(iv, AES_BLOCK_SIZE);
 
 				dst += AES_BLOCK_SIZE;
 				src += AES_BLOCK_SIZE;
 			} while (--remaining > 0);
 
-			ghash_do_update(blocks, dg,
-					walk.dst.virt.addr, &ctx->ghash_key,
-					NULL, pmull_ghash_update_p64);
+			ghash_do_update(blocks, dg, walk.dst.virt.addr,
+					&ctx->ghash_key, NULL, NULL);
 
 			err = skcipher_walk_done(&walk,
-						 walk.nbytes % (2 * AES_BLOCK_SIZE));
-		}
-		if (walk.nbytes) {
-			aes_encrypt(&ctx->aes_key, ks, iv);
-			if (walk.nbytes > AES_BLOCK_SIZE) {
-				crypto_inc(iv, AES_BLOCK_SIZE);
-				aes_encrypt(&ctx->aes_key, ks + AES_BLOCK_SIZE, iv);
-			}
+						 walk.nbytes % AES_BLOCK_SIZE);
 		}
-	}
 
-	/* handle the tail */
-	if (walk.nbytes) {
-		u8 buf[GHASH_BLOCK_SIZE];
-		unsigned int nbytes = walk.nbytes;
-		u8 *dst = walk.dst.virt.addr;
-		u8 *head = NULL;
+		/* handle the tail */
+		if (walk.nbytes) {
+			aes_encrypt(&ctx->aes_key, buf, iv);
 
-		crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr, ks,
-			       walk.nbytes);
+			crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr,
+				       buf, walk.nbytes);
 
-		if (walk.nbytes > GHASH_BLOCK_SIZE) {
-			head = dst;
-			dst += GHASH_BLOCK_SIZE;
-			nbytes %= GHASH_BLOCK_SIZE;
+			memcpy(buf, walk.dst.virt.addr, walk.nbytes);
+			memset(buf + walk.nbytes, 0, sizeof(buf) - walk.nbytes);
 		}
 
-		memcpy(buf, dst, nbytes);
-		memset(buf + nbytes, 0, GHASH_BLOCK_SIZE - nbytes);
-		ghash_do_update(!!nbytes, dg, buf, &ctx->ghash_key, head,
-				pmull_ghash_update_p64);
+		tag = (u8 *)&lengths;
+		ghash_do_update(1, dg, tag, &ctx->ghash_key,
+				walk.nbytes ? buf : NULL, NULL);
 
-		err = skcipher_walk_done(&walk, 0);
+		if (walk.nbytes)
+			err = skcipher_walk_done(&walk, 0);
+
+		put_unaligned_be64(dg[1], tag);
+		put_unaligned_be64(dg[0], tag + 8);
+		put_unaligned_be32(1, iv + GCM_IV_SIZE);
+		aes_encrypt(&ctx->aes_key, iv, iv);
+		crypto_xor(tag, iv, AES_BLOCK_SIZE);
 	}
 
 	if (err)
 		return err;
 
-	gcm_final(req, ctx, dg, tag, req->cryptlen);
-
 	/* copy authtag to end of dst */
 	scatterwalk_map_and_copy(tag, req->dst, req->assoclen + req->cryptlen,
 				 crypto_aead_authsize(aead), 1);
@@ -540,75 +514,65 @@  static int gcm_decrypt(struct aead_request *req)
 	struct crypto_aead *aead = crypto_aead_reqtfm(req);
 	struct gcm_aes_ctx *ctx = crypto_aead_ctx(aead);
 	unsigned int authsize = crypto_aead_authsize(aead);
+	int nrounds = num_rounds(&ctx->aes_key);
 	struct skcipher_walk walk;
-	u8 iv[2 * AES_BLOCK_SIZE];
-	u8 tag[AES_BLOCK_SIZE];
-	u8 buf[2 * GHASH_BLOCK_SIZE];
+	u8 buf[AES_BLOCK_SIZE];
+	u8 iv[AES_BLOCK_SIZE];
 	u64 dg[2] = {};
-	int nrounds = num_rounds(&ctx->aes_key);
+	u128 lengths;
+	u8 *tag;
 	int err;
 
+	lengths.a = cpu_to_be64(req->assoclen * 8);
+	lengths.b = cpu_to_be64((req->cryptlen - authsize) * 8);
+
 	if (req->assoclen)
 		gcm_calculate_auth_mac(req, dg);
 
 	memcpy(iv, req->iv, GCM_IV_SIZE);
-	put_unaligned_be32(1, iv + GCM_IV_SIZE);
+	put_unaligned_be32(2, iv + GCM_IV_SIZE);
 
 	err = skcipher_walk_aead_decrypt(&walk, req, false);
 
-	if (likely(crypto_simd_usable() && walk.total >= 2 * AES_BLOCK_SIZE)) {
-		u32 const *rk = NULL;
-
-		kernel_neon_begin();
-		pmull_gcm_encrypt_block(tag, iv, ctx->aes_key.key_enc, nrounds);
-		put_unaligned_be32(2, iv + GCM_IV_SIZE);
-
+	if (likely(crypto_simd_usable())) {
 		do {
-			int blocks = walk.nbytes / (2 * AES_BLOCK_SIZE) * 2;
-			int rem = walk.total - blocks * AES_BLOCK_SIZE;
-
-			if (rk)
-				kernel_neon_begin();
-
-			pmull_gcm_decrypt(blocks, dg, walk.dst.virt.addr,
-					  walk.src.virt.addr, &ctx->ghash_key,
-					  iv, rk, nrounds);
-
-			/* check if this is the final iteration of the loop */
-			if (rem < (2 * AES_BLOCK_SIZE)) {
-				u8 *iv2 = iv + AES_BLOCK_SIZE;
-
-				if (rem > AES_BLOCK_SIZE) {
-					memcpy(iv2, iv, AES_BLOCK_SIZE);
-					crypto_inc(iv2, AES_BLOCK_SIZE);
-				}
+			const u8 *src = walk.src.virt.addr;
+			u8 *dst = walk.dst.virt.addr;
+			int nbytes = walk.nbytes;
 
-				pmull_gcm_encrypt_block(iv, iv, NULL, nrounds);
+			tag = (u8 *)&lengths;
 
-				if (rem > AES_BLOCK_SIZE)
-					pmull_gcm_encrypt_block(iv2, iv2, NULL,
-								nrounds);
+			if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE)) {
+				src = dst = memcpy(buf + sizeof(buf) - nbytes,
+						   src, nbytes);
+			} else if (nbytes < walk.total) {
+				nbytes &= ~(AES_BLOCK_SIZE - 1);
+				tag = NULL;
 			}
 
+			kernel_neon_begin();
+			pmull_gcm_decrypt(nbytes, dst, src, &ctx->ghash_key, dg,
+					  iv, ctx->aes_key.key_enc, nrounds,
+					  tag);
 			kernel_neon_end();
 
-			err = skcipher_walk_done(&walk,
-					walk.nbytes % (2 * AES_BLOCK_SIZE));
+			if (unlikely(!nbytes))
+				break;
 
-			rk = ctx->aes_key.key_enc;
-		} while (walk.nbytes >= 2 * AES_BLOCK_SIZE);
-	} else {
-		aes_encrypt(&ctx->aes_key, tag, iv);
-		put_unaligned_be32(2, iv + GCM_IV_SIZE);
+			if (unlikely(nbytes > 0 && nbytes < AES_BLOCK_SIZE))
+				memcpy(walk.dst.virt.addr,
+				       buf + sizeof(buf) - nbytes, nbytes);
 
-		while (walk.nbytes >= (2 * AES_BLOCK_SIZE)) {
-			int blocks = walk.nbytes / (2 * AES_BLOCK_SIZE) * 2;
+			err = skcipher_walk_done(&walk, walk.nbytes - nbytes);
+		} while (walk.nbytes);
+	} else {
+		while (walk.nbytes >= AES_BLOCK_SIZE) {
+			int blocks = walk.nbytes / AES_BLOCK_SIZE;
+			const u8 *src = walk.src.virt.addr;
 			u8 *dst = walk.dst.virt.addr;
-			u8 *src = walk.src.virt.addr;
 
 			ghash_do_update(blocks, dg, walk.src.virt.addr,
-					&ctx->ghash_key, NULL,
-					pmull_ghash_update_p64);
+					&ctx->ghash_key, NULL, NULL);
 
 			do {
 				aes_encrypt(&ctx->aes_key, buf, iv);
@@ -620,49 +584,38 @@  static int gcm_decrypt(struct aead_request *req)
 			} while (--blocks > 0);
 
 			err = skcipher_walk_done(&walk,
-						 walk.nbytes % (2 * AES_BLOCK_SIZE));
+						 walk.nbytes % AES_BLOCK_SIZE);
 		}
-		if (walk.nbytes) {
-			if (walk.nbytes > AES_BLOCK_SIZE) {
-				u8 *iv2 = iv + AES_BLOCK_SIZE;
-
-				memcpy(iv2, iv, AES_BLOCK_SIZE);
-				crypto_inc(iv2, AES_BLOCK_SIZE);
 
-				aes_encrypt(&ctx->aes_key, iv2, iv2);
-			}
-			aes_encrypt(&ctx->aes_key, iv, iv);
+		/* handle the tail */
+		if (walk.nbytes) {
+			memcpy(buf, walk.src.virt.addr, walk.nbytes);
+			memset(buf + walk.nbytes, 0, sizeof(buf) - walk.nbytes);
 		}
-	}
 
-	/* handle the tail */
-	if (walk.nbytes) {
-		const u8 *src = walk.src.virt.addr;
-		const u8 *head = NULL;
-		unsigned int nbytes = walk.nbytes;
+		tag = (u8 *)&lengths;
+		ghash_do_update(1, dg, tag, &ctx->ghash_key,
+				walk.nbytes ? buf : NULL, NULL);
 
-		if (walk.nbytes > GHASH_BLOCK_SIZE) {
-			head = src;
-			src += GHASH_BLOCK_SIZE;
-			nbytes %= GHASH_BLOCK_SIZE;
-		}
+		if (walk.nbytes) {
+			aes_encrypt(&ctx->aes_key, buf, iv);
 
-		memcpy(buf, src, nbytes);
-		memset(buf + nbytes, 0, GHASH_BLOCK_SIZE - nbytes);
-		ghash_do_update(!!nbytes, dg, buf, &ctx->ghash_key, head,
-				pmull_ghash_update_p64);
+			crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr,
+				       buf, walk.nbytes);
 
-		crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr, iv,
-			       walk.nbytes);
+			err = skcipher_walk_done(&walk, 0);
+		}
 
-		err = skcipher_walk_done(&walk, 0);
+		put_unaligned_be64(dg[1], tag);
+		put_unaligned_be64(dg[0], tag + 8);
+		put_unaligned_be32(1, iv + GCM_IV_SIZE);
+		aes_encrypt(&ctx->aes_key, iv, iv);
+		crypto_xor(tag, iv, AES_BLOCK_SIZE);
 	}
 
 	if (err)
 		return err;
 
-	gcm_final(req, ctx, dg, tag, req->cryptlen - authsize);
-
 	/* compare calculated auth tag with the stored one */
 	scatterwalk_map_and_copy(buf, req->src,
 				 req->assoclen + req->cryptlen - authsize,
@@ -675,7 +628,7 @@  static int gcm_decrypt(struct aead_request *req)
 
 static struct aead_alg gcm_aes_alg = {
 	.ivsize			= GCM_IV_SIZE,
-	.chunksize		= 2 * AES_BLOCK_SIZE,
+	.chunksize		= AES_BLOCK_SIZE,
 	.maxauthsize		= AES_BLOCK_SIZE,
 	.setkey			= gcm_setkey,
 	.setauthsize		= gcm_setauthsize,