Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Unique Experiment ID #181

Merged
merged 4 commits into from
Jul 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions src/modalities/config/config.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from functools import partial
import os
from pathlib import Path
from typing import Annotated, Callable, Dict, List, Literal, Optional, Tuple
Expand Down Expand Up @@ -26,7 +27,7 @@
)
from modalities.config.utils import parse_torch_device
from modalities.running_env.env_utils import MixedPrecisionSettings, has_bfloat_support
from modalities.util import get_date_of_run, parse_enum_by_name
from modalities.util import get_experiment_id_of_run, parse_enum_by_name


class ProcessGroupBackendType(LookupEnum):
Expand Down Expand Up @@ -356,18 +357,20 @@ def cuda_env_resolver_fun(var_name: str) -> int:
int_env_variable_names = ["LOCAL_RANK", "WORLD_SIZE", "RANK"]
return int(os.getenv(var_name)) if var_name in int_env_variable_names else os.getenv(var_name)

def modalities_env_resolver_fun(var_name: str) -> int:
def modalities_env_resolver_fun(var_name: str, config_file_path: Path) -> str | Path:
if var_name == "experiment_id":
return get_date_of_run()
if var_name == "config_file_path":
return get_experiment_id_of_run(config_file_path=config_file_path)
elif var_name == "config_file_path":
return config_file_path
else:
raise ValueError(f"Unknown modalities_env variable: {var_name}.")

def node_env_resolver_fun(var_name: str) -> int:
if var_name == "num_cpus":
return os.cpu_count()

OmegaConf.register_new_resolver("cuda_env", cuda_env_resolver_fun, replace=True)
OmegaConf.register_new_resolver("modalities_env", modalities_env_resolver_fun, replace=True)
OmegaConf.register_new_resolver("modalities_env", partial(modalities_env_resolver_fun, config_file_path=config_file_path), replace=True)
OmegaConf.register_new_resolver("node_env", node_env_resolver_fun, replace=True)

cfg = OmegaConf.load(config_file_path)
Expand Down
21 changes: 13 additions & 8 deletions src/modalities/util.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import hashlib
import time
import warnings
from datetime import datetime
from enum import Enum
from pathlib import Path
from types import TracebackType
from typing import Callable, Dict, Generic, Type, TypeVar
from typing import Callable, Dict, Generic, Optional, Type, TypeVar

import torch
import torch.distributed as dist
Expand Down Expand Up @@ -36,12 +38,14 @@ def get_callback_interval_in_batches_per_rank(
return num_local_train_micro_batches_ret


def get_date_of_run():
"""create date and time for file save uniqueness
example: 2022-05-07__14-31-22'
def get_experiment_id_of_run(config_file_path: Path, hash_length: Optional[int] = 8) -> str:
"""create experiment ID including the date and time for file save uniqueness
example: 2022-05-07__14-31-22_fdh1xaj2'
"""
hash = hashlib.sha256(str(config_file_path).encode()).hexdigest()[:hash_length]
date_of_run = datetime.now().strftime("%Y-%m-%d__%H-%M-%S")
return date_of_run
experiment_id = f"{date_of_run}_{hash}"
return experiment_id


def format_metrics_to_gb(item):
Expand Down Expand Up @@ -137,8 +141,9 @@ def get_all_reduced_value(
)
return value

def get_module_class_from_name(module: torch.nn.Module, name:str) -> Type[torch.nn.Module] | None:
""" From Accelerate source code

def get_module_class_from_name(module: torch.nn.Module, name: str) -> Type[torch.nn.Module] | None:
"""From Accelerate source code
(https://github.com/huggingface/accelerate/blob/1f7a79b428749f45187ec69485f2c966fe21926e/src/accelerate/utils/dataclasses.py#L1902)
Gets a class from a module by its name.

Expand All @@ -155,4 +160,4 @@ def get_module_class_from_name(module: torch.nn.Module, name:str) -> Type[torch.
for child_module in modules_children:
module_class = get_module_class_from_name(child_module, name)
if module_class is not None:
return module_class
return module_class
Loading