Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions src/interpreter/Interpreter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2213,23 +2213,25 @@ NEVER_INLINE bool Interpreter::testRefGeneric(void* refPtr, Value::Type type)
return type == Value::AnyRef;
}

ASSERT(type == Value::I31Ref || type == Value::StructRef || type == Value::ArrayRef);
ASSERT(type == Value::I31Ref || type == Value::StructRef
|| type == Value::ArrayRef || type == Value::EqRef);

if (Value::isI31Value(refPtr)) {
return type == Value::I31Ref;
}

if (type == Value::I31Ref) {
return false;
return type == Value::I31Ref || type == Value::EqRef;
}

Object::Kind kind = reinterpret_cast<Object*>(refPtr)->kind();

if (type == Value::StructRef) {
switch (type) {
case Value::I31Ref:
return false;
case Value::StructRef:
return kind == Object::StructKind;
case Value::ArrayRef:
return kind == Object::ArrayKind;
default:
return kind == Object::StructKind || kind == Object::ArrayKind;
}

return kind == Object::ArrayKind;
}

NEVER_INLINE bool Interpreter::testRefDefined(void* refPtr, const CompositeType** typeInfo)
Expand Down
85 changes: 81 additions & 4 deletions src/jit/GarbageCollectorInl.h
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,89 @@ static void emitGCCastGeneric(sljit_compiler* compiler, Instruction* instr)
if ((srcInfo & (JumpIfCastGeneric::IsSrcNullable | JumpIfCastGeneric::IsSrcTagged)) == 0) {
sljit_emit_op1(compiler, SLJIT_MOV_P, SLJIT_TMP_DEST_REG, 0, SLJIT_MEM1(srcReg), JITFieldAccessor::objectTypeInfo());
if (label != nullptr) {
label->jumpFrom(sljit_emit_cmp(compiler, isTestOrCastFail ? SLJIT_NOT_EQUAL : SLJIT_EQUAL, SLJIT_MEM1(SLJIT_TMP_DEST_REG), -static_cast<sljit_sw>(sizeof(sljit_up)), SLJIT_IMM, kind));
sljit_s32 type = (genericType == Value::EqRef) ? SLJIT_LESS_EQUAL : SLJIT_EQUAL;
if (isTestOrCastFail) {
type ^= 0x1;
}
label->jumpFrom(sljit_emit_cmp(compiler, type, SLJIT_MEM1(SLJIT_TMP_DEST_REG), -static_cast<sljit_sw>(sizeof(sljit_up)), SLJIT_IMM, kind));
} else if (!isTestOrCastFail) {
context->appendTrapJump(ExecutionContext::CastFailureError, sljit_emit_cmp(compiler, SLJIT_NOT_EQUAL, SLJIT_MEM1(SLJIT_TMP_DEST_REG), -static_cast<sljit_sw>(sizeof(sljit_up)), SLJIT_IMM, kind));
sljit_s32 type = (genericType == Value::EqRef) ? SLJIT_GREATER : SLJIT_NOT_EQUAL;
context->appendTrapJump(ExecutionContext::CastFailureError, sljit_emit_cmp(compiler, type, SLJIT_MEM1(SLJIT_TMP_DEST_REG), -static_cast<sljit_sw>(sizeof(sljit_up)), SLJIT_IMM, kind));
} else {
sljit_emit_op2u(compiler, SLJIT_SUB | SLJIT_SET_Z, SLJIT_MEM1(SLJIT_TMP_DEST_REG), -static_cast<sljit_sw>(sizeof(sljit_up)), SLJIT_IMM, kind);
sljit_emit_op_flags(compiler, SLJIT_MOV, args[1].arg, args[1].argw, SLJIT_EQUAL);
sljit_s32 type = (genericType == Value::EqRef) ? SLJIT_SET_LESS_EQUAL : SLJIT_SET_Z;
sljit_emit_op2u(compiler, SLJIT_SUB | type, SLJIT_MEM1(SLJIT_TMP_DEST_REG), -static_cast<sljit_sw>(sizeof(sljit_up)), SLJIT_IMM, kind);
type = (genericType == Value::EqRef) ? SLJIT_LESS_EQUAL : SLJIT_EQUAL;
sljit_emit_op_flags(compiler, SLJIT_MOV, args[1].arg, args[1].argw, type);
}
return;
}

