jlonsako commited on
Commit
da7f7e0
1 Parent(s): f09834e

Update to chunking and half precision

Browse files
Files changed (1) hide show
  1. app.py +51 -20
app.py CHANGED
@@ -31,11 +31,15 @@ def Transcribe(file):
31
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
32
  start_time = time.time()
33
  model.load_adapter("amh")
 
34
 
35
  preprocessAudio(file)
36
- block_size = 30 #30 second chunks of audio
 
37
 
38
  transcripts = []
 
 
39
  stream = librosa.stream(
40
  "./audioToConvert.wav",
41
  block_length=block_size,
@@ -43,9 +47,8 @@ def Transcribe(file):
43
  hop_length=16000
44
  )
45
 
46
- model.half()
47
  model.to(device)
48
- print(f"Model loaded to {device}: Entering transcription phase")
49
 
50
  #Code for timestamping
51
  encoding_start = 0
@@ -54,42 +57,70 @@ def Transcribe(file):
54
  for speech_segment in stream:
55
  if len(speech_segment.shape) > 1:
56
  speech_segment = speech_segment[:,0] + speech_segment[:,1]
57
- input_values = processor(speech_segment, sampling_rate=16_000, return_tensors="pt").input_values.to(device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
  input_values = input_values.half()
59
  with torch.no_grad():
60
  logits = model(input_values).logits
61
- if len(logits.shape) == 1:
62
- logits = logits.unsqueeze(0)
63
- transcription = processor.batch_decode(logits.cpu().numpy()).text
64
- transcripts.append(transcription[0])
65
 
66
- #Generate timestamps
67
- encoding_end = encoding_start + block_size
68
- formatted_start = format_time(encoding_start)
69
- formatted_end = format_time(encoding_end)
70
-
71
- #Write to the .sbv file
72
- sbv_file.write(f"{formatted_start},{formatted_end}\n")
73
- sbv_file.write(f"{transcription[0]}\n\n")
74
- encoding_start = encoding_end
75
 
76
  # Freeing up memory
77
  del input_values
78
  del logits
79
- del transcription
80
  torch.cuda.empty_cache()
81
  gc.collect()
82
 
 
83
  # Join all transcripts into a single transcript
84
  transcript = ' '.join(transcripts)
85
  sbv_file.close()
86
 
87
  end_time = time.time()
88
- os.system("rm ./audio.wav")
89
  print(f"The script ran for {end_time - start_time} seconds.")
90
  return("./subtitle.sbv")
91
 
92
  demo = gr.Interface(fn=Transcribe, inputs=gr.File(), outputs="file")
93
  demo.launch()
94
 
95
-
 
31
  device = "cuda:0" if torch.cuda.is_available() else "cpu"
32
  start_time = time.time()
33
  model.load_adapter("amh")
34
+ model.half()
35
 
36
  preprocessAudio(file)
37
+ block_size = 30
38
+ batch_size = 22 # or whatever number you choose
39
 
40
  transcripts = []
41
+ speech_segments = []
42
+
43
  stream = librosa.stream(
44
  "./audioToConvert.wav",
45
  block_length=block_size,
 
47
  hop_length=16000
48
  )
49
 
 
50
  model.to(device)
51
+ print("Model loaded to gpu: Entering transcription phase")
52
 
53
  #Code for timestamping
54
  encoding_start = 0
 
57
  for speech_segment in stream:
58
  if len(speech_segment.shape) > 1:
59
  speech_segment = speech_segment[:,0] + speech_segment[:,1]
60
+ speech_segments.append(speech_segment)
61
+
62
+ if len(speech_segments) == batch_size:
63
+ input_values = processor(speech_segments, sampling_rate=16_000, return_tensors="pt", padding=True).input_values.to(device)
64
+ input_values = input_values.half()
65
+ with torch.no_grad():
66
+ logits = model(input_values).logits
67
+ if len(logits.shape) == 1:
68
+ logits = logits.unsqueeze(0)
69
+ #predicted_ids = torch.argmax(logits, dim=-1)
70
+ transcriptions = processor.batch_decode(logits.cpu().numpy()).text
71
+ transcripts.extend(transcriptions)
72
+
73
+ # Write to the .sbv file
74
+ for i, transcription in enumerate(transcriptions):
75
+ encoding_start = (i * block_size)
76
+ encoding_end = encoding_start + block_size
77
+ formatted_start = format_time(encoding_start)
78
+ formatted_end = format_time(encoding_end)
79
+ sbv_file.write(f"{formatted_start},{formatted_end}\n")
80
+ sbv_file.write(f"{transcription}\n\n")
81
+
82
+ # Clear the batch
83
+ speech_segments = []
84
+
85
+ # Freeing up memory
86
+ del input_values
87
+ del logits
88
+ del transcriptions
89
+ torch.cuda.empty_cache()
90
+ gc.collect()
91
+
92
+ if speech_segments:
93
+ input_values = processor(speech_segments, sampling_rate=16_000, return_tensors="pt", padding=True).input_values.to(device)
94
  input_values = input_values.half()
95
  with torch.no_grad():
96
  logits = model(input_values).logits
97
+ transcriptions = processor.batch_decode(logits.cpu().numpy()).text
98
+ transcripts.extend(transcriptions)
 
 
99
 
100
+ for i in range(len(speech_segments)):
101
+ encoding_end = encoding_start + block_size
102
+ formatted_start = format_time(encoding_start)
103
+ formatted_end = format_time(encoding_end)
104
+ sbv_file.write(f"{formatted_start},{formatted_end}\n")
105
+ sbv_file.write(f"{transcriptions[i]}\n\n")
106
+ encoding_start = encoding_end
 
 
107
 
108
  # Freeing up memory
109
  del input_values
110
  del logits
111
+ del transcriptions
112
  torch.cuda.empty_cache()
113
  gc.collect()
114
 
115
+
116
  # Join all transcripts into a single transcript
117
  transcript = ' '.join(transcripts)
118
  sbv_file.close()
119
 
120
  end_time = time.time()
 
121
  print(f"The script ran for {end_time - start_time} seconds.")
122
  return("./subtitle.sbv")
123
 
124
  demo = gr.Interface(fn=Transcribe, inputs=gr.File(), outputs="file")
125
  demo.launch()
126