зеркало из https://github.com/github/ruby.git
YJIT: implement fast path for integer multiplication in opt_mult (#8204)
* YJIT: implement fast path for integer multiplication in opt_mult * Update yjit/src/codegen.rs Co-authored-by: Alan Wu <XrXr@users.noreply.github.com> * Implement mul with overflow checking on arm64 * Fix missing semicolon * Add arm splitting for lshift, rshift, urshift --------- Co-authored-by: Alan Wu <XrXr@users.noreply.github.com>
This commit is contained in:
Родитель
724223b4ca
Коммит
314eed8a5e
|
@ -4101,3 +4101,18 @@ assert_equal '6', %q{
|
|||
|
||||
Sub.new.number { 3 }
|
||||
}
|
||||
|
||||
# Integer multiplication and overflow
|
||||
assert_equal '[6, -6, 9671406556917033397649408, -9671406556917033397649408, 21267647932558653966460912964485513216]', %q{
|
||||
def foo(a, b)
|
||||
a * b
|
||||
end
|
||||
|
||||
r1 = foo(2, 3)
|
||||
r2 = foo(2, -3)
|
||||
r3 = foo(2 << 40, 2 << 41)
|
||||
r4 = foo(2 << 40, -2 << 41)
|
||||
r5 = foo(1 << 62, 1 << 62)
|
||||
|
||||
[r1, r2, r3, r4, r5]
|
||||
}
|
||||
|
|
|
@ -17,6 +17,7 @@ mod load_store_exclusive;
|
|||
mod logical_imm;
|
||||
mod logical_reg;
|
||||
mod madd;
|
||||
mod smulh;
|
||||
mod mov;
|
||||
mod nop;
|
||||
mod pc_rel;
|
||||
|
@ -42,6 +43,7 @@ pub use load_store_exclusive::LoadStoreExclusive;
|
|||
pub use logical_imm::LogicalImm;
|
||||
pub use logical_reg::LogicalReg;
|
||||
pub use madd::MAdd;
|
||||
pub use smulh::SMulH;
|
||||
pub use mov::Mov;
|
||||
pub use nop::Nop;
|
||||
pub use pc_rel::PCRelative;
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
/// The struct that represents an A64 signed multipy high instruction
|
||||
///
|
||||
/// +-------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+
|
||||
/// | 31 30 29 28 | 27 26 25 24 | 23 22 21 20 | 19 18 17 16 | 15 14 13 12 | 11 10 09 08 | 07 06 05 04 | 03 02 01 00 |
|
||||
/// | 1 0 0 1 1 0 1 1 0 1 0 0 |
|
||||
/// | rm.............. ra.............. rn.............. rd.............. |
|
||||
/// +-------------+-------------+-------------+-------------+-------------+-------------+-------------+-------------+
|
||||
///
|
||||
pub struct SMulH {
|
||||
/// The number of the general-purpose destination register.
|
||||
rd: u8,
|
||||
|
||||
/// The number of the first general-purpose source register.
|
||||
rn: u8,
|
||||
|
||||
/// The number of the third general-purpose source register.
|
||||
ra: u8,
|
||||
|
||||
/// The number of the second general-purpose source register.
|
||||
rm: u8,
|
||||
}
|
||||
|
||||
impl SMulH {
|
||||
/// SMULH
|
||||
/// https://developer.arm.com/documentation/ddi0602/2023-06/Base-Instructions/SMULH--Signed-Multiply-High-
|
||||
pub fn smulh(rd: u8, rn: u8, rm: u8) -> Self {
|
||||
Self { rd, rn, ra: 0b11111, rm }
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SMulH> for u32 {
|
||||
/// Convert an instruction into a 32-bit value.
|
||||
fn from(inst: SMulH) -> Self {
|
||||
0
|
||||
| (0b10011011010 << 21)
|
||||
| ((inst.rm as u32) << 16)
|
||||
| ((inst.ra as u32) << 10)
|
||||
| ((inst.rn as u32) << 5)
|
||||
| (inst.rd as u32)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SMulH> for [u8; 4] {
|
||||
/// Convert an instruction into a 4 byte array.
|
||||
fn from(inst: SMulH) -> [u8; 4] {
|
||||
let result: u32 = inst.into();
|
||||
result.to_le_bytes()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn test_smulh() {
|
||||
let result: u32 = SMulH::smulh(0, 1, 2).into();
|
||||
assert_eq!(0x9b427c20, result);
|
||||
}
|
||||
}
|
|
@ -186,7 +186,7 @@ pub fn asr(cb: &mut CodeBlock, rd: A64Opnd, rn: A64Opnd, shift: A64Opnd) {
|
|||
|
||||
SBFM::asr(rd.reg_no, rn.reg_no, shift.try_into().unwrap(), rd.num_bits).into()
|
||||
},
|
||||
_ => panic!("Invalid operand combination to asr instruction."),
|
||||
_ => panic!("Invalid operand combination to asr instruction: asr {:?}, {:?}, {:?}", rd, rn, shift),
|
||||
};
|
||||
|
||||
cb.write_bytes(&bytes);
|
||||
|
@ -713,6 +713,21 @@ pub fn mul(cb: &mut CodeBlock, rd: A64Opnd, rn: A64Opnd, rm: A64Opnd) {
|
|||
cb.write_bytes(&bytes);
|
||||
}
|
||||
|
||||
/// SMULH - multiply two 64-bit registers to produce a 128-bit result, put the high 64-bits of the result into rd
|
||||
pub fn smulh(cb: &mut CodeBlock, rd: A64Opnd, rn: A64Opnd, rm: A64Opnd) {
|
||||
let bytes: [u8; 4] = match (rd, rn, rm) {
|
||||
(A64Opnd::Reg(rd), A64Opnd::Reg(rn), A64Opnd::Reg(rm)) => {
|
||||
assert!(rd.num_bits == rn.num_bits && rn.num_bits == rm.num_bits, "Expected registers to be the same size");
|
||||
assert!(rd.num_bits == 64, "smulh only applicable to 64-bit registers");
|
||||
|
||||
SMulH::smulh(rd.reg_no, rn.reg_no, rm.reg_no).into()
|
||||
},
|
||||
_ => panic!("Invalid operand combination to mul instruction")
|
||||
};
|
||||
|
||||
cb.write_bytes(&bytes);
|
||||
}
|
||||
|
||||
/// MVN - move a value in a register to another register, negating it
|
||||
pub fn mvn(cb: &mut CodeBlock, rd: A64Opnd, rm: A64Opnd) {
|
||||
let bytes: [u8; 4] = match (rd, rm) {
|
||||
|
|
|
@ -612,6 +612,19 @@ impl Assembler
|
|||
|
||||
asm.not(opnd0);
|
||||
},
|
||||
Insn::LShift { opnd, shift, .. } |
|
||||
Insn::RShift { opnd, shift, .. } |
|
||||
Insn::URShift { opnd, shift, .. } => {
|
||||
// The operand must be in a register, so
|
||||
// if we get anything else we need to load it first.
|
||||
let opnd0 = match opnd {
|
||||
Opnd::Mem(_) => split_load_operand(asm, *opnd),
|
||||
_ => *opnd
|
||||
};
|
||||
|
||||
*opnd = opnd0;
|
||||
asm.push_insn(insn);
|
||||
},
|
||||
Insn::Store { dest, src } => {
|
||||
// The value being stored must be in a register, so if it's
|
||||
// not already one we'll load it first.
|
||||
|
@ -811,6 +824,7 @@ impl Assembler
|
|||
let start_write_pos = cb.get_write_pos();
|
||||
let mut insn_idx: usize = 0;
|
||||
while let Some(insn) = self.insns.get(insn_idx) {
|
||||
let mut next_insn_idx = insn_idx + 1;
|
||||
let src_ptr = cb.get_write_ptr();
|
||||
let had_dropped_bytes = cb.has_dropped_bytes();
|
||||
let old_label_state = cb.get_label_state();
|
||||
|
@ -863,7 +877,32 @@ impl Assembler
|
|||
subs(cb, out.into(), left.into(), right.into());
|
||||
},
|
||||
Insn::Mul { left, right, out } => {
|
||||
mul(cb, out.into(), left.into(), right.into());
|
||||
// If the next instruction is jo (jump on overflow)
|
||||
match self.insns.get(insn_idx + 1) {
|
||||
Some(Insn::Jo(target)) => {
|
||||
// Compute the high 64 bits
|
||||
smulh(cb, Self::SCRATCH0, left.into(), right.into());
|
||||
|
||||
// Compute the low 64 bits
|
||||
// This may clobber one of the input registers,
|
||||
// so we do it after smulh
|
||||
mul(cb, out.into(), left.into(), right.into());
|
||||
|
||||
// Produce a register that is all zeros or all ones
|
||||
// Based on the sign bit of the 64-bit mul result
|
||||
asr(cb, Self::SCRATCH1, out.into(), A64Opnd::UImm(63));
|
||||
|
||||
// If the high 64-bits are not all zeros or all ones,
|
||||
// matching the sign bit, then we have an overflow
|
||||
cmp(cb, Self::SCRATCH0, Self::SCRATCH1);
|
||||
emit_conditional_jump::<{Condition::NE}>(cb, compile_side_exit(*target, self, ocb));
|
||||
|
||||
next_insn_idx += 1;
|
||||
}
|
||||
_ => {
|
||||
mul(cb, out.into(), left.into(), right.into());
|
||||
}
|
||||
}
|
||||
},
|
||||
Insn::And { left, right, out } => {
|
||||
and(cb, out.into(), left.into(), right.into());
|
||||
|
@ -1158,7 +1197,7 @@ impl Assembler
|
|||
return Err(());
|
||||
}
|
||||
} else {
|
||||
insn_idx += 1;
|
||||
insn_idx = next_insn_idx;
|
||||
gc_offsets.append(&mut insn_gc_offsets);
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3398,8 +3398,42 @@ fn gen_opt_mult(
|
|||
asm: &mut Assembler,
|
||||
ocb: &mut OutlinedCb,
|
||||
) -> Option<CodegenStatus> {
|
||||
// Delegate to send, call the method on the recv
|
||||
gen_opt_send_without_block(jit, asm, ocb)
|
||||
let two_fixnums = match asm.ctx.two_fixnums_on_stack(jit) {
|
||||
Some(two_fixnums) => two_fixnums,
|
||||
None => {
|
||||
defer_compilation(jit, asm, ocb);
|
||||
return Some(EndBlock);
|
||||
}
|
||||
};
|
||||
|
||||
if two_fixnums {
|
||||
if !assume_bop_not_redefined(jit, asm, ocb, INTEGER_REDEFINED_OP_FLAG, BOP_MULT) {
|
||||
return None;
|
||||
}
|
||||
|
||||
// Check that both operands are fixnums
|
||||
guard_two_fixnums(jit, asm, ocb);
|
||||
|
||||
// Get the operands from the stack
|
||||
let arg1 = asm.stack_pop(1);
|
||||
let arg0 = asm.stack_pop(1);
|
||||
|
||||
// Do some bitwise gymnastics to handle tag bits
|
||||
// x * y is translated to (x >> 1) * (y - 1) + 1
|
||||
let arg0_untag = asm.rshift(arg0, Opnd::UImm(1));
|
||||
let arg1_untag = asm.sub(arg1, Opnd::UImm(1));
|
||||
let out_val = asm.mul(arg0_untag, arg1_untag);
|
||||
asm.jo(Target::side_exit(Counter::opt_mult_overflow));
|
||||
let out_val = asm.add(out_val, Opnd::UImm(1));
|
||||
|
||||
// Push the output on the stack
|
||||
let dst = asm.stack_push(Type::Fixnum);
|
||||
asm.mov(dst, out_val);
|
||||
|
||||
Some(KeepCompiling)
|
||||
} else {
|
||||
gen_opt_send_without_block(jit, asm, ocb)
|
||||
}
|
||||
}
|
||||
|
||||
fn gen_opt_div(
|
||||
|
|
|
@ -343,6 +343,7 @@ make_counters! {
|
|||
|
||||
opt_plus_overflow,
|
||||
opt_minus_overflow,
|
||||
opt_mult_overflow,
|
||||
|
||||
opt_mod_zero,
|
||||
opt_div_zero,
|
||||
|
|
Загрузка…
Ссылка в новой задаче