Spaces:
Running
on
Zero
Running
on
Zero
import gradio as gr | |
from transformers import WhisperProcessor, WhisperForConditionalGeneration | |
import torch | |
import torchaudio | |
import spaces | |
import re | |
# Initialize devices | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
# Load model and processor | |
processor = WhisperProcessor.from_pretrained("aiola/whisper-ner-tag-and-mask-v1") | |
model = WhisperForConditionalGeneration.from_pretrained("aiola/whisper-ner-tag-and-mask-v1") | |
model = model.to(device) | |
examples = [ | |
[ | |
"audio/sports.wav", | |
"football-club, football-player, referee", | |
False | |
], | |
[ | |
"audio/entertainment.wav", | |
"movie, date, actor, tv-show, musician", | |
True | |
], | |
[ | |
"audio/personal_info.wav", | |
"address, name, phone-number", | |
True | |
], | |
[ | |
"audio/672-122797-0026.wav", | |
"biological-classification, desire, demographic-group, object-category, relationship-role, reflexive-pronoun, furniture-type", | |
False | |
], | |
[ | |
"audio/672-122797-0027.wav", | |
"action, emotional-resilience, comparative-path-characteristic, social-role", | |
True | |
], | |
[ | |
"audio/672-122797-0024.wav", | |
"health-warning, importance-indicator, event, sentiment", | |
False | |
], | |
[ | |
"audio/672-122797-0048.wav", | |
"weapon, emotional-state, household-chore, atmosphere-quality", | |
False | |
], | |
] | |
def unify_ner_text(text, symbols_to_replace=("/", " ", ":", "_")): | |
"""Process and standardize entity text by replacing certain symbols and normalizing spaces.""" | |
text = " ".join(text.split()) | |
for symbol in symbols_to_replace: | |
text = text.replace(symbol, "-") | |
return text.lower() | |
def extract_entities_and_clean_text_fixed(text, ner_mask=False): | |
entity_pattern = r"<(.*?)>(.*?)<\1>>" if not ner_mask else r"<(.*?)>>" | |
entities = [] | |
clean_text = [] | |
current_pos = 0 | |
# Iterate through the matches for entity tags | |
for match in re.finditer(entity_pattern, text): | |
# Add text before the entity to the clean text | |
clean_text.append(text[current_pos:match.start()]) | |
entity_type = match.group(1) | |
entity_text = "-" if ner_mask else match.group(2) | |
start_pos = len("".join(clean_text)) # Start position in the clean text | |
end_pos = start_pos + len(entity_text) | |
# Append the entity text to the clean text | |
clean_text.append(entity_text) | |
# Add the entity details to the list | |
entities.append({ | |
"entity": entity_type, | |
"text": entity_text, | |
"start": start_pos, | |
"end": end_pos | |
}) | |
# Update the current position to the end of the match | |
current_pos = match.end() | |
# Append the remaining part of the text after the last entity | |
clean_text.append(text[current_pos:]) | |
# Join all parts of the clean text | |
clean_text_str = "".join(clean_text) | |
return clean_text_str, entities | |
# This decorator ensures your function can use GPU on Hugging Face Spaces | |
def transcribe_and_recognize_entities(audio_file, prompt, ner_mask=False): | |
target_sample_rate = 16000 | |
signal, sampling_rate = torchaudio.load(audio_file) | |
resampler = torchaudio.transforms.Resample(orig_freq=sampling_rate, new_freq=target_sample_rate) | |
signal = resampler(signal) | |
if signal.ndim == 2: | |
signal = torch.mean(signal, dim=0) | |
input_features = processor(signal, sampling_rate=target_sample_rate, return_tensors="pt").input_features | |
input_features = input_features.to(device) | |
ner_types = prompt.split(',') | |
processed_ner_types = [unify_ner_text(ner_type.strip()) for ner_type in ner_types] | |
prompt = ", ".join(processed_ner_types) | |
if ner_mask: | |
prompt = f"<|mask|>{prompt}" | |
print(f"Prompt after unify_ner_text: {prompt}") | |
prompt_ids = processor.get_prompt_ids(prompt, return_tensors="pt") | |
prompt_ids = prompt_ids.to(device) | |
predicted_ids = model.generate( | |
input_features, | |
max_new_tokens=256, | |
prompt_ids=prompt_ids, | |
language='en', | |
generation_config=model.generation_config, | |
) | |
transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)[0] | |
clean_text_fixed, extracted_entities_fixed = extract_entities_and_clean_text_fixed(transcription, ner_mask=ner_mask) | |
return transcription, {"text": clean_text_fixed, "entities": extracted_entities_fixed} | |
with gr.Blocks(title="WhisperNER v1") as demo: | |
gr.Markdown( | |
""" | |
# π₯ Whisper-NER: ASR with zero-shot NER | |
WhisperNER is a unified model for automatic speech recognition (ASR) and named entity recognition (NER), with zero-shot capabilities. | |
The WhisperNER model is designed as a strong base model for the downstream task of ASR with NER, and can be fine-tuned on specific datasets for improved performance. | |
The [aiola/whisper-ner-tag-and-mask-v1](https://huggingface.co/aiola/whisper-ner-tag-and-mask-v1) model was finetuned from | |
the [aiola/whisper-ner-v1](https://huggingface.co/aiola/whisper-ner-v1) checkpoint using the NuNER dataset to perform joint audio transcription and NER tagging or NER masking. | |
The model was not trained on PII specific datasets, hence can perform general and open type entity masking. | |
It should be further finetuned in order to be used for PII detection. The model was trained and evaluated only on English data. Check out the paper for full details. | |
## Links | |
* π Paper: [WhisperNER: Unified Open Named Entity and Speech Recognition](https://arxiv.org/abs/2409.08107) | |
* π€ [WhisperNER model collection](https://huggingface.co/collections/aiola/whisperner-6723f14506f3662cf3a73df2) | |
* π» Code: https://github.com/aiola-lab/whisper-ner | |
""" | |
) | |
with gr.Row() as row1: | |
with gr.Column() as col1: | |
audio_input = gr.Audio(value=examples[0][0], label="Audio Example", type="filepath") | |
with gr.Column() as col2: | |
label_input = gr.Textbox(label="Entity Labels", value=examples[0][1]) | |
ner_mask = gr.Checkbox( | |
value=examples[0][2], | |
label="Entity Mask", | |
info="Mask or tag entities in the transcription.", | |
scale=0, | |
) | |
submit_btn = gr.Button("Submit") | |
gr.Markdown("## Output") | |
with gr.Row() as row3: | |
transcript_output = gr.Textbox(label="Transcription and Entities") | |
with gr.Row() as row4: | |
highlighted_text_output = gr.HighlightedText(label="Predicted Highlighted Entities") | |
examples = gr.Examples( | |
examples, | |
fn=transcribe_and_recognize_entities, | |
inputs=[audio_input, label_input, ner_mask], | |
outputs=[transcript_output, highlighted_text_output], | |
cache_examples=True, | |
run_on_click=True, | |
) | |
# Submitting | |
label_input.submit( | |
fn=transcribe_and_recognize_entities, | |
inputs=[audio_input, label_input, ner_mask], | |
outputs=[transcript_output, highlighted_text_output], | |
) | |
submit_btn.click( | |
fn=transcribe_and_recognize_entities, | |
inputs=[audio_input, label_input, ner_mask], | |
outputs=[transcript_output, highlighted_text_output], | |
) | |
demo.launch() | |