@@ -104,61 +104,50 @@ static void do_aes_decrypt_ecb(const void *vctx,
}
}
-static void AES_cbc_encrypt(const unsigned char *in, unsigned char *out,
- const unsigned long length, const AES_KEY *key,
- unsigned char *ivec, const int enc)
+static void do_aes_encrypt_cbc(const AES_KEY *key,
+ size_t len,
+ uint8_t *out,
+ const uint8_t *in,
+ uint8_t *ivec)
{
- unsigned long n;
- unsigned long len = length;
- unsigned char tmp[AES_BLOCK_SIZE];
-
- assert(in && out && key && ivec);
-
- if (enc) {
- while (len >= AES_BLOCK_SIZE) {
- for (n = 0; n < AES_BLOCK_SIZE; ++n) {
- tmp[n] = in[n] ^ ivec[n];
- }
- AES_encrypt(tmp, out, key);
- memcpy(ivec, out, AES_BLOCK_SIZE);
- len -= AES_BLOCK_SIZE;
- in += AES_BLOCK_SIZE;
- out += AES_BLOCK_SIZE;
- }
- if (len) {
- for (n = 0; n < len; ++n) {
- tmp[n] = in[n] ^ ivec[n];
- }
- for (n = len; n < AES_BLOCK_SIZE; ++n) {
- tmp[n] = ivec[n];
- }
- AES_encrypt(tmp, tmp, key);
- memcpy(out, tmp, AES_BLOCK_SIZE);
- memcpy(ivec, tmp, AES_BLOCK_SIZE);
- }
- } else {
- while (len >= AES_BLOCK_SIZE) {
- memcpy(tmp, in, AES_BLOCK_SIZE);
- AES_decrypt(in, out, key);
- for (n = 0; n < AES_BLOCK_SIZE; ++n) {
- out[n] ^= ivec[n];
- }
- memcpy(ivec, tmp, AES_BLOCK_SIZE);
- len -= AES_BLOCK_SIZE;
- in += AES_BLOCK_SIZE;
- out += AES_BLOCK_SIZE;
- }
- if (len) {
- memcpy(tmp, in, AES_BLOCK_SIZE);
- AES_decrypt(tmp, tmp, key);
- for (n = 0; n < len; ++n) {
- out[n] = tmp[n] ^ ivec[n];
- }
- memcpy(ivec, tmp, AES_BLOCK_SIZE);
+ uint8_t tmp[AES_BLOCK_SIZE];
+ size_t n;
+
+ /* We have already verified that len % AES_BLOCK_SIZE == 0. */
+ while (len) {
+ for (n = 0; n < AES_BLOCK_SIZE; ++n) {
+ tmp[n] = in[n] ^ ivec[n];
}
+ AES_encrypt(tmp, out, key);
+ memcpy(ivec, out, AES_BLOCK_SIZE);
+ len -= AES_BLOCK_SIZE;
+ in += AES_BLOCK_SIZE;
+ out += AES_BLOCK_SIZE;
}
}
+static void do_aes_decrypt_cbc(const AES_KEY *key,
+ size_t len,
+ uint8_t *out,
+ const uint8_t *in,
+ uint8_t *ivec)
+{
+ uint8_t tmp[AES_BLOCK_SIZE];
+ size_t n;
+
+ /* We have already verified that len % AES_BLOCK_SIZE == 0. */
+ while (len) {
+ memcpy(tmp, in, AES_BLOCK_SIZE);
+ AES_decrypt(in, out, key);
+ for (n = 0; n < AES_BLOCK_SIZE; ++n) {
+ out[n] ^= ivec[n];
+ }
+ memcpy(ivec, tmp, AES_BLOCK_SIZE);
+ len -= AES_BLOCK_SIZE;
+ in += AES_BLOCK_SIZE;
+ out += AES_BLOCK_SIZE;
+ }
+}
static int qcrypto_cipher_encrypt_aes(QCryptoCipher *cipher,
const void *in,
@@ -174,9 +163,8 @@ static int qcrypto_cipher_encrypt_aes(QCryptoCipher *cipher,
do_aes_encrypt_ecb(&ctxt->state.aes.key, len, out, in);
break;
case QCRYPTO_CIPHER_MODE_CBC:
- AES_cbc_encrypt(in, out, len,
- &ctxt->state.aes.key.enc,
- ctxt->state.aes.iv, 1);
+ do_aes_encrypt_cbc(&ctxt->state.aes.key.enc, len, out, in,
+ ctxt->state.aes.iv);
break;
case QCRYPTO_CIPHER_MODE_XTS:
xts_encrypt(&ctxt->state.aes.key,
@@ -208,9 +196,8 @@ static int qcrypto_cipher_decrypt_aes(QCryptoCipher *cipher,
do_aes_decrypt_ecb(&ctxt->state.aes.key, len, out, in);
break;
case QCRYPTO_CIPHER_MODE_CBC:
- AES_cbc_encrypt(in, out, len,
- &ctxt->state.aes.key.dec,
- ctxt->state.aes.iv, 0);
+ do_aes_decrypt_cbc(&ctxt->state.aes.key.dec, len, out, in,
+ ctxt->state.aes.iv);
break;
case QCRYPTO_CIPHER_MODE_XTS:
xts_decrypt(&ctxt->state.aes.key,