Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added support for using multiple losses and metrics in evaluator #8

Closed
Closed
Show file tree
Hide file tree
Changes from 31 commits
Commits
Show all changes
44 commits
Select commit Hold shift + click to select a range
7ef1362
feat(evaluator): added support for using multiple losses and metrics …
BlueCrescent Jan 15, 2024
69134ad
chore: Merge branch 'main' into feat/multiple_losses_and_metrics_in_e…
le1nux Jan 19, 2024
af47020
chore: Merge branch 'main' into feat/multiple_losses_and_metrics_in_e…
le1nux Jan 21, 2024
abd2268
fix: num_samples added to throughput_aggregator were not of type tensor
le1nux Jan 21, 2024
3519217
feat: started implementation of Aggregator and StatefulMetrics
le1nux Jan 23, 2024
48e4a89
feat: towards generic measure evaluation
le1nux Jan 29, 2024
cd15bbd
feat(evaluation): Adding generic evaluation measure for training
Jan 29, 2024
317978b
feat(evaluation): various updates
BlueCrescent Feb 5, 2024
2678d05
Merge remote-tracking branch 'origin/main' into feat/multiple_losses_…
BlueCrescent Feb 5, 2024
efd97e6
refactor(evaluation): integration of evaluators in training
Feb 5, 2024
5034018
refactor(evaluation): adaptions to loss function; using conftest.py f…
Feb 5, 2024
c2da315
test(evaluation): added comparison implementation for perplexity test
BlueCrescent Feb 5, 2024
09f1078
feat(config): added first version of evaluation losses in config
BlueCrescent Feb 5, 2024
c6b0a1b
fix(config): added some missing code for reading the new eval measure…
BlueCrescent Feb 6, 2024
66a9971
feat(evaluation): add factory for ThroughputAggregator, finish perple…
Feb 12, 2024
00f0598
fix(evaluation): Fixed perplexity and finalized batch size two test.
BlueCrescent Feb 16, 2024
59e8996
fix: Minor fixes and reverted changes to Reducer.
BlueCrescent Feb 19, 2024
9af55e4
test(evaluation): Added multiple dataloadres to evaluator tests.
BlueCrescent Feb 19, 2024
6215257
fix(loss): Reverted change to cross entropy reduction default back to…
BlueCrescent Feb 19, 2024
bbf4d4f
fix(evaluation): Fixed parameter for throughput aggregator in Evaluator.
BlueCrescent Feb 19, 2024
a113152
fix(evaluation): Bug fix (mutable default arg) and minor refactoring.
BlueCrescent Feb 19, 2024
d898e84
test(evaluation): Added tests for ThroughputAggregator.
BlueCrescent Feb 19, 2024
7ea5ed9
Merge branch 'main' into feat/multiple_losses_and_metrics_in_evaluator
BlueCrescent Feb 19, 2024
c24cf26
feat(config): Added validation_measure_factories to all config files.
BlueCrescent Feb 19, 2024
547c0d0
refactor: Ran isort.
BlueCrescent Feb 19, 2024
066f092
Merge remote-tracking branch 'remotes/origin/main' into feat/multiple…
BlueCrescent Mar 4, 2024
fa004be
fix(trainer): Fixed errors from merging.
BlueCrescent Mar 4, 2024
67e77ee
feat(config): Adapted eval measure configs to new config scheme.
BlueCrescent Mar 4, 2024
bb0bc62
fix(config): Fixed configs for eval measures.
BlueCrescent Mar 4, 2024
eb5e792
fix(evaluation): data_loader.sampler_batch_size to data_loader.batch_…
BlueCrescent Mar 4, 2024
cc76bdf
chore(merge): Merge remote-tracking branch 'origin/main' into feat/mu…
BlueCrescent Mar 11, 2024
4e91766
refactor(evaluation): Renamed AggregativeMeasure to AggregatedMeasure.
BlueCrescent Mar 15, 2024
e12dcd7
refactor(evaluation): Renamed batch_result to result_batch.
BlueCrescent Mar 15, 2024
a8cd8b2
refactor(evaluation): Explicit reduce_operation parameter.
BlueCrescent Mar 15, 2024
27b0fa8
refactor(evaluation): Fixed typos.
BlueCrescent Mar 15, 2024
c886473
refactor(evaluation): In AggregatedMeasure, renamed compute() to aggr…
BlueCrescent Mar 15, 2024
fdd4fa2
refactor(evaluation): Turned _extract_num_samples into a static metho…
BlueCrescent Mar 15, 2024
82f5458
docs(config): Improved code comment.
BlueCrescent Mar 15, 2024
356c4da
refactor(utilities): Removed unused imports.
BlueCrescent Mar 15, 2024
dba5680
refactor(utilities): Removed method only used by tests.
BlueCrescent Mar 15, 2024
381caff
refactor(evaluation): Adapted evaluator tests to previous renaming of…
BlueCrescent Mar 15, 2024
362ff84
feat(evaluation): Changed ThroughputAggregationContext to be usable i…
BlueCrescent Mar 15, 2024
00b8c95
chore(merge): Merge remote-tracking branch 'origin/main' into feat/mu…
BlueCrescent Mar 18, 2024
b955e98
chore(merge): Merge remote-tracking branch 'origin/main' into feat/mu…
BlueCrescent Mar 18, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions config_files/config.yaml
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config is outdated w.r.t. component_key and variant_key

Original file line number Diff line number Diff line change
Expand Up @@ -173,3 +173,9 @@ loss:
config:
target_key: ${data.target_key}
prediction_key: ${model.config.prediction_key}

validation_measure_factories:
- type_hint: AggregativeCLMCrossEntropyLossFactory
config:
target_key: ${data.target_key}
prediction_key: ${model.config.prediction_key}
7 changes: 7 additions & 0 deletions config_files/config_example_hf_meditron_7B_instruction.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,13 @@ loss_fn:
target_key: ${settings.referencing_keys.target_key}
prediction_key: ${settings.referencing_keys.prediction_key}

evaluation_measures:
- component_key: eval_measures
variant_key: clm_cross_entropy_loss
config:
target_key: ${data.target_key}
prediction_key: ${model.config.prediction_key}

# scheduler:
# type_hint: StepLR
# config:
Expand Down
9 changes: 8 additions & 1 deletion config_files/config_example_mem_map_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,13 @@ loss_fn:
target_key: ${settings.referencing_keys.target_key}
prediction_key: ${settings.referencing_keys.prediction_key}

evaluation_measures:
- component_key: eval_measures
variant_key: clm_cross_entropy_loss
config:
target_key: ${data.target_key}
prediction_key: ${model.config.prediction_key}

optimizer:
component_key: optimizer
variant_key: adam_w
Expand Down Expand Up @@ -209,4 +216,4 @@ evaluation_subscriber:
project: modalities
mode: ONLINE
experiment_id: ${settings.experiment_id}
directory: "."
directory: "."
28 changes: 18 additions & 10 deletions config_files/config_example_openGPTx_dataset.yaml
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

config is outdated w.r.t. component_key and variant_key

Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@ modalities_setup:
run_mode: FROM_SCRATCH
settings:
global_num_seen_samples: 0

wandb:
project_name: modalities
mode: ONLINE

data:
sample_key: "input_ids"
target_key: "target_ids"
Expand All @@ -22,7 +27,7 @@ data:
sampler:
type_hint: DistributedSampler
config:
rank: ${training.local_rank}
rank: ${training.global_rank}
num_replicas: ${training.world_size}
shuffle: true
dataset:
Expand Down Expand Up @@ -53,7 +58,7 @@ data:
sampler:
type_hint: DistributedSampler
config:
rank: ${training.local_rank}
rank: ${training.global_rank}
num_replicas: ${training.world_size}
shuffle: false
dataset:
Expand Down Expand Up @@ -82,7 +87,7 @@ data:
sampler:
type_hint: DistributedSampler
config:
rank: ${training.local_rank}
rank: ${training.global_rank}
num_replicas: ${training.world_size}
shuffle: false
dataset:
Expand All @@ -98,11 +103,6 @@ data:
sample_key: ${data.sample_key}
target_key: ${data.target_key}

wandb:
project_name: modalities
mode: ONLINE


training:
process_group_backend: "nccl"
global_num_training_samples: 2048
Expand All @@ -114,7 +114,8 @@ training:
local_train_micro_batch_size: ${data.train_dataloader.config.batch_sampler.config.batch_size}
global_num_seen_samples: ${modalities_setup.settings.global_num_seen_samples}
gradient_acc_step: 1
do_apply_activation_checkpointing: True
do_apply_activation_checkpointing: false


checkpointing:
checkpointing_strategy:
Expand Down Expand Up @@ -175,4 +176,11 @@ loss:
type_hint: CLMCrossEntropyLoss
config:
target_key: ${data.target_key}
prediction_key: ${model.config.prediction_key}
prediction_key: ${model.config.prediction_key}

evaluation_measures:
- component_key: eval_measures
variant_key: clm_cross_entropy_loss
config:
target_key: ${data.target_key}
prediction_key: ${model.config.prediction_key}
7 changes: 7 additions & 0 deletions config_files/config_lorem_ipsum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@ loss_fn:
target_key: target_ids
prediction_key: logits

evaluation_measures:
- component_key: eval_measures
variant_key: clm_cross_entropy_loss
config:
target_key: target_ids
prediction_key: logits

wrapped_model:
component_key: model
variant_key: fsdp_wrapped
Expand Down
2 changes: 2 additions & 0 deletions src/modalities/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,8 @@ def run(self):
local_rank=components.settings.cuda_env.local_rank,
batch_progress_publisher=batch_processed_publisher,
evaluation_result_publisher=evaluation_result_publisher,
loss_factories=components.evaluation_measures,
metric_factories=[],
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we really need metric_factories or could we just have a more generic evaluation_measure_factories?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that it would probably be better to just have one evaluation_measure_factories and use those for both metrics and losses (and whatever other measures we want to compute). The reason, for having the metric and losses split up was that in the output object already had that distinction.

As a comparison, what hugginface's Trainer does:
It also only has one filed for eval measures. One of the uses their is to decide the best checkpoint(s). For this you can select on of the eval measures you're logging (in an additional parameter). And you can set a flag "higher is better" to differentiate basically between losses and metrics/scores.

)

# Gym
Expand Down
3 changes: 3 additions & 0 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from modalities.checkpointing.checkpointing_strategies import CheckpointingStrategyIF
from modalities.config.lookup_enum import LookupEnum
from modalities.dataloader.dataloader import LLMDataLoader
from modalities.evaluation.measure import AggregativeMeasureFactory
from modalities.logging_broker.subscriber import MessageSubscriberIF
from modalities.loss_functions import Loss
from modalities.models.gpt2.collator import CollateFnIF
Expand Down Expand Up @@ -60,6 +61,7 @@ def __get_pydantic_core_schema__(
PydanticLLMDataLoaderIFType = Annotated[LLMDataLoader, PydanticThirdPartyTypeIF(LLMDataLoader)]
PydanticOptimizerIFType = Annotated[Optimizer, PydanticThirdPartyTypeIF(Optimizer)]
PydanticLossIFType = Annotated[Loss, PydanticThirdPartyTypeIF(Loss)]
PydanticMeasureFactoryIFType = Annotated[AggregativeMeasureFactory, PydanticThirdPartyTypeIF(AggregativeMeasureFactory)]
PydanticMessageSubscriberIFType = Annotated[MessageSubscriberIF, PydanticThirdPartyTypeIF(MessageSubscriberIF)]


Expand Down Expand Up @@ -306,6 +308,7 @@ class ComponentsModel(BaseModel):
wrapped_model: PydanticModelIFType
optimizer: PydanticOptimizerIFType
loss_fn: PydanticLossIFType
evaluation_measures: List[PydanticMeasureFactoryIFType]
train_dataloader: PydanticLLMDataLoaderIFType
eval_dataloaders: List[PydanticLLMDataLoaderIFType]
batch_progress_subscriber: PydanticMessageSubscriberIFType
Expand Down
Empty file.
34 changes: 34 additions & 0 deletions src/modalities/evaluation/aggregator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
from __future__ import annotations

from typing import Dict, Generic, Hashable, Optional, TypeVar

import torch
import torch.distributed as dist

from modalities.running_env.fsdp.reducer import Reducer

KeyType = TypeVar("KeyType", bound=Hashable)


class Aggregator(Generic[KeyType]):

def __init__(self, initial_values: Optional[Dict[KeyType, torch.Tensor]] = None) -> None:
self._key_to_value = initial_values if initial_values else {}

def add_values(self, value_dict: Dict[KeyType, torch.Tensor]):
for key, value in value_dict.items():
self.add_value(key, value)

def add_value(self, key: KeyType, value: torch.Tensor):
if key not in self._key_to_value:
self._key_to_value[key] = value
else:
self._key_to_value[key] += value

def get_all_reduced_value(
self, key: KeyType, reduce_operation: dist.ReduceOp.RedOpType = dist.ReduceOp.SUM
) -> torch.Tensor:
# we clone the value so that we can always resync the value without side-effects
cloned_value = self._key_to_value[key].clone()
value = Reducer.reduce(tensor=cloned_value, operation=reduce_operation)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we have the hierarchical instantiation up and running now, we should pass in the reducer via the constructor.
We can think of different reducers e.g., torch distributed reducer, which reduces the tensors across ranks. Another reducer for single GPU training without FSDP (which is still a todo) could just call torch.mean(). What do you think?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I agree that this makes sense. But I probably would postpone such a change until we actually have a second Reducer.

return value
51 changes: 51 additions & 0 deletions src/modalities/evaluation/clm_cross_entropy_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import annotations

from enum import Enum
from typing import Dict

import torch
import torch.distributed as dist

from modalities.batch import InferenceResultBatch
from modalities.evaluation.measure import AggregativeMeasure, AggregativeMeasureFactory
from modalities.loss_functions import CLMCrossEntropyLoss


class LossKeys(Enum):
CLM_CROSS_ENTROPY = "clm_cross_entropy"
NUM_SAMPLES = "num_samples"
Comment on lines +14 to +16
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

previously, we always defined those in the config yaml. I'd suggest to also do that here for consistency.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In this case, these keys are more like internals of the class. Rather than making them configurable, we should probably make an inner class of this enum.



class AggregativeCLMCrossEntropyLoss(AggregativeMeasure[LossKeys]):

def __init__(self, target_key: str, prediction_key: str, local_rank: int) -> None:
super().__init__(
aggregate_keys=list(LossKeys),
reduce_ops={k: dist.ReduceOp.SUM for k in LossKeys},
tag="CLMCrossEntropyLoss",
local_rank=local_rank,
)
self._loss = CLMCrossEntropyLoss(target_key=target_key, prediction_key=prediction_key, reduction="sum")

def _postprocess_result_batch(self, batch_result: InferenceResultBatch) -> Dict[LossKeys, torch.Tensor]:
loss = self._loss(batch_result)
return {
LossKeys.CLM_CROSS_ENTROPY: loss,
LossKeys.NUM_SAMPLES: torch.tensor(len(batch_result)),
}

def _calc_measure(self, values: Dict[LossKeys, torch.Tensor]) -> torch.Tensor:
return values[LossKeys.CLM_CROSS_ENTROPY] / values[LossKeys.NUM_SAMPLES]


class AggregativeCLMCrossEntropyLossFactory(AggregativeMeasureFactory[LossKeys]):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do we need the factory for?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason for using factories here is so that the AggregatedMeasure objects only exist in one context. Since they are stateful objects, only the context that constructs them should be using them. In order to still parameterize which measure to use, the context is given the required factory.

def __init__(self, target_key: str, prediction_key: str) -> None:
self._target_key = target_key
self._prediction_key = prediction_key

def create(self, local_rank: int) -> AggregativeMeasure:
return AggregativeCLMCrossEntropyLoss(
target_key=self._target_key,
prediction_key=self._prediction_key,
local_rank=local_rank,
)
63 changes: 63 additions & 0 deletions src/modalities/evaluation/measure.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Dict, Generic, List

import torch
import torch.distributed as dist

from modalities.batch import InferenceResultBatch
from modalities.evaluation.aggregator import Aggregator, KeyType


class AggregativeMeasureFactory(Generic[KeyType]):
def create(self, local_rank: int) -> AggregativeMeasure:
raise NotImplementedError


class AggregativeMeasure(Generic[KeyType], ABC):
BlueCrescent marked this conversation as resolved.
Show resolved Hide resolved
def __init__(
self,
aggregate_keys: List[KeyType],
reduce_ops: Dict[KeyType, dist.ReduceOp.RedOpType],
tag: str,
local_rank: int,
) -> None:
if torch.cuda.is_available():
self._device = torch.device(local_rank)
else:
self._device = "cpu"
self._aggregator = Aggregator[KeyType](
initial_values={k: torch.zeros(1).to(self._device) for k in aggregate_keys}
)
self._aggregate_keys = aggregate_keys
self._reduce_ops = reduce_ops
self._tag = tag

@property
def tag(self) -> str:
return self._tag

def add(self, batch_result: InferenceResultBatch) -> None:
res = self._postprocess_result_batch(batch_result)

for key, value in res.items():
self._aggregator.add_value(key, value.to(self._device))

def compute(self) -> torch.Tensor:
BlueCrescent marked this conversation as resolved.
Show resolved Hide resolved
synced_vals: Dict[KeyType, torch.Tensor] = {}
for key in self._aggregate_keys:
synced_vals[key] = self._aggregator.get_all_reduced_value(
key,
self._reduce_ops[key],
)

return self._calc_measure(synced_vals)

@abstractmethod
def _postprocess_result_batch(self, batch_result: InferenceResultBatch) -> Dict[KeyType, torch.Tensor]:
BlueCrescent marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError

@abstractmethod
def _calc_measure(self, values: Dict[KeyType, torch.Tensor]) -> torch.Tensor:
raise NotImplementedError
54 changes: 54 additions & 0 deletions src/modalities/evaluation/perplexity.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
from __future__ import annotations

from enum import Enum
from typing import Dict

import torch
import torch.distributed as dist

from modalities.batch import InferenceResultBatch
from modalities.evaluation.measure import AggregativeMeasure, AggregativeMeasureFactory
from modalities.loss_functions import CLMCrossEntropyLoss


class PerplexityKeys(Enum):
PERPLEXITY = "loss"
NUM_SAMPLES = "num_samples"


class AggregativePerplexity(AggregativeMeasure[PerplexityKeys]):
def __init__(self, target_key: str, prediction_key: str, local_rank: int) -> None:
super().__init__(
aggregate_keys=list(PerplexityKeys),
reduce_ops={k: dist.ReduceOp.SUM for k in PerplexityKeys},
tag="Perplexity",
local_rank=local_rank,
)
self._target_key = target_key
self._loss = CLMCrossEntropyLoss(target_key=target_key, prediction_key=prediction_key, reduction="none")

def _postprocess_result_batch(self, batch_result: InferenceResultBatch) -> Dict[PerplexityKeys, torch.Tensor]:
loss = self._loss(batch_result) # shape: (batch_size * seq_len)
batch_size, seq_len = batch_result.get_targets(self._target_key).shape
loss = loss.view(batch_size, seq_len) # shape: (batch_size, seq_len)
perplexity = torch.exp(loss.sum(-1) / seq_len)
return {
PerplexityKeys.PERPLEXITY: perplexity.sum(),
PerplexityKeys.NUM_SAMPLES: torch.tensor(len(batch_result)),
}

def _calc_measure(self, values: Dict[PerplexityKeys, torch.Tensor]) -> torch.Tensor:
return values[PerplexityKeys.PERPLEXITY] / values[PerplexityKeys.NUM_SAMPLES]


class AggregativePerplexityFactory(AggregativeMeasureFactory[PerplexityKeys]):
def __init__(self, target_key: str, prediction_key: str) -> None:
self._target_key = target_key
self._prediction_key = prediction_key

def create(self, local_rank: int) -> AggregativeMeasure[PerplexityKeys]:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as mentioned earlier, I think the factories are overkill.

return AggregativePerplexity(
target_key=self._target_key,
prediction_key=self._prediction_key,
local_rank=local_rank,
)
Loading