diff --git a/src/encord_active/app/common/components/metric_summary.py b/src/encord_active/app/common/components/metric_summary.py index e5a530745..59f275ff3 100644 --- a/src/encord_active/app/common/components/metric_summary.py +++ b/src/encord_active/app/common/components/metric_summary.py @@ -329,7 +329,7 @@ def render_summary_item(row, metric_name: str, iqr_outliers: IqrOutliers, metric image = show_image_and_draw_polygons(row, get_state().project_paths.data, get_state().object_drawing_configurations) st.image(image) - multiselect_tag(row, f"{metric_name}_summary", metric_scope) + multiselect_tag(row, f"{metric_name}_summary") # === Write scores and link to editor === # diff --git a/src/encord_active/app/common/components/prediction_grid.py b/src/encord_active/app/common/components/prediction_grid.py index d43ef8d5b..678fda831 100644 --- a/src/encord_active/app/common/components/prediction_grid.py +++ b/src/encord_active/app/common/components/prediction_grid.py @@ -21,7 +21,6 @@ show_image_and_draw_polygons, show_image_with_predictions_and_label, ) -from encord_active.lib.metrics.utils import MetricScope from encord_active.lib.model_predictions.reader import ( LabelMatchSchema, PredictionMatchSchema, @@ -46,7 +45,7 @@ def build_card_for_labels( draw_configurations=get_state().object_drawing_configurations, ) st.image(image) - multiselect_tag(label, "false_negatives", MetricScope.MODEL_QUALITY) + multiselect_tag(label, "false_negatives") cls = get_state().predictions.all_classes[str(label["class_id"])]["name"] label = label.copy() @@ -60,7 +59,7 @@ def build_card_for_predictions(row: pd.Series, data_dir: Path, box_color=Color.G image = show_image_and_draw_polygons(row, data_dir, draw_configurations=conf, skip_object_hash=True) image = draw_object(image, row, draw_configuration=conf, color=box_color, with_box=True) st.image(image) - multiselect_tag(row, "metric_view", MetricScope.MODEL_QUALITY) + multiselect_tag(row, "metric_view", is_predictions=True) # === Write scores and link to editor === # build_data_tags(row, get_state().predictions.metric_datas.selected_predicion) @@ -101,7 +100,7 @@ def prediction_grid( subset = render_df_slicer(df, selected_metric) paginated_subset = render_pagination(subset, n_cols, n_rows, selected_metric) - form = bulk_tagging_form(MetricScope.MODEL_QUALITY) + form = bulk_tagging_form(subset, is_predictions=True) if form and form.submitted: df = paginated_subset if form.level == BulkLevel.PAGE else subset action_bulk_tags(df, form.tags, form.action) diff --git a/src/encord_active/app/common/components/tags/bulk_tagging_form.py b/src/encord_active/app/common/components/tags/bulk_tagging_form.py index c1b578df1..0843d3dec 100644 --- a/src/encord_active/app/common/components/tags/bulk_tagging_form.py +++ b/src/encord_active/app/common/components/tags/bulk_tagging_form.py @@ -2,7 +2,7 @@ from typing import List, NamedTuple, Optional import streamlit as st -from pandas import DataFrame +from pandera.typing import DataFrame from encord_active.app.common.components.tags.individual_tagging import ( target_identifier, @@ -10,8 +10,8 @@ from encord_active.app.common.components.tags.tag_creator import scoped_tags from encord_active.app.common.state import get_state from encord_active.lib.db.merged_metrics import MergedMetrics -from encord_active.lib.db.tags import METRIC_SCOPE_TAG_SCOPES, Tag -from encord_active.lib.metrics.utils import MetricScope +from encord_active.lib.db.tags import Tag, TagScope +from encord_active.lib.metrics.utils import IdentifierSchema class TagAction(str, Enum): @@ -31,7 +31,7 @@ class TaggingFormResult(NamedTuple): action: TagAction -def action_bulk_tags(subset: DataFrame, selected_tags: List[Tag], action: TagAction): +def action_bulk_tags(subset: DataFrame[IdentifierSchema], selected_tags: List[Tag], action: TagAction): if not selected_tags: return @@ -39,25 +39,26 @@ def action_bulk_tags(subset: DataFrame, selected_tags: List[Tag], action: TagAct for selected_tag in selected_tags: target_ids = [target_identifier(id, selected_tag.scope) for id in subset.identifier.to_list()] - for id, tags in all_df.loc[target_ids, "tags"].items(): - if action == TagAction.ADD: - next = list(set(tags + [selected_tag])) - elif action == TagAction.REMOVE: - next = list(set(tag for tag in tags if tag != selected_tag)) - else: - raise Exception(f"Action {action} is not supported") - - all_df.at[id, "tags"] = next + for _, tags in all_df.loc[target_ids, "tags"].items(): + if action == TagAction.ADD and selected_tag not in tags: + tags.append(selected_tag) + elif action == TagAction.REMOVE and selected_tag in tags: + tags.remove(selected_tag) get_state().merged_metrics = all_df MergedMetrics().replace_all(all_df) -def bulk_tagging_form(metric_type: MetricScope) -> Optional[TaggingFormResult]: +def bulk_tagging_form(df: DataFrame[IdentifierSchema], is_predictions: bool = False) -> Optional[TaggingFormResult]: with st.expander("Bulk Tagging"): with st.form("bulk_tagging"): select, level_radio, action_radio, button = st.columns([6, 2, 2, 1]) - allowed_tags = scoped_tags(METRIC_SCOPE_TAG_SCOPES[metric_type]) + allowed_tag_scopes = {TagScope.DATA} + all_rows_have_objects = df[IdentifierSchema.identifier].map(lambda x: len(x.split("_")) > 3).all() + if all_rows_have_objects and not is_predictions: + allowed_tag_scopes.add(TagScope.LABEL) + allowed_tags = scoped_tags(allowed_tag_scopes) + selected_tags = select.multiselect( label="Tags", options=allowed_tags, format_func=lambda x: x[0], label_visibility="collapsed" ) diff --git a/src/encord_active/app/common/components/tags/individual_tagging.py b/src/encord_active/app/common/components/tags/individual_tagging.py index c16eda2fb..966f308b9 100644 --- a/src/encord_active/app/common/components/tags/individual_tagging.py +++ b/src/encord_active/app/common/components/tags/individual_tagging.py @@ -1,4 +1,5 @@ -from typing import Dict, List, Optional +from functools import partial +from typing import Dict, List, Optional, Set import streamlit as st from pandas import Series @@ -6,13 +7,12 @@ from encord_active.app.common.components.tags.tag_creator import scoped_tags from encord_active.app.common.state import get_state from encord_active.lib.db.merged_metrics import MergedMetrics -from encord_active.lib.db.tags import METRIC_SCOPE_TAG_SCOPES, Tag, TagScope -from encord_active.lib.metrics.utils import MetricScope +from encord_active.lib.db.tags import Tag, TagScope def target_identifier(identifier: str, scope: TagScope) -> Optional[str]: parts = identifier.split("_") - data, object = parts[:3], parts[3:] + data = parts[:3] if scope == TagScope.DATA: return "_".join(data) @@ -22,28 +22,33 @@ def target_identifier(identifier: str, scope: TagScope) -> Optional[str]: return None -def update_tags(identifier: str, key: str): +def update_tags(identifier: str, key: str, scopes: Set[TagScope]): tags_for_update: List[Tag] = st.session_state[key] targeted_tags: Dict[str, List[Tag]] = {} - for scope in TagScope: + for scope in scopes: target_id = target_identifier(identifier, scope) if target_id: targeted_tags[target_id] = [tag for tag in tags_for_update if tag.scope == scope] for id, tags in targeted_tags.items(): - get_state().merged_metrics.at[id, "tags"] = tags + tag_arr = get_state().merged_metrics.at[id, "tags"] + tag_arr.clear() + tag_arr.extend(tags) MergedMetrics().update_tags(id, tags) -def multiselect_tag(row: Series, key_prefix: str, metric_type: MetricScope): +def multiselect_tag(row: Series, key_prefix: str, is_predictions=False): identifier = row["identifier"] if not isinstance(identifier, str): st.error("Multiple rows with the same identifier were found. Please create a new issue.") return - metric_scopes = METRIC_SCOPE_TAG_SCOPES[metric_type] + _, _, _, *objects = identifier.split("_") + metric_scopes = {TagScope.DATA} + if objects and not is_predictions: + metric_scopes.add(TagScope.LABEL) tag_status = [] merged_metrics = get_state().merged_metrics @@ -61,6 +66,6 @@ def multiselect_tag(row: Series, key_prefix: str, metric_type: MetricScope): default=tag_status if len(tag_status) else None, key=key, label_visibility="collapsed", - on_change=update_tags, + on_change=partial(update_tags, scopes=metric_scopes), args=(identifier, key), ) diff --git a/src/encord_active/app/data_quality/sub_pages/explorer.py b/src/encord_active/app/data_quality/sub_pages/explorer.py index 5ab2cca78..c8bfa62c1 100644 --- a/src/encord_active/app/data_quality/sub_pages/explorer.py +++ b/src/encord_active/app/data_quality/sub_pages/explorer.py @@ -41,6 +41,7 @@ ) from encord_active.lib.metrics.metric import EmbeddingType from encord_active.lib.metrics.utils import ( + IdentifierSchema, MetricData, MetricSchema, MetricScope, @@ -223,7 +224,7 @@ def fill_data_quality_window( paginated_subset = render_pagination(subset, n_cols, n_rows, "score") - form = bulk_tagging_form(metric_scope) + form = bulk_tagging_form(subset.pipe(DataFrame[IdentifierSchema])) if form and form.submitted: df = paginated_subset if form.level == BulkLevel.PAGE else subset @@ -286,7 +287,7 @@ def build_card( description = re.sub(r"(\d+\.\d{0,3})\d*", r"\1", row["description"]) st.write(f"Description: {description}") - multiselect_tag(row, "explorer", metric_scope) + multiselect_tag(row, "explorer") target_expander = similarity_expanders[card_no // get_state().page_grid_settings.columns] diff --git a/src/encord_active/lib/db/tags.py b/src/encord_active/lib/db/tags.py index 36bb5e24a..e3b5cd7d2 100644 --- a/src/encord_active/lib/db/tags.py +++ b/src/encord_active/lib/db/tags.py @@ -3,7 +3,6 @@ from typing import Callable, List, NamedTuple from encord_active.lib.db.connection import DBConnection -from encord_active.lib.metrics.utils import MetricScope TABLE_NAME = "tags" @@ -23,12 +22,6 @@ class Tag(NamedTuple): TagScope.LABEL.value: "✏️", } -METRIC_SCOPE_TAG_SCOPES = { - MetricScope.DATA_QUALITY: {TagScope.DATA}, - MetricScope.LABEL_QUALITY: {TagScope.DATA, TagScope.LABEL}, - MetricScope.MODEL_QUALITY: {TagScope.DATA}, -} - def ensure_existence(fn: Callable): def wrapper(*args, **kwargs):