csikasote commited on
Commit
158da7b
1 Parent(s): 00fca47

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +3 -3
README.md CHANGED
@@ -40,7 +40,7 @@ import torchaudio
40
  from datasets import load_dataset
41
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
42
 
43
- test_dataset = load_dataset("csv", data_files={"test": "/content/test.csv"}, delimiter="\t")["test"] # Adapt the path to test.csv
44
 
45
  processor = Wav2Vec2Processor.from_pretrained("csikasote/wav2vec2-large-xlsr-bemba")
46
  model = Wav2Vec2ForCTC.from_pretrained("csikasote/wav2vec2-large-xlsr-bemba")
@@ -59,7 +59,7 @@ test_dataset = test_dataset.map(speech_file_to_array_fn)
59
  inputs = processor(test_dataset["speech"][:2], sampling_rate=16_000, return_tensors="pt", padding=True)
60
 
61
  with torch.no_grad():
62
- logits = model(inputs.input_values, attention_mask=inputs.attention_mask).logits
63
 
64
  predicted_ids = torch.argmax(logits, dim=-1)
65
 
@@ -86,7 +86,7 @@ processor = Wav2Vec2Processor.from_pretrained("csikasote/wav2vec2-large-xlsr-bem
86
  model = Wav2Vec2ForCTC.from_pretrained("csikasote/wav2vec2-large-xlsr-bemba")
87
  model.to("cuda")
88
 
89
- chars_to_ignore_regex = '[\\,\\?\\.\\!\\;\\:\\"\\“]'
90
  #resampler = torchaudio.transforms.Resample(48_000, 16_000)
91
 
92
  # Preprocessing the datasets.
 
40
  from datasets import load_dataset
41
  from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
42
 
43
+ test_dataset = load_dataset("csv", data_files={"test": "/content/test.csv"}, delimiter="\\t")["test"] # Adapt the path to test.csv
44
 
45
  processor = Wav2Vec2Processor.from_pretrained("csikasote/wav2vec2-large-xlsr-bemba")
46
  model = Wav2Vec2ForCTC.from_pretrained("csikasote/wav2vec2-large-xlsr-bemba")
 
59
  inputs = processor(test_dataset["speech"][:2], sampling_rate=16_000, return_tensors="pt", padding=True)
60
 
61
  with torch.no_grad():
62
+ logits = model(inputs.input_values.to("cuda"), attention_mask=inputs.attention_mask.to("cuda")).logits
63
 
64
  predicted_ids = torch.argmax(logits, dim=-1)
65
 
 
86
  model = Wav2Vec2ForCTC.from_pretrained("csikasote/wav2vec2-large-xlsr-bemba")
87
  model.to("cuda")
88
 
89
+ chars_to_ignore_regex = '[\\\\,\\\\?\\\\.\\\\!\\\\;\\\\:\\\\"\\\\“]'
90
  #resampler = torchaudio.transforms.Resample(48_000, 16_000)
91
 
92
  # Preprocessing the datasets.