-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add part: lstm block #66
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
One minor Q.
i6_models/parts/lstm.py
Outdated
enforce_sorted: bool | ||
|
||
@classmethod | ||
def from_dict(cls, model_cfg_dict: Dict): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same Q as in the other PR: why is this necessary now, and hasn't been for the other assemblies?
Co-authored-by: Albert Zeyer <[email protected]>
if seq_len.get_device() >= 0: | ||
seq_len = seq_len.cpu() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if seq_len.get_device() >= 0: | |
seq_len = seq_len.cpu() | |
seq_len = seq_len.cpu() |
) | ||
|
||
def forward(self, x: torch.Tensor, seq_len: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | ||
if not torch.jit.is_scripting() and not torch.jit.is_tracing(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why only when not scripting? Don't you want that seq_len
is always on CPU?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I followed the example in the blstm part.
if not torch.jit.is_scripting() and not torch.jit.is_tracing():
# during graph mode we have to assume all Tensors are on the correct device,
# otherwise move lengths to the CPU if they are on GPU
if seq_len.get_device() >= 0:
seq_len = seq_len.cpu()
I did not copy the comment over... I did not yet get to look why this is necessary
@JackTemaki you implemented the BLSTM IIRC. You remember why this was done in this way?
Co-authored-by: Albert Zeyer <[email protected]>
Adds LSTM Block