import flask import os app = flask.Flask(__name__, template_folder="./templates/") import numpy as np import torch from diffusers import LDMSuperResolutionPipeline from diffusers.utils import PIL_INTERPOLATION, load_image, torch_device from pkg.util import img_binary_data_to_pil, resizePilToMaxSide, pil_to_base64 if False: torch_device = 'cpu' print(f'Running inference on {torch_device}') torch.backends.cuda.matmul.allow_tf32 = False ldm = LDMSuperResolutionPipeline.from_pretrained("duongna/ldm-super-resolution", device_map="auto") ldm.to(torch_device) if False: print(f"{ldm.device=}") print(f"{type(ldm.components)=}") print(f"{ldm.components.keys()=}") print(f"{ldm.components['vqvae'].device=}") print(f"{ldm.components['unet'].device=}") print(f"{ldm.components['scheduler'].config=}") ldm.set_progress_bar_config(disable=None) generator = torch.Generator(device=torch_device).manual_seed(0) @app.route('/') def index(): print('Route: /') return flask.render_template('index.html') @app.route('/superres', methods=['POST']) def superres(): print('Route: /superres') if flask.request.method != 'POST': return flask.jsonify( isError=True, message=f"This route doesn't support {flask.request.method} method." ) imgBinary = flask.request.data maxSideLength = flask.request.args.get('maxSideLength', default=100, type=int) numIterations = flask.request.args.get('numIterations', default=20, type=int) img = img_binary_data_to_pil(imgBinary) # img.show() # arr = np.asarray(img) img = resizePilToMaxSide(img, maxSideLength=maxSideLength) # img.show() result = ldm(image=img, generator=generator, num_inference_steps=numIterations, output_type="pil").images[0] # result.show() resultBinary = pil_to_base64(result) return flask.jsonify( isError=False, message='Success', statusCode=200, data=resultBinary ) if __name__ == '__main__': app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 7860)))