if (genericType == Value::EqRef) {
if ((srcInfo & JumpIfCastGeneric::IsSrcTagged) != 0) {
sljit_emit_op2(compiler, SLJIT_ROTR, SLJIT_TMP_DEST_REG, 0, srcReg, 0, SLJIT_IMM, 1);
sljit_jump* jump = sljit_emit_cmp(compiler, SLJIT_SIG_LESS_EQUAL, SLJIT_TMP_DEST_REG, 0, SLJIT_IMM, 0);
sljit_emit_op2(compiler, SLJIT_SHL, SLJIT_TMP_DEST_REG, 0, SLJIT_TMP_DEST_REG, 0, SLJIT_IMM, 1);

sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_TMP_DEST_REG, 0, SLJIT_MEM1(SLJIT_TMP_DEST_REG), JITFieldAccessor::objectTypeInfo());
sljit_emit_op2(compiler, SLJIT_SUB, SLJIT_TMP_DEST_REG, 0, SLJIT_MEM1(SLJIT_TMP_DEST_REG), -static_cast<sljit_sw>(sizeof(sljit_up)), SLJIT_IMM, kind + 1);
sljit_set_label(jump, sljit_emit_label(compiler));

if (label != nullptr) {
sljit_s32 type = ((srcInfo & JumpIfCastGeneric::IsNullable) == 0) ? SLJIT_SIG_LESS : SLJIT_SIG_LESS_EQUAL;
if (isTestOrCastFail) {
type ^= 0x1;
}
label->jumpFrom(sljit_emit_cmp(compiler, type, SLJIT_TMP_DEST_REG, 0, SLJIT_IMM, 0));
} else if (!isTestOrCastFail) {
sljit_s32 type = ((srcInfo & JumpIfCastGeneric::IsNullable) == 0) ? SLJIT_SIG_GREATER_EQUAL : SLJIT_SIG_GREATER;
context->appendTrapJump(ExecutionContext::CastFailureError, sljit_emit_cmp(compiler, type, SLJIT_TMP_DEST_REG, 0, SLJIT_IMM, 0));
} else {
sljit_s32 type = ((srcInfo & JumpIfCastGeneric::IsNullable) == 0) ? SLJIT_SET_SIG_LESS : SLJIT_SET_SIG_LESS_EQUAL;
sljit_emit_op2u(compiler, SLJIT_SUB | type, SLJIT_TMP_DEST_REG, 0, SLJIT_IMM, 0);
type = ((srcInfo & JumpIfCastGeneric::IsNullable) == 0) ? SLJIT_SIG_LESS : SLJIT_SIG_LESS_EQUAL;
sljit_emit_op_flags(compiler, SLJIT_MOV, args[1].arg, args[1].argw, type);
}
return;
}

if (srcReg != SLJIT_TMP_DEST_REG && label == nullptr && isTestOrCastFail) {
sljit_emit_op1(compiler, SLJIT_MOV, SLJIT_TMP_DEST_REG, 0, srcReg, 0);
srcReg = SLJIT_TMP_DEST_REG;
}

sljit_jump* jump = sljit_emit_cmp(compiler, SLJIT_EQUAL, srcReg, 0, SLJIT_IMM, 0);
sljit_emit_op1(compiler, SLJIT_MOV_P, SLJIT_TMP_DEST_REG, 0, SLJIT_MEM1(srcReg), JITFieldAccessor::objectTypeInfo());

