diff --git a/config_files/config_example_hf_meditron_7B_instruction.yaml b/config_files/config_example_hf_meditron_7B_instruction.yaml index ad018508..94c08c79 100644 --- a/config_files/config_example_hf_meditron_7B_instruction.yaml +++ b/config_files/config_example_hf_meditron_7B_instruction.yaml @@ -159,9 +159,9 @@ loss_fn: target_key: ${settings.referencing_keys.target_key} prediction_key: ${settings.referencing_keys.prediction_key} -# TODO adapt this to new config scheme -validation_measure_factories: - - type_hint: AggregativeCLMCrossEntropyLossFactory +evaluation_measures: + - component_key: eval_measures + variant_key: clm_cross_entropy_loss config: target_key: ${data.target_key} prediction_key: ${model.config.prediction_key} diff --git a/config_files/config_example_mem_map_dataset.yaml b/config_files/config_example_mem_map_dataset.yaml index 0c2af2d7..84535292 100644 --- a/config_files/config_example_mem_map_dataset.yaml +++ b/config_files/config_example_mem_map_dataset.yaml @@ -177,9 +177,9 @@ loss_fn: target_key: ${settings.referencing_keys.target_key} prediction_key: ${settings.referencing_keys.prediction_key} -# TODO adapt this to new config scheme -validation_measure_factories: - - type_hint: AggregativeCLMCrossEntropyLossFactory +evaluation_measures: + - component_key: eval_measures + variant_key: clm_cross_entropy_loss config: target_key: ${data.target_key} prediction_key: ${model.config.prediction_key} diff --git a/config_files/config_example_openGPTx_dataset.yaml b/config_files/config_example_openGPTx_dataset.yaml index 715c4166..e0ca6eaf 100644 --- a/config_files/config_example_openGPTx_dataset.yaml +++ b/config_files/config_example_openGPTx_dataset.yaml @@ -178,8 +178,9 @@ loss: target_key: ${data.target_key} prediction_key: ${model.config.prediction_key} -validation_measure_factories: - - type_hint: AggregativeCLMCrossEntropyLossFactory +evaluation_measures: + - component_key: eval_measures + variant_key: clm_cross_entropy_loss config: target_key: ${data.target_key} prediction_key: ${model.config.prediction_key} diff --git a/config_files/config_lorem_ipsum.yaml b/config_files/config_lorem_ipsum.yaml index 8b29951b..6e4e3934 100644 --- a/config_files/config_lorem_ipsum.yaml +++ b/config_files/config_lorem_ipsum.yaml @@ -170,12 +170,12 @@ loss_fn: target_key: target_ids prediction_key: logits -# TODO adapt this to new config scheme -validation_measure_factories: - - type_hint: AggregativeCLMCrossEntropyLossFactory +evaluation_measures: + - component_key: eval_measures + variant_key: clm_cross_entropy_loss config: - target_key: ${data.target_key} - prediction_key: ${model.config.prediction_key} + target_key: target_ids + prediction_key: logits wrapped_model: component_key: model diff --git a/src/modalities/config/config.py b/src/modalities/config/config.py index 7a558122..353cf14e 100644 --- a/src/modalities/config/config.py +++ b/src/modalities/config/config.py @@ -92,11 +92,6 @@ class CLMCrossEntropyLossConfig(BaseModel): target_key: str prediction_key: str -# TODO adapt to new config scheme -class ValidationMeasureFactoryConfig(BaseModel): - type_hint: ValidationMeasureFactoryTypes - config: CLMCrossEntropyLossConfig - # Checkpointing class SaveEveryKStepsCheckpointingStrategyConfig(BaseModel): @@ -286,7 +281,6 @@ class CudaEnv(BaseModel): world_size: Annotated[int, Field(strict=True, ge=1)] global_rank: Annotated[int, Field(strict=True, ge=0)] - validation_measure_factories: List[ValidationMeasureFactoryConfig] class Settings(BaseModel): class Training(BaseModel): diff --git a/src/modalities/registry/components.py b/src/modalities/registry/components.py index 40dfccfd..9c5cc4d2 100644 --- a/src/modalities/registry/components.py +++ b/src/modalities/registry/components.py @@ -38,6 +38,8 @@ ) from modalities.dataloader.dataloader_factory import DataloaderFactory from modalities.dataloader.dataset_factory import DatasetFactory +from modalities.evaluation.clm_cross_entropy_loss import AggregativeCLMCrossEntropyLossFactory +from modalities.evaluation.perplexity import AggregativePerplexityFactory from modalities.logging_broker.subscriber_impl.subscriber_factory import ( ProgressSubscriberFactory, ResultsSubscriberFactory, @@ -71,6 +73,14 @@ class ComponentEntity: ComponentEntity("model", "fsdp_wrapped", ModelFactory.get_fsdp_wrapped_model, FSDPWrappedModelConfig), # losses ComponentEntity("loss", "clm_cross_entropy_loss", CLMCrossEntropyLoss, CLMCrossEntropyLossConfig), + # EvalLosses + ComponentEntity( + "eval_measures", + "clm_cross_entropy_loss", + AggregativeCLMCrossEntropyLossFactory, + CLMCrossEntropyLossConfig, + ), + ComponentEntity("eval_measures", "perplexity", AggregativePerplexityFactory, CLMCrossEntropyLossConfig), # optmizers ComponentEntity("optimizer", "adam_w", OptimizerFactory.get_adam_w, AdamWOptimizerConfig), ComponentEntity(