Skip to content

Commit

Permalink
feat: add new frequencies command to show term frequencies
Browse files Browse the repository at this point in the history
This shows a count of each text/label combo, with a combined
total for all annotators. It also flags when the same text has
been labelled differently.

Also:
- Delete some unused code, notably the old unused term_freq code.
- For the `mentions` command, don't show duplicate mentions - it's just
  noise.
- For the `mentions` command, sort output by text (rather than using
  the order that the text appears in the document - which makes less
  sense when we're removing duplicates)
  • Loading branch information
mikix committed Jun 26, 2024
1 parent a8d9848 commit 766967e
Show file tree
Hide file tree
Showing 17 changed files with 465 additions and 265 deletions.
51 changes: 0 additions & 51 deletions chart_review/agree.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,38 +151,6 @@ def score_matrix(matrix: dict, sig_digits=3) -> dict:
}


def avg_scores(first: dict, second: dict, sig_digits=3) -> dict:
merged = {}
for header in csv_header():
added = first[header] + second[header]
if header in ["TP", "FP", "FN", "TN"]:
merged[header] = added
else:
merged[header] = round(added / 2, sig_digits)
return merged


def score_reviewer(
annotations: types.ProjectAnnotations,
truth: str,
annotator: str,
note_range: Collection[int],
labels: Iterable[str] = None,
) -> dict:
"""
Score reliability of an annotator against a truth annotator.
:param annotations: prepared map of annotators and mentions
:param truth: annotator to use as the ground truth
:param annotator: another annotator to compare with truth
:param note_range: collection of LabelStudio document ID
:param labels: (optional) set of labels to score
:return: dict, keys f1, precision, recall and vals= %score
"""
truth_matrix = confusion_matrix(annotations, truth, annotator, note_range, labels=labels)
return score_matrix(truth_matrix)


def csv_table(score: dict, class_labels: types.LabelSet):
table = list()
table.append(csv_header(False, True))
Expand Down Expand Up @@ -229,22 +197,3 @@ def csv_row_score(

row.append(pick_label if pick_label else "*")
return "\t".join(row)


def true_prevalence(prevalence_apparent: float, sensitivity: float, specificity: float):
"""
See paper: "The apparent prevalence, the true prevalence"
https://www.ncbi.nlm.nih.gov/pmc/articles/PMC9195606
Using Eq. 4. it can be calculated:
True prevalence = (Apparent prevalence + Sp - 1)/(Se + Sp - 1)
:param prevalence_apparent: estimated prevalence, concretely:
the %NLP labled positives / cohort
:param: sensitivity: of the class label (where prevalence was measured)
:param: specificity: of the class label (where prevalence was measured)
:return: float adjusted prevalence
"""
return round((prevalence_apparent + specificity - 1) / (sensitivity + specificity - 1), 5)
5 changes: 4 additions & 1 deletion chart_review/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import argparse
import sys

from chart_review.commands import accuracy, default, ids, labels, mentions
from chart_review.commands import accuracy, default, frequency, ids, labels, mentions


def define_parser() -> argparse.ArgumentParser:
Expand All @@ -13,6 +13,9 @@ def define_parser() -> argparse.ArgumentParser:

subparsers = parser.add_subparsers()
accuracy.make_subparser(subparsers.add_parser("accuracy", help="calculate F1 and Kappa scores"))
frequency.make_subparser(
subparsers.add_parser("frequency", help="show counts of each text mention")
)
ids.make_subparser(subparsers.add_parser("ids", help="map Label Studio IDs to FHIR IDs"))
labels.make_subparser(subparsers.add_parser("labels", help="show label usage by annotator"))
mentions.make_subparser(subparsers.add_parser("mentions", help="show each mention of a label"))
Expand Down
57 changes: 2 additions & 55 deletions chart_review/cohort.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Iterable

from chart_review.common import guard_str, guard_iter, guard_in
from chart_review import agree, common, config, errors, external, term_freq, simplify, types
from chart_review.common import guard_iter, guard_in
from chart_review import agree, common, config, errors, external, simplify, types


class CohortReader:
Expand Down Expand Up @@ -84,25 +84,6 @@ def _collect_note_ranges(
def class_labels(self):
return self.annotations.labels

def calc_term_freq(self, annotator) -> dict:
"""
Calculate Term Frequency of highlighted mentions.
:param annotator: an annotator name
:return: dict key=TERM val= {label, list of chart_id}
"""
return term_freq.calc_term_freq(self.annotations, guard_str(annotator))

def calc_label_freq(self, annotator) -> dict:
"""
Calculate Term Frequency of highlighted mentions.
:param annotator: an annotator name
:return: dict key=TERM val= {label, list of chart_id}
"""
return term_freq.calc_label_freq(self.calc_term_freq(annotator))

def calc_term_label_confusion(self, annotator) -> dict:
return term_freq.calc_term_label_confusion(self.calc_term_freq(annotator))

def _select_labels(self, label_pick: str = None) -> Iterable[str]:
if label_pick:
guard_in(label_pick, self.class_labels)
Expand Down Expand Up @@ -131,37 +112,3 @@ def confusion_matrix(
note_range,
labels=labels,
)

def score_reviewer(self, truth: str, annotator: str, note_range, label_pick: str = None):
"""
Score reliability of rater at the level of all symptom *PREVALENCE*
:param truth: annotator to use as the ground truth
:param annotator: another annotator to compare with truth
:param note_range: default= all in corpus
:param label_pick: (optional) of the CLASS_LABEL to score separately
:return: dict, keys f1, precision, recall and vals= %score
"""
labels = self._select_labels(label_pick)
note_range = set(guard_iter(note_range))
return agree.score_reviewer(self.annotations, truth, annotator, note_range, labels=labels)

def score_reviewer_table_csv(self, truth: str, annotator: str, note_range) -> str:
table = list()
table.append(agree.csv_header(False, True))

score = self.score_reviewer(truth, annotator, note_range)
table.append(agree.csv_row_score(score, as_string=True))

for label in self.class_labels:
score = self.score_reviewer(truth, annotator, note_range, label)
table.append(agree.csv_row_score(score, label, as_string=True))

return "\n".join(table) + "\n"

def score_reviewer_table_dict(self, truth, annotator, note_range) -> dict:
table = self.score_reviewer(truth, annotator, note_range)

for label in self.class_labels:
table[label] = self.score_reviewer(truth, annotator, note_range, label)

return table
2 changes: 1 addition & 1 deletion chart_review/commands/default.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def print_info(args: argparse.Namespace) -> None:
notes = reader.note_range[annotator]
chart_table.add_row(
annotator,
str(len(notes)),
f"{len(notes):,}",
console_utils.pretty_note_range(notes),
)

Expand Down
77 changes: 77 additions & 0 deletions chart_review/commands/frequency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
import argparse

import rich
import rich.box
import rich.table
import rich.text

from chart_review import cli_utils, console_utils, types


def make_subparser(parser: argparse.ArgumentParser) -> None:
cli_utils.add_project_args(parser)
cli_utils.add_output_args(parser)
parser.set_defaults(func=print_frequency)


def print_frequency(args: argparse.Namespace) -> None:
"""
Print counts of each text mention.
"""
reader = cli_utils.get_cohort_reader(args)

frequencies = {} # annotator -> label -> text -> count
all_annotator_frequencies = {} # label -> text -> count
text_labels = {} # text -> labelset (to flag term confusion)
for annotator in reader.annotations.original_text_mentions:
annotator_mentions = reader.annotations.original_text_mentions[annotator]
for labeled_texts in annotator_mentions.values():
for labeled_text in labeled_texts:
text = (labeled_text.text or "").strip().casefold()
for label in labeled_text.labels:
if label in reader.annotations.labels:
# Count the mention for this annotator
label_to_text = frequencies.setdefault(annotator, {})
text_to_count = label_to_text.setdefault(label, {})
text_to_count[text] = text_to_count.get(text, 0) + 1

# Count the mention for our running all-annotators total
all_text_to_count = all_annotator_frequencies.setdefault(label, {})
all_text_to_count[text] = all_text_to_count.get(text, 0) + 1

# And finally, add it to our running term-confusion tracker
text_labels.setdefault(text, types.LabelSet()).add(label)

# Now group up the data into a formatted table
table = cli_utils.create_table("Annotator", "Label", "Mention", "Count")
has_term_confusion = False # whether multiple labels are used for the same text

# Helper method to add all the info for a single annotator to our table
def add_annotator_to_table(name, label_to_text: dict) -> None:
nonlocal has_term_confusion
table.add_section()
for label in sorted(label_to_text, key=str.casefold):
text_to_count = label_to_text[label]
for text, count in sorted(
text_to_count.items(), key=lambda t: (t[1], t[0]), reverse=True
):
is_confused = not args.csv and text and len(text_labels[text]) > 1
if is_confused:
text = rich.text.Text(text + "*", style="bold")
has_term_confusion = True
table.add_row(name, label, text, f"{count:,}")

# Add each annotator
add_annotator_to_table(rich.text.Text("All", style="italic"), all_annotator_frequencies)
for annotator in sorted(frequencies, key=str.casefold):
add_annotator_to_table(annotator, frequencies[annotator])

if args.csv:
cli_utils.print_table_as_csv(table)
else:
rich.get_console().print(table)
console_utils.print_ignored_charts(reader)
if has_term_confusion:
rich.get_console().print(
f" * This text has multiple associated labels.", style="italic"
)
2 changes: 1 addition & 1 deletion chart_review/commands/labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def print_labels(args: argparse.Namespace) -> None:

# First add summary entries, for counts across the union of all annotators
for name in label_names:
count = str(len(any_annotator_note_sets.get(name, {})))
count = f"{len(any_annotator_note_sets.get(name, {})):,}"
label_table.add_row(rich.text.Text("Any", style="italic"), name, count)

# Now do each annotator as their own little boxed section
Expand Down
19 changes: 13 additions & 6 deletions chart_review/commands/mentions.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import rich.table
import rich.text

from chart_review import cli_utils, console_utils, types
from chart_review import cli_utils, console_utils


def make_subparser(parser: argparse.ArgumentParser) -> None:
Expand All @@ -24,12 +24,19 @@ def print_mentions(args: argparse.Namespace) -> None:

for annotator in sorted(reader.annotations.original_text_mentions, key=str.casefold):
table.add_section()
mentions = reader.annotations.original_text_mentions[annotator]
for note_id, labeled_texts in mentions.items():
for label_text in labeled_texts:
for label in sorted(label_text.labels, key=str.casefold):
annotator_mentions = reader.annotations.original_text_mentions[annotator]
for note_id, labeled_texts in annotator_mentions.items():
# Gather all combos of text/label (i.e. all mentions) in this note
note_mentions = set()
for labeled_text in labeled_texts:
text = labeled_text.text and labeled_text.text.casefold()
for label in labeled_text.labels:
if label in reader.annotations.labels:
table.add_row(annotator, str(note_id), label_text.text, label)
note_mentions.add((text, label))

# Now add each mention to the table
for note_mention in sorted(note_mentions, key=lambda m: (m[0], m[1].casefold())):
table.add_row(annotator, str(note_id), note_mention[0], note_mention[1])

if args.csv:
cli_utils.print_table_as_csv(table)
Expand Down
11 changes: 0 additions & 11 deletions chart_review/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,6 @@ def print_line(heading=None) -> None:
###############################################################################
# Helper Functions: enum type smoothing
###############################################################################
def guard_str(object) -> str:
if isinstance(object, Enum):
return str(object.name)
elif isinstance(object, EnumMeta):
return str(object.name)
elif isinstance(object, str):
return object
else:
raise Exception(f"expected str|Enum but got {type(object)}")


def guard_iter(object) -> Iterable:
if isinstance(object, Enum):
return guard_iter(object.value)
Expand Down
Loading

0 comments on commit 766967e

Please sign in to comment.