diff --git a/src/passes/Asyncify.cpp b/src/passes/Asyncify.cpp index 7b2d57574db..b6bf6e2ae01 100644 --- a/src/passes/Asyncify.cpp +++ b/src/passes/Asyncify.cpp @@ -96,6 +96,11 @@ // Overall, this should allow good performance with small overhead that is // mostly noticed at rewind time. // +// Exceptions handling (-fwasm-exceptions) is partially supported. Asyncify +// can't start unwind operation from a catch block. When assertions mode is +// enabled this pass will check if unwind called from catch block or not, and +// if so throw an unreachable exception. +// // After this pass is run a new i32 global "__asyncify_state" is added, which // has the following values: // @@ -163,6 +168,10 @@ // calls, so that you know when to start an asynchronous operation and // when to propagate results back. // +// * asyncify_get_catch_counter(): call this to get the current value of the +// internal "__asyncify_catch_counter" variable (only when assertions +// enabled). +// // These four functions are exported so that you can call them from the // outside. If you want to manage things from inside the wasm, then you // couldn't have called them before they were created by this pass. To work @@ -305,7 +314,9 @@ #include "asmjs/shared-constants.h" #include "cfg/liveness-traversal.h" +#include "ir/branch-utils.h" #include "ir/effects.h" +#include "ir/eh-utils.h" #include "ir/find_all.h" #include "ir/linear-execution.h" #include "ir/literal-utils.h" @@ -325,6 +336,8 @@ namespace { static const Name ASYNCIFY_STATE = "__asyncify_state"; static const Name ASYNCIFY_GET_STATE = "asyncify_get_state"; +static const Name ASYNCIFY_CATCH_COUNTER = "__asyncify_catch_counter"; +static const Name ASYNCIFY_GET_CATCH_COUNTER = "asyncify_get_catch_counter"; static const Name ASYNCIFY_DATA = "__asyncify_data"; static const Name ASYNCIFY_START_UNWIND = "asyncify_start_unwind"; static const Name ASYNCIFY_STOP_UNWIND = "asyncify_stop_unwind"; @@ -1134,6 +1147,16 @@ struct AsyncifyFlow : public Pass { } else if (doesCall(curr)) { results.push_back(makeCallSupport(curr)); continue; + } else if (auto* iTry = curr->dynCast()) { + if (item.phase == Work::Scan) { + work.push_back(Work{curr, Work::Finish}); + work.push_back(Work{iTry->body, Work::Scan}); + continue; + } + iTry->body = results.back(); + results.pop_back(); + results.push_back(iTry); + continue; } // We must handle all control flow above, and all things that can change // the state, so there should be nothing that can reach here - add it @@ -1214,6 +1237,221 @@ struct AsyncifyFlow : public Pass { } }; +// Add catch block counters to verify that unwind is not called from catch block +struct AsyncifyAddCatchCounters : public Pass { + bool isFunctionParallel() override { return true; } + + std::unique_ptr create() override { + return std::make_unique(); + } + + void runOnFunction(Module* module_, Function* func) override { + class CountersBuilder : public Builder { + public: + CountersBuilder(Module& wasm) : Builder(wasm) {} + Expression* makeInc(int amount = 1) { + return makeGlobalSet( + ASYNCIFY_CATCH_COUNTER, + makeBinary(AddInt32, + makeGlobalGet(ASYNCIFY_CATCH_COUNTER, Type::i32), + makeConst(int32_t(amount)))); + } + Expression* makeDec(int amount = 1) { + return makeGlobalSet( + ASYNCIFY_CATCH_COUNTER, + makeBinary(SubInt32, + makeGlobalGet(ASYNCIFY_CATCH_COUNTER, Type::i32), + makeConst(int32_t(amount)))); + }; + }; + CountersBuilder builder(*module_); + BranchUtils::BranchTargets branchTargets(func->body); + + // with this walker we found level of "nesting" for each expression + // ... - +0 + // catch + // ... - +1 + // catch + // ... - +2 + std::unordered_map expressionNestedLevel; + struct NestedLevelWalker + : public PostWalker> { + std::unordered_map* expressionNestedLevel; + int nestedLevel = 0; + + static void doStartCatch(NestedLevelWalker* self, Expression** currp) { + self->nestedLevel++; + } + + static void doEndCatch(NestedLevelWalker* self, Expression** currp) { + self->nestedLevel--; + } + + static void scan(NestedLevelWalker* self, Expression** currp) { + auto curr = *currp; + if (curr->_id == Expression::Id::TryId) { + auto& catchBodies = curr->cast()->catchBodies; + for (Index i = 0; i < catchBodies.size(); i++) { + self->pushTask(doEndCatch, currp); + self->pushTask(NestedLevelWalker::scan, &catchBodies[i]); + self->pushTask(doStartCatch, currp); + } + self->pushTask(NestedLevelWalker::scan, &curr->cast()->body); + return; + } + + PostWalker>::scan(self, + currp); + } + + void visitExpression(Expression* curr) { + expressionNestedLevel->insert(std::make_pair<>(curr, nestedLevel)); + } + }; + NestedLevelWalker nestedLevelWalker; + nestedLevelWalker.expressionNestedLevel = &expressionNestedLevel; + nestedLevelWalker.walk(func->body); + + // with this walker we handle those counters: + // - entering into catch (= pop) +1 + // - return -N (nested catch count up to root) + // - break -N (nested catch count up to label) + // - exiting from catch -1 + struct AddCountersWalker : public PostWalker { + Function* func; + CountersBuilder* builder; + BranchUtils::BranchTargets* branchTargets; + std::unordered_map* expressionNestedLevel; + int labelNum = 0; + + // Each catch block except catch_all should have pop instruction + // We increment counter each time when pop happens (= entering catch + // block) + void visitPop(Pop* pop) { + replaceCurrent(builder->makeSequence(pop, builder->makeInc())); + } + void visitLocalSet(LocalSet* set) { + auto block = set->value->dynCast(); // from visitPop above + if (block) { + auto pop = block->list[0]->dynCast(); + if (pop) { + set->value = pop; + replaceCurrent(builder->makeSequence(set, builder->makeInc())); + } + } + } + + // When return happens we decrement counter on amount of nested catch + // blocks up to root catch + // +1 + // catch + // +1 + // ... + // -2 + // return + // ... + void visitReturn(Return* ret) { + auto it = expressionNestedLevel->find(ret); + assert(it != expressionNestedLevel->end()); + auto nestedLevel = it->second; + if (nestedLevel > 0) { + replaceCurrent( + builder->makeSequence(builder->makeDec(nestedLevel), ret)); + } + } + + // When break happens we decrement counter on amount of nested catch + // blocks up to label + void visitBreak(Break* br) { + auto it = expressionNestedLevel->find(br); + assert(it != expressionNestedLevel->end()); + auto nestedLevel = it->second; + + Expression* target = branchTargets->getTarget(br->name); + assert(target != nullptr); + + it = expressionNestedLevel->find(target); + assert(it != expressionNestedLevel->end()); + + auto amount = nestedLevel - it->second; + assert(amount >= 0); + + if (amount > 0) { + if (br->condition == nullptr) { + replaceCurrent(builder->makeSequence(builder->makeDec(amount), br)); + } else { + auto decIf = builder->makeIf( + br->condition, + builder->makeSequence(builder->makeDec(amount), br), + br->value); + br->condition = nullptr; + replaceCurrent(decIf); + } + } + } + + // Replacing each catch block with try/finally and increase counter for + // catch_all blocks (not handled by visitPop); dec counter at the end + // of catch block + // try {fn}-finally-{label} + // +1 + // {catch body} + // -1 + // catch + // -1 + // rethrow {fn}-finally-{label} + void visitTry(Try* curr) { + for (size_t i = 0; i < curr->catchBodies.size(); ++i) { + curr->catchBodies[i] = + addCatchCounters(curr->catchBodies[i], i == curr->catchTags.size()); + } + } + Expression* addCatchCounters(Expression* expression, bool catchAll) { + // catch_all case is not covered by PopWalker + if (catchAll) { + auto block = expression->dynCast(); + assert(block != nullptr); + block->list.insertAt(0, builder->makeInc()); + } + + // dec counters at the end of catch + if (expression->type == Type::none) { + if (auto block = expression->dynCast()) { + auto last = block->list[block->list.size() - 1]; + if (!last->dynCast()) { + block->list.push_back(builder->makeDec()); + block->finalize(); + } + } else { + WASM_UNREACHABLE("Unexpected expression type"); + } + } + + auto name = + func->name.toString() + "-finally-" + std::to_string(++labelNum); + return builder->makeTry( + name, + expression, + {}, + {builder->makeSequence(builder->makeDec(), + builder->makeRethrow(name))}, + expression->type); + } + }; + + AddCountersWalker addCountersWalker; + addCountersWalker.func = func; + addCountersWalker.builder = &builder; + addCountersWalker.branchTargets = &branchTargets; + addCountersWalker.expressionNestedLevel = &expressionNestedLevel; + addCountersWalker.walk(func->body); + + EHUtils::handleBlockNestedPops(func, *module_); + } +}; + // Add asserts in non-instrumented code. struct AsyncifyAssertInNonInstrumented : public Pass { bool isFunctionParallel() override { return true; } @@ -1692,7 +1930,7 @@ struct Asyncify : public Pass { verbose); // Add necessary globals before we emit code to use them. - addGlobals(module, relocatable); + addGlobals(module, relocatable, asserts); // Instrument the flow of code, adding code instrumentation and // skips for when rewinding. We do this on flat IR so that it is @@ -1728,6 +1966,7 @@ struct Asyncify : public Pass { // Add asserts in non-instrumented code. Note we do not use an // instrumented pass runner here as we do want to run on all functions. PassRunner runner(module); + runner.add(make_unique()); runner.add(make_unique( &analyzer, pointerType, asyncifyMemory)); runner.setIsNested(true); @@ -1755,11 +1994,11 @@ struct Asyncify : public Pass { } // Finally, add function support (that should not have been seen by // the previous passes). - addFunctions(module); + addFunctions(module, asserts); } private: - void addGlobals(Module* module, bool imported) { + void addGlobals(Module* module, bool imported, bool asserts) { Builder builder(*module); auto asyncifyState = builder.makeGlobal(ASYNCIFY_STATE, @@ -1772,6 +2011,19 @@ struct Asyncify : public Pass { } module->addGlobal(std::move(asyncifyState)); + if (asserts) { + auto asyncifyCatchCounter = + builder.makeGlobal(ASYNCIFY_CATCH_COUNTER, + Type::i32, + builder.makeConst(int32_t(0)), + Builder::Mutable); + if (imported) { + asyncifyCatchCounter->module = ENV; + asyncifyCatchCounter->base = ASYNCIFY_CATCH_COUNTER; + } + module->addGlobal(std::move(asyncifyCatchCounter)); + } + auto asyncifyData = builder.makeGlobal(ASYNCIFY_DATA, pointerType, builder.makeConst(pointerType), @@ -1783,14 +2035,23 @@ struct Asyncify : public Pass { module->addGlobal(std::move(asyncifyData)); } - void addFunctions(Module* module) { + void addFunctions(Module* module, bool asserts) { Builder builder(*module); auto makeFunction = [&](Name name, bool setData, State state) { + auto* body = builder.makeBlock(); + if (asserts && name == ASYNCIFY_START_UNWIND) { + auto* check = builder.makeIf( + builder.makeBinary( + NeInt32, + builder.makeGlobalGet(ASYNCIFY_CATCH_COUNTER, Type::i32), + builder.makeConst(int32_t(0))), + builder.makeUnreachable()); + body->list.push_back(check); + } std::vector params; if (setData) { params.push_back(pointerType); } - auto* body = builder.makeBlock(); body->list.push_back(builder.makeGlobalSet( ASYNCIFY_STATE, builder.makeConst(int32_t(state)))); if (setData) { @@ -1838,6 +2099,17 @@ struct Asyncify : public Pass { builder.makeGlobalGet(ASYNCIFY_STATE, Type::i32))); module->addExport(builder.makeExport( ASYNCIFY_GET_STATE, ASYNCIFY_GET_STATE, ExternalKind::Function)); + + if (asserts) { + module->addFunction(builder.makeFunction( + ASYNCIFY_GET_CATCH_COUNTER, + Signature(Type::none, Type::i32), + {}, + builder.makeGlobalGet(ASYNCIFY_CATCH_COUNTER, Type::i32))); + module->addExport(builder.makeExport(ASYNCIFY_GET_CATCH_COUNTER, + ASYNCIFY_GET_CATCH_COUNTER, + ExternalKind::Function)); + } } Name createSecondaryMemory(Module* module, Address secondaryMemorySize) {