Skip to content

Commit

Permalink
Merge pull request #149 from asappresearch/smaller-weight-c
Browse files Browse the repository at this point in the history
weight_c_init
  • Loading branch information
taoleicn authored Dec 17, 2020
2 parents 318ab42 + 7c27bea commit 4f691ec
Showing 1 changed file with 17 additions and 6 deletions.
23 changes: 17 additions & 6 deletions sru/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ class SRUCell(nn.Module):
'dropout', 'bidirectional', 'has_skip_term', 'highway_bias',
'v1', 'rescale', 'activation_type', 'activation', 'custom_m',
'projection_size', 'num_matrices', 'layer_norm', 'weight_proj',
'scale_x']
'scale_x', 'normalize_after', 'weight_c_init',]

scale_x: Tensor
weight_proj: Optional[Tensor]
Expand All @@ -42,7 +42,8 @@ def __init__(self,
v1: bool = False,
custom_m: Optional[nn.Module] = None,
amp_recurrence_fp16: bool = False,
normalize_after: bool = False):
normalize_after: bool = False,
weight_c_init: Optional[float] = None):
"""Initialize the SRUCell module.
Parameters
Expand Down Expand Up @@ -97,6 +98,8 @@ def __init__(self,
False: torch.float32, True: torch.float16
normalize_after: bool
if True use post layer norm, else pre layer norm
weight_c_init: Optional[float]
if not None, then size of uniform initiatialization of weight_c
"""
super(SRUCell, self).__init__()
self.input_size = input_size
Expand All @@ -117,6 +120,7 @@ def __init__(self,
self.activation = 'tanh'
self.amp_recurrence_fp16 = amp_recurrence_fp16
self.normalize_after = normalize_after
self.weight_c_init = weight_c_init

# projection dimension
self.projection_size = 0
Expand Down Expand Up @@ -214,13 +218,16 @@ def reset_parameters(self):

if not self.v1:
# intialize weight_c such that E[w]=0 and Var[w]=1
self.weight_c.data.uniform_(-3.0**0.5, 3.0**0.5)
if self.weight_c_init is None:
self.weight_c.data.uniform_(-3.0**0.5, 3.0**0.5)
self.weight_c.data.mul_(0.5**0.5)
else:
self.weight_c.data.uniform_(-self.weight_c_init, self.weight_c_init)

# rescale weight_c and the weight of sigmoid gates with a factor of sqrt(0.5)
if self.custom_m is None:
w[:, :, :, 1].mul_(0.5**0.5)
w[:, :, :, 2].mul_(0.5**0.5)
self.weight_c.data.mul_(0.5**0.5)
else:
self.weight_c.data.zero_()
self.weight_c.requires_grad = False
Expand Down Expand Up @@ -439,7 +446,8 @@ def __init__(self,
custom_m: Optional[Union[nn.Module, List[nn.Module]]] = None,
proj_input_to_hidden_first: bool = False,
amp_recurrence_fp16: bool = False,
normalize_after: bool = False):
normalize_after: bool = False,
weight_c_init: Optional[float] = None):
"""Initialize the SRU module.
Parameters
Expand Down Expand Up @@ -500,6 +508,8 @@ def __init__(self,
False: torch.float32, True: torch.float16
normalize_after: bool
if True use post layer norm, else use pre layer norm
weight_c_init: Optional[float]
if not None, then size of uniform initiatialization of weight_c
"""

super(SRU, self).__init__()
Expand Down Expand Up @@ -552,7 +562,8 @@ def __init__(self,
v1=v1,
custom_m=custom_m_i,
amp_recurrence_fp16=amp_recurrence_fp16,
normalize_after=normalize_after
normalize_after=normalize_after,
weight_c_init=weight_c_init,
)
rnn_lst.append(layer_i)
self.rnn_lst = rnn_lst
Expand Down

0 comments on commit 4f691ec

Please sign in to comment.