-
Notifications
You must be signed in to change notification settings - Fork 6
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added support for using multiple losses and metrics in evaluator #8
Changes from 31 commits
7ef1362
69134ad
af47020
abd2268
3519217
48e4a89
cd15bbd
317978b
2678d05
efd97e6
5034018
c2da315
09f1078
c6b0a1b
66a9971
00f0598
59e8996
9af55e4
6215257
bbf4d4f
a113152
d898e84
7ea5ed9
c24cf26
547c0d0
066f092
fa004be
67e77ee
bb0bc62
eb5e792
cc76bdf
4e91766
e12dcd7
a8cd8b2
27b0fa8
c886473
fdd4fa2
82f5458
356c4da
dba5680
381caff
362ff84
00b8c95
b955e98
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. config is outdated w.r.t. |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -243,6 +243,8 @@ def run(self): | |
local_rank=components.settings.cuda_env.local_rank, | ||
batch_progress_publisher=batch_processed_publisher, | ||
evaluation_result_publisher=evaluation_result_publisher, | ||
loss_factories=components.evaluation_measures, | ||
metric_factories=[], | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. do we really need There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I agree that it would probably be better to just have one As a comparison, what hugginface's Trainer does: |
||
) | ||
|
||
# Gym | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
from __future__ import annotations | ||
|
||
from typing import Dict, Generic, Hashable, Optional, TypeVar | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
||
from modalities.running_env.fsdp.reducer import Reducer | ||
|
||
KeyType = TypeVar("KeyType", bound=Hashable) | ||
|
||
|
||
class Aggregator(Generic[KeyType]): | ||
|
||
def __init__(self, initial_values: Optional[Dict[KeyType, torch.Tensor]] = None) -> None: | ||
self._key_to_value = initial_values if initial_values else {} | ||
|
||
def add_values(self, value_dict: Dict[KeyType, torch.Tensor]): | ||
for key, value in value_dict.items(): | ||
self.add_value(key, value) | ||
|
||
def add_value(self, key: KeyType, value: torch.Tensor): | ||
if key not in self._key_to_value: | ||
self._key_to_value[key] = value | ||
else: | ||
self._key_to_value[key] += value | ||
|
||
def get_all_reduced_value( | ||
self, key: KeyType, reduce_operation: dist.ReduceOp.RedOpType = dist.ReduceOp.SUM | ||
) -> torch.Tensor: | ||
# we clone the value so that we can always resync the value without side-effects | ||
cloned_value = self._key_to_value[key].clone() | ||
value = Reducer.reduce(tensor=cloned_value, operation=reduce_operation) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since we have the hierarchical instantiation up and running now, we should pass in the reducer via the constructor. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, I agree that this makes sense. But I probably would postpone such a change until we actually have a second Reducer. |
||
return value |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from __future__ import annotations | ||
|
||
from enum import Enum | ||
from typing import Dict | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
||
from modalities.batch import InferenceResultBatch | ||
from modalities.evaluation.measure import AggregativeMeasure, AggregativeMeasureFactory | ||
from modalities.loss_functions import CLMCrossEntropyLoss | ||
|
||
|
||
class LossKeys(Enum): | ||
CLM_CROSS_ENTROPY = "clm_cross_entropy" | ||
NUM_SAMPLES = "num_samples" | ||
Comment on lines
+14
to
+16
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. previously, we always defined those in the config yaml. I'd suggest to also do that here for consistency. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this case, these keys are more like internals of the class. Rather than making them configurable, we should probably make an inner class of this enum. |
||
|
||
|
||
class AggregativeCLMCrossEntropyLoss(AggregativeMeasure[LossKeys]): | ||
|
||
def __init__(self, target_key: str, prediction_key: str, local_rank: int) -> None: | ||
super().__init__( | ||
aggregate_keys=list(LossKeys), | ||
reduce_ops={k: dist.ReduceOp.SUM for k in LossKeys}, | ||
tag="CLMCrossEntropyLoss", | ||
local_rank=local_rank, | ||
) | ||
self._loss = CLMCrossEntropyLoss(target_key=target_key, prediction_key=prediction_key, reduction="sum") | ||
|
||
def _postprocess_result_batch(self, batch_result: InferenceResultBatch) -> Dict[LossKeys, torch.Tensor]: | ||
loss = self._loss(batch_result) | ||
return { | ||
LossKeys.CLM_CROSS_ENTROPY: loss, | ||
LossKeys.NUM_SAMPLES: torch.tensor(len(batch_result)), | ||
} | ||
|
||
def _calc_measure(self, values: Dict[LossKeys, torch.Tensor]) -> torch.Tensor: | ||
return values[LossKeys.CLM_CROSS_ENTROPY] / values[LossKeys.NUM_SAMPLES] | ||
|
||
|
||
class AggregativeCLMCrossEntropyLossFactory(AggregativeMeasureFactory[LossKeys]): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do we need the factory for? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The reason for using factories here is so that the AggregatedMeasure objects only exist in one context. Since they are stateful objects, only the context that constructs them should be using them. In order to still parameterize which measure to use, the context is given the required factory. |
||
def __init__(self, target_key: str, prediction_key: str) -> None: | ||
self._target_key = target_key | ||
self._prediction_key = prediction_key | ||
|
||
def create(self, local_rank: int) -> AggregativeMeasure: | ||
return AggregativeCLMCrossEntropyLoss( | ||
target_key=self._target_key, | ||
prediction_key=self._prediction_key, | ||
local_rank=local_rank, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
from __future__ import annotations | ||
|
||
from abc import ABC, abstractmethod | ||
from typing import Dict, Generic, List | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
||
from modalities.batch import InferenceResultBatch | ||
from modalities.evaluation.aggregator import Aggregator, KeyType | ||
|
||
|
||
class AggregativeMeasureFactory(Generic[KeyType]): | ||
def create(self, local_rank: int) -> AggregativeMeasure: | ||
raise NotImplementedError | ||
|
||
|
||
class AggregativeMeasure(Generic[KeyType], ABC): | ||
BlueCrescent marked this conversation as resolved.
Show resolved
Hide resolved
|
||
def __init__( | ||
self, | ||
aggregate_keys: List[KeyType], | ||
reduce_ops: Dict[KeyType, dist.ReduceOp.RedOpType], | ||
tag: str, | ||
local_rank: int, | ||
) -> None: | ||
if torch.cuda.is_available(): | ||
self._device = torch.device(local_rank) | ||
else: | ||
self._device = "cpu" | ||
self._aggregator = Aggregator[KeyType]( | ||
initial_values={k: torch.zeros(1).to(self._device) for k in aggregate_keys} | ||
) | ||
self._aggregate_keys = aggregate_keys | ||
self._reduce_ops = reduce_ops | ||
self._tag = tag | ||
|
||
@property | ||
def tag(self) -> str: | ||
return self._tag | ||
|
||
def add(self, batch_result: InferenceResultBatch) -> None: | ||
res = self._postprocess_result_batch(batch_result) | ||
|
||
for key, value in res.items(): | ||
self._aggregator.add_value(key, value.to(self._device)) | ||
|
||
def compute(self) -> torch.Tensor: | ||
BlueCrescent marked this conversation as resolved.
Show resolved
Hide resolved
|
||
synced_vals: Dict[KeyType, torch.Tensor] = {} | ||
for key in self._aggregate_keys: | ||
synced_vals[key] = self._aggregator.get_all_reduced_value( | ||
key, | ||
self._reduce_ops[key], | ||
) | ||
|
||
return self._calc_measure(synced_vals) | ||
|
||
@abstractmethod | ||
def _postprocess_result_batch(self, batch_result: InferenceResultBatch) -> Dict[KeyType, torch.Tensor]: | ||
BlueCrescent marked this conversation as resolved.
Show resolved
Hide resolved
|
||
raise NotImplementedError | ||
|
||
@abstractmethod | ||
def _calc_measure(self, values: Dict[KeyType, torch.Tensor]) -> torch.Tensor: | ||
raise NotImplementedError |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
from __future__ import annotations | ||
|
||
from enum import Enum | ||
from typing import Dict | ||
|
||
import torch | ||
import torch.distributed as dist | ||
|
||
from modalities.batch import InferenceResultBatch | ||
from modalities.evaluation.measure import AggregativeMeasure, AggregativeMeasureFactory | ||
from modalities.loss_functions import CLMCrossEntropyLoss | ||
|
||
|
||
class PerplexityKeys(Enum): | ||
PERPLEXITY = "loss" | ||
NUM_SAMPLES = "num_samples" | ||
|
||
|
||
class AggregativePerplexity(AggregativeMeasure[PerplexityKeys]): | ||
def __init__(self, target_key: str, prediction_key: str, local_rank: int) -> None: | ||
super().__init__( | ||
aggregate_keys=list(PerplexityKeys), | ||
reduce_ops={k: dist.ReduceOp.SUM for k in PerplexityKeys}, | ||
tag="Perplexity", | ||
local_rank=local_rank, | ||
) | ||
self._target_key = target_key | ||
self._loss = CLMCrossEntropyLoss(target_key=target_key, prediction_key=prediction_key, reduction="none") | ||
|
||
def _postprocess_result_batch(self, batch_result: InferenceResultBatch) -> Dict[PerplexityKeys, torch.Tensor]: | ||
loss = self._loss(batch_result) # shape: (batch_size * seq_len) | ||
batch_size, seq_len = batch_result.get_targets(self._target_key).shape | ||
loss = loss.view(batch_size, seq_len) # shape: (batch_size, seq_len) | ||
perplexity = torch.exp(loss.sum(-1) / seq_len) | ||
return { | ||
PerplexityKeys.PERPLEXITY: perplexity.sum(), | ||
PerplexityKeys.NUM_SAMPLES: torch.tensor(len(batch_result)), | ||
} | ||
|
||
def _calc_measure(self, values: Dict[PerplexityKeys, torch.Tensor]) -> torch.Tensor: | ||
return values[PerplexityKeys.PERPLEXITY] / values[PerplexityKeys.NUM_SAMPLES] | ||
|
||
|
||
class AggregativePerplexityFactory(AggregativeMeasureFactory[PerplexityKeys]): | ||
def __init__(self, target_key: str, prediction_key: str) -> None: | ||
self._target_key = target_key | ||
self._prediction_key = prediction_key | ||
|
||
def create(self, local_rank: int) -> AggregativeMeasure[PerplexityKeys]: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. as mentioned earlier, I think the factories are overkill. |
||
return AggregativePerplexity( | ||
target_key=self._target_key, | ||
prediction_key=self._prediction_key, | ||
local_rank=local_rank, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
config is outdated w.r.t.
component_key
andvariant_key