File size: 3,475 Bytes
8e0b903
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import cv2
import os

from chexnet import ChexNet
from unet import Unet
from heatmap import HeatmapGenerator
from constant import IMAGENET_MEAN, IMAGENET_STD, CLASS_NAMES

import sys
script_dir = os.path.dirname(os.path.abspath(__file__))
imgto3d_path = os.path.join(script_dir, '.')
sys.path.append(imgto3d_path)

from chestXray_utils import blend_segmentation
import torch
import pandas as pd


output_dir = "pages/images"
os.makedirs(output_dir, exist_ok=True)

unet_model = '20190211-101020'
chexnet_model = '20180429-130928'
DISEASES = np.array(CLASS_NAMES)


# Initialize models
unet = Unet(trained=True, model_name=unet_model)
chexnet = ChexNet(trained=True, model_name=chexnet_model)
heatmap_generator = HeatmapGenerator(chexnet, mode='cam')
unet.eval()
chexnet.eval()


def process_image(image_path):
    image = Image.open(image_path).convert('RGB')

    # Run through net
    (t, l, b, r), mask = unet.segment(image)
    cropped_image = image.crop((l, t, r, b))
    prob = chexnet.predict(cropped_image)

    # Save segmentation result
    blended = blend_segmentation(image, mask)
    blended = (blended - blended.min()) / (blended.max() - blended.min())  # Normalize to [0, 1]
    blended = (blended * 255).astype(np.uint8)  # Convert to 0-255 range for cv2
    cv2.rectangle(blended, (l, t), (r, b), (255, 0, 0), 5)  # Color in BGR format for cv2
    segment_result_path = os.path.join(output_dir, 'segment_result.png')
    plt.imsave(segment_result_path, blended)

    # Save CAM result
    w, h = cropped_image.size
    heatmap, _ = heatmap_generator.from_prob(prob, w, h)

    # Resize the heatmap to match the original image dimensions
    heatmap_resized = cv2.resize(heatmap, (image.width, image.height))
    heatmap_resized = np.repeat(heatmap_resized[:, :, np.newaxis], 3, axis=2)  # Ensure it has 3 channels

    heatmap_resized = ((heatmap_resized - heatmap_resized.min()) * (
                1 / (heatmap_resized.max() - heatmap_resized.min())) * 255).astype(np.uint8)

    cam = cv2.applyColorMap(heatmap_resized, cv2.COLORMAP_JET)
    cam = cv2.resize(cam, (image.width, image.height))  # Ensure cam has same dimensions as image
    cam = cv2.addWeighted(cam, 0.4, np.array(image), 0.6, 0)  # Combine heatmap with the original image
    cam_result_path = os.path.join(output_dir, 'cam_result.png')
    print("a",cam_result_path)
    cv2.imwrite(cam_result_path, cam)

    # Top-10 diseases
    idx = np.argsort(-prob)
    top_prob = prob[idx[:10]]
    top_prob = [f'{x:.3}' for x in top_prob]
    top_disease = DISEASES[idx[:10]]
    prediction = dict(zip(top_disease, top_prob))

    result = {'result': prediction}
    df = pd.DataFrame(result['result'].items(), columns=['Disease', 'Probability'])
    output_file = 'prediction_results.csv'
    output_file_path = os.path.join(output_dir, output_file)
    df.to_csv(output_file_path, index=False)

    return result, segment_result_path, cam_result_path


# if __name__ == '__main__':
#     image_path = r'E:\NLP\KN2024\chestX-ray-14\src\fibrosis.jpg'  # Replace with your image path
#     result, segment_result_path, cam_result_path = process_image(image_path)
#     print("Prediction Results:", result)
#     print(f"Segmentation Result Saved to: {segment_result_path}")
#     print(f"CAM Result Saved to: {cam_result_path}")