aboutsummaryrefslogtreecommitdiffstats
path: root/src/crypto/bigint.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/crypto/bigint.c')
-rw-r--r--src/crypto/bigint.c248
1 files changed, 102 insertions, 146 deletions
diff --git a/src/crypto/bigint.c b/src/crypto/bigint.c
index ad22af771..9ccd9ff88 100644
--- a/src/crypto/bigint.c
+++ b/src/crypto/bigint.c
@@ -27,7 +27,6 @@ FILE_LICENCE ( GPL2_OR_LATER_OR_UBDL );
#include <string.h>
#include <assert.h>
#include <stdio.h>
-#include <ipxe/profile.h>
#include <ipxe/bigint.h>
/** @file
@@ -35,10 +34,6 @@ FILE_LICENCE ( GPL2_OR_LATER_OR_UBDL );
* Big integer support
*/
-/** Modular direct reduction profiler */
-static struct profiler bigint_mod_profiler __profiler =
- { .name = "bigint_mod" };
-
/** Minimum number of least significant bytes included in transcription */
#define BIGINT_NTOA_LSB_MIN 16
@@ -180,172 +175,136 @@ void bigint_multiply_raw ( const bigint_element_t *multiplicand0,
}
/**
- * Reduce big integer
+ * Reduce big integer R^2 modulo N
*
* @v modulus0 Element 0 of big integer modulus
- * @v value0 Element 0 of big integer to be reduced
- * @v size Number of elements in modulus and value
+ * @v result0 Element 0 of big integer to hold result
+ * @v size Number of elements in modulus and result
+ *
+ * Reduce the value R^2 modulo N, where R=2^n and n is the number of
+ * bits in the representation of the modulus N, including any leading
+ * zero bits.
*/
-void bigint_reduce_raw ( bigint_element_t *modulus0, bigint_element_t *value0,
- unsigned int size ) {
- bigint_t ( size ) __attribute__ (( may_alias ))
- *modulus = ( ( void * ) modulus0 );
+void bigint_reduce_raw ( const bigint_element_t *modulus0,
+ bigint_element_t *result0, unsigned int size ) {
+ const bigint_t ( size ) __attribute__ (( may_alias ))
+ *modulus = ( ( const void * ) modulus0 );
bigint_t ( size ) __attribute__ (( may_alias ))
- *value = ( ( void * ) value0 );
+ *result = ( ( void * ) result0 );
const unsigned int width = ( 8 * sizeof ( bigint_element_t ) );
- bigint_element_t *element;
- unsigned int modulus_max;
- unsigned int value_max;
- unsigned int subshift;
- int offset;
- int shift;
+ unsigned int shift;
+ int max;
+ int sign;
int msb;
- int i;
-
- /* Start profiling */
- profile_start ( &bigint_mod_profiler );
+ int carry;
- /* Normalise the modulus
+ /* We have the constants:
*
- * Scale the modulus by shifting left such that both modulus
- * "m" and value "x" have the same most significant set bit.
- * (If this is not possible, then the value is already less
- * than the modulus, and we may therefore skip reduction
- * completely.)
- */
- value_max = bigint_max_set_bit ( value );
- modulus_max = bigint_max_set_bit ( modulus );
- shift = ( value_max - modulus_max );
- if ( shift < 0 )
- goto skip;
- subshift = ( shift & ( width - 1 ) );
- offset = ( shift / width );
- element = modulus->element;
- for ( i = ( ( value_max - 1 ) / width ) ; ; i-- ) {
- element[i] = ( element[ i - offset ] << subshift );
- if ( i <= offset )
- break;
- if ( subshift ) {
- element[i] |= ( element[ i - offset - 1 ]
- >> ( width - subshift ) );
- }
- }
- for ( i-- ; i >= 0 ; i-- )
- element[i] = 0;
-
- /* Reduce the value "x" by iteratively adding or subtracting
- * the scaled modulus "m".
+ * N = modulus
*
- * On each loop iteration, we maintain the invariant:
+ * n = number of bits in the modulus (including any leading zeros)
*
- * -2m <= x < 2m
+ * R = 2^n
*
- * If x is positive, we obtain the new value x' by
- * subtracting m, otherwise we add m:
+ * Let r be the extension of the n-bit result register by a
+ * separate two's complement sign bit, such that -R <= r < R,
+ * and define:
*
- * 0 <= x < 2m => x' := x - m => -m <= x' < m
- * -2m <= x < 0 => x' := x + m => -m <= x' < m
+ * x = r * 2^k
*
- * and then halve the modulus (by shifting right):
+ * as the value being reduced modulo N, where k is a
+ * non-negative integer bit shift.
*
- * m' = m/2
+ * We want to reduce the initial value R^2=2^(2n), which we
+ * may trivially represent using r=1 and k=2n.
*
- * We therefore end up with:
+ * We then iterate over decrementing k, maintaining the loop
+ * invariant:
*
- * -m <= x' < m => -2m' <= x' < 2m'
+ * -N <= r < N
*
- * i.e. we have preseved the invariant while reducing the
- * bounds on x' by one power of two.
+ * On each iteration we must first double r, to compensate for
+ * having decremented k:
*
- * The issue remains of how to determine on each iteration
- * whether or not x is currently positive, given that both
- * input values are unsigned big integers that may use all
- * available bits (including the MSB).
+ * k' = k - 1
*
- * On the first loop iteration, we may simply assume that x is
- * positive, since it is unmodified from the input value and
- * so is positive by definition (even if the MSB is set). We
- * therefore unconditionally perform a subtraction on the
- * first loop iteration.
+ * r' = 2r
*
- * Let k be the MSB after normalisation. We then have:
+ * x = r * 2^k = 2r * 2^(k-1) = r' * 2^k'
*
- * 2^k <= m < 2^(k+1)
- * 2^k <= x < 2^(k+1)
+ * Note that doubling the n-bit result register will create a
+ * value of n+1 bits: this extra bit needs to be handled
+ * separately during the calculation.
*
- * On the first loop iteration, we therefore have:
+ * We then subtract N (if r is currently non-negative) or add
+ * N (if r is currently negative) to restore the loop
+ * invariant:
*
- * x' = (x - m)
- * < 2^(k+1) - 2^k
- * < 2^k
+ * 0 <= r < N => r" = 2r - N => -N <= r" < N
+ * -N <= r < 0 => r" = 2r + N => -N <= r" < N
*
- * Any positive value of x' therefore has its MSB set to zero,
- * and so we may validly treat the MSB of x' as a sign bit at
- * the end of the first loop iteration.
+ * Note that since N may use all n bits, the most significant
+ * bit of the n-bit result register is not a valid two's
+ * complement sign bit for r: the extra sign bit therefore
+ * also needs to be handled separately.
*
- * On all subsequent loop iterations, the starting value m is
- * guaranteed to have its MSB set to zero (since it has
- * already been shifted right at least once). Since we know
- * from above that we preserve the loop invariant:
+ * Once we reach k=0, we have x=r and therefore:
*
- * -m <= x' < m
+ * -N <= x < N
*
- * we immediately know that any positive value of x' also has
- * its MSB set to zero, and so we may validly treat the MSB of
- * x' as a sign bit at the end of all subsequent loop
- * iterations.
+ * After this last loop iteration (with k=0), we may need to
+ * add a single multiple of N to ensure that x is positive,
+ * i.e. lies within the range 0 <= x < N.
*
- * After the last loop iteration (when m' has been shifted
- * back down to the original value of the modulus), we may
- * need to add a single multiple of m' to ensure that x' is
- * positive, i.e. lies within the range 0 <= x' < m'. To
- * allow for reusing the (inlined) expansion of
- * bigint_subtract(), we achieve this via a potential
- * additional loop iteration that performs the addition and is
- * then guaranteed to terminate (since the result will be
- * positive).
+ * Since neither the modulus nor the value R^2 are secret, we
+ * may elide approximately half of the total number of
+ * iterations by constructing the initial representation of
+ * R^2 as r=2^m and k=2n-m (for some m such that 2^m < N).
*/
- for ( msb = 0 ; ( msb || ( shift >= 0 ) ) ; shift-- ) {
- if ( msb ) {
- bigint_add ( modulus, value );
- } else {
- bigint_subtract ( modulus, value );
- }
- msb = bigint_msb_is_set ( value );
- if ( shift > 0 )
- bigint_shr ( modulus );
+
+ /* Initialise x=R^2 */
+ memset ( result, 0, sizeof ( *result ) );
+ max = ( bigint_max_set_bit ( modulus ) - 2 );
+ if ( max < 0 ) {
+ /* Degenerate case of N=0 or N=1: return a zero result */
+ return;
}
+ bigint_set_bit ( result, max );
+ shift = ( ( 2 * size * width ) - max );
+ sign = 0;
- skip:
- /* Sanity check */
- assert ( ! bigint_is_geq ( value, modulus ) );
+ /* Iterate as described above */
+ while ( shift-- ) {
- /* Stop profiling */
- profile_stop ( &bigint_mod_profiler );
-}
+ /* Calculate 2r, storing extra bit separately */
+ msb = bigint_shl ( result );
-/**
- * Reduce supremum of big integer representation
- *
- * @v modulus0 Element 0 of big integer modulus
- * @v result0 Element 0 of big integer to hold result
- * @v size Number of elements in modulus and value
- *
- * Reduce the value 2^k (where k is the bit width of the big integer
- * representation) modulo the specified modulus.
- */
-void bigint_reduce_supremum_raw ( bigint_element_t *modulus0,
- bigint_element_t *result0,
- unsigned int size ) {
- bigint_t ( size ) __attribute__ (( may_alias ))
- *modulus = ( ( void * ) modulus0 );
- bigint_t ( size ) __attribute__ (( may_alias ))
- *result = ( ( void * ) result0 );
+ /* Add or subtract N according to current sign */
+ if ( sign ) {
+ carry = bigint_add ( modulus, result );
+ } else {
+ carry = bigint_subtract ( modulus, result );
+ }
- /* Calculate (2^k) mod N via direct reduction of (2^k - N) mod N */
- memset ( result, 0, sizeof ( *result ) );
- bigint_subtract ( modulus, result );
- bigint_reduce ( modulus, result );
+ /* Calculate new sign of result
+ *
+ * We know the result lies in the range -N <= r < N
+ * and so the tuple (old sign, msb, carry) cannot ever
+ * take the values (0, 1, 0) or (1, 0, 0). We can
+ * therefore treat these as don't-care inputs, which
+ * allows us to simplify the boolean expression by
+ * ignoring the old sign completely.
+ */
+ assert ( ( sign == msb ) || carry );
+ sign = ( msb ^ carry );
+ }
+
+ /* Add N to make result positive if necessary */
+ if ( sign )
+ bigint_add ( modulus, result );
+
+ /* Sanity check */
+ assert ( ! bigint_is_geq ( result, modulus ) );
}
/**
@@ -805,12 +764,9 @@ void bigint_mod_exp_raw ( const bigint_element_t *base0,
( ( void * ) result0 );
const unsigned int width = ( 8 * sizeof ( bigint_element_t ) );
struct {
- union {
- bigint_t ( 2 * size ) padded_modulus;
- struct {
- bigint_t ( size ) modulus;
- bigint_t ( size ) stash;
- };
+ struct {
+ bigint_t ( size ) modulus;
+ bigint_t ( size ) stash;
};
union {
bigint_t ( 2 * size ) full;
@@ -833,7 +789,7 @@ void bigint_mod_exp_raw ( const bigint_element_t *base0,
}
/* Factor modulus as (N * 2^scale) where N is odd */
- bigint_grow ( modulus, &temp->padded_modulus );
+ bigint_copy ( modulus, &temp->modulus );
for ( scale = 0 ; ( ! bigint_bit_is_set ( &temp->modulus, 0 ) ) ;
scale++ ) {
bigint_shr ( &temp->modulus );
@@ -844,10 +800,10 @@ void bigint_mod_exp_raw ( const bigint_element_t *base0,
submask = ~submask;
/* Calculate (R^2 mod N) */
- bigint_reduce_supremum ( &temp->padded_modulus, &temp->product.full );
- bigint_copy ( &temp->product.low, &temp->stash );
+ bigint_reduce ( &temp->modulus, &temp->stash );
/* Initialise result = Montgomery(1, R^2 mod N) */
+ bigint_grow ( &temp->stash, &temp->product.full );
bigint_montgomery ( &temp->modulus, &temp->product.full, result );
/* Convert base into Montgomery form */