Spaces:
Runtime error
Runtime error
alistairmcleay
commited on
Commit
·
b16a132
1
Parent(s):
6aeedda
Added dialogue system code
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +167 -4
- scripts/UBAR_code/__init__.py +0 -0
- scripts/UBAR_code/data_analysis.py +170 -0
- scripts/UBAR_code/interaction/UBAR_interact.py +457 -0
- scripts/UBAR_code/interaction/__init__.py +0 -0
- scripts/UBAR_code/interaction/config.yaml +23 -0
- scripts/UBAR_code/preprocess.py +576 -0
- scripts/UBAR_code/preprocess2.1.py +585 -0
- scripts/UBAR_code/train_ubar.py +697 -0
- scripts/agent_agent.yaml +0 -0
- scripts/crazyneuraluser.egg-info/PKG-INFO +171 -0
- scripts/crazyneuraluser.egg-info/SOURCES.txt +0 -0
- scripts/crazyneuraluser.egg-info/dependency_links.txt +1 -0
- scripts/crazyneuraluser.egg-info/not-zip-safe +1 -0
- scripts/crazyneuraluser.egg-info/requires.txt +15 -0
- scripts/crazyneuraluser.egg-info/top_level.txt +1 -0
- scripts/simulate_interaction.py +171 -0
- scripts/template_train_model.py +45 -0
- scripts/user_model_code/__init__.py +0 -0
- scripts/user_model_code/decode.sh +37 -0
- scripts/user_model_code/interaction/__init__.py +0 -0
- scripts/user_model_code/interaction/config.yaml +12 -0
- scripts/user_model_code/interaction/multiwoz_interact.py +1034 -0
- scripts/user_model_code/interaction/schema.json +712 -0
- scripts/user_model_code/interaction/utils.py +308 -0
- scripts/user_model_code/main_user_model.py +347 -0
- scripts/user_model_code/preprocess_multiwoz.py +528 -0
- scripts/user_model_code/preprocess_sgd.py +431 -0
- scripts/user_model_code/train.sh +51 -0
- src/crazyneuraluser.egg-info/PKG-INFO +173 -0
- src/crazyneuraluser.egg-info/SOURCES.txt +76 -0
- src/crazyneuraluser.egg-info/dependency_links.txt +1 -0
- src/crazyneuraluser.egg-info/not-zip-safe +1 -0
- src/crazyneuraluser.egg-info/requires.txt +15 -0
- src/crazyneuraluser.egg-info/top_level.txt +1 -0
- src/crazyneuraluser/UBAR_code/__init__.py +16 -0
- src/crazyneuraluser/UBAR_code/clean_dataset.py +334 -0
- src/crazyneuraluser/UBAR_code/config.py +164 -0
- src/crazyneuraluser/UBAR_code/config21.py +169 -0
- src/crazyneuraluser/UBAR_code/db_ops.py +314 -0
- src/crazyneuraluser/UBAR_code/eval.py +932 -0
- src/crazyneuraluser/UBAR_code/ontology.py +328 -0
- src/crazyneuraluser/UBAR_code/reader.py +1262 -0
- src/crazyneuraluser/UBAR_code/utils.py +292 -0
- src/crazyneuraluser/user_model_code/analysis_multiwoz.py +119 -0
- src/crazyneuraluser/user_model_code/analysis_sgd.py +483 -0
- src/crazyneuraluser/user_model_code/argument.py +153 -0
- src/crazyneuraluser/user_model_code/dataset.py +297 -0
- src/crazyneuraluser/user_model_code/utils_generation.py +210 -0
- src/crazyneuraluser/user_model_code/utils_multiwoz.py +204 -0
app.py
CHANGED
@@ -1,7 +1,170 @@
|
|
1 |
import gradio as gr
|
|
|
|
|
|
|
|
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
import scripts.simulate_interaction as si
|
3 |
+
import sys
|
4 |
+
import traceback
|
5 |
+
import pandas as pd
|
6 |
|
7 |
+
# from tqdm import tqdm
|
8 |
+
from scripts.UBAR_code.interaction import UBAR_interact
|
9 |
+
from scripts.user_model_code.interaction import multiwoz_interact
|
10 |
+
from scripts.UBAR_code.interaction.UBAR_interact import bcolors
|
11 |
|
12 |
+
|
13 |
+
def instantiate_agents():
|
14 |
+
|
15 |
+
UBAR_checkpoint_path = "models/UBAR/experiments/distilgpt-2_sd11_lr0.0001_bs16_ga2/epoch50_trloss0.59_gpt2"
|
16 |
+
user_model_checkpoint_path = "models/user_model/MultiWOZ-full_checkpoint_step340k"
|
17 |
+
|
18 |
+
sys_model = UBAR_interact.UbarSystemModel(
|
19 |
+
"UBAR_sys_model", UBAR_checkpoint_path, "scripts/UBAR_code/interaction/config.yaml"
|
20 |
+
)
|
21 |
+
|
22 |
+
user_model = multiwoz_interact.NeuralAgent(
|
23 |
+
"user", user_model_checkpoint_path, "scripts/user_model_code/interaction/config.yaml"
|
24 |
+
)
|
25 |
+
|
26 |
+
return sys_model, user_model
|
27 |
+
|
28 |
+
|
29 |
+
def read_multiwoz_data():
|
30 |
+
"""
|
31 |
+
Read the multiwoz 2.0 raw data from the .json file
|
32 |
+
"""
|
33 |
+
raw_mwoz_20_path = "data/raw/UBAR/multi-woz/data.json"
|
34 |
+
df_raw_mwoz = pd.read_json(raw_mwoz_20_path)
|
35 |
+
return df_raw_mwoz
|
36 |
+
|
37 |
+
|
38 |
+
def load_test_val_lists():
|
39 |
+
val_list_file = "data/raw/UBAR/multi-woz/valListFile.json"
|
40 |
+
test_list_file = "data/raw/UBAR/multi-woz/testListFile.json"
|
41 |
+
|
42 |
+
with open(val_list_file, "r") as f:
|
43 |
+
val_list = f.readlines()
|
44 |
+
val_list = [x.strip() for x in val_list]
|
45 |
+
|
46 |
+
with open(test_list_file, "r") as f:
|
47 |
+
test_list = f.readlines()
|
48 |
+
test_list = [x.strip() for x in test_list]
|
49 |
+
|
50 |
+
return val_list, test_list
|
51 |
+
|
52 |
+
|
53 |
+
def main(
|
54 |
+
write_to_file=False, ground_truth_system_responses=False, train_only=True, n_dialogues="all", log_successes=False
|
55 |
+
):
|
56 |
+
sys_model, user_model = instantiate_agents()
|
57 |
+
|
58 |
+
# TODO: move hardcoded vars into config file
|
59 |
+
raw_mwoz_20_path = "data/raw/UBAR/multi-woz/data.json"
|
60 |
+
user_utterances_out_path = "data/preprocessed/UBAR/user_utterances_from_simulator.txt"
|
61 |
+
logging_successes_path = "data/preprocessed/UBAR/logging_successes"
|
62 |
+
sys_model.print_intermediary_info = False
|
63 |
+
user_model.print_intermediary_info = False
|
64 |
+
|
65 |
+
df_raw_mwoz = pd.read_json(raw_mwoz_20_path)
|
66 |
+
if n_dialogues == "all":
|
67 |
+
n_dialogues = len(df_raw_mwoz.columns)
|
68 |
+
|
69 |
+
curr_dialogue_user_utterances_formatted = []
|
70 |
+
|
71 |
+
print("Loading goals...")
|
72 |
+
goals = multiwoz_interact.read_multiWOZ_20_goals(raw_mwoz_20_path, n_dialogues)
|
73 |
+
|
74 |
+
# Write column headers
|
75 |
+
if write_to_file:
|
76 |
+
with open(user_utterances_out_path, "w") as f:
|
77 |
+
f.write("Dialogue #\tDialogue ID\tTurn #\tSystem Response\n")
|
78 |
+
|
79 |
+
print("Loading data...")
|
80 |
+
df_mwoz_data = read_multiwoz_data()
|
81 |
+
val_list, test_list = load_test_val_lists()
|
82 |
+
|
83 |
+
successful_dialogues = 0
|
84 |
+
total_dialogues_generated = 0 # train dialogues only
|
85 |
+
for dialogue_idx, (goal, dialogue_filename) in enumerate(zip(goals, df_mwoz_data.columns)):
|
86 |
+
if log_successes:
|
87 |
+
# log successful_dialogues to logging_successes_path every 100 dialogues
|
88 |
+
if dialogue_idx % 100 == 0:
|
89 |
+
with open(logging_successes_path, "w") as f:
|
90 |
+
f.write(str(successful_dialogues) + " / " + str(total_dialogues_generated))
|
91 |
+
|
92 |
+
curr_dialogue_user_utterances_formatted = []
|
93 |
+
if train_only:
|
94 |
+
if dialogue_filename in val_list or dialogue_filename in test_list:
|
95 |
+
continue
|
96 |
+
|
97 |
+
total_dialogues_generated += 1
|
98 |
+
print("Dialogue: {}".format(dialogue_filename))
|
99 |
+
|
100 |
+
# There are occasionally exceptions thrown from one of the agents, usually the user
|
101 |
+
# In this case we simply continue to the next dialogue
|
102 |
+
try:
|
103 |
+
# Reset state after each dialogue
|
104 |
+
sys_model.init_session()
|
105 |
+
user_model.init_session(ini_goal=goal)
|
106 |
+
sys_response = ""
|
107 |
+
|
108 |
+
for turn_idx in range(50):
|
109 |
+
# Turn idx in this case represents the turn as one user utterance AND one system response
|
110 |
+
usr_response_raw_data_idx = turn_idx * 2
|
111 |
+
sys_response_raw_data_idx = turn_idx * 2 + 1
|
112 |
+
|
113 |
+
user_utterance = user_model.response(sys_response)
|
114 |
+
print(bcolors.OKBLUE + "User: " + bcolors.ENDC + user_utterance)
|
115 |
+
|
116 |
+
if write_to_file:
|
117 |
+
user_utterance = user_utterance.replace("\n", " ")
|
118 |
+
curr_dialogue_user_utterances_formatted.append(
|
119 |
+
str(dialogue_idx)
|
120 |
+
+ "\t"
|
121 |
+
+ dialogue_filename
|
122 |
+
+ "\t"
|
123 |
+
+ str(usr_response_raw_data_idx)
|
124 |
+
+ "\t"
|
125 |
+
+ user_utterance
|
126 |
+
+ "\n"
|
127 |
+
)
|
128 |
+
|
129 |
+
if user_model.is_terminated():
|
130 |
+
successful_dialogues += 1
|
131 |
+
print(bcolors.OKCYAN + "Dialogue terminated successfully!" + bcolors.ENDC)
|
132 |
+
print(bcolors.OKCYAN + "---" * 30 + bcolors.ENDC + "\n")
|
133 |
+
if write_to_file:
|
134 |
+
# Write whole dialogue to file
|
135 |
+
with open(user_utterances_out_path, "a") as f:
|
136 |
+
for line in curr_dialogue_user_utterances_formatted:
|
137 |
+
f.write(line)
|
138 |
+
break
|
139 |
+
|
140 |
+
# Next turn materials
|
141 |
+
if ground_truth_system_responses:
|
142 |
+
# If we are at the end of the ground truth dialogues
|
143 |
+
if len(df_mwoz_data.iloc[:, dialogue_idx].log) <= sys_response_raw_data_idx:
|
144 |
+
print(bcolors.RED + "Dialogue terminated unsuccessfully!" + bcolors.ENDC)
|
145 |
+
print(bcolors.RED + "---" * 30 + bcolors.ENDC + "\n")
|
146 |
+
break
|
147 |
+
sys_response = df_mwoz_data.iloc[:, dialogue_idx].log[sys_response_raw_data_idx]["text"]
|
148 |
+
else:
|
149 |
+
sys_response = sys_model.response(user_utterance, turn_idx)
|
150 |
+
capitalised_sys_response = sys_response[0].upper() + sys_response[1:]
|
151 |
+
print(bcolors.GREEN + "System: " + bcolors.ENDC + capitalised_sys_response)
|
152 |
+
|
153 |
+
except Exception:
|
154 |
+
print(bcolors.RED + "*" * 30 + bcolors.ENDC)
|
155 |
+
print(bcolors.RED + "Error in dialogue {}".format(dialogue_filename) + bcolors.ENDC)
|
156 |
+
print(bcolors.RED + "*" * 30 + bcolors.ENDC)
|
157 |
+
traceback.print_exc()
|
158 |
+
continue
|
159 |
+
|
160 |
+
print("Successful dialogues: {}".format(successful_dialogues))
|
161 |
+
print("Total dialogues: {}".format(n_dialogues))
|
162 |
+
print("% Successful Dialopues: {}".format(successful_dialogues / n_dialogues))
|
163 |
+
|
164 |
+
|
165 |
+
def test():
|
166 |
+
return "SUCCESS"
|
167 |
+
|
168 |
+
|
169 |
+
iface = gr.Interface(fn=test, outputs="text")
|
170 |
+
iface.launch()
|
scripts/UBAR_code/__init__.py
ADDED
File without changes
|
scripts/UBAR_code/data_analysis.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import zipfile
|
6 |
+
from collections import OrderedDict
|
7 |
+
|
8 |
+
from crazyneuraluser.UBAR_code.ontology import all_domains
|
9 |
+
|
10 |
+
# 2.0
|
11 |
+
data_path = "data/preprocessed/UBAR/gen_usr_utt_experiment_data.json"
|
12 |
+
save_path = "data/interim/gen_usr_utts/multi-woz-analysis/"
|
13 |
+
save_path_exp = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/"
|
14 |
+
# 2.1
|
15 |
+
# data_path = 'data/raw/UBAR/MultiWOZ_2.1/'
|
16 |
+
# save_path = 'data/interim/multi-woz-2.1-analysis/'
|
17 |
+
# save_path_exp = 'data/preprocessed/multi-woz-2.1-processed/'
|
18 |
+
data_file = "data.json"
|
19 |
+
domains = all_domains
|
20 |
+
# all_domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'police', 'hospital']
|
21 |
+
|
22 |
+
|
23 |
+
def analysis():
|
24 |
+
compressed_raw_data = {}
|
25 |
+
goal_of_dials = {}
|
26 |
+
req_slots = {}
|
27 |
+
info_slots = {}
|
28 |
+
dom_count = {}
|
29 |
+
dom_fnlist = {}
|
30 |
+
all_domain_specific_slots = set()
|
31 |
+
for domain in domains:
|
32 |
+
req_slots[domain] = []
|
33 |
+
info_slots[domain] = []
|
34 |
+
|
35 |
+
# archive = zipfile.ZipFile(data_path + data_file + ".zip", "r")
|
36 |
+
# data = archive.open(data_file, "r").read().decode("utf-8").lower()
|
37 |
+
data = open(data_path, "r").read().lower()
|
38 |
+
ref_nos = list(set(re.findall(r"\"reference\"\: \"(\w+)\"", data)))
|
39 |
+
data = json.loads(data)
|
40 |
+
|
41 |
+
for fn, dial in data.items():
|
42 |
+
goals = dial["goal"]
|
43 |
+
logs = dial["log"]
|
44 |
+
|
45 |
+
# get compressed_raw_data and goal_of_dials
|
46 |
+
compressed_raw_data[fn] = {"goal": {}, "log": []}
|
47 |
+
goal_of_dials[fn] = {}
|
48 |
+
for dom, goal in goals.items(): # get goal of domains that are in demmand
|
49 |
+
if dom != "topic" and dom != "message" and goal:
|
50 |
+
compressed_raw_data[fn]["goal"][dom] = goal
|
51 |
+
goal_of_dials[fn][dom] = goal
|
52 |
+
|
53 |
+
for turn in logs:
|
54 |
+
if not turn["metadata"]: # user's turn
|
55 |
+
compressed_raw_data[fn]["log"].append({"text": turn["text"]})
|
56 |
+
else: # system's turn
|
57 |
+
meta = turn["metadata"]
|
58 |
+
turn_dict = {"text": turn["text"], "metadata": {}}
|
59 |
+
for (
|
60 |
+
dom,
|
61 |
+
book_semi,
|
62 |
+
) in meta.items(): # for every domain, sys updates "book" and "semi"
|
63 |
+
book, semi = book_semi["book"], book_semi["semi"]
|
64 |
+
record = False
|
65 |
+
for (
|
66 |
+
slot,
|
67 |
+
value,
|
68 |
+
) in book.items(): # record indicates non-empty-book domain
|
69 |
+
if value not in ["", []]:
|
70 |
+
record = True
|
71 |
+
if record:
|
72 |
+
turn_dict["metadata"][dom] = {}
|
73 |
+
turn_dict["metadata"][dom]["book"] = book # add that domain's book
|
74 |
+
record = False
|
75 |
+
for (
|
76 |
+
slot,
|
77 |
+
value,
|
78 |
+
) in semi.items(): # here record indicates non-empty-semi domain
|
79 |
+
if value not in ["", []]:
|
80 |
+
record = True
|
81 |
+
break
|
82 |
+
if record:
|
83 |
+
for s, v in copy.deepcopy(semi).items():
|
84 |
+
if v == "not mentioned":
|
85 |
+
del semi[s]
|
86 |
+
if not turn_dict["metadata"].get(dom):
|
87 |
+
turn_dict["metadata"][dom] = {}
|
88 |
+
turn_dict["metadata"][dom]["semi"] = semi # add that domain's semi
|
89 |
+
compressed_raw_data[fn]["log"].append(turn_dict) # add to log the compressed turn_dict
|
90 |
+
|
91 |
+
# get domain statistics
|
92 |
+
dial_type = (
|
93 |
+
"multi" if "mul" in fn or "MUL" in fn else "single"
|
94 |
+
) # determine the dialog's type: sinle or multi
|
95 |
+
if fn in ["pmul2756.json", "pmul4958.json", "pmul3599.json"]:
|
96 |
+
dial_type = "single"
|
97 |
+
dial_domains = [dom for dom in domains if goals[dom]] # domains that are in demmand
|
98 |
+
dom_str = ""
|
99 |
+
for dom in dial_domains:
|
100 |
+
if not dom_count.get(dom + "_" + dial_type): # count each domain type, with single or multi considered
|
101 |
+
dom_count[dom + "_" + dial_type] = 1
|
102 |
+
else:
|
103 |
+
dom_count[dom + "_" + dial_type] += 1
|
104 |
+
if not dom_fnlist.get(dom + "_" + dial_type): # keep track the file number of each domain type
|
105 |
+
dom_fnlist[dom + "_" + dial_type] = [fn]
|
106 |
+
else:
|
107 |
+
dom_fnlist[dom + "_" + dial_type].append(fn)
|
108 |
+
dom_str += "%s_" % dom
|
109 |
+
dom_str = dom_str[:-1] # substract the last char in dom_str
|
110 |
+
if dial_type == "multi": # count multi-domains
|
111 |
+
if not dom_count.get(dom_str):
|
112 |
+
dom_count[dom_str] = 1
|
113 |
+
else:
|
114 |
+
dom_count[dom_str] += 1
|
115 |
+
if not dom_fnlist.get(dom_str):
|
116 |
+
dom_fnlist[dom_str] = [fn]
|
117 |
+
else:
|
118 |
+
dom_fnlist[dom_str].append(fn)
|
119 |
+
######
|
120 |
+
|
121 |
+
# get informable and requestable slots statistics
|
122 |
+
for domain in domains:
|
123 |
+
info_ss = goals[domain].get("info", {})
|
124 |
+
book_ss = goals[domain].get("book", {})
|
125 |
+
req_ss = goals[domain].get("reqt", {})
|
126 |
+
for info_s in info_ss:
|
127 |
+
all_domain_specific_slots.add(domain + "-" + info_s)
|
128 |
+
if info_s not in info_slots[domain]:
|
129 |
+
info_slots[domain] += [info_s]
|
130 |
+
for book_s in book_ss:
|
131 |
+
if "book_" + book_s not in info_slots[domain] and book_s not in [
|
132 |
+
"invalid",
|
133 |
+
"pre_invalid",
|
134 |
+
]:
|
135 |
+
all_domain_specific_slots.add(domain + "-" + book_s)
|
136 |
+
info_slots[domain] += ["book_" + book_s]
|
137 |
+
for req_s in req_ss:
|
138 |
+
if req_s not in req_slots[domain]:
|
139 |
+
req_slots[domain] += [req_s]
|
140 |
+
|
141 |
+
# result statistics
|
142 |
+
if not os.path.exists(save_path):
|
143 |
+
os.mkdir(save_path)
|
144 |
+
if not os.path.exists(save_path_exp):
|
145 |
+
os.mkdir(save_path_exp)
|
146 |
+
with open(save_path + "req_slots.json", "w") as sf:
|
147 |
+
json.dump(req_slots, sf, indent=2)
|
148 |
+
with open(save_path + "info_slots.json", "w") as sf:
|
149 |
+
json.dump(info_slots, sf, indent=2)
|
150 |
+
with open(save_path + "all_domain_specific_info_slots.json", "w") as sf:
|
151 |
+
json.dump(list(all_domain_specific_slots), sf, indent=2)
|
152 |
+
print("slot num:", len(list(all_domain_specific_slots)))
|
153 |
+
with open(save_path + "goal_of_each_dials.json", "w") as sf:
|
154 |
+
json.dump(goal_of_dials, sf, indent=2)
|
155 |
+
with open(save_path + "compressed_data.json", "w") as sf:
|
156 |
+
json.dump(compressed_raw_data, sf, indent=2)
|
157 |
+
with open(save_path + "domain_count.json", "w") as sf:
|
158 |
+
single_count = [d for d in dom_count.items() if "single" in d[0]]
|
159 |
+
multi_count = [d for d in dom_count.items() if "multi" in d[0]]
|
160 |
+
other_count = [d for d in dom_count.items() if "multi" not in d[0] and "single" not in d[0]]
|
161 |
+
dom_count_od = OrderedDict(single_count + multi_count + other_count)
|
162 |
+
json.dump(dom_count_od, sf, indent=2)
|
163 |
+
with open(save_path_exp + "reference_no.json", "w") as sf:
|
164 |
+
json.dump(ref_nos, sf, indent=2)
|
165 |
+
with open(save_path_exp + "domain_files.json", "w") as sf:
|
166 |
+
json.dump(dom_fnlist, sf, indent=2)
|
167 |
+
|
168 |
+
|
169 |
+
if __name__ == "__main__":
|
170 |
+
analysis()
|
scripts/UBAR_code/interaction/UBAR_interact.py
ADDED
@@ -0,0 +1,457 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
import random
|
4 |
+
import string
|
5 |
+
|
6 |
+
# import bcolors
|
7 |
+
from omegaconf import OmegaConf
|
8 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
9 |
+
|
10 |
+
from crazyneuraluser.UBAR_code.config import global_config as cfg
|
11 |
+
from crazyneuraluser.UBAR_code.reader import MultiWozReader
|
12 |
+
from crazyneuraluser.UBAR_code.db_ops import MultiWozDB
|
13 |
+
|
14 |
+
from typing import List
|
15 |
+
|
16 |
+
|
17 |
+
class bcolors:
|
18 |
+
HEADER = "\033[95m"
|
19 |
+
OKBLUE = "\033[94m"
|
20 |
+
OKCYAN = "\033[96m"
|
21 |
+
GREEN = "\033[92m"
|
22 |
+
YELLOW = "\033[93m"
|
23 |
+
RED = "\033[91m"
|
24 |
+
ENDC = "\033[0m"
|
25 |
+
BOLD = "\033[1m"
|
26 |
+
UNDERLINE = "\033[4m"
|
27 |
+
|
28 |
+
|
29 |
+
class UbarSystemModel: # may inherit convlab or not, just like andy's
|
30 |
+
def __init__(self, name: str, checkpoint_path: str, model_config_path: str):
|
31 |
+
|
32 |
+
self.tokenizer = GPT2Tokenizer.from_pretrained(checkpoint_path)
|
33 |
+
self.model = GPT2LMHeadModel.from_pretrained(checkpoint_path)
|
34 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
35 |
+
self.name = name
|
36 |
+
self.turn_domain = ["general"] # returns a list of one string that is the domain e.g. 'taxi'
|
37 |
+
# (this is because of the way the db_ops.py deals with the domain. It should really be a string.)
|
38 |
+
|
39 |
+
self.ubar_status = {"dialogue_terminate": False}
|
40 |
+
|
41 |
+
self.print_intermediary_info = False
|
42 |
+
|
43 |
+
self.config = OmegaConf.load(model_config_path)
|
44 |
+
self.previous_turn = {"user": [], "bspn": [], "aspn": [], "db": []}
|
45 |
+
|
46 |
+
# NB: best to use corpus goals to guide interactions - baselines/simulate_agent.py allows that.
|
47 |
+
|
48 |
+
# initialize multiwoz reader and db_ops
|
49 |
+
self.reader = MultiWozReader(self.tokenizer)
|
50 |
+
self.db = MultiWozDB(self.config.dbs_path)
|
51 |
+
|
52 |
+
def lexicalize_sys_response(self, sys_response, domain_hits, decoded_belief_state_subseq) -> str:
|
53 |
+
lexicalized_sys_response = ""
|
54 |
+
|
55 |
+
# Track entities already filled e.g. if there are 3 restaurants track which have already been added to a slot
|
56 |
+
max_idx_of_added_entities = -1
|
57 |
+
|
58 |
+
# Fill slots with values from the DB (lexicalization)
|
59 |
+
for token in sys_response.split():
|
60 |
+
token = token.strip(" .,;:")
|
61 |
+
if token.startswith("["): # It is a slot to be filled
|
62 |
+
|
63 |
+
# Note in hotel there is specific price data too but to simplify things
|
64 |
+
# we just use the price range (e.g. moderate)
|
65 |
+
# TODO: there are different uses of price in different databases ('price' vs 'pricerange':
|
66 |
+
# need to deal with this appropriately below)
|
67 |
+
slots_to_db_keys_map = {
|
68 |
+
"[value_price]": "price",
|
69 |
+
"[value_pricerange]": "pricerange",
|
70 |
+
"[value_food]": "food",
|
71 |
+
"[value_area]": "area",
|
72 |
+
"[value_type]": "type",
|
73 |
+
"[value_phone]": "phone",
|
74 |
+
"[value_address]": "address",
|
75 |
+
"[value_leave]": "leave",
|
76 |
+
"[value_postcode]": "postcode",
|
77 |
+
"[value_id]": "id",
|
78 |
+
"[value_arrive]": "arrive",
|
79 |
+
"[value_stars]": "stars",
|
80 |
+
"[value_day]": "day",
|
81 |
+
"[value_destination]": "destination",
|
82 |
+
"[value_car]": "taxi_types",
|
83 |
+
"[value_departure]": "departure",
|
84 |
+
"[value_people]": "people",
|
85 |
+
"[value_stay]": "stay",
|
86 |
+
"[value_department]": "department",
|
87 |
+
"[value_time]": "time",
|
88 |
+
"[value_name]": "name",
|
89 |
+
"[value_reference]": "reference",
|
90 |
+
}
|
91 |
+
# Hospital domain is a strange outlier data structure
|
92 |
+
if self.turn_domain == ["hospital"] and token == "[value_address]":
|
93 |
+
token = "1 Addenbrooks Street"
|
94 |
+
elif self.turn_domain == ["hospital"] and token == "[value_postcode]":
|
95 |
+
token = "CB11QD"
|
96 |
+
|
97 |
+
# So does taxi
|
98 |
+
elif self.turn_domain == ["taxi"] and token == "[value_phone]" and domain_hits != []:
|
99 |
+
token = domain_hits[0]["taxi_phone"]
|
100 |
+
|
101 |
+
# Deal with value_name differently because there can be multiple
|
102 |
+
elif token == "[value_name]" and domain_hits != []:
|
103 |
+
token = domain_hits[max_idx_of_added_entities + 1]["name"]
|
104 |
+
max_idx_of_added_entities += 1
|
105 |
+
|
106 |
+
# This slot tells the user how many db hits there were matching their constraints
|
107 |
+
elif token == "[value_choice]" and domain_hits != []:
|
108 |
+
token = len(domain_hits)
|
109 |
+
|
110 |
+
# Randomly generate the reference
|
111 |
+
elif token == "[value_reference]" and domain_hits != []:
|
112 |
+
token = "".join(random.choices(string.ascii_uppercase, k=10))
|
113 |
+
|
114 |
+
else:
|
115 |
+
# First check can we fill the token from the db results
|
116 |
+
db_success = False
|
117 |
+
if domain_hits != []:
|
118 |
+
for slot, db_key in slots_to_db_keys_map.items():
|
119 |
+
if token == slot and db_key in domain_hits[0]:
|
120 |
+
token = domain_hits[0][db_key]
|
121 |
+
db_success = True
|
122 |
+
|
123 |
+
# If we cannot, then try to fill it from the belief state by looking for a match
|
124 |
+
# in the belief state and then if there is a match adding the next token.
|
125 |
+
# This is not perfect as some are more than one word but its probably good enough.
|
126 |
+
if not db_success:
|
127 |
+
decoded_belief_states = decoded_belief_state_subseq.split()
|
128 |
+
for idx, belief_state_slot in enumerate(decoded_belief_states):
|
129 |
+
if token in slots_to_db_keys_map.keys():
|
130 |
+
if slots_to_db_keys_map[token] == belief_state_slot:
|
131 |
+
token == decoded_belief_states[idx + 1]
|
132 |
+
|
133 |
+
# Otherwise just leave the slot as it is as we have failed to fill it
|
134 |
+
|
135 |
+
lexicalized_sys_response += str(token)
|
136 |
+
lexicalized_sys_response += " "
|
137 |
+
|
138 |
+
return lexicalized_sys_response
|
139 |
+
|
140 |
+
def set_turn_domain(self, belief_span_ids_subseq, sys_act_span_ids_subseq=None) -> None:
|
141 |
+
"""
|
142 |
+
IMPORTANT: use_system_act is not None when actually querying the DB to
|
143 |
+
lexicalise the system response. When it is None the Belief state NOT the system act is used to determine
|
144 |
+
the domain. In self.response() the DB is queried twice. The first time is using the Belief state as the system
|
145 |
+
act has not yet been generated, and it is only used to find out if there are matches in the DB for the current
|
146 |
+
domain + constraints. Then, after the system act is generated, we call the DB to actually get the results to
|
147 |
+
lexicalise the system response. It is much more important that the domain is correct for the second call, and
|
148 |
+
the system act is much more accurate at determining the domain.
|
149 |
+
"""
|
150 |
+
|
151 |
+
if sys_act_span_ids_subseq is None:
|
152 |
+
decoded_belief_state_subseq = self.tokenizer.decode(belief_span_ids_subseq[1:-1])
|
153 |
+
decoded_prev_belief_state_subseq = self.tokenizer.decode(self.previous_turn["bspn"][1:-1])
|
154 |
+
|
155 |
+
# If it is the first turn and the belief state is empty then set the domain to general
|
156 |
+
if self.previous_turn["bspn"] == [] and len(belief_span_ids_subseq) == 2:
|
157 |
+
self.turn_domain = ["general"]
|
158 |
+
return
|
159 |
+
|
160 |
+
# If the belief state doesn't change then keep the same domain
|
161 |
+
if belief_span_ids_subseq == self.previous_turn["bspn"]:
|
162 |
+
return
|
163 |
+
|
164 |
+
# The domain has changed, get the new one (from the right)
|
165 |
+
else:
|
166 |
+
# remove substring from string
|
167 |
+
if decoded_prev_belief_state_subseq in decoded_belief_state_subseq:
|
168 |
+
decoded_new_tokens = decoded_belief_state_subseq.replace("decoded_prev_belief_state_subseq", "")
|
169 |
+
most_recent_domain_in_belief_state = [
|
170 |
+
[token.strip("[]") for token in decoded_new_tokens.split() if token.startswith("[")][-1]
|
171 |
+
]
|
172 |
+
self.turn_domain = most_recent_domain_in_belief_state
|
173 |
+
else:
|
174 |
+
# Sometimes the previous belief state is not in the current belief state as
|
175 |
+
# the output changes very slightly (say by one word) - in this case just keep the same domain
|
176 |
+
# TODO: Could probably handle this better.
|
177 |
+
if self.print_intermediary_info:
|
178 |
+
print(
|
179 |
+
bcolors.YELLOW
|
180 |
+
+ "!Previous belief state not in current belief state! Details below:"
|
181 |
+
+ bcolors.ENDC
|
182 |
+
)
|
183 |
+
print("Previous Belief State: " + decoded_prev_belief_state_subseq)
|
184 |
+
print("Current Belief State: " + decoded_belief_state_subseq)
|
185 |
+
|
186 |
+
else:
|
187 |
+
decoded_sys_act_subseq = self.tokenizer.decode(sys_act_span_ids_subseq[1:-1])
|
188 |
+
|
189 |
+
most_recent_domain_in_sys_act = [
|
190 |
+
[token.strip("[]") for token in decoded_sys_act_subseq.split() if token.startswith("[")][0]
|
191 |
+
]
|
192 |
+
self.turn_domain = most_recent_domain_in_sys_act
|
193 |
+
|
194 |
+
def get_domain_hits(self, decoded_belief_state_subseq) -> dict:
|
195 |
+
# Get hits from db based on belief state, unless its a general turn (no hits then)
|
196 |
+
constraint_dict = self.reader.bspan_to_constraint_dict(decoded_belief_state_subseq)
|
197 |
+
query_turn_domain = self.turn_domain[0] # db.queryJsons needs a string not a list (single domain)
|
198 |
+
# If the constraint dict doesn't contain any constraints for the current domain then pass an empty dict
|
199 |
+
if query_turn_domain in constraint_dict:
|
200 |
+
domain_hits = self.db.queryJsons(query_turn_domain, constraint_dict[query_turn_domain])
|
201 |
+
else:
|
202 |
+
domain_hits = self.db.queryJsons(query_turn_domain, {})
|
203 |
+
|
204 |
+
return domain_hits
|
205 |
+
|
206 |
+
def print_turn_intermediate_info(self, generated_subseq_ids_map) -> None:
|
207 |
+
print(bcolors.OKCYAN + "Turn domain: " + bcolors.ENDC + "[" + str(self.turn_domain[0]) + "]")
|
208 |
+
|
209 |
+
belief_state = self.tokenizer.decode(generated_subseq_ids_map["bspn"])
|
210 |
+
print(bcolors.OKCYAN + "Belief state: " + bcolors.ENDC + belief_state)
|
211 |
+
|
212 |
+
db_output = self.tokenizer.decode(generated_subseq_ids_map["db"])
|
213 |
+
print(bcolors.OKCYAN + "DB Output: " + bcolors.ENDC + db_output)
|
214 |
+
|
215 |
+
sys_act = self.tokenizer.decode(generated_subseq_ids_map["aspn"])
|
216 |
+
print(bcolors.OKCYAN + "System Act: " + bcolors.ENDC + sys_act)
|
217 |
+
|
218 |
+
def _init_ubar_status(self) -> dict:
|
219 |
+
return {"dialogue_terminate": False}
|
220 |
+
|
221 |
+
def init_session(self):
|
222 |
+
self.ubar_status = self._init_ubar_status()
|
223 |
+
self.previous_turn = {"user": [], "bspn": [], "aspn": [], "db": []}
|
224 |
+
self.turn_domain = ["general"]
|
225 |
+
|
226 |
+
def is_terminated(self) -> bool:
|
227 |
+
"""This should tell an external client whether the user model considers they have completed the task."""
|
228 |
+
# return False
|
229 |
+
return self.ubar_status["dialogue_terminate"]
|
230 |
+
|
231 |
+
def _activate_dialogue_terminate(self) -> None:
|
232 |
+
"""Turn on the ubar status about dialogue termination"""
|
233 |
+
self.ubar_status["dialogue_terminate"] = True
|
234 |
+
|
235 |
+
def add_torch_input_eval(self, inputs):
|
236 |
+
# inputs: context
|
237 |
+
inputs["context_tensor"] = torch.tensor([inputs["context"]]).to(self.device)
|
238 |
+
return inputs
|
239 |
+
|
240 |
+
def prepare_input_for_model(self, user_utterance: str, turn_id: int) -> torch.Tensor:
|
241 |
+
# TODO: CONVERT DIALOGUE HISTORY TO TOKEN IDS
|
242 |
+
|
243 |
+
tokenised_user_utterance = self.tokenizer.encode("<sos_u> " + user_utterance + " <eos_u>")
|
244 |
+
# In this application turn always only contains ["user"], not ["bspn", "aspn", "db"] etc.
|
245 |
+
turn = {"user": tokenised_user_utterance}
|
246 |
+
|
247 |
+
first_turn = turn_id == 0
|
248 |
+
inputs = self.reader.convert_turn_eval(turn, self.previous_turn, first_turn)
|
249 |
+
inputs = self.add_torch_input_eval(inputs)
|
250 |
+
|
251 |
+
return inputs
|
252 |
+
|
253 |
+
def decode_generated_bspn(self, generated) -> List[int]:
|
254 |
+
eos_b_id = self.tokenizer.encode(["<eos_b>"])[0]
|
255 |
+
if eos_b_id in generated:
|
256 |
+
eos_b_idx = generated.index(eos_b_id)
|
257 |
+
else:
|
258 |
+
eos_b_idx = len(generated) - 1
|
259 |
+
return generated[: eos_b_idx + 1]
|
260 |
+
|
261 |
+
def decode_grenerated_act_resp(self, generated) -> dict:
|
262 |
+
"""
|
263 |
+
decode generated
|
264 |
+
return decoded['resp'] ('bspn', 'aspn')
|
265 |
+
"""
|
266 |
+
decoded = {}
|
267 |
+
eos_a_id = self.tokenizer.encode(["<eos_a>"])[0]
|
268 |
+
eos_r_id = self.tokenizer.encode(["<eos_r>"])[0]
|
269 |
+
# eos_b_id = self.tokenizer.encode(["<eos_b>"])[0]
|
270 |
+
|
271 |
+
# eos_r may not exists if gpt2 generated repetitive words.
|
272 |
+
if eos_r_id in generated:
|
273 |
+
eos_r_idx = generated.index(eos_r_id)
|
274 |
+
else:
|
275 |
+
eos_r_idx = len(generated) - 1
|
276 |
+
|
277 |
+
if cfg.use_true_curr_aspn: # only predict resp
|
278 |
+
decoded["resp"] = generated[: eos_r_idx + 1]
|
279 |
+
else: # predicted aspn, resp
|
280 |
+
eos_a_idx = generated.index(eos_a_id)
|
281 |
+
decoded["aspn"] = generated[: eos_a_idx + 1]
|
282 |
+
decoded["resp"] = generated[eos_a_idx + 1 : eos_r_idx + 1]
|
283 |
+
return decoded
|
284 |
+
|
285 |
+
def generate_ids_subseq_map(self, inputs):
|
286 |
+
|
287 |
+
context_input_subseq = inputs["context"]
|
288 |
+
# decoded_context_input_subseq = self.tokenizer.decode(context_input_subseq)
|
289 |
+
# Check if model has put duplicate tags in the context and if so remove one of the duplicates
|
290 |
+
# Yes this is kind of hacky, but UBAR seems to learn to duplicate certain tags - I don't know why
|
291 |
+
# Also instead of decoding and encoding here tags could be checked with their ids - but time is short...
|
292 |
+
# cleaned_decoded_list = []
|
293 |
+
# prev_token = ""
|
294 |
+
# for token in decoded_context_input_subseq.split():
|
295 |
+
# if token.startswith("<") and token.endswith(">"): # It is a tag
|
296 |
+
# if token == prev_token: # It is a duplicate tag
|
297 |
+
# continue
|
298 |
+
# cleaned_decoded_list.append(token)
|
299 |
+
# prev_token = token
|
300 |
+
# decoded_context_input_subseq = " ".join(cleaned_decoded_list)
|
301 |
+
# context_input_subseq = self.tokenizer.encode(decoded_context_input_subseq)
|
302 |
+
|
303 |
+
context_input_subeq_tensor = inputs["context_tensor"]
|
304 |
+
|
305 |
+
# TODO: FIND OUT BY COMPARING WITH MODEL.VALIDATE() how to calculate context_length
|
306 |
+
context_length = len(context_input_subseq)
|
307 |
+
|
308 |
+
belief_state_ids = self.model.generate(
|
309 |
+
input_ids=context_input_subeq_tensor,
|
310 |
+
max_length=context_length + 60,
|
311 |
+
temperature=0.7,
|
312 |
+
top_p=1,
|
313 |
+
num_beams=1,
|
314 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
315 |
+
eos_token_id=self.tokenizer.encode(["<eos_b>"])[0],
|
316 |
+
)
|
317 |
+
gen_belief_state_token_ids = belief_state_ids[0].cpu().numpy().tolist() # type: list[int]
|
318 |
+
belief_span_ids_subseq = self.decode_generated_bspn(
|
319 |
+
gen_belief_state_token_ids[context_length - 1 :]
|
320 |
+
) # type: list[int]
|
321 |
+
|
322 |
+
self.set_turn_domain(belief_span_ids_subseq)
|
323 |
+
|
324 |
+
db_result = self.reader.bspan_to_DBpointer(
|
325 |
+
self.tokenizer.decode(belief_span_ids_subseq), self.turn_domain
|
326 |
+
) # type: str
|
327 |
+
db_ids_subseq = self.tokenizer.convert_tokens_to_ids(
|
328 |
+
self.tokenizer.tokenize("<sos_db> " + db_result + " <eos_db>")
|
329 |
+
) + self.tokenizer.encode(["<sos_a>"])
|
330 |
+
|
331 |
+
# TODO: context_input_subseq is already a tensor but the other two subseqs aren't - why?
|
332 |
+
act_response_gen_input_subseq = context_input_subseq + belief_span_ids_subseq + db_ids_subseq
|
333 |
+
act_response_gen_input_subseq_tensor = torch.tensor([act_response_gen_input_subseq]).to(self.device)
|
334 |
+
context_length = len(act_response_gen_input_subseq)
|
335 |
+
|
336 |
+
outputs_db = self.model.generate(
|
337 |
+
input_ids=act_response_gen_input_subseq_tensor,
|
338 |
+
max_length=context_length + 80,
|
339 |
+
temperature=0.7,
|
340 |
+
top_p=1,
|
341 |
+
num_beams=1,
|
342 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
343 |
+
eos_token_id=self.tokenizer.encode(["<eos_r>"])[0],
|
344 |
+
)
|
345 |
+
generated_act_resp_token_ids = outputs_db[0].cpu().numpy().tolist() # type: list[int]
|
346 |
+
generated_act_resp_token_ids = generated_act_resp_token_ids[context_length - 1 :]
|
347 |
+
|
348 |
+
try:
|
349 |
+
generated_subseq_ids_map = self.decode_grenerated_act_resp(generated_act_resp_token_ids)
|
350 |
+
# TODO: IF YOU WANT Option b) then you just read the ['resp'] key and convert to string using huggingface;
|
351 |
+
# that would be sys_response; Obviously, this applies to Option a as well
|
352 |
+
generated_subseq_ids_map["bspn"] = belief_span_ids_subseq
|
353 |
+
# TODO: Option a) STORE THESE MAPPINGS IN SELF.CONTEXT IF YOU WANT TO HAVE
|
354 |
+
# {U_1, BS_1, DB_1, A_1, R_1, U_2, BS_2... history}
|
355 |
+
|
356 |
+
generated_subseq_ids_map["db"] = db_ids_subseq
|
357 |
+
generated_subseq_ids_map["labels"] = context_input_subseq
|
358 |
+
|
359 |
+
except ValueError:
|
360 |
+
generated_subseq_ids_map = {"resp": [], "bspn": [], "aspn": [], "db": [], "labels": []}
|
361 |
+
|
362 |
+
# IMPORTANT: this is how all of the previous state is updated (appended) after each turn
|
363 |
+
# Update self.previous_turn to track state to be fed into GPT2
|
364 |
+
for k, v in generated_subseq_ids_map.items():
|
365 |
+
self.previous_turn[k] = v
|
366 |
+
|
367 |
+
if self.print_intermediary_info:
|
368 |
+
self.print_turn_intermediate_info(generated_subseq_ids_map)
|
369 |
+
|
370 |
+
return generated_subseq_ids_map
|
371 |
+
|
372 |
+
def response(self, usr_utterance: str, turn_id: int) -> str:
|
373 |
+
|
374 |
+
if usr_utterance == "Goodbye":
|
375 |
+
self._activate_dialogue_terminate()
|
376 |
+
return "Session Terminated by User"
|
377 |
+
|
378 |
+
inputs = self.prepare_input_for_model(usr_utterance, turn_id)
|
379 |
+
|
380 |
+
generated_subseq_ids_map = self.generate_ids_subseq_map(inputs)
|
381 |
+
belief_span_ids_subseq = generated_subseq_ids_map["bspn"]
|
382 |
+
|
383 |
+
sys_response = self.tokenizer.decode(generated_subseq_ids_map["resp"][1:-1])
|
384 |
+
|
385 |
+
prev_turn_domain = self.turn_domain
|
386 |
+
sys_act_span_ids_subseq = generated_subseq_ids_map["aspn"]
|
387 |
+
self.set_turn_domain(belief_span_ids_subseq, sys_act_span_ids_subseq)
|
388 |
+
|
389 |
+
if self.turn_domain != ["general"]:
|
390 |
+
# If the domain changes when reading the system response, then we need to re-do the generation process
|
391 |
+
# for both the belief state and the system action and response. We do this because self.get_domain_hits()
|
392 |
+
# will break if the domain is different when querying the DB for the second time here than when it was
|
393 |
+
# originally queried above, due to the constraint dict it uses that is generated from the belief state
|
394 |
+
# How can the belief state domain and the system act domain be different? Bunch of things, for example:
|
395 |
+
# When asking for the police the belief state may be empty (so 'general' domain)
|
396 |
+
# but then the system action will have [police].
|
397 |
+
if prev_turn_domain != self.turn_domain:
|
398 |
+
if self.print_intermediary_info:
|
399 |
+
print(
|
400 |
+
bcolors.RED
|
401 |
+
+ "Domain changed from {} to {}".format(prev_turn_domain, self.turn_domain)
|
402 |
+
+ bcolors.RED
|
403 |
+
)
|
404 |
+
generated_subseq_ids_map = self.generate_ids_subseq_map(inputs)
|
405 |
+
sys_response = self.tokenizer.decode(generated_subseq_ids_map["resp"][1:-1])
|
406 |
+
|
407 |
+
decoded_belief_state_subseq = self.tokenizer.decode(belief_span_ids_subseq)
|
408 |
+
domain_hits = self.get_domain_hits(decoded_belief_state_subseq)
|
409 |
+
# print(bcolors.UNDERLINE + "Domain hits: \n" + bcolors.ENDC, domain_hits) # for debugging
|
410 |
+
|
411 |
+
sys_response = self.lexicalize_sys_response(sys_response, domain_hits, decoded_belief_state_subseq)
|
412 |
+
|
413 |
+
return sys_response
|
414 |
+
|
415 |
+
|
416 |
+
def interact(checkpoint_path):
|
417 |
+
sys_model = UbarSystemModel("UBAR_sys_model", checkpoint_path, "scripts/UBAR_code/interaction/config.yaml")
|
418 |
+
# TODO: Fix this hardcoded variable (should be in config)
|
419 |
+
sys_model.print_intermediary_info = True
|
420 |
+
|
421 |
+
for dial_id in range(1, 11):
|
422 |
+
print(f"In dialogue {dial_id}")
|
423 |
+
|
424 |
+
# Reset state after each dialog
|
425 |
+
sys_model.init_session()
|
426 |
+
|
427 |
+
user_utt = input(bcolors.GREEN + "Enter user response here: " + bcolors.ENDC)
|
428 |
+
|
429 |
+
for turn_id in range(100):
|
430 |
+
try:
|
431 |
+
sys_response = sys_model.response(user_utt, turn_id)
|
432 |
+
# There are a lot of edge case bugs that are possible that could break the current turn. If so, continue
|
433 |
+
# to ensure a large run across the dataset isn't ruined by a single bad turn.
|
434 |
+
except Exception() as e:
|
435 |
+
print(bcolors.RED + "Exception: {}".format(e) + bcolors.ENDC)
|
436 |
+
continue
|
437 |
+
|
438 |
+
if sys_model.is_terminated():
|
439 |
+
print(bcolors.RED + sys_response + bcolors.ENDC)
|
440 |
+
print(bcolors.RED + "---" * 30 + bcolors.ENDC)
|
441 |
+
break
|
442 |
+
|
443 |
+
print(bcolors.YELLOW + "System: " + bcolors.ENDC + sys_response)
|
444 |
+
print("---" * 30)
|
445 |
+
|
446 |
+
# next turn materials
|
447 |
+
user_utt = input(bcolors.GREEN + "Enter user response here: " + bcolors.ENDC)
|
448 |
+
|
449 |
+
|
450 |
+
if __name__ == "__main__":
|
451 |
+
if len(sys.argv) == 1:
|
452 |
+
print("Wrong argument!")
|
453 |
+
print("Usage: python UBAR_interact.py checkpoint_path")
|
454 |
+
sys.exit(1)
|
455 |
+
|
456 |
+
checkpoint_path = sys.argv[1]
|
457 |
+
interact(checkpoint_path)
|
scripts/UBAR_code/interaction/__init__.py
ADDED
File without changes
|
scripts/UBAR_code/interaction/config.yaml
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
path: "./models/UBAR/experiments/distilgpt-2_sd11_lr0.0001_bs16_ga2/epoch50_trloss0.59_gpt2"
|
3 |
+
goal_update:
|
4 |
+
finish_inform: "loose" # loose or strict
|
5 |
+
|
6 |
+
schema_path: "scripts/user_model_code/interaction/schema.json"
|
7 |
+
|
8 |
+
decode:
|
9 |
+
dec_max_len: 1024
|
10 |
+
num_beams: 1
|
11 |
+
temperature: 1.0
|
12 |
+
do_sample: False
|
13 |
+
|
14 |
+
use_all_previous_context: False
|
15 |
+
|
16 |
+
dbs_path:
|
17 |
+
"attraction": "data/preprocessed/UBAR/db_processed/attraction_db_processed.json"
|
18 |
+
"hospital": "data/preprocessed/UBAR/db_processed/hospital_db_processed.json"
|
19 |
+
"hotel": "data/preprocessed/UBAR/db_processed/hotel_db_processed.json"
|
20 |
+
"police": "data/preprocessed/UBAR/db_processed/police_db_processed.json"
|
21 |
+
"restaurant": "data/preprocessed/UBAR/db_processed/restaurant_db_processed.json"
|
22 |
+
"taxi": "data/preprocessed/UBAR/db_processed/taxi_db_processed.json"
|
23 |
+
"train": "data/preprocessed/UBAR/db_processed/train_db_processed.json"
|
scripts/UBAR_code/preprocess.py
ADDED
@@ -0,0 +1,576 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import zipfile
|
6 |
+
from collections import OrderedDict
|
7 |
+
|
8 |
+
import spacy
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from crazyneuraluser.UBAR_code import ontology, utils
|
12 |
+
from crazyneuraluser.UBAR_code.clean_dataset import clean_slot_values, clean_text
|
13 |
+
from crazyneuraluser.UBAR_code.config import global_config as cfg
|
14 |
+
from crazyneuraluser.UBAR_code.db_ops import MultiWozDB
|
15 |
+
|
16 |
+
|
17 |
+
# value_set.json, all the domain[slot] values in datasets
|
18 |
+
def get_db_values(value_set_path):
|
19 |
+
processed = {}
|
20 |
+
bspn_word = []
|
21 |
+
nlp = spacy.load("en_core_web_sm")
|
22 |
+
|
23 |
+
with open(value_set_path, "r") as f: # read value set file in lower
|
24 |
+
value_set = json.loads(f.read().lower())
|
25 |
+
|
26 |
+
with open("data/raw/UBAR/db/ontology.json", "r") as f: # read ontology in lower, all the domain-slot values
|
27 |
+
otlg = json.loads(f.read().lower())
|
28 |
+
|
29 |
+
for (
|
30 |
+
domain,
|
31 |
+
slots,
|
32 |
+
) in value_set.items(): # add all informable slots to bspn_word, create lists holder for values
|
33 |
+
processed[domain] = {}
|
34 |
+
bspn_word.append("[" + domain + "]")
|
35 |
+
for slot, values in slots.items():
|
36 |
+
s_p = ontology.normlize_slot_names.get(slot, slot)
|
37 |
+
if s_p in ontology.informable_slots[domain]:
|
38 |
+
bspn_word.append(s_p)
|
39 |
+
processed[domain][s_p] = []
|
40 |
+
|
41 |
+
for (
|
42 |
+
domain,
|
43 |
+
slots,
|
44 |
+
) in value_set.items(): # add all words of values of informable slots to bspn_word
|
45 |
+
for slot, values in slots.items():
|
46 |
+
s_p = ontology.normlize_slot_names.get(slot, slot)
|
47 |
+
if s_p in ontology.informable_slots[domain]:
|
48 |
+
for v in values:
|
49 |
+
_, v_p = clean_slot_values(domain, slot, v)
|
50 |
+
v_p = " ".join([token.text for token in nlp(v_p)]).strip()
|
51 |
+
processed[domain][s_p].append(v_p)
|
52 |
+
for x in v_p.split():
|
53 |
+
if x not in bspn_word:
|
54 |
+
bspn_word.append(x)
|
55 |
+
|
56 |
+
for domain_slot, values in otlg.items(): # split domain-slots to domains and slots
|
57 |
+
domain, slot = domain_slot.split("-")
|
58 |
+
if domain == "bus":
|
59 |
+
domain = "taxi"
|
60 |
+
if slot == "price range":
|
61 |
+
slot = "pricerange"
|
62 |
+
if slot == "book stay":
|
63 |
+
slot = "stay"
|
64 |
+
if slot == "book day":
|
65 |
+
slot = "day"
|
66 |
+
if slot == "book people":
|
67 |
+
slot = "people"
|
68 |
+
if slot == "book time":
|
69 |
+
slot = "time"
|
70 |
+
if slot == "arrive by":
|
71 |
+
slot = "arrive"
|
72 |
+
if slot == "leave at":
|
73 |
+
slot = "leave"
|
74 |
+
if slot == "leaveat":
|
75 |
+
slot = "leave"
|
76 |
+
# add all slots and words of values if not already in processed and bspn_word
|
77 |
+
if slot not in processed[domain]:
|
78 |
+
processed[domain][slot] = []
|
79 |
+
bspn_word.append(slot)
|
80 |
+
for v in values:
|
81 |
+
_, v_p = clean_slot_values(domain, slot, v)
|
82 |
+
v_p = " ".join([token.text for token in nlp(v_p)]).strip()
|
83 |
+
if v_p not in processed[domain][slot]:
|
84 |
+
processed[domain][slot].append(v_p)
|
85 |
+
for x in v_p.split():
|
86 |
+
if x not in bspn_word:
|
87 |
+
bspn_word.append(x)
|
88 |
+
|
89 |
+
with open(value_set_path.replace(".json", "_processed.json"), "w") as f:
|
90 |
+
json.dump(processed, f, indent=2) # save processed.json
|
91 |
+
with open("data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/bspn_word_collection.json", "w") as f:
|
92 |
+
json.dump(bspn_word, f, indent=2) # save bspn_word
|
93 |
+
|
94 |
+
print("DB value set processed! ")
|
95 |
+
|
96 |
+
|
97 |
+
def preprocess_db(db_paths): # apply clean_slot_values to all dbs
|
98 |
+
dbs = {}
|
99 |
+
nlp = spacy.load("en_core_web_sm")
|
100 |
+
for domain in ontology.all_domains:
|
101 |
+
with open(db_paths[domain], "r") as f: # for every db_domain, read json file
|
102 |
+
dbs[domain] = json.loads(f.read().lower())
|
103 |
+
# entry has information about slots of said domain
|
104 |
+
for idx, entry in enumerate(dbs[domain]):
|
105 |
+
new_entry = copy.deepcopy(entry)
|
106 |
+
for key, value in entry.items(): # key = slot
|
107 |
+
if type(value) is not str:
|
108 |
+
continue
|
109 |
+
del new_entry[key]
|
110 |
+
key, value = clean_slot_values(domain, key, value)
|
111 |
+
tokenize_and_back = " ".join([token.text for token in nlp(value)]).strip()
|
112 |
+
new_entry[key] = tokenize_and_back
|
113 |
+
dbs[domain][idx] = new_entry
|
114 |
+
with open(db_paths[domain].replace(".json", "_processed.json"), "w") as f:
|
115 |
+
json.dump(dbs[domain], f, indent=2)
|
116 |
+
print("[%s] DB processed! " % domain)
|
117 |
+
|
118 |
+
|
119 |
+
class DataPreprocessor(object):
|
120 |
+
def __init__(self):
|
121 |
+
self.nlp = spacy.load("en_core_web_sm")
|
122 |
+
self.db = MultiWozDB(cfg.dbs) # load all processed dbs
|
123 |
+
data_path = "data/preprocessed/UBAR/gen_usr_utt_experiment_data_with_span_full.json"
|
124 |
+
# archive = zipfile.ZipFile(data_path + ".zip", "r")
|
125 |
+
# self.convlab_data = json.loads(archive.open(data_path.split("/")[-1], "r").read().lower())
|
126 |
+
self.convlab_data = json.loads(open(data_path, "r").read().lower())
|
127 |
+
self.delex_sg_valdict_path = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/delex_single_valdict.json"
|
128 |
+
self.delex_mt_valdict_path = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/delex_multi_valdict.json"
|
129 |
+
self.ambiguous_val_path = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/ambiguous_values.json"
|
130 |
+
self.delex_refs_path = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/reference_no.json"
|
131 |
+
self.delex_refs = json.loads(open(self.delex_refs_path, "r").read())
|
132 |
+
if not os.path.exists(self.delex_sg_valdict_path):
|
133 |
+
(
|
134 |
+
self.delex_sg_valdict,
|
135 |
+
self.delex_mt_valdict,
|
136 |
+
self.ambiguous_vals,
|
137 |
+
) = self.get_delex_valdict()
|
138 |
+
else:
|
139 |
+
self.delex_sg_valdict = json.loads(open(self.delex_sg_valdict_path, "r").read())
|
140 |
+
self.delex_mt_valdict = json.loads(open(self.delex_mt_valdict_path, "r").read())
|
141 |
+
self.ambiguous_vals = json.loads(open(self.ambiguous_val_path, "r").read())
|
142 |
+
|
143 |
+
self.vocab = utils.Vocab(cfg.vocab_size)
|
144 |
+
|
145 |
+
def delex_by_annotation(self, dial_turn):
|
146 |
+
u = dial_turn["text"].split()
|
147 |
+
span = dial_turn["span_info"]
|
148 |
+
for s in span:
|
149 |
+
slot = s[1]
|
150 |
+
if slot == "open":
|
151 |
+
continue
|
152 |
+
if ontology.da_abbr_to_slot_name.get(slot):
|
153 |
+
slot = ontology.da_abbr_to_slot_name[slot]
|
154 |
+
for idx in range(s[3], s[4] + 1):
|
155 |
+
u[idx] = ""
|
156 |
+
try:
|
157 |
+
u[s[3]] = "[value_" + slot + "]"
|
158 |
+
except Exception:
|
159 |
+
u[5] = "[value_" + slot + "]"
|
160 |
+
u_delex = " ".join([t for t in u if t != ""])
|
161 |
+
u_delex = u_delex.replace("[value_address] , [value_address] , [value_address]", "[value_address]")
|
162 |
+
u_delex = u_delex.replace("[value_address] , [value_address]", "[value_address]")
|
163 |
+
u_delex = u_delex.replace("[value_name] [value_name]", "[value_name]")
|
164 |
+
u_delex = u_delex.replace("[value_name]([value_phone] )", "[value_name] ( [value_phone] )")
|
165 |
+
return u_delex
|
166 |
+
|
167 |
+
def delex_by_valdict(self, text):
|
168 |
+
text = clean_text(text)
|
169 |
+
|
170 |
+
text = re.sub(r"\d{5}\s?\d{5,7}", "[value_phone]", text)
|
171 |
+
text = re.sub(r"\d[\s-]stars?", "[value_stars]", text)
|
172 |
+
text = re.sub(r"\$\d+|\$?\d+.?(\d+)?\s(pounds?|gbps?)", "[value_price]", text)
|
173 |
+
text = re.sub(r"tr[\d]{4}", "[value_id]", text)
|
174 |
+
text = re.sub(
|
175 |
+
r"([a-z]{1}[\. ]?[a-z]{1}[\. ]?\d{1,2}[, ]+\d{1}[\. ]?[a-z]{1}[\. ]?[a-z]{1}|[a-z]{2}\d{2}[a-z]{2})",
|
176 |
+
"[value_postcode]",
|
177 |
+
text,
|
178 |
+
)
|
179 |
+
|
180 |
+
for value, slot in self.delex_mt_valdict.items():
|
181 |
+
text = text.replace(value, "[value_%s]" % slot)
|
182 |
+
|
183 |
+
for value, slot in self.delex_sg_valdict.items():
|
184 |
+
tokens = text.split()
|
185 |
+
for idx, tk in enumerate(tokens):
|
186 |
+
if tk == value:
|
187 |
+
tokens[idx] = "[value_%s]" % slot
|
188 |
+
text = " ".join(tokens)
|
189 |
+
|
190 |
+
for ambg_ent in self.ambiguous_vals:
|
191 |
+
# ely is a place, but appears in words like moderately
|
192 |
+
start_idx = text.find(" " + ambg_ent)
|
193 |
+
if start_idx == -1:
|
194 |
+
continue
|
195 |
+
front_words = text[:start_idx].split()
|
196 |
+
ent_type = "time" if ":" in ambg_ent else "place"
|
197 |
+
|
198 |
+
for fw in front_words[::-1]:
|
199 |
+
if fw in [
|
200 |
+
"arrive",
|
201 |
+
"arrives",
|
202 |
+
"arrived",
|
203 |
+
"arriving",
|
204 |
+
"arrival",
|
205 |
+
"destination",
|
206 |
+
"there",
|
207 |
+
"reach",
|
208 |
+
"to",
|
209 |
+
"by",
|
210 |
+
"before",
|
211 |
+
]:
|
212 |
+
slot = "[value_arrive]" if ent_type == "time" else "[value_destination]"
|
213 |
+
text = re.sub(" " + ambg_ent, " " + slot, text)
|
214 |
+
elif fw in [
|
215 |
+
"leave",
|
216 |
+
"leaves",
|
217 |
+
"leaving",
|
218 |
+
"depart",
|
219 |
+
"departs",
|
220 |
+
"departing",
|
221 |
+
"departure",
|
222 |
+
"from",
|
223 |
+
"after",
|
224 |
+
"pulls",
|
225 |
+
]:
|
226 |
+
slot = "[value_leave]" if ent_type == "time" else "[value_departure]"
|
227 |
+
text = re.sub(" " + ambg_ent, " " + slot, text)
|
228 |
+
|
229 |
+
text = text.replace("[value_car] [value_car]", "[value_car]")
|
230 |
+
return text
|
231 |
+
|
232 |
+
def get_delex_valdict(
|
233 |
+
self,
|
234 |
+
):
|
235 |
+
skip_entry_type = {
|
236 |
+
"taxi": ["taxi_phone"],
|
237 |
+
"police": ["id"],
|
238 |
+
"hospital": ["id"],
|
239 |
+
"hotel": [
|
240 |
+
"id",
|
241 |
+
"location",
|
242 |
+
"internet",
|
243 |
+
"parking",
|
244 |
+
"takesbookings",
|
245 |
+
"stars",
|
246 |
+
"price",
|
247 |
+
"n",
|
248 |
+
"postcode",
|
249 |
+
"phone",
|
250 |
+
],
|
251 |
+
"attraction": [
|
252 |
+
"id",
|
253 |
+
"location",
|
254 |
+
"pricerange",
|
255 |
+
"price",
|
256 |
+
"openhours",
|
257 |
+
"postcode",
|
258 |
+
"phone",
|
259 |
+
],
|
260 |
+
"train": ["price", "id"],
|
261 |
+
"restaurant": [
|
262 |
+
"id",
|
263 |
+
"location",
|
264 |
+
"introduction",
|
265 |
+
"signature",
|
266 |
+
"type",
|
267 |
+
"postcode",
|
268 |
+
"phone",
|
269 |
+
],
|
270 |
+
}
|
271 |
+
entity_value_to_slot = {}
|
272 |
+
ambiguous_entities = []
|
273 |
+
for domain, db_data in self.db.dbs.items():
|
274 |
+
print("Processing entity values in [%s]" % domain)
|
275 |
+
if domain != "taxi":
|
276 |
+
for db_entry in db_data:
|
277 |
+
for slot, value in db_entry.items():
|
278 |
+
if slot not in skip_entry_type[domain]:
|
279 |
+
if type(value) is not str:
|
280 |
+
raise TypeError("value '%s' in domain '%s' should be rechecked" % (slot, domain))
|
281 |
+
else:
|
282 |
+
slot, value = clean_slot_values(domain, slot, value)
|
283 |
+
value = " ".join([token.text for token in self.nlp(value)]).strip()
|
284 |
+
if value in entity_value_to_slot and entity_value_to_slot[value] != slot:
|
285 |
+
# print(value, ": ",entity_value_to_slot[value], slot)
|
286 |
+
ambiguous_entities.append(value)
|
287 |
+
entity_value_to_slot[value] = slot
|
288 |
+
else: # taxi db specific
|
289 |
+
db_entry = db_data[0]
|
290 |
+
for slot, ent_list in db_entry.items():
|
291 |
+
if slot not in skip_entry_type[domain]:
|
292 |
+
for ent in ent_list:
|
293 |
+
entity_value_to_slot[ent] = "car"
|
294 |
+
ambiguous_entities = set(ambiguous_entities)
|
295 |
+
ambiguous_entities.remove("cambridge")
|
296 |
+
ambiguous_entities = list(ambiguous_entities)
|
297 |
+
for amb_ent in ambiguous_entities: # departure or destination? arrive time or leave time?
|
298 |
+
entity_value_to_slot.pop(amb_ent)
|
299 |
+
entity_value_to_slot["parkside"] = "address"
|
300 |
+
entity_value_to_slot["parkside, cambridge"] = "address"
|
301 |
+
entity_value_to_slot["cambridge belfry"] = "name"
|
302 |
+
entity_value_to_slot["hills road"] = "address"
|
303 |
+
entity_value_to_slot["hills rd"] = "address"
|
304 |
+
entity_value_to_slot["Parkside Police Station"] = "name"
|
305 |
+
|
306 |
+
single_token_values = {}
|
307 |
+
multi_token_values = {}
|
308 |
+
for val, slt in entity_value_to_slot.items():
|
309 |
+
if val in ["cambridge"]:
|
310 |
+
continue
|
311 |
+
if len(val.split()) > 1:
|
312 |
+
multi_token_values[val] = slt
|
313 |
+
else:
|
314 |
+
single_token_values[val] = slt
|
315 |
+
|
316 |
+
with open(self.delex_sg_valdict_path, "w") as f:
|
317 |
+
single_token_values = OrderedDict(
|
318 |
+
sorted(single_token_values.items(), key=lambda kv: len(kv[0]), reverse=True)
|
319 |
+
)
|
320 |
+
json.dump(single_token_values, f, indent=2)
|
321 |
+
print("single delex value dict saved!")
|
322 |
+
with open(self.delex_mt_valdict_path, "w") as f:
|
323 |
+
multi_token_values = OrderedDict(
|
324 |
+
sorted(multi_token_values.items(), key=lambda kv: len(kv[0]), reverse=True)
|
325 |
+
)
|
326 |
+
json.dump(multi_token_values, f, indent=2)
|
327 |
+
print("multi delex value dict saved!")
|
328 |
+
with open(self.ambiguous_val_path, "w") as f:
|
329 |
+
json.dump(ambiguous_entities, f, indent=2)
|
330 |
+
print("ambiguous value dict saved!")
|
331 |
+
|
332 |
+
return single_token_values, multi_token_values, ambiguous_entities
|
333 |
+
|
334 |
+
def preprocess_main(self, save_path=None, is_test=False):
|
335 |
+
""" """
|
336 |
+
data = {}
|
337 |
+
count = 0
|
338 |
+
self.unique_da = {}
|
339 |
+
ordered_sysact_dict = {}
|
340 |
+
for fn, raw_dial in tqdm(list(self.convlab_data.items())):
|
341 |
+
count += 1
|
342 |
+
# if count == 100:
|
343 |
+
# break
|
344 |
+
|
345 |
+
compressed_goal = {} # for every dialog, keep track the goal, domains, requests
|
346 |
+
dial_domains, dial_reqs = [], []
|
347 |
+
for dom, g in raw_dial["goal"].items():
|
348 |
+
if dom != "topic" and dom != "message" and g:
|
349 |
+
if g.get("reqt"): # request info. eg. postcode/address/phone
|
350 |
+
# normalize request slots
|
351 |
+
for i, req_slot in enumerate(g["reqt"]):
|
352 |
+
if ontology.normlize_slot_names.get(req_slot):
|
353 |
+
g["reqt"][i] = ontology.normlize_slot_names[req_slot]
|
354 |
+
dial_reqs.append(g["reqt"][i])
|
355 |
+
compressed_goal[dom] = g
|
356 |
+
if dom in ontology.all_domains:
|
357 |
+
dial_domains.append(dom)
|
358 |
+
|
359 |
+
dial_reqs = list(set(dial_reqs))
|
360 |
+
|
361 |
+
dial = {"goal": compressed_goal, "log": []}
|
362 |
+
single_turn = {}
|
363 |
+
constraint_dict = OrderedDict()
|
364 |
+
prev_constraint_dict = {}
|
365 |
+
prev_turn_domain = ["general"]
|
366 |
+
ordered_sysact_dict[fn] = {}
|
367 |
+
|
368 |
+
for turn_num, dial_turn in enumerate(raw_dial["log"]):
|
369 |
+
# for user turn, have text
|
370 |
+
# sys turn: text, belief states(metadata), dialog_act, span_info
|
371 |
+
dial_state = dial_turn["metadata"]
|
372 |
+
if not dial_state: # user
|
373 |
+
# delexicalize user utterance, either by annotation or by val_dict
|
374 |
+
u = " ".join(clean_text(dial_turn["text"]).split())
|
375 |
+
|
376 |
+
# NOTE: Commenting out delexicalisation because it is not used and
|
377 |
+
# breaks when I use generated user dialogues for some reason
|
378 |
+
|
379 |
+
# if dial_turn["span_info"]:
|
380 |
+
# u_delex = clean_text(self.delex_by_annotation(dial_turn))
|
381 |
+
# else:
|
382 |
+
# u_delex = self.delex_by_valdict(dial_turn["text"])
|
383 |
+
|
384 |
+
single_turn["user"] = u
|
385 |
+
# single_turn["user_delex"] = u_delex
|
386 |
+
|
387 |
+
else: # system
|
388 |
+
# delexicalize system response, either by annotation or by val_dict
|
389 |
+
if dial_turn["span_info"]:
|
390 |
+
s_delex = clean_text(self.delex_by_annotation(dial_turn))
|
391 |
+
else:
|
392 |
+
if not dial_turn["text"]:
|
393 |
+
print(fn)
|
394 |
+
s_delex = self.delex_by_valdict(dial_turn["text"])
|
395 |
+
single_turn["resp"] = s_delex
|
396 |
+
|
397 |
+
# get belief state, semi=informable/book=requestable, put into constraint_dict
|
398 |
+
for domain in dial_domains:
|
399 |
+
if not constraint_dict.get(domain):
|
400 |
+
constraint_dict[domain] = OrderedDict()
|
401 |
+
info_sv = dial_state[domain]["semi"]
|
402 |
+
for s, v in info_sv.items():
|
403 |
+
s, v = clean_slot_values(domain, s, v)
|
404 |
+
if len(v.split()) > 1:
|
405 |
+
v = " ".join([token.text for token in self.nlp(v)]).strip()
|
406 |
+
if v != "":
|
407 |
+
constraint_dict[domain][s] = v
|
408 |
+
book_sv = dial_state[domain]["book"]
|
409 |
+
for s, v in book_sv.items():
|
410 |
+
if s == "booked":
|
411 |
+
continue
|
412 |
+
s, v = clean_slot_values(domain, s, v)
|
413 |
+
if len(v.split()) > 1:
|
414 |
+
v = " ".join([token.text for token in self.nlp(v)]).strip()
|
415 |
+
if v != "":
|
416 |
+
constraint_dict[domain][s] = v
|
417 |
+
|
418 |
+
constraints = [] # list in format of [domain] slot value
|
419 |
+
cons_delex = []
|
420 |
+
turn_dom_bs = []
|
421 |
+
for domain, info_slots in constraint_dict.items():
|
422 |
+
if info_slots:
|
423 |
+
constraints.append("[" + domain + "]")
|
424 |
+
cons_delex.append("[" + domain + "]")
|
425 |
+
for slot, value in info_slots.items():
|
426 |
+
constraints.append(slot)
|
427 |
+
constraints.extend(value.split())
|
428 |
+
cons_delex.append(slot)
|
429 |
+
if domain not in prev_constraint_dict:
|
430 |
+
turn_dom_bs.append(domain)
|
431 |
+
elif prev_constraint_dict[domain] != constraint_dict[domain]:
|
432 |
+
turn_dom_bs.append(domain)
|
433 |
+
|
434 |
+
sys_act_dict = {}
|
435 |
+
turn_dom_da = set()
|
436 |
+
for act in dial_turn["dialog_act"]:
|
437 |
+
d, a = act.split("-") # split domain-act
|
438 |
+
turn_dom_da.add(d)
|
439 |
+
turn_dom_da = list(turn_dom_da)
|
440 |
+
if len(turn_dom_da) != 1 and "general" in turn_dom_da:
|
441 |
+
turn_dom_da.remove("general")
|
442 |
+
if len(turn_dom_da) != 1 and "booking" in turn_dom_da:
|
443 |
+
turn_dom_da.remove("booking")
|
444 |
+
|
445 |
+
# get turn domain
|
446 |
+
turn_domain = turn_dom_bs
|
447 |
+
for dom in turn_dom_da:
|
448 |
+
if dom != "booking" and dom not in turn_domain:
|
449 |
+
turn_domain.append(dom)
|
450 |
+
if not turn_domain:
|
451 |
+
turn_domain = prev_turn_domain
|
452 |
+
if len(turn_domain) == 2 and "general" in turn_domain:
|
453 |
+
turn_domain.remove("general")
|
454 |
+
if len(turn_domain) == 2:
|
455 |
+
if len(prev_turn_domain) == 1 and prev_turn_domain[0] == turn_domain[1]:
|
456 |
+
turn_domain = turn_domain[::-1]
|
457 |
+
|
458 |
+
# get system action
|
459 |
+
for dom in turn_domain:
|
460 |
+
sys_act_dict[dom] = {}
|
461 |
+
add_to_last_collect = []
|
462 |
+
booking_act_map = {"inform": "offerbook", "book": "offerbooked"}
|
463 |
+
for act, params in dial_turn["dialog_act"].items():
|
464 |
+
if act == "general-greet":
|
465 |
+
continue
|
466 |
+
d, a = act.split("-")
|
467 |
+
if d == "general" and d not in sys_act_dict:
|
468 |
+
sys_act_dict[d] = {}
|
469 |
+
if d == "booking":
|
470 |
+
d = turn_domain[0]
|
471 |
+
a = booking_act_map.get(a, a)
|
472 |
+
add_p = []
|
473 |
+
for param in params:
|
474 |
+
p = param[0]
|
475 |
+
if p == "none":
|
476 |
+
continue
|
477 |
+
elif ontology.da_abbr_to_slot_name.get(p):
|
478 |
+
p = ontology.da_abbr_to_slot_name[p]
|
479 |
+
if p not in add_p:
|
480 |
+
add_p.append(p)
|
481 |
+
add_to_last = True if a in ["request", "reqmore", "bye", "offerbook"] else False
|
482 |
+
if add_to_last:
|
483 |
+
add_to_last_collect.append((d, a, add_p))
|
484 |
+
else:
|
485 |
+
sys_act_dict[d][a] = add_p
|
486 |
+
for d, a, add_p in add_to_last_collect:
|
487 |
+
sys_act_dict[d][a] = add_p
|
488 |
+
|
489 |
+
for d in copy.copy(sys_act_dict):
|
490 |
+
acts = sys_act_dict[d]
|
491 |
+
if not acts:
|
492 |
+
del sys_act_dict[d]
|
493 |
+
if "inform" in acts and "offerbooked" in acts:
|
494 |
+
for s in sys_act_dict[d]["inform"]:
|
495 |
+
sys_act_dict[d]["offerbooked"].append(s)
|
496 |
+
del sys_act_dict[d]["inform"]
|
497 |
+
|
498 |
+
ordered_sysact_dict[fn][len(dial["log"])] = sys_act_dict
|
499 |
+
|
500 |
+
sys_act = []
|
501 |
+
if "general-greet" in dial_turn["dialog_act"]:
|
502 |
+
sys_act.extend(["[general]", "[greet]"])
|
503 |
+
for d, acts in sys_act_dict.items():
|
504 |
+
sys_act += ["[" + d + "]"]
|
505 |
+
for a, slots in acts.items():
|
506 |
+
self.unique_da[d + "-" + a] = 1
|
507 |
+
sys_act += ["[" + a + "]"]
|
508 |
+
sys_act += slots
|
509 |
+
|
510 |
+
# get db pointers
|
511 |
+
matnums = self.db.get_match_num(constraint_dict)
|
512 |
+
match_dom = turn_domain[0] if len(turn_domain) == 1 else turn_domain[1]
|
513 |
+
match = matnums[match_dom]
|
514 |
+
dbvec = self.db.addDBPointer(match_dom, match)
|
515 |
+
bkvec = self.db.addBookingPointer(dial_turn["dialog_act"])
|
516 |
+
|
517 |
+
# 4 database pointer for domains, 2 for booking
|
518 |
+
single_turn["pointer"] = ",".join([str(d) for d in dbvec + bkvec])
|
519 |
+
single_turn["match"] = str(match)
|
520 |
+
single_turn["constraint"] = " ".join(constraints)
|
521 |
+
single_turn["cons_delex"] = " ".join(cons_delex)
|
522 |
+
single_turn["sys_act"] = " ".join(sys_act)
|
523 |
+
single_turn["turn_num"] = len(dial["log"])
|
524 |
+
single_turn["turn_domain"] = " ".join(["[" + d + "]" for d in turn_domain])
|
525 |
+
|
526 |
+
prev_turn_domain = copy.deepcopy(turn_domain)
|
527 |
+
prev_constraint_dict = copy.deepcopy(constraint_dict)
|
528 |
+
|
529 |
+
if "user" in single_turn:
|
530 |
+
dial["log"].append(single_turn)
|
531 |
+
for t in single_turn["user"].split() + single_turn["resp"].split() + constraints + sys_act:
|
532 |
+
self.vocab.add_word(t)
|
533 |
+
|
534 |
+
# NOTE: Commenting out delexicalisation because it is not used and
|
535 |
+
# breaks when I use generated user dialogues for some reason
|
536 |
+
|
537 |
+
# for t in single_turn["user_delex"].split():
|
538 |
+
# if "[" in t and "]" in t and not t.startswith("[") and not t.endswith("]"):
|
539 |
+
# single_turn["user_delex"].replace(t, t[t.index("[") : t.index("]") + 1])
|
540 |
+
# elif not self.vocab.has_word(t):
|
541 |
+
# self.vocab.add_word(t)
|
542 |
+
|
543 |
+
single_turn = {}
|
544 |
+
|
545 |
+
data[fn] = dial
|
546 |
+
# pprint(dial)
|
547 |
+
# if count == 20:
|
548 |
+
# break
|
549 |
+
self.vocab.construct()
|
550 |
+
self.vocab.save_vocab("data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/vocab")
|
551 |
+
with open("data/interim/gen_usr_utts/multi-woz-analysis/dialog_acts.json", "w") as f:
|
552 |
+
json.dump(ordered_sysact_dict, f, indent=2)
|
553 |
+
with open("data/interim/gen_usr_utts/multi-woz-analysis/dialog_act_type.json", "w") as f:
|
554 |
+
json.dump(self.unique_da, f, indent=2)
|
555 |
+
return data
|
556 |
+
|
557 |
+
|
558 |
+
if __name__ == "__main__":
|
559 |
+
db_paths = {
|
560 |
+
"attraction": "data/raw/UBAR/db/attraction_db.json",
|
561 |
+
"hospital": "data/raw/UBAR/db/hospital_db.json",
|
562 |
+
"hotel": "data/raw/UBAR/db/hotel_db.json",
|
563 |
+
"police": "data/raw/UBAR/db/police_db.json",
|
564 |
+
"restaurant": "data/raw/UBAR/db/restaurant_db.json",
|
565 |
+
"taxi": "data/raw/UBAR/db/taxi_db.json",
|
566 |
+
"train": "data/raw/UBAR/db/train_db.json",
|
567 |
+
}
|
568 |
+
get_db_values("data/raw/UBAR/db/value_set.json")
|
569 |
+
preprocess_db(db_paths)
|
570 |
+
dh = DataPreprocessor()
|
571 |
+
data = dh.preprocess_main()
|
572 |
+
if not os.path.exists("data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed"):
|
573 |
+
os.mkdir("data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed")
|
574 |
+
|
575 |
+
with open("data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/data_for_ubar.json", "w") as f:
|
576 |
+
json.dump(data, f, indent=2)
|
scripts/UBAR_code/preprocess2.1.py
ADDED
@@ -0,0 +1,585 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import re
|
5 |
+
import zipfile
|
6 |
+
from collections import OrderedDict
|
7 |
+
|
8 |
+
import spacy
|
9 |
+
from tqdm import tqdm
|
10 |
+
|
11 |
+
from crazyneuraluser.UBAR_code import ontology, utils
|
12 |
+
from crazyneuraluser.UBAR_code.clean_dataset import clean_slot_values, clean_text
|
13 |
+
from crazyneuraluser.UBAR_code.config import global_config as cfg
|
14 |
+
from crazyneuraluser.UBAR_code.db_ops import MultiWozDB
|
15 |
+
|
16 |
+
|
17 |
+
def get_db_values(
|
18 |
+
value_set_path,
|
19 |
+
): # value_set.json, all the domain[slot] values in datasets
|
20 |
+
processed = {}
|
21 |
+
bspn_word = []
|
22 |
+
nlp = spacy.load("en_core_web_sm")
|
23 |
+
|
24 |
+
with open(value_set_path, "r") as f: # read value set file in lower
|
25 |
+
value_set = json.loads(f.read().lower())
|
26 |
+
|
27 |
+
with open("db/ontology.json", "r") as f: # read ontology in lower, all the domain-slot values
|
28 |
+
otlg = json.loads(f.read().lower())
|
29 |
+
|
30 |
+
for (
|
31 |
+
domain,
|
32 |
+
slots,
|
33 |
+
) in value_set.items(): # add all informable slots to bspn_word, create lists holder for values
|
34 |
+
processed[domain] = {}
|
35 |
+
bspn_word.append("[" + domain + "]")
|
36 |
+
for slot, values in slots.items():
|
37 |
+
s_p = ontology.normlize_slot_names.get(slot, slot)
|
38 |
+
if s_p in ontology.informable_slots[domain]:
|
39 |
+
bspn_word.append(s_p)
|
40 |
+
processed[domain][s_p] = []
|
41 |
+
|
42 |
+
for (
|
43 |
+
domain,
|
44 |
+
slots,
|
45 |
+
) in value_set.items(): # add all words of values of informable slots to bspn_word
|
46 |
+
for slot, values in slots.items():
|
47 |
+
s_p = ontology.normlize_slot_names.get(slot, slot)
|
48 |
+
if s_p in ontology.informable_slots[domain]:
|
49 |
+
for v in values:
|
50 |
+
_, v_p = clean_slot_values(domain, slot, v)
|
51 |
+
v_p = " ".join([token.text for token in nlp(v_p)]).strip()
|
52 |
+
processed[domain][s_p].append(v_p)
|
53 |
+
for x in v_p.split():
|
54 |
+
if x not in bspn_word:
|
55 |
+
bspn_word.append(x)
|
56 |
+
|
57 |
+
for domain_slot, values in otlg.items(): # split domain-slots to domains and slots
|
58 |
+
domain, slot = domain_slot.split("-")
|
59 |
+
if domain == "bus":
|
60 |
+
domain = "taxi"
|
61 |
+
if slot == "price range":
|
62 |
+
slot = "pricerange"
|
63 |
+
if slot == "book stay":
|
64 |
+
slot = "stay"
|
65 |
+
if slot == "book day":
|
66 |
+
slot = "day"
|
67 |
+
if slot == "book people":
|
68 |
+
slot = "people"
|
69 |
+
if slot == "book time":
|
70 |
+
slot = "time"
|
71 |
+
if slot == "arrive by":
|
72 |
+
slot = "arrive"
|
73 |
+
if slot == "leave at":
|
74 |
+
slot = "leave"
|
75 |
+
if slot == "leaveat":
|
76 |
+
slot = "leave"
|
77 |
+
if slot not in processed[domain]: # add all slots and words of values if not already in processed and bspn_word
|
78 |
+
processed[domain][slot] = []
|
79 |
+
bspn_word.append(slot)
|
80 |
+
for v in values:
|
81 |
+
_, v_p = clean_slot_values(domain, slot, v)
|
82 |
+
v_p = " ".join([token.text for token in nlp(v_p)]).strip()
|
83 |
+
if v_p not in processed[domain][slot]:
|
84 |
+
processed[domain][slot].append(v_p)
|
85 |
+
for x in v_p.split():
|
86 |
+
if x not in bspn_word:
|
87 |
+
bspn_word.append(x)
|
88 |
+
|
89 |
+
with open(value_set_path.replace(".json", "_processed.json"), "w") as f:
|
90 |
+
json.dump(processed, f, indent=2) # save processed.json
|
91 |
+
with open("data/preprocessed/UBAR/multi-woz-processed/bspn_word_collection.json", "w") as f:
|
92 |
+
json.dump(bspn_word, f, indent=2) # save bspn_word
|
93 |
+
|
94 |
+
print("DB value set processed! ")
|
95 |
+
|
96 |
+
|
97 |
+
def preprocess_db(db_paths): # apply clean_slot_values to all dbs
|
98 |
+
dbs = {}
|
99 |
+
nlp = spacy.load("en_core_web_sm")
|
100 |
+
for domain in ontology.all_domains:
|
101 |
+
with open(db_paths[domain], "r") as f: # for every db_domain, read json file
|
102 |
+
dbs[domain] = json.loads(f.read().lower())
|
103 |
+
for idx, entry in enumerate(dbs[domain]): # entry has information about slots of said domain
|
104 |
+
new_entry = copy.deepcopy(entry)
|
105 |
+
for key, value in entry.items(): # key = slot
|
106 |
+
if type(value) is not str:
|
107 |
+
continue
|
108 |
+
del new_entry[key]
|
109 |
+
key, value = clean_slot_values(domain, key, value)
|
110 |
+
tokenize_and_back = " ".join([token.text for token in nlp(value)]).strip()
|
111 |
+
new_entry[key] = tokenize_and_back
|
112 |
+
dbs[domain][idx] = new_entry
|
113 |
+
with open(db_paths[domain].replace(".json", "_processed.json"), "w") as f:
|
114 |
+
json.dump(dbs[domain], f, indent=2)
|
115 |
+
print("[%s] DB processed! " % domain)
|
116 |
+
|
117 |
+
|
118 |
+
# 2.1
|
119 |
+
class DataPreprocessor(object):
|
120 |
+
def __init__(self):
|
121 |
+
self.nlp = spacy.load("en_core_web_sm")
|
122 |
+
self.db = MultiWozDB(cfg.dbs) # load all processed dbs
|
123 |
+
# data_path = 'data/multi-woz/annotated_user_da_with_span_full.json'
|
124 |
+
data_path = "data/raw/UBAR/MultiWOZ_2.1/data.json"
|
125 |
+
archive = zipfile.ZipFile(data_path + ".zip", "r")
|
126 |
+
self.convlab_data = json.loads(archive.open(data_path.split("/")[-1], "r").read().lower())
|
127 |
+
# self.delex_sg_valdict_path = 'data/multi-woz-processed/delex_single_valdict.json'
|
128 |
+
# self.delex_mt_valdict_path = 'data/multi-woz-processed/delex_multi_valdict.json'
|
129 |
+
# self.ambiguous_val_path = 'data/multi-woz-processed/ambiguous_values.json'
|
130 |
+
# self.delex_refs_path = 'data/multi-woz-processed/reference_no.json'
|
131 |
+
self.delex_sg_valdict_path = "data/preprocessed/UBAR/multi-woz-2.1-processed/delex_single_valdict.json"
|
132 |
+
self.delex_mt_valdict_path = "data/preprocessed/UBAR/multi-woz-2.1-processed/delex_multi_valdict.json"
|
133 |
+
self.ambiguous_val_path = "data/preprocessed/UBAR/multi-woz-2.1-processed/ambiguous_values.json"
|
134 |
+
self.delex_refs_path = "data/preprocessed/UBAR/multi-woz-2.1-processed/reference_no.json"
|
135 |
+
self.delex_refs = json.loads(open(self.delex_refs_path, "r").read())
|
136 |
+
if not os.path.exists(self.delex_sg_valdict_path):
|
137 |
+
(
|
138 |
+
self.delex_sg_valdict,
|
139 |
+
self.delex_mt_valdict,
|
140 |
+
self.ambiguous_vals,
|
141 |
+
) = self.get_delex_valdict()
|
142 |
+
else:
|
143 |
+
self.delex_sg_valdict = json.loads(open(self.delex_sg_valdict_path, "r").read())
|
144 |
+
self.delex_mt_valdict = json.loads(open(self.delex_mt_valdict_path, "r").read())
|
145 |
+
self.ambiguous_vals = json.loads(open(self.ambiguous_val_path, "r").read())
|
146 |
+
|
147 |
+
self.vocab = utils.Vocab(cfg.vocab_size)
|
148 |
+
|
149 |
+
def delex_by_annotation(self, dial_turn):
|
150 |
+
# add by yyy in 13:48 0803
|
151 |
+
u = dial_turn["text"].split()
|
152 |
+
# u = my_clean_text(dial_turn['text']).split()
|
153 |
+
##
|
154 |
+
span = dial_turn["span_info"]
|
155 |
+
for s in span:
|
156 |
+
slot = s[1]
|
157 |
+
if slot == "open":
|
158 |
+
continue
|
159 |
+
if ontology.da_abbr_to_slot_name.get(slot):
|
160 |
+
slot = ontology.da_abbr_to_slot_name[slot]
|
161 |
+
for idx in range(s[3], s[4] + 1):
|
162 |
+
u[idx] = ""
|
163 |
+
try:
|
164 |
+
u[s[3]] = "[value_" + slot + "]"
|
165 |
+
except Exception:
|
166 |
+
u[5] = "[value_" + slot + "]"
|
167 |
+
u_delex = " ".join([t for t in u if t != ""])
|
168 |
+
u_delex = u_delex.replace("[value_address] , [value_address] , [value_address]", "[value_address]")
|
169 |
+
u_delex = u_delex.replace("[value_address] , [value_address]", "[value_address]")
|
170 |
+
u_delex = u_delex.replace("[value_name] [value_name]", "[value_name]")
|
171 |
+
u_delex = u_delex.replace("[value_name]([value_phone] )", "[value_name] ( [value_phone] )")
|
172 |
+
return u_delex
|
173 |
+
|
174 |
+
def delex_by_valdict(self, text):
|
175 |
+
text = clean_text(text)
|
176 |
+
|
177 |
+
text = re.sub(r"\d{5}\s?\d{5,7}", "[value_phone]", text)
|
178 |
+
text = re.sub(r"\d[\s-]stars?", "[value_stars]", text)
|
179 |
+
text = re.sub(r"\$\d+|\$?\d+.?(\d+)?\s(pounds?|gbps?)", "[value_price]", text)
|
180 |
+
text = re.sub(r"tr[\d]{4}", "[value_id]", text)
|
181 |
+
text = re.sub(
|
182 |
+
r"([a-z]{1}[\. ]?[a-z]{1}[\. ]?\d{1,2}[, ]+\d{1}[\. ]?[a-z]{1}[\. ]?[a-z]{1}|[a-z]{2}\d{2}[a-z]{2})",
|
183 |
+
"[value_postcode]",
|
184 |
+
text,
|
185 |
+
)
|
186 |
+
|
187 |
+
for value, slot in self.delex_mt_valdict.items():
|
188 |
+
text = text.replace(value, "[value_%s]" % slot)
|
189 |
+
|
190 |
+
for value, slot in self.delex_sg_valdict.items():
|
191 |
+
tokens = text.split()
|
192 |
+
for idx, tk in enumerate(tokens):
|
193 |
+
if tk == value:
|
194 |
+
tokens[idx] = "[value_%s]" % slot
|
195 |
+
text = " ".join(tokens)
|
196 |
+
|
197 |
+
for ambg_ent in self.ambiguous_vals:
|
198 |
+
start_idx = text.find(" " + ambg_ent) # ely is a place, but appears in words like moderately
|
199 |
+
if start_idx == -1:
|
200 |
+
continue
|
201 |
+
front_words = text[:start_idx].split()
|
202 |
+
ent_type = "time" if ":" in ambg_ent else "place"
|
203 |
+
|
204 |
+
for fw in front_words[::-1]:
|
205 |
+
if fw in [
|
206 |
+
"arrive",
|
207 |
+
"arrives",
|
208 |
+
"arrived",
|
209 |
+
"arriving",
|
210 |
+
"arrival",
|
211 |
+
"destination",
|
212 |
+
"there",
|
213 |
+
"reach",
|
214 |
+
"to",
|
215 |
+
"by",
|
216 |
+
"before",
|
217 |
+
]:
|
218 |
+
slot = "[value_arrive]" if ent_type == "time" else "[value_destination]"
|
219 |
+
text = re.sub(" " + ambg_ent, " " + slot, text)
|
220 |
+
elif fw in [
|
221 |
+
"leave",
|
222 |
+
"leaves",
|
223 |
+
"leaving",
|
224 |
+
"depart",
|
225 |
+
"departs",
|
226 |
+
"departing",
|
227 |
+
"departure",
|
228 |
+
"from",
|
229 |
+
"after",
|
230 |
+
"pulls",
|
231 |
+
]:
|
232 |
+
slot = "[value_leave]" if ent_type == "time" else "[value_departure]"
|
233 |
+
text = re.sub(" " + ambg_ent, " " + slot, text)
|
234 |
+
|
235 |
+
text = text.replace("[value_car] [value_car]", "[value_car]")
|
236 |
+
return text
|
237 |
+
|
238 |
+
def get_delex_valdict(
|
239 |
+
self,
|
240 |
+
):
|
241 |
+
skip_entry_type = {
|
242 |
+
"taxi": ["taxi_phone"],
|
243 |
+
"police": ["id"],
|
244 |
+
"hospital": ["id"],
|
245 |
+
"hotel": [
|
246 |
+
"id",
|
247 |
+
"location",
|
248 |
+
"internet",
|
249 |
+
"parking",
|
250 |
+
"takesbookings",
|
251 |
+
"stars",
|
252 |
+
"price",
|
253 |
+
"n",
|
254 |
+
"postcode",
|
255 |
+
"phone",
|
256 |
+
],
|
257 |
+
"attraction": [
|
258 |
+
"id",
|
259 |
+
"location",
|
260 |
+
"pricerange",
|
261 |
+
"price",
|
262 |
+
"openhours",
|
263 |
+
"postcode",
|
264 |
+
"phone",
|
265 |
+
],
|
266 |
+
"train": ["price", "id"],
|
267 |
+
"restaurant": [
|
268 |
+
"id",
|
269 |
+
"location",
|
270 |
+
"introduction",
|
271 |
+
"signature",
|
272 |
+
"type",
|
273 |
+
"postcode",
|
274 |
+
"phone",
|
275 |
+
],
|
276 |
+
}
|
277 |
+
entity_value_to_slot = {}
|
278 |
+
ambiguous_entities = []
|
279 |
+
for domain, db_data in self.db.dbs.items():
|
280 |
+
print("Processing entity values in [%s]" % domain)
|
281 |
+
if domain != "taxi":
|
282 |
+
for db_entry in db_data:
|
283 |
+
for slot, value in db_entry.items():
|
284 |
+
if slot not in skip_entry_type[domain]:
|
285 |
+
if type(value) is not str:
|
286 |
+
raise TypeError("value '%s' in domain '%s' should be rechecked" % (slot, domain))
|
287 |
+
else:
|
288 |
+
slot, value = clean_slot_values(domain, slot, value)
|
289 |
+
value = " ".join([token.text for token in self.nlp(value)]).strip()
|
290 |
+
if value in entity_value_to_slot and entity_value_to_slot[value] != slot:
|
291 |
+
# print(value, ": ",entity_value_to_slot[value], slot)
|
292 |
+
ambiguous_entities.append(value)
|
293 |
+
entity_value_to_slot[value] = slot
|
294 |
+
else: # taxi db specific
|
295 |
+
db_entry = db_data[0]
|
296 |
+
for slot, ent_list in db_entry.items():
|
297 |
+
if slot not in skip_entry_type[domain]:
|
298 |
+
for ent in ent_list:
|
299 |
+
entity_value_to_slot[ent] = "car"
|
300 |
+
ambiguous_entities = set(ambiguous_entities)
|
301 |
+
ambiguous_entities.remove("cambridge")
|
302 |
+
ambiguous_entities = list(ambiguous_entities)
|
303 |
+
for amb_ent in ambiguous_entities: # departure or destination? arrive time or leave time?
|
304 |
+
entity_value_to_slot.pop(amb_ent)
|
305 |
+
entity_value_to_slot["parkside"] = "address"
|
306 |
+
entity_value_to_slot["parkside, cambridge"] = "address"
|
307 |
+
entity_value_to_slot["cambridge belfry"] = "name"
|
308 |
+
entity_value_to_slot["hills road"] = "address"
|
309 |
+
entity_value_to_slot["hills rd"] = "address"
|
310 |
+
entity_value_to_slot["Parkside Police Station"] = "name"
|
311 |
+
|
312 |
+
single_token_values = {}
|
313 |
+
multi_token_values = {}
|
314 |
+
for val, slt in entity_value_to_slot.items():
|
315 |
+
if val in ["cambridge"]:
|
316 |
+
continue
|
317 |
+
if len(val.split()) > 1:
|
318 |
+
multi_token_values[val] = slt
|
319 |
+
else:
|
320 |
+
single_token_values[val] = slt
|
321 |
+
|
322 |
+
with open(self.delex_sg_valdict_path, "w") as f:
|
323 |
+
single_token_values = OrderedDict(
|
324 |
+
sorted(single_token_values.items(), key=lambda kv: len(kv[0]), reverse=True)
|
325 |
+
)
|
326 |
+
json.dump(single_token_values, f, indent=2)
|
327 |
+
print("single delex value dict saved!")
|
328 |
+
with open(self.delex_mt_valdict_path, "w") as f:
|
329 |
+
multi_token_values = OrderedDict(
|
330 |
+
sorted(multi_token_values.items(), key=lambda kv: len(kv[0]), reverse=True)
|
331 |
+
)
|
332 |
+
json.dump(multi_token_values, f, indent=2)
|
333 |
+
print("multi delex value dict saved!")
|
334 |
+
with open(self.ambiguous_val_path, "w") as f:
|
335 |
+
json.dump(ambiguous_entities, f, indent=2)
|
336 |
+
print("ambiguous value dict saved!")
|
337 |
+
|
338 |
+
return single_token_values, multi_token_values, ambiguous_entities
|
339 |
+
|
340 |
+
def preprocess_main(self, save_path=None, is_test=False):
|
341 |
+
""" """
|
342 |
+
data = {}
|
343 |
+
count = 0
|
344 |
+
self.unique_da = {}
|
345 |
+
ordered_sysact_dict = {}
|
346 |
+
# yyy
|
347 |
+
for fn, raw_dial in tqdm(list(self.convlab_data.items())):
|
348 |
+
if fn in [
|
349 |
+
"pmul4707.json",
|
350 |
+
"pmul2245.json",
|
351 |
+
"pmul4776.json",
|
352 |
+
"pmul3872.json",
|
353 |
+
"pmul4859.json",
|
354 |
+
]:
|
355 |
+
continue
|
356 |
+
count += 1
|
357 |
+
# if count == 100:
|
358 |
+
# break
|
359 |
+
|
360 |
+
compressed_goal = {} # for every dialog, keep track the goal, domains, requests
|
361 |
+
dial_domains, dial_reqs = [], []
|
362 |
+
for dom, g in raw_dial["goal"].items():
|
363 |
+
if dom != "topic" and dom != "message" and g:
|
364 |
+
if g.get("reqt"): # request info. eg. postcode/address/phone
|
365 |
+
for i, req_slot in enumerate(g["reqt"]): # normalize request slots
|
366 |
+
if ontology.normlize_slot_names.get(req_slot):
|
367 |
+
g["reqt"][i] = ontology.normlize_slot_names[req_slot]
|
368 |
+
dial_reqs.append(g["reqt"][i])
|
369 |
+
compressed_goal[dom] = g
|
370 |
+
if dom in ontology.all_domains:
|
371 |
+
dial_domains.append(dom)
|
372 |
+
|
373 |
+
dial_reqs = list(set(dial_reqs))
|
374 |
+
|
375 |
+
dial = {"goal": compressed_goal, "log": []}
|
376 |
+
single_turn = {}
|
377 |
+
constraint_dict = OrderedDict()
|
378 |
+
prev_constraint_dict = {}
|
379 |
+
prev_turn_domain = ["general"]
|
380 |
+
ordered_sysact_dict[fn] = {}
|
381 |
+
|
382 |
+
for turn_num, dial_turn in enumerate(raw_dial["log"]):
|
383 |
+
# for user turn, have text
|
384 |
+
# sys turn: text, belief states(metadata), dialog_act, span_info
|
385 |
+
dial_state = dial_turn["metadata"]
|
386 |
+
dial_turn["text"] = " ".join([t.text for t in self.nlp(dial_turn["text"])])
|
387 |
+
if not dial_state: # user
|
388 |
+
# delexicalize user utterance, either by annotation or by val_dict
|
389 |
+
u = " ".join(clean_text(dial_turn["text"]).split())
|
390 |
+
if "span_info" in dial_turn and dial_turn["span_info"]:
|
391 |
+
u_delex = clean_text(self.delex_by_annotation(dial_turn))
|
392 |
+
else:
|
393 |
+
u_delex = self.delex_by_valdict(dial_turn["text"])
|
394 |
+
|
395 |
+
single_turn["user"] = u
|
396 |
+
single_turn["user_delex"] = u_delex
|
397 |
+
|
398 |
+
else: # system
|
399 |
+
# delexicalize system response, either by annotation or by val_dict
|
400 |
+
if "span_info" in dial_turn and dial_turn["span_info"]:
|
401 |
+
s_delex = clean_text(self.delex_by_annotation(dial_turn))
|
402 |
+
else:
|
403 |
+
if not dial_turn["text"]:
|
404 |
+
print(fn)
|
405 |
+
s_delex = self.delex_by_valdict(dial_turn["text"])
|
406 |
+
single_turn["resp"] = s_delex
|
407 |
+
single_turn["nodelx_resp"] = " ".join(clean_text(dial_turn["text"]).split())
|
408 |
+
|
409 |
+
# get belief state, semi=informable/book=requestable, put into constraint_dict
|
410 |
+
for domain in dial_domains:
|
411 |
+
if not constraint_dict.get(domain):
|
412 |
+
constraint_dict[domain] = OrderedDict()
|
413 |
+
info_sv = dial_state[domain]["semi"]
|
414 |
+
for s, v in info_sv.items():
|
415 |
+
s, v = clean_slot_values(domain, s, v)
|
416 |
+
if len(v.split()) > 1:
|
417 |
+
v = " ".join([token.text for token in self.nlp(v)]).strip()
|
418 |
+
if v != "":
|
419 |
+
constraint_dict[domain][s] = v
|
420 |
+
book_sv = dial_state[domain]["book"]
|
421 |
+
for s, v in book_sv.items():
|
422 |
+
if s == "booked":
|
423 |
+
continue
|
424 |
+
s, v = clean_slot_values(domain, s, v)
|
425 |
+
if len(v.split()) > 1:
|
426 |
+
v = " ".join([token.text for token in self.nlp(v)]).strip()
|
427 |
+
if v != "":
|
428 |
+
constraint_dict[domain][s] = v
|
429 |
+
|
430 |
+
constraints = [] # list in format of [domain] slot value
|
431 |
+
cons_delex = []
|
432 |
+
turn_dom_bs = []
|
433 |
+
for domain, info_slots in constraint_dict.items():
|
434 |
+
if info_slots:
|
435 |
+
constraints.append("[" + domain + "]")
|
436 |
+
cons_delex.append("[" + domain + "]")
|
437 |
+
for slot, value in info_slots.items():
|
438 |
+
constraints.append(slot)
|
439 |
+
constraints.extend(value.split())
|
440 |
+
cons_delex.append(slot)
|
441 |
+
if domain not in prev_constraint_dict:
|
442 |
+
turn_dom_bs.append(domain)
|
443 |
+
elif prev_constraint_dict[domain] != constraint_dict[domain]:
|
444 |
+
turn_dom_bs.append(domain)
|
445 |
+
|
446 |
+
sys_act_dict = {}
|
447 |
+
turn_dom_da = set()
|
448 |
+
for act in dial_turn["dialog_act"]:
|
449 |
+
d, a = act.split("-") # split domain-act
|
450 |
+
turn_dom_da.add(d)
|
451 |
+
turn_dom_da = list(turn_dom_da)
|
452 |
+
if len(turn_dom_da) != 1 and "general" in turn_dom_da:
|
453 |
+
turn_dom_da.remove("general")
|
454 |
+
if len(turn_dom_da) != 1 and "booking" in turn_dom_da:
|
455 |
+
turn_dom_da.remove("booking")
|
456 |
+
|
457 |
+
# get turn domain
|
458 |
+
turn_domain = turn_dom_bs
|
459 |
+
for dom in turn_dom_da:
|
460 |
+
if dom != "booking" and dom not in turn_domain:
|
461 |
+
turn_domain.append(dom)
|
462 |
+
if not turn_domain:
|
463 |
+
turn_domain = prev_turn_domain
|
464 |
+
if len(turn_domain) == 2 and "general" in turn_domain:
|
465 |
+
turn_domain.remove("general")
|
466 |
+
if len(turn_domain) == 2:
|
467 |
+
if len(prev_turn_domain) == 1 and prev_turn_domain[0] == turn_domain[1]:
|
468 |
+
turn_domain = turn_domain[::-1]
|
469 |
+
|
470 |
+
# get system action
|
471 |
+
for dom in turn_domain:
|
472 |
+
sys_act_dict[dom] = {}
|
473 |
+
add_to_last_collect = []
|
474 |
+
booking_act_map = {"inform": "offerbook", "book": "offerbooked"}
|
475 |
+
for act, params in dial_turn["dialog_act"].items():
|
476 |
+
if act == "general-greet":
|
477 |
+
continue
|
478 |
+
d, a = act.split("-")
|
479 |
+
if d == "general" and d not in sys_act_dict:
|
480 |
+
sys_act_dict[d] = {}
|
481 |
+
if d == "booking":
|
482 |
+
d = turn_domain[0]
|
483 |
+
a = booking_act_map.get(a, a)
|
484 |
+
add_p = []
|
485 |
+
for param in params:
|
486 |
+
p = param[0]
|
487 |
+
if p == "none":
|
488 |
+
continue
|
489 |
+
elif ontology.da_abbr_to_slot_name.get(p):
|
490 |
+
p = ontology.da_abbr_to_slot_name[p]
|
491 |
+
if p not in add_p:
|
492 |
+
add_p.append(p)
|
493 |
+
add_to_last = True if a in ["request", "reqmore", "bye", "offerbook"] else False
|
494 |
+
if add_to_last:
|
495 |
+
add_to_last_collect.append((d, a, add_p))
|
496 |
+
else:
|
497 |
+
sys_act_dict[d][a] = add_p
|
498 |
+
for d, a, add_p in add_to_last_collect:
|
499 |
+
sys_act_dict[d][a] = add_p
|
500 |
+
|
501 |
+
for d in copy.copy(sys_act_dict):
|
502 |
+
acts = sys_act_dict[d]
|
503 |
+
if not acts:
|
504 |
+
del sys_act_dict[d]
|
505 |
+
if "inform" in acts and "offerbooked" in acts:
|
506 |
+
for s in sys_act_dict[d]["inform"]:
|
507 |
+
sys_act_dict[d]["offerbooked"].append(s)
|
508 |
+
del sys_act_dict[d]["inform"]
|
509 |
+
|
510 |
+
ordered_sysact_dict[fn][len(dial["log"])] = sys_act_dict
|
511 |
+
|
512 |
+
sys_act = []
|
513 |
+
if "general-greet" in dial_turn["dialog_act"]:
|
514 |
+
sys_act.extend(["[general]", "[greet]"])
|
515 |
+
for d, acts in sys_act_dict.items():
|
516 |
+
sys_act += ["[" + d + "]"]
|
517 |
+
for a, slots in acts.items():
|
518 |
+
self.unique_da[d + "-" + a] = 1
|
519 |
+
sys_act += ["[" + a + "]"]
|
520 |
+
sys_act += slots
|
521 |
+
|
522 |
+
# get db pointers
|
523 |
+
matnums = self.db.get_match_num(constraint_dict)
|
524 |
+
match_dom = turn_domain[0] if len(turn_domain) == 1 else turn_domain[1]
|
525 |
+
match = matnums[match_dom]
|
526 |
+
dbvec = self.db.addDBPointer(match_dom, match)
|
527 |
+
bkvec = self.db.addBookingPointer(dial_turn["dialog_act"])
|
528 |
+
|
529 |
+
single_turn["pointer"] = ",".join(
|
530 |
+
[str(d) for d in dbvec + bkvec]
|
531 |
+
) # 4 database pointer for domains, 2 for booking
|
532 |
+
single_turn["match"] = str(match)
|
533 |
+
single_turn["constraint"] = " ".join(constraints)
|
534 |
+
single_turn["cons_delex"] = " ".join(cons_delex)
|
535 |
+
single_turn["sys_act"] = " ".join(sys_act)
|
536 |
+
single_turn["turn_num"] = len(dial["log"])
|
537 |
+
single_turn["turn_domain"] = " ".join(["[" + d + "]" for d in turn_domain])
|
538 |
+
|
539 |
+
prev_turn_domain = copy.deepcopy(turn_domain)
|
540 |
+
prev_constraint_dict = copy.deepcopy(constraint_dict)
|
541 |
+
|
542 |
+
if "user" in single_turn:
|
543 |
+
dial["log"].append(single_turn)
|
544 |
+
for t in single_turn["user"].split() + single_turn["resp"].split() + constraints + sys_act:
|
545 |
+
self.vocab.add_word(t)
|
546 |
+
for t in single_turn["user_delex"].split():
|
547 |
+
if "[" in t and "]" in t and not t.startswith("[") and not t.endswith("]"):
|
548 |
+
single_turn["user_delex"].replace(t, t[t.index("[") : t.index("]") + 1])
|
549 |
+
elif not self.vocab.has_word(t):
|
550 |
+
self.vocab.add_word(t)
|
551 |
+
|
552 |
+
single_turn = {}
|
553 |
+
|
554 |
+
data[fn] = dial
|
555 |
+
# pprint(dial)
|
556 |
+
# if count == 20:
|
557 |
+
# break
|
558 |
+
self.vocab.construct()
|
559 |
+
self.vocab.save_vocab("data/preprocessed/UBAR/multi-woz-2.1-processed/vocab")
|
560 |
+
with open("data/interim/multi-woz-2.1-analysis/dialog_acts.json", "w") as f:
|
561 |
+
json.dump(ordered_sysact_dict, f, indent=2)
|
562 |
+
with open("data/interim/multi-woz-2.1-analysis/dialog_act_type.json", "w") as f:
|
563 |
+
json.dump(self.unique_da, f, indent=2)
|
564 |
+
return data
|
565 |
+
|
566 |
+
|
567 |
+
if __name__ == "__main__":
|
568 |
+
db_paths = {
|
569 |
+
"attraction": "db/raw/attraction_db.json",
|
570 |
+
"hospital": "db/raw/hospital_db.json",
|
571 |
+
"hotel": "db/raw/hotel_db.json",
|
572 |
+
"police": "db/raw/police_db.json",
|
573 |
+
"restaurant": "db/raw/restaurant_db.json",
|
574 |
+
"taxi": "db/raw/taxi_db.json",
|
575 |
+
"train": "db/raw/train_db.json",
|
576 |
+
}
|
577 |
+
# get_db_values('db/value_set.json') #
|
578 |
+
# preprocess_db(db_paths)
|
579 |
+
if not os.path.exists("data/preprocessed/UBAR/multi-woz-2.1-processed"):
|
580 |
+
os.mkdir("data/preprocessed/UBAR/multi-woz-2.1-processed")
|
581 |
+
dh = DataPreprocessor()
|
582 |
+
data = dh.preprocess_main()
|
583 |
+
|
584 |
+
with open("data/preprocessed/UBAR/multi-woz-2.1-processed/data_for_ubar.json", "w") as f:
|
585 |
+
json.dump(data, f, indent=2)
|
scripts/UBAR_code/train_ubar.py
ADDED
@@ -0,0 +1,697 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
import time
|
7 |
+
import warnings
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from torch.utils.tensorboard import SummaryWriter
|
13 |
+
from tqdm import tqdm
|
14 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
15 |
+
from transformers.optimization import AdamW, get_linear_schedule_with_warmup
|
16 |
+
|
17 |
+
import wandb
|
18 |
+
from crazyneuraluser.UBAR_code.config import global_config as cfg
|
19 |
+
from crazyneuraluser.UBAR_code.eval import MultiWozEvaluator
|
20 |
+
from crazyneuraluser.UBAR_code.reader import MultiWozReader
|
21 |
+
|
22 |
+
# from config21 import global_config as cfg # global, already initialized
|
23 |
+
|
24 |
+
|
25 |
+
warnings.filterwarnings("ignore")
|
26 |
+
|
27 |
+
|
28 |
+
class Model(object):
|
29 |
+
def __init__(self, device):
|
30 |
+
self.device = device
|
31 |
+
# initialize tokenizer
|
32 |
+
self.tokenizer = GPT2Tokenizer.from_pretrained(cfg.gpt_path)
|
33 |
+
# cfg.tokenizer = tokenizer
|
34 |
+
|
35 |
+
# initialize multiwoz reader
|
36 |
+
self.reader = MultiWozReader(self.tokenizer)
|
37 |
+
|
38 |
+
# create model: gpt2
|
39 |
+
self.model = GPT2LMHeadModel.from_pretrained(cfg.gpt_path)
|
40 |
+
if cfg.mode == "train":
|
41 |
+
self.model.resize_token_embeddings(len(self.tokenizer))
|
42 |
+
self.model.to(self.device) # single gpu
|
43 |
+
|
44 |
+
#
|
45 |
+
self.evaluator = MultiWozEvaluator(self.reader)
|
46 |
+
if cfg.save_log and cfg.mode == "train":
|
47 |
+
self.tb_writer = SummaryWriter(log_dir="./log")
|
48 |
+
else:
|
49 |
+
self.tb_writer = None
|
50 |
+
|
51 |
+
def get_optimizers(self):
|
52 |
+
"""
|
53 |
+
Setup the optimizer and the learning rate scheduler.
|
54 |
+
|
55 |
+
from transformers.Trainer
|
56 |
+
|
57 |
+
parameters from cfg: lr (1e-3); warmup_steps
|
58 |
+
"""
|
59 |
+
# Prepare optimizer and schedule (linear warmup and decay)
|
60 |
+
no_decay = ["bias", "LayerNorm.weight"]
|
61 |
+
optimizer_grouped_parameters = [
|
62 |
+
{
|
63 |
+
"params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
|
64 |
+
"weight_decay": cfg.weight_decay,
|
65 |
+
},
|
66 |
+
{
|
67 |
+
"params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
|
68 |
+
"weight_decay": 0.0,
|
69 |
+
},
|
70 |
+
]
|
71 |
+
optimizer = AdamW(optimizer_grouped_parameters, lr=cfg.lr)
|
72 |
+
num_training_steps = (
|
73 |
+
self.reader.set_stats["train"]["num_dials"]
|
74 |
+
* cfg.epoch_num
|
75 |
+
// (cfg.gradient_accumulation_steps * cfg.batch_size)
|
76 |
+
)
|
77 |
+
num_warmup_steps = cfg.warmup_steps if cfg.warmup_steps >= 0 else int(num_training_steps * 0.2)
|
78 |
+
scheduler = get_linear_schedule_with_warmup(
|
79 |
+
optimizer,
|
80 |
+
num_warmup_steps=num_warmup_steps,
|
81 |
+
num_training_steps=num_training_steps,
|
82 |
+
)
|
83 |
+
return optimizer, scheduler
|
84 |
+
|
85 |
+
def log_first_inputs(self, inputs):
|
86 |
+
tokenizer = self.tokenizer
|
87 |
+
logging.info("**** Input Examples: ****")
|
88 |
+
for context in inputs["contexts"][:4]:
|
89 |
+
# ubar = tokenizer.convert_ids_to_tokens(context)
|
90 |
+
# ubar = tokenizer.convert_tokens_to_string(context)
|
91 |
+
# ubar = " ".join(ubar)
|
92 |
+
ubar = tokenizer.decode(context)
|
93 |
+
logging.info(ubar)
|
94 |
+
|
95 |
+
def add_torch_input(self, inputs):
|
96 |
+
# to tensor and to device
|
97 |
+
contexts_tensor = torch.from_numpy(inputs["contexts_np"]).long()
|
98 |
+
contexts_tensor = contexts_tensor.to(self.device)
|
99 |
+
inputs["contexts_tensor"] = contexts_tensor
|
100 |
+
return inputs
|
101 |
+
|
102 |
+
def add_torch_input_eval(self, inputs):
|
103 |
+
# inputs: context
|
104 |
+
inputs["context_tensor"] = torch.tensor([inputs["context"]]).to(self.device)
|
105 |
+
return inputs
|
106 |
+
|
107 |
+
def calculate_loss_and_accuracy(self, outputs, labels):
|
108 |
+
# GPT2-chicahat/train.py
|
109 |
+
lm_logits = outputs[0]
|
110 |
+
|
111 |
+
shift_logits = lm_logits[..., :-1, :].contiguous()
|
112 |
+
shift_labels = labels[..., 1:].contiguous()
|
113 |
+
|
114 |
+
pad_id = cfg.pad_id
|
115 |
+
loss_fct = nn.CrossEntropyLoss(ignore_index=pad_id, reduction="sum")
|
116 |
+
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
117 |
+
|
118 |
+
# avg loss
|
119 |
+
not_ignore = shift_labels.ne(pad_id)
|
120 |
+
num_targets = not_ignore.long().sum().item()
|
121 |
+
|
122 |
+
loss /= num_targets
|
123 |
+
return loss
|
124 |
+
|
125 |
+
def train(self):
|
126 |
+
"""
|
127 |
+
UBARU
|
128 |
+
"""
|
129 |
+
|
130 |
+
wandb.init(
|
131 |
+
# Set the project where this run will be logged
|
132 |
+
project="E2E User Simulator (Alistair)",
|
133 |
+
entity="byrne-lab",
|
134 |
+
# We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
|
135 |
+
name=cfg.wandb_train_run_name,
|
136 |
+
# Track hyperparameters and run metadata
|
137 |
+
config={
|
138 |
+
"dataset": cfg.data_path,
|
139 |
+
"gpt_path": cfg.gpt_path,
|
140 |
+
"learning_rate": cfg.lr,
|
141 |
+
"warmup_steps": cfg.warmup_steps,
|
142 |
+
"gradient_accumulation_steps": cfg.gradient_accumulation_steps,
|
143 |
+
"batch_size": cfg.batch_size,
|
144 |
+
"epochs": cfg.epoch_num,
|
145 |
+
},
|
146 |
+
)
|
147 |
+
|
148 |
+
all_batches = self.reader.get_batches("train")
|
149 |
+
# compute num_training_steps in get_batches()
|
150 |
+
optimizer, scheduler = self.get_optimizers()
|
151 |
+
|
152 |
+
# log info
|
153 |
+
set_stats = self.reader.set_stats["train"]
|
154 |
+
logging.info("***** Running training *****")
|
155 |
+
logging.info(
|
156 |
+
" Num Training steps(one turn in a batch of dialogs) per epoch = %d",
|
157 |
+
set_stats["num_training_steps_per_epoch"],
|
158 |
+
)
|
159 |
+
logging.info(" Num Turns = %d", set_stats["num_turns"])
|
160 |
+
logging.info(" Num Dialogs = %d", set_stats["num_dials"])
|
161 |
+
logging.info(" Num Epochs = %d", cfg.epoch_num)
|
162 |
+
logging.info(" Batch size = %d", cfg.batch_size)
|
163 |
+
logging.info(" Gradient Accumulation steps = %d", cfg.gradient_accumulation_steps)
|
164 |
+
logging.info(
|
165 |
+
" Total optimization steps = %d",
|
166 |
+
set_stats["num_dials"] * cfg.epoch_num // (cfg.gradient_accumulation_steps * cfg.batch_size),
|
167 |
+
)
|
168 |
+
|
169 |
+
# tb writer
|
170 |
+
if self.tb_writer is not None:
|
171 |
+
self.tb_writer.add_text("cfg", json.dumps(cfg.__dict__, indent=2))
|
172 |
+
# self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
|
173 |
+
|
174 |
+
log_inputs = 2
|
175 |
+
global_step = 0
|
176 |
+
# sw = time.time()
|
177 |
+
|
178 |
+
for epoch in range(cfg.epoch_num):
|
179 |
+
epoch_step = 0
|
180 |
+
tr_loss = 0.0
|
181 |
+
logging_loss = 0.0
|
182 |
+
btm = time.time()
|
183 |
+
oom_time = 0
|
184 |
+
self.model.zero_grad()
|
185 |
+
|
186 |
+
data_iterator = self.reader.get_nontranspose_data_iterator(all_batches)
|
187 |
+
|
188 |
+
for batch_idx, dial_batch in enumerate(data_iterator):
|
189 |
+
inputs = self.reader.convert_batch_session(dial_batch)
|
190 |
+
try: # avoid OOM
|
191 |
+
self.model.train()
|
192 |
+
if log_inputs > 0: # log inputs for the very first two turns
|
193 |
+
self.log_first_inputs(inputs)
|
194 |
+
log_inputs -= 1
|
195 |
+
|
196 |
+
# to tensor
|
197 |
+
inputs = self.add_torch_input(inputs)
|
198 |
+
# loss
|
199 |
+
outputs = self.model(inputs["contexts_tensor"])
|
200 |
+
# outputs = self.model(inputs['contexts_tensor']) # debugging with GPT2Model
|
201 |
+
loss = self.calculate_loss_and_accuracy(outputs, labels=inputs["contexts_tensor"])
|
202 |
+
loss.backward()
|
203 |
+
tr_loss += loss.item()
|
204 |
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
|
205 |
+
epoch_step += 1
|
206 |
+
|
207 |
+
# step, wrt gradient_accumulation_steps, clip grad norm
|
208 |
+
if (epoch_step + 1) % cfg.gradient_accumulation_steps == 0 or (
|
209 |
+
# end of an epoch
|
210 |
+
(epoch_step + 1)
|
211 |
+
== set_stats["num_training_steps_per_epoch"]
|
212 |
+
):
|
213 |
+
optimizer.step()
|
214 |
+
scheduler.step()
|
215 |
+
optimizer.zero_grad()
|
216 |
+
# global_step: actual step the optimizer took
|
217 |
+
global_step += 1
|
218 |
+
|
219 |
+
logs = {} # for tb writer
|
220 |
+
# logging: loss, lr... after certain amount of steps
|
221 |
+
if cfg.report_interval > 0 and global_step % cfg.report_interval == 0:
|
222 |
+
loss_scalar = (tr_loss - logging_loss) / cfg.report_interval
|
223 |
+
logging_loss = tr_loss
|
224 |
+
logs["loss"] = loss_scalar
|
225 |
+
logging.info(
|
226 |
+
"Global step: {}, epoch step: {}, interval loss: {:.4f}".format(
|
227 |
+
global_step, epoch_step, loss_scalar
|
228 |
+
)
|
229 |
+
)
|
230 |
+
|
231 |
+
# validate
|
232 |
+
# add to tensorboard...
|
233 |
+
if cfg.evaluate_during_training and loss_scalar < 10:
|
234 |
+
results = self.validate(epoch)
|
235 |
+
for k, v in results.items():
|
236 |
+
eval_key = "eval_{}".format(k)
|
237 |
+
logs[eval_key] = v
|
238 |
+
|
239 |
+
if self.tb_writer:
|
240 |
+
for k, v in logs.items():
|
241 |
+
self.tb_writer.add_scalar(k, v, global_step)
|
242 |
+
# save model...
|
243 |
+
|
244 |
+
except RuntimeError as exception:
|
245 |
+
if "out of memory" in str(exception):
|
246 |
+
max_length = max(inputs["lengths"])
|
247 |
+
oom_time += 1
|
248 |
+
logging.info(
|
249 |
+
"WARNING: ran out of memory,times: {}, batch size: {}, max_len: {}".format(
|
250 |
+
oom_time, cfg.batch_size, max_length
|
251 |
+
)
|
252 |
+
)
|
253 |
+
if hasattr(torch.cuda, "empty_cache"):
|
254 |
+
torch.cuda.empty_cache()
|
255 |
+
else:
|
256 |
+
logging.info(str(exception))
|
257 |
+
raise exception
|
258 |
+
logging.info("Train epoch time: {:.2f} min, epoch loss: {:.4f}".format((time.time() - btm) / 60, tr_loss))
|
259 |
+
# save model after every epoch
|
260 |
+
# if epoch > 10 or tr_loss/epoch_step < 1:
|
261 |
+
self.save_model(epoch, tr_loss / epoch_step)
|
262 |
+
|
263 |
+
wandb.log({"epoch loss": tr_loss})
|
264 |
+
|
265 |
+
# Mark the run as finished on wandb
|
266 |
+
wandb.finish()
|
267 |
+
|
268 |
+
def save_model(self, epoch, loss):
|
269 |
+
save_path = os.path.join(cfg.exp_path, "epoch{}_trloss{:.2f}_gpt2".format(epoch + 1, loss))
|
270 |
+
if not os.path.exists(save_path):
|
271 |
+
os.mkdir(save_path)
|
272 |
+
logging.info("Saving model checkpoint to %s", save_path)
|
273 |
+
# save gpt2
|
274 |
+
self.model.save_pretrained(save_path)
|
275 |
+
# save tokenizer
|
276 |
+
self.tokenizer.save_pretrained(save_path)
|
277 |
+
# save cfg
|
278 |
+
|
279 |
+
def validate(self, data="dev", do_test=False, epoch=0):
|
280 |
+
|
281 |
+
if cfg.mode != "train":
|
282 |
+
wandb.init(
|
283 |
+
# Set the project where this run will be logged
|
284 |
+
project="E2E User Simulator (Alistair)",
|
285 |
+
entity="byrne-lab",
|
286 |
+
# We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
|
287 |
+
name=cfg.wandb_eval_run_name,
|
288 |
+
# Track hyperparameters and run metadata
|
289 |
+
config={
|
290 |
+
"eval_load_path": cfg.eval_load_path,
|
291 |
+
"dataset": cfg.data_path,
|
292 |
+
"gpt_path": cfg.gpt_path,
|
293 |
+
"learning_rate": cfg.lr,
|
294 |
+
"warmup_steps": cfg.warmup_steps,
|
295 |
+
"gradient_accumulation_steps": cfg.gradient_accumulation_steps,
|
296 |
+
"batch_size": cfg.batch_size,
|
297 |
+
"epochs": cfg.epoch_num,
|
298 |
+
"data": data,
|
299 |
+
},
|
300 |
+
)
|
301 |
+
|
302 |
+
test_data_at = wandb.Artifact(str(wandb.run.id + str(epoch)), type="predictions")
|
303 |
+
|
304 |
+
# Create your W&B Table
|
305 |
+
column_names = [
|
306 |
+
"dialog",
|
307 |
+
"turn_num",
|
308 |
+
"turn_domain",
|
309 |
+
"pointer",
|
310 |
+
"user",
|
311 |
+
"usdx",
|
312 |
+
"resp",
|
313 |
+
"bspn",
|
314 |
+
"bsdx",
|
315 |
+
"aspn",
|
316 |
+
"dspn",
|
317 |
+
"db",
|
318 |
+
"resp_gen",
|
319 |
+
"bspn_gen",
|
320 |
+
"aspn_gen",
|
321 |
+
"dspn_gen",
|
322 |
+
]
|
323 |
+
val_table = wandb.Table(columns=column_names)
|
324 |
+
|
325 |
+
# predict one dialog/ one turn at a time
|
326 |
+
self.model.eval()
|
327 |
+
|
328 |
+
# all_batches = self.reader.get_batches('dev')
|
329 |
+
# data_iterator = self.reader.get_data_iterator(all_batches)
|
330 |
+
eval_data = self.reader.get_eval_data(data)
|
331 |
+
|
332 |
+
set_stats = self.reader.set_stats[data]
|
333 |
+
logging.info("***** Running Evaluation *****")
|
334 |
+
logging.info(" Num Turns = %d", set_stats["num_turns"])
|
335 |
+
# logging.info(" Num Dialogs = %d", set_stats['num_dials'])
|
336 |
+
|
337 |
+
# valid_losses = []
|
338 |
+
btm = time.time()
|
339 |
+
result_collection = {}
|
340 |
+
with torch.no_grad():
|
341 |
+
# Adding this index to allow for quick testing of evaluation
|
342 |
+
dialogues_to_run = 1
|
343 |
+
for dial_idx, dialog in tqdm(enumerate(eval_data)):
|
344 |
+
if dialogues_to_run == 0:
|
345 |
+
break
|
346 |
+
dialogues_to_run -= 1
|
347 |
+
|
348 |
+
pv_turn = {}
|
349 |
+
for turn_idx, turn in enumerate(dialog):
|
350 |
+
first_turn = turn_idx == 0
|
351 |
+
inputs = self.reader.convert_turn_eval(turn, pv_turn, first_turn)
|
352 |
+
inputs = self.add_torch_input_eval(inputs)
|
353 |
+
|
354 |
+
# fail to generate new tokens, if max_length not set
|
355 |
+
context_length = len(inputs["context"])
|
356 |
+
if cfg.use_true_curr_bspn: # generate act, response
|
357 |
+
max_len = 60
|
358 |
+
if not cfg.use_true_curr_aspn:
|
359 |
+
max_len = 80
|
360 |
+
|
361 |
+
outputs = self.model.generate(
|
362 |
+
input_ids=inputs["context_tensor"],
|
363 |
+
max_length=context_length + max_len,
|
364 |
+
temperature=0.7, # top_p=0.9, num_beams=4,
|
365 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
366 |
+
eos_token_id=self.tokenizer.encode(["<eos_r>"])[0],
|
367 |
+
)
|
368 |
+
# no_repeat_ngram_size=4
|
369 |
+
# turn['generated'] = self.tokenizer.decode(outputs[0])
|
370 |
+
|
371 |
+
# resp_gen, need to trim previous context
|
372 |
+
generated = outputs[0].cpu().numpy().tolist()
|
373 |
+
generated = generated[context_length - 1 :]
|
374 |
+
|
375 |
+
try:
|
376 |
+
decoded = self.decode_generated_act_resp(generated)
|
377 |
+
except ValueError as exception:
|
378 |
+
logging.info(str(exception))
|
379 |
+
logging.info(self.tokenizer.decode(generated))
|
380 |
+
decoded = {"resp": [], "bspn": [], "aspn": []}
|
381 |
+
|
382 |
+
else: # predict bspn, access db, then generate act and resp
|
383 |
+
outputs = self.model.generate(
|
384 |
+
input_ids=inputs["context_tensor"],
|
385 |
+
max_length=context_length + 60,
|
386 |
+
temperature=0.7, # top_p=0.9, num_beams=4,
|
387 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
388 |
+
eos_token_id=self.tokenizer.encode(["<eos_b>"])[0],
|
389 |
+
)
|
390 |
+
generated_bs = outputs[0].cpu().numpy().tolist()
|
391 |
+
# generated_bs = generated_bs[context_length-1:]
|
392 |
+
bspn_gen = self.decode_generated_bspn(generated_bs[context_length - 1 :])
|
393 |
+
# check DB result
|
394 |
+
if cfg.use_true_db_pointer:
|
395 |
+
# db_result = self.reader.bspan_to_DBpointer(
|
396 |
+
# self.tokenizer.decode(turn['bspn']), turn['turn_domain'])
|
397 |
+
db = turn["db"]
|
398 |
+
else:
|
399 |
+
db_result = self.reader.bspan_to_DBpointer(
|
400 |
+
self.tokenizer.decode(bspn_gen), turn["turn_domain"]
|
401 |
+
)
|
402 |
+
db = self.tokenizer.convert_tokens_to_ids(
|
403 |
+
self.tokenizer.tokenize("<sos_db> " + db_result + " <eos_db>")
|
404 |
+
) + self.tokenizer.encode(["<sos_a>"])
|
405 |
+
inputs["context_tensor_db"] = torch.tensor([inputs["context"][:-1] + bspn_gen + db]).to(
|
406 |
+
self.device
|
407 |
+
)
|
408 |
+
context_length = len(inputs["context_tensor_db"][0])
|
409 |
+
outputs_db = self.model.generate(
|
410 |
+
input_ids=inputs["context_tensor_db"],
|
411 |
+
max_length=context_length + 80,
|
412 |
+
temperature=0.7, # top_p=0.9, num_beams=4,
|
413 |
+
pad_token_id=self.tokenizer.eos_token_id,
|
414 |
+
eos_token_id=self.tokenizer.encode(["<eos_r>"])[0],
|
415 |
+
)
|
416 |
+
generated_ar = outputs_db[0].cpu().numpy().tolist()
|
417 |
+
generated_ar = generated_ar[context_length - 1 :]
|
418 |
+
try:
|
419 |
+
decoded = self.decode_generated_act_resp(generated_ar)
|
420 |
+
decoded["bspn"] = bspn_gen
|
421 |
+
except ValueError:
|
422 |
+
# NOTE: the below logging is commented out because when running evaluation
|
423 |
+
# on early checkpoints of gpt2, the generated response is almost always
|
424 |
+
# missing <eos_b> and it kills the GPU due to constant decoding (plus it swamps the logs)
|
425 |
+
|
426 |
+
# logging.info(str(exception))
|
427 |
+
# logging.info(self.tokenizer.decode(generated_ar))
|
428 |
+
decoded = {"resp": [], "bspn": [], "aspn": []}
|
429 |
+
|
430 |
+
turn["resp_gen"] = decoded["resp"]
|
431 |
+
turn["bspn_gen"] = turn["bspn"] if cfg.use_true_curr_bspn else decoded["bspn"]
|
432 |
+
turn["aspn_gen"] = turn["aspn"] if cfg.use_true_curr_aspn else decoded["aspn"]
|
433 |
+
turn["dspn_gen"] = turn["dspn"]
|
434 |
+
|
435 |
+
# check DB results
|
436 |
+
# db_result = self.reader.bspan_to_DBpointer(self.tokenizer.decode(turn['bspn']),
|
437 |
+
# turn['turn_domain'])
|
438 |
+
# if db_result[0] == 1: # no match
|
439 |
+
# print('gt:', self.tokenizer.decode(turn['aspn']), '
|
440 |
+
# |gen:', self.tokenizer.decode(decoded['aspn']))
|
441 |
+
# print('gen_resp: ', self.tokenizer.decode(decoded['resp']))
|
442 |
+
# print('gt_resp: ', self.tokenizer.decode(turn['resp']), '\n')
|
443 |
+
|
444 |
+
# all true previous context
|
445 |
+
pv_turn["labels"] = inputs["labels"]
|
446 |
+
pv_turn["resp"] = turn["resp"] if cfg.use_true_prev_resp else decoded["resp"]
|
447 |
+
pv_turn["bspn"] = turn["bspn"] if cfg.use_true_prev_bspn else decoded["bspn"]
|
448 |
+
pv_turn["db"] = turn["db"] if cfg.use_true_curr_bspn else db
|
449 |
+
pv_turn["aspn"] = turn["aspn"] if cfg.use_true_prev_aspn else decoded["aspn"]
|
450 |
+
|
451 |
+
turn_result = self.reader.inverse_transpose_turn(dialog)
|
452 |
+
result_collection.update(turn_result)
|
453 |
+
|
454 |
+
for dialog, turns in turn_result.items():
|
455 |
+
for turn in turns:
|
456 |
+
curr_turn_plain = [
|
457 |
+
dialog,
|
458 |
+
turn["turn_num"],
|
459 |
+
turn["turn_domain"],
|
460 |
+
turn["pointer"],
|
461 |
+
]
|
462 |
+
curr_turn_tokenised = [
|
463 |
+
self.tokenizer.decode(turn[key])
|
464 |
+
for key in turn.keys()
|
465 |
+
if key != "pointer" and key != "turn_domain" and key != "turn_num"
|
466 |
+
]
|
467 |
+
curr_turn_data = curr_turn_plain + curr_turn_tokenised
|
468 |
+
val_table.add_data(*curr_turn_data)
|
469 |
+
|
470 |
+
logging.info("inference time: {:.2f} min".format((time.time() - btm) / 60))
|
471 |
+
# score
|
472 |
+
btm = time.time()
|
473 |
+
results, _ = self.reader.wrap_result_lm(result_collection)
|
474 |
+
bleu, success, match = self.evaluator.validation_metric(results)
|
475 |
+
logging.info("Scoring time: {:.2f} min".format((time.time() - btm) / 60))
|
476 |
+
score = 0.5 * (success + match) + bleu
|
477 |
+
# valid_loss = 130 - score
|
478 |
+
logging.info(
|
479 |
+
"validation [CTR] match: %2.2f success: %2.2f bleu: %2.2f score: %.2f" % (match, success, bleu, score)
|
480 |
+
)
|
481 |
+
eval_results = {}
|
482 |
+
eval_results["bleu"] = bleu
|
483 |
+
eval_results["success"] = success
|
484 |
+
eval_results["match"] = match
|
485 |
+
eval_results["score"] = score
|
486 |
+
eval_results["result"] = "validation [CTR] match: %2.2f success: %2.2f bleu: %2.2f score: %.2f" % (
|
487 |
+
match,
|
488 |
+
success,
|
489 |
+
bleu,
|
490 |
+
score,
|
491 |
+
)
|
492 |
+
|
493 |
+
wandb.log(
|
494 |
+
{
|
495 |
+
"bleu": eval_results["bleu"],
|
496 |
+
"success": eval_results["success"],
|
497 |
+
"match": eval_results["match"],
|
498 |
+
"score": eval_results["score"],
|
499 |
+
}
|
500 |
+
)
|
501 |
+
|
502 |
+
model_setting, epoch_setting = (
|
503 |
+
cfg.eval_load_path.split("/")[1],
|
504 |
+
cfg.eval_load_path.split("/")[2],
|
505 |
+
)
|
506 |
+
eval_on = "-".join(cfg.exp_domains)
|
507 |
+
if data == "test":
|
508 |
+
eval_on += "_test"
|
509 |
+
if not os.path.exists(cfg.log_path):
|
510 |
+
os.mkdir(cfg.log_path)
|
511 |
+
log_file_name = os.path.join(cfg.log_path, model_setting + "-" + eval_on + ".json")
|
512 |
+
if os.path.exists(log_file_name):
|
513 |
+
eval_to_json = json.load(open(log_file_name, "r"))
|
514 |
+
eval_to_json[epoch_setting] = eval_results
|
515 |
+
json.dump(eval_to_json, open(log_file_name, "w"), indent=2)
|
516 |
+
else:
|
517 |
+
eval_to_json = {}
|
518 |
+
eval_to_json[epoch_setting] = eval_results
|
519 |
+
json.dump(eval_to_json, open(log_file_name, "w"), indent=2)
|
520 |
+
logging.info("update eval results to {}".format(log_file_name))
|
521 |
+
|
522 |
+
# log predictions table to wandb, giving it a name
|
523 |
+
test_data_at.add(val_table, "predictions")
|
524 |
+
wandb.run.log_artifact(test_data_at)
|
525 |
+
|
526 |
+
if cfg.mode != "train":
|
527 |
+
# Mark the run as finished on wandb
|
528 |
+
wandb.finish()
|
529 |
+
|
530 |
+
return eval_results
|
531 |
+
|
532 |
+
def decode_generated_act_resp(self, generated):
|
533 |
+
"""
|
534 |
+
decode generated
|
535 |
+
return decoded['resp'] ('bspn', 'aspn')
|
536 |
+
"""
|
537 |
+
decoded = {}
|
538 |
+
eos_a_id = self.tokenizer.encode(["<eos_a>"])[0]
|
539 |
+
eos_r_id = self.tokenizer.encode(["<eos_r>"])[0]
|
540 |
+
# eos_b_id = self.tokenizer.encode(["<eos_b>"])[0]
|
541 |
+
|
542 |
+
# eos_r may not exists if gpt2 generated repetitive words.
|
543 |
+
if eos_r_id in generated:
|
544 |
+
eos_r_idx = generated.index(eos_r_id)
|
545 |
+
else:
|
546 |
+
eos_r_idx = len(generated) - 1
|
547 |
+
# NOTE: the below logging is commented out because when running evaluation
|
548 |
+
# on early checkpoints of gpt2, the generated response is almost always missing
|
549 |
+
# <eos_r> and it kills the GPU due to constant decoding (plus it swamps the logs)
|
550 |
+
|
551 |
+
# logging.info('eos_r not in generated: ' +
|
552 |
+
# self.tokenizer.decode(generated))
|
553 |
+
|
554 |
+
if cfg.use_true_curr_aspn: # only predict resp
|
555 |
+
decoded["resp"] = generated[: eos_r_idx + 1]
|
556 |
+
else: # predicted aspn, resp
|
557 |
+
eos_a_idx = generated.index(eos_a_id)
|
558 |
+
decoded["aspn"] = generated[: eos_a_idx + 1]
|
559 |
+
decoded["resp"] = generated[eos_a_idx + 1 : eos_r_idx + 1]
|
560 |
+
# if cfg.use_true_curr_bspn:
|
561 |
+
|
562 |
+
# else: # predict bspn aspn resp
|
563 |
+
# eos_b_idx = generated.index(eos_b_id)
|
564 |
+
# eos_a_idx = generated.index(eos_a_id)
|
565 |
+
# decoded['bspn'] = generated[: eos_b_idx+1]
|
566 |
+
# decoded['aspn'] = generated[eos_b_idx+1: eos_a_idx+1]
|
567 |
+
# decoded['resp'] = generated[eos_a_idx+1: eos_r_idx+1]
|
568 |
+
return decoded
|
569 |
+
|
570 |
+
def decode_generated_bspn(self, generated):
|
571 |
+
eos_b_id = self.tokenizer.encode(["<eos_b>"])[0]
|
572 |
+
if eos_b_id in generated:
|
573 |
+
eos_b_idx = generated.index(eos_b_id)
|
574 |
+
else:
|
575 |
+
eos_b_idx = len(generated) - 1
|
576 |
+
return generated[: eos_b_idx + 1]
|
577 |
+
|
578 |
+
|
579 |
+
def parse_arg_cfg(args):
|
580 |
+
# add args to cfg
|
581 |
+
if args.cfg:
|
582 |
+
for pair in args.cfg:
|
583 |
+
k, v = tuple(pair.split("="))
|
584 |
+
dtype = type(getattr(cfg, k))
|
585 |
+
if dtype == type(None):
|
586 |
+
raise ValueError()
|
587 |
+
if dtype is bool:
|
588 |
+
v = False if v == "False" else True
|
589 |
+
elif dtype is list:
|
590 |
+
v = v.split(",")
|
591 |
+
if k == "cuda_device":
|
592 |
+
v = [int(no) for no in v]
|
593 |
+
else:
|
594 |
+
v = dtype(v)
|
595 |
+
setattr(cfg, k, v)
|
596 |
+
return
|
597 |
+
|
598 |
+
|
599 |
+
def main():
|
600 |
+
if not os.path.exists("./models/UBAR/experiments"):
|
601 |
+
os.mkdir("./models/UBAR/experiments")
|
602 |
+
|
603 |
+
if not os.path.exists("./models/UBAR/experiments_21"):
|
604 |
+
os.mkdir("./models/UBAR/experiments_21")
|
605 |
+
|
606 |
+
parser = argparse.ArgumentParser()
|
607 |
+
parser.add_argument("-mode")
|
608 |
+
parser.add_argument("-cfg", nargs="*")
|
609 |
+
args = parser.parse_args()
|
610 |
+
|
611 |
+
cfg.mode = args.mode
|
612 |
+
if args.mode == "test" or args.mode == "adjust":
|
613 |
+
parse_arg_cfg(args)
|
614 |
+
# cfg.model_path = cfg.eval_load_path
|
615 |
+
cfg.gpt_path = cfg.eval_load_path
|
616 |
+
else: # train
|
617 |
+
|
618 |
+
parse_arg_cfg(args)
|
619 |
+
if cfg.exp_path in ["", "to be generated"]:
|
620 |
+
# log file path, control the factors: seed, learning_rate, batch_size,
|
621 |
+
# early_stop_count, weight decay... cfg.exp_path = 'experiments/
|
622 |
+
# {}_{}_sd{}_lr{}_bs{}_sp{}_dc{}/'.format('-'.join(cfg.exp_domains),
|
623 |
+
# cfg.exp_no, cfg.seed, cfg.lr, cfg.batch_size,
|
624 |
+
# cfg.early_stop_count, cfg.weight_decay_count)
|
625 |
+
|
626 |
+
experiments_path = (
|
627 |
+
"./models/UBAR/experiments" if "all" in cfg.exp_domains else "./models/experiments_Xdomain"
|
628 |
+
)
|
629 |
+
cfg.exp_path = os.path.join(
|
630 |
+
experiments_path,
|
631 |
+
"{}_{}_sd{}_lr{}_bs{}_ga{}".format(
|
632 |
+
"-".join(cfg.exp_domains),
|
633 |
+
cfg.exp_no,
|
634 |
+
cfg.seed,
|
635 |
+
cfg.lr,
|
636 |
+
cfg.batch_size,
|
637 |
+
cfg.gradient_accumulation_steps,
|
638 |
+
),
|
639 |
+
)
|
640 |
+
logging.info("save path:", cfg.exp_path)
|
641 |
+
if cfg.save_log:
|
642 |
+
if not os.path.exists(cfg.exp_path):
|
643 |
+
os.mkdir(cfg.exp_path)
|
644 |
+
|
645 |
+
# to gpt later
|
646 |
+
cfg.model_path = os.path.join(cfg.exp_path, "model.pkl")
|
647 |
+
cfg.result_path = os.path.join(cfg.exp_path, "result.csv")
|
648 |
+
cfg.vocab_path_eval = os.path.join(cfg.exp_path, "vocab")
|
649 |
+
cfg.eval_load_path = cfg.exp_path
|
650 |
+
|
651 |
+
cfg._init_logging_handler(args.mode)
|
652 |
+
if cfg.cuda:
|
653 |
+
if len(cfg.cuda_device) == 1:
|
654 |
+
cfg.multi_gpu = False
|
655 |
+
# torch.cuda.set_device(cfg.cuda_device[0])
|
656 |
+
device = torch.device("cuda:{}".format(cfg.cuda_device[0]))
|
657 |
+
else:
|
658 |
+
pass # multi-gpu
|
659 |
+
else:
|
660 |
+
device = torch.device("cpu")
|
661 |
+
# logging.info('Device: {}'.format(torch.cuda.current_device()))
|
662 |
+
|
663 |
+
# fix random seed
|
664 |
+
torch.manual_seed(cfg.seed)
|
665 |
+
torch.cuda.manual_seed(cfg.seed)
|
666 |
+
random.seed(cfg.seed)
|
667 |
+
np.random.seed(cfg.seed)
|
668 |
+
|
669 |
+
# initialize model
|
670 |
+
m = Model(device)
|
671 |
+
|
672 |
+
if args.mode == "train": # train
|
673 |
+
if cfg.save_log: # save cfg details.
|
674 |
+
pass
|
675 |
+
m.train()
|
676 |
+
else: # test
|
677 |
+
logging.info(
|
678 |
+
"Generate setting: \n\t use true_prev_bspn={} \n\t use true_prev_aspn={} \n\t use true_db_pointer={} \
|
679 |
+
\n\t use true_prev_resp={} \n\t use true_curr_bspn={} \n\t use true_curr_aspn={} \
|
680 |
+
\n\t use_all_previous_context={}".format(
|
681 |
+
cfg.use_true_prev_bspn,
|
682 |
+
cfg.use_true_prev_aspn,
|
683 |
+
cfg.use_true_db_pointer,
|
684 |
+
cfg.use_true_prev_resp,
|
685 |
+
cfg.use_true_curr_bspn,
|
686 |
+
cfg.use_true_curr_aspn,
|
687 |
+
cfg.use_all_previous_context,
|
688 |
+
)
|
689 |
+
)
|
690 |
+
|
691 |
+
logging.info("Running eval on test")
|
692 |
+
m.validate(cfg.eval_set)
|
693 |
+
logging.info("Evaluation finished")
|
694 |
+
|
695 |
+
|
696 |
+
if __name__ == "__main__":
|
697 |
+
main()
|
scripts/agent_agent.yaml
ADDED
File without changes
|
scripts/crazyneuraluser.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: crazyneuraluser
|
3 |
+
Version: 0.0.post1.dev55+g3c295fb.d20220606
|
4 |
+
Summary: Add a short description here!
|
5 |
+
Home-page: https://github.com/pyscaffold/pyscaffold/
|
6 |
+
Author: Extended by Alistair McLeay, original code by Alexandru Coca
|
7 |
+
Author-email: am@alistairmcleay.com and alexcoca23@yahoo.co.uk
|
8 |
+
License: MIT
|
9 |
+
Project-URL: Documentation, https://pyscaffold.org/
|
10 |
+
Platform: any
|
11 |
+
Classifier: Development Status :: 4 - Beta
|
12 |
+
Classifier: Programming Language :: Python
|
13 |
+
Description-Content-Type: text/markdown; charset=UTF-8; variant=GFM
|
14 |
+
Provides-Extra: testing
|
15 |
+
License-File: LICENSE.txt
|
16 |
+
License-File: AUTHORS.md
|
17 |
+
|
18 |
+
# Cambridge Masters Project
|
19 |
+
Joint Learning of Practical Dialogue Systems and User Simulators
|
20 |
+
|
21 |
+
## Environment setup
|
22 |
+
|
23 |
+
1. Create an environment `crazyneuraluser` with the help of [conda]
|
24 |
+
```
|
25 |
+
conda env create -f environment.yml
|
26 |
+
```
|
27 |
+
2. Activate the new environment with:
|
28 |
+
```
|
29 |
+
conda activate crazyneuraluser
|
30 |
+
```
|
31 |
+
3. Install a version of `pytorch` compatible with your hardware (see the [pytorch website](https://pytorch.org/get-started/previous-versions/)). E.g.:
|
32 |
+
```
|
33 |
+
pip install torch --extra-index-url https://download.pytorch.org/whl/cu113
|
34 |
+
```
|
35 |
+
|
36 |
+
4. Install `spacy` and download the tokenization tool in spacy:
|
37 |
+
```
|
38 |
+
pip install spacy'
|
39 |
+
python -m spacy download en_core_web_sm
|
40 |
+
```
|
41 |
+
|
42 |
+
### Generating dialogues through agent-agent interaction
|
43 |
+
|
44 |
+
To generate dialogues, first change working directory to the `baselines` directory. Run the command
|
45 |
+
```
|
46 |
+
python baselines_setup.py
|
47 |
+
```
|
48 |
+
to prepare `convlab2` for running the baselines.
|
49 |
+
|
50 |
+
#### Generating dialogues conditioned on randomly sampled goals
|
51 |
+
|
52 |
+
Select one of the available configurations in the `configs` directory and run the command
|
53 |
+
```
|
54 |
+
python simulate_agent_interaction.py --config /rel/path/to/chosen/config
|
55 |
+
```
|
56 |
+
to generate dialogues conditioned on randomly sampled goals according to the `convlab2` goal model. The dialogues will be be saved automatically in the `models` directory, under a directory whose name depends on the configuration run. The `models` directory is located in the parent directory of the `baselines` directory. The `metadata.json` file saved with the dialogues contains information about the data generation process.
|
57 |
+
|
58 |
+
#### Generating dialogues conditioned on `MultiWOZ2.1` goals
|
59 |
+
|
60 |
+
To generate the entire corpus, simply pass the `--goals-path /path/to/multiwoz2.1/data.json/file` flag to `simulate_agent_interaction.py`. To generate the `test/val` split additionally pass the `--filter-path /path/to/multiwoz2.1/test-or-valListFile` argument to `simulate_agent_interaction.py`. You can use the `generate_multiwoz21_train_id_file` function in `baselines/utils.py` to generate `trainListFile` which can then be passed via the `--filter-path` argument to the dialogue generation script in order to generate dialogues conditioned on the `MultiWOZ2.1` training goals.
|
61 |
+
|
62 |
+
### Converting the generated dialogues to SGD-like format
|
63 |
+
|
64 |
+
The `create_data_from_multiwoz.py` script can be used to convert the generated dialogues to SGD format, necessary for evaluation. It is based on the script provided by Google for DSTC8, but with additional functionality such as:
|
65 |
+
|
66 |
+
- conversion of slot names as annotated in the MultiWOZ 2.1 dialogue acts to different slot names, specified through the `--slots_convention` argument. Options are `multiwoz22` to convert the slots to the same slots as defined in the MultiWOZ 2.2 dataset whreas the `multiwoz_goals` converts the slot names to the names used in the dialogue goal and state tracking annotations.
|
67 |
+
|
68 |
+
- addition of system and user `nlu` fields for every turn
|
69 |
+
|
70 |
+
- option to perform cleaning operations on the goals to ensure a standard format is received by the evaluator.
|
71 |
+
|
72 |
+
The conversion is done according to the `schema.json` file in the `baselines` directory, which is the same as used by `DSTC8` conversion except for the addition of the `police` domain. Type ``python create_data_from_multiwoz.py --helpfull`` to see a full list of flags and usage.
|
73 |
+
|
74 |
+
## Installation
|
75 |
+
|
76 |
+
The recommended way to use this repository is to develop the core code under `src/crazyneuraluser`. The experiments/exporatory analysis making use of the core package code should be placed outside the library and imported. See more guidance under the [Project Organisation](#project-organization) section below.
|
77 |
+
|
78 |
+
To create an environment for the package, make sure you have deactivated all `conda` environments. Then:
|
79 |
+
|
80 |
+
1. Create an environment `crazyneuraluser` with the help of [conda]:
|
81 |
+
```
|
82 |
+
conda env create -f environment.yml
|
83 |
+
```
|
84 |
+
2. Add the developer dependencies to this environment with the help of [conda]:
|
85 |
+
```
|
86 |
+
conda env update -f dev_environment.yml
|
87 |
+
```
|
88 |
+
|
89 |
+
Optional and needed only once after `git clone`:
|
90 |
+
|
91 |
+
3. install several [pre-commit] git hooks with:
|
92 |
+
```bash
|
93 |
+
pre-commit install
|
94 |
+
# You _are encouraged_ to run `pre-commit autoupdate`
|
95 |
+
```
|
96 |
+
and checkout the configuration under `.pre-commit-config.yaml`.
|
97 |
+
The `-n, --no-verify` flag of `git commit` can be used to deactivate pre-commit hooks temporarily.
|
98 |
+
|
99 |
+
4. install [nbstripout] git hooks to remove the output cells of committed notebooks with:
|
100 |
+
```bash
|
101 |
+
nbstripout --install --attributes notebooks/.gitattributes
|
102 |
+
```
|
103 |
+
This is useful to avoid large diffs due to plots in your notebooks.
|
104 |
+
A simple `nbstripout --uninstall` will revert these changes.
|
105 |
+
|
106 |
+
Then take a look into the `scripts` and `notebooks` folders.
|
107 |
+
|
108 |
+
## Dependency Management & Reproducibility
|
109 |
+
|
110 |
+
1. Always keep your abstract (unpinned) dependencies updated in `environment.yml` and eventually
|
111 |
+
in `setup.cfg` if you want to ship and install your package via `pip` later on.
|
112 |
+
2. Create concrete dependencies as `environment.lock.yml` for the exact reproduction of your
|
113 |
+
environment with:
|
114 |
+
```bash
|
115 |
+
conda env export -n crazyneuraluser -f environment.lock.yml
|
116 |
+
```
|
117 |
+
For multi-OS development, consider using `--no-builds` during the export.
|
118 |
+
3. Update your current environment with respect to a new `environment.lock.yml` using:
|
119 |
+
```bash
|
120 |
+
conda env update -f environment.lock.yml --prune
|
121 |
+
```
|
122 |
+
## Project Organization
|
123 |
+
|
124 |
+
```
|
125 |
+
├── AUTHORS.md <- List of developers and maintainers.
|
126 |
+
├── CHANGELOG.md <- Changelog to keep track of new features and fixes.
|
127 |
+
├── LICENSE.txt <- License as chosen on the command-line.
|
128 |
+
├── README.md <- The top-level README for developers.
|
129 |
+
├── configs <- Directory for configurations of model & application.
|
130 |
+
├── data
|
131 |
+
│ ├── external <- Data from third party sources.
|
132 |
+
│ ├── interim <- Intermediate data that has been transformed.
|
133 |
+
│ ├── processed <- The final, canonical data sets for modeling.
|
134 |
+
│ └── raw <- The original, immutable data dump.
|
135 |
+
├── docs <- Directory for Sphinx documentation in rst or md.
|
136 |
+
├── environment.yml <- The conda environment file for reproducibility.
|
137 |
+
├── models <- Trained and serialized models, model predictions,
|
138 |
+
│ or model summaries.
|
139 |
+
├── notebooks <- Jupyter notebooks. Naming convention is a number (for
|
140 |
+
│ ordering), the creator's initials and a description,
|
141 |
+
│ e.g. `1.0-fw-initial-data-exploration`.
|
142 |
+
├── pyproject.toml <- Build system configuration. Do not change!
|
143 |
+
├── references <- Data dictionaries, manuals, and all other materials.
|
144 |
+
├── reports <- Generated analysis as HTML, PDF, LaTeX, etc.
|
145 |
+
│ └── figures <- Generated plots and figures for reports.
|
146 |
+
├── scripts <- Analysis and production scripts which import the
|
147 |
+
│ actual Python package, e.g. train_model.py.
|
148 |
+
├── setup.cfg <- Declarative configuration of your project.
|
149 |
+
├── setup.py <- Use `pip install -e .` to install for development or
|
150 |
+
| or create a distribution with `tox -e build`.
|
151 |
+
├── src
|
152 |
+
│ └── crazyneuraluser <- Actual Python package where the main functionality goes.
|
153 |
+
├── tests <- Unit tests which can be run with `py.test`.
|
154 |
+
├── .coveragerc <- Configuration for coverage reports of unit tests.
|
155 |
+
├── .isort.cfg <- Configuration for git hook that sorts imports.
|
156 |
+
└── .pre-commit-config.yaml <- Configuration of pre-commit git hooks.
|
157 |
+
```
|
158 |
+
|
159 |
+
<!-- pyscaffold-notes -->
|
160 |
+
|
161 |
+
## Note
|
162 |
+
|
163 |
+
This project has been set up using [PyScaffold] 4.0.1 and the [dsproject extension] 0.6.1.
|
164 |
+
|
165 |
+
[conda]: https://docs.conda.io/
|
166 |
+
[pre-commit]: https://pre-commit.com/
|
167 |
+
[Jupyter]: https://jupyter.org/
|
168 |
+
[nbstripout]: https://github.com/kynan/nbstripout
|
169 |
+
[Google style]: http://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings
|
170 |
+
[PyScaffold]: https://pyscaffold.org/
|
171 |
+
[dsproject extension]: https://github.com/pyscaffold/pyscaffoldext-dsproject
|
scripts/crazyneuraluser.egg-info/SOURCES.txt
ADDED
File without changes
|
scripts/crazyneuraluser.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
scripts/crazyneuraluser.egg-info/not-zip-safe
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
scripts/crazyneuraluser.egg-info/requires.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers==4.18.0
|
2 |
+
tqdm==4.64.0
|
3 |
+
wandb==0.12.16
|
4 |
+
nltk==3.7
|
5 |
+
sklearn==0.0
|
6 |
+
tensorboard==2.9.0
|
7 |
+
spacy==3.3.0
|
8 |
+
|
9 |
+
[:python_version < "3.8"]
|
10 |
+
importlib-metadata
|
11 |
+
|
12 |
+
[testing]
|
13 |
+
setuptools
|
14 |
+
pytest
|
15 |
+
pytest-cov
|
scripts/crazyneuraluser.egg-info/top_level.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
crazyneuraluser
|
scripts/simulate_interaction.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import traceback
|
3 |
+
import pandas as pd
|
4 |
+
|
5 |
+
# from tqdm import tqdm
|
6 |
+
from UBAR_code.interaction import UBAR_interact
|
7 |
+
from user_model_code.interaction import multiwoz_interact
|
8 |
+
from UBAR_code.interaction.UBAR_interact import bcolors
|
9 |
+
|
10 |
+
|
11 |
+
def instantiate_agents():
|
12 |
+
|
13 |
+
UBAR_checkpoint_path = "models/UBAR/experiments/distilgpt-2_sd11_lr0.0001_bs16_ga2/epoch50_trloss0.59_gpt2"
|
14 |
+
user_model_checkpoint_path = "models/user_model/MultiWOZ-full_checkpoint_step340k"
|
15 |
+
|
16 |
+
sys_model = UBAR_interact.UbarSystemModel(
|
17 |
+
"UBAR_sys_model", UBAR_checkpoint_path, "scripts/UBAR_code/interaction/config.yaml"
|
18 |
+
)
|
19 |
+
|
20 |
+
user_model = multiwoz_interact.NeuralAgent(
|
21 |
+
"user", user_model_checkpoint_path, "scripts/user_model_code/interaction/config.yaml"
|
22 |
+
)
|
23 |
+
|
24 |
+
return sys_model, user_model
|
25 |
+
|
26 |
+
|
27 |
+
def read_multiwoz_data():
|
28 |
+
"""
|
29 |
+
Read the multiwoz 2.0 raw data from the .json file
|
30 |
+
"""
|
31 |
+
raw_mwoz_20_path = "data/raw/UBAR/multi-woz/data.json"
|
32 |
+
df_raw_mwoz = pd.read_json(raw_mwoz_20_path)
|
33 |
+
return df_raw_mwoz
|
34 |
+
|
35 |
+
|
36 |
+
def load_test_val_lists():
|
37 |
+
val_list_file = "data/raw/UBAR/multi-woz/valListFile.json"
|
38 |
+
test_list_file = "data/raw/UBAR/multi-woz/testListFile.json"
|
39 |
+
|
40 |
+
with open(val_list_file, "r") as f:
|
41 |
+
val_list = f.readlines()
|
42 |
+
val_list = [x.strip() for x in val_list]
|
43 |
+
|
44 |
+
with open(test_list_file, "r") as f:
|
45 |
+
test_list = f.readlines()
|
46 |
+
test_list = [x.strip() for x in test_list]
|
47 |
+
|
48 |
+
return val_list, test_list
|
49 |
+
|
50 |
+
|
51 |
+
def main(
|
52 |
+
write_to_file=False, ground_truth_system_responses=False, train_only=True, n_dialogues="all", log_successes=False
|
53 |
+
):
|
54 |
+
sys_model, user_model = instantiate_agents()
|
55 |
+
|
56 |
+
# TODO: move hardcoded vars into config file
|
57 |
+
raw_mwoz_20_path = "data/raw/UBAR/multi-woz/data.json"
|
58 |
+
user_utterances_out_path = "data/preprocessed/UBAR/user_utterances_from_simulator.txt"
|
59 |
+
logging_successes_path = "data/preprocessed/UBAR/logging_successes"
|
60 |
+
sys_model.print_intermediary_info = False
|
61 |
+
user_model.print_intermediary_info = False
|
62 |
+
|
63 |
+
df_raw_mwoz = pd.read_json(raw_mwoz_20_path)
|
64 |
+
if n_dialogues == "all":
|
65 |
+
n_dialogues = len(df_raw_mwoz.columns)
|
66 |
+
|
67 |
+
curr_dialogue_user_utterances_formatted = []
|
68 |
+
|
69 |
+
print("Loading goals...")
|
70 |
+
goals = multiwoz_interact.read_multiWOZ_20_goals(raw_mwoz_20_path, n_dialogues)
|
71 |
+
|
72 |
+
# Write column headers
|
73 |
+
if write_to_file:
|
74 |
+
with open(user_utterances_out_path, "w") as f:
|
75 |
+
f.write("Dialogue #\tDialogue ID\tTurn #\tSystem Response\n")
|
76 |
+
|
77 |
+
print("Loading data...")
|
78 |
+
df_mwoz_data = read_multiwoz_data()
|
79 |
+
val_list, test_list = load_test_val_lists()
|
80 |
+
|
81 |
+
successful_dialogues = 0
|
82 |
+
total_dialogues_generated = 0 # train dialogues only
|
83 |
+
for dialogue_idx, (goal, dialogue_filename) in enumerate(zip(goals, df_mwoz_data.columns)):
|
84 |
+
if log_successes:
|
85 |
+
# log successful_dialogues to logging_successes_path every 100 dialogues
|
86 |
+
if dialogue_idx % 100 == 0:
|
87 |
+
with open(logging_successes_path, "w") as f:
|
88 |
+
f.write(str(successful_dialogues) + " / " + str(total_dialogues_generated))
|
89 |
+
|
90 |
+
curr_dialogue_user_utterances_formatted = []
|
91 |
+
if train_only:
|
92 |
+
if dialogue_filename in val_list or dialogue_filename in test_list:
|
93 |
+
continue
|
94 |
+
|
95 |
+
total_dialogues_generated += 1
|
96 |
+
print("Dialogue: {}".format(dialogue_filename))
|
97 |
+
|
98 |
+
# There are occasionally exceptions thrown from one of the agents, usually the user
|
99 |
+
# In this case we simply continue to the next dialogue
|
100 |
+
try:
|
101 |
+
# Reset state after each dialogue
|
102 |
+
sys_model.init_session()
|
103 |
+
user_model.init_session(ini_goal=goal)
|
104 |
+
sys_response = ""
|
105 |
+
|
106 |
+
for turn_idx in range(50):
|
107 |
+
# Turn idx in this case represents the turn as one user utterance AND one system response
|
108 |
+
usr_response_raw_data_idx = turn_idx * 2
|
109 |
+
sys_response_raw_data_idx = turn_idx * 2 + 1
|
110 |
+
|
111 |
+
user_utterance = user_model.response(sys_response)
|
112 |
+
print(bcolors.OKBLUE + "User: " + bcolors.ENDC + user_utterance)
|
113 |
+
|
114 |
+
if write_to_file:
|
115 |
+
user_utterance = user_utterance.replace("\n", " ")
|
116 |
+
curr_dialogue_user_utterances_formatted.append(
|
117 |
+
str(dialogue_idx)
|
118 |
+
+ "\t"
|
119 |
+
+ dialogue_filename
|
120 |
+
+ "\t"
|
121 |
+
+ str(usr_response_raw_data_idx)
|
122 |
+
+ "\t"
|
123 |
+
+ user_utterance
|
124 |
+
+ "\n"
|
125 |
+
)
|
126 |
+
|
127 |
+
if user_model.is_terminated():
|
128 |
+
successful_dialogues += 1
|
129 |
+
print(bcolors.OKCYAN + "Dialogue terminated successfully!" + bcolors.ENDC)
|
130 |
+
print(bcolors.OKCYAN + "---" * 30 + bcolors.ENDC + "\n")
|
131 |
+
if write_to_file:
|
132 |
+
# Write whole dialogue to file
|
133 |
+
with open(user_utterances_out_path, "a") as f:
|
134 |
+
for line in curr_dialogue_user_utterances_formatted:
|
135 |
+
f.write(line)
|
136 |
+
break
|
137 |
+
|
138 |
+
# Next turn materials
|
139 |
+
if ground_truth_system_responses:
|
140 |
+
# If we are at the end of the ground truth dialogues
|
141 |
+
if len(df_mwoz_data.iloc[:, dialogue_idx].log) <= sys_response_raw_data_idx:
|
142 |
+
print(bcolors.RED + "Dialogue terminated unsuccessfully!" + bcolors.ENDC)
|
143 |
+
print(bcolors.RED + "---" * 30 + bcolors.ENDC + "\n")
|
144 |
+
break
|
145 |
+
sys_response = df_mwoz_data.iloc[:, dialogue_idx].log[sys_response_raw_data_idx]["text"]
|
146 |
+
else:
|
147 |
+
sys_response = sys_model.response(user_utterance, turn_idx)
|
148 |
+
capitalised_sys_response = sys_response[0].upper() + sys_response[1:]
|
149 |
+
print(bcolors.GREEN + "System: " + bcolors.ENDC + capitalised_sys_response)
|
150 |
+
|
151 |
+
except Exception:
|
152 |
+
print(bcolors.RED + "*" * 30 + bcolors.ENDC)
|
153 |
+
print(bcolors.RED + "Error in dialogue {}".format(dialogue_filename) + bcolors.ENDC)
|
154 |
+
print(bcolors.RED + "*" * 30 + bcolors.ENDC)
|
155 |
+
traceback.print_exc()
|
156 |
+
continue
|
157 |
+
|
158 |
+
print("Successful dialogues: {}".format(successful_dialogues))
|
159 |
+
print("Total dialogues: {}".format(n_dialogues))
|
160 |
+
print("% Successful Dialopues: {}".format(successful_dialogues / n_dialogues))
|
161 |
+
|
162 |
+
|
163 |
+
if __name__ == "__main__":
|
164 |
+
# TODO: move parameters to config file
|
165 |
+
# Fix the hacky mess below
|
166 |
+
ground_truth_system_responses = sys.argv[1]
|
167 |
+
if ground_truth_system_responses == "False":
|
168 |
+
ground_truth_system_responses = False
|
169 |
+
else:
|
170 |
+
ground_truth_system_responses = True
|
171 |
+
main(write_to_file=False, ground_truth_system_responses=ground_truth_system_responses)
|
scripts/template_train_model.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
import logging
|
3 |
+
import sys
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
import click
|
7 |
+
from IPython.core import ultratb
|
8 |
+
|
9 |
+
import crazyneuraluser
|
10 |
+
|
11 |
+
# fallback to debugger on error
|
12 |
+
sys.excepthook = ultratb.FormattedTB(mode="Verbose", color_scheme="Linux", call_pdb=1)
|
13 |
+
# turn UserWarning messages to errors to find the actual cause
|
14 |
+
# import warnings
|
15 |
+
# warnings.simplefilter("error")
|
16 |
+
|
17 |
+
_logger = logging.getLogger(__name__)
|
18 |
+
|
19 |
+
|
20 |
+
@click.command()
|
21 |
+
@click.option(
|
22 |
+
"-c",
|
23 |
+
"--config",
|
24 |
+
"cfg_path",
|
25 |
+
required=True,
|
26 |
+
type=click.Path(exists=True),
|
27 |
+
help="path to config file",
|
28 |
+
)
|
29 |
+
@click.option("--quiet", "log_level", flag_value=logging.WARNING, default=True)
|
30 |
+
@click.option("-v", "--verbose", "log_level", flag_value=logging.INFO)
|
31 |
+
@click.option("-vv", "--very-verbose", "log_level", flag_value=logging.DEBUG)
|
32 |
+
@click.version_option(crazyneuraluser.__version__)
|
33 |
+
def main(cfg_path: Path, log_level: int):
|
34 |
+
logging.basicConfig(
|
35 |
+
stream=sys.stdout,
|
36 |
+
level=log_level,
|
37 |
+
datefmt="%Y-%m-%d %H:%M",
|
38 |
+
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
39 |
+
)
|
40 |
+
# YOUR CODE GOES HERE! Keep the main functionality in src/crazyneuraluser
|
41 |
+
# est = crazyneuraluser.models.Estimator()
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == "__main__":
|
45 |
+
main()
|
scripts/user_model_code/__init__.py
ADDED
File without changes
|
scripts/user_model_code/decode.sh
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
experiment=$1
|
2 |
+
checkpoint=$2
|
3 |
+
|
4 |
+
if [[ "$experiment" == "SGD" ]]; then
|
5 |
+
echo "Conduct experiment with SGD dataset"
|
6 |
+
job_name='SGD-full'
|
7 |
+
data_list="sgd" # 165k training examples
|
8 |
+
eval_interval=50000 # evaluation interval
|
9 |
+
|
10 |
+
elif [[ "$experiment" == "MultiWOZ" ]]; then
|
11 |
+
echo "Conduct experiment with MulwiWOZ dataset"
|
12 |
+
job_name='MultiWOZ-full'
|
13 |
+
data_list="multiwoz" # 56k training examples
|
14 |
+
eval_interval=20000
|
15 |
+
|
16 |
+
elif [[ "$experiment" == "Joint" ]]; then
|
17 |
+
echo "Conduct experiment with SGD + MulwiWOZ dataset"
|
18 |
+
job_name='Joint-full'
|
19 |
+
data_list="sgd multiwoz" # 221k training examples
|
20 |
+
eval_interval=70000
|
21 |
+
|
22 |
+
else
|
23 |
+
echo "Unrecognised argument"
|
24 |
+
exit
|
25 |
+
fi
|
26 |
+
|
27 |
+
mkdir -p log decode
|
28 |
+
decode_file='decode/'$job_name'.json'
|
29 |
+
eye_browse_output=true # set to false for storing generation results in file
|
30 |
+
|
31 |
+
python main.py --mode='testing' \
|
32 |
+
--model_name=$job_name \
|
33 |
+
--checkpoint=$checkpoint \
|
34 |
+
--decode_file=$decode_file \
|
35 |
+
--data_dir="processed_data" \
|
36 |
+
--data_list=$data_list \
|
37 |
+
--eye_browse_output=$eye_browse_output
|
scripts/user_model_code/interaction/__init__.py
ADDED
File without changes
|
scripts/user_model_code/interaction/config.yaml
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
path: "./models/user_model/MultiWOZ-full_checkpoint_step340k"
|
3 |
+
goal_update:
|
4 |
+
finish_inform: "loose" # loose or strict
|
5 |
+
|
6 |
+
schema_path: "scripts/user_model_code/interaction/schema.json"
|
7 |
+
|
8 |
+
decode:
|
9 |
+
dec_max_len: 1024
|
10 |
+
num_beams: 1
|
11 |
+
temperature: 0.7
|
12 |
+
do_sample: False
|
scripts/user_model_code/interaction/multiwoz_interact.py
ADDED
@@ -0,0 +1,1034 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import random
|
3 |
+
import re
|
4 |
+
import sys
|
5 |
+
import traceback
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import pandas as pd
|
10 |
+
import torch
|
11 |
+
import transformers
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
from transformers import GPT2LMHeadModel, GPT2Tokenizer
|
14 |
+
from .utils import add_str, bcolors, find_segment, load_schema, segment_gen, wrap_element
|
15 |
+
|
16 |
+
|
17 |
+
class DummyPolicy:
|
18 |
+
def init_session(self, ini_goal): # noqa
|
19 |
+
self.goal = ini_goal # noqa
|
20 |
+
|
21 |
+
def get_goal(self) -> dict:
|
22 |
+
"""Returns current user goal.
|
23 |
+
|
24 |
+
Notes
|
25 |
+
-----
|
26 |
+
``hasattr`` user works around the fact that ``convlab2`` initialises the dialogue session
|
27 |
+
before we can explicitly pass the goal to the user model.
|
28 |
+
"""
|
29 |
+
if hasattr(self.goal, "domain_goals"):
|
30 |
+
return self.goal.domain_goals
|
31 |
+
# return {}
|
32 |
+
return self.goal # for consistency
|
33 |
+
|
34 |
+
|
35 |
+
def generation_func(model, input_ids, eos_id, dec_max_len):
|
36 |
+
"""Generation method using greedy search for Transformer v2.x"""
|
37 |
+
|
38 |
+
def _extend_mask(mask):
|
39 |
+
mask = torch.cat([mask, mask.new_ones((mask.shape[0], 1))], dim=-1)
|
40 |
+
return mask
|
41 |
+
|
42 |
+
# input_ids, attention_mask, token_type_ids = batch['input_ids'], batch['attention_mask'], batch['token_type_ids']
|
43 |
+
batch_size = input_ids.size(0)
|
44 |
+
attention_mask = torch.ones_like(input_ids)
|
45 |
+
past = None
|
46 |
+
finish_sent = [False for _ in range(batch_size)]
|
47 |
+
for i in range(dec_max_len):
|
48 |
+
logits, past = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=None).values()
|
49 |
+
|
50 |
+
# logits: (B, T, V), T=1 when past is passed
|
51 |
+
next_token_logits = logits[:, -1, :]
|
52 |
+
next_token = torch.argmax(next_token_logits, dim=-1)
|
53 |
+
input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
|
54 |
+
attention_mask = _extend_mask(attention_mask)
|
55 |
+
|
56 |
+
for bs_idx, token_id in enumerate(next_token):
|
57 |
+
if finish_sent[bs_idx] is False and token_id.item() == eos_id: # first produce <eos>
|
58 |
+
finish_sent[bs_idx] = True
|
59 |
+
if sum(finish_sent) == batch_size:
|
60 |
+
break
|
61 |
+
return input_ids
|
62 |
+
|
63 |
+
|
64 |
+
class NeuralAgent: # crazyusermodel
|
65 |
+
def __init__(self, name: str, model_path: str, model_config_path: str):
|
66 |
+
"""User Simulator
|
67 |
+
Description
|
68 |
+
---------
|
69 |
+
A user model that is able to chat with the task-oriented dialogue system in an end-to-end manner
|
70 |
+
|
71 |
+
Parameters
|
72 |
+
----------
|
73 |
+
name
|
74 |
+
Should indicate the role played by the agent. It should be always user
|
75 |
+
"""
|
76 |
+
|
77 |
+
if name != "user":
|
78 |
+
raise ValueError(f"Expected name 'user' but got {name} instead.")
|
79 |
+
|
80 |
+
# load necessities
|
81 |
+
self.set_device()
|
82 |
+
self.config = OmegaConf.load(model_config_path)
|
83 |
+
|
84 |
+
self.print_intermediary_info = False
|
85 |
+
|
86 |
+
# get schema, which is dependent to dataset, only for providing task description here
|
87 |
+
self.service2meta, self.schema_intents, self.schema_slots = load_schema(self.config["schema_path"])
|
88 |
+
# self.load_checkpoint_and_tokenizer(self.config["model"]["path"])
|
89 |
+
self.load_checkpoint_and_tokenizer(model_path)
|
90 |
+
self.load_materials()
|
91 |
+
|
92 |
+
self.context = []
|
93 |
+
self.current_goal = {}
|
94 |
+
self.behaviour_params = {}
|
95 |
+
self.input_action = [] # type: list[list[str]]
|
96 |
+
self.output_action = [] # type: list[list[str]]
|
97 |
+
|
98 |
+
# for compatibility with convlab2 evaluator
|
99 |
+
self.policy = DummyPolicy()
|
100 |
+
|
101 |
+
""" for reproduction """
|
102 |
+
seed = 1130
|
103 |
+
random.seed(seed)
|
104 |
+
np.random.seed(seed)
|
105 |
+
torch.manual_seed(seed)
|
106 |
+
torch.cuda.manual_seed(seed)
|
107 |
+
torch.cuda.manual_seed_all(seed)
|
108 |
+
torch.backends.cudnn.deterministic = True
|
109 |
+
torch.backends.cudnn.enabled = False
|
110 |
+
torch.backends.cudnn.benchmark = False
|
111 |
+
|
112 |
+
def load_checkpoint_and_tokenizer(self, checkpoint_path: str) -> None:
|
113 |
+
"""Load model checkpoint with the model tokenizer, only for GPT2 for now"""
|
114 |
+
print("Load model, tokenizer from {}".format(checkpoint_path))
|
115 |
+
self.tokenizer = GPT2Tokenizer.from_pretrained(checkpoint_path)
|
116 |
+
self.model = GPT2LMHeadModel.from_pretrained(checkpoint_path)
|
117 |
+
self.model.to(self.device)
|
118 |
+
|
119 |
+
def set_device(self) -> None:
|
120 |
+
"""Set device to GPU/CPU"""
|
121 |
+
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
122 |
+
|
123 |
+
def load_materials(self):
|
124 |
+
"""Load useful materials used in generation"""
|
125 |
+
# model attributes
|
126 |
+
"""
|
127 |
+
finish_inform
|
128 |
+
how strict to finish an informable slot in goal: "strict" or "loose"
|
129 |
+
if "strict": the attribute are finished (removed from goal) only if both slot and value are produced in act
|
130 |
+
if "loose": the attribute are finished if the slot is produced
|
131 |
+
"""
|
132 |
+
self.finish_inform = (
|
133 |
+
self.config.model.goal_update.finish_inform
|
134 |
+
) # controls how strict to eliminate informed slots in goal
|
135 |
+
assert self.finish_inform in ["strict", "loose"]
|
136 |
+
|
137 |
+
# constants
|
138 |
+
self.bos_id, _, self.pad_id, self.sep_id = self.tokenizer.convert_tokens_to_ids(
|
139 |
+
["<BOS>", "<EOS>", "<PAD>", "<SEP>"]
|
140 |
+
)
|
141 |
+
self.bin_flags = {"true": "_True_", "false": "_False_"}
|
142 |
+
|
143 |
+
self.supported_services = ["train", "attraction", "hotel", "restaurant", "taxi", "police", "hospital"]
|
144 |
+
|
145 |
+
self.slot_types = {"search": "search", "book": "book"}
|
146 |
+
# important to change the corresponding act str name when using different tokenization methods,
|
147 |
+
# as they are used to control the user behaviours
|
148 |
+
self.const_act_str = {
|
149 |
+
"inform": "inform",
|
150 |
+
"recommend": "recommend",
|
151 |
+
"request": "request",
|
152 |
+
"fail_search": "no offer",
|
153 |
+
"fail_book": "no book",
|
154 |
+
}
|
155 |
+
|
156 |
+
def prepare_input_ids(self, data: dict, start_token: str) -> str:
|
157 |
+
assert start_token in ["<SYS_ACT/>", "<USR_ACT/>"]
|
158 |
+
input_seq = ""
|
159 |
+
for key in [
|
160 |
+
"CTX",
|
161 |
+
"SYS_UTT",
|
162 |
+
"SYS_ACT",
|
163 |
+
"SNT",
|
164 |
+
"RA",
|
165 |
+
"GC",
|
166 |
+
"GOAL",
|
167 |
+
]: # fixed order, consistent between training and inference
|
168 |
+
if key not in data:
|
169 |
+
continue
|
170 |
+
wrap = wrap_element(key, data[key])
|
171 |
+
input_seq = add_str(input_seq, wrap)
|
172 |
+
|
173 |
+
input_seq = add_str(input_seq, start_token)
|
174 |
+
if transformers.__version__.startswith("2."): # compatible with transformers v2.x used in convlab2
|
175 |
+
input_ids = self.tokenizer.encode(input_seq)
|
176 |
+
else:
|
177 |
+
input_ids = self.tokenizer(input_seq)["input_ids"] # convert to ids
|
178 |
+
input_ids = torch.tensor([input_ids]).long().to(self.device)
|
179 |
+
return input_ids
|
180 |
+
|
181 |
+
def update_internal_data(self, data: dict) -> None:
|
182 |
+
"""Maintain context and user act in the format of generation string for the next turn generation"""
|
183 |
+
# update context
|
184 |
+
sys_utt_wrap = wrap_element("SYS", data["SYS_UTT"]) # e.g., <SYS/> Which area would you prefer? </SYS>
|
185 |
+
usr_utt_wrap = wrap_element("USR", data["USR_UTT"]) # e.g., <USR/> I want to be in the centre. </USR>
|
186 |
+
self._context_str = add_str(self._context_str, sys_utt_wrap)
|
187 |
+
self._context_str = add_str(self._context_str, usr_utt_wrap)
|
188 |
+
|
189 |
+
# update prev usr act
|
190 |
+
self._prev_usr_act = data["USR_ACT"] # e.g., <ACT/> inform </ACT> <SLOT/> area </SLOT> <VALUE/> centre </VALUE>
|
191 |
+
|
192 |
+
def run_inference_once(self, input_ids: torch.tensor, eos_id: int) -> List:
|
193 |
+
if transformers.__version__.startswith("2."): # compatible with transformers v2.x used in convlab2
|
194 |
+
output = generation_func(self.model, input_ids, eos_id, self.config.decode.dec_max_len)
|
195 |
+
else:
|
196 |
+
output = self.model.generate(
|
197 |
+
input_ids,
|
198 |
+
max_length=self.config.decode.dec_max_len,
|
199 |
+
do_sample=self.config.decode.do_sample,
|
200 |
+
early_stopping=True,
|
201 |
+
temperature=self.config.decode.temperature,
|
202 |
+
use_cache=True,
|
203 |
+
num_beams=self.config.decode.num_beams,
|
204 |
+
bos_token_id=self.bos_id,
|
205 |
+
eos_token_id=eos_id,
|
206 |
+
pad_token_id=self.pad_id,
|
207 |
+
)
|
208 |
+
return output
|
209 |
+
|
210 |
+
def generate_whole_sequence(self, sys_utt: str) -> tuple:
|
211 |
+
# first forward pass: generate NLU output and three special flags #####
|
212 |
+
data = {"CTX": self._context_str, "SYS_UTT": sys_utt}
|
213 |
+
start_token, end_token = "<SYS_ACT/>", "</GC>"
|
214 |
+
input_ids = self.prepare_input_ids(data, start_token)
|
215 |
+
eos_id = self.tokenizer.convert_tokens_to_ids(end_token)
|
216 |
+
output = self.run_inference_once(input_ids, eos_id)
|
217 |
+
generation = self.tokenizer.decode(output[0]) # decode back to str, including the fed context
|
218 |
+
|
219 |
+
# parse first pass prediction
|
220 |
+
for key in ["SYS_ACT", "SNT", "GC", "RA"]:
|
221 |
+
value = find_segment(generation, key)
|
222 |
+
data[key] = value
|
223 |
+
|
224 |
+
# update dynamic goal
|
225 |
+
if self.print_intermediary_info:
|
226 |
+
print("SYS ACT ->", data["SYS_ACT"])
|
227 |
+
goal = self.prepare_turn_goal(self._prev_usr_act, data["SYS_ACT"], data["SNT"], data["GC"], data["RA"])
|
228 |
+
data["GOAL"] = goal
|
229 |
+
|
230 |
+
# second forward pass: generate dialogue act and NLG output #####
|
231 |
+
start_token, end_token = "<USR_ACT/>", "<EOS>"
|
232 |
+
input_ids = self.prepare_input_ids(data, start_token)
|
233 |
+
eos_id = self.tokenizer.convert_tokens_to_ids(end_token)
|
234 |
+
output = self.run_inference_once(input_ids, eos_id)
|
235 |
+
generation = self.tokenizer.decode(output[0]) # decode back to str, including the fed context
|
236 |
+
|
237 |
+
# parse second pass prediction
|
238 |
+
for key in ["USR_ACT", "USR_UTT"]:
|
239 |
+
value = find_segment(generation, key)
|
240 |
+
data[key] = value
|
241 |
+
return data, generation
|
242 |
+
|
243 |
+
def _format_complete_goal(self, input_goal: dict) -> dict:
|
244 |
+
"""Format the internal goal representation given a goal
|
245 |
+
|
246 |
+
:param input_goal: a goal that the user has in mind
|
247 |
+
either from the corpus or sampled randomly in a valid way (e.g., correct slot names)
|
248 |
+
:returns: complete_goal: an internal representation of the given goal, a dict with the keys "intents",
|
249 |
+
"constraints"
|
250 |
+
intents: list[str], list of intents in the dialogue, aka scenario
|
251 |
+
constraints: dict, intent as key, in the following format
|
252 |
+
dict(intent: intent_constraints)
|
253 |
+
intent_constraints: {"informable": dict(slot: value_list), "requestable": slot_set}
|
254 |
+
each slot has a value list in case of failure of searching
|
255 |
+
"""
|
256 |
+
# TODO: make the order of services more flexible (how does convlab2 decide the service order?)
|
257 |
+
constraints = dict()
|
258 |
+
intents = []
|
259 |
+
self.n_max_value = {
|
260 |
+
self.slot_types["book"]: 0,
|
261 |
+
self.slot_types["search"]: 0,
|
262 |
+
} # record the max length of value list of a slot
|
263 |
+
|
264 |
+
for service in input_goal["ordered_services"]:
|
265 |
+
if service not in self.supported_services:
|
266 |
+
continue
|
267 |
+
|
268 |
+
# record intent list (scenario), order matters
|
269 |
+
intent = self._map_service_to_intent(service)
|
270 |
+
assert intent not in intents and intent not in constraints
|
271 |
+
intents.append(intent)
|
272 |
+
constraints[intent] = {"informable": dict(), "requestable": set()}
|
273 |
+
|
274 |
+
# collect informable slots
|
275 |
+
assert "info" in input_goal[service] # info has to exist
|
276 |
+
for key in ["fail_info", "info", "fail_book", "book"]: # order matters
|
277 |
+
# assert key in input_goal[service]
|
278 |
+
if key not in input_goal[service]:
|
279 |
+
continue
|
280 |
+
for slot, value in input_goal[service][key].items():
|
281 |
+
self._add_info(constraints[intent]["informable"], slot, value)
|
282 |
+
|
283 |
+
# collect requestable slots
|
284 |
+
key = "reqt"
|
285 |
+
# assert key in input_goal[service]
|
286 |
+
# for slot in input_goal[service][key]:
|
287 |
+
if key in input_goal[service]:
|
288 |
+
for slot in input_goal[service][key].keys():
|
289 |
+
self._add_reqt(constraints[intent]["requestable"], slot)
|
290 |
+
|
291 |
+
# order intents by the order they are dealt with in the data so
|
292 |
+
# if using ground truth system responses the right order of the intents
|
293 |
+
# is preserved
|
294 |
+
|
295 |
+
complete_goal = {"intents": intents, "constraints": constraints}
|
296 |
+
return complete_goal
|
297 |
+
|
298 |
+
def _init_user_status(self) -> dict:
|
299 |
+
"""Initialise user status with intent and constraint
|
300 |
+
intent_idx: int, the index of current intent
|
301 |
+
constraint_idx: dict, intent as key, value is the constraint index used to record which value is used
|
302 |
+
in the slot value list
|
303 |
+
:return:
|
304 |
+
"""
|
305 |
+
intent_idx = 0 # -1
|
306 |
+
# constraint_idx = {intent: 0 for intent in self.complete_goal["intents"]}
|
307 |
+
constraint_idx = {
|
308 |
+
intent: {self.slot_types["search"]: 0, self.slot_types["book"]: 0}
|
309 |
+
for intent in self.complete_goal["intents"]
|
310 |
+
}
|
311 |
+
# TODO: entity provide records, one of the criteria to move to the next intents
|
312 |
+
entity_provided = {intent: False for intent in self.complete_goal["intents"]}
|
313 |
+
return {
|
314 |
+
"intent_idx": intent_idx,
|
315 |
+
"constraint_idx": constraint_idx,
|
316 |
+
"dialogue_terminate": False,
|
317 |
+
"entity_provided": entity_provided,
|
318 |
+
}
|
319 |
+
|
320 |
+
def _get_scenario_str(self) -> None:
|
321 |
+
"""Get a scenario str from a intent list
|
322 |
+
|
323 |
+
Description
|
324 |
+
convert a list of intents, aka scenario, into string with special marks
|
325 |
+
the scenario is determined at the start of dialogue and static during interaction
|
326 |
+
"""
|
327 |
+
intents = self.complete_goal["intents"]
|
328 |
+
_str = [wrap_element("INTENT", intent) for intent in intents]
|
329 |
+
_str = " ".join(_str)
|
330 |
+
self.scenario_str = wrap_element("SCENARIO", _str)
|
331 |
+
|
332 |
+
def _prepare_current_constraints(
|
333 |
+
self,
|
334 |
+
involved_intents: List[str],
|
335 |
+
involved_slot_types: List[str],
|
336 |
+
if_reset_reqt: bool,
|
337 |
+
) -> None:
|
338 |
+
"""Prepare the current constraints, copied the specified content from the complete goal
|
339 |
+
|
340 |
+
the current constraints is used as condition in the model generation
|
341 |
+
its content comes from the "constraints" in "complete goal",
|
342 |
+
but the current constraints only allows one value for a slot at a time
|
343 |
+
the value is chosen from the value list by the "constraint_idx" in user status
|
344 |
+
|
345 |
+
:param involved_intents: list[str], intent list
|
346 |
+
:return:
|
347 |
+
current_constraints: dict, similar format as constraints in the complete goal,
|
348 |
+
but a slot has only one value, e.g.,
|
349 |
+
dict(intent: intent_constraints)
|
350 |
+
intent_constraints: {"informable": dict(slot: value), "requestable": slot_set}
|
351 |
+
"""
|
352 |
+
# iterate the involved intents
|
353 |
+
for intent in involved_intents:
|
354 |
+
constraints = {"informable": dict(), "requestable": set()}
|
355 |
+
# informable slots value pairs
|
356 |
+
for slot, value_list in self.complete_goal["constraints"][intent]["informable"].items():
|
357 |
+
slot_type = self._get_slot_type(slot)
|
358 |
+
if slot_type not in involved_slot_types:
|
359 |
+
continue
|
360 |
+
value_idx = self.user_status["constraint_idx"][intent][slot_type]
|
361 |
+
if value_idx < len(value_list):
|
362 |
+
value = value_list[value_idx]
|
363 |
+
constraints["informable"][slot] = value
|
364 |
+
|
365 |
+
# requestable
|
366 |
+
if if_reset_reqt:
|
367 |
+
constraints["requestable"] = copy.deepcopy(self.complete_goal["constraints"][intent]["requestable"])
|
368 |
+
else:
|
369 |
+
constraints["requestable"] = copy.deepcopy(self.current_constraints[intent]["requestable"])
|
370 |
+
|
371 |
+
# overwrite intent constraints
|
372 |
+
self.current_constraints[intent] = constraints
|
373 |
+
|
374 |
+
@staticmethod
|
375 |
+
def _map_intent_to_service(intent: str) -> str:
|
376 |
+
# TODO: make it not dataset dependent?
|
377 |
+
"""map an intent into a service, multiwoz only"""
|
378 |
+
return intent.split()[1]
|
379 |
+
|
380 |
+
@staticmethod
|
381 |
+
def _map_service_to_intent(service: str) -> str:
|
382 |
+
# TODO: make it not dataset dependent?
|
383 |
+
"""map a service into an intent, multiwoz only"""
|
384 |
+
return f"find {service}"
|
385 |
+
|
386 |
+
def _get_slot_type(self, slot: str) -> str:
|
387 |
+
"""return search or book type of a slot"""
|
388 |
+
slot_type = "book" if "book" in slot else "search"
|
389 |
+
assert slot_type in self.slot_types.keys()
|
390 |
+
return slot_type
|
391 |
+
|
392 |
+
def _get_goal_str(self, intent: str) -> str:
|
393 |
+
"""prepare the proper goal sequence, same as used in training"""
|
394 |
+
goal_str = ""
|
395 |
+
# dialogue scenario
|
396 |
+
goal_str = add_str(goal_str, self.scenario_str)
|
397 |
+
|
398 |
+
# current task
|
399 |
+
goal_str = add_str(goal_str, wrap_element("TASK", intent))
|
400 |
+
|
401 |
+
# task description
|
402 |
+
service = self._map_intent_to_service(intent)
|
403 |
+
description = self.service2meta[service]["intents"][intent]["description"]
|
404 |
+
goal_str = add_str(goal_str, wrap_element("DESC", description))
|
405 |
+
|
406 |
+
# intent_constraints = self.dynamic_constraints[intent]
|
407 |
+
intent_constraints = self.current_constraints[intent]
|
408 |
+
# informable slots
|
409 |
+
info_str = ""
|
410 |
+
# for slot, value in intent_constraints["informable"].items():
|
411 |
+
for slot in sorted(intent_constraints["informable"].keys()): # sort by slot
|
412 |
+
value = intent_constraints["informable"][slot]
|
413 |
+
info_str = add_str(info_str, wrap_element("SLOT", slot))
|
414 |
+
info_str = add_str(info_str, wrap_element("VALUE", value))
|
415 |
+
goal_str = add_str(goal_str, wrap_element("INFORM", info_str))
|
416 |
+
|
417 |
+
# requestable slots
|
418 |
+
req_str = ""
|
419 |
+
for slot in sorted(list(intent_constraints["requestable"])):
|
420 |
+
req_str = add_str(req_str, wrap_element("SLOT", slot))
|
421 |
+
goal_str = add_str(goal_str, wrap_element("REQUEST", req_str))
|
422 |
+
return goal_str.strip()
|
423 |
+
|
424 |
+
def _start_new_intent(self, SNT_flag: str) -> bool:
|
425 |
+
"""decide whether to start a new intent"""
|
426 |
+
# SNT (start new task) is predicted as on
|
427 |
+
assert SNT_flag in list(self.bin_flags.values())
|
428 |
+
# intent = self.intents[self.intent_idx]
|
429 |
+
intent = self.complete_goal["intents"][self.user_status["intent_idx"]]
|
430 |
+
|
431 |
+
# TODO: need at least an entity provided (not really sure...
|
432 |
+
# if not self.intent_entity_provided[intent]: # no entities provided in the intent yet
|
433 |
+
# return False
|
434 |
+
|
435 |
+
# TODO: think about the priority of SNT prediction. It's should be less prioritised than
|
436 |
+
# the number of left constraints.
|
437 |
+
# if SNT_flag == self.bin_flags["true"]: # model prediction in first turn is true
|
438 |
+
# return True
|
439 |
+
|
440 |
+
# current intent has empty constraints
|
441 |
+
if (
|
442 |
+
len(self.current_constraints[intent]["informable"]) == 0
|
443 |
+
and len(self.current_constraints[intent]["requestable"]) == 0
|
444 |
+
):
|
445 |
+
return True
|
446 |
+
return False
|
447 |
+
|
448 |
+
def _check_entity_provided(self, sys_act, intent):
|
449 |
+
# TODO:
|
450 |
+
"""Check if an entity provided in system response (act)"""
|
451 |
+
assert intent in [
|
452 |
+
"find restaurant",
|
453 |
+
"find hotel",
|
454 |
+
"find attraction",
|
455 |
+
"find train",
|
456 |
+
"find taxi",
|
457 |
+
"find police",
|
458 |
+
"find hospital",
|
459 |
+
]
|
460 |
+
if intent in ["find restaurant", "find hotel", "find attraction"]:
|
461 |
+
if "<SLOT/> name </SLOT>" in sys_act:
|
462 |
+
self.intent_entity_provided[intent] = True
|
463 |
+
elif intent == "find train":
|
464 |
+
if "<SLOT/> train id </SLOT>" in sys_act:
|
465 |
+
self.intent_entity_provided[intent] = True
|
466 |
+
else: # taxi
|
467 |
+
if "<SLOT/> type </SLOT>" in sys_act:
|
468 |
+
self.intent_entity_provided[intent] = True
|
469 |
+
|
470 |
+
def _activate_dialogue_terminate(self) -> None:
|
471 |
+
"""Turn on the user status about dialogue termination"""
|
472 |
+
self.user_status["dialogue_terminate"] = True
|
473 |
+
|
474 |
+
def prepare_turn_goal(self, prev_usr_act: str, sys_act: str, SNT_flag: str, GC_flag: str, RA_flag: str) -> str:
|
475 |
+
"""prepare the goal sequence for the current turn"""
|
476 |
+
# TODO: more detailed instruction here
|
477 |
+
# TODO: Deal with empty intents (and figure out why they happen)
|
478 |
+
intent = self.complete_goal["intents"][self.user_status["intent_idx"]]
|
479 |
+
|
480 |
+
# TODO: check if at least one entity is provided in system act
|
481 |
+
# First thing to do, check if the system provides an entity
|
482 |
+
# self._check_entity_provided(sys_act, intent)
|
483 |
+
|
484 |
+
# update goal first then check if moves to next intent (task)
|
485 |
+
self._update_current_constraints(intent, "usr", prev_usr_act, sys_act)
|
486 |
+
self._update_current_constraints(
|
487 |
+
intent, "sys", prev_usr_act, sys_act
|
488 |
+
) # impact of sys_act overwrites that of usr_act
|
489 |
+
|
490 |
+
# check if new intent starts
|
491 |
+
if self._start_new_intent(SNT_flag):
|
492 |
+
self.user_status["intent_idx"] += 1
|
493 |
+
if self.user_status["intent_idx"] < len(self.complete_goal["intents"]):
|
494 |
+
intent = self.complete_goal["intents"][self.user_status["intent_idx"]]
|
495 |
+
else:
|
496 |
+
self._activate_dialogue_terminate()
|
497 |
+
# TODO: request alternative by setting <RA> for sgd
|
498 |
+
# TODO: sample new goal if goal change <GC> for sgd
|
499 |
+
goal_str = self._get_goal_str(intent)
|
500 |
+
|
501 |
+
# print("***** user status *****\n->", self.user_status, "\n")
|
502 |
+
# print("***** current intent *****\n->", intent, "\n") # BACK
|
503 |
+
# print("***** current intent constraint *****\n->", self.current_constraints[intent], "\n")
|
504 |
+
# print("***** corresponding goal str *****\n->", goal_str, "\n")
|
505 |
+
# print("***** current entities provided (empty) *****\n->", self.intent_entity_provided, "\n")
|
506 |
+
return goal_str
|
507 |
+
|
508 |
+
def _use_next_constraints(self, intent: str, slot_type: str) -> None:
|
509 |
+
"""move the constraint pointer to the next"""
|
510 |
+
# Another problem is that how to decide which slot type (search or book) to add when failure?
|
511 |
+
# one solution is that dont use act mapping to keep NoOffer and NoBook separate, if so, try use nl on act
|
512 |
+
self.user_status["constraint_idx"][intent][slot_type] += 1
|
513 |
+
if self.user_status["constraint_idx"][intent][slot_type] >= self.n_max_value[slot_type]:
|
514 |
+
# TODO: ask Alex, usually how to deal with this warning case? And make it as warning rather than just print
|
515 |
+
print(
|
516 |
+
f"Failure times on {slot_type} is more than the given value candidates, \
|
517 |
+
no new value to choose as alternative"
|
518 |
+
)
|
519 |
+
print("A valid goal should not enter here!")
|
520 |
+
self.user_status["constraint_idx"][intent][slot_type] = (
|
521 |
+
self.n_max_value[slot_type] - 1
|
522 |
+
) # let user use last values as they are supposed to be fine
|
523 |
+
|
524 |
+
def _update_current_constraints(self, intent: str, spk: str, usr_act: str, sys_act: str) -> None:
|
525 |
+
# TODO: complete instruction here
|
526 |
+
"""Update current constraints used for generation based on either previous usr or sys act
|
527 |
+
|
528 |
+
:param act:
|
529 |
+
:param spk:
|
530 |
+
:param intent:
|
531 |
+
:return:
|
532 |
+
"""
|
533 |
+
assert spk in ["usr", "sys"]
|
534 |
+
# act_dict = parse_act(act)
|
535 |
+
intent_constraints = self.current_constraints[intent]
|
536 |
+
|
537 |
+
if spk == "sys":
|
538 |
+
act_dict = self.parse_act(sys_act, self.print_intermediary_info)
|
539 |
+
# When the system provides information (in the act_dict) then remove it from the user's
|
540 |
+
# requestable constraints as the user has been provided with the info!
|
541 |
+
# NB: This was added by Alistair after the original code was shared by Andy,
|
542 |
+
# as it seems the original implementation missed this critical step.
|
543 |
+
if self.const_act_str["inform"] in act_dict:
|
544 |
+
for slot in act_dict[self.const_act_str["inform"]]:
|
545 |
+
if slot in intent_constraints["requestable"]:
|
546 |
+
intent_constraints["requestable"].remove(slot)
|
547 |
+
elif self.const_act_str["recommend"] in act_dict:
|
548 |
+
for slot in act_dict[self.const_act_str["recommend"]]:
|
549 |
+
if slot in intent_constraints["requestable"]:
|
550 |
+
intent_constraints["requestable"].remove(slot)
|
551 |
+
|
552 |
+
# when the system informs failure (search or book), use next set of constraints given in goal #####
|
553 |
+
# if "_NOTIFY_FAILURE_" in act_dict:
|
554 |
+
if self.const_act_str["fail_search"] in act_dict:
|
555 |
+
slot_type = self.slot_types["search"]
|
556 |
+
self._use_next_constraints(intent, slot_type)
|
557 |
+
keep_slot_types = [
|
558 |
+
self.slot_types["search"],
|
559 |
+
self.slot_types["book"],
|
560 |
+
] # still in search phase, book slots should be kept
|
561 |
+
self._prepare_current_constraints(
|
562 |
+
[intent], keep_slot_types, if_reset_reqt=False
|
563 |
+
) # only change constraints for this intent
|
564 |
+
|
565 |
+
elif self.const_act_str["fail_book"] in act_dict:
|
566 |
+
slot_type = self.slot_types["book"]
|
567 |
+
self._use_next_constraints(intent, slot_type)
|
568 |
+
keep_slot_types = [self.slot_types["book"]] # already found entities, no need to keep search slots
|
569 |
+
self._prepare_current_constraints([intent], keep_slot_types, if_reset_reqt=False)
|
570 |
+
|
571 |
+
# when the system request #
|
572 |
+
elif self.const_act_str["request"] in act_dict:
|
573 |
+
requested_slots = act_dict[self.const_act_str["request"]]
|
574 |
+
for slot in requested_slots.keys():
|
575 |
+
# requested slot in current constraint, do nothing
|
576 |
+
if slot in intent_constraints["informable"].keys():
|
577 |
+
continue
|
578 |
+
|
579 |
+
# slots that are beyond the current goal enter the following section
|
580 |
+
# case 1: requested slot in the complete goal,
|
581 |
+
# this should be entered if the system requests the informed slots
|
582 |
+
# if slot in self.complete_constraints[intent]["informable"].keys():
|
583 |
+
# value = self.complete_constraints[intent]["informable"][slot]
|
584 |
+
if (
|
585 |
+
slot in self.complete_goal["constraints"][intent]["informable"].keys()
|
586 |
+
): # dict of slot to value_list
|
587 |
+
slot_type = self._get_slot_type(slot)
|
588 |
+
value_idx = self.user_status["constraint_idx"][intent][slot_type]
|
589 |
+
value = self.complete_goal["constraints"][intent]["informable"][slot][value_idx]
|
590 |
+
|
591 |
+
# case 2: requested slot not in the complete goal, set the value to "dontcare"
|
592 |
+
# can sample a new value here for more interesting interactions
|
593 |
+
else:
|
594 |
+
value = "dontcare" # "no preference" # TODO: play around to see nlg output
|
595 |
+
intent_constraints["informable"][slot] = value
|
596 |
+
|
597 |
+
else: # usr
|
598 |
+
act_dict = self.parse_act(usr_act, self.print_intermediary_info)
|
599 |
+
# remove informed slot/value pair, if informed #
|
600 |
+
if self.const_act_str["inform"] in act_dict:
|
601 |
+
for slot, value_list in act_dict[self.const_act_str["inform"]].items():
|
602 |
+
# value = value_list[0]
|
603 |
+
for value in value_list: # possible to have multi-value slots in user act in corpus
|
604 |
+
if self.finish_inform == "loose" and slot in intent_constraints["informable"]:
|
605 |
+
del intent_constraints["informable"][slot]
|
606 |
+
if (
|
607 |
+
self.finish_inform == "strict"
|
608 |
+
and slot in intent_constraints["informable"]
|
609 |
+
and value == intent_constraints["informable"][slot]
|
610 |
+
):
|
611 |
+
del intent_constraints["informable"][slot]
|
612 |
+
|
613 |
+
# remove requested slot, if requested
|
614 |
+
if self.const_act_str["request"] in act_dict:
|
615 |
+
sys_act_dict = self.parse_act(sys_act, self.print_intermediary_info) # auxiliary check
|
616 |
+
for slot in act_dict[self.const_act_str["request"]].keys():
|
617 |
+
# if slot in intent_constraints["requestable"]: # one choice
|
618 |
+
if self.const_act_str["inform"] in sys_act_dict:
|
619 |
+
if (
|
620 |
+
slot in intent_constraints["requestable"]
|
621 |
+
and slot in sys_act_dict[self.const_act_str["inform"]].keys()
|
622 |
+
): # another choice, more strict
|
623 |
+
intent_constraints["requestable"].remove(slot)
|
624 |
+
|
625 |
+
def _add_info(self, slot_to_value_list, slot, value) -> None:
|
626 |
+
# print(slot)
|
627 |
+
# assert slot in self.schema_slots # SLOT_FORMAT
|
628 |
+
# constraints[intent]["informable"][slot] = value
|
629 |
+
if slot not in slot_to_value_list:
|
630 |
+
slot_to_value_list[slot] = []
|
631 |
+
# assert value not in slot_to_value_list[slot]
|
632 |
+
if value not in slot_to_value_list[slot]:
|
633 |
+
slot_to_value_list[slot].append(value)
|
634 |
+
slot_type = self._get_slot_type(slot)
|
635 |
+
if len(slot_to_value_list) > self.n_max_value[slot_type]:
|
636 |
+
self.n_max_value[slot_type] = len(slot_to_value_list)
|
637 |
+
|
638 |
+
def _add_reqt(self, slot_set, slot) -> None:
|
639 |
+
# assert slot in self.schema_slots # SLOT_FORMAT
|
640 |
+
# constraints[intent]["requestable"].add(slot)
|
641 |
+
slot_set.add(slot)
|
642 |
+
|
643 |
+
def _validate_input_goal(self):
|
644 |
+
"""validate the input goal"""
|
645 |
+
# TODO: finish the method
|
646 |
+
# assert all([intent in self.schema_intents for intent in intents]) # ensure intents are in schema
|
647 |
+
pass
|
648 |
+
|
649 |
+
@staticmethod
|
650 |
+
def parse_act(act_seq: str, print_intermediary_info: bool) -> dict:
|
651 |
+
"""parse usr/sys act string into dict(act: {slot=value_list}) (slots in act_request have '_Empty_' value)"""
|
652 |
+
act_dict = {}
|
653 |
+
assert isinstance(act_seq, str)
|
654 |
+
act_seq = act_seq.split("<ACT/>")
|
655 |
+
for act_seg in act_seq:
|
656 |
+
if act_seg == "":
|
657 |
+
continue
|
658 |
+
|
659 |
+
act_seg = act_seg.strip() # remove space at the start/end
|
660 |
+
act_seg = act_seg.split()
|
661 |
+
# get act in special token format #
|
662 |
+
# act = act_seg[0] # e.g., _INFORM_, _REQUEST_
|
663 |
+
# assert act[0] == "_" and act[-1] == "_"
|
664 |
+
# act_seg = " ".join(act_seg[2:]) # discard first two tokens, "_ACT_ </ACT>"
|
665 |
+
|
666 |
+
# get act in natural language format #
|
667 |
+
end_idx = act_seg.index("</ACT>")
|
668 |
+
act = " ".join(act_seg[:end_idx])
|
669 |
+
act_seg = " ".join(act_seg[end_idx + 1 :]) # act arguments (slot/value pairs)
|
670 |
+
# print(f"act: {act}\n", act_seg, "\n")
|
671 |
+
|
672 |
+
assert act not in act_dict
|
673 |
+
act_dict[act] = {}
|
674 |
+
|
675 |
+
# Sometimes the model bugs out and puts <ACT/> or </ACT> where there should be </VALUE> or <VALUE/>
|
676 |
+
if "ACT" in act_seg:
|
677 |
+
continue
|
678 |
+
|
679 |
+
for sv_seg in act_seg.split("</VALUE>"):
|
680 |
+
if sv_seg == "":
|
681 |
+
continue
|
682 |
+
|
683 |
+
try:
|
684 |
+
sv_seg = sv_seg.replace("<SLOT/>", "")
|
685 |
+
sv_seg = sv_seg.strip() # remove spaces at begin and end
|
686 |
+
# print("|{}|".format(sv_seg))
|
687 |
+
slot, value = sv_seg.split("</SLOT> <VALUE/>")
|
688 |
+
slot, value = slot.strip(), value.strip()
|
689 |
+
# print("act: |{}|, slot: |{}|, value: |{}|".format(act, slot, value))
|
690 |
+
# one slot one value
|
691 |
+
# act_dict[act][slot] = value
|
692 |
+
# one slot, multi-value is possible by system
|
693 |
+
if slot not in act_dict[act]:
|
694 |
+
act_dict[act][slot] = []
|
695 |
+
if value not in act_dict[act][slot]:
|
696 |
+
act_dict[act][slot].append(value)
|
697 |
+
|
698 |
+
except Exception:
|
699 |
+
if print_intermediary_info:
|
700 |
+
print(
|
701 |
+
bcolors.YELLOW
|
702 |
+
+ "!The User Agent got messed up the intermediate syntax! Exception:"
|
703 |
+
+ bcolors.ENDC
|
704 |
+
)
|
705 |
+
traceback.print_exc()
|
706 |
+
continue
|
707 |
+
|
708 |
+
# print(act_dict)
|
709 |
+
return act_dict
|
710 |
+
|
711 |
+
def convert_into_system_act_format(self):
|
712 |
+
# TODO
|
713 |
+
pass
|
714 |
+
|
715 |
+
# below methods need be implemented for convlab-2 to work #
|
716 |
+
def init_session(self, **kwargs):
|
717 |
+
"""Use this method to reset the agent state after each dialogue, if necessary.
|
718 |
+
This gets called before each dialogue.
|
719 |
+
|
720 |
+
Examples
|
721 |
+
--------
|
722 |
+
In `simulate_corpus_interaction.py` you will see that this is used, for example, to pass
|
723 |
+
the dialogue to the corpus agent so it knows what to talk about.
|
724 |
+
|
725 |
+
An example here would be to reset the dialogue context.
|
726 |
+
"""
|
727 |
+
# dialogue goal in MultiWOZ2.1-like format
|
728 |
+
self.current_goal = kwargs.get("ini_goal", {})
|
729 |
+
self.policy.init_session(ini_goal=self.current_goal)
|
730 |
+
self.current_goal = self.policy.get_goal()
|
731 |
+
# TODO: ANYTHING ELSE THAT NEEDS TO HAPPEN BEFORE EACH DIALOGUE?
|
732 |
+
self.context = []
|
733 |
+
self.input_action = []
|
734 |
+
self.output_action = []
|
735 |
+
|
736 |
+
# init internal data
|
737 |
+
self._context_str = "" # context string with special tags used in generation
|
738 |
+
self._prev_usr_act = "" # user act string used in generation
|
739 |
+
|
740 |
+
# goal process
|
741 |
+
self.complete_goal = self._format_complete_goal(self.current_goal)
|
742 |
+
self.user_status = self._init_user_status()
|
743 |
+
self._get_scenario_str()
|
744 |
+
self.current_constraints = {} # init
|
745 |
+
self._prepare_current_constraints(
|
746 |
+
self.complete_goal["intents"],
|
747 |
+
list(self.slot_types.keys()),
|
748 |
+
if_reset_reqt=True,
|
749 |
+
)
|
750 |
+
|
751 |
+
# print("input goal:\n", self.current_goal, "\n")
|
752 |
+
# print("complete goal:\n", self.complete_goal, "\n")
|
753 |
+
# print("current constraints:\n", self.current_constraints, "\n")
|
754 |
+
# sys.exit(1)
|
755 |
+
|
756 |
+
def response(self, sys_utterance: str) -> str:
|
757 |
+
"""Generate natural language response given the system response.
|
758 |
+
|
759 |
+
Parameters
|
760 |
+
---------
|
761 |
+
sys_utterance
|
762 |
+
Last system utterance. For first turn, sys_utterance is the empty string.
|
763 |
+
|
764 |
+
Returns
|
765 |
+
-------
|
766 |
+
response
|
767 |
+
A natural language response.
|
768 |
+
|
769 |
+
"""
|
770 |
+
|
771 |
+
# TODO: MAKE APPROPRIATE USE OF THE HISTORY, BEHAVIOUR_PARAMS, CURRENT_GOAL, UPDATE_GOAL TO GENERATE A RESPONSE
|
772 |
+
# TODO: DON'T FORGET TO UPDATE INPUT AND OUTPUT ACTIONS STATES
|
773 |
+
# response = "I want Italian."
|
774 |
+
gen_parse, gen_str = self.generate_whole_sequence(sys_utterance)
|
775 |
+
self.update_internal_data(gen_parse) # prepare for next turn
|
776 |
+
if self.print_intermediary_info:
|
777 |
+
segment_gen(gen_str, "example dialogue") # crazyusermodel
|
778 |
+
# TODO: update lists of context, da_in, da_out here
|
779 |
+
return gen_parse["USR_UTT"]
|
780 |
+
|
781 |
+
def get_in_da(self) -> List[List[str]]:
|
782 |
+
"""Used by clients to retrieve the user model NLU.
|
783 |
+
|
784 |
+
Returns
|
785 |
+
-------
|
786 |
+
NLU output, assumed to be a list of lists, each formatted as::
|
787 |
+
|
788 |
+
[[intention, domain, slot, value], ...]
|
789 |
+
|
790 |
+
Here ``intention`` refers to a dialogue act and the ``intention``, ``domain`` and ``slot`` strings should
|
791 |
+
follow the same convention as the corpus dialogue act annotations (i.e., capitalised, and using the correct
|
792 |
+
set of slot names).
|
793 |
+
"""
|
794 |
+
return self.input_action
|
795 |
+
|
796 |
+
def get_out_da(self) -> List[List[str]]:
|
797 |
+
"""Used by clients to retrieve the user model policy output.
|
798 |
+
|
799 |
+
Returns
|
800 |
+
-------
|
801 |
+
Policy output, following the same convention as the NLU output.
|
802 |
+
"""
|
803 |
+
return self.output_action
|
804 |
+
|
805 |
+
def get_reward(self) -> float:
|
806 |
+
"""Dummy method, used for API consistency."""
|
807 |
+
return -1
|
808 |
+
|
809 |
+
def is_terminated(self) -> bool:
|
810 |
+
"""This should tell an external client whether the user model considers they have completed the task."""
|
811 |
+
# return False
|
812 |
+
return self.user_status["dialogue_terminate"]
|
813 |
+
|
814 |
+
|
815 |
+
def parse_complete_gen(gen):
|
816 |
+
"""parse the complete generation output, return predictions of system act, user act and user utterance"""
|
817 |
+
output = {}
|
818 |
+
for key in ["SYS_ACT", "SNT", "GC", "RA", "USR_ACT", "USR_UTT"]:
|
819 |
+
value = find_segment(gen, key)
|
820 |
+
output[key] = value
|
821 |
+
# print("***** complete generation output *****\n->", gen, "\n") # BACK
|
822 |
+
# print("***** parse output *****\n->", output, "\n")
|
823 |
+
return output
|
824 |
+
|
825 |
+
|
826 |
+
def generate_example_goal() -> dict:
|
827 |
+
"""create an example goal for testing"""
|
828 |
+
# {service: service_meta},
|
829 |
+
# service_mate: {"info": {slot: value}, "fail_info": {slot: value},
|
830 |
+
# "book": {slot}: value, "fail_book": {slot: value}, "reqt": set(slot)}
|
831 |
+
goal = {}
|
832 |
+
services = ["restaurant", "hotel"]
|
833 |
+
# services = ["train", "attraction"]
|
834 |
+
# services = ["restaurant"]
|
835 |
+
|
836 |
+
# # restaurant
|
837 |
+
service = services[0]
|
838 |
+
goal[service] = {}
|
839 |
+
goal[service]["fail_info"] = {"food": "eastern european", "area": "south", "price range": "expensive"}
|
840 |
+
goal[service]["info"] = {"food": "chinese", "area": "south", "price range": "cheap"}
|
841 |
+
goal[service]["fail_book"] = {}
|
842 |
+
goal[service]["book"] = {"book day": "monday", "book people": "8", "book time": "13:15"}
|
843 |
+
goal[service]["reqt"] = {"address": "?"}
|
844 |
+
|
845 |
+
# hotel
|
846 |
+
service = services[1]
|
847 |
+
goal[service] = {}
|
848 |
+
goal[service]["fail_info"] = {"stars": "3", "price range": "cheap", "area": "centre", "internet": "_True_"}
|
849 |
+
goal[service]["info"] = {"stars": "5", "price range": "expensive", "area": "centre", "internet": "_True_"}
|
850 |
+
goal[service]["fail_book"] = {"book day": "sunday", "book stay": 3, "book people": 2}
|
851 |
+
goal[service]["book"] = {"book day": "monday", "book stay": 1, "book people": 2}
|
852 |
+
goal[service]["reqt"] = {"phone": "?", "postcode": "?"}
|
853 |
+
|
854 |
+
# # train
|
855 |
+
# service = services[1]
|
856 |
+
# goal[service] = {}
|
857 |
+
# goal[service]["info"] = {
|
858 |
+
# "destination": "ely",
|
859 |
+
# "day": "monday",
|
860 |
+
# "arrive by": "19:00",
|
861 |
+
# "departure": "cambridge",
|
862 |
+
# "book people": "8"
|
863 |
+
# }
|
864 |
+
# goal[service]["reqt"] = {"duration": "?", "leave at": "?", "train id": "?"}
|
865 |
+
|
866 |
+
# # attraction
|
867 |
+
# service = services[1]
|
868 |
+
# goal[service] = {}
|
869 |
+
# goal[service]["info"] = {
|
870 |
+
# "type": "college",
|
871 |
+
# "area": "west"
|
872 |
+
# }
|
873 |
+
# goal[service]["reqt"] = {"phone": "?", "postcode": "?"}
|
874 |
+
|
875 |
+
# taxi
|
876 |
+
# service = services[0]
|
877 |
+
# goal[service] = {}
|
878 |
+
# goal[service]["info"] = {
|
879 |
+
# "arrive by": "17:30",
|
880 |
+
# "departure": "city stop restaurant",
|
881 |
+
# "destination": "the cambridge punter"
|
882 |
+
# }
|
883 |
+
# goal[service]["reqt"] = {"phone": "?", "type": "?"}
|
884 |
+
# more services...
|
885 |
+
return goal
|
886 |
+
|
887 |
+
|
888 |
+
def set_sorted_services_for_current_goal(goal, goal_idx, df_raw_mwoz):
|
889 |
+
# Get the list of services in the goal as they appear in the data so they can be processed correctly
|
890 |
+
|
891 |
+
current_dialogue_services = []
|
892 |
+
for service_name in goal:
|
893 |
+
current_dialogue_services.append(service_name)
|
894 |
+
|
895 |
+
message = df_raw_mwoz.iloc[:, goal_idx].goal["message"]
|
896 |
+
|
897 |
+
ordered_current_dialogue_services = []
|
898 |
+
|
899 |
+
for instruction in message:
|
900 |
+
instruction_split = re.split(" |<|>", instruction)
|
901 |
+
for word in instruction_split:
|
902 |
+
if word in current_dialogue_services:
|
903 |
+
ordered_current_dialogue_services.append(word)
|
904 |
+
current_dialogue_services.remove(word)
|
905 |
+
|
906 |
+
# Make sure any words not mentioned in the message (e.g. it happens for 'police' in the second goal) are n#ot missed
|
907 |
+
for word in current_dialogue_services:
|
908 |
+
if word not in ordered_current_dialogue_services:
|
909 |
+
ordered_current_dialogue_services.append(word)
|
910 |
+
|
911 |
+
return ordered_current_dialogue_services
|
912 |
+
|
913 |
+
|
914 |
+
def read_multiWOZ_20_goals(file_path, n_goals):
|
915 |
+
df_raw_mwoz = pd.read_json(file_path)
|
916 |
+
|
917 |
+
goals = []
|
918 |
+
for i in range(n_goals):
|
919 |
+
parsed_goal = {}
|
920 |
+
goal = df_raw_mwoz.iloc[:, i].goal
|
921 |
+
|
922 |
+
# Determine relevant keys
|
923 |
+
for _ in goal.keys():
|
924 |
+
relevant_goals = {k: v for k, v in goal.items() if v != {} and k != "topic" and k != "message"}
|
925 |
+
services = [key for key in relevant_goals.keys()]
|
926 |
+
for service in services:
|
927 |
+
parsed_goal[service] = relevant_goals[service]
|
928 |
+
|
929 |
+
ordered_services = set_sorted_services_for_current_goal(parsed_goal, i, df_raw_mwoz)
|
930 |
+
parsed_goal["ordered_services"] = ordered_services
|
931 |
+
|
932 |
+
# Update the format of those relevant keys to match the format of this code
|
933 |
+
for service in parsed_goal.keys():
|
934 |
+
if service == "ordered_services":
|
935 |
+
continue
|
936 |
+
|
937 |
+
for service_key, service_value in parsed_goal[service].items():
|
938 |
+
|
939 |
+
# Handle 'reqt' key which is a list. Convert it to a dict. (and do the same for similar keys).
|
940 |
+
if type(parsed_goal[service][service_key]) is list and parsed_goal[service][service_key] != []:
|
941 |
+
replacement_dict = {}
|
942 |
+
for item in parsed_goal[service][service_key]:
|
943 |
+
replacement_dict[item] = "?"
|
944 |
+
parsed_goal[service][service_key] = replacement_dict
|
945 |
+
|
946 |
+
# Handle 'hotel' key which has a string value
|
947 |
+
# with the name of hotel - or other similar situations
|
948 |
+
elif type(parsed_goal[service][service_key]) is str:
|
949 |
+
continue
|
950 |
+
|
951 |
+
# Make sure the dictionary we are adding is not empty
|
952 |
+
if not parsed_goal[service][service_key]:
|
953 |
+
continue
|
954 |
+
|
955 |
+
# Remove any attributes that are "invalid" or "preinvalid"
|
956 |
+
# Also check if 'arriveBy' or 'leaveAt' or 'pricerange' is inside the attributes of the service_key
|
957 |
+
# If so reformat it according to the code in this file
|
958 |
+
|
959 |
+
list_of_attribute_keys = [k for k in parsed_goal[service][service_key].keys()]
|
960 |
+
for k in list_of_attribute_keys:
|
961 |
+
if k == "invalid" or k == "pre_invalid":
|
962 |
+
parsed_goal[service][service_key].pop(k)
|
963 |
+
if k == "arriveBy":
|
964 |
+
parsed_goal[service][service_key]["arrive by"] = parsed_goal[service][service_key].pop(k)
|
965 |
+
elif k == "leaveAt":
|
966 |
+
parsed_goal[service][service_key]["leave at"] = parsed_goal[service][service_key].pop(k)
|
967 |
+
elif k == "pricerange":
|
968 |
+
parsed_goal[service][service_key]["price range"] = parsed_goal[service][service_key].pop(k)
|
969 |
+
elif k == "car type":
|
970 |
+
parsed_goal[service][service_key]["type"] = parsed_goal[service][service_key].pop(k)
|
971 |
+
elif k == "trainID":
|
972 |
+
parsed_goal[service][service_key]["train id"] = parsed_goal[service][service_key].pop(k)
|
973 |
+
|
974 |
+
# Check if "book" is in the service info dict ("book" or "fail_book") then prepend
|
975 |
+
# 'book' to the keys inside the service (the attributes of the service)
|
976 |
+
if "book" in service_key:
|
977 |
+
list_of_attribute_keys = [k for k in parsed_goal[service][service_key].keys()]
|
978 |
+
for k in list_of_attribute_keys:
|
979 |
+
parsed_goal[service][service_key]["book {}".format(k)] = parsed_goal[service][service_key].pop(
|
980 |
+
k
|
981 |
+
)
|
982 |
+
|
983 |
+
# If True or False is in service.values, convert to "_True_" or "_False_"
|
984 |
+
for k, v in parsed_goal[service][service_key].items():
|
985 |
+
if v is True:
|
986 |
+
parsed_goal[service][service_key][k] = "_True_"
|
987 |
+
elif v is False:
|
988 |
+
parsed_goal[service][service_key][k] = "_False_"
|
989 |
+
|
990 |
+
goals.append(parsed_goal)
|
991 |
+
|
992 |
+
return goals
|
993 |
+
|
994 |
+
|
995 |
+
def interact(checkpoint_path):
|
996 |
+
user_model = NeuralAgent("user", checkpoint_path, "scripts/user_model_code/interaction/config.yaml")
|
997 |
+
|
998 |
+
# TODO: fix the hardcoded variables here
|
999 |
+
file_path = "data/raw/UBAR/multi-woz/data.json"
|
1000 |
+
user_model.print_intermediary_info = True
|
1001 |
+
n_goals = 50
|
1002 |
+
|
1003 |
+
for dialogue_number, goal in enumerate(read_multiWOZ_20_goals(file_path, n_goals)):
|
1004 |
+
try:
|
1005 |
+
# goal = generate_example_goal()
|
1006 |
+
user_model.init_session(ini_goal=goal)
|
1007 |
+
sys_utt = ""
|
1008 |
+
|
1009 |
+
for turn_id in range(100):
|
1010 |
+
user_model.response(sys_utt)
|
1011 |
+
|
1012 |
+
if user_model.is_terminated():
|
1013 |
+
print("Dialogue terminates!")
|
1014 |
+
break
|
1015 |
+
|
1016 |
+
# next turn materials
|
1017 |
+
sys_utt = input("Enter system response here: ")
|
1018 |
+
if sys_utt == "Goodbye":
|
1019 |
+
break
|
1020 |
+
|
1021 |
+
except Exception:
|
1022 |
+
print("Error in dialogue {}".format(dialogue_number))
|
1023 |
+
traceback.print_exc()
|
1024 |
+
continue
|
1025 |
+
|
1026 |
+
|
1027 |
+
if __name__ == "__main__":
|
1028 |
+
if len(sys.argv) == 1:
|
1029 |
+
print("Wrong argument!")
|
1030 |
+
print("Usage: python multiwoz_interact.py checkpoint_path")
|
1031 |
+
sys.exit(1)
|
1032 |
+
|
1033 |
+
checkpoint_path = sys.argv[1]
|
1034 |
+
interact(checkpoint_path)
|
scripts/user_model_code/interaction/schema.json
ADDED
@@ -0,0 +1,712 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"service_name": "hotel",
|
4 |
+
"slots": [
|
5 |
+
{
|
6 |
+
"name": "hotel-pricerange",
|
7 |
+
"description": "price budget of the hotel",
|
8 |
+
"possible_values": [
|
9 |
+
"expensive",
|
10 |
+
"cheap",
|
11 |
+
"moderate"
|
12 |
+
],
|
13 |
+
"is_categorical": true
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"name": "hotel-type",
|
17 |
+
"description": "what is the type of the hotel",
|
18 |
+
"possible_values": [
|
19 |
+
"guesthouse",
|
20 |
+
"hotel"
|
21 |
+
],
|
22 |
+
"is_categorical": true
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"name": "hotel-parking",
|
26 |
+
"description": "whether the hotel has parking",
|
27 |
+
"possible_values": [
|
28 |
+
"free",
|
29 |
+
"no",
|
30 |
+
"yes"
|
31 |
+
],
|
32 |
+
"is_categorical": true
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"name": "hotel-bookday",
|
36 |
+
"description": "day of the hotel booking",
|
37 |
+
"possible_values": [
|
38 |
+
"monday",
|
39 |
+
"tuesday",
|
40 |
+
"wednesday",
|
41 |
+
"thursday",
|
42 |
+
"friday",
|
43 |
+
"saturday",
|
44 |
+
"sunday"
|
45 |
+
],
|
46 |
+
"is_categorical": true
|
47 |
+
},
|
48 |
+
{
|
49 |
+
"name": "hotel-bookpeople",
|
50 |
+
"description": "number of people for the hotel booking",
|
51 |
+
"possible_values": [
|
52 |
+
"1",
|
53 |
+
"2",
|
54 |
+
"3",
|
55 |
+
"4",
|
56 |
+
"5",
|
57 |
+
"6",
|
58 |
+
"7",
|
59 |
+
"8"
|
60 |
+
],
|
61 |
+
"is_categorical": true
|
62 |
+
},
|
63 |
+
{
|
64 |
+
"name": "hotel-bookstay",
|
65 |
+
"description": "length of stay at the hotel",
|
66 |
+
"possible_values": [
|
67 |
+
"1",
|
68 |
+
"2",
|
69 |
+
"3",
|
70 |
+
"4",
|
71 |
+
"5",
|
72 |
+
"6",
|
73 |
+
"7",
|
74 |
+
"8"
|
75 |
+
],
|
76 |
+
"is_categorical": true
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"name": "hotel-stars",
|
80 |
+
"description": "star rating of the hotel",
|
81 |
+
"possible_values": [
|
82 |
+
"0",
|
83 |
+
"1",
|
84 |
+
"2",
|
85 |
+
"3",
|
86 |
+
"4",
|
87 |
+
"5"
|
88 |
+
],
|
89 |
+
"is_categorical": true
|
90 |
+
},
|
91 |
+
{
|
92 |
+
"name": "hotel-internet",
|
93 |
+
"description": "whether the hotel has internet",
|
94 |
+
"possible_values": [
|
95 |
+
"free",
|
96 |
+
"no",
|
97 |
+
"yes"
|
98 |
+
],
|
99 |
+
"is_categorical": true
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"name": "hotel-name",
|
103 |
+
"description": "name of the hotel",
|
104 |
+
"possible_values": [],
|
105 |
+
"is_categorical": false
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"name": "hotel-area",
|
109 |
+
"description": "area or place of the hotel",
|
110 |
+
"possible_values": [
|
111 |
+
"centre",
|
112 |
+
"east",
|
113 |
+
"north",
|
114 |
+
"south",
|
115 |
+
"west"
|
116 |
+
],
|
117 |
+
"is_categorical": true
|
118 |
+
},
|
119 |
+
{
|
120 |
+
"name": "hotel-address",
|
121 |
+
"description": "address of the hotel",
|
122 |
+
"is_categorical": false
|
123 |
+
},
|
124 |
+
{
|
125 |
+
"name": "hotel-phone",
|
126 |
+
"description": "phone number of the hotel",
|
127 |
+
"is_categorical": false
|
128 |
+
},
|
129 |
+
{
|
130 |
+
"name": "hotel-postcode",
|
131 |
+
"description": "postal code of the hotel",
|
132 |
+
"is_categorical": false
|
133 |
+
},
|
134 |
+
{
|
135 |
+
"name": "hotel-ref",
|
136 |
+
"description": "reference number of the hotel booking",
|
137 |
+
"is_categorical": false
|
138 |
+
}
|
139 |
+
],
|
140 |
+
"description": "hotel reservations and vacation stays",
|
141 |
+
"intents": [
|
142 |
+
{
|
143 |
+
"name": "find_hotel",
|
144 |
+
"description": "search for a hotel to stay in",
|
145 |
+
"is_transactional": false,
|
146 |
+
"required_slots": [],
|
147 |
+
"optional_slots": {
|
148 |
+
"hotel-pricerange": "dontcare",
|
149 |
+
"hotel-type": "dontcare",
|
150 |
+
"hotel-parking": "dontcare",
|
151 |
+
"hotel-bookday": "dontcare",
|
152 |
+
"hotel-bookpeople": "dontcare",
|
153 |
+
"hotel-bookstay": "dontcare",
|
154 |
+
"hotel-stars": "dontcare",
|
155 |
+
"hotel-internet": "dontcare",
|
156 |
+
"hotel-name": "dontcare",
|
157 |
+
"hotel-area": "dontcare"
|
158 |
+
}
|
159 |
+
},
|
160 |
+
{
|
161 |
+
"name": "book_hotel",
|
162 |
+
"description": "book a hotel to stay in",
|
163 |
+
"is_transactional": true,
|
164 |
+
"required_slots": [],
|
165 |
+
"optional_slots": {
|
166 |
+
"hotel-pricerange": "dontcare",
|
167 |
+
"hotel-type": "dontcare",
|
168 |
+
"hotel-parking": "dontcare",
|
169 |
+
"hotel-bookday": "dontcare",
|
170 |
+
"hotel-bookpeople": "dontcare",
|
171 |
+
"hotel-bookstay": "dontcare",
|
172 |
+
"hotel-stars": "dontcare",
|
173 |
+
"hotel-internet": "dontcare",
|
174 |
+
"hotel-name": "dontcare",
|
175 |
+
"hotel-area": "dontcare"
|
176 |
+
}
|
177 |
+
}
|
178 |
+
]
|
179 |
+
},
|
180 |
+
{
|
181 |
+
"service_name": "train",
|
182 |
+
"slots": [
|
183 |
+
{
|
184 |
+
"name": "train-arriveby",
|
185 |
+
"description": "arrival time of the train",
|
186 |
+
"possible_values": [],
|
187 |
+
"is_categorical": false
|
188 |
+
},
|
189 |
+
{
|
190 |
+
"name": "train-departure",
|
191 |
+
"description": "departure location of the train",
|
192 |
+
"possible_values": [
|
193 |
+
"birmingham new street",
|
194 |
+
"bishops stortford",
|
195 |
+
"broxbourne",
|
196 |
+
"cambridge",
|
197 |
+
"ely",
|
198 |
+
"kings lynn",
|
199 |
+
"leicester",
|
200 |
+
"london kings cross",
|
201 |
+
"london liverpool street",
|
202 |
+
"norwich",
|
203 |
+
"peterborough",
|
204 |
+
"stansted airport",
|
205 |
+
"stevenage"
|
206 |
+
],
|
207 |
+
"is_categorical": true
|
208 |
+
},
|
209 |
+
{
|
210 |
+
"name": "train-day",
|
211 |
+
"description": "day of the train",
|
212 |
+
"possible_values": [
|
213 |
+
"monday",
|
214 |
+
"tuesday",
|
215 |
+
"wednesday",
|
216 |
+
"thursday",
|
217 |
+
"friday",
|
218 |
+
"saturday",
|
219 |
+
"sunday"
|
220 |
+
],
|
221 |
+
"is_categorical": true
|
222 |
+
},
|
223 |
+
{
|
224 |
+
"name": "train-bookpeople",
|
225 |
+
"description": "how many train tickets you need",
|
226 |
+
"possible_values": [
|
227 |
+
"0",
|
228 |
+
"1",
|
229 |
+
"2",
|
230 |
+
"3",
|
231 |
+
"4",
|
232 |
+
"5",
|
233 |
+
"6",
|
234 |
+
"7",
|
235 |
+
"8",
|
236 |
+
"9",
|
237 |
+
"10",
|
238 |
+
"15"
|
239 |
+
],
|
240 |
+
"is_categorical": true
|
241 |
+
},
|
242 |
+
{
|
243 |
+
"name": "train-leaveat",
|
244 |
+
"description": "leaving time for the train",
|
245 |
+
"possible_values": [],
|
246 |
+
"is_categorical": false
|
247 |
+
},
|
248 |
+
{
|
249 |
+
"name": "train-destination",
|
250 |
+
"description": "destination of the train",
|
251 |
+
"possible_values": [
|
252 |
+
"birmingham new street",
|
253 |
+
"bishops stortford",
|
254 |
+
"broxbourne",
|
255 |
+
"cambridge",
|
256 |
+
"ely",
|
257 |
+
"kings lynn",
|
258 |
+
"leicester",
|
259 |
+
"london kings cross",
|
260 |
+
"london liverpool street",
|
261 |
+
"norwich",
|
262 |
+
"peterborough",
|
263 |
+
"stansted airport",
|
264 |
+
"stevenage"
|
265 |
+
],
|
266 |
+
"is_categorical": true
|
267 |
+
},
|
268 |
+
{
|
269 |
+
"name": "train-trainid",
|
270 |
+
"description": "id of the train",
|
271 |
+
"is_categorical": false
|
272 |
+
},
|
273 |
+
{
|
274 |
+
"name": "train-ref",
|
275 |
+
"description": "reference number of the train booking",
|
276 |
+
"is_categorical": false
|
277 |
+
},
|
278 |
+
{
|
279 |
+
"name": "train-price",
|
280 |
+
"description": "price of the train",
|
281 |
+
"is_categorical": false
|
282 |
+
},
|
283 |
+
{
|
284 |
+
"name": "train-duration",
|
285 |
+
"description": "duration of the travel",
|
286 |
+
"is_categorical": false
|
287 |
+
}
|
288 |
+
],
|
289 |
+
"description": "find trains that take you to places",
|
290 |
+
"intents": [
|
291 |
+
{
|
292 |
+
"name": "find_train",
|
293 |
+
"description": "search for trains that take you places",
|
294 |
+
"is_transactional": false,
|
295 |
+
"required_slots": [],
|
296 |
+
"optional_slots": {
|
297 |
+
"train-destination": "dontcare",
|
298 |
+
"train-arriveby": "dontcare",
|
299 |
+
"train-departure": "dontcare",
|
300 |
+
"train-day": "dontcare",
|
301 |
+
"train-bookpeople": "dontcare",
|
302 |
+
"train-leaveat": "dontcare"
|
303 |
+
}
|
304 |
+
},
|
305 |
+
{
|
306 |
+
"name": "book_train",
|
307 |
+
"description": "book train tickets",
|
308 |
+
"is_transactional": true,
|
309 |
+
"required_slots": [],
|
310 |
+
"optional_slots": {
|
311 |
+
"train-destination": "dontcare",
|
312 |
+
"train-arriveby": "dontcare",
|
313 |
+
"train-departure": "dontcare",
|
314 |
+
"train-day": "dontcare",
|
315 |
+
"train-bookpeople": "dontcare",
|
316 |
+
"train-leaveat": "dontcare"
|
317 |
+
}
|
318 |
+
}
|
319 |
+
]
|
320 |
+
},
|
321 |
+
{
|
322 |
+
"service_name": "attraction",
|
323 |
+
"slots": [
|
324 |
+
{
|
325 |
+
"name": "attraction-area",
|
326 |
+
"description": "area to search for attractions",
|
327 |
+
"possible_values": [
|
328 |
+
"centre",
|
329 |
+
"east",
|
330 |
+
"north",
|
331 |
+
"south",
|
332 |
+
"west"
|
333 |
+
],
|
334 |
+
"is_categorical": true
|
335 |
+
},
|
336 |
+
{
|
337 |
+
"name": "attraction-name",
|
338 |
+
"description": "name of the attraction",
|
339 |
+
"possible_values": [],
|
340 |
+
"is_categorical": false
|
341 |
+
},
|
342 |
+
{
|
343 |
+
"name": "attraction-type",
|
344 |
+
"description": "type of the attraction",
|
345 |
+
"possible_values": [
|
346 |
+
"architecture",
|
347 |
+
"boat",
|
348 |
+
"cinema",
|
349 |
+
"college",
|
350 |
+
"concerthall",
|
351 |
+
"entertainment",
|
352 |
+
"museum",
|
353 |
+
"multiple sports",
|
354 |
+
"nightclub",
|
355 |
+
"park",
|
356 |
+
"swimmingpool",
|
357 |
+
"theatre"
|
358 |
+
],
|
359 |
+
"is_categorical": true
|
360 |
+
},
|
361 |
+
{
|
362 |
+
"name": "attraction-entrancefee",
|
363 |
+
"description": "how much is the entrance fee",
|
364 |
+
"is_categorical": false
|
365 |
+
},
|
366 |
+
{
|
367 |
+
"name": "attraction-openhours",
|
368 |
+
"description": "open hours of the attraction",
|
369 |
+
"is_categorical": false
|
370 |
+
},
|
371 |
+
{
|
372 |
+
"name": "attraction-address",
|
373 |
+
"description": "address of the attraction",
|
374 |
+
"is_categorical": false
|
375 |
+
},
|
376 |
+
{
|
377 |
+
"name": "attraction-phone",
|
378 |
+
"description": "phone number of the attraction",
|
379 |
+
"is_categorical": false
|
380 |
+
},
|
381 |
+
{
|
382 |
+
"name": "attraction-postcode",
|
383 |
+
"description": "postal code of the attraction",
|
384 |
+
"is_categorical": false
|
385 |
+
}
|
386 |
+
],
|
387 |
+
"description": "find touristy stuff to do around you",
|
388 |
+
"intents": [
|
389 |
+
{
|
390 |
+
"name": "find_attraction",
|
391 |
+
"description": "search for places to see for leisure",
|
392 |
+
"is_transactional": false,
|
393 |
+
"required_slots": [],
|
394 |
+
"optional_slots": {
|
395 |
+
"attraction-area": "dontcare",
|
396 |
+
"attraction-name": "dontcare",
|
397 |
+
"attraction-type": "dontcare"
|
398 |
+
}
|
399 |
+
}
|
400 |
+
]
|
401 |
+
},
|
402 |
+
{
|
403 |
+
"service_name": "restaurant",
|
404 |
+
"slots": [
|
405 |
+
{
|
406 |
+
"name": "restaurant-pricerange",
|
407 |
+
"description": "price budget for the restaurant",
|
408 |
+
"possible_values": [
|
409 |
+
"cheap",
|
410 |
+
"expensive",
|
411 |
+
"moderate"
|
412 |
+
],
|
413 |
+
"is_categorical": true
|
414 |
+
},
|
415 |
+
{
|
416 |
+
"name": "restaurant-area",
|
417 |
+
"description": "area or place of the restaurant",
|
418 |
+
"possible_values": [
|
419 |
+
"centre",
|
420 |
+
"east",
|
421 |
+
"north",
|
422 |
+
"south",
|
423 |
+
"west"
|
424 |
+
],
|
425 |
+
"is_categorical": true
|
426 |
+
},
|
427 |
+
{
|
428 |
+
"name": "restaurant-food",
|
429 |
+
"description": "the cuisine of the restaurant you are looking for",
|
430 |
+
"is_categorical": false
|
431 |
+
},
|
432 |
+
{
|
433 |
+
"name": "restaurant-name",
|
434 |
+
"description": "name of the restaurant",
|
435 |
+
"possible_values": [],
|
436 |
+
"is_categorical": false
|
437 |
+
},
|
438 |
+
{
|
439 |
+
"name": "restaurant-bookday",
|
440 |
+
"description": "day of the restaurant booking",
|
441 |
+
"possible_values": [
|
442 |
+
"monday",
|
443 |
+
"tuesday",
|
444 |
+
"wednesday",
|
445 |
+
"thursday",
|
446 |
+
"friday",
|
447 |
+
"saturday",
|
448 |
+
"sunday"
|
449 |
+
],
|
450 |
+
"is_categorical": true
|
451 |
+
},
|
452 |
+
{
|
453 |
+
"name": "restaurant-bookpeople",
|
454 |
+
"description": "how many people for the restaurant reservation",
|
455 |
+
"possible_values": [
|
456 |
+
"1",
|
457 |
+
"2",
|
458 |
+
"3",
|
459 |
+
"4",
|
460 |
+
"5",
|
461 |
+
"6",
|
462 |
+
"7",
|
463 |
+
"8"
|
464 |
+
],
|
465 |
+
"is_categorical": true
|
466 |
+
},
|
467 |
+
{
|
468 |
+
"name": "restaurant-booktime",
|
469 |
+
"description": "time of the restaurant booking",
|
470 |
+
"possible_values": [],
|
471 |
+
"is_categorical": false
|
472 |
+
},
|
473 |
+
{
|
474 |
+
"name": "restaurant-address",
|
475 |
+
"description": "address of the restaurant",
|
476 |
+
"is_categorical": false
|
477 |
+
},
|
478 |
+
{
|
479 |
+
"name": "restaurant-phone",
|
480 |
+
"description": "phone number of the restaurant",
|
481 |
+
"is_categorical": false
|
482 |
+
},
|
483 |
+
{
|
484 |
+
"name": "restaurant-postcode",
|
485 |
+
"description": "postal code of the restaurant",
|
486 |
+
"is_categorical": false
|
487 |
+
},
|
488 |
+
{
|
489 |
+
"name": "restaurant-ref",
|
490 |
+
"description": "reference number of the restaurant booking",
|
491 |
+
"is_categorical": false
|
492 |
+
}
|
493 |
+
],
|
494 |
+
"description": "find places to dine and whet your appetite",
|
495 |
+
"intents": [
|
496 |
+
{
|
497 |
+
"name": "find_restaurant",
|
498 |
+
"description": "search for places to wine and dine",
|
499 |
+
"is_transactional": false,
|
500 |
+
"required_slots": [],
|
501 |
+
"optional_slots": {
|
502 |
+
"restaurant-pricerange": "dontcare",
|
503 |
+
"restaurant-area": "dontcare",
|
504 |
+
"restaurant-food": "dontcare",
|
505 |
+
"restaurant-name": "dontcare",
|
506 |
+
"restaurant-bookday": "dontcare",
|
507 |
+
"restaurant-bookpeople": "dontcare",
|
508 |
+
"restaurant-booktime": "dontcare"
|
509 |
+
}
|
510 |
+
},
|
511 |
+
{
|
512 |
+
"name": "book_restaurant",
|
513 |
+
"description": "book a table at a restaurant",
|
514 |
+
"is_transactional": true,
|
515 |
+
"required_slots": [],
|
516 |
+
"optional_slots": {
|
517 |
+
"restaurant-pricerange": "dontcare",
|
518 |
+
"restaurant-area": "dontcare",
|
519 |
+
"restaurant-food": "dontcare",
|
520 |
+
"restaurant-name": "dontcare",
|
521 |
+
"restaurant-bookday": "dontcare",
|
522 |
+
"restaurant-bookpeople": "dontcare",
|
523 |
+
"restaurant-booktime": "dontcare"
|
524 |
+
}
|
525 |
+
}
|
526 |
+
]
|
527 |
+
},
|
528 |
+
{
|
529 |
+
"service_name": "hospital",
|
530 |
+
"slots": [
|
531 |
+
{
|
532 |
+
"name": "hospital-department",
|
533 |
+
"description": "type of medical care",
|
534 |
+
"possible_values": [],
|
535 |
+
"is_categorical": false
|
536 |
+
},
|
537 |
+
{
|
538 |
+
"name": "hospital-address",
|
539 |
+
"description": "address of the hospital",
|
540 |
+
"is_categorical": false
|
541 |
+
},
|
542 |
+
{
|
543 |
+
"name": "hospital-phone",
|
544 |
+
"description": "phone number of the hospital",
|
545 |
+
"is_categorical": false
|
546 |
+
},
|
547 |
+
{
|
548 |
+
"name": "hospital-postcode",
|
549 |
+
"description": "postal code of the hospital",
|
550 |
+
"is_categorical": false
|
551 |
+
}
|
552 |
+
],
|
553 |
+
"description": "making you feel better when you are ill",
|
554 |
+
"intents": [
|
555 |
+
{
|
556 |
+
"name": "find_hospital",
|
557 |
+
"description": "search for a medical facility or a doctor",
|
558 |
+
"is_transactional": false,
|
559 |
+
"required_slots": [],
|
560 |
+
"optional_slots": {
|
561 |
+
"hospital-department": "dontcare"
|
562 |
+
}
|
563 |
+
}
|
564 |
+
]
|
565 |
+
},
|
566 |
+
{
|
567 |
+
"service_name": "taxi",
|
568 |
+
"slots": [
|
569 |
+
{
|
570 |
+
"name": "taxi-leaveat",
|
571 |
+
"description": "leaving time of taxi",
|
572 |
+
"possible_values": [],
|
573 |
+
"is_categorical": false
|
574 |
+
},
|
575 |
+
{
|
576 |
+
"name": "taxi-destination",
|
577 |
+
"description": "destination of taxi",
|
578 |
+
"possible_values": [],
|
579 |
+
"is_categorical": false
|
580 |
+
},
|
581 |
+
{
|
582 |
+
"name": "taxi-departure",
|
583 |
+
"description": "departure location of taxi",
|
584 |
+
"possible_values": [],
|
585 |
+
"is_categorical": false
|
586 |
+
},
|
587 |
+
{
|
588 |
+
"name": "taxi-arriveby",
|
589 |
+
"description": "arrival time of taxi",
|
590 |
+
"possible_values": [],
|
591 |
+
"is_categorical": false
|
592 |
+
},
|
593 |
+
{
|
594 |
+
"name": "taxi-type",
|
595 |
+
"description": "car type of the taxi",
|
596 |
+
"is_categorical": false
|
597 |
+
},
|
598 |
+
{
|
599 |
+
"name": "taxi-phone",
|
600 |
+
"description": "phone number of the taxi",
|
601 |
+
"is_categorical": false
|
602 |
+
}
|
603 |
+
],
|
604 |
+
"description": "rent cheap cabs to avoid traffic",
|
605 |
+
"intents": [
|
606 |
+
{
|
607 |
+
"name": "book_taxi",
|
608 |
+
"description": "book taxis to travel between places",
|
609 |
+
"is_transactional": true,
|
610 |
+
"required_slots": [],
|
611 |
+
"optional_slots": {
|
612 |
+
"taxi-leaveat": "dontcare",
|
613 |
+
"taxi-destination": "dontcare",
|
614 |
+
"taxi-departure": "dontcare",
|
615 |
+
"taxi-arriveby": "dontcare"
|
616 |
+
}
|
617 |
+
}
|
618 |
+
]
|
619 |
+
},
|
620 |
+
{
|
621 |
+
"service_name": "bus",
|
622 |
+
"slots": [
|
623 |
+
{
|
624 |
+
"name": "bus-departure",
|
625 |
+
"description": "departure location of bus",
|
626 |
+
"possible_values": [
|
627 |
+
"cambridge"
|
628 |
+
],
|
629 |
+
"is_categorical": false
|
630 |
+
},
|
631 |
+
{
|
632 |
+
"name": "bus-destination",
|
633 |
+
"description": "destination of bus",
|
634 |
+
"possible_values": [
|
635 |
+
"london kings cross",
|
636 |
+
"bishops stortford",
|
637 |
+
"cambridge",
|
638 |
+
"kohinoor"
|
639 |
+
],
|
640 |
+
"is_categorical": false
|
641 |
+
},
|
642 |
+
{
|
643 |
+
"name": "bus-leaveat",
|
644 |
+
"description": "leaving time of bus",
|
645 |
+
"is_categorical": false
|
646 |
+
},
|
647 |
+
{
|
648 |
+
"name": "bus-day",
|
649 |
+
"description": "day to use the bus tickets",
|
650 |
+
"possible_values": [
|
651 |
+
"wednesday"
|
652 |
+
],
|
653 |
+
"is_categorical": true
|
654 |
+
}
|
655 |
+
],
|
656 |
+
"description": "bus service for traveling",
|
657 |
+
"intents": [
|
658 |
+
{
|
659 |
+
"name": "find_bus",
|
660 |
+
"description": "search for a bus",
|
661 |
+
"is_transactional": false,
|
662 |
+
"required_slots": [],
|
663 |
+
"optional_slots": {
|
664 |
+
"bus-departure": "dontcare",
|
665 |
+
"bus-destination": "dontcare",
|
666 |
+
"bus-day": "dontcare",
|
667 |
+
"bus-leaveat": "dontcare"
|
668 |
+
}
|
669 |
+
}
|
670 |
+
]
|
671 |
+
},
|
672 |
+
{
|
673 |
+
"service_name": "police",
|
674 |
+
"slots": [
|
675 |
+
{
|
676 |
+
"name": "police-address",
|
677 |
+
"description": "address of the police station",
|
678 |
+
"is_categorical": false
|
679 |
+
},
|
680 |
+
{
|
681 |
+
"name": "police-phone",
|
682 |
+
"description": "phone number of the police station",
|
683 |
+
"is_categorical": false
|
684 |
+
},
|
685 |
+
{
|
686 |
+
"name": "police-postcode",
|
687 |
+
"description": "postal code of the police station",
|
688 |
+
"is_categorical": false
|
689 |
+
},
|
690 |
+
{
|
691 |
+
"name": "police-name",
|
692 |
+
"description": "name of the police station",
|
693 |
+
"possible_values": [
|
694 |
+
"parkside police station"
|
695 |
+
],
|
696 |
+
"is_categorical": true
|
697 |
+
}
|
698 |
+
],
|
699 |
+
"description": "police station",
|
700 |
+
"intents": [
|
701 |
+
{
|
702 |
+
"name": "police",
|
703 |
+
"description": "search for police station",
|
704 |
+
"is_transactional": false,
|
705 |
+
"required_slots": [],
|
706 |
+
"optional_slots": {
|
707 |
+
"police-name": "dontcare"
|
708 |
+
}
|
709 |
+
}
|
710 |
+
]
|
711 |
+
}
|
712 |
+
]
|
scripts/user_model_code/interaction/utils.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
|
4 |
+
|
5 |
+
def segment_gen(gen, dial_id):
|
6 |
+
def _color(_segment):
|
7 |
+
if tag == "CTX":
|
8 |
+
_segment = _segment.replace(" </USR>", f"{bcolors.ENDC}")
|
9 |
+
_segment = _segment.replace(" </SYS>", f"{bcolors.ENDC}")
|
10 |
+
_segment = _segment.replace("<USR/> ", f"USR: {bcolors.OKCYAN}")
|
11 |
+
_segment = _segment.replace("<SYS/> ", f"SYS: {bcolors.OKBLUE}")
|
12 |
+
if tag == "SYS_UTT":
|
13 |
+
_segment = f"{bcolors.OKBLUE}" + _segment + f"{bcolors.ENDC}"
|
14 |
+
if tag == "USR_UTT":
|
15 |
+
_segment = f"{bcolors.OKCYAN}" + _segment + f"{bcolors.ENDC}"
|
16 |
+
if tag in ["SYS_ACT", "USR_ACT", "GOAL"]:
|
17 |
+
_segment = _segment.replace("<ACT/> ", f"{bcolors.RED}")
|
18 |
+
_segment = _segment.replace(" </ACT>", f"{bcolors.ENDC}")
|
19 |
+
_segment = _segment.replace("<SLOT/> ", f"{bcolors.YELLOW}")
|
20 |
+
_segment = _segment.replace(" </SLOT>", f"{bcolors.ENDC}")
|
21 |
+
_segment = _segment.replace("<VALUE/> ", f"{bcolors.GREEN}")
|
22 |
+
_segment = _segment.replace(" </VALUE>", f"{bcolors.ENDC}")
|
23 |
+
if tag == "GOAL":
|
24 |
+
_segment = _segment.replace(
|
25 |
+
"<SCENARIO/>", f"<SCENARIO/>{bcolors.UNDERLINE}"
|
26 |
+
)
|
27 |
+
_segment = _segment.replace("</SCENARIO>", f"{bcolors.ENDC}</SCENARIO>")
|
28 |
+
_segment = _segment.replace("<TASK/>", f"<TASK/>{bcolors.UNDERLINE}")
|
29 |
+
_segment = _segment.replace("</TASK>", f"{bcolors.ENDC}</TASK>")
|
30 |
+
# if tag in ["SNT", "GC"]:
|
31 |
+
# segment = segment.replace("<{}/> ".format(tag), "<{}/> *".format(tag))
|
32 |
+
# segment = segment.replace(" </{}>".format(tag), "* <{}/>".format(tag))
|
33 |
+
return _segment
|
34 |
+
|
35 |
+
assert isinstance(gen, str)
|
36 |
+
# gen = gen.split()
|
37 |
+
# print(gen)
|
38 |
+
print("*** Dial_id: {} ***".format(dial_id))
|
39 |
+
for tag in [
|
40 |
+
"CTX",
|
41 |
+
"SYS_UTT",
|
42 |
+
"SYS_ACT",
|
43 |
+
"GOAL",
|
44 |
+
"SNT",
|
45 |
+
"RA",
|
46 |
+
"GC",
|
47 |
+
"USR_ACT",
|
48 |
+
"USR_UTT",
|
49 |
+
]:
|
50 |
+
segment = find_segment(gen, tag)
|
51 |
+
if segment is not None:
|
52 |
+
print('{} -> "{}"'.format(tag, _color(segment)))
|
53 |
+
else:
|
54 |
+
print("Fail to find the segment...")
|
55 |
+
print("GEN:", gen)
|
56 |
+
print("---" * 30)
|
57 |
+
|
58 |
+
|
59 |
+
# input("press...")
|
60 |
+
|
61 |
+
|
62 |
+
def get_original_act_set():
|
63 |
+
# full act vocab:
|
64 |
+
# https://github.com/ConvLab/ConvLab/blob/master/data/multiwoz/annotation/Multiwoz%20data%20analysis.md#dialog-act
|
65 |
+
acts = set()
|
66 |
+
acts.add("Inform")
|
67 |
+
acts.add("Request")
|
68 |
+
acts.add(
|
69 |
+
"NoOffer"
|
70 |
+
) # equivalent to the concept of `no matching`, `cannot find` in database
|
71 |
+
acts.add("Recommend")
|
72 |
+
acts.add("Select")
|
73 |
+
acts.add(
|
74 |
+
"OfferBook"
|
75 |
+
) # only for `train` domain, ask if book is needed, equivalent to `Booking-Inform` with [[none, none]]
|
76 |
+
# args in restaurant/hotel domain
|
77 |
+
acts.add(
|
78 |
+
"OfferBooked"
|
79 |
+
) # only for `train` domain, inform booking is complete, with corresponding info (such as ref number)
|
80 |
+
acts.add("Book") # inform booking is successful, equivalent to `OfferBooked` above
|
81 |
+
acts.add(
|
82 |
+
"NoBook"
|
83 |
+
) # inform booking fails, might because of no availability, usually come together act `request`
|
84 |
+
acts.add("bye")
|
85 |
+
acts.add("greet")
|
86 |
+
acts.add("reqmore")
|
87 |
+
acts.add("welcome")
|
88 |
+
acts.add("thank")
|
89 |
+
return acts
|
90 |
+
|
91 |
+
|
92 |
+
def get_act_natural_language(act):
|
93 |
+
if act in ["bye", "greet", "reqmore", "welcome", "thank"]:
|
94 |
+
return act
|
95 |
+
|
96 |
+
assert act[0].isupper()
|
97 |
+
tokens = re.findall("[A-Z][^A-Z]*", act) # e.g., `FindEvents` -> `Find Events`
|
98 |
+
tokens = list(map(str.lower, tokens)) # lower case, -> `find events`
|
99 |
+
act_nl = " ".join(tokens)
|
100 |
+
return act_nl
|
101 |
+
|
102 |
+
|
103 |
+
def convert_act_into_sgd(act, SPECIAL_TOKENS):
|
104 |
+
# TODO: check inference result to see if mapping on NoOffer, OfferBook and NoBook are fine
|
105 |
+
"""
|
106 |
+
convert multiwoz acts (w/o domain info) into sgd acts ensure that acts with same concept use one name
|
107 |
+
e.g., Book (OfferBooked) -> NOTIFY_SUCCESS, NoBook -> NOTIFY_FAILURE
|
108 |
+
"""
|
109 |
+
if act == "NoOffer":
|
110 |
+
act = "NOTIFY_FAILURE"
|
111 |
+
|
112 |
+
elif act == "Recommend":
|
113 |
+
act = "OFFER"
|
114 |
+
|
115 |
+
# technically, `OfferBook` is equivalent to (`act=OFFER_INTENT, slot=intent, value=ReserveRestaurant`)
|
116 |
+
# on system side in sgd since (1) the conversion is not trivial (completely different representations)
|
117 |
+
# and (2) multiwoz has no slot called `intent` one cannot simply convert `OfferBook` to `OFFER_INTENT`
|
118 |
+
# we thus keep the act as is
|
119 |
+
# note that there is no slot `intent` and value conveying intents in multiwoz
|
120 |
+
elif act == "OfferBook":
|
121 |
+
act = "Offer_Book"
|
122 |
+
|
123 |
+
elif act == "OfferBooked":
|
124 |
+
act = "NOTIFY_SUCCESS"
|
125 |
+
|
126 |
+
elif act == "Book": # same as `OfferBooked`
|
127 |
+
act = "NOTIFY_SUCCESS"
|
128 |
+
|
129 |
+
elif act == "NoBook":
|
130 |
+
act = "NOTIFY_FAILURE"
|
131 |
+
|
132 |
+
elif act == "bye":
|
133 |
+
act = "GOODBYE"
|
134 |
+
|
135 |
+
elif act == "reqmore":
|
136 |
+
act = "REQ_MORE"
|
137 |
+
|
138 |
+
elif act == "thank":
|
139 |
+
act = "THANK_YOU"
|
140 |
+
# elif act == "greet":
|
141 |
+
# elif act == "welcome":
|
142 |
+
act = act.upper() # align with sgd acts, e.g., `Inform` -> `INFORM`
|
143 |
+
|
144 |
+
# check if valid
|
145 |
+
assert "_{}_".format(act) in SPECIAL_TOKENS["additional_special_tokens"]
|
146 |
+
return act
|
147 |
+
|
148 |
+
|
149 |
+
def load_schema(schema_file):
|
150 |
+
def _update(key, value, mapping):
|
151 |
+
if key in mapping:
|
152 |
+
assert (
|
153 |
+
value == mapping[key]
|
154 |
+
) # ensure service meta is the same between data splits
|
155 |
+
else:
|
156 |
+
mapping[key] = value
|
157 |
+
|
158 |
+
def _restructure_service_meta(service_meta, attribute):
|
159 |
+
""" "convert slot/intent metadata list into dict(slot/intent=metadata)"""
|
160 |
+
assert attribute in ["slots", "intents"]
|
161 |
+
mapping = {}
|
162 |
+
for value in service_meta[attribute]:
|
163 |
+
key = value["name"]
|
164 |
+
if attribute == "slots": # domain-slot in multiwoz
|
165 |
+
assert "-" in key
|
166 |
+
_, key = key.split("-") # domain, slot
|
167 |
+
key = normalise_slot(key)
|
168 |
+
else: # intent
|
169 |
+
key = normalise_intent(key)
|
170 |
+
mapping[key] = value
|
171 |
+
service_meta[attribute] = mapping
|
172 |
+
|
173 |
+
with open(schema_file) as f:
|
174 |
+
data = json.load(f)
|
175 |
+
|
176 |
+
SERVICE2META = {}
|
177 |
+
SLOTS, INTENTS = set(), set()
|
178 |
+
for service_meta in data:
|
179 |
+
service = service_meta["service_name"]
|
180 |
+
_restructure_service_meta(service_meta, "slots")
|
181 |
+
_restructure_service_meta(service_meta, "intents")
|
182 |
+
_update(service, service_meta, SERVICE2META)
|
183 |
+
|
184 |
+
# collect domain-independent slots
|
185 |
+
# for domain_slot in service_meta["slots"]:
|
186 |
+
# assert "-" in domain_slot
|
187 |
+
# domain, slot = domain_slot.split("-")
|
188 |
+
# slot = normalise_slot(slot)
|
189 |
+
# SLOTS.add(slot)
|
190 |
+
for slot in service_meta["slots"]:
|
191 |
+
SLOTS.add(slot)
|
192 |
+
|
193 |
+
for intent in service_meta["intents"]:
|
194 |
+
# intent = normalise_intent(intent)
|
195 |
+
INTENTS.add(intent)
|
196 |
+
|
197 |
+
print("Load schema, intents: {}, slots: {}".format(len(INTENTS), len(SLOTS)))
|
198 |
+
return SERVICE2META, INTENTS, SLOTS
|
199 |
+
|
200 |
+
|
201 |
+
def normalise_intent(intent):
|
202 |
+
"""convert intent into natural language, e.g., find_hotel -> find hotel"""
|
203 |
+
if intent == "police":
|
204 |
+
intent = "find_police"
|
205 |
+
if intent == "book_taxi":
|
206 |
+
intent = "find_taxi"
|
207 |
+
assert "_" in intent
|
208 |
+
return " ".join(intent.split("_"))
|
209 |
+
|
210 |
+
|
211 |
+
def normalise_slot(slot):
|
212 |
+
if slot == "pricerange":
|
213 |
+
return "price range"
|
214 |
+
|
215 |
+
elif slot == "bookday":
|
216 |
+
return "book day"
|
217 |
+
|
218 |
+
elif slot == "bookpeople":
|
219 |
+
return "book people"
|
220 |
+
|
221 |
+
elif slot == "booktime":
|
222 |
+
return "book time"
|
223 |
+
|
224 |
+
elif slot == "bookstay":
|
225 |
+
return "book stay"
|
226 |
+
|
227 |
+
elif slot == "ref":
|
228 |
+
return "reference"
|
229 |
+
|
230 |
+
elif slot == "arriveby":
|
231 |
+
return "arrive by"
|
232 |
+
|
233 |
+
elif slot == "leaveat":
|
234 |
+
return "leave at"
|
235 |
+
|
236 |
+
elif slot == "trainid":
|
237 |
+
return "train id"
|
238 |
+
|
239 |
+
elif slot == "openhours":
|
240 |
+
return "open hours"
|
241 |
+
|
242 |
+
elif slot == "entrancefee":
|
243 |
+
return "entrance fee"
|
244 |
+
|
245 |
+
elif slot in ["none", "?"]:
|
246 |
+
# return "_Empty_" # special token mark will be added during sequence linearlisation
|
247 |
+
return "Empty"
|
248 |
+
|
249 |
+
else:
|
250 |
+
return slot
|
251 |
+
|
252 |
+
|
253 |
+
def normalise_value(value):
|
254 |
+
# deal with binary and empty values
|
255 |
+
if value == "yes":
|
256 |
+
# return "_True_"
|
257 |
+
return "True"
|
258 |
+
|
259 |
+
elif value == "no":
|
260 |
+
# return "_False_"
|
261 |
+
return "False"
|
262 |
+
|
263 |
+
elif value in ["none", "?"]:
|
264 |
+
# return "_Empty_"
|
265 |
+
return "Empty"
|
266 |
+
|
267 |
+
# if value == "swimmingpool": # for simplicity, dont split
|
268 |
+
# return "swimming pool"
|
269 |
+
|
270 |
+
else:
|
271 |
+
return value
|
272 |
+
|
273 |
+
|
274 |
+
def wrap_element(content_type, content):
|
275 |
+
"""
|
276 |
+
wrap elements such as slot, value, e.g., <SLOT/> slot </SLOT>
|
277 |
+
"""
|
278 |
+
assert "/" not in content_type
|
279 |
+
return "<{}/> {} </{}>".format(content_type, content, content_type)
|
280 |
+
|
281 |
+
|
282 |
+
def add_str(str1, str2):
|
283 |
+
return str1 + " " + str2
|
284 |
+
|
285 |
+
|
286 |
+
def find_segment(gen, tag):
|
287 |
+
assert isinstance(gen, str)
|
288 |
+
gen = gen.split()
|
289 |
+
try:
|
290 |
+
start = gen.index("<{}/>".format(tag)) + 1
|
291 |
+
end = gen.index("</{}>".format(tag))
|
292 |
+
segment = " ".join(gen[start:end])
|
293 |
+
except Exception:
|
294 |
+
print("Missing {} tag in generated sequence".format(tag))
|
295 |
+
segment = None
|
296 |
+
return segment
|
297 |
+
|
298 |
+
|
299 |
+
class bcolors:
|
300 |
+
HEADER = "\033[95m"
|
301 |
+
OKBLUE = "\033[94m"
|
302 |
+
OKCYAN = "\033[96m"
|
303 |
+
GREEN = "\033[92m"
|
304 |
+
YELLOW = "\033[93m"
|
305 |
+
RED = "\033[91m"
|
306 |
+
ENDC = "\033[0m"
|
307 |
+
BOLD = "\033[1m"
|
308 |
+
UNDERLINE = "\033[4m"
|
scripts/user_model_code/main_user_model.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import random
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
|
9 |
+
from tqdm import tqdm
|
10 |
+
from transformers import (
|
11 |
+
AdamW,
|
12 |
+
GPT2Config,
|
13 |
+
GPT2LMHeadModel,
|
14 |
+
GPT2Tokenizer,
|
15 |
+
get_linear_schedule_with_warmup,
|
16 |
+
)
|
17 |
+
|
18 |
+
import wandb
|
19 |
+
from crazyneuraluser.user_model_code.argument import get_args
|
20 |
+
|
21 |
+
# from interact import interact
|
22 |
+
from crazyneuraluser.user_model_code.dataset import SGD_Dataset
|
23 |
+
from crazyneuraluser.user_model_code.utils_generation import decode_e2e
|
24 |
+
from crazyneuraluser.user_model_code.utils_sgd import get_special_tokens
|
25 |
+
|
26 |
+
|
27 |
+
def print_loss(epoch, data_type, LOSS, t0):
|
28 |
+
print(
|
29 |
+
"Epoch: {} | {} loss: {:.3f} | time: {:.1f}".format(
|
30 |
+
epoch, data_type, LOSS, time.time() - t0
|
31 |
+
)
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
def print_score(epoch, data_type, res, t0):
|
36 |
+
print(
|
37 |
+
"Epoch: {} | {}: joint_acc: {:.2f}%, slot_acc: {:.2f}% | time: {:.1f}".format(
|
38 |
+
epoch,
|
39 |
+
data_type,
|
40 |
+
res["avg_joint_acc"],
|
41 |
+
res["avg_slot_acc"],
|
42 |
+
time.time() - t0,
|
43 |
+
)
|
44 |
+
)
|
45 |
+
|
46 |
+
|
47 |
+
def run_one_epoch(data_type, dataloader, trainer, epoch, run_type, collector=None):
|
48 |
+
t0 = time.time()
|
49 |
+
assert data_type in ["dev", "test"]
|
50 |
+
assert run_type in ["teacher_force", "generation"]
|
51 |
+
model, optimizer, scheduler, tokenizer = trainer
|
52 |
+
|
53 |
+
LOSS = 0
|
54 |
+
# result = {"slot_acc": [], "joint_acc": []}
|
55 |
+
# mention_match = 0
|
56 |
+
# coref_lines = []
|
57 |
+
iterator = enumerate(
|
58 |
+
tqdm(
|
59 |
+
dataloader,
|
60 |
+
desc="Epoch {} {}".format(epoch, run_type),
|
61 |
+
disable=args.disable_display,
|
62 |
+
)
|
63 |
+
)
|
64 |
+
for step, batch in iterator:
|
65 |
+
if run_type == "teacher_force":
|
66 |
+
loss, logits, _ = model(
|
67 |
+
input_ids=batch["input_ids"],
|
68 |
+
attention_mask=batch["attention_mask"],
|
69 |
+
token_type_ids=batch["token_type_ids"],
|
70 |
+
labels=batch["label_ids"],
|
71 |
+
).values()
|
72 |
+
LOSS += loss
|
73 |
+
else:
|
74 |
+
decode_e2e(args, batch, model, tokenizer, collector=collector)
|
75 |
+
|
76 |
+
# print log
|
77 |
+
if run_type == "teacher_force":
|
78 |
+
LOSS /= step + 1
|
79 |
+
print_loss(epoch, data_type, LOSS, t0)
|
80 |
+
return LOSS
|
81 |
+
else: # generation
|
82 |
+
# TODO: add evaluation code here
|
83 |
+
return None
|
84 |
+
|
85 |
+
|
86 |
+
def set_dataloader(args, tokenizer, data_type, run_type, data_size=-1):
|
87 |
+
dataset = SGD_Dataset(
|
88 |
+
args, tokenizer, data_type, run_type == "generation", data_size
|
89 |
+
)
|
90 |
+
# sys.exit(1)
|
91 |
+
if data_type == "train":
|
92 |
+
sampler = RandomSampler(
|
93 |
+
dataset
|
94 |
+
) # if args.local_rank == -1 else DistributedSampler(train_dataset)
|
95 |
+
else:
|
96 |
+
sampler = SequentialSampler(dataset)
|
97 |
+
|
98 |
+
dataloader = DataLoader(
|
99 |
+
dataset,
|
100 |
+
sampler=sampler,
|
101 |
+
batch_size=args.train_batch_size
|
102 |
+
if data_type == "train"
|
103 |
+
else args.eval_batch_size,
|
104 |
+
collate_fn=dataset.collate_fn,
|
105 |
+
)
|
106 |
+
return dataloader
|
107 |
+
|
108 |
+
|
109 |
+
def train(args, tokenizer, model):
|
110 |
+
|
111 |
+
wandb.init(
|
112 |
+
# Set the project where this run will be logged
|
113 |
+
project="E2E User Simulator (Alistair)",
|
114 |
+
entity="byrne-lab",
|
115 |
+
# We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
|
116 |
+
name=args.wandb_train_run_name,
|
117 |
+
# Track hyperparameters and run metadata
|
118 |
+
config={
|
119 |
+
"data_dir": args.data_dir,
|
120 |
+
"model_name": args.model_name,
|
121 |
+
"learning_rate": args.learning_rate,
|
122 |
+
"gradient_accumulation_steps": args.gradient_accumulation_steps,
|
123 |
+
"train_batch_size": args.train_batch_size,
|
124 |
+
"eval_batch_size": args.eval_batch_size,
|
125 |
+
},
|
126 |
+
)
|
127 |
+
|
128 |
+
# load data
|
129 |
+
train_dataloader = set_dataloader(
|
130 |
+
args, tokenizer, "train", "teacher_force", data_size=args.train_size
|
131 |
+
)
|
132 |
+
dev_dataloader = set_dataloader(
|
133 |
+
args, tokenizer, "dev", "teacher_force", data_size=args.eval_size
|
134 |
+
)
|
135 |
+
|
136 |
+
optimizer = AdamW(model.parameters(), lr=args.learning_rate, eps=args.adam_epsilon)
|
137 |
+
if args.use_scheduler:
|
138 |
+
t_total = (
|
139 |
+
len(train_dataloader) // args.gradient_accumulation_steps * args.max_epoch
|
140 |
+
)
|
141 |
+
scheduler = get_linear_schedule_with_warmup(
|
142 |
+
optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
|
143 |
+
)
|
144 |
+
else:
|
145 |
+
scheduler = None
|
146 |
+
trainer = (model, optimizer, scheduler, tokenizer)
|
147 |
+
|
148 |
+
print("Do evaluation before training!")
|
149 |
+
model.eval()
|
150 |
+
with torch.no_grad():
|
151 |
+
_ = run_one_epoch("dev", dev_dataloader, trainer, -1, "teacher_force")
|
152 |
+
|
153 |
+
print("Start training!\n{}".format("***" * 30))
|
154 |
+
eval_step = args.eval_interval // args.train_batch_size
|
155 |
+
best_score = -100
|
156 |
+
global_step = 0
|
157 |
+
no_improve_count = 0
|
158 |
+
for epoch in range(args.max_epoch):
|
159 |
+
# initialize for each epoch training
|
160 |
+
t0 = time.time()
|
161 |
+
model.train()
|
162 |
+
model.zero_grad()
|
163 |
+
LOSS = 0
|
164 |
+
iterator = enumerate(
|
165 |
+
tqdm(
|
166 |
+
train_dataloader,
|
167 |
+
desc="Epoch {}".format(epoch),
|
168 |
+
disable=args.disable_display,
|
169 |
+
)
|
170 |
+
)
|
171 |
+
for local_step, batch in iterator:
|
172 |
+
loss, logits, _ = model(
|
173 |
+
input_ids=batch["input_ids"],
|
174 |
+
attention_mask=batch["attention_mask"],
|
175 |
+
token_type_ids=batch["token_type_ids"],
|
176 |
+
labels=batch["label_ids"],
|
177 |
+
).values()
|
178 |
+
LOSS += loss
|
179 |
+
global_step += 1
|
180 |
+
|
181 |
+
wandb.log({"loss": loss})
|
182 |
+
|
183 |
+
# update model
|
184 |
+
if loss != 0:
|
185 |
+
loss = loss / args.gradient_accumulation_steps
|
186 |
+
loss.backward()
|
187 |
+
|
188 |
+
if global_step % args.gradient_accumulation_steps == 0:
|
189 |
+
# norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
|
190 |
+
optimizer.step()
|
191 |
+
if args.use_scheduler:
|
192 |
+
scheduler.step()
|
193 |
+
optimizer.zero_grad()
|
194 |
+
|
195 |
+
# evaluate model
|
196 |
+
if global_step % eval_step == 0:
|
197 |
+
model.eval()
|
198 |
+
with torch.no_grad():
|
199 |
+
loss = run_one_epoch(
|
200 |
+
"dev", dev_dataloader, trainer, epoch, "teacher_force"
|
201 |
+
)
|
202 |
+
score = -loss # dev loss as criterion for early training
|
203 |
+
wandb.log({"dev_loss": loss})
|
204 |
+
model.train()
|
205 |
+
|
206 |
+
save_checkpoint(
|
207 |
+
args, tokenizer, model, global_step * args.train_batch_size
|
208 |
+
)
|
209 |
+
if score > best_score:
|
210 |
+
best_score = score
|
211 |
+
print("Best score: {:.2f}".format(best_score))
|
212 |
+
no_improve_count = 0
|
213 |
+
else:
|
214 |
+
no_improve_count += 1
|
215 |
+
|
216 |
+
# early stop
|
217 |
+
if no_improve_count == args.no_improve_max:
|
218 |
+
print("Early stop!")
|
219 |
+
return
|
220 |
+
|
221 |
+
LOSS /= local_step + 1
|
222 |
+
print_loss(epoch, "train", LOSS, t0)
|
223 |
+
print("***" * 30)
|
224 |
+
|
225 |
+
wandb.log({"epoch": epoch, "epoch_loss": LOSS})
|
226 |
+
|
227 |
+
# Mark the run as finished on wandb
|
228 |
+
wandb.finish()
|
229 |
+
|
230 |
+
|
231 |
+
def test(args, tokenizer, model):
|
232 |
+
# load data
|
233 |
+
test_gen_dataloader = set_dataloader(args, tokenizer, "test", "generation")
|
234 |
+
|
235 |
+
trainer = (model, None, None, tokenizer)
|
236 |
+
model.eval()
|
237 |
+
collector = {"decode-dev": {}, "decode-test": {}}
|
238 |
+
with torch.no_grad():
|
239 |
+
# # evaluate on dev
|
240 |
+
# _ = run_one_epoch('dev', dev_dataloader, trainer, 'Eval', 'teacher_force')
|
241 |
+
|
242 |
+
# # generate on dev
|
243 |
+
# res_dev = run_one_epoch('dev', dev_gen_dataloader, trainer, 'Dev', 'generation',
|
244 |
+
# collector=collector['decode-dev'])
|
245 |
+
# collector['result-dev'] = res_dev
|
246 |
+
# print_qr_result(res_dev['qr'], 'dev')
|
247 |
+
|
248 |
+
# generate on test
|
249 |
+
res_test = run_one_epoch(
|
250 |
+
"test",
|
251 |
+
test_gen_dataloader,
|
252 |
+
trainer,
|
253 |
+
"Test",
|
254 |
+
"generation",
|
255 |
+
collector=collector["decode-test"],
|
256 |
+
)
|
257 |
+
collector["result-test"] = res_test
|
258 |
+
|
259 |
+
out_file = args.decode_file
|
260 |
+
with open(out_file, "w") as f:
|
261 |
+
json.dump(collector, f, indent=4, sort_keys=True)
|
262 |
+
print("Decode file is saved at {}".format(out_file))
|
263 |
+
print("Done decoding!")
|
264 |
+
|
265 |
+
|
266 |
+
def save_checkpoint(args, tokenizer, model, step):
|
267 |
+
save_path = args.checkpoint + "_step" + str(step)
|
268 |
+
print("Save model in {}!".format(save_path))
|
269 |
+
tokenizer.save_pretrained(save_path)
|
270 |
+
model.save_pretrained(save_path)
|
271 |
+
|
272 |
+
|
273 |
+
def load_checkpoint(args):
|
274 |
+
save_path = args.checkpoint # + '_step' + str(args.step)
|
275 |
+
print("Load model, tokenizer from {}".format(save_path))
|
276 |
+
tokenizer = GPT2Tokenizer.from_pretrained(save_path)
|
277 |
+
model = GPT2LMHeadModel.from_pretrained(save_path)
|
278 |
+
model.to(args.device)
|
279 |
+
return tokenizer, model
|
280 |
+
|
281 |
+
|
282 |
+
def load_pretrained_model(args):
|
283 |
+
save_path = args.pre_checkpoint
|
284 |
+
print("Load model, tokenizer from {}".format(save_path))
|
285 |
+
tokenizer = GPT2Tokenizer.from_pretrained(save_path)
|
286 |
+
model = GPT2LMHeadModel.from_pretrained(save_path)
|
287 |
+
model.to(args.device)
|
288 |
+
return tokenizer, model
|
289 |
+
|
290 |
+
|
291 |
+
def set_model(args, SPECIAL_TOKENS):
|
292 |
+
"""initiate config, tokenizer and model"""
|
293 |
+
# add special tokens into tokenizer
|
294 |
+
config = GPT2Config.from_pretrained(args.model_name_or_path)
|
295 |
+
tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
|
296 |
+
tokenizer.add_special_tokens(SPECIAL_TOKENS)
|
297 |
+
model = GPT2LMHeadModel.from_pretrained(
|
298 |
+
args.model_name_or_path, config=config
|
299 |
+
) # GPT2LMHeadModel
|
300 |
+
model.resize_token_embeddings(len(tokenizer))
|
301 |
+
model.to(args.device)
|
302 |
+
print("Done setting model")
|
303 |
+
return config, tokenizer, model
|
304 |
+
|
305 |
+
|
306 |
+
def set_seed(args):
|
307 |
+
"""for reproduction"""
|
308 |
+
random.seed(args.seed)
|
309 |
+
np.random.seed(args.seed)
|
310 |
+
torch.manual_seed(args.seed)
|
311 |
+
torch.cuda.manual_seed(args.seed)
|
312 |
+
torch.cuda.manual_seed_all(args.seed)
|
313 |
+
torch.backends.cudnn.deterministic = True
|
314 |
+
torch.backends.cudnn.enabled = False
|
315 |
+
torch.backends.cudnn.benchmark = False
|
316 |
+
|
317 |
+
|
318 |
+
if __name__ == "__main__":
|
319 |
+
# Load arguments
|
320 |
+
args = get_args()
|
321 |
+
|
322 |
+
# Set seed, device
|
323 |
+
set_seed(args)
|
324 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
325 |
+
args.device = device
|
326 |
+
|
327 |
+
# Load special tokens
|
328 |
+
SPECIAL_TOKENS = get_special_tokens()
|
329 |
+
|
330 |
+
if args.mode == "training":
|
331 |
+
config, tokenizer, model = set_model(args, SPECIAL_TOKENS)
|
332 |
+
train(args, tokenizer, model)
|
333 |
+
|
334 |
+
elif args.mode == "finetune":
|
335 |
+
tokenizer, model = load_pretrained_model(args)
|
336 |
+
train(args, tokenizer, model)
|
337 |
+
|
338 |
+
elif args.mode == "testing":
|
339 |
+
tokenizer, model = load_checkpoint(args)
|
340 |
+
test(args, tokenizer, model)
|
341 |
+
|
342 |
+
# elif args.mode == 'interact':
|
343 |
+
# tokenizer, model = load_checkpoint(args)
|
344 |
+
# interact(args, tokenizer, model)
|
345 |
+
|
346 |
+
else:
|
347 |
+
sys.exit(1)
|
scripts/user_model_code/preprocess_multiwoz.py
ADDED
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
from crazyneuraluser.user_model_code.analysis_multiwoz import DATA_SPLIT, collect_data
|
8 |
+
from crazyneuraluser.user_model_code.utils_multiwoz import (
|
9 |
+
get_act_natural_language,
|
10 |
+
get_original_act_set,
|
11 |
+
load_schema,
|
12 |
+
normalise_intent,
|
13 |
+
normalise_slot,
|
14 |
+
normalise_value,
|
15 |
+
)
|
16 |
+
from crazyneuraluser.user_model_code.utils_sgd import (
|
17 |
+
add_str,
|
18 |
+
compare_slot_values_in_state,
|
19 |
+
conv_special_token,
|
20 |
+
dict2list,
|
21 |
+
get_special_tokens,
|
22 |
+
wrap_element,
|
23 |
+
)
|
24 |
+
|
25 |
+
""" pre-process script for MultiWOZ v2.2 """
|
26 |
+
|
27 |
+
|
28 |
+
class DialMetaData:
|
29 |
+
def __init__(self, dial_id, dial_meta, dial_act, unify_act):
|
30 |
+
self.dial_id = dial_id
|
31 |
+
self.unify_act = unify_act
|
32 |
+
self.turn_meta_list, self.scenario = self.parse(
|
33 |
+
dial_meta, dial_act
|
34 |
+
) # None for system turn
|
35 |
+
self.linearise_turns()
|
36 |
+
|
37 |
+
def parse(self, dial_meta, dial_act):
|
38 |
+
global n, act_intent, non_intent
|
39 |
+
assert len(dial_meta["turns"]) == len(dial_act)
|
40 |
+
|
41 |
+
turn_meta_list = []
|
42 |
+
scenario = []
|
43 |
+
sys_turn = None # dummy sys turn for first usr turn
|
44 |
+
prev_intent = ""
|
45 |
+
prev_usr_turn, prev_usr_turn_meta = (
|
46 |
+
None,
|
47 |
+
None,
|
48 |
+
) # dummpy for tracing goal change at first turn
|
49 |
+
for turn_id, turn in enumerate(dial_meta["turns"]):
|
50 |
+
assert turn_id == int(turn["turn_id"])
|
51 |
+
|
52 |
+
if turn["speaker"] == "SYSTEM":
|
53 |
+
sys_turn = turn
|
54 |
+
turn_meta_list.append(None)
|
55 |
+
continue
|
56 |
+
|
57 |
+
# init turn meta
|
58 |
+
turn_meta = TurnMetaData(
|
59 |
+
prev_intent, sys_turn, turn, self.dial_id, self.unify_act
|
60 |
+
)
|
61 |
+
|
62 |
+
# get goal change label
|
63 |
+
turn_meta.get_goal_change_label(prev_usr_turn, prev_usr_turn_meta)
|
64 |
+
|
65 |
+
# update previous goal
|
66 |
+
for prev_turn_meta in reversed(turn_meta_list):
|
67 |
+
if prev_turn_meta is None:
|
68 |
+
continue
|
69 |
+
prev_turn_meta.accumulate_constraints(turn_meta) # TODO: check goal
|
70 |
+
|
71 |
+
# record task (intent) in scenario
|
72 |
+
if turn_meta.usr_intent not in scenario:
|
73 |
+
scenario.append(turn_meta.usr_intent)
|
74 |
+
|
75 |
+
turn_meta_list.append(turn_meta)
|
76 |
+
prev_intent = turn_meta.usr_intent
|
77 |
+
prev_usr_turn, prev_usr_turn_meta = turn, turn_meta
|
78 |
+
assert len(turn_meta_list) == len(dial_meta["turns"])
|
79 |
+
return turn_meta_list, scenario
|
80 |
+
|
81 |
+
def linearise_turns(self):
|
82 |
+
# linearise necessary meterials
|
83 |
+
for turn_meta in self.turn_meta_list:
|
84 |
+
if turn_meta is None:
|
85 |
+
continue
|
86 |
+
turn_meta._linearise(self.scenario, SERVICE2META)
|
87 |
+
|
88 |
+
|
89 |
+
class TurnMetaData:
|
90 |
+
def __init__(self, prev_intent, sys_turn, usr_turn, dial_id, unify_act):
|
91 |
+
self.dial_id = dial_id
|
92 |
+
self.unify_act = unify_act
|
93 |
+
self.original_act_set = get_original_act_set() # act set w/o domain information
|
94 |
+
self.sys_turn, self.usr_turn = sys_turn, usr_turn
|
95 |
+
|
96 |
+
# turn id
|
97 |
+
self.sys_turn_id, self.usr_turn_id = self._get_turn_id(sys_turn, usr_turn)
|
98 |
+
|
99 |
+
# intent
|
100 |
+
self.usr_intent = normalise_intent(self._get_intent(usr_turn, prev_intent))
|
101 |
+
if remove_book_intent:
|
102 |
+
self.usr_intent = self.usr_intent.replace("book", "find")
|
103 |
+
assert self.usr_intent in INTENTS # or self.usr_intent == "temp temp"
|
104 |
+
self.service = self.usr_intent.split()[1]
|
105 |
+
|
106 |
+
# utterances
|
107 |
+
self.utt = {}
|
108 |
+
self.utt["sys"], self.utt["usr"] = self._get_utt(sys_turn), self._get_utt(
|
109 |
+
usr_turn
|
110 |
+
)
|
111 |
+
|
112 |
+
# act
|
113 |
+
self.act2sv = {}
|
114 |
+
self.act2sv["sys"], _ = self._parse_action(self.sys_turn_id, self.sys_turn)
|
115 |
+
self.act2sv["usr"], self.usr_constraints = self._parse_action(
|
116 |
+
self.usr_turn_id, self.usr_turn
|
117 |
+
)
|
118 |
+
|
119 |
+
# task boundary
|
120 |
+
self._get_new_task_label(prev_intent)
|
121 |
+
|
122 |
+
# req_alts
|
123 |
+
self._get_req_alts_label()
|
124 |
+
|
125 |
+
def _get_turn_id(self, sys_turn, usr_turn):
|
126 |
+
usr_turn_id = int(usr_turn["turn_id"]) # 0, 2, 4 ...
|
127 |
+
sys_turn_id = int(sys_turn["turn_id"]) if sys_turn is not None else -1
|
128 |
+
assert sys_turn_id == (usr_turn_id - 1)
|
129 |
+
return sys_turn_id, usr_turn_id
|
130 |
+
|
131 |
+
def _get_utt(self, turn):
|
132 |
+
if turn is None:
|
133 |
+
return ""
|
134 |
+
return turn["utterance"]
|
135 |
+
|
136 |
+
def accumulate_constraints(self, new_turn_meta):
|
137 |
+
"""
|
138 |
+
Add slot, slot-value pairs from a given following turn
|
139 |
+
This function forms the user goal by accumulating constraints backward
|
140 |
+
"""
|
141 |
+
# only accumulate constraints with the same task/intent
|
142 |
+
if new_turn_meta.usr_intent != self.usr_intent:
|
143 |
+
return
|
144 |
+
|
145 |
+
if (
|
146 |
+
new_turn_meta.goal_change
|
147 |
+
): # if goal changes at a new turn, these constraints should not be put in previous turns
|
148 |
+
return
|
149 |
+
|
150 |
+
# only accumulate constraints without goal change
|
151 |
+
# if the value of a slot is changed (goal change) in a new turn,
|
152 |
+
# this slot-value pair is not part of initial goal and should not be added into the goal of previous turns
|
153 |
+
new_constraints = new_turn_meta.usr_constraints
|
154 |
+
self.usr_constraints["requestable"] = self.usr_constraints["requestable"].union(
|
155 |
+
new_constraints["requestable"]
|
156 |
+
)
|
157 |
+
for slot, value_list in new_constraints["informable"].items():
|
158 |
+
if slot not in self.usr_constraints["informable"]:
|
159 |
+
self.usr_constraints["informable"][slot] = value_list
|
160 |
+
|
161 |
+
def get_goal_change_label(self, prev_usr_turn, prev_turn_meta):
|
162 |
+
"""check if goal changed (value of slot changes) between two turn states"""
|
163 |
+
# first usr turn
|
164 |
+
if prev_usr_turn is None:
|
165 |
+
assert self.usr_turn_id == 0
|
166 |
+
self.goal_change = False
|
167 |
+
return
|
168 |
+
|
169 |
+
# last usr turn
|
170 |
+
if "GOODBYE" in self.act2sv["usr"] or "THANK_YOU" in self.act2sv["usr"]:
|
171 |
+
self.goal_change = False
|
172 |
+
return
|
173 |
+
|
174 |
+
assert self.usr_turn_id != 0
|
175 |
+
assert prev_usr_turn["speaker"] == "USER"
|
176 |
+
|
177 |
+
# new task
|
178 |
+
if self.usr_intent != prev_turn_meta.usr_intent:
|
179 |
+
self.goal_change = False
|
180 |
+
return
|
181 |
+
|
182 |
+
# compare two states to obtain goal change flag
|
183 |
+
curr_state, prev_state = None, None
|
184 |
+
for frame in self.usr_turn["frames"]:
|
185 |
+
if frame["service"] == self.service:
|
186 |
+
curr_state = frame["state"]["slot_values"]
|
187 |
+
|
188 |
+
for frame in prev_usr_turn["frames"]:
|
189 |
+
if frame["service"] == prev_turn_meta.service:
|
190 |
+
prev_state = frame["state"]["slot_values"]
|
191 |
+
|
192 |
+
# check if slot value has changed at current turn (new slot is not counted)
|
193 |
+
assert curr_state is not None and prev_state is not None
|
194 |
+
self.goal_change = compare_slot_values_in_state(curr_state, prev_state)
|
195 |
+
|
196 |
+
def _get_domain_from_act(self, dialogue_act):
|
197 |
+
"""
|
198 |
+
parse the raw dialouge act annotation to get domain info
|
199 |
+
number of doamin can be more than 1 for multi-domain turns
|
200 |
+
"""
|
201 |
+
domains = set()
|
202 |
+
book_flag = False
|
203 |
+
for dact, sv_pairs in dialogue_act.items():
|
204 |
+
assert "-" in dact
|
205 |
+
domain, _ = dact.split("-")
|
206 |
+
if domain not in ["Booking", "general"]:
|
207 |
+
domains.add(domain)
|
208 |
+
for slot, value in sv_pairs:
|
209 |
+
if "book" in slot: # e.g., bookday
|
210 |
+
book_flag = True
|
211 |
+
return domains, book_flag
|
212 |
+
|
213 |
+
def _get_intent(self, usr_turn, prev_intent):
|
214 |
+
intents = []
|
215 |
+
for frame in usr_turn["frames"]:
|
216 |
+
# service = frame["service"]
|
217 |
+
intent = frame["state"]["active_intent"]
|
218 |
+
if intent != "NONE":
|
219 |
+
intents.append(intent)
|
220 |
+
|
221 |
+
if len(intents) == 1:
|
222 |
+
intent = intents[0]
|
223 |
+
if intent == "find_taxi":
|
224 |
+
intent = "book_taxi"
|
225 |
+
return intent # tackle 51.5k out of 71.5k user turns
|
226 |
+
|
227 |
+
# if above fails (e.g., due to wrong label), leverage usr act to help determine main intent/service
|
228 |
+
# possible domains in da: {'Hospital', 'Taxi', 'Train', 'Police', 'Restaurant', 'Booking', 'general',
|
229 |
+
# 'Attraction', 'Hotel'}
|
230 |
+
usr_act = data_act[self.dial_id][str(self.usr_turn_id)]["dialog_act"]
|
231 |
+
domains, book_flag = self._get_domain_from_act(usr_act)
|
232 |
+
if len(domains) == 1:
|
233 |
+
domain = list(domains)[0].lower()
|
234 |
+
if book_flag and domain in ["restaurant", "hotel", "train"]:
|
235 |
+
intent = "book_{}".format(domain)
|
236 |
+
elif domain == "taxi":
|
237 |
+
intent = "book_{}".format(domain)
|
238 |
+
else:
|
239 |
+
intent = "find_{}".format(domain)
|
240 |
+
return intent # tackle 58.1k out of 71.5k user turns
|
241 |
+
|
242 |
+
if "Taxi" in domains:
|
243 |
+
return "book_taxi" # tackle 58.8k out of 71.5k user turns
|
244 |
+
|
245 |
+
if (
|
246 |
+
self.usr_turn_id == 0
|
247 |
+
): # wrong label at first turn, no previous intent to use, only 136 user turns here
|
248 |
+
utt = usr_turn["utterance"]
|
249 |
+
if (
|
250 |
+
"restaurant" in utt
|
251 |
+
or "Restaurant" in utt
|
252 |
+
or "eat" in utt
|
253 |
+
or "din" in utt
|
254 |
+
):
|
255 |
+
return "find_restaurant"
|
256 |
+
elif (
|
257 |
+
"hotel" in utt
|
258 |
+
or "room" in utt
|
259 |
+
or "house" in utt
|
260 |
+
or "stay" in utt
|
261 |
+
or "live" in utt
|
262 |
+
):
|
263 |
+
return "find_hotel"
|
264 |
+
else:
|
265 |
+
return "find_attraction" # tackle 58.9k out of 71.5k user turns
|
266 |
+
|
267 |
+
else: # not first turn, leverage sys act to help decide intent
|
268 |
+
sys_act = data_act[self.dial_id][str(self.sys_turn_id)]["dialog_act"]
|
269 |
+
sys_domains, _ = self._get_domain_from_act(sys_act)
|
270 |
+
if len(sys_domains) == 1:
|
271 |
+
domain = list(sys_domains)[0].lower()
|
272 |
+
if book_flag and domain in ["restaurant", "hotel", "train"]:
|
273 |
+
intent = "book_{}".format(domain)
|
274 |
+
elif domain == "taxi":
|
275 |
+
intent = "book_{}".format(domain)
|
276 |
+
else:
|
277 |
+
intent = "find_{}".format(domain)
|
278 |
+
return intent # tackle 67.3k out of 71.5k user turns
|
279 |
+
|
280 |
+
# two cases left enter here
|
281 |
+
# 1. turns with only general act, e.g., bye
|
282 |
+
# 2. turns have multiple intents (very few)
|
283 |
+
# both will be handled using previous intent
|
284 |
+
assert prev_intent != ""
|
285 |
+
intent = "_".join(
|
286 |
+
prev_intent.split()
|
287 |
+
) # as prev_intent has been normalised already
|
288 |
+
return intent
|
289 |
+
|
290 |
+
def _parse_action(self, turn_id, turn):
|
291 |
+
"""parse the `dialog_act` field in `dialog_acts.json`
|
292 |
+
|
293 |
+
Returns:
|
294 |
+
act2sv: act to slot value pairs, {act=sv}; sv: slot to value list, {slot=[v1, v2]}
|
295 |
+
"""
|
296 |
+
act2sv = dict()
|
297 |
+
constraints = {"informable": dict(), "requestable": set()}
|
298 |
+
if turn is None:
|
299 |
+
return None, constraints
|
300 |
+
|
301 |
+
# get da from data_act
|
302 |
+
dialogue_act = data_act[self.dial_id][str(turn_id)]["dialog_act"]
|
303 |
+
# domains = set()
|
304 |
+
for dact, svs in dialogue_act.items():
|
305 |
+
assert "-" in dact
|
306 |
+
if self.unify_act: # will use only act part without domain info
|
307 |
+
domain, act = dact.split(
|
308 |
+
"-"
|
309 |
+
) # split `domain-act`, e.g., `hotel-inform` -> hotel, inform
|
310 |
+
else: # keep original mwoz act
|
311 |
+
act = dact # use act with domain info
|
312 |
+
|
313 |
+
if self.unify_act:
|
314 |
+
# unify act: `Booking-Inform` with no args is equivalent to `OfferBook` in train domain
|
315 |
+
if dact == "Booking-Inform" and svs == [["none", "none"]]:
|
316 |
+
act = "OfferBook"
|
317 |
+
|
318 |
+
# deal with act
|
319 |
+
if self.unify_act:
|
320 |
+
assert act in self.original_act_set
|
321 |
+
if turn["speaker"] == "USER":
|
322 |
+
assert act in ["Inform", "Request", "bye", "thank", "greet"]
|
323 |
+
act = get_act_natural_language(act)
|
324 |
+
|
325 |
+
if act not in act2sv:
|
326 |
+
act2sv[act] = dict()
|
327 |
+
|
328 |
+
# iterate slot value pairs
|
329 |
+
for slot, value in svs:
|
330 |
+
slot = normalise_slot(slot)
|
331 |
+
value = normalise_value(value)
|
332 |
+
|
333 |
+
# act to slot value pairs
|
334 |
+
# NOTE: same slot might appear more than once per turn, e.g., when the system informs two hotels with
|
335 |
+
# their addresses so a value list is stored for each slot
|
336 |
+
if slot not in act2sv[act]:
|
337 |
+
act2sv[act][slot] = []
|
338 |
+
act2sv[act][slot].append(value)
|
339 |
+
|
340 |
+
# collect constraints
|
341 |
+
if act in ["REQUEST", "Request", "request"]:
|
342 |
+
constraints["requestable"].add(slot)
|
343 |
+
else:
|
344 |
+
if slot != "Empty":
|
345 |
+
if (
|
346 |
+
slot not in constraints["informable"]
|
347 |
+
): # NOTE: same reason as act, value list per slot
|
348 |
+
constraints["informable"][slot] = []
|
349 |
+
constraints["informable"][slot].append(value)
|
350 |
+
return act2sv, constraints
|
351 |
+
|
352 |
+
def _linearise(self, scenario, service2meta):
|
353 |
+
self.linear_act = {}
|
354 |
+
self.linear_act["sys"] = self._linearise_act(self.act2sv["sys"])
|
355 |
+
self.linear_act["usr"] = self._linearise_act(self.act2sv["usr"])
|
356 |
+
self.linear_goal = self._linearise_goal(
|
357 |
+
self.usr_constraints, scenario, service2meta
|
358 |
+
)
|
359 |
+
|
360 |
+
def _linearise_goal(self, constraints, scenario, service2meta):
|
361 |
+
"""
|
362 |
+
linearise goal representation which consists of several parts:
|
363 |
+
scenario, task (intent), task description, constraints with informable and requestable
|
364 |
+
e.g., <SCENARIO/> task1 task2 .. </SCENARIO>
|
365 |
+
<TASK/> current task </TASK> <DESC/> task description </DESC>
|
366 |
+
<INFORM/> <SLOT/> slot1 </SLOT> <VALUE> value1 </VALUE> .. </INFORM>
|
367 |
+
<REQUEST/> <SLOT> slot1 </SLOT> <SLOT> slot2 </SLOT> .. </REQUEST>
|
368 |
+
"""
|
369 |
+
res = ""
|
370 |
+
# scenario
|
371 |
+
assert isinstance(scenario, list) and len(scenario) > 0
|
372 |
+
scenario = " ".join(
|
373 |
+
[wrap_element("INTENT", intent) for intent in scenario]
|
374 |
+
) # treat intent as nl
|
375 |
+
scenario_wrap = wrap_element("SCENARIO", scenario)
|
376 |
+
res = add_str(res, scenario_wrap)
|
377 |
+
|
378 |
+
# task name
|
379 |
+
intent = self.usr_intent
|
380 |
+
assert intent in scenario
|
381 |
+
intent_wrap = wrap_element("TASK", intent)
|
382 |
+
res = add_str(res, intent_wrap)
|
383 |
+
|
384 |
+
# task description
|
385 |
+
description = service2meta[self.service]["intents"][intent]["description"]
|
386 |
+
description_warp = wrap_element("DESC", description)
|
387 |
+
res = add_str(res, description_warp)
|
388 |
+
|
389 |
+
# informable
|
390 |
+
informable = dict2list(
|
391 |
+
constraints["informable"]
|
392 |
+
) # sorted sv pair list [slot=value]
|
393 |
+
res = add_str(res, "<INFORM/>")
|
394 |
+
for sv_pair in informable:
|
395 |
+
slot, value = sv_pair.split("=")
|
396 |
+
if value in ["True", "False", "Empty"]:
|
397 |
+
value = conv_special_token(value, SPECIAL_TOKENS)
|
398 |
+
if slot in ["Empty"]:
|
399 |
+
slot = conv_special_token(slot, SPECIAL_TOKENS)
|
400 |
+
# slot
|
401 |
+
slot_wrap = wrap_element("SLOT", slot)
|
402 |
+
res = add_str(res, slot_wrap)
|
403 |
+
# value
|
404 |
+
value_wrap = wrap_element("VALUE", value)
|
405 |
+
res = add_str(res, value_wrap)
|
406 |
+
res = add_str(res, "</INFORM>")
|
407 |
+
|
408 |
+
# requestable
|
409 |
+
requestable = sorted(
|
410 |
+
list(constraints["requestable"])
|
411 |
+
) # sorted slot list [slot]
|
412 |
+
res = add_str(res, "<REQUEST/>")
|
413 |
+
for slot in requestable:
|
414 |
+
slot_wrap = wrap_element("SLOT", slot)
|
415 |
+
res = add_str(res, slot_wrap)
|
416 |
+
res = add_str(res, "</REQUEST>")
|
417 |
+
return res[1:] # remove first space
|
418 |
+
|
419 |
+
def _linearise_act(self, act2sv):
|
420 |
+
"""
|
421 |
+
NOTE: 1) split slot/value if "_"; 2) special tokens of acts; 3) empty slot or empty value
|
422 |
+
NOTE: filer too many values (e.g., 10 movie names) but make sure the one the user chose is present
|
423 |
+
|
424 |
+
Return: ordered (slots sorted within act, acts sorted) linearised act sequence,
|
425 |
+
e.g., <ACT/> <INFORM> </ACT> <SLOT/> area </SLOT> <VALUE/> Cambridge </VALUE> ...
|
426 |
+
e.g., <ACT/> <REQUEST> </ACT> <SLOT/> _Empty_ </SLOT> <VALUE/> _Empty_ </VALUE>
|
427 |
+
"""
|
428 |
+
res = ""
|
429 |
+
if act2sv is None:
|
430 |
+
return res
|
431 |
+
|
432 |
+
for act in sorted(act2sv.keys()): # sort act
|
433 |
+
sv = act2sv[act] # dict{slot: value_list}
|
434 |
+
act_wrap = wrap_element("ACT", act)
|
435 |
+
res = add_str(res, act_wrap)
|
436 |
+
|
437 |
+
sorted_sv = dict2list(
|
438 |
+
sv
|
439 |
+
) # sorted sv list, [s1=v1, s2=v2], note slot can repeat
|
440 |
+
for sv_pair in sorted_sv:
|
441 |
+
slot, value = sv_pair.split("=")
|
442 |
+
if value in ["True", "False", "Empty"]:
|
443 |
+
value = conv_special_token(value, SPECIAL_TOKENS)
|
444 |
+
if slot in ["Empty"]:
|
445 |
+
slot = conv_special_token(slot, SPECIAL_TOKENS)
|
446 |
+
|
447 |
+
# slot
|
448 |
+
slot_wrap = wrap_element("SLOT", slot)
|
449 |
+
res = add_str(res, slot_wrap)
|
450 |
+
|
451 |
+
# value
|
452 |
+
value_wrap = wrap_element("VALUE", value)
|
453 |
+
res = add_str(res, value_wrap)
|
454 |
+
|
455 |
+
return res[1:] # remove first space
|
456 |
+
|
457 |
+
def _get_new_task_label(self, prev_intent):
|
458 |
+
"""
|
459 |
+
get a binary label indicating if a turn starts a new task (intent) in dialogue
|
460 |
+
"""
|
461 |
+
assert prev_intent != "NONE" and self.usr_intent != "NONE"
|
462 |
+
if self.usr_intent != prev_intent:
|
463 |
+
self.start_new_task = True
|
464 |
+
else:
|
465 |
+
self.start_new_task = False
|
466 |
+
|
467 |
+
def _get_req_alts_label(self):
|
468 |
+
self.req_alts = False # no request alternative in mwoz
|
469 |
+
|
470 |
+
|
471 |
+
def collect_examples(dial_id, dial_meta, examples):
|
472 |
+
num = 0
|
473 |
+
examples[dial_id] = {}
|
474 |
+
for turn_meta in dial_meta.turn_meta_list:
|
475 |
+
if turn_meta is None: # sys turn
|
476 |
+
continue
|
477 |
+
|
478 |
+
example_id = "{}-{}".format(dial_id, num)
|
479 |
+
example = {
|
480 |
+
"utterances": turn_meta.utt,
|
481 |
+
"actions": turn_meta.linear_act,
|
482 |
+
"goal": turn_meta.linear_goal,
|
483 |
+
"service": turn_meta.service,
|
484 |
+
"intent": turn_meta.usr_intent,
|
485 |
+
"goal_change": turn_meta.goal_change,
|
486 |
+
"start_new_task": turn_meta.start_new_task,
|
487 |
+
"req_alts": turn_meta.req_alts,
|
488 |
+
}
|
489 |
+
examples[dial_id][example_id] = example
|
490 |
+
num += 1
|
491 |
+
|
492 |
+
|
493 |
+
def prepare_data_seq(unify_act, out_data_path):
|
494 |
+
for split in DATA_SPLIT:
|
495 |
+
examples = {}
|
496 |
+
for dial_num, dial_id in enumerate(tqdm(sorted(data[split].keys()))):
|
497 |
+
dial = data[split][dial_id]
|
498 |
+
dial_act = data_act[dial_id]
|
499 |
+
|
500 |
+
dial_meta = DialMetaData(dial_id, dial, dial_act, unify_act)
|
501 |
+
collect_examples(dial_id, dial_meta, examples)
|
502 |
+
|
503 |
+
with open("{}/{}.json".format(out_data_path, split), "w") as f:
|
504 |
+
json.dump(examples, f, sort_keys=True, indent=4)
|
505 |
+
print("Done process {} {} dialogues".format(split, len(examples)))
|
506 |
+
|
507 |
+
|
508 |
+
if __name__ == "__main__":
|
509 |
+
if len(sys.argv) == 1:
|
510 |
+
print("Wrong argument!")
|
511 |
+
print("usage: python utils/preprocess_multiwoz.py multiwoz2.2-data-path")
|
512 |
+
sys.exit(1)
|
513 |
+
|
514 |
+
# Set data path
|
515 |
+
data_path = sys.argv[1]
|
516 |
+
out_data_path = "./data/preprocessed/user_model"
|
517 |
+
os.makedirs(out_data_path, exist_ok=True)
|
518 |
+
|
519 |
+
# Control flags
|
520 |
+
unify_act = True
|
521 |
+
remove_book_intent = True
|
522 |
+
|
523 |
+
# Load data and material as global var
|
524 |
+
SERVICE2META, INTENTS, SLOTS = load_schema(os.path.join(data_path, "schema.json"))
|
525 |
+
SPECIAL_TOKENS = get_special_tokens()
|
526 |
+
data, data_act = collect_data(data_path, remove_dial_switch=False)
|
527 |
+
|
528 |
+
prepare_data_seq(unify_act, out_data_path)
|
scripts/user_model_code/preprocess_sgd.py
ADDED
@@ -0,0 +1,431 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
from crazyneuraluser.user_model_code.analysis_sgd import DATA_SPLIT, collect_data
|
8 |
+
from crazyneuraluser.user_model_code.utils_sgd import (
|
9 |
+
add_str,
|
10 |
+
compare_slot_values_in_state,
|
11 |
+
dict2list,
|
12 |
+
get_special_tokens,
|
13 |
+
get_turn_intent,
|
14 |
+
load_schema,
|
15 |
+
split_intent,
|
16 |
+
wrap_element,
|
17 |
+
)
|
18 |
+
|
19 |
+
"""pre-processing script for SGD
|
20 |
+
|
21 |
+
The annotations for a turn are grouped into frames, where each frame corresponds to a single service
|
22 |
+
The values of "slot_values" in user "state" is a list, where spoken variations are considered, e.g., tomorrow, 8/2
|
23 |
+
"""
|
24 |
+
|
25 |
+
|
26 |
+
class DialMetaData:
|
27 |
+
def __init__(self, dial_id, dial):
|
28 |
+
self.dial_id = dial_id
|
29 |
+
self.turn_meta_list, self.scenario = self.parse(dial) # None for system turn
|
30 |
+
self.linearise_turns()
|
31 |
+
|
32 |
+
def parse(self, dial):
|
33 |
+
turn_meta_list = []
|
34 |
+
scenario = []
|
35 |
+
sys_turn = None # dummy sys turn for first usr turn
|
36 |
+
prev_intent = ""
|
37 |
+
prev_usr_turn, prev_usr_turn_meta = (
|
38 |
+
None,
|
39 |
+
None,
|
40 |
+
) # dummpy for tracing goal change at first turn
|
41 |
+
for turn_id, turn in enumerate(dial["turns"]):
|
42 |
+
if turn["speaker"] == "SYSTEM":
|
43 |
+
sys_turn = turn
|
44 |
+
turn_meta_list.append(None)
|
45 |
+
continue
|
46 |
+
|
47 |
+
# init turn meta
|
48 |
+
turn_meta = TurnMetaData(prev_intent, sys_turn, turn, self.dial_id)
|
49 |
+
|
50 |
+
# get goal change label
|
51 |
+
turn_meta.get_goal_change_label(prev_usr_turn, prev_usr_turn_meta)
|
52 |
+
|
53 |
+
# update previous goal
|
54 |
+
for prev_turn_meta in reversed(turn_meta_list):
|
55 |
+
if prev_turn_meta is None:
|
56 |
+
continue
|
57 |
+
prev_turn_meta.accumulate_constraints(turn_meta)
|
58 |
+
|
59 |
+
# record task (intent) in scenario
|
60 |
+
prev_intent = turn_meta.usr_intent
|
61 |
+
if turn_meta.usr_intent not in scenario:
|
62 |
+
scenario.append(turn_meta.usr_intent)
|
63 |
+
|
64 |
+
turn_meta_list.append(turn_meta)
|
65 |
+
prev_usr_turn, prev_usr_turn_meta = turn, turn_meta
|
66 |
+
|
67 |
+
assert len(turn_meta_list) == len(dial["turns"])
|
68 |
+
return turn_meta_list, scenario
|
69 |
+
|
70 |
+
def linearise_turns(self):
|
71 |
+
# linearise necessary meterials
|
72 |
+
for turn_meta in self.turn_meta_list:
|
73 |
+
if turn_meta is None:
|
74 |
+
continue
|
75 |
+
turn_meta._linearise(self.scenario)
|
76 |
+
|
77 |
+
|
78 |
+
class TurnMetaData:
|
79 |
+
def __init__(self, prev_intent, sys_turn, usr_turn, dial_id):
|
80 |
+
self.dial_id = dial_id
|
81 |
+
self.sys_turn, self.usr_turn = sys_turn, usr_turn
|
82 |
+
self.empty_token = "_Empty_"
|
83 |
+
assert self.empty_token in SPECIAL_TOKENS["additional_special_tokens"]
|
84 |
+
|
85 |
+
# intent
|
86 |
+
self.usr_intent, self.service = self._get_intent(usr_turn, prev_intent)
|
87 |
+
|
88 |
+
# utterances
|
89 |
+
self.utt = {}
|
90 |
+
self.utt["sys"], self.utt["usr"] = self._get_utt(sys_turn), self._get_utt(
|
91 |
+
usr_turn
|
92 |
+
)
|
93 |
+
|
94 |
+
# action
|
95 |
+
self.act2sv = {}
|
96 |
+
self.act2sv["sys"], _ = self._parse_action(sys_turn)
|
97 |
+
self.act2sv["usr"], self.usr_constraints = self._parse_action(usr_turn)
|
98 |
+
|
99 |
+
# task boundary
|
100 |
+
self._get_new_task_label(prev_intent)
|
101 |
+
|
102 |
+
# req_alts
|
103 |
+
self._get_req_alts_label(self.act2sv["usr"])
|
104 |
+
|
105 |
+
def _get_intent(self, turn, prev_intent):
|
106 |
+
"""manually set the `NONE` intent to the intent of previous turn"""
|
107 |
+
active_intent, service = get_turn_intent(
|
108 |
+
turn
|
109 |
+
) # intent annotation (migt be `NONE`)
|
110 |
+
if active_intent == "NONE":
|
111 |
+
active_intent = prev_intent
|
112 |
+
return active_intent, service
|
113 |
+
|
114 |
+
def _get_utt(self, turn):
|
115 |
+
if turn is None:
|
116 |
+
return ""
|
117 |
+
return turn["utterance"]
|
118 |
+
|
119 |
+
def _parse_action(self, turn):
|
120 |
+
"""
|
121 |
+
parse action annotation to collect turn level information
|
122 |
+
1) act to slot-value pairs, dict{act: {slot: value}}
|
123 |
+
2) turn level constraints, dict{'informable': dict{slot: value}, 'requestable': set(slot)}
|
124 |
+
"""
|
125 |
+
# get mapping from act to slot-value pairs
|
126 |
+
act2sv = {}
|
127 |
+
info_req = {"informable": dict(), "requestable": set()} # constraints
|
128 |
+
|
129 |
+
if turn is None:
|
130 |
+
return None, info_req
|
131 |
+
|
132 |
+
for frame in turn["frames"]:
|
133 |
+
for action in frame["actions"]:
|
134 |
+
act, slot, values = action["act"], action["slot"], action["values"]
|
135 |
+
|
136 |
+
# deal with empty slot or value
|
137 |
+
if turn["speaker"] == "USER":
|
138 |
+
assert len(values) in [0, 1]
|
139 |
+
if slot == "":
|
140 |
+
slot = self.empty_token
|
141 |
+
value = values[0] if len(values) > 0 else self.empty_token
|
142 |
+
|
143 |
+
# act to slot-value pairs
|
144 |
+
if act not in act2sv:
|
145 |
+
act2sv[act] = {}
|
146 |
+
assert slot not in act2sv[act]
|
147 |
+
act2sv[act][slot] = value
|
148 |
+
|
149 |
+
# collect constraints
|
150 |
+
if slot in [
|
151 |
+
"",
|
152 |
+
self.empty_token,
|
153 |
+
]: # only act but no constraints, e.g., AFFIRM, NEGATE
|
154 |
+
continue
|
155 |
+
|
156 |
+
# turn level informalable and requestable info
|
157 |
+
if act == "REQUEST":
|
158 |
+
assert slot != ""
|
159 |
+
info_req["requestable"].add(slot)
|
160 |
+
else:
|
161 |
+
if turn["speaker"] == "USER":
|
162 |
+
assert act in [
|
163 |
+
"INFORM_INTENT",
|
164 |
+
"INFORM",
|
165 |
+
"SELECT",
|
166 |
+
] # not apply to system side
|
167 |
+
if (
|
168 |
+
act != "SELECT"
|
169 |
+
): # result offered by system is part of initial user goal
|
170 |
+
assert slot not in info_req["informable"]
|
171 |
+
info_req["informable"][slot] = value
|
172 |
+
return act2sv, info_req
|
173 |
+
|
174 |
+
def accumulate_constraints(self, new_turn_meta):
|
175 |
+
"""
|
176 |
+
Add slot, slot-value pairs from a given following turn
|
177 |
+
This function is used to form user goal by accumulating constraints backward
|
178 |
+
"""
|
179 |
+
# only accumulate constraints with the same task/intent
|
180 |
+
if new_turn_meta.usr_intent != self.usr_intent:
|
181 |
+
return
|
182 |
+
|
183 |
+
if (
|
184 |
+
new_turn_meta.goal_change
|
185 |
+
): # if goal changes at a new turn, these constraints should not be put in previous turns
|
186 |
+
return
|
187 |
+
|
188 |
+
# only accumulate constraints without goal change
|
189 |
+
# if the value of a slot is changed (goal change) in a new turn,
|
190 |
+
# this slot-value pair is not part of initial goal and should not be added into the goal of previous turns
|
191 |
+
new_constraints = new_turn_meta.usr_constraints
|
192 |
+
self.usr_constraints["requestable"] = self.usr_constraints["requestable"].union(
|
193 |
+
new_constraints["requestable"]
|
194 |
+
)
|
195 |
+
for slot, value in new_constraints["informable"].items():
|
196 |
+
if slot not in self.usr_constraints["informable"]:
|
197 |
+
self.usr_constraints["informable"][slot] = value
|
198 |
+
|
199 |
+
def _get_new_task_label(self, prev_intent):
|
200 |
+
"""get a binary label indicating if a turn starts a new task (intent) in dialogue"""
|
201 |
+
assert prev_intent != "NONE" and self.usr_intent != "NONE"
|
202 |
+
if self.usr_intent != prev_intent:
|
203 |
+
self.start_new_task = True
|
204 |
+
else:
|
205 |
+
self.start_new_task = False
|
206 |
+
|
207 |
+
def _get_req_alts_label(self, act2sv):
|
208 |
+
"""get a binary label indicating if usr requests alternatives"""
|
209 |
+
if "REQUEST_ALTS" in act2sv:
|
210 |
+
self.req_alts = True
|
211 |
+
else:
|
212 |
+
self.req_alts = False
|
213 |
+
|
214 |
+
def get_goal_change_label(self, prev_usr_turn, prev_turn_meta):
|
215 |
+
"""check if goal changed (value of slot changes) between two turn states"""
|
216 |
+
if prev_usr_turn is None: # first usr turn
|
217 |
+
self.goal_change = False
|
218 |
+
return
|
219 |
+
|
220 |
+
if (
|
221 |
+
len(self.usr_turn["frames"]) == 1
|
222 |
+
and self.usr_turn["frames"][0]["state"]["active_intent"] == "NONE"
|
223 |
+
): # `NONE` intent
|
224 |
+
self.goal_change = False
|
225 |
+
return
|
226 |
+
|
227 |
+
if self.usr_intent != prev_turn_meta.usr_intent: # new task
|
228 |
+
self.goal_change = False
|
229 |
+
return
|
230 |
+
|
231 |
+
assert prev_usr_turn["speaker"] == "USER"
|
232 |
+
prev_state_sv, curr_state_sv = None, None
|
233 |
+
for frame in prev_usr_turn["frames"]:
|
234 |
+
if frame["state"]["active_intent"] == self.usr_intent:
|
235 |
+
prev_state_sv = frame["state"]["slot_values"]
|
236 |
+
|
237 |
+
# fix some weird cases (count very few, around 30 turns)
|
238 |
+
if prev_state_sv is None:
|
239 |
+
assert (
|
240 |
+
len(prev_usr_turn["frames"]) == 1
|
241 |
+
and prev_usr_turn["frames"][0]["state"]["active_intent"] == "NONE"
|
242 |
+
)
|
243 |
+
prev_state_sv = prev_usr_turn["frames"][0]["state"]["slot_values"]
|
244 |
+
|
245 |
+
for frame in self.usr_turn["frames"]:
|
246 |
+
if frame["state"]["active_intent"] == self.usr_intent:
|
247 |
+
curr_state_sv = frame["state"]["slot_values"]
|
248 |
+
|
249 |
+
assert prev_state_sv is not None and curr_state_sv is not None
|
250 |
+
self.goal_change = compare_slot_values_in_state(
|
251 |
+
prev_state_sv, curr_state_sv
|
252 |
+
) # True if goal changes
|
253 |
+
|
254 |
+
def _linearise(self, scenario):
|
255 |
+
self.linear_act = {}
|
256 |
+
self.linear_act["sys"] = self._linearise_act(self.act2sv["sys"])
|
257 |
+
self.linear_act["usr"] = self._linearise_act(self.act2sv["usr"])
|
258 |
+
self.linear_goal = self._linearise_goal(self.usr_constraints, scenario)
|
259 |
+
|
260 |
+
def _linearise_act(self, act2sv):
|
261 |
+
"""
|
262 |
+
NOTE: 1) split slot/value if "_"; 2) special tokens of acts; 3) empty slot or empty value
|
263 |
+
NOTE: filer too many values (e.g., 10 movie names) but make sure the one the user chose is present
|
264 |
+
|
265 |
+
Return: ordered (slots sorted within act, acts sorted) linearised act sequence,
|
266 |
+
e.g., <ACT/> <INFORM> </ACT> <SLOT/> area </SLOT> <VALUE/> Cambridge </VALUE> ...
|
267 |
+
e.g., <ACT/> <REQUEST> </ACT> <SLOT/> _Empty_ </SLOT> <VALUE/> _Empty_ </VALUE>
|
268 |
+
"""
|
269 |
+
res = ""
|
270 |
+
if act2sv is None:
|
271 |
+
return res
|
272 |
+
|
273 |
+
for act in sorted(act2sv.keys()): # sort act
|
274 |
+
sv = act2sv[act] # dict{slot: value}
|
275 |
+
|
276 |
+
act = "_{}_".format(act) # act is special token
|
277 |
+
assert act in SPECIAL_TOKENS["additional_special_tokens"]
|
278 |
+
act_wrap = wrap_element("ACT", act)
|
279 |
+
res = add_str(res, act_wrap)
|
280 |
+
|
281 |
+
sorted_sv = dict2list(sv) # sorted sv list, [slot=value]
|
282 |
+
for sv_pair in sorted_sv:
|
283 |
+
slot, value = sv_pair.split("=")
|
284 |
+
slot, value = self._basic_normalise_slot(
|
285 |
+
slot
|
286 |
+
), self._basic_normalise_value(value, slot)
|
287 |
+
|
288 |
+
# slot
|
289 |
+
slot_wrap = wrap_element("SLOT", slot)
|
290 |
+
res = add_str(res, slot_wrap)
|
291 |
+
|
292 |
+
# value
|
293 |
+
value_wrap = wrap_element("VALUE", value)
|
294 |
+
res = add_str(res, value_wrap)
|
295 |
+
return res[1:] # remove first space
|
296 |
+
|
297 |
+
def _basic_normalise_value(self, value, slot):
|
298 |
+
# intent value
|
299 |
+
if slot == "intent":
|
300 |
+
value = split_intent(value)
|
301 |
+
return value
|
302 |
+
|
303 |
+
# special token value
|
304 |
+
if value in ["True", "False"]: # Empty is already in the form of "_Empty_"
|
305 |
+
value = "_{}_".format(value)
|
306 |
+
assert value in SPECIAL_TOKENS["additional_special_tokens"]
|
307 |
+
return value
|
308 |
+
return value
|
309 |
+
|
310 |
+
def _basic_normalise_slot(self, slot):
|
311 |
+
if slot not in SPECIAL_TOKENS["additional_special_tokens"]:
|
312 |
+
slot = slot.replace(
|
313 |
+
"_", " "
|
314 |
+
) # e.g., `date_of_journey` -> `date of journey`
|
315 |
+
return slot
|
316 |
+
|
317 |
+
def _linearise_goal(self, constraints, scenario):
|
318 |
+
"""
|
319 |
+
linearise goal representation which consists of several parts:
|
320 |
+
scenario, task (intent), task description, constraints with informable and requestable
|
321 |
+
e.g., <SCENARIO/> task1 task2 .. </SCENARIO>
|
322 |
+
<TASK/> current task </TASK> <DESC/> task description </DESC>
|
323 |
+
<INFORM/> <SLOT/> slot1 </SLOT> <VALUE> value1 </VALUE> .. </INFORM>
|
324 |
+
<REQUEST/> <SLOT> slot1 </SLOT> <SLOT> slot2 </SLOT> .. </REQUEST>
|
325 |
+
"""
|
326 |
+
res = ""
|
327 |
+
# scenario
|
328 |
+
assert isinstance(scenario, list) and len(scenario) > 0
|
329 |
+
scenario = " ".join(
|
330 |
+
[wrap_element("INTENT", split_intent(intent)) for intent in scenario]
|
331 |
+
)
|
332 |
+
scenario_wrap = wrap_element("SCENARIO", scenario)
|
333 |
+
res = add_str(res, scenario_wrap)
|
334 |
+
|
335 |
+
# task name
|
336 |
+
intent = split_intent(self.usr_intent)
|
337 |
+
assert intent in scenario
|
338 |
+
intent_wrap = wrap_element("TASK", intent)
|
339 |
+
res = add_str(res, intent_wrap)
|
340 |
+
|
341 |
+
# task description
|
342 |
+
description = SERVICE2META[self.service]["intents"][self.usr_intent][
|
343 |
+
"description"
|
344 |
+
]
|
345 |
+
description_warp = wrap_element("DESC", description)
|
346 |
+
res = add_str(res, description_warp)
|
347 |
+
|
348 |
+
# informable
|
349 |
+
informable = dict2list(
|
350 |
+
constraints["informable"]
|
351 |
+
) # sorted sv pair list [slot=value]
|
352 |
+
res = add_str(res, "<INFORM/>")
|
353 |
+
for sv_pair in informable:
|
354 |
+
slot, value = sv_pair.split("=")
|
355 |
+
slot, value = self._basic_normalise_slot(slot), self._basic_normalise_value(
|
356 |
+
value, slot
|
357 |
+
)
|
358 |
+
# slot
|
359 |
+
slot_wrap = wrap_element("SLOT", slot)
|
360 |
+
res = add_str(res, slot_wrap)
|
361 |
+
# value
|
362 |
+
value_wrap = wrap_element("VALUE", value)
|
363 |
+
res = add_str(res, value_wrap)
|
364 |
+
res = add_str(res, "</INFORM>")
|
365 |
+
|
366 |
+
# requestable
|
367 |
+
requestable = sorted(
|
368 |
+
list(constraints["requestable"])
|
369 |
+
) # sorted slot list [slot]
|
370 |
+
res = add_str(res, "<REQUEST/>")
|
371 |
+
for slot in requestable:
|
372 |
+
slot = self._basic_normalise_slot(slot)
|
373 |
+
slot_wrap = wrap_element("SLOT", slot)
|
374 |
+
res = add_str(res, slot_wrap)
|
375 |
+
res = add_str(res, "</REQUEST>")
|
376 |
+
return res[1:] # remove first space
|
377 |
+
|
378 |
+
|
379 |
+
def collect_examples(dial_id, dial_meta, examples):
|
380 |
+
num = 0
|
381 |
+
examples[dial_id] = {}
|
382 |
+
for turn_meta in dial_meta.turn_meta_list:
|
383 |
+
if turn_meta is None: # sys turn
|
384 |
+
continue
|
385 |
+
|
386 |
+
example_id = "{}-{}".format(dial_id, num)
|
387 |
+
example = {
|
388 |
+
"utterances": turn_meta.utt,
|
389 |
+
"actions": turn_meta.linear_act,
|
390 |
+
"goal": turn_meta.linear_goal,
|
391 |
+
"service": turn_meta.service,
|
392 |
+
"intent": turn_meta.usr_intent,
|
393 |
+
"goal_change": turn_meta.goal_change,
|
394 |
+
"start_new_task": turn_meta.start_new_task,
|
395 |
+
"req_alts": turn_meta.req_alts,
|
396 |
+
}
|
397 |
+
examples[dial_id][example_id] = example
|
398 |
+
num += 1
|
399 |
+
|
400 |
+
|
401 |
+
def prepare_data_seq(data, out_data_path):
|
402 |
+
for split in DATA_SPLIT:
|
403 |
+
examples = {}
|
404 |
+
for dial_num, dial_id in enumerate(tqdm(sorted(data[split].keys()))):
|
405 |
+
dial = data[split][dial_id]
|
406 |
+
dial_meta = DialMetaData(dial_id, dial)
|
407 |
+
collect_examples(dial_id, dial_meta, examples)
|
408 |
+
|
409 |
+
with open("{}/{}.json".format(out_data_path, split), "w") as f:
|
410 |
+
json.dump(examples, f, sort_keys=True, indent=4)
|
411 |
+
print("Done process {} {} dialogues".format(split, len(examples)))
|
412 |
+
|
413 |
+
|
414 |
+
if __name__ == "__main__":
|
415 |
+
if len(sys.argv) == 1:
|
416 |
+
print("wrong arguments!")
|
417 |
+
print("usage: python utils/preprocess_sgd.py sgd-data-path")
|
418 |
+
sys.exit(1)
|
419 |
+
|
420 |
+
# Set data path
|
421 |
+
data_path = sys.argv[1]
|
422 |
+
out_data_path = "./processed_data/sgd/"
|
423 |
+
os.makedirs(out_data_path, exist_ok=True)
|
424 |
+
|
425 |
+
# Load data and material as global var
|
426 |
+
SERVICE2META, INTENTS, SLOTS = load_schema(data_path)
|
427 |
+
SPECIAL_TOKENS = get_special_tokens()
|
428 |
+
data = collect_data(data_path, remove_dial_switch=True)
|
429 |
+
|
430 |
+
# Process data
|
431 |
+
prepare_data_seq(data, out_data_path)
|
scripts/user_model_code/train.sh
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
experiment=$1
|
2 |
+
|
3 |
+
# common setup
|
4 |
+
wandb_train_run_name="Full-user-model-training"
|
5 |
+
bs=16 # batch size for training
|
6 |
+
grad_step=2 # accumulated gradient steps
|
7 |
+
max_epoch=8 # max epoch for training
|
8 |
+
data_dir="./data/preprocessed/user_model"
|
9 |
+
train_size=-1 # number of examples used for training, -1 means all
|
10 |
+
eval_size=-1 # number of examples ued for evaluation, -1 means all
|
11 |
+
|
12 |
+
|
13 |
+
|
14 |
+
if [[ "$experiment" == "SGD" ]]; then
|
15 |
+
echo "Conduct experiment with SGD dataset"
|
16 |
+
job_name='SGD-full'
|
17 |
+
data_list="sgd" # 165k training examples
|
18 |
+
eval_interval=50000 # evaluation interval
|
19 |
+
|
20 |
+
elif [[ "$experiment" == "MultiWOZ" ]]; then
|
21 |
+
echo "Conduct experiment with MulwiWOZ dataset"
|
22 |
+
job_name='MultiWOZ-full'
|
23 |
+
data_list="multiwoz" # 56k training examples
|
24 |
+
eval_interval=20000
|
25 |
+
|
26 |
+
elif [[ "$experiment" == "Joint" ]]; then
|
27 |
+
echo "Conduct experiment with SGD + MulwiWOZ dataset"
|
28 |
+
job_name='Joint-full'
|
29 |
+
data_list="sgd multiwoz" # 221k training examples
|
30 |
+
eval_interval=70000
|
31 |
+
|
32 |
+
else
|
33 |
+
echo "Unrecognised argument"
|
34 |
+
exit
|
35 |
+
fi
|
36 |
+
|
37 |
+
mkdir -p checkpoint log
|
38 |
+
checkpoint='checkpoint/'$job_name
|
39 |
+
log='log/'$job_name'.log'
|
40 |
+
python ./scripts/user_model_code/main_user_model.py --mode='training' \
|
41 |
+
--wandb_train_run_name=$wandb_train_run_name \
|
42 |
+
--model_name=$job_name \
|
43 |
+
--checkpoint=$checkpoint \
|
44 |
+
--data_dir=$data_dir \
|
45 |
+
--data_list $data_list \
|
46 |
+
--train_size=$train_size \
|
47 |
+
--eval_size=$eval_size \
|
48 |
+
--eval_interval=$eval_interval \
|
49 |
+
--gradient_accumulation_steps=$grad_step \
|
50 |
+
--train_batch_size=$bs \
|
51 |
+
--max_epoch=$max_epoch
|
src/crazyneuraluser.egg-info/PKG-INFO
ADDED
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Metadata-Version: 2.1
|
2 |
+
Name: crazyneuraluser
|
3 |
+
Version: 0.0.post1.dev47+g049b138.d20220509
|
4 |
+
Summary: Add a short description here!
|
5 |
+
Home-page: https://github.com/pyscaffold/pyscaffold/
|
6 |
+
Author: Extended by Alistair McLeay, original code by Alexandru Coca
|
7 |
+
Author-email: am@alistairmcleay.com and alexcoca23@yahoo.co.uk
|
8 |
+
License: MIT
|
9 |
+
Project-URL: Documentation, https://pyscaffold.org/
|
10 |
+
Platform: any
|
11 |
+
Classifier: Development Status :: 4 - Beta
|
12 |
+
Classifier: Programming Language :: Python
|
13 |
+
Description-Content-Type: text/markdown; charset=UTF-8; variant=GFM
|
14 |
+
Provides-Extra: testing
|
15 |
+
License-File: LICENSE.txt
|
16 |
+
License-File: AUTHORS.md
|
17 |
+
|
18 |
+
# Cambridge Masters Project
|
19 |
+
Joint Learning of Practical Dialogue Systems and User Simulators
|
20 |
+
|
21 |
+
## Environment setup
|
22 |
+
|
23 |
+
1. Create an environment `crazyneuraluser` with the help of [conda]
|
24 |
+
```
|
25 |
+
conda env create -f environment.yml
|
26 |
+
```
|
27 |
+
2. Activate the new environment with:
|
28 |
+
```
|
29 |
+
conda activate crazyneuraluser
|
30 |
+
```
|
31 |
+
3. Install a version of `pytorch` compatible with your hardware (see the [pytorch website](https://pytorch.org/get-started/previous-versions/)). E.g.:
|
32 |
+
```
|
33 |
+
pip install torch --extra-index-url https://download.pytorch.org/whl/cu113
|
34 |
+
```
|
35 |
+
|
36 |
+
4. Install `spacy` and download the tokenization tool in spacy:
|
37 |
+
```
|
38 |
+
pip install spacy'
|
39 |
+
python -m spacy download en_core_web_sm
|
40 |
+
```
|
41 |
+
|
42 |
+
### Generating dialogues through agent-agent interaction
|
43 |
+
|
44 |
+
To generate dialogues, first change working directory to the `baselines` directory. Run the command
|
45 |
+
```
|
46 |
+
python baselines_setup.py
|
47 |
+
```
|
48 |
+
to prepare `convlab2` for running the baselines.
|
49 |
+
|
50 |
+
#### Generating dialogues conditioned on randomly sampled goals
|
51 |
+
|
52 |
+
Select one of the available configurations in the `configs` directory and run the command
|
53 |
+
```
|
54 |
+
python simulate_agent_interaction.py --config /rel/path/to/chosen/config
|
55 |
+
```
|
56 |
+
to generate dialogues conditioned on randomly sampled goals according to the `convlab2` goal model. The dialogues will be be saved automatically in the `models` directory, under a directory whose name depends on the configuration run. The `models` directory is located in the parent directory of the `baselines` directory. The `metadata.json` file saved with the dialogues contains information about the data generation process.
|
57 |
+
|
58 |
+
#### Generating dialogues conditioned on `MultiWOZ2.1` goals
|
59 |
+
|
60 |
+
To generate the entire corpus, simply pass the `--goals-path /path/to/multiwoz2.1/data.json/file` flag to `simulate_agent_interaction.py`. To generate the `test/val` split additionally pass the `--filter-path /path/to/multiwoz2.1/test-or-valListFile` argument to `simulate_agent_interaction.py`. You can use the `generate_multiwoz21_train_id_file` function in `baselines/utils.py` to generate `trainListFile` which can then be passed via the `--filter-path` argument to the dialogue generation script in order to generate dialogues conditioned on the `MultiWOZ2.1` training goals.
|
61 |
+
|
62 |
+
### Converting the generated dialogues to SGD-like format
|
63 |
+
|
64 |
+
The `create_data_from_multiwoz.py` script can be used to convert the generated dialogues to SGD format, necessary for evaluation. It is based on the script provided by Google for DSTC8, but with additional functionality such as:
|
65 |
+
|
66 |
+
- conversion of slot names as annotated in the MultiWOZ 2.1 dialogue acts to different slot names, specified through the `--slots_convention` argument. Options are `multiwoz22` to convert the slots to the same slots as defined in the MultiWOZ 2.2 dataset whreas the `multiwoz_goals` converts the slot names to the names used in the dialogue goal and state tracking annotations.
|
67 |
+
|
68 |
+
- addition of system and user `nlu` fields for every turn
|
69 |
+
|
70 |
+
- option to perform cleaning operations on the goals to ensure a standard format is received by the evaluator.
|
71 |
+
|
72 |
+
The conversion is done according to the `schema.json` file in the `baselines` directory, which is the same as used by `DSTC8` conversion except for the addition of the `police` domain. Type ``python create_data_from_multiwoz.py --helpfull`` to see a full list of flags and usage.
|
73 |
+
|
74 |
+
## Installation
|
75 |
+
|
76 |
+
The recommended way to use this repository is to develop the core code under `src/crazyneuraluser`. The experiments/exporatory analysis making use of the core package code should be placed outside the library and imported. See more guidance under the [Project Organisation](#project-organization) section below.
|
77 |
+
|
78 |
+
To create an environment for the package, make sure you have deactivated all `conda` environments. Then:
|
79 |
+
|
80 |
+
1. Create an environment `crazyneuraluser` with the help of [conda]:
|
81 |
+
```
|
82 |
+
conda env create -f environment.yml
|
83 |
+
```
|
84 |
+
2. Add the developer dependencies to this environment with the help of [conda]:
|
85 |
+
```
|
86 |
+
conda env update -f dev_environment.yml
|
87 |
+
```
|
88 |
+
|
89 |
+
Optional and needed only once after `git clone`:
|
90 |
+
|
91 |
+
3. install several [pre-commit] git hooks with:
|
92 |
+
```bash
|
93 |
+
pre-commit install
|
94 |
+
# You _are encouraged_ to run `pre-commit autoupdate`
|
95 |
+
```
|
96 |
+
and checkout the configuration under `.pre-commit-config.yaml`.
|
97 |
+
The `-n, --no-verify` flag of `git commit` can be used to deactivate pre-commit hooks temporarily.
|
98 |
+
|
99 |
+
4. install [nbstripout] git hooks to remove the output cells of committed notebooks with:
|
100 |
+
```bash
|
101 |
+
nbstripout --install --attributes notebooks/.gitattributes
|
102 |
+
```
|
103 |
+
This is useful to avoid large diffs due to plots in your notebooks.
|
104 |
+
A simple `nbstripout --uninstall` will revert these changes.
|
105 |
+
|
106 |
+
Then take a look into the `scripts` and `notebooks` folders.
|
107 |
+
|
108 |
+
## Dependency Management & Reproducibility
|
109 |
+
|
110 |
+
1. Always keep your abstract (unpinned) dependencies updated in `environment.yml` and eventually
|
111 |
+
in `setup.cfg` if you want to ship and install your package via `pip` later on.
|
112 |
+
2. Create concrete dependencies as `environment.lock.yml` for the exact reproduction of your
|
113 |
+
environment with:
|
114 |
+
```bash
|
115 |
+
conda env export -n crazyneuraluser -f environment.lock.yml
|
116 |
+
```
|
117 |
+
For multi-OS development, consider using `--no-builds` during the export.
|
118 |
+
3. Update your current environment with respect to a new `environment.lock.yml` using:
|
119 |
+
```bash
|
120 |
+
conda env update -f environment.lock.yml --prune
|
121 |
+
```
|
122 |
+
## Project Organization
|
123 |
+
|
124 |
+
```
|
125 |
+
├── AUTHORS.md <- List of developers and maintainers.
|
126 |
+
├── CHANGELOG.md <- Changelog to keep track of new features and fixes.
|
127 |
+
├── LICENSE.txt <- License as chosen on the command-line.
|
128 |
+
├── README.md <- The top-level README for developers.
|
129 |
+
├── configs <- Directory for configurations of model & application.
|
130 |
+
├── data
|
131 |
+
│ ├── external <- Data from third party sources.
|
132 |
+
│ ├── interim <- Intermediate data that has been transformed.
|
133 |
+
│ ├── processed <- The final, canonical data sets for modeling.
|
134 |
+
│ └── raw <- The original, immutable data dump.
|
135 |
+
├── docs <- Directory for Sphinx documentation in rst or md.
|
136 |
+
├── environment.yml <- The conda environment file for reproducibility.
|
137 |
+
├── models <- Trained and serialized models, model predictions,
|
138 |
+
│ or model summaries.
|
139 |
+
├── notebooks <- Jupyter notebooks. Naming convention is a number (for
|
140 |
+
│ ordering), the creator's initials and a description,
|
141 |
+
│ e.g. `1.0-fw-initial-data-exploration`.
|
142 |
+
├── pyproject.toml <- Build system configuration. Do not change!
|
143 |
+
├── references <- Data dictionaries, manuals, and all other materials.
|
144 |
+
├── reports <- Generated analysis as HTML, PDF, LaTeX, etc.
|
145 |
+
│ └── figures <- Generated plots and figures for reports.
|
146 |
+
├── scripts <- Analysis and production scripts which import the
|
147 |
+
│ actual Python package, e.g. train_model.py.
|
148 |
+
├── setup.cfg <- Declarative configuration of your project.
|
149 |
+
├── setup.py <- Use `pip install -e .` to install for development or
|
150 |
+
| or create a distribution with `tox -e build`.
|
151 |
+
├── src
|
152 |
+
│ └── crazyneuraluser <- Actual Python package where the main functionality goes.
|
153 |
+
├── tests <- Unit tests which can be run with `py.test`.
|
154 |
+
├── .coveragerc <- Configuration for coverage reports of unit tests.
|
155 |
+
├── .isort.cfg <- Configuration for git hook that sorts imports.
|
156 |
+
└── .pre-commit-config.yaml <- Configuration of pre-commit git hooks.
|
157 |
+
```
|
158 |
+
|
159 |
+
<!-- pyscaffold-notes -->
|
160 |
+
|
161 |
+
## Note
|
162 |
+
|
163 |
+
This project has been set up using [PyScaffold] 4.0.1 and the [dsproject extension] 0.6.1.
|
164 |
+
|
165 |
+
[conda]: https://docs.conda.io/
|
166 |
+
[pre-commit]: https://pre-commit.com/
|
167 |
+
[Jupyter]: https://jupyter.org/
|
168 |
+
[nbstripout]: https://github.com/kynan/nbstripout
|
169 |
+
[Google style]: http://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings
|
170 |
+
[PyScaffold]: https://pyscaffold.org/
|
171 |
+
[dsproject extension]: https://github.com/pyscaffold/pyscaffoldext-dsproject
|
172 |
+
|
173 |
+
|
src/crazyneuraluser.egg-info/SOURCES.txt
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.coveragerc
|
2 |
+
.gitignore
|
3 |
+
.isort.cfg
|
4 |
+
.pre-commit-config.yaml
|
5 |
+
.readthedocs.yml
|
6 |
+
AUTHORS.md
|
7 |
+
CHANGELOG.md
|
8 |
+
CONVLAB_README.md
|
9 |
+
LICENSE.txt
|
10 |
+
README.md
|
11 |
+
baselines_environment.lock.yml
|
12 |
+
baselines_environment.yml
|
13 |
+
dev_environment.yml
|
14 |
+
environment.yml
|
15 |
+
pyproject.toml
|
16 |
+
setup.cfg
|
17 |
+
setup.py
|
18 |
+
tox.ini
|
19 |
+
baselines/__init__.py
|
20 |
+
baselines/_preprocess_raw_canonical_map.py
|
21 |
+
baselines/baseline_setup.py
|
22 |
+
baselines/canonical_map.json
|
23 |
+
baselines/canonical_map.py
|
24 |
+
baselines/correct_categorical_state_values.tsv
|
25 |
+
baselines/create_data_from_multiwoz.py
|
26 |
+
baselines/create_dbleu_reference_map.py
|
27 |
+
baselines/goal_new_values.json
|
28 |
+
baselines/sanity_checks.py
|
29 |
+
baselines/schema.json
|
30 |
+
baselines/simulate_agent_interaction.py
|
31 |
+
baselines/simulate_corpus_interaction.py
|
32 |
+
baselines/system_models.py
|
33 |
+
baselines/user_models.py
|
34 |
+
baselines/utils.py
|
35 |
+
baselines/configs/agent_agent.yaml
|
36 |
+
configs/.gitignore
|
37 |
+
data/.gitignore
|
38 |
+
data/external/.gitignore
|
39 |
+
data/interim/.gitignore
|
40 |
+
data/preprocessed/.gitignore
|
41 |
+
data/raw/.gitignore
|
42 |
+
docs/Makefile
|
43 |
+
docs/authors.md
|
44 |
+
docs/changelog.md
|
45 |
+
docs/conf.py
|
46 |
+
docs/index.md
|
47 |
+
docs/license.rst
|
48 |
+
docs/readme.md
|
49 |
+
docs/requirements.txt
|
50 |
+
docs/_static/.gitignore
|
51 |
+
models/.gitignore
|
52 |
+
notebooks/1.0-ac-goals_consistency_check.ipynb
|
53 |
+
notebooks/template.ipynb
|
54 |
+
references/.gitignore
|
55 |
+
reports/figures/.gitignore
|
56 |
+
scripts/data_analysis.py
|
57 |
+
scripts/preprocess.py
|
58 |
+
scripts/preprocess2.1.py
|
59 |
+
scripts/template_train_model.py
|
60 |
+
scripts/train_ubar.py
|
61 |
+
src/crazyneuraluser/__init__.py
|
62 |
+
src/crazyneuraluser/clean_dataset.py
|
63 |
+
src/crazyneuraluser/config.py
|
64 |
+
src/crazyneuraluser/config21.py
|
65 |
+
src/crazyneuraluser/db_ops.py
|
66 |
+
src/crazyneuraluser/eval.py
|
67 |
+
src/crazyneuraluser/ontology.py
|
68 |
+
src/crazyneuraluser/reader.py
|
69 |
+
src/crazyneuraluser/utils.py
|
70 |
+
src/crazyneuraluser.egg-info/PKG-INFO
|
71 |
+
src/crazyneuraluser.egg-info/SOURCES.txt
|
72 |
+
src/crazyneuraluser.egg-info/dependency_links.txt
|
73 |
+
src/crazyneuraluser.egg-info/not-zip-safe
|
74 |
+
src/crazyneuraluser.egg-info/requires.txt
|
75 |
+
src/crazyneuraluser.egg-info/top_level.txt
|
76 |
+
tests/conftest.py
|
src/crazyneuraluser.egg-info/dependency_links.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
src/crazyneuraluser.egg-info/not-zip-safe
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
src/crazyneuraluser.egg-info/requires.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers==4.18.0
|
2 |
+
tqdm==4.64.0
|
3 |
+
wandb==0.12.16
|
4 |
+
nltk==3.7
|
5 |
+
sklearn==0.0
|
6 |
+
tensorboard==2.9.0
|
7 |
+
spacy==3.3.0
|
8 |
+
|
9 |
+
[:python_version < "3.8"]
|
10 |
+
importlib-metadata
|
11 |
+
|
12 |
+
[testing]
|
13 |
+
setuptools
|
14 |
+
pytest
|
15 |
+
pytest-cov
|
src/crazyneuraluser.egg-info/top_level.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
crazyneuraluser
|
src/crazyneuraluser/UBAR_code/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
|
3 |
+
if sys.version_info[:2] >= (3, 8):
|
4 |
+
# TODO: Import directly (no need for conditional) when `python_requires = >= 3.8`
|
5 |
+
from importlib.metadata import PackageNotFoundError, version # pragma: no cover
|
6 |
+
else:
|
7 |
+
from importlib_metadata import PackageNotFoundError, version # pragma: no cover
|
8 |
+
|
9 |
+
try:
|
10 |
+
# Change here if project is renamed and does not equal the package name
|
11 |
+
dist_name = __name__
|
12 |
+
__version__ = version(dist_name)
|
13 |
+
except PackageNotFoundError: # pragma: no cover
|
14 |
+
__version__ = "unknown"
|
15 |
+
finally:
|
16 |
+
del version, PackageNotFoundError
|
src/crazyneuraluser/UBAR_code/clean_dataset.py
ADDED
@@ -0,0 +1,334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
import re
|
3 |
+
|
4 |
+
from crazyneuraluser.UBAR_code import ontology
|
5 |
+
|
6 |
+
|
7 |
+
def my_clean_text(text):
|
8 |
+
text = re.sub(r"([a-zT]+)\.([a-z])", r"\1 . \2", text) # 'abc.xyz' -> 'abc . xyz'
|
9 |
+
text = re.sub(r"(\w+)\.\.? ", r"\1 . ", text) # if 'abc. ' -> 'abc . '
|
10 |
+
return text
|
11 |
+
|
12 |
+
|
13 |
+
def clean_text(text):
|
14 |
+
text = text.strip()
|
15 |
+
text = text.lower()
|
16 |
+
text = text.replace("’", "'")
|
17 |
+
text = text.replace("‘", "'")
|
18 |
+
text = text.replace(";", ",")
|
19 |
+
text = text.replace('"', " ")
|
20 |
+
text = text.replace("/", " and ")
|
21 |
+
text = text.replace("don't", "do n't")
|
22 |
+
text = clean_time(text)
|
23 |
+
baddata = {
|
24 |
+
r"c\.b (\d), (\d) ([a-z])\.([a-z])": r"cb\1\2\3\4",
|
25 |
+
"c.b. 1 7 d.y": "cb17dy",
|
26 |
+
"c.b.1 7 d.y": "cb17dy",
|
27 |
+
"c.b 25, 9 a.q": "cb259aq",
|
28 |
+
"isc.b 25, 9 a.q": "is cb259aq",
|
29 |
+
"c.b2, 1 u.f": "cb21uf",
|
30 |
+
"c.b 1,2 q.a": "cb12qa",
|
31 |
+
"0-122-336-5664": "01223365664",
|
32 |
+
"postcodecb21rs": "postcode cb21rs",
|
33 |
+
r"i\.d": "id",
|
34 |
+
" i d ": "id",
|
35 |
+
"Telephone:01223358966": "Telephone: 01223358966",
|
36 |
+
"depature": "departure",
|
37 |
+
"depearting": "departing",
|
38 |
+
"-type": " type",
|
39 |
+
r"b[\s]?&[\s]?b": "bed and breakfast",
|
40 |
+
"b and b": "bed and breakfast",
|
41 |
+
r"guesthouse[s]?": "guest house",
|
42 |
+
r"swimmingpool[s]?": "swimming pool",
|
43 |
+
"wo n't": "will not",
|
44 |
+
" 'd ": " would ",
|
45 |
+
" 'm ": " am ",
|
46 |
+
" 're' ": " are ",
|
47 |
+
" 'll' ": " will ",
|
48 |
+
" 've ": " have ",
|
49 |
+
r"^\'": "",
|
50 |
+
r"\'$": "",
|
51 |
+
}
|
52 |
+
for tmpl, good in baddata.items():
|
53 |
+
text = re.sub(tmpl, good, text)
|
54 |
+
|
55 |
+
text = re.sub(r"([a-zT]+)\.([a-z])", r"\1 . \2", text) # 'abc.xyz' -> 'abc . xyz'
|
56 |
+
text = re.sub(r"(\w+)\.\.? ", r"\1 . ", text) # if 'abc. ' -> 'abc . '
|
57 |
+
|
58 |
+
with open("data/raw/UBAR/multi-woz/mapping.pair", "r") as fin:
|
59 |
+
for line in fin.readlines():
|
60 |
+
fromx, tox = line.replace("\n", "").split("\t")
|
61 |
+
text = " " + text + " "
|
62 |
+
text = text.replace(" " + fromx + " ", " " + tox + " ")[1:-1]
|
63 |
+
|
64 |
+
return text
|
65 |
+
|
66 |
+
|
67 |
+
def clean_time(utter):
|
68 |
+
utter = re.sub(
|
69 |
+
r"(\d+) ([ap]\.?m)", lambda x: x.group(1) + x.group(2), utter
|
70 |
+
) # 9 am -> 9am
|
71 |
+
utter = re.sub(r"((?<!\d)\d:\d+)(am)?", r"0\1", utter)
|
72 |
+
utter = re.sub(r"((?<!\d)\d)am", r"0\1:00", utter)
|
73 |
+
utter = re.sub(r"((?<!\d)\d)pm", lambda x: str(int(x.group(1)) + 12) + ":00", utter)
|
74 |
+
utter = re.sub(
|
75 |
+
r"(\d+)(:\d+)pm", lambda x: str(int(x.group(1)) + 12) + x.group(2), utter
|
76 |
+
)
|
77 |
+
utter = re.sub(r"(\d+)a\.?m", r"\1", utter)
|
78 |
+
return utter
|
79 |
+
|
80 |
+
|
81 |
+
def clean_slot_values(domain, slot, value):
|
82 |
+
value = clean_text(value)
|
83 |
+
if not value:
|
84 |
+
value = ""
|
85 |
+
elif value == "not mentioned":
|
86 |
+
value = ""
|
87 |
+
# value = 'not mentioned' # if in DST setting
|
88 |
+
elif domain == "attraction":
|
89 |
+
if slot == "name":
|
90 |
+
if value == "t":
|
91 |
+
value = ""
|
92 |
+
if value == "trinity":
|
93 |
+
value = "trinity college"
|
94 |
+
elif slot == "area":
|
95 |
+
if value in ["town centre", "cent", "center", "ce"]:
|
96 |
+
value = "centre"
|
97 |
+
elif value in ["ely", "in town", "museum", "norwich", "same area as hotel"]:
|
98 |
+
value = ""
|
99 |
+
elif value in ["we"]:
|
100 |
+
value = "west"
|
101 |
+
elif slot == "type":
|
102 |
+
if value in ["m", "mus", "musuem"]:
|
103 |
+
value = "museum"
|
104 |
+
elif value in ["art", "architectural"]:
|
105 |
+
value = "architecture"
|
106 |
+
elif value in ["churches"]:
|
107 |
+
value = "church"
|
108 |
+
elif value in ["coll"]:
|
109 |
+
value = "college"
|
110 |
+
elif value in ["concert", "concerthall"]:
|
111 |
+
value = "concert hall"
|
112 |
+
elif value in ["night club"]:
|
113 |
+
value = "nightclub"
|
114 |
+
elif value in ["mutiple sports", "mutliple sports", "sports", "galleria"]:
|
115 |
+
value = "multiple sports"
|
116 |
+
elif value in ["ol", "science", "gastropub", "la raza"]:
|
117 |
+
value = ""
|
118 |
+
elif value in ["swimmingpool", "pool"]:
|
119 |
+
value = "swimming pool"
|
120 |
+
elif value in ["fun"]:
|
121 |
+
value = "entertainment"
|
122 |
+
|
123 |
+
elif domain == "hotel":
|
124 |
+
if slot == "area":
|
125 |
+
if value in ["cen", "centre of town", "near city center", "center"]:
|
126 |
+
value = "centre"
|
127 |
+
elif value in ["east area", "east side"]:
|
128 |
+
value = "east"
|
129 |
+
elif value in ["in the north", "north part of town"]:
|
130 |
+
value = "north"
|
131 |
+
elif value in ["we"]:
|
132 |
+
value = "west"
|
133 |
+
elif slot == "day":
|
134 |
+
if value == "monda":
|
135 |
+
value = "monday"
|
136 |
+
elif value == "t":
|
137 |
+
value = "tuesday"
|
138 |
+
elif slot == "name":
|
139 |
+
if value == "uni":
|
140 |
+
value = "university arms hotel"
|
141 |
+
elif value == "university arms":
|
142 |
+
value = "university arms hotel"
|
143 |
+
elif value == "acron":
|
144 |
+
value = "acorn guest house"
|
145 |
+
elif value == "ashley":
|
146 |
+
value = "ashley hotel"
|
147 |
+
elif value == "arbury lodge guesthouse":
|
148 |
+
value = "arbury lodge guest house"
|
149 |
+
elif value == "la":
|
150 |
+
value = "la margherit"
|
151 |
+
elif value == "no":
|
152 |
+
value = ""
|
153 |
+
elif slot == "internet":
|
154 |
+
if value == "does not":
|
155 |
+
value = "no"
|
156 |
+
elif value in ["y", "free", "free internet"]:
|
157 |
+
value = "yes"
|
158 |
+
elif value in ["4"]:
|
159 |
+
value = ""
|
160 |
+
elif slot == "parking":
|
161 |
+
if value == "n":
|
162 |
+
value = "no"
|
163 |
+
elif value in ["free parking"]:
|
164 |
+
value = "yes"
|
165 |
+
elif value in ["y"]:
|
166 |
+
value = "yes"
|
167 |
+
elif slot in ["pricerange", "price range"]:
|
168 |
+
slot = "pricerange"
|
169 |
+
if value == "moderately":
|
170 |
+
value = "moderate"
|
171 |
+
elif value in ["any"]:
|
172 |
+
value = "do n't care"
|
173 |
+
elif value in ["any"]:
|
174 |
+
value = "do n't care"
|
175 |
+
elif value in ["inexpensive"]:
|
176 |
+
value = "cheap"
|
177 |
+
elif value in ["2", "4"]:
|
178 |
+
value = ""
|
179 |
+
elif slot == "stars":
|
180 |
+
if value == "two":
|
181 |
+
value = "2"
|
182 |
+
elif value == "three":
|
183 |
+
value = "3"
|
184 |
+
elif value in ["4-star", "4 stars", "4 star", "four star", "four stars"]:
|
185 |
+
value = "4"
|
186 |
+
elif slot == "type":
|
187 |
+
if value == "0 star rarting":
|
188 |
+
value = ""
|
189 |
+
elif value == "guesthouse":
|
190 |
+
value = "guest house"
|
191 |
+
elif value not in ["hotel", "guest house", "do n't care"]:
|
192 |
+
value = ""
|
193 |
+
elif domain == "restaurant":
|
194 |
+
if slot == "area":
|
195 |
+
if value in [
|
196 |
+
"center",
|
197 |
+
"scentre",
|
198 |
+
"center of town",
|
199 |
+
"city center",
|
200 |
+
"cb30aq",
|
201 |
+
"town center",
|
202 |
+
"centre of cambridge",
|
203 |
+
"city centre",
|
204 |
+
]:
|
205 |
+
value = "centre"
|
206 |
+
elif value == "west part of town":
|
207 |
+
value = "west"
|
208 |
+
elif value == "n":
|
209 |
+
value = "north"
|
210 |
+
elif value in ["the south"]:
|
211 |
+
value = "south"
|
212 |
+
elif value not in [
|
213 |
+
"centre",
|
214 |
+
"south",
|
215 |
+
"do n't care",
|
216 |
+
"west",
|
217 |
+
"east",
|
218 |
+
"north",
|
219 |
+
]:
|
220 |
+
value = ""
|
221 |
+
elif slot == "day":
|
222 |
+
if value == "monda":
|
223 |
+
value = "monday"
|
224 |
+
elif value == "t":
|
225 |
+
value = "tuesday"
|
226 |
+
elif slot in ["pricerange", "price range"]:
|
227 |
+
slot = "pricerange"
|
228 |
+
if value in ["moderately", "mode", "mo"]:
|
229 |
+
value = "moderate"
|
230 |
+
elif value in ["not"]:
|
231 |
+
value = ""
|
232 |
+
elif value in ["inexpensive", "ch"]:
|
233 |
+
value = "cheap"
|
234 |
+
elif slot == "food":
|
235 |
+
if value == "barbecue":
|
236 |
+
value = "barbeque"
|
237 |
+
elif slot == "pricerange":
|
238 |
+
if value == "moderately":
|
239 |
+
value = "moderate"
|
240 |
+
elif slot == "time":
|
241 |
+
if value == "9:00":
|
242 |
+
value = "09:00"
|
243 |
+
elif value == "9:45":
|
244 |
+
value = "09:45"
|
245 |
+
elif value == "1330":
|
246 |
+
value = "13:30"
|
247 |
+
elif value == "1430":
|
248 |
+
value = "14:30"
|
249 |
+
elif value == "9:15":
|
250 |
+
value = "09:15"
|
251 |
+
elif value == "9:30":
|
252 |
+
value = "09:30"
|
253 |
+
elif value == "1830":
|
254 |
+
value = "18:30"
|
255 |
+
elif value == "9":
|
256 |
+
value = "09:00"
|
257 |
+
elif value == "2:00":
|
258 |
+
value = "14:00"
|
259 |
+
elif value == "1:00":
|
260 |
+
value = "13:00"
|
261 |
+
elif value == "3:00":
|
262 |
+
value = "15:00"
|
263 |
+
elif domain == "taxi":
|
264 |
+
if slot in ["arriveBy", "arrive by"]:
|
265 |
+
slot = "arriveby"
|
266 |
+
if value == "1530":
|
267 |
+
value = "15:30"
|
268 |
+
elif value == "15 minutes":
|
269 |
+
value = ""
|
270 |
+
elif slot in ["leaveAt", "leave at"]:
|
271 |
+
slot = "leaveat"
|
272 |
+
if value == "1:00":
|
273 |
+
value = "01:00"
|
274 |
+
elif value == "21:4":
|
275 |
+
value = "21:04"
|
276 |
+
elif value == "4:15":
|
277 |
+
value = "04:15"
|
278 |
+
elif value == "5:45":
|
279 |
+
value = "05:45"
|
280 |
+
elif value == "0700":
|
281 |
+
value = "07:00"
|
282 |
+
elif value == "4:45":
|
283 |
+
value = "04:45"
|
284 |
+
elif value == "8:30":
|
285 |
+
value = "08:30"
|
286 |
+
elif value == "9:30":
|
287 |
+
value = "09:30"
|
288 |
+
value = value.replace(".", ":")
|
289 |
+
|
290 |
+
elif domain == "train":
|
291 |
+
if slot in ["arriveBy", "arrive by"]:
|
292 |
+
slot = "arriveby"
|
293 |
+
if value == "1":
|
294 |
+
value = "01:00"
|
295 |
+
elif value in ["does not care", "doesnt care", "doesn't care"]:
|
296 |
+
value = "do n't care"
|
297 |
+
elif value == "8:30":
|
298 |
+
value = "08:30"
|
299 |
+
elif value == "not 15:45":
|
300 |
+
value = ""
|
301 |
+
value = value.replace(".", ":")
|
302 |
+
elif slot == "day":
|
303 |
+
if value == "doesnt care" or value == "doesn't care":
|
304 |
+
value = "do n't care"
|
305 |
+
elif slot in ["leaveAt", "leave at"]:
|
306 |
+
slot = "leaveat"
|
307 |
+
if value == "2:30":
|
308 |
+
value = "02:30"
|
309 |
+
elif value == "7:54":
|
310 |
+
value = "07:54"
|
311 |
+
elif value == "after 5:45 pm":
|
312 |
+
value = "17:45"
|
313 |
+
elif value in ["early evening", "friday", "sunday", "tuesday", "afternoon"]:
|
314 |
+
value = ""
|
315 |
+
elif value == "12":
|
316 |
+
value = "12:00"
|
317 |
+
elif value == "1030":
|
318 |
+
value = "10:30"
|
319 |
+
elif value == "1700":
|
320 |
+
value = "17:00"
|
321 |
+
elif value in [
|
322 |
+
"does not care",
|
323 |
+
"doesnt care",
|
324 |
+
"do nt care",
|
325 |
+
"doesn't care",
|
326 |
+
]:
|
327 |
+
value = "do n't care"
|
328 |
+
|
329 |
+
value = value.replace(".", ":")
|
330 |
+
if value in ["dont care", "don't care", "do nt care", "doesn't care"]:
|
331 |
+
value = "do n't care"
|
332 |
+
if ontology.normlize_slot_names.get(slot):
|
333 |
+
slot = ontology.normlize_slot_names[slot]
|
334 |
+
return slot, value
|
src/crazyneuraluser/UBAR_code/config.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
|
5 |
+
|
6 |
+
class _Config:
|
7 |
+
def __init__(self):
|
8 |
+
self._multiwoz_ubar_init()
|
9 |
+
|
10 |
+
def _multiwoz_ubar_init(self):
|
11 |
+
self.gpt_path = "distilgpt2"
|
12 |
+
|
13 |
+
self.vocab_path_train = "./data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/vocab"
|
14 |
+
self.vocab_path_eval = None
|
15 |
+
self.data_path = "./data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/"
|
16 |
+
self.data_file = "data_for_ubar.json"
|
17 |
+
self.dev_list = "data/raw/UBAR/multi-woz/valListFile.json"
|
18 |
+
self.test_list = "data/raw/UBAR/multi-woz/testListFile.json"
|
19 |
+
self.dbs = {
|
20 |
+
"attraction": "data/preprocessed_gen_usr_utts/UBAR/db_processed/attraction_db_processed.json",
|
21 |
+
"hospital": "data/preprocessed_gen_usr_utts/UBAR/db_processed/hospital_db_processed.json",
|
22 |
+
"hotel": "data/preprocessed_gen_usr_utts/UBAR/db_processed/hotel_db_processed.json",
|
23 |
+
"police": "data/preprocessed_gen_usr_utts/UBAR/db_processed/police_db_processed.json",
|
24 |
+
"restaurant": "data/preprocessed_gen_usr_utts/UBAR/db_processed/restaurant_db_processed.json",
|
25 |
+
"taxi": "data/preprocessed_gen_usr_utts/UBAR/db_processed/taxi_db_processed.json",
|
26 |
+
"train": "data/preprocessed_gen_usr_utts/UBAR/db_processed/train_db_processed.json",
|
27 |
+
}
|
28 |
+
self.glove_path = "./data/glove/glove.6B.50d.txt"
|
29 |
+
self.domain_file_path = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/domain_files.json"
|
30 |
+
self.slot_value_set_path = "data/preprocessed_gen_usr_utts/UBAR/db_processed/value_set_processed.json"
|
31 |
+
self.multi_acts_path = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/multi_act_mapping_train.json"
|
32 |
+
self.exp_path = "to be generated"
|
33 |
+
self.log_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
|
34 |
+
|
35 |
+
# experiment settings
|
36 |
+
self.mode = "unknown"
|
37 |
+
self.cuda = False
|
38 |
+
self.cuda_device = [0]
|
39 |
+
self.exp_no = ""
|
40 |
+
self.seed = 11
|
41 |
+
self.save_log = True # tensorboard
|
42 |
+
self.evaluate_during_training = False # evaluate during training
|
43 |
+
self.report_interval = 200 # 485 for bs 128
|
44 |
+
self.max_nl_length = 60
|
45 |
+
self.max_span_length = 30
|
46 |
+
self.truncated = False
|
47 |
+
|
48 |
+
# training settings
|
49 |
+
self.lr = 1e-4
|
50 |
+
self.warmup_steps = -1
|
51 |
+
self.weight_decay = 0.0
|
52 |
+
self.gradient_accumulation_steps = 16
|
53 |
+
self.batch_size = 2
|
54 |
+
|
55 |
+
self.label_smoothing = 0.0
|
56 |
+
self.lr_decay = 0.5
|
57 |
+
self.epoch_num = 40
|
58 |
+
self.early_stop_count = 5
|
59 |
+
self.weight_decay_count = 3
|
60 |
+
self.teacher_force = 100
|
61 |
+
self.multi_acts_training = False
|
62 |
+
self.multi_act_sampling_num = 1
|
63 |
+
self.valid_loss = "score"
|
64 |
+
|
65 |
+
self.wandb_train_run_name = "Train with generated usr utterances"
|
66 |
+
|
67 |
+
# evaluation settings
|
68 |
+
self.eval_load_path = "models/UBAR/experiments/distilgpt-2_sd11_lr0.0001_bs16_ga2/epoch53_trloss0.59_gpt2"
|
69 |
+
self.model_output = "model_output_e2e_FFFT_fix_bs.json"
|
70 |
+
self.eval_per_domain = False
|
71 |
+
self.eval_set = "test" # test, dev
|
72 |
+
|
73 |
+
self.wandb_eval_run_name = "US Generated usr utterances evaluation"
|
74 |
+
|
75 |
+
# my setting
|
76 |
+
self.use_true_prev_bspn = False
|
77 |
+
self.use_true_prev_aspn = False
|
78 |
+
self.use_true_db_pointer = False
|
79 |
+
self.use_true_prev_resp = False
|
80 |
+
|
81 |
+
self.use_true_curr_bspn = False
|
82 |
+
self.use_true_curr_aspn = False
|
83 |
+
self.use_all_previous_context = True
|
84 |
+
|
85 |
+
self.exp_domains = ["all"] # hotel,train, attraction, restaurant, taxi
|
86 |
+
self.log_path = "logs_test"
|
87 |
+
self.low_resource = False
|
88 |
+
###
|
89 |
+
|
90 |
+
# dst setting
|
91 |
+
self.fix_bs = True
|
92 |
+
self.use_nodelex_resp = True
|
93 |
+
self.max_context_length = 900
|
94 |
+
##
|
95 |
+
|
96 |
+
# model settings
|
97 |
+
self.vocab_size = 3000
|
98 |
+
self.embed_size = 50
|
99 |
+
self.hidden_size = 100
|
100 |
+
self.pointer_dim = 6 # fixed
|
101 |
+
self.enc_layer_num = 1
|
102 |
+
self.dec_layer_num = 1
|
103 |
+
self.dropout = 0
|
104 |
+
self.layer_norm = False
|
105 |
+
self.skip_connect = False
|
106 |
+
self.encoder_share = False
|
107 |
+
self.attn_param_share = False
|
108 |
+
self.copy_param_share = False
|
109 |
+
self.enable_aspn = True
|
110 |
+
self.use_pvaspn = False
|
111 |
+
self.enable_bspn = True
|
112 |
+
self.bspn_mode = "bspn" # 'bspn' or 'bsdx'
|
113 |
+
self.enable_dspn = False # removed
|
114 |
+
self.enable_dst = False
|
115 |
+
|
116 |
+
self.use_true_bspn_for_ctr_eval = True
|
117 |
+
self.use_true_domain_for_ctr_eval = True
|
118 |
+
self.limit_bspn_vocab = False
|
119 |
+
self.limit_aspn_vocab = False
|
120 |
+
self.same_eval_as_cambridge = True
|
121 |
+
self.same_eval_act_f1_as_hdsa = False
|
122 |
+
self.aspn_decode_mode = "greedy" # beam, greedy, nucleur_sampling, topk_sampling
|
123 |
+
self.beam_width = 5
|
124 |
+
self.nbest = 5
|
125 |
+
self.beam_diverse_param = 0.2
|
126 |
+
self.act_selection_scheme = "high_test_act_f1"
|
127 |
+
self.topk_num = 1
|
128 |
+
self.nucleur_p = 0.0
|
129 |
+
self.record_mode = False
|
130 |
+
|
131 |
+
def __str__(self):
|
132 |
+
s = ""
|
133 |
+
for k, v in self.__dict__.items():
|
134 |
+
s += "{} : {}\n".format(k, v)
|
135 |
+
return s
|
136 |
+
|
137 |
+
def _init_logging_handler(self, mode):
|
138 |
+
stderr_handler = logging.StreamHandler()
|
139 |
+
if not os.path.exists("./log"):
|
140 |
+
os.mkdir("./log")
|
141 |
+
if self.save_log and self.mode == "train":
|
142 |
+
file_handler = logging.FileHandler(
|
143 |
+
"./log/log_{}_{}_{}_{}_sd{}.txt".format(
|
144 |
+
self.log_time,
|
145 |
+
mode,
|
146 |
+
"-".join(self.exp_domains),
|
147 |
+
self.exp_no,
|
148 |
+
self.seed,
|
149 |
+
)
|
150 |
+
)
|
151 |
+
logging.basicConfig(handlers=[stderr_handler, file_handler])
|
152 |
+
elif self.mode == "test":
|
153 |
+
eval_log_path = os.path.join(self.eval_load_path, "eval_log.json")
|
154 |
+
# if os.path.exists(eval_log_path):
|
155 |
+
# os.remove(eval_log_path)
|
156 |
+
file_handler = logging.FileHandler(eval_log_path)
|
157 |
+
logging.basicConfig(handlers=[stderr_handler, file_handler])
|
158 |
+
else:
|
159 |
+
logging.basicConfig(handlers=[stderr_handler])
|
160 |
+
logger = logging.getLogger()
|
161 |
+
logger.setLevel(logging.INFO)
|
162 |
+
|
163 |
+
|
164 |
+
global_config = _Config()
|
src/crazyneuraluser/UBAR_code/config21.py
ADDED
@@ -0,0 +1,169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import logging
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
|
5 |
+
|
6 |
+
class _Config:
|
7 |
+
def __init__(self):
|
8 |
+
self._multiwoz_ubar_init()
|
9 |
+
|
10 |
+
def _multiwoz_ubar_init(self):
|
11 |
+
self.gpt_path = "/data/yangyy/BERT-models/huggingface/distilgpt2/"
|
12 |
+
|
13 |
+
self.vocab_path_train = "./data/multi-woz-2.1-processed/vocab"
|
14 |
+
self.vocab_path_eval = None
|
15 |
+
self.data_path = "./data/multi-woz-2.1-processed/"
|
16 |
+
self.data_file = "data_for_ubar.json"
|
17 |
+
self.dev_list = "data/multi-woz/valListFile.json"
|
18 |
+
self.test_list = "data/multi-woz/testListFile.json"
|
19 |
+
self.dbs = {
|
20 |
+
"attraction": "data/preprocessed/UBAR/db_processed/attraction_db_processed.json",
|
21 |
+
"hospital": "data/preprocessed/UBAR/db_processed/hospital_db_processed.json",
|
22 |
+
"hotel": "data/preprocessed/UBAR/db_processed/hotel_db_processed.json",
|
23 |
+
"police": "data/preprocessed/UBAR/db_processed/police_db_processed.json",
|
24 |
+
"restaurant": "data/preprocessed/UBAR/db_processed/restaurant_db_processed.json",
|
25 |
+
"taxi": "data/preprocessed/UBAR/db_processed/taxi_db_processed.json",
|
26 |
+
"train": "data/preprocessed/UBAR/db_processed/train_db_processed.json",
|
27 |
+
}
|
28 |
+
self.glove_path = "./data/glove/glove.6B.50d.txt"
|
29 |
+
self.domain_file_path = (
|
30 |
+
"data/preprocessed/UBAR/multi-woz-2.1-processed/domain_files.json"
|
31 |
+
)
|
32 |
+
self.slot_value_set_path = (
|
33 |
+
"data/preprocessed/UBAR/db_processed/value_set_processed.json"
|
34 |
+
)
|
35 |
+
self.multi_acts_path = "data/preprocessed/UBAR/multi-woz-2.1-processed/multi_act_mapping_train.json"
|
36 |
+
self.exp_path = "to be generated"
|
37 |
+
self.log_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
|
38 |
+
|
39 |
+
# experiment settings
|
40 |
+
self.mode = "unknown"
|
41 |
+
self.cuda = True
|
42 |
+
self.cuda_device = [1]
|
43 |
+
self.exp_no = ""
|
44 |
+
self.seed = 11
|
45 |
+
self.exp_domains = ["all"]
|
46 |
+
self.save_log = True # tensorboard
|
47 |
+
self.evaluate_during_training = False # evaluate during training
|
48 |
+
self.report_interval = 200 # 485 for bs 128
|
49 |
+
self.max_nl_length = 60
|
50 |
+
self.max_span_length = 30
|
51 |
+
self.truncated = False
|
52 |
+
|
53 |
+
# model settings
|
54 |
+
self.vocab_size = 3000
|
55 |
+
self.embed_size = 50
|
56 |
+
self.hidden_size = 100
|
57 |
+
self.pointer_dim = 6 # fixed
|
58 |
+
self.enc_layer_num = 1
|
59 |
+
self.dec_layer_num = 1
|
60 |
+
self.dropout = 0
|
61 |
+
self.layer_norm = False
|
62 |
+
self.skip_connect = False
|
63 |
+
self.encoder_share = False
|
64 |
+
self.attn_param_share = False
|
65 |
+
self.copy_param_share = False
|
66 |
+
self.enable_aspn = True
|
67 |
+
self.use_pvaspn = False
|
68 |
+
self.enable_bspn = True
|
69 |
+
self.bspn_mode = "bsdx" # 'bspn' or 'bsdx'
|
70 |
+
self.enable_dspn = False # removed
|
71 |
+
self.enable_dst = False
|
72 |
+
|
73 |
+
# training settings
|
74 |
+
self.lr = 5e-4
|
75 |
+
self.warmup_steps = 2000 # gpt tbd
|
76 |
+
self.weight_decay = 0.0 # gpt tbd
|
77 |
+
self.gradient_accumulation_steps = 16
|
78 |
+
self.batch_size = 2
|
79 |
+
|
80 |
+
self.label_smoothing = 0.0
|
81 |
+
self.lr_decay = 0.5
|
82 |
+
self.epoch_num = 60
|
83 |
+
self.early_stop_count = 5
|
84 |
+
self.weight_decay_count = 3
|
85 |
+
self.teacher_force = 100
|
86 |
+
self.multi_acts_training = False
|
87 |
+
self.multi_act_sampling_num = 1
|
88 |
+
self.valid_loss = "score"
|
89 |
+
|
90 |
+
self.wandb_train_run_name = "Name to be added"
|
91 |
+
|
92 |
+
# evaluation settings
|
93 |
+
self.eval_load_path = "models/UBAR/experiments/all_0729_sd11_lr0.0001_bs2_ga16/epoch43_trloss0.56_gpt2"
|
94 |
+
self.model_output = "model_output_e2e_FFFT_fix_bs.json"
|
95 |
+
self.eval_per_domain = False
|
96 |
+
self.eval_set = "test" # test, dev
|
97 |
+
|
98 |
+
self.wandb_eval_run_name = "Name to be added"
|
99 |
+
|
100 |
+
# generation setting
|
101 |
+
self.use_true_prev_bspn = True
|
102 |
+
self.use_true_prev_aspn = True
|
103 |
+
self.use_true_db_pointer = False
|
104 |
+
self.use_true_prev_resp = True
|
105 |
+
|
106 |
+
self.use_true_curr_bspn = True
|
107 |
+
self.use_true_curr_aspn = False
|
108 |
+
self.use_all_previous_context = True
|
109 |
+
|
110 |
+
self.exp_domains = ["all"] # hotel,train, attraction, restaurant, taxi
|
111 |
+
self.log_path = "logs2.1"
|
112 |
+
self.low_resource = False
|
113 |
+
|
114 |
+
# dst setting
|
115 |
+
self.fix_bs = True
|
116 |
+
self.use_nodelex_resp = True
|
117 |
+
self.max_context_length = 900
|
118 |
+
|
119 |
+
self.use_true_bspn_for_ctr_eval = True
|
120 |
+
self.use_true_domain_for_ctr_eval = True
|
121 |
+
self.limit_bspn_vocab = False
|
122 |
+
self.limit_aspn_vocab = False
|
123 |
+
self.same_eval_as_cambridge = True
|
124 |
+
self.same_eval_act_f1_as_hdsa = False
|
125 |
+
self.aspn_decode_mode = (
|
126 |
+
"greedy" # beam, greedy, nucleur_sampling, topk_sampling
|
127 |
+
)
|
128 |
+
self.beam_width = 5
|
129 |
+
self.nbest = 5
|
130 |
+
self.beam_diverse_param = 0.2
|
131 |
+
self.act_selection_scheme = "high_test_act_f1"
|
132 |
+
self.topk_num = 1
|
133 |
+
self.nucleur_p = 0.0
|
134 |
+
self.record_mode = False
|
135 |
+
|
136 |
+
def __str__(self):
|
137 |
+
s = ""
|
138 |
+
for k, v in self.__dict__.items():
|
139 |
+
s += "{} : {}\n".format(k, v)
|
140 |
+
return s
|
141 |
+
|
142 |
+
def _init_logging_handler(self, mode):
|
143 |
+
stderr_handler = logging.StreamHandler()
|
144 |
+
if not os.path.exists("./log"):
|
145 |
+
os.mkdir("./log")
|
146 |
+
if self.save_log and self.mode == "train":
|
147 |
+
file_handler = logging.FileHandler(
|
148 |
+
"./log/log_{}_{}_{}_{}_sd{}.txt".format(
|
149 |
+
self.log_time,
|
150 |
+
mode,
|
151 |
+
"-".join(self.exp_domains),
|
152 |
+
self.exp_no,
|
153 |
+
self.seed,
|
154 |
+
)
|
155 |
+
)
|
156 |
+
logging.basicConfig(handlers=[stderr_handler, file_handler])
|
157 |
+
elif self.mode == "test":
|
158 |
+
eval_log_path = os.path.join(self.eval_load_path, "eval_log.json")
|
159 |
+
# if os.path.exists(eval_log_path):
|
160 |
+
# os.remove(eval_log_path)
|
161 |
+
file_handler = logging.FileHandler(eval_log_path)
|
162 |
+
logging.basicConfig(handlers=[stderr_handler, file_handler])
|
163 |
+
else:
|
164 |
+
logging.basicConfig(handlers=[stderr_handler])
|
165 |
+
logger = logging.getLogger()
|
166 |
+
logger.setLevel(logging.INFO)
|
167 |
+
|
168 |
+
|
169 |
+
global_config = _Config()
|
src/crazyneuraluser/UBAR_code/db_ops.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import random
|
3 |
+
import sqlite3
|
4 |
+
import string
|
5 |
+
|
6 |
+
from crazyneuraluser.UBAR_code.ontology import all_domains, db_domains
|
7 |
+
|
8 |
+
|
9 |
+
class MultiWozDB(object):
|
10 |
+
def __init__(self, db_paths):
|
11 |
+
self.dbs = {}
|
12 |
+
self.sql_dbs = {}
|
13 |
+
for domain in all_domains:
|
14 |
+
with open(db_paths[domain], "r") as f:
|
15 |
+
self.dbs[domain] = json.loads(f.read().lower())
|
16 |
+
|
17 |
+
def oneHotVector(self, domain, num):
|
18 |
+
"""Return number of available entities for particular domain."""
|
19 |
+
vector = [0, 0, 0, 0]
|
20 |
+
if num == "":
|
21 |
+
return vector
|
22 |
+
if domain != "train":
|
23 |
+
if num == 0:
|
24 |
+
vector = [1, 0, 0, 0]
|
25 |
+
elif num == 1:
|
26 |
+
vector = [0, 1, 0, 0]
|
27 |
+
elif num <= 3:
|
28 |
+
vector = [0, 0, 1, 0]
|
29 |
+
else:
|
30 |
+
vector = [0, 0, 0, 1]
|
31 |
+
else:
|
32 |
+
if num == 0:
|
33 |
+
vector = [1, 0, 0, 0]
|
34 |
+
elif num <= 5:
|
35 |
+
vector = [0, 1, 0, 0]
|
36 |
+
elif num <= 10:
|
37 |
+
vector = [0, 0, 1, 0]
|
38 |
+
else:
|
39 |
+
vector = [0, 0, 0, 1]
|
40 |
+
return vector
|
41 |
+
|
42 |
+
def addBookingPointer(self, turn_da):
|
43 |
+
"""Add information about availability of the booking option."""
|
44 |
+
# Booking pointer
|
45 |
+
# Do not consider booking two things in a single turn.
|
46 |
+
vector = [0, 0]
|
47 |
+
if turn_da.get("booking-nobook"):
|
48 |
+
vector = [1, 0]
|
49 |
+
if turn_da.get("booking-book") or turn_da.get("train-offerbooked"):
|
50 |
+
vector = [0, 1]
|
51 |
+
return vector
|
52 |
+
|
53 |
+
def addDBPointer(self, domain, match_num, return_num=False):
|
54 |
+
"""Create database pointer for all related domains."""
|
55 |
+
# if turn_domains is None:
|
56 |
+
# turn_domains = db_domains
|
57 |
+
if domain in db_domains:
|
58 |
+
vector = self.oneHotVector(domain, match_num)
|
59 |
+
else:
|
60 |
+
vector = [0, 0, 0, 0]
|
61 |
+
return vector
|
62 |
+
|
63 |
+
def addDBIndicator(self, domain, match_num, return_num=False):
|
64 |
+
"""Create database indicator for all related domains."""
|
65 |
+
# if turn_domains is None:
|
66 |
+
# turn_domains = db_domains
|
67 |
+
if domain in db_domains:
|
68 |
+
vector = self.oneHotVector(domain, match_num)
|
69 |
+
else:
|
70 |
+
vector = [0, 0, 0, 0]
|
71 |
+
|
72 |
+
# '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]'
|
73 |
+
if vector == [0, 0, 0, 0]:
|
74 |
+
indicator = "[db_nores]"
|
75 |
+
else:
|
76 |
+
indicator = "[db_%s]" % vector.index(1)
|
77 |
+
return indicator
|
78 |
+
|
79 |
+
def get_match_num(self, constraints, return_entry=False):
|
80 |
+
"""Create database pointer for all related domains."""
|
81 |
+
match = {"general": ""}
|
82 |
+
entry = {}
|
83 |
+
# if turn_domains is None:
|
84 |
+
# turn_domains = db_domains
|
85 |
+
for domain in all_domains:
|
86 |
+
match[domain] = ""
|
87 |
+
if domain in db_domains and constraints.get(domain):
|
88 |
+
matched_ents = self.queryJsons(domain, constraints[domain])
|
89 |
+
match[domain] = len(matched_ents)
|
90 |
+
if return_entry:
|
91 |
+
entry[domain] = matched_ents
|
92 |
+
if return_entry:
|
93 |
+
return entry
|
94 |
+
return match
|
95 |
+
|
96 |
+
def pointerBack(self, vector, domain):
|
97 |
+
# multi domain implementation
|
98 |
+
# domnum = cfg.domain_num
|
99 |
+
if domain.endswith("]"):
|
100 |
+
domain = domain[1:-1]
|
101 |
+
if domain != "train":
|
102 |
+
nummap = {0: "0", 1: "1", 2: "2-3", 3: ">3"}
|
103 |
+
else:
|
104 |
+
nummap = {0: "0", 1: "1-5", 2: "6-10", 3: ">10"}
|
105 |
+
if vector[:4] == [0, 0, 0, 0]:
|
106 |
+
report = ""
|
107 |
+
else:
|
108 |
+
num = vector.index(1)
|
109 |
+
report = domain + ": " + nummap[num] + "; "
|
110 |
+
|
111 |
+
if vector[-2] == 0 and vector[-1] == 1:
|
112 |
+
report += "booking: ok"
|
113 |
+
if vector[-2] == 1 and vector[-1] == 0:
|
114 |
+
report += "booking: unable"
|
115 |
+
|
116 |
+
return report
|
117 |
+
|
118 |
+
def queryJsons(self, domain, constraints, exactly_match=True, return_name=False):
|
119 |
+
"""Returns the list of entities for a given domain
|
120 |
+
based on the annotation of the belief state
|
121 |
+
constraints: dict e.g. {'pricerange': 'cheap', 'area': 'west'}
|
122 |
+
"""
|
123 |
+
# query the db
|
124 |
+
if domain == "taxi":
|
125 |
+
return [
|
126 |
+
{
|
127 |
+
"taxi_colors": random.choice(self.dbs[domain][0]["taxi_colors"]),
|
128 |
+
"taxi_types": random.choice(self.dbs[domain][0]["taxi_types"]),
|
129 |
+
"taxi_phone": "".join(random.choices(string.digits, k=10)),
|
130 |
+
}
|
131 |
+
]
|
132 |
+
if domain == "police":
|
133 |
+
return self.dbs["police"]
|
134 |
+
if domain == "hospital":
|
135 |
+
if constraints.get("department"):
|
136 |
+
for entry in self.dbs["hospital"]:
|
137 |
+
if entry.get("department") == constraints.get("department"):
|
138 |
+
return [entry]
|
139 |
+
else:
|
140 |
+
# Instead of returning an empty list which breaks lexicalisation, when is no department constraint,
|
141 |
+
# return the first entry from the hospital db so the user still gets hospital information.
|
142 |
+
return [self.dbs["hospital"][0]]
|
143 |
+
|
144 |
+
valid_cons = False
|
145 |
+
for v in constraints.values():
|
146 |
+
if v not in ["not mentioned", ""]:
|
147 |
+
valid_cons = True
|
148 |
+
if not valid_cons:
|
149 |
+
return []
|
150 |
+
|
151 |
+
match_result = []
|
152 |
+
|
153 |
+
if "name" in constraints:
|
154 |
+
for db_ent in self.dbs[domain]:
|
155 |
+
if "name" in db_ent:
|
156 |
+
cons = constraints["name"]
|
157 |
+
dbn = db_ent["name"]
|
158 |
+
if cons == dbn:
|
159 |
+
db_ent = db_ent if not return_name else db_ent["name"]
|
160 |
+
match_result.append(db_ent)
|
161 |
+
return match_result
|
162 |
+
|
163 |
+
for db_ent in self.dbs[domain]:
|
164 |
+
match = True
|
165 |
+
for s, v in constraints.items():
|
166 |
+
if s == "name":
|
167 |
+
continue
|
168 |
+
if (
|
169 |
+
s in ["people", "stay"]
|
170 |
+
or (domain == "hotel" and s == "day")
|
171 |
+
or (domain == "restaurant" and s in ["day", "time"])
|
172 |
+
):
|
173 |
+
continue
|
174 |
+
|
175 |
+
skip_case = {
|
176 |
+
"don't care": 1,
|
177 |
+
"do n't care": 1,
|
178 |
+
"dont care": 1,
|
179 |
+
"not mentioned": 1,
|
180 |
+
"dontcare": 1,
|
181 |
+
"": 1,
|
182 |
+
}
|
183 |
+
if skip_case.get(v):
|
184 |
+
continue
|
185 |
+
|
186 |
+
if s not in db_ent:
|
187 |
+
# logging.warning('Searching warning: slot %s not in %s db'%(s, domain))
|
188 |
+
match = False
|
189 |
+
break
|
190 |
+
|
191 |
+
# v = 'guesthouse' if v == 'guest house' else v
|
192 |
+
# v = 'swimmingpool' if v == 'swimming pool' else v
|
193 |
+
v = "yes" if v == "free" else v
|
194 |
+
|
195 |
+
if s in ["arrive", "leave"]:
|
196 |
+
try:
|
197 |
+
h, m = v.split(":") # raise error if time value is not xx:xx format
|
198 |
+
v = int(h) * 60 + int(m)
|
199 |
+
except Exception:
|
200 |
+
match = False
|
201 |
+
break
|
202 |
+
time = int(db_ent[s].split(":")[0]) * 60 + int(db_ent[s].split(":")[1])
|
203 |
+
if s == "arrive" and v > time:
|
204 |
+
match = False
|
205 |
+
if s == "leave" and v < time:
|
206 |
+
match = False
|
207 |
+
else:
|
208 |
+
if exactly_match and v != db_ent[s]:
|
209 |
+
match = False
|
210 |
+
break
|
211 |
+
elif v not in db_ent[s]:
|
212 |
+
match = False
|
213 |
+
break
|
214 |
+
|
215 |
+
if match:
|
216 |
+
match_result.append(db_ent)
|
217 |
+
|
218 |
+
if not return_name:
|
219 |
+
return match_result
|
220 |
+
else:
|
221 |
+
if domain == "train":
|
222 |
+
match_result = [e["id"] for e in match_result]
|
223 |
+
else:
|
224 |
+
match_result = [e["name"] for e in match_result]
|
225 |
+
return match_result
|
226 |
+
|
227 |
+
def querySQL(self, domain, constraints):
|
228 |
+
if not self.sql_dbs:
|
229 |
+
for dom in db_domains:
|
230 |
+
db = "db/{}-dbase.db".format(dom)
|
231 |
+
conn = sqlite3.connect(db)
|
232 |
+
c = conn.cursor()
|
233 |
+
self.sql_dbs[dom] = c
|
234 |
+
|
235 |
+
sql_query = "select * from {}".format(domain)
|
236 |
+
|
237 |
+
flag = True
|
238 |
+
for key, val in constraints.items():
|
239 |
+
if (
|
240 |
+
val == ""
|
241 |
+
or val == "dontcare"
|
242 |
+
or val == "not mentioned"
|
243 |
+
or val == "don't care"
|
244 |
+
or val == "dont care"
|
245 |
+
or val == "do n't care"
|
246 |
+
):
|
247 |
+
pass
|
248 |
+
else:
|
249 |
+
if flag:
|
250 |
+
sql_query += " where "
|
251 |
+
val2 = val.replace("'", "''")
|
252 |
+
# val2 = normalize(val2)
|
253 |
+
if key == "leaveAt":
|
254 |
+
sql_query += r" " + key + " > " + r"'" + val2 + r"'"
|
255 |
+
elif key == "arriveBy":
|
256 |
+
sql_query += r" " + key + " < " + r"'" + val2 + r"'"
|
257 |
+
else:
|
258 |
+
sql_query += r" " + key + "=" + r"'" + val2 + r"'"
|
259 |
+
flag = False
|
260 |
+
else:
|
261 |
+
val2 = val.replace("'", "''")
|
262 |
+
# val2 = normalize(val2)
|
263 |
+
if key == "leaveAt":
|
264 |
+
sql_query += r" and " + key + " > " + r"'" + val2 + r"'"
|
265 |
+
elif key == "arriveBy":
|
266 |
+
sql_query += r" and " + key + " < " + r"'" + val2 + r"'"
|
267 |
+
else:
|
268 |
+
sql_query += r" and " + key + "=" + r"'" + val2 + r"'"
|
269 |
+
|
270 |
+
try: # "select * from attraction where name = 'queens college'"
|
271 |
+
print(sql_query)
|
272 |
+
return self.sql_dbs[domain].execute(sql_query).fetchall()
|
273 |
+
except Exception:
|
274 |
+
return [] # TODO test it
|
275 |
+
|
276 |
+
|
277 |
+
if __name__ == "__main__":
|
278 |
+
dbPATHs = {
|
279 |
+
"attraction": "db/attraction_db_processed.json",
|
280 |
+
"hospital": "db/hospital_db_processed.json",
|
281 |
+
"hotel": "db/hotel_db_processed.json",
|
282 |
+
"police": "db/police_db_processed.json",
|
283 |
+
"restaurant": "db/restaurant_db_processed.json",
|
284 |
+
"taxi": "db/taxi_db_processed.json",
|
285 |
+
"train": "db/train_db_processed.json",
|
286 |
+
}
|
287 |
+
db = MultiWozDB(dbPATHs)
|
288 |
+
while True:
|
289 |
+
constraints = {}
|
290 |
+
inp = input("input belief state in fomat: domain-slot1=value1;slot2=value2...\n")
|
291 |
+
domain, cons = inp.split("-")
|
292 |
+
for sv in cons.split(";"):
|
293 |
+
s, v = sv.split("=")
|
294 |
+
constraints[s] = v
|
295 |
+
# res = db.querySQL(domain, constraints)
|
296 |
+
res = db.queryJsons(domain, constraints, return_name=True)
|
297 |
+
report = []
|
298 |
+
reidx = {
|
299 |
+
"hotel": 8,
|
300 |
+
"restaurant": 6,
|
301 |
+
"attraction": 5,
|
302 |
+
"train": 1,
|
303 |
+
}
|
304 |
+
# for ent in res:
|
305 |
+
# if reidx.get(domain):
|
306 |
+
# report.append(ent[reidx[domain]])
|
307 |
+
# for ent in res:
|
308 |
+
# if 'name' in ent:
|
309 |
+
# report.append(ent['name'])
|
310 |
+
# if 'trainid' in ent:
|
311 |
+
# report.append(ent['trainid'])
|
312 |
+
print(constraints)
|
313 |
+
print(res)
|
314 |
+
print("count:", len(res), "\nnames:", report)
|
src/crazyneuraluser/UBAR_code/eval.py
ADDED
@@ -0,0 +1,932 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
import logging
|
3 |
+
import math
|
4 |
+
from collections import Counter, OrderedDict
|
5 |
+
|
6 |
+
from nltk.util import ngrams
|
7 |
+
|
8 |
+
from crazyneuraluser.UBAR_code import ontology
|
9 |
+
from crazyneuraluser.UBAR_code.clean_dataset import clean_slot_values
|
10 |
+
from crazyneuraluser.UBAR_code.config import global_config as cfg
|
11 |
+
|
12 |
+
|
13 |
+
class BLEUScorer(object):
|
14 |
+
# BLEU score calculator via GentScorer interface
|
15 |
+
# it calculates the BLEU-4 by taking the entire corpus in
|
16 |
+
# Calulate based multiple candidates against multiple references
|
17 |
+
def __init__(self):
|
18 |
+
pass
|
19 |
+
|
20 |
+
def score(self, parallel_corpus):
|
21 |
+
|
22 |
+
# containers
|
23 |
+
count = [0, 0, 0, 0]
|
24 |
+
clip_count = [0, 0, 0, 0]
|
25 |
+
r = 0
|
26 |
+
c = 0
|
27 |
+
weights = [0.25, 0.25, 0.25, 0.25]
|
28 |
+
|
29 |
+
# accumulate ngram statistics
|
30 |
+
for hyps, refs in parallel_corpus:
|
31 |
+
hyps = [hyp.split() for hyp in hyps]
|
32 |
+
refs = [ref.split() for ref in refs]
|
33 |
+
for hyp in hyps:
|
34 |
+
|
35 |
+
for i in range(4):
|
36 |
+
# accumulate ngram counts
|
37 |
+
hypcnts = Counter(ngrams(hyp, i + 1))
|
38 |
+
cnt = sum(hypcnts.values())
|
39 |
+
count[i] += cnt
|
40 |
+
|
41 |
+
# compute clipped counts
|
42 |
+
max_counts = {}
|
43 |
+
for ref in refs:
|
44 |
+
refcnts = Counter(ngrams(ref, i + 1))
|
45 |
+
for ng in hypcnts:
|
46 |
+
max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng])
|
47 |
+
clipcnt = dict(
|
48 |
+
(ng, min(count, max_counts[ng]))
|
49 |
+
for ng, count in hypcnts.items()
|
50 |
+
)
|
51 |
+
clip_count[i] += sum(clipcnt.values())
|
52 |
+
|
53 |
+
# accumulate r & c
|
54 |
+
bestmatch = [1000, 1000]
|
55 |
+
for ref in refs:
|
56 |
+
if bestmatch[0] == 0:
|
57 |
+
break
|
58 |
+
diff = abs(len(ref) - len(hyp))
|
59 |
+
if diff < bestmatch[0]:
|
60 |
+
bestmatch[0] = diff
|
61 |
+
bestmatch[1] = len(ref)
|
62 |
+
r += bestmatch[1]
|
63 |
+
c += len(hyp)
|
64 |
+
|
65 |
+
# computing bleu score
|
66 |
+
p0 = 1e-7
|
67 |
+
bp = 1 if c > r else math.exp(1 - float(r) / float(c))
|
68 |
+
p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 for i in range(4)]
|
69 |
+
s = math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns) if p_n)
|
70 |
+
bleu = bp * math.exp(s)
|
71 |
+
return bleu * 100
|
72 |
+
|
73 |
+
|
74 |
+
class MultiWozEvaluator(object):
|
75 |
+
def __init__(self, reader):
|
76 |
+
self.reader = reader
|
77 |
+
self.domains = ontology.all_domains
|
78 |
+
self.domain_files = self.reader.domain_files
|
79 |
+
self.all_data = self.reader.data
|
80 |
+
self.test_data = self.reader.test
|
81 |
+
|
82 |
+
self.bleu_scorer = BLEUScorer()
|
83 |
+
|
84 |
+
self.all_info_slot = []
|
85 |
+
for d, s_list in ontology.informable_slots.items():
|
86 |
+
for s in s_list:
|
87 |
+
self.all_info_slot.append(d + "-" + s)
|
88 |
+
|
89 |
+
# only evaluate these slots for dialog success
|
90 |
+
self.requestables = ["phone", "address", "postcode", "reference", "id"]
|
91 |
+
|
92 |
+
def pack_dial(self, data):
|
93 |
+
dials = {}
|
94 |
+
for turn in data:
|
95 |
+
dial_id = turn["dial_id"]
|
96 |
+
if dial_id not in dials:
|
97 |
+
dials[dial_id] = []
|
98 |
+
dials[dial_id].append(turn)
|
99 |
+
return dials
|
100 |
+
|
101 |
+
def run_metrics(self, data):
|
102 |
+
if "all" in cfg.exp_domains:
|
103 |
+
metric_results = []
|
104 |
+
metric_result = self._get_metric_results(data)
|
105 |
+
metric_results.append(metric_result)
|
106 |
+
|
107 |
+
if cfg.eval_per_domain:
|
108 |
+
# all domain experiments, sub domain evaluation
|
109 |
+
domains = [d + "_single" for d in ontology.all_domains]
|
110 |
+
domains = domains + [
|
111 |
+
"restaurant_train",
|
112 |
+
"restaurant_hotel",
|
113 |
+
"restaurant_attraction",
|
114 |
+
"hotel_train",
|
115 |
+
"hotel_attraction",
|
116 |
+
"attraction_train",
|
117 |
+
"restaurant_hotel_taxi",
|
118 |
+
"restaurant_attraction_taxi",
|
119 |
+
"hotel_attraction_taxi",
|
120 |
+
]
|
121 |
+
for domain in domains:
|
122 |
+
file_list = self.domain_files.get(domain, [])
|
123 |
+
if not file_list:
|
124 |
+
print("No sub domain [%s]" % domain)
|
125 |
+
metric_result = self._get_metric_results(data, domain, file_list)
|
126 |
+
if metric_result:
|
127 |
+
metric_results.append(metric_result)
|
128 |
+
|
129 |
+
else:
|
130 |
+
# sub domain experiments
|
131 |
+
metric_results = []
|
132 |
+
for domain, file_list in self.domain_files.items():
|
133 |
+
if domain not in cfg.exp_domains:
|
134 |
+
continue
|
135 |
+
metric_result = self._get_metric_results(data, domain, file_list)
|
136 |
+
if metric_result:
|
137 |
+
metric_results.append(metric_result)
|
138 |
+
|
139 |
+
return metric_results
|
140 |
+
|
141 |
+
def validation_metric(self, data):
|
142 |
+
bleu = self.bleu_metric(data)
|
143 |
+
# accu_single_dom, accu_multi_dom, multi_dom_num = self.domain_eval(data)
|
144 |
+
success, match, req_offer_counts, dial_num = self.context_to_response_eval(
|
145 |
+
data, same_eval_as_cambridge=cfg.same_eval_as_cambridge
|
146 |
+
)
|
147 |
+
return bleu, success, match
|
148 |
+
|
149 |
+
def _get_metric_results(self, data, domain="all", file_list=None):
|
150 |
+
metric_result = {"domain": domain}
|
151 |
+
bleu = self.bleu_metric(data, file_list)
|
152 |
+
if cfg.bspn_mode == "bspn" or cfg.enable_dst:
|
153 |
+
(
|
154 |
+
jg,
|
155 |
+
slot_f1,
|
156 |
+
slot_acc,
|
157 |
+
slot_cnt,
|
158 |
+
slot_corr,
|
159 |
+
) = self.dialog_state_tracking_eval(data, file_list)
|
160 |
+
jg_nn, sf1_nn, sac_nn, _, _ = self.dialog_state_tracking_eval(
|
161 |
+
data, file_list, no_name=True, no_book=False
|
162 |
+
)
|
163 |
+
jg_nb, sf1_nb, sac_nb, _, _ = self.dialog_state_tracking_eval(
|
164 |
+
data, file_list, no_name=False, no_book=True
|
165 |
+
)
|
166 |
+
jg_nnnb, sf1_nnnb, sac_nnnb, _, _ = self.dialog_state_tracking_eval(
|
167 |
+
data, file_list, no_name=True, no_book=True
|
168 |
+
)
|
169 |
+
metric_result.update(
|
170 |
+
{"joint_goal": jg, "slot_acc": slot_acc, "slot_f1": slot_f1}
|
171 |
+
)
|
172 |
+
if cfg.bspn_mode == "bsdx":
|
173 |
+
(
|
174 |
+
jg_,
|
175 |
+
slot_f1_,
|
176 |
+
slot_acc_,
|
177 |
+
slot_cnt,
|
178 |
+
slot_corr,
|
179 |
+
) = self.dialog_state_tracking_eval(data, file_list, bspn_mode="bsdx")
|
180 |
+
jg_nn_, sf1_nn_, sac_nn_, _, _ = self.dialog_state_tracking_eval(
|
181 |
+
data, file_list, bspn_mode="bsdx", no_name=True, no_book=False
|
182 |
+
)
|
183 |
+
metric_result.update(
|
184 |
+
{
|
185 |
+
"joint_goal_delex": jg_,
|
186 |
+
"slot_acc_delex": slot_acc_,
|
187 |
+
"slot_f1_delex": slot_f1_,
|
188 |
+
}
|
189 |
+
)
|
190 |
+
|
191 |
+
info_slots_acc = {}
|
192 |
+
for slot in slot_cnt:
|
193 |
+
correct = slot_corr.get(slot, 0)
|
194 |
+
info_slots_acc[slot] = correct / slot_cnt[slot] * 100
|
195 |
+
info_slots_acc = OrderedDict(sorted(info_slots_acc.items(), key=lambda x: x[1]))
|
196 |
+
|
197 |
+
act_f1 = self.aspn_eval(data, file_list)
|
198 |
+
avg_act_num, avg_diverse_score = self.multi_act_eval(data, file_list)
|
199 |
+
accu_single_dom, accu_multi_dom, multi_dom_num = self.domain_eval(
|
200 |
+
data, file_list
|
201 |
+
)
|
202 |
+
|
203 |
+
success, match, req_offer_counts, dial_num = self.context_to_response_eval(
|
204 |
+
data, file_list, same_eval_as_cambridge=cfg.same_eval_as_cambridge
|
205 |
+
)
|
206 |
+
req_slots_acc = {}
|
207 |
+
for req in self.requestables:
|
208 |
+
acc = req_offer_counts[req + "_offer"] / (
|
209 |
+
req_offer_counts[req + "_total"] + 1e-10
|
210 |
+
)
|
211 |
+
req_slots_acc[req] = acc * 100
|
212 |
+
req_slots_acc = OrderedDict(sorted(req_slots_acc.items(), key=lambda x: x[1]))
|
213 |
+
|
214 |
+
if dial_num:
|
215 |
+
metric_result.update(
|
216 |
+
{
|
217 |
+
"act_f1": act_f1,
|
218 |
+
"success": success,
|
219 |
+
"match": match,
|
220 |
+
"bleu": bleu,
|
221 |
+
"req_slots_acc": req_slots_acc,
|
222 |
+
"info_slots_acc": info_slots_acc,
|
223 |
+
"dial_num": dial_num,
|
224 |
+
"accu_single_dom": accu_single_dom,
|
225 |
+
"accu_multi_dom": accu_multi_dom,
|
226 |
+
"avg_act_num": avg_act_num,
|
227 |
+
"avg_diverse_score": avg_diverse_score,
|
228 |
+
}
|
229 |
+
)
|
230 |
+
if domain == "all":
|
231 |
+
logging.info(
|
232 |
+
"-------------------------- All DOMAINS --------------------------"
|
233 |
+
)
|
234 |
+
else:
|
235 |
+
logging.info(
|
236 |
+
"-------------------------- %s (# %d) -------------------------- "
|
237 |
+
% (domain.upper(), dial_num)
|
238 |
+
)
|
239 |
+
if cfg.bspn_mode == "bspn" or cfg.enable_dst:
|
240 |
+
logging.info(
|
241 |
+
"[DST] joint goal:%2.1f slot acc: %2.1f slot f1: %2.1f act f1: %2.1f"
|
242 |
+
% (jg, slot_acc, slot_f1, act_f1)
|
243 |
+
)
|
244 |
+
logging.info(
|
245 |
+
"[DST] [not eval name slots] joint goal:%2.1f slot acc: %2.1f slot f1: %2.1f"
|
246 |
+
% (jg_nn, sac_nn, sf1_nn)
|
247 |
+
)
|
248 |
+
logging.info(
|
249 |
+
"[DST] [not eval book slots] joint goal:%2.1f slot acc: %2.1f slot f1: %2.1f"
|
250 |
+
% (jg_nb, sac_nb, sf1_nb)
|
251 |
+
)
|
252 |
+
logging.info(
|
253 |
+
"[DST] [not eval name & book slots] joint goal:%2.1f slot acc: %2.1f slot f1: %2.1f"
|
254 |
+
% (jg_nnnb, sac_nnnb, sf1_nnnb)
|
255 |
+
)
|
256 |
+
if cfg.bspn_mode == "bsdx":
|
257 |
+
logging.info(
|
258 |
+
"[BDX] joint goal:%2.1f slot acc: %2.1f slot f1: %2.1f act f1: %2.1f"
|
259 |
+
% (jg_, slot_acc_, slot_f1_, act_f1)
|
260 |
+
)
|
261 |
+
logging.info(
|
262 |
+
"[BDX] [not eval name slots] joint goal:%2.1f slot acc: %2.1f slot f1: %2.1f"
|
263 |
+
% (jg_nn_, sac_nn_, sf1_nn_)
|
264 |
+
)
|
265 |
+
logging.info(
|
266 |
+
"[CTR] match: %2.1f success: %2.1f bleu: %2.1f"
|
267 |
+
% (match, success, bleu)
|
268 |
+
)
|
269 |
+
logging.info(
|
270 |
+
"[CTR] "
|
271 |
+
+ "; ".join(
|
272 |
+
["%s: %2.1f" % (req, acc) for req, acc in req_slots_acc.items()]
|
273 |
+
)
|
274 |
+
)
|
275 |
+
logging.info(
|
276 |
+
"[DOM] accuracy: single %2.1f / multi: %2.1f (%d)"
|
277 |
+
% (accu_single_dom, accu_multi_dom, multi_dom_num)
|
278 |
+
)
|
279 |
+
if self.reader.multi_acts_record is not None:
|
280 |
+
logging.info(
|
281 |
+
"[MA] avg acts num %2.1f avg slots num: %2.1f "
|
282 |
+
% (avg_act_num, avg_diverse_score)
|
283 |
+
)
|
284 |
+
return metric_result
|
285 |
+
else:
|
286 |
+
return None
|
287 |
+
|
288 |
+
def bleu_metric(self, data, eval_dial_list=None):
|
289 |
+
gen, truth = [], []
|
290 |
+
for row in data:
|
291 |
+
if eval_dial_list and row["dial_id"] + ".json" not in eval_dial_list:
|
292 |
+
continue
|
293 |
+
gen.append(row["resp_gen"])
|
294 |
+
truth.append(row["resp"])
|
295 |
+
wrap_generated = [[_] for _ in gen]
|
296 |
+
wrap_truth = [[_] for _ in truth]
|
297 |
+
if gen and truth:
|
298 |
+
sc = self.bleu_scorer.score(zip(wrap_generated, wrap_truth))
|
299 |
+
else:
|
300 |
+
sc = 0.0
|
301 |
+
return sc
|
302 |
+
|
303 |
+
def value_similar(self, a, b):
|
304 |
+
return True if a == b else False
|
305 |
+
|
306 |
+
# the value equal condition used in "Sequicity" is too loose
|
307 |
+
if (
|
308 |
+
a in b
|
309 |
+
or b in a
|
310 |
+
or a.split()[0] == b.split()[0]
|
311 |
+
or a.split()[-1] == b.split()[-1]
|
312 |
+
):
|
313 |
+
return True
|
314 |
+
return False
|
315 |
+
|
316 |
+
def _bspn_to_dict(self, bspn, no_name=False, no_book=False, bspn_mode="bspn"):
|
317 |
+
constraint_dict = self.reader.bspan_to_constraint_dict(
|
318 |
+
bspn, bspn_mode=bspn_mode
|
319 |
+
)
|
320 |
+
constraint_dict_flat = {}
|
321 |
+
for domain, cons in constraint_dict.items():
|
322 |
+
for s, v in cons.items():
|
323 |
+
key = domain + "-" + s
|
324 |
+
if no_name and s == "name":
|
325 |
+
continue
|
326 |
+
if no_book:
|
327 |
+
if s in ["people", "stay"] or key in [
|
328 |
+
"hotel-day",
|
329 |
+
"restaurant-day",
|
330 |
+
"restaurant-time",
|
331 |
+
]:
|
332 |
+
continue
|
333 |
+
constraint_dict_flat[key] = v
|
334 |
+
return constraint_dict_flat
|
335 |
+
|
336 |
+
def _constraint_compare(
|
337 |
+
self, truth_cons, gen_cons, slot_appear_num=None, slot_correct_num=None
|
338 |
+
):
|
339 |
+
tp, fp, fn = 0, 0, 0
|
340 |
+
false_slot = []
|
341 |
+
for slot in gen_cons:
|
342 |
+
v_gen = gen_cons[slot]
|
343 |
+
if slot in truth_cons and self.value_similar(
|
344 |
+
v_gen, truth_cons[slot]
|
345 |
+
): # v_truth = truth_cons[slot]
|
346 |
+
tp += 1
|
347 |
+
if slot_correct_num is not None:
|
348 |
+
slot_correct_num[slot] = (
|
349 |
+
1
|
350 |
+
if not slot_correct_num.get(slot)
|
351 |
+
else slot_correct_num.get(slot) + 1
|
352 |
+
)
|
353 |
+
else:
|
354 |
+
fp += 1
|
355 |
+
false_slot.append(slot)
|
356 |
+
for slot in truth_cons:
|
357 |
+
v_truth = truth_cons[slot]
|
358 |
+
if slot_appear_num is not None:
|
359 |
+
slot_appear_num[slot] = (
|
360 |
+
1
|
361 |
+
if not slot_appear_num.get(slot)
|
362 |
+
else slot_appear_num.get(slot) + 1
|
363 |
+
)
|
364 |
+
if slot not in gen_cons or not self.value_similar(v_truth, gen_cons[slot]):
|
365 |
+
fn += 1
|
366 |
+
false_slot.append(slot)
|
367 |
+
acc = len(self.all_info_slot) - fp - fn
|
368 |
+
return tp, fp, fn, acc, list(set(false_slot))
|
369 |
+
|
370 |
+
def domain_eval(self, data, eval_dial_list=None):
|
371 |
+
dials = self.pack_dial(data)
|
372 |
+
corr_single, total_single, corr_multi, total_multi = 0, 0, 0, 0
|
373 |
+
|
374 |
+
dial_num = 0
|
375 |
+
for dial_id in dials:
|
376 |
+
if eval_dial_list and dial_id + ".json" not in eval_dial_list:
|
377 |
+
continue
|
378 |
+
dial_num += 1
|
379 |
+
dial = dials[dial_id]
|
380 |
+
wrong_pred = []
|
381 |
+
|
382 |
+
prev_constraint_dict = {}
|
383 |
+
prev_turn_domain = ["general"]
|
384 |
+
|
385 |
+
for turn_num, turn in enumerate(dial):
|
386 |
+
if turn_num == 0:
|
387 |
+
continue
|
388 |
+
true_domains = self.reader.dspan_to_domain(turn["dspn"])
|
389 |
+
if cfg.enable_dspn:
|
390 |
+
pred_domains = self.reader.dspan_to_domain(turn["dspn_gen"])
|
391 |
+
else:
|
392 |
+
turn_dom_bs = []
|
393 |
+
if (
|
394 |
+
cfg.enable_bspn
|
395 |
+
and not cfg.use_true_bspn_for_ctr_eval
|
396 |
+
and (cfg.bspn_mode == "bspn" or cfg.enable_dst)
|
397 |
+
):
|
398 |
+
constraint_dict = self.reader.bspan_to_constraint_dict(
|
399 |
+
turn["bspn_gen"]
|
400 |
+
)
|
401 |
+
else:
|
402 |
+
constraint_dict = self.reader.bspan_to_constraint_dict(
|
403 |
+
turn["bspn"]
|
404 |
+
)
|
405 |
+
for domain in constraint_dict:
|
406 |
+
if domain not in prev_constraint_dict:
|
407 |
+
turn_dom_bs.append(domain)
|
408 |
+
elif prev_constraint_dict[domain] != constraint_dict[domain]:
|
409 |
+
turn_dom_bs.append(domain)
|
410 |
+
aspn = "aspn" if not cfg.enable_aspn else "aspn_gen"
|
411 |
+
turn_dom_da = []
|
412 |
+
for a in turn[aspn].split():
|
413 |
+
if a[1:-1] in ontology.all_domains + ["general"]:
|
414 |
+
turn_dom_da.append(a[1:-1])
|
415 |
+
|
416 |
+
# get turn domain
|
417 |
+
turn_domain = turn_dom_bs
|
418 |
+
for dom in turn_dom_da:
|
419 |
+
if dom != "booking" and dom not in turn_domain:
|
420 |
+
turn_domain.append(dom)
|
421 |
+
if not turn_domain:
|
422 |
+
turn_domain = prev_turn_domain
|
423 |
+
if len(turn_domain) == 2 and "general" in turn_domain:
|
424 |
+
turn_domain.remove("general")
|
425 |
+
if len(turn_domain) == 2:
|
426 |
+
if (
|
427 |
+
len(prev_turn_domain) == 1
|
428 |
+
and prev_turn_domain[0] == turn_domain[1]
|
429 |
+
):
|
430 |
+
turn_domain = turn_domain[::-1]
|
431 |
+
prev_turn_domain = copy.deepcopy(turn_domain)
|
432 |
+
prev_constraint_dict = copy.deepcopy(constraint_dict)
|
433 |
+
|
434 |
+
turn["dspn_gen"] = " ".join(["[" + d + "]" for d in turn_domain])
|
435 |
+
pred_domains = {}
|
436 |
+
for d in turn_domain:
|
437 |
+
pred_domains["[" + d + "]"] = 1
|
438 |
+
|
439 |
+
if len(true_domains) == 1:
|
440 |
+
total_single += 1
|
441 |
+
if pred_domains == true_domains:
|
442 |
+
corr_single += 1
|
443 |
+
else:
|
444 |
+
wrong_pred.append(str(turn["turn_num"]))
|
445 |
+
turn["wrong_domain"] = "x"
|
446 |
+
else:
|
447 |
+
total_multi += 1
|
448 |
+
if pred_domains == true_domains:
|
449 |
+
corr_multi += 1
|
450 |
+
else:
|
451 |
+
wrong_pred.append(str(turn["turn_num"]))
|
452 |
+
turn["wrong_domain"] = "x"
|
453 |
+
|
454 |
+
# dialog inform metric record
|
455 |
+
dial[0]["wrong_domain"] = " ".join(wrong_pred)
|
456 |
+
accu_single = corr_single / (total_single + 1e-10)
|
457 |
+
accu_multi = corr_multi / (total_multi + 1e-10)
|
458 |
+
return accu_single * 100, accu_multi * 100, total_multi
|
459 |
+
|
460 |
+
def dialog_state_tracking_eval(
|
461 |
+
self, data, eval_dial_list=None, bspn_mode="bspn", no_name=False, no_book=False
|
462 |
+
):
|
463 |
+
dials = self.pack_dial(data)
|
464 |
+
total_turn, joint_match, total_tp, total_fp, total_fn, total_acc = (
|
465 |
+
0,
|
466 |
+
0,
|
467 |
+
0,
|
468 |
+
0,
|
469 |
+
0,
|
470 |
+
0,
|
471 |
+
)
|
472 |
+
slot_appear_num, slot_correct_num = {}, {}
|
473 |
+
dial_num = 0
|
474 |
+
for dial_id in dials:
|
475 |
+
if eval_dial_list and dial_id + ".json" not in eval_dial_list:
|
476 |
+
continue
|
477 |
+
dial_num += 1
|
478 |
+
dial = dials[dial_id]
|
479 |
+
missed_jg_turn_id = []
|
480 |
+
for turn_num, turn in enumerate(dial):
|
481 |
+
if turn_num == 0:
|
482 |
+
continue
|
483 |
+
gen_cons = self._bspn_to_dict(
|
484 |
+
turn[bspn_mode + "_gen"],
|
485 |
+
no_name=no_name,
|
486 |
+
no_book=no_book,
|
487 |
+
bspn_mode=bspn_mode,
|
488 |
+
)
|
489 |
+
truth_cons = self._bspn_to_dict(
|
490 |
+
turn[bspn_mode],
|
491 |
+
no_name=no_name,
|
492 |
+
no_book=no_book,
|
493 |
+
bspn_mode=bspn_mode,
|
494 |
+
)
|
495 |
+
|
496 |
+
if truth_cons == gen_cons:
|
497 |
+
joint_match += 1
|
498 |
+
else:
|
499 |
+
missed_jg_turn_id.append(str(turn["turn_num"]))
|
500 |
+
|
501 |
+
if eval_dial_list is None:
|
502 |
+
tp, fp, fn, acc, false_slots = self._constraint_compare(
|
503 |
+
truth_cons, gen_cons, slot_appear_num, slot_correct_num
|
504 |
+
)
|
505 |
+
else:
|
506 |
+
tp, fp, fn, acc, false_slots = self._constraint_compare(
|
507 |
+
truth_cons,
|
508 |
+
gen_cons,
|
509 |
+
)
|
510 |
+
|
511 |
+
total_tp += tp
|
512 |
+
total_fp += fp
|
513 |
+
total_fn += fn
|
514 |
+
total_acc += acc
|
515 |
+
total_turn += 1
|
516 |
+
if not no_name and not no_book:
|
517 |
+
turn["wrong_inform"] = "; ".join(
|
518 |
+
false_slots
|
519 |
+
) # turn inform metric record
|
520 |
+
|
521 |
+
# dialog inform metric record
|
522 |
+
if not no_name and not no_book:
|
523 |
+
dial[0]["wrong_inform"] = " ".join(missed_jg_turn_id)
|
524 |
+
|
525 |
+
precision = total_tp / (total_tp + total_fp + 1e-10)
|
526 |
+
recall = total_tp / (total_tp + total_fn + 1e-10)
|
527 |
+
f1 = 2 * precision * recall / (precision + recall + 1e-10) * 100
|
528 |
+
accuracy = total_acc / (total_turn * len(self.all_info_slot) + 1e-10) * 100
|
529 |
+
joint_goal = joint_match / (total_turn + 1e-10) * 100
|
530 |
+
|
531 |
+
return joint_goal, f1, accuracy, slot_appear_num, slot_correct_num
|
532 |
+
|
533 |
+
def aspn_eval(self, data, eval_dial_list=None):
|
534 |
+
def _get_tp_fp_fn(label_list, pred_list):
|
535 |
+
tp = len([t for t in pred_list if t in label_list])
|
536 |
+
fp = max(0, len(pred_list) - tp)
|
537 |
+
fn = max(0, len(label_list) - tp)
|
538 |
+
return tp, fp, fn
|
539 |
+
|
540 |
+
dials = self.pack_dial(data)
|
541 |
+
total_tp, total_fp, total_fn = 0, 0, 0
|
542 |
+
|
543 |
+
dial_num = 0
|
544 |
+
for dial_id in dials:
|
545 |
+
if eval_dial_list and dial_id + ".json" not in eval_dial_list:
|
546 |
+
continue
|
547 |
+
dial_num += 1
|
548 |
+
dial = dials[dial_id]
|
549 |
+
wrong_act = []
|
550 |
+
for turn_num, turn in enumerate(dial):
|
551 |
+
if turn_num == 0:
|
552 |
+
continue
|
553 |
+
if cfg.same_eval_act_f1_as_hdsa:
|
554 |
+
pred_acts, true_acts = {}, {}
|
555 |
+
for t in turn["aspn_gen"]:
|
556 |
+
pred_acts[t] = 1
|
557 |
+
for t in turn["aspn"]:
|
558 |
+
true_acts[t] = 1
|
559 |
+
tp, fp, fn = _get_tp_fp_fn(true_acts, pred_acts)
|
560 |
+
else:
|
561 |
+
pred_acts = self.reader.aspan_to_act_list(turn["aspn_gen"])
|
562 |
+
true_acts = self.reader.aspan_to_act_list(turn["aspn"])
|
563 |
+
tp, fp, fn = _get_tp_fp_fn(true_acts, pred_acts)
|
564 |
+
if fp + fn != 0:
|
565 |
+
wrong_act.append(str(turn["turn_num"]))
|
566 |
+
turn["wrong_act"] = "x"
|
567 |
+
|
568 |
+
total_tp += tp
|
569 |
+
total_fp += fp
|
570 |
+
total_fn += fn
|
571 |
+
|
572 |
+
dial[0]["wrong_act"] = " ".join(wrong_act)
|
573 |
+
precision = total_tp / (total_tp + total_fp + 1e-10)
|
574 |
+
recall = total_tp / (total_tp + total_fn + 1e-10)
|
575 |
+
f1 = 2 * precision * recall / (precision + recall + 1e-10)
|
576 |
+
|
577 |
+
return f1 * 100
|
578 |
+
|
579 |
+
def multi_act_eval(self, data, eval_dial_list=None):
|
580 |
+
|
581 |
+
dials = self.pack_dial(data)
|
582 |
+
total_act_num, total_slot_num = 0, 0
|
583 |
+
|
584 |
+
dial_num = 0
|
585 |
+
turn_count = 0
|
586 |
+
for dial_id in dials:
|
587 |
+
if eval_dial_list and dial_id + ".json" not in eval_dial_list:
|
588 |
+
continue
|
589 |
+
dial_num += 1
|
590 |
+
dial = dials[dial_id]
|
591 |
+
for turn_num, turn in enumerate(dial):
|
592 |
+
if turn_num == 0:
|
593 |
+
continue
|
594 |
+
target = (
|
595 |
+
turn["multi_act_gen"]
|
596 |
+
if self.reader.multi_acts_record is not None
|
597 |
+
else turn["aspn_gen"]
|
598 |
+
)
|
599 |
+
|
600 |
+
# diversity
|
601 |
+
act_collect, slot_collect = {}, {}
|
602 |
+
act_type_collect = {}
|
603 |
+
slot_score = 0
|
604 |
+
for act_str in target.split(" | "):
|
605 |
+
pred_acts = self.reader.aspan_to_act_list(act_str)
|
606 |
+
act_type = ""
|
607 |
+
for act in pred_acts:
|
608 |
+
d, a, s = act.split("-")
|
609 |
+
if d + "-" + a not in act_collect:
|
610 |
+
act_collect[d + "-" + a] = {s: 1}
|
611 |
+
slot_score += 1
|
612 |
+
act_type += d + "-" + a + ";"
|
613 |
+
elif s not in act_collect:
|
614 |
+
act_collect[d + "-" + a][s] = 1
|
615 |
+
slot_score += 1
|
616 |
+
slot_collect[s] = 1
|
617 |
+
act_type_collect[act_type] = 1
|
618 |
+
total_act_num += len(act_collect)
|
619 |
+
total_slot_num += len(slot_collect)
|
620 |
+
turn_count += 1
|
621 |
+
|
622 |
+
total_act_num = total_act_num / (float(turn_count) + 1e-10)
|
623 |
+
total_slot_num = total_slot_num / (float(turn_count) + 1e-10)
|
624 |
+
return total_act_num, total_slot_num
|
625 |
+
|
626 |
+
def context_to_response_eval(
|
627 |
+
self, data, eval_dial_list=None, same_eval_as_cambridge=False
|
628 |
+
):
|
629 |
+
dials = self.pack_dial(data)
|
630 |
+
counts = {}
|
631 |
+
for req in self.requestables:
|
632 |
+
counts[req + "_total"] = 0
|
633 |
+
counts[req + "_offer"] = 0
|
634 |
+
|
635 |
+
dial_num, successes, matches = 0, 0, 0
|
636 |
+
|
637 |
+
for dial_id in dials:
|
638 |
+
if eval_dial_list and dial_id + ".json" not in eval_dial_list:
|
639 |
+
continue
|
640 |
+
dial = dials[dial_id]
|
641 |
+
reqs = {}
|
642 |
+
goal = {}
|
643 |
+
if ".json" not in dial_id and ".json" in list(self.all_data.keys())[0]:
|
644 |
+
dial_id = dial_id + ".json"
|
645 |
+
for domain in ontology.all_domains:
|
646 |
+
if self.all_data[dial_id]["goal"].get(domain):
|
647 |
+
true_goal = self.all_data[dial_id]["goal"]
|
648 |
+
goal = self._parseGoal(goal, true_goal, domain)
|
649 |
+
# print(goal)
|
650 |
+
for domain in goal.keys():
|
651 |
+
reqs[domain] = goal[domain]["requestable"]
|
652 |
+
|
653 |
+
# print('\n',dial_id)
|
654 |
+
success, match, stats, counts = self._evaluateGeneratedDialogue(
|
655 |
+
dial, goal, reqs, counts, same_eval_as_cambridge=same_eval_as_cambridge
|
656 |
+
)
|
657 |
+
|
658 |
+
successes += success
|
659 |
+
matches += match
|
660 |
+
dial_num += 1
|
661 |
+
|
662 |
+
# for domain in gen_stats.keys():
|
663 |
+
# gen_stats[domain][0] += stats[domain][0]
|
664 |
+
# gen_stats[domain][1] += stats[domain][1]
|
665 |
+
# gen_stats[domain][2] += stats[domain][2]
|
666 |
+
|
667 |
+
# if 'SNG' in filename:
|
668 |
+
# for domain in gen_stats.keys():
|
669 |
+
# sng_gen_stats[domain][0] += stats[domain][0]
|
670 |
+
# sng_gen_stats[domain][1] += stats[domain][1]
|
671 |
+
# sng_gen_stats[domain][2] += stats[domain][2]
|
672 |
+
|
673 |
+
# self.logger.info(report)
|
674 |
+
succ_rate = successes / (float(dial_num) + 1e-10) * 100
|
675 |
+
match_rate = matches / (float(dial_num) + 1e-10) * 100
|
676 |
+
return succ_rate, match_rate, counts, dial_num
|
677 |
+
|
678 |
+
def _evaluateGeneratedDialogue(
|
679 |
+
self,
|
680 |
+
dialog,
|
681 |
+
goal,
|
682 |
+
real_requestables,
|
683 |
+
counts,
|
684 |
+
soft_acc=False,
|
685 |
+
same_eval_as_cambridge=False,
|
686 |
+
):
|
687 |
+
"""Evaluates the dialogue created by the model.
|
688 |
+
First we load the user goal of the dialogue, then for each turn
|
689 |
+
generated by the system we look for key-words.
|
690 |
+
For the Inform rate we look whether the entity was proposed.
|
691 |
+
For the Success rate we look for requestables slots"""
|
692 |
+
# for computing corpus success 'id'
|
693 |
+
requestables = self.requestables
|
694 |
+
|
695 |
+
# CHECK IF MATCH HAPPENED
|
696 |
+
provided_requestables = {}
|
697 |
+
venue_offered = {}
|
698 |
+
domains_in_goal = []
|
699 |
+
bspans = {}
|
700 |
+
|
701 |
+
for domain in goal.keys():
|
702 |
+
venue_offered[domain] = []
|
703 |
+
provided_requestables[domain] = []
|
704 |
+
domains_in_goal.append(domain)
|
705 |
+
|
706 |
+
for t, turn in enumerate(dialog):
|
707 |
+
if t == 0:
|
708 |
+
continue
|
709 |
+
sent_t = turn["resp_gen"]
|
710 |
+
# sent_t = turn['resp']
|
711 |
+
for domain in goal.keys():
|
712 |
+
# for computing success
|
713 |
+
if same_eval_as_cambridge:
|
714 |
+
# [restaurant_name], [hotel_name] instead of [value_name]
|
715 |
+
if cfg.use_true_domain_for_ctr_eval:
|
716 |
+
dom_pred = [d[1:-1] for d in turn["dspn"].split()]
|
717 |
+
else:
|
718 |
+
dom_pred = [d[1:-1] for d in turn["dspn_gen"].split()]
|
719 |
+
# else:
|
720 |
+
# raise NotImplementedError('Just use true domain label')
|
721 |
+
if domain not in dom_pred: # fail
|
722 |
+
continue
|
723 |
+
if "[value_name]" in sent_t or "[value_id]" in sent_t:
|
724 |
+
if domain in ["restaurant", "hotel", "attraction", "train"]:
|
725 |
+
# HERE YOU CAN PUT YOUR BELIEF STATE ESTIMATION
|
726 |
+
if (
|
727 |
+
not cfg.use_true_curr_bspn
|
728 |
+
and not cfg.use_true_bspn_for_ctr_eval
|
729 |
+
):
|
730 |
+
bspn = turn["bspn_gen"]
|
731 |
+
else:
|
732 |
+
bspn = turn["bspn"]
|
733 |
+
# bspn = turn['bspn']
|
734 |
+
|
735 |
+
constraint_dict = self.reader.bspan_to_constraint_dict(bspn)
|
736 |
+
if constraint_dict.get(domain):
|
737 |
+
venues = self.reader.db.queryJsons(
|
738 |
+
domain, constraint_dict[domain], return_name=True
|
739 |
+
)
|
740 |
+
else:
|
741 |
+
venues = []
|
742 |
+
|
743 |
+
# if venue has changed
|
744 |
+
if len(venue_offered[domain]) == 0 and venues:
|
745 |
+
# venue_offered[domain] = random.sample(venues, 1)
|
746 |
+
venue_offered[domain] = venues
|
747 |
+
bspans[domain] = constraint_dict[domain]
|
748 |
+
else:
|
749 |
+
# flag = False
|
750 |
+
# for ven in venues:
|
751 |
+
# if venue_offered[domain][0] == ven:
|
752 |
+
# flag = True
|
753 |
+
# break
|
754 |
+
# if not flag and venues:
|
755 |
+
flag = False
|
756 |
+
for ven in venues:
|
757 |
+
if ven not in venue_offered[domain]:
|
758 |
+
# if ven not in venue_offered[domain]:
|
759 |
+
flag = True
|
760 |
+
break
|
761 |
+
# if flag and venues:
|
762 |
+
if (
|
763 |
+
flag and venues
|
764 |
+
): # sometimes there are no results so sample won't work
|
765 |
+
# print venues
|
766 |
+
# venue_offered[domain] = random.sample(venues, 1)
|
767 |
+
venue_offered[domain] = venues
|
768 |
+
bspans[domain] = constraint_dict[domain]
|
769 |
+
else: # not limited so we can provide one
|
770 |
+
venue_offered[domain] = "[value_name]"
|
771 |
+
|
772 |
+
# ATTENTION: assumption here - we didn't provide phone or address twice! etc
|
773 |
+
for requestable in requestables:
|
774 |
+
if requestable == "reference":
|
775 |
+
if "[value_reference]" in sent_t:
|
776 |
+
if (
|
777 |
+
"booked" in turn["pointer"] or "ok" in turn["pointer"]
|
778 |
+
): # if pointer was allowing for that?
|
779 |
+
provided_requestables[domain].append("reference")
|
780 |
+
# provided_requestables[domain].append('reference')
|
781 |
+
else:
|
782 |
+
if "[value_" + requestable + "]" in sent_t:
|
783 |
+
provided_requestables[domain].append(requestable)
|
784 |
+
|
785 |
+
# if name was given in the task
|
786 |
+
for domain in goal.keys():
|
787 |
+
# if name was provided for the user, the match is being done automatically
|
788 |
+
if "name" in goal[domain]["informable"]:
|
789 |
+
venue_offered[domain] = "[value_name]"
|
790 |
+
|
791 |
+
# special domains - entity does not need to be provided
|
792 |
+
if domain in ["taxi", "police", "hospital"]:
|
793 |
+
venue_offered[domain] = "[value_name]"
|
794 |
+
|
795 |
+
if domain == "train":
|
796 |
+
if (
|
797 |
+
not venue_offered[domain]
|
798 |
+
and "id" not in goal[domain]["requestable"]
|
799 |
+
):
|
800 |
+
venue_offered[domain] = "[value_name]"
|
801 |
+
|
802 |
+
"""
|
803 |
+
Given all inform and requestable slots
|
804 |
+
we go through each domain from the user goal
|
805 |
+
and check whether right entity was provided and
|
806 |
+
all requestable slots were given to the user.
|
807 |
+
The dialogue is successful if that's the case for all domains.
|
808 |
+
"""
|
809 |
+
# HARD EVAL
|
810 |
+
stats = {
|
811 |
+
"restaurant": [0, 0, 0],
|
812 |
+
"hotel": [0, 0, 0],
|
813 |
+
"attraction": [0, 0, 0],
|
814 |
+
"train": [0, 0, 0],
|
815 |
+
"taxi": [0, 0, 0],
|
816 |
+
"hospital": [0, 0, 0],
|
817 |
+
"police": [0, 0, 0],
|
818 |
+
}
|
819 |
+
|
820 |
+
match = 0
|
821 |
+
success = 0
|
822 |
+
# MATCH
|
823 |
+
for domain in goal.keys():
|
824 |
+
match_stat = 0
|
825 |
+
if domain in ["restaurant", "hotel", "attraction", "train"]:
|
826 |
+
goal_venues = self.reader.db.queryJsons(
|
827 |
+
domain, goal[domain]["informable"], return_name=True
|
828 |
+
)
|
829 |
+
if (
|
830 |
+
type(venue_offered[domain]) is str
|
831 |
+
and "_name" in venue_offered[domain]
|
832 |
+
):
|
833 |
+
match += 1
|
834 |
+
match_stat = 1
|
835 |
+
elif (
|
836 |
+
len(venue_offered[domain]) > 0
|
837 |
+
and len(set(venue_offered[domain]) & set(goal_venues)) > 0
|
838 |
+
):
|
839 |
+
match += 1
|
840 |
+
match_stat = 1
|
841 |
+
else:
|
842 |
+
if "_name]" in venue_offered[domain]:
|
843 |
+
match += 1
|
844 |
+
match_stat = 1
|
845 |
+
|
846 |
+
stats[domain][0] = match_stat
|
847 |
+
stats[domain][2] = 1
|
848 |
+
|
849 |
+
if soft_acc:
|
850 |
+
match = float(match) / len(goal.keys())
|
851 |
+
else:
|
852 |
+
if match == len(goal.keys()):
|
853 |
+
match = 1.0
|
854 |
+
else:
|
855 |
+
match = 0.0
|
856 |
+
|
857 |
+
for domain in domains_in_goal:
|
858 |
+
for request in real_requestables[domain]:
|
859 |
+
counts[request + "_total"] += 1
|
860 |
+
if request in provided_requestables[domain]:
|
861 |
+
counts[request + "_offer"] += 1
|
862 |
+
|
863 |
+
# SUCCESS
|
864 |
+
if match == 1.0:
|
865 |
+
for domain in domains_in_goal:
|
866 |
+
success_stat = 0
|
867 |
+
domain_success = 0
|
868 |
+
if len(real_requestables[domain]) == 0:
|
869 |
+
success += 1
|
870 |
+
success_stat = 1
|
871 |
+
stats[domain][1] = success_stat
|
872 |
+
continue
|
873 |
+
# if values in sentences are super set of requestables
|
874 |
+
# for request in set(provided_requestables[domain]):
|
875 |
+
# if request in real_requestables[domain]:
|
876 |
+
# domain_success += 1
|
877 |
+
for request in real_requestables[domain]:
|
878 |
+
if request in provided_requestables[domain]:
|
879 |
+
domain_success += 1
|
880 |
+
|
881 |
+
# if domain_success >= len(real_requestables[domain]):
|
882 |
+
if domain_success == len(real_requestables[domain]):
|
883 |
+
success += 1
|
884 |
+
success_stat = 1
|
885 |
+
|
886 |
+
stats[domain][1] = success_stat
|
887 |
+
|
888 |
+
# final eval
|
889 |
+
if soft_acc:
|
890 |
+
success = float(success) / len(real_requestables)
|
891 |
+
else:
|
892 |
+
if success >= len(real_requestables):
|
893 |
+
success = 1
|
894 |
+
else:
|
895 |
+
success = 0
|
896 |
+
|
897 |
+
return success, match, stats, counts
|
898 |
+
|
899 |
+
def _parseGoal(self, goal, true_goal, domain):
|
900 |
+
"""Parses user goal into dictionary format."""
|
901 |
+
goal[domain] = {}
|
902 |
+
goal[domain] = {"informable": {}, "requestable": [], "booking": []}
|
903 |
+
if "info" in true_goal[domain]:
|
904 |
+
if domain == "train":
|
905 |
+
# we consider dialogues only where train had to be booked!
|
906 |
+
if "book" in true_goal[domain]:
|
907 |
+
goal[domain]["requestable"].append("reference")
|
908 |
+
if "reqt" in true_goal[domain]:
|
909 |
+
if "id" in true_goal[domain]["reqt"]:
|
910 |
+
goal[domain]["requestable"].append("id")
|
911 |
+
else:
|
912 |
+
if "reqt" in true_goal[domain]:
|
913 |
+
for s in true_goal[domain]["reqt"]: # addtional requests:
|
914 |
+
if s in ["phone", "address", "postcode", "reference", "id"]:
|
915 |
+
# ones that can be easily delexicalized
|
916 |
+
goal[domain]["requestable"].append(s)
|
917 |
+
if "book" in true_goal[domain]:
|
918 |
+
goal[domain]["requestable"].append("reference")
|
919 |
+
|
920 |
+
for s, v in true_goal[domain]["info"].items():
|
921 |
+
s_, v_ = clean_slot_values(domain, s, v)
|
922 |
+
if len(v_.split()) > 1:
|
923 |
+
v_ = " ".join([token.text for token in self.reader.nlp(v_)]).strip()
|
924 |
+
goal[domain]["informable"][s_] = v_
|
925 |
+
|
926 |
+
if "book" in true_goal[domain]:
|
927 |
+
goal[domain]["booking"] = true_goal[domain]["book"]
|
928 |
+
return goal
|
929 |
+
|
930 |
+
|
931 |
+
if __name__ == "__main__":
|
932 |
+
pass
|
src/crazyneuraluser/UBAR_code/ontology.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
all_domains = [
|
2 |
+
"restaurant",
|
3 |
+
"hotel",
|
4 |
+
"attraction",
|
5 |
+
"train",
|
6 |
+
"taxi",
|
7 |
+
"police",
|
8 |
+
"hospital",
|
9 |
+
]
|
10 |
+
db_domains = ["restaurant", "hotel", "attraction", "train"]
|
11 |
+
|
12 |
+
# original slot names in goals (including booking slots)
|
13 |
+
# requestable_slots_in_goals = {
|
14 |
+
# "taxi": ["car type", "phone"],
|
15 |
+
# "police": ["postcode", "address", "phone"],
|
16 |
+
# "hospital": ["address", "phone", "postcode"],
|
17 |
+
# "hotel": ["address", "postcode", "internet", "phone", "parking",
|
18 |
+
# "type", "pricerange", "stars", "area", "reference"],
|
19 |
+
# "attraction": ["entrance fee", "type", "address", "postcode", "phone", "area", "reference"],
|
20 |
+
# "train": ["duration", "leaveat", "price", "arriveby", "id", "reference"],
|
21 |
+
# "restaurant": ["phone", "postcode", "address", "pricerange", "food", "area", "reference"]
|
22 |
+
# }
|
23 |
+
|
24 |
+
# informable_slots_in_goals = {
|
25 |
+
# "taxi": ["leaveat", "destination", "departure", "arriveby"],
|
26 |
+
# "police": [],
|
27 |
+
# "hospital": ["department"],
|
28 |
+
# "hotel": ["type", "parking", "pricerange", "internet", "stay", "day", "people", "area", "stars", "name"],
|
29 |
+
# "attraction": ["area", "type", "name"],
|
30 |
+
# "train": ["destination", "day", "arriveby", "departure", "people", "leaveat"],
|
31 |
+
# "restaurant": ["food", "pricerange", "area", "name", "time", "day", "people"]
|
32 |
+
# }
|
33 |
+
|
34 |
+
normlize_slot_names = {
|
35 |
+
"car type": "car",
|
36 |
+
"entrance fee": "price",
|
37 |
+
"duration": "time",
|
38 |
+
"leaveat": "leave",
|
39 |
+
"arriveby": "arrive",
|
40 |
+
"trainid": "id",
|
41 |
+
}
|
42 |
+
|
43 |
+
requestable_slots = {
|
44 |
+
"taxi": ["car", "phone"],
|
45 |
+
"police": ["postcode", "address", "phone"],
|
46 |
+
"hospital": ["address", "phone", "postcode"],
|
47 |
+
"hotel": [
|
48 |
+
"address",
|
49 |
+
"postcode",
|
50 |
+
"internet",
|
51 |
+
"phone",
|
52 |
+
"parking",
|
53 |
+
"type",
|
54 |
+
"pricerange",
|
55 |
+
"stars",
|
56 |
+
"area",
|
57 |
+
"reference",
|
58 |
+
],
|
59 |
+
"attraction": [
|
60 |
+
"price",
|
61 |
+
"type",
|
62 |
+
"address",
|
63 |
+
"postcode",
|
64 |
+
"phone",
|
65 |
+
"area",
|
66 |
+
"reference",
|
67 |
+
],
|
68 |
+
"train": ["time", "leave", "price", "arrive", "id", "reference"],
|
69 |
+
"restaurant": [
|
70 |
+
"phone",
|
71 |
+
"postcode",
|
72 |
+
"address",
|
73 |
+
"pricerange",
|
74 |
+
"food",
|
75 |
+
"area",
|
76 |
+
"reference",
|
77 |
+
],
|
78 |
+
}
|
79 |
+
all_reqslot = [
|
80 |
+
"car",
|
81 |
+
"address",
|
82 |
+
"postcode",
|
83 |
+
"phone",
|
84 |
+
"internet",
|
85 |
+
"parking",
|
86 |
+
"type",
|
87 |
+
"pricerange",
|
88 |
+
"food",
|
89 |
+
"stars",
|
90 |
+
"area",
|
91 |
+
"reference",
|
92 |
+
"time",
|
93 |
+
"leave",
|
94 |
+
"price",
|
95 |
+
"arrive",
|
96 |
+
"id",
|
97 |
+
]
|
98 |
+
# count: 17
|
99 |
+
|
100 |
+
informable_slots = {
|
101 |
+
"taxi": ["leave", "destination", "departure", "arrive"],
|
102 |
+
"police": [],
|
103 |
+
"hospital": ["department"],
|
104 |
+
"hotel": [
|
105 |
+
"type",
|
106 |
+
"parking",
|
107 |
+
"pricerange",
|
108 |
+
"internet",
|
109 |
+
"stay",
|
110 |
+
"day",
|
111 |
+
"people",
|
112 |
+
"area",
|
113 |
+
"stars",
|
114 |
+
"name",
|
115 |
+
],
|
116 |
+
"attraction": ["area", "type", "name"],
|
117 |
+
"train": ["destination", "day", "arrive", "departure", "people", "leave"],
|
118 |
+
"restaurant": ["food", "pricerange", "area", "name", "time", "day", "people"],
|
119 |
+
}
|
120 |
+
all_infslot = [
|
121 |
+
"type",
|
122 |
+
"parking",
|
123 |
+
"pricerange",
|
124 |
+
"internet",
|
125 |
+
"stay",
|
126 |
+
"day",
|
127 |
+
"people",
|
128 |
+
"area",
|
129 |
+
"stars",
|
130 |
+
"name",
|
131 |
+
"leave",
|
132 |
+
"destination",
|
133 |
+
"departure",
|
134 |
+
"arrive",
|
135 |
+
"department",
|
136 |
+
"food",
|
137 |
+
"time",
|
138 |
+
]
|
139 |
+
# count: 17
|
140 |
+
|
141 |
+
all_slots = all_reqslot + [
|
142 |
+
"stay",
|
143 |
+
"day",
|
144 |
+
"people",
|
145 |
+
"name",
|
146 |
+
"destination",
|
147 |
+
"departure",
|
148 |
+
"department",
|
149 |
+
]
|
150 |
+
get_slot = {}
|
151 |
+
for s in all_slots:
|
152 |
+
get_slot[s] = 1
|
153 |
+
# count: 24
|
154 |
+
|
155 |
+
|
156 |
+
# mapping slots in dialogue act to original goal slot names
|
157 |
+
da_abbr_to_slot_name = {
|
158 |
+
"addr": "address",
|
159 |
+
"fee": "price",
|
160 |
+
"post": "postcode",
|
161 |
+
"ref": "reference",
|
162 |
+
"ticket": "price",
|
163 |
+
"depart": "departure",
|
164 |
+
"dest": "destination",
|
165 |
+
}
|
166 |
+
|
167 |
+
# slot merging: not used currently
|
168 |
+
# slot_name_to_value_token = {
|
169 |
+
# 'entrance fee': 'price',
|
170 |
+
# 'pricerange': 'price',
|
171 |
+
# 'arrive': 'time',
|
172 |
+
# 'leave': 'time',
|
173 |
+
# 'departure': 'name',
|
174 |
+
# 'destination': 'name',
|
175 |
+
# 'stay': 'count',
|
176 |
+
# 'people': 'count',
|
177 |
+
# 'stars': 'count',
|
178 |
+
# }
|
179 |
+
# dialog_act_dom = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'police', 'hospital', 'general', 'booking']
|
180 |
+
dialog_acts = {
|
181 |
+
"restaurant": [
|
182 |
+
"inform",
|
183 |
+
"request",
|
184 |
+
"nooffer",
|
185 |
+
"recommend",
|
186 |
+
"select",
|
187 |
+
"offerbook",
|
188 |
+
"offerbooked",
|
189 |
+
"nobook",
|
190 |
+
],
|
191 |
+
"hotel": [
|
192 |
+
"inform",
|
193 |
+
"request",
|
194 |
+
"nooffer",
|
195 |
+
"recommend",
|
196 |
+
"select",
|
197 |
+
"offerbook",
|
198 |
+
"offerbooked",
|
199 |
+
"nobook",
|
200 |
+
],
|
201 |
+
"attraction": ["inform", "request", "nooffer", "recommend", "select"],
|
202 |
+
"train": ["inform", "request", "nooffer", "offerbook", "offerbooked", "select"],
|
203 |
+
"taxi": ["inform", "request"],
|
204 |
+
"police": ["inform", "request"],
|
205 |
+
"hospital": ["inform", "request"],
|
206 |
+
# 'booking': ['book', 'inform', 'nobook', 'request'],
|
207 |
+
"general": ["bye", "greet", "reqmore", "welcome"],
|
208 |
+
}
|
209 |
+
all_acts = []
|
210 |
+
for acts in dialog_acts.values():
|
211 |
+
for act in acts:
|
212 |
+
if act not in all_acts:
|
213 |
+
all_acts.append(act)
|
214 |
+
# print(all_acts)
|
215 |
+
|
216 |
+
dialog_act_params = {
|
217 |
+
"inform": all_slots + ["choice", "open"],
|
218 |
+
"request": all_infslot + ["choice", "price"],
|
219 |
+
"nooffer": all_slots + ["choice"],
|
220 |
+
"recommend": all_reqslot + ["choice", "open"],
|
221 |
+
"select": all_slots + ["choice"],
|
222 |
+
# 'book': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'],
|
223 |
+
"nobook": ["time", "people", "stay", "reference", "day", "name", "choice"],
|
224 |
+
"offerbook": all_slots + ["choice"],
|
225 |
+
"offerbooked": all_slots + ["choice"],
|
226 |
+
"reqmore": [],
|
227 |
+
"welcome": [],
|
228 |
+
"bye": [],
|
229 |
+
"greet": [],
|
230 |
+
}
|
231 |
+
|
232 |
+
# dialog_acts = ['inform', 'request', 'nooffer', 'recommend', 'select', 'book', 'nobook', 'offerbook', 'offerbooked',
|
233 |
+
# 'reqmore', 'welcome', 'bye', 'greet'] # thank
|
234 |
+
dialog_act_all_slots = all_slots + ["choice", "open"]
|
235 |
+
# act_span_vocab = ['['+i+']' for i in dialog_act_dom] + ['['+i+']' for i in dialog_acts] + all_slots
|
236 |
+
|
237 |
+
# value_token_in_resp = ['address', 'name', 'phone', 'postcode', 'area', 'food', 'pricerange', 'id',
|
238 |
+
# 'department', 'place', 'day', 'count', 'car']
|
239 |
+
# count: 12
|
240 |
+
|
241 |
+
|
242 |
+
# special slot tokens in belief span
|
243 |
+
# no need of this, just covert slot to [slot] e.g. pricerange -> [pricerange]
|
244 |
+
slot_name_to_slot_token = {}
|
245 |
+
|
246 |
+
|
247 |
+
# special slot tokens in responses
|
248 |
+
# not use at the momoent
|
249 |
+
slot_name_to_value_token = {
|
250 |
+
# 'entrance fee': '[value_price]',
|
251 |
+
# 'pricerange': '[value_price]',
|
252 |
+
# 'arriveby': '[value_time]',
|
253 |
+
# 'leaveat': '[value_time]',
|
254 |
+
# 'departure': '[value_place]',
|
255 |
+
# 'destination': '[value_place]',
|
256 |
+
# 'stay': 'count',
|
257 |
+
# 'people': 'count'
|
258 |
+
}
|
259 |
+
|
260 |
+
|
261 |
+
db_tokens = [
|
262 |
+
"<sos_db>",
|
263 |
+
"<eos_db>",
|
264 |
+
"[db_nores]",
|
265 |
+
"[db_0]",
|
266 |
+
"[db_1]",
|
267 |
+
"[db_2]",
|
268 |
+
"[db_3]",
|
269 |
+
]
|
270 |
+
|
271 |
+
special_tokens = [
|
272 |
+
"<pad>",
|
273 |
+
"<go_r>",
|
274 |
+
"<unk>",
|
275 |
+
"<go_b>",
|
276 |
+
"<go_a>",
|
277 |
+
"<eos_u>",
|
278 |
+
"<eos_r>",
|
279 |
+
"<eos_b>",
|
280 |
+
"<eos_a>",
|
281 |
+
"<go_d>",
|
282 |
+
"<eos_d>",
|
283 |
+
"<sos_u>",
|
284 |
+
"<sos_r>",
|
285 |
+
"<sos_b>",
|
286 |
+
"<sos_a>",
|
287 |
+
"<sos_d>",
|
288 |
+
] + db_tokens
|
289 |
+
|
290 |
+
eos_tokens = {
|
291 |
+
"user": "<eos_u>",
|
292 |
+
"user_delex": "<eos_u>",
|
293 |
+
"resp": "<eos_r>",
|
294 |
+
"resp_gen": "<eos_r>",
|
295 |
+
"pv_resp": "<eos_r>",
|
296 |
+
"bspn": "<eos_b>",
|
297 |
+
"bspn_gen": "<eos_b>",
|
298 |
+
"pv_bspn": "<eos_b>",
|
299 |
+
"bsdx": "<eos_b>",
|
300 |
+
"bsdx_gen": "<eos_b>",
|
301 |
+
"pv_bsdx": "<eos_b>",
|
302 |
+
"aspn": "<eos_a>",
|
303 |
+
"aspn_gen": "<eos_a>",
|
304 |
+
"pv_aspn": "<eos_a>",
|
305 |
+
"dspn": "<eos_d>",
|
306 |
+
"dspn_gen": "<eos_d>",
|
307 |
+
"pv_dspn": "<eos_d>",
|
308 |
+
}
|
309 |
+
|
310 |
+
sos_tokens = {
|
311 |
+
"user": "<sos_u>",
|
312 |
+
"user_delex": "<sos_u>",
|
313 |
+
"resp": "<sos_r>",
|
314 |
+
"resp_gen": "<sos_r>",
|
315 |
+
"pv_resp": "<sos_r>",
|
316 |
+
"bspn": "<sos_b>",
|
317 |
+
"bspn_gen": "<sos_b>",
|
318 |
+
"pv_bspn": "<sos_b>",
|
319 |
+
"bsdx": "<sos_b>",
|
320 |
+
"bsdx_gen": "<sos_b>",
|
321 |
+
"pv_bsdx": "<sos_b>",
|
322 |
+
"aspn": "<sos_a>",
|
323 |
+
"aspn_gen": "<sos_a>",
|
324 |
+
"pv_aspn": "<sos_a>",
|
325 |
+
"dspn": "<sos_d>",
|
326 |
+
"dspn_gen": "<sos_d>",
|
327 |
+
"pv_dspn": "<sos_d>",
|
328 |
+
}
|
src/crazyneuraluser/UBAR_code/reader.py
ADDED
@@ -0,0 +1,1262 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import csv
|
2 |
+
import json
|
3 |
+
import logging
|
4 |
+
import os
|
5 |
+
import random
|
6 |
+
from collections import OrderedDict
|
7 |
+
from copy import deepcopy
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import spacy
|
11 |
+
from transformers import GPT2Tokenizer
|
12 |
+
|
13 |
+
from crazyneuraluser.UBAR_code import ontology, utils
|
14 |
+
from crazyneuraluser.UBAR_code.config import global_config as cfg
|
15 |
+
from crazyneuraluser.UBAR_code.db_ops import MultiWozDB
|
16 |
+
|
17 |
+
# from config21 import global_config as cfg
|
18 |
+
|
19 |
+
|
20 |
+
class _ReaderBase(object):
|
21 |
+
def __init__(self):
|
22 |
+
self.train, self.dev, self.test = [], [], []
|
23 |
+
self.vocab = None
|
24 |
+
self.db = None
|
25 |
+
self.set_stats = {}
|
26 |
+
|
27 |
+
def _bucket_by_turn(self, encoded_data):
|
28 |
+
turn_bucket = {}
|
29 |
+
for dial in encoded_data:
|
30 |
+
turn_len = len(dial)
|
31 |
+
if turn_len not in turn_bucket:
|
32 |
+
turn_bucket[turn_len] = []
|
33 |
+
turn_bucket[turn_len].append(dial)
|
34 |
+
del_l = []
|
35 |
+
for k in turn_bucket:
|
36 |
+
if k >= 5:
|
37 |
+
del_l.append(k)
|
38 |
+
logging.debug("bucket %d instance %d" % (k, len(turn_bucket[k])))
|
39 |
+
# for k in del_l:
|
40 |
+
# turn_bucket.pop(k)
|
41 |
+
return OrderedDict(sorted(turn_bucket.items(), key=lambda i: i[0]))
|
42 |
+
|
43 |
+
def _construct_mini_batch(self, data):
|
44 |
+
all_batches = []
|
45 |
+
batch = []
|
46 |
+
for dial in data:
|
47 |
+
batch.append(dial)
|
48 |
+
if len(batch) == cfg.batch_size:
|
49 |
+
# print('batch size: %d, batch num +1'%(len(batch)))
|
50 |
+
all_batches.append(batch)
|
51 |
+
batch = []
|
52 |
+
# if remainder > 1/2 batch_size, just put them in the previous batch, otherwise form a new batch
|
53 |
+
# print('last batch size: %d, batch num +1'%(len(batch)))
|
54 |
+
if (len(batch) % len(cfg.cuda_device)) != 0:
|
55 |
+
batch = batch[: -(len(batch) % len(cfg.cuda_device))]
|
56 |
+
if len(batch) > 0.5 * cfg.batch_size:
|
57 |
+
all_batches.append(batch)
|
58 |
+
elif len(all_batches):
|
59 |
+
all_batches[-1].extend(batch)
|
60 |
+
else:
|
61 |
+
all_batches.append(batch)
|
62 |
+
return all_batches
|
63 |
+
|
64 |
+
def transpose_batch(self, batch):
|
65 |
+
dial_batch = []
|
66 |
+
turn_num = len(batch[0])
|
67 |
+
for turn in range(turn_num):
|
68 |
+
turn_l = {}
|
69 |
+
for dial in batch:
|
70 |
+
this_turn = dial[turn]
|
71 |
+
for k in this_turn:
|
72 |
+
if k not in turn_l:
|
73 |
+
turn_l[k] = []
|
74 |
+
turn_l[k].append(this_turn[k])
|
75 |
+
dial_batch.append(turn_l)
|
76 |
+
return dial_batch
|
77 |
+
|
78 |
+
def inverse_transpose_turn(self, turn_list):
|
79 |
+
"""
|
80 |
+
eval, one dialog at a time
|
81 |
+
"""
|
82 |
+
dialogs = {}
|
83 |
+
turn_num = len(turn_list)
|
84 |
+
dial_id = turn_list[0]["dial_id"]
|
85 |
+
dialogs[dial_id] = []
|
86 |
+
for turn_idx in range(turn_num):
|
87 |
+
dial_turn = {}
|
88 |
+
turn = turn_list[turn_idx]
|
89 |
+
for key, value in turn.items():
|
90 |
+
if key == "dial_id":
|
91 |
+
continue
|
92 |
+
if key == "pointer" and self.db is not None:
|
93 |
+
turn_domain = turn["turn_domain"][-1]
|
94 |
+
value = self.db.pointerBack(value, turn_domain)
|
95 |
+
dial_turn[key] = value
|
96 |
+
dialogs[dial_id].append(dial_turn)
|
97 |
+
return dialogs
|
98 |
+
|
99 |
+
def inverse_transpose_batch(self, turn_batch_list):
|
100 |
+
"""
|
101 |
+
:param turn_batch_list: list of transpose dial batch
|
102 |
+
"""
|
103 |
+
dialogs = {}
|
104 |
+
total_turn_num = len(turn_batch_list)
|
105 |
+
# initialize
|
106 |
+
for idx_in_batch, dial_id in enumerate(turn_batch_list[0]["dial_id"]):
|
107 |
+
dialogs[dial_id] = []
|
108 |
+
for turn_n in range(total_turn_num):
|
109 |
+
dial_turn = {}
|
110 |
+
turn_batch = turn_batch_list[turn_n]
|
111 |
+
for key, v_list in turn_batch.items():
|
112 |
+
if key == "dial_id":
|
113 |
+
continue
|
114 |
+
value = v_list[idx_in_batch]
|
115 |
+
if key == "pointer" and self.db is not None:
|
116 |
+
turn_domain = turn_batch["turn_domain"][idx_in_batch][-1]
|
117 |
+
value = self.db.pointerBack(value, turn_domain)
|
118 |
+
dial_turn[key] = value
|
119 |
+
dialogs[dial_id].append(dial_turn)
|
120 |
+
return dialogs
|
121 |
+
|
122 |
+
def get_eval_data(self, set_name="dev"):
|
123 |
+
name_to_set = {"train": self.train, "test": self.test, "dev": self.dev}
|
124 |
+
dial = name_to_set[set_name]
|
125 |
+
|
126 |
+
if set_name not in self.set_stats:
|
127 |
+
self.set_stats[set_name] = {}
|
128 |
+
num_turns = 0
|
129 |
+
num_dials = len(dial)
|
130 |
+
for d in dial:
|
131 |
+
num_turns += len(d)
|
132 |
+
|
133 |
+
self.set_stats[set_name]["num_turns"] = num_turns
|
134 |
+
self.set_stats[set_name]["num_dials"] = num_dials
|
135 |
+
|
136 |
+
return dial
|
137 |
+
|
138 |
+
def get_batches(self, set_name):
|
139 |
+
"""
|
140 |
+
compute dataset stats.
|
141 |
+
"""
|
142 |
+
global dia_count
|
143 |
+
log_str = ""
|
144 |
+
name_to_set = {"train": self.train, "test": self.test, "dev": self.dev}
|
145 |
+
dial = name_to_set[set_name]
|
146 |
+
if cfg.low_resource and set_name == "train":
|
147 |
+
# dial = random.sample(dial, int(len(dial)*0.01))
|
148 |
+
dial = random.sample(dial, 100)
|
149 |
+
logging.info("Low Resource setting, finetuning size: {}".format(len(dial)))
|
150 |
+
turn_bucket = self._bucket_by_turn(dial)
|
151 |
+
# self._shuffle_turn_bucket(turn_bucket)
|
152 |
+
all_batches = []
|
153 |
+
|
154 |
+
if set_name not in self.set_stats:
|
155 |
+
self.set_stats[set_name] = {}
|
156 |
+
num_training_steps = 0
|
157 |
+
num_turns = 0
|
158 |
+
num_dials = 0
|
159 |
+
|
160 |
+
for k in turn_bucket:
|
161 |
+
if set_name != "test" and k == 1 or k >= 17:
|
162 |
+
continue
|
163 |
+
batches = self._construct_mini_batch(turn_bucket[k])
|
164 |
+
log_str += "turn num:%d, dial num: %d, batch num: %d last batch len: %d\n" % (
|
165 |
+
k,
|
166 |
+
len(turn_bucket[k]),
|
167 |
+
len(batches),
|
168 |
+
len(batches[-1]),
|
169 |
+
)
|
170 |
+
# print("turn num:%d, dial num:v%d, batch num: %d, "%(k, len(turn_bucket[k]), len(batches)))
|
171 |
+
num_training_steps += k * len(batches)
|
172 |
+
num_turns += k * len(turn_bucket[k])
|
173 |
+
num_dials += len(turn_bucket[k])
|
174 |
+
all_batches += batches
|
175 |
+
log_str += "total batch num: %d\n" % len(all_batches)
|
176 |
+
# print('total batch num: %d'%len(all_batches))
|
177 |
+
# print('dialog count: %d'%dia_count)
|
178 |
+
# return all_batches
|
179 |
+
|
180 |
+
# log stats
|
181 |
+
# logging.info(log_str)
|
182 |
+
# cfg.num_training_steps = num_training_steps * cfg.epoch_num
|
183 |
+
self.set_stats[set_name]["num_training_steps_per_epoch"] = num_training_steps
|
184 |
+
self.set_stats[set_name]["num_turns"] = num_turns
|
185 |
+
self.set_stats[set_name]["num_dials"] = num_dials
|
186 |
+
|
187 |
+
if set_name == "train":
|
188 |
+
random.shuffle(all_batches)
|
189 |
+
return all_batches
|
190 |
+
|
191 |
+
def get_nontranspose_data_iterator(self, all_batches):
|
192 |
+
for i, batch in enumerate(all_batches):
|
193 |
+
yield batch
|
194 |
+
|
195 |
+
def get_data_iterator(self, all_batches):
|
196 |
+
for i, batch in enumerate(all_batches):
|
197 |
+
yield self.transpose_batch(batch)
|
198 |
+
|
199 |
+
def save_result(self, write_mode, results, field, write_title=False):
|
200 |
+
with open(cfg.result_path, write_mode) as rf:
|
201 |
+
if write_title:
|
202 |
+
rf.write(write_title + "\n")
|
203 |
+
writer = csv.DictWriter(rf, fieldnames=field)
|
204 |
+
writer.writeheader()
|
205 |
+
writer.writerows(results)
|
206 |
+
return None
|
207 |
+
|
208 |
+
def save_result_report(self, results):
|
209 |
+
# if 'joint_goal' in results[0]:
|
210 |
+
# with open(cfg.result_path[:-4] + '_report_dst.txt', 'w') as rf:
|
211 |
+
# rf.write('joint goal\tslot_acc\tslot_f1\tact_f1\n')
|
212 |
+
# for res in results:
|
213 |
+
# a,b,c,d = res['joint_goal'], res['slot_acc'], res['slot_f1'], res['act_f1']
|
214 |
+
# rf.write('%2.1f\t%2.1f\t%2.1f\t%2.1f\n'%(a,b,c,d))
|
215 |
+
# elif 'joint_goal_delex' in results[0]:
|
216 |
+
# with open(cfg.result_path[:-4] + '_report_bsdx.txt', 'w') as rf:
|
217 |
+
# rf.write('joint goal\tslot_acc\tslot_f1\tact_f1\n')
|
218 |
+
# for res in results:
|
219 |
+
# a,b,c,d = res['joint_goal_delex'], res['slot_acc_delex'], res['slot_f1_delex'], res['act_f1']
|
220 |
+
# rf.write('%2.1f\t%2.1f\t%2.1f\t%2.1f\n'%(a,b,c,d))
|
221 |
+
ctr_save_path = cfg.result_path[:-4] + "_report_ctr%s.csv" % cfg.seed
|
222 |
+
write_title = False if os.path.exists(ctr_save_path) else True
|
223 |
+
if cfg.aspn_decode_mode == "greedy":
|
224 |
+
setting = ""
|
225 |
+
elif cfg.aspn_decode_mode == "beam":
|
226 |
+
setting = "width=%s" % str(cfg.beam_width)
|
227 |
+
if cfg.beam_diverse_param > 0:
|
228 |
+
setting += ", penalty=%s" % str(cfg.beam_diverse_param)
|
229 |
+
elif cfg.aspn_decode_mode == "topk_sampling":
|
230 |
+
setting = "topk=%s" % str(cfg.topk_num)
|
231 |
+
elif cfg.aspn_decode_mode == "nucleur_sampling":
|
232 |
+
setting = "p=%s" % str(cfg.nucleur_p)
|
233 |
+
res = {
|
234 |
+
"exp": cfg.eval_load_path,
|
235 |
+
"true_bspn": cfg.use_true_curr_bspn,
|
236 |
+
"true_aspn": cfg.use_true_curr_aspn,
|
237 |
+
"decode": cfg.aspn_decode_mode,
|
238 |
+
"param": setting,
|
239 |
+
"nbest": cfg.nbest,
|
240 |
+
"selection_sheme": cfg.act_selection_scheme,
|
241 |
+
"match": results[0]["match"],
|
242 |
+
"success": results[0]["success"],
|
243 |
+
"bleu": results[0]["bleu"],
|
244 |
+
"act_f1": results[0]["act_f1"],
|
245 |
+
"avg_act_num": results[0]["avg_act_num"],
|
246 |
+
"avg_diverse": results[0]["avg_diverse_score"],
|
247 |
+
}
|
248 |
+
with open(ctr_save_path, "a") as rf:
|
249 |
+
writer = csv.DictWriter(rf, fieldnames=list(res.keys()))
|
250 |
+
if write_title:
|
251 |
+
writer.writeheader()
|
252 |
+
writer.writerows([res])
|
253 |
+
|
254 |
+
|
255 |
+
class MultiWozReader(_ReaderBase):
|
256 |
+
def __init__(self, tokenizer):
|
257 |
+
super().__init__()
|
258 |
+
self.nlp = spacy.load("en_core_web_sm")
|
259 |
+
|
260 |
+
self.db = MultiWozDB(cfg.dbs)
|
261 |
+
self.vocab_size = self._build_vocab()
|
262 |
+
|
263 |
+
# self.tokenizer = GPT2Tokenizer.from_pretrained(cfg.gpt_path) # add special tokens later
|
264 |
+
self.tokenizer = tokenizer
|
265 |
+
if cfg.mode == "train":
|
266 |
+
self.add_sepcial_tokens()
|
267 |
+
|
268 |
+
self.domain_files = json.loads(open(cfg.domain_file_path, "r").read())
|
269 |
+
self.slot_value_set = json.loads(open(cfg.slot_value_set_path, "r").read())
|
270 |
+
if cfg.multi_acts_training:
|
271 |
+
self.multi_acts = json.loads(open(cfg.multi_acts_path, "r").read())
|
272 |
+
|
273 |
+
test_list = [test_list.strip().lower() for test_list in open(cfg.test_list, "r").readlines()]
|
274 |
+
dev_list = [dev_list.strip().lower() for dev_list in open(cfg.dev_list, "r").readlines()]
|
275 |
+
self.dev_files, self.test_files = {}, {}
|
276 |
+
for fn in test_list:
|
277 |
+
self.test_files[fn.replace(".json", "")] = 1
|
278 |
+
for fn in dev_list:
|
279 |
+
self.dev_files[fn.replace(".json", "")] = 1
|
280 |
+
|
281 |
+
# for domain expanse aka. Cross domain
|
282 |
+
self.exp_files = {}
|
283 |
+
# if 'all' not in cfg.exp_domains:
|
284 |
+
# for domain in cfg.exp_domains:
|
285 |
+
# fn_list = self.domain_files.get(domain)
|
286 |
+
# if not fn_list:
|
287 |
+
# raise ValueError(
|
288 |
+
# '[%s] is an invalid experiment setting' % domain)
|
289 |
+
# for fn in fn_list:
|
290 |
+
# self.exp_files[fn.replace('.json', '')] = 1
|
291 |
+
all_domains_list = list(self.domain_files.keys())
|
292 |
+
if "all" not in cfg.exp_domains:
|
293 |
+
domains = self.get_exp_domains(cfg.exp_domains, all_domains_list)
|
294 |
+
logging.info(domains)
|
295 |
+
for domain in domains:
|
296 |
+
fn_list = self.domain_files.get(domain)
|
297 |
+
if not fn_list:
|
298 |
+
raise ValueError("[%s] is an invalid experiment setting" % domain)
|
299 |
+
for fn in fn_list:
|
300 |
+
self.exp_files[fn.replace(".json", "")] = 1
|
301 |
+
#
|
302 |
+
|
303 |
+
self._load_data()
|
304 |
+
|
305 |
+
if cfg.limit_bspn_vocab:
|
306 |
+
self.bspn_masks = self._construct_bspn_constraint()
|
307 |
+
if cfg.limit_aspn_vocab:
|
308 |
+
self.aspn_masks = self._construct_aspn_constraint()
|
309 |
+
|
310 |
+
self.multi_acts_record = None
|
311 |
+
|
312 |
+
def get_exp_domains(self, exp_domains, all_domains_list):
|
313 |
+
if "hotel" in exp_domains:
|
314 |
+
if "except" in exp_domains:
|
315 |
+
# ['except', 'hotel']
|
316 |
+
domains = [d for d in all_domains_list if "hotel" not in d and "multi" not in d]
|
317 |
+
else:
|
318 |
+
# ['hotel']
|
319 |
+
domains = ["hotel_single", "hotel_multi"]
|
320 |
+
if "train" in exp_domains:
|
321 |
+
if "except" in exp_domains:
|
322 |
+
# ['except', 'train']
|
323 |
+
domains = [d for d in all_domains_list if "train" not in d and "multi" not in d]
|
324 |
+
else:
|
325 |
+
# ['train']
|
326 |
+
domains = ["train_single", "train_multi"]
|
327 |
+
if "attraction" in exp_domains:
|
328 |
+
if "except" in exp_domains:
|
329 |
+
# ['except', 'attraction']
|
330 |
+
domains = [d for d in all_domains_list if "attraction" not in d and "multi" not in d]
|
331 |
+
else:
|
332 |
+
# ['attraction']
|
333 |
+
domains = ["attraction_single", "attraction_multi"]
|
334 |
+
if "restaurant" in exp_domains:
|
335 |
+
if "except" in exp_domains:
|
336 |
+
# ['except', 'restaurant']
|
337 |
+
domains = [d for d in all_domains_list if "restaurant" not in d and "multi" not in d]
|
338 |
+
else:
|
339 |
+
# ['restaurant']
|
340 |
+
domains = ["restaurant_single", "restaurant_multi"]
|
341 |
+
if "taxi" in exp_domains:
|
342 |
+
if "except" in exp_domains:
|
343 |
+
# ['except', 'taxi']
|
344 |
+
domains = [d for d in all_domains_list if "taxi" not in d and "multi" not in d]
|
345 |
+
else:
|
346 |
+
# ['taxi']
|
347 |
+
domains = ["taxi_single", "taxi_multi"]
|
348 |
+
return domains
|
349 |
+
|
350 |
+
def add_sepcial_tokens(self):
|
351 |
+
"""
|
352 |
+
add special tokens to gpt tokenizer
|
353 |
+
serves a similar role of Vocab.construt()
|
354 |
+
make a dict of special tokens
|
355 |
+
"""
|
356 |
+
special_tokens = []
|
357 |
+
for word in ontology.all_domains + ["general"]:
|
358 |
+
word = "[" + word + "]"
|
359 |
+
special_tokens.append(word)
|
360 |
+
for word in ontology.all_acts:
|
361 |
+
word = "[" + word + "]"
|
362 |
+
special_tokens.append(word)
|
363 |
+
# for word in ontology.all_slots:
|
364 |
+
# to be determine whether slot should be [slot]
|
365 |
+
# if slot, tokenizer having trouble decoding.
|
366 |
+
# special_tokens.append(word)
|
367 |
+
for word in self.vocab._word2idx.keys():
|
368 |
+
if word.startswith("[value_") and word.endswith("]"):
|
369 |
+
special_tokens.append(word)
|
370 |
+
special_tokens.extend(ontology.special_tokens)
|
371 |
+
|
372 |
+
special_tokens_dict = {"additional_special_tokens": special_tokens}
|
373 |
+
self.tokenizer.add_special_tokens(special_tokens_dict)
|
374 |
+
logging.info("Added special tokens to gpt tokenizer.")
|
375 |
+
|
376 |
+
cfg.pad_id = self.tokenizer.encode("<pad>")[0]
|
377 |
+
|
378 |
+
def _build_vocab(self):
|
379 |
+
self.vocab = utils.Vocab(cfg.vocab_size)
|
380 |
+
vp = cfg.vocab_path_train if cfg.mode == "train" or cfg.vocab_path_eval is None else cfg.vocab_path_eval
|
381 |
+
# vp = cfg.vocab_path+'.json.freq.json'
|
382 |
+
self.vocab.load_vocab(vp)
|
383 |
+
return self.vocab.vocab_size
|
384 |
+
|
385 |
+
def _construct_bspn_constraint(self):
|
386 |
+
bspn_masks = {}
|
387 |
+
valid_domains = [
|
388 |
+
"restaurant",
|
389 |
+
"hotel",
|
390 |
+
"attraction",
|
391 |
+
"train",
|
392 |
+
"taxi",
|
393 |
+
"hospital",
|
394 |
+
]
|
395 |
+
all_dom_codes = [self.vocab.encode("[" + d + "]") for d in valid_domains]
|
396 |
+
all_slot_codes = [self.vocab.encode(s) for s in ontology.all_slots]
|
397 |
+
bspn_masks[self.vocab.encode("<go_b>")] = all_dom_codes + [
|
398 |
+
self.vocab.encode("<eos_b>"),
|
399 |
+
0,
|
400 |
+
]
|
401 |
+
bspn_masks[self.vocab.encode("<eos_b>")] = [self.vocab.encode("<pad>")]
|
402 |
+
bspn_masks[self.vocab.encode("<pad>")] = [self.vocab.encode("<pad>")]
|
403 |
+
for domain, slot_values in self.slot_value_set.items():
|
404 |
+
if domain == "police":
|
405 |
+
continue
|
406 |
+
dom_code = self.vocab.encode("[" + domain + "]")
|
407 |
+
bspn_masks[dom_code] = []
|
408 |
+
for slot, values in slot_values.items():
|
409 |
+
slot_code = self.vocab.encode(slot)
|
410 |
+
if slot_code not in bspn_masks:
|
411 |
+
bspn_masks[slot_code] = []
|
412 |
+
if slot_code not in bspn_masks[dom_code]:
|
413 |
+
bspn_masks[dom_code].append(slot_code)
|
414 |
+
for value in values:
|
415 |
+
for idx, v in enumerate(value.split()):
|
416 |
+
if not self.vocab.has_word(v):
|
417 |
+
continue
|
418 |
+
v_code = self.vocab.encode(v)
|
419 |
+
if v_code not in bspn_masks:
|
420 |
+
# print(self.vocab._word2idx)
|
421 |
+
bspn_masks[v_code] = []
|
422 |
+
if idx == 0 and v_code not in bspn_masks[slot_code]:
|
423 |
+
bspn_masks[slot_code].append(v_code)
|
424 |
+
if idx == (len(value.split()) - 1):
|
425 |
+
for w in all_dom_codes + all_slot_codes:
|
426 |
+
if self.vocab.encode("<eos_b>") not in bspn_masks[v_code]:
|
427 |
+
bspn_masks[v_code].append(self.vocab.encode("<eos_b>"))
|
428 |
+
if w not in bspn_masks[v_code]:
|
429 |
+
bspn_masks[v_code].append(w)
|
430 |
+
break
|
431 |
+
if not self.vocab.has_word(value.split()[idx + 1]):
|
432 |
+
continue
|
433 |
+
next_v_code = self.vocab.encode(value.split()[idx + 1])
|
434 |
+
if next_v_code not in bspn_masks[v_code]:
|
435 |
+
bspn_masks[v_code].append(next_v_code)
|
436 |
+
bspn_masks[self.vocab.encode("<unk>")] = list(bspn_masks.keys())
|
437 |
+
|
438 |
+
with open("data/processed/multi-woz-processed/bspn_masks.txt", "w") as f:
|
439 |
+
for i, j in bspn_masks.items():
|
440 |
+
f.write(self.vocab.decode(i) + ": " + " ".join([self.vocab.decode(int(m)) for m in j]) + "\n")
|
441 |
+
return bspn_masks
|
442 |
+
|
443 |
+
def _construct_aspn_constraint(self):
|
444 |
+
aspn_masks = {}
|
445 |
+
aspn_masks = {}
|
446 |
+
all_dom_codes = [self.vocab.encode("[" + d + "]") for d in ontology.dialog_acts.keys()]
|
447 |
+
all_act_codes = [self.vocab.encode("[" + a + "]") for a in ontology.dialog_act_params]
|
448 |
+
all_slot_codes = [self.vocab.encode(s) for s in ontology.dialog_act_all_slots]
|
449 |
+
aspn_masks[self.vocab.encode("<go_a>")] = all_dom_codes + [
|
450 |
+
self.vocab.encode("<eos_a>"),
|
451 |
+
0,
|
452 |
+
]
|
453 |
+
aspn_masks[self.vocab.encode("<eos_a>")] = [self.vocab.encode("<pad>")]
|
454 |
+
aspn_masks[self.vocab.encode("<pad>")] = [self.vocab.encode("<pad>")]
|
455 |
+
# for d in all_dom_codes:
|
456 |
+
# aspn_masks[d] = all_act_codes
|
457 |
+
for a in all_act_codes:
|
458 |
+
aspn_masks[a] = all_dom_codes + all_slot_codes + [self.vocab.encode("<eos_a>")]
|
459 |
+
for domain, acts in ontology.dialog_acts.items():
|
460 |
+
dom_code = self.vocab.encode("[" + domain + "]")
|
461 |
+
aspn_masks[dom_code] = []
|
462 |
+
for a in acts:
|
463 |
+
act_code = self.vocab.encode("[" + a + "]")
|
464 |
+
if act_code not in aspn_masks[dom_code]:
|
465 |
+
aspn_masks[dom_code].append(act_code)
|
466 |
+
# for a, slots in ontology.dialog_act_params.items():
|
467 |
+
# act_code = self.vocab.encode('['+a+']')
|
468 |
+
# slot_codes = [self.vocab.encode(s) for s in slots]
|
469 |
+
# aspn_masks[act_code] = all_dom_codes + slot_codes + [self.vocab.encode('<eos_a>')]
|
470 |
+
for s in all_slot_codes:
|
471 |
+
aspn_masks[s] = all_dom_codes + all_slot_codes + [self.vocab.encode("<eos_a>")]
|
472 |
+
aspn_masks[self.vocab.encode("<unk>")] = list(aspn_masks.keys())
|
473 |
+
|
474 |
+
with open("processed/multi-woz-processed/aspn_masks.txt", "w") as f:
|
475 |
+
for i, j in aspn_masks.items():
|
476 |
+
f.write(self.vocab.decode(i) + ": " + " ".join([self.vocab.decode(int(m)) for m in j]) + "\n")
|
477 |
+
return aspn_masks
|
478 |
+
|
479 |
+
def _load_data(self, save_temp=True):
|
480 |
+
"""
|
481 |
+
load processed data and encode, or load already encoded data
|
482 |
+
"""
|
483 |
+
if save_temp: # save encoded data
|
484 |
+
if "all" in cfg.exp_domains:
|
485 |
+
encoded_file = os.path.join(cfg.data_path, "new_db_se_blank_encoded.data.json")
|
486 |
+
# encoded: no sos, se_encoded: sos and eos
|
487 |
+
# db: add db results every turn
|
488 |
+
else:
|
489 |
+
xdomain_dir = "./models/UBAR/experiments_Xdomain/data"
|
490 |
+
if not os.path.exists(xdomain_dir):
|
491 |
+
os.makedirs(xdomain_dir)
|
492 |
+
encoded_file = os.path.join(
|
493 |
+
xdomain_dir,
|
494 |
+
"{}-encoded.data.json".format("-".join(cfg.exp_domains)),
|
495 |
+
)
|
496 |
+
|
497 |
+
if os.path.exists(encoded_file):
|
498 |
+
logging.info("Reading encoded data from {}".format(encoded_file))
|
499 |
+
self.data = json.loads(open(cfg.data_path + cfg.data_file, "r", encoding="utf-8").read().lower())
|
500 |
+
encoded_data = json.loads(open(encoded_file, "r", encoding="utf-8").read())
|
501 |
+
self.train = encoded_data["train"]
|
502 |
+
self.dev = encoded_data["dev"]
|
503 |
+
self.test = encoded_data["test"]
|
504 |
+
else:
|
505 |
+
logging.info("Encoding data now and save the encoded data in {}".format(encoded_file))
|
506 |
+
# not exists, encode data and save
|
507 |
+
self.data = json.loads(open(cfg.data_path + cfg.data_file, "r", encoding="utf-8").read().lower())
|
508 |
+
self.train, self.dev, self.test = [], [], []
|
509 |
+
for fn, dial in self.data.items():
|
510 |
+
if ".json" in fn:
|
511 |
+
fn = fn.replace(".json", "")
|
512 |
+
if "all" in cfg.exp_domains or self.exp_files.get(fn):
|
513 |
+
if self.dev_files.get(fn):
|
514 |
+
self.dev.append(self._get_encoded_data(fn, dial))
|
515 |
+
elif self.test_files.get(fn):
|
516 |
+
self.test.append(self._get_encoded_data(fn, dial))
|
517 |
+
else:
|
518 |
+
self.train.append(self._get_encoded_data(fn, dial))
|
519 |
+
|
520 |
+
# save encoded data
|
521 |
+
encoded_data = {"train": self.train, "dev": self.dev, "test": self.test}
|
522 |
+
json.dump(encoded_data, open(encoded_file, "w"), indent=2)
|
523 |
+
|
524 |
+
else: # directly read processed data and encode
|
525 |
+
self.data = json.loads(open(cfg.data_path + cfg.data_file, "r", encoding="utf-8").read().lower())
|
526 |
+
self.train, self.dev, self.test = [], [], []
|
527 |
+
for fn, dial in self.data.items():
|
528 |
+
if ".json" in fn:
|
529 |
+
fn = fn.replace(".json", "")
|
530 |
+
if "all" in cfg.exp_domains or self.exp_files.get(fn):
|
531 |
+
if self.dev_files.get(fn):
|
532 |
+
self.dev.append(self._get_encoded_data(fn, dial))
|
533 |
+
elif self.test_files.get(fn):
|
534 |
+
self.test.append(self._get_encoded_data(fn, dial))
|
535 |
+
else:
|
536 |
+
self.train.append(self._get_encoded_data(fn, dial))
|
537 |
+
# if save_temp:
|
538 |
+
# json.dump(self.test, open(
|
539 |
+
# 'data/multi-woz-analysis/test.encoded.json', 'w'), indent=2)
|
540 |
+
# self.vocab.save_vocab('data/multi-woz-analysis/vocab_temp')
|
541 |
+
|
542 |
+
random.shuffle(self.train)
|
543 |
+
# random.shuffle(self.dev)
|
544 |
+
# random.shuffle(self.test)
|
545 |
+
logging.info("train size:{}, dev size:{}, test size:{}".format(len(self.train), len(self.dev), len(self.test)))
|
546 |
+
|
547 |
+
def _get_encoded_data(self, fn, dial):
|
548 |
+
encoded_dial = []
|
549 |
+
for idx, t in enumerate(dial["log"]): # tokenize to list of ids
|
550 |
+
enc = {}
|
551 |
+
enc["dial_id"] = fn
|
552 |
+
|
553 |
+
# enc['user'] = self.vocab.sentence_encode(t['user'].split() + ['<eos_u>'])
|
554 |
+
# enc['usdx'] = self.vocab.sentence_encode(t['user_delex'].split() + ['<eos_u>'])
|
555 |
+
# enc['resp'] = self.vocab.sentence_encode(t['resp'].split() + ['<eos_r>'])
|
556 |
+
# enc['bspn'] = self.vocab.sentence_encode(t['constraint'].split() + ['<eos_b>'])
|
557 |
+
# enc['bsdx'] = self.vocab.sentence_encode(t['cons_delex'].split() + ['<eos_b>'])
|
558 |
+
# enc['aspn'] = self.vocab.sentence_encode(t['sys_act'].split() + ['<eos_a>'])
|
559 |
+
# enc['dspn'] = self.vocab.sentence_encode(t['turn_domain'].split() + ['<eos_d>'])
|
560 |
+
|
561 |
+
# use gpt tokenizer directly tokenize word list, prone to encode unknown words to |endoftext|
|
562 |
+
# enc['user'] = self.tokenizer.encode(
|
563 |
+
# t['user'].split() + ['<eos_u>'])
|
564 |
+
# enc['usdx'] = self.tokenizer.encode(
|
565 |
+
# t['user_delex'].split() + ['<eos_u>'])
|
566 |
+
# enc['resp'] = self.tokenizer.encode(
|
567 |
+
# t['resp'].split() + ['<eos_r>'])
|
568 |
+
# enc['bspn'] = self.tokenizer.encode(
|
569 |
+
# t['constraint'].split() + ['<eos_b>'])
|
570 |
+
# enc['bsdx'] = self.tokenizer.encode(
|
571 |
+
# t['cons_delex'].split() + ['<eos_b>'])
|
572 |
+
# enc['aspn'] = self.tokenizer.encode(
|
573 |
+
# t['sys_act'].split() + ['<eos_a>'])
|
574 |
+
# enc['dspn'] = self.tokenizer.encode(
|
575 |
+
# t['turn_domain'].split() + ['<eos_d>'])
|
576 |
+
|
577 |
+
# gpt use bpe to encode strings, very very slow. ~9min
|
578 |
+
# in tokenization_utils.encode I find encode can pad_to_max_length, and reutrn tensor
|
579 |
+
enc["user"] = self.tokenizer.convert_tokens_to_ids(
|
580 |
+
self.tokenizer.tokenize("<sos_u> " + t["user"] + " <eos_u>")
|
581 |
+
)
|
582 |
+
enc["usdx"] = self.tokenizer.convert_tokens_to_ids(
|
583 |
+
self.tokenizer.tokenize("<sos_u> " + t["user"] + " <eos_u>")
|
584 |
+
)
|
585 |
+
enc["resp"] = self.tokenizer.convert_tokens_to_ids(
|
586 |
+
self.tokenizer.tokenize("<sos_r> " + t["resp"] + " <eos_r>")
|
587 |
+
)
|
588 |
+
enc["bspn"] = self.tokenizer.convert_tokens_to_ids(
|
589 |
+
self.tokenizer.tokenize("<sos_b> " + t["constraint"] + " <eos_b>")
|
590 |
+
)
|
591 |
+
enc["bsdx"] = self.tokenizer.convert_tokens_to_ids(
|
592 |
+
self.tokenizer.tokenize("<sos_b> " + t["cons_delex"] + " <eos_b>")
|
593 |
+
)
|
594 |
+
enc["aspn"] = self.tokenizer.convert_tokens_to_ids(
|
595 |
+
self.tokenizer.tokenize("<sos_a> " + t["sys_act"] + " <eos_a>")
|
596 |
+
)
|
597 |
+
enc["dspn"] = self.tokenizer.convert_tokens_to_ids(
|
598 |
+
self.tokenizer.tokenize("<sos_d> " + t["turn_domain"] + " <eos_d>")
|
599 |
+
)
|
600 |
+
|
601 |
+
enc["pointer"] = [int(i) for i in t["pointer"].split(",")]
|
602 |
+
enc["turn_domain"] = t["turn_domain"].split()
|
603 |
+
enc["turn_num"] = t["turn_num"]
|
604 |
+
if cfg.multi_acts_training:
|
605 |
+
enc["aspn_aug"] = []
|
606 |
+
if fn in self.multi_acts:
|
607 |
+
turn_ma = self.multi_acts[fn].get(str(idx), {})
|
608 |
+
for act_type, act_spans in turn_ma.items():
|
609 |
+
enc["aspn_aug"].append([self.tokenizer.encode(a.split() + ["<eos_a>"]) for a in act_spans])
|
610 |
+
|
611 |
+
# add db results to enc, at every turn
|
612 |
+
db_pointer = self.bspan_to_DBpointer(t["constraint"], t["turn_domain"].split())
|
613 |
+
# db_tokens = ['<sos_db>', '<eos_db>', '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]']
|
614 |
+
enc["db"] = self.tokenizer.convert_tokens_to_ids(
|
615 |
+
self.tokenizer.tokenize("<sos_db> " + db_pointer + " <eos_db>")
|
616 |
+
)
|
617 |
+
|
618 |
+
encoded_dial.append(enc)
|
619 |
+
return encoded_dial
|
620 |
+
|
621 |
+
def bspan_to_constraint_dict(self, bspan, bspn_mode="bspn"):
|
622 |
+
bspan = bspan.split() if isinstance(bspan, str) else bspan
|
623 |
+
constraint_dict = {}
|
624 |
+
domain = None
|
625 |
+
conslen = len(bspan)
|
626 |
+
for idx, cons in enumerate(bspan):
|
627 |
+
cons = self.vocab.decode(cons) if type(cons) is not str else cons
|
628 |
+
if cons == "<eos_b>":
|
629 |
+
break
|
630 |
+
if "[" in cons:
|
631 |
+
if cons[1:-1] not in ontology.all_domains:
|
632 |
+
continue
|
633 |
+
domain = cons[1:-1]
|
634 |
+
elif cons in ontology.get_slot:
|
635 |
+
if domain is None:
|
636 |
+
continue
|
637 |
+
if cons == "people":
|
638 |
+
# handle confusion of value name "people's portraits..." and slot people
|
639 |
+
try:
|
640 |
+
ns = bspan[idx + 1]
|
641 |
+
ns = self.vocab.decode(ns) if type(ns) is not str else ns
|
642 |
+
if ns == "'s":
|
643 |
+
continue
|
644 |
+
except Exception:
|
645 |
+
continue
|
646 |
+
if not constraint_dict.get(domain):
|
647 |
+
constraint_dict[domain] = {}
|
648 |
+
if bspn_mode == "bsdx":
|
649 |
+
constraint_dict[domain][cons] = 1
|
650 |
+
continue
|
651 |
+
vidx = idx + 1
|
652 |
+
if vidx == conslen:
|
653 |
+
break
|
654 |
+
vt_collect = []
|
655 |
+
vt = bspan[vidx]
|
656 |
+
vt = self.vocab.decode(vt) if type(vt) is not str else vt
|
657 |
+
while vidx < conslen and vt != "<eos_b>" and "[" not in vt and vt not in ontology.get_slot:
|
658 |
+
vt_collect.append(vt)
|
659 |
+
vidx += 1
|
660 |
+
if vidx == conslen:
|
661 |
+
break
|
662 |
+
vt = bspan[vidx]
|
663 |
+
vt = self.vocab.decode(vt) if type(vt) is not str else vt
|
664 |
+
if vt_collect:
|
665 |
+
constraint_dict[domain][cons] = " ".join(vt_collect)
|
666 |
+
|
667 |
+
return constraint_dict
|
668 |
+
|
669 |
+
def bspan_to_DBpointer(self, bspan, turn_domain):
|
670 |
+
constraint_dict = self.bspan_to_constraint_dict(bspan)
|
671 |
+
# print(constraint_dict)
|
672 |
+
matnums = self.db.get_match_num(constraint_dict)
|
673 |
+
match_dom = turn_domain[0] if len(turn_domain) == 1 else turn_domain[1]
|
674 |
+
match_dom = match_dom[1:-1] if match_dom.startswith("[") else match_dom
|
675 |
+
match = matnums[match_dom]
|
676 |
+
# vector = self.db.addDBPointer(match_dom, match)
|
677 |
+
vector = self.db.addDBIndicator(match_dom, match)
|
678 |
+
return vector
|
679 |
+
|
680 |
+
def aspan_to_act_list(self, aspan):
|
681 |
+
aspan = aspan.split() if isinstance(aspan, str) else aspan
|
682 |
+
acts = []
|
683 |
+
domain = None
|
684 |
+
conslen = len(aspan)
|
685 |
+
for idx, cons in enumerate(aspan):
|
686 |
+
cons = self.vocab.decode(cons) if type(cons) is not str else cons
|
687 |
+
if cons == "<eos_a>":
|
688 |
+
break
|
689 |
+
if "[" in cons and cons[1:-1] in ontology.dialog_acts:
|
690 |
+
domain = cons[1:-1]
|
691 |
+
|
692 |
+
elif "[" in cons and cons[1:-1] in ontology.dialog_act_params:
|
693 |
+
if domain is None:
|
694 |
+
continue
|
695 |
+
vidx = idx + 1
|
696 |
+
if vidx == conslen:
|
697 |
+
acts.append(domain + "-" + cons[1:-1] + "-none")
|
698 |
+
break
|
699 |
+
vt = aspan[vidx]
|
700 |
+
vt = self.vocab.decode(vt) if type(vt) is not str else vt
|
701 |
+
no_param_act = True
|
702 |
+
while vidx < conslen and vt != "<eos_a>" and "[" not in vt:
|
703 |
+
no_param_act = False
|
704 |
+
acts.append(domain + "-" + cons[1:-1] + "-" + vt)
|
705 |
+
vidx += 1
|
706 |
+
if vidx == conslen:
|
707 |
+
break
|
708 |
+
vt = aspan[vidx]
|
709 |
+
vt = self.vocab.decode(vt) if type(vt) is not str else vt
|
710 |
+
if no_param_act:
|
711 |
+
acts.append(domain + "-" + cons[1:-1] + "-none")
|
712 |
+
|
713 |
+
return acts
|
714 |
+
|
715 |
+
def dspan_to_domain(self, dspan):
|
716 |
+
domains = {}
|
717 |
+
dspan = dspan.split() if isinstance(dspan, str) else dspan
|
718 |
+
for d in dspan:
|
719 |
+
dom = self.vocab.decode(d) if type(d) is not str else d
|
720 |
+
if dom != "<eos_d>":
|
721 |
+
domains[dom] = 1
|
722 |
+
else:
|
723 |
+
break
|
724 |
+
return domains
|
725 |
+
|
726 |
+
def convert_turn_eval(self, turn, pv_turn, first_turn=False):
|
727 |
+
"""
|
728 |
+
input: [all previous ubar, U_t, B_t, A_t] predict R_t
|
729 |
+
firts turn: [U_t, B_t, A_t] predict R_t
|
730 |
+
|
731 |
+
regarding the context, all previous ubar is too slow, try the previous ubar
|
732 |
+
"""
|
733 |
+
inputs = {}
|
734 |
+
|
735 |
+
context_list = []
|
736 |
+
# predict_list = []
|
737 |
+
prompt = ""
|
738 |
+
if cfg.use_true_curr_bspn:
|
739 |
+
if cfg.use_true_curr_aspn: # only predict resp
|
740 |
+
context_list = ["user", "bspn", "db", "aspn"]
|
741 |
+
# context_list = ['user','aspn'] # predict resp based on current aspn and bspn
|
742 |
+
# predict_list = ['resp']
|
743 |
+
prompt = "<sos_r>"
|
744 |
+
else: # predicted aspn
|
745 |
+
context_list = ["user", "bspn", "db"]
|
746 |
+
# predict_list = ['aspn', 'resp']
|
747 |
+
prompt = "<sos_a>"
|
748 |
+
else: # predict bspn aspn resp. db are not predicted. this part tbd.
|
749 |
+
context_list = ["user"]
|
750 |
+
# predict_list = ['bspn', 'db','aspn', 'resp']
|
751 |
+
prompt = "<sos_b>"
|
752 |
+
|
753 |
+
if first_turn:
|
754 |
+
context = []
|
755 |
+
for c in context_list:
|
756 |
+
context += turn[c]
|
757 |
+
|
758 |
+
inputs["context"] = context + self.tokenizer.encode([prompt])
|
759 |
+
inputs["labels"] = context
|
760 |
+
# e43 with BABAU
|
761 |
+
# inputs['labels'] = []
|
762 |
+
|
763 |
+
else:
|
764 |
+
context = []
|
765 |
+
for c in context_list:
|
766 |
+
context += turn[c]
|
767 |
+
|
768 |
+
pv_context = pv_turn["labels"] + pv_turn["bspn"] + pv_turn["db"] + pv_turn["aspn"] + pv_turn["resp"]
|
769 |
+
# e43 with BABAU
|
770 |
+
# pv_context = pv_turn['labels'] + pv_turn['bspn'] + pv_turn['db'] + pv_turn['aspn']
|
771 |
+
|
772 |
+
# prompt response, add sos_r
|
773 |
+
inputs["context"] = pv_context + context + self.tokenizer.encode([prompt])
|
774 |
+
# context just the current turn
|
775 |
+
# inputs['context'] = context + self.tokenizer.encode([prompt])
|
776 |
+
# context just the current action
|
777 |
+
|
778 |
+
if cfg.use_all_previous_context:
|
779 |
+
inputs["labels"] = pv_context + context # use all previous ubar history
|
780 |
+
else:
|
781 |
+
inputs["labels"] = context # use previous trun
|
782 |
+
|
783 |
+
if len(inputs["context"]) > 900:
|
784 |
+
print("len exceeds 900")
|
785 |
+
inputs["context"] = inputs["context"][-900:]
|
786 |
+
|
787 |
+
return inputs
|
788 |
+
|
789 |
+
def convert_batch_session(self, dial_batch):
|
790 |
+
"""
|
791 |
+
convert the whole session for training
|
792 |
+
concat [U_0, B_0, A_0, R_0, ... , U_n, B_n, A_n, R_n]
|
793 |
+
|
794 |
+
try: [user, bspn, aspn, resp]
|
795 |
+
or
|
796 |
+
try: [user, bspn, db, aspn, resp]
|
797 |
+
"""
|
798 |
+
inputs = {}
|
799 |
+
contexts = []
|
800 |
+
cell_list = ["user", "bspn", "db", "aspn", "resp"]
|
801 |
+
for idx, dial in enumerate(dial_batch):
|
802 |
+
context = []
|
803 |
+
for turn_num, turn in enumerate(dial):
|
804 |
+
for cell in cell_list:
|
805 |
+
context.extend(turn[cell])
|
806 |
+
contexts.append(context)
|
807 |
+
|
808 |
+
inputs["contexts"] = contexts
|
809 |
+
inputs["contexts_np"], inputs["lengths"] = utils.padSeqs_gpt(inputs["contexts"], cfg.pad_id)
|
810 |
+
return inputs
|
811 |
+
|
812 |
+
def convert_batch_gpt(self, turn_batch, pv_batch, first_turn=False):
|
813 |
+
"""
|
814 |
+
convert the current and the last turn
|
815 |
+
concat [U_{t-1}, B_{t-1}, A_{t-1}, R_{t-1}, U_t, B_t, A_t, R_t]
|
816 |
+
firts turn: [U_t, B_t, A_t, R_t]
|
817 |
+
try: [usdx, bspn, aspn, resp]
|
818 |
+
|
819 |
+
"""
|
820 |
+
inputs = {}
|
821 |
+
if first_turn:
|
822 |
+
contexts = []
|
823 |
+
batch_zipped = zip(
|
824 |
+
turn_batch["usdx"],
|
825 |
+
turn_batch["bspn"],
|
826 |
+
turn_batch["aspn"],
|
827 |
+
turn_batch["resp"],
|
828 |
+
)
|
829 |
+
for u, b, a, r in batch_zipped:
|
830 |
+
context = u + b + a + r
|
831 |
+
contexts.append(context)
|
832 |
+
inputs["contexts"] = contexts
|
833 |
+
# padSeqs to make [UBAR] the same length
|
834 |
+
inputs["contexts_np"], inputs["lengths"] = utils.padSeqs_gpt(inputs["contexts"], cfg.pad_id)
|
835 |
+
else:
|
836 |
+
contexts = []
|
837 |
+
batch_zipped = zip(
|
838 |
+
pv_batch["pv_usdx"],
|
839 |
+
pv_batch["pv_bspn"],
|
840 |
+
pv_batch["pv_aspn"],
|
841 |
+
pv_batch["pv_resp"],
|
842 |
+
turn_batch["usdx"],
|
843 |
+
turn_batch["bspn"],
|
844 |
+
turn_batch["aspn"],
|
845 |
+
turn_batch["resp"],
|
846 |
+
)
|
847 |
+
for pu, pb, pa, pr, u, b, a, r in batch_zipped:
|
848 |
+
context = pu + pb + pa + pr + u + b + a + r
|
849 |
+
contexts.append(context)
|
850 |
+
inputs["contexts"] = contexts
|
851 |
+
contexts_np, lengths = utils.padSeqs_gpt(inputs["contexts"], cfg.pad_id)
|
852 |
+
inputs["contexts_np"] = contexts_np
|
853 |
+
inputs["lengths"] = lengths
|
854 |
+
return inputs
|
855 |
+
|
856 |
+
def convert_batch(self, py_batch, py_prev, first_turn=False):
|
857 |
+
inputs = {}
|
858 |
+
if first_turn:
|
859 |
+
for item, py_list in py_prev.items():
|
860 |
+
batch_size = len(py_batch["user"])
|
861 |
+
inputs[item + "_np"] = np.array([[1]] * batch_size)
|
862 |
+
inputs[item + "_unk_np"] = np.array([[1]] * batch_size)
|
863 |
+
else:
|
864 |
+
for item, py_list in py_prev.items():
|
865 |
+
if py_list is None:
|
866 |
+
continue
|
867 |
+
if not cfg.enable_aspn and "aspn" in item:
|
868 |
+
continue
|
869 |
+
if not cfg.enable_bspn and "bspn" in item:
|
870 |
+
continue
|
871 |
+
if not cfg.enable_dspn and "dspn" in item:
|
872 |
+
continue
|
873 |
+
prev_np = utils.padSeqs(py_list, truncated=cfg.truncated, trunc_method="pre")
|
874 |
+
inputs[item + "_np"] = prev_np
|
875 |
+
if item in ["pv_resp", "pv_bspn"]:
|
876 |
+
inputs[item + "_unk_np"] = deepcopy(inputs[item + "_np"])
|
877 |
+
# <unk>, restrict vocab size to 3k, map ids>3k to <unk>
|
878 |
+
inputs[item + "_unk_np"][inputs[item + "_unk_np"] >= self.vocab_size] = 2
|
879 |
+
else:
|
880 |
+
inputs[item + "_unk_np"] = inputs[item + "_np"]
|
881 |
+
|
882 |
+
for item in ["user", "usdx", "resp", "bspn", "aspn", "bsdx", "dspn"]:
|
883 |
+
if not cfg.enable_aspn and item == "aspn":
|
884 |
+
continue
|
885 |
+
if not cfg.enable_bspn and item == "bspn":
|
886 |
+
continue
|
887 |
+
|
888 |
+
if not cfg.enable_dspn and item == "dspn":
|
889 |
+
continue
|
890 |
+
py_list = py_batch[item]
|
891 |
+
trunc_method = "post" if item == "resp" else "pre"
|
892 |
+
# max_length = cfg.max_nl_length if item in ['user', 'usdx', 'resp'] else cfg.max_span_length
|
893 |
+
inputs[item + "_np"] = utils.padSeqs(py_list, truncated=cfg.truncated, trunc_method=trunc_method)
|
894 |
+
if item in ["user", "usdx", "resp", "bspn"]:
|
895 |
+
inputs[item + "_unk_np"] = deepcopy(inputs[item + "_np"])
|
896 |
+
inputs[item + "_unk_np"][inputs[item + "_unk_np"] >= self.vocab_size] = 2 # <unk>
|
897 |
+
else:
|
898 |
+
inputs[item + "_unk_np"] = inputs[item + "_np"]
|
899 |
+
|
900 |
+
if cfg.multi_acts_training and cfg.mode == "train":
|
901 |
+
inputs["aspn_bidx"], multi_aspn = [], []
|
902 |
+
for bidx, aspn_type_list in enumerate(py_batch["aspn_aug"]):
|
903 |
+
if aspn_type_list:
|
904 |
+
for aspn_list in aspn_type_list:
|
905 |
+
random.shuffle(aspn_list)
|
906 |
+
# choose one random act span in each act type
|
907 |
+
aspn = aspn_list[0]
|
908 |
+
multi_aspn.append(aspn)
|
909 |
+
inputs["aspn_bidx"].append(bidx)
|
910 |
+
if cfg.multi_act_sampling_num > 1:
|
911 |
+
for i in range(cfg.multi_act_sampling_num):
|
912 |
+
if len(aspn_list) >= i + 2:
|
913 |
+
# choose one random act span in each act type
|
914 |
+
aspn = aspn_list[i + 1]
|
915 |
+
multi_aspn.append(aspn)
|
916 |
+
inputs["aspn_bidx"].append(bidx)
|
917 |
+
|
918 |
+
if multi_aspn:
|
919 |
+
inputs["aspn_aug_np"] = utils.padSeqs(multi_aspn, truncated=cfg.truncated, trunc_method="pre")
|
920 |
+
# [all available aspn num in the batch, T]
|
921 |
+
inputs["aspn_aug_unk_np"] = inputs["aspn_aug_np"]
|
922 |
+
|
923 |
+
inputs["db_np"] = np.array(py_batch["pointer"])
|
924 |
+
inputs["turn_domain"] = py_batch["turn_domain"]
|
925 |
+
|
926 |
+
return inputs
|
927 |
+
|
928 |
+
def wrap_result_lm(self, result_dict, eos_syntax=None):
|
929 |
+
results = []
|
930 |
+
eos_syntax = ontology.eos_tokens if not eos_syntax else eos_syntax
|
931 |
+
sos_syntax = ontology.sos_tokens
|
932 |
+
# ground truth bs, as, ds.. generate response
|
933 |
+
field = [
|
934 |
+
"dial_id",
|
935 |
+
"turn_num",
|
936 |
+
"user",
|
937 |
+
"bspn_gen",
|
938 |
+
"bsdx",
|
939 |
+
"resp_gen",
|
940 |
+
"resp",
|
941 |
+
"aspn_gen",
|
942 |
+
"aspn",
|
943 |
+
"dspn_gen",
|
944 |
+
"dspn",
|
945 |
+
"bspn",
|
946 |
+
"pointer",
|
947 |
+
]
|
948 |
+
|
949 |
+
for dial_id, turns in result_dict.items():
|
950 |
+
entry = {"dial_id": dial_id, "trun_num": len(turns)}
|
951 |
+
for f in field[2:]:
|
952 |
+
entry[f] = "" # ???
|
953 |
+
results.append(entry)
|
954 |
+
for turn_idx, turn in enumerate(turns):
|
955 |
+
entry = {"dial_id": dial_id}
|
956 |
+
for key in field:
|
957 |
+
if key in ["dial_id"]:
|
958 |
+
continue
|
959 |
+
v = turn.get(key, "")
|
960 |
+
if key == "turn_domain":
|
961 |
+
v = " ".join(v)
|
962 |
+
|
963 |
+
if key in eos_syntax and v != "":
|
964 |
+
# remove eos tokens
|
965 |
+
v = self.tokenizer.decode(v)
|
966 |
+
v = v.split()
|
967 |
+
# remove eos/sos in span
|
968 |
+
if eos_syntax[key] in v:
|
969 |
+
v.remove(eos_syntax[key])
|
970 |
+
if sos_syntax[key] in v:
|
971 |
+
v.remove(sos_syntax[key])
|
972 |
+
# if key != 'resp_gen':
|
973 |
+
# # remove eos/sos in span
|
974 |
+
# if eos_syntax[key] in v:
|
975 |
+
# v.remove(eos_syntax[key])
|
976 |
+
# if sos_syntax[key] in v:
|
977 |
+
# v.remove(sos_syntax[key])
|
978 |
+
# else: # 'resp_gen'
|
979 |
+
# sos_index = 0
|
980 |
+
# eos_index = -1
|
981 |
+
# if sos_syntax[key] in v:
|
982 |
+
# sos_index = v.index(sos_syntax[key])
|
983 |
+
# if eos_syntax[key] in v:
|
984 |
+
# eos_index = v.index(eos_syntax[key])
|
985 |
+
# else:
|
986 |
+
# pass # take too long
|
987 |
+
# # no <eos_r> found, stop at any eos_tokens
|
988 |
+
# # for i in range(sos_index+1, len(v)):
|
989 |
+
# # if v[i] in sos_syntax.values() or v[i] in eos_syntax.values():
|
990 |
+
# # eos_index = i
|
991 |
+
# v = v[sos_index+1: eos_index]
|
992 |
+
|
993 |
+
# v = self.tokenizer.convert_tokens_to_string(v)
|
994 |
+
v = " ".join(v)
|
995 |
+
else:
|
996 |
+
pass # v = v
|
997 |
+
entry[key] = v
|
998 |
+
|
999 |
+
results.append(entry)
|
1000 |
+
|
1001 |
+
return results, field
|
1002 |
+
|
1003 |
+
def wrap_result(self, result_dict, eos_syntax=None):
|
1004 |
+
decode_fn = self.vocab.sentence_decode
|
1005 |
+
results = []
|
1006 |
+
eos_syntax = ontology.eos_tokens if not eos_syntax else eos_syntax
|
1007 |
+
|
1008 |
+
if cfg.bspn_mode == "bspn":
|
1009 |
+
field = [
|
1010 |
+
"dial_id",
|
1011 |
+
"turn_num",
|
1012 |
+
"user",
|
1013 |
+
"bspn_gen",
|
1014 |
+
"bspn",
|
1015 |
+
"resp_gen",
|
1016 |
+
"resp",
|
1017 |
+
"aspn_gen",
|
1018 |
+
"aspn",
|
1019 |
+
"dspn_gen",
|
1020 |
+
"dspn",
|
1021 |
+
"pointer",
|
1022 |
+
]
|
1023 |
+
elif not cfg.enable_dst: # this
|
1024 |
+
field = [
|
1025 |
+
"dial_id",
|
1026 |
+
"turn_num",
|
1027 |
+
"user",
|
1028 |
+
"bsdx_gen",
|
1029 |
+
"bsdx",
|
1030 |
+
"resp_gen",
|
1031 |
+
"resp",
|
1032 |
+
"aspn_gen",
|
1033 |
+
"aspn",
|
1034 |
+
"dspn_gen",
|
1035 |
+
"dspn",
|
1036 |
+
"bspn",
|
1037 |
+
"pointer",
|
1038 |
+
]
|
1039 |
+
else:
|
1040 |
+
field = [
|
1041 |
+
"dial_id",
|
1042 |
+
"turn_num",
|
1043 |
+
"user",
|
1044 |
+
"bsdx_gen",
|
1045 |
+
"bsdx",
|
1046 |
+
"resp_gen",
|
1047 |
+
"resp",
|
1048 |
+
"aspn_gen",
|
1049 |
+
"aspn",
|
1050 |
+
"dspn_gen",
|
1051 |
+
"dspn",
|
1052 |
+
"bspn_gen",
|
1053 |
+
"bspn",
|
1054 |
+
"pointer",
|
1055 |
+
]
|
1056 |
+
if self.multi_acts_record is not None:
|
1057 |
+
field.insert(7, "multi_act_gen")
|
1058 |
+
|
1059 |
+
for dial_id, turns in result_dict.items():
|
1060 |
+
entry = {"dial_id": dial_id, "turn_num": len(turns)}
|
1061 |
+
for prop in field[2:]:
|
1062 |
+
entry[prop] = ""
|
1063 |
+
results.append(entry)
|
1064 |
+
for turn_no, turn in enumerate(turns):
|
1065 |
+
entry = {"dial_id": dial_id}
|
1066 |
+
for key in field:
|
1067 |
+
if key in ["dial_id"]:
|
1068 |
+
continue
|
1069 |
+
v = turn.get(key, "")
|
1070 |
+
if key == "turn_domain":
|
1071 |
+
v = " ".join(v)
|
1072 |
+
entry[key] = decode_fn(v, eos=eos_syntax[key]) if key in eos_syntax and v != "" else v
|
1073 |
+
results.append(entry)
|
1074 |
+
return results, field
|
1075 |
+
|
1076 |
+
def restore(self, resp, domain, constraint_dict, mat_ents):
|
1077 |
+
restored = resp
|
1078 |
+
|
1079 |
+
restored = restored.replace("[value_reference]", "53022")
|
1080 |
+
restored = restored.replace("[value_car]", "BMW")
|
1081 |
+
|
1082 |
+
# restored.replace('[value_phone]', '830-430-6666')
|
1083 |
+
for d in domain:
|
1084 |
+
constraint = constraint_dict.get(d, None)
|
1085 |
+
if constraint:
|
1086 |
+
if "stay" in constraint:
|
1087 |
+
restored = restored.replace("[value_stay]", constraint["stay"])
|
1088 |
+
if "day" in constraint:
|
1089 |
+
restored = restored.replace("[value_day]", constraint["day"])
|
1090 |
+
if "people" in constraint:
|
1091 |
+
restored = restored.replace("[value_people]", constraint["people"])
|
1092 |
+
if "time" in constraint:
|
1093 |
+
restored = restored.replace("[value_time]", constraint["time"])
|
1094 |
+
if "type" in constraint:
|
1095 |
+
restored = restored.replace("[value_type]", constraint["type"])
|
1096 |
+
if d in mat_ents and len(mat_ents[d]) == 0:
|
1097 |
+
for s in constraint:
|
1098 |
+
if s == "pricerange" and d in ["hotel", "restaurant"] and "price]" in restored:
|
1099 |
+
restored = restored.replace("[value_price]", constraint["pricerange"])
|
1100 |
+
if s + "]" in restored:
|
1101 |
+
restored = restored.replace("[value_%s]" % s, constraint[s])
|
1102 |
+
|
1103 |
+
if "[value_choice" in restored and mat_ents.get(d):
|
1104 |
+
restored = restored.replace("[value_choice]", str(len(mat_ents[d])))
|
1105 |
+
if "[value_choice" in restored:
|
1106 |
+
restored = restored.replace("[value_choice]", "3")
|
1107 |
+
|
1108 |
+
# restored.replace('[value_car]', 'BMW')
|
1109 |
+
|
1110 |
+
try:
|
1111 |
+
ent = mat_ents.get(domain[-1], [])
|
1112 |
+
if ent:
|
1113 |
+
ent = ent[0]
|
1114 |
+
|
1115 |
+
for t in restored.split():
|
1116 |
+
if "[value" in t:
|
1117 |
+
slot = t[7:-1]
|
1118 |
+
if ent.get(slot):
|
1119 |
+
if domain[-1] == "hotel" and slot == "price":
|
1120 |
+
slot = "pricerange"
|
1121 |
+
restored = restored.replace(t, ent[slot])
|
1122 |
+
elif slot == "price":
|
1123 |
+
if ent.get("pricerange"):
|
1124 |
+
restored = restored.replace(t, ent["pricerange"])
|
1125 |
+
else:
|
1126 |
+
print(restored, domain)
|
1127 |
+
except Exception:
|
1128 |
+
print(resp)
|
1129 |
+
print(restored)
|
1130 |
+
quit()
|
1131 |
+
|
1132 |
+
restored = restored.replace("[value_phone]", "62781111")
|
1133 |
+
restored = restored.replace("[value_postcode]", "CG9566")
|
1134 |
+
restored = restored.replace("[value_address]", "Parkside, Cambridge")
|
1135 |
+
|
1136 |
+
# if '[value_' in restored:
|
1137 |
+
|
1138 |
+
# print(domain)
|
1139 |
+
# # print(mat_ents)
|
1140 |
+
# print(resp)
|
1141 |
+
# print(restored)
|
1142 |
+
return restored
|
1143 |
+
|
1144 |
+
def record_utterance(self, result_dict):
|
1145 |
+
decode_fn = self.vocab.sentence_decode
|
1146 |
+
|
1147 |
+
ordered_dial = {}
|
1148 |
+
for dial_id, turns in result_dict.items():
|
1149 |
+
diverse = 0
|
1150 |
+
turn_count = 0
|
1151 |
+
for turn_no, turn in enumerate(turns):
|
1152 |
+
act_collect = {}
|
1153 |
+
act_type_collect = {}
|
1154 |
+
slot_score = 0
|
1155 |
+
for i in range(cfg.nbest):
|
1156 |
+
aspn = decode_fn(turn["multi_act"][i], eos=ontology.eos_tokens["aspn"])
|
1157 |
+
pred_acts = self.aspan_to_act_list(" ".join(aspn))
|
1158 |
+
act_type = ""
|
1159 |
+
for act in pred_acts:
|
1160 |
+
d, a, s = act.split("-")
|
1161 |
+
if d + "-" + a not in act_collect:
|
1162 |
+
act_collect[d + "-" + a] = {s: 1}
|
1163 |
+
slot_score += 1
|
1164 |
+
act_type += d + "-" + a + ";"
|
1165 |
+
elif s not in act_collect:
|
1166 |
+
act_collect[d + "-" + a][s] = 1
|
1167 |
+
slot_score += 1
|
1168 |
+
act_type_collect[act_type] = 1
|
1169 |
+
turn_count += 1
|
1170 |
+
diverse += len(act_collect) * 3 + slot_score
|
1171 |
+
ordered_dial[dial_id] = diverse / turn_count
|
1172 |
+
|
1173 |
+
ordered_dial = sorted(ordered_dial.keys(), key=lambda x: -ordered_dial[x])
|
1174 |
+
|
1175 |
+
dialog_record = {}
|
1176 |
+
|
1177 |
+
with open(cfg.eval_load_path + "/dialogue_record.csv", "w") as rf:
|
1178 |
+
writer = csv.writer(rf)
|
1179 |
+
|
1180 |
+
for dial_id in ordered_dial:
|
1181 |
+
dialog_record[dial_id] = []
|
1182 |
+
turns = result_dict[dial_id]
|
1183 |
+
writer.writerow([dial_id])
|
1184 |
+
for turn_no, turn in enumerate(turns):
|
1185 |
+
user = decode_fn(turn["user"], eos=ontology.eos_tokens["user"])
|
1186 |
+
bspn = decode_fn(turn["bspn"], eos=ontology.eos_tokens["bspn"])
|
1187 |
+
aspn = decode_fn(turn["aspn"], eos=ontology.eos_tokens["aspn"])
|
1188 |
+
resp = decode_fn(turn["resp"], eos=ontology.eos_tokens["resp"])
|
1189 |
+
constraint_dict = self.bspan_to_constraint_dict(bspn)
|
1190 |
+
# print(constraint_dict)
|
1191 |
+
mat_ents = self.db.get_match_num(constraint_dict, True)
|
1192 |
+
domain = [i[1:-1] for i in self.dspan_to_domain(turn["dspn"]).keys()]
|
1193 |
+
restored = self.restore(resp, domain, constraint_dict, mat_ents)
|
1194 |
+
writer.writerow([turn_no, user, turn["pointer"], domain, restored, resp])
|
1195 |
+
turn_record = {
|
1196 |
+
"user": user,
|
1197 |
+
"bspn": bspn,
|
1198 |
+
"aspn": aspn,
|
1199 |
+
"dom": domain,
|
1200 |
+
"resp": resp,
|
1201 |
+
"resp_res": restored,
|
1202 |
+
}
|
1203 |
+
|
1204 |
+
resp_col = []
|
1205 |
+
aspn_col = []
|
1206 |
+
resp_restore_col = []
|
1207 |
+
for i in range(cfg.nbest):
|
1208 |
+
aspn = decode_fn(turn["multi_act"][i], eos=ontology.eos_tokens["aspn"])
|
1209 |
+
resp = decode_fn(turn["multi_resp"][i], eos=ontology.eos_tokens["resp"])
|
1210 |
+
|
1211 |
+
restored = self.restore(resp, domain, constraint_dict, mat_ents)
|
1212 |
+
resp_col.append(resp)
|
1213 |
+
resp_restore_col.append(restored)
|
1214 |
+
aspn_col.append(aspn)
|
1215 |
+
|
1216 |
+
zipped = list(zip(resp_restore_col, resp_col, aspn_col))
|
1217 |
+
zipped.sort(key=lambda s: len(s[0]))
|
1218 |
+
resp_restore_col = list(list(zip(*zipped))[0])
|
1219 |
+
aspn_col = list(list(zip(*zipped))[2])
|
1220 |
+
resp_col = list(list(zip(*zipped))[1])
|
1221 |
+
turn_record["aspn_col"] = aspn_col
|
1222 |
+
turn_record["resp_col"] = resp_col
|
1223 |
+
turn_record["resp_res_col"] = resp_restore_col
|
1224 |
+
for i in range(cfg.nbest):
|
1225 |
+
# aspn = decode_fn(turn['multi_act'][i], eos=ontology.eos_tokens['aspn'])
|
1226 |
+
resp = resp_col[i]
|
1227 |
+
aspn = aspn_col[i]
|
1228 |
+
resp_restore = resp_restore_col[i]
|
1229 |
+
|
1230 |
+
writer.writerow(["", resp_restore, resp, aspn])
|
1231 |
+
|
1232 |
+
dialog_record[dial_id].append(turn_record)
|
1233 |
+
|
1234 |
+
# json.dump(dialog_record, open(cfg.eval_load_path + '/resultdict.json','w'))
|
1235 |
+
|
1236 |
+
|
1237 |
+
if __name__ == "__main__":
|
1238 |
+
reader = MultiWozReader(GPT2Tokenizer)
|
1239 |
+
# for aspan in ["[general] [bye] [welcome] <eos_a>","[train] [inform] trainid destination \
|
1240 |
+
# arrive leave [offerbook] [general] [reqmore] <eos_a>",]:
|
1241 |
+
# act = reader.aspan_to_constraint_dict(aspan.split())
|
1242 |
+
# print('!!!')
|
1243 |
+
# print(act)
|
1244 |
+
|
1245 |
+
for bspan in [
|
1246 |
+
"[taxi] destination golden house departure broughton house gallery arrive 19:30 [attraction]"
|
1247 |
+
+ " type museum name whipple museum of the history of science people 5 day monday",
|
1248 |
+
"[taxi] destination golden house departure broughton house gallery arrive 19:30 [attraction]"
|
1249 |
+
+ " type museum name whipple museum of the history of science people 5 day monday <eos_b>",
|
1250 |
+
]:
|
1251 |
+
encoded = reader.vocab.sentence_encode(bspan.split())
|
1252 |
+
print(encoded)
|
1253 |
+
cons = reader.bspan_to_constraint_dict(encoded, bspn_mode="bspn")
|
1254 |
+
print(cons)
|
1255 |
+
for bspan in [
|
1256 |
+
"[taxi] destination departure leave [hotel] name [attraction] name people day",
|
1257 |
+
"[taxi] destination departure leave [hotel] name [attraction] name people day <eos_b>",
|
1258 |
+
]:
|
1259 |
+
encoded = reader.vocab.sentence_encode(bspan.split())
|
1260 |
+
print(encoded)
|
1261 |
+
cons = reader.bspan_to_constraint_dict(encoded, bspn_mode="bsdx")
|
1262 |
+
print(cons)
|
src/crazyneuraluser/UBAR_code/utils.py
ADDED
@@ -0,0 +1,292 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import logging
|
3 |
+
from collections import OrderedDict
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
from crazyneuraluser.UBAR_code import ontology
|
8 |
+
|
9 |
+
|
10 |
+
def py2np(list):
|
11 |
+
return np.array(list)
|
12 |
+
|
13 |
+
|
14 |
+
def write_dict(fn, dic):
|
15 |
+
with open(fn, "w") as f:
|
16 |
+
json.dump(dic, f, indent=2)
|
17 |
+
|
18 |
+
|
19 |
+
def f1_score(label_list, pred_list):
|
20 |
+
tp = len([t for t in pred_list if t in label_list])
|
21 |
+
fp = max(0, len(pred_list) - tp)
|
22 |
+
fn = max(0, len(label_list) - tp)
|
23 |
+
precision = tp / (tp + fp + 1e-10)
|
24 |
+
recall = tp / (tp + fn + 1e-10)
|
25 |
+
f1 = 2 * precision * recall / (precision + recall + 1e-10)
|
26 |
+
return f1
|
27 |
+
|
28 |
+
|
29 |
+
class Vocab(object):
|
30 |
+
def __init__(self, vocab_size=0):
|
31 |
+
self.vocab_size = vocab_size
|
32 |
+
self.vocab_size_oov = 0 # get after construction
|
33 |
+
self._idx2word = {} # word + oov
|
34 |
+
self._word2idx = {} # word
|
35 |
+
self._freq_dict = {} # word + oov
|
36 |
+
for w in [
|
37 |
+
"<pad>",
|
38 |
+
"<go_r>",
|
39 |
+
"<unk>",
|
40 |
+
"<go_b>",
|
41 |
+
"<go_a>",
|
42 |
+
"<eos_u>",
|
43 |
+
"<eos_r>",
|
44 |
+
"<eos_b>",
|
45 |
+
"<eos_a>",
|
46 |
+
"<go_d>",
|
47 |
+
"<eos_d>",
|
48 |
+
]:
|
49 |
+
self._absolute_add_word(w)
|
50 |
+
|
51 |
+
def _absolute_add_word(self, w):
|
52 |
+
idx = len(self._idx2word)
|
53 |
+
self._idx2word[idx] = w
|
54 |
+
self._word2idx[w] = idx
|
55 |
+
|
56 |
+
def add_word(self, word):
|
57 |
+
if word not in self._freq_dict:
|
58 |
+
self._freq_dict[word] = 0
|
59 |
+
self._freq_dict[word] += 1
|
60 |
+
|
61 |
+
def has_word(self, word):
|
62 |
+
return self._freq_dict.get(word)
|
63 |
+
|
64 |
+
def _add_to_vocab(self, word):
|
65 |
+
if word not in self._word2idx:
|
66 |
+
idx = len(self._idx2word)
|
67 |
+
self._idx2word[idx] = word
|
68 |
+
self._word2idx[word] = idx
|
69 |
+
|
70 |
+
def construct(self):
|
71 |
+
decoded = sorted(self._freq_dict.keys(), key=lambda x: -self._freq_dict[x])
|
72 |
+
print("Vocabulary size including oov: %d" % (len(decoded) + len(self._idx2word)))
|
73 |
+
if len(decoded) + len(self._idx2word) < self.vocab_size:
|
74 |
+
logging.warning(
|
75 |
+
"actual label set smaller than that configured: {}/{}".format(
|
76 |
+
len(decoded) + len(self._idx2word), self.vocab_size
|
77 |
+
)
|
78 |
+
)
|
79 |
+
for word in ontology.all_domains + ["general"]:
|
80 |
+
word = "[" + word + "]"
|
81 |
+
self._add_to_vocab(word)
|
82 |
+
for word in ontology.all_acts:
|
83 |
+
word = "[" + word + "]"
|
84 |
+
self._add_to_vocab(word)
|
85 |
+
for word in ontology.all_slots:
|
86 |
+
self._add_to_vocab(word)
|
87 |
+
for word in decoded:
|
88 |
+
if word.startswith("[value_") and word.endswith("]"):
|
89 |
+
self._add_to_vocab(word)
|
90 |
+
for word in decoded:
|
91 |
+
self._add_to_vocab(word)
|
92 |
+
self.vocab_size_oov = len(self._idx2word)
|
93 |
+
|
94 |
+
def load_vocab(self, vocab_path):
|
95 |
+
self._freq_dict = json.loads(open(vocab_path + ".freq.json", "r").read())
|
96 |
+
self._word2idx = json.loads(open(vocab_path + ".word2idx.json", "r").read())
|
97 |
+
self._idx2word = {}
|
98 |
+
for w, idx in self._word2idx.items():
|
99 |
+
self._idx2word[idx] = w
|
100 |
+
self.vocab_size_oov = len(self._idx2word)
|
101 |
+
print('vocab file loaded from "' + vocab_path + '"')
|
102 |
+
print("Vocabulary size including oov: %d" % (self.vocab_size_oov))
|
103 |
+
|
104 |
+
def save_vocab(self, vocab_path):
|
105 |
+
_freq_dict = OrderedDict(sorted(self._freq_dict.items(), key=lambda kv: kv[1], reverse=True))
|
106 |
+
|
107 |
+
write_dict(vocab_path + ".word2idx.json", self._word2idx)
|
108 |
+
write_dict(vocab_path + ".freq.json", _freq_dict)
|
109 |
+
|
110 |
+
def encode(self, word, include_oov=True):
|
111 |
+
if include_oov:
|
112 |
+
if self._word2idx.get(word, None) is None:
|
113 |
+
raise ValueError("Unknown word: %s. Vocabulary should include oovs here." % word)
|
114 |
+
return self._word2idx[word]
|
115 |
+
else:
|
116 |
+
word = "<unk>" if word not in self._word2idx else word
|
117 |
+
return self._word2idx[word]
|
118 |
+
|
119 |
+
def sentence_encode(self, word_list):
|
120 |
+
return [self.encode(_) for _ in word_list]
|
121 |
+
|
122 |
+
def oov_idx_map(self, idx):
|
123 |
+
return 2 if idx > self.vocab_size else idx
|
124 |
+
|
125 |
+
def sentence_oov_map(self, index_list):
|
126 |
+
return [self.oov_idx_map(_) for _ in index_list]
|
127 |
+
|
128 |
+
def decode(self, idx, indicate_oov=False):
|
129 |
+
if not self._idx2word.get(idx):
|
130 |
+
raise ValueError("Error idx: %d. Vocabulary should include oovs here." % idx)
|
131 |
+
if not indicate_oov or idx < self.vocab_size:
|
132 |
+
return self._idx2word[idx]
|
133 |
+
else:
|
134 |
+
return self._idx2word[idx] + "(o)"
|
135 |
+
|
136 |
+
def sentence_decode(self, index_list, eos=None, indicate_oov=False):
|
137 |
+
decoded = [self.decode(_, indicate_oov) for _ in index_list]
|
138 |
+
if not eos or eos not in decoded:
|
139 |
+
return " ".join(decoded)
|
140 |
+
else:
|
141 |
+
idx = decoded.index(eos)
|
142 |
+
return " ".join(decoded[:idx])
|
143 |
+
|
144 |
+
def nl_decode(self, decoded, eos=None):
|
145 |
+
return [self.sentence_decode(_, eos) + "\n" for _ in decoded]
|
146 |
+
|
147 |
+
|
148 |
+
def padSeqs_gpt(sequences, pad_id, maxlen=None):
|
149 |
+
lengths = []
|
150 |
+
for x in sequences:
|
151 |
+
lengths.append(len(x))
|
152 |
+
|
153 |
+
num_samples = len(sequences)
|
154 |
+
seq_mexlen = np.max(lengths)
|
155 |
+
|
156 |
+
# maxlen = 1024
|
157 |
+
if seq_mexlen > 1024: # gpt2.n_ctx
|
158 |
+
# print('maxlen exceeds 1024')
|
159 |
+
maxlen = 1024
|
160 |
+
else:
|
161 |
+
maxlen = seq_mexlen
|
162 |
+
|
163 |
+
# tokenizer.encode('<|endoftext|>') = ['50256']
|
164 |
+
# All labels set to ``-100`` are ignored (masked), the loss is only
|
165 |
+
# computed for labels in ``[0, ..., config.vocab_size]`` (from modeling_gpt2.GPT2LMHeadModel)
|
166 |
+
|
167 |
+
x = np.ones((num_samples, maxlen)) * pad_id
|
168 |
+
for idx, s in enumerate(sequences):
|
169 |
+
if not len(s):
|
170 |
+
print("empty list was found in padSeqs")
|
171 |
+
# trunc method = 'pre'
|
172 |
+
trunc = s[-maxlen:]
|
173 |
+
trunc = np.asarray(trunc)
|
174 |
+
|
175 |
+
# pad method = 'post'
|
176 |
+
x[idx, : len(trunc)] = trunc
|
177 |
+
|
178 |
+
return x, lengths
|
179 |
+
|
180 |
+
|
181 |
+
def padSeqs(
|
182 |
+
sequences,
|
183 |
+
maxlen=None,
|
184 |
+
truncated=False,
|
185 |
+
pad_method="post",
|
186 |
+
trunc_method="pre",
|
187 |
+
dtype="int32",
|
188 |
+
value=0.0,
|
189 |
+
):
|
190 |
+
if not hasattr(sequences, "__len__"):
|
191 |
+
raise ValueError("`sequences` must be iterable.")
|
192 |
+
lengths = []
|
193 |
+
for x in sequences:
|
194 |
+
if not hasattr(x, "__len__"):
|
195 |
+
raise ValueError("`sequences` must be a list of iterables. " "Found non-iterable: " + str(x))
|
196 |
+
lengths.append(len(x))
|
197 |
+
|
198 |
+
num_samples = len(sequences)
|
199 |
+
seq_maxlen = np.max(lengths)
|
200 |
+
|
201 |
+
if maxlen is not None and truncated:
|
202 |
+
maxlen = min(seq_maxlen, maxlen)
|
203 |
+
else:
|
204 |
+
maxlen = seq_maxlen
|
205 |
+
# take the sample shape from the first non empty sequence
|
206 |
+
# checking for consistency in the main loop below.
|
207 |
+
sample_shape = tuple()
|
208 |
+
for s in sequences:
|
209 |
+
if len(s) > 0:
|
210 |
+
sample_shape = np.asarray(s).shape[1:]
|
211 |
+
break
|
212 |
+
|
213 |
+
x = (np.ones((num_samples, maxlen) + sample_shape) * value).astype(dtype)
|
214 |
+
for idx, s in enumerate(sequences):
|
215 |
+
if not len(s):
|
216 |
+
print("empty list/array was found")
|
217 |
+
continue # empty list/array was found
|
218 |
+
if trunc_method == "pre":
|
219 |
+
trunc = s[-maxlen:]
|
220 |
+
elif trunc_method == "post":
|
221 |
+
trunc = s[:maxlen]
|
222 |
+
else:
|
223 |
+
raise ValueError('Truncating type "%s" not understood' % trunc_method)
|
224 |
+
|
225 |
+
# check `trunc` has expected shape
|
226 |
+
trunc = np.asarray(trunc, dtype=dtype)
|
227 |
+
if trunc.shape[1:] != sample_shape:
|
228 |
+
raise ValueError(
|
229 |
+
"Shape of sample %s of sequence at position %s is different from expected shape %s"
|
230 |
+
% (trunc.shape[1:], idx, sample_shape)
|
231 |
+
)
|
232 |
+
|
233 |
+
if pad_method == "post":
|
234 |
+
x[idx, : len(trunc)] = trunc
|
235 |
+
elif pad_method == "pre":
|
236 |
+
x[idx, -len(trunc) :] = trunc
|
237 |
+
else:
|
238 |
+
raise ValueError('Padding type "%s" not understood' % pad_method)
|
239 |
+
return x
|
240 |
+
|
241 |
+
|
242 |
+
def get_glove_matrix(glove_path, vocab, initial_embedding_np):
|
243 |
+
"""
|
244 |
+
return a glove embedding matrix
|
245 |
+
:param self:
|
246 |
+
:param glove_file:
|
247 |
+
:param initial_embedding_np:
|
248 |
+
:return: np array of [V,E]
|
249 |
+
"""
|
250 |
+
ef = open(glove_path, "r", encoding="UTF-8")
|
251 |
+
cnt = 0
|
252 |
+
vec_array = initial_embedding_np
|
253 |
+
old_avg = np.average(vec_array)
|
254 |
+
old_std = np.std(vec_array)
|
255 |
+
vec_array = vec_array.astype(np.float32)
|
256 |
+
new_avg, new_std = 0, 0
|
257 |
+
|
258 |
+
for line in ef.readlines():
|
259 |
+
line = line.strip().split(" ")
|
260 |
+
word, vec = line[0], line[1:]
|
261 |
+
vec = np.array(vec, np.float32)
|
262 |
+
if not vocab.has_word(word):
|
263 |
+
continue
|
264 |
+
word_idx = vocab.encode(word)
|
265 |
+
if word_idx < vocab.vocab_size:
|
266 |
+
cnt += 1
|
267 |
+
vec_array[word_idx] = vec
|
268 |
+
new_avg += np.average(vec)
|
269 |
+
new_std += np.std(vec)
|
270 |
+
new_avg /= cnt
|
271 |
+
new_std /= cnt
|
272 |
+
ef.close()
|
273 |
+
logging.info(
|
274 |
+
"%d known embedding. old mean: %f new mean %f, old std %f new std %f"
|
275 |
+
% (cnt, old_avg, new_avg, old_std, new_std)
|
276 |
+
)
|
277 |
+
return vec_array
|
278 |
+
|
279 |
+
|
280 |
+
def position_encoding_init(self, n_position, d_pos_vec):
|
281 |
+
position_enc = np.array(
|
282 |
+
[
|
283 |
+
[pos / np.power(10000, 2 * (j // 2) / d_pos_vec) for j in range(d_pos_vec)]
|
284 |
+
if pos != 0
|
285 |
+
else np.zeros(d_pos_vec)
|
286 |
+
for pos in range(n_position)
|
287 |
+
]
|
288 |
+
)
|
289 |
+
|
290 |
+
position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i
|
291 |
+
position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1
|
292 |
+
return position_enc
|
src/crazyneuraluser/user_model_code/analysis_multiwoz.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
DATA_SPLIT = ["train", "dev", "test"]
|
5 |
+
|
6 |
+
|
7 |
+
def _check_n_turns(data, data_act):
|
8 |
+
for split in DATA_SPLIT:
|
9 |
+
for dial_id, meta in data[split].items():
|
10 |
+
n_in_meta = len(meta["turns"])
|
11 |
+
|
12 |
+
assert dial_id in data_act
|
13 |
+
n_in_act = len(data_act[dial_id])
|
14 |
+
assert n_in_meta == n_in_act
|
15 |
+
|
16 |
+
|
17 |
+
def collect_data(data_path, remove_dial_switch=False):
|
18 |
+
# load act
|
19 |
+
act_file = os.path.join(data_path, "dialog_acts.json")
|
20 |
+
with open(act_file) as f:
|
21 |
+
data_act = json.load(f)
|
22 |
+
print("Load {} dialogues in act file".format(len(data_act)))
|
23 |
+
|
24 |
+
# load data
|
25 |
+
data = {}
|
26 |
+
for split in DATA_SPLIT:
|
27 |
+
data[split] = iter_data_folder(data_path, split, remove_dial_switch, data_act)
|
28 |
+
|
29 |
+
_check_n_turns(data, data_act)
|
30 |
+
return data, data_act
|
31 |
+
|
32 |
+
|
33 |
+
def remove_dial(dial_id, dial, dial_act):
|
34 |
+
# check services
|
35 |
+
services = dial["services"]
|
36 |
+
if "police" in services or "bus" in services or "hospital" in services:
|
37 |
+
return True
|
38 |
+
|
39 |
+
# check act
|
40 |
+
domains = set()
|
41 |
+
for turn_id, turn_act in dial_act.items():
|
42 |
+
dialogue_act = turn_act["dialog_act"]
|
43 |
+
for dact in dialogue_act:
|
44 |
+
assert "-" in dact
|
45 |
+
domain, act = dact.split("-")
|
46 |
+
domains.add(domain)
|
47 |
+
if "Police" in domains or "Bus" in domains or "Hospital" in domains:
|
48 |
+
return True
|
49 |
+
return False
|
50 |
+
|
51 |
+
|
52 |
+
def iter_data_folder(data_path, split, remove_dial_switch, data_act):
|
53 |
+
"""Iterate data folder"""
|
54 |
+
split_dir = os.path.join(data_path, split)
|
55 |
+
data_split = {}
|
56 |
+
remove_dial_ids = []
|
57 |
+
total_dial_ids = []
|
58 |
+
for f in os.listdir(split_dir):
|
59 |
+
if not f.startswith("dialogues"): # skip schema.json
|
60 |
+
continue
|
61 |
+
file_path = os.path.join(data_path, split, f)
|
62 |
+
iter_file(
|
63 |
+
file_path,
|
64 |
+
data_split,
|
65 |
+
remove_dial_ids,
|
66 |
+
total_dial_ids,
|
67 |
+
remove_dial_switch,
|
68 |
+
data_act,
|
69 |
+
)
|
70 |
+
print(
|
71 |
+
"Done collecting {} | total {} dialogues | load {} dialogues | remove {} dialogues".format(
|
72 |
+
split, len(total_dial_ids), len(data_split), len(remove_dial_ids)
|
73 |
+
)
|
74 |
+
)
|
75 |
+
return data_split
|
76 |
+
|
77 |
+
|
78 |
+
def iter_file(
|
79 |
+
file_path, data_split, remove_dial_ids, total_dial_ids, remove_dial_switch, data_act
|
80 |
+
):
|
81 |
+
with open(file_path) as f:
|
82 |
+
data_in = json.load(f) # list of dialouges in a json file
|
83 |
+
|
84 |
+
for dial in data_in:
|
85 |
+
dial_id = dial["dialogue_id"]
|
86 |
+
total_dial_ids.append(dial_id)
|
87 |
+
dial_act = data_act[dial_id]
|
88 |
+
|
89 |
+
if remove_dial_switch and remove_dial(dial_id, dial, dial_act):
|
90 |
+
remove_dial_ids.append(dial_id)
|
91 |
+
else:
|
92 |
+
data_split[dial_id] = dial
|
93 |
+
|
94 |
+
|
95 |
+
def show_dial(dial_id, data, data_act):
|
96 |
+
def simple_linearise_act(dialouge_act):
|
97 |
+
linear_act = ""
|
98 |
+
for domain_act, slot_value_list in dialouge_act.items():
|
99 |
+
linear_act += domain_act + " "
|
100 |
+
for slot_value in slot_value_list:
|
101 |
+
slot, value = slot_value[0], slot_value[1]
|
102 |
+
linear_act += slot + " "
|
103 |
+
linear_act += value + " "
|
104 |
+
return linear_act
|
105 |
+
|
106 |
+
split = None
|
107 |
+
for data_split in DATA_SPLIT:
|
108 |
+
if dial_id in data[data_split]:
|
109 |
+
split = data_split
|
110 |
+
break
|
111 |
+
|
112 |
+
print("dial_id: {}".format(dial_id))
|
113 |
+
for turn_id, turn in enumerate(data[split][dial_id]["turns"]):
|
114 |
+
dialouge_act = data_act[dial_id][str(turn_id)]["dialog_act"]
|
115 |
+
linear_act = simple_linearise_act(dialouge_act)
|
116 |
+
print("-----" * 15)
|
117 |
+
print("turn_id: {}, spk: {}".format(turn_id, turn["speaker"]))
|
118 |
+
print("act: |{}|".format(linear_act))
|
119 |
+
print("utt: |{}|".format(turn["utterance"]))
|
src/crazyneuraluser/user_model_code/analysis_sgd.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
from utils_sgd import (
|
5 |
+
bcolors,
|
6 |
+
compare_slot_values_in_state,
|
7 |
+
dict2str,
|
8 |
+
get_turn_act,
|
9 |
+
list2str,
|
10 |
+
)
|
11 |
+
|
12 |
+
""" This file contains some utilities for analysis and parsing SGD """
|
13 |
+
|
14 |
+
DATA_SPLIT = ["train", "dev", "test"]
|
15 |
+
|
16 |
+
|
17 |
+
def collect_data(data_path, remove_dial_switch=False):
|
18 |
+
data = {}
|
19 |
+
for split in DATA_SPLIT:
|
20 |
+
data[split] = iter_data_folder(data_path, split, remove_dial_switch)
|
21 |
+
return data
|
22 |
+
|
23 |
+
|
24 |
+
def _remove_dial(dial_id, dial):
|
25 |
+
# remove_flag = False
|
26 |
+
# removes service `Homes_2` in test set as the slot `intent` is the same name as the user intent,
|
27 |
+
# which causes problem in goal preparation
|
28 |
+
if "Homes_2" in dial["services"]:
|
29 |
+
return True
|
30 |
+
return False
|
31 |
+
|
32 |
+
|
33 |
+
def iter_data_folder(data_path, split, remove_dial_switch):
|
34 |
+
"""Iterate data split folder"""
|
35 |
+
split_dir = os.path.join(data_path, split)
|
36 |
+
data_split = {}
|
37 |
+
remove_dial_ids = []
|
38 |
+
total_dial_ids = []
|
39 |
+
for f in os.listdir(split_dir):
|
40 |
+
if not f.startswith("dialogues"): # skip schema.json
|
41 |
+
continue
|
42 |
+
file_path = os.path.join(data_path, split, f)
|
43 |
+
iter_file(
|
44 |
+
file_path, data_split, remove_dial_ids, total_dial_ids, remove_dial_switch
|
45 |
+
)
|
46 |
+
print(
|
47 |
+
"Done collecting {} | total {} dialogues | load {} dialogues | remove {} dialogues".format(
|
48 |
+
split, len(total_dial_ids), len(data_split), len(remove_dial_ids)
|
49 |
+
)
|
50 |
+
)
|
51 |
+
return data_split
|
52 |
+
|
53 |
+
|
54 |
+
def iter_file(
|
55 |
+
file_path, data_split, remove_dial_ids, total_dial_ids, remove_dial_switch
|
56 |
+
):
|
57 |
+
"""Iterate data file"""
|
58 |
+
with open(file_path) as f:
|
59 |
+
data_in = json.load(f) # list of dialouges in a json file
|
60 |
+
|
61 |
+
for dial in data_in:
|
62 |
+
dial_id = dial["dialogue_id"]
|
63 |
+
total_dial_ids.append(dial_id)
|
64 |
+
|
65 |
+
if remove_dial_switch and _remove_dial(dial_id, dial):
|
66 |
+
remove_dial_ids.append(dial_id)
|
67 |
+
else:
|
68 |
+
data_split[dial_id] = dial
|
69 |
+
|
70 |
+
|
71 |
+
def check_multiple_services_per_turn(data):
|
72 |
+
for split in DATA_SPLIT:
|
73 |
+
for dial_id in sorted(data[split].keys()):
|
74 |
+
dial = data[split][dial_id]
|
75 |
+
for turn_id, turn in enumerate(dial["turns"]):
|
76 |
+
frames = turn["frames"]
|
77 |
+
if len(frames) > 1:
|
78 |
+
print(split, dial_id, turn_id, turn["utterance"])
|
79 |
+
|
80 |
+
|
81 |
+
def show_actions(actions):
|
82 |
+
for action_id, action in enumerate(actions):
|
83 |
+
act, slot, values = action["act"], action["slot"], action["values"]
|
84 |
+
print(
|
85 |
+
f"====> ACTION | Act {action_id}: {bcolors.RED}{act}{bcolors.ENDC}, \
|
86 |
+
slot: {bcolors.YELLOW}{slot}{bcolors.ENDC}, values: {bcolors.GREEN}{values}{bcolors.ENDC}"
|
87 |
+
)
|
88 |
+
|
89 |
+
|
90 |
+
def show_user_state(frame):
|
91 |
+
state = frame["state"]
|
92 |
+
active_intent = state["active_intent"]
|
93 |
+
req_slots = list2str(state["requested_slots"])
|
94 |
+
slot2value = dict2str(state["slot_values"], colored=True)
|
95 |
+
print(
|
96 |
+
"====> STATE | intent: {}, req_slots: {}, slot2value: {}".format(
|
97 |
+
active_intent, req_slots, slot2value
|
98 |
+
)
|
99 |
+
)
|
100 |
+
|
101 |
+
|
102 |
+
def show_service_call(frame):
|
103 |
+
if "service_call" not in frame:
|
104 |
+
return
|
105 |
+
# system calls api
|
106 |
+
service_call, service_results = frame["service_call"], frame["service_results"]
|
107 |
+
print(
|
108 |
+
"====> API call | method: {}, args: {}, results: {}".format(
|
109 |
+
service_call["method"],
|
110 |
+
dict2str(service_call["parameters"]),
|
111 |
+
len(service_results),
|
112 |
+
)
|
113 |
+
)
|
114 |
+
|
115 |
+
|
116 |
+
def show_frame(spk, frame_id, frame):
|
117 |
+
service = frame["service"]
|
118 |
+
print("==> Frame_id: {}, service: {}".format(frame_id, service))
|
119 |
+
|
120 |
+
# actions (include all slots)
|
121 |
+
show_actions(frame["actions"])
|
122 |
+
|
123 |
+
# slots (only provide non-categorical slots with word span boundaries)
|
124 |
+
if spk == "USER":
|
125 |
+
show_user_state(frame)
|
126 |
+
else: # system
|
127 |
+
show_service_call(frame)
|
128 |
+
|
129 |
+
|
130 |
+
def show_turn(turn_id, turn):
|
131 |
+
if turn is None:
|
132 |
+
return
|
133 |
+
|
134 |
+
frames = turn["frames"]
|
135 |
+
spk = turn["speaker"]
|
136 |
+
utt = turn["utterance"]
|
137 |
+
assert spk in ["USER", "SYSTEM"]
|
138 |
+
print(f"{spk}: {bcolors.UNDERLINE}{utt}{bcolors.ENDC}")
|
139 |
+
for frame_id, frame in enumerate(frames):
|
140 |
+
show_frame(spk, frame_id, frame)
|
141 |
+
print("------" * 15)
|
142 |
+
|
143 |
+
|
144 |
+
def show_dial_info(dial_id, dial):
|
145 |
+
print("\n")
|
146 |
+
print("******" * 15)
|
147 |
+
print("Dialogue={} | Service={}".format(dial_id, list2str(dial["services"])))
|
148 |
+
print("******" * 15)
|
149 |
+
|
150 |
+
|
151 |
+
def show_dial(dial_id, dial):
|
152 |
+
show_dial_info(dial_id, dial)
|
153 |
+
for turn_id, turn in enumerate(dial["turns"]):
|
154 |
+
show_turn(turn_id, turn)
|
155 |
+
|
156 |
+
|
157 |
+
def show_data(data):
|
158 |
+
for split in DATA_SPLIT:
|
159 |
+
for dial_id in sorted(data[split].keys()):
|
160 |
+
dial = data[split][dial_id]
|
161 |
+
show_dial(dial_id, dial)
|
162 |
+
input("press...")
|
163 |
+
|
164 |
+
|
165 |
+
def identify_scenarios(data):
|
166 |
+
"""
|
167 |
+
According to dataset paper, a scenario is a sequence of intents, seeded at the start of a conversation
|
168 |
+
to the user agent
|
169 |
+
"""
|
170 |
+
# TODO: deal with NONE intent, check the # of intent seq conbinations
|
171 |
+
for split in DATA_SPLIT:
|
172 |
+
scenario2dialogues = {}
|
173 |
+
n_scenario_max, n_scenario_min = 0, 100
|
174 |
+
for dial_id in sorted(data[split].keys()):
|
175 |
+
dial = data[split][dial_id]
|
176 |
+
scenario = []
|
177 |
+
for turn in dial["turns"]:
|
178 |
+
if turn["speaker"] == "SYSTEM":
|
179 |
+
continue
|
180 |
+
# USER turn
|
181 |
+
# it's fine to consider only first frame (service) if the turn is at the bounrary between two services
|
182 |
+
frame = turn["frames"][0]
|
183 |
+
intent = frame["state"]["active_intent"]
|
184 |
+
if intent == "NONE":
|
185 |
+
continue
|
186 |
+
if len(scenario) == 0 or intent != scenario[-1]:
|
187 |
+
scenario.append(intent)
|
188 |
+
|
189 |
+
# update count
|
190 |
+
if len(scenario) > n_scenario_max:
|
191 |
+
n_scenario_max = len(scenario)
|
192 |
+
if len(scenario) < n_scenario_min:
|
193 |
+
n_scenario_min = len(scenario)
|
194 |
+
|
195 |
+
scenario = list2str(scenario)
|
196 |
+
if scenario not in scenario2dialogues:
|
197 |
+
scenario2dialogues[scenario] = []
|
198 |
+
scenario2dialogues[scenario].append(dial_id)
|
199 |
+
|
200 |
+
# done iter over split
|
201 |
+
print(
|
202 |
+
"Summary: split={}, unique_scenario={}, max_intent={}, min_intent={}".format(
|
203 |
+
split, len(scenario2dialogues), n_scenario_max, n_scenario_min
|
204 |
+
)
|
205 |
+
)
|
206 |
+
|
207 |
+
|
208 |
+
def _check_request_alts_type(prev_turn, sys_turn, curr_turn, curr_acts):
|
209 |
+
"""
|
210 |
+
check which of the following happens when request_alts
|
211 |
+
1. randomly change goal (state changes)
|
212 |
+
2. request_alts as system provides venue with missing slot-value (usr provides new info)
|
213 |
+
3. simply dislike the provided venue, change venue without new slot-value (same info)
|
214 |
+
|
215 |
+
Input:
|
216 |
+
prev_turn: previous user turn
|
217 |
+
curr_turn: current user turn
|
218 |
+
"""
|
219 |
+
|
220 |
+
def _get_intent2state(turn):
|
221 |
+
intent2state = {}
|
222 |
+
for frame in turn["frames"]:
|
223 |
+
state = frame["state"]
|
224 |
+
intent = state["active_intent"]
|
225 |
+
intent2state[intent] = state
|
226 |
+
return intent2state
|
227 |
+
|
228 |
+
assert "REQUEST_ALTS" in curr_acts
|
229 |
+
if len(curr_acts) == 1: # case 3
|
230 |
+
# return "_dislike_"
|
231 |
+
if "OFFER" in get_turn_act(sys_turn):
|
232 |
+
return "_dislike_offer_"
|
233 |
+
else:
|
234 |
+
return "_dislike_info_"
|
235 |
+
elif (
|
236 |
+
"INFORM" in curr_acts and len(set(curr_acts)) == 2
|
237 |
+
): # only inform and request_alts
|
238 |
+
assert len(curr_turn["frames"]) == 1
|
239 |
+
curr_slot_values = curr_turn["frames"][0]["state"]["slot_values"]
|
240 |
+
curr_intent = curr_turn["frames"][0]["state"]["active_intent"]
|
241 |
+
|
242 |
+
if len(prev_turn["frames"]) == 1:
|
243 |
+
prev_slot_values = prev_turn["frames"][0]["state"]["slot_values"]
|
244 |
+
else: # need to get the state with the same intent
|
245 |
+
intent2state = _get_intent2state(prev_turn)
|
246 |
+
prev_slot_values = intent2state[curr_intent]["slot_values"]
|
247 |
+
|
248 |
+
state_diff = compare_slot_values_in_state(prev_slot_values, curr_slot_values)
|
249 |
+
if state_diff: # case 1
|
250 |
+
return "_random_"
|
251 |
+
else: # case 2
|
252 |
+
return "_miss_"
|
253 |
+
else:
|
254 |
+
return "_unknown_"
|
255 |
+
|
256 |
+
|
257 |
+
def stats_request_alts_type(data):
|
258 |
+
for split in DATA_SPLIT:
|
259 |
+
stats = {
|
260 |
+
"_random_": 0,
|
261 |
+
"_miss_": 0,
|
262 |
+
"_dislike_offer_": 0,
|
263 |
+
"_dislike_info_": 0,
|
264 |
+
"_unknown_": 0,
|
265 |
+
}
|
266 |
+
n_all_usr_turn, n_request_alts = 0, 0
|
267 |
+
|
268 |
+
for dial_id in sorted(data[split].keys()):
|
269 |
+
dial = data[split][dial_id]
|
270 |
+
for turn_id, turn in enumerate(dial["turns"]):
|
271 |
+
prev_turn = turn
|
272 |
+
if turn["speaker"] == "SYSTEM":
|
273 |
+
sys_turn = turn
|
274 |
+
continue
|
275 |
+
acts = get_turn_act(turn)
|
276 |
+
if "REQUEST_ALTS" in acts:
|
277 |
+
n_request_alts += 1
|
278 |
+
type_result = _check_request_alts_type(
|
279 |
+
prev_turn, sys_turn, turn, acts
|
280 |
+
)
|
281 |
+
stats[type_result] += 1
|
282 |
+
if type_result == "_random_":
|
283 |
+
print("CASE {}".format(type_result))
|
284 |
+
show_turn(0, prev_turn)
|
285 |
+
show_turn(0, sys_turn)
|
286 |
+
show_turn(0, turn)
|
287 |
+
input("press...")
|
288 |
+
n_all_usr_turn += 1
|
289 |
+
prev_turn = turn
|
290 |
+
|
291 |
+
print("REQUEST_ALTS type statistics")
|
292 |
+
for k, v in stats.items():
|
293 |
+
print("{} => {}".format(k, v))
|
294 |
+
print(
|
295 |
+
"request_alts turns: {}, all usr turns: {}, dialogues: {}".format(
|
296 |
+
n_request_alts, n_all_usr_turn, len(data[split])
|
297 |
+
)
|
298 |
+
)
|
299 |
+
|
300 |
+
|
301 |
+
def show_utt_by_act(data):
|
302 |
+
target_act = "OFFER"
|
303 |
+
for split in DATA_SPLIT:
|
304 |
+
for dial_id in sorted(data[split].keys()):
|
305 |
+
dial = data[split][dial_id]
|
306 |
+
match_flag = False
|
307 |
+
for turn_id, turn in enumerate(dial["turns"]):
|
308 |
+
acts = get_turn_act(turn)
|
309 |
+
if target_act in acts:
|
310 |
+
match_flag = True
|
311 |
+
if match_flag:
|
312 |
+
show_dial(dial_id, dial)
|
313 |
+
input("press...")
|
314 |
+
|
315 |
+
|
316 |
+
def show_state_with_value_change(data):
|
317 |
+
for split in DATA_SPLIT:
|
318 |
+
for dial_id in sorted(data[split].keys()):
|
319 |
+
dial = data[split][dial_id]
|
320 |
+
intent2slot_values = {}
|
321 |
+
for turn_id, turn in enumerate(dial["turns"]):
|
322 |
+
utt, spk = turn["utterance"], turn["speaker"]
|
323 |
+
if spk != "USER":
|
324 |
+
prev_system_turn = turn
|
325 |
+
continue
|
326 |
+
for frame in turn["frames"]:
|
327 |
+
state = frame["state"]
|
328 |
+
active_intent = state["active_intent"]
|
329 |
+
slot_values = state["slot_values"]
|
330 |
+
if active_intent in intent2slot_values:
|
331 |
+
state_diff = compare_slot_values_in_state(
|
332 |
+
intent2slot_values[active_intent], slot_values
|
333 |
+
)
|
334 |
+
if state_diff:
|
335 |
+
print(
|
336 |
+
"Dial: {}, state change: {}".format(dial_id, state_diff)
|
337 |
+
)
|
338 |
+
print(
|
339 |
+
"==> Prev SYS: {}".format(prev_system_turn["utterance"])
|
340 |
+
)
|
341 |
+
for sys_frame in prev_system_turn["frames"]:
|
342 |
+
show_actions(sys_frame["actions"])
|
343 |
+
print("==> Curr USR: {}".format(utt))
|
344 |
+
show_actions(frame["actions"])
|
345 |
+
print(
|
346 |
+
"recorded state => intent: {}, slot2value: {}".format(
|
347 |
+
active_intent,
|
348 |
+
dict2str(intent2slot_values[active_intent]),
|
349 |
+
)
|
350 |
+
)
|
351 |
+
print(
|
352 |
+
"current state => intent: {}, slot2value: {}".format(
|
353 |
+
active_intent, dict2str(slot_values)
|
354 |
+
)
|
355 |
+
)
|
356 |
+
input("press...")
|
357 |
+
intent2slot_values[
|
358 |
+
active_intent
|
359 |
+
] = slot_values # overlap with new state, no matter values changed or not
|
360 |
+
|
361 |
+
|
362 |
+
def check_state_with_value_change(data, display=False):
|
363 |
+
for split in DATA_SPLIT:
|
364 |
+
n_diff = {"NOTIFY_FAILURE": 0, "NEGATE": 0, "REQUEST_ALTS": 0, "RANDOM": 0}
|
365 |
+
for dial_id in sorted(data[split].keys()):
|
366 |
+
dial = data[split][dial_id]
|
367 |
+
intent2slot_values = {}
|
368 |
+
diff_flag = False
|
369 |
+
for turn_id, turn in enumerate(dial["turns"]):
|
370 |
+
if diff_flag:
|
371 |
+
break
|
372 |
+
utt, spk = turn["utterance"], turn["speaker"]
|
373 |
+
if spk != "USER":
|
374 |
+
prev_system_turn = turn
|
375 |
+
continue
|
376 |
+
for frame in turn["frames"]:
|
377 |
+
state = frame["state"]
|
378 |
+
active_intent = state["active_intent"]
|
379 |
+
slot_values = state["slot_values"]
|
380 |
+
if active_intent in intent2slot_values:
|
381 |
+
state_diff = compare_slot_values_in_state(
|
382 |
+
intent2slot_values[active_intent], slot_values
|
383 |
+
)
|
384 |
+
if state_diff:
|
385 |
+
usr_acts = get_turn_act(turn)
|
386 |
+
if "NOTIFY_FAILURE" in get_turn_act(prev_system_turn):
|
387 |
+
if display:
|
388 |
+
print("FAILURE", dial_id, utt)
|
389 |
+
n_diff["NOTIFY_FAILURE"] += 1
|
390 |
+
elif "NEGATE" in usr_acts:
|
391 |
+
if display:
|
392 |
+
print("NEGATE", dial_id, utt)
|
393 |
+
n_diff["NEGATE"] += 1
|
394 |
+
elif "REQUEST_ALTS" in usr_acts:
|
395 |
+
if display:
|
396 |
+
print("REQUEST_ALTS", dial_id, utt)
|
397 |
+
n_diff["REQUEST_ALTS"] += 1
|
398 |
+
else:
|
399 |
+
if display:
|
400 |
+
print("RANDOM", dial_id, utt)
|
401 |
+
n_diff["RANDOM"] += 1
|
402 |
+
if display:
|
403 |
+
input("press...")
|
404 |
+
# n_diff += 1
|
405 |
+
diff_flag = True
|
406 |
+
intent2slot_values[
|
407 |
+
active_intent
|
408 |
+
] = slot_values # overlap with new state, no matter values changed or not
|
409 |
+
n = (
|
410 |
+
n_diff["NOTIFY_FAILURE"]
|
411 |
+
+ n_diff["NEGATE"]
|
412 |
+
+ n_diff["REQUEST_ALTS"]
|
413 |
+
+ n_diff["RANDOM"]
|
414 |
+
)
|
415 |
+
print(
|
416 |
+
"{} => total dials: {}, change goal dials: {} (total: {})".format(
|
417 |
+
split, len(data[split]), dict2str(n_diff), n
|
418 |
+
)
|
419 |
+
)
|
420 |
+
|
421 |
+
|
422 |
+
def stats_after_system(data):
|
423 |
+
"""
|
424 |
+
check the possible user behavior right after system offers/notify_failure
|
425 |
+
"""
|
426 |
+
n = 0
|
427 |
+
stats = {
|
428 |
+
"SELECT": 0,
|
429 |
+
"REQUEST_ALTS": 0,
|
430 |
+
"REQUEST": 0,
|
431 |
+
"AFFIRM": 0,
|
432 |
+
"unknown": 0,
|
433 |
+
} # if system offers
|
434 |
+
# stats = {"INFORM": 0, "AFFIRM": 0, "NEGATE": 0, "unknown": 0} # if system notify_failure
|
435 |
+
for split in DATA_SPLIT:
|
436 |
+
for dial_id in sorted(data[split].keys()):
|
437 |
+
dial = data[split][dial_id]
|
438 |
+
for turn_id, turn in enumerate(dial["turns"]):
|
439 |
+
if turn_id == 0:
|
440 |
+
prev_turn = turn
|
441 |
+
continue
|
442 |
+
if turn["speaker"] == "SYSTEM":
|
443 |
+
sys_turn = turn
|
444 |
+
continue
|
445 |
+
|
446 |
+
if "OFFER" in get_turn_act(sys_turn):
|
447 |
+
# if "OFFER" in get_turn_act(sys_turn) and "NOTIFY_FAILURE" in get_turn_act(sys_turn):
|
448 |
+
# if "NOTIFY_FAILURE" in get_turn_act(sys_turn):
|
449 |
+
n += 1
|
450 |
+
acts = get_turn_act(turn)
|
451 |
+
# OFFER
|
452 |
+
if "SELECT" in acts:
|
453 |
+
stats["SELECT"] += 1
|
454 |
+
elif "REQUEST_ALTS" in acts:
|
455 |
+
stats["REQUEST_ALTS"] += 1
|
456 |
+
elif "REQUEST" in acts:
|
457 |
+
stats["REQUEST"] += 1
|
458 |
+
elif (
|
459 |
+
"AFFIRM" in acts
|
460 |
+
): # cases fall into here are SYS_ACT: ["OFFER", "NOTIFY_FAILURE"], and USR_ACT: ["AFFIRM"],
|
461 |
+
# e.g., accept new proposal
|
462 |
+
show_turn(0, prev_turn)
|
463 |
+
show_turn(0, sys_turn)
|
464 |
+
show_turn(0, turn)
|
465 |
+
input("press...")
|
466 |
+
stats["AFFIRM"] += 1
|
467 |
+
else:
|
468 |
+
stats["unknown"] += 1
|
469 |
+
|
470 |
+
# NOTIFY_FAILURE
|
471 |
+
# if "INFORM" in acts:
|
472 |
+
# stats["INFORM"] += 1
|
473 |
+
# elif "AFFIRM" in acts:
|
474 |
+
# stats["AFFIRM"] += 1
|
475 |
+
# elif "NEGATE" in acts:
|
476 |
+
# stats["NEGATE"] += 1
|
477 |
+
# else:
|
478 |
+
# stats["unknown"] += 1
|
479 |
+
|
480 |
+
prev_turn = turn
|
481 |
+
for k, v in stats.items():
|
482 |
+
print("{} -> {}".format(k, v))
|
483 |
+
print("Total offer turns: {}".format(n))
|
src/crazyneuraluser/user_model_code/argument.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
|
4 |
+
def str2bool(v):
|
5 |
+
if v.lower() in ("yes", "true", "t", "y", "1"):
|
6 |
+
return True
|
7 |
+
elif v.lower() in ("no", "false", "f", "n", "0"):
|
8 |
+
return False
|
9 |
+
else:
|
10 |
+
raise argparse.ArgumentTypeError("Boolean value expected.")
|
11 |
+
|
12 |
+
|
13 |
+
def verify_args(args):
|
14 |
+
# datasets
|
15 |
+
assert isinstance(args.data_list, list) and len(args.data_list) > 0
|
16 |
+
for data_name in args.data_list:
|
17 |
+
assert data_name in ["sgd", "multiwoz"]
|
18 |
+
|
19 |
+
# mode
|
20 |
+
assert args.mode in ["training", "finetune", "testing", "interact"]
|
21 |
+
if args.mode == "finetune":
|
22 |
+
assert args.pre_checkpoint != ""
|
23 |
+
|
24 |
+
|
25 |
+
def get_args():
|
26 |
+
parser = argparse.ArgumentParser(description="")
|
27 |
+
# logging
|
28 |
+
parser.add_argument("--wandb_train_run_name", type=str, default="Default name")
|
29 |
+
# data
|
30 |
+
parser.add_argument(
|
31 |
+
"--data_dir",
|
32 |
+
type=str,
|
33 |
+
default="proc_data",
|
34 |
+
help="Directory of processed datasets",
|
35 |
+
)
|
36 |
+
parser.add_argument(
|
37 |
+
"--data_list",
|
38 |
+
type=str,
|
39 |
+
nargs="+",
|
40 |
+
default="",
|
41 |
+
help="Datasets involved, split by space, e.g., `sgd multiwoz`",
|
42 |
+
)
|
43 |
+
|
44 |
+
# design control
|
45 |
+
parser.add_argument(
|
46 |
+
"--use_ra_flag",
|
47 |
+
type=str2bool,
|
48 |
+
default=True,
|
49 |
+
help="Whether to use `request_alternatives` flag",
|
50 |
+
)
|
51 |
+
|
52 |
+
# training
|
53 |
+
parser.add_argument("--mode", type=str, required=True, help="")
|
54 |
+
parser.add_argument("--seed", type=int, default=1122)
|
55 |
+
parser.add_argument(
|
56 |
+
"--model_name", type=str, required=True, help="Unique name, e.g., job id"
|
57 |
+
)
|
58 |
+
parser.add_argument("--model_name_or_path", type=str, default="gpt2")
|
59 |
+
parser.add_argument(
|
60 |
+
"--train_batch_size", type=int, default=4, help="Batch size of training per gpu"
|
61 |
+
)
|
62 |
+
parser.add_argument(
|
63 |
+
"--eval_batch_size",
|
64 |
+
type=int,
|
65 |
+
default=1,
|
66 |
+
help="Batch size of evaluation per gpu",
|
67 |
+
) # TODO: make decoding parallel
|
68 |
+
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
|
69 |
+
parser.add_argument("--learning_rate", type=float, default=6.25e-5) # tune
|
70 |
+
parser.add_argument("--adam_epsilon", type=float, default=1e-12)
|
71 |
+
parser.add_argument("--max_grad_norm", type=float, default=1.0)
|
72 |
+
parser.add_argument("--max_epoch", type=int, default=20)
|
73 |
+
parser.add_argument(
|
74 |
+
"--fp16", type=str2bool, default=False, help="Whether to use float16"
|
75 |
+
)
|
76 |
+
parser.add_argument(
|
77 |
+
"--use_scheduler",
|
78 |
+
type=str2bool,
|
79 |
+
default=True,
|
80 |
+
help="Whether to use lr scheduler",
|
81 |
+
)
|
82 |
+
parser.add_argument("--warmup_steps", type=int, default=0)
|
83 |
+
parser.add_argument(
|
84 |
+
"--checkpoint",
|
85 |
+
type=str,
|
86 |
+
default="",
|
87 |
+
required=True,
|
88 |
+
help="Path of your trained model",
|
89 |
+
)
|
90 |
+
parser.add_argument(
|
91 |
+
"--pre_checkpoint",
|
92 |
+
type=str,
|
93 |
+
default="",
|
94 |
+
help="Path of the pretrained model used for finetuning",
|
95 |
+
)
|
96 |
+
parser.add_argument(
|
97 |
+
"--train_size",
|
98 |
+
type=int,
|
99 |
+
default=-1,
|
100 |
+
help="How many examples used for training. -1 means all data",
|
101 |
+
)
|
102 |
+
parser.add_argument(
|
103 |
+
"--eval_size",
|
104 |
+
type=int,
|
105 |
+
default=-1,
|
106 |
+
help="How many examples used for evaluation. -1 means all data",
|
107 |
+
)
|
108 |
+
parser.add_argument(
|
109 |
+
"--eval_interval",
|
110 |
+
type=int,
|
111 |
+
default=1000,
|
112 |
+
help="During training, how frequent to evaluate the model in terms of training examples",
|
113 |
+
)
|
114 |
+
parser.add_argument(
|
115 |
+
"--no_improve_max",
|
116 |
+
type=int,
|
117 |
+
default=100,
|
118 |
+
help="The max tolerance for model not improving",
|
119 |
+
)
|
120 |
+
parser.add_argument("--eps", type=float, default=1e-12)
|
121 |
+
parser.add_argument(
|
122 |
+
"--disable_display", type=str2bool, default=False, help="display progress bar"
|
123 |
+
)
|
124 |
+
|
125 |
+
# decoding
|
126 |
+
# parser.add_argument('--step', type=int, default=-1) # load model trained at which specific step
|
127 |
+
parser.add_argument(
|
128 |
+
"--dec_max_len", type=int, default=2000
|
129 |
+
) # we use early stop to stop generation when hits <EOS>
|
130 |
+
parser.add_argument("--num_beams", type=int, default=1)
|
131 |
+
parser.add_argument("--temperature", type=float, default=1.0)
|
132 |
+
# parser.add_argument('--top_k', type=int, default=0)
|
133 |
+
# parser.add_argument('--top_p', type=int, default=0)
|
134 |
+
parser.add_argument("--decode_file", type=str, default="")
|
135 |
+
parser.add_argument(
|
136 |
+
"--eye_browse_output",
|
137 |
+
type=str2bool,
|
138 |
+
default=False,
|
139 |
+
help="Whether to eye browse decoded results",
|
140 |
+
)
|
141 |
+
|
142 |
+
# ddp
|
143 |
+
parser.add_argument(
|
144 |
+
"--local_rank",
|
145 |
+
type=int,
|
146 |
+
default=-1,
|
147 |
+
help="Local rank for distributed training (-1: not distributed)",
|
148 |
+
)
|
149 |
+
|
150 |
+
args = parser.parse_args()
|
151 |
+
verify_args(args)
|
152 |
+
print(args)
|
153 |
+
return args
|
src/crazyneuraluser/user_model_code/dataset.py
ADDED
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
from crazyneuraluser.user_model_code.utils_sgd import (
|
8 |
+
add_str,
|
9 |
+
get_special_tokens,
|
10 |
+
wrap_element,
|
11 |
+
)
|
12 |
+
|
13 |
+
|
14 |
+
class SGD_Dataset(torch.utils.data.Dataset):
|
15 |
+
def __init__(self, args, tokenizer, data_split, generation, data_size):
|
16 |
+
assert data_split in ["train", "dev", "test", "demo"]
|
17 |
+
self.args = args
|
18 |
+
self.data_size = data_size
|
19 |
+
self.tokenizer = tokenizer
|
20 |
+
self.data_split = data_split
|
21 |
+
self.generation = generation
|
22 |
+
self.n_trimmed = 0
|
23 |
+
|
24 |
+
self.SPECIAL_TOKENS = get_special_tokens()
|
25 |
+
self._get_special_token_ids()
|
26 |
+
|
27 |
+
# create examples
|
28 |
+
self.examples = []
|
29 |
+
for data_name in args.data_list:
|
30 |
+
examples = self._create_examples(data_name, data_split)
|
31 |
+
self.examples += examples
|
32 |
+
print("Total ({}) -> {} examples".format(data_split, len(self.examples)))
|
33 |
+
|
34 |
+
def _get_special_token_ids(self):
|
35 |
+
self.bos_id = self.tokenizer.convert_tokens_to_ids(
|
36 |
+
self.SPECIAL_TOKENS["bos_token"]
|
37 |
+
)
|
38 |
+
self.eos_id = self.tokenizer.convert_tokens_to_ids(
|
39 |
+
self.SPECIAL_TOKENS["eos_token"]
|
40 |
+
)
|
41 |
+
self.pad_id = self.tokenizer.convert_tokens_to_ids(
|
42 |
+
self.SPECIAL_TOKENS["pad_token"]
|
43 |
+
)
|
44 |
+
self.sep_id = self.tokenizer.convert_tokens_to_ids(
|
45 |
+
self.SPECIAL_TOKENS["sep_token"]
|
46 |
+
)
|
47 |
+
# print('SPECIAL TOKEN MAPPING:')
|
48 |
+
# print('bos:{} | eos:{} | pad:{} | sep:{}'.format(self.bos_id, self.eos_id, self.pad_id, self.sep_id))
|
49 |
+
|
50 |
+
self.add_special_token_ids = {}
|
51 |
+
for token in self.SPECIAL_TOKENS["additional_special_tokens"]:
|
52 |
+
self.add_special_token_ids[token] = self.tokenizer.convert_tokens_to_ids(
|
53 |
+
token
|
54 |
+
)
|
55 |
+
|
56 |
+
self.true_token, self.false_token = "_True_", "_False_"
|
57 |
+
assert self.true_token in self.SPECIAL_TOKENS["additional_special_tokens"]
|
58 |
+
assert self.false_token in self.SPECIAL_TOKENS["additional_special_tokens"]
|
59 |
+
"""
|
60 |
+
if using BPE (default method, simply call tokenizer(natural sentence)), no need unk_token
|
61 |
+
if using convert_tokens_to_ids, check which is correct way to handle oov:
|
62 |
+
a) simply use <endoftext> as unk_token (default setup) or
|
63 |
+
b) add unk_token into special tokens
|
64 |
+
"""
|
65 |
+
|
66 |
+
def _create_examples(self, data_name, data_split):
|
67 |
+
data_file = os.path.join(
|
68 |
+
self.args.data_dir, data_name, "{}.json".format(data_split)
|
69 |
+
)
|
70 |
+
with open(data_file) as f:
|
71 |
+
data = json.load(f)
|
72 |
+
|
73 |
+
examples = []
|
74 |
+
for dial_id in tqdm(sorted(data.keys())):
|
75 |
+
if self.data_size != -1 and len(examples) >= self.data_size:
|
76 |
+
break
|
77 |
+
dial_meta = data[dial_id]
|
78 |
+
context = ""
|
79 |
+
for i in range(100):
|
80 |
+
example_id = "{}-{}".format(dial_id, i)
|
81 |
+
self.example_id = example_id
|
82 |
+
if example_id not in dial_meta:
|
83 |
+
break
|
84 |
+
|
85 |
+
# testing #
|
86 |
+
# # SGD
|
87 |
+
# if data_split == "test" and dial_id not in ["10_00056", "10_00075"]: # seen, movie domain
|
88 |
+
# if data_split == "test" and dial_id not in ["16_00040"]: # seen
|
89 |
+
# if data_split == "test" and dial_id not in ["8_00066", "16_00095", "8_00065"]: # unseen
|
90 |
+
# if data_split == "test" and dial_id not in ["9_00121", "9_00122"]:
|
91 |
+
# # req_alts cases w/i, w/o inform
|
92 |
+
# continue
|
93 |
+
# # mwoz
|
94 |
+
# if data_split == "test" and dial_id not in ["MUL0071.json"]:
|
95 |
+
# # test predictions in no offer & no book
|
96 |
+
# continue
|
97 |
+
|
98 |
+
# turn info
|
99 |
+
goal = dial_meta[example_id]["goal"]
|
100 |
+
# service = dial_meta[example_id]["service"]
|
101 |
+
# intent = dial_meta[example_id]["intent"]
|
102 |
+
|
103 |
+
# utterances
|
104 |
+
usr_utt = dial_meta[example_id]["utterances"]["usr"]
|
105 |
+
sys_utt = dial_meta[example_id]["utterances"]["sys"]
|
106 |
+
|
107 |
+
# actions
|
108 |
+
usr_act = dial_meta[example_id]["actions"]["usr"]
|
109 |
+
sys_act = dial_meta[example_id]["actions"]["sys"]
|
110 |
+
|
111 |
+
# binary flags
|
112 |
+
snt = dial_meta[example_id]["start_new_task"]
|
113 |
+
gc = dial_meta[example_id]["goal_change"]
|
114 |
+
ra = dial_meta[example_id]["req_alts"]
|
115 |
+
|
116 |
+
# get input ids
|
117 |
+
(
|
118 |
+
input_seq,
|
119 |
+
input_ids,
|
120 |
+
label_ids,
|
121 |
+
valid_example,
|
122 |
+
) = self._prepare_input_ids(
|
123 |
+
goal, context, usr_utt, usr_act, sys_utt, sys_act, snt, gc, ra
|
124 |
+
)
|
125 |
+
|
126 |
+
if valid_example:
|
127 |
+
assert len(input_ids) < 1024
|
128 |
+
dial_meta[example_id]["context"] = context
|
129 |
+
examples.append(
|
130 |
+
{
|
131 |
+
"input_ids": input_ids, # list of ids
|
132 |
+
"label_ids": label_ids, # list of ids
|
133 |
+
"metadata": dial_meta[example_id],
|
134 |
+
"example_id": self.example_id,
|
135 |
+
"data_name": data_name,
|
136 |
+
}
|
137 |
+
)
|
138 |
+
|
139 |
+
# collect context
|
140 |
+
sys_utt_wrap = wrap_element("SYS", sys_utt)
|
141 |
+
usr_utt_wrap = wrap_element("USR", usr_utt)
|
142 |
+
context = add_str(context, sys_utt_wrap)
|
143 |
+
context = add_str(context, usr_utt_wrap)
|
144 |
+
|
145 |
+
print(
|
146 |
+
"Data Stat: {} ({}) -> {} examples ({} examples are trimmed)".format(
|
147 |
+
data_name, self.data_split, len(examples), self.n_trimmed
|
148 |
+
)
|
149 |
+
)
|
150 |
+
return examples
|
151 |
+
|
152 |
+
def _prepare_input_ids(
|
153 |
+
self, goal, context, usr_utt, usr_act, sys_utt, sys_act, snt, gc, ra
|
154 |
+
):
|
155 |
+
"""
|
156 |
+
prepare input sequence ids to GPT2
|
157 |
+
template: <CTX> <SYS_UTT> <SYS_ACT> <SNT> <RA> <GC> <GOAL> <USR_ACT> <USR_UTT>
|
158 |
+
"""
|
159 |
+
goal_wrap = wrap_element("GOAL", goal)
|
160 |
+
context_wrap = wrap_element("CTX", context)
|
161 |
+
usr_utt_wrap = wrap_element("USR_UTT", usr_utt)
|
162 |
+
usr_act_wrap = wrap_element("USR_ACT", usr_act)
|
163 |
+
sys_utt_wrap = wrap_element("SYS_UTT", sys_utt)
|
164 |
+
sys_act_wrap = wrap_element("SYS_ACT", sys_act)
|
165 |
+
|
166 |
+
snt = self.true_token if snt else self.false_token # `Start New Task` flag
|
167 |
+
snt_wrap = wrap_element("SNT", snt)
|
168 |
+
gc = self.true_token if gc else self.false_token # `Goal Change` flag
|
169 |
+
gc_wrap = wrap_element("GC", gc)
|
170 |
+
ra = self.true_token if ra else self.false_token # `Request Alternatives` flag
|
171 |
+
ra_wrap = wrap_element("RA", ra)
|
172 |
+
if self.args.use_ra_flag:
|
173 |
+
flags_wrap = snt_wrap + " " + ra_wrap + " " + gc_wrap
|
174 |
+
else:
|
175 |
+
flags_wrap = snt_wrap + " " + gc_wrap
|
176 |
+
|
177 |
+
if not self.generation: # supervised
|
178 |
+
input_seq = (
|
179 |
+
context_wrap
|
180 |
+
+ " "
|
181 |
+
+ sys_utt_wrap
|
182 |
+
+ " "
|
183 |
+
+ sys_act_wrap
|
184 |
+
+ " "
|
185 |
+
+ flags_wrap
|
186 |
+
+ " "
|
187 |
+
+ goal_wrap
|
188 |
+
+ " "
|
189 |
+
+ usr_act_wrap
|
190 |
+
+ " "
|
191 |
+
+ usr_utt_wrap
|
192 |
+
+ " "
|
193 |
+
+ self.SPECIAL_TOKENS["eos_token"]
|
194 |
+
)
|
195 |
+
input_ids = self.tokenizer(input_seq)["input_ids"] # convert to ids
|
196 |
+
label_ids = self._get_labels(input_ids)
|
197 |
+
else: # generation
|
198 |
+
input_seq = (
|
199 |
+
context_wrap
|
200 |
+
+ " "
|
201 |
+
+ sys_utt_wrap
|
202 |
+
+ " "
|
203 |
+
+ sys_act_wrap
|
204 |
+
+ " "
|
205 |
+
+ flags_wrap
|
206 |
+
+ " "
|
207 |
+
+ goal_wrap
|
208 |
+
+ " "
|
209 |
+
+ "<USR_ACT/>"
|
210 |
+
) # + " " + usr_act_wrap + " " + usr_utt_wrap
|
211 |
+
input_ids = self.tokenizer(input_seq)["input_ids"] # convert to ids
|
212 |
+
label_ids = None
|
213 |
+
|
214 |
+
valid_example = True
|
215 |
+
if len(input_ids) > 1023:
|
216 |
+
print("{}: {}".format(self.n_trimmed, self.example_id))
|
217 |
+
self.n_trimmed += 1
|
218 |
+
valid_example = False
|
219 |
+
|
220 |
+
return input_seq, input_ids, label_ids, valid_example
|
221 |
+
|
222 |
+
def _get_labels(self, input_ids):
|
223 |
+
for special_token in ["<SYS_ACT/>", "</GC>", "<USR_ACT/>"]:
|
224 |
+
special_token_id = self.add_special_token_ids[special_token]
|
225 |
+
assert input_ids.count(special_token_id) == 1
|
226 |
+
|
227 |
+
label_ids = [-100] * len(input_ids)
|
228 |
+
|
229 |
+
# sys act signal interval
|
230 |
+
start_position = input_ids.index(self.add_special_token_ids["<SYS_ACT/>"])
|
231 |
+
end_position = input_ids.index(self.add_special_token_ids["</GC>"]) + 1
|
232 |
+
label_ids[start_position:end_position] = input_ids[start_position:end_position]
|
233 |
+
|
234 |
+
# usr act and utt singal interval
|
235 |
+
start_position = input_ids.index(self.add_special_token_ids["<USR_ACT/>"])
|
236 |
+
assert self.eos_id == input_ids[-1]
|
237 |
+
label_ids[start_position:] = input_ids[start_position:]
|
238 |
+
assert len(label_ids) == len(input_ids)
|
239 |
+
return label_ids
|
240 |
+
|
241 |
+
def _pad(self, sentences, pad_id):
|
242 |
+
max_len = max((map(len, sentences)))
|
243 |
+
attention_mask = []
|
244 |
+
sentences_pad = []
|
245 |
+
for sent in sentences:
|
246 |
+
pad_len = max_len - len(sent)
|
247 |
+
sentences_pad.append(sent + [pad_id] * pad_len)
|
248 |
+
attention_mask.append([1] * len(sent) + [0] * pad_len)
|
249 |
+
return sentences_pad, attention_mask
|
250 |
+
|
251 |
+
def __len__(self): # required
|
252 |
+
return len(self.examples)
|
253 |
+
|
254 |
+
def __getitem__(self, index): # required
|
255 |
+
"""
|
256 |
+
index will be ramdomly sampled by the fed sampler, we dont need to worry about index
|
257 |
+
"""
|
258 |
+
return self.examples[index]
|
259 |
+
|
260 |
+
def collate_fn(self, batch): # optional but useful
|
261 |
+
"""
|
262 |
+
when collate_fn is given to the torch dataloader, we can do further actions to the batch, e.g.,
|
263 |
+
tensor can be formed here a batch is formed as a list where each element is a defined data returned
|
264 |
+
by __getitem__, andy
|
265 |
+
"""
|
266 |
+
input_ids = [example["input_ids"] for example in batch]
|
267 |
+
input_ids, attention_mask = self._pad(input_ids, self.pad_id)
|
268 |
+
input_ids, attention_mask = torch.tensor(input_ids).long().to(
|
269 |
+
self.args.device
|
270 |
+
), torch.tensor(attention_mask).long().to(self.args.device)
|
271 |
+
|
272 |
+
if not self.generation:
|
273 |
+
label_ids = [example["label_ids"] for example in batch]
|
274 |
+
label_ids, _ = self._pad(label_ids, -100)
|
275 |
+
label_ids = torch.tensor(label_ids).long().to(self.args.device)
|
276 |
+
else:
|
277 |
+
label_ids = None
|
278 |
+
token_type_ids = None
|
279 |
+
|
280 |
+
# store info for scoring
|
281 |
+
metadata = [ex["metadata"] for ex in batch]
|
282 |
+
example_id = [ex["example_id"] for ex in batch]
|
283 |
+
data_name = [ex["data_name"] for ex in batch]
|
284 |
+
|
285 |
+
return {
|
286 |
+
"input_ids": input_ids,
|
287 |
+
"attention_mask": attention_mask,
|
288 |
+
"token_type_ids": token_type_ids,
|
289 |
+
"label_ids": label_ids,
|
290 |
+
"metadata": metadata,
|
291 |
+
"example_id": example_id,
|
292 |
+
"data_name": data_name,
|
293 |
+
}
|
294 |
+
|
295 |
+
|
296 |
+
if __name__ == "__main__":
|
297 |
+
pass
|
src/crazyneuraluser/user_model_code/utils_generation.py
ADDED
@@ -0,0 +1,210 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
from crazyneuraluser.user_model_code.utils_sgd import add_str, bcolors, wrap_element
|
4 |
+
|
5 |
+
|
6 |
+
def find_segment(gen, tag):
|
7 |
+
assert isinstance(gen, str)
|
8 |
+
gen = gen.split()
|
9 |
+
try:
|
10 |
+
start = gen.index("<{}/>".format(tag)) + 1
|
11 |
+
end = gen.index("</{}>".format(tag))
|
12 |
+
segment = " ".join(gen[start:end])
|
13 |
+
except Exception:
|
14 |
+
print("Missing {} tag in generated sequence".format(tag))
|
15 |
+
segment = None
|
16 |
+
return segment
|
17 |
+
|
18 |
+
|
19 |
+
def segment_gen(gen, dial_id):
|
20 |
+
def _color(_segment):
|
21 |
+
if tag == "CTX":
|
22 |
+
_segment = _segment.replace(" </USR>", f"{bcolors.ENDC}")
|
23 |
+
_segment = _segment.replace(" </SYS>", f"{bcolors.ENDC}")
|
24 |
+
_segment = _segment.replace("<USR/> ", f"USR: {bcolors.OKCYAN}")
|
25 |
+
_segment = _segment.replace("<SYS/> ", f"SYS: {bcolors.OKBLUE}")
|
26 |
+
if tag == "SYS_UTT":
|
27 |
+
_segment = f"{bcolors.OKBLUE}" + _segment + f"{bcolors.ENDC}"
|
28 |
+
if tag == "USR_UTT":
|
29 |
+
_segment = f"{bcolors.OKCYAN}" + _segment + f"{bcolors.ENDC}"
|
30 |
+
if tag in ["SYS_ACT", "USR_ACT", "GOAL"]:
|
31 |
+
_segment = _segment.replace("<ACT/> ", f"{bcolors.RED}")
|
32 |
+
_segment = _segment.replace(" </ACT>", f"{bcolors.ENDC}")
|
33 |
+
_segment = _segment.replace("<SLOT/> ", f"{bcolors.YELLOW}")
|
34 |
+
_segment = _segment.replace(" </SLOT>", f"{bcolors.ENDC}")
|
35 |
+
_segment = _segment.replace("<VALUE/> ", f"{bcolors.GREEN}")
|
36 |
+
_segment = _segment.replace(" </VALUE>", f"{bcolors.ENDC}")
|
37 |
+
if tag == "GOAL":
|
38 |
+
_segment = _segment.replace(
|
39 |
+
"<SCENARIO/>", f"<SCENARIO/>{bcolors.UNDERLINE}"
|
40 |
+
)
|
41 |
+
_segment = _segment.replace("</SCENARIO>", f"{bcolors.ENDC}</SCENARIO>")
|
42 |
+
_segment = _segment.replace("<TASK/>", f"<TASK/>{bcolors.UNDERLINE}")
|
43 |
+
_segment = _segment.replace("</TASK>", f"{bcolors.ENDC}</TASK>")
|
44 |
+
# if tag in ["SNT", "GC"]:
|
45 |
+
# segment = segment.replace("<{}/> ".format(tag), "<{}/> *".format(tag))
|
46 |
+
# segment = segment.replace(" </{}>".format(tag), "* <{}/>".format(tag))
|
47 |
+
return _segment
|
48 |
+
|
49 |
+
assert isinstance(gen, str)
|
50 |
+
print("*** Dial_id: {} ***".format(dial_id))
|
51 |
+
for tag in [
|
52 |
+
"CTX",
|
53 |
+
"SYS_UTT",
|
54 |
+
"SYS_ACT",
|
55 |
+
"GOAL",
|
56 |
+
"SNT",
|
57 |
+
"RA",
|
58 |
+
"GC",
|
59 |
+
"USR_ACT",
|
60 |
+
"USR_UTT",
|
61 |
+
]:
|
62 |
+
segment = find_segment(gen, tag)
|
63 |
+
if segment is not None:
|
64 |
+
print('{} -> "{}"'.format(tag, _color(segment)))
|
65 |
+
else:
|
66 |
+
print("Fail to find the segment...")
|
67 |
+
print("GEN:", gen)
|
68 |
+
print("---" * 30)
|
69 |
+
input("press any key to continue...")
|
70 |
+
|
71 |
+
|
72 |
+
def save_gen(gen, dial_id, container):
|
73 |
+
output = {"raw_generation": gen}
|
74 |
+
parsed_generation = {}
|
75 |
+
|
76 |
+
assert isinstance(gen, str)
|
77 |
+
for tag in [
|
78 |
+
"CTX",
|
79 |
+
"SYS_UTT",
|
80 |
+
"SYS_ACT",
|
81 |
+
"GOAL",
|
82 |
+
"SNT",
|
83 |
+
"RA",
|
84 |
+
"GC",
|
85 |
+
"USR_ACT",
|
86 |
+
"USR_UTT",
|
87 |
+
]:
|
88 |
+
segment = find_segment(gen, tag)
|
89 |
+
if segment is not None:
|
90 |
+
parsed_generation[tag] = segment
|
91 |
+
else:
|
92 |
+
print("Fail to parse generation on example {}".format(dial_id))
|
93 |
+
parsed_generation[tag] = None
|
94 |
+
|
95 |
+
output["parsed_generation"] = parsed_generation
|
96 |
+
container[dial_id] = output
|
97 |
+
|
98 |
+
|
99 |
+
# def decode(args, batch, model, tokenizer):
|
100 |
+
# input_ids = batch['input_ids']
|
101 |
+
# batch_size, ctx_len = input_ids.size()
|
102 |
+
# assert batch_size == 1
|
103 |
+
# bos_id, eos_id, pad_id, sep_id = tokenizer.convert_tokens_to_ids(['<BOS>', '<EOS>', '<PAD>', '<SEP>'])
|
104 |
+
#
|
105 |
+
# # output size: (B, T)
|
106 |
+
# output = model.generate(input_ids, max_length=(ctx_len+args.dec_max_len), do_sample=False,
|
107 |
+
# temperature=args.temperature, use_cache=True, num_beams=args.num_beams, bos_token_id=bos_id,
|
108 |
+
# eos_token_id=eos_id, pad_token_id=pad_id, early_stopping=True)
|
109 |
+
#
|
110 |
+
# gen = tokenizer.decode(output[0]) # include context fed into model
|
111 |
+
# segment_gen(gen, batch["example_id"][0])
|
112 |
+
# return [gen]
|
113 |
+
|
114 |
+
|
115 |
+
def prepare_input_ids(
|
116 |
+
args: object, tokenizer: object, data: object, start_token: object
|
117 |
+
) -> object:
|
118 |
+
assert start_token in ["<SYS_ACT/>", "<USR_ACT/>"]
|
119 |
+
input_seq = ""
|
120 |
+
for key in [
|
121 |
+
"CTX",
|
122 |
+
"SYS_UTT",
|
123 |
+
"SYS_ACT",
|
124 |
+
"SNT",
|
125 |
+
"RA",
|
126 |
+
"GC",
|
127 |
+
"GOAL",
|
128 |
+
]: # fixed order, consistent between training and inference
|
129 |
+
if key not in data:
|
130 |
+
continue
|
131 |
+
wrap = wrap_element(key, data[key])
|
132 |
+
input_seq = add_str(input_seq, wrap)
|
133 |
+
|
134 |
+
input_seq = add_str(input_seq, start_token)
|
135 |
+
|
136 |
+
input_ids = tokenizer(input_seq)["input_ids"] # convert to ids
|
137 |
+
input_ids = torch.tensor([input_ids]).long().to(args.device)
|
138 |
+
return input_ids
|
139 |
+
|
140 |
+
|
141 |
+
def decode_e2e(
|
142 |
+
args, batch, model, tokenizer, user_goal=None, prev_usr_act=None, collector=None
|
143 |
+
):
|
144 |
+
"""decode with predicted sys act, goal can be random or from the corpus"""
|
145 |
+
assert len(batch["metadata"]) == 1
|
146 |
+
context = batch["metadata"][0]["context"]
|
147 |
+
sys_utt = batch["metadata"][0]["utterances"]["sys"]
|
148 |
+
bos_id, _, pad_id, sep_id = tokenizer.convert_tokens_to_ids(
|
149 |
+
["<BOS>", "<EOS>", "<PAD>", "<SEP>"]
|
150 |
+
)
|
151 |
+
|
152 |
+
# first forward pass
|
153 |
+
data = {"CTX": context, "SYS_UTT": sys_utt}
|
154 |
+
start_token, end_token = "<SYS_ACT/>", "</GC>"
|
155 |
+
input_ids = prepare_input_ids(args, tokenizer, data, start_token)
|
156 |
+
eos_id = tokenizer.convert_tokens_to_ids(end_token)
|
157 |
+
output = model.generate(
|
158 |
+
input_ids,
|
159 |
+
max_length=args.dec_max_len,
|
160 |
+
do_sample=False,
|
161 |
+
temperature=args.temperature,
|
162 |
+
use_cache=True,
|
163 |
+
num_beams=args.num_beams,
|
164 |
+
bos_token_id=bos_id,
|
165 |
+
eos_token_id=eos_id,
|
166 |
+
pad_token_id=pad_id,
|
167 |
+
early_stopping=True,
|
168 |
+
)
|
169 |
+
gen = tokenizer.decode(output[0]) # include context fed into model
|
170 |
+
|
171 |
+
# parse the first pass prediction
|
172 |
+
for key in ["SYS_ACT", "SNT", "GC", "RA"]:
|
173 |
+
value = find_segment(gen, key)
|
174 |
+
data[key] = value
|
175 |
+
# print("***** First run generation *****")
|
176 |
+
# print("SYS_ACT -> {}".format(data["SYS_ACT"]))
|
177 |
+
# print("FLAGS -> SNT: {}, GC: {}, RA: {} *****".format(data["SNT"], data["GC"], data["RA"]))
|
178 |
+
# print("********************************")
|
179 |
+
|
180 |
+
# prepare goal
|
181 |
+
if user_goal is None: # use ground truth goal from corpus
|
182 |
+
data["GOAL"] = batch["metadata"][0]["goal"]
|
183 |
+
else:
|
184 |
+
goal = user_goal.prepare_turn_goal(
|
185 |
+
prev_usr_act, data["SYS_ACT"], data["SNT"], data["GC"], data["RA"]
|
186 |
+
)
|
187 |
+
data["GOAL"] = goal
|
188 |
+
|
189 |
+
# second forward pass
|
190 |
+
start_token, end_token = "<USR_ACT/>", "<EOS>"
|
191 |
+
input_ids = prepare_input_ids(args, tokenizer, data, start_token)
|
192 |
+
eos_id = tokenizer.convert_tokens_to_ids(end_token)
|
193 |
+
output = model.generate(
|
194 |
+
input_ids,
|
195 |
+
max_length=args.dec_max_len,
|
196 |
+
do_sample=False,
|
197 |
+
temperature=args.temperature,
|
198 |
+
use_cache=True,
|
199 |
+
num_beams=args.num_beams,
|
200 |
+
bos_token_id=bos_id,
|
201 |
+
eos_token_id=eos_id,
|
202 |
+
pad_token_id=pad_id,
|
203 |
+
early_stopping=True,
|
204 |
+
)
|
205 |
+
gen = tokenizer.decode(output[0]) # include context fed into model
|
206 |
+
if args.eye_browse_output:
|
207 |
+
segment_gen(gen, batch["example_id"][0])
|
208 |
+
else:
|
209 |
+
save_gen(gen, batch["example_id"][0], collector)
|
210 |
+
return [gen]
|
src/crazyneuraluser/user_model_code/utils_multiwoz.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import re
|
3 |
+
|
4 |
+
|
5 |
+
def get_original_act_set():
|
6 |
+
# NOTE:
|
7 |
+
# act `Book` and `NoBook` belong to `Booking` domain by ontology,
|
8 |
+
# they contain information about either `restaurant` or `hotel` domain
|
9 |
+
# full act vocab: https://github.com/ConvLab/ConvLab/blob/master/data/multiwoz/ \
|
10 |
+
# annotation/Multiwoz%20data%20analysis.md#dialog-act
|
11 |
+
acts = set()
|
12 |
+
acts.add("Inform")
|
13 |
+
acts.add("Request")
|
14 |
+
acts.add(
|
15 |
+
"NoOffer"
|
16 |
+
) # equivalent to the concept of `no matching`, `cannot find` in database
|
17 |
+
acts.add("Recommend")
|
18 |
+
acts.add("Select")
|
19 |
+
acts.add(
|
20 |
+
"OfferBook"
|
21 |
+
) # only for `train` domain, ask if book is needed, equivalent to `Booking-Inform`
|
22 |
+
# with [[none, none]] args in restaurant/hotel domain
|
23 |
+
acts.add(
|
24 |
+
"OfferBooked"
|
25 |
+
) # only for `train` domain, inform booking is complete, with corresponding info (such as ref number)
|
26 |
+
acts.add("Book") # inform booking is successful, equivalent to `OfferBooked` above
|
27 |
+
acts.add(
|
28 |
+
"NoBook"
|
29 |
+
) # inform booking fails, might because of no availability, usually come together act `request`
|
30 |
+
acts.add("bye")
|
31 |
+
acts.add("greet")
|
32 |
+
acts.add("reqmore")
|
33 |
+
acts.add("welcome")
|
34 |
+
acts.add("thank")
|
35 |
+
return acts
|
36 |
+
|
37 |
+
|
38 |
+
def get_act_natural_language(act):
|
39 |
+
if act in ["bye", "greet", "reqmore", "welcome", "thank"]:
|
40 |
+
return act
|
41 |
+
|
42 |
+
assert act[0].isupper()
|
43 |
+
tokens = re.findall("[A-Z][^A-Z]*", act) # e.g., `FindEvents` -> `Find Events`
|
44 |
+
tokens = list(map(str.lower, tokens)) # lower case, -> `find events`
|
45 |
+
act_nl = " ".join(tokens)
|
46 |
+
return act_nl
|
47 |
+
|
48 |
+
|
49 |
+
def convert_act_into_sgd(act, SPECIAL_TOKENS):
|
50 |
+
"""
|
51 |
+
convert multiwoz acts (w/o domain info) into sgd acts ensure that acts with same concept use one name
|
52 |
+
e.g., Book (OfferBooked) -> NOTIFY_SUCCESS, NoBook -> NOTIFY_FAILURE
|
53 |
+
"""
|
54 |
+
if act == "NoOffer":
|
55 |
+
act = "NOTIFY_FAILURE"
|
56 |
+
|
57 |
+
elif act == "Recommend":
|
58 |
+
act = "OFFER"
|
59 |
+
|
60 |
+
# technically, `OfferBook` is equivalent to (`act=OFFER_INTENT, slot=intent, value=ReserveRestaurant`)
|
61 |
+
# on system side in sgd since (1) the conversion is not trivial (completely different representations)
|
62 |
+
# and (2) multiwoz has no slot called `intent`
|
63 |
+
# one cannot simply convert `OfferBook` to `OFFER_INTENT`
|
64 |
+
# we thus keep the act as is
|
65 |
+
# note that there is no slot `intent` and value conveying intents in multiwoz
|
66 |
+
elif act == "OfferBook":
|
67 |
+
act = "Offer_Book"
|
68 |
+
|
69 |
+
elif act == "OfferBooked":
|
70 |
+
act = "NOTIFY_SUCCESS"
|
71 |
+
|
72 |
+
elif act == "Book": # same as `OfferBooked`
|
73 |
+
act = "NOTIFY_SUCCESS"
|
74 |
+
|
75 |
+
elif act == "NoBook":
|
76 |
+
act = "NOTIFY_FAILURE"
|
77 |
+
|
78 |
+
elif act == "bye":
|
79 |
+
act = "GOODBYE"
|
80 |
+
|
81 |
+
elif act == "reqmore":
|
82 |
+
act = "REQ_MORE"
|
83 |
+
|
84 |
+
elif act == "thank":
|
85 |
+
act = "THANK_YOU"
|
86 |
+
# elif act == "greet":
|
87 |
+
# elif act == "welcome":
|
88 |
+
act = act.upper() # align with sgd acts, e.g., `Inform` -> `INFORM`
|
89 |
+
|
90 |
+
# check if valid
|
91 |
+
assert "_{}_".format(act) in SPECIAL_TOKENS["additional_special_tokens"]
|
92 |
+
return act
|
93 |
+
|
94 |
+
|
95 |
+
def load_schema(schema_file):
|
96 |
+
def _update(key, value, mapping):
|
97 |
+
if key in mapping:
|
98 |
+
assert (
|
99 |
+
value == mapping[key]
|
100 |
+
) # ensure service meta is the same between data splits
|
101 |
+
else:
|
102 |
+
mapping[key] = value
|
103 |
+
|
104 |
+
def _restructure_service_meta(service_meta, attribute):
|
105 |
+
"""convert slot/intent metadata list into dict(slot/intent=metadata)"""
|
106 |
+
assert attribute in ["slots", "intents"]
|
107 |
+
mapping = {}
|
108 |
+
for value in service_meta[attribute]:
|
109 |
+
key = value["name"]
|
110 |
+
if attribute == "slots": # domain-slot in multiwoz
|
111 |
+
assert "-" in key
|
112 |
+
_, key = key.split("-") # domain, slot
|
113 |
+
key = normalise_slot(key)
|
114 |
+
else: # intent
|
115 |
+
key = normalise_intent(key)
|
116 |
+
mapping[key] = value
|
117 |
+
service_meta[attribute] = mapping
|
118 |
+
|
119 |
+
with open(schema_file) as f:
|
120 |
+
data = json.load(f)
|
121 |
+
|
122 |
+
SERVICE2META = {}
|
123 |
+
SLOTS, INTENTS = set(), set()
|
124 |
+
for service_meta in data:
|
125 |
+
service = service_meta["service_name"]
|
126 |
+
_restructure_service_meta(service_meta, "slots")
|
127 |
+
_restructure_service_meta(service_meta, "intents")
|
128 |
+
_update(service, service_meta, SERVICE2META)
|
129 |
+
|
130 |
+
# collect domain-independent slots
|
131 |
+
for slot in service_meta["slots"]:
|
132 |
+
SLOTS.add(slot)
|
133 |
+
|
134 |
+
for intent in service_meta["intents"]:
|
135 |
+
INTENTS.add(intent)
|
136 |
+
|
137 |
+
print("Load schema, intents: {}, slots: {}".format(len(INTENTS), len(SLOTS)))
|
138 |
+
return SERVICE2META, INTENTS, SLOTS
|
139 |
+
|
140 |
+
|
141 |
+
def normalise_intent(intent):
|
142 |
+
"""convert intent into natural language, e.g., find_hotel -> find hotel"""
|
143 |
+
if intent == "police":
|
144 |
+
intent = "find_police"
|
145 |
+
if intent == "book_taxi":
|
146 |
+
intent = "find_taxi"
|
147 |
+
assert "_" in intent
|
148 |
+
return " ".join(intent.split("_"))
|
149 |
+
|
150 |
+
|
151 |
+
def normalise_slot(slot):
|
152 |
+
if slot == "pricerange":
|
153 |
+
return "price range"
|
154 |
+
|
155 |
+
elif slot == "bookday":
|
156 |
+
return "book day"
|
157 |
+
|
158 |
+
elif slot == "bookpeople":
|
159 |
+
return "book people"
|
160 |
+
|
161 |
+
elif slot == "booktime":
|
162 |
+
return "book time"
|
163 |
+
|
164 |
+
elif slot == "bookstay":
|
165 |
+
return "book stay"
|
166 |
+
|
167 |
+
elif slot == "ref":
|
168 |
+
return "reference"
|
169 |
+
|
170 |
+
elif slot == "arriveby":
|
171 |
+
return "arrive by"
|
172 |
+
|
173 |
+
elif slot == "leaveat":
|
174 |
+
return "leave at"
|
175 |
+
|
176 |
+
elif slot == "trainid":
|
177 |
+
return "train id"
|
178 |
+
|
179 |
+
elif slot == "openhours":
|
180 |
+
return "open hours"
|
181 |
+
|
182 |
+
elif slot == "entrancefee":
|
183 |
+
return "entrance fee"
|
184 |
+
|
185 |
+
elif slot in ["none", "?"]:
|
186 |
+
return "Empty"
|
187 |
+
|
188 |
+
else:
|
189 |
+
return slot
|
190 |
+
|
191 |
+
|
192 |
+
def normalise_value(value):
|
193 |
+
# deal with binary and empty values
|
194 |
+
if value == "yes":
|
195 |
+
return "True"
|
196 |
+
|
197 |
+
elif value == "no":
|
198 |
+
return "False"
|
199 |
+
|
200 |
+
elif value in ["none", "?"]:
|
201 |
+
return "Empty"
|
202 |
+
|
203 |
+
else:
|
204 |
+
return value
|