From e01104f8998163ddb3f893e8c5d2d1691989e36e Mon Sep 17 00:00:00 2001 From: Simon Tatham Date: Fri, 2 Aug 2013 06:27:54 +0000 Subject: [PATCH] Fix an array-size bug in modmul, and add some tests for it. [originally from svn r9977] --- sshbn.c | 45 +++++++++++++++++++++++++++++++++++++++++++++ testdata/bignum.py | 10 ++++++++++ 2 files changed, 55 insertions(+) diff --git a/sshbn.c b/sshbn.c index 24f3ca6b..677b121c 100644 --- a/sshbn.c +++ b/sshbn.c @@ -1018,6 +1018,13 @@ Bignum modmul(Bignum p, Bignum q, Bignum mod) pqlen = (p[0] > q[0] ? p[0] : q[0]); + /* + * Make sure that we're allowing enough space. The shifting below + * will underflow the vectors we allocate if pqlen is too small. + */ + if (2*pqlen <= mlen) + pqlen = mlen/2 + 1; + /* Allocate n of size pqlen, copy p to n */ n = snewn(pqlen, BignumInt); i = pqlen - p[0]; @@ -1864,6 +1871,44 @@ int main(int argc, char **argv) freebn(b); freebn(c); freebn(p); + } else if (!strcmp(buf, "modmul")) { + Bignum a, b, m, c, p; + + if (ptrnum != 4) { + printf("%d: modmul with %d parameters, expected 4\n", + line, ptrnum); + exit(1); + } + a = bignum_from_bytes(ptrs[0], ptrs[1]-ptrs[0]); + b = bignum_from_bytes(ptrs[1], ptrs[2]-ptrs[1]); + m = bignum_from_bytes(ptrs[2], ptrs[3]-ptrs[2]); + c = bignum_from_bytes(ptrs[3], ptrs[4]-ptrs[3]); + p = modmul(a, b, m); + + if (bignum_cmp(c, p) == 0) { + passes++; + } else { + char *as = bignum_decimal(a); + char *bs = bignum_decimal(b); + char *ms = bignum_decimal(m); + char *cs = bignum_decimal(c); + char *ps = bignum_decimal(p); + + printf("%d: fail: %s * %s mod %s gave %s expected %s\n", + line, as, bs, ms, ps, cs); + fails++; + + sfree(as); + sfree(bs); + sfree(ms); + sfree(cs); + sfree(ps); + } + freebn(a); + freebn(b); + freebn(m); + freebn(c); + freebn(p); } else if (!strcmp(buf, "pow")) { Bignum base, expt, modulus, expected, answer; diff --git a/testdata/bignum.py b/testdata/bignum.py index 05ca4528..b2a6614b 100644 --- a/testdata/bignum.py +++ b/testdata/bignum.py @@ -103,6 +103,15 @@ for i in range(1,4200): a, b, p = findprod((1<