Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixed batch sizes #68

Closed
wants to merge 6 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 14 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ Checkpointing is handled by
[Lightning](https://pytorch-lightning.readthedocs.io/en/stable/common/checkpointing_basic.html).
The path for model information, including checkpoints, is specified by a
combination of `--model_dir` and `--experiment`, such that we build the path
`model_dir/experiment/version_n`, where each run of an experiment with the
same `model_dir` and `experiment` is namespaced with a new version number.
A version stores everything needed to reload the model, including the
hyperparameters (`model_dir/experiment_name/version_n/hparams.yaml`) and the
checkpoints directory (`model_dir/experiment_name/version_n/checkpoints`).
`model_dir/experiment/version_n`, where each run of an experiment with the same
`model_dir` and `experiment` is namespaced with a new version number. A version
stores everything needed to reload the model, including the hyperparameters
(`model_dir/experiment_name/version_n/hparams.yaml`) and the checkpoints
directory (`model_dir/experiment_name/version_n/checkpoints`).

By default, each run initializes a new model from scratch, unless the
`--train_from` argument is specified. To continue training from a specific
Expand Down Expand Up @@ -151,31 +151,33 @@ By default, the `attentive_lstm`, `lstm`, `pointer_generator_lstm`, and

A non-exhaustive list includes:

- Batch size:
- `--batch_size` (default: `32`)
- Regularization:
- `--dropout` (default: `.2`)
- `--label_smoothing` (default: not enabled)
- `--gradient_clip_val` (default: not enabled)
- Optimizer:
- Optimization:
- `--learning_rate` (default: `.001`)
- `--optimizer` (default: `"adam"`)
- `--beta1` (default: `.9`): $\beta_1$ hyperparameter for the Adam
optimizer (`--optimizer adam`)
- `--beta2` (default: `.99`): $\beta_2$ hyperparameter for the Adam
optimizer (`--optimizer adam`)
- `--scheduler` (default: not enabled)
- Duration:
- Training duration:
- `--max_epochs`
- `--min_epochs`
- `--max_steps`
- `--min_steps`
- `--max_time`
- `--patience`
- Seeding:
- Sequence length:
- `--max_source_length` (default: `128`)
- `--max_target_length` (default: `128`)
- Other:
- `--batch_size` (default: `32`)
- `--seed`
- [Weights & Biases](https://wandb.ai/site):
- `--wandb` (default: `False`): enables Weights & Biases tracking
- `--wandb` (default: `False`): enables [Weights &
Biases](https://wandb.ai/site) tracking

**No neural model should be deployed without proper hyperparameter tuning.**
However, the default options give a reasonable initial settings for an attentive
Expand Down
23 changes: 11 additions & 12 deletions yoyodyne/batches.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,34 +23,33 @@ def __init__(
self,
tensorlist: List[torch.Tensor],
pad_idx: int,
length: int,
length_msg_callback: Optional[Callable[[int], None]] = None,
pad_len: Optional[int] = None,
):
"""Constructs the padded tensor from a list of tensors.

The optional pad_len argument can be used, e.g., to keep all batches
The optional length argument can be used, e.g., to keep all batches
the exact same length, which improves performance on certain
accelerators. If not specified, it will be computed using the length
of the longest input tensor.

Args:
tensorlist (List[torch.Tensor]): a list of tensors.
pad_idx (int): padding index.
length_msg_callback (Callable[[int], None]): callback for catching
a violating tensor length.
pad_len (int, optional): desired length for padding.
length (int): desired padded length.
length_msg_callback (Callable[[int], None]): callback which flags
invalid tensor lengths.

"""
super().__init__()
if pad_len is None:
pad_len = max(len(tensor) for tensor in tensorlist)
if length_msg_callback is not None:
length_msg_callback(pad_len)
batch_length = max(len(tensor) for tensor in tensorlist)
length_msg_callback(batch_length)
self.register_buffer(
"padded",
torch.stack(
[
self.pad_tensor(tensor, pad_idx, pad_len)
self.pad_tensor(tensor, pad_idx, length)
for tensor in tensorlist
],
),
Expand All @@ -59,19 +58,19 @@ def __init__(

@staticmethod
def pad_tensor(
tensor: torch.Tensor, pad_idx: int, pad_max: int
tensor: torch.Tensor, pad_idx: int, length: int
) -> torch.Tensor:
"""Pads a tensor.

Args:
tensor (torch.Tensor).
pad_idx (int): padding index.
pad_max (int): desired tensor length.
length (int): desired tensor length.

Returns:
torch.Tensor.
"""
padding = pad_max - len(tensor)
padding = length - len(tensor)
return nn.functional.pad(tensor, (0, padding), "constant", pad_idx)

def __len__(self) -> int:
Expand Down
32 changes: 22 additions & 10 deletions yoyodyne/collators.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,27 +27,32 @@ def __init__(
pad_idx,
config: dataconfig.DataConfig,
arch: str,
max_source_length: int = defaults.MAX_SOURCE_LENGTH,
max_target_length: int = defaults.MAX_TARGET_LENGTH,
max_source_length=defaults.MAX_SOURCE_LENGTH,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this have an annotation Optional[int] ?

max_target_length=defaults.MAX_TARGET_LENGTH,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should this have an annotation Optional[int] ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not if it's specified a kwargs with a default value that isn't None.

# FIXME: max features length.
):
"""Initializes the collator.

If max source and/or target length are not specified, they are
computed dynamically (i.e., for each source and target batch).

Args:
pad_idx (int).
config (dataconfig.DataConfig).
arch (str).
max_source_length (int).
max_target_length (int).
max_source_length (int, optional).
max_target_length (int, optional).
"""
self.pad_idx = pad_idx
self.has_features = config.has_features
self.has_target = config.has_target
self.max_source_length = max_source_length
self.max_target_length = max_target_length
self.separate_features = config.has_features and arch in [
"pointer_generator_lstm",
"transducer",
]
self.max_source_length = max_source_length
self.max_target_length = max_target_length
# FIXME: max features length.

def _source_length_error(self, padded_length: int):
"""Callback function to raise the error when the padded length of the
Expand All @@ -69,8 +74,8 @@ def _target_length_warning(self, padded_length: int):
"""Callback function to log a message when the padded length of the
target batch is greater than the `max_target_length` allowed.

Since `max_target_length` just truncates during inference, this is
simply a suggestion.
Since `max_target_length` truncates during inference, this is only a
suggestion; it is logged rather than converted into an error.

Args:
padded_length (int): The length of the the padded tensor.
Expand All @@ -79,7 +84,7 @@ def _target_length_warning(self, padded_length: int):
msg = f"The length of a batch ({padded_length}) "
msg += "is greater than the `--max_target_length` specified "
msg += f"({self.max_target_length}). This means that "
msg += "decoding at inference time will likely be truncated. "
msg += "truncation may occur at inference time. "
msg += "Consider increasing `--max_target_length`."
util.log_info(msg)

Expand Down Expand Up @@ -111,6 +116,7 @@ def pad_source(
return batches.PaddedTensor(
[item.source for item in itemlist],
self.pad_idx,
self.max_source_length,
self._source_length_error,
)

Expand All @@ -129,6 +135,8 @@ def pad_source_features(
return batches.PaddedTensor(
self.concatenate_source_and_features(itemlist),
self.pad_idx,
self.max_source_length,
# FIXME: max features length.
self._source_length_error,
)

Expand All @@ -145,7 +153,9 @@ def pad_features(
batches.PaddedTensor.
"""
return batches.PaddedTensor(
[item.features for item in itemlist], self.pad_idx
[item.features for item in itemlist],
self.pad_idx,
# FIXME: max features length.
)

def pad_target(
Expand All @@ -162,6 +172,7 @@ def pad_target(
return batches.PaddedTensor(
[item.target for item in itemlist],
self.pad_idx,
self.max_target_length,
self._target_length_warning,
)

Expand Down Expand Up @@ -205,3 +216,4 @@ def add_argparse_args(parser: argparse.ArgumentParser) -> None:
default=defaults.MAX_TARGET_LENGTH,
help="Maximum target string length. Default: %(default)s.",
)
# FIXME: max features length.
58 changes: 33 additions & 25 deletions yoyodyne/dataconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import argparse
import csv
import dataclasses
import functools
import inspect
from typing import Iterator, List, Tuple

Expand Down Expand Up @@ -44,6 +45,22 @@ class DataConfig:
features_sep: str = defaults.FEATURES_SEP
tied_vocabulary: bool = defaults.TIED_VOCABULARY

@staticmethod
def _get_cell(row: List[str], col: int, sep: str) -> List[str]:
"""Returns the split cell of a row.

Args:
row (List[str]): the split row.
col (int): the column index
sep (str): the string to split the column on; if the empty string,
the column is split into characters instead.

Returns:
List[str]: symbols from that cell.
"""
cell = row[col - 1] # -1 because we're using one-based indexing.
return list(cell) if not sep else cell.split(sep)

def __post_init__(self) -> None:
# This is automatically called after initialization.
if self.source_col < 1:
Expand All @@ -56,6 +73,16 @@ def __post_init__(self) -> None:
raise Error(f"Invalid features column: {self.features_col}")
if self.features_col != 0:
util.log_info("Including features")
# Curries repeatedly-used lookup functions.
self._get_source_cell = functools.partial(
self._get_cell, col=self.source_col, sep=self.source_sep
)
self._get_target_cell = functools.partial(
self._get_cell, col=self.target_col, sep=self.target_sep
)
self._get_features_cell = functools.partial(
self._get_cell, col=self.features_col, sep=self.features_sep
)

@classmethod
def from_argparse_args(cls, args, **kwargs):
Expand All @@ -68,22 +95,6 @@ def from_argparse_args(cls, args, **kwargs):
dataconfig_kwargs.update(**kwargs)
return cls(**dataconfig_kwargs)

@staticmethod
def _get_cell(row: List[str], col: int, sep: str) -> List[str]:
"""Returns the split cell of a row.

Args:
row (List[str]): the split row.
col (int): the column index
sep (str): the string to split the column on; if the empty string,
the column is split into characters instead.

Returns:
List[str]: symbol from that cell.
"""
cell = row[col - 1] # -1 because we're using one-based indexing.
return list(cell) if not sep else cell.split(sep)

# Source is always present.

@property
Expand All @@ -99,7 +110,7 @@ def source_samples(self, filename: str) -> Iterator[List[str]]:
with open(filename, "r") as source:
tsv_reader = csv.reader(source, delimiter="\t")
for row in tsv_reader:
yield self._get_cell(row, self.source_col, self.source_sep)
yield self._get_source_cell(row)

def source_target_samples(
self, filename: str
Expand All @@ -108,8 +119,8 @@ def source_target_samples(
with open(filename, "r") as source:
tsv_reader = csv.reader(source, delimiter="\t")
for row in tsv_reader:
source = self._get_cell(row, self.source_col, self.source_sep)
target = self._get_cell(row, self.target_col, self.target_sep)
source = self._get_source_cell(row)
target = self._get_target_cell(row)
yield source, target

def source_features_target_samples(
Expand All @@ -119,15 +130,12 @@ def source_features_target_samples(
with open(filename, "r") as source:
tsv_reader = csv.reader(source, delimiter="\t")
for row in tsv_reader:
source = self._get_cell(row, self.source_col, self.source_sep)
source = self._get_source_cell(row)
# Avoids overlap with source.
features = [
f"[{feature}]"
for feature in self._get_cell(
row, self.features_col, self.features_sep
)
f"[{feature}]" for feature in self._get_features_cell(row)
]
target = self._get_cell(row, self.target_col, self.target_sep)
target = self._get_target_cell(row)
yield source, features, target

def samples(self, filename: str) -> Iterator[Tuple[List[str], ...]]:
Expand Down
17 changes: 16 additions & 1 deletion yoyodyne/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
Trainer to move them to the appropriate device."""

import abc
import functools
from typing import List, Optional, Set, Union

import torch
Expand Down Expand Up @@ -75,7 +76,7 @@ def __init__(
Args:
filename (str): input filename.
config (dataconfig.DataConfig): dataset configuration.
other (indexes.IndexNoFeatures, optional): if provided,
index (indexes.IndexNoFeatures, optional): if provided,
use this index to avoid recomputing it.
"""
super().__init__()
Expand Down Expand Up @@ -113,6 +114,16 @@ def _make_index(self) -> indexes.IndexNoFeatures:
sorted(source_vocabulary), sorted(target_vocabulary)
)

@functools.cached_property
def max_source_length(self) -> int:
# " + 2" for start and end tag.
return max(len(source) for source, _, *_ in self.samples) + 2

@functools.cached_property
def max_target_length(self) -> int:
# " + 1" for end tag.
return max(len(target) for _, target, *_ in self.samples) + 1

def encode(
self,
symbol_map: indexes.SymbolMap,
Expand Down Expand Up @@ -307,6 +318,10 @@ def _make_index(self) -> indexes.IndexFeatures:
sorted(target_vocabulary),
)

@functools.cached_property
def max_features_length(self) -> int:
return max(len(features) for _, _, features in self.samples)

def __getitem__(self, idx: int) -> Item:
"""Retrieves item by index.

Expand Down
Loading