Skip to content

Commit

Permalink
Fix incorrect all_groups order configuration in HLabelInfo (#4067)
Browse files Browse the repository at this point in the history
* Fix all_labels

* Update CHAGELOG

* label_groups change
  • Loading branch information
harimkang authored Oct 25, 2024
1 parent 8bba44c commit a2d2c81
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 11 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ All notable changes to this project will be documented in this file.
(<https://github.com/openvinotoolkit/training_extensions/pull/4018>)
- Update HPO interface
(<https://github.com/openvinotoolkit/training_extensions/pull/4035>)
- Bump onnx to 1.17.0 to omit CVE-2024-5187
(<https://github.com/openvinotoolkit/training_extensions/pull/4063>)

### Bug fixes

Expand Down Expand Up @@ -104,6 +106,8 @@ All notable changes to this project will be documented in this file.
(<https://github.com/openvinotoolkit/training_extensions/pull/4052>)
- Fix applying model's hparams when loading model from checkpoint
(<https://github.com/openvinotoolkit/training_extensions/pull/4057>)
- Fix incorrect all_groups order configuration in HLabelInfo
(<https://github.com/openvinotoolkit/training_extensions/pull/4067>)

## \[v2.1.0\]

Expand Down
2 changes: 1 addition & 1 deletion src/otx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# Copyright (C) 2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

__version__ = "2.2.0rc11"
__version__ = "2.2.0rc12"

import os
from pathlib import Path
Expand Down
21 changes: 11 additions & 10 deletions src/otx/core/types/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,10 +169,8 @@ def from_dm_label_groups(cls, dm_label_categories: LabelCategories) -> HLabelInf
dm_label_categories (LabelCategories): the label categories of datumaro.
"""

def get_exclusive_group_info(all_groups: list[Label | list[Label]]) -> dict[str, Any]:
def get_exclusive_group_info(exclusive_groups: list[Label | list[Label]]) -> dict[str, Any]:
"""Get exclusive group information."""
exclusive_groups = [g for g in all_groups if len(g) > 1]

last_logits_pos = 0
num_single_label_classes = 0
head_idx_to_logits_range = {}
Expand All @@ -193,12 +191,10 @@ def get_exclusive_group_info(all_groups: list[Label | list[Label]]) -> dict[str,
}

def get_single_label_group_info(
all_groups: list[Label | list[Label]],
single_label_groups: list[Label | list[Label]],
num_exclusive_groups: int,
) -> dict[str, Any]:
"""Get single label group information."""
single_label_groups = [g for g in all_groups if len(g) == 1]

class_to_idx = {}

for i, group in enumerate(single_label_groups):
Expand Down Expand Up @@ -256,8 +252,13 @@ def convert_labels_if_needed(
label_names = [item.name for item in dm_label_categories.items]
all_groups = convert_labels_if_needed(dm_label_categories, label_names)

exclusive_group_info = get_exclusive_group_info(all_groups)
single_label_group_info = get_single_label_group_info(all_groups, exclusive_group_info["num_multiclass_heads"])
exclusive_groups = [g for g in all_groups if len(g) > 1]
exclusive_group_info = get_exclusive_group_info(exclusive_groups)
single_label_groups = [g for g in all_groups if len(g) == 1]
single_label_group_info = get_single_label_group_info(
single_label_groups,
exclusive_group_info["num_multiclass_heads"],
)

merged_class_to_idx = merge_class_to_idx(
exclusive_group_info["class_to_idx"],
Expand All @@ -268,13 +269,13 @@ def convert_labels_if_needed(

return HLabelInfo(
label_names=label_names,
label_groups=all_groups,
label_groups=exclusive_groups + single_label_groups,
num_multiclass_heads=exclusive_group_info["num_multiclass_heads"],
num_multilabel_classes=single_label_group_info["num_multilabel_classes"],
head_idx_to_logits_range=exclusive_group_info["head_idx_to_logits_range"],
num_single_label_classes=exclusive_group_info["num_single_label_classes"],
class_to_group_idx=merged_class_to_idx,
all_groups=all_groups,
all_groups=exclusive_groups + single_label_groups,
label_to_idx=label_to_idx,
label_tree_edges=get_label_tree_edges(dm_label_categories.items),
empty_multiclass_head_indices=[], # consider the label removing case
Expand Down

0 comments on commit a2d2c81

Please sign in to comment.