Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions src/pytorch_tabular/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,9 @@ class DataConfig:
)

def __post_init__(self):
assert (
len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0
), "There should be at-least one feature defined in categorical, continuous, or date columns"
assert len(self.categorical_cols) + len(self.continuous_cols) + len(self.date_columns) > 0, (
"There should be at-least one feature defined in categorical, continuous, or date columns"
)
_validate_choices(self)
if os.name == "nt" and self.num_workers != 0:
print("Windows does not support num_workers > 0. Setting num_workers to 0")
Expand Down Expand Up @@ -255,9 +255,9 @@ class InferredConfig:

def __post_init__(self):
if self.embedding_dims is not None:
assert all(
(isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims
), "embedding_dims must be a list of tuples (cardinality, embedding_dim)"
assert all((isinstance(t, Iterable) and len(t) == 2) for t in self.embedding_dims), (
"embedding_dims must be a list of tuples (cardinality, embedding_dim)"
)
self.embedded_cat_dim = sum([t[1] for t in self.embedding_dims])
else:
self.embedded_cat_dim = 0
Expand Down Expand Up @@ -287,8 +287,8 @@ class TrainerConfig:
'cpu','gpu','tpu','ipu', 'mps', 'auto'. Defaults to 'auto'.
Choices are: [`cpu`,`gpu`,`tpu`,`ipu`,'mps',`auto`].

devices (Optional[int]): Number of devices to train on (int). -1 uses all available devices. By
default, uses all available devices (-1)
devices (Union[int, List[int]]): Number of devices to train on (int), or list of device indices.
-1 uses all available devices. By default, uses all available devices (-1)

devices_list (Optional[List[int]]): List of devices to train on (list). If specified, takes
precedence over `devices` argument. Defaults to None
Expand Down Expand Up @@ -400,7 +400,7 @@ class TrainerConfig:
"choices": ["cpu", "gpu", "tpu", "ipu", "mps", "auto"],
},
)
devices: Optional[int] = field(
devices: Any = field(
default=-1,
metadata={
"help": "Number of devices to train on. -1 uses all available devices."
Expand Down
73 changes: 73 additions & 0 deletions tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#!/usr/bin/env python
"""Tests for config classes."""

from omegaconf import OmegaConf

from pytorch_tabular.config import TrainerConfig


class TestTrainerConfig:
"""Tests for TrainerConfig class."""

def test_devices_list_to_devices_conversion(self):
"""Test that devices_list is properly converted to devices."""
# Test with a list of devices
trainer_config = TrainerConfig(devices_list=[0, 1])
assert trainer_config.devices == [0, 1]

# Wrap with OmegaConf as done in TabularModel
config = OmegaConf.structured(trainer_config)
assert config.devices == [0, 1]

def test_devices_list_multiple_gpus(self):
"""Test devices_list with multiple GPU IDs as documented."""
trainer_config = TrainerConfig(devices_list=[1, 2, 3, 4])
assert trainer_config.devices == [1, 2, 3, 4]

config = OmegaConf.structured(trainer_config)
assert config.devices == [1, 2, 3, 4]

def test_devices_int_value(self):
"""Test that devices accepts integer values."""
trainer_config = TrainerConfig(devices=2)
assert trainer_config.devices == 2

config = OmegaConf.structured(trainer_config)
assert config.devices == 2

def test_devices_default_value(self):
"""Test that devices has default value of -1."""
trainer_config = TrainerConfig()
assert trainer_config.devices == -1

config = OmegaConf.structured(trainer_config)
assert config.devices == -1

def test_devices_list_single_device(self):
"""Test devices_list with a single device."""
trainer_config = TrainerConfig(devices_list=[0])
assert trainer_config.devices == [0]

config = OmegaConf.structured(trainer_config)
assert config.devices == [0]

def test_devices_list_precedence(self):
"""Test that devices_list takes precedence over devices."""
# When both are provided, devices_list should take precedence
trainer_config = TrainerConfig(devices=2, devices_list=[0, 1])
assert trainer_config.devices == [0, 1]

config = OmegaConf.structured(trainer_config)
assert config.devices == [0, 1]

def test_omegaconf_merge_compatibility(self):
"""Test that config works correctly with OmegaConf.merge."""
trainer_config = TrainerConfig(devices_list=[0, 1], max_epochs=10)
config = OmegaConf.structured(trainer_config)

# Simulate merging as done in TabularModel
merged = OmegaConf.merge(OmegaConf.to_container(config), {"accelerator": "gpu"})

assert merged.devices == [0, 1]
assert merged.max_epochs == 10
assert merged.accelerator == "gpu"