Ahsen Khaliq commited on
Commit
178d84c
1 Parent(s): d0b8090

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +9 -7
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import os
2
  os.system("gdown https://drive.google.com/uc?id=14pXWwB4Zm82rsDdvbGguLfx9F8aM7ovT")
 
3
  import clip
4
  import os
5
  from torch import nn
@@ -32,7 +33,6 @@ TA = Union[T, ARRAY]
32
  D = torch.device
33
  CPU = torch.device('cpu')
34
 
35
- model_path = 'conceptual_weights.pt'
36
 
37
  def get_device(device_id: int) -> D:
38
  if not torch.cuda.is_available():
@@ -234,14 +234,16 @@ prefix_length = 10
234
 
235
  model = ClipCaptionModel(prefix_length)
236
 
237
- model.load_state_dict(torch.load(model_path, map_location=CPU))
238
-
239
- model = model.eval()
240
- device = CUDA(0) if is_gpu else "cpu"
241
  model = model.to(device)
242
 
243
 
244
- def inference(img):
 
 
 
 
 
 
245
  use_beam_search = False
246
  image = io.imread(img.name)
247
  pil_image = PIL.Image.fromarray(image)
@@ -262,7 +264,7 @@ article = "<p style='text-align: center'><a href='https://github.com/rmokady/CLI
262
  examples=[['water.jpeg']]
263
  gr.Interface(
264
  inference,
265
- gr.inputs.Image(type="file", label="Input"),
266
  gr.outputs.Textbox(label="Output"),
267
  title=title,
268
  description=description,
 
1
  import os
2
  os.system("gdown https://drive.google.com/uc?id=14pXWwB4Zm82rsDdvbGguLfx9F8aM7ovT")
3
+ os.system("gdown https://drive.google.com/uc?id=1IdaBtMSvtyzF0ByVaBHtvM0JYSXRExRX")
4
  import clip
5
  import os
6
  from torch import nn
 
33
  D = torch.device
34
  CPU = torch.device('cpu')
35
 
 
36
 
37
  def get_device(device_id: int) -> D:
38
  if not torch.cuda.is_available():
 
234
 
235
  model = ClipCaptionModel(prefix_length)
236
 
 
 
 
 
237
  model = model.to(device)
238
 
239
 
240
+ def inference(img,model):
241
+ if model == "COCO":
242
+ model_path = 'coco_weights.pt'
243
+ else:
244
+ model_path = 'conceptual_weights.pt'
245
+ model.load_state_dict(torch.load(model_path, map_location=CPU))
246
+
247
  use_beam_search = False
248
  image = io.imread(img.name)
249
  pil_image = PIL.Image.fromarray(image)
 
264
  examples=[['water.jpeg']]
265
  gr.Interface(
266
  inference,
267
+ [gr.inputs.Image(type="file", label="Input"),gr.inputs.Radio(choices["COCO","Conceptual captions"], type="value", default="COCO", label="Model")],
268
  gr.outputs.Textbox(label="Output"),
269
  title=title,
270
  description=description,