if (label != nullptr) {
label->jumpFrom(sljit_emit_cmp(compiler, isTestOrCastFail ? SLJIT_GREATER : SLJIT_LESS_EQUAL, SLJIT_MEM1(SLJIT_TMP_DEST_REG), -static_cast<sljit_sw>(sizeof(sljit_up)), SLJIT_IMM, kind));
if ((srcInfo & JumpIfCastGeneric::IsNullable) != 0) {
isTestOrCastFail = !isTestOrCastFail;
}

if (isTestOrCastFail) {
label->jumpFrom(jump);
} else {
sljit_set_label(jump, sljit_emit_label(compiler));
}
} else if (!isTestOrCastFail) {
context->appendTrapJump(ExecutionContext::CastFailureError, sljit_emit_cmp(compiler, SLJIT_GREATER, SLJIT_MEM1(SLJIT_TMP_DEST_REG), -static_cast<sljit_sw>(sizeof(sljit_up)), SLJIT_IMM, kind));
if ((srcInfo & JumpIfCastGeneric::IsNullable) != 0) {
sljit_set_label(jump, sljit_emit_label(compiler));
} else {
context->appendTrapJump(ExecutionContext::CastFailureError, jump);
}
} else {
ASSERT(srcReg == SLJIT_TMP_DEST_REG);
kind++;
if ((srcInfo & JumpIfCastGeneric::IsNullable) != 0) {
sljit_emit_op1(compiler, SLJIT_MOV_P, SLJIT_TMP_DEST_REG, 0, SLJIT_MEM1(SLJIT_TMP_DEST_REG), -static_cast<sljit_sw>(sizeof(sljit_up)));
} else {
sljit_emit_op2(compiler, SLJIT_SUB, SLJIT_TMP_DEST_REG, 0, SLJIT_MEM1(SLJIT_TMP_DEST_REG), -static_cast<sljit_sw>(sizeof(sljit_up)), SLJIT_IMM, kind);
kind = 0;
}
sljit_set_label(jump, sljit_emit_label(compiler));
sljit_emit_op2u(compiler, SLJIT_SUB | SLJIT_SET_SIG_LESS, SLJIT_TMP_DEST_REG, 0, SLJIT_IMM, kind);
sljit_emit_op_flags(compiler, SLJIT_MOV, args[1].arg, args[1].argw, SLJIT_SIG_LESS);
}
return;
}
Expand Down
22 changes: 18 additions & 4 deletions src/parser/WASMParser.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2594,6 +2594,7 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate {
case Walrus::Value::I31Ref:
case Walrus::Value::StructRef:
case Walrus::Value::ArrayRef:
case Walrus::Value::EqRef:
break;
case Walrus::Value::NoAnyRef:
case Walrus::Value::NoExternRef:
Expand Down Expand Up @@ -2693,8 +2694,17 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate {
virtual void OnGCUnaryExpr(int opcode) override
{
switch (opcode) {
case Opcode::RefEq:
case Opcode::RefEq: {
auto src1 = popVMStack();
auto src0 = popVMStack();
auto dst = computeExprResultPosition(Walrus::Value::Type::I32);
if (sizeof(void*) == 4) {
pushByteCode(Walrus::I32Eq(src0, src1, dst), WASMOpcode::RefEqOpcode);
} else {
pushByteCode(Walrus::I64Eq(src0, src1, dst), WASMOpcode::RefEqOpcode);
}
break;
}
case Opcode::ArrayLen: {
bool isNullable = Walrus::Value::isNullableRefType(peekVMStackInfo().valueType());
auto src = popVMStack();
Expand All @@ -2703,9 +2713,13 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate {
break;
}
case Opcode::AnyConvertExtern:
case Opcode::ExternConvertAny: {
Walrus::Value::Type type = (opcode == Opcode::AnyConvertExtern) ? Walrus::Value::Type::AnyRef : Walrus::Value::Type::ExternRef;
auto src = popVMStack();
auto dst = computeExprResultPosition(type);
generateMoveCodeIfNeeds(src, dst, type);
break;
case Opcode::ExternConvertAny:
break;
}
case Opcode::RefI31: {
auto src = popVMStack();
auto dst = computeExprResultPosition(Walrus::Value::Type::I31Ref);
Expand Down Expand Up @@ -3008,7 +3022,7 @@ class WASMBinaryReader : public wabt::WASMBinaryReaderDelegate {
}
if (m_shouldContinueToGenerateByteCode) {
for (size_t i = 0; i < m_currentFunctionType->result().size() && m_vmStack.size(); i++) {
ASSERT(popVMStackInfo().valueType() == m_currentFunctionType->result()[m_currentFunctionType->result().size() - i - 1]);
ASSERT(toDebugType(popVMStackInfo().valueType()) == toDebugType(m_currentFunctionType->result()[m_currentFunctionType->result().size() - i - 1]));
}
ASSERT(m_vmStack.empty());
}
Expand Down
168 changes: 168 additions & 0 deletions test/extended/gc/ref_eq.wast
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
(module
(type $st (sub (struct)))
(type $st' (sub (struct (field i32))))
(type $at (array i8))
(type $st-sub1 (sub $st (struct)))
(type $st-sub2 (sub $st (struct)))
(type $st'-sub1 (sub $st' (struct (field i32))))
(type $st'-sub2 (sub $st' (struct (field i32))))

(table 20 (ref null eq))

(func (export "init")
(table.set (i32.const 0) (ref.null eq))
(table.set (i32.const 1) (ref.null i31))
(table.set (i32.const 2) (ref.i31 (i32.const 7)))
(table.set (i32.const 3) (ref.i31 (i32.const 7)))
(table.set (i32.const 4) (ref.i31 (i32.const 8)))
(table.set (i32.const 5) (struct.new_default $st))
(table.set (i32.const 6) (struct.new_default $st))
(table.set (i32.const 7) (array.new_default $at (i32.const 0)))
(table.set (i32.const 8) (array.new_default $at (i32.const 0)))
)

(func (export "eq") (param $i i32) (param $j i32) (result i32)
(ref.eq (table.get (local.get $i)) (table.get (local.get $j)))
)
)

(invoke "init")

(assert_return (invoke "eq" (i32.const 0) (i32.const 0)) (i32.const 1))
(assert_return (invoke "eq" (i32.const 0) (i32.const 1)) (i32.const 1))
(assert_return (invoke "eq" (i32.const 0) (i32.const 2)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 0) (i32.const 3)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 0) (i32.const 4)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 0) (i32.const 5)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 0) (i32.const 6)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 0) (i32.const 7)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 0) (i32.const 8)) (i32.const 0))

