YJIT: implement fast path for integer multiplication in opt_mult (#8204)

* YJIT: implement fast path for integer multiplication in opt_mult

* Update yjit/src/codegen.rs

Co-authored-by: Alan Wu <XrXr@users.noreply.github.com>

* Implement mul with overflow checking on arm64

* Fix missing semicolon

* Add arm splitting for lshift, rshift, urshift

---------

Co-authored-by: Alan Wu <XrXr@users.noreply.github.com>
This commit is contained in:
Maxime Chevalier-Boisvert 2023-08-18 10:05:32 -04:00 коммит произвёл GitHub
Родитель 724223b4ca
Коммит 314eed8a5e
Не найден ключ, соответствующий данной подписи
Идентификатор ключа GPG: 4AEE18F83AFDEB23
7 изменённых файлов: 171 добавлений и 5 удалений

Просмотреть файл

@ -4101,3 +4101,18 @@ assert_equal '6', %q{
Sub.new.number { 3 }
}
# Integer multiplication and overflow
assert_equal '[6, -6, 9671406556917033397649408, -9671406556917033397649408, 21267647932558653966460912964485513216]', %q{
def foo(a, b)
a * b
end
r1 = foo(2, 3)
r2 = foo(2, -3)
r3 = foo(2 << 40, 2 << 41)
r4 = foo(2 << 40, -2 << 41)
r5 = foo(1 << 62, 1 << 62)
[r1, r2, r3, r4, r5]
}

Просмотреть файл

@ -17,6 +17,7 @@ mod load_store_exclusive;
mod logical_imm;
mod logical_reg;
mod madd;
mod smulh;
mod mov;
mod nop;
mod pc_rel;
@ -42,6 +43,7 @@ pub use load_store_exclusive::LoadStoreExclusive;
pub use logical_imm::LogicalImm;
pub use logical_reg::LogicalReg;
pub use madd::MAdd;
pub use smulh::SMulH;
pub use mov::Mov;
pub use nop::Nop;
pub use pc_rel::PCRelative;

Просмотреть файл

@ -0,0 +1,60 @@
/// The struct that represents an A64 signed multipy high instruction
///
/// +-------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+
/// | 31 30 29 28 | 27 26 25 24 | 23 22 21 20 | 19 18 17 16 | 15 14 13 12 | 11 10 09 08 | 07 06 05 04 | 03 02 01 00 |
/// | 1 0 0 1 1 0 1 1 0 1 0 0 |
/// | rm.............. ra.............. rn.............. rd.............. |
/// +-------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+
///
pub struct SMulH {
/// The number of the general-purpose destination register.
rd: u8,
/// The number of the first general-purpose source register.
rn: u8,
/// The number of the third general-purpose source register.
ra: u8,
/// The number of the second general-purpose source register.
rm: u8,
}
impl SMulH {
/// SMULH
/// https://developer.arm.com/documentation/ddi0602/2023-06/Base-Instructions/SMULH--Signed-Multiply-High-
pub fn smulh(rd: u8, rn: u8, rm: u8) -> Self {
Self { rd, rn, ra: 0b11111, rm }
}
}
impl From<SMulH> for u32 {
/// Convert an instruction into a 32-bit value.
fn from(inst: SMulH) -> Self {
0
| (0b10011011010 << 21)
| ((inst.rm as u32) << 16)
| ((inst.ra as u32) << 10)
| ((inst.rn as u32) << 5)
| (inst.rd as u32)
}
}
impl From<SMulH> for [u8; 4] {
/// Convert an instruction into a 4 byte array.
fn from(inst: SMulH) -> [u8; 4] {
let result: u32 = inst.into();
result.to_le_bytes()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_smulh() {
let result: u32 = SMulH::smulh(0, 1, 2).into();
assert_eq!(0x9b427c20, result);
}
}

Просмотреть файл

@ -186,7 +186,7 @@ pub fn asr(cb: &mut CodeBlock, rd: A64Opnd, rn: A64Opnd, shift: A64Opnd) {
SBFM::asr(rd.reg_no, rn.reg_no, shift.try_into().unwrap(), rd.num_bits).into()
},
_ => panic!("Invalid operand combination to asr instruction."),
_ => panic!("Invalid operand combination to asr instruction: asr {:?}, {:?}, {:?}", rd, rn, shift),
};
cb.write_bytes(&bytes);
@ -713,6 +713,21 @@ pub fn mul(cb: &mut CodeBlock, rd: A64Opnd, rn: A64Opnd, rm: A64Opnd) {
cb.write_bytes(&bytes);
}
/// SMULH - multiply two 64-bit registers to produce a 128-bit result, put the high 64-bits of the result into rd
pub fn smulh(cb: &mut CodeBlock, rd: A64Opnd, rn: A64Opnd, rm: A64Opnd) {
let bytes: [u8; 4] = match (rd, rn, rm) {
(A64Opnd::Reg(rd), A64Opnd::Reg(rn), A64Opnd::Reg(rm)) => {
assert!(rd.num_bits == rn.num_bits && rn.num_bits == rm.num_bits, "Expected registers to be the same size");
assert!(rd.num_bits == 64, "smulh only applicable to 64-bit registers");
SMulH::smulh(rd.reg_no, rn.reg_no, rm.reg_no).into()
},
_ => panic!("Invalid operand combination to mul instruction")
};
cb.write_bytes(&bytes);
}
/// MVN - move a value in a register to another register, negating it
pub fn mvn(cb: &mut CodeBlock, rd: A64Opnd, rm: A64Opnd) {
let bytes: [u8; 4] = match (rd, rm) {

Просмотреть файл

@ -612,6 +612,19 @@ impl Assembler
asm.not(opnd0);
},
Insn::LShift { opnd, shift, .. } |
Insn::RShift { opnd, shift, .. } |
Insn::URShift { opnd, shift, .. } => {
// The operand must be in a register, so
// if we get anything else we need to load it first.
let opnd0 = match opnd {
Opnd::Mem(_) => split_load_operand(asm, *opnd),
_ => *opnd
};
*opnd = opnd0;
asm.push_insn(insn);
},
Insn::Store { dest, src } => {
// The value being stored must be in a register, so if it's
// not already one we'll load it first.
@ -811,6 +824,7 @@ impl Assembler
let start_write_pos = cb.get_write_pos();
let mut insn_idx: usize = 0;
while let Some(insn) = self.insns.get(insn_idx) {
let mut next_insn_idx = insn_idx + 1;
let src_ptr = cb.get_write_ptr();
let had_dropped_bytes = cb.has_dropped_bytes();
let old_label_state = cb.get_label_state();
@ -863,7 +877,32 @@ impl Assembler
subs(cb, out.into(), left.into(), right.into());
},
Insn::Mul { left, right, out } => {
mul(cb, out.into(), left.into(), right.into());
// If the next instruction is jo (jump on overflow)
match self.insns.get(insn_idx + 1) {
Some(Insn::Jo(target)) => {
// Compute the high 64 bits
smulh(cb, Self::SCRATCH0, left.into(), right.into());
// Compute the low 64 bits
// This may clobber one of the input registers,
// so we do it after smulh
mul(cb, out.into(), left.into(), right.into());
// Produce a register that is all zeros or all ones
// Based on the sign bit of the 64-bit mul result
asr(cb, Self::SCRATCH1, out.into(), A64Opnd::UImm(63));
// If the high 64-bits are not all zeros or all ones,
// matching the sign bit, then we have an overflow
cmp(cb, Self::SCRATCH0, Self::SCRATCH1);
emit_conditional_jump::<{Condition::NE}>(cb, compile_side_exit(*target, self, ocb));
next_insn_idx += 1;
}
_ => {
mul(cb, out.into(), left.into(), right.into());
}
}
},
Insn::And { left, right, out } => {
and(cb, out.into(), left.into(), right.into());
@ -1158,7 +1197,7 @@ impl Assembler
return Err(());
}
} else {
insn_idx += 1;
insn_idx = next_insn_idx;
gc_offsets.append(&mut insn_gc_offsets);
}
}

Просмотреть файл

@ -3398,8 +3398,42 @@ fn gen_opt_mult(
asm: &mut Assembler,
ocb: &mut OutlinedCb,
) -> Option<CodegenStatus> {
// Delegate to send, call the method on the recv
gen_opt_send_without_block(jit, asm, ocb)
let two_fixnums = match asm.ctx.two_fixnums_on_stack(jit) {
Some(two_fixnums) => two_fixnums,
None => {
defer_compilation(jit, asm, ocb);
return Some(EndBlock);
}
};
if two_fixnums {
if !assume_bop_not_redefined(jit, asm, ocb, INTEGER_REDEFINED_OP_FLAG, BOP_MULT) {
return None;
}
// Check that both operands are fixnums
guard_two_fixnums(jit, asm, ocb);
// Get the operands from the stack
let arg1 = asm.stack_pop(1);
let arg0 = asm.stack_pop(1);
// Do some bitwise gymnastics to handle tag bits
// x * y is translated to (x >> 1) * (y - 1) + 1
let arg0_untag = asm.rshift(arg0, Opnd::UImm(1));
let arg1_untag = asm.sub(arg1, Opnd::UImm(1));
let out_val = asm.mul(arg0_untag, arg1_untag);
asm.jo(Target::side_exit(Counter::opt_mult_overflow));
let out_val = asm.add(out_val, Opnd::UImm(1));
// Push the output on the stack
let dst = asm.stack_push(Type::Fixnum);
asm.mov(dst, out_val);
Some(KeepCompiling)
} else {
gen_opt_send_without_block(jit, asm, ocb)
}
}
fn gen_opt_div(

Просмотреть файл

@ -343,6 +343,7 @@ make_counters! {
opt_plus_overflow,
opt_minus_overflow,
opt_mult_overflow,
opt_mod_zero,
opt_div_zero,