Skip to content

Commit

Permalink
feat(config): Adapted eval measure configs to new config scheme.
Browse files Browse the repository at this point in the history
  • Loading branch information
BlueCrescent committed Mar 4, 2024
1 parent fa004be commit 67e77ee
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 19 deletions.
6 changes: 3 additions & 3 deletions config_files/config_example_hf_meditron_7B_instruction.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ loss_fn:
target_key: ${settings.referencing_keys.target_key}
prediction_key: ${settings.referencing_keys.prediction_key}

# TODO adapt this to new config scheme
validation_measure_factories:
- type_hint: AggregativeCLMCrossEntropyLossFactory
evaluation_measures:
- component_key: eval_measures
variant_key: clm_cross_entropy_loss
config:
target_key: ${data.target_key}
prediction_key: ${model.config.prediction_key}
Expand Down
6 changes: 3 additions & 3 deletions config_files/config_example_mem_map_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -177,9 +177,9 @@ loss_fn:
target_key: ${settings.referencing_keys.target_key}
prediction_key: ${settings.referencing_keys.prediction_key}

# TODO adapt this to new config scheme
validation_measure_factories:
- type_hint: AggregativeCLMCrossEntropyLossFactory
evaluation_measures:
- component_key: eval_measures
variant_key: clm_cross_entropy_loss
config:
target_key: ${data.target_key}
prediction_key: ${model.config.prediction_key}
Expand Down
5 changes: 3 additions & 2 deletions config_files/config_example_openGPTx_dataset.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -178,8 +178,9 @@ loss:
target_key: ${data.target_key}
prediction_key: ${model.config.prediction_key}

validation_measure_factories:
- type_hint: AggregativeCLMCrossEntropyLossFactory
evaluation_measures:
- component_key: eval_measures
variant_key: clm_cross_entropy_loss
config:
target_key: ${data.target_key}
prediction_key: ${model.config.prediction_key}
10 changes: 5 additions & 5 deletions config_files/config_lorem_ipsum.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -170,12 +170,12 @@ loss_fn:
target_key: target_ids
prediction_key: logits

# TODO adapt this to new config scheme
validation_measure_factories:
- type_hint: AggregativeCLMCrossEntropyLossFactory
evaluation_measures:
- component_key: eval_measures
variant_key: clm_cross_entropy_loss
config:
target_key: ${data.target_key}
prediction_key: ${model.config.prediction_key}
target_key: target_ids
prediction_key: logits

wrapped_model:
component_key: model
Expand Down
6 changes: 0 additions & 6 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,6 @@ class CLMCrossEntropyLossConfig(BaseModel):
target_key: str
prediction_key: str

# TODO adapt to new config scheme
class ValidationMeasureFactoryConfig(BaseModel):
type_hint: ValidationMeasureFactoryTypes
config: CLMCrossEntropyLossConfig


# Checkpointing
class SaveEveryKStepsCheckpointingStrategyConfig(BaseModel):
Expand Down Expand Up @@ -286,7 +281,6 @@ class CudaEnv(BaseModel):
world_size: Annotated[int, Field(strict=True, ge=1)]
global_rank: Annotated[int, Field(strict=True, ge=0)]

validation_measure_factories: List[ValidationMeasureFactoryConfig]

class Settings(BaseModel):
class Training(BaseModel):
Expand Down
10 changes: 10 additions & 0 deletions src/modalities/registry/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
)
from modalities.dataloader.dataloader_factory import DataloaderFactory
from modalities.dataloader.dataset_factory import DatasetFactory
from modalities.evaluation.clm_cross_entropy_loss import AggregativeCLMCrossEntropyLossFactory
from modalities.evaluation.perplexity import AggregativePerplexityFactory
from modalities.logging_broker.subscriber_impl.subscriber_factory import (
ProgressSubscriberFactory,
ResultsSubscriberFactory,
Expand Down Expand Up @@ -71,6 +73,14 @@ class ComponentEntity:
ComponentEntity("model", "fsdp_wrapped", ModelFactory.get_fsdp_wrapped_model, FSDPWrappedModelConfig),
# losses
ComponentEntity("loss", "clm_cross_entropy_loss", CLMCrossEntropyLoss, CLMCrossEntropyLossConfig),
# EvalLosses
ComponentEntity(
"eval_measures",
"clm_cross_entropy_loss",
AggregativeCLMCrossEntropyLossFactory,
CLMCrossEntropyLossConfig,
),
ComponentEntity("eval_measures", "perplexity", AggregativePerplexityFactory, CLMCrossEntropyLossConfig),
# optmizers
ComponentEntity("optimizer", "adam_w", OptimizerFactory.get_adam_w, AdamWOptimizerConfig),
ComponentEntity(
Expand Down

0 comments on commit 67e77ee

Please sign in to comment.