s3nh commited on
Commit
9702e1d
·
verified ·
1 Parent(s): 1c1372e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -4
app.py CHANGED
@@ -52,9 +52,13 @@ def _load_model_tokenizer():
52
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
53
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto",trust_remote_code=True, fp16=True).eval()
54
  return model, tokenizer
55
- whisper_model = load_whisper()
56
- tts_model, tokenizer_tss = load_tts()
57
- model, tokenizer = _load_model_tokenizer()
 
 
 
 
58
  def postprocess(self, y):
59
  if y is None:
60
  return []
@@ -96,7 +100,7 @@ def _parse_text(text):
96
  lines[i] = "<br>" + line
97
  text = "".join(lines)
98
  return text
99
-
100
  def predict(_query, _chatbot, _task_history):
101
  print(f"User: {_parse_text(_query)}")
102
  _chatbot.append((_parse_text(_query), ""))
 
52
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
53
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto",trust_remote_code=True, fp16=True).eval()
54
  return model, tokenizer
55
+
56
+ if torch.cuda.is_available():
57
+ whisper_model = load_whisper()
58
+ tts_model, tokenizer_tss = load_tts()
59
+ model, tokenizer = _load_model_tokenizer()
60
+
61
+
62
  def postprocess(self, y):
63
  if y is None:
64
  return []
 
100
  lines[i] = "<br>" + line
101
  text = "".join(lines)
102
  return text
103
+ @spaces.GPU
104
  def predict(_query, _chatbot, _task_history):
105
  print(f"User: {_parse_text(_query)}")
106
  _chatbot.append((_parse_text(_query), ""))