Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Distributed][auto-parallel] Sharding Specification and rule discovery #336

Merged
merged 54 commits into from
Aug 4, 2023
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
54 commits
Select commit Hold shift + click to select a range
a491cb3
[CI] List package versions in ci (#334)
yaoyaoding Jul 29, 2023
0bc0f2d
[CI] List package versions in ci (#334)
yaoyaoding Jul 29, 2023
8a3cdec
update
soodoshll Jul 30, 2023
6a064bf
update
soodoshll Jul 30, 2023
78fbf83
update
soodoshll Jul 30, 2023
924d9d4
update
soodoshll Jul 30, 2023
317d7b4
update
soodoshll Jul 30, 2023
8748c2e
update
soodoshll Jul 30, 2023
28dcec9
update
soodoshll Jul 31, 2023
0532154
[Models] Llama2 fix (#333)
Aalanli Jul 31, 2023
be37b1f
update
soodoshll Jul 31, 2023
6790b82
format
soodoshll Jul 31, 2023
bb9e1cb
[Models] Llama2 fix (#333)
Aalanli Jul 31, 2023
09f340e
update
soodoshll Jul 31, 2023
3a37a1e
update
soodoshll Jul 31, 2023
33bcff8
fix
soodoshll Jul 31, 2023
138f007
update
soodoshll Jul 31, 2023
4eaa65c
update
soodoshll Aug 1, 2023
1eb41ad
[Operator] Composite Elementwise Operation (#337)
hjjq Aug 1, 2023
a395fb1
[Fixbug] Clear the intermediate object files for kernel tuning (#339)
yaoyaoding Aug 2, 2023
7e798e1
update
soodoshll Aug 3, 2023
476acaa
format
soodoshll Aug 3, 2023
42b3d0f
update
soodoshll Aug 3, 2023
56197c9
[Hidet script] Add `hidet.lang.types` submodule (#340)
yaoyaoding Aug 3, 2023
3213abb
refactor
soodoshll Aug 3, 2023
e9bfa99
[CPU][Scheduler] Use mutli-threads for autl-scheduler (#341)
yaoyaoding Aug 3, 2023
d50f78a
fix
soodoshll Aug 4, 2023
1cf263f
update
soodoshll Aug 4, 2023
bb18a18
update
soodoshll Aug 4, 2023
0b61aec
shard spec data structure
soodoshll Jul 27, 2023
012748d
update
soodoshll Jul 30, 2023
dbf6edf
update
soodoshll Jul 30, 2023
6b2e4a5
update
soodoshll Jul 30, 2023
f28c935
update
soodoshll Jul 30, 2023
fcd1381
update
soodoshll Jul 30, 2023
d00e34d
update
soodoshll Jul 30, 2023
8a36b44
update
soodoshll Jul 31, 2023
6d2ee5a
update
soodoshll Jul 31, 2023
bef294b
format
soodoshll Jul 31, 2023
2e68239
update
soodoshll Jul 31, 2023
600d643
update
soodoshll Jul 31, 2023
2f8aca7
fix
soodoshll Jul 31, 2023
091b4d5
update
soodoshll Jul 31, 2023
a513d6a
update
soodoshll Aug 1, 2023
c58effe
update
soodoshll Aug 3, 2023
2152e8e
format
soodoshll Aug 3, 2023
87443c8
update
soodoshll Aug 3, 2023
5c69558
refactor
soodoshll Aug 3, 2023
de4daef
fix
soodoshll Aug 4, 2023
16aecb2
update
soodoshll Aug 4, 2023
c47a41c
update
soodoshll Aug 4, 2023
36611a4
fix
soodoshll Aug 4, 2023
7d529b5
fix
soodoshll Aug 4, 2023
0e860be
merge
soodoshll Aug 4, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,10 @@ jobs:
run: |
pip install --no-deps --force-reinstall ${{ env.WHEEL_NAME }}

- name: List installed packages
run: |
pip list

# Run tests

- name: Run tests
Expand Down
20 changes: 20 additions & 0 deletions python/hidet/distributed/partition/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
**Please note that most features under this module only support 1-D partitioning**, because it is
mainly designed for single-machine-multi-GPU settings. This module can automatically search intra-op
partition plan (data-parallel and tensor-parallel), while pipeline-parallel will be processed in
other modules.

We are planning to extend it to 2-D (multi-machine-multi-GPU) in the future.
"""
227 changes: 227 additions & 0 deletions python/hidet/distributed/partition/rule.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Dict
from itertools import product

from hidet.graph import Operator
from hidet.ir import TensorElement, Var, Expr
from hidet.ir.compute import GridCompute, ReduceCompute, ArgReduceCompute, TensorInput
from hidet.ir.functors import ExprRewriter, ComputeRewriter
from hidet.ir.analyzers.bound_analyzer import infer_bound, BoundInfo


from .shard import TensorShardSpec, OpShardSpec, get_tile, AllReduce, ReduceScatter


class IndexRewriter(ExprRewriter, ComputeRewriter):
def __init__(self):
super().__init__()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I believe this is implicitly just calling ExprRewriter.init ?

Maybe a nitpick but relying on MRO might be a bit error prone (if say in the future the inheritance order changes, or if a different method with the same name is added to the first base class ), it might be easier to just explicitly call Base.init(self) ? I guess same for super().visit, etc?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @xinli-git , the current writing looks fine. This function does not require any explicit ordering of calling its parent classes' init method. The default one super().__init__() will call them in MRO order and it works as we expected. Explicitly calling some parent class's init method might ignore some other parent classes and is error prone.

self.var2idx: Dict[Var, Expr] = {}
self.input_accessed_indices: Dict[TensorInput, List[Expr]] = {}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Dict[TensorInput, List[List[Expr]]] ?

self.reduce_axis_shapes: Dict[Var, Expr] = {}

def visit_Var(self, e: Var):
if e in self.var2idx:
return self.var2idx[e]
else:
return super().visit_Var(e)

def visit_TensorElement(self, e: TensorElement):
tensor = e.base
indices = tuple((self.visit(idx) for idx in e.indices))
if isinstance(tensor, GridCompute):
# passing the indices to the dependending grids's axes
for axis, index in zip(tensor.axes, indices):
self.var2idx[axis] = index
elif isinstance(tensor, TensorInput):
if tensor not in self.input_accessed_indices:
self.input_accessed_indices[tensor] = []
self.input_accessed_indices[tensor].append(indices)
else:
raise ValueError()
tensor = self.visit(tensor)
return TensorElement(tensor, indices)

def visit_ReduceCompute(self, node: ReduceCompute):
for axis, axis_len in zip(node.axes, node.shape):
self.reduce_axis_shapes[axis] = axis_len
return super().visit_ReduceCompute(node)

def visit_ArgReduceCompute(self, node: ArgReduceCompute):
self.reduce_axis_shapes[node.axis] = node.extent
return super().visit_ArgReduceCompute(node)


class DataDependencyAnalyzer:
def __init__(self, op: Operator, num_shards: int):
self.op: Operator = op
self.num_shards: int = num_shards

index_rewriter = IndexRewriter()
for o in self.op.task.outputs:
if not isinstance(o, GridCompute):
self.valid = False
return # we don't support scalar output
index_rewriter.visit(o)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
index_rewriter = IndexRewriter()
for o in self.op.task.outputs:
if not isinstance(o, GridCompute):
self.valid = False
return # we don't support scalar output
index_rewriter.visit(o)
index_rewriter = IndexRewriter()
index_rewriter.visit(self.op.task.outputs)
self.valid = all(isinstance(out, GridCompute) for out in self.op.task.outputs)

self.input_accessed_indices = index_rewriter.input_accessed_indices
self.reduce_axis_shapes = index_rewriter.reduce_axis_shapes
self.valid = True

def check(self, in_shard_dim: List[int], out_shard_dim: List[int], shard_reduce: bool = False):
# Now only supports 1D partition
task = self.op.task

def _bound_info(rank, shape, is_shard=False):
shape = int(shape)
if is_shard:
shard_size = shape // self.num_shards
return BoundInfo(min_value=rank * shard_size, max_value=(rank + 1) * shard_size - 1)
else:
return BoundInfo(min_value=0, max_value=shape - 1)

def _check(sharded_reduce_axes=None):
if sharded_reduce_axes is None:
sharded_reduce_axes = []
# Enumerate all ranks
for rank in range(self.num_shards):
out_and_reduce_bound = {}
# Generate the premises of output indices
for o, shard_dim in zip(task.outputs, out_shard_dim):
if not isinstance(o, GridCompute):
return False
for i, (axis, shape) in enumerate(zip(o.axes, o.shape)):
out_and_reduce_bound[axis] = _bound_info(rank, shape, i == shard_dim)

for axis, shape in self.reduce_axis_shapes.items():
out_and_reduce_bound[axis] = _bound_info(rank, shape, axis in sharded_reduce_axes)
# here we assume reduction axis won't be reused

for input_tensor, accessed_indices in self.input_accessed_indices.items():
shard_dim = in_shard_dim[task.inputs.index(input_tensor)]
if shard_dim < 0:
continue
for indices in accessed_indices:
idx = indices[shard_dim]
shape = int(input_tensor.shape[shard_dim])
shard_size = shape // self.num_shards
lower, upper = rank * shard_size, (rank + 1) * shard_size
bound = infer_bound(idx, out_and_reduce_bound)[idx]
if bound.possible_min_value() is None or bound.possible_max_value() is None:
return False
if bound.possible_min_value() < lower or bound.possible_max_value() >= upper:
return False
return True

if shard_reduce:
# We only analyze sharding reduction axis for single-output op
if len(out_shard_dim) > 1:
return False

o = task.outputs[0]
# We only consider the outer most reduction axis
if not isinstance(o.value, ReduceCompute):
return False

if str(o.value.reduce_operation) not in ('sum', 'prod', 'max', 'min', 'avg'):
# other reduce ops are not supported by NCCL
return False

# Try reduction on each axis to see if the result can be fixed.
for shard_axis in o.value.axes:
if _check([shard_axis]):
return True
return False
else:
return _check()


def op_shard_rule_search(op: Operator, num_shards: int) -> List[OpShardSpec]:
# Now we only search for 1D partition
inputs = op.inputs
outputs = op.outputs
found_rules = []

# We do not allow dynamic shapes
assert len(op.task.symbols) == 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we raise NotImpleentError instead?


# num_shards == 1 will break following logic
if num_shards == 1:
return [
OpShardSpec(
[TensorShardSpec(len(t.shape)) for t in op.inputs], [TensorShardSpec(len(t.shape)) for t in op.outputs]
xinli-git marked this conversation as resolved.
Show resolved Hide resolved
)
]

# Initialize data dependency analyzer
# If it is not valid, it means we meet some scenarios we cannot analyze
# And we won't shard this op
data_dependency_analyzer = DataDependencyAnalyzer(op, num_shards)
if not data_dependency_analyzer.valid:
return []

# enumerate all possible input shardings. -1 means duplicate
for in_shard_dims in product(*(range(-1, len(i.shape)) for i in inputs)):
# Get a tile of each input tensor according to the input sharding
# And compute the output shape
try:
new_inputs = []
for i, shard_dim in enumerate(in_shard_dims):
if shard_dim >= 0:
if inputs[i].shape[shard_dim] % num_shards != 0:
break
new_inputs.append(get_tile(inputs[i], shard_dim, num_shards))
else:
new_inputs.append(inputs[i])
if len(new_inputs) < len(inputs):
continue
outputs = op.reforward(new_inputs)
except (ValueError, RuntimeError, AssertionError, IndexError):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

how do we arrive at these exceptions? I assume through trial and error?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. For example, if we have an op that adds up two 4x4 matrices, and we try to shard one of them along the first dimension, then it becomes 2x4 + 4x4, which is illegal.

continue

# find the sharded output dimension
# for each output tensor, there should be at most one dimension being sharded
out_shard_dims = []
for i, out in enumerate(outputs):
shard_dim = -1
origin_shape = op.outputs[i].shape
for dim in range(len(out.shape)):
if out.shape[dim] == origin_shape[dim] // num_shards:
if shard_dim != -1:
shard_dim = None # invalid output shape
# Because there should be at most one sharded dimension
else:
shard_dim = dim
elif out.shape[dim] != origin_shape[dim]:
shard_dim = None # invalid output shape
if shard_dim is not None: # output is invalid result of a sharding
out_shard_dims.append(shard_dim)
else:
break

if len(out_shard_dims) == len(outputs): # shape analysis valid
# data dependency analysis
in_specs = [TensorShardSpec(len(i.shape), shard_dim) for i, shard_dim in zip(inputs, in_shard_dims)]
out_specs = [TensorShardSpec(len(o.shape), shard_dim) for o, shard_dim in zip(outputs, out_shard_dims)]
if data_dependency_analyzer.check(in_shard_dims, out_shard_dims):
found_rules.append(OpShardSpec(in_specs, out_specs))
else: # Failed, but it still can be partial results. Try to fix it with reduction.
if data_dependency_analyzer.check(in_shard_dims, out_shard_dims, shard_reduce=True):
reduce_op = op.task.outputs[0].value.reduce_operation
found_rules.append(OpShardSpec(in_specs, out_specs, reduce_fn=[AllReduce(reduce_op)]))
# if all_reduce is valid, then reduce_scatter is also valid
output = op.task.outputs[0]
if output.shape[0] % num_shards == 0:
out_spec = TensorShardSpec(len(output.shape), 0)
found_rules.append(OpShardSpec(in_specs, [out_spec], reduce_fn=[ReduceScatter(reduce_op)]))

return found_rules
Loading