File size: 2,297 Bytes
ffa317c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from transformers import (
    AutoConfig,
    AutoModelForSeq2SeqLM,
    AutoTokenizer
)

from datasets import load_dataset
import torch

def load_nlu_model():
    config = AutoConfig.from_pretrained("Beomseok-LEE/NLU-Speech-MASSIVE-finetune")
    tokenizer = AutoTokenizer.from_pretrained("Beomseok-LEE/NLU-Speech-MASSIVE-finetune")
    model = AutoModelForSeq2SeqLM.from_pretrained("Beomseok-LEE/NLU-Speech-MASSIVE-finetune", config=config)

    return model, tokenizer

def run_nlu_inference(model, tokenizer, example):
  print(example)
  formatted_example = "Annotate: " + example
  input_values = tokenizer(formatted_example, max_length=128, padding=False, truncation=True, return_tensors="pt").input_ids

  with torch.no_grad():
    pred_ids = model.generate(input_values)

  prediction = tokenizer.decode(pred_ids[0], skip_special_tokens=True)
  print(prediction)

  splitted_pred = prediction.strip().split()

  slots_prediction = ''
  intent_prediction = ''

  if len(splitted_pred) >= 2:
      slots_prediction = splitted_pred[:-1]
      intent_prediction = splitted_pred[-1]
  if len(splitted_pred) == 1:
      slots_prediction = splitted_pred

  words = example.split(' ')
  
  title_1 = '[ASR output]\n'
  title_2 = '\n\n[NLU - slot filling]\n'
  title_3 = '\n\n[NLU - intent classifcation]\n'

  prefix_str_1 = title_1 + example + title_2
  prefix_str_2 = title_3

  structured_output = {
     'text' : prefix_str_1 + example + prefix_str_2 + intent_prediction, 
      'entities': []}

  structured_output['entities'].append({
         'entity': 'ASR output',
         'word': example,
         'start': len(title_1),
         'end': len(title_1) + len(example)
      })

  idx = len(prefix_str_1)

  for slot, word in zip(slots_prediction, words):
    _entity = slot
    _word = word
    _start = idx
    _end = idx + len(word)
    idx = _end + 1

    structured_output['entities'].append({
        'entity': _entity,
        'word': _word,
        'start': _start,
        'end': _end
    })
  
  idx = len(prefix_str_1 + example + prefix_str_2)

  if intent_prediction:
    structured_output['entities'].append({
      'entity': 'Classified Intent',
      'word': intent_prediction,
      'start': idx,
      'end': idx + len(intent_prediction)
    })

  return structured_output