Kalbe-x-Bangkit's picture
Change input of detection from enhanced_image to enhancement_type.
b437143 verified
raw
history blame
19.2 kB
import streamlit as st
import cv2
import numpy as np
import pydicom
import tensorflow as tf
import keras
from pydicom.dataset import Dataset, FileDataset
from pydicom.uid import generate_uid
from google.cloud import storage
import os
import io
from PIL import Image
import uuid
import pandas as pd
import tensorflow as tf
from datetime import datetime
from tensorflow import image
from tensorflow.python.keras.models import load_model
from pydicom.pixel_data_handlers.util import apply_voi_lut
# Environment Configuration
os.environ['GOOGLE_APPLICATION_CREDENTIALS'] = "./da-kalbe-63ee33c9cdbb.json"
bucket_name = "da-kalbe-ml-result-png"
storage_client = storage.Client()
bucket_result = storage_client.bucket(bucket_name)
bucket_name_load = "da-ml-models"
bucket_load = storage_client.bucket(bucket_name_load)
st.sidebar.title("Configuration")
uploaded_file = st.sidebar.file_uploader("Upload Original Image", type=["png", "jpg", "jpeg", "dcm"])
enhancement_type = st.sidebar.selectbox(
"Enhancement Type",
["Invert", "High Pass Filter", "Unsharp Masking", "Histogram Equalization", "CLAHE"]
)
H = 224
W = 224
@st.cache_resource
def load_model():
model = tf.keras.models.load_model("model-detection.h5", compile=False)
model.compile(
loss={
"bbox": "mse",
"class": "sparse_categorical_crossentropy"
},
optimizer=tf.keras.optimizers.Adam(),
metrics={
"bbox": ['mse'],
"class": ['accuracy']
}
)
return model
# def preprocess_image(image):
# """ Preprocess the image to the required size and normalization. """
# image = cv2.resize(image, (W, H))
# image = (image - 127.5) / 127.5 # Normalize to [-1, +1]
# image = np.expand_dims(image, axis=0).astype(np.float32)
# return image
# st.title("Chest X-ray Disease Detection")
# st.write("Upload a chest X-ray image and click on 'Detect' to find out if there's any disease.")
model = load_model()
# uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
# if uploaded_file is not None:
# file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
# image = cv2.imdecode(file_bytes, 1)
# st.image(image, caption='Uploaded Image.', use_column_width=True)
# Utility Functions
def upload_to_gcs(image_data: io.BytesIO, filename: str, content_type='application/dicom'):
"""Uploads an image to Google Cloud Storage."""
try:
blob = bucket_result.blob(filename)
blob.upload_from_file(image_data, content_type=content_type)
st.write("File ready to be seen in OHIF Viewer.")
except Exception as e:
st.error(f"An unexpected error occurred: {e}")
def load_dicom_from_gcs(file_name: str = "dicom_00000001_000.dcm"):
# Get the blob object
blob = bucket_load.blob(file_name)
# Download the file as a bytes object
dicom_bytes = blob.download_as_bytes()
# Wrap bytes object into BytesIO (file-like object)
dicom_stream = io.BytesIO(dicom_bytes)
# Load the DICOM file
ds = pydicom.dcmread(dicom_stream)
return ds
def png_to_dicom(image_path: str, image_name: str, dicom: str = None):
if dicom is None:
ds = load_dicom_from_gcs()
else:
ds = load_dicom_from_gcs(dicom)
jpg_image = Image.open(image_path) # Open the image using the path
print("Image Mode:", jpg_image.mode)
if jpg_image.mode == 'L':
np_image = np.array(jpg_image.getdata(), dtype=np.uint8)
ds.Rows = jpg_image.height
ds.Columns = jpg_image.width
ds.PhotometricInterpretation = "MONOCHROME1"
ds.SamplesPerPixel = 1
ds.BitsStored = 8
ds.BitsAllocated = 8
ds.HighBit = 7
ds.PixelRepresentation = 0
ds.PixelData = np_image.tobytes()
ds.save_as(image_name)
elif jpg_image.mode == 'RGBA':
np_image = np.array(jpg_image.getdata(), dtype=np.uint8)[:, :3]
ds.Rows = jpg_image.height
ds.Columns = jpg_image.width
ds.PhotometricInterpretation = "RGB"
ds.SamplesPerPixel = 3
ds.BitsStored = 8
ds.BitsAllocated = 8
ds.HighBit = 7
ds.PixelRepresentation = 0
ds.PixelData = np_image.tobytes()
ds.save_as(image_name)
elif jpg_image.mode == 'RGB':
np_image = np.array(jpg_image.getdata(), dtype=np.uint8)[:, :3] # Remove alpha if present
ds.Rows = jpg_image.height
ds.Columns = jpg_image.width
ds.PhotometricInterpretation = "RGB"
ds.SamplesPerPixel = 3
ds.BitsStored = 8
ds.BitsAllocated = 8
ds.HighBit = 7
ds.PixelRepresentation = 0
ds.PixelData = np_image.tobytes()
ds.save_as(image_name)
else:
raise ValueError("Unsupported image mode:", jpg_image.mode)
return ds
def save_dicom_to_bytes(dicom):
dicom_bytes = io.BytesIO()
dicom.save_as(dicom_bytes)
dicom_bytes.seek(0)
return dicom_bytes
def upload_folder_images(original_image_path, enhanced_image_path):
# Extract the base name of the uploaded image without the extension
folder_name = os.path.splitext(uploaded_file.name)[0]
# Create the folder in Cloud Storage
bucket_result.blob(folder_name + '/').upload_from_string('', content_type='application/x-www-form-urlencoded')
enhancement_name = enhancement_type.split('_')[-1]
# Convert images to DICOM
original_dicom = png_to_dicom(original_image_path, "original_image.dcm")
enhanced_dicom = png_to_dicom(enhanced_image_path, enhancement_name + ".dcm")
# Convert DICOM to byte stream for uploading
original_dicom_bytes = io.BytesIO()
enhanced_dicom_bytes = io.BytesIO()
original_dicom.save_as(original_dicom_bytes)
enhanced_dicom.save_as(enhanced_dicom_bytes)
original_dicom_bytes.seek(0)
enhanced_dicom_bytes.seek(0)
# Upload images to GCS
upload_to_gcs(original_dicom_bytes, folder_name + '/' + 'original_image.dcm', content_type='application/dicom')
upload_to_gcs(enhanced_dicom_bytes, folder_name + '/' + enhancement_name + '.dcm', content_type='application/dicom')
def get_mean_std_per_batch(image_path, df, H=320, W=320):
sample_data = []
for idx, img in enumerate(df.sample(100)["Image Index"].values):
# path = image_dir + img
sample_data.append(
np.array(keras.utils.load_img(image_path, target_size=(H, W))))
mean = np.mean(sample_data[0])
std = np.std(sample_data[0])
return mean, std
def load_image(img_path, preprocess=True, height=320, width=320):
mean, std = get_mean_std_per_batch(img_path, df, height, width)
x = keras.utils.load_img(img_path, target_size=(height, width))
x = keras.utils.img_to_array(x)
if preprocess:
x -= mean
x /= std
x = np.expand_dims(x, axis=0)
return x
def grad_cam(input_model, img_array, cls, layer_name):
grad_model = tf.keras.models.Model(
[input_model.inputs],
[input_model.get_layer(layer_name).output, input_model.output]
)
with tf.GradientTape() as tape:
conv_outputs, predictions = grad_model(img_array)
loss = predictions[:, cls]
output = conv_outputs[0]
grads = tape.gradient(loss, conv_outputs)[0]
gate_f = tf.cast(output > 0, 'float32')
gate_r = tf.cast(grads > 0, 'float32')
guided_grads = gate_f * gate_r * grads
weights = tf.reduce_mean(guided_grads, axis=(0, 1))
cam = np.dot(output, weights)
for index, w in enumerate(weights):
cam += w * output[:, :, index]
cam = cv2.resize(cam.numpy(), (320, 320), cv2.INTER_LINEAR)
cam = np.maximum(cam, 0)
cam = cam / cam.max()
return cam
# Compute Grad-CAM
def compute_gradcam(model, img_path, layer_name='bn'):
base_model = keras.applications.DenseNet121(weights = './densenet.hdf5', include_top = False)
x = base_model.output
x = keras.layers.GlobalAveragePooling2D()(x)
predictions = keras.layers.Dense(14, activation = "sigmoid")(x)
model_gradcam = keras.Model(inputs=base_model.input, outputs=predictions)
model_gradcam.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
loss='sparse_categorical_crossentropy')
model.load_weights('./pretrained_model.h5')
preprocessed_input = load_image(img_path)
predictions = model_gradcam.predict(preprocessed_input)
original_image = load_image(img_path, preprocess=False)
# Assuming you have 14 classes as previously mentioned
labels = ['Cardiomegaly', 'Emphysema', 'Effusion', 'Hernia', 'Infiltration', 'Mass',
'Nodule', 'Atelectasis', 'Pneumothorax', 'Pleural_Thickening',
'Pneumonia', 'Fibrosis', 'Edema', 'Consolidation']
for i in range(len(labels)):
st.write(f"Generating gradcam for class {labels[i]}")
gradcam = grad_cam(model, preprocessed_input, i, layer_name)
gradcam = (gradcam * 255).astype(np.uint8)
gradcam = cv2.applyColorMap(gradcam, cv2.COLORMAP_JET)
gradcam = cv2.addWeighted(gradcam, 0.5, original_image.squeeze().astype(np.uint8), 0.5, 0)
st.image(gradcam, caption=f"{labels[i]}: p={predictions[0][i]:.3f}", use_column_width=True)
def calculate_mse(original_image, enhanced_image):
mse = np.mean((original_image - enhanced_image) ** 2)
return mse
def calculate_psnr(original_image, enhanced_image):
mse = calculate_mse(original_image, enhanced_image)
if mse == 0:
return float('inf')
max_pixel_value = 255.0
psnr = 20 * np.log10(max_pixel_value / np.sqrt(mse))
return psnr
def calculate_maxerr(original_image, enhanced_image):
maxerr = np.max((original_image - enhanced_image) ** 2)
return maxerr
def calculate_l2rat(original_image, enhanced_image):
l2norm_ratio = np.sum(original_image ** 2) / np.sum((original_image - enhanced_image) ** 2)
return l2norm_ratio
def process_image(original_image, enhancement_type, fix_monochrome=True):
if fix_monochrome and original_image.shape[-1] == 3:
original_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY)
image = original_image - np.min(original_image)
image = image / np.max(original_image)
image = (image * 255).astype(np.uint8)
enhanced_image = enhance_image(image, enhancement_type)
mse = calculate_mse(original_image, enhanced_image)
psnr = calculate_psnr(original_image, enhanced_image)
maxerr = calculate_maxerr(original_image, enhanced_image)
l2rat = calculate_l2rat(original_image, enhanced_image)
return enhanced_image, mse, psnr, maxerr, l2rat
def apply_clahe(image):
clahe = cv2.createCLAHE(clipLimit=40.0, tileGridSize=(8, 8))
return clahe.apply(image)
def invert(image):
return cv2.bitwise_not(image)
def hp_filter(image, kernel=None):
if kernel is None:
kernel = np.array([[-1, -1, -1], [-1, 9, -1], [-1, -1, -1]])
return cv2.filter2D(image, -1, kernel)
def unsharp_mask(image, radius=5, amount=2):
def usm(image, radius, amount):
blurred = cv2.GaussianBlur(image, (0, 0), radius)
sharpened = cv2.addWeighted(image, 1.0 + amount, blurred, -amount, 0)
return sharpened
return usm(image, radius, amount)
def hist_eq(image):
return cv2.equalizeHist(image)
def enhance_image(image, enhancement_type):
if enhancement_type == "Invert":
return invert(image)
elif enhancement_type == "High Pass Filter":
return hp_filter(image)
elif enhancement_type == "Unsharp Masking":
return unsharp_mask(image)
elif enhancement_type == "Histogram Equalization":
return hist_eq(image)
elif enhancement_type == "CLAHE":
return apply_clahe(image)
else:
raise ValueError(f"Unknown enhancement type: {enhancement_type}")
# Function to add a button to redirect to the URL
def redirect_button(url):
button = st.button('Go to OHIF Viewer')
if button:
st.markdown(f'<meta http-equiv="refresh" content="0;url={url}" />', unsafe_allow_html=True)
###########################################################################################
########################### Bounding Box Function ###########################################
###########################################################################################
def predict(model, image):
""" Predict bounding box and label for the input image. """
pred_bbox, pred_class = model.predict(enhancement_type)
pred_label_confidence = np.max(pred_class, axis=1)[0]
pred_label = np.argmax(pred_class, axis=1)[0]
return pred_bbox[0], pred_label, pred_label_confidence
def draw_bbox(image, bbox):
""" Draw bounding box on the image. """
h, w, _ = image.shape
x1, y1, x2, y2 = bbox
x1, y1, x2, y2 = int(x1 * w), int(y1 * h), int(x2 * w), int(y2 * h)
image = cv2.rectangle(image, (x1, y1), (x2, y2), (255, 0, 0), 2)
return image
###########################################################################################
########################### Streamlit Interface ###########################################
###########################################################################################
# File uploader for DICOM files
if uploaded_file is not None:
if hasattr(uploaded_file, 'name'):
file_extension = uploaded_file.name.split(".")[-1] # Get the file extension
if file_extension.lower() == "dcm":
# Process DICOM file
dicom_data = pydicom.dcmread(uploaded_file)
pixel_array = dicom_data.pixel_array
# Process the pixel_array further if needed
# Extract all metadata
metadata = {elem.keyword: elem.value for elem in dicom_data if elem.keyword}
metadata_dict = {str(key): str(value) for key, value in metadata.items()}
df = pd.DataFrame.from_dict(metadata_dict, orient='index', columns=['Value'])
# Display metadata in the left-most column
with st.expander("Lihat Metadata"):
st.write("Metadata:")
st.dataframe(df)
# Read the pixel data
pixel_array = dicom_data.pixel_array
img_array = pixel_array.astype(float)
img_array = (np.maximum(img_array, 0) / img_array.max()) * 255.0 # Normalize to 0-255
img_array = np.uint8(img_array) # Convert to uint8
img = Image.fromarray(img_array)
col1, col2 = st.columns(2)
# Check the number of dimensions of the image
if img_array.ndim == 3:
n_slices = img_array.shape[0]
if n_slices > 1:
slice_ix = st.sidebar.slider('Slice', 0, n_slices - 1, int(n_slices / 2))
# Display the selected slice
st.image(img_array[slice_ix, :, :], caption=f"Slice {slice_ix}", use_column_width=True)
else:
# If there's only one slice, just display it
st.image(img_array[0, :, :], caption="Single Slice Image", use_column_width=True)
elif img_array.ndim == 2:
# If the image is 2D, just display it
with col1:
st.image(img_array, caption="Original Image", use_column_width=True)
else:
st.error("Unsupported image dimensions")
original_image = img_array
# Example: convert to grayscale if it's a color image
if len(pixel_array.shape) > 2:
pixel_array = pixel_array[:, :, 0] # Take only the first channel
# Perform image enhancement and evaluation on pixel_array
enhanced_image, mse, psnr, maxerr, l2rat = process_image(pixel_array, enhancement_type)
else:
# Process regular image file
original_image = np.array(keras.utils.load_img(uploaded_file, color_mode='rgb' if enhancement_type == "Invert" else 'grayscale'))
# Perform image enhancement and evaluation on original_image
enhanced_image, mse, psnr, maxerr, l2rat = process_image(original_image, enhancement_type)
col1, col2 = st.columns(2)
with col1:
st.image(original_image, caption="Original Image", use_column_width=True)
with col2:
st.image(enhanced_image, caption='Enhanced Image', use_column_width=True)
col1, col2 = st.columns(2)
col3, col4 = st.columns(2)
col1.metric("MSE", round(mse,3))
col2.metric("PSNR", round(psnr,3))
col3.metric("Maxerr", round(maxerr,3))
col4.metric("L2Rat", round(l2rat,3))
# Save enhanced image to a file
enhanced_image_path = "enhanced_image.png"
cv2.imwrite(enhanced_image_path, enhanced_image)
# Save enhanced image to a file
enhanced_image_path = "enhanced_image.png"
cv2.imwrite(enhanced_image_path, enhanced_image)
# Save original image to a file
original_image_path = "original_image.png"
cv2.imwrite(original_image_path, original_image)
# Add the redirect button
col1, col2, col3 = st.columns(3)
with col1:
redirect_button("https://new-ohif-viewer-k7c3gdlxua-et.a.run.app/")
with col2:
if uploaded_file is not None:
file_bytes = np.asarray(bytearray(uploaded_file.read()), dtype=np.uint8)
image = cv2.imdecode(file_bytes, 1)
if st.button('Auto Detect'):
st.write("Processing...")
# input_image = preprocess_image(image)
input_image = enhancement_type
pred_bbox, pred_label, pred_label_confidence = predict(model, input_image)
# Updated label mapping based on the dataset
label_mapping = {
0: 'Atelectasis',
1: 'Cardiomegaly',
2: 'Effusion',
3: 'Infiltrate',
4: 'Mass',
5: 'Nodule',
6: 'Pneumonia',
7: 'Pneumothorax'
}
if pred_label_confidence < 0.2:
st.write("May not detect a disease.")
else:
pred_label_name = label_mapping[pred_label]
st.write(f"Prediction Label: {pred_label_name}")
st.write(f"Prediction Bounding Box: {pred_bbox}")
st.write(f"Prediction Confidence: {pred_label_confidence:.2f}")
output_image = draw_bbox(image.copy(), pred_bbox)
st.image(output_image, caption='Detected Image.', use_column_width=True)
with col3:
if st.button('Generate Grad-CAM'):
st.write("Loading model...")
model = load_model()
# Compute and show Grad-CAM
st.write("Generating Grad-CAM visualizations")
try:
compute_gradcam(model, uploaded_file)
except Exception as e:
st.error(f"Error generating Grad-CAM: {e}")