diff --git a/include/raf/pass.h b/include/raf/pass.h index 605c6d5f..e08a40fe 100644 --- a/include/raf/pass.h +++ b/include/raf/pass.h @@ -17,6 +17,7 @@ #include "raf/value.h" #include "raf/ir_ext.h" #include "raf/pass_manager.h" +#include "raf/sharding.h" namespace raf { namespace pass { @@ -342,6 +343,28 @@ Pass WavefrontStreamSchedule(); */ Pass ASAPStreamSchedule(); +/*! + * \brief Set ShardOpCallAttrs for annotated Relay Op Call + * + * \return The created pass. + */ +Pass AnnotateShardOpCall(const ir::Map& attrs_map); + +/*! + * \brief Expand Op Call with ShardOpCallAttrs to a series of expressions + * according to the corresponding expansion pattern + * + * \return The created pass. + */ +Pass ExpandShardOpCall(); + +/*! + * \brief . + * + * \return . + */ +Pass InferShardSpec(); + /*! * \brief This pass transforms BBNF into ANF and schedules operators to improve overlapping * between computation and communication. diff --git a/include/raf/sharding.h b/include/raf/sharding.h index b9c24147..1ca9c8ad 100644 --- a/include/raf/sharding.h +++ b/include/raf/sharding.h @@ -100,8 +100,8 @@ class ShardSpecObj final : public BaseShardSpecObj { v->Visit("ranks", &ranks); v->Visit("logic_shape", &logic_shape); v->Visit("logic_index", &logic_index_); - v->Visit("phy_shape", &logic_shape); - v->Visit("phy_index", &logic_index_); + v->Visit("phy_shape", &phy_shape); + v->Visit("phy_index", &phy_index_); v->Visit("subgroup_shape", &subgroup_shape); v->Visit("subgroup_index", &subgroup_index_); } diff --git a/python/raf/distributed/sharding/__init__.py b/python/raf/distributed/sharding/__init__.py index f86bc27a..19ebff03 100644 --- a/python/raf/distributed/sharding/__init__.py +++ b/python/raf/distributed/sharding/__init__.py @@ -5,3 +5,5 @@ from raf._ffi.sharding._make import ShardOpCallAttrs from .shardspec import BaseShardSpec, ShardSpec, UnsetShardSpec from .utils import make_replicated_spec, make_shard_spec, make_unset_spec +from .expandrule import expand_opcall +from .inferhint import infer_shardspec diff --git a/python/raf/distributed/sharding/expandrule.py b/python/raf/distributed/sharding/expandrule.py new file mode 100644 index 00000000..77f87f1c --- /dev/null +++ b/python/raf/distributed/sharding/expandrule.py @@ -0,0 +1,283 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=invalid-name, unused-argument +"""Implementation of Expansion Rules""" +import functools +from queue import PriorityQueue +from typing import Callable, List +import numpy as np +import raf +import tvm + + +from raf._ffi.op import GetOp +from raf._lib import _register_func, relay +from raf.distributed.sharding import ( + ShardSpec, + BaseShardSpec, + ShardOpCallAttrs, +) +from tvm.relay import Call, Expr +from tvm.ir import Op +from tvm.runtime.object import Object + +pattern_map = { + 0: "kElemWise", + 1: "kBroadcast", + 2: "kInjective", + 3: "kCommReduce", + 4: "kOutEWiseFusable", + 7: "kTuple", + 8: "kOpaque", +} +# TODO: this pattern map is replicated multiple times in source code + + +class ShardInfo: + """Helper for parsing ShardSpec.""" + + # pylint: disable=too-few-public-methods + call: relay.Call + op: Op + args: List[Expr] + attrs: Object + sin: List[BaseShardSpec] + sout: List[BaseShardSpec] + + def __init__(self, call: relay.Call): + assert isinstance(call, relay.Call) + self.call = call + self.op = call.op + self.args = call.args + self.attrs = call.attrs + self.sin = call.attrs.sin + self.sout = call.attrs.sout + + def make_updated(self, op=None, args=None, sin=None, sout=None, attrs=None): + # pylint: disable=too-many-arguments + """Make a new ShardInfo based on this ShardInfo with a few fields modified""" + op = op if op else self.op + args = args if args else self.args + if sin or sout: + sin = sin if sin else self.sin + sout = sout if sout else self.sout + attrs = ShardOpCallAttrs(sin, sout) + elif not attrs: + attrs = self.attrs + call = Call(op, args, attrs) + return ShardInfo(call) + + +def all_satisfied(conds: List[Callable[[ShardInfo], bool]]): + """Return true when all conditions are satisfied.""" + + def func(s: ShardInfo): + for c in conds: + if not c(s): + return False + return True + + return func + + +def is_exact_same_spec(*args): + """Check whether two ShardSpecs are exact same.""" + for e in args[1:]: + if not tvm.ir.structural_equal(args[0], e): + return False + return True + + +def is_same_spec(*args): + """Check whether two ShardSpecs are same except Mutable Attr.""" + if is_sharded(args[0]): + for e in args[1:]: + if not is_sharded(e): + return False + if not tvm.ir.structural_equal(args[0].ranks, e.ranks): + return False + if not tvm.ir.structural_equal(args[0].phy_shape, e.phy_shape): + return False + if not tvm.ir.structural_equal(args[0].subgroup_shape, e.subgroup_shape): + return False + else: + return is_exact_same_spec(*args) + return True + + +def is_sharded(s: BaseShardSpec): + """Check whether it is a ShardSpec.""" + return isinstance(s, ShardSpec) + + +def is_replicated(s: BaseShardSpec): + """Check whether it is a replicated ShardSpec.""" + if not isinstance(s, ShardSpec): + return False + return s.nshard == 1 + + +def no_subgroup(s: BaseShardSpec): + """Check whether subgrouping feature is disabled.""" + if not isinstance(s, ShardSpec): + return False + return s.ngroup == 1 + + +def always_apply(s: ShardInfo): + """Always return True.""" + return True + + +def expand_when(cond: Callable[[ShardInfo], bool], priority=1): + """Specify the priority and the condition when this expansion rule should be used. + + Parameters + ---------- + cond : function(ShardInfo) -> bool + A function validating this expansion rule is eligible to apply. + """ + + if not hasattr(expand_when, "counter"): + expand_when.counter = 0 + if not hasattr(expand_when, "rules"): + expand_when.rules = {} + + def decorator(pyfunc): + if not hasattr(pyfunc, "op_names"): + raise ValueError("Must register expansion rule first") + for op_name in pyfunc.op_names: + op = GetOp(op_name) + if op not in expand_when.rules: + expand_when.rules[op] = PriorityQueue() + expand_when.rules[op].put((-priority, expand_when.counter, cond, pyfunc)) + expand_when.counter += 1 + return pyfunc + + return decorator + + +def register_expansion_rule(op_name): + """Register an expansion rule that converts a full-sized op into a partitioned-size op + + Parameters + ---------- + op_name: str or List[str] + Names of op to register + """ + op_names = [op_name] if isinstance(op_name, str) else op_name + assert isinstance(op_names, list) + + def decorator(pyfunc): + @functools.wraps(pyfunc) + def new_pyfunc(call: relay.Call): + return pyfunc(call) + + setattr(new_pyfunc, "op_names", op_names) + return new_pyfunc + + return decorator + + +@_register_func("raf.sharding._match_expansion_rule") +def expand_opcall(call: relay.Call): + """Match an eligible expansion rule and return expanded IR expr.""" + rules = expand_when.rules[call.op] + s = ShardInfo(call) + for rule in rules.queue: + _, _, cond, irgen = rule + if cond(s): + return irgen(s) + return None + + +@expand_when( + all_satisfied([lambda s: is_replicated(s.sin[0]), lambda s: is_sharded(s.sout[0])]), + priority=1, +) +@register_expansion_rule("raf.op._reshard") +def reshard_replicated_to_sharded(s: ShardInfo): + """_reshard (R to S) -> strided_slice""" + begin, end = [], [] + shape = s.args[0].checked_type.concrete_shape + spec = s.sout[0] + for idx, dim_nshard, dim_size in zip(spec.logic_index, spec.logic_shape, shape): + assert dim_size % dim_nshard == 0 + begin.append(int((dim_size // dim_nshard) * idx)) + end.append(int((dim_size // dim_nshard) * (idx + 1))) + return relay.Call( + GetOp("raf.op.strided_slice"), + [ + s.args[0], + raf.ir.const(begin), + raf.ir.const(end), + raf.ir.const([1] * spec.ndim), + raf.ir.const("end"), + ], + ) + + +@expand_when( + all_satisfied( + [ + lambda s: print(s.sin[0], s.sout[0]) or True, + lambda s: is_sharded(s.sin[0]), + lambda s: is_replicated(s.sout[0]), + ] + ), + priority=1, +) +@register_expansion_rule("raf.op._reshard") +def reshard_sharded_to_replicated(s: ShardInfo): + """_reshard (S to R) -> allgather""" + spec = s.sin[0] + axis = [] + full_shape = [] + for i in range(spec.ndim): + if spec.logic_shape[i] > 1: + axis.append(i) + full_shape.append(int(spec.logic_shape[i])) + full_shape.append(int(spec.subgroup_shape[i])) + assert len(axis) == 1 # TODO: remove this constrain + ranks = np.array([int(e) for e in spec.ranks]).reshape(full_shape) + nshard_on_dim = int(spec.logic_shape[axis[0]]) + rank_list = np.moveaxis(ranks, axis[0], -1).reshape( + (ranks.size // nshard_on_dim, nshard_on_dim) + ) + return relay.Call( + GetOp("raf.op._allgather"), + [s.args[0], raf.ir.const(axis[0]), raf.ir.const(rank_list.tolist())], + ) + + +@expand_when(lambda s: is_same_spec(s.sin[0], s.sin[1], s.sout[0])) +@register_expansion_rule(["raf.op.add", "raf.op.subtract"]) +def add_or_sub(s: ShardInfo): + """add/sub -> add/sub""" + return relay.Call(s.op, s.args) + + +@expand_when(lambda s: is_same_spec(s.sin[0], s.sout[0])) +@register_expansion_rule(["raf.op.relu"]) # TODO: should use a generated list instead +def element_wise(s: ShardInfo): + """element wise -> element wise""" + return relay.Call(s.op, s.args) + + +@expand_when( + all_satisfied( + [ + lambda s: is_sharded(s.sin[0]) and is_sharded(s.sin[1]), + lambda s: no_subgroup(s.sin[0]) and no_subgroup(s.sin[1]), + lambda s: is_replicated(s.sout[0]), + lambda s: s.sin[0].logic_shape[1] == s.sin[1].logic_shape[0], + ] + ) +) +@register_expansion_rule(["raf.op.matmul"]) +def matmul_algor1(s: ShardInfo): + """matmul -> matmul + allreduce""" + y_1 = relay.Call(s.op, s.args) + y_2 = tvm.relay.Tuple([y_1]) + return relay.Call(GetOp("raf.op._allreduce"), [y_2, raf.ir.const("sum"), raf.ir.const(None)]) diff --git a/python/raf/distributed/sharding/inferhint.py b/python/raf/distributed/sharding/inferhint.py new file mode 100644 index 00000000..a5593e9a --- /dev/null +++ b/python/raf/distributed/sharding/inferhint.py @@ -0,0 +1,188 @@ +# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +# pylint: disable=invalid-name, unused-argument, missing-function-docstring +"""Implementaion of Infer Hints""" +from queue import PriorityQueue +from typing import Callable, List + +from raf._ffi.sharding._make import ShardOpCallAttrs +from raf._ffi.op import GetOp +from raf._lib import _register_func, relay +from raf.distributed.sharding.shardspec import BaseShardSpec, UnsetShardSpec, ShardSpec +from raf.distributed.sharding.utils import make_replicated_spec + +from .expandrule import ( + ShardInfo, + always_apply, + expand_opcall, + is_exact_same_spec, + is_same_spec, + is_sharded, +) +from .expandrule import register_expansion_rule as register_infer_hint + + +def try_when(cond: Callable[[ShardInfo], bool], priority=1): + """Specify the priority and the condition when this infer hint should be used. + + Parameters + ---------- + cond : function(ShardInfo) -> bool + A function validating this infer hint is eligible to apply. + """ + + if not hasattr(try_when, "counter"): + try_when.counter = 0 + if not hasattr(try_when, "rules"): + try_when.rules = {} + + def decorator(pyfunc): + if not hasattr(pyfunc, "op_names"): + raise ValueError("Must register infer hint first") + for op_name in pyfunc.op_names: + op = GetOp(op_name) + if op not in try_when.rules: + try_when.rules[op] = PriorityQueue() + try_when.rules[op].put((-priority, try_when.counter, cond, pyfunc)) + try_when.counter += 1 + return pyfunc + + return decorator + + +@_register_func("raf.sharding._infer_shardspec") +def infer_shardspec(call: relay.Call): + # pylint: disable=too-many-locals, too-many-branches, too-many-statements + """Fill the placeholders of ShardSpec with infer hints.""" + rules = try_when.rules[call.op] + s = ShardInfo(call) + + # Step 1: Inherit input spec from previous output + + # inherit_sin should be the correct specs of current inputs + inherit_sin = [] + # specified_sin should be the user-specified specs with filled unset shard specs + specified_sin = [] + + for i in range(len(s.sin)): + if isinstance(s.args[i], relay.Call) and hasattr(s.args[i].attrs, "sin"): + # cannot use isinstance to check the type of OpCall Attrs + # direct inherit ShardSpec + prev_sinfo = ShardInfo(s.args[i]) + inherit_sin.append(prev_sinfo.sout[0]) + else: + # the previous output isn't annotated with ShardSpec + if isinstance(s.sin[i], ShardSpec): + # already exist a specified ShardSpec + inherit_sin.append(s.sin[i]) + else: + # assume the previous output is replicated on all ranks + ndim = len(s.args[i].checked_type.concrete_shape) + inherit_sin.append(make_replicated_spec(ndim)) + + if isinstance(s.sin[i], UnsetShardSpec): + specified_sin.append(inherit_sin[-1]) + else: + specified_sin.append(s.sin[i]) + + inherit_s = s.make_updated(sin=inherit_sin) + specified_s = s.make_updated(sin=specified_sin) + + # Step 2: Match InferHints + + filled_s_list: List[ShardInfo] = [] # TODO: try to remove duplicated solutions + for rule in rules.queue: + _, _, cond, irgen = rule + if cond(specified_s): + filled_s_list.extend([s.make_updated(attrs=a) for a in irgen(specified_s)]) + if cond(inherit_s): + filled_s_list.extend([s.make_updated(attrs=a) for a in irgen(inherit_s)]) + + if not filled_s_list: + raise ValueError("Failed to match an InferHint") + + # Step 3: Check the solution is practicable + ninputs = len(s.sin) + noutputs = len(s.sout) + immut_in_idx = [i for i in range(ninputs) if is_sharded(s.sin[i]) and not s.sin[i].mutable] + immut_out_idx = [i for i in range(noutputs) if is_sharded(s.sout[i]) and not s.sout[i].mutable] + + possible_s_list: List[ShardInfo] = [] + for filled_s in filled_s_list: + if not expand_opcall(filled_s.call): + # there doesn't exist a expansion rule that accepts this sharding solution + continue + immut_args = [(inherit_s.sin[i], filled_s.sin[i]) for i in immut_in_idx] + [ + (inherit_s.sout[i], filled_s.sout[i]) for i in immut_out_idx + ] + for pair in immut_args: + if not is_same_spec(pair[0], pair[1]): + # violate immutable attribute of shard spec + break + else: + # reset Mutable flag for outputs to prevent from spreading this flag mistakenly + sout = [ + spec if spec.mutable else spec.make_updated(mutable=True) for spec in filled_s.sout + ] + possible_s_list.append(filled_s.make_updated(sout=sout)) + + # Step 4: Pick an OpCall with full ShardSpec + # TODO: should use graph searching algorithm with cost map here. + # For now, always select the first solution. + inferred_s = possible_s_list[0] + + # Step 5: Insert Reshard OpCalls + resharded_args = [] + for i in range(ninputs): + if is_same_spec(inherit_s.sin[i], inferred_s.sin[i]): + resharded_args.append(inferred_s.args[i]) + else: + resharded_args.append( + relay.Call( + GetOp("raf.op._reshard"), + [inferred_s.args[i]], + ShardOpCallAttrs([inherit_s.sin[i]], [inferred_s.sin[i]]), + ) + ) + + print("[Sharding Infer] OpCall: %s" % s.op) + for phase in ("In", "Out"): + for i in range(ninputs if phase == "In" else noutputs): + if phase == "In": + a_spec, b_spec, c_spec = s.sin[i], inherit_s.sin[i], inferred_s.sin[i] + else: + a_spec, b_spec, c_spec = s.sout[i], inherit_s.sout[i], inferred_s.sout[i] + print(" %sArg %s: %s" % (phase, i, a_spec), end="") + if not is_exact_same_spec(a_spec, b_spec): + print(" -> %s" % b_spec, end="") + if not is_exact_same_spec(b_spec, c_spec): + print(" -> %s" % c_spec, end="") + print() + + return relay.Call(inferred_s.op, resharded_args, inferred_s.attrs) + + +def is_unset(s: BaseShardSpec): + """Check whether it is an UnsetShardSpec (placeholder of ShardSpec).""" + return isinstance(s, UnsetShardSpec) + + +@try_when(always_apply) +@register_infer_hint(["raf.op.add", "raf.op.subtract"]) +def element_wise_op_with_2in_1out(s: ShardInfo) -> List[ShardOpCallAttrs]: + specs = [] + for e in (s.sin[0], s.sin[1], s.sout[0]): + if not is_unset(e): + specs.append(e) + return [ShardOpCallAttrs([e, e], [e]) for e in specs] + + +@try_when(always_apply) +@register_infer_hint(["raf.op.relu"]) +def element_wise_op_with_1in_1out(s: ShardInfo) -> List[ShardOpCallAttrs]: + specs = [] + for e in (s.sin[0], s.sout[0]): + if not is_unset(e): + specs.append(e) + return [ShardOpCallAttrs([e], [e]) for e in specs] diff --git a/python/raf/distributed/sharding/shardspec.py b/python/raf/distributed/sharding/shardspec.py index 60872ad2..b3abd90f 100644 --- a/python/raf/distributed/sharding/shardspec.py +++ b/python/raf/distributed/sharding/shardspec.py @@ -35,6 +35,14 @@ def __init__(self, ranks, phy_shape, subgroup_shape, mutable): _make.ShardSpec, ranks, phy_shape, subgroup_shape, mutable ) + def make_updated(self, ranks=None, phy_shape=None, subgroup_shape=None, mutable=None): + """Make a new spec based on this spec with a few fields modified""" + ranks = ranks if ranks else self.ranks + phy_shape = phy_shape if phy_shape else self.phy_shape + subgroup_shape = subgroup_shape if subgroup_shape else self.subgroup_shape + mutable = mutable if mutable else self.mutable + return ShardSpec(ranks, phy_shape, subgroup_shape, mutable) + @register_node("raf.sharding.UnsetShardSpec") class UnsetShardSpec(BaseShardSpec): diff --git a/src/impl/sharding.cc b/src/impl/sharding.cc index 5a62c8cc..d97854c7 100644 --- a/src/impl/sharding.cc +++ b/src/impl/sharding.cc @@ -170,7 +170,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) auto r = Downcast(ref); auto ndim = r->ndim_; if (r->nshard_ == 1) { - p->stream << "ShardSpec(Replicated)"; + p->stream << "ShardSpec(Replicated, " << (r->mutable_ ? "Mut)" : "Immut)"); } else { p->stream << "ShardSpec(" << "["; @@ -179,9 +179,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) auto ngroup_on_dim = r->subgroup_shape[i]->value; p->stream << (nshard_on_dim == 1 ? ":" : std::to_string(nshard_on_dim)) << (ngroup_on_dim == 1 ? "" : "(x" + std::to_string(ngroup_on_dim) + ")") - << (i != ndim - 1 ? ", " : ""); + << (i != ndim - 1 ? ", " : "], "); } - p->stream << "])"; + p->stream << (r->mutable_ ? "Mut)" : "Immut)"); } }); diff --git a/src/pass/sharding.cc b/src/pass/sharding.cc new file mode 100644 index 00000000..4f19ce34 --- /dev/null +++ b/src/pass/sharding.cc @@ -0,0 +1,150 @@ +/* + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0 + */ + +/*! + * Copyright (c) 2021 by Contributors + * \file sharding.cc + * \brief Sharding-related Passes (C++ Side) + */ +#include +#include "raf/op.h" +#include "raf/ir.h" +#include "raf/pass.h" +#include "raf/sharding.h" +#include +#include + +namespace raf { +namespace pass { + +using namespace raf::ir; +using namespace raf::op; +using namespace raf::value; +using namespace raf::sharding; + +namespace shard_pass { + +class ShardOpCallAttrsSetter : public ExprMutator { + public: + explicit ShardOpCallAttrsSetter(const Map& attrs_map) : _attrs_map(attrs_map) { + } + + Expr VisitExpr_(const CallNode* node) override { + Call call = Downcast(ExprMutator::VisitExpr_(node)); + const Expr& op = call->op; + if (op->IsInstance()) { + if (_attrs_map.count(call)) { + return Call(node->op, node->args, Attrs(_attrs_map[call])); + } + } + return call; + } + + private: + const Map& _attrs_map; +}; + +class ShardOpCallExpander : public ExprMutator { + public: + Expr VisitExpr_(const FunctionNode* node) override { + // remove inferred function return type as IR has changed + Expr new_body = VisitExpr(node->body); + return Function(node->params, new_body, {}, {}); + } + + Expr VisitExpr_(const CallNode* node) override { + Call call = GetRef(node); + const Expr& op = call->op; + const Attrs& attrs = call->attrs; + const auto* f = tvm::runtime::Registry::Get("raf.sharding._match_expansion_rule"); + + if (attrs.defined() && op->IsInstance() && attrs->IsInstance()) { + Call new_opcall = (*f)(call); + return ExprMutator::VisitExpr_(new_opcall.as()); + } + + return ExprMutator::VisitExpr_(node); + } +}; + +class ShardSpecPropagator : public ExprMutator { + public: + Expr VisitExpr_(const CallNode* node) override { + Call call = Downcast(ExprMutator::VisitExpr_(node)); + const Expr& op = call->op; + const Attrs& attrs = call->attrs; + const Array& args = call->args; + const auto* f = tvm::runtime::Registry::Get("raf.sharding._infer_shardspec"); + + if (attrs.defined() && op->IsInstance() && attrs->IsInstance()) { + Call new_opcall = (*f)(call); + return new_opcall; + } + + return call; + } +}; + +} // namespace shard_pass + +Pass AnnotateShardOpCall(const Map& attrs_map) { + return CreateModulePass( + [=](IRModule mod, const PassContext& pass_ctx) { + DLOG(INFO) << "pass::AnnotateShardOpCall"; + IRModule updated_mod = IRModule(mod->functions); + for (auto kv : updated_mod->functions) { + if (kv.second.as()) { + auto setter = shard_pass::ShardOpCallAttrsSetter(attrs_map); + auto func = tvm::runtime::Downcast(setter.VisitExpr(kv.second)); + updated_mod->Add(kv.first, func, true); + } + } + return updated_mod; + }, + 0, "AnnotateShardOpCall", {}); +} + +RAF_REGISTER_GLOBAL("raf.pass_.AnnotateShardOpCall").set_body_typed(AnnotateShardOpCall); + +Pass ExpandShardOpCall() { + return CreateModulePass( + [=](IRModule mod, const PassContext& pass_ctx) { + DLOG(INFO) << "pass::ExpandShardOpCall"; + IRModule updated_mod = IRModule(mod->functions); + for (auto kv : updated_mod->functions) { + if (kv.second.as()) { + auto setter = shard_pass::ShardOpCallExpander(); + auto func = tvm::runtime::Downcast(setter.VisitExpr(kv.second)); + updated_mod->Add(kv.first, func, true); + } + } + return updated_mod; + }, + 0, "ExpandShardOpCall", {}); +} + +RAF_REGISTER_GLOBAL("raf.pass_.ExpandShardOpCall").set_body_typed(ExpandShardOpCall); + +Pass InferShardSpec() { + return CreateModulePass( + [=](IRModule mod, const PassContext& pass_ctx) { + DLOG(INFO) << "pass::InferShardSpec"; + IRModule updated_mod = IRModule(mod->functions); + for (auto kv : updated_mod->functions) { + if (kv.second.as()) { + auto setter = shard_pass::ShardSpecPropagator(); + auto func = tvm::runtime::Downcast(setter.VisitExpr(kv.second)); + updated_mod->Add(kv.first, func, true); + } + } + return updated_mod; + }, + 0, "InferShardSpec", {}); +} + +RAF_REGISTER_GLOBAL("raf.pass_.InferShardSpec").set_body_typed(InferShardSpec); + +} // namespace pass +} // namespace raf diff --git a/tests/python/pass/test_pass_sharding.py b/tests/python/pass/test_pass_sharding.py index c50c316f..e465f50a 100644 --- a/tests/python/pass/test_pass_sharding.py +++ b/tests/python/pass/test_pass_sharding.py @@ -1,10 +1,22 @@ # Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. # SPDX-License-Identifier: Apache-2.0 -# pylint: disable=missing-function-docstring, missing-class-docstring, invalid-name, protected-access +# pylint: disable=missing-function-docstring, missing-class-docstring, invalid-name, protected-access, no-self-use, too-many-locals +import numpy as np import pytest +import raf +from raf.distributed.sharding import ShardOpCallAttrs +from raf._ffi.pass_ import ( + AnnotateShardOpCall, + ToGraphNormalForm, + ExpandShardOpCall, + InferType, + InferShardSpec, +) +from raf._lib import relay from raf.distributed.sharding import make_replicated_spec, make_shard_spec, make_unset_spec from tvm.ir import structural_equal +from tvm.relay.analysis.analysis import post_order_visit def test_shardspec(): @@ -32,5 +44,97 @@ def test_shardspec(): assert not structural_equal(a, i) +def test_infer_hint_without_prev_spec(): + class Model(raf.Model): + def build(self): + pass + + @raf.model.trace + def forward(self, x, y): + z = raf.add(x, y) + a = raf.relu(z) + b = raf.relu(a) + return b + + model = Model() + m_x = raf.array(np.arange(16, dtype="float").reshape((4, 4))) + m_y = raf.array(np.zeros(16, dtype="float").reshape((4, 4))) + record = model._internal(m_x, m_y) + mod_before = record.mod + mod_before = InferType()(mod_before) + + call_list = [] + post_order_visit( + mod_before["main"].body, + lambda op: call_list.append(op) if isinstance(op, relay.Call) else None, + ) + + attrs_map = { + call_list[1]: ShardOpCallAttrs( + [make_unset_spec()], [make_shard_spec([4, 1], ranks=4, mutable=False)] + ), + call_list[2]: ShardOpCallAttrs( + [make_unset_spec()], [make_replicated_spec(2, mutable=False)] + ), + } + + mod0 = AnnotateShardOpCall(attrs_map)(mod_before) + mod1 = ToGraphNormalForm()(mod0) + mod2 = InferType()(mod1) + mod3 = InferShardSpec()(mod2) + mod4 = InferType()(mod3) + mod5 = ExpandShardOpCall()(mod4) + print("after expand shard opcall") + print(raf._ffi.ir.AsText(mod5)) + + +def test_infer_hint_inserting_reshard(): + class Model(raf.Model): + def build(self): + pass + + @raf.model.trace + def forward(self, x, y): + z = raf.add(x, y) + a = raf.relu(z) + b = raf.relu(a) + return b + + model = Model() + m_x = raf.array(np.arange(16, dtype="float").reshape((4, 4))) + m_y = raf.array(np.zeros(16, dtype="float").reshape((4, 4))) + record = model._internal(m_x, m_y) + mod_before = record.mod + mod_before = InferType()(mod_before) + + print(m_x) + call_list = [] + post_order_visit( + mod_before["main"].body, + lambda op: call_list.append(op) if isinstance(op, relay.Call) else None, + ) + + spec = make_shard_spec([2, 2], [1, 2], 4, mutable=False) + + attrs_map = { + call_list[0]: ShardOpCallAttrs([make_unset_spec(), make_unset_spec()], [make_unset_spec()]), + call_list[1]: ShardOpCallAttrs([make_unset_spec()], [spec]), + } + + mod0 = AnnotateShardOpCall(attrs_map)(mod_before) + mod1 = ToGraphNormalForm()(mod0) + mod2 = InferType()(mod1) + mod3 = InferShardSpec()(mod2) + mod4 = InferType()(mod3) + print("after infer type") + print(raf._ffi.ir.AsText(mod4)) + mod5 = ExpandShardOpCall()(mod4) + print("after expand shard opcall") + print(raf._ffi.ir.AsText(mod5)) + mod6 = InferType()(mod5) + print("after infer type2") + print(raf._ffi.ir.AsText(mod6)) + + if __name__ == "__main__": pytest.main([__file__])