|
|
|
|
|
""" |
|
This file gives a sample demonstration of how to use the given functions in Python, for the Voice Safety Classifier model. |
|
""" |
|
|
|
import torch |
|
import librosa |
|
import numpy as np |
|
import argparse |
|
from transformers import WavLMForSequenceClassification |
|
|
|
|
|
def feature_extract_simple( |
|
wav, |
|
sr=16_000, |
|
win_len=15.0, |
|
win_stride=15.0, |
|
do_normalize=False, |
|
): |
|
"""simple feature extraction for wavLM |
|
Parameters |
|
---------- |
|
wav : str or array-like |
|
path to the wav file, or array-like |
|
sr : int, optional |
|
sample rate, by default 16_000 |
|
win_len : float, optional |
|
window length, by default 15.0 |
|
win_stride : float, optional |
|
window stride, by default 15.0 |
|
do_normalize: bool, optional |
|
whether to normalize the input, by default False. |
|
Returns |
|
------- |
|
np.ndarray |
|
batched input to wavLM |
|
""" |
|
if type(wav) == str: |
|
signal, _ = librosa.core.load(wav, sr=sr) |
|
else: |
|
try: |
|
signal = np.array(wav).squeeze() |
|
except Exception as e: |
|
print(e) |
|
raise RuntimeError |
|
batched_input = [] |
|
stride = int(win_stride * sr) |
|
l = int(win_len * sr) |
|
if len(signal) / sr > win_len: |
|
for i in range(0, len(signal), stride): |
|
if i + int(win_len * sr) > len(signal): |
|
|
|
chunked = np.pad(signal[i:], (0, l - len(signal[i:]))) |
|
else: |
|
chunked = signal[i : i + l] |
|
if do_normalize: |
|
chunked = (chunked - np.mean(chunked)) / (np.std(chunked) + 1e-7) |
|
batched_input.append(chunked) |
|
if i + int(win_len * sr) > len(signal): |
|
break |
|
else: |
|
if do_normalize: |
|
signal = (signal - np.mean(signal)) / (np.std(signal) + 1e-7) |
|
batched_input.append(signal) |
|
return np.stack(batched_input) |
|
|
|
|
|
def infer(model, inputs): |
|
output = model(inputs) |
|
probs = torch.sigmoid(torch.Tensor(output.logits)) |
|
return probs |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--audio_file", |
|
type=str, |
|
help="File to run inference", |
|
) |
|
parser.add_argument( |
|
"--model_path", |
|
type=str, |
|
default="roblox/voice-safety-classifier", |
|
help="checkpoint file of model", |
|
) |
|
args = parser.parse_args() |
|
labels_name_list = [ |
|
"Profanity", |
|
"DatingAndSexting", |
|
"Racist", |
|
"Bullying", |
|
"Other", |
|
"NoViolation", |
|
] |
|
|
|
audio, _ = librosa.core.load(args.audio_file, sr=16000) |
|
input_np = feature_extract_simple(audio, sr=16000) |
|
input_pt = torch.Tensor(input_np) |
|
model = WavLMForSequenceClassification.from_pretrained( |
|
args.model_path, num_labels=len(labels_name_list) |
|
) |
|
probs = infer(model, input_pt) |
|
probs = probs.reshape(-1, 6).detach().tolist() |
|
print(f"Probabilities for {args.audio_file} is:") |
|
for chunk_idx in range(len(probs)): |
|
print(f"\nSegment {chunk_idx}:") |
|
for label_idx, label in enumerate(labels_name_list): |
|
print(f"{label} : {probs[chunk_idx][label_idx]}") |
|
|