import time
import os
import re
import torch
import torchaudio
import gradio as gr
import spaces
from transformers import AutoFeatureExtractor, AutoTokenizer, WhisperForConditionalGeneration, WhisperProcessor, pipeline
from huggingface_hub import model_info
try:
import flash_attn
FLASH_ATTENTION = True
except ImportError:
FLASH_ATTENTION = False
import yt_dlp # Added import for yt-dlp
MODEL_NAME = "NbAiLab/nb-whisper-large"
max_audio_length = 30 * 60
share = (os.environ.get("SHARE", "False")[0].lower() in "ty1") or None
auth_token = os.environ.get("AUTH_TOKEN") or True
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print(f"Bruker enhet: {device}")
@spaces.GPU(duration=60 * 2)
def pipe(file, return_timestamps=False, lang="no"):
asr = pipeline(
task="automatic-speech-recognition",
model=MODEL_NAME,
chunk_length_s=28,
device=device,
token=auth_token,
torch_dtype=torch.float16,
model_kwargs={"attn_implementation": "flash_attention_2", "num_beams": 5, "language": lang} if FLASH_ATTENTION else {"attn_implementation": "sdpa", "num_beams": 5},
)
asr.model.config.forced_decoder_ids = asr.tokenizer.get_decoder_prompt_ids(
language=lang,
task="transcribe",
no_timestamps=not return_timestamps,
)
return asr(file, return_timestamps=return_timestamps, batch_size=24, generate_kwargs={'task': 'transcribe', 'language': lang})
def format_output(text):
text = re.sub(r'(\.{3,}|[.!:?])', lambda m: m.group() + '
', text)
return text
def transcribe(file, return_timestamps=False, lang_nn=False):
waveform, sample_rate = torchaudio.load(file)
audio_duration = waveform.size(1) / sample_rate
warning_message = None
if audio_duration > max_audio_length:
warning_message = (
"⚠️ Advarsel: "
"Lydfilen er lengre enn 30 minutter. Kun de første 30 minuttene vil bli transkribert."
)
waveform = waveform[:, :int(max_audio_length * sample_rate)]
truncated_file = "truncated_audio.wav"
torchaudio.save(truncated_file, waveform, sample_rate)
file_to_transcribe = truncated_file
truncated = True
else:
file_to_transcribe = file
truncated = False
if not lang_nn:
if not return_timestamps:
text = pipe(file_to_transcribe)["text"]
formatted_text = format_output(text)
else:
chunks = pipe(file_to_transcribe, return_timestamps=True)["chunks"]
text = []
for chunk in chunks:
start_time = time.strftime('%H:%M:%S', time.gmtime(chunk["timestamp"][0])) if chunk["timestamp"][0] is not None else "??:??:??"
end_time = time.strftime('%H:%M:%S', time.gmtime(chunk["timestamp"][1])) if chunk["timestamp"][1] is not None else "??:??:??"
line = f"[{start_time} -> {end_time}] {chunk['text']}"
text.append(line)
formatted_text = "
".join(text)
else:
if not return_timestamps:
text = pipe(file_to_transcribe, lang="nn")["text"]
formatted_text = format_output(text)
else:
chunks = pipe(file_to_transcribe, return_timestamps=True, lang="nn")["chunks"]
text = []
for chunk in chunks:
start_time = time.strftime('%H:%M:%S', time.gmtime(chunk["timestamp"][0])) if chunk["timestamp"][0] is not None else "??:??:??"
end_time = time.strftime('%H:%M:%S', time.gmtime(chunk["timestamp"][1])) if chunk["timestamp"][1] is not None else "??:??:??"
line = f"[{start_time} -> {end_time}] {chunk['text']}"
text.append(line)
formatted_text = "
".join(text)
output_file = "transcription.txt"
with open(output_file, "w") as f:
f.write(re.sub('
', '\n', formatted_text))
if truncated:
link="https://github.com/NbAiLab/nostram/blob/main/leverandorer.md"
disclaimer = (
"\n\n Dette er en demo. Det er ikke tillatt å bruke denne teksten i profesjonell sammenheng. "
"Vi anbefaler at hvis du trenger å transkribere lengre opptak, så kjører du enten modellen lokalt "
"eller sjekker denne siden for å se hvem som leverer løsninger basert på NB-Whisper: "
f"denne siden."
)
formatted_text += f"
{disclaimer}"
formatted_text += "
Transkribert med NB-Whisper demo"
return warning_message, formatted_text, output_file
def _return_yt_html_embed(yt_url):
video_id = yt_url.split("?v=")[-1]
HTML_str = (
f'