diff mbox series

[14/18] crypto: Implement MGF1 Mask Generation Function for RSASSA-PSS

Message ID 20210330202829.4825-15-varad.gautam@suse.com
State Superseded
Headers show
Series Implement RSASSA-PSS signature verification | expand

Commit Message

Varad Gautam March 30, 2021, 8:28 p.m. UTC
This generates a "mask" byte array of size mask_len bytes as a
concatenation of digests, where each digest is calculated on a
concatenation of an input seed and a running counter to fill up
mask_len bytes - as described by RFC8017 sec B.2.1. "MGF1".

The mask is useful for RSA signing/verification process with
encoding RSASSA-PSS.

Reference: https://tools.ietf.org/html/rfc8017#appendix-B.2.1
Signed-off-by: Varad Gautam <varad.gautam@suse.com>
---
 crypto/rsa-psspad.c | 54 +++++++++++++++++++++++++++++++++++++++++++++
 1 file changed, 54 insertions(+)
diff mbox series

Patch

diff --git a/crypto/rsa-psspad.c b/crypto/rsa-psspad.c
index bb8052821982..46578b8b14b1 100644
--- a/crypto/rsa-psspad.c
+++ b/crypto/rsa-psspad.c
@@ -50,6 +50,60 @@  static int psspad_set_sig_params(struct crypto_akcipher *tfm,
 	return 0;
 }
 
+/* MGF1 per RFC8017 B.2.1. */
+static int pkcs1_mgf1(u8 *seed, unsigned int seed_len,
+		      struct shash_desc *desc,
+		      u8 *mask, unsigned int mask_len)
+{
+	unsigned int pos, h_len, i, c;
+	u8 *tmp;
+	int ret = 0;
+
+	h_len = crypto_shash_digestsize(desc->tfm);
+
+	pos = i = 0;
+	while ((i < (mask_len / h_len) + 1) && pos < mask_len) {
+		/* Compute T = T || Hash(mgfSeed || C) into mask at pos. */
+		c = cpu_to_be32(i);
+
+		ret = crypto_shash_init(desc);
+		if (ret < 0)
+			goto out_err;
+
+		ret = crypto_shash_update(desc, seed, seed_len);
+		if (ret < 0)
+			goto out_err;
+
+		ret = crypto_shash_update(desc, (u8 *) &c, sizeof(c));
+		if (ret < 0)
+			goto out_err;
+
+		if (mask_len - pos >= h_len) {
+			ret = crypto_shash_final(desc, mask + pos);
+			pos += h_len;
+		} else {
+			tmp = kzalloc(h_len, GFP_KERNEL);
+			if (!tmp) {
+				ret = -ENOMEM;
+				goto out_err;
+			}
+			ret = crypto_shash_final(desc, tmp);
+			/* copy the last hash */
+			memcpy(mask + pos, tmp, mask_len - pos);
+			kfree(tmp);
+			pos = mask_len;
+		}
+		if (ret < 0) {
+			goto out_err;
+		}
+
+		i++;
+	}
+
+out_err:
+	return ret;
+}
+
 static int psspad_s_v_e_d(struct akcipher_request *req)
 {
 	return -EOPNOTSUPP;