Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Logging Hyperparameters for list of dicts #19957

Open
vork opened this issue Jun 7, 2024 · 0 comments · May be fixed by #19963
Open

Logging Hyperparameters for list of dicts #19957

vork opened this issue Jun 7, 2024 · 0 comments · May be fixed by #19963
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.2.x

Comments

@vork
Copy link

vork commented Jun 7, 2024

Bug description

Currently, when hyper parameters are logged with log_hyperparams the function calls _flatten_dict to collapse the dict to a single level. However, when the config contains a list of dicts, this gets flattened to a single string. Instead I would propose to log the list as [key/0/item, key/1/item] etc.

A fix could be simple:

def _flatten_dict(params: MutableMapping[Any, Any], delimiter: str = "/", parent_key: str = "") -> Dict[str, Any]:
    """Flatten hierarchical dict, e.g. ``{'a': {'b': 'c'}} -> {'a/b': 'c'}``.

    Args:
        params: Dictionary containing the hyperparameters
        delimiter: Delimiter to express the hierarchy. Defaults to ``'/'``.

    Returns:
        Flattened dict.

    Examples:
        >>> _flatten_dict({'a': {'b': 'c'}})
        {'a/b': 'c'}
        >>> _flatten_dict({'a': {'b': 123}})
        {'a/b': 123}
        >>> _flatten_dict({5: {'a': 123}})
        {'5/a': 123}
        >>> _flatten_dict({"dl": [{"a": 1, "c": 3}, {"b": 2, "d": 5}], "l": [1, 2, 3, 4]})
        {'dl/0/a': 1, 'dl/0/c': 3, 'dl/1/b': 2, 'dl/1/d': 5, 'l': [1, 2, 3, 4]}

    """
    result: Dict[str, Any] = {}
    for k, v in params.items():
        new_key = parent_key + delimiter + str(k) if parent_key else str(k)
        if is_dataclass(v):
            v = asdict(v)
        elif isinstance(v, Namespace):
            v = vars(v)

        if isinstance(v, MutableMapping):
            result = {**result, **_flatten_dict(v, parent_key=new_key, delimiter=delimiter)}
        # Also handle the case where v is a list of dictionaries
        elif isinstance(v, list) and all(isinstance(item, MutableMapping) for item in v):
            for i, item in enumerate(v):
                result = {**result, **_flatten_dict(item, parent_key=f"{new_key}/{i}", delimiter=delimiter)}
        else:
            result[new_key] = v
    return result

What version are you seeing the problem on?

master

How to reproduce the bug

No response

Error messages and logs

# Error messages and logs here please

Environment

Current environment
#- Lightning Component (e.g. Trainer, LightningModule, LightningApp, LightningWork, LightningFlow):
#- PyTorch Lightning Version (e.g., 1.5.0):
#- Lightning App Version (e.g., 0.5.2):
#- PyTorch Version (e.g., 2.0):
#- Python version (e.g., 3.9):
#- OS (e.g., Linux):
#- CUDA/cuDNN version:
#- GPU models and configuration:
#- How you installed Lightning(`conda`, `pip`, source):
#- Running environment of LightningApp (e.g. local, cloud):

More info

No response

@vork vork added bug Something isn't working needs triage Waiting to be triaged by maintainers labels Jun 7, 2024
@vork vork linked a pull request Jun 10, 2024 that will close this issue
10 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working needs triage Waiting to be triaged by maintainers ver: 2.2.x
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant