From b5fed434259e8d616b4f215b0202fe529ea6b8da Mon Sep 17 00:00:00 2001 From: Gonzalo Chiarlone Date: Thu, 10 Nov 2022 09:45:35 -0300 Subject: [PATCH] missing labels problem fixed --- vectory/visualization/main.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/vectory/visualization/main.py b/vectory/visualization/main.py index 6e04bf0..a62305c 100644 --- a/vectory/visualization/main.py +++ b/vectory/visualization/main.py @@ -4,6 +4,7 @@ from bokeh.plotting import figure from streamlit_bokeh_events import streamlit_bokeh_events from vectory.db.models import DatasetModel, EmbeddingSpaceModel +from vectory.es.utils import load_csv_with_headers from vectory.visualization.utils import ( calculate_indices, calculate_points, @@ -24,11 +25,11 @@ def cached_calculate_points(model, embeddings, rows): return calculate_points(model, embeddings, rows) -def selection(dataset): +def selection(dataset_query): embedding_spaces = [ emb_space.name for emb_space in EmbeddingSpaceModel.select().where( - EmbeddingSpaceModel.dataset == dataset + EmbeddingSpaceModel.dataset == dataset_query ) ] selected_emb_space = st.selectbox( @@ -41,7 +42,11 @@ def selection(dataset): index = None similarity = None selected_vector = None - headers = [] + + dataset = next(iter(dataset_query)) + _, _, headers = load_csv_with_headers(dataset.csv_path) + headers.remove("_idx") + headers.remove(str(dataset.id_field)) model = st.selectbox( "Choose a model to use for embedding kNN search", @@ -325,7 +330,8 @@ def main(): col1, col2 = st.columns(2) if dataset is not None: - dataset = DatasetModel.select().where(DatasetModel.name == dataset) + dataset_query = DatasetModel.select().where(DatasetModel.name == dataset) + with col1: ( dimensional_reduction_model_1, @@ -337,7 +343,7 @@ def main(): selected_vector_1, similarity_1, embedding_space_1_name, - ) = selection(dataset) + ) = selection(dataset_query) with col2: ( dimensional_reduction_model_2, @@ -349,7 +355,7 @@ def main(): selected_vector_2, similarity_2, embedding_space_2_name, - ) = selection(dataset) + ) = selection(dataset_query) if "submit" not in st.session_state: st.session_state.submit = False