diff mbox series

[v2,1/8] crypto: arm64/aes-ccm - Revert "Rewrite skcipher walker loop"

Message ID 20240118170628.3049797-11-ardb+git@google.com
State New
Headers show
Series crypto: Clean up arm64 AES-CCM code | expand

Commit Message

Ard Biesheuvel Jan. 18, 2024, 5:06 p.m. UTC
From: Ard Biesheuvel <ardb@kernel.org>

This reverts commit 57ead1bf1c54, which updated the CCM code to only
rely on walk.nbytes to check for failures returned from the skcipher
walk API, mostly for the common good rather than to fix a particular
problem in the code.

This change introduces a problem of its own: the skcipher walk is
started with the 'atomic' argument set to false, which means that the
skcipher walk API is permitted to sleep. Subsequently, it invokes
skcipher_walk_done() with preemption disabled on the final iteration of
the loop. This appears to work by accident, but it is arguably a bad
example, and providing a better example was the point of the original
patch.

Given that future changes to the CCM code will rely on the original
behavior of entering the loop even for zero sized inputs, let's just
revert this change entirely, and proceed from there.

Signed-off-by: Ard Biesheuvel <ardb@kernel.org>
---
 arch/arm64/crypto/aes-ce-ccm-glue.c | 57 +++++++++++---------
 1 file changed, 31 insertions(+), 26 deletions(-)
diff mbox series

Patch

diff --git a/arch/arm64/crypto/aes-ce-ccm-glue.c b/arch/arm64/crypto/aes-ce-ccm-glue.c
index 25cd3808ecbe..c4f14415f5f0 100644
--- a/arch/arm64/crypto/aes-ce-ccm-glue.c
+++ b/arch/arm64/crypto/aes-ce-ccm-glue.c
@@ -161,39 +161,43 @@  static int ccm_encrypt(struct aead_request *req)
 	memcpy(buf, req->iv, AES_BLOCK_SIZE);
 
 	err = skcipher_walk_aead_encrypt(&walk, req, false);
+	if (unlikely(err))
+		return err;
 
 	kernel_neon_begin();
 
 	if (req->assoclen)
 		ccm_calculate_auth_mac(req, mac);
 
-	while (walk.nbytes) {
+	do {
 		u32 tail = walk.nbytes % AES_BLOCK_SIZE;
-		bool final = walk.nbytes == walk.total;
 
-		if (final)
+		if (walk.nbytes == walk.total)
 			tail = 0;
 
 		ce_aes_ccm_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
 				   walk.nbytes - tail, ctx->key_enc,
 				   num_rounds(ctx), mac, walk.iv);
 
-		if (!final)
-			kernel_neon_end();
-		err = skcipher_walk_done(&walk, tail);
-		if (!final)
-			kernel_neon_begin();
-	}
+		if (walk.nbytes == walk.total)
+			ce_aes_ccm_final(mac, buf, ctx->key_enc, num_rounds(ctx));
 
-	ce_aes_ccm_final(mac, buf, ctx->key_enc, num_rounds(ctx));
+		kernel_neon_end();
 
-	kernel_neon_end();
+		if (walk.nbytes) {
+			err = skcipher_walk_done(&walk, tail);
+			if (unlikely(err))
+				return err;
+			if (unlikely(walk.nbytes))
+				kernel_neon_begin();
+		}
+	} while (walk.nbytes);
 
 	/* copy authtag to end of dst */
 	scatterwalk_map_and_copy(mac, req->dst, req->assoclen + req->cryptlen,
 				 crypto_aead_authsize(aead), 1);
 
-	return err;
+	return 0;
 }
 
 static int ccm_decrypt(struct aead_request *req)
@@ -215,36 +219,37 @@  static int ccm_decrypt(struct aead_request *req)
 	memcpy(buf, req->iv, AES_BLOCK_SIZE);
 
 	err = skcipher_walk_aead_decrypt(&walk, req, false);
+	if (unlikely(err))
+		return err;
 
 	kernel_neon_begin();
 
 	if (req->assoclen)
 		ccm_calculate_auth_mac(req, mac);
 
-	while (walk.nbytes) {
+	do {
 		u32 tail = walk.nbytes % AES_BLOCK_SIZE;
-		bool final = walk.nbytes == walk.total;
 
-		if (final)
+		if (walk.nbytes == walk.total)
 			tail = 0;
 
 		ce_aes_ccm_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
 				   walk.nbytes - tail, ctx->key_enc,
 				   num_rounds(ctx), mac, walk.iv);
 
-		if (!final)
-			kernel_neon_end();
-		err = skcipher_walk_done(&walk, tail);
-		if (!final)
-			kernel_neon_begin();
-	}
+		if (walk.nbytes == walk.total)
+			ce_aes_ccm_final(mac, buf, ctx->key_enc, num_rounds(ctx));
 
-	ce_aes_ccm_final(mac, buf, ctx->key_enc, num_rounds(ctx));
+		kernel_neon_end();
 
-	kernel_neon_end();
-
-	if (unlikely(err))
-		return err;
+		if (walk.nbytes) {
+			err = skcipher_walk_done(&walk, tail);
+			if (unlikely(err))
+				return err;
+			if (unlikely(walk.nbytes))
+				kernel_neon_begin();
+		}
+	} while (walk.nbytes);
 
 	/* compare calculated auth tag with the stored one */
 	scatterwalk_map_and_copy(buf, req->src,