Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed batch sizes #68

Closed
wants to merge 6 commits into from
Closed

Fixed batch sizes #68

wants to merge 6 commits into from

Conversation

kylebgorman
Copy link
Contributor

@kylebgorman kylebgorman commented May 29, 2023

This is a draft. Important simplifications while I get the basics right:

  • I assume that we are always padding to the max (for source and target); I can make it optional later, though the performance penalty is quite small in my experiments.
  • I am ignoring feature models for now; that can be incorporated later.
  • Prediction isn't tested yet.

I set the actual source_max_length to the min(max(longest source string in train, longest source string in dev), --max_source_length)), and similarly for target length. I then lightly modify the LSTM (it needs to be told the max source length) and the transformer (it needs to make the positional embedding as large as max of the longest source and target string). Everything else is just plumbing.

Closes #50.

It is not plugged into anything yet.

Working on issue CUNY-CL#50.
This is not strictly related to the issue but it came up so I did it. Currying is implemented at a very low level in CPython, and eliminates 3N dictionary lookups.
@kylebgorman kylebgorman requested a review from Adamits May 29, 2023 22:50
Copy link
Collaborator

@Adamits Adamits left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This looks good! I left a couple small nits. Regarding future things:

I assume that we are always padding to the max (for source and target); I can make it optional later, though the performance penalty is quite small in my experiments.

So we would add an option to dynamically pad to the max of a given batch? Yeah since we have the code already to this, we might as well add that option back later. But true, this is less of a concern these days.

I am ignoring feature models for now; that can be incorporated later.

Sounds good. I guess we would just want another argument for max feature size in models that have separate features? The API for this could be a little confusing, probably worth thinking through a bit.

@@ -27,27 +27,32 @@ def __init__(
pad_idx,
config: dataconfig.DataConfig,
arch: str,
max_source_length: int = defaults.MAX_SOURCE_LENGTH,
max_target_length: int = defaults.MAX_TARGET_LENGTH,
max_source_length=defaults.MAX_SOURCE_LENGTH,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this have an annotation Optional[int] ?

max_source_length: int = defaults.MAX_SOURCE_LENGTH,
max_target_length: int = defaults.MAX_TARGET_LENGTH,
max_source_length=defaults.MAX_SOURCE_LENGTH,
max_target_length=defaults.MAX_TARGET_LENGTH,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this have an annotation Optional[int] ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not if it's specified a kwargs with a default value that isn't None.

@functools.cached_property
def max_source_length(self) -> int:
# " + 2" for start and end tag.
return max(len(source) + 2 for source, _, *_ in self.samples)
Copy link
Collaborator

@Adamits Adamits May 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tiny thing, but is max(...) + 2 preferable?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes!

@@ -81,7 +81,7 @@ def encode(
packed_outs,
batch_first=True,
padding_value=self.pad_idx,
total_length=None,
total_length=self.max_source_length,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch on this!

Copy link
Collaborator

@Adamits Adamits left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@kylebgorman
Copy link
Contributor Author

This is really stale so I'm going to close it and reopen next time / if I tackle it in the future.

@kylebgorman kylebgorman closed this Jul 3, 2024
@kylebgorman kylebgorman deleted the max branch December 9, 2024 22:46
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Consistent batch size
2 participants