(assert_return (invoke "eq" (i32.const 1) (i32.const 0)) (i32.const 1))
(assert_return (invoke "eq" (i32.const 1) (i32.const 1)) (i32.const 1))
(assert_return (invoke "eq" (i32.const 1) (i32.const 2)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 1) (i32.const 3)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 1) (i32.const 4)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 1) (i32.const 5)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 1) (i32.const 6)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 1) (i32.const 7)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 1) (i32.const 8)) (i32.const 0))

(assert_return (invoke "eq" (i32.const 2) (i32.const 0)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 2) (i32.const 1)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 2) (i32.const 2)) (i32.const 1))
(assert_return (invoke "eq" (i32.const 2) (i32.const 3)) (i32.const 1))
(assert_return (invoke "eq" (i32.const 2) (i32.const 4)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 2) (i32.const 5)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 2) (i32.const 6)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 2) (i32.const 7)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 2) (i32.const 8)) (i32.const 0))

(assert_return (invoke "eq" (i32.const 3) (i32.const 0)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 3) (i32.const 1)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 3) (i32.const 2)) (i32.const 1))
(assert_return (invoke "eq" (i32.const 3) (i32.const 3)) (i32.const 1))
(assert_return (invoke "eq" (i32.const 3) (i32.const 4)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 3) (i32.const 5)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 3) (i32.const 6)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 3) (i32.const 7)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 3) (i32.const 8)) (i32.const 0))

(assert_return (invoke "eq" (i32.const 4) (i32.const 0)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 4) (i32.const 1)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 4) (i32.const 2)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 4) (i32.const 3)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 4) (i32.const 4)) (i32.const 1))
(assert_return (invoke "eq" (i32.const 4) (i32.const 5)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 4) (i32.const 6)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 4) (i32.const 7)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 4) (i32.const 8)) (i32.const 0))

(assert_return (invoke "eq" (i32.const 5) (i32.const 0)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 5) (i32.const 1)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 5) (i32.const 2)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 5) (i32.const 3)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 5) (i32.const 4)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 5) (i32.const 5)) (i32.const 1))
(assert_return (invoke "eq" (i32.const 5) (i32.const 6)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 5) (i32.const 7)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 5) (i32.const 8)) (i32.const 0))

(assert_return (invoke "eq" (i32.const 6) (i32.const 0)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 6) (i32.const 1)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 6) (i32.const 2)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 6) (i32.const 3)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 6) (i32.const 4)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 6) (i32.const 5)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 6) (i32.const 6)) (i32.const 1))
(assert_return (invoke "eq" (i32.const 6) (i32.const 7)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 6) (i32.const 8)) (i32.const 0))

(assert_return (invoke "eq" (i32.const 7) (i32.const 0)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 7) (i32.const 1)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 7) (i32.const 2)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 7) (i32.const 3)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 7) (i32.const 4)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 7) (i32.const 5)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 7) (i32.const 6)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 7) (i32.const 7)) (i32.const 1))
(assert_return (invoke "eq" (i32.const 7) (i32.const 8)) (i32.const 0))

(assert_return (invoke "eq" (i32.const 8) (i32.const 0)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 8) (i32.const 1)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 8) (i32.const 2)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 8) (i32.const 3)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 8) (i32.const 4)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 8) (i32.const 5)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 8) (i32.const 6)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 8) (i32.const 7)) (i32.const 0))
(assert_return (invoke "eq" (i32.const 8) (i32.const 8)) (i32.const 1))

(assert_invalid
(module
(func (export "eq") (param $r (ref any)) (result i32)
(ref.eq (local.get $r) (local.get $r))
)
)
"type mismatch"
)
(assert_invalid
(module
(func (export "eq") (param $r (ref null any)) (result i32)
(ref.eq (local.get $r) (local.get $r))
)
)
"type mismatch"
)
(assert_invalid
(module
(func (export "eq") (param $r (ref func)) (result i32)
(ref.eq (local.get $r) (local.get $r))
)
)
"type mismatch"
)
(assert_invalid
(module
(func (export "eq") (param $r (ref null func)) (result i32)
(ref.eq (local.get $r) (local.get $r))
)
)
"type mismatch"
)
(assert_invalid
(module
(func (export "eq") (param $r (ref extern)) (result i32)
(ref.eq (local.get $r) (local.get $r))
)
)
"type mismatch"
)
(assert_invalid
(module
(func (export "eq") (param $r (ref null extern)) (result i32)
(ref.eq (local.get $r) (local.get $r))
)
)
"type mismatch"
)
Loading
Loading