Bug 1642687 - land NSS 699541a7793b UPGRADE_NSS_RELEASE, r=jcj

2020-06-16  Sohaib ul Hassan  <sohaibulhassan@tuni.fi>

	* lib/freebl/mpi/mpi.c, lib/freebl/mpi/mpi.h,
	lib/freebl/mpi/mplogic.c:
	Bug 1631597 - Constant-time GCD and modular inversion
	r=rrelyea,kjacobs

	The implementation is based on the work by Bernstein and Yang
	(https://eprint.iacr.org/2019/266) "Fast constant-time gcd
	computation and modular inversion". It fixes the old mp_gcd and
	s_mp_invmod_odd_m functions.

	The patch also fix mpl_significant_bits s_mp_div_2d and s_mp_mul_2d
	by having less control flow to reduce side-channel leaks.

	Co Author : Billy Bob Brumley

	[699541a7793b] [tip]

Differential Revision: https://phabricator.services.mozilla.com/D80120
This commit is contained in:
Kevin Jacobs 2020-06-18 15:48:05 +00:00
Родитель f446e2ad19
Коммит bc02cf3e36
5 изменённых файлов: 294 добавлений и 135 удалений

Просмотреть файл

@ -1 +1 @@
6dcd00c13ffc
699541a7793b

Просмотреть файл

@ -10,4 +10,3 @@
*/
#error "Do not include this header file."

Просмотреть файл

@ -8,6 +8,7 @@
* file, You can obtain one at http://mozilla.org/MPL/2.0/. */
#include "mpi-priv.h"
#include "mplogic.h"
#if defined(OSF1)
#include <c_asm.h>
#endif
@ -1688,98 +1689,112 @@ mp_iseven(const mp_int *a)
/* {{{ mp_gcd(a, b, c) */
/*
Like the old mp_gcd() function, except computes the GCD using the
binary algorithm due to Josef Stein in 1961 (via Knuth).
Computes the GCD using the constant-time algorithm
by Bernstein and Yang (https://eprint.iacr.org/2019/266)
"Fast constant-time gcd computation and modular inversion"
*/
mp_err
mp_gcd(mp_int *a, mp_int *b, mp_int *c)
{
mp_err res;
mp_int u, v, t;
mp_size k = 0;
mp_digit cond = 0, mask = 0;
mp_int g, temp, f;
int i, j, m, bit = 1, delta = 1, shifts = 0, last = -1;
mp_size top, flen, glen;
mp_int *clear[3];
ARGCHK(a != NULL && b != NULL && c != NULL, MP_BADARG);
if (mp_cmp_z(a) == MP_EQ && mp_cmp_z(b) == MP_EQ)
return MP_RANGE;
/*
Early exit if either of the inputs is zero.
Caller is responsible for the proper handling of inputs.
*/
if (mp_cmp_z(a) == MP_EQ) {
return mp_copy(b, c);
} else if (mp_cmp_z(b) == MP_EQ) {
return mp_copy(a, c);
}
if ((res = mp_init(&t)) != MP_OKAY)
res = mp_copy(b, c);
SIGN(c) = ZPOS;
return res;
} else if (mp_cmp_z(b) == MP_EQ) {
res = mp_copy(a, c);
SIGN(c) = ZPOS;
return res;
if ((res = mp_init_copy(&u, a)) != MP_OKAY)
goto U;
if ((res = mp_init_copy(&v, b)) != MP_OKAY)
goto V;
SIGN(&u) = ZPOS;
SIGN(&v) = ZPOS;
/* Divide out common factors of 2 until at least 1 of a, b is even */
while (mp_iseven(&u) && mp_iseven(&v)) {
s_mp_div_2(&u);
s_mp_div_2(&v);
++k;
}
/* Initialize t */
if (mp_isodd(&u)) {
if ((res = mp_copy(&v, &t)) != MP_OKAY)
goto CLEANUP;
MP_CHECKOK(mp_init(&temp));
clear[++last] = &temp;
MP_CHECKOK(mp_init_copy(&g, a));
clear[++last] = &g;
MP_CHECKOK(mp_init_copy(&f, b));
clear[++last] = &f;
/* t = -v */
if (SIGN(&v) == ZPOS)
SIGN(&t) = NEG;
else
SIGN(&t) = ZPOS;
} else {
if ((res = mp_copy(&u, &t)) != MP_OKAY)
goto CLEANUP;
}
for (;;) {
while (mp_iseven(&t)) {
s_mp_div_2(&t);
/*
For even case compute the number of
shared powers of 2 in f and g.
*/
for (i = 0; i < USED(&f) && i < USED(&g); i++) {
mask = ~(DIGIT(&f, i) | DIGIT(&g, i));
for (j = 0; j < MP_DIGIT_BIT; j++) {
bit &= mask;
shifts += bit;
mask >>= 1;
}
}
/* Reduce to the odd case by removing the powers of 2. */
s_mp_div_2d(&f, shifts);
s_mp_div_2d(&g, shifts);
if (mp_cmp_z(&t) == MP_GT) {
if ((res = mp_copy(&t, &u)) != MP_OKAY)
goto CLEANUP;
/* Allocate to the size of largest mp_int. */
top = (mp_size)1 + ((USED(&f) >= USED(&g)) ? USED(&f) : USED(&g));
MP_CHECKOK(s_mp_grow(&f, top));
MP_CHECKOK(s_mp_grow(&g, top));
MP_CHECKOK(s_mp_grow(&temp, top));
} else {
if ((res = mp_copy(&t, &v)) != MP_OKAY)
goto CLEANUP;
/* Make sure f contains the odd value. */
MP_CHECKOK(mp_cswap((~DIGIT(&f, 0) & 1), &f, &g, top));
/* v = -t */
if (SIGN(&t) == ZPOS)
SIGN(&v) = NEG;
else
SIGN(&v) = ZPOS;
}
/* Upper bound for the total iterations. */
flen = mpl_significant_bits(&f);
glen = mpl_significant_bits(&g);
m = 4 + 3 * ((flen >= glen) ? flen : glen);
if ((res = mp_sub(&u, &v, &t)) != MP_OKAY)
goto CLEANUP;
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4146) // Thanks MSVC, we know what we're negating an unsigned mp_digit
#endif
if (s_mp_cmp_d(&t, 0) == MP_EQ)
break;
for (i = 0; i < m; i++) {
/* Step 1: conditional swap. */
/* Set cond if delta > 0 and g is odd. */
cond = (-delta >> (8 * sizeof(delta) - 1)) & DIGIT(&g, 0) & 1;
/* If cond is set replace (delta,f) with (-delta,-f). */
delta = (-cond & -delta) | ((cond - 1) & delta);
SIGN(&f) ^= cond;
/* If cond is set swap f with g. */
MP_CHECKOK(mp_cswap(cond, &f, &g, top));
/* Step 2: elemination. */
/* Update delta. */
delta++;
/* If g is odd, right shift (g+f) else right shift g. */
MP_CHECKOK(mp_add(&g, &f, &temp));
MP_CHECKOK(mp_cswap((DIGIT(&g, 0) & 1), &g, &temp, top));
s_mp_div_2(&g);
}
s_mp_2expt(&v, k); /* v = 2^k */
res = mp_mul(&u, &v, c); /* c = u * v */
#if defined(_MSC_VER)
#pragma warning(pop)
#endif
/* GCD is in f, take the absolute value. */
SIGN(&f) = ZPOS;
/* Add back the removed powers of 2. */
MP_CHECKOK(s_mp_mul_2d(&f, shifts));
MP_CHECKOK(mp_copy(&f, c));
CLEANUP:
mp_clear(&v);
V:
mp_clear(&u);
U:
mp_clear(&t);
while (last >= 0)
mp_clear(clear[last--]);
return res;
} /* end mp_gcd() */
/* }}} */
@ -2131,42 +2146,114 @@ CLEANUP:
return res;
}
/* compute mod inverse using Schroeppel's method, only if m is odd */
/*
Computes the modular inverse using the constant-time algorithm
by Bernstein and Yang (https://eprint.iacr.org/2019/266)
"Fast constant-time gcd computation and modular inversion"
*/
mp_err
s_mp_invmod_odd_m(const mp_int *a, const mp_int *m, mp_int *c)
{
int k;
mp_err res;
mp_int x;
mp_digit cond = 0;
mp_int g, f, v, r, temp;
int i, its, delta = 1, last = -1;
mp_size top, flen, glen;
mp_int *clear[6];
ARGCHK(a != NULL && m != NULL && c != NULL, MP_BADARG);
if (mp_cmp_z(a) == 0 || mp_cmp_z(m) == 0)
/* Check for invalid inputs. */
if (mp_cmp_z(a) == MP_EQ || mp_cmp_d(m, 2) == MP_LT)
return MP_RANGE;
if (mp_iseven(m))
if (a == m || mp_iseven(m))
return MP_UNDEF;
MP_DIGITS(&x) = 0;
MP_CHECKOK(mp_init(&temp));
clear[++last] = &temp;
MP_CHECKOK(mp_init(&v));
clear[++last] = &v;
MP_CHECKOK(mp_init(&r));
clear[++last] = &r;
MP_CHECKOK(mp_init_copy(&g, a));
clear[++last] = &g;
MP_CHECKOK(mp_init_copy(&f, m));
clear[++last] = &f;
if (a == c) {
if ((res = mp_init_copy(&x, a)) != MP_OKAY)
return res;
if (a == m)
m = &x;
a = &x;
} else if (m == c) {
if ((res = mp_init_copy(&x, m)) != MP_OKAY)
return res;
m = &x;
} else {
MP_DIGITS(&x) = 0;
mp_set(&v, 0);
mp_set(&r, 1);
/* Allocate to the size of largest mp_int. */
top = (mp_size)1 + ((USED(&f) >= USED(&g)) ? USED(&f) : USED(&g));
MP_CHECKOK(s_mp_grow(&f, top));
MP_CHECKOK(s_mp_grow(&g, top));
MP_CHECKOK(s_mp_grow(&temp, top));
MP_CHECKOK(s_mp_grow(&v, top));
MP_CHECKOK(s_mp_grow(&r, top));
/* Upper bound for the total iterations. */
flen = mpl_significant_bits(&f);
glen = mpl_significant_bits(&g);
its = 4 + 3 * ((flen >= glen) ? flen : glen);
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4146) // Thanks MSVC, we know what we're negating an unsigned mp_digit
#endif
for (i = 0; i < its; i++) {
/* Step 1: conditional swap. */
/* Set cond if delta > 0 and g is odd. */
cond = (-delta >> (8 * sizeof(delta) - 1)) & DIGIT(&g, 0) & 1;
/* If cond is set replace (delta,f,v) with (-delta,-f,-v). */
delta = (-cond & -delta) | ((cond - 1) & delta);
SIGN(&f) ^= cond;
SIGN(&v) ^= cond;
/* If cond is set swap (f,v) with (g,r). */
MP_CHECKOK(mp_cswap(cond, &f, &g, top));
MP_CHECKOK(mp_cswap(cond, &v, &r, top));
/* Step 2: elemination. */
/* Update delta */
delta++;
/* If g is odd replace r with (r+v). */
MP_CHECKOK(mp_add(&r, &v, &temp));
MP_CHECKOK(mp_cswap((DIGIT(&g, 0) & 1), &r, &temp, top));
/* If g is odd, right shift (g+f) else right shift g. */
MP_CHECKOK(mp_add(&g, &f, &temp));
MP_CHECKOK(mp_cswap((DIGIT(&g, 0) & 1), &g, &temp, top));
s_mp_div_2(&g);
/*
If r is even, right shift it.
If r is odd, right shift (r+m) which is even because m is odd.
We want the result modulo m so adding in multiples of m here vanish.
*/
MP_CHECKOK(mp_add(&r, m, &temp));
MP_CHECKOK(mp_cswap((DIGIT(&r, 0) & 1), &r, &temp, top));
s_mp_div_2(&r);
}
MP_CHECKOK(s_mp_almost_inverse(a, m, c));
k = res;
MP_CHECKOK(s_mp_fixup_reciprocal(c, m, k, c));
#if defined(_MSC_VER)
#pragma warning(pop)
#endif
/* We have the inverse in v, propagate sign from f. */
SIGN(&v) ^= SIGN(&f);
/* GCD is in f, take the absolute value. */
SIGN(&f) = ZPOS;
/* If gcd != 1, not invertible. */
if (mp_cmp_d(&f, 1) != MP_EQ) {
res = MP_UNDEF;
goto CLEANUP;
}
/* Return inverse modulo m. */
MP_CHECKOK(mp_mod(&v, m, c));
CLEANUP:
mp_clear(&x);
while (last >= 0)
mp_clear(clear[last--]);
return res;
}
@ -2218,13 +2305,24 @@ s_mp_invmod_2d(const mp_int *a, mp_size k, mp_int *c)
if (mp_iseven(a))
return MP_UNDEF;
#if defined(_MSC_VER)
#pragma warning(push)
#pragma warning(disable : 4146) // Thanks MSVC, we know what we're negating an unsigned mp_digit
#endif
if (k <= MP_DIGIT_BIT) {
mp_digit i = s_mp_invmod_radix(MP_DIGIT(a, 0));
/* propagate the sign from mp_int */
i = (i ^ -(mp_digit)SIGN(a)) + (mp_digit)SIGN(a);
if (k < MP_DIGIT_BIT)
i &= ((mp_digit)1 << k) - (mp_digit)1;
mp_set(c, i);
return MP_OKAY;
}
#if defined(_MSC_VER)
#pragma warning(pop)
#endif
MP_DIGITS(&t0) = 0;
MP_DIGITS(&t1) = 0;
MP_DIGITS(&val) = 0;
@ -2831,6 +2929,8 @@ s_mp_clamp(mp_int *mp)
while (used > 1 && DIGIT(mp, used - 1) == 0)
--used;
MP_USED(mp) = used;
if (used == 1 && DIGIT(mp, 0) == 0)
MP_SIGN(mp) = ZPOS;
} /* end s_mp_clamp() */
/* }}} */
@ -2908,37 +3008,36 @@ mp_err
s_mp_mul_2d(mp_int *mp, mp_digit d)
{
mp_err res;
mp_digit dshift, bshift;
mp_digit mask;
mp_digit dshift, rshift, mask, x, prev = 0;
mp_digit *pa = NULL;
int i;
ARGCHK(mp != NULL, MP_BADARG);
dshift = d / MP_DIGIT_BIT;
bshift = d % MP_DIGIT_BIT;
d %= MP_DIGIT_BIT;
/* mp_digit >> rshift is undefined behavior for rshift >= MP_DIGIT_BIT */
/* mod and corresponding mask logic avoid that when d = 0 */
rshift = MP_DIGIT_BIT - d;
rshift %= MP_DIGIT_BIT;
/* mask = (2**d - 1) * 2**(w-d) mod 2**w */
mask = (DIGIT_MAX << rshift) + 1;
mask &= DIGIT_MAX - 1;
/* bits to be shifted out of the top word */
if (bshift) {
mask = (mp_digit)~0 << (MP_DIGIT_BIT - bshift);
mask &= MP_DIGIT(mp, MP_USED(mp) - 1);
} else {
mask = 0;
}
x = MP_DIGIT(mp, MP_USED(mp) - 1) & mask;
if (MP_OKAY != (res = s_mp_pad(mp, MP_USED(mp) + dshift + (mask != 0))))
if (MP_OKAY != (res = s_mp_pad(mp, MP_USED(mp) + dshift + (x != 0))))
return res;
if (dshift && MP_OKAY != (res = s_mp_lshd(mp, dshift)))
return res;
if (bshift) {
mp_digit *pa = MP_DIGITS(mp);
mp_digit *alim = pa + MP_USED(mp);
mp_digit prev = 0;
pa = MP_DIGITS(mp) + dshift;
for (pa += dshift; pa < alim;) {
mp_digit x = *pa;
*pa++ = (x << bshift) | prev;
prev = x >> (DIGIT_BIT - bshift);
}
for (i = MP_USED(mp) - dshift; i > 0; i--) {
x = *pa;
*pa++ = (x << d) | prev;
prev = (x & mask) >> rshift;
}
s_mp_clamp(mp);
@ -3077,18 +3176,20 @@ void
s_mp_div_2d(mp_int *mp, mp_digit d)
{
int ix;
mp_digit save, next, mask;
mp_digit save, next, mask, lshift;
s_mp_rshd(mp, d / DIGIT_BIT);
d %= DIGIT_BIT;
if (d) {
mask = ((mp_digit)1 << d) - 1;
save = 0;
for (ix = USED(mp) - 1; ix >= 0; ix--) {
next = DIGIT(mp, ix) & mask;
DIGIT(mp, ix) = (DIGIT(mp, ix) >> d) | (save << (DIGIT_BIT - d));
save = next;
}
/* mp_digit << lshift is undefined behavior for lshift >= MP_DIGIT_BIT */
/* mod and corresponding mask logic avoid that when d = 0 */
lshift = DIGIT_BIT - d;
lshift %= DIGIT_BIT;
mask = ((mp_digit)1 << d) - 1;
save = 0;
for (ix = USED(mp) - 1; ix >= 0; ix--) {
next = DIGIT(mp, ix) & mask;
DIGIT(mp, ix) = (save << lshift) | (DIGIT(mp, ix) >> d);
save = next;
}
s_mp_clamp(mp);
@ -4841,5 +4942,44 @@ mp_to_fixlen_octets(const mp_int *mp, unsigned char *str, mp_size length)
} /* end mp_to_fixlen_octets() */
/* }}} */
/* {{{ mp_cswap(condition, a, b, numdigits) */
/* performs a conditional swap between mp_int. */
mp_err
mp_cswap(mp_digit condition, mp_int *a, mp_int *b, mp_size numdigits)
{
mp_digit x;
unsigned int i;
mp_err res = 0;
/* if pointers are equal return */
if (a == b)
return res;
if (MP_ALLOC(a) < numdigits || MP_ALLOC(b) < numdigits) {
MP_CHECKOK(s_mp_grow(a, numdigits));
MP_CHECKOK(s_mp_grow(b, numdigits));
}
condition = ((~condition & ((condition - 1))) >> (MP_DIGIT_BIT - 1)) - 1;
x = (USED(a) ^ USED(b)) & condition;
USED(a) ^= x;
USED(b) ^= x;
x = (SIGN(a) ^ SIGN(b)) & condition;
SIGN(a) ^= x;
SIGN(b) ^= x;
for (i = 0; i < numdigits; i++) {
x = (DIGIT(a, i) ^ DIGIT(b, i)) & condition;
DIGIT(a, i) ^= x;
DIGIT(b, i) ^= x;
}
CLEANUP:
return res;
} /* end mp_cswap() */
/* }}} */
/*------------------------------------------------------------------------*/
/* HERE THERE BE DRAGONS */

Просмотреть файл

@ -267,6 +267,7 @@ mp_size mp_trailing_zeros(const mp_int *mp);
void freebl_cpuid(unsigned long op, unsigned long *eax,
unsigned long *ebx, unsigned long *ecx,
unsigned long *edx);
mp_err mp_cswap(mp_digit condition, mp_int *a, mp_int *b, mp_size numdigits);
#define MP_CHECKOK(x) \
if (MP_OKAY > (res = (x))) \

Просмотреть файл

@ -407,35 +407,54 @@ mpl_get_bits(const mp_int *a, mp_size lsbNum, mp_size numBits)
return (mp_err)mask;
}
#define LZCNTLOOP(i) \
do { \
x = d >> (i); \
mask = (0 - x); \
mask = (0 - (mask >> (MP_DIGIT_BIT - 1))); \
bits += (i)&mask; \
d ^= (x ^ d) & mask; \
} while (0)
/*
mpl_significant_bits
returns number of significnant bits in abs(a).
returns number of significant bits in abs(a).
In other words: floor(lg(abs(a))) + 1.
returns 1 if value is zero.
*/
mp_size
mpl_significant_bits(const mp_int *a)
{
mp_size bits = 0;
/*
start bits at 1.
lg(0) = 0 => bits = 1 by function semantics.
below does a binary search for the _position_ of the top bit set,
which is floor(lg(abs(a))) for a != 0.
*/
mp_size bits = 1;
int ix;
ARGCHK(a != NULL, MP_BADARG);
for (ix = MP_USED(a); ix > 0;) {
mp_digit d;
d = MP_DIGIT(a, --ix);
if (d) {
while (d) {
++bits;
d >>= 1;
}
break;
}
mp_digit d, x, mask;
if ((d = MP_DIGIT(a, --ix)) == 0)
continue;
#if !defined(MP_USE_UINT_DIGIT)
LZCNTLOOP(32);
#endif
LZCNTLOOP(16);
LZCNTLOOP(8);
LZCNTLOOP(4);
LZCNTLOOP(2);
LZCNTLOOP(1);
break;
}
bits += ix * MP_DIGIT_BIT;
if (!bits)
bits = 1;
return bits;
}
#undef LZCNTLOOP
/*------------------------------------------------------------------------*/
/* HERE THERE BE DRAGONS */