Spaces:
Sleeping
Sleeping
import json | |
import nltk | |
import pandas as pd | |
import re | |
nltk.download('punkt') | |
from dataclasses import dataclass | |
from nltk.tokenize import sent_tokenize | |
from typing import Dict, List, Sequence | |
from utils_report_parser import get_section_from_report | |
from transformers import ( | |
AutoModelForTokenClassification, | |
AutoTokenizer, | |
pipeline, | |
) | |
class Report: | |
patient_id: str | int | |
text: str | |
date: str | |
summary: str | None = None | |
def clean(s: str) -> str: | |
s = s.replace("\n", " ") # Concatenate into one string | |
s = s.replace("_", "") # Remove long lines and underscores | |
s = re.sub(r"\[.*?\]", "", s) # Remove brackets and parentheses | |
s = re.sub(r"\(.*?\)", "", s) | |
s = " ".join(s.split()) # Replace multiple white spaces | |
return s | |
def split_paragraphs(text: str) -> List[str]: | |
paragraphs = text.split("\n\n") | |
paragraphs = list(map(clean, paragraphs)) | |
paragraphs = list(filter(lambda s: len(s.split()) > 10, paragraphs)) | |
return paragraphs | |
def format_casemaker_data( | |
df: pd.DataFrame, patient_id_column: str, text_column: str, date_column: str | |
): | |
"""Take in a pandas dataframe where each row corresponds to one report for a patient, | |
and output a dataframe where each row corresponds to a patient, and the "records" column | |
contains a list of dictionaries of all their reports sorted by date | |
Args: | |
df (pd.DataFrame): Input dataframe on report level | |
patient_id_column (str): Patient ID | |
text_column (str): Text/Report | |
date_column (str): Date (will be used to sort) | |
""" | |
df = df.rename( | |
columns={ | |
patient_id_column: "patient_id", | |
text_column: "text", | |
date_column: "date", | |
} | |
) | |
df = ( | |
df.sort_values(by=["patient_id", "date"]) | |
.groupby("patient_id") | |
.apply(lambda df: df[["patient_id", "text", "date"]].to_dict("records")) | |
) | |
reports_by_patient = dict[str, Sequence[Report]]() | |
for patient_id, report_list in zip(df.index, df): | |
patient_id = str(patient_id) | |
report_list = [Report(**report) for report in report_list] | |
reports_by_patient[patient_id] = report_list | |
return reports_by_patient | |
class CaseMaker: | |
def __init__(self, organ_keywords_dict_path: str = "../assets/terms.json"): | |
self.organ_keyword_dict = json.load(open(organ_keywords_dict_path, "r")) | |
self.ner_pipe = pipeline( | |
"ner", | |
model=AutoModelForTokenClassification.from_pretrained( | |
"d4data/biomedical-ner-all" | |
), | |
tokenizer=AutoTokenizer.from_pretrained("d4data/biomedical-ner-all"), | |
aggregation_strategy="simple", | |
device_map="auto", | |
) | |
# self.summ_pipe = pipeline( | |
# "text2text-generation", model="starmpcc/Asclepius-7B", device_map="auto" | |
# ) | |
def standardize_organ(self, organ_entity: Dict) -> Dict: | |
"""Given an entity, map its name to a set of recognized entities provided in | |
organ_keyword_dict if it matches any of the keywords; otherwise set it as "Other" | |
Args: | |
organ_entity (Dict): Dictionary corresponding to entity; should contain "word" key | |
which is the entity | |
Returns: | |
Dict: Same dictionary where the "word" key has been updated to either a set of standard | |
body organs or "Other" | |
""" | |
# If the organ matches any of the keys or their synonyms, replace the name and return | |
for key in self.organ_keyword_dict: | |
if (organ_entity["word"].lower() == key.lower()) or ( | |
organ_entity["word"].lower() in self.organ_keyword_dict[key] | |
): | |
organ_entity["word"] = key | |
return organ_entity | |
# Otherwise, it's a bad match so set the score to 0 and return other | |
organ_entity["word"] = "Other" | |
organ_entity["score"] = 0.0 | |
return organ_entity | |
def pick_organ_by_keyword(self, s: str): | |
words = s.lower() | |
for organ in self.organ_keyword_dict.keys(): | |
if any( | |
[ | |
keyword.lower() in words | |
for keyword in [organ] + self.organ_keyword_dict[organ] | |
] | |
): | |
return organ | |
return "other" | |
def parse_report_by_organ(self, report: str): | |
"""Take in a text report and output a dictionary of body organs | |
and a list of all the sentences corresponding to that organ | |
Args: | |
report (str): Input report | |
""" | |
report_string_by_organ = dict[str, str]() | |
# Split the report into a list of paragraphs | |
paragraphs = split_paragraphs(report) | |
# Collect a list of paragraphs related to each organ | |
for p in paragraphs: | |
# Figure out which organ is being referenced | |
selected_organ = self.pick_organ_by_keyword(p) | |
# Concatenate the report to its corresponding organ | |
if selected_organ not in report_string_by_organ: | |
report_string_by_organ[selected_organ] = p | |
else: | |
report_string_by_organ[selected_organ] += p | |
return report_string_by_organ | |
def trim_to_relevant_portion(self, report: str): | |
# Only keep sentences with symptoms and disease descriptions | |
relevant_sentences = list[str]() | |
for sentence in sent_tokenize(report): | |
if any( | |
[ | |
ent["entity_group"] in ["Sign_symptom", "Disease_disorder"] | |
for ent in self.ner_pipe(sentence) | |
] | |
): | |
relevant_sentences.append(str(sentence)) | |
return "\n".join(relevant_sentences) | |
def summarize_report(self, text: str) -> str: | |
"""Format text into prompt and summarize clinical text | |
Args: | |
text (str): Input report | |
Returns: | |
str: Output | |
""" | |
question = ( | |
"Can you provide a succinct summary of the key clinical findings " | |
"and treatment recommendations outlined in this discharge summary?" | |
) | |
prompt = """ | |
You are an intelligent clinical languge model. | |
Below is a snippet of patient's discharge summary and a following instruction from healthcare professional. | |
Write a response that appropriately completes the instruction. | |
The response should provide the accurate answer to the instruction, while being concise. | |
[Discharge Summary Begin] | |
{note} | |
[Discharge Summary End] | |
[Instruction Begin] | |
{question} | |
[Instruction End] | |
""".format( | |
question=question, note=text | |
) | |
output = self.summ_pipe(prompt, max_new_tokens=len(text.split()) // 2)[0][ | |
"generated_text" | |
] | |
answer = output.split("[Instruction End]")[-1] | |
answer = clean(answer) | |
return answer | |
def parse_records( | |
self, | |
reports: Sequence[Report], | |
): | |
"""Given a list of reports (represented by dictionaries), split each of them | |
by body part using parse_report_by_organ, then compile all the text for the same | |
organ across different reports | |
(i.e. for each body part, have a list of dicts which contain the text from various reports) | |
Args: | |
records (Sequence[Report]): List of reports represented by dictionaries; each dictionary | |
must contain "text" and "date" keys | |
""" | |
# Split the reports by organ | |
reports_by_organ = dict[str, Sequence[Report]]() | |
for report in reports: | |
# Cut the report to the findings | |
report_findings = get_section_from_report(report.text, "findings") | |
# For each organ, collect a list of relevant records containing the text and date | |
report_by_organ = self.parse_report_by_organ(report_findings) | |
for organ, report_text in report_by_organ.items(): | |
organ_level_record = Report( | |
text=report_text, date=report.date, patient_id=report.patient_id | |
) | |
if organ in reports_by_organ: | |
reports_by_organ[organ].append(organ_level_record) | |
else: | |
reports_by_organ[organ] = [organ_level_record] | |
# For each organ, then filter only to the relevant reports and summarize them | |
summarized_reports_by_organ = dict[str, Sequence[Report]]() | |
for organ in reports_by_organ.keys(): | |
cleaned_reports = list[Report]() | |
for report in reports_by_organ[organ]: | |
# Trim the report | |
report_text = self.trim_to_relevant_portion(report.text) | |
if report_text: | |
report.summary = report_text | |
cleaned_reports.append(report) | |
summarized_reports_by_organ[organ] = cleaned_reports | |
return summarized_reports_by_organ | |
def format_reports(self, all_reports: Dict[str, List[Dict]]): | |
new_reports = {} | |
for organ, organ_reports in all_reports.items(): | |
new_reports[organ] = "\n\n".join( | |
[f"**Report {str(r.date)}**\n\n{str(r.summary)}" for r in organ_reports] | |
) | |
return new_reports | |