Skip to content

Commit

Permalink
Add partial support for -fwasm-exceptions in Asyncify (WebAssembly#5343)
Browse files Browse the repository at this point in the history
  • Loading branch information
caiiiycuk committed Feb 2, 2023
1 parent b5476c6 commit ad85db6
Showing 1 changed file with 277 additions and 5 deletions.
282 changes: 277 additions & 5 deletions src/passes/Asyncify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,11 @@
// Overall, this should allow good performance with small overhead that is
// mostly noticed at rewind time.
//
// Exceptions handling (-fwasm-exceptions) is partially supported. Asyncify
// can't start unwind operation from a catch block. When assertions mode is
// enabled this pass will check if unwind called from catch block or not, and
// if so throw an unreachable exception.
//
// After this pass is run a new i32 global "__asyncify_state" is added, which
// has the following values:
//
Expand Down Expand Up @@ -163,6 +168,10 @@
// calls, so that you know when to start an asynchronous operation and
// when to propagate results back.
//
// * asyncify_get_catch_counter(): call this to get the current value of the
// internal "__asyncify_catch_counter" variable (only when assertions
// enabled).
//
// These four functions are exported so that you can call them from the
// outside. If you want to manage things from inside the wasm, then you
// couldn't have called them before they were created by this pass. To work
Expand Down Expand Up @@ -305,7 +314,9 @@

#include "asmjs/shared-constants.h"
#include "cfg/liveness-traversal.h"
#include "ir/branch-utils.h"
#include "ir/effects.h"
#include "ir/eh-utils.h"
#include "ir/find_all.h"
#include "ir/linear-execution.h"
#include "ir/literal-utils.h"
Expand All @@ -325,6 +336,8 @@ namespace {

static const Name ASYNCIFY_STATE = "__asyncify_state";
static const Name ASYNCIFY_GET_STATE = "asyncify_get_state";
static const Name ASYNCIFY_CATCH_COUNTER = "__asyncify_catch_counter";
static const Name ASYNCIFY_GET_CATCH_COUNTER = "asyncify_get_catch_counter";
static const Name ASYNCIFY_DATA = "__asyncify_data";
static const Name ASYNCIFY_START_UNWIND = "asyncify_start_unwind";
static const Name ASYNCIFY_STOP_UNWIND = "asyncify_stop_unwind";
Expand Down Expand Up @@ -1134,6 +1147,16 @@ struct AsyncifyFlow : public Pass {
} else if (doesCall(curr)) {
results.push_back(makeCallSupport(curr));
continue;
} else if (auto* iTry = curr->dynCast<Try>()) {
if (item.phase == Work::Scan) {
work.push_back(Work{curr, Work::Finish});
work.push_back(Work{iTry->body, Work::Scan});
continue;
}
iTry->body = results.back();
results.pop_back();
results.push_back(iTry);
continue;
}
// We must handle all control flow above, and all things that can change
// the state, so there should be nothing that can reach here - add it
Expand Down Expand Up @@ -1214,6 +1237,221 @@ struct AsyncifyFlow : public Pass {
}
};

// Add catch block counters to verify that unwind is not called from catch block
struct AsyncifyAddCatchCounters : public Pass {
bool isFunctionParallel() override { return true; }

std::unique_ptr<Pass> create() override {
return std::make_unique<AsyncifyAddCatchCounters>();
}

void runOnFunction(Module* module_, Function* func) override {
class CountersBuilder : public Builder {
public:
CountersBuilder(Module& wasm) : Builder(wasm) {}
Expression* makeInc(int amount = 1) {
return makeGlobalSet(
ASYNCIFY_CATCH_COUNTER,
makeBinary(AddInt32,
makeGlobalGet(ASYNCIFY_CATCH_COUNTER, Type::i32),
makeConst(int32_t(amount))));
}
Expression* makeDec(int amount = 1) {
return makeGlobalSet(
ASYNCIFY_CATCH_COUNTER,
makeBinary(SubInt32,
makeGlobalGet(ASYNCIFY_CATCH_COUNTER, Type::i32),
makeConst(int32_t(amount))));
};
};
CountersBuilder builder(*module_);
BranchUtils::BranchTargets branchTargets(func->body);

// with this walker we found level of "nesting" for each expression
// ... - +0
// catch
// ... - +1
// catch
// ... - +2
std::unordered_map<Expression*, int> expressionNestedLevel;
struct NestedLevelWalker
: public PostWalker<NestedLevelWalker,
UnifiedExpressionVisitor<NestedLevelWalker>> {
std::unordered_map<Expression*, int>* expressionNestedLevel;
int nestedLevel = 0;

static void doStartCatch(NestedLevelWalker* self, Expression** currp) {
self->nestedLevel++;
}

static void doEndCatch(NestedLevelWalker* self, Expression** currp) {
self->nestedLevel--;
}

static void scan(NestedLevelWalker* self, Expression** currp) {
auto curr = *currp;
if (curr->_id == Expression::Id::TryId) {
auto& catchBodies = curr->cast<Try>()->catchBodies;
for (Index i = 0; i < catchBodies.size(); i++) {
self->pushTask(doEndCatch, currp);
self->pushTask(NestedLevelWalker::scan, &catchBodies[i]);
self->pushTask(doStartCatch, currp);
}
self->pushTask(NestedLevelWalker::scan, &curr->cast<Try>()->body);
return;
}

PostWalker<NestedLevelWalker,
UnifiedExpressionVisitor<NestedLevelWalker>>::scan(self,
currp);
}

void visitExpression(Expression* curr) {
expressionNestedLevel->insert(std::make_pair<>(curr, nestedLevel));
}
};
NestedLevelWalker nestedLevelWalker;
nestedLevelWalker.expressionNestedLevel = &expressionNestedLevel;
nestedLevelWalker.walk(func->body);

// with this walker we handle those counters:
// - entering into catch (= pop) +1
// - return -N (nested catch count up to root)
// - break -N (nested catch count up to label)
// - exiting from catch -1
struct AddCountersWalker : public PostWalker<AddCountersWalker> {
Function* func;
CountersBuilder* builder;
BranchUtils::BranchTargets* branchTargets;
std::unordered_map<Expression*, int>* expressionNestedLevel;
int labelNum = 0;

// Each catch block except catch_all should have pop instruction
// We increment counter each time when pop happens (= entering catch
// block)
void visitPop(Pop* pop) {
replaceCurrent(builder->makeSequence(pop, builder->makeInc()));
}
void visitLocalSet(LocalSet* set) {
auto block = set->value->dynCast<Block>(); // from visitPop above
if (block) {
auto pop = block->list[0]->dynCast<Pop>();
if (pop) {
set->value = pop;
replaceCurrent(builder->makeSequence(set, builder->makeInc()));
}
}
}

// When return happens we decrement counter on amount of nested catch
// blocks up to root catch
// +1
// catch
// +1
// ...
// -2
// return
// ...
void visitReturn(Return* ret) {
auto it = expressionNestedLevel->find(ret);
assert(it != expressionNestedLevel->end());
auto nestedLevel = it->second;
if (nestedLevel > 0) {
replaceCurrent(
builder->makeSequence(builder->makeDec(nestedLevel), ret));
}
}

// When break happens we decrement counter on amount of nested catch
// blocks up to label
void visitBreak(Break* br) {
auto it = expressionNestedLevel->find(br);
assert(it != expressionNestedLevel->end());
auto nestedLevel = it->second;

Expression* target = branchTargets->getTarget(br->name);
assert(target != nullptr);

it = expressionNestedLevel->find(target);
assert(it != expressionNestedLevel->end());

auto amount = nestedLevel - it->second;
assert(amount >= 0);

if (amount > 0) {
if (br->condition == nullptr) {
replaceCurrent(builder->makeSequence(builder->makeDec(amount), br));
} else {
auto decIf = builder->makeIf(
br->condition,
builder->makeSequence(builder->makeDec(amount), br),
br->value);
br->condition = nullptr;
replaceCurrent(decIf);
}
}
}

// Replacing each catch block with try/finally and increase counter for
// catch_all blocks (not handled by visitPop); dec counter at the end
// of catch block
// try {fn}-finally-{label}
// +1
// {catch body}
// -1
// catch
// -1
// rethrow {fn}-finally-{label}
void visitTry(Try* curr) {
for (size_t i = 0; i < curr->catchBodies.size(); ++i) {
curr->catchBodies[i] =
addCatchCounters(curr->catchBodies[i], i == curr->catchTags.size());
}
}
Expression* addCatchCounters(Expression* expression, bool catchAll) {
// catch_all case is not covered by PopWalker
if (catchAll) {
auto block = expression->dynCast<Block>();
assert(block != nullptr);
block->list.insertAt(0, builder->makeInc());
}

// dec counters at the end of catch
if (expression->type == Type::none) {
if (auto block = expression->dynCast<Block>()) {
auto last = block->list[block->list.size() - 1];
if (!last->dynCast<Return>()) {
block->list.push_back(builder->makeDec());
block->finalize();
}
} else {
WASM_UNREACHABLE("Unexpected expression type");
}
}

auto name =
func->name.toString() + "-finally-" + std::to_string(++labelNum);
return builder->makeTry(
name,
expression,
{},
{builder->makeSequence(builder->makeDec(),
builder->makeRethrow(name))},
expression->type);
}
};

AddCountersWalker addCountersWalker;
addCountersWalker.func = func;
addCountersWalker.builder = &builder;
addCountersWalker.branchTargets = &branchTargets;
addCountersWalker.expressionNestedLevel = &expressionNestedLevel;
addCountersWalker.walk(func->body);

EHUtils::handleBlockNestedPops(func, *module_);
}
};

// Add asserts in non-instrumented code.
struct AsyncifyAssertInNonInstrumented : public Pass {
bool isFunctionParallel() override { return true; }
Expand Down Expand Up @@ -1692,7 +1930,7 @@ struct Asyncify : public Pass {
verbose);

// Add necessary globals before we emit code to use them.
addGlobals(module, relocatable);
addGlobals(module, relocatable, asserts);

// Instrument the flow of code, adding code instrumentation and
// skips for when rewinding. We do this on flat IR so that it is
Expand Down Expand Up @@ -1728,6 +1966,7 @@ struct Asyncify : public Pass {
// Add asserts in non-instrumented code. Note we do not use an
// instrumented pass runner here as we do want to run on all functions.
PassRunner runner(module);
runner.add(make_unique<AsyncifyAddCatchCounters>());
runner.add(make_unique<AsyncifyAssertInNonInstrumented>(
&analyzer, pointerType, asyncifyMemory));
runner.setIsNested(true);
Expand Down Expand Up @@ -1755,11 +1994,11 @@ struct Asyncify : public Pass {
}
// Finally, add function support (that should not have been seen by
// the previous passes).
addFunctions(module);
addFunctions(module, asserts);
}

private:
void addGlobals(Module* module, bool imported) {
void addGlobals(Module* module, bool imported, bool asserts) {
Builder builder(*module);

auto asyncifyState = builder.makeGlobal(ASYNCIFY_STATE,
Expand All @@ -1772,6 +2011,19 @@ struct Asyncify : public Pass {
}
module->addGlobal(std::move(asyncifyState));

if (asserts) {
auto asyncifyCatchCounter =
builder.makeGlobal(ASYNCIFY_CATCH_COUNTER,
Type::i32,
builder.makeConst(int32_t(0)),
Builder::Mutable);
if (imported) {
asyncifyCatchCounter->module = ENV;
asyncifyCatchCounter->base = ASYNCIFY_CATCH_COUNTER;
}
module->addGlobal(std::move(asyncifyCatchCounter));
}

auto asyncifyData = builder.makeGlobal(ASYNCIFY_DATA,
pointerType,
builder.makeConst(pointerType),
Expand All @@ -1783,14 +2035,23 @@ struct Asyncify : public Pass {
module->addGlobal(std::move(asyncifyData));
}

void addFunctions(Module* module) {
void addFunctions(Module* module, bool asserts) {
Builder builder(*module);
auto makeFunction = [&](Name name, bool setData, State state) {
auto* body = builder.makeBlock();
if (asserts && name == ASYNCIFY_START_UNWIND) {
auto* check = builder.makeIf(
builder.makeBinary(
NeInt32,
builder.makeGlobalGet(ASYNCIFY_CATCH_COUNTER, Type::i32),
builder.makeConst(int32_t(0))),
builder.makeUnreachable());
body->list.push_back(check);
}
std::vector<Type> params;
if (setData) {
params.push_back(pointerType);
}
auto* body = builder.makeBlock();
body->list.push_back(builder.makeGlobalSet(
ASYNCIFY_STATE, builder.makeConst(int32_t(state))));
if (setData) {
Expand Down Expand Up @@ -1838,6 +2099,17 @@ struct Asyncify : public Pass {
builder.makeGlobalGet(ASYNCIFY_STATE, Type::i32)));
module->addExport(builder.makeExport(
ASYNCIFY_GET_STATE, ASYNCIFY_GET_STATE, ExternalKind::Function));

if (asserts) {
module->addFunction(builder.makeFunction(
ASYNCIFY_GET_CATCH_COUNTER,
Signature(Type::none, Type::i32),
{},
builder.makeGlobalGet(ASYNCIFY_CATCH_COUNTER, Type::i32)));
module->addExport(builder.makeExport(ASYNCIFY_GET_CATCH_COUNTER,
ASYNCIFY_GET_CATCH_COUNTER,
ExternalKind::Function));
}
}

Name createSecondaryMemory(Module* module, Address secondaryMemorySize) {
Expand Down

0 comments on commit ad85db6

Please sign in to comment.