diff mbox series

[v3,4/8] crypto: x86/sha256-ni - add support for finup_mb

Message ID 20240507002343.239552-5-ebiggers@kernel.org
State Superseded
Headers show
Series Optimize dm-verity and fsverity using multibuffer hashing | expand

Commit Message

Eric Biggers May 7, 2024, 12:23 a.m. UTC
From: Eric Biggers <ebiggers@google.com>

Add an implementation of finup_mb to sha256-ni, using an interleaving
factor of 2.  It interleaves a finup operation for two equal-length
messages that share a common prefix.  dm-verity and fs-verity will take
advantage of this for greatly improved performance on capable CPUs.

This increases the throughput of SHA-256 hashing 4096-byte messages by
the following amounts on the following CPUs:

    AMD Zen 1:                  84%
    AMD Zen 4:                  98%
    Intel Ice Lake:              4%
    Intel Sapphire Rapids:      20%

For now, this seems to benefit AMD much more than Intel.  This seems to
be because current AMD CPUs support concurrent execution of the SHA-NI
instructions, but unfortunately current Intel CPUs don't, except for the
sha256msg2 instruction.  Hopefully future Intel CPUs will support SHA-NI
on more execution ports.  Zen 1 supports 2 concurrent sha256rnds2, and
Zen 4 supports 4 concurrent sha256rnds2, which suggests that even better
performance may be achievable on Zen 4 by interleaving more than two
hashes; however, doing so poses a number of trade-offs.

It's been reported that the method that achieves the highest SHA-256
throughput on Intel CPUs is actually computing 16 hashes simultaneously
using AVX512.  That method would be quite different to the SHA-NI method
used in this patch.  However, such a high interleaving factor isn't
practical for the use cases being targeted in the kernel.

Signed-off-by: Eric Biggers <ebiggers@google.com>
---
 arch/x86/crypto/sha256_ni_asm.S     | 368 ++++++++++++++++++++++++++++
 arch/x86/crypto/sha256_ssse3_glue.c |  39 +++
 2 files changed, 407 insertions(+)
diff mbox series

Patch

diff --git a/arch/x86/crypto/sha256_ni_asm.S b/arch/x86/crypto/sha256_ni_asm.S
index d515a55a3bc1..5e97922a24e4 100644
--- a/arch/x86/crypto/sha256_ni_asm.S
+++ b/arch/x86/crypto/sha256_ni_asm.S
@@ -172,10 +172,378 @@  SYM_TYPED_FUNC_START(sha256_ni_transform)
 .Ldone_hash:
 
 	RET
 SYM_FUNC_END(sha256_ni_transform)
 
