Skip to content

Commit

Permalink
Merge pull request #34 from smart-on-fhir/mikix/info
Browse files Browse the repository at this point in the history
feat: add `info` subcommand to show computed ranges & labels
  • Loading branch information
mikix authored Jun 5, 2024
2 parents 5f4fd9e + f575512 commit 41fcc7d
Show file tree
Hide file tree
Showing 11 changed files with 230 additions and 36 deletions.
2 changes: 1 addition & 1 deletion chart_review/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Chart Review public entry point"""

__version__ = "1.1.0"
__version__ = "1.2.0"
21 changes: 21 additions & 0 deletions chart_review/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from chart_review import cohort, config
from chart_review.commands.accuracy import accuracy
from chart_review.commands.info import info


###############################################################################
Expand Down Expand Up @@ -35,6 +36,7 @@ def define_parser() -> argparse.ArgumentParser:
subparsers = parser.add_subparsers(required=True)

add_accuracy_subparser(subparsers)
add_info_subparser(subparsers)

return parser

Expand All @@ -61,6 +63,25 @@ def run_accuracy(args: argparse.Namespace) -> None:
accuracy(reader, args.truth_annotator, args.annotator, save=args.save)


###############################################################################
#
# Info
#
###############################################################################


def add_info_subparser(subparsers) -> None:
parser = subparsers.add_parser("info")
add_project_args(parser)
parser.set_defaults(func=run_info)


def run_info(args: argparse.Namespace) -> None:
proj_config = config.ProjectConfig(args.project_dir, config_path=args.config)
reader = cohort.CohortReader(proj_config)
info(reader)


###############################################################################
#
# Main CLI entrypoints
Expand Down
46 changes: 29 additions & 17 deletions chart_review/cohort.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,31 +32,43 @@ def __init__(self, proj_config: config.ProjectConfig):
for name, value in self.config.external_annotations.items():
external.merge_external(self.annotations, saved, self.project_dir, name, value)

# Parse ignored IDs (might be note IDs, might be external IDs)
self.ignored_notes: set[int] = set()
for ignore_id in self.config.ignore:
ls_id = external.external_id_to_label_studio_id(saved, str(ignore_id))
if ls_id is None:
if isinstance(ignore_id, int):
ls_id = ignore_id # must be direct note ID
else:
# Must just be over-zealous excluding (like automatically from SQL)
continue
self.ignored_notes.add(ls_id)

# Consolidate/expand mentions based on config
simplify.simplify_mentions(
self.annotations,
implied_labels=self.config.implied_labels,
grouped_labels=self.config.grouped_labels,
)

# Calculate the final set of note ranges for each annotator
self.note_range = self._collect_note_ranges(saved)

def _collect_note_ranges(self, exported_json: list[dict]) -> dict[str, set[int]]:
# Detect note ranges if they were not defined in the project config
# (i.e. default to the full set of annotated notes)
self.note_range = self.config.note_ranges
note_ranges = {k: set(v) for k, v in self.config.note_ranges.items()}
for annotator, annotator_mentions in self.annotations.mentions.items():
if annotator not in self.note_range:
self.note_range[annotator] = sorted(annotator_mentions.keys())
if annotator not in note_ranges:
note_ranges[annotator] = set(annotator_mentions.keys())

# Parse ignored IDs (might be note IDs, might be external IDs)
ignored_notes: set[int] = set()
for ignore_id in self.config.ignore:
ls_id = external.external_id_to_label_studio_id(exported_json, str(ignore_id))
if ls_id is None:
if isinstance(ignore_id, int):
ls_id = ignore_id # must be direct note ID
else:
# Must just be over-zealous excluding (like automatically from SQL)
continue
ignored_notes.add(ls_id)

# Remove any invalid (ignored, non-existent) notes from the range sets
all_ls_notes = {int(entry["id"]) for entry in exported_json if "id" in entry}
for note_ids in note_ranges.values():
note_ids.difference_update(ignored_notes)
note_ids.intersection_update(all_ls_notes)

return note_ranges

@property
def class_labels(self):
Expand Down Expand Up @@ -103,7 +115,7 @@ def confusion_matrix(
:return: dict
"""
labels = self._select_labels(label_pick)
note_range = set(guard_iter(note_range)) - self.ignored_notes
note_range = set(guard_iter(note_range))
return agree.confusion_matrix(
self.annotations,
truth,
Expand All @@ -122,7 +134,7 @@ def score_reviewer(self, truth: str, annotator: str, note_range, label_pick: str
:return: dict, keys f1, precision, recall and vals= %score
"""
labels = self._select_labels(label_pick)
note_range = set(guard_iter(note_range)) - self.ignored_notes
note_range = set(guard_iter(note_range))
return agree.score_reviewer(self.annotations, truth, annotator, note_range, labels=labels)

def score_reviewer_table_csv(self, truth: str, annotator: str, note_range) -> str:
Expand Down
64 changes: 64 additions & 0 deletions chart_review/commands/info.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
"""Methods for showing config & calculated setup info."""

import rich
import rich.box
import rich.table

from chart_review import cohort


def info(reader: cohort.CohortReader) -> None:
"""
Show project information on the console.
:param reader: the cohort configuration
"""
console = rich.get_console()

# Charts
chart_table = rich.table.Table(
"Annotator",
"Chart Count",
"Chart IDs",
box=rich.box.ROUNDED,
pad_edge=False,
title="Annotations:",
title_justify="left",
title_style="bold",
)
for annotator in sorted(reader.note_range):
notes = reader.note_range[annotator]
chart_table.add_row(annotator, str(len(notes)), pretty_note_range(notes))
console.print(chart_table)
console.print()

# Labels
console.print("Labels:", style="bold")
if reader.class_labels:
console.print(", ".join(sorted(reader.class_labels, key=str.casefold)))
else:
console.print("None", style="italic", highlight=False)


def pretty_note_range(notes: set[int]) -> str:
ranges = []
range_start = None
prev_note = None

def end_range() -> None:
if prev_note is None:
return
if range_start == prev_note:
ranges.append(str(prev_note))
else:
ranges.append(f"{range_start}{prev_note}") # en dash

for note in sorted(notes):
if prev_note is None or prev_note + 1 != note:
end_range()
range_start = note
prev_note = note

end_range()

return ", ".join(ranges)
2 changes: 1 addition & 1 deletion chart_review/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def read_json(path: str) -> Union[dict, list[dict]]:
return json.load(f, strict=False)


def write_json(path: str, data: dict, indent: Optional[int] = 4) -> None:
def write_json(path: str, data: dict | list, indent: Optional[int] = 4) -> None:
"""
Writes data to the given path, in json format
:param path: filesystem path
Expand Down
2 changes: 1 addition & 1 deletion chart_review/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def __init__(self, project_dir: str, config_path: Optional[str] = None):

# ** Note ranges **
# Handle some extra syntax like 1-3 == [1, 2, 3]
self.note_ranges = self._data.get("ranges", {})
self.note_ranges: dict[str, list[int]] = self._data.get("ranges", {})
for key, values in self.note_ranges.items():
self.note_ranges[key] = list(self._parse_note_range(values))

Expand Down
4 changes: 2 additions & 2 deletions chart_review/simplify.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,15 @@ def simplify_export(
for entry in exported_json:
note_id = int(entry.get("id"))

for annot in entry.get("annotations"):
for annot in entry.get("annotations", []):
completed_by = annot.get("completed_by")
if completed_by not in proj_config.annotators:
continue # we don't know who this is!

# Grab all valid mentions for this annotator & note
labels = types.LabelSet()
text_tags = []
for result in annot.get("result"):
for result in annot.get("result", []):
result_value = result.get("value", {})
result_text = result_value.get("text")
result_labels = set(result_value.get("labels", []))
Expand Down
42 changes: 42 additions & 0 deletions docs/info.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
---
title: Info Command
parent: Chart Review
nav_order: 6
# audience: lightly technical folks
# type: how-to
---

# The Info Command

The `info` command will print information about your current project.

This is helpful to examine the computed list of chart ID ranges or labels.

## Example

```shell
$ chart-review info
Annotations:
╭──────────┬─────────────┬──────────╮
│Annotator │ Chart Count │ Chart IDs│
├──────────┼─────────────┼──────────┤
│jane │ 3 │ 1, 3–4 │
│jill │ 4 │ 1–4 │
│john │ 3 │ 1–2, 4 │
╰──────────┴─────────────┴──────────╯

Labels:
Cough, Fatigue, Headache
```

## Options

### `--config=PATH`

Use this to point to a secondary (non-default) config file.
Useful if you have multiple label setups (e.g. one grouped into a binary label and one not).

### `--project-dir=DIR`

Use this to run `chart-review` outside of your project dir.
Config files, external annotations, etc will be looked for in that directory.
32 changes: 21 additions & 11 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Tests for cli.py"""

import contextlib
import io
import os
import shutil
import tempfile
Expand Down Expand Up @@ -86,18 +88,26 @@ def test_accuracy(self):
accuracy_csv,
)

def test_ignored_ids(self):
with tempfile.TemporaryDirectory() as tmpdir:
shutil.copytree(f"{DATA_DIR}/ignore", tmpdir, dirs_exist_ok=True)
cli.main_cli(["accuracy", "--project-dir", tmpdir, "--save", "allison", "adam"])
def test_info(self):
stdout = io.StringIO()
with contextlib.redirect_stdout(stdout):
cli.main_cli(["info", "--project-dir", f"{DATA_DIR}/cold"])

self.assertEqual(
"""Annotations:
╭──────────┬─────────────┬──────────╮
│Annotator │ Chart Count │ Chart IDs│
├──────────┼─────────────┼──────────┤
│jane │ 3 │ 1, 3–4 │
│jill │ 4 │ 1–4 │
│john │ 3 │ 1–2, 4 │
╰──────────┴─────────────┴──────────╯
# Only two of the five notes should be considered, and we should have full agreement.
accuracy_json = common.read_json(f"{tmpdir}/accuracy-allison-adam.json")
self.assertEqual(1, accuracy_json["F1"])
self.assertEqual(2, accuracy_json["TP"])
self.assertEqual(0, accuracy_json["FN"])
self.assertEqual(2, accuracy_json["TN"])
self.assertEqual(0, accuracy_json["FP"])
Labels:
Cough, Fatigue, Headache
""", # noqa: W291
stdout.getvalue(),
)

def test_custom_config(self):
with tempfile.TemporaryDirectory() as tmpdir:
Expand Down
45 changes: 45 additions & 0 deletions tests/test_cohort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
"""Tests for cohort.py"""

import os
import tempfile
import unittest

from chart_review import cohort, common, config

DATA_DIR = os.path.join(os.path.dirname(__file__), "data")


class TestCohort(unittest.TestCase):
"""Test case for basic cohort management"""

def setUp(self):
super().setUp()
self.maxDiff = None

def test_ignored_ids(self):
reader = cohort.CohortReader(config.ProjectConfig(f"{DATA_DIR}/ignore"))

# Confirm 3, 4, and 5 got ignored
self.assertEqual(
{
"adam": {1, 2},
"allison": {1, 2},
},
reader.note_range,
)

def test_non_existent_ids(self):
with tempfile.TemporaryDirectory() as tmpdir:
common.write_json(
f"{tmpdir}/config.json", {"annotators": {"bob": 1}, "ranges": {"bob": ["1-5"]}}
)
common.write_json(
f"{tmpdir}/labelstudio-export.json",
[
{"id": 1, "annotations": [{"completed_by": 1}]}, # done by bob
{"id": 3}, # not done by bob, but we are explicitly told it was
],
)
reader = cohort.CohortReader(config.ProjectConfig(tmpdir))

self.assertEqual({"bob": {1, 3}}, reader.note_range)
6 changes: 3 additions & 3 deletions tests/test_external.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def test_basic_read(self):
# Confirm ranges got auto-detected for both human and icd10
self.assertEqual(
{
"human": [1, 2, 3],
"icd10-doc": [1, 3],
"icd10-enc": [1, 3],
"human": {1, 2, 3},
"icd10-doc": {1, 3},
"icd10-enc": {1, 3},
},
reader.note_range,
)

0 comments on commit 41fcc7d

Please sign in to comment.