s3nh commited on
Commit
bdd07df
·
verified ·
1 Parent(s): 129d868

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -28,7 +28,7 @@ This Space demonstrates LAIONBOT functionalities,
28
 
29
 
30
  def load_whisper():
31
- return whisper.load_model("medium")
32
 
33
  def load_tts():
34
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-pol")
@@ -52,9 +52,10 @@ def _load_model_tokenizer():
52
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
53
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto",trust_remote_code=True, fp16=True).eval()
54
  return model, tokenizer
55
-
56
  if torch.cuda.is_available():
57
- whisper_model = load_whisper()
 
58
  tts_model, tokenizer_tss = load_tts()
59
  model, tokenizer = _load_model_tokenizer()
60
 
 
28
 
29
 
30
  def load_whisper():
31
+ return whisper.load_model("medium", device = 'cpu')
32
 
33
  def load_tts():
34
  tts_model = VitsModel.from_pretrained("facebook/mms-tts-pol")
 
52
  tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
53
  model = AutoModelForCausalLM.from_pretrained(model_id, device_map="auto",trust_remote_code=True, fp16=True).eval()
54
  return model, tokenizer
55
+ whisper_model = load_whisper()
56
  if torch.cuda.is_available():
57
+ whisper_model = whisper_model.to(device='cuda')
58
+ #whisper_model = load_whisper()
59
  tts_model, tokenizer_tss = load_tts()
60
  model, tokenizer = _load_model_tokenizer()
61