NickKolok commited on
Commit
ea89a8b
·
verified ·
1 Parent(s): 55425f6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -6
app.py CHANGED
@@ -4,7 +4,7 @@ import torchvision.transforms as transforms
4
  from PIL import Image
5
  import os
6
  import requests
7
- from io import BytesIO
8
 
9
  # Define the model URL and local path
10
  MODEL_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
@@ -22,8 +22,9 @@ if not os.path.exists(MODEL_PATH):
22
  # Define the Real-ESRGAN class
23
  class RealESRGAN:
24
  def __init__(self, model_path, device):
25
- self.model = torch.hub.load('xinntao/Real-ESRGAN', 'restoration', model=model_path, device=device)
26
- self.model.eval()
 
27
 
28
  def upscale(self, image_tensor):
29
  with torch.no_grad():
@@ -44,15 +45,15 @@ def upscale_image(image):
44
  sr_image_tensor = model.upscale(image_tensor)
45
 
46
  # Post-process the output image and convert back to PIL
47
- sr_image = (sr_image_tensor.squeeze(0).numpy() * 255).astype(np.uint8)
48
  sr_image = Image.fromarray(sr_image.transpose(1, 2, 0)) # Change to HWC format
49
  return sr_image
50
 
51
  # Create a Gradio interface
52
  interface = gr.Interface(
53
  fn=upscale_image,
54
- inputs=gr.inputs.Image(type="pil"),
55
- outputs=gr.outputs.Image(type="pil"),
56
  title="Image Upscaler with R-ESRGAN Anime 6B",
57
  description="Upload an image, and it'll be upscaled using the R-ESRGAN Anime model. Supports inputs in various formats."
58
  )
 
4
  from PIL import Image
5
  import os
6
  import requests
7
+ import numpy as np
8
 
9
  # Define the model URL and local path
10
  MODEL_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
 
22
  # Define the Real-ESRGAN class
23
  class RealESRGAN:
24
  def __init__(self, model_path, device):
25
+ # Load the model using torch.hub
26
+ self.model = torch.hub.load('xinntao/Real-ESRGAN', 'real_esrgan', model=model_path, force_reload=True)
27
+ self.model.eval().to(device) # Put model in evaluation mode and move to device
28
 
29
  def upscale(self, image_tensor):
30
  with torch.no_grad():
 
45
  sr_image_tensor = model.upscale(image_tensor)
46
 
47
  # Post-process the output image and convert back to PIL
48
+ sr_image = (sr_image_tensor.squeeze(0).cpu().numpy() * 255).astype(np.uint8)
49
  sr_image = Image.fromarray(sr_image.transpose(1, 2, 0)) # Change to HWC format
50
  return sr_image
51
 
52
  # Create a Gradio interface
53
  interface = gr.Interface(
54
  fn=upscale_image,
55
+ inputs=gr.Image(type="pil"),
56
+ outputs=gr.Image(type="pil"),
57
  title="Image Upscaler with R-ESRGAN Anime 6B",
58
  description="Upload an image, and it'll be upscaled using the R-ESRGAN Anime model. Supports inputs in various formats."
59
  )