* bignum.c (rb_big_mul): faster multiplication by Karatsuba method and

twice faster square than normal multiplication.

* random.c (rb_rand_internal): used by Bignum#*.

* test/ruby/test_bignum.rb: add some tests for above.


git-svn-id: svn+ssh://ci.ruby-lang.org/ruby/trunk@20733 b2dd03c8-39d4-4d8f-98ff-823fe69b080e
This commit is contained in:
mame 2008-12-14 03:59:02 +00:00
Родитель 529ad093d4
Коммит 19f45f853c
4 изменённых файлов: 377 добавлений и 79 удалений

Просмотреть файл

@ -1,3 +1,12 @@
$BF|(B 12$B7n(B 14 12:51:48 2008 Yusuke Endoh <mame@tsg.ne.jp>
* bignum.c (rb_big_mul): faster multiplication by Karatsuba method and
twice faster square than normal multiplication.
* random.c (rb_rand_internal): used by Bignum#*.
* test/ruby/test_bignum.rb: add some tests for above.
Sun Dec 14 09:14:37 2008 Yuki Sonoda (Yugui) <yugui@yugui.jp>
* reverts r20713.

417
bignum.c
Просмотреть файл

@ -17,6 +17,7 @@
#ifdef HAVE_IEEEFP_H
#include <ieeefp.h>
#endif
#include <assert.h>
VALUE rb_cBignum;
@ -1380,12 +1381,36 @@ rb_big_neg(VALUE x)
return bignorm(z);
}
static void
bigsub_core(BDIGIT *xds, long xn, BDIGIT *yds, long yn, BDIGIT *zds, long zn)
{
BDIGIT_DBL_SIGNED num;
long i;
for (i = 0, num = 0; i < yn; i++) {
num += (BDIGIT_DBL_SIGNED)xds[i] - yds[i];
zds[i] = BIGLO(num);
num = BIGDN(num);
}
while (num && i < xn) {
num += xds[i];
zds[i++] = BIGLO(num);
num = BIGDN(num);
}
while (i < xn) {
zds[i] = xds[i];
i++;
}
assert(i <= zn);
while (i < zn) {
zds[i++] = 0;
}
}
static VALUE
bigsub(VALUE x, VALUE y)
{
VALUE z = 0;
BDIGIT *zds;
BDIGIT_DBL_SIGNED num;
long i = RBIGNUM_LEN(x);
/* if x is larger than y, swap */
@ -1406,32 +1431,52 @@ bigsub(VALUE x, VALUE y)
}
z = bignew(RBIGNUM_LEN(x), z==0);
zds = BDIGITS(z);
bigsub_core(BDIGITS(x), RBIGNUM_LEN(x),
BDIGITS(y), RBIGNUM_LEN(y),
BDIGITS(z), RBIGNUM_LEN(z));
for (i = 0, num = 0; i < RBIGNUM_LEN(y); i++) {
num += (BDIGIT_DBL_SIGNED)BDIGITS(x)[i] - BDIGITS(y)[i];
zds[i] = BIGLO(num);
num = BIGDN(num);
return z;
}
static void
bigadd_core(BDIGIT *xds, long xn, BDIGIT *yds, long yn, BDIGIT *zds, long zn)
{
BDIGIT_DBL num = 0;
long i;
if (xn > yn) {
BDIGIT *tds;
tds = xds; xds = yds; yds = tds;
i = xn; xn = yn; yn = i;
}
while (num && i < RBIGNUM_LEN(x)) {
num += BDIGITS(x)[i];
i = 0;
while (i < xn) {
num += (BDIGIT_DBL)xds[i] + yds[i];
zds[i++] = BIGLO(num);
num = BIGDN(num);
}
while (i < RBIGNUM_LEN(x)) {
zds[i] = BDIGITS(x)[i];
while (num && i < yn) {
num += yds[i];
zds[i++] = BIGLO(num);
num = BIGDN(num);
}
while (i < yn) {
zds[i] = yds[i];
i++;
}
return z;
if (num) zds[i++] = (BDIGIT)num;
assert(i <= zn);
while (i < zn) {
zds[i++] = 0;
}
}
static VALUE
bigadd(VALUE x, VALUE y, int sign)
{
VALUE z;
BDIGIT_DBL num;
long i, len;
long len;
sign = (sign == RBIGNUM_SIGN(y));
if (RBIGNUM_SIGN(x) != sign) {
@ -1441,30 +1486,15 @@ bigadd(VALUE x, VALUE y, int sign)
if (RBIGNUM_LEN(x) > RBIGNUM_LEN(y)) {
len = RBIGNUM_LEN(x) + 1;
z = x; x = y; y = z;
}
else {
len = RBIGNUM_LEN(y) + 1;
}
z = bignew(len, sign);
len = RBIGNUM_LEN(x);
for (i = 0, num = 0; i < len; i++) {
num += (BDIGIT_DBL)BDIGITS(x)[i] + BDIGITS(y)[i];
BDIGITS(z)[i] = BIGLO(num);
num = BIGDN(num);
}
len = RBIGNUM_LEN(y);
while (num && i < len) {
num += BDIGITS(y)[i];
BDIGITS(z)[i++] = BIGLO(num);
num = BIGDN(num);
}
while (i < len) {
BDIGITS(z)[i] = BDIGITS(y)[i];
i++;
}
BDIGITS(z)[i] = (BDIGIT)num;
bigadd_core(BDIGITS(x), RBIGNUM_LEN(x),
BDIGITS(y), RBIGNUM_LEN(y),
BDIGITS(z), RBIGNUM_LEN(z));
return z;
}
@ -1519,24 +1549,20 @@ rb_big_minus(VALUE x, VALUE y)
}
}
static void
rb_big_stop(void *ptr)
static long
big_real_len(VALUE x)
{
VALUE *stop = (VALUE*)ptr;
*stop = Qtrue;
long i = RBIGNUM_LEN(x);
while (--i && !BDIGITS(x)[i]);
return i + 1;
}
struct big_mul_struct {
VALUE x, y, z, stop;
};
static VALUE
bigmul1(void *ptr)
bigmul1_normal(VALUE x, VALUE y)
{
struct big_mul_struct *bms = (struct big_mul_struct*)ptr;
long i, j;
BDIGIT_DBL n = 0;
VALUE x = bms->x, y = bms->y, z = bms->z;
VALUE z = bignew(RBIGNUM_LEN(x) + RBIGNUM_LEN(y) + 1, RBIGNUM_SIGN(x)==RBIGNUM_SIGN(y));
BDIGIT *zds;
j = RBIGNUM_LEN(x) + RBIGNUM_LEN(y) + 1;
@ -1544,7 +1570,6 @@ bigmul1(void *ptr)
while (j--) zds[j] = 0;
for (i = 0; i < RBIGNUM_LEN(x); i++) {
BDIGIT_DBL dd;
if (bms->stop) return Qnil;
dd = BDIGITS(x)[i];
if (dd == 0) continue;
n = 0;
@ -1558,15 +1583,267 @@ bigmul1(void *ptr)
zds[i + j] = n;
}
}
rb_thread_check_ints();
return z;
}
static VALUE
rb_big_mul0(VALUE x, VALUE y)
{
struct big_mul_struct bms;
volatile VALUE z;
static VALUE bigmul0(VALUE x, VALUE y);
/* balancing multiplication by slicing larger argument */
static VALUE
bigmul1_balance(VALUE x, VALUE y)
{
VALUE z, t1, t2;
long i, xn, yn, r, n;
xn = RBIGNUM_LEN(x);
yn = RBIGNUM_LEN(y);
assert(2 * xn <= yn);
z = bignew(xn + yn, RBIGNUM_SIGN(x)==RBIGNUM_SIGN(y));
t1 = bignew(xn, 1);
for (i = 0; i < xn + yn; i++) BDIGITS(z)[i] = 0;
n = 0;
while (yn > 0) {
r = xn > yn ? yn : xn;
MEMCPY(BDIGITS(t1), BDIGITS(y) + n, BDIGIT, r);
RBIGNUM_SET_LEN(t1, r);
t2 = bigmul0(x, t1);
bigadd_core(BDIGITS(z) + n, RBIGNUM_LEN(z) - n,
BDIGITS(t2), big_real_len(t2),
BDIGITS(z) + n, RBIGNUM_LEN(z) - n);
yn -= r;
n += r;
}
rb_gc_force_recycle(t1);
return z;
}
/* split a bignum into high and low bignums */
static void
big_split(VALUE v, long n, VALUE *ph, VALUE *pl)
{
long hn, ln;
VALUE h, l;
ln = RBIGNUM_LEN(v) > n ? n : RBIGNUM_LEN(v);
hn = RBIGNUM_LEN(v) - ln;
while (--hn && !BDIGITS(v)[hn + ln]);
h = bignew(++hn, 1);
MEMCPY(BDIGITS(h), BDIGITS(v) + ln, BDIGIT, hn);
while (--ln && !BDIGITS(v)[ln]);
l = bignew(++ln, 1);
MEMCPY(BDIGITS(l), BDIGITS(v), BDIGIT, ln);
*pl = l;
*ph = h;
}
/* multiplication by karatsuba method */
static VALUE
bigmul1_karatsuba(VALUE x, VALUE y)
{
long i, n, xn, yn, t1n, t2n;
VALUE xh, xl, yh, yl, z, t1, t2, t3;
BDIGIT *zds;
xn = RBIGNUM_LEN(x);
yn = RBIGNUM_LEN(y);
n = yn / 2;
big_split(x, n, &xh, &xl);
if (x == y) {
yh = xh; yl = xl;
}
else big_split(y, n, &yh, &yl);
/* x = xh * b + xl
* y = yh * b + yl
*
* Karatsuba method:
* x * y = z2 * b^2 + z1 * b + z0
* where
* z2 = xh * yh
* z0 = xl * yl
* z1 = (xh + xl) * (yh + yl) - x2 - x0
*
* ref: http://en.wikipedia.org/wiki/Karatsuba_algorithm
*/
/* allocate a result bignum */
z = bignew(xn + yn, RBIGNUM_SIGN(x)==RBIGNUM_SIGN(y));
zds = BDIGITS(z);
/* t1 <- xh * yh */
t1 = bigmul0(xh, yh);
t1n = big_real_len(t1);
/* copy t1 into high bytes of the result (z2) */
MEMCPY(zds + 2 * n, BDIGITS(t1), BDIGIT, t1n);
for (i = 2 * n + t1n; i < xn + yn; i++) BDIGITS(z)[i] = 0;
if (!BIGZEROP(xl) && !BIGZEROP(yl)) {
/* t2 <- xl * yl */
t2 = bigmul0(xl, yl);
t2n = big_real_len(t2);
/* copy t2 into low bytes of the result (z0) */
MEMCPY(zds, BDIGITS(t2), BDIGIT, t2n);
for (i = t2n; i < 2 * n; i++) BDIGITS(z)[i] = 0;
/* subtract t2 from middle bytes of the result (z1) */
i = xn + yn - n;
bigsub_core(zds + n, i, BDIGITS(t2), t2n, zds + n, i);
rb_gc_force_recycle(t2);
}
else {
/* copy 0 into low bytes of the result (z0) */
for (i = 0; i < 2 * n; i++) BDIGITS(z)[i] = 0;
}
/* subtract t1 from middle bytes of the result (z1) */
i = xn + yn - n;
bigsub_core(zds + n, i, BDIGITS(t1), t1n, zds + n, i);
rb_gc_force_recycle(t1);
/* t1 <- xh + xl */
t1 = bigadd(xh, xl, 1);
if (xh != yh) rb_gc_force_recycle(xh);
if (xl != yl) rb_gc_force_recycle(xl);
/* t2 <- yh + yl */
t2 = (x == y) ? t1 : bigadd(yh, yl, 1);
rb_gc_force_recycle(yh);
rb_gc_force_recycle(yl);
/* t3 <- t1 * t2 */
t3 = bigmul0(t1, t2);
rb_gc_force_recycle(t1);
if (t1 != t2) rb_gc_force_recycle(t2);
/* add t3 to middle bytes of the result (z1) */
bigadd_core(zds + n, i, BDIGITS(t3), big_real_len(t3), zds + n, i);
rb_gc_force_recycle(t3);
return z;
}
/* efficient squaring (2 times faster than normal multiplication)
* ref: Handbook of Applied Cryptography, Algorithm 14.16
* http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf
*/
static VALUE
bigsqr_fast(VALUE x)
{
long len = RBIGNUM_LEN(x), i, j;
VALUE z = bignew(2 * len + 1, 1);
BDIGIT *xds = BDIGITS(x), *zds = BDIGITS(z);
BDIGIT_DBL c, v, w;
for (i = 2 * len + 1; i--; ) zds[i] = 0;
for (i = 0; i < len; i++) {
v = (BDIGIT_DBL)xds[i];
if (!v) continue;
c = (BDIGIT_DBL)zds[i + i] + v * v;
zds[i + i] = BIGLO(c);
c = BIGDN(c);
v *= 2;
for (j = i + 1; j < len; j++) {
w = (BDIGIT_DBL)xds[j];
c += (BDIGIT_DBL)zds[i + j] + BIGLO(v) * w;
zds[i + j] = BIGLO(c);
c = BIGDN(c);
if (BIGDN(v)) c += w;
}
if (c) {
c += (BDIGIT_DBL)zds[i + len];
zds[i + len] = BIGLO(c);
c = BIGDN(c);
}
if (c) zds[i + len + 1] += c;
}
return z;
}
#define KARATSUBA_MUL_DIGITS 70
/* determine whether a bignum is sparse or not by random sampling */
static inline VALUE
big_sparse_p(VALUE x)
{
long c = 0, n = RBIGNUM_LEN(x);
unsigned long rb_rand_internal(unsigned long i);
if ( BDIGITS(x)[rb_rand_internal(n / 2) + n / 4]) c++;
if (c <= 1 && BDIGITS(x)[rb_rand_internal(n / 2) + n / 4]) c++;
if (c <= 1 && BDIGITS(x)[rb_rand_internal(n / 2) + n / 4]) c++;
return (c <= 1) ? Qtrue : Qfalse;
}
#if 0
static void
dump_bignum(VALUE x)
{
long i;
printf("0x0");
for (i = RBIGNUM_LEN(x); i--; ) {
printf("_%08x", BDIGITS(x)[i]);
}
puts("");
}
#endif
static VALUE
bigmul0(VALUE x, VALUE y)
{
long xn, yn;
xn = RBIGNUM_LEN(x);
yn = RBIGNUM_LEN(y);
/* make sure that y is longer than x */
if (xn > yn) {
VALUE t;
long tn;
t = x; x = y; y = t;
tn = xn; xn = yn; yn = tn;
}
assert(xn <= yn);
/* normal multiplication when x is small */
if (xn < KARATSUBA_MUL_DIGITS) {
normal:
if (x == y) return bigsqr_fast(x);
return bigmul1_normal(x, y);
}
/* normal multiplication when x or y is a sparse bignum */
if (big_sparse_p(x)) goto normal;
if (big_sparse_p(y)) return bigmul1_normal(y, x);
/* balance multiplication by slicing y when x is much smaller than y */
if (2 * xn <= yn) return bigmul1_balance(x, y);
/* multiplication by karatsuba method */
return bigmul1_karatsuba(x, y);
}
/*
* call-seq:
* big * other => Numeric
*
* Multiplies big and other, returning the result.
*/
VALUE
rb_big_mul(VALUE x, VALUE y)
{
switch (TYPE(y)) {
case T_FIXNUM:
y = rb_int2big(FIX2LONG(y));
@ -1582,32 +1859,7 @@ rb_big_mul0(VALUE x, VALUE y)
return rb_num_coerce_bin(x, y, '*');
}
bms.x = x;
bms.y = y;
bms.z = bignew(RBIGNUM_LEN(x) + RBIGNUM_LEN(y) + 1, RBIGNUM_SIGN(x)==RBIGNUM_SIGN(y));
bms.stop = Qfalse;
if (RBIGNUM_LEN(x) + RBIGNUM_LEN(y) > 10000) {
z = rb_thread_blocking_region(bigmul1, &bms, rb_big_stop, &bms.stop);
}
else {
z = bigmul1(&bms);
}
return z;
}
/*
* call-seq:
* big * other => Numeric
*
* Multiplies big and other, returning the result.
*/
VALUE
rb_big_mul(VALUE x, VALUE y)
{
return bignorm(rb_big_mul0(x, y));
return bignorm(bigmul0(x, y));
}
struct big_div_struct {
@ -1661,6 +1913,13 @@ bigdivrem1(void *ptr)
return Qnil;
}
static void
rb_big_stop(void *ptr)
{
VALUE *stop = (VALUE*)ptr;
*stop = Qtrue;
}
static VALUE
bigdivrem(VALUE x, VALUE y, VALUE *divp, VALUE *modp)
{
@ -2037,7 +2296,7 @@ bigsqr(VALUE x)
BDIGIT_DBL num;
if (len < 4000 / BITSPERDIG) {
return bigtrunc(rb_big_mul0(x, x));
return bigtrunc(bigmul0(x, x));
}
a = bignew(len - k, 1);
@ -2054,7 +2313,7 @@ bigsqr(VALUE x)
}
MEMCPY(BDIGITS(z) + 2 * k, BDIGITS(a2), BDIGIT, RBIGNUM_LEN(a2));
RBIGNUM_SET_LEN(z, len);
a2 = bigtrunc(rb_big_mul0(a, b));
a2 = bigtrunc(bigmul0(a, b));
len = RBIGNUM_LEN(a2);
for (i = 0, num = 0; i < len; i++) {
num += (BDIGIT_DBL)BDIGITS(z)[i + k] + ((BDIGIT_DBL)BDIGITS(a2)[i] << 1);
@ -2125,7 +2384,7 @@ rb_big_pow(VALUE x, VALUE y)
for (mask = FIXNUM_MAX + 1; mask; mask >>= 1) {
if (z) z = bigtrunc(bigsqr(z));
if (yy & mask) {
z = z ? bigtrunc(rb_big_mul0(z, x)) : x;
z = z ? bigtrunc(bigmul0(z, x)) : x;
}
}
return bignorm(z);

Просмотреть файл

@ -452,6 +452,16 @@ limited_big_rand(struct MT *mt, struct RBignum *limit)
return rb_big_norm((VALUE)val);
}
unsigned long
rb_rand_internal(unsigned long i)
{
struct MT *mt = &default_mt.mt;
if (!genrand_initialized(mt)) {
rand_init(mt, random_seed());
}
return limited_rand(mt, i);
}
/*
* call-seq:
* rand(max=0) => number

Просмотреть файл

@ -200,11 +200,24 @@ class TestBignum < Test::Unit::TestCase
def test_sub
assert_equal(-T31, T32 - (T32 + T31))
x = 2**100
assert_equal(1, (x+2) - (x+1))
assert_equal(-1, (x+1) - (x+2))
assert_equal(0, (2**100) - (2.0**100))
o = Object.new
def o.coerce(x); [2**100+2, x]; end
assert_equal(1, (2**100+1) - o)
end
def test_plus
assert_equal(T32.to_f, T32P + 1.0)
assert_raise(TypeError) { T32 + "foo" }
assert_equal(1267651809154049016125877911552, (2**100) + (2**80))
assert_equal(1267651809154049016125877911552, (2**80) + (2**100))
assert_equal(2**101, (2**100) + (2.0**100))
o = Object.new
def o.coerce(x); [2**80, x]; end
assert_equal(1267651809154049016125877911552, (2**100) + o)
end
def test_minus
@ -215,6 +228,13 @@ class TestBignum < Test::Unit::TestCase
def test_mul
assert_equal(T32.to_f, T32 * 1.0)
assert_raise(TypeError) { T32 * "foo" }
o = Object.new
def o.coerce(x); [2**100, x]; end
assert_equal(2**180, (2**80) * o)
end
def test_mul_balance
assert_equal(3**7000, (3**5000) * (3**2000))
end
def test_divrem