zhiqiulin commited on
Commit
1112f1b
·
verified ·
1 Parent(s): d258d19

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +11 -3
app.py CHANGED
@@ -6,14 +6,21 @@ torch.jit.script = lambda f: f # Avoid script error in lambda
6
  from t2v_metrics import VQAScore, list_all_vqascore_models
7
 
8
  # Global model variable, but do not initialize or move to CUDA here
9
- model_pipe = VQAScore(model="clip-flant5-xl", device="cuda") # our recommended scoring model
 
 
 
 
 
10
 
11
  @spaces.GPU(duration = 20)
12
  def generate(model_name, image, text):
13
- print(list_all_vqascore_models()) # Debug: List available models
 
 
14
  print("Image:", image) # Debug: Print image path
15
  print("Text:", text) # Debug: Print text input
16
- print("Generating!")
17
  # Wrap the model call in a try-except block to capture and debug CUDA errors
18
  try:
19
  result = model_pipe(images=[image], texts=[text]).cpu()[0][0].item() # Perform the model inference
@@ -26,6 +33,7 @@ def generate(model_name, image, text):
26
 
27
  demo = gr.Interface(
28
  fn=generate, # function to call
 
29
  inputs=[gr.Dropdown(["clip-flant5-xl", "clip-flant5-xxl"], label="Model Name"), gr.Image(type="filepath"), gr.Textbox(label="Prompt")], # define the types of inputs
30
  outputs="number", # define the type of output
31
  title="VQAScore", # title of the app
 
6
  from t2v_metrics import VQAScore, list_all_vqascore_models
7
 
8
  # Global model variable, but do not initialize or move to CUDA here
9
+ cur_model_name = "clip-flant5-xl"
10
+ model_pipe = update_model(cur_model_name)
11
+
12
+ def update_model(model_name):
13
+ if model_nm
14
+ return VQAScore(model=model_name, device="cuda")
15
 
16
  @spaces.GPU(duration = 20)
17
  def generate(model_name, image, text):
18
+ if model_name != cur_model_name:
19
+ model_pipe = update_model(model_name)
20
+
21
  print("Image:", image) # Debug: Print image path
22
  print("Text:", text) # Debug: Print text input
23
+ print("Using model:", model_name)
24
  # Wrap the model call in a try-except block to capture and debug CUDA errors
25
  try:
26
  result = model_pipe(images=[image], texts=[text]).cpu()[0][0].item() # Perform the model inference
 
33
 
34
  demo = gr.Interface(
35
  fn=generate, # function to call
36
+ # ['clip-flant5-xxl', 'clip-flant5-xl', 'clip-flant5-xxl-no-system', 'clip-flant5-xxl-no-system-no-user', 'llava-v1.5-13b', 'llava-v1.5-7b', 'sharegpt4v-7b', 'sharegpt4v-13b', 'llava-v1.6-13b', 'instructblip-flant5-xxl', 'instructblip-flant5-xl']
37
  inputs=[gr.Dropdown(["clip-flant5-xl", "clip-flant5-xxl"], label="Model Name"), gr.Image(type="filepath"), gr.Textbox(label="Prompt")], # define the types of inputs
38
  outputs="number", # define the type of output
39
  title="VQAScore", # title of the app