Spaces:
Runtime error
Runtime error
import torch | |
from PIL import Image | |
from RealESRGAN import RealESRGAN | |
from flask import Flask, request, jsonify, send_file | |
import io | |
import logging | |
# Setup logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = Flask(__name__) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
logger.info(f'Using device: {device}') | |
model2 = RealESRGAN(device, scale=2) | |
model2.load_weights('weights/RealESRGAN_x2.pth', download=True) | |
logger.info('Model x2 loaded successfully') | |
model4 = RealESRGAN(device, scale=4) | |
model4.load_weights('weights/RealESRGAN_x4.pth', download=True) | |
logger.info('Model x4 loaded successfully') | |
model8 = RealESRGAN(device, scale=8) | |
model8.load_weights('weights/RealESRGAN_x8.pth', download=True) | |
logger.info('Model x8 loaded successfully') | |
def inference(image, size): | |
global model2, model4, model8 | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
logger.info('CUDA cache cleared') | |
logger.info(f'Starting inference with scale {size}') | |
try: | |
if size == '2x': | |
result = model2.predict(image.convert('RGB')) | |
elif size == '4x': | |
result = model4.predict(image.convert('RGB')) | |
else: | |
width, height = image.size | |
if width >= 5000 or height >= 5000: | |
return None, "The image is too large." | |
result = model8.predict(image.convert('RGB')) | |
logger.info(f'Inference completed for scale {size}') | |
except torch.cuda.OutOfMemoryError as e: | |
logger.error(f'OutOfMemoryError: {e}') | |
logger.info(f'Reloading model for scale {size}') | |
if size == '2x': | |
model2 = RealESRGAN(device, scale=2) | |
model2.load_weights('weights/RealESRGAN_x2.pth', download=False) | |
result = model2.predict(image.convert('RGB')) | |
elif size == '4x': | |
model4 = RealESRGAN(device, scale=4) | |
model4.load_weights('weights/RealESRGAN_x4.pth', download=False) | |
result = model4.predict(image.convert('RGB')) | |
else: | |
model8 = RealESRGAN(device, scale=8) | |
model8.load_weights('weights/RealESRGAN_x8.pth', download=False) | |
result = model8.predict(image.convert('RGB')) | |
logger.info(f'Model reloaded and inference completed for scale {size}') | |
return result, None | |
def upscale(): | |
if 'image' not in request.files: | |
logger.warning('No image uploaded') | |
return jsonify({"error": "No image uploaded"}), 400 | |
image_file = request.files['image'] | |
size = request.form.get('size', '2x') | |
try: | |
image = Image.open(image_file) | |
logger.info(f'Image uploaded and opened successfully') | |
except Exception as e: | |
logger.error(f'Invalid image file: {e}') | |
return jsonify({"error": "Invalid image file"}), 400 | |
result, error = inference(image, size) | |
if error: | |
logger.error(f'Error during inference: {error}') | |
return jsonify({"error": error}), 400 | |
img_io = io.BytesIO() | |
result.save(img_io, 'PNG') | |
img_io.seek(0) | |
logger.info('Image processing completed and ready to be sent back') | |
return send_file(img_io, mimetype='image/png') | |
if __name__ == '__main__': | |
logger.info('Starting the Flask server...') | |
app.run(host='0.0.0.0', port=5000) |