casemaker_demo / utils_casemaker.py
ljyflores
Add cached files and update app
850fcc9
raw
history blame
9.41 kB
import json
import nltk
import pandas as pd
import re
nltk.download('punkt')
from dataclasses import dataclass
from nltk.tokenize import sent_tokenize
from typing import Dict, List, Sequence
from utils_report_parser import get_section_from_report
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
pipeline,
)
@dataclass
class Report:
patient_id: str | int
text: str
date: str
summary: str | None = None
def clean(s: str) -> str:
s = s.replace("\n", " ") # Concatenate into one string
s = s.replace("_", "") # Remove long lines and underscores
s = re.sub(r"\[.*?\]", "", s) # Remove brackets and parentheses
s = re.sub(r"\(.*?\)", "", s)
s = " ".join(s.split()) # Replace multiple white spaces
return s
def split_paragraphs(text: str) -> List[str]:
paragraphs = text.split("\n\n")
paragraphs = list(map(clean, paragraphs))
paragraphs = list(filter(lambda s: len(s.split()) > 10, paragraphs))
return paragraphs
def format_casemaker_data(
df: pd.DataFrame, patient_id_column: str, text_column: str, date_column: str
):
"""Take in a pandas dataframe where each row corresponds to one report for a patient,
and output a dataframe where each row corresponds to a patient, and the "records" column
contains a list of dictionaries of all their reports sorted by date
Args:
df (pd.DataFrame): Input dataframe on report level
patient_id_column (str): Patient ID
text_column (str): Text/Report
date_column (str): Date (will be used to sort)
"""
df = df.rename(
columns={
patient_id_column: "patient_id",
text_column: "text",
date_column: "date",
}
)
df = (
df.sort_values(by=["patient_id", "date"])
.groupby("patient_id")
.apply(lambda df: df[["patient_id", "text", "date"]].to_dict("records"))
)
reports_by_patient = dict[str, Sequence[Report]]()
for patient_id, report_list in zip(df.index, df):
patient_id = str(patient_id)
report_list = [Report(**report) for report in report_list]
reports_by_patient[patient_id] = report_list
return reports_by_patient
class CaseMaker:
def __init__(self, organ_keywords_dict_path: str = "../assets/terms.json"):
self.organ_keyword_dict = json.load(open(organ_keywords_dict_path, "r"))
self.ner_pipe = pipeline(
"ner",
model=AutoModelForTokenClassification.from_pretrained(
"d4data/biomedical-ner-all"
),
tokenizer=AutoTokenizer.from_pretrained("d4data/biomedical-ner-all"),
aggregation_strategy="simple",
device_map="auto",
)
# self.summ_pipe = pipeline(
# "text2text-generation", model="starmpcc/Asclepius-7B", device_map="auto"
# )
def standardize_organ(self, organ_entity: Dict) -> Dict:
"""Given an entity, map its name to a set of recognized entities provided in
organ_keyword_dict if it matches any of the keywords; otherwise set it as "Other"
Args:
organ_entity (Dict): Dictionary corresponding to entity; should contain "word" key
which is the entity
Returns:
Dict: Same dictionary where the "word" key has been updated to either a set of standard
body organs or "Other"
"""
# If the organ matches any of the keys or their synonyms, replace the name and return
for key in self.organ_keyword_dict:
if (organ_entity["word"].lower() == key.lower()) or (
organ_entity["word"].lower() in self.organ_keyword_dict[key]
):
organ_entity["word"] = key
return organ_entity
# Otherwise, it's a bad match so set the score to 0 and return other
organ_entity["word"] = "Other"
organ_entity["score"] = 0.0
return organ_entity
def pick_organ_by_keyword(self, s: str):
words = s.lower()
for organ in self.organ_keyword_dict.keys():
if any(
[
keyword.lower() in words
for keyword in [organ] + self.organ_keyword_dict[organ]
]
):
return organ
return "other"
def parse_report_by_organ(self, report: str):
"""Take in a text report and output a dictionary of body organs
and a list of all the sentences corresponding to that organ
Args:
report (str): Input report
"""
report_string_by_organ = dict[str, str]()
# Split the report into a list of paragraphs
paragraphs = split_paragraphs(report)
# Collect a list of paragraphs related to each organ
for p in paragraphs:
# Figure out which organ is being referenced
selected_organ = self.pick_organ_by_keyword(p)
# Concatenate the report to its corresponding organ
if selected_organ not in report_string_by_organ:
report_string_by_organ[selected_organ] = p
else:
report_string_by_organ[selected_organ] += p
return report_string_by_organ
def trim_to_relevant_portion(self, report: str):
# Only keep sentences with symptoms and disease descriptions
relevant_sentences = list[str]()
for sentence in sent_tokenize(report):
if any(
[
ent["entity_group"] in ["Sign_symptom", "Disease_disorder"]
for ent in self.ner_pipe(sentence)
]
):
relevant_sentences.append(str(sentence))
return "\n".join(relevant_sentences)
def summarize_report(self, text: str) -> str:
"""Format text into prompt and summarize clinical text
Args:
text (str): Input report
Returns:
str: Output
"""
question = (
"Can you provide a succinct summary of the key clinical findings "
"and treatment recommendations outlined in this discharge summary?"
)
prompt = """
You are an intelligent clinical languge model.
Below is a snippet of patient's discharge summary and a following instruction from healthcare professional.
Write a response that appropriately completes the instruction.
The response should provide the accurate answer to the instruction, while being concise.
[Discharge Summary Begin]
{note}
[Discharge Summary End]
[Instruction Begin]
{question}
[Instruction End]
""".format(
question=question, note=text
)
output = self.summ_pipe(prompt, max_new_tokens=len(text.split()) // 2)[0][
"generated_text"
]
answer = output.split("[Instruction End]")[-1]
answer = clean(answer)
return answer
def parse_records(
self,
reports: Sequence[Report],
):
"""Given a list of reports (represented by dictionaries), split each of them
by body part using parse_report_by_organ, then compile all the text for the same
organ across different reports
(i.e. for each body part, have a list of dicts which contain the text from various reports)
Args:
records (Sequence[Report]): List of reports represented by dictionaries; each dictionary
must contain "text" and "date" keys
"""
# Split the reports by organ
reports_by_organ = dict[str, Sequence[Report]]()
for report in reports:
# Cut the report to the findings
report_findings = get_section_from_report(report.text, "findings")
# For each organ, collect a list of relevant records containing the text and date
report_by_organ = self.parse_report_by_organ(report_findings)
for organ, report_text in report_by_organ.items():
organ_level_record = Report(
text=report_text, date=report.date, patient_id=report.patient_id
)
if organ in reports_by_organ:
reports_by_organ[organ].append(organ_level_record)
else:
reports_by_organ[organ] = [organ_level_record]
# For each organ, then filter only to the relevant reports and summarize them
summarized_reports_by_organ = dict[str, Sequence[Report]]()
for organ in reports_by_organ.keys():
cleaned_reports = list[Report]()
for report in reports_by_organ[organ]:
# Trim the report
report_text = self.trim_to_relevant_portion(report.text)
if report_text:
report.summary = report_text
cleaned_reports.append(report)
summarized_reports_by_organ[organ] = cleaned_reports
return summarized_reports_by_organ
def format_reports(self, all_reports: Dict[str, List[Dict]]):
new_reports = {}
for organ, organ_reports in all_reports.items():
new_reports[organ] = "\n\n".join(
[f"**Report {str(r.date)}**\n\n{str(r.summary)}" for r in organ_reports]
)
return new_reports