Jonny001 commited on
Commit
7be1ca4
·
verified ·
1 Parent(s): 31eea9e

Upload Image_Enhancer.ipynb

Browse files
Files changed (1) hide show
  1. Image_Enhancer.ipynb +148 -0
Image_Enhancer.ipynb ADDED
@@ -0,0 +1,148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "d8d2437e",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": [
10
+ "import torch\n",
11
+ "from PIL import Image\n",
12
+ "from RealESRGAN import RealESRGAN\n",
13
+ "import gradio as gr\n",
14
+ "import numpy as np\n",
15
+ "import tempfile\n",
16
+ "import time\n",
17
+ "\n",
18
+ "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n"
19
+ ]
20
+ },
21
+ {
22
+ "cell_type": "code",
23
+ "execution_count": null,
24
+ "id": "871d9b94",
25
+ "metadata": {},
26
+ "outputs": [],
27
+ "source": [
28
+ "def load_model(scale):\n",
29
+ " model = RealESRGAN(device, scale=scale)\n",
30
+ " weights_path = f'weights/RealESRGAN_x{scale}.pth'\n",
31
+ " try:\n",
32
+ " model.load_weights(weights_path, download=True)\n",
33
+ " print(f\"Weights for scale {scale} loaded successfully.\")\n",
34
+ " except Exception as e:\n",
35
+ " print(f\"Error loading weights for scale {scale}: {e}\")\n",
36
+ " model.load_weights(weights_path, download=False)\n",
37
+ " return model\n",
38
+ "\n",
39
+ "model2 = load_model(2)\n",
40
+ "model4 = load_model(4)\n",
41
+ "model8 = load_model(8)\n"
42
+ ]
43
+ },
44
+ {
45
+ "cell_type": "code",
46
+ "execution_count": null,
47
+ "id": "c891d4b3",
48
+ "metadata": {},
49
+ "outputs": [],
50
+ "source": [
51
+ "def enhance_image(image, scale):\n",
52
+ " try:\n",
53
+ " print(f\"Enhancing image with scale {scale}...\")\n",
54
+ " start_time = time.time()\n",
55
+ " image_np = np.array(image.convert('RGB'))\n",
56
+ " print(f\"Image converted to numpy array: shape {image_np.shape}, dtype {image_np.dtype}\")\n",
57
+ " \n",
58
+ " if scale == '2x':\n",
59
+ " result = model2.predict(image_np)\n",
60
+ " elif scale == '4x':\n",
61
+ " result = model4.predict(image_np)\n",
62
+ " else:\n",
63
+ " result = model8.predict(image_np)\n",
64
+ " \n",
65
+ " enhanced_image = Image.fromarray(np.uint8(result))\n",
66
+ " print(f\"Image enhanced in {time.time() - start_time:.2f} seconds\")\n",
67
+ " return enhanced_image\n",
68
+ " except Exception as e:\n",
69
+ " print(f\"Error enhancing image: {e}\")\n",
70
+ " return image\n"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "id": "9073bff6",
77
+ "metadata": {},
78
+ "outputs": [],
79
+ "source": [
80
+ "def muda_dpi(input_image, dpi):\n",
81
+ " dpi_tuple = (dpi, dpi)\n",
82
+ " image = Image.fromarray(input_image.astype('uint8'), 'RGB')\n",
83
+ " temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')\n",
84
+ " image.save(temp_file, format='PNG', dpi=dpi_tuple)\n",
85
+ " temp_file.close()\n",
86
+ " return Image.open(temp_file.name)\n",
87
+ "\n",
88
+ "def resize_image(input_image, width, height):\n",
89
+ " image = Image.fromarray(input_image.astype('uint8'), 'RGB')\n",
90
+ " resized_image = image.resize((width, height))\n",
91
+ " temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')\n",
92
+ " resized_image.save(temp_file, format='PNG')\n",
93
+ " temp_file.close()\n",
94
+ " return Image.open(temp_file.name)\n"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": null,
100
+ "id": "e470926d",
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
+ "def process_image(input_image, enhance, scale, adjust_dpi, dpi, resize, width, height):\n",
105
+ " original_image = Image.fromarray(input_image.astype('uint8'), 'RGB')\n",
106
+ " \n",
107
+ " if enhance:\n",
108
+ " original_image = enhance_image(original_image, scale)\n",
109
+ " \n",
110
+ " if adjust_dpi:\n",
111
+ " original_image = muda_dpi(np.array(original_image), dpi)\n",
112
+ " \n",
113
+ " if resize:\n",
114
+ " original_image = resize_image(np.array(original_image), width, height)\n",
115
+ " \n",
116
+ " temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')\n",
117
+ " original_image.save(temp_file.name)\n",
118
+ " return original_image, temp_file.name\n",
119
+ "\n",
120
+ "iface = gr.Interface(\n",
121
+ " fn=process_image,\n",
122
+ " inputs=[\n",
123
+ " gr.Image(label=\"Upload\"),\n",
124
+ " gr.Checkbox(label=\"Enhance Image\"),\n",
125
+ " gr.Radio(['2x', '4x', '8x'], type=\"value\", value='2x', label='Select Resolution model'),\n",
126
+ " gr.Checkbox(label=\"Apply DPI\"),\n",
127
+ " gr.Number(label=\"DPI\", value=300),\n",
128
+ " gr.Checkbox(label=\"Apply Resize\"),\n",
129
+ " gr.Number(label=\"Width\", value=512),\n",
130
+ " gr.Number(label=\"Height\", value=512)\n",
131
+ " ],\n",
132
+ " outputs=[\n",
133
+ " gr.Image(label=\"Final Image\"),\n",
134
+ " gr.File(label=\"Download Final Image\")\n",
135
+ " ],\n",
136
+ " title=\"Image Enhancer\",\n",
137
+ " description=\"Sorry for the inconvenience. The model is currently running on the CPU, which might affect performance. We appreciate your understanding.\",\n",
138
+ " theme=\"Yntec/HaleyCH_Theme_Orange\"\n",
139
+ ")\n",
140
+ "\n",
141
+ "iface.launch(debug=True)\n"
142
+ ]
143
+ }
144
+ ],
145
+ "metadata": {},
146
+ "nbformat": 4,
147
+ "nbformat_minor": 5
148
+ }