xianbao HF staff commited on
Commit
72583bd
1 Parent(s): a9c5082

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -5
app.py CHANGED
@@ -20,7 +20,7 @@ import spaces
20
  import gradio as gr
21
  from qwen_vl_utils import process_vision_info
22
  from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, TextIteratorStreamer
23
-
24
  DEFAULT_CKPT_PATH = 'Qwen/Qwen2-VL-7B-Instruct'
25
 
26
 
@@ -50,10 +50,12 @@ def _get_args():
50
 
51
 
52
  def _load_model_processor(args):
53
- if args.cpu_only:
54
- device_map = 'cpu'
55
- else:
56
- device_map = 'auto'
 
 
57
 
58
  # default: Load the model on the available device(s)
59
  # model = Qwen2VLForConditionalGeneration.from_pretrained(args.checkpoint_path, device_map=device_map)
 
20
  import gradio as gr
21
  from qwen_vl_utils import process_vision_info
22
  from transformers import AutoProcessor, Qwen2VLForConditionalGeneration, TextIteratorStreamer
23
+ import torch
24
  DEFAULT_CKPT_PATH = 'Qwen/Qwen2-VL-7B-Instruct'
25
 
26
 
 
50
 
51
 
52
  def _load_model_processor(args):
53
+ # if args.cpu_only:
54
+ # device_map = 'cpu'
55
+ # else:
56
+ # device_map = 'auto'
57
+
58
+ device_map = "cuda" if torch.cuda.is_available() else "cpu"
59
 
60
  # default: Load the model on the available device(s)
61
  # model = Qwen2VLForConditionalGeneration.from_pretrained(args.checkpoint_path, device_map=device_map)