Skip to content

Commit

Permalink
Merge pull request #1008 from PrimozGodec/ldavis-fix
Browse files Browse the repository at this point in the history
[FIX] LDAVis - Fix fail after None data
  • Loading branch information
PrimozGodec authored Sep 21, 2023
2 parents 96a9548 + 6994a5b commit b225ac7
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 6 deletions.
15 changes: 9 additions & 6 deletions orangecontrib/text/widgets/owldavis.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Optional, Tuple, Set

import numpy as np
import pyqtgraph as pg
Expand Down Expand Up @@ -230,9 +230,9 @@ class OWLDAvis(OWWidget):
icon = "icons/LDAvis.svg"
keywords = "ldavis"

selected_topic = Setting(0, schema_only=True)
relevance = Setting(0.5)
visual_settings = Setting({}, schema_only=True)
selected_topic: int = Setting(0, schema_only=True)
relevance: float = Setting(0.5)
visual_settings: Set = Setting({}, schema_only=True)

graph = SettingProvider(BarPlotGraph)
graph_name = "graph.plotItem"
Expand Down Expand Up @@ -329,7 +329,6 @@ def on_params_change(self):

@Inputs.topics
def set_data(self, data: Optional[Topics]):
prev_topic = self.selected_topic
self.clear()
if data is None:
return
Expand All @@ -343,7 +342,8 @@ def set_data(self, data: Optional[Topics]):
self.term_topic_matrix = self.compute_distributions(data)
self.term_frequency = np.sum(self.term_topic_matrix, axis=0)

self.selected_topic = prev_topic if prev_topic < len(self.topic_list) else 0
st = self.selected_topic
self.selected_topic = st if st < len(self.topic_list) else 0
self.on_params_change()

def set_visual_settings(self, key: KeyType, value: ValueType):
Expand All @@ -354,7 +354,10 @@ def clear(self):
self.Error.clear()
self.graph.clear_all()
self.data = None
prev_topic = self.selected_topic
self.topic_list = []
# resting topic_list resets selected_topic to None - setting back to prev value
self.selected_topic = prev_topic
self.term_topic_matrix = None
self.term_frequency = None
self.num_tokens = None
Expand Down
27 changes: 27 additions & 0 deletions orangecontrib/text/widgets/tests/test_owldavis.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,33 @@ def test_report(self, mocked_plot, mocked_items: Mock):
mocked_items.assert_called_once()
mocked_plot. assert_called_once()

def test_wrong_model(self):
lsi_topic = self.topics.copy()
lsi_topic.attributes["Model"] = "Latent Sematic Indexing"
self.send_signal(self.widget.Inputs.topics, lsi_topic)
self.assertTrue(self.widget.Error.wrong_model.is_shown())

self.send_signal(self.widget.Inputs.topics, self.topics)
self.assertFalse(self.widget.Error.wrong_model.is_shown())
self.assertListEqual(
["Topic 1", "Topic 2", "Topic 3", "Topic 4", "Topic 5"],
self.widget.topic_list,
)
self.widget.topic_box.setCurrentRow(2)
self.assertEqual(2, self.widget.selected_topic)

self.send_signal(self.widget.Inputs.topics, lsi_topic)
self.assertTrue(self.widget.Error.wrong_model.is_shown())

self.send_signal(self.widget.Inputs.topics, self.topics)
self.assertFalse(self.widget.Error.wrong_model.is_shown())
self.assertListEqual(
["Topic 1", "Topic 2", "Topic 3", "Topic 4", "Topic 5"],
self.widget.topic_list,
)
# should be remembered from before
self.assertEqual(2, self.widget.selected_topic)


if __name__ == "__main__":
unittest.main()

0 comments on commit b225ac7

Please sign in to comment.