-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodels.py
99 lines (81 loc) · 2.78 KB
/
models.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
from functools import partial
from typing import Any, Callable, Sequence, Tuple
from flax import linen as nn
import jax.numpy as jnp
ModuleDef = Any
class ResNetBlock(nn.Module):
filters: int
conv: ModuleDef
norm: ModuleDef
act: Callable
strides: Tuple[int, int] = (1, 1)
@nn.compact
def __call__(self, x):
residual = x
y = self.conv(self.filters, (3, 3), self.strides)(x)
y = self.norm()(y)
y = self.act(y)
y = self.conv(self.filters, (3, 3))(y)
y = self.norm(scale_init=nn.initializers.zeros)(y)
if residual.shape != y.shape:
residual = self.conv(self.filters, (1, 1),
self.strides, name='conv_proj')(residual)
residual = self.norm(name='norm_proj')(residual)
return self.act(residual + y)
class BottleneckResNetBlock(nn.Module):
filters: int
conv: ModuleDef
norm: ModuleDef
act: Callable
strides: Tuple[int, int] = (1, 1)
@nn.compact
def __call__(self, x):
residual = x
y = self.conv(self.filters, (1, 1))(x)
y = self.norm()(y)
y = self.act(y)
y = self.conv(self.filters, (3, 3), self.strides)(y)
y = self.norm()(y)
y = self.act(y)
y = self.conv(self.filters * 4, (1, 1))(y)
y = self.norm(scale_init=nn.initializers.zeros)(y)
if residual.shape != y.shape:
residual = self.conv(self.filters * 4, (1, 1),
self.strides, name='conv_proj')(residual)
residual = self.norm(name='norm_proj')(residual)
return self.act(residual + y)
class ResNet(nn.Module):
stage_sizes: Sequence[int]
block_cls: ModuleDef
num_classes: int
num_filters: int = 64
dtype: Any = jnp.float32
act: Callable = nn.relu
@nn.compact
def __call__(self, x, train: bool = True):
conv = partial(nn.Conv, use_bias=False, dtype=self.dtype)
norm = partial(nn.BatchNorm,
use_running_average=not train,
momentum=0.8,
epsilon=1e-5,
dtype=self.dtype)
x = conv(self.num_filters, (7, 7), (2, 2),
padding=[(3, 3), (3, 3)],
name='conv_init')(x)
x = norm(name='bn_init')(x)
x = nn.relu(x)
x = nn.max_pool(x, (3, 3), strides=(2, 2), padding='SAME')
for i, block_size in enumerate(self.stage_sizes):
for j in range(block_size):
strides = (2, 2) if i > 0 and j == 0 else (1, 1)
x = self.block_cls(self.num_filters * 2 ** i,
strides=strides,
conv=conv,
norm=norm,
act=self.act)(x)
x = jnp.mean(x, axis=(1, 2))
x = nn.Dense(self.num_classes, dtype=self.dtype)(x)
x = jnp.asarray(x, self.dtype)
return x
ResNet50 = partial(ResNet, stage_sizes=[3, 4, 6, 3],
block_cls=BottleneckResNetBlock)