casemaker_demo / utils_casemaker.py
ljyflores
Add nltk download
16939ac
raw
history blame
9.28 kB
import json
import nltk
import pandas as pd
import re
nltk.download('punkt')
from dataclasses import dataclass
from nltk.tokenize import sent_tokenize
from typing import Dict, List, Sequence
from utils_report_parser import get_section_from_report
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
pipeline,
)
@dataclass
class Report:
patient_id: str|int
text: str
date: str
summary: str|None = None
def clean(s: str) -> str:
s = s.replace("\n", " ") # Concatenate into one string
s = s.replace("_", "") # Remove long lines and underscores
s = re.sub(r"\[.*?\]", "", s) # Remove brackets and parentheses
s = re.sub(r"\(.*?\)", "", s)
s = " ".join(s.split()) # Replace multiple white spaces
return s
def split_paragraphs(text: str) -> List[str]:
paragraphs = text.split("\n\n")
paragraphs = list(map(clean, paragraphs))
paragraphs = list(filter(lambda s: len(s.split()) > 10, paragraphs))
return paragraphs
def format_casemaker_data(
df: pd.DataFrame, patient_id_column: str, text_column: str, date_column: str
):
"""Take in a pandas dataframe where each row corresponds to one report for a patient,
and output a dataframe where each row corresponds to a patient, and the "records" column
contains a list of dictionaries of all their reports sorted by date
Args:
df (pd.DataFrame): Input dataframe on report level
patient_id_column (str): Patient ID
text_column (str): Text/Report
date_column (str): Date (will be used to sort)
"""
df = df.rename(
columns={
patient_id_column: "patient_id",
text_column: "text",
date_column: "date",
}
)
df = (
df.sort_values(by=["patient_id", "date"])
.groupby("patient_id")
.apply(lambda df: df[["patient_id", "text", "date"]].to_dict("records"))
)
reports_by_patient = dict[str,Sequence[Report]]()
for patient_id, report_list in zip(df.index, df):
patient_id = str(patient_id)
report_list = [Report(**report) for report in report_list]
reports_by_patient[patient_id] = report_list
return reports_by_patient
class CaseMaker:
def __init__(self, organ_keywords_dict_path: str = "../assets/terms.json"):
self.organ_keyword_dict = json.load(open(organ_keywords_dict_path, "r"))
self.ner_pipe = pipeline(
"ner",
model=AutoModelForTokenClassification.from_pretrained(
"d4data/biomedical-ner-all"
),
tokenizer=AutoTokenizer.from_pretrained("d4data/biomedical-ner-all"),
aggregation_strategy="simple",
device_map="auto",
)
# self.summ_pipe = pipeline(
# "text2text-generation", model="starmpcc/Asclepius-7B", device_map="auto"
# )
def standardize_organ(self, organ_entity: Dict) -> Dict:
"""Given an entity, map its name to a set of recognized entities provided in
organ_keyword_dict if it matches any of the keywords; otherwise set it as "Other"
Args:
organ_entity (Dict): Dictionary corresponding to entity; should contain "word" key
which is the entity
Returns:
Dict: Same dictionary where the "word" key has been updated to either a set of standard
body organs or "Other"
"""
# If the organ matches any of the keys or their synonyms, replace the name and return
for key in self.organ_keyword_dict:
if (organ_entity["word"].lower() == key.lower()) or (
organ_entity["word"].lower() in self.organ_keyword_dict[key]
):
organ_entity["word"] = key
return organ_entity
# Otherwise, it's a bad match so set the score to 0 and return other
organ_entity["word"] = "Other"
organ_entity["score"] = 0.0
return organ_entity
def pick_organ_by_keyword(self, s: str):
words = s.lower()
for organ in self.organ_keyword_dict.keys():
if any(
[
keyword.lower() in words
for keyword in [organ] + self.organ_keyword_dict[organ]
]
):
return organ
return "other"
def parse_report_by_organ(self, report: str):
"""Take in a text report and output a dictionary of body organs
and a list of all the sentences corresponding to that organ
Args:
report (str): Input report
"""
report_string_by_organ = dict[str, str]()
# Split the report into a list of paragraphs
paragraphs = split_paragraphs(report)
# Collect a list of paragraphs related to each organ
for p in paragraphs:
# Figure out which organ is being referenced
selected_organ = self.pick_organ_by_keyword(p)
# Concatenate the report to its corresponding organ
if selected_organ not in report_string_by_organ:
report_string_by_organ[selected_organ] = p
else:
report_string_by_organ[selected_organ] += p
return report_string_by_organ
def trim_to_relevant_portion(self, report: str):
# Cut the report to the findings
report = get_section_from_report(report, "findings")
# Only keep sentences with symptoms and disease descriptions
relevant_sentences = []
for sentence in sent_tokenize(report):
if any(
[
ent["entity_group"] in ["Sign_symptom", "Disease_disorder"]
for ent in self.ner_pipe(sentence)
]
):
relevant_sentences.append(sentence)
return "\n".join(relevant_sentences)
def summarize_report(self, text: str) -> str:
"""Format text into prompt and summarize clinical text
Args:
text (str): Input report
Returns:
str: Output
"""
question = (
"Can you provide a succinct summary of the key clinical findings "
"and treatment recommendations outlined in this discharge summary?"
)
prompt = """
You are an intelligent clinical languge model.
Below is a snippet of patient's discharge summary and a following instruction from healthcare professional.
Write a response that appropriately completes the instruction.
The response should provide the accurate answer to the instruction, while being concise.
[Discharge Summary Begin]
{note}
[Discharge Summary End]
[Instruction Begin]
{question}
[Instruction End]
""".format(
question=question, note=text
)
output = self.summ_pipe(prompt, max_new_tokens=len(text.split()) // 2)[0][
"generated_text"
]
answer = output.split("[Instruction End]")[-1]
answer = clean(answer)
return answer
def parse_records(
self,
reports: Sequence[Report],
):
"""Given a list of reports (represented by dictionaries), split each of them
by body part using parse_report_by_organ, then compile all the text for the same
organ across different reports
(i.e. for each body part, have a list of dicts which contain the text from various reports)
Args:
records (Sequence[Report]): List of reports represented by dictionaries; each dictionary
must contain "text" and "date" keys
"""
# For each organ, collect a list of relevant records containing the text and date
reports_by_organ = dict[str, Sequence[Report]]()
for report in reports:
report_by_organ = self.parse_report_by_organ(report.text)
for organ, report_text in report_by_organ.items():
organ_level_record = Report(text=report_text, date=report.date, patient_id=report.patient_id)
if organ in reports_by_organ:
reports_by_organ[organ].append(organ_level_record)
else:
reports_by_organ[organ] = [organ_level_record]
# For each organ, then filter only to the relevant reports and summarize them
summarized_reports_by_organ = dict[str, Sequence[Report]]()
for organ in reports_by_organ.keys():
cleaned_reports = list[Report]()
for report in reports_by_organ[organ]:
# Trim the report
report_text = self.trim_to_relevant_portion(report.text)
if report_text:
report.summary = report_text
cleaned_reports.append(report)
summarized_reports_by_organ[organ] = cleaned_reports
return summarized_reports_by_organ
def format_reports(self, all_reports: Dict[str, List[Dict]]):
new_reports = {}
for organ, organ_reports in all_reports.items():
new_reports[organ] = "\n\n".join(
[f"**Report {str(r.date)}**\n\n{str(r.summary)}" for r in organ_reports]
)
return new_reports