File size: 2,297 Bytes
ffa317c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
from transformers import (
AutoConfig,
AutoModelForSeq2SeqLM,
AutoTokenizer
)
from datasets import load_dataset
import torch
def load_nlu_model():
config = AutoConfig.from_pretrained("Beomseok-LEE/NLU-Speech-MASSIVE-finetune")
tokenizer = AutoTokenizer.from_pretrained("Beomseok-LEE/NLU-Speech-MASSIVE-finetune")
model = AutoModelForSeq2SeqLM.from_pretrained("Beomseok-LEE/NLU-Speech-MASSIVE-finetune", config=config)
return model, tokenizer
def run_nlu_inference(model, tokenizer, example):
print(example)
formatted_example = "Annotate: " + example
input_values = tokenizer(formatted_example, max_length=128, padding=False, truncation=True, return_tensors="pt").input_ids
with torch.no_grad():
pred_ids = model.generate(input_values)
prediction = tokenizer.decode(pred_ids[0], skip_special_tokens=True)
print(prediction)
splitted_pred = prediction.strip().split()
slots_prediction = ''
intent_prediction = ''
if len(splitted_pred) >= 2:
slots_prediction = splitted_pred[:-1]
intent_prediction = splitted_pred[-1]
if len(splitted_pred) == 1:
slots_prediction = splitted_pred
words = example.split(' ')
title_1 = '[ASR output]\n'
title_2 = '\n\n[NLU - slot filling]\n'
title_3 = '\n\n[NLU - intent classifcation]\n'
prefix_str_1 = title_1 + example + title_2
prefix_str_2 = title_3
structured_output = {
'text' : prefix_str_1 + example + prefix_str_2 + intent_prediction,
'entities': []}
structured_output['entities'].append({
'entity': 'ASR output',
'word': example,
'start': len(title_1),
'end': len(title_1) + len(example)
})
idx = len(prefix_str_1)
for slot, word in zip(slots_prediction, words):
_entity = slot
_word = word
_start = idx
_end = idx + len(word)
idx = _end + 1
structured_output['entities'].append({
'entity': _entity,
'word': _word,
'start': _start,
'end': _end
})
idx = len(prefix_str_1 + example + prefix_str_2)
if intent_prediction:
structured_output['entities'].append({
'entity': 'Classified Intent',
'word': intent_prediction,
'start': idx,
'end': idx + len(intent_prediction)
})
return structured_output |