+#undef DIGEST_PTR
+#undef DATA_PTR
+#undef NUM_BLKS
+#undef SHA256CONSTANTS
+#undef MSG
+#undef STATE0
+#undef STATE1
+#undef MSG0
+#undef MSG1
+#undef MSG2
+#undef MSG3
+#undef TMP
+#undef SHUF_MASK
+#undef ABEF_SAVE
+#undef CDGH_SAVE
+
+// parameters for __sha256_ni_finup2x()
+#define SCTX		%rdi
+#define DATA1		%rsi
+#define DATA2		%rdx
+#define LEN		%ecx
+#define LEN8		%cl
+#define LEN64		%rcx
+#define OUT1		%r8
+#define OUT2		%r9
+
+// other scalar variables
+#define SHA256CONSTANTS	%rax
+#define COUNT		%r10
+#define COUNT32		%r10d
+#define FINAL_STEP	%r11d
+
+// rbx is used as a temporary.
+
+#define MSG		%xmm0	// sha256rnds2 implicit operand
+#define STATE0_A	%xmm1
+#define STATE1_A	%xmm2
+#define STATE0_B	%xmm3
+#define STATE1_B	%xmm4
+#define TMP_A		%xmm5
+#define TMP_B		%xmm6
+#define MSG0_A		%xmm7
+#define MSG1_A		%xmm8
+#define MSG2_A		%xmm9
+#define MSG3_A		%xmm10
+#define MSG0_B		%xmm11
+#define MSG1_B		%xmm12
+#define MSG2_B		%xmm13
+#define MSG3_B		%xmm14
+#define SHUF_MASK	%xmm15
+
+#define OFFSETOF_STATE	0	// offsetof(struct sha256_state, state)
+#define OFFSETOF_COUNT	32	// offsetof(struct sha256_state, count)
+#define OFFSETOF_BUF	40	// offsetof(struct sha256_state, buf)
+
+// Do 4 rounds of SHA-256 for each of two messages (interleaved).  m0_a and m0_b
+// contain the current 4 message schedule words for the first and second message
+// respectively.
+//
+// If not all the message schedule words have been computed yet, then this also
+// computes 4 more message schedule words for each message.  m1_a-m3_a contain
+// the next 3 groups of 4 message schedule words for the first message, and
+// likewise m1_b-m3_b for the second.  After consuming the current value of
+// m0_a, this macro computes the group after m3_a and writes it to m0_a, and
+// likewise for *_b.  This means that the next (m0_a, m1_a, m2_a, m3_a) is the
+// current (m1_a, m2_a, m3_a, m0_a), and likewise for *_b, so the caller must
+// cycle through the registers accordingly.
+.macro	do_4rounds_2x	i, m0_a, m1_a, m2_a, m3_a,  m0_b, m1_b, m2_b, m3_b
+	movdqa		(\i-32)*4(SHA256CONSTANTS), TMP_A
+	movdqa		TMP_A, TMP_B
+	paddd		\m0_a, TMP_A
+	paddd		\m0_b, TMP_B
+.if \i < 48
+	sha256msg1	\m1_a, \m0_a
+	sha256msg1	\m1_b, \m0_b
+.endif
+	movdqa		TMP_A, MSG
+	sha256rnds2	STATE0_A, STATE1_A
+	movdqa		TMP_B, MSG
+	sha256rnds2	STATE0_B, STATE1_B
+	pshufd 		$0x0E, TMP_A, MSG
+	sha256rnds2	STATE1_A, STATE0_A
+	pshufd 		$0x0E, TMP_B, MSG
+	sha256rnds2	STATE1_B, STATE0_B
+.if \i < 48
+	movdqa		\m3_a, TMP_A
+	movdqa		\m3_b, TMP_B
+	palignr		$4, \m2_a, TMP_A
+	palignr		$4, \m2_b, TMP_B
+	paddd		TMP_A, \m0_a
+	paddd		TMP_B, \m0_b
+	sha256msg2	\m3_a, \m0_a
+	sha256msg2	\m3_b, \m0_b
+.endif
+.endm
+
+//
+// void __sha256_ni_finup2x(const struct sha256_state *sctx,
+//			    const u8 *data1, const u8 *data2, int len,
+//			    u8 out1[SHA256_DIGEST_SIZE],
+//			    u8 out2[SHA256_DIGEST_SIZE]);
+//
+// This function computes the SHA-256 digests of two messages |data1| and
+// |data2| that are both |len| bytes long, starting from the initial state
+// |sctx|.  |len| must be at least SHA256_BLOCK_SIZE.
+//
+// The instructions for the two SHA-256 operations are interleaved.  On many
+// CPUs, this is almost twice as fast as hashing each message individually due
+// to taking better advantage of the CPU's SHA-256 and SIMD throughput.
+//
+SYM_FUNC_START(__sha256_ni_finup2x)
+	// Allocate 128 bytes of stack space, 16-byte aligned.
+	push		%rbx
+	push		%rbp
+	mov		%rsp, %rbp
+	sub		$128, %rsp
+	and		$~15, %rsp
+
+	// Load the shuffle mask for swapping the endianness of 32-bit words.
+	movdqa		PSHUFFLE_BYTE_FLIP_MASK(%rip), SHUF_MASK
+
+	// Set up pointer to the round constants.
+	lea		K256+32*4(%rip), SHA256CONSTANTS
+
+	// Initially we're not processing the final blocks.
+	xor		FINAL_STEP, FINAL_STEP
+
+	// Load the initial state from sctx->state.
+	movdqu		OFFSETOF_STATE+0*16(SCTX), STATE0_A	// DCBA
+	movdqu		OFFSETOF_STATE+1*16(SCTX), STATE1_A	// HGFE
+	movdqa		STATE0_A, TMP_A
+	punpcklqdq	STATE1_A, STATE0_A			// FEBA
+	punpckhqdq	TMP_A, STATE1_A				// DCHG
+	pshufd		$0x1B, STATE0_A, STATE0_A		// ABEF
+	pshufd		$0xB1, STATE1_A, STATE1_A		// CDGH
+
+	// Load sctx->count.  Take the mod 64 of it to get the number of bytes
+	// that are buffered in sctx->buf.  Also save it in a register with LEN
+	// added to it.
+	mov		LEN, LEN
+	mov		OFFSETOF_COUNT(SCTX), %rbx
+	lea		(%rbx, LEN64, 1), COUNT
+	and		$63, %ebx
+	jz		.Lfinup2x_enter_loop	// No bytes buffered?
+
+	// %ebx bytes (1 to 63) are currently buffered in sctx->buf.  Load them
+	// followed by the first 64 - %ebx bytes of data.  Since LEN >= 64, we
+	// just load 64 bytes from each of sctx->buf, DATA1, and DATA2
+	// unconditionally and rearrange the data as needed.
+
+	movdqu		OFFSETOF_BUF+0*16(SCTX), MSG0_A
+	movdqu		OFFSETOF_BUF+1*16(SCTX), MSG1_A
+	movdqu		OFFSETOF_BUF+2*16(SCTX), MSG2_A
+	movdqu		OFFSETOF_BUF+3*16(SCTX), MSG3_A
+	movdqa		MSG0_A, 0*16(%rsp)
+	movdqa		MSG1_A, 1*16(%rsp)
+	movdqa		MSG2_A, 2*16(%rsp)
+	movdqa		MSG3_A, 3*16(%rsp)
+
+	movdqu		0*16(DATA1), MSG0_A
+	movdqu		1*16(DATA1), MSG1_A
+	movdqu		2*16(DATA1), MSG2_A
+	movdqu		3*16(DATA1), MSG3_A
+	movdqu		MSG0_A, 0*16(%rsp,%rbx)
+	movdqu		MSG1_A, 1*16(%rsp,%rbx)
+	movdqu		MSG2_A, 2*16(%rsp,%rbx)
+	movdqu		MSG3_A, 3*16(%rsp,%rbx)
+	movdqa		0*16(%rsp), MSG0_A
+	movdqa		1*16(%rsp), MSG1_A
+	movdqa		2*16(%rsp), MSG2_A
+	movdqa		3*16(%rsp), MSG3_A
+
+	movdqu		0*16(DATA2), MSG0_B
+	movdqu		1*16(DATA2), MSG1_B
+	movdqu		2*16(DATA2), MSG2_B
+	movdqu		3*16(DATA2), MSG3_B
+	movdqu		MSG0_B, 0*16(%rsp,%rbx)
+	movdqu		MSG1_B, 1*16(%rsp,%rbx)
+	movdqu		MSG2_B, 2*16(%rsp,%rbx)
+	movdqu		MSG3_B, 3*16(%rsp,%rbx)
+	movdqa		0*16(%rsp), MSG0_B
+	movdqa		1*16(%rsp), MSG1_B
+	movdqa		2*16(%rsp), MSG2_B
+	movdqa		3*16(%rsp), MSG3_B
+
+	sub		$64, %rbx 	// rbx = buffered - 64
+	sub		%rbx, DATA1	// DATA1 += 64 - buffered
+	sub		%rbx, DATA2	// DATA2 += 64 - buffered
+	add		%ebx, LEN	// LEN += buffered - 64
+	movdqa		STATE0_A, STATE0_B
+	movdqa		STATE1_A, STATE1_B
+	jmp		.Lfinup2x_loop_have_data
+
+.Lfinup2x_enter_loop:
+	sub		$64, LEN
+	movdqa		STATE0_A, STATE0_B
+	movdqa		STATE1_A, STATE1_B
+.Lfinup2x_loop:
+	// Load the next two data blocks.
+	movdqu		0*16(DATA1), MSG0_A
+	movdqu		0*16(DATA2), MSG0_B
+	movdqu		1*16(DATA1), MSG1_A
+	movdqu		1*16(DATA2), MSG1_B
+	movdqu		2*16(DATA1), MSG2_A
+	movdqu		2*16(DATA2), MSG2_B
+	movdqu		3*16(DATA1), MSG3_A
+	movdqu		3*16(DATA2), MSG3_B
+	add		$64, DATA1
+	add		$64, DATA2
+.Lfinup2x_loop_have_data:
+	// Convert the words of the data blocks from big endian.
+	pshufb		SHUF_MASK, MSG0_A
+	pshufb		SHUF_MASK, MSG0_B
+	pshufb		SHUF_MASK, MSG1_A
+	pshufb		SHUF_MASK, MSG1_B
+	pshufb		SHUF_MASK, MSG2_A
+	pshufb		SHUF_MASK, MSG2_B
+	pshufb		SHUF_MASK, MSG3_A
+	pshufb		SHUF_MASK, MSG3_B
+.Lfinup2x_loop_have_bswapped_data:
+
+	// Save the original state for each block.
+	movdqa		STATE0_A, 0*16(%rsp)
+	movdqa		STATE0_B, 1*16(%rsp)
+	movdqa		STATE1_A, 2*16(%rsp)
+	movdqa		STATE1_B, 3*16(%rsp)
+
+	// Do the SHA-256 rounds on each block.
+.irp i, 0, 16, 32, 48
+	do_4rounds_2x	(\i + 0),  MSG0_A, MSG1_A, MSG2_A, MSG3_A, \
+				   MSG0_B, MSG1_B, MSG2_B, MSG3_B
+	do_4rounds_2x	(\i + 4),  MSG1_A, MSG2_A, MSG3_A, MSG0_A, \
+				   MSG1_B, MSG2_B, MSG3_B, MSG0_B
+	do_4rounds_2x	(\i + 8),  MSG2_A, MSG3_A, MSG0_A, MSG1_A, \
+				   MSG2_B, MSG3_B, MSG0_B, MSG1_B
+	do_4rounds_2x	(\i + 12), MSG3_A, MSG0_A, MSG1_A, MSG2_A, \
+				   MSG3_B, MSG0_B, MSG1_B, MSG2_B
+.endr
+
+	// Add the original state for each block.
+	paddd		0*16(%rsp), STATE0_A
+	paddd		1*16(%rsp), STATE0_B
+	paddd		2*16(%rsp), STATE1_A
+	paddd		3*16(%rsp), STATE1_B
+
+	// Update LEN and loop back if more blocks remain.
+	sub		$64, LEN
+	jge		.Lfinup2x_loop
+
+	// Check if any final blocks need to be handled.
+	// FINAL_STEP = 2: all done
+	// FINAL_STEP = 1: need to do count-only padding block
+	// FINAL_STEP = 0: need to do the block with 0x80 padding byte
+	cmp		$1, FINAL_STEP
+	jg		.Lfinup2x_done
+	je		.Lfinup2x_finalize_countonly
+	add		$64, LEN
+	jz		.Lfinup2x_finalize_blockaligned
+
+	// Not block-aligned; 1 <= LEN <= 63 data bytes remain.  Pad the block.
+	// To do this, write the padding starting with the 0x80 byte to
+	// &sp[64].  Then for each message, copy the last 64 data bytes to sp
+	// and load from &sp[64 - LEN] to get the needed padding block.  This
+	// code relies on the data buffers being >= 64 bytes in length.
+	mov		$64, %ebx
+	sub		LEN, %ebx		// ebx = 64 - LEN
+	sub		%rbx, DATA1		// DATA1 -= 64 - LEN
+	sub		%rbx, DATA2		// DATA2 -= 64 - LEN
+	mov		$0x80, FINAL_STEP   // using FINAL_STEP as a temporary
+	movd		FINAL_STEP, MSG0_A
+	pxor		MSG1_A, MSG1_A
+	movdqa		MSG0_A, 4*16(%rsp)
+	movdqa		MSG1_A, 5*16(%rsp)
+	movdqa		MSG1_A, 6*16(%rsp)
+	movdqa		MSG1_A, 7*16(%rsp)
+	cmp		$56, LEN
+	jge		1f	// will COUNT spill into its own block?
+	shl		$3, COUNT
+	bswap		COUNT
+	mov		COUNT, 56(%rsp,%rbx)
+	mov		$2, FINAL_STEP	// won't need count-only block
+	jmp		2f
+1:
+	mov		$1, FINAL_STEP	// will need count-only block
+2:
+	movdqu		0*16(DATA1), MSG0_A
+	movdqu		1*16(DATA1), MSG1_A
+	movdqu		2*16(DATA1), MSG2_A
+	movdqu		3*16(DATA1), MSG3_A
+	movdqa		MSG0_A, 0*16(%rsp)
+	movdqa		MSG1_A, 1*16(%rsp)
+	movdqa		MSG2_A, 2*16(%rsp)
+	movdqa		MSG3_A, 3*16(%rsp)
+	movdqu		0*16(%rsp,%rbx), MSG0_A
+	movdqu		1*16(%rsp,%rbx), MSG1_A
+	movdqu		2*16(%rsp,%rbx), MSG2_A
+	movdqu		3*16(%rsp,%rbx), MSG3_A
+
+	movdqu		0*16(DATA2), MSG0_B
+	movdqu		1*16(DATA2), MSG1_B
+	movdqu		2*16(DATA2), MSG2_B
+	movdqu		3*16(DATA2), MSG3_B
+	movdqa		MSG0_B, 0*16(%rsp)
+	movdqa		MSG1_B, 1*16(%rsp)
+	movdqa		MSG2_B, 2*16(%rsp)
+	movdqa		MSG3_B, 3*16(%rsp)
+	movdqu		0*16(%rsp,%rbx), MSG0_B
+	movdqu		1*16(%rsp,%rbx), MSG1_B
+	movdqu		2*16(%rsp,%rbx), MSG2_B
+	movdqu		3*16(%rsp,%rbx), MSG3_B
+	jmp		.Lfinup2x_loop_have_data
+
+	// Prepare a padding block, either:
+	//
+	//	{0x80, 0, 0, 0, ..., count (as __be64)}
+	//	This is for a block aligned message.
+	//
+	//	{   0, 0, 0, 0, ..., count (as __be64)}
+	//	This is for a message whose length mod 64 is >= 56.
+	//
+	// Pre-swap the endianness of the words.
+.Lfinup2x_finalize_countonly:
+	pxor		MSG0_A, MSG0_A
+	jmp		1f
+
+.Lfinup2x_finalize_blockaligned:
+	mov		$0x80000000, %ebx
+	movd		%ebx, MSG0_A
+1:
+	pxor		MSG1_A, MSG1_A
+	pxor		MSG2_A, MSG2_A
+	ror		$29, COUNT
+	movq		COUNT, MSG3_A
+	pslldq		$8, MSG3_A
+	movdqa		MSG0_A, MSG0_B
+	pxor		MSG1_B, MSG1_B
+	pxor		MSG2_B, MSG2_B
+	movdqa		MSG3_A, MSG3_B
+	mov		$2, FINAL_STEP
+	jmp		.Lfinup2x_loop_have_bswapped_data
+
+.Lfinup2x_done:
+	// Write the two digests with all bytes in the correct order.
+	movdqa		STATE0_A, TMP_A
+	movdqa		STATE0_B, TMP_B
+	punpcklqdq	STATE1_A, STATE0_A		// GHEF
+	punpcklqdq	STATE1_B, STATE0_B
+	punpckhqdq	TMP_A, STATE1_A			// ABCD
+	punpckhqdq	TMP_B, STATE1_B
+	pshufd		$0xB1, STATE0_A, STATE0_A	// HGFE
+	pshufd		$0xB1, STATE0_B, STATE0_B
+	pshufd		$0x1B, STATE1_A, STATE1_A	// DCBA
+	pshufd		$0x1B, STATE1_B, STATE1_B
+	pshufb		SHUF_MASK, STATE0_A
+	pshufb		SHUF_MASK, STATE0_B
+	pshufb		SHUF_MASK, STATE1_A
+	pshufb		SHUF_MASK, STATE1_B
+	movdqu		STATE0_A, 1*16(OUT1)
+	movdqu		STATE0_B, 1*16(OUT2)
+	movdqu		STATE1_A, 0*16(OUT1)
+	movdqu		STATE1_B, 0*16(OUT2)
+
+	mov		%rbp, %rsp
+	pop		%rbp
+	pop		%rbx
+	RET
+SYM_FUNC_END(__sha256_ni_finup2x)
+
 .section	.rodata.cst256.K256, "aM", @progbits, 256
 .align 64
 K256:
 	.long	0x428a2f98,0x71374491,0xb5c0fbcf,0xe9b5dba5
 	.long	0x3956c25b,0x59f111f1,0x923f82a4,0xab1c5ed5
diff --git a/arch/x86/crypto/sha256_ssse3_glue.c b/arch/x86/crypto/sha256_ssse3_glue.c
index e04a43d9f7d5..f5e6cc7afac7 100644
--- a/arch/x86/crypto/sha256_ssse3_glue.c
+++ b/arch/x86/crypto/sha256_ssse3_glue.c
@@ -331,10 +331,15 @@  static void unregister_sha256_avx2(void)
 
 #ifdef CONFIG_AS_SHA256_NI
 asmlinkage void sha256_ni_transform(struct sha256_state *digest,
 				    const u8 *data, int rounds);
 
+asmlinkage void __sha256_ni_finup2x(const struct sha256_state *sctx,
+				    const u8 *data1, const u8 *data2, int len,
+				    u8 out1[SHA256_DIGEST_SIZE],
+				    u8 out2[SHA256_DIGEST_SIZE]);
+
 static int sha256_ni_update(struct shash_desc *desc, const u8 *data,
 			 unsigned int len)
 {
 	return _sha256_update(desc, data, len, sha256_ni_transform);
 }
@@ -355,18 +360,52 @@  static int sha256_ni_digest(struct shash_desc *desc, const u8 *data,
 {
 	return sha256_base_init(desc) ?:
 	       sha256_ni_finup(desc, data, len, out);
 }
 
+static int sha256_ni_finup_mb(struct shash_desc *desc,
+			      const u8 * const data[], unsigned int len,
+			      u8 * const outs[], unsigned int num_msgs)
+{
+	struct sha256_state *sctx = shash_desc_ctx(desc);
+
+	/*
+	 * num_msgs != 2 should not happen here, since this algorithm sets
+	 * mb_max_msgs=2, and the crypto API handles num_msgs <= 1 before
+	 * calling into the algorithm's finup_mb method.
+	 */
+	if (WARN_ON_ONCE(num_msgs != 2))
+		return -EOPNOTSUPP;
+
+	if (unlikely(!crypto_simd_usable()))
+		return -EOPNOTSUPP;
+
+	/* __sha256_ni_finup2x() assumes SHA256_BLOCK_SIZE <= len <= INT_MAX. */
+	if (unlikely(len < SHA256_BLOCK_SIZE || len > PAGE_SIZE))
+		return -EOPNOTSUPP;
+
+	/* __sha256_ni_finup2x() assumes the following offsets. */
+	BUILD_BUG_ON(offsetof(struct sha256_state, state) != 0);
+	BUILD_BUG_ON(offsetof(struct sha256_state, count) != 32);
+	BUILD_BUG_ON(offsetof(struct sha256_state, buf) != 40);
+
+	kernel_fpu_begin();
+	__sha256_ni_finup2x(sctx, data[0], data[1], len, outs[0], outs[1]);
+	kernel_fpu_end();
+	return 0;
+}
+
 static struct shash_alg sha256_ni_algs[] = { {
 	.digestsize	=	SHA256_DIGEST_SIZE,
 	.init		=	sha256_base_init,
 	.update		=	sha256_ni_update,
 	.final		=	sha256_ni_final,
 	.finup		=	sha256_ni_finup,
 	.digest		=	sha256_ni_digest,
+	.finup_mb	=	sha256_ni_finup_mb,
 	.descsize	=	sizeof(struct sha256_state),
+	.mb_max_msgs	=	2,
 	.base		=	{
 		.cra_name	=	"sha256",
 		.cra_driver_name =	"sha256-ni",
 		.cra_priority	=	250,
 		.cra_blocksize	=	SHA256_BLOCK_SIZE,