Spaces:
Runtime error
Runtime error
import flask | |
import os | |
app = flask.Flask(__name__, template_folder="./templates/") | |
import numpy as np | |
import torch | |
from diffusers import LDMSuperResolutionPipeline | |
from diffusers.utils import PIL_INTERPOLATION, load_image, torch_device | |
from pkg.util import img_binary_data_to_pil, resizePilToMaxSide, pil_to_base64 | |
if False: | |
torch_device = 'cpu' | |
print(f'Running inference on {torch_device}') | |
torch.backends.cuda.matmul.allow_tf32 = False | |
ldm = LDMSuperResolutionPipeline.from_pretrained("duongna/ldm-super-resolution", device_map="auto") | |
ldm.to(torch_device) | |
if False: | |
print(f"{ldm.device=}") | |
print(f"{type(ldm.components)=}") | |
print(f"{ldm.components.keys()=}") | |
print(f"{ldm.components['vqvae'].device=}") | |
print(f"{ldm.components['unet'].device=}") | |
print(f"{ldm.components['scheduler'].config=}") | |
ldm.set_progress_bar_config(disable=None) | |
generator = torch.Generator(device=torch_device).manual_seed(0) | |
def index(): | |
print('Route: /') | |
return flask.render_template('index.html') | |
def superres(): | |
print('Route: /superres') | |
if flask.request.method != 'POST': | |
return flask.jsonify( | |
isError=True, | |
message=f"This route doesn't support {flask.request.method} method." | |
) | |
imgBinary = flask.request.data | |
maxSideLength = flask.request.args.get('maxSideLength', default=100, type=int) | |
numIterations = flask.request.args.get('numIterations', default=20, type=int) | |
img = img_binary_data_to_pil(imgBinary) | |
# img.show() | |
# arr = np.asarray(img) | |
img = resizePilToMaxSide(img, maxSideLength=maxSideLength) | |
# img.show() | |
result = ldm(image=img, generator=generator, num_inference_steps=numIterations, output_type="pil").images[0] | |
# result.show() | |
resultBinary = pil_to_base64(result) | |
return flask.jsonify( | |
isError=False, | |
message='Success', | |
statusCode=200, | |
data=resultBinary | |
) | |
if __name__ == '__main__': | |
app.run(host='0.0.0.0', port=int(os.environ.get('PORT', 7860))) | |