diff --git a/mpint.c b/mpint.c index 4c9fd86e..774d744e 100644 --- a/mpint.c +++ b/mpint.c @@ -36,6 +36,7 @@ static inline BignumInt mp_word(mp_int *x, size_t i) static mp_int *mp_make_sized(size_t nw) { mp_int *x = snew_plus(mp_int, nw * sizeof(BignumInt)); + assert(nw); /* we outlaw the zero-word mp_int */ x->nw = nw; x->w = snew_plus_get_aux(x); mp_clear(x); @@ -140,8 +141,9 @@ void mp_cond_clear(mp_int *x, unsigned clear) */ static mp_int *mp_from_bytes_int(ptrlen bytes, size_t m, size_t c) { - mp_int *n = mp_make_sized( - (bytes.len + BIGNUM_INT_BYTES - 1) / BIGNUM_INT_BYTES); + size_t nw = (bytes.len + BIGNUM_INT_BYTES - 1) / BIGNUM_INT_BYTES; + nw = size_t_max(nw, 1); + mp_int *n = mp_make_sized(nw); for (size_t i = 0; i < bytes.len; i++) n->w[i / BIGNUM_INT_BYTES] |= (BignumInt)(((const unsigned char *)bytes.ptr)[m*i+c]) << @@ -211,6 +213,7 @@ mp_int *mp_from_hex_pl(ptrlen hex) assert(hex.len <= (~(size_t)0) / 4); size_t bits = hex.len * 4; size_t words = (bits + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS; + words = size_t_max(words, 1); mp_int *x = mp_make_sized(words); for (size_t nibble = 0; nibble < hex.len; nibble++) { BignumInt digit = ((char *)hex.ptr)[hex.len-1 - nibble]; @@ -1077,7 +1080,8 @@ void mp_rshift_fixed_into(mp_int *r, mp_int *a, size_t bits) mp_int *mp_rshift_fixed(mp_int *x, size_t bits) { size_t words = bits / BIGNUM_INT_BITS; - mp_int *r = mp_make_sized(x->nw - size_t_min(x->nw, words)); + size_t nw = x->nw - size_t_min(x->nw, words); + mp_int *r = mp_make_sized(size_t_max(nw, 1)); mp_rshift_fixed_into(r, x, bits); return r; } @@ -1148,6 +1152,7 @@ mp_int *mp_invert_mod_2to(mp_int *x, size_t p) assert(p > 0); size_t rw = (p + BIGNUM_INT_BITS - 1) / BIGNUM_INT_BITS; + rw = size_t_max(rw, 1); mp_int *r = mp_make_sized(rw); size_t mul_scratchsize = mp_mul_scratchspace(2*rw, rw, rw); diff --git a/test/cryptsuite.py b/test/cryptsuite.py index 317137f2..8082f56c 100755 --- a/test/cryptsuite.py +++ b/test/cryptsuite.py @@ -159,6 +159,7 @@ class mpint(MyTestBase): hexstr = 'ea7cb89f409ae845215822e37D32D0C63EC43E1381C2FF8094' self.assertEqual(int(mp_from_hex_pl(hexstr)), int(hexstr, 16)) self.assertEqual(int(mp_from_hex(hexstr)), int(hexstr, 16)) + self.assertEqual(int(mp_from_hex("")), 0) p2 = mp_power_2(123) self.assertEqual(int(p2), 1 << 123) p2c = mp_copy(p2) @@ -319,7 +320,7 @@ class mpint(MyTestBase): diff = mp_sub(am, bm) self.assertEqual(int(diff), (ai - bi) & mp_mask(diff)) - for bits in range(0, 512, 64): + for bits in range(64, 512, 64): cm = mp_new(bits) mp_add_into(cm, am, bm) self.assertEqual(int(cm), (ai + bi) & mp_mask(cm)) @@ -357,8 +358,8 @@ class mpint(MyTestBase): if r >= d: continue # silly cases with tiny divisors n = q*d + r - mq = mp_new(nbits(q)) - mr = mp_new(nbits(r)) + mq = mp_new(max(nbits(q), 1)) + mr = mp_new(max(nbits(r), 1)) mp_divmod_into(n, d, mq, mr) self.assertEqual(int(mq), q) self.assertEqual(int(mr), r)