|
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor |
|
import torchaudio |
|
import torch |
|
|
|
|
|
repo_id = "eddiegulay/wav2vec2-large-xlsr-mvc-swahili" |
|
|
|
model = Wav2Vec2ForCTC.from_pretrained(repo_id) |
|
processor = Wav2Vec2Processor.from_pretrained(repo_id) |
|
|
|
|
|
|
|
def transcribe(audio_path): |
|
|
|
audio_input, sample_rate = torchaudio.load(audio_path) |
|
target_sample_rate = 16000 |
|
audio_input = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sample_rate)(audio_input) |
|
|
|
|
|
input_dict = processor(audio_input[0], return_tensors="pt", padding=True, sampling_rate=16000) |
|
|
|
|
|
logits = model(input_dict.input_values).logits |
|
pred_ids = torch.argmax(logits, dim=-1)[0] |
|
transcription = processor.decode(pred_ids) |
|
|
|
return transcription |
|
|
|
|
|
|
|
transcribe("download.wav") |