Arifzyn19 commited on
Commit
6cd4595
·
1 Parent(s): 3d5c454

Add application file

Browse files
Files changed (1) hide show
  1. app.py +100 -0
app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from PIL import Image
3
+ from RealESRGAN import RealESRGAN
4
+ from flask import Flask, request, jsonify, send_file
5
+ import io
6
+ import logging
7
+
8
+ # Setup logging
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ app = Flask(__name__)
13
+
14
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
15
+ logger.info(f'Using device: {device}')
16
+
17
+ model2 = RealESRGAN(device, scale=2)
18
+ model2.load_weights('weights/RealESRGAN_x2.pth', download=True)
19
+ logger.info('Model x2 loaded successfully')
20
+
21
+ model4 = RealESRGAN(device, scale=4)
22
+ model4.load_weights('weights/RealESRGAN_x4.pth', download=True)
23
+ logger.info('Model x4 loaded successfully')
24
+
25
+ model8 = RealESRGAN(device, scale=8)
26
+ model8.load_weights('weights/RealESRGAN_x8.pth', download=True)
27
+ logger.info('Model x8 loaded successfully')
28
+
29
+ def inference(image, size):
30
+ global model2, model4, model8
31
+
32
+ if torch.cuda.is_available():
33
+ torch.cuda.empty_cache()
34
+ logger.info('CUDA cache cleared')
35
+
36
+ logger.info(f'Starting inference with scale {size}')
37
+
38
+ try:
39
+ if size == '2x':
40
+ result = model2.predict(image.convert('RGB'))
41
+ elif size == '4x':
42
+ result = model4.predict(image.convert('RGB'))
43
+ else:
44
+ width, height = image.size
45
+ if width >= 5000 or height >= 5000:
46
+ return None, "The image is too large."
47
+ result = model8.predict(image.convert('RGB'))
48
+ logger.info(f'Inference completed for scale {size}')
49
+ except torch.cuda.OutOfMemoryError as e:
50
+ logger.error(f'OutOfMemoryError: {e}')
51
+ logger.info(f'Reloading model for scale {size}')
52
+
53
+ if size == '2x':
54
+ model2 = RealESRGAN(device, scale=2)
55
+ model2.load_weights('weights/RealESRGAN_x2.pth', download=False)
56
+ result = model2.predict(image.convert('RGB'))
57
+ elif size == '4x':
58
+ model4 = RealESRGAN(device, scale=4)
59
+ model4.load_weights('weights/RealESRGAN_x4.pth', download=False)
60
+ result = model4.predict(image.convert('RGB'))
61
+ else:
62
+ model8 = RealESRGAN(device, scale=8)
63
+ model8.load_weights('weights/RealESRGAN_x8.pth', download=False)
64
+ result = model8.predict(image.convert('RGB'))
65
+ logger.info(f'Model reloaded and inference completed for scale {size}')
66
+
67
+ return result, None
68
+
69
+ @app.route('/upscale', methods=['POST'])
70
+ def upscale():
71
+ if 'image' not in request.files:
72
+ logger.warning('No image uploaded')
73
+ return jsonify({"error": "No image uploaded"}), 400
74
+
75
+ image_file = request.files['image']
76
+ size = request.form.get('size', '2x')
77
+
78
+ try:
79
+ image = Image.open(image_file)
80
+ logger.info(f'Image uploaded and opened successfully')
81
+ except Exception as e:
82
+ logger.error(f'Invalid image file: {e}')
83
+ return jsonify({"error": "Invalid image file"}), 400
84
+
85
+ result, error = inference(image, size)
86
+
87
+ if error:
88
+ logger.error(f'Error during inference: {error}')
89
+ return jsonify({"error": error}), 400
90
+
91
+ img_io = io.BytesIO()
92
+ result.save(img_io, 'PNG')
93
+ img_io.seek(0)
94
+ logger.info('Image processing completed and ready to be sent back')
95
+
96
+ return send_file(img_io, mimetype='image/png')
97
+
98
+ if __name__ == '__main__':
99
+ logger.info('Starting the Flask server...')
100
+ app.run(host='0.0.0.0', port=5000)