Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] custom ring all-reduce #23

Merged
merged 26 commits into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
da8615a
feat: custom ring all-reduce
Jackmin801 Sep 30, 2024
e5b6d25
Feat allreduce bis (#24)
samsja Sep 30, 2024
761f6af
flatten storage in reduce
Jackmin801 Sep 30, 2024
dec9620
do compute comm interleave
Jackmin801 Sep 30, 2024
db4a11f
test quantized
Jackmin801 Oct 2, 2024
996ab21
add scripts bandwith
samsja Oct 2, 2024
be63b4f
halfway point
Jackmin801 Oct 3, 2024
c99d812
do the all gather in full prec for now
Jackmin801 Oct 3, 2024
5068c70
add quant test
Jackmin801 Oct 3, 2024
c2554ed
hivemind quantization attribution
Jackmin801 Oct 3, 2024
664f906
allow setting transfer_dtype in bench script
Jackmin801 Oct 3, 2024
a5318b9
first draft of speedups
Jackmin801 Oct 4, 2024
c4dc41f
eliminate expensive long cast in average tensors
Jackmin801 Oct 4, 2024
f93b19a
cleanup compress.cpp
Jackmin801 Oct 4, 2024
7279863
python wrap compression ops
Jackmin801 Oct 4, 2024
4da4c97
type hint fix
Jackmin801 Oct 4, 2024
7808c27
add tests
Jackmin801 Oct 4, 2024
5d0e704
fix naming
Jackmin801 Oct 4, 2024
b290a76
allow bf16 transfer dtype in bench script
Jackmin801 Oct 4, 2024
4246818
use c implementation of quantization
Jackmin801 Oct 4, 2024
5ccb3ca
gather in uint8
Jackmin801 Oct 4, 2024
5e616a4
self quantize to make reduce consistant across all nodes
Jackmin801 Oct 4, 2024
0b9f458
refactor: move csrc into C folder
Jackmin801 Oct 4, 2024
58d9b64
make tests faster
Jackmin801 Oct 4, 2024
c724f2a
refactor consistent module naming
Jackmin801 Oct 4, 2024
edcec90
even faster tests
Jackmin801 Oct 4, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions scripts/all_reduce_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from pydantic_config import BaseConfig, parse_argv
import torch
from torch.distributed import destroy_process_group, init_process_group
import torch.utils.benchmark as benchmark

from zeroband.collectives import AllReduceBackend, ALL_REDUCE_FN
from zeroband.utils.world_info import get_world_info
from zeroband.utils.logging import get_logger


class Config(BaseConfig):
size_model: int = int(1e9)
n_iters: int = 5
backend: AllReduceBackend = AllReduceBackend.GLOO


def main(config: Config):
world_info = get_world_info()

mat = torch.rand(1, config.size_model)

logger.info(
f"\n ======== Benchmark all reduce between {world_info.world_size} gpus over {world_info.nnodes} nodes =========\n"
)

all_reduce = ALL_REDUCE_FN[config.backend]
Jackmin801 marked this conversation as resolved.
Show resolved Hide resolved

t0 = benchmark.Timer(stmt="all_reduce(mat)", globals={"all_reduce": all_reduce, "mat": mat})
measured_time = t0.timeit(config.n_iters).mean

bandwidth = config.size_model * 4 / 1e9 / measured_time

logger.info(f"Average time per iteration: {measured_time:.2f} seconds, Average bandwidth: {bandwidth:.2f} GB/s")


if __name__ == "__main__":
config = Config(**parse_argv())

torch.set_float32_matmul_precision("high")
init_process_group(backend="gloo")

logger = get_logger()
main(config)
destroy_process_group()
24 changes: 24 additions & 0 deletions scripts/bandwith/down.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
#!/bin/bash

# Check if the script is run as root
if [ "$EUID" -ne 0 ]; then
echo "Please run as root"
exit 1
fi

# Define variables
INTERFACE="lo" # localhost interface
RATE="500mbit" # 500 Mbps
BURST="500k" # burst size
LATENCY="50ms" # maximum latency

# Remove any existing traffic control rules on the interface
tc qdisc del dev $INTERFACE root 2>/dev/null

# Add the rate limiting rule
tc qdisc add dev $INTERFACE root tbf rate $RATE burst $BURST latency $LATENCY

echo "Bandwidth limit of $RATE has been set on $INTERFACE"

# To remove the limit, run:
# tc qdisc del dev $INTERFACE root
1 change: 1 addition & 0 deletions scripts/bandwith/up.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
tc qdisc del dev lo root
164 changes: 164 additions & 0 deletions src/zeroband/collectives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
from enum import Enum
from typing import Callable, Optional, TypeAlias
import torch
import torch.distributed as dist

AllReduceFunc: TypeAlias = Callable[
[torch.Tensor, dist.ReduceOp, Optional[dist.ProcessGroup], Optional[torch.dtype]], None
]


def gloo_all_reduce(
tensor: torch.Tensor,
op: dist.ReduceOp = dist.ReduceOp.SUM,
group: Optional[dist.ProcessGroup] = None,
transfer_dtype: Optional[torch.dtype] = None,
) -> None:
"""Wrap gloo all reduce"""
if transfer_dtype is None:
transfer_dtype = tensor.dtype
if group is None:
group = dist.distributed_c10d._get_default_group()
if op not in [dist.ReduceOp.SUM, dist.ReduceOp.AVG]:
raise ValueError(f"Unsupported reduce operation {op}. Only SUM and AVG are supported.")

# group = cast(dist.ProcessGroup, group) # just type hint stuff for IDE
if op == dist.ReduceOp.AVG:
# todo check numerical stability of doing post or pre div
tensor.div_(group.size())

tensor = tensor.to(transfer_dtype) # todo is it no op ?

dist.all_reduce(tensor, op, group=group)


BUFFER_COUNT = 2


def ring_allreduce(
tensor: torch.Tensor,
op: dist.ReduceOp = dist.ReduceOp.SUM,
group: Optional[dist.ProcessGroup] = None,
transfer_dtype: Optional[torch.dtype] = None,
quantization_func: Optional[Callable] = None,
) -> None:
"""
Perform all-reduce on a tensor using ring algorithm.
The accumulation will be done in-place on the input tensor.
The transfers will be done using the specified transfer_dtype.
"""
if quantization_func is not None:
if transfer_dtype is not None:
raise ValueError("Quantization and transfer_dtype cannot be used together")
transfer_dtype = tensor.dtype
if transfer_dtype is None:
transfer_dtype = tensor.dtype
if group is None:
group = dist.distributed_c10d._get_default_group()
if op not in [dist.ReduceOp.SUM, dist.ReduceOp.AVG]:
raise ValueError(f"Unsupported reduce operation {op}. Only SUM and AVG are supported.")

world_size = group.size()
rank = group.rank()

# Divide the tensor into chunks
flat_tensor = tensor.as_strided((tensor.numel(),), (1,))
chunks = flat_tensor.chunk(world_size * BUFFER_COUNT)

assert flat_tensor.size(0) % (world_size * BUFFER_COUNT) == 0, "Tensor size must be divisible by world size"

# Temporary buffers for transferring data
num_buffers = BUFFER_COUNT * world_size
if quantization_func is not None:
recv_buffer = [torch.empty_like(chunks[0], dtype=torch.uint8) for _ in range(BUFFER_COUNT)]
send_buffer = [None for _ in range(BUFFER_COUNT)]
send_lookup_buffer = [None for _ in range(BUFFER_COUNT)]
recv_lookup_buffer = [torch.empty(256, dtype=chunks[0].dtype) for _ in range(BUFFER_COUNT)]
send_lookup_work = [None for _ in range(BUFFER_COUNT)]
recv_lookup_work = [None for _ in range(BUFFER_COUNT)]
else:
recv_buffer = [torch.empty_like(chunks[0], dtype=transfer_dtype) for _ in range(BUFFER_COUNT)]
send_buffer = [torch.empty_like(chunks[0], dtype=transfer_dtype) for _ in range(BUFFER_COUNT)]
send_work = [None] * BUFFER_COUNT
recv_work = [None] * BUFFER_COUNT

send_rank = (rank + 1) % world_size
recv_rank = (rank - 1) % world_size
for step in range(1, world_size * BUFFER_COUNT + 1):
send_chunk = (rank * BUFFER_COUNT - step) % num_buffers

if send_work[step % BUFFER_COUNT] is not None:
send_work[step % BUFFER_COUNT].wait()
recv_work[step % BUFFER_COUNT].wait()
if quantization_func is not None:
send_lookup_work[step % BUFFER_COUNT].wait()
recv_lookup_work[step % BUFFER_COUNT].wait()
# print(recv_lookup_buffer[step % BUFFER_COUNT][recv_buffer[step % BUFFER_COUNT].long()])
chunks[send_chunk].add_(
recv_lookup_buffer[step % BUFFER_COUNT][recv_buffer[step % BUFFER_COUNT].long()]
)
else:
chunks[send_chunk].add_(recv_buffer[step % BUFFER_COUNT])

if step <= (world_size - 1) * BUFFER_COUNT:
# Send and receive
if quantization_func is not None:
send_buffer[step % BUFFER_COUNT], send_lookup_buffer[step % BUFFER_COUNT] = quantization_func(
chunks[send_chunk]
)
send_lookup_work[step % BUFFER_COUNT] = dist.isend(
send_lookup_buffer[step % BUFFER_COUNT], dst=send_rank, group=group, tag=step + 1000
)
recv_lookup_work[step % BUFFER_COUNT] = dist.irecv(
recv_lookup_buffer[step % BUFFER_COUNT], src=recv_rank, group=group, tag=step + 1000
)
else:
send_buffer[step % BUFFER_COUNT].copy_(chunks[send_chunk])
send_work[step % BUFFER_COUNT] = dist.isend(
send_buffer[step % BUFFER_COUNT], dst=send_rank, group=group, tag=step
)
recv_work[step % BUFFER_COUNT] = dist.irecv(
recv_buffer[step % BUFFER_COUNT], src=recv_rank, group=group, tag=step
)

if op == dist.ReduceOp.AVG:
for i in range(BUFFER_COUNT):
chunks[i + rank * BUFFER_COUNT].divide_(world_size)

# TODO: Maybe have an option to all gather in lower precision

if quantization_func is not None:
send_lookup_work = [None for _ in range(BUFFER_COUNT)]
recv_lookup_work = [None for _ in range(BUFFER_COUNT)]
recv_buffer = [torch.empty_like(chunks[0], dtype=transfer_dtype) for _ in range(BUFFER_COUNT)]
send_buffer = [torch.empty_like(chunks[0], dtype=transfer_dtype) for _ in range(BUFFER_COUNT)]
send_work = [None] * BUFFER_COUNT
recv_work = [None] * BUFFER_COUNT
for step in range(1, world_size * BUFFER_COUNT + 1):
send_chunk = (rank * BUFFER_COUNT + BUFFER_COUNT - step) % num_buffers

if send_work[step % BUFFER_COUNT] is not None:
send_work[step % BUFFER_COUNT].wait()
recv_work[step % BUFFER_COUNT].wait()
chunks[send_chunk].copy_(recv_buffer[step % BUFFER_COUNT])

if step <= (world_size - 1) * BUFFER_COUNT:
# Send and receive
send_buffer[step % BUFFER_COUNT].copy_(chunks[send_chunk])
send_work[step % BUFFER_COUNT] = dist.isend(
send_buffer[step % BUFFER_COUNT], dst=send_rank, group=group, tag=step
)
recv_work[step % BUFFER_COUNT] = dist.irecv(
recv_buffer[step % BUFFER_COUNT], src=recv_rank, group=group, tag=step
)


class AllReduceBackend(Enum):
GLOO = "gloo"
CUSTOM = "custom"


ALL_REDUCE_FN: dict[AllReduceBackend, AllReduceFunc] = {
AllReduceBackend.GLOO: gloo_all_reduce,
AllReduceBackend.CUSTOM: ring_allreduce,
}
68 changes: 68 additions & 0 deletions src/zeroband/compression.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import torch
import numpy as np
from typing import Tuple
import math
from concurrent.futures import ThreadPoolExecutor
import os

Jackmin801 marked this conversation as resolved.
Show resolved Hide resolved
RANGE_IN_SIGMAS: int = 6
EXECUTOR = ThreadPoolExecutor(max_workers=int(os.environ.get("QUANTIZATION_THREADS", 128)))


def average_buckets(tensor: torch.Tensor, quant_weight: torch.Tensor, n_bins: int):
"""Return the average value in each bucket"""
bin_sums = torch.zeros(n_bins).scatter_add_(0, quant_weight.flatten().long(), tensor.flatten())
bin_counts = torch.clamp_min_(torch.bincount(quant_weight.flatten(), minlength=n_bins), 1)
lookup = bin_sums / bin_counts
return lookup


def get_chunk_size(num_elements: int, min_chunk_size: int) -> int:
"""Adjust chunk_size to minimize imbalance between chunk sizes"""
if min_chunk_size >= num_elements:
return min_chunk_size
leftover_elements = num_elements % min_chunk_size
num_chunks = num_elements // min_chunk_size
return min_chunk_size + (leftover_elements - 1) // num_chunks + 1


def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_size: int = 10**5) -> np.ndarray:
"""Estimate uniform quantiles of data using quantile-of-quantiles. Runs in parallel."""
if not array.data.c_contiguous and array.data.f_contiguous:
array = array.T
array = np.ascontiguousarray(array.reshape(-1))
quantiles = np.linspace(0.0, 1.0, num=n_quantiles, dtype=array.dtype)
chunk_size = get_chunk_size(len(array), min_chunk_size)
num_chunks = (len(array) - 1) // chunk_size + 1
partition_quantiles = np.empty((num_chunks, len(quantiles)), dtype=array.dtype)

jobs = []
for i in range(num_chunks):
chunk = slice(chunk_size * i, chunk_size * (i + 1))
jobs.append(EXECUTOR.submit(np.quantile, array[chunk], quantiles, out=partition_quantiles[i]))

for job in jobs:
job.result()
return np.quantile(partition_quantiles, quantiles)


n_bins = 2**8


def uniform_8bit_quantize(tensor: torch.Tensor, inplace: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
offset = n_bins // 2
# shift = tensor.mean()
# centered_tensor = tensor.sub_(shift) if inplace else tensor - shift
centered_tensor = tensor
std_unbiased = centered_tensor.norm() / math.sqrt(centered_tensor.numel() - 1)
scale = RANGE_IN_SIGMAS * std_unbiased / n_bins
quantized = torch.quantize_per_tensor(centered_tensor, scale, offset, torch.quint8).int_repr()
lookup = average_buckets(tensor, quantized, n_bins)
return quantized, lookup


def quantile_8bit_quantize(tensor: torch.Tensor, inplace: bool = True) -> Tuple[torch.Tensor, torch.Tensor]:
borders = torch.as_tensor(quantile_qq_approximation(tensor.numpy(), n_bins + 1)[1:-1])
quantized = torch.clamp_(torch.bucketize(tensor, borders), 0, n_bins - 1)
lookup = average_buckets(tensor, quantized, n_bins)
return quantized, lookup
4 changes: 2 additions & 2 deletions tests/test_dist/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def random_available_port():
def dist_environment() -> callable:
@contextmanager
def dist_environment(
random_available_port, rank=0, local_rank=0, world_size=1, local_world_size=1, global_unique_id=""
random_available_port, backend=None, rank=0, local_rank=0, world_size=1, local_world_size=1, global_unique_id=""
):
with mock.patch.dict(
os.environ,
Expand All @@ -62,7 +62,7 @@ def dist_environment(
},
):
try:
init_process_group()
init_process_group(backend=backend)
torch.cuda.set_device(local_rank)
yield
finally:
Expand Down
30 changes: 0 additions & 30 deletions tests/test_dist/test_all_reduce.py

This file was deleted.

Loading