Kartikeyssj2 commited on
Commit
78ab44f
1 Parent(s): c033def

Update fast_api.py

Browse files
Files changed (1) hide show
  1. fast_api.py +24 -10
fast_api.py CHANGED
@@ -8,16 +8,32 @@ import pickle
8
  import numpy as np
9
  from gensim.models import KeyedVectors
10
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  # Load the saved Word2Vec model
12
  word2vec_model = KeyedVectors.load("word2vec-google-news-300.model")
13
-
14
- model = whisper.load_model("tiny")
15
-
16
- # Load the saved state dictionary
17
- model_state = torch.load("whisper_tiny_model.pt")
18
-
19
- # Load the state dictionary into the model
20
- model.load_state_dict(model_state)
21
 
22
  def load_model(pickle_file_path: str):
23
  """Load a model from a pickle file."""
@@ -25,12 +41,10 @@ def load_model(pickle_file_path: str):
25
  model = pickle.load(file)
26
  return model
27
 
28
-
29
  pronunciation_fluency_model = load_model("pronunciation_fluency_v2.pkl")
30
 
31
  app = FastAPI()
32
 
33
-
34
  def transcribe(audio_file_path: str, model):
35
  # Load audio and run inference
36
  result = model.transcribe(audio_file_path)
 
8
  import numpy as np
9
  from gensim.models import KeyedVectors
10
 
11
+ def load_whisper_model(model_path, device='cpu'):
12
+ # Load model architecture
13
+ model = whisper.model.Whisper(
14
+ whisper.model.ModelDimensions(
15
+ n_mels=80,
16
+ n_audio_ctx=1500,
17
+ n_audio_state=384,
18
+ n_audio_head=6,
19
+ n_audio_layer=4,
20
+ n_vocab=51865,
21
+ n_text_ctx=448,
22
+ n_text_state=384,
23
+ n_text_head=6,
24
+ n_text_layer=4
25
+ )
26
+ )
27
+
28
+ # Load state dict
29
+ state_dict = torch.load(model_path, map_location=device)
30
+ model.load_state_dict(state_dict)
31
+ model.eval()
32
+ return model
33
+
34
  # Load the saved Word2Vec model
35
  word2vec_model = KeyedVectors.load("word2vec-google-news-300.model")
36
+ model = load_whisper_model("whisper_tiny_model.pt")
 
 
 
 
 
 
 
37
 
38
  def load_model(pickle_file_path: str):
39
  """Load a model from a pickle file."""
 
41
  model = pickle.load(file)
42
  return model
43
 
 
44
  pronunciation_fluency_model = load_model("pronunciation_fluency_v2.pkl")
45
 
46
  app = FastAPI()
47
 
 
48
  def transcribe(audio_file_path: str, model):
49
  # Load audio and run inference
50
  result = model.transcribe(audio_file_path)