-
Notifications
You must be signed in to change notification settings - Fork 19
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fixed batch sizes #68
Conversation
It is not plugged into anything yet. Working on issue CUNY-CL#50.
This is not strictly related to the issue but it came up so I did it. Currying is implemented at a very low level in CPython, and eliminates 3N dictionary lookups.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks good! I left a couple small nits. Regarding future things:
I assume that we are always padding to the max (for source and target); I can make it optional later, though the performance penalty is quite small in my experiments.
So we would add an option to dynamically pad to the max of a given batch? Yeah since we have the code already to this, we might as well add that option back later. But true, this is less of a concern these days.
I am ignoring feature models for now; that can be incorporated later.
Sounds good. I guess we would just want another argument for max feature size in models that have separate features? The API for this could be a little confusing, probably worth thinking through a bit.
@@ -27,27 +27,32 @@ def __init__( | |||
pad_idx, | |||
config: dataconfig.DataConfig, | |||
arch: str, | |||
max_source_length: int = defaults.MAX_SOURCE_LENGTH, | |||
max_target_length: int = defaults.MAX_TARGET_LENGTH, | |||
max_source_length=defaults.MAX_SOURCE_LENGTH, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this have an annotation Optional[int]
?
max_source_length: int = defaults.MAX_SOURCE_LENGTH, | ||
max_target_length: int = defaults.MAX_TARGET_LENGTH, | ||
max_source_length=defaults.MAX_SOURCE_LENGTH, | ||
max_target_length=defaults.MAX_TARGET_LENGTH, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this have an annotation Optional[int] ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not if it's specified a kwargs with a default value that isn't None
.
yoyodyne/datasets.py
Outdated
@functools.cached_property | ||
def max_source_length(self) -> int: | ||
# " + 2" for start and end tag. | ||
return max(len(source) + 2 for source, _, *_ in self.samples) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tiny thing, but is max(...) + 2
preferable?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes!
@@ -81,7 +81,7 @@ def encode( | |||
packed_outs, | |||
batch_first=True, | |||
padding_value=self.pad_idx, | |||
total_length=None, | |||
total_length=self.max_source_length, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch on this!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This is really stale so I'm going to close it and reopen next time / if I tackle it in the future. |
This is a draft. Important simplifications while I get the basics right:
I set the actual source_max_length to the min(max(longest source string in train, longest source string in dev), --max_source_length)), and similarly for target length. I then lightly modify the LSTM (it needs to be told the max source length) and the transformer (it needs to make the positional embedding as large as max of the longest source and target string). Everything else is just plumbing.
Closes #50.