Skip to content

Commit 0f5ef65

Browse files
author
Omar
committed
massive bgd updates
1 parent cc93737 commit 0f5ef65

File tree

2 files changed

+90
-86
lines changed

2 files changed

+90
-86
lines changed

bouncing_gd_np.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,10 +9,10 @@
99
type Vec = npt.NDArray[float] # can represent a matrix as well
1010
type Input = csr_matrix | Vec
1111
type LossFunc = Callable[[Input, Vec], float]
12-
type Gradient = Callable[[Input, Vec], Vec]
12+
type GradFunc = Callable[[Input, Vec], Vec]
1313

1414

15-
def gd(data: Input, loss_f: LossFunc, gradient: Gradient, lr=1, epochs=50, seed=0) -> tuple[Vec, list[float]]:
15+
def gd(data: Input, loss_f: LossFunc, gradient: GradFunc, lr=1, epochs=50, seed=0) -> tuple[Vec, list[float]]:
1616
np.random.seed(seed)
1717
weight = np.random.randn(data.shape[-1])
1818
losses = [loss_f(data, weight)]
@@ -25,22 +25,24 @@ def gd(data: Input, loss_f: LossFunc, gradient: Gradient, lr=1, epochs=50, seed=
2525
return weight, losses
2626

2727

28-
def bouncy_gd(data: Input, loss_f: LossFunc, gradient: Gradient, lr=1, epochs=50, TH=0.7, seed=0, beta=1) -> tuple[Vec, list[float]]:
28+
def bouncy_gd(data: Input, loss_f: LossFunc, gradient: GradFunc, lr=1, epochs=50, TH=0.7, seed=0, beta=0.995) -> tuple[Vec, list[float]]:
2929
np.random.seed(seed)
30-
weight = np.random.randn(data.shape[-1])
30+
feature_dimensions = data.shape[-1]
31+
weight = np.random.randn(feature_dimensions)
3132
losses = [loss_f(data, weight)]
32-
lr = np.ones(data.shape[-1]) * lr # alpha
33+
lr = np.ones(feature_dimensions) * lr # per weight (parameter) adaptive learning rate
3334
sw = 1 # v_t
35+
e = 1e-08
3436

3537
def dist(g1: Vec, g2: Vec) -> Vec:
36-
e = 1e-05
3738
flatness_1, flatness_2 = np.linalg.norm(g1), np.linalg.norm(g2)
3839
dists = np.array([flatness_2, flatness_1])
3940
return dists / (dists.sum() + e)
4041

4142
for _ in trange(epochs):
4243
g = gradient(data, weight)
43-
sw = beta * sw + abs(g)
44+
sw = beta * sw + np.abs(g) # second moment
45+
4446
oracle = weight - g * lr
4547
g_orc = gradient(data, oracle)
4648
if g @ g_orc < 0:

bouncing_gd_torch.py

Lines changed: 81 additions & 79 deletions
Original file line numberDiff line numberDiff line change
@@ -5,56 +5,62 @@
55
from torch.utils.data import DataLoader
66
from tqdm import trange
77
import numpy as np
8-
import numpy.typing as npt
98
import torch.nn.functional as F
109
from matplotlib import pyplot as plt
10+
import random
1111

1212
# plt.style.use('seaborn')
13-
plt.rcParams['figure.autolayout'] = True
13+
# plt.rcParams['figure.autolayout'] = True
14+
15+
torch.cuda.empty_cache()
16+
device = torch.device('cpu')
1417

1518
norm = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
16-
T = transforms.Compose([transforms.ToTensor(), transforms.Lambda(torch.flatten)])
17-
train_data = datasets.MNIST(root=f'datasets/MNIST', download=True, transform=T) # (BS, 784)
19+
T = transforms.Compose([transforms.ToTensor(), transforms.Lambda(lambda x: torch.cat((x.flatten(), torch.Tensor([1]))))])
20+
train_data = datasets.MNIST(root=f'datasets/MNIST', download=True, transform=T) # (BS, 785)
1821
test_data = datasets.MNIST(root=f'datasets/MNIST', train=False, transform=T)
1922

20-
BS = 100
21-
LR = 10
22-
EPOCHS = 5
23+
BS: int = 100
24+
LR: float = 10.
25+
EPOCHS: int = 5
26+
SEED: int = 0
2327

2428
# SGD:
25-
torch.manual_seed(0)
26-
np.random.seed(0)
29+
torch.manual_seed(SEED)
30+
np.random.seed(SEED)
31+
random.seed(SEED)
2732

28-
train_loader = DataLoader(train_data, batch_size=BS, shuffle=True, pin_memory=True)
29-
test_loader = DataLoader(test_data, batch_size=BS, pin_memory=True)
33+
train_loader = DataLoader(train_data, batch_size=BS, num_workers=0, shuffle=True, pin_memory=True, generator=torch.Generator().manual_seed(SEED))
34+
test_loader = DataLoader(test_data, batch_size=BS, num_workers=0, pin_memory=True)
3035

31-
w = torch.randn(784, 100, requires_grad=True)
32-
b = torch.randn(100, requires_grad=True)
33-
wh = torch.randn(100, 10, requires_grad=True)
34-
bh = torch.randn(10, requires_grad=True)
36+
w = torch.randn(784, 100, requires_grad=True, device=device)
37+
b = torch.randn(100, requires_grad=True, device=device)
38+
wh = torch.randn(100, 10, requires_grad=True, device=device)
39+
bh = torch.randn(10, requires_grad=True, device=device)
3540

3641
losses_sgd = []
3742
for _ in trange(EPOCHS):
3843
loss_batch = []
3944
for x, y in train_loader:
40-
out = torch.softmax(torch.tanh(x @ w + b) @ wh + bh, dim=1)
45+
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
46+
out = torch.tanh(x @ w + b) @ wh + bh # raw logits
4147
loss = F.cross_entropy(out, y)
4248
loss_batch.append(loss.item())
43-
grad_w = torch.autograd.grad(loss, w, retain_graph=True)[0] # (784, 100)
44-
grad_b = torch.autograd.grad(loss, b, retain_graph=True)[0] # (100)
45-
grad_wh = torch.autograd.grad(loss, wh, retain_graph=True)[0] # (100, 10)
46-
grad_bh = torch.autograd.grad(loss, bh)[0] # (10)
47-
with torch.no_grad():
48-
w -= grad_w * LR
49-
b -= grad_b * LR
50-
wh -= grad_wh * LR
51-
bh -= grad_bh * LR
49+
grad_w, grad_b, grad_wh, grad_bh = torch.autograd.grad(loss, (w, b, wh, bh))
50+
51+
with torch.no_grad(): # computations below are untracked as tensors are treated as "detached" tensors
52+
w -= grad_w * LR # (784, 100)
53+
b -= grad_b * LR # (100)
54+
wh -= grad_wh * LR # (100, 10)
55+
bh -= grad_bh * LR # (10)
56+
5257
losses_sgd.append(np.mean(loss_batch))
5358
plt.semilogy(losses_sgd, label='SGD')
5459

5560
acc = []
5661
with torch.no_grad():
5762
for x, y in test_loader:
63+
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
5864
out = torch.tanh(x @ w + b) @ wh + bh
5965
_, pred = out.max(1)
6066
acc.append((pred == y).float().mean().item())
@@ -63,86 +69,82 @@
6369
############################################################################################################################################################
6470

6571
# BGD:
66-
torch.manual_seed(0)
67-
np.random.seed(0)
72+
torch.manual_seed(SEED)
73+
np.random.seed(SEED)
74+
random.seed(SEED)
75+
76+
train_loader = DataLoader(train_data, batch_size=BS, num_workers=0, shuffle=True, pin_memory=True, generator=torch.Generator().manual_seed(SEED))
77+
test_loader = DataLoader(test_data, batch_size=BS, num_workers=0, pin_memory=True)
6878

69-
train_loader = DataLoader(train_data, batch_size=BS, shuffle=True, pin_memory=True)
70-
test_loader = DataLoader(test_data, batch_size=BS, pin_memory=True)
79+
w = torch.randn(785, 100, requires_grad=True, device=device) # each column corresponds to an output node (neuron) in the network
80+
wh = torch.randn(101, 10, requires_grad=True, device=device)
7181

72-
w = torch.randn(784, 100, requires_grad=True)
73-
b = torch.randn(100, requires_grad=True)
74-
wh = torch.randn(100, 10, requires_grad=True)
75-
bh = torch.randn(10, requires_grad=True)
82+
TH = 0.9 # ThreshHold value is inversely proportional to the initial learning rate
83+
EPS = torch.Tensor([1e-11])
84+
BIAS = torch.ones(100, device=device).unsqueeze(1)
85+
SHRINK = torch.tensor(1.1, device=device)
86+
LRw = torch.ones(100, device=device) * LR # per output node (neuron) adaptive learning rate (not per parameter/weight)
87+
LRwh = torch.ones(10, device=device) * LR
7688

7789

78-
def dist(g1: torch.Tensor, g2: torch.Tensor) -> npt.NDArray[float]:
79-
e = 1e-11
90+
def dist(g1: torch.Tensor, g2: torch.Tensor) -> torch.Tensor:
8091
flatness_1, flatness_2 = g1.norm().item(), g2.norm().item()
81-
dists = np.array([flatness_2, flatness_1])
82-
return dists / (dists.sum() + e)
92+
dists = torch.Tensor([flatness_2, flatness_1])
93+
return dists / (dists.sum() + EPS)
94+
8395

96+
def bounce_update(weight, oracle, weight_gradient, oracle_gradient, weight_LR) -> None:
97+
dot_prods: torch.Tensor = torch.einsum("kj,kj->j", weight_gradient, oracle_gradient)
98+
for i, dot_prod in enumerate(dot_prods):
99+
if dot_prod.item() < 0:
100+
# print('bounce')
101+
d1, d2 = dist(weight_gradient[:, i], oracle_gradient[:, i])
102+
if d1.item() > TH:
103+
weight_LR[i] /= SHRINK
104+
105+
weight[:, i] = weight[:, i] * d1 + oracle[:, i] * d2
106+
107+
else:
108+
weight[:, i] = oracle[:, i] - oracle_gradient[:, i] * weight_LR[i]
84109

85-
TH = 0.9 # TH value is inversely proportional to "ini"
86-
shrink = 1.1
87-
LRw = torch.ones(100) * LR
88-
LRwh = torch.ones(10) * LR
89110

90111
losses_bgd = []
91112
for _ in trange(EPOCHS):
92113
loss_batch = []
93114
for x, y in train_loader:
94-
out = torch.softmax(torch.tanh(x @ w + b) @ wh + bh, dim=1)
115+
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
116+
out = torch.cat(((x @ w).tanh(), BIAS), dim=1) @ wh
95117
loss = F.cross_entropy(out, y)
96118
loss_batch.append(loss.item())
97-
grad_w = torch.autograd.grad(loss, w, retain_graph=True)[0] # (784, 100)
98-
grad_b = torch.autograd.grad(loss, b, retain_graph=True)[0] # (100)
99-
grad_wh = torch.autograd.grad(loss, wh, retain_graph=True)[0] # (100, 10)
100-
grad_bh = torch.autograd.grad(loss, bh)[0] # (10)
101-
102-
# Find oracle weights:
103-
oracle_w = w - grad_w * LRw
104-
oracle_wh = wh - grad_wh * LRwh
119+
loss.backward()
120+
grad_w, grad_wh = w.grad, wh.grad
105121

106-
# Update bias vectors:
107-
with torch.no_grad():
108-
b -= grad_b * LR
109-
bh -= grad_bh * LR
122+
with torch.no_grad(): # computations below are untracked as tensors are treated as "detached" tensors
123+
# Find oracle weights:
124+
w_oracle = (w - grad_w * LRw).requires_grad_() # (784, 100)
125+
wh_oracle = (wh - grad_wh * LRwh).requires_grad_() # (100, 10)
110126

111-
out = torch.softmax(torch.tanh(x @ oracle_w + b) @ oracle_wh + bh, dim=1)
127+
out = torch.cat(((x @ w_oracle).tanh(), BIAS), dim=1) @ wh_oracle
112128
loss = F.cross_entropy(out, y)
113-
grad_orc_w = torch.autograd.grad(loss, oracle_w, retain_graph=True)[0] # (784, 100)
114-
grad_orc_wh = torch.autograd.grad(loss, oracle_wh, retain_graph=True)[0] # (100, 10)
129+
loss.backward()
130+
grad_w_orc, grad_wh_orc = w_oracle.grad, wh_oracle.grad
115131

116132
with torch.no_grad():
117-
# Update w:
118-
for i, (g, g_orc) in enumerate(zip(grad_w.T, grad_orc_w.T)): # (100, 784)
119-
if g @ g_orc < 0: # (784)
120-
# print('bounce')
121-
d1, d2 = dist(g, g_orc)
122-
if d1 > TH:
123-
LRw[i] /= shrink
124-
w[:, i] = w[:, i] * d1 + oracle_w[:, i] * d2
125-
else:
126-
w[:, i] = oracle_w[:, i] - g_orc * LRw[i]
127-
128-
# Update wh:
129-
for i, (g, g_orc) in enumerate(zip(grad_wh.T, grad_orc_wh.T)): # (10, 100)
130-
if g @ g_orc < 0: # (100)
131-
# print('bounce')
132-
d1, d2 = dist(g, g_orc)
133-
if d1 > TH:
134-
LRwh[i] /= shrink
135-
wh[:, i] = wh[:, i] * d1 + oracle_wh[:, i] * d2
136-
else:
137-
wh[:, i] = oracle_wh[:, i] - g_orc * LRwh[i]
133+
bounce_update(w, w_oracle, grad_w, grad_w_orc, LRw)
134+
bounce_update(wh, wh_oracle, grad_wh, grad_wh_orc, LRwh)
135+
136+
# Zero each parameters' gradients to avoid gradient accumulation across iterations:
137+
w.grad.zero_()
138+
wh.grad.zero_()
138139

139140
losses_bgd.append(np.mean(loss_batch))
140141
plt.semilogy(losses_bgd, label='BGD')
141142

142143
acc = []
143144
with torch.no_grad():
144145
for x, y in test_loader:
145-
out = torch.tanh(x @ w + b) @ wh + bh
146+
x, y = x.to(device, non_blocking=True), y.to(device, non_blocking=True)
147+
out = torch.cat(((x @ w).tanh(), BIAS), dim=1) @ wh
146148
_, pred = out.max(1)
147149
acc.append((pred == y).float().mean().item())
148150
print(f'Accuracy: {np.mean(acc) * 100:.4}%')

0 commit comments

Comments
 (0)