diff options
Diffstat (limited to 'src/crypto/bigint.c')
-rw-r--r-- | src/crypto/bigint.c | 248 |
1 files changed, 102 insertions, 146 deletions
diff --git a/src/crypto/bigint.c b/src/crypto/bigint.c index ad22af771..9ccd9ff88 100644 --- a/src/crypto/bigint.c +++ b/src/crypto/bigint.c @@ -27,7 +27,6 @@ FILE_LICENCE ( GPL2_OR_LATER_OR_UBDL ); #include <string.h> #include <assert.h> #include <stdio.h> -#include <ipxe/profile.h> #include <ipxe/bigint.h> /** @file @@ -35,10 +34,6 @@ FILE_LICENCE ( GPL2_OR_LATER_OR_UBDL ); * Big integer support */ -/** Modular direct reduction profiler */ -static struct profiler bigint_mod_profiler __profiler = - { .name = "bigint_mod" }; - /** Minimum number of least significant bytes included in transcription */ #define BIGINT_NTOA_LSB_MIN 16 @@ -180,172 +175,136 @@ void bigint_multiply_raw ( const bigint_element_t *multiplicand0, } /** - * Reduce big integer + * Reduce big integer R^2 modulo N * * @v modulus0 Element 0 of big integer modulus - * @v value0 Element 0 of big integer to be reduced - * @v size Number of elements in modulus and value + * @v result0 Element 0 of big integer to hold result + * @v size Number of elements in modulus and result + * + * Reduce the value R^2 modulo N, where R=2^n and n is the number of + * bits in the representation of the modulus N, including any leading + * zero bits. */ -void bigint_reduce_raw ( bigint_element_t *modulus0, bigint_element_t *value0, - unsigned int size ) { - bigint_t ( size ) __attribute__ (( may_alias )) - *modulus = ( ( void * ) modulus0 ); +void bigint_reduce_raw ( const bigint_element_t *modulus0, + bigint_element_t *result0, unsigned int size ) { + const bigint_t ( size ) __attribute__ (( may_alias )) + *modulus = ( ( const void * ) modulus0 ); bigint_t ( size ) __attribute__ (( may_alias )) - *value = ( ( void * ) value0 ); + *result = ( ( void * ) result0 ); const unsigned int width = ( 8 * sizeof ( bigint_element_t ) ); - bigint_element_t *element; - unsigned int modulus_max; - unsigned int value_max; - unsigned int subshift; - int offset; - int shift; + unsigned int shift; + int max; + int sign; int msb; - int i; - - /* Start profiling */ - profile_start ( &bigint_mod_profiler ); + int carry; - /* Normalise the modulus + /* We have the constants: * - * Scale the modulus by shifting left such that both modulus - * "m" and value "x" have the same most significant set bit. - * (If this is not possible, then the value is already less - * than the modulus, and we may therefore skip reduction - * completely.) - */ - value_max = bigint_max_set_bit ( value ); - modulus_max = bigint_max_set_bit ( modulus ); - shift = ( value_max - modulus_max ); - if ( shift < 0 ) - goto skip; - subshift = ( shift & ( width - 1 ) ); - offset = ( shift / width ); - element = modulus->element; - for ( i = ( ( value_max - 1 ) / width ) ; ; i-- ) { - element[i] = ( element[ i - offset ] << subshift ); - if ( i <= offset ) - break; - if ( subshift ) { - element[i] |= ( element[ i - offset - 1 ] - >> ( width - subshift ) ); - } - } - for ( i-- ; i >= 0 ; i-- ) - element[i] = 0; - - /* Reduce the value "x" by iteratively adding or subtracting - * the scaled modulus "m". + * N = modulus * - * On each loop iteration, we maintain the invariant: + * n = number of bits in the modulus (including any leading zeros) * - * -2m <= x < 2m + * R = 2^n * - * If x is positive, we obtain the new value x' by - * subtracting m, otherwise we add m: + * Let r be the extension of the n-bit result register by a + * separate two's complement sign bit, such that -R <= r < R, + * and define: * - * 0 <= x < 2m => x' := x - m => -m <= x' < m - * -2m <= x < 0 => x' := x + m => -m <= x' < m + * x = r * 2^k * - * and then halve the modulus (by shifting right): + * as the value being reduced modulo N, where k is a + * non-negative integer bit shift. * - * m' = m/2 + * We want to reduce the initial value R^2=2^(2n), which we + * may trivially represent using r=1 and k=2n. * - * We therefore end up with: + * We then iterate over decrementing k, maintaining the loop + * invariant: * - * -m <= x' < m => -2m' <= x' < 2m' + * -N <= r < N * - * i.e. we have preseved the invariant while reducing the - * bounds on x' by one power of two. + * On each iteration we must first double r, to compensate for + * having decremented k: * - * The issue remains of how to determine on each iteration - * whether or not x is currently positive, given that both - * input values are unsigned big integers that may use all - * available bits (including the MSB). + * k' = k - 1 * - * On the first loop iteration, we may simply assume that x is - * positive, since it is unmodified from the input value and - * so is positive by definition (even if the MSB is set). We - * therefore unconditionally perform a subtraction on the - * first loop iteration. + * r' = 2r * - * Let k be the MSB after normalisation. We then have: + * x = r * 2^k = 2r * 2^(k-1) = r' * 2^k' * - * 2^k <= m < 2^(k+1) - * 2^k <= x < 2^(k+1) + * Note that doubling the n-bit result register will create a + * value of n+1 bits: this extra bit needs to be handled + * separately during the calculation. * - * On the first loop iteration, we therefore have: + * We then subtract N (if r is currently non-negative) or add + * N (if r is currently negative) to restore the loop + * invariant: * - * x' = (x - m) - * < 2^(k+1) - 2^k - * < 2^k + * 0 <= r < N => r" = 2r - N => -N <= r" < N + * -N <= r < 0 => r" = 2r + N => -N <= r" < N * - * Any positive value of x' therefore has its MSB set to zero, - * and so we may validly treat the MSB of x' as a sign bit at - * the end of the first loop iteration. + * Note that since N may use all n bits, the most significant + * bit of the n-bit result register is not a valid two's + * complement sign bit for r: the extra sign bit therefore + * also needs to be handled separately. * - * On all subsequent loop iterations, the starting value m is - * guaranteed to have its MSB set to zero (since it has - * already been shifted right at least once). Since we know - * from above that we preserve the loop invariant: + * Once we reach k=0, we have x=r and therefore: * - * -m <= x' < m + * -N <= x < N * - * we immediately know that any positive value of x' also has - * its MSB set to zero, and so we may validly treat the MSB of - * x' as a sign bit at the end of all subsequent loop - * iterations. + * After this last loop iteration (with k=0), we may need to + * add a single multiple of N to ensure that x is positive, + * i.e. lies within the range 0 <= x < N. * - * After the last loop iteration (when m' has been shifted - * back down to the original value of the modulus), we may - * need to add a single multiple of m' to ensure that x' is - * positive, i.e. lies within the range 0 <= x' < m'. To - * allow for reusing the (inlined) expansion of - * bigint_subtract(), we achieve this via a potential - * additional loop iteration that performs the addition and is - * then guaranteed to terminate (since the result will be - * positive). + * Since neither the modulus nor the value R^2 are secret, we + * may elide approximately half of the total number of + * iterations by constructing the initial representation of + * R^2 as r=2^m and k=2n-m (for some m such that 2^m < N). */ - for ( msb = 0 ; ( msb || ( shift >= 0 ) ) ; shift-- ) { - if ( msb ) { - bigint_add ( modulus, value ); - } else { - bigint_subtract ( modulus, value ); - } - msb = bigint_msb_is_set ( value ); - if ( shift > 0 ) - bigint_shr ( modulus ); + + /* Initialise x=R^2 */ + memset ( result, 0, sizeof ( *result ) ); + max = ( bigint_max_set_bit ( modulus ) - 2 ); + if ( max < 0 ) { + /* Degenerate case of N=0 or N=1: return a zero result */ + return; } + bigint_set_bit ( result, max ); + shift = ( ( 2 * size * width ) - max ); + sign = 0; - skip: - /* Sanity check */ - assert ( ! bigint_is_geq ( value, modulus ) ); + /* Iterate as described above */ + while ( shift-- ) { - /* Stop profiling */ - profile_stop ( &bigint_mod_profiler ); -} + /* Calculate 2r, storing extra bit separately */ + msb = bigint_shl ( result ); -/** - * Reduce supremum of big integer representation - * - * @v modulus0 Element 0 of big integer modulus - * @v result0 Element 0 of big integer to hold result - * @v size Number of elements in modulus and value - * - * Reduce the value 2^k (where k is the bit width of the big integer - * representation) modulo the specified modulus. - */ -void bigint_reduce_supremum_raw ( bigint_element_t *modulus0, - bigint_element_t *result0, - unsigned int size ) { - bigint_t ( size ) __attribute__ (( may_alias )) - *modulus = ( ( void * ) modulus0 ); - bigint_t ( size ) __attribute__ (( may_alias )) - *result = ( ( void * ) result0 ); + /* Add or subtract N according to current sign */ + if ( sign ) { + carry = bigint_add ( modulus, result ); + } else { + carry = bigint_subtract ( modulus, result ); + } - /* Calculate (2^k) mod N via direct reduction of (2^k - N) mod N */ - memset ( result, 0, sizeof ( *result ) ); - bigint_subtract ( modulus, result ); - bigint_reduce ( modulus, result ); + /* Calculate new sign of result + * + * We know the result lies in the range -N <= r < N + * and so the tuple (old sign, msb, carry) cannot ever + * take the values (0, 1, 0) or (1, 0, 0). We can + * therefore treat these as don't-care inputs, which + * allows us to simplify the boolean expression by + * ignoring the old sign completely. + */ + assert ( ( sign == msb ) || carry ); + sign = ( msb ^ carry ); + } + + /* Add N to make result positive if necessary */ + if ( sign ) + bigint_add ( modulus, result ); + + /* Sanity check */ + assert ( ! bigint_is_geq ( result, modulus ) ); } /** @@ -805,12 +764,9 @@ void bigint_mod_exp_raw ( const bigint_element_t *base0, ( ( void * ) result0 ); const unsigned int width = ( 8 * sizeof ( bigint_element_t ) ); struct { - union { - bigint_t ( 2 * size ) padded_modulus; - struct { - bigint_t ( size ) modulus; - bigint_t ( size ) stash; - }; + struct { + bigint_t ( size ) modulus; + bigint_t ( size ) stash; }; union { bigint_t ( 2 * size ) full; @@ -833,7 +789,7 @@ void bigint_mod_exp_raw ( const bigint_element_t *base0, } /* Factor modulus as (N * 2^scale) where N is odd */ - bigint_grow ( modulus, &temp->padded_modulus ); + bigint_copy ( modulus, &temp->modulus ); for ( scale = 0 ; ( ! bigint_bit_is_set ( &temp->modulus, 0 ) ) ; scale++ ) { bigint_shr ( &temp->modulus ); @@ -844,10 +800,10 @@ void bigint_mod_exp_raw ( const bigint_element_t *base0, submask = ~submask; /* Calculate (R^2 mod N) */ - bigint_reduce_supremum ( &temp->padded_modulus, &temp->product.full ); - bigint_copy ( &temp->product.low, &temp->stash ); + bigint_reduce ( &temp->modulus, &temp->stash ); /* Initialise result = Montgomery(1, R^2 mod N) */ + bigint_grow ( &temp->stash, &temp->product.full ); bigint_montgomery ( &temp->modulus, &temp->product.full, result ); /* Convert base into Montgomery form */ |