xiaoming32236046 commited on
Commit
6065bc2
1 Parent(s): 90ed8b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +78 -32
app.py CHANGED
@@ -1,45 +1,91 @@
1
- import gradio as gr
2
  import numpy as np
3
- from PIL import Image, ImageDraw
4
  import cv2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- def calculate_snr(image, roi):
7
- # Convert image to grayscale if it's not already
8
- if len(image.shape) == 3:
9
- image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
10
- roi = cv2.cvtColor(roi, cv2.COLOR_BGR2GRAY)
11
 
12
- # Create a mask for the ROI
13
- mask = np.zeros_like(image)
14
- cv2.fillPoly(mask, [roi], 255)
15
 
16
- # Calculate signal and noise
17
- signal = np.mean(image[mask > 0])
18
- noise = np.std(image[mask == 0])
19
 
20
- # Calculate SNR
21
- snr = signal / noise
 
22
 
23
- return signal, noise, snr
 
24
 
25
- def process_image(input_image, roi_points):
26
- # Load the input image
27
- image = Image.open(input_image)
28
- image = np.array(image)
29
 
30
- # Load the ROI image
31
- roi_image = Image.new('L', (image.shape[1], image.shape[0]))
32
- draw = ImageDraw.Draw(roi_image)
33
- draw.polygon(roi_points, fill=255)
34
- roi_image = np.array(roi_image)
 
 
 
 
35
 
36
- # Calculate SNR for each channel
37
  results = []
38
- for i in range(image.shape[2]):
39
- signal, noise, snr = calculate_snr(image[:, :, i], roi_image)
40
- results.append((signal, noise, snr))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- return results, roi_image
 
43
 
44
- iface = gr.Interface(fn=process_image, inputs=[gr.inputs.Image(), gr.inputs.Image()], outputs=[gr.outputs.Textbox(), gr.outputs.Image()], title="图像SNR计算器")
45
- iface.launch()
 
 
1
  import numpy as np
 
2
  import cv2
3
+ import matplotlib.pyplot as plt
4
+ import gradio as gr
5
+ import tempfile
6
+ import os
7
+ from PIL import Image, ImageDraw
8
+
9
+ def calculate_snr(channel, rois):
10
+ signal_mask = np.zeros(channel.shape, np.uint8)
11
+ for roi in rois:
12
+ cv2.fillPoly(signal_mask, [np.array(roi)], 255)
13
+ signal = cv2.bitwise_and(channel, channel, mask=signal_mask)
14
+
15
+ background_mask = cv2.bitwise_not(signal_mask)
16
+ background = cv2.bitwise_and(channel, channel, mask=background_mask)
17
+
18
+ signal_mean = np.mean(signal[signal_mask == 255])
19
+ background_std = np.std(background[background_mask == 255])
20
 
21
+ snr = signal_mean / background_std if background_std != 0 else float('inf')
 
 
 
 
22
 
23
+ return signal_mean, background_std, snr, signal, background
 
 
24
 
25
+ def process_image(image, roi_sketch):
26
+ if image is None or roi_sketch is None:
27
+ return None, None
28
 
29
+ # Convert sketch to binary mask
30
+ mask = np.array(roi_sketch)
31
+ _, binary_mask = cv2.threshold(mask, 1, 255, cv2.THRESH_BINARY)
32
 
33
+ # Find contours in the binary mask
34
+ contours, _ = cv2.findContours(binary_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
35
 
36
+ # Convert contours to list of ROIs
37
+ rois = [contour.squeeze().tolist() for contour in contours if len(contour) > 2]
 
 
38
 
39
+ # Convert image to numpy array
40
+ img_array = np.array(image)
41
+
42
+ if len(img_array.shape) == 2:
43
+ channels = [img_array]
44
+ channel_names = ['灰度']
45
+ else:
46
+ channels = cv2.split(img_array)
47
+ channel_names = ['Red', 'Green', 'Blue']
48
 
 
49
  results = []
50
+ plt.figure(figsize=(5*len(channels), 10))
51
+
52
+ for i, (channel, name) in enumerate(zip(channels, channel_names)):
53
+ signal_mean, background_std, snr, signal, background = calculate_snr(channel, rois)
54
+
55
+ results.append(f"{name} channel:\n"
56
+ f"信号平均强度(Signal): {signal_mean:.2f}\n"
57
+ f"背景标准差(Noise): {background_std:.2f}\n"
58
+ f"信噪比(SNR): {snr:.2f}\n")
59
+
60
+ plt.subplot(2, len(channels), i+1)
61
+ plt.imshow(signal, cmap='gray')
62
+ plt.title(f'{name} Signal ROI')
63
+ plt.subplot(2, len(channels), i+1+len(channels))
64
+ plt.imshow(background, cmap='gray')
65
+ plt.title(f'{name} Background ROI')
66
+
67
+ result_filename = tempfile.mktemp(suffix='.png')
68
+ plt.savefig(result_filename)
69
+ plt.close()
70
+
71
+ return ("\n".join(results), result_filename)
72
+
73
+ with gr.Blocks() as demo:
74
+ gr.Markdown("# 图像SNR计算器")
75
+ gr.Markdown("上传一张图像,在第二个框中绘制多个感兴趣区域,结果会实时更新。支持单通道和多通道图像。")
76
+
77
+ with gr.Row():
78
+ input_image = gr.Image(label="上传图像", type="numpy")
79
+ roi_sketch = gr.Image(label="绘制ROI", source="upload", tool="sketch", type="numpy", interactive=True)
80
+
81
+ with gr.Row():
82
+ result_text = gr.Textbox(label="结果")
83
+ result_image = gr.Image(label="ROI 图像")
84
+
85
+ inputs = [input_image, roi_sketch]
86
+ outputs = [result_text, result_image]
87
 
88
+ input_image.change(process_image, inputs=inputs, outputs=outputs)
89
+ roi_sketch.change(process_image, inputs=inputs, outputs=outputs)
90
 
91
+ demo.launch()