diff options
author | Michael Brown <mcb30@ipxe.org> | 2024-11-25 15:59:22 +0000 |
---|---|---|
committer | Michael Brown <mcb30@ipxe.org> | 2024-11-28 15:06:01 +0000 |
commit | 83ac98ce22b5b735cba4d1a21db8cc8e8648dfa4 (patch) | |
tree | e226bd3863e9b0a1d666a7f5656431f6b069b881 | |
parent | 4f7dd7fbba205d413cf9b989f7cdc928fa02caf2 (diff) | |
download | ipxe-83ac98ce22b5b735cba4d1a21db8cc8e8648dfa4.tar.gz |
[crypto] Use Montgomery reduction for modular exponentiation
Speed up modular exponentiation by using Montgomery reduction rather
than direct modular reduction.
Montgomery reduction in base 2^n requires the modulus to be coprime to
2^n, which would limit us to requiring that the modulus is an odd
number. Extend the implementation to include support for
exponentiation with even moduli via Garner's algorithm as described in
"Montgomery reduction with even modulus" (KoƧ, 1994).
Since almost all use cases for modular exponentation require a large
prime (and hence odd) modulus, the support for even moduli could
potentially be removed in future.
Signed-off-by: Michael Brown <mcb30@ipxe.org>
-rw-r--r-- | src/crypto/bigint.c | 147 | ||||
-rw-r--r-- | src/crypto/dhe.c | 3 | ||||
-rw-r--r-- | src/crypto/rsa.c | 3 | ||||
-rw-r--r-- | src/include/ipxe/bigint.h | 10 | ||||
-rw-r--r-- | src/tests/bigint_test.c | 30 |
5 files changed, 164 insertions, 29 deletions
diff --git a/src/crypto/bigint.c b/src/crypto/bigint.c index 6d75fbe9b..39e1a25cd 100644 --- a/src/crypto/bigint.c +++ b/src/crypto/bigint.c @@ -505,25 +505,142 @@ void bigint_mod_exp_raw ( const bigint_element_t *base0, *exponent = ( ( const void * ) exponent0 ); bigint_t ( size ) __attribute__ (( may_alias )) *result = ( ( void * ) result0 ); - size_t mod_multiply_len = bigint_mod_multiply_tmp_len ( modulus ); + const unsigned int width = ( 8 * sizeof ( bigint_element_t ) ); struct { - bigint_t ( size ) base; - bigint_t ( exponent_size ) exponent; - uint8_t mod_multiply[mod_multiply_len]; + union { + bigint_t ( 2 * size ) padded_modulus; + struct { + bigint_t ( size ) modulus; + bigint_t ( size ) stash; + }; + }; + union { + bigint_t ( 2 * size ) full; + bigint_t ( size ) low; + } product; } *temp = tmp; - static const uint8_t start[1] = { 0x01 }; + const uint8_t one[1] = { 1 }; + bigint_t ( 1 ) modinv; + bigint_element_t submask; + unsigned int subsize; + unsigned int scale; + unsigned int max; + unsigned int bit; + + /* Sanity check */ + assert ( sizeof ( *temp ) == bigint_mod_exp_tmp_len ( modulus ) ); + + /* Handle degenerate case of zero modulus */ + if ( ! bigint_max_set_bit ( modulus ) ) { + memset ( result, 0, sizeof ( *result ) ); + return; + } - memcpy ( &temp->base, base, sizeof ( temp->base ) ); - memcpy ( &temp->exponent, exponent, sizeof ( temp->exponent ) ); - bigint_init ( result, start, sizeof ( start ) ); + /* Factor modulus as (N * 2^scale) where N is odd */ + bigint_grow ( modulus, &temp->padded_modulus ); + for ( scale = 0 ; ( ! bigint_bit_is_set ( &temp->modulus, 0 ) ) ; + scale++ ) { + bigint_shr ( &temp->modulus ); + } + subsize = ( ( scale + width - 1 ) / width ); + submask = ( ( 1UL << ( scale % width ) ) - 1 ); + if ( ! submask ) + submask = ~submask; + + /* Calculate inverse of (scaled) modulus N modulo element size */ + bigint_mod_invert ( &temp->modulus, &modinv ); + + /* Calculate (R^2 mod N) via direct reduction of (R^2 - N) */ + memset ( &temp->product.full, 0, sizeof ( temp->product.full ) ); + bigint_subtract ( &temp->padded_modulus, &temp->product.full ); + bigint_reduce ( &temp->padded_modulus, &temp->product.full ); + bigint_copy ( &temp->product.low, &temp->stash ); + + /* Initialise result = Montgomery(1, R^2 mod N) */ + bigint_montgomery ( &temp->modulus, &modinv, + &temp->product.full, result ); + + /* Convert base into Montgomery form */ + bigint_multiply ( base, &temp->stash, &temp->product.full ); + bigint_montgomery ( &temp->modulus, &modinv, &temp->product.full, + &temp->stash ); + + /* Calculate x1 = base^exponent modulo N */ + max = bigint_max_set_bit ( exponent ); + for ( bit = 1 ; bit <= max ; bit++ ) { + + /* Square (and reduce) */ + bigint_multiply ( result, result, &temp->product.full ); + bigint_montgomery ( &temp->modulus, &modinv, + &temp->product.full, result ); + + /* Multiply (and reduce) */ + bigint_multiply ( &temp->stash, result, &temp->product.full ); + bigint_montgomery ( &temp->modulus, &modinv, + &temp->product.full, &temp->product.low ); + + /* Conditionally swap the multiplied result */ + bigint_swap ( result, &temp->product.low, + bigint_bit_is_set ( exponent, ( max - bit ) ) ); + } - while ( ! bigint_is_zero ( &temp->exponent ) ) { - if ( bigint_bit_is_set ( &temp->exponent, 0 ) ) { - bigint_mod_multiply ( result, &temp->base, modulus, - result, temp->mod_multiply ); + /* Convert back out of Montgomery form */ + bigint_grow ( result, &temp->product.full ); + bigint_montgomery ( &temp->modulus, &modinv, &temp->product.full, + result ); + + /* Handle even moduli via Garner's algorithm */ + if ( subsize ) { + const bigint_t ( subsize ) __attribute__ (( may_alias )) + *subbase = ( ( const void * ) base ); + bigint_t ( subsize ) __attribute__ (( may_alias )) + *submodulus = ( ( void * ) &temp->modulus ); + bigint_t ( subsize ) __attribute__ (( may_alias )) + *substash = ( ( void * ) &temp->stash ); + bigint_t ( subsize ) __attribute__ (( may_alias )) + *subresult = ( ( void * ) result ); + union { + bigint_t ( 2 * subsize ) full; + bigint_t ( subsize ) low; + } __attribute__ (( may_alias )) + *subproduct = ( ( void * ) &temp->product.full ); + + /* Calculate x2 = base^exponent modulo 2^k */ + bigint_init ( substash, one, sizeof ( one ) ); + for ( bit = 1 ; bit <= max ; bit++ ) { + + /* Square (and reduce) */ + bigint_multiply ( substash, substash, + &subproduct->full ); + bigint_copy ( &subproduct->low, substash ); + + /* Multiply (and reduce) */ + bigint_multiply ( subbase, substash, + &subproduct->full ); + + /* Conditionally swap the multiplied result */ + bigint_swap ( substash, &subproduct->low, + bigint_bit_is_set ( exponent, + ( max - bit ) ) ); } - bigint_shr ( &temp->exponent ); - bigint_mod_multiply ( &temp->base, &temp->base, modulus, - &temp->base, temp->mod_multiply ); + + /* Calculate N^-1 modulo 2^k */ + bigint_mod_invert ( submodulus, &subproduct->low ); + bigint_copy ( &subproduct->low, submodulus ); + + /* Calculate y = (x2 - x1) * N^-1 modulo 2^k */ + bigint_subtract ( subresult, substash ); + bigint_multiply ( substash, submodulus, &subproduct->full ); + subproduct->low.element[ subsize - 1 ] &= submask; + bigint_grow ( &subproduct->low, &temp->stash ); + + /* Reconstruct N */ + bigint_mod_invert ( submodulus, &subproduct->low ); + bigint_copy ( &subproduct->low, submodulus ); + + /* Calculate x = x1 + N * y */ + bigint_multiply ( &temp->modulus, &temp->stash, + &temp->product.full ); + bigint_add ( &temp->product.low, result ); } } diff --git a/src/crypto/dhe.c b/src/crypto/dhe.c index 2da107d24..a249f9b40 100644 --- a/src/crypto/dhe.c +++ b/src/crypto/dhe.c @@ -57,8 +57,7 @@ int dhe_key ( const void *modulus, size_t len, const void *generator, unsigned int size = bigint_required_size ( len ); unsigned int private_size = bigint_required_size ( private_len ); bigint_t ( size ) *mod; - bigint_t ( private_size ) *exp; - size_t tmp_len = bigint_mod_exp_tmp_len ( mod, exp ); + size_t tmp_len = bigint_mod_exp_tmp_len ( mod ); struct { bigint_t ( size ) modulus; bigint_t ( size ) generator; diff --git a/src/crypto/rsa.c b/src/crypto/rsa.c index 19472c121..44041da3e 100644 --- a/src/crypto/rsa.c +++ b/src/crypto/rsa.c @@ -109,8 +109,7 @@ static int rsa_alloc ( struct rsa_context *context, size_t modulus_len, unsigned int size = bigint_required_size ( modulus_len ); unsigned int exponent_size = bigint_required_size ( exponent_len ); bigint_t ( size ) *modulus; - bigint_t ( exponent_size ) *exponent; - size_t tmp_len = bigint_mod_exp_tmp_len ( modulus, exponent ); + size_t tmp_len = bigint_mod_exp_tmp_len ( modulus ); struct { bigint_t ( size ) modulus; bigint_t ( exponent_size ) exponent; diff --git a/src/include/ipxe/bigint.h b/src/include/ipxe/bigint.h index 6c9730252..3ca871962 100644 --- a/src/include/ipxe/bigint.h +++ b/src/include/ipxe/bigint.h @@ -322,18 +322,12 @@ FILE_LICENCE ( GPL2_OR_LATER_OR_UBDL ); * Calculate temporary working space required for moduluar exponentiation * * @v modulus Big integer modulus - * @v exponent Big integer exponent * @ret len Length of temporary working space */ -#define bigint_mod_exp_tmp_len( modulus, exponent ) ( { \ +#define bigint_mod_exp_tmp_len( modulus ) ( { \ unsigned int size = bigint_size (modulus); \ - unsigned int exponent_size = bigint_size (exponent); \ - size_t mod_multiply_len = \ - bigint_mod_multiply_tmp_len (modulus); \ sizeof ( struct { \ - bigint_t ( size ) temp_base; \ - bigint_t ( exponent_size ) temp_exponent; \ - uint8_t mod_multiply[mod_multiply_len]; \ + bigint_t ( size ) temp[4]; \ } ); } ) #include <bits/bigint.h> diff --git a/src/tests/bigint_test.c b/src/tests/bigint_test.c index 1f2f5f244..f3291f6a6 100644 --- a/src/tests/bigint_test.c +++ b/src/tests/bigint_test.c @@ -746,8 +746,7 @@ void bigint_mod_exp_sample ( const bigint_element_t *base0, bigint_t ( size ) modulus_temp; \ bigint_t ( exponent_size ) exponent_temp; \ bigint_t ( size ) result_temp; \ - size_t tmp_len = bigint_mod_exp_tmp_len ( &modulus_temp, \ - &exponent_temp ); \ + size_t tmp_len = bigint_mod_exp_tmp_len ( &modulus_temp ); \ uint8_t tmp[tmp_len]; \ {} /* Fix emacs alignment */ \ \ @@ -2070,6 +2069,14 @@ static void bigint_test_exec ( void ) { BIGINT ( 0xb9 ), BIGINT ( 0x39, 0x68, 0xba, 0x7d ), BIGINT ( 0x17 ) ); + bigint_mod_exp_ok ( BIGINT ( 0x71, 0x4d, 0x02, 0xe9 ), + BIGINT ( 0x00, 0x00, 0x00, 0x00 ), + BIGINT ( 0x91, 0x7f, 0x4e, 0x3a, 0x5d, 0x5c ), + BIGINT ( 0x00, 0x00, 0x00, 0x00 ) ); + bigint_mod_exp_ok ( BIGINT ( 0x2b, 0xf5, 0x07, 0xaf ), + BIGINT ( 0x6e, 0xb5, 0xda, 0x5a ), + BIGINT ( 0x00, 0x00, 0x00, 0x00, 0x00 ), + BIGINT ( 0x00, 0x00, 0x00, 0x01 ) ); bigint_mod_exp_ok ( BIGINT ( 0x2e ), BIGINT ( 0xb7 ), BIGINT ( 0x39, 0x07, 0x1b, 0x49, 0x5b, 0xea, @@ -2774,6 +2781,25 @@ static void bigint_test_exec ( void ) { 0xfa, 0x83, 0xd4, 0x7c, 0xe9, 0x77, 0x46, 0x91, 0x3a, 0x50, 0x0d, 0x6a, 0x25, 0xd0 ) ); + bigint_mod_exp_ok ( BIGINT ( 0x5b, 0x80, 0xc5, 0x03, 0xb3, 0x1e, + 0x46, 0x9b, 0xa3, 0x0a, 0x70, 0x43, + 0x51, 0x2a, 0x4a, 0x44, 0xcb, 0x87, + 0x3e, 0x00, 0x2a, 0x48, 0x46, 0xf5, + 0xb3, 0xb9, 0x73, 0xa7, 0x77, 0xfc, + 0x2a, 0x1d ), + BIGINT ( 0x5e, 0x8c, 0x80, 0x03, 0xe7, 0xb0, + 0x45, 0x23, 0x8f, 0xe0, 0x77, 0x02, + 0xc0, 0x7e, 0xfb, 0xc4, 0xbe, 0x7b, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00 ), + BIGINT ( 0x71, 0xd9, 0x38, 0xb6 ), + BIGINT ( 0x52, 0xfc, 0x73, 0x55, 0x2f, 0x86, + 0x0f, 0xde, 0x04, 0xbc, 0x6d, 0xb8, + 0xfd, 0x48, 0xf8, 0x8c, 0x91, 0x1c, + 0xa0, 0x8a, 0x70, 0xa8, 0xc6, 0x20, + 0x0a, 0x0d, 0x3b, 0x2a, 0x92, 0x65, + 0x9c, 0x59 ) ); } /** Big integer self-test */ |