Bug 1846762 - Implement return_call_ref. r=jseward

Differential Revision: https://phabricator.services.mozilla.com/D185706
This commit is contained in:
Yury Delendik 2023-08-31 19:32:44 +00:00
Родитель 4dd4213c22
Коммит 45d9cac5e5
12 изменённых файлов: 369 добавлений и 0 удалений

Просмотреть файл

@ -65,6 +65,7 @@ const CallCode = 0x10;
const CallIndirectCode = 0x11; const CallIndirectCode = 0x11;
const ReturnCallCode = 0x12; const ReturnCallCode = 0x12;
const ReturnCallIndirectCode = 0x13; const ReturnCallIndirectCode = 0x13;
const ReturnCallRefCode = 0x15;
const DelegateCode = 0x18; const DelegateCode = 0x18;
const DropCode = 0x1a; const DropCode = 0x1a;
const SelectCode = 0x1b; const SelectCode = 0x1b;

Просмотреть файл

@ -0,0 +1,94 @@
// |jit-test| --wasm-gc; skip-if: !wasmGcEnabled()
var ins = wasmEvalText(`(module
(type $t (func (param i64 i64 funcref) (result i64)))
(elem declare func $fac-acc $fac-acc-broken)
(func $fac-acc (export "fac-acc") (param i64 i64 funcref) (result i64)
(if (result i64) (i64.eqz (local.get 0))
(then (local.get 1))
(else
(return_call $vis
(i64.sub (local.get 0) (i64.const 1))
(i64.mul (local.get 0) (local.get 1))
(local.get 2)
)
)
)
)
;; same as $fac-acc but fails on i == 6
(func $fac-acc-broken (param i64 i64 funcref) (result i64)
(if (result i64) (i64.eqz (local.get 0))
(then (local.get 1))
(else
(return_call $vis
(i64.sub (local.get 0) (i64.const 1))
(i64.mul (local.get 0) (local.get 1))
(select (result funcref)
(ref.null func) (local.get 2)
(i64.eq (local.get 0) (i64.const 6)))
)
)
)
)
(func $vis (export "vis") (param i64 i64 funcref) (result i64)
local.get 0
local.get 1
local.get 2
local.get 2
ref.cast (ref null $t)
return_call_ref $t
)
(func $trap (export "trap") (param i64 i64 funcref) (result i64)
unreachable
)
(func (export "main") (param i64) (result i64)
(call $vis (local.get 0) (i64.const 1) (ref.func $fac-acc))
)
(func (export "main_null") (param i64) (result i64)
(return_call $vis (local.get 0) (i64.const 1) (ref.null $t))
)
(func (export "main_broken") (param i64) (result i64)
(return_call $vis (local.get 0) (i64.const 1) (ref.func $fac-acc-broken))
)
)`);
// Check return call via wasm function
assertEq(ins.exports.main(5n), 120n);
// Check return call directly via interpreter stub
const fac = ins.exports["fac-acc"];
const vis = ins.exports["vis"];
assertEq(vis(4n, 1n, fac), 24n);
// Calling into JavaScript (and back).
if ("Function" in WebAssembly) {
const visFn = new WebAssembly.Function({
parameters: ["i64", "i64", "funcref"],
results: ["i64"]
}, function (i, n, fn) {
if (i <= 0n) {
return n;
}
return vis(i - 1n, i * n, fn);
});
assertEq(vis(3n, 1n, visFn), 6n);
}
// Check return call directly via jit stub
check_stub1: {
let options = getJitCompilerOptions();
if (!options["baseline.enable"]) break check_stub1;
const check = function() {
vis(4n, 1n, fac);
};
for (let i = options["baseline.warmup.trigger"] + 1; i--;)
check();
}
// Handling traps.
const trap = ins.exports["trap"];
assertErrorMessage(() => vis(4n, 1n, trap), WebAssembly.RuntimeError, /unreachable executed/);
const main_broken = ins.exports["main_broken"];
assertErrorMessage(() => main_broken(8n), WebAssembly.RuntimeError, /dereferencing null pointer/);
const main_null = ins.exports["main_null"];
assertErrorMessage(() => main_null(5n), WebAssembly.RuntimeError, /dereferencing null pointer/);

Просмотреть файл

@ -8954,6 +8954,16 @@ void CodeGenerator::visitWasmCall(LWasmCall* lir) {
switchRealm = false; switchRealm = false;
break; break;
case wasm::CalleeDesc::FuncRef: case wasm::CalleeDesc::FuncRef:
#ifdef ENABLE_WASM_TAIL_CALLS
if (isReturnCall) {
ReturnCallAdjustmentInfo retCallInfo(
callBase->stackArgAreaSizeUnaligned(), inboundStackArgBytes_);
masm.wasmReturnCallRef(desc, callee, retCallInfo);
// The rest of the method is unnecessary for a return call.
return;
}
#endif
MOZ_ASSERT(!isReturnCall);
// Register reloading and realm switching are handled dynamically inside // Register reloading and realm switching are handled dynamically inside
// wasmCallRef. There are two return offsets, one for each call // wasmCallRef. There are two return offsets, one for each call
// instruction (fast path and slow path). // instruction (fast path and slow path).

Просмотреть файл

@ -5552,6 +5552,63 @@ void MacroAssembler::wasmCallRef(const wasm::CallSiteDesc& desc,
bind(&done); bind(&done);
} }
#ifdef ENABLE_WASM_TAIL_CALLS
void MacroAssembler::wasmReturnCallRef(
const wasm::CallSiteDesc& desc, const wasm::CalleeDesc& callee,
const ReturnCallAdjustmentInfo& retCallInfo) {
MOZ_ASSERT(callee.which() == wasm::CalleeDesc::FuncRef);
const Register calleeScratch = WasmCallRefCallScratchReg0;
const Register calleeFnObj = WasmCallRefReg;
// Load from the function's WASM_INSTANCE_SLOT extended slot, and decide
// whether to take the fast path or the slow path. Register this load
// instruction to be source of a trap -- null pointer check.
Label fastCall;
Label done;
const Register newInstanceTemp = WasmCallRefCallScratchReg1;
size_t instanceSlotOffset = FunctionExtended::offsetOfExtendedSlot(
FunctionExtended::WASM_INSTANCE_SLOT);
static_assert(FunctionExtended::WASM_INSTANCE_SLOT < wasm::NullPtrGuardSize);
wasm::BytecodeOffset trapOffset(desc.lineOrBytecode());
append(wasm::Trap::NullPointerDereference,
wasm::TrapSite(currentOffset(), trapOffset));
loadPtr(Address(calleeFnObj, instanceSlotOffset), newInstanceTemp);
branchPtr(Assembler::Equal, InstanceReg, newInstanceTemp, &fastCall);
storePtr(InstanceReg,
Address(getStackPointer(), WasmCallerInstanceOffsetBeforeCall));
movePtr(newInstanceTemp, InstanceReg);
storePtr(InstanceReg,
Address(getStackPointer(), WasmCalleeInstanceOffsetBeforeCall));
loadWasmPinnedRegsFromInstance();
switchToWasmInstanceRealm(WasmCallRefCallScratchReg0,
WasmCallRefCallScratchReg1);
// Get funcUncheckedCallEntry() from the function's
// WASM_FUNC_UNCHECKED_ENTRY_SLOT extended slot.
size_t uncheckedEntrySlotOffset = FunctionExtended::offsetOfExtendedSlot(
FunctionExtended::WASM_FUNC_UNCHECKED_ENTRY_SLOT);
loadPtr(Address(calleeFnObj, uncheckedEntrySlotOffset), calleeScratch);
wasm::CallSiteDesc stubDesc(desc.lineOrBytecode(),
wasm::CallSiteDesc::ReturnStub);
wasmCollapseFrameSlow(retCallInfo, stubDesc);
jump(calleeScratch);
// Fast path: just load WASM_FUNC_UNCHECKED_ENTRY_SLOT value and go.
// The instance and pinned registers are the same as in the caller.
bind(&fastCall);
loadPtr(Address(calleeFnObj, uncheckedEntrySlotOffset), calleeScratch);
wasmCollapseFrameFast(retCallInfo);
jump(calleeScratch);
}
#endif
bool MacroAssembler::needScratch1ForBranchWasmRefIsSubtypeAny( bool MacroAssembler::needScratch1ForBranchWasmRefIsSubtypeAny(
wasm::RefType type) { wasm::RefType type) {
MOZ_ASSERT(type.isValid()); MOZ_ASSERT(type.isValid());

Просмотреть файл

@ -3882,6 +3882,12 @@ class MacroAssembler : public MacroAssemblerSpecific {
const wasm::CalleeDesc& callee, CodeOffset* fastCallOffset, const wasm::CalleeDesc& callee, CodeOffset* fastCallOffset,
CodeOffset* slowCallOffset); CodeOffset* slowCallOffset);
#ifdef ENABLE_WASM_TAIL_CALLS
void wasmReturnCallRef(const wasm::CallSiteDesc& desc,
const wasm::CalleeDesc& callee,
const ReturnCallAdjustmentInfo& retCallInfo);
#endif // ENABLE_WASM_TAIL_CALLS
// WasmTableCallIndexReg must contain the index of the indirect call. // WasmTableCallIndexReg must contain the index of the indirect call.
// This is for asm.js calls only. // This is for asm.js calls only.
CodeOffset asmCallIndirect(const wasm::CallSiteDesc& desc, CodeOffset asmCallIndirect(const wasm::CallSiteDesc& desc,

Просмотреть файл

@ -961,6 +961,10 @@ struct BaseCompiler final {
#ifdef ENABLE_WASM_FUNCTION_REFERENCES #ifdef ENABLE_WASM_FUNCTION_REFERENCES
void callRef(const Stk& calleeRef, const FunctionCall& call, void callRef(const Stk& calleeRef, const FunctionCall& call,
CodeOffset* fastCallOffset, CodeOffset* slowCallOffset); CodeOffset* fastCallOffset, CodeOffset* slowCallOffset);
# ifdef ENABLE_WASM_TAIL_CALLS
void returnCallRef(const Stk& calleeRef, const FunctionCall& call,
const FuncType* funcType);
# endif
#endif #endif
CodeOffset builtinCall(SymbolicAddress builtin, const FunctionCall& call); CodeOffset builtinCall(SymbolicAddress builtin, const FunctionCall& call);
CodeOffset builtinInstanceMethodCall(const SymbolicAddressSignature& builtin, CodeOffset builtinInstanceMethodCall(const SymbolicAddressSignature& builtin,
@ -1613,6 +1617,7 @@ struct BaseCompiler final {
[[nodiscard]] bool emitBrOnNull(); [[nodiscard]] bool emitBrOnNull();
[[nodiscard]] bool emitBrOnNonNull(); [[nodiscard]] bool emitBrOnNonNull();
[[nodiscard]] bool emitCallRef(); [[nodiscard]] bool emitCallRef();
[[nodiscard]] bool emitReturnCallRef();
#endif #endif
[[nodiscard]] bool emitAtomicCmpXchg(ValType type, Scalar::Type viewType); [[nodiscard]] bool emitAtomicCmpXchg(ValType type, Scalar::Type viewType);

Просмотреть файл

@ -1639,6 +1639,20 @@ void BaseCompiler::callRef(const Stk& calleeRef, const FunctionCall& call,
loadRef(calleeRef, RegRef(WasmCallRefReg)); loadRef(calleeRef, RegRef(WasmCallRefReg));
masm.wasmCallRef(desc, callee, fastCallOffset, slowCallOffset); masm.wasmCallRef(desc, callee, fastCallOffset, slowCallOffset);
} }
# ifdef ENABLE_WASM_TAIL_CALLS
void BaseCompiler::returnCallRef(const Stk& calleeRef, const FunctionCall& call,
const FuncType* funcType) {
CallSiteDesc desc(bytecodeOffset(), CallSiteDesc::FuncRef);
CalleeDesc callee = CalleeDesc::wasmFuncRef();
loadRef(calleeRef, RegRef(WasmCallRefReg));
ReturnCallAdjustmentInfo retCallInfo =
BuildReturnCallAdjustmentInfo(this->funcType(), *funcType);
masm.wasmReturnCallRef(desc, callee, retCallInfo);
}
# endif
#endif #endif
// Precondition: sync() // Precondition: sync()
@ -5022,6 +5036,57 @@ bool BaseCompiler::emitCallRef() {
captureCallResultRegisters(resultType); captureCallResultRegisters(resultType);
return pushCallResults(baselineCall, resultType, results); return pushCallResults(baselineCall, resultType, results);
} }
# ifdef ENABLE_WASM_TAIL_CALLS
bool BaseCompiler::emitReturnCallRef() {
const FuncType* funcType;
Nothing unused_callee;
BaseNothingVector unused_args{};
BaseNothingVector unused_values{};
if (!iter_.readReturnCallRef(&funcType, &unused_callee, &unused_args,
&unused_values)) {
return false;
}
if (deadCode_) {
return true;
}
sync();
// Stack: ... arg1 .. argn callee
uint32_t numArgs = funcType->args().length() + 1;
ResultType resultType(ResultType::Vector(funcType->results()));
StackResultsLoc results;
if (!pushStackResultsForCall(resultType, RegPtr(ABINonArgReg0), &results)) {
return false;
}
FunctionCall baselineCall{};
// State and realm are restored as needed by by callRef (really by
// MacroAssembler::wasmCallRef).
beginCall(baselineCall, UseABI::Wasm, RestoreRegisterStateAndRealm::False);
if (!emitCallArgs(funcType->args(), NormalCallResults(results), &baselineCall,
CalleeOnStack::True)) {
return false;
}
const Stk& callee = peek(results.count());
returnCallRef(callee, baselineCall, funcType);
MOZ_ASSERT(stackMapGenerator_.framePushedExcludingOutboundCallArgs.isSome());
stackMapGenerator_.framePushedExcludingOutboundCallArgs.reset();
popValueStackBy(numArgs);
deadCode_ = true;
return true;
}
# endif
#endif #endif
void BaseCompiler::emitRound(RoundingMode roundingMode, ValType operandType) { void BaseCompiler::emitRound(RoundingMode roundingMode, ValType operandType) {
@ -9374,6 +9439,14 @@ bool BaseCompiler::emitBody() {
return iter_.unrecognizedOpcode(&op); return iter_.unrecognizedOpcode(&op);
} }
CHECK_NEXT(emitCallRef()); CHECK_NEXT(emitCallRef());
# ifdef ENABLE_WASM_TAIL_CALLS
case uint16_t(Op::ReturnCallRef):
if (!moduleEnv_.functionReferencesEnabled() ||
!moduleEnv_.tailCallsEnabled()) {
return iter_.unrecognizedOpcode(&op);
}
CHECK_NEXT(emitReturnCallRef());
# endif
#endif #endif
// Locals and globals // Locals and globals

Просмотреть файл

@ -264,6 +264,7 @@ enum class Op {
ReturnCall = 0x12, ReturnCall = 0x12,
ReturnCallIndirect = 0x13, ReturnCallIndirect = 0x13,
CallRef = 0x14, CallRef = 0x14,
ReturnCallRef = 0x15,
// Additional exception operators // Additional exception operators
Delegate = 0x18, Delegate = 0x18,

Просмотреть файл

@ -2403,6 +2403,28 @@ class FunctionCompiler {
return collectCallResults(resultType, call.stackResultArea_, results); return collectCallResults(resultType, call.stackResultArea_, results);
} }
# ifdef ENABLE_WASM_TAIL_CALLS
[[nodiscard]] bool returnCallRef(const FuncType& funcType, MDefinition* ref,
uint32_t lineOrBytecode,
const CallCompileState& call,
DefVector* results) {
CalleeDesc callee = CalleeDesc::wasmFuncRef();
CallSiteDesc desc(lineOrBytecode, CallSiteDesc::FuncRef);
ArgTypeVector args(funcType);
auto* ins = MWasmReturnCall::New(alloc(), desc, callee, call.regArgs_,
StackArgAreaSizeUnaligned(args), ref);
if (!ins) {
return false;
}
curBlock_->end(ins);
curBlock_ = nullptr;
return true;
}
# endif // ENABLE_WASM_TAIL_CALLS
#endif // ENABLE_WASM_FUNCTION_REFERENCES #endif // ENABLE_WASM_FUNCTION_REFERENCES
/*********************************************** Control flow generation */ /*********************************************** Control flow generation */
@ -5147,6 +5169,34 @@ static bool EmitReturnCallIndirect(FunctionCompiler& f) {
} }
#endif #endif
#if defined(ENABLE_WASM_TAIL_CALLS) && defined(ENABLE_WASM_FUNCTION_REFERENCES)
static bool EmitReturnCallRef(FunctionCompiler& f) {
uint32_t lineOrBytecode = f.readCallSiteLineOrBytecode();
const FuncType* funcType;
MDefinition* callee;
DefVector args;
DefVector unused_values;
if (!f.iter().readReturnCallRef(&funcType, &callee, &args, &unused_values)) {
return false;
}
if (f.inDeadCode()) {
return true;
}
CallCompileState call;
f.markReturnCall(&call);
if (!EmitCallArgs(f, *funcType, args, &call)) {
return false;
}
DefVector results;
return f.returnCallRef(*funcType, callee, lineOrBytecode, call, &results);
}
#endif
static bool EmitGetLocal(FunctionCompiler& f) { static bool EmitGetLocal(FunctionCompiler& f) {
uint32_t id; uint32_t id;
if (!f.iter().readGetLocal(f.locals(), &id)) { if (!f.iter().readGetLocal(f.locals(), &id)) {
@ -8158,6 +8208,16 @@ static bool EmitBodyExprs(FunctionCompiler& f) {
} }
#endif #endif
#if defined(ENABLE_WASM_TAIL_CALLS) && defined(ENABLE_WASM_FUNCTION_REFERENCES)
case uint16_t(Op::ReturnCallRef): {
if (!f.moduleEnv().functionReferencesEnabled() ||
!f.moduleEnv().tailCallsEnabled()) {
return f.iter().unrecognizedOpcode(&op);
}
CHECK(EmitReturnCallRef(f));
}
#endif
// Gc operations // Gc operations
#ifdef ENABLE_WASM_GC #ifdef ENABLE_WASM_GC
case uint16_t(Op::GcPrefix): { case uint16_t(Op::GcPrefix): {

Просмотреть файл

@ -258,6 +258,8 @@ OpKind wasm::Classify(OpBytes op) {
return OpKind::ReturnCallIndirect; return OpKind::ReturnCallIndirect;
case Op::CallRef: case Op::CallRef:
WASM_FUNCTION_REFERENCES_OP(OpKind::CallRef); WASM_FUNCTION_REFERENCES_OP(OpKind::CallRef);
case Op::ReturnCallRef:
WASM_FUNCTION_REFERENCES_OP(OpKind::ReturnCallRef);
case Op::Return: case Op::Return:
case Op::Limit: case Op::Limit:
// Accept Limit, for use in decoding the end of a function after the body. // Accept Limit, for use in decoding the end of a function after the body.

Просмотреть файл

@ -166,6 +166,7 @@ enum class OpKind {
ReturnCallIndirect, ReturnCallIndirect,
# ifdef ENABLE_WASM_FUNCTION_REFERENCES # ifdef ENABLE_WASM_FUNCTION_REFERENCES
CallRef, CallRef,
ReturnCallRef,
# endif # endif
OldCallDirect, OldCallDirect,
OldCallIndirect, OldCallIndirect,
@ -699,6 +700,12 @@ class MOZ_STACK_CLASS OpIter : private Policy {
#ifdef ENABLE_WASM_FUNCTION_REFERENCES #ifdef ENABLE_WASM_FUNCTION_REFERENCES
[[nodiscard]] bool readCallRef(const FuncType** funcType, Value* callee, [[nodiscard]] bool readCallRef(const FuncType** funcType, Value* callee,
ValueVector* argValues); ValueVector* argValues);
# ifdef ENABLE_WASM_TAIL_CALLS
[[nodiscard]] bool readReturnCallRef(const FuncType** funcType, Value* callee,
ValueVector* argValues,
ValueVector* values);
# endif
#endif #endif
[[nodiscard]] bool readOldCallDirect(uint32_t numFuncImports, [[nodiscard]] bool readOldCallDirect(uint32_t numFuncImports,
uint32_t* funcTypeIndex, uint32_t* funcTypeIndex,
@ -2625,6 +2632,47 @@ inline bool OpIter<Policy>::readCallRef(const FuncType** funcType,
} }
#endif #endif
#if defined(ENABLE_WASM_TAIL_CALLS) && defined(ENABLE_WASM_FUNCTION_REFERENCES)
template <typename Policy>
inline bool OpIter<Policy>::readReturnCallRef(const FuncType** funcType,
Value* callee,
ValueVector* argValues,
ValueVector* values) {
MOZ_ASSERT(Classify(op_) == OpKind::ReturnCallRef);
uint32_t funcTypeIndex;
if (!readFuncTypeIndex(&funcTypeIndex)) {
return false;
}
const TypeDef& typeDef = env_.types->type(funcTypeIndex);
*funcType = &typeDef.funcType();
if (!popWithType(ValType(RefType::fromTypeDef(&typeDef, true)), callee)) {
return false;
}
if (!popCallArgs((*funcType)->args(), argValues)) {
return false;
}
if (!push(ResultType::Vector((*funcType)->results()))) {
return false;
}
Control& body = controlStack_[0];
MOZ_ASSERT(body.kind() == LabelKind::Body);
// Pop function results as the instruction will cause a return.
if (!popWithType(body.resultType(), values)) {
return false;
}
afterUnconditionalBranch();
return true;
}
#endif
template <typename Policy> template <typename Policy>
inline bool OpIter<Policy>::readOldCallDirect(uint32_t numFuncImports, inline bool OpIter<Policy>::readOldCallDirect(uint32_t numFuncImports,
uint32_t* funcTypeIndex, uint32_t* funcTypeIndex,

Просмотреть файл

@ -244,6 +244,18 @@ static bool DecodeFunctionBodyExprs(const ModuleEnvironment& env,
NothingVector unusedArgs{}; NothingVector unusedArgs{};
CHECK(iter.readCallRef(&unusedType, &nothing, &unusedArgs)); CHECK(iter.readCallRef(&unusedType, &nothing, &unusedArgs));
} }
# ifdef ENABLE_WASM_TAIL_CALLS
case uint16_t(Op::ReturnCallRef): {
if (!env.functionReferencesEnabled() || !env.tailCallsEnabled()) {
return iter.unrecognizedOpcode(&op);
}
const FuncType* unusedType;
NothingVector unusedArgs{};
NothingVector unusedValues{};
CHECK(iter.readReturnCallRef(&unusedType, &nothing, &unusedArgs,
&unusedValues));
}
# endif
#endif #endif
case uint16_t(Op::I32Const): { case uint16_t(Op::I32Const): {
int32_t unused; int32_t unused;