Skip to content

Commit

Permalink
refactor: consolidated state and bugfixes (#56)
Browse files Browse the repository at this point in the history
* refactor: consolidate state

* state cleanup and state mutator

* adds a lazy state initialization and some more cleanup

* refactor: organize state files

* removes global state const

* remove generic multiselect

* remove mutliselect

* more cleanup of predictions state

* fix multiselect values

* completely gets rid of predictions state

* Fix: convert imported polygon points to absolute values (#52)

fix: unnormalize polygon points

* adds db module

* fix missing similarities missing imports

* fix: double annotator stats title (#55)

fix: remove duplicate title

Co-authored-by: Gorkem Polat <[email protected]>
Co-authored-by: Frederik Hvilshøj <[email protected]>
  • Loading branch information
3 people authored Jan 4, 2023
1 parent 9bb2c7d commit 9d744be
Show file tree
Hide file tree
Showing 35 changed files with 525 additions and 433 deletions.
40 changes: 17 additions & 23 deletions src/encord_active/app/actions_page/export_balance.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
import pandas as pd
import streamlit as st

import encord_active.app.common.state as state
from encord_active.app.common.components import multiselect_with_all_option
from encord_active.app.common.state import get_state
from encord_active.app.common.state_hooks import use_state
from encord_active.app.common.utils import set_page_config, setup_page
from encord_active.lib.charts.partition_histogram import get_partition_histogram
from encord_active.lib.dataset.balance import balance_dataframe, get_partitions_zip
Expand All @@ -16,14 +16,6 @@
)


def add_partition():
st.session_state[state.NUMBER_OF_PARTITIONS] += 1


def remove_partition():
st.session_state[state.NUMBER_OF_PARTITIONS] -= 1


def metrics_panel() -> Tuple[List[MetricData], int]:
"""
Panel for selecting the metrics to balance over.
Expand All @@ -33,20 +25,19 @@ def metrics_panel() -> Tuple[List[MetricData], int]:
seed (int): The seed for the random sampling.
"""
# TODO - add label metrics
metrics = load_available_metrics(st.session_state.metric_dir, MetricScope.DATA_QUALITY)
metrics = load_available_metrics(get_state().project_paths.metrics, MetricScope.DATA_QUALITY)
metric_names = [metric.name for metric in metrics]

col1, col2 = st.columns([6, 1])
with col1:
selected_metric_names = multiselect_with_all_option(
label="Metrics to balance",
selected_metric_names = st.multiselect(
label="Filter by metric",
options=metric_names,
key="balance_metrics",
default=["All"],
)
seed = col2.number_input("Seed", value=42, step=1, key="seed")

if "All" in selected_metric_names:
if not selected_metric_names:
selected_metric_names = metric_names
selected_metrics = [metric for metric in metrics if metric.name in selected_metric_names]
return selected_metrics, int(seed)
Expand All @@ -59,8 +50,16 @@ def partitions_panel() -> Dict[str, int]:
Returns:
A dictionary with the partition names as keys and the partition sizes as values.
"""
get_partitions_number, set_partitions_number = use_state(1)

def add_partition():
set_partitions_number(lambda prev: prev + 1)

def remove_partition():
set_partitions_number(lambda prev: prev - 1)

partition_sizes = {}
for i in range(st.session_state[state.NUMBER_OF_PARTITIONS]):
for i in range(get_partitions_number()):
partition_columns = st.columns((4, 12, 1))
partition_name = partition_columns[0].text_input(
f"Name of partition {i + 1}", key=f"name_partition_{i + 1}", value=f"Partition {i + 1}"
Expand All @@ -70,7 +69,7 @@ def partitions_panel() -> Dict[str, int]:
key=f"size_partition_{i + 1}",
min_value=1,
max_value=100,
value=100 // st.session_state[state.NUMBER_OF_PARTITIONS],
value=100 // get_partitions_number(),
step=1,
)
if i > 0:
Expand All @@ -95,9 +94,6 @@ def export_balance():
"Here you can create balanced partitions of your dataset over a set of metrics and export them as a CSV file."
)

if not st.session_state.get(state.NUMBER_OF_PARTITIONS):
st.session_state[state.NUMBER_OF_PARTITIONS] = 1

selected_metrics, seed = metrics_panel()
partition_sizes = partitions_panel()

Expand All @@ -124,9 +120,7 @@ def export_balance():
)

with st.spinner("Generating COCO files"):
partitions_zip_file = (
get_partitions_zip(partition_dict, st.session_state.project_file_structure) if is_pressed else ""
)
partitions_zip_file = get_partitions_zip(partition_dict, get_state().project_paths) if is_pressed else ""

action_columns[1].download_button(
"⬇ Download filtered data",
Expand Down
34 changes: 18 additions & 16 deletions src/encord_active/app/actions_page/export_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
is_numeric_dtype,
)

import encord_active.app.common.state as state
from encord_active.app.common.state import get_state
from encord_active.app.common.state_hooks import use_state
from encord_active.app.common.utils import set_page_config, setup_page
from encord_active.lib.coco.encoder import generate_coco_file
from encord_active.lib.common.utils import ProjectNotFound
Expand Down Expand Up @@ -101,14 +102,18 @@ def filter_dataframe(df: pd.DataFrame) -> pd.DataFrame:


def export_filter():
get_filtered_row_count, set_filtered_row_count = use_state(0)
get_clone_button, set_clone_button = use_state(False)

setup_page()
message_placeholder = st.empty()

st.header("Filter & Export")

filtered_df = filter_dataframe(st.session_state[state.MERGED_DATAFRAME].copy())
filtered_df = filter_dataframe(get_state().merged_metrics.copy())
filtered_df.reset_index(inplace=True)
st.markdown(f"**Total row:** {filtered_df.shape[0]}")
row_count = filtered_df.shape[0]
st.markdown(f"**Total row:** {row_count}")
st.dataframe(filtered_df, use_container_width=True)

action_columns = st.columns((3, 3, 2, 2, 2, 2, 2))
Expand All @@ -119,7 +124,7 @@ def export_filter():

with st.spinner(text="Generating COCO file"):
coco_json = (
generate_coco_file(filtered_df, st.session_state.project_dir, st.session_state.ontology_file)
generate_coco_file(filtered_df, get_state().project_paths.project_dir, get_state().project_paths.ontology)
if is_pressed
else ""
)
Expand All @@ -132,11 +137,10 @@ def export_filter():
help="Ensure you have generated an updated COCO file before downloading",
)

def _clone_button_pressed():
st.session_state[state.ACTION_PAGE_CLONE_BUTTON] = True

action_columns[3].button(
"🏗 Clone", on_click=_clone_button_pressed, help="Clone the filtered data into a new Encord dataset and project"
"🏗 Clone",
on_click=lambda: set_clone_button(True),
help="Clone the filtered data into a new Encord dataset and project",
)
delete_btn = action_columns[4].button("❌ Review", help="Assign the filtered data for review on the Encord platform")
edit_btn = action_columns[5].button(
Expand All @@ -145,7 +149,7 @@ def _clone_button_pressed():
augment_btn = action_columns[6].button("➕ Augment", help="Augment your dataset based on the filered data")

if any([delete_btn, edit_btn, augment_btn]):
st.session_state[state.ACTION_PAGE_CLONE_BUTTON] = False
set_clone_button(False)
message_placeholder.markdown(
"""
<div class="encord-active-info-box">
Expand All @@ -159,14 +163,12 @@ def _clone_button_pressed():
unsafe_allow_html=True,
)

if (
state.ACTION_PAGE_PREVIOUS_FILTERED_NUM not in st.session_state
or st.session_state[state.ACTION_PAGE_PREVIOUS_FILTERED_NUM] != filtered_df.shape[0]
):
st.session_state[state.ACTION_PAGE_PREVIOUS_FILTERED_NUM] = filtered_df.shape[0]
st.session_state[state.ACTION_PAGE_CLONE_BUTTON] = False
prev_row_count = get_filtered_row_count()
if prev_row_count != row_count:
set_filtered_row_count(row_count)
set_clone_button(False)

if st.session_state[state.ACTION_PAGE_CLONE_BUTTON]:
if get_clone_button():
with st.form("new_project_form"):
st.subheader("Create a new project with the selected items")
l_column, r_column = st.columns(2)
Expand Down
1 change: 0 additions & 1 deletion src/encord_active/app/common/components/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,2 @@
from .data_tags import build_data_tags
from .multi_select import multiselect_with_all_option
from .sticky.sticky import sticky_header
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ def render_annotator_properties(df: DataFrame[MetricSchema]):
annotators_df = pd.DataFrame(annotators.values())

fig = px.pie(annotators_df, values="total_annotations", names="name", hover_data=["mean_score"])
fig.update_layout(title_text="Distribution of the annotations", title_x=0.5, title_font_size=20)

left_col.plotly_chart(fig, use_container_width=True)

Expand Down
4 changes: 2 additions & 2 deletions src/encord_active/app/common/components/data_tags.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import List
from typing import List, Optional

import pandas as pd
import streamlit as st
Expand Down Expand Up @@ -61,7 +61,7 @@ def get_icon_color(self, value):
</div>"""


def build_data_tags(row: pd.Series, metric_name: str):
def build_data_tags(row: pd.Series, metric_name: Optional[str] = None):
tag_list = []
for p in properties:
key = p.get_key(metric_name)
Expand Down
8 changes: 4 additions & 4 deletions src/encord_active/app/common/components/metric_summary.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@
import streamlit as st
from pandera.typing import DataFrame

import encord_active.app.common.state as state
from encord_active.app.common.components import build_data_tags
from encord_active.app.common.components.tags.individual_tagging import multiselect_tag
from encord_active.app.common.state import get_state
from encord_active.lib.common.image_utils import show_image_and_draw_polygons
from encord_active.lib.dataset.outliers import IqrOutliers, MetricWithDistanceSchema
from encord_active.lib.metrics.utils import MetricData, MetricScope
Expand All @@ -17,8 +17,8 @@
def render_metric_summary(
metric: MetricData, df: DataFrame[MetricWithDistanceSchema], iqr_outliers: IqrOutliers, metric_scope: MetricScope
):
n_cols = int(st.session_state[state.MAIN_VIEW_COLUMN_NUM])
n_rows = int(st.session_state[state.MAIN_VIEW_ROW_NUM])
n_cols = get_state().page_grid_settings.columns
n_rows = get_state().page_grid_settings.rows
page_size = n_cols * n_rows

st.markdown(metric.meta["long_description"])
Expand Down Expand Up @@ -53,7 +53,7 @@ def render_metric_summary(


def render_summary_item(row, metric_name: str, iqr_outliers: IqrOutliers, metric_scope: MetricScope):
image = show_image_and_draw_polygons(row, st.session_state.data_dir)
image = show_image_and_draw_polygons(row, get_state().project_paths.data)
st.image(image)

multiselect_tag(row, f"{metric_name}_summary", metric_scope)
Expand Down
39 changes: 0 additions & 39 deletions src/encord_active/app/common/components/multi_select.py

This file was deleted.

18 changes: 8 additions & 10 deletions src/encord_active/app/common/components/prediction_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
bulk_tagging_form,
)
from encord_active.app.common.components.tags.individual_tagging import multiselect_tag
from encord_active.app.common.state import get_state
from encord_active.lib.common.colors import Color
from encord_active.lib.common.image_utils import (
draw_object,
Expand All @@ -34,18 +35,18 @@ def build_card_for_labels(
data_dir: Path,
label_color: Color = Color.RED,
):
class_colors = {int(k): idx["color"] for k, idx in st.session_state.full_class_idx.items()}
class_colors = {int(k): idx["color"] for k, idx in get_state().predictions.all_classes.items()}
image = show_image_with_predictions_and_label(
label, predictions, data_dir, label_color=label_color, class_colors=class_colors
)
st.image(image)
multiselect_tag(label, "false_negatives", MetricScope.MODEL_QUALITY)

cls = st.session_state.full_class_idx[str(label["class_id"])]["name"]
cls = get_state().predictions.all_classes[str(label["class_id"])]["name"]
label = label.copy()
label["label_class_name"] = cls
# === Write scores and link to editor === #
build_data_tags(label, st.session_state[state.PREDICTIONS_LABEL_METRIC])
build_data_tags(label, get_state().predictions.metric_datas.selected_label)


def build_card_for_predictions(row: pd.Series, data_dir: Path, box_color=Color.GREEN):
Expand All @@ -55,7 +56,7 @@ def build_card_for_predictions(row: pd.Series, data_dir: Path, box_color=Color.G
multiselect_tag(row, "metric_view", MetricScope.MODEL_QUALITY)

# === Write scores and link to editor === #
build_data_tags(row, st.session_state.predictions_metric)
build_data_tags(row, get_state().predictions.metric_datas.selected_predicion)

if row[PredictionMatchSchema.false_positive_reason] and not row[PredictionMatchSchema.is_true_positive]:
st.write(f"Reason: {row[PredictionMatchSchema.false_positive_reason]}")
Expand Down Expand Up @@ -83,16 +84,13 @@ def prediction_grid(
if use_labels:
df = labels
additionals = model_predictions
selected_metric = st.session_state.get(state.PREDICTIONS_LABEL_METRIC, "")
selected_metric = get_state().predictions.metric_datas.selected_label or ""
else:
df = model_predictions
additionals = None
selected_metric = st.session_state.get(state.PREDICTIONS_METRIC, "")
selected_metric = get_state().predictions.metric_datas.selected_predicion or ""

if state.METRIC_VIEW_PAGE_NUMBER not in st.session_state:
st.session_state[state.METRIC_VIEW_PAGE_NUMBER] = 1

n_cols, n_rows = int(st.session_state[state.MAIN_VIEW_COLUMN_NUM]), int(st.session_state[state.MAIN_VIEW_ROW_NUM])
n_cols, n_rows = get_state().page_grid_settings.columns, get_state().page_grid_settings.rows
subset = render_df_slicer(df, selected_metric)
paginated_subset = render_pagination(subset, n_cols, n_rows, selected_metric)

Expand Down
Loading

0 comments on commit 9d744be

Please sign in to comment.