-
Notifications
You must be signed in to change notification settings - Fork 0
/
distributed_utils.py
106 lines (89 loc) · 3.19 KB
/
distributed_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import torch
# The following code is borrowed from Ross Wightman implementation at
# https://github.com/mlfoundations/open_clip/blob/a5ba05f7cab5ddab7c9967bfb8bbef303be6f3aa/src/open_clip/loss.py
# The code is borrowed for the purpose of testing the correctness of the Sigmoid Loss
def neighbour_exchange(from_rank, to_rank, tensor, group=None):
tensor_recv = torch.zeros_like(tensor)
send_op = torch.distributed.P2POp(
torch.distributed.isend,
tensor,
to_rank,
group=group,
)
recv_op = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_recv,
from_rank,
group=group,
)
reqs = torch.distributed.batch_isend_irecv([send_op, recv_op])
for req in reqs:
req.wait()
return tensor_recv
def neighbour_exchange_bidir(left_rank, right_rank, tensor_to_left, tensor_to_right, group=None):
tensor_from_left = torch.zeros_like(tensor_to_right)
tensor_from_right = torch.zeros_like(tensor_to_left)
send_op_left = torch.distributed.P2POp(
torch.distributed.isend,
tensor_to_left,
left_rank,
group=group,
)
send_op_right = torch.distributed.P2POp(
torch.distributed.isend,
tensor_to_right,
right_rank,
group=group,
)
recv_op_left = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_from_left,
left_rank,
group=group,
)
recv_op_right = torch.distributed.P2POp(
torch.distributed.irecv,
tensor_from_right,
right_rank,
group=group,
)
reqs = torch.distributed.batch_isend_irecv(
[send_op_right, send_op_left, recv_op_right, recv_op_left]
)
for req in reqs:
req.wait()
return tensor_from_right, tensor_from_left
class NeighbourExchange(torch.autograd.Function):
@staticmethod
def forward(ctx, from_rank, to_rank, group, tensor):
ctx.group = group
ctx.from_rank = from_rank
ctx.to_rank = to_rank
return neighbour_exchange(from_rank, to_rank, tensor, group=group)
@staticmethod
def backward(ctx, grad_output):
return (None, None, None) + (
NeighbourExchange.apply(ctx.to_rank, ctx.from_rank, ctx.group, grad_output),
)
def neighbour_exchange_with_grad(from_rank, to_rank, tensor, group=None):
return NeighbourExchange.apply(from_rank, to_rank, group, tensor)
class NeighbourExchangeBidir(torch.autograd.Function):
@staticmethod
def forward(ctx, left_rank, right_rank, group, tensor_to_left, tensor_to_right):
ctx.group = group
ctx.left_rank = left_rank
ctx.right_rank = right_rank
return neighbour_exchange_bidir(
left_rank, right_rank, tensor_to_left, tensor_to_right, group=group
)
@staticmethod
def backward(ctx, *grad_outputs):
return (None, None, None) + NeighbourExchangeBidir.apply(
ctx.right_rank, ctx.left_rank, ctx.group, *grad_outputs
)
def neighbour_exchange_bidir_with_grad(
left_rank, right_rank, tensor_to_left, tensor_to_right, group=None
):
return NeighbourExchangeBidir.apply(
left_rank, right_rank, group, tensor_to_left, tensor_to_right
)