aboutsummaryrefslogtreecommitdiffstats
path: root/src/crypto/bigint.c
diff options
context:
space:
mode:
authorMichael Brown <mcb30@ipxe.org>2024-11-25 15:59:22 +0000
committerMichael Brown <mcb30@ipxe.org>2024-11-28 15:06:01 +0000
commit83ac98ce22b5b735cba4d1a21db8cc8e8648dfa4 (patch)
treee226bd3863e9b0a1d666a7f5656431f6b069b881 /src/crypto/bigint.c
parent4f7dd7fbba205d413cf9b989f7cdc928fa02caf2 (diff)
downloadipxe-83ac98ce22b5b735cba4d1a21db8cc8e8648dfa4.tar.gz
[crypto] Use Montgomery reduction for modular exponentiation
Speed up modular exponentiation by using Montgomery reduction rather than direct modular reduction. Montgomery reduction in base 2^n requires the modulus to be coprime to 2^n, which would limit us to requiring that the modulus is an odd number. Extend the implementation to include support for exponentiation with even moduli via Garner's algorithm as described in "Montgomery reduction with even modulus" (KoƧ, 1994). Since almost all use cases for modular exponentation require a large prime (and hence odd) modulus, the support for even moduli could potentially be removed in future. Signed-off-by: Michael Brown <mcb30@ipxe.org>
Diffstat (limited to 'src/crypto/bigint.c')
-rw-r--r--src/crypto/bigint.c147
1 files changed, 132 insertions, 15 deletions
diff --git a/src/crypto/bigint.c b/src/crypto/bigint.c
index 6d75fbe9b..39e1a25cd 100644
--- a/src/crypto/bigint.c
+++ b/src/crypto/bigint.c
@@ -505,25 +505,142 @@ void bigint_mod_exp_raw ( const bigint_element_t *base0,
*exponent = ( ( const void * ) exponent0 );
bigint_t ( size ) __attribute__ (( may_alias )) *result =
( ( void * ) result0 );
- size_t mod_multiply_len = bigint_mod_multiply_tmp_len ( modulus );
+ const unsigned int width = ( 8 * sizeof ( bigint_element_t ) );
struct {
- bigint_t ( size ) base;
- bigint_t ( exponent_size ) exponent;
- uint8_t mod_multiply[mod_multiply_len];
+ union {
+ bigint_t ( 2 * size ) padded_modulus;
+ struct {
+ bigint_t ( size ) modulus;
+ bigint_t ( size ) stash;
+ };
+ };
+ union {
+ bigint_t ( 2 * size ) full;
+ bigint_t ( size ) low;
+ } product;
} *temp = tmp;
- static const uint8_t start[1] = { 0x01 };
+ const uint8_t one[1] = { 1 };
+ bigint_t ( 1 ) modinv;
+ bigint_element_t submask;
+ unsigned int subsize;
+ unsigned int scale;
+ unsigned int max;
+ unsigned int bit;
+
+ /* Sanity check */
+ assert ( sizeof ( *temp ) == bigint_mod_exp_tmp_len ( modulus ) );
+
+ /* Handle degenerate case of zero modulus */
+ if ( ! bigint_max_set_bit ( modulus ) ) {
+ memset ( result, 0, sizeof ( *result ) );
+ return;
+ }
- memcpy ( &temp->base, base, sizeof ( temp->base ) );
- memcpy ( &temp->exponent, exponent, sizeof ( temp->exponent ) );
- bigint_init ( result, start, sizeof ( start ) );
+ /* Factor modulus as (N * 2^scale) where N is odd */
+ bigint_grow ( modulus, &temp->padded_modulus );
+ for ( scale = 0 ; ( ! bigint_bit_is_set ( &temp->modulus, 0 ) ) ;
+ scale++ ) {
+ bigint_shr ( &temp->modulus );
+ }
+ subsize = ( ( scale + width - 1 ) / width );
+ submask = ( ( 1UL << ( scale % width ) ) - 1 );
+ if ( ! submask )
+ submask = ~submask;
+
+ /* Calculate inverse of (scaled) modulus N modulo element size */
+ bigint_mod_invert ( &temp->modulus, &modinv );
+
+ /* Calculate (R^2 mod N) via direct reduction of (R^2 - N) */
+ memset ( &temp->product.full, 0, sizeof ( temp->product.full ) );
+ bigint_subtract ( &temp->padded_modulus, &temp->product.full );
+ bigint_reduce ( &temp->padded_modulus, &temp->product.full );
+ bigint_copy ( &temp->product.low, &temp->stash );
+
+ /* Initialise result = Montgomery(1, R^2 mod N) */
+ bigint_montgomery ( &temp->modulus, &modinv,
+ &temp->product.full, result );
+
+ /* Convert base into Montgomery form */
+ bigint_multiply ( base, &temp->stash, &temp->product.full );
+ bigint_montgomery ( &temp->modulus, &modinv, &temp->product.full,
+ &temp->stash );
+
+ /* Calculate x1 = base^exponent modulo N */
+ max = bigint_max_set_bit ( exponent );
+ for ( bit = 1 ; bit <= max ; bit++ ) {
+
+ /* Square (and reduce) */
+ bigint_multiply ( result, result, &temp->product.full );
+ bigint_montgomery ( &temp->modulus, &modinv,
+ &temp->product.full, result );
+
+ /* Multiply (and reduce) */
+ bigint_multiply ( &temp->stash, result, &temp->product.full );
+ bigint_montgomery ( &temp->modulus, &modinv,
+ &temp->product.full, &temp->product.low );
+
+ /* Conditionally swap the multiplied result */
+ bigint_swap ( result, &temp->product.low,
+ bigint_bit_is_set ( exponent, ( max - bit ) ) );
+ }
- while ( ! bigint_is_zero ( &temp->exponent ) ) {
- if ( bigint_bit_is_set ( &temp->exponent, 0 ) ) {
- bigint_mod_multiply ( result, &temp->base, modulus,
- result, temp->mod_multiply );
+ /* Convert back out of Montgomery form */
+ bigint_grow ( result, &temp->product.full );
+ bigint_montgomery ( &temp->modulus, &modinv, &temp->product.full,
+ result );
+
+ /* Handle even moduli via Garner's algorithm */
+ if ( subsize ) {
+ const bigint_t ( subsize ) __attribute__ (( may_alias ))
+ *subbase = ( ( const void * ) base );
+ bigint_t ( subsize ) __attribute__ (( may_alias ))
+ *submodulus = ( ( void * ) &temp->modulus );
+ bigint_t ( subsize ) __attribute__ (( may_alias ))
+ *substash = ( ( void * ) &temp->stash );
+ bigint_t ( subsize ) __attribute__ (( may_alias ))
+ *subresult = ( ( void * ) result );
+ union {
+ bigint_t ( 2 * subsize ) full;
+ bigint_t ( subsize ) low;
+ } __attribute__ (( may_alias ))
+ *subproduct = ( ( void * ) &temp->product.full );
+
+ /* Calculate x2 = base^exponent modulo 2^k */
+ bigint_init ( substash, one, sizeof ( one ) );
+ for ( bit = 1 ; bit <= max ; bit++ ) {
+
+ /* Square (and reduce) */
+ bigint_multiply ( substash, substash,
+ &subproduct->full );
+ bigint_copy ( &subproduct->low, substash );
+
+ /* Multiply (and reduce) */
+ bigint_multiply ( subbase, substash,
+ &subproduct->full );
+
+ /* Conditionally swap the multiplied result */
+ bigint_swap ( substash, &subproduct->low,
+ bigint_bit_is_set ( exponent,
+ ( max - bit ) ) );
}
- bigint_shr ( &temp->exponent );
- bigint_mod_multiply ( &temp->base, &temp->base, modulus,
- &temp->base, temp->mod_multiply );
+
+ /* Calculate N^-1 modulo 2^k */
+ bigint_mod_invert ( submodulus, &subproduct->low );
+ bigint_copy ( &subproduct->low, submodulus );
+
+ /* Calculate y = (x2 - x1) * N^-1 modulo 2^k */
+ bigint_subtract ( subresult, substash );
+ bigint_multiply ( substash, submodulus, &subproduct->full );
+ subproduct->low.element[ subsize - 1 ] &= submask;
+ bigint_grow ( &subproduct->low, &temp->stash );
+
+ /* Reconstruct N */
+ bigint_mod_invert ( submodulus, &subproduct->low );
+ bigint_copy ( &subproduct->low, submodulus );
+
+ /* Calculate x = x1 + N * y */
+ bigint_multiply ( &temp->modulus, &temp->stash,
+ &temp->product.full );
+ bigint_add ( &temp->product.low, result );
}
}