Skip to content

Commit

Permalink
refactoring; lit tests added
Browse files Browse the repository at this point in the history
  • Loading branch information
caiiiycuk committed Feb 8, 2023
1 parent 6f1e8b5 commit 6447962
Show file tree
Hide file tree
Showing 2 changed files with 1,139 additions and 79 deletions.
170 changes: 91 additions & 79 deletions src/passes/Asyncify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1274,32 +1274,37 @@ struct AsyncifyAddCatchCounters : public Pass {
CountersBuilder builder(*module_);
BranchUtils::BranchTargets branchTargets(func->body);

// with this walker we will find level of "nesting" for each expression
// ... - +0
// with this walker we will assign count of enclosing catch block to
// each expression
// ... - 0
// catch
// ... - +1
// ... - 1
// catch
// ... - +2
std::unordered_map<Expression*, int> expressionNestedLevel;
// ... - 2
std::unordered_map<Expression*, int> expressionCatchCount;
struct NestedLevelWalker
: public PostWalker<NestedLevelWalker,
UnifiedExpressionVisitor<NestedLevelWalker>> {
std::unordered_map<Expression*, int>* expressionNestedLevel;
int nestedLevel = 0;
std::unordered_map<Expression*, int>* expressionCatchCount;
int catchCount = 0;

static void doStartCatch(NestedLevelWalker* self, Expression** currp) {
self->nestedLevel++;
self->catchCount++;
}

static void doEndCatch(NestedLevelWalker* self, Expression** currp) {
self->nestedLevel--;
self->catchCount--;
}

static void scan(NestedLevelWalker* self, Expression** currp) {
auto curr = *currp;
if (curr->_id == Expression::Id::TryId) {
self->expressionCatchCount->insert(
std::make_pair<>(curr, self->catchCount));
auto& catchBodies = curr->cast<Try>()->catchBodies;
for (Index i = 0; i < catchBodies.size(); i++) {
self->expressionCatchCount->insert(
std::make_pair<>(catchBodies[i], self->catchCount));
self->pushTask(doEndCatch, currp);
self->pushTask(NestedLevelWalker::scan, &catchBodies[i]);
self->pushTask(doStartCatch, currp);
Expand All @@ -1314,146 +1319,153 @@ struct AsyncifyAddCatchCounters : public Pass {
}

void visitExpression(Expression* curr) {
expressionNestedLevel->insert(std::make_pair<>(curr, nestedLevel));
expressionCatchCount->insert(std::make_pair<>(curr, catchCount));
}
};
NestedLevelWalker nestedLevelWalker;
nestedLevelWalker.expressionNestedLevel = &expressionNestedLevel;
nestedLevelWalker.expressionCatchCount = &expressionCatchCount;
nestedLevelWalker.walk(func->body);

// with this walker we will handle those counters:
// with this walker we will handle those changes of counter:
// - entering into catch (= pop) +1
// - return -N (nested catch count up to root)
// - break -N (nested catch count up to label)
// - return -1
// - break -1
// - exiting from catch -1
struct AddCountersWalker : public PostWalker<AddCountersWalker> {
Function* func;
CountersBuilder* builder;
BranchUtils::BranchTargets* branchTargets;
std::unordered_map<Expression*, int>* expressionNestedLevel;
int labelNum = 0;
std::unordered_map<Expression*, int>* expressionCatchCount;
int finallyNum = 0;
int popNum = 0;

int getCatchCount(Expression* expression) {
auto it = expressionCatchCount->find(expression);
assert(it != expressionCatchCount->end());
return it->second;
}

// Each catch block except catch_all should have pop instruction
// We increment counter each time when pop happens (= entering catch
// block)
// We increment counter each time when we enter top-level catch block
void visitPop(Pop* pop) {
replaceCurrent(builder->makeSequence(pop, builder->makeInc()));
if (getCatchCount(pop) == 1) {
auto name =
func->name.toString() + "-pop-" + std::to_string(++popNum);
replaceCurrent(
builder->makeBlock(name, {pop, builder->makeInc()}, Type::none));
}
}
void visitLocalSet(LocalSet* set) {
auto block = set->value->dynCast<Block>(); // from visitPop above
if (block) {
if (block && block->name.hasSubstring("-pop-")) {
auto pop = block->list[0]->dynCast<Pop>();
if (pop) {
set->value = pop;
replaceCurrent(builder->makeSequence(set, builder->makeInc()));
}
assert(pop && getCatchCount(pop) == 1);
set->value = pop;
replaceCurrent(builder->makeBlock(
block->name, {set, builder->makeInc()}, Type::none));
}
}

// When return happens we decrement counter on amount of nested catch
// blocks up to root catch
// When return happens we decrement counter on 1, because we account
// only top-level catch blocks
// catch
// +1
// catch
// +1
// ...
// -2
// ;; not counted
// -1
// return
// ...
void visitReturn(Return* ret) {
auto it = expressionNestedLevel->find(ret);
assert(it != expressionNestedLevel->end());
auto nestedLevel = it->second;
if (nestedLevel > 0) {
replaceCurrent(
builder->makeSequence(builder->makeDec(nestedLevel), ret));
if (getCatchCount(ret) > 0) {
replaceCurrent(builder->makeSequence(builder->makeDec(), ret));
}
}

// When break happens we decrement counter on amount of nested catch
// blocks up to label
// When break happens we decrement counter only if it goes out
// from top-level catch block
void visitBreak(Break* br) {
auto it = expressionNestedLevel->find(br);
assert(it != expressionNestedLevel->end());
auto nestedLevel = it->second;

Expression* target = branchTargets->getTarget(br->name);
assert(target != nullptr);

it = expressionNestedLevel->find(target);
assert(it != expressionNestedLevel->end());

auto amount = nestedLevel - it->second;
assert(amount >= 0);

if (amount > 0) {
if (getCatchCount(br) > 0 && getCatchCount(target) == 0) {
if (br->condition == nullptr) {
replaceCurrent(builder->makeSequence(builder->makeDec(amount), br));
} else {
auto decIf = builder->makeIf(
br->condition,
builder->makeSequence(builder->makeDec(amount), br),
br->value);
replaceCurrent(builder->makeSequence(builder->makeDec(), br));
} else if (br->value == nullptr) {
auto decIf =
builder->makeIf(br->condition,
builder->makeSequence(builder->makeDec(), br),
nullptr);
br->condition = nullptr;
replaceCurrent(decIf);
} else {
Index newLocal = builder->addVar(func, br->value->type);
auto setLocal = builder->makeLocalSet(newLocal, br->value);
auto getLocal = builder->makeLocalGet(newLocal, br->value->type);
auto condition = br->condition;
br->condition = nullptr;
br->value = getLocal;
auto decIf =
builder->makeIf(condition,
builder->makeSequence(builder->makeDec(), br),
getLocal);
replaceCurrent(builder->makeSequence(setLocal, decIf));
}
}
}

// Replacing each catch block with try/finally and increase counter for
// catch_all blocks (not handled by visitPop); dec counter at the end
// of catch block
// try {fn}-finally-{label}
// Replacing each top-level catch block with try/catch_all(finally) and
// increase counter for catch_all blocks (not handled by visitPop); dec
// counter at the end of catch block try ({fn}-finally-{label})
// +1
// {catch body}
// -1
// catch
// catch_all
// -1
// rethrow {fn}-finally-{label}
void visitTry(Try* curr) {
for (size_t i = 0; i < curr->catchBodies.size(); ++i) {
curr->catchBodies[i] =
addCatchCounters(curr->catchBodies[i], i == curr->catchTags.size());
if (getCatchCount(curr) == 0) {
for (size_t i = 0; i < curr->catchBodies.size(); ++i) {
curr->catchBodies[i] = addCatchCounters(
curr->catchBodies[i], i == curr->catchTags.size());
}
}
}
Expression* addCatchCounters(Expression* expression, bool catchAll) {
// catch_all case is not covered by PopWalker
auto block = expression->dynCast<Block>();
if (block == nullptr) {
block = builder->makeBlock(expression);
}

// catch_all case is not covered by visitPop
if (catchAll) {
auto block = expression->dynCast<Block>();
assert(block != nullptr);
block->list.insertAt(0, builder->makeInc());
}

// dec counters at the end of catch
if (expression->type == Type::none) {
if (auto block = expression->dynCast<Block>()) {
auto last = block->list[block->list.size() - 1];
if (!last->dynCast<Return>()) {
block->list.push_back(builder->makeDec());
block->finalize();
}
} else {
WASM_UNREACHABLE("Unexpected expression type");
if (block->type == Type::none) {
auto last = block->list[block->list.size() - 1];
if (!last->dynCast<Return>()) {
block->list.push_back(builder->makeDec());
block->finalize();
}
}

auto name =
func->name.toString() + "-finally-" + std::to_string(++labelNum);
func->name.toString() + "-finally-" + std::to_string(++finallyNum);
return builder->makeTry(
name,
expression,
block,
{},
{builder->makeSequence(builder->makeDec(),
builder->makeRethrow(name))},
expression->type);
block->type);
}
};

AddCountersWalker addCountersWalker;
addCountersWalker.func = func;
addCountersWalker.builder = &builder;
addCountersWalker.branchTargets = &branchTargets;
addCountersWalker.expressionNestedLevel = &expressionNestedLevel;
addCountersWalker.expressionCatchCount = &expressionCatchCount;
addCountersWalker.walk(func->body);

EHUtils::handleBlockNestedPops(func, *module_);
Expand Down
Loading

0 comments on commit 6447962

Please sign in to comment.