diff options
author | Michael Brown <mcb30@ipxe.org> | 2024-11-25 15:59:22 +0000 |
---|---|---|
committer | Michael Brown <mcb30@ipxe.org> | 2024-11-28 15:06:01 +0000 |
commit | 83ac98ce22b5b735cba4d1a21db8cc8e8648dfa4 (patch) | |
tree | e226bd3863e9b0a1d666a7f5656431f6b069b881 /src/crypto/bigint.c | |
parent | 4f7dd7fbba205d413cf9b989f7cdc928fa02caf2 (diff) | |
download | ipxe-83ac98ce22b5b735cba4d1a21db8cc8e8648dfa4.tar.gz |
[crypto] Use Montgomery reduction for modular exponentiation
Speed up modular exponentiation by using Montgomery reduction rather
than direct modular reduction.
Montgomery reduction in base 2^n requires the modulus to be coprime to
2^n, which would limit us to requiring that the modulus is an odd
number. Extend the implementation to include support for
exponentiation with even moduli via Garner's algorithm as described in
"Montgomery reduction with even modulus" (KoƧ, 1994).
Since almost all use cases for modular exponentation require a large
prime (and hence odd) modulus, the support for even moduli could
potentially be removed in future.
Signed-off-by: Michael Brown <mcb30@ipxe.org>
Diffstat (limited to 'src/crypto/bigint.c')
-rw-r--r-- | src/crypto/bigint.c | 147 |
1 files changed, 132 insertions, 15 deletions
diff --git a/src/crypto/bigint.c b/src/crypto/bigint.c index 6d75fbe9b..39e1a25cd 100644 --- a/src/crypto/bigint.c +++ b/src/crypto/bigint.c @@ -505,25 +505,142 @@ void bigint_mod_exp_raw ( const bigint_element_t *base0, *exponent = ( ( const void * ) exponent0 ); bigint_t ( size ) __attribute__ (( may_alias )) *result = ( ( void * ) result0 ); - size_t mod_multiply_len = bigint_mod_multiply_tmp_len ( modulus ); + const unsigned int width = ( 8 * sizeof ( bigint_element_t ) ); struct { - bigint_t ( size ) base; - bigint_t ( exponent_size ) exponent; - uint8_t mod_multiply[mod_multiply_len]; + union { + bigint_t ( 2 * size ) padded_modulus; + struct { + bigint_t ( size ) modulus; + bigint_t ( size ) stash; + }; + }; + union { + bigint_t ( 2 * size ) full; + bigint_t ( size ) low; + } product; } *temp = tmp; - static const uint8_t start[1] = { 0x01 }; + const uint8_t one[1] = { 1 }; + bigint_t ( 1 ) modinv; + bigint_element_t submask; + unsigned int subsize; + unsigned int scale; + unsigned int max; + unsigned int bit; + + /* Sanity check */ + assert ( sizeof ( *temp ) == bigint_mod_exp_tmp_len ( modulus ) ); + + /* Handle degenerate case of zero modulus */ + if ( ! bigint_max_set_bit ( modulus ) ) { + memset ( result, 0, sizeof ( *result ) ); + return; + } - memcpy ( &temp->base, base, sizeof ( temp->base ) ); - memcpy ( &temp->exponent, exponent, sizeof ( temp->exponent ) ); - bigint_init ( result, start, sizeof ( start ) ); + /* Factor modulus as (N * 2^scale) where N is odd */ + bigint_grow ( modulus, &temp->padded_modulus ); + for ( scale = 0 ; ( ! bigint_bit_is_set ( &temp->modulus, 0 ) ) ; + scale++ ) { + bigint_shr ( &temp->modulus ); + } + subsize = ( ( scale + width - 1 ) / width ); + submask = ( ( 1UL << ( scale % width ) ) - 1 ); + if ( ! submask ) + submask = ~submask; + + /* Calculate inverse of (scaled) modulus N modulo element size */ + bigint_mod_invert ( &temp->modulus, &modinv ); + + /* Calculate (R^2 mod N) via direct reduction of (R^2 - N) */ + memset ( &temp->product.full, 0, sizeof ( temp->product.full ) ); + bigint_subtract ( &temp->padded_modulus, &temp->product.full ); + bigint_reduce ( &temp->padded_modulus, &temp->product.full ); + bigint_copy ( &temp->product.low, &temp->stash ); + + /* Initialise result = Montgomery(1, R^2 mod N) */ + bigint_montgomery ( &temp->modulus, &modinv, + &temp->product.full, result ); + + /* Convert base into Montgomery form */ + bigint_multiply ( base, &temp->stash, &temp->product.full ); + bigint_montgomery ( &temp->modulus, &modinv, &temp->product.full, + &temp->stash ); + + /* Calculate x1 = base^exponent modulo N */ + max = bigint_max_set_bit ( exponent ); + for ( bit = 1 ; bit <= max ; bit++ ) { + + /* Square (and reduce) */ + bigint_multiply ( result, result, &temp->product.full ); + bigint_montgomery ( &temp->modulus, &modinv, + &temp->product.full, result ); + + /* Multiply (and reduce) */ + bigint_multiply ( &temp->stash, result, &temp->product.full ); + bigint_montgomery ( &temp->modulus, &modinv, + &temp->product.full, &temp->product.low ); + + /* Conditionally swap the multiplied result */ + bigint_swap ( result, &temp->product.low, + bigint_bit_is_set ( exponent, ( max - bit ) ) ); + } - while ( ! bigint_is_zero ( &temp->exponent ) ) { - if ( bigint_bit_is_set ( &temp->exponent, 0 ) ) { - bigint_mod_multiply ( result, &temp->base, modulus, - result, temp->mod_multiply ); + /* Convert back out of Montgomery form */ + bigint_grow ( result, &temp->product.full ); + bigint_montgomery ( &temp->modulus, &modinv, &temp->product.full, + result ); + + /* Handle even moduli via Garner's algorithm */ + if ( subsize ) { + const bigint_t ( subsize ) __attribute__ (( may_alias )) + *subbase = ( ( const void * ) base ); + bigint_t ( subsize ) __attribute__ (( may_alias )) + *submodulus = ( ( void * ) &temp->modulus ); + bigint_t ( subsize ) __attribute__ (( may_alias )) + *substash = ( ( void * ) &temp->stash ); + bigint_t ( subsize ) __attribute__ (( may_alias )) + *subresult = ( ( void * ) result ); + union { + bigint_t ( 2 * subsize ) full; + bigint_t ( subsize ) low; + } __attribute__ (( may_alias )) + *subproduct = ( ( void * ) &temp->product.full ); + + /* Calculate x2 = base^exponent modulo 2^k */ + bigint_init ( substash, one, sizeof ( one ) ); + for ( bit = 1 ; bit <= max ; bit++ ) { + + /* Square (and reduce) */ + bigint_multiply ( substash, substash, + &subproduct->full ); + bigint_copy ( &subproduct->low, substash ); + + /* Multiply (and reduce) */ + bigint_multiply ( subbase, substash, + &subproduct->full ); + + /* Conditionally swap the multiplied result */ + bigint_swap ( substash, &subproduct->low, + bigint_bit_is_set ( exponent, + ( max - bit ) ) ); } - bigint_shr ( &temp->exponent ); - bigint_mod_multiply ( &temp->base, &temp->base, modulus, - &temp->base, temp->mod_multiply ); + + /* Calculate N^-1 modulo 2^k */ + bigint_mod_invert ( submodulus, &subproduct->low ); + bigint_copy ( &subproduct->low, submodulus ); + + /* Calculate y = (x2 - x1) * N^-1 modulo 2^k */ + bigint_subtract ( subresult, substash ); + bigint_multiply ( substash, submodulus, &subproduct->full ); + subproduct->low.element[ subsize - 1 ] &= submask; + bigint_grow ( &subproduct->low, &temp->stash ); + + /* Reconstruct N */ + bigint_mod_invert ( submodulus, &subproduct->low ); + bigint_copy ( &subproduct->low, submodulus ); + + /* Calculate x = x1 + N * y */ + bigint_multiply ( &temp->modulus, &temp->stash, + &temp->product.full ); + bigint_add ( &temp->product.low, result ); } } |