|
from flask import Flask, request, jsonify, render_template, send_file
|
|
from werkzeug.utils import secure_filename
|
|
import torch
|
|
from torchvision import transforms
|
|
import tensorflow as tf
|
|
from PIL import Image
|
|
import numpy as np
|
|
import io
|
|
import base64
|
|
import cv2
|
|
from model import tumor_model
|
|
|
|
|
|
app = Flask(__name__)
|
|
|
|
|
|
CLASSIFICATION_MODEL_PATH = 'models\\tumor_model_statedict_f.pth'
|
|
SEGMENTATION_MODEL_PATH = 'models\\unet_model.h5'
|
|
|
|
|
|
class MultiTaskModelWrapper:
|
|
def __init__(self):
|
|
self.segmentation_model = self.load_segmentation_model()
|
|
self.classification_model = self.load_classification_model()
|
|
|
|
def load_segmentation_model(self):
|
|
|
|
model = tf.keras.models.load_model(SEGMENTATION_MODEL_PATH, custom_objects={'conv2d_transpose': tf.keras.layers.Conv2DTranspose})
|
|
return model
|
|
|
|
def load_classification_model(self):
|
|
|
|
tumor_model.load_state_dict(torch.load(CLASSIFICATION_MODEL_PATH, map_location=torch.device('cpu')))
|
|
tumor_model.eval()
|
|
return tumor_model
|
|
|
|
def predict(self, image):
|
|
|
|
transform = transforms.Compose([
|
|
transforms.Resize((224, 224)),
|
|
transforms.ToTensor()
|
|
])
|
|
img_tensor = transform(image).unsqueeze(0)
|
|
|
|
with torch.no_grad():
|
|
classification_output = self.classification_model(img_tensor)
|
|
|
|
class_probabilities = torch.nn.functional.softmax(classification_output, dim=1)
|
|
class_label = torch.argmax(class_probabilities).item()
|
|
probability = class_probabilities[0, class_label].item()
|
|
|
|
class_names = {
|
|
0: 'Glioma Tumor',
|
|
1: 'Meningioma Tumor',
|
|
2: 'No Tumor',
|
|
3: 'Pituitary Tumor'
|
|
}
|
|
|
|
|
|
img_array = np.array(image.resize((128, 128)))
|
|
img_array = np.expand_dims(img_array, axis=0) / 255.0
|
|
|
|
segmentation_output = self.segmentation_model.predict(img_array)
|
|
segmentation_mask = (segmentation_output > 0.5).astype(np.uint8)[0, :, :, 0] * 255
|
|
|
|
|
|
mask_image = Image.fromarray(segmentation_mask.astype(np.uint8))
|
|
buffer = io.BytesIO()
|
|
mask_image.save(buffer, format='PNG')
|
|
mask_image_base64 = base64.b64encode(buffer.getvalue()).decode('utf-8')
|
|
|
|
|
|
input_image_buffer = io.BytesIO()
|
|
image.save(input_image_buffer, format='PNG')
|
|
input_image_base64 = base64.b64encode(input_image_buffer.getvalue()).decode('utf-8')
|
|
|
|
return input_image_base64, mask_image_base64, class_names[class_label], probability
|
|
|
|
|
|
model_wrapper = MultiTaskModelWrapper()
|
|
|
|
@app.route('/')
|
|
def home():
|
|
return render_template('index.html')
|
|
|
|
@app.route('/predict', methods=['POST'])
|
|
def predict():
|
|
if 'file' not in request.files:
|
|
return jsonify({'error': 'No file part'})
|
|
file = request.files['file']
|
|
if file.filename == '':
|
|
return jsonify({'error': 'No selected file'})
|
|
|
|
try:
|
|
image = Image.open(file.stream).convert('RGB')
|
|
input_image_base64, mask_image_base64, class_label, probability = model_wrapper.predict(image)
|
|
|
|
return jsonify({
|
|
'input_image': input_image_base64,
|
|
'mask_image': mask_image_base64,
|
|
'class_label': class_label,
|
|
'probability': probability
|
|
})
|
|
|
|
except Exception as e:
|
|
return jsonify({'error': str(e)})
|
|
|
|
if __name__ == '__main__':
|
|
app.run(debug=True)
|
|
|