NickKolok commited on
Commit
789ee84
·
verified ·
1 Parent(s): 22250a9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +24 -51
app.py CHANGED
@@ -1,63 +1,36 @@
 
1
  import gradio as gr
2
  import torch
3
- import torchvision.transforms as transforms
4
  from PIL import Image
5
- import os
6
- import requests
7
- import numpy as np
8
-
9
- # Define the model URL and local path
10
- MODEL_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4plus_anime_6B.pth"
11
- MODEL_PATH = "RealESRGAN_x4_Anime6B.pth"
12
-
13
- # Download the model if it doesn't exist
14
- if not os.path.exists(MODEL_PATH):
15
- response = requests.get(MODEL_URL)
16
- if response.status_code == 200:
17
- with open(MODEL_PATH, 'wb') as f:
18
- f.write(response.content)
19
- else:
20
- raise Exception(f"Could not download model: {response.status_code}")
21
-
22
- # Define the Real-ESRGAN class
23
- class RealESRGAN:
24
- def __init__(self, model_path, device):
25
- self.model = torch.hub.load('xinntao/Real-ESRGAN', 'real_esrgan', model_path, device=device, trust_repo=True)
26
- self.model.eval() # Set the model to eval mode
27
- self.device = device # Store the device
28
-
29
- def upscale(self, image_tensor):
30
- with torch.no_grad():
31
- sr_image_tensor = self.model(image_tensor).clamp(0, 1)
32
- return sr_image_tensor
33
 
34
- device = "cuda" if torch.cuda.is_available() else "cpu"
35
- model = RealESRGAN(MODEL_PATH, device)
36
 
37
- # Define the image upscaling function
38
- def upscale_image(image):
39
- # Pre-process the input image
40
- image = image.convert("RGB")
41
- transform = transforms.ToTensor()
42
- image_tensor = transform(image).unsqueeze(0).to(device) # Add batch dimension and move to device
43
 
44
- # Perform upscaling
45
- sr_image_tensor = model.upscale(image_tensor)
 
46
 
47
- # Post-process the output image and convert back to PIL
48
- sr_image = (sr_image_tensor.squeeze(0).cpu().numpy() * 255).astype(np.uint8)
49
- sr_image = Image.fromarray(sr_image.transpose(1, 2, 0)) # Change to HWC format
50
- return sr_image
51
 
52
- # Create a Gradio interface
53
- interface = gr.Interface(
54
  fn=upscale_image,
55
- inputs=gr.Image(type="pil"),
56
- outputs=gr.Image(type="pil"),
57
- title="Image Upscaler with R-ESRGAN Anime 6B",
58
- description="Upload an image, and it'll be upscaled using the R-ESRGAN Anime model. Supports inputs in various formats."
 
 
 
59
  )
60
 
61
- # Launch the interface
62
  if __name__ == "__main__":
63
- interface.launch()
 
1
+ import os
2
  import gradio as gr
3
  import torch
 
4
  from PIL import Image
5
+ from torchvision.transforms import ToTensor, ToPILImage
6
+ from torchvision import models
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # Load the R-ESRGAN Anime model (you need to set up this model properly)
9
+ model = torch.hub.load('xinntao/Real-ESRGAN', 'restorer', model='R-ESRGAN_Anime_X6')
10
 
11
+ def upscale_image(image, scale_factor):
12
+ # Ensure image is in PIL format
13
+ if not isinstance(image, Image.Image):
14
+ image = Image.fromarray(image)
 
 
15
 
16
+ # Upscale using the model
17
+ with torch.no_grad():
18
+ upscaled_image = model(image, scale=scale_factor)
19
 
20
+ return upscaled_image
 
 
 
21
 
22
+ # Create Gradio interface
23
+ iface = gr.Interface(
24
  fn=upscale_image,
25
+ inputs=[
26
+ gr.inputs.Image(type="pil", label="Input Image"),
27
+ gr.inputs.Slider(minimum=1, maximum=4, step=1, default=2, label="Scale Factor")
28
+ ],
29
+ outputs=gr.outputs.Image(type="pil", label="Upscaled Image"),
30
+ title="R-ESRGAN Anime 6B Image Upscaler",
31
+ description="Upload an image and select a scale factor to upscale the image using R-ESRGAN Anime 6B model."
32
  )
33
 
34
+ # Launch the Gradio app
35
  if __name__ == "__main__":
36
+ iface.launch()