Skip to content

Commit

Permalink
weight_c_init
Browse files Browse the repository at this point in the history
  • Loading branch information
hpasapp committed Dec 15, 2020
1 parent 1569f31 commit 5688728
Showing 1 changed file with 16 additions and 5 deletions.
21 changes: 16 additions & 5 deletions sru/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,8 @@ def __init__(self,
rescale: bool = True,
v1: bool = False,
custom_m: Optional[nn.Module] = None,
amp_recurrence_fp16: bool = False):
amp_recurrence_fp16: bool = False,
weight_c_init: Optional[float] = None):
"""Initialize the SRUCell module.
Parameters
Expand Down Expand Up @@ -94,6 +95,8 @@ def __init__(self,
When using AMP autocast, selects which type to use
for recurrence custom kernel.
False: torch.float32, True: torch.float16
weight_c_init: Optional[float]
if not None, then size of uniform initiatialization of weight_c
"""
super(SRUCell, self).__init__()
self.input_size = input_size
Expand All @@ -113,6 +116,7 @@ def __init__(self,
self.activation_type = 1
self.activation = 'tanh'
self.amp_recurrence_fp16 = amp_recurrence_fp16
self.weight_c_init = weight_c_init

# projection dimension
self.projection_size = 0
Expand Down Expand Up @@ -207,13 +211,16 @@ def reset_parameters(self):

if not self.v1:
# intialize weight_c such that E[w]=0 and Var[w]=1
self.weight_c.data.uniform_(-3.0**0.5, 3.0**0.5)
if self.weight_c_init is None:
self.weight_c.data.uniform_(-3.0**0.5, 3.0**0.5)
self.weight_c.data.mul_(0.5**0.5)
else:
self.weight_c.data.uniform_(-self.weight_c_init, self.weight_c_init)

# rescale weight_c and the weight of sigmoid gates with a factor of sqrt(0.5)
if self.custom_m is None:
w[:, :, :, 1].mul_(0.5**0.5)
w[:, :, :, 2].mul_(0.5**0.5)
self.weight_c.data.mul_(0.5**0.5)
else:
self.weight_c.data.zero_()
self.weight_c.requires_grad = False
Expand Down Expand Up @@ -427,7 +434,8 @@ def __init__(self,
nn_rnn_compatible_return: bool = False,
custom_m: Optional[Union[nn.Module, List[nn.Module]]] = None,
proj_input_to_hidden_first: bool = False,
amp_recurrence_fp16: bool = False):
amp_recurrence_fp16: bool = False,
weight_c_init: Optional[float] = None):
"""Initialize the SRU module.
Parameters
Expand Down Expand Up @@ -486,6 +494,8 @@ def __init__(self,
When using AMP autocast, selects which type to use
for recurrence custom kernel.
False: torch.float32, True: torch.float16
weight_c_init: Optional[float]
if not None, then size of uniform initiatialization of weight_c
"""

Expand Down Expand Up @@ -538,7 +548,8 @@ def __init__(self,
rescale=rescale,
v1=v1,
custom_m=custom_m_i,
amp_recurrence_fp16=amp_recurrence_fp16
amp_recurrence_fp16=amp_recurrence_fp16,
weight_c_init=weight_c_init,
)
rnn_lst.append(layer_i)
self.rnn_lst = rnn_lst
Expand Down

0 comments on commit 5688728

Please sign in to comment.