NickKolok commited on
Commit
c08ae31
·
verified ·
1 Parent(s): f72a4f5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -9
app.py CHANGED
@@ -3,25 +3,35 @@ import torch
3
  import torchvision.transforms as transforms
4
  from PIL import Image
5
  import os
 
 
6
 
7
- # Download the required model checkpoint if it doesn't exist
8
- MODEL_PATH = "weights/RealESRGAN_x4 Anime6B.pth"
9
- if not os.path.exists(MODEL_PATH):
10
- os.system(f"wget https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4_Anime6B.pth -P weights")
11
 
12
- # Define the Real-ESRGAN class using the pre-trained model path
 
 
 
 
 
 
 
 
 
13
  class RealESRGAN:
14
- def __init__(self, device):
15
- self.model = torch.hub.load("xinntao/Real-ESRGAN", "restoration", model="RealESRGAN_x4_Anime6B.pth", device=device)
16
  self.model.eval()
17
-
18
  def upscale(self, image_tensor):
19
  with torch.no_grad():
20
  sr_image_tensor = self.model(image_tensor).clamp(0, 1)
21
  return sr_image_tensor
22
 
23
  device = "cuda" if torch.cuda.is_available() else "cpu"
24
- model = RealESRGAN(device)
25
 
26
  # Define the image upscaling function
27
  def upscale_image(image):
 
3
  import torchvision.transforms as transforms
4
  from PIL import Image
5
  import os
6
+ import requests
7
+ from io import BytesIO
8
 
9
+ # Define the model URL and local path
10
+ MODEL_URL = "https://github.com/xinntao/Real-ESRGAN/releases/download/v0.2.2.4/RealESRGAN_x4_Anime6B.pth"
11
+ MODEL_PATH = "RealESRGAN_x4_Anime6B.pth"
 
12
 
13
+ # Download the model if it doesn't exist
14
+ if not os.path.exists(MODEL_PATH):
15
+ response = requests.get(MODEL_URL)
16
+ if response.status_code == 200:
17
+ with open(MODEL_PATH, 'wb') as f:
18
+ f.write(response.content)
19
+ else:
20
+ raise Exception(f"Could not download model: {response.status_code}")
21
+
22
+ # Define the Real-ESRGAN class
23
  class RealESRGAN:
24
+ def __init__(self, model_path, device):
25
+ self.model = torch.hub.load('xinntao/Real-ESRGAN', 'restoration', model=model_path, device=device)
26
  self.model.eval()
27
+
28
  def upscale(self, image_tensor):
29
  with torch.no_grad():
30
  sr_image_tensor = self.model(image_tensor).clamp(0, 1)
31
  return sr_image_tensor
32
 
33
  device = "cuda" if torch.cuda.is_available() else "cpu"
34
+ model = RealESRGAN(MODEL_PATH, device)
35
 
36
  # Define the image upscaling function
37
  def upscale_image(image):