diff --git a/src/pytorch_tabular/config/config.py b/src/pytorch_tabular/config/config.py index 999c2c4a..ce6e1d0d 100644 --- a/src/pytorch_tabular/config/config.py +++ b/src/pytorch_tabular/config/config.py @@ -192,9 +192,9 @@ class DataConfig: ) def __post_init__(self): - assert ( - len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0 - ), "There should be at-least one feature defined in categorical, continuous, or date columns" + assert len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0, ( + "There should be at-least one feature defined in categorical, continuous, or date columns" + ) _validate_choices(self) if os.name == "nt" and self.num_workers != 0: print("Windows does not support num_workers > 0. Setting num_workers to 0") @@ -255,9 +255,9 @@ class InferredConfig: def __post_init__(self): if self.embedding_dims is not None: - assert all( - (isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims - ), "embedding_dims must be a list of tuples (cardinality, embedding_dim)" + assert all((isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims), ( + "embedding_dims must be a list of tuples (cardinality, embedding_dim)" + ) self.embedded_cat_dim = sum([t[1] for t in self.embedding_dims]) else: self.embedded_cat_dim = 0 @@ -287,8 +287,8 @@ class TrainerConfig: 'cpu','gpu','tpu','ipu', 'mps', 'auto'. Defaults to 'auto'. Choices are: [`cpu`,`gpu`,`tpu`,`ipu`,'mps',`auto`]. - devices (Optional[int]): Number of devices to train on (int). -1 uses all available devices. By - default, uses all available devices (-1) + devices (Union[int, List[int]]): Number of devices to train on (int), or list of device indices. + -1 uses all available devices. By default, uses all available devices (-1) devices_list (Optional[List[int]]): List of devices to train on (list). If specified, takes precedence over `devices` argument. Defaults to None @@ -400,7 +400,7 @@ class TrainerConfig: "choices": ["cpu", "gpu", "tpu", "ipu", "mps", "auto"], }, ) - devices: Optional[int] = field( + devices: Any = field( default=-1, metadata={ "help": "Number of devices to train on. -1 uses all available devices." diff --git a/tests/test_config.py b/tests/test_config.py new file mode 100644 index 00000000..a59dab16 --- /dev/null +++ b/tests/test_config.py @@ -0,0 +1,73 @@ +#!/usr/bin/env python +"""Tests for config classes.""" + +from omegaconf import OmegaConf + +from pytorch_tabular.config import TrainerConfig + + +class TestTrainerConfig: + """Tests for TrainerConfig class.""" + + def test_devices_list_to_devices_conversion(self): + """Test that devices_list is properly converted to devices.""" + # Test with a list of devices + trainer_config = TrainerConfig(devices_list=[0, 1]) + assert trainer_config.devices == [0, 1] + + # Wrap with OmegaConf as done in TabularModel + config = OmegaConf.structured(trainer_config) + assert config.devices == [0, 1] + + def test_devices_list_multiple_gpus(self): + """Test devices_list with multiple GPU IDs as documented.""" + trainer_config = TrainerConfig(devices_list=[1, 2, 3, 4]) + assert trainer_config.devices == [1, 2, 3, 4] + + config = OmegaConf.structured(trainer_config) + assert config.devices == [1, 2, 3, 4] + + def test_devices_int_value(self): + """Test that devices accepts integer values.""" + trainer_config = TrainerConfig(devices=2) + assert trainer_config.devices == 2 + + config = OmegaConf.structured(trainer_config) + assert config.devices == 2 + + def test_devices_default_value(self): + """Test that devices has default value of -1.""" + trainer_config = TrainerConfig() + assert trainer_config.devices == -1 + + config = OmegaConf.structured(trainer_config) + assert config.devices == -1 + + def test_devices_list_single_device(self): + """Test devices_list with a single device.""" + trainer_config = TrainerConfig(devices_list=[0]) + assert trainer_config.devices == [0] + + config = OmegaConf.structured(trainer_config) + assert config.devices == [0] + + def test_devices_list_precedence(self): + """Test that devices_list takes precedence over devices.""" + # When both are provided, devices_list should take precedence + trainer_config = TrainerConfig(devices=2, devices_list=[0, 1]) + assert trainer_config.devices == [0, 1] + + config = OmegaConf.structured(trainer_config) + assert config.devices == [0, 1] + + def test_omegaconf_merge_compatibility(self): + """Test that config works correctly with OmegaConf.merge.""" + trainer_config = TrainerConfig(devices_list=[0, 1], max_epochs=10) + config = OmegaConf.structured(trainer_config) + + # Simulate merging as done in TabularModel + merged = OmegaConf.merge(OmegaConf.to_container(config), {"accelerator": "gpu"}) + + assert merged.devices == [0, 1] + assert merged.max_epochs == 10 + assert merged.accelerator == "gpu"