cm107's picture
tweak
b84be05
import flask
import os
app = flask.Flask(__name__, template_folder="./templates/")
import numpy as np
import torch
from diffusers import LDMSuperResolutionPipeline
from diffusers.utils import PIL_INTERPOLATION, load_image, torch_device
from pkg.util import img_binary_data_to_pil, resizePilToMaxSide, pil_to_base64
if False:
torch_device = 'cpu'
print(f'Running inference on {torch_device}')
torch.backends.cuda.matmul.allow_tf32 = False
ldm = LDMSuperResolutionPipeline.from_pretrained("duongna/ldm-super-resolution", device_map="auto")
ldm.to(torch_device)
if False:
print(f"{ldm.device=}")
print(f"{type(ldm.components)=}")
print(f"{ldm.components.keys()=}")
print(f"{ldm.components['vqvae'].device=}")
print(f"{ldm.components['unet'].device=}")
print(f"{ldm.components['scheduler'].config=}")
ldm.set_progress_bar_config(disable=None)
generator = torch.Generator(device=torch_device).manual_seed(0)
@app.route('/')
def index():
print('Route: /')
return flask.render_template('index.html')
@app.route('/superres', methods=['POST'])
def superres():
print('Route: /superres')
if flask.request.method != 'POST':
return flask.jsonify(
isError=True,
message=f"This route doesn't support {flask.request.method} method."
)
imgBinary = flask.request.data
maxSideLength = flask.request.args.get('maxSideLength', default=100, type=int)
numIterations = flask.request.args.get('numIterations', default=20, type=int)
img = img_binary_data_to_pil(imgBinary)
# img.show()
# arr = np.asarray(img)
img = resizePilToMaxSide(img, maxSideLength=maxSideLength)
# img.show()
result = ldm(image=img, generator=generator, num_inference_steps=numIterations, output_type="pil").images[0]
# result.show()
resultBinary = pil_to_base64(result)
return flask.jsonify(
isError=False,
message='Success',
statusCode=200,
data=resultBinary
)
if __name__ == '__main__':
app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 7860)))