Spaces:
Sleeping
Sleeping
import streamlit as st | |
from PIL import Image | |
import cv2 | |
import pydicom | |
import numpy as np | |
from streamlit_image_zoom import image_zoom | |
import time | |
import pandas as pd | |
import os | |
import subprocess | |
import sys | |
try: | |
import torchmcubes | |
import torch | |
import torchvision | |
import fpdf | |
except ImportError: | |
subprocess.check_call(['pip', 'install', 'git+https://github.com/tatsy/torchmcubes.git']) | |
subprocess.check_call(['pip', 'install','fpdf']) | |
from fpdf import FPDF | |
############### Import PATH | |
script_dir = os.path.dirname(os.path.abspath(__file__)) | |
chestXray14_path = os.path.join(script_dir, '..', 'chestXray14') | |
sys.path.append(chestXray14_path) | |
def convert_dcm_to_png(input_image_path, output_image_path='a.png'): | |
ds = pydicom.dcmread(input_image_path) | |
img = ds.pixel_array | |
img = cv2.normalize(img, None, 0, 255, cv2.NORM_MINMAX).astype(np.uint8) | |
cv2.imwrite(output_image_path, img) | |
from chestXray14.test import process_image | |
def predictAll(image_path): | |
result, segment_result_path, cam_result_path = process_image(image_path) | |
print("Prediction Results:", result) | |
print(f"Segmentation Result Saved to: {segment_result_path}") | |
print(f"CAM Result Saved to: {cam_result_path}") | |
def load_report(): | |
df = pd.read_csv('pages/images/prediction_results.csv') | |
df.columns = ['Bệnh lý', 'Xác suất'] | |
translation_dict = { | |
'Infiltration': 'Thâm nhiễm', | |
'Nodule': 'Nốt', | |
'Pleural Thickening': 'Dày màng phổi', | |
'Cardiomegaly': 'Tim to', | |
'Effusion': 'Tràn dịch', | |
'Pneumonia': 'Viêm phổi', | |
'Atelectasis': 'Xẹp phổi', | |
'Mass': 'Khối u', | |
'Fibrosis': 'Xơ phổi', | |
'Pneumothorax': 'Tràn khí màng phổi' | |
} | |
df['Bệnh lý'] = df['Bệnh lý'].map(translation_dict) | |
df['Xác suất'] = df['Xác suất'].astype(float) * 100 | |
df['Xác suất'] = df['Xác suất'].round(2).astype(str) + '%' | |
def highlight_rows(row): | |
if row.name == 0: | |
return ['background-color: darkred; color: white'] * len(row) | |
if row.name == 1: | |
return ['background-color: darkblue; color: white'] * len(row) | |
if row.name == 2: | |
return ['background-color: lightblue; color: white'] * len(row) | |
else: | |
return [''] * len(row) | |
df_styled = df.style.apply(highlight_rows, axis=1).set_table_styles( | |
[{'selector': 'thead th', 'props': [('background-color', '#d3d3d3')]}] | |
) | |
return df_styled | |
st.markdown("<h1 style='text-align: center;'>Welcome to Thoracic Classification 🎈</h1>", unsafe_allow_html=True) | |
with st.sidebar: | |
st.markdown("## Upload your scans") | |
uploaded_files = st.file_uploader("Choose scans...", type=["jpg", "jpeg", "png", "dicom"], accept_multiple_files=True) | |
with st.expander("Hướng dẫn"): | |
st.markdown("1. Tải lên ảnh Scan của bạn bằng cách ấn vào **Browse files** hoặc có thể **Kéo và thả** file ảnh của bạn vào phần browse files. Các định dạng cho phép bao gồm **DICOM, PNG, JPG, JPEG**, các định dạng khác cần phải chuyển về các định dạng được chấp nhận.") | |
st.markdown("2. Sau đó ảnh sẽ tự được mở lên") | |
st.markdown("3. Để phóng to ảnh, bạn chuyển chuột trái vào trong ảnh, dùng lăn chuột để thực hiện phóng to- thu nhỏ ảnh") | |
st.markdown("4. Để kéo xuống xem ảnh phía dưới, bạn di chuột ra ngoài vùng ảnh và dùng lăn chuột cuộn trang như bình thường.") | |
status_images=False | |
col_1, col_2 = st.columns([7, 5.5]) | |
with col_1: | |
if uploaded_files: | |
for uploaded_file in uploaded_files: | |
file_type = uploaded_file.name.split('.')[-1].lower() | |
if file_type in ["jpg", "jpeg", "png"]: | |
img = Image.open(uploaded_file) | |
img.save('temp_image.png') | |
st.markdown("<div style='display: flex; justify-content: center;'>", unsafe_allow_html=True) | |
width, height = img.size | |
image_zoom(img, mode="both") | |
st.markdown("</div>", unsafe_allow_html=True) | |
status_images=True | |
elif file_type in ["dicom", "dcm"]: | |
convert_dcm_to_png(uploaded_file) | |
img = Image.open('a.png').convert('RGB') | |
img.save('temp_image.png') | |
st.markdown("<div style='display: flex; justify-content: center;'>", unsafe_allow_html=True) | |
width, height = img.size | |
image_zoom(img, mode="both",size=(width//4, height//4), keep_aspect_ratio=True, zoom_factor=4.0, increment=0.2) | |
st.markdown("</div>", unsafe_allow_html=True) | |
status_images=True | |
else: | |
st.info("Please upload some scans to view them.") | |
############ CREATE PDF | |
import io | |
def generate_pdf(name, age, gender, address, phone): | |
pdf = FPDF() | |
pdf.add_page() | |
pdf.set_font("Arial", size=12) | |
# Title | |
pdf.set_font("Arial", style='B', size=16) | |
pdf.cell(200, 10, txt="Patient Report", ln=True, align='C') | |
pdf.ln(10) | |
# Patient details | |
pdf.set_font("Arial", size=12) | |
pdf.cell(200, 10, txt=f"Name: {name}", ln=True, align='L') | |
pdf.cell(200, 10, txt=f"Age: {age}", ln=True, align='L') | |
pdf.cell(200, 10, txt=f"Gender: {gender}", ln=True, align='L') | |
pdf.cell(200, 10, txt=f"Address: {address}", ln=True, align='L') | |
pdf.cell(200, 10, txt=f"Phone: {phone}", ln=True, align='L') | |
pdf.ln(10) | |
# Placeholder for additional content | |
pdf.cell(200, 10, txt="Predicted Disease Probabilities:", ln=True, align='L') | |
pdf.ln(10) | |
# Simulate adding prediction data (replace this with actual data) | |
diseases = ['Disease A', 'Disease B', 'Disease C'] | |
probabilities = ['70%', '50%', '30%'] | |
for disease, probability in zip(diseases, probabilities): | |
pdf.cell(200, 10, txt=f"{disease}: {probability}", ln=True, align='L') | |
# Add image (optional) | |
pdf.ln(10) | |
pdf.cell(200, 10, txt="Class Activation Map (CAM):", ln=True, align='L') | |
image_path = 'pages/images/cam_result.png' # Adjust this path as necessary | |
if os.path.exists(image_path): | |
pdf.image(image_path, x=10, y=pdf.get_y(), w=100) | |
# Save the PDF to a bytes buffer | |
pdf_buffer = io.BytesIO() | |
pdf.output(pdf_buffer) | |
pdf_buffer.seek(0) # Move the cursor to the beginning of the buffer | |
return pdf_buffer | |
def download_report(name, age, gender, address, phone): | |
st.markdown("<h2 style='text-align: center;'>Patient Report</h2>", unsafe_allow_html=True) | |
st.write(f"**Name:** {name}") | |
st.write(f"**Age:** {age}") | |
st.write(f"**Gender:** {gender}") | |
st.write(f"**Address:** {address}") | |
st.write(f"**Phone:** {phone}") | |
# Load the prediction report | |
df_styled = load_report() | |
st.markdown("<div style='display: flex; justify-content: center;'>", unsafe_allow_html=True) | |
st.write(df_styled.to_html(), unsafe_allow_html=True) | |
# Simulating the addition of an image with a caption | |
img = Image.open('pages/images/cam_result.png').convert('RGB') | |
st.image(img, caption="Class Activation Map (CAM) Visualization", use_column_width=True) | |
# Provide a link to download the report | |
st.markdown( | |
"<a href='pages/images/prediction_results.csv' download='prediction_results.csv'>Click here to download the report</a>", | |
unsafe_allow_html=True) | |
if(status_images): | |
with col_2: | |
st.markdown("<h2 style='text-align: center;'>Function</h2>", unsafe_allow_html=True) | |
btn_predictAll_Scans = st.button("Predict All Scans") | |
btn_CAM_Visualization = st.button("CAM Visualization") | |
btn_Segment_Lung = st.button("Segmentation Visualization for Lung") | |
btn_View_Report = st.button("View Report") | |
btn_Download_Report = st.button("Download Report") | |
if btn_predictAll_Scans: | |
start_time = time.time() | |
predictAll('temp_image.png') | |
elapsed_time = time.time() - start_time | |
st.success(f"Predicted all Scans success - ⏳ {int(elapsed_time)} seconds. You can use CAM, Segmentation, View, and Download Report", icon="✅") | |
button_status=True | |
st.divider() | |
col_3, col_4, col_5 = st.columns([4,7.5, 6]) | |
with col_4: | |
if btn_Segment_Lung: | |
st.markdown("<h2 style='text-align: center;color:red;margin-left: 100px'>Segmentation Image for Lung </h2>", unsafe_allow_html=True) | |
img = Image.open('pages/images/segment_result.png').convert('RGB') | |
st.markdown("<div style='display: flex; justify-content: center;'>", unsafe_allow_html=True) | |
width, height = img.size | |
image_zoom(img, mode="both", size=(width // 4, height // 4), keep_aspect_ratio=True, zoom_factor=4.0, increment=0.2) | |
with col_4: | |
if btn_CAM_Visualization: | |
st.markdown("<h2 style='text-align: center;text-color:red;margin-left: 100px'>Class Activation Map(CAM) Visualization </h2>", unsafe_allow_html=True) | |
img = Image.open('pages/images/cam_result.png').convert('RGB') | |
st.markdown("<div style='display: flex; justify-content: center;'>", unsafe_allow_html=True) | |
width, height = img.size | |
image_zoom(img, mode="both", size=(width // 4, height // 4), keep_aspect_ratio=True, zoom_factor=4.0, increment=0.2) | |
with col_4: | |
if btn_Download_Report: | |
with st.form("patient_info_form"): | |
st.write("Please provide patient details before downloading the report:") | |
name = st.text_input("Name") | |
age = st.number_input("Age", min_value=0, max_value=130) | |
gender = st.selectbox("Gender", ["Male", "Female", "Other"]) | |
address = st.text_input("Address") | |
phone = st.text_input("Phone") | |
submit = st.form_submit_button("Submit") | |
if submit: | |
pdf_buffer = generate_pdf(name, age, gender, address, phone) | |
print(pdf_buffer) | |
st.download_button( | |
label="Download Report", | |
data=pdf_buffer, | |
file_name="patient_report.pdf", | |
mime="application/pdf" | |
) | |
col_6, col_7, col_8 = st.columns([7.8, 4.5, 8]) | |
with col_7: | |
if btn_View_Report: | |
st.markdown("<h2 style='text-align: center;'>Prediction Report</h2>", unsafe_allow_html=True) | |
df_styled = load_report() | |
st.markdown("<div style='display: flex; justify-content: center;'>", unsafe_allow_html=True) | |
st.write(df_styled.to_html(), unsafe_allow_html=True) | |