diff options
author | Michael Brown <mcb30@ipxe.org> | 2024-08-21 16:25:10 +0100 |
---|---|---|
committer | Michael Brown <mcb30@ipxe.org> | 2024-08-21 21:00:57 +0100 |
commit | 46937a9df622d1e9fb5b1e926a04176b8855fdce (patch) | |
tree | 05287931d7afaad1f6eb3294fcddda4118484c79 /src/crypto/rsa.c | |
parent | acbabdb335f47eb8246188a23ed7e3997da6e8ba (diff) | |
download | ipxe-46937a9df622d1e9fb5b1e926a04176b8855fdce.tar.gz |
[crypto] Remove the concept of a public-key algorithm reusable context
Instances of cipher and digest algorithms tend to get called
repeatedly to process substantial amounts of data. This is not true
for public-key algorithms, which tend to get called only once or twice
for a given key.
Simplify the public-key algorithm API so that there is no reusable
algorithm context. In particular, this allows callers to omit the
error handling currently required to handle memory allocation (or key
parsing) errors from pubkey_init(), and to omit the cleanup calls to
pubkey_final().
This change does remove the ability for a caller to distinguish
between a verification failure due to a memory allocation failure and
a verification failure due to a bad signature. This difference is not
material in practice: in both cases, for whatever reason, the caller
was unable to verify the signature and so cannot proceed further, and
the cause of the error will be visible to the user via the return
status code.
Signed-off-by: Michael Brown <mcb30@ipxe.org>
Diffstat (limited to 'src/crypto/rsa.c')
-rw-r--r-- | src/crypto/rsa.c | 295 |
1 files changed, 188 insertions, 107 deletions
diff --git a/src/crypto/rsa.c b/src/crypto/rsa.c index 2d288a953..19472c121 100644 --- a/src/crypto/rsa.c +++ b/src/crypto/rsa.c @@ -47,6 +47,28 @@ FILE_LICENCE ( GPL2_OR_LATER_OR_UBDL ); #define EINFO_EACCES_VERIFY \ __einfo_uniqify ( EINFO_EACCES, 0x01, "RSA signature incorrect" ) +/** An RSA context */ +struct rsa_context { + /** Allocated memory */ + void *dynamic; + /** Modulus */ + bigint_element_t *modulus0; + /** Modulus size */ + unsigned int size; + /** Modulus length */ + size_t max_len; + /** Exponent */ + bigint_element_t *exponent0; + /** Exponent size */ + unsigned int exponent_size; + /** Input buffer */ + bigint_element_t *input0; + /** Output buffer */ + bigint_element_t *output0; + /** Temporary working space for modular exponentiation */ + void *tmp; +}; + /** * Identify RSA prefix * @@ -69,10 +91,9 @@ rsa_find_prefix ( struct digest_algorithm *digest ) { * * @v context RSA context */ -static void rsa_free ( struct rsa_context *context ) { +static inline void rsa_free ( struct rsa_context *context ) { free ( context->dynamic ); - context->dynamic = NULL; } /** @@ -98,9 +119,6 @@ static int rsa_alloc ( struct rsa_context *context, size_t modulus_len, uint8_t tmp[tmp_len]; } __attribute__ (( packed )) *dynamic; - /* Free any existing dynamic storage */ - rsa_free ( context ); - /* Allocate dynamic storage */ dynamic = malloc ( sizeof ( *dynamic ) ); if ( ! dynamic ) @@ -231,12 +249,12 @@ static int rsa_parse_mod_exp ( struct asn1_cursor *modulus, /** * Initialise RSA cipher * - * @v ctx RSA context + * @v context RSA context * @v key Key * @ret rc Return status code */ -static int rsa_init ( void *ctx, const struct asn1_cursor *key ) { - struct rsa_context *context = ctx; +static int rsa_init ( struct rsa_context *context, + const struct asn1_cursor *key ) { struct asn1_cursor modulus; struct asn1_cursor exponent; int rc; @@ -277,13 +295,22 @@ static int rsa_init ( void *ctx, const struct asn1_cursor *key ) { /** * Calculate RSA maximum output length * - * @v ctx RSA context + * @v key Key * @ret max_len Maximum output length */ -static size_t rsa_max_len ( void *ctx ) { - struct rsa_context *context = ctx; +static size_t rsa_max_len ( const struct asn1_cursor *key ) { + struct asn1_cursor modulus; + struct asn1_cursor exponent; + int rc; - return context->max_len; + /* Parse moduli and exponents */ + if ( ( rc = rsa_parse_mod_exp ( &modulus, &exponent, key ) ) != 0 ) { + /* Return a zero maximum length on error */ + return 0; + } + + /* Output length can never exceed modulus length */ + return modulus.len; } /** @@ -314,111 +341,147 @@ static void rsa_cipher ( struct rsa_context *context, /** * Encrypt using RSA * - * @v ctx RSA context + * @v key Key * @v plaintext Plaintext * @v plaintext_len Length of plaintext * @v ciphertext Ciphertext * @ret ciphertext_len Length of ciphertext, or negative error */ -static int rsa_encrypt ( void *ctx, const void *plaintext, +static int rsa_encrypt ( const struct asn1_cursor *key, const void *plaintext, size_t plaintext_len, void *ciphertext ) { - struct rsa_context *context = ctx; + struct rsa_context context; void *temp; uint8_t *encoded; - size_t max_len = ( context->max_len - 11 ); - size_t random_nz_len = ( max_len - plaintext_len + 8 ); + size_t max_len; + size_t random_nz_len; int rc; + DBGC ( &context, "RSA %p encrypting:\n", &context ); + DBGC_HDA ( &context, 0, plaintext, plaintext_len ); + + /* Initialise context */ + if ( ( rc = rsa_init ( &context, key ) ) != 0 ) + goto err_init; + + /* Calculate lengths */ + max_len = ( context.max_len - 11 ); + random_nz_len = ( max_len - plaintext_len + 8 ); + /* Sanity check */ if ( plaintext_len > max_len ) { - DBGC ( context, "RSA %p plaintext too long (%zd bytes, max " - "%zd)\n", context, plaintext_len, max_len ); - return -ERANGE; + DBGC ( &context, "RSA %p plaintext too long (%zd bytes, max " + "%zd)\n", &context, plaintext_len, max_len ); + rc = -ERANGE; + goto err_sanity; } - DBGC ( context, "RSA %p encrypting:\n", context ); - DBGC_HDA ( context, 0, plaintext, plaintext_len ); /* Construct encoded message (using the big integer output * buffer as temporary storage) */ - temp = context->output0; + temp = context.output0; encoded = temp; encoded[0] = 0x00; encoded[1] = 0x02; if ( ( rc = get_random_nz ( &encoded[2], random_nz_len ) ) != 0 ) { - DBGC ( context, "RSA %p could not generate random data: %s\n", - context, strerror ( rc ) ); - return rc; + DBGC ( &context, "RSA %p could not generate random data: %s\n", + &context, strerror ( rc ) ); + goto err_random; } encoded[ 2 + random_nz_len ] = 0x00; - memcpy ( &encoded[ context->max_len - plaintext_len ], + memcpy ( &encoded[ context.max_len - plaintext_len ], plaintext, plaintext_len ); /* Encipher the encoded message */ - rsa_cipher ( context, encoded, ciphertext ); - DBGC ( context, "RSA %p encrypted:\n", context ); - DBGC_HDA ( context, 0, ciphertext, context->max_len ); + rsa_cipher ( &context, encoded, ciphertext ); + DBGC ( &context, "RSA %p encrypted:\n", &context ); + DBGC_HDA ( &context, 0, ciphertext, context.max_len ); + + /* Free context */ + rsa_free ( &context ); - return context->max_len; + return context.max_len; + + err_random: + err_sanity: + rsa_free ( &context ); + err_init: + return rc; } /** * Decrypt using RSA * - * @v ctx RSA context + * @v key Key * @v ciphertext Ciphertext * @v ciphertext_len Ciphertext length * @v plaintext Plaintext * @ret plaintext_len Plaintext length, or negative error */ -static int rsa_decrypt ( void *ctx, const void *ciphertext, +static int rsa_decrypt ( const struct asn1_cursor *key, const void *ciphertext, size_t ciphertext_len, void *plaintext ) { - struct rsa_context *context = ctx; + struct rsa_context context; void *temp; uint8_t *encoded; uint8_t *end; uint8_t *zero; uint8_t *start; size_t plaintext_len; + int rc; + + DBGC ( &context, "RSA %p decrypting:\n", &context ); + DBGC_HDA ( &context, 0, ciphertext, ciphertext_len ); + + /* Initialise context */ + if ( ( rc = rsa_init ( &context, key ) ) != 0 ) + goto err_init; /* Sanity check */ - if ( ciphertext_len != context->max_len ) { - DBGC ( context, "RSA %p ciphertext incorrect length (%zd " + if ( ciphertext_len != context.max_len ) { + DBGC ( &context, "RSA %p ciphertext incorrect length (%zd " "bytes, should be %zd)\n", - context, ciphertext_len, context->max_len ); - return -ERANGE; + &context, ciphertext_len, context.max_len ); + rc = -ERANGE; + goto err_sanity; } - DBGC ( context, "RSA %p decrypting:\n", context ); - DBGC_HDA ( context, 0, ciphertext, ciphertext_len ); /* Decipher the message (using the big integer input buffer as * temporary storage) */ - temp = context->input0; + temp = context.input0; encoded = temp; - rsa_cipher ( context, ciphertext, encoded ); + rsa_cipher ( &context, ciphertext, encoded ); /* Parse the message */ - end = ( encoded + context->max_len ); - if ( ( encoded[0] != 0x00 ) || ( encoded[1] != 0x02 ) ) - goto invalid; + end = ( encoded + context.max_len ); + if ( ( encoded[0] != 0x00 ) || ( encoded[1] != 0x02 ) ) { + rc = -EINVAL; + goto err_invalid; + } zero = memchr ( &encoded[2], 0, ( end - &encoded[2] ) ); - if ( ! zero ) - goto invalid; + if ( ! zero ) { + rc = -EINVAL; + goto err_invalid; + } start = ( zero + 1 ); plaintext_len = ( end - start ); /* Copy out message */ memcpy ( plaintext, start, plaintext_len ); - DBGC ( context, "RSA %p decrypted:\n", context ); - DBGC_HDA ( context, 0, plaintext, plaintext_len ); + DBGC ( &context, "RSA %p decrypted:\n", &context ); + DBGC_HDA ( &context, 0, plaintext, plaintext_len ); + + /* Free context */ + rsa_free ( &context ); return plaintext_len; - invalid: - DBGC ( context, "RSA %p invalid decrypted message:\n", context ); - DBGC_HDA ( context, 0, encoded, context->max_len ); - return -EINVAL; + err_invalid: + DBGC ( &context, "RSA %p invalid decrypted message:\n", &context ); + DBGC_HDA ( &context, 0, encoded, context.max_len ); + err_sanity: + rsa_free ( &context ); + err_init: + return rc; } /** @@ -452,9 +515,9 @@ static int rsa_encode_digest ( struct rsa_context *context, /* Sanity check */ max_len = ( context->max_len - 11 ); if ( digestinfo_len > max_len ) { - DBGC ( context, "RSA %p %s digestInfo too long (%zd bytes, max" - "%zd)\n", - context, digest->name, digestinfo_len, max_len ); + DBGC ( context, "RSA %p %s digestInfo too long (%zd bytes, " + "max %zd)\n", context, digest->name, digestinfo_len, + max_len ); return -ERANGE; } DBGC ( context, "RSA %p encoding %s digest:\n", @@ -482,104 +545,125 @@ static int rsa_encode_digest ( struct rsa_context *context, /** * Sign digest value using RSA * - * @v ctx RSA context + * @v key Key * @v digest Digest algorithm * @v value Digest value * @v signature Signature * @ret signature_len Signature length, or negative error */ -static int rsa_sign ( void *ctx, struct digest_algorithm *digest, - const void *value, void *signature ) { - struct rsa_context *context = ctx; +static int rsa_sign ( const struct asn1_cursor *key, + struct digest_algorithm *digest, const void *value, + void *signature ) { + struct rsa_context context; void *temp; int rc; - DBGC ( context, "RSA %p signing %s digest:\n", context, digest->name ); - DBGC_HDA ( context, 0, value, digest->digestsize ); + DBGC ( &context, "RSA %p signing %s digest:\n", + &context, digest->name ); + DBGC_HDA ( &context, 0, value, digest->digestsize ); + + /* Initialise context */ + if ( ( rc = rsa_init ( &context, key ) ) != 0 ) + goto err_init; /* Encode digest (using the big integer output buffer as * temporary storage) */ - temp = context->output0; - if ( ( rc = rsa_encode_digest ( context, digest, value, temp ) ) != 0 ) - return rc; + temp = context.output0; + if ( ( rc = rsa_encode_digest ( &context, digest, value, temp ) ) != 0 ) + goto err_encode; /* Encipher the encoded digest */ - rsa_cipher ( context, temp, signature ); - DBGC ( context, "RSA %p signed %s digest:\n", context, digest->name ); - DBGC_HDA ( context, 0, signature, context->max_len ); + rsa_cipher ( &context, temp, signature ); + DBGC ( &context, "RSA %p signed %s digest:\n", &context, digest->name ); + DBGC_HDA ( &context, 0, signature, context.max_len ); + + /* Free context */ + rsa_free ( &context ); - return context->max_len; + return context.max_len; + + err_encode: + rsa_free ( &context ); + err_init: + return rc; } /** * Verify signed digest value using RSA * - * @v ctx RSA context + * @v key Key * @v digest Digest algorithm * @v value Digest value * @v signature Signature * @v signature_len Signature length * @ret rc Return status code */ -static int rsa_verify ( void *ctx, struct digest_algorithm *digest, - const void *value, const void *signature, - size_t signature_len ) { - struct rsa_context *context = ctx; +static int rsa_verify ( const struct asn1_cursor *key, + struct digest_algorithm *digest, const void *value, + const void *signature, size_t signature_len ) { + struct rsa_context context; void *temp; void *expected; void *actual; int rc; + DBGC ( &context, "RSA %p verifying %s digest:\n", + &context, digest->name ); + DBGC_HDA ( &context, 0, value, digest->digestsize ); + DBGC_HDA ( &context, 0, signature, signature_len ); + + /* Initialise context */ + if ( ( rc = rsa_init ( &context, key ) ) != 0 ) + goto err_init; + /* Sanity check */ - if ( signature_len != context->max_len ) { - DBGC ( context, "RSA %p signature incorrect length (%zd " + if ( signature_len != context.max_len ) { + DBGC ( &context, "RSA %p signature incorrect length (%zd " "bytes, should be %zd)\n", - context, signature_len, context->max_len ); - return -ERANGE; + &context, signature_len, context.max_len ); + rc = -ERANGE; + goto err_sanity; } - DBGC ( context, "RSA %p verifying %s digest:\n", - context, digest->name ); - DBGC_HDA ( context, 0, value, digest->digestsize ); - DBGC_HDA ( context, 0, signature, signature_len ); /* Decipher the signature (using the big integer input buffer * as temporary storage) */ - temp = context->input0; + temp = context.input0; expected = temp; - rsa_cipher ( context, signature, expected ); - DBGC ( context, "RSA %p deciphered signature:\n", context ); - DBGC_HDA ( context, 0, expected, context->max_len ); + rsa_cipher ( &context, signature, expected ); + DBGC ( &context, "RSA %p deciphered signature:\n", &context ); + DBGC_HDA ( &context, 0, expected, context.max_len ); /* Encode digest (using the big integer output buffer as * temporary storage) */ - temp = context->output0; + temp = context.output0; actual = temp; - if ( ( rc = rsa_encode_digest ( context, digest, value, actual ) ) !=0 ) - return rc; + if ( ( rc = rsa_encode_digest ( &context, digest, value, + actual ) ) != 0 ) + goto err_encode; /* Verify the signature */ - if ( memcmp ( actual, expected, context->max_len ) != 0 ) { - DBGC ( context, "RSA %p signature verification failed\n", - context ); - return -EACCES_VERIFY; + if ( memcmp ( actual, expected, context.max_len ) != 0 ) { + DBGC ( &context, "RSA %p signature verification failed\n", + &context ); + rc = -EACCES_VERIFY; + goto err_verify; } - DBGC ( context, "RSA %p signature verified successfully\n", context ); - return 0; -} + /* Free context */ + rsa_free ( &context ); -/** - * Finalise RSA cipher - * - * @v ctx RSA context - */ -static void rsa_final ( void *ctx ) { - struct rsa_context *context = ctx; + DBGC ( &context, "RSA %p signature verified successfully\n", &context ); + return 0; - rsa_free ( context ); + err_verify: + err_encode: + err_sanity: + rsa_free ( &context ); + err_init: + return rc; } /** @@ -615,14 +699,11 @@ static int rsa_match ( const struct asn1_cursor *private_key, /** RSA public-key algorithm */ struct pubkey_algorithm rsa_algorithm = { .name = "rsa", - .ctxsize = RSA_CTX_SIZE, - .init = rsa_init, .max_len = rsa_max_len, .encrypt = rsa_encrypt, .decrypt = rsa_decrypt, .sign = rsa_sign, .verify = rsa_verify, - .final = rsa_final, .match = rsa_match, }; |