Skip to content

Commit

Permalink
fix: disappearing tags and update tags inplace (#238)
Browse files Browse the repository at this point in the history
  • Loading branch information
frederik-encord authored Mar 8, 2023
1 parent 27a9f72 commit 7629288
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 39 deletions.
2 changes: 1 addition & 1 deletion src/encord_active/app/common/components/metric_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ def render_summary_item(row, metric_name: str, iqr_outliers: IqrOutliers, metric
image = show_image_and_draw_polygons(row, get_state().project_paths.data, get_state().object_drawing_configurations)
st.image(image)

multiselect_tag(row, f"{metric_name}_summary", metric_scope)
multiselect_tag(row, f"{metric_name}_summary")

# === Write scores and link to editor === #

Expand Down
7 changes: 3 additions & 4 deletions src/encord_active/app/common/components/prediction_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
show_image_and_draw_polygons,
show_image_with_predictions_and_label,
)
from encord_active.lib.metrics.utils import MetricScope
from encord_active.lib.model_predictions.reader import (
LabelMatchSchema,
PredictionMatchSchema,
Expand All @@ -46,7 +45,7 @@ def build_card_for_labels(
draw_configurations=get_state().object_drawing_configurations,
)
st.image(image)
multiselect_tag(label, "false_negatives", MetricScope.MODEL_QUALITY)
multiselect_tag(label, "false_negatives")

cls = get_state().predictions.all_classes[str(label["class_id"])]["name"]
label = label.copy()
Expand All @@ -60,7 +59,7 @@ def build_card_for_predictions(row: pd.Series, data_dir: Path, box_color=Color.G
image = show_image_and_draw_polygons(row, data_dir, draw_configurations=conf, skip_object_hash=True)
image = draw_object(image, row, draw_configuration=conf, color=box_color, with_box=True)
st.image(image)
multiselect_tag(row, "metric_view", MetricScope.MODEL_QUALITY)
multiselect_tag(row, "metric_view", is_predictions=True)

# === Write scores and link to editor === #
build_data_tags(row, get_state().predictions.metric_datas.selected_predicion)
Expand Down Expand Up @@ -101,7 +100,7 @@ def prediction_grid(
subset = render_df_slicer(df, selected_metric)
paginated_subset = render_pagination(subset, n_cols, n_rows, selected_metric)

form = bulk_tagging_form(MetricScope.MODEL_QUALITY)
form = bulk_tagging_form(subset, is_predictions=True)
if form and form.submitted:
df = paginated_subset if form.level == BulkLevel.PAGE else subset
action_bulk_tags(df, form.tags, form.action)
Expand Down
31 changes: 16 additions & 15 deletions src/encord_active/app/common/components/tags/bulk_tagging_form.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,16 @@
from typing import List, NamedTuple, Optional

import streamlit as st
from pandas import DataFrame
from pandera.typing import DataFrame

from encord_active.app.common.components.tags.individual_tagging import (
target_identifier,
)
from encord_active.app.common.components.tags.tag_creator import scoped_tags
from encord_active.app.common.state import get_state
from encord_active.lib.db.merged_metrics import MergedMetrics
from encord_active.lib.db.tags import METRIC_SCOPE_TAG_SCOPES, Tag
from encord_active.lib.metrics.utils import MetricScope
from encord_active.lib.db.tags import Tag, TagScope
from encord_active.lib.metrics.utils import IdentifierSchema


class TagAction(str, Enum):
Expand All @@ -31,33 +31,34 @@ class TaggingFormResult(NamedTuple):
action: TagAction


def action_bulk_tags(subset: DataFrame, selected_tags: List[Tag], action: TagAction):
def action_bulk_tags(subset: DataFrame[IdentifierSchema], selected_tags: List[Tag], action: TagAction):
if not selected_tags:
return

all_df = get_state().merged_metrics.copy()

for selected_tag in selected_tags:
target_ids = [target_identifier(id, selected_tag.scope) for id in subset.identifier.to_list()]
for id, tags in all_df.loc[target_ids, "tags"].items():
if action == TagAction.ADD:
next = list(set(tags + [selected_tag]))
elif action == TagAction.REMOVE:
next = list(set(tag for tag in tags if tag != selected_tag))
else:
raise Exception(f"Action {action} is not supported")

all_df.at[id, "tags"] = next
for _, tags in all_df.loc[target_ids, "tags"].items():
if action == TagAction.ADD and selected_tag not in tags:
tags.append(selected_tag)
elif action == TagAction.REMOVE and selected_tag in tags:
tags.remove(selected_tag)

get_state().merged_metrics = all_df
MergedMetrics().replace_all(all_df)


def bulk_tagging_form(metric_type: MetricScope) -> Optional[TaggingFormResult]:
def bulk_tagging_form(df: DataFrame[IdentifierSchema], is_predictions: bool = False) -> Optional[TaggingFormResult]:
with st.expander("Bulk Tagging"):
with st.form("bulk_tagging"):
select, level_radio, action_radio, button = st.columns([6, 2, 2, 1])
allowed_tags = scoped_tags(METRIC_SCOPE_TAG_SCOPES[metric_type])
allowed_tag_scopes = {TagScope.DATA}
all_rows_have_objects = df[IdentifierSchema.identifier].map(lambda x: len(x.split("_")) > 3).all()
if all_rows_have_objects and not is_predictions:
allowed_tag_scopes.add(TagScope.LABEL)
allowed_tags = scoped_tags(allowed_tag_scopes)

selected_tags = select.multiselect(
label="Tags", options=allowed_tags, format_func=lambda x: x[0], label_visibility="collapsed"
)
Expand Down
25 changes: 15 additions & 10 deletions src/encord_active/app/common/components/tags/individual_tagging.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
from typing import Dict, List, Optional
from functools import partial
from typing import Dict, List, Optional, Set

import streamlit as st
from pandas import Series

from encord_active.app.common.components.tags.tag_creator import scoped_tags
from encord_active.app.common.state import get_state
from encord_active.lib.db.merged_metrics import MergedMetrics
from encord_active.lib.db.tags import METRIC_SCOPE_TAG_SCOPES, Tag, TagScope
from encord_active.lib.metrics.utils import MetricScope
from encord_active.lib.db.tags import Tag, TagScope


def target_identifier(identifier: str, scope: TagScope) -> Optional[str]:
parts = identifier.split("_")
data, object = parts[:3], parts[3:]
data = parts[:3]

if scope == TagScope.DATA:
return "_".join(data)
Expand All @@ -22,28 +22,33 @@ def target_identifier(identifier: str, scope: TagScope) -> Optional[str]:
return None


def update_tags(identifier: str, key: str):
def update_tags(identifier: str, key: str, scopes: Set[TagScope]):
tags_for_update: List[Tag] = st.session_state[key]

targeted_tags: Dict[str, List[Tag]] = {}
for scope in TagScope:
for scope in scopes:
target_id = target_identifier(identifier, scope)
if target_id:
targeted_tags[target_id] = [tag for tag in tags_for_update if tag.scope == scope]

for id, tags in targeted_tags.items():
get_state().merged_metrics.at[id, "tags"] = tags
tag_arr = get_state().merged_metrics.at[id, "tags"]
tag_arr.clear()
tag_arr.extend(tags)
MergedMetrics().update_tags(id, tags)


def multiselect_tag(row: Series, key_prefix: str, metric_type: MetricScope):
def multiselect_tag(row: Series, key_prefix: str, is_predictions=False):
identifier = row["identifier"]

if not isinstance(identifier, str):
st.error("Multiple rows with the same identifier were found. Please create a new issue.")
return

metric_scopes = METRIC_SCOPE_TAG_SCOPES[metric_type]
_, _, _, *objects = identifier.split("_")
metric_scopes = {TagScope.DATA}
if objects and not is_predictions:
metric_scopes.add(TagScope.LABEL)

tag_status = []
merged_metrics = get_state().merged_metrics
Expand All @@ -61,6 +66,6 @@ def multiselect_tag(row: Series, key_prefix: str, metric_type: MetricScope):
default=tag_status if len(tag_status) else None,
key=key,
label_visibility="collapsed",
on_change=update_tags,
on_change=partial(update_tags, scopes=metric_scopes),
args=(identifier, key),
)
5 changes: 3 additions & 2 deletions src/encord_active/app/data_quality/sub_pages/explorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
)
from encord_active.lib.metrics.metric import EmbeddingType
from encord_active.lib.metrics.utils import (
IdentifierSchema,
MetricData,
MetricSchema,
MetricScope,
Expand Down Expand Up @@ -223,7 +224,7 @@ def fill_data_quality_window(

paginated_subset = render_pagination(subset, n_cols, n_rows, "score")

form = bulk_tagging_form(metric_scope)
form = bulk_tagging_form(subset.pipe(DataFrame[IdentifierSchema]))

if form and form.submitted:
df = paginated_subset if form.level == BulkLevel.PAGE else subset
Expand Down Expand Up @@ -286,7 +287,7 @@ def build_card(
description = re.sub(r"(\d+\.\d{0,3})\d*", r"\1", row["description"])
st.write(f"Description: {description}")

multiselect_tag(row, "explorer", metric_scope)
multiselect_tag(row, "explorer")

target_expander = similarity_expanders[card_no // get_state().page_grid_settings.columns]

Expand Down
7 changes: 0 additions & 7 deletions src/encord_active/lib/db/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from typing import Callable, List, NamedTuple

from encord_active.lib.db.connection import DBConnection
from encord_active.lib.metrics.utils import MetricScope

TABLE_NAME = "tags"

Expand All @@ -23,12 +22,6 @@ class Tag(NamedTuple):
TagScope.LABEL.value: "✏️",
}

METRIC_SCOPE_TAG_SCOPES = {
MetricScope.DATA_QUALITY: {TagScope.DATA},
MetricScope.LABEL_QUALITY: {TagScope.DATA, TagScope.LABEL},
MetricScope.MODEL_QUALITY: {TagScope.DATA},
}


def ensure_existence(fn: Callable):
def wrapper(*args, **kwargs):
Expand Down

0 comments on commit 7629288

Please sign in to comment.