diff mbox series

[08/15] crypto: skcipher - Add incremental support to lskcipher wrapper

Message ID dcd973a33a21bda3f8ce2aa7030fa7a3391b5ce0.1707815065.git.herbert@gondor.apana.org.au
State New
Headers show
Series crypto: Add twopass lskcipher for adiantum | expand

Commit Message

Herbert Xu Dec. 6, 2023, 4:46 a.m. UTC
Execute a second pass for incremental lskcipher algorithms when the
skcipher request contains all the data and when the SG list itself
cannot be passed to the lskcipher in one go.

If the SG list can be processed in one go, there is no need for a
second pass.  If the skcipher request itself is incremental, then
the expectation is for the user to execute a second pass on the
skcipher request.

Signed-off-by: Herbert Xu <herbert@gondor.apana.org.au>
---
 crypto/lskcipher.c | 29 ++++++++++++++++++++---------
 1 file changed, 20 insertions(+), 9 deletions(-)
diff mbox series

Patch

diff --git a/crypto/lskcipher.c b/crypto/lskcipher.c
index bc54cfc2734d..10e082f3cde6 100644
--- a/crypto/lskcipher.c
+++ b/crypto/lskcipher.c
@@ -206,11 +206,15 @@  static int crypto_lskcipher_crypt_sg(struct skcipher_request *req,
 	u8 *ivs = skcipher_request_ctx(req);
 	struct crypto_lskcipher *tfm = *ctx;
 	struct skcipher_walk walk;
+	int secondpass = 0;
+	bool isincremental;
+	bool morethanone;
 	unsigned ivsize;
 	u32 flags;
 	int err;
 
 	ivsize = crypto_lskcipher_ivsize(tfm);
+	isincremental = crypto_lskcipher_isincremental(tfm);
 	ivs = PTR_ALIGN(ivs, crypto_skcipher_alignmask(skcipher) + 1);
 
 	flags = req->base.flags & CRYPTO_TFM_REQ_MAY_SLEEP;
@@ -223,16 +227,23 @@  static int crypto_lskcipher_crypt_sg(struct skcipher_request *req,
 	if (!(req->base.flags & CRYPTO_SKCIPHER_REQ_NOTFINAL))
 		flags |= CRYPTO_LSKCIPHER_FLAG_FINAL;
 
-	err = skcipher_walk_virt(&walk, req, false);
+	do {
+		err = skcipher_walk_virt(&walk, req, false);
+		morethanone = walk.nbytes != walk.total;
 
-	while (walk.nbytes) {
-		err = crypt(tfm, walk.src.virt.addr, walk.dst.virt.addr,
-			    walk.nbytes, ivs,
-			    flags & ~(walk.nbytes == walk.total ?
-			    0 : CRYPTO_LSKCIPHER_FLAG_FINAL));
-		err = skcipher_walk_done(&walk, err);
-		flags |= CRYPTO_LSKCIPHER_FLAG_CONT;
-	}
+		while (walk.nbytes) {
+			err = crypt(tfm, walk.src.virt.addr,
+				    walk.dst.virt.addr,
+				    walk.nbytes, ivs,
+				    flags & ~(walk.nbytes == walk.total ?
+				    0 : CRYPTO_LSKCIPHER_FLAG_FINAL));
+			err = skcipher_walk_done(&walk, err);
+			flags |= CRYPTO_LSKCIPHER_FLAG_CONT;
+		}
+
+		if (err)
+			return err;
+	} while (!secondpass++ && !isincremental && morethanone);
 
 	if (flags & CRYPTO_LSKCIPHER_FLAG_FINAL)
 		memcpy(req->iv, ivs, ivsize);