alistairmcleay commited on
Commit
b16a132
·
1 Parent(s): 6aeedda

Added dialogue system code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +167 -4
  2. scripts/UBAR_code/__init__.py +0 -0
  3. scripts/UBAR_code/data_analysis.py +170 -0
  4. scripts/UBAR_code/interaction/UBAR_interact.py +457 -0
  5. scripts/UBAR_code/interaction/__init__.py +0 -0
  6. scripts/UBAR_code/interaction/config.yaml +23 -0
  7. scripts/UBAR_code/preprocess.py +576 -0
  8. scripts/UBAR_code/preprocess2.1.py +585 -0
  9. scripts/UBAR_code/train_ubar.py +697 -0
  10. scripts/agent_agent.yaml +0 -0
  11. scripts/crazyneuraluser.egg-info/PKG-INFO +171 -0
  12. scripts/crazyneuraluser.egg-info/SOURCES.txt +0 -0
  13. scripts/crazyneuraluser.egg-info/dependency_links.txt +1 -0
  14. scripts/crazyneuraluser.egg-info/not-zip-safe +1 -0
  15. scripts/crazyneuraluser.egg-info/requires.txt +15 -0
  16. scripts/crazyneuraluser.egg-info/top_level.txt +1 -0
  17. scripts/simulate_interaction.py +171 -0
  18. scripts/template_train_model.py +45 -0
  19. scripts/user_model_code/__init__.py +0 -0
  20. scripts/user_model_code/decode.sh +37 -0
  21. scripts/user_model_code/interaction/__init__.py +0 -0
  22. scripts/user_model_code/interaction/config.yaml +12 -0
  23. scripts/user_model_code/interaction/multiwoz_interact.py +1034 -0
  24. scripts/user_model_code/interaction/schema.json +712 -0
  25. scripts/user_model_code/interaction/utils.py +308 -0
  26. scripts/user_model_code/main_user_model.py +347 -0
  27. scripts/user_model_code/preprocess_multiwoz.py +528 -0
  28. scripts/user_model_code/preprocess_sgd.py +431 -0
  29. scripts/user_model_code/train.sh +51 -0
  30. src/crazyneuraluser.egg-info/PKG-INFO +173 -0
  31. src/crazyneuraluser.egg-info/SOURCES.txt +76 -0
  32. src/crazyneuraluser.egg-info/dependency_links.txt +1 -0
  33. src/crazyneuraluser.egg-info/not-zip-safe +1 -0
  34. src/crazyneuraluser.egg-info/requires.txt +15 -0
  35. src/crazyneuraluser.egg-info/top_level.txt +1 -0
  36. src/crazyneuraluser/UBAR_code/__init__.py +16 -0
  37. src/crazyneuraluser/UBAR_code/clean_dataset.py +334 -0
  38. src/crazyneuraluser/UBAR_code/config.py +164 -0
  39. src/crazyneuraluser/UBAR_code/config21.py +169 -0
  40. src/crazyneuraluser/UBAR_code/db_ops.py +314 -0
  41. src/crazyneuraluser/UBAR_code/eval.py +932 -0
  42. src/crazyneuraluser/UBAR_code/ontology.py +328 -0
  43. src/crazyneuraluser/UBAR_code/reader.py +1262 -0
  44. src/crazyneuraluser/UBAR_code/utils.py +292 -0
  45. src/crazyneuraluser/user_model_code/analysis_multiwoz.py +119 -0
  46. src/crazyneuraluser/user_model_code/analysis_sgd.py +483 -0
  47. src/crazyneuraluser/user_model_code/argument.py +153 -0
  48. src/crazyneuraluser/user_model_code/dataset.py +297 -0
  49. src/crazyneuraluser/user_model_code/utils_generation.py +210 -0
  50. src/crazyneuraluser/user_model_code/utils_multiwoz.py +204 -0
app.py CHANGED
@@ -1,7 +1,170 @@
1
  import gradio as gr
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
 
 
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import scripts.simulate_interaction as si
3
+ import sys
4
+ import traceback
5
+ import pandas as pd
6
 
7
+ # from tqdm import tqdm
8
+ from scripts.UBAR_code.interaction import UBAR_interact
9
+ from scripts.user_model_code.interaction import multiwoz_interact
10
+ from scripts.UBAR_code.interaction.UBAR_interact import bcolors
11
 
12
+
13
+ def instantiate_agents():
14
+
15
+ UBAR_checkpoint_path = "models/UBAR/experiments/distilgpt-2_sd11_lr0.0001_bs16_ga2/epoch50_trloss0.59_gpt2"
16
+ user_model_checkpoint_path = "models/user_model/MultiWOZ-full_checkpoint_step340k"
17
+
18
+ sys_model = UBAR_interact.UbarSystemModel(
19
+ "UBAR_sys_model", UBAR_checkpoint_path, "scripts/UBAR_code/interaction/config.yaml"
20
+ )
21
+
22
+ user_model = multiwoz_interact.NeuralAgent(
23
+ "user", user_model_checkpoint_path, "scripts/user_model_code/interaction/config.yaml"
24
+ )
25
+
26
+ return sys_model, user_model
27
+
28
+
29
+ def read_multiwoz_data():
30
+ """
31
+ Read the multiwoz 2.0 raw data from the .json file
32
+ """
33
+ raw_mwoz_20_path = "data/raw/UBAR/multi-woz/data.json"
34
+ df_raw_mwoz = pd.read_json(raw_mwoz_20_path)
35
+ return df_raw_mwoz
36
+
37
+
38
+ def load_test_val_lists():
39
+ val_list_file = "data/raw/UBAR/multi-woz/valListFile.json"
40
+ test_list_file = "data/raw/UBAR/multi-woz/testListFile.json"
41
+
42
+ with open(val_list_file, "r") as f:
43
+ val_list = f.readlines()
44
+ val_list = [x.strip() for x in val_list]
45
+
46
+ with open(test_list_file, "r") as f:
47
+ test_list = f.readlines()
48
+ test_list = [x.strip() for x in test_list]
49
+
50
+ return val_list, test_list
51
+
52
+
53
+ def main(
54
+ write_to_file=False, ground_truth_system_responses=False, train_only=True, n_dialogues="all", log_successes=False
55
+ ):
56
+ sys_model, user_model = instantiate_agents()
57
+
58
+ # TODO: move hardcoded vars into config file
59
+ raw_mwoz_20_path = "data/raw/UBAR/multi-woz/data.json"
60
+ user_utterances_out_path = "data/preprocessed/UBAR/user_utterances_from_simulator.txt"
61
+ logging_successes_path = "data/preprocessed/UBAR/logging_successes"
62
+ sys_model.print_intermediary_info = False
63
+ user_model.print_intermediary_info = False
64
+
65
+ df_raw_mwoz = pd.read_json(raw_mwoz_20_path)
66
+ if n_dialogues == "all":
67
+ n_dialogues = len(df_raw_mwoz.columns)
68
+
69
+ curr_dialogue_user_utterances_formatted = []
70
+
71
+ print("Loading goals...")
72
+ goals = multiwoz_interact.read_multiWOZ_20_goals(raw_mwoz_20_path, n_dialogues)
73
+
74
+ # Write column headers
75
+ if write_to_file:
76
+ with open(user_utterances_out_path, "w") as f:
77
+ f.write("Dialogue #\tDialogue ID\tTurn #\tSystem Response\n")
78
+
79
+ print("Loading data...")
80
+ df_mwoz_data = read_multiwoz_data()
81
+ val_list, test_list = load_test_val_lists()
82
+
83
+ successful_dialogues = 0
84
+ total_dialogues_generated = 0 # train dialogues only
85
+ for dialogue_idx, (goal, dialogue_filename) in enumerate(zip(goals, df_mwoz_data.columns)):
86
+ if log_successes:
87
+ # log successful_dialogues to logging_successes_path every 100 dialogues
88
+ if dialogue_idx % 100 == 0:
89
+ with open(logging_successes_path, "w") as f:
90
+ f.write(str(successful_dialogues) + " / " + str(total_dialogues_generated))
91
+
92
+ curr_dialogue_user_utterances_formatted = []
93
+ if train_only:
94
+ if dialogue_filename in val_list or dialogue_filename in test_list:
95
+ continue
96
+
97
+ total_dialogues_generated += 1
98
+ print("Dialogue: {}".format(dialogue_filename))
99
+
100
+ # There are occasionally exceptions thrown from one of the agents, usually the user
101
+ # In this case we simply continue to the next dialogue
102
+ try:
103
+ # Reset state after each dialogue
104
+ sys_model.init_session()
105
+ user_model.init_session(ini_goal=goal)
106
+ sys_response = ""
107
+
108
+ for turn_idx in range(50):
109
+ # Turn idx in this case represents the turn as one user utterance AND one system response
110
+ usr_response_raw_data_idx = turn_idx * 2
111
+ sys_response_raw_data_idx = turn_idx * 2 + 1
112
+
113
+ user_utterance = user_model.response(sys_response)
114
+ print(bcolors.OKBLUE + "User: " + bcolors.ENDC + user_utterance)
115
+
116
+ if write_to_file:
117
+ user_utterance = user_utterance.replace("\n", " ")
118
+ curr_dialogue_user_utterances_formatted.append(
119
+ str(dialogue_idx)
120
+ + "\t"
121
+ + dialogue_filename
122
+ + "\t"
123
+ + str(usr_response_raw_data_idx)
124
+ + "\t"
125
+ + user_utterance
126
+ + "\n"
127
+ )
128
+
129
+ if user_model.is_terminated():
130
+ successful_dialogues += 1
131
+ print(bcolors.OKCYAN + "Dialogue terminated successfully!" + bcolors.ENDC)
132
+ print(bcolors.OKCYAN + "---" * 30 + bcolors.ENDC + "\n")
133
+ if write_to_file:
134
+ # Write whole dialogue to file
135
+ with open(user_utterances_out_path, "a") as f:
136
+ for line in curr_dialogue_user_utterances_formatted:
137
+ f.write(line)
138
+ break
139
+
140
+ # Next turn materials
141
+ if ground_truth_system_responses:
142
+ # If we are at the end of the ground truth dialogues
143
+ if len(df_mwoz_data.iloc[:, dialogue_idx].log) <= sys_response_raw_data_idx:
144
+ print(bcolors.RED + "Dialogue terminated unsuccessfully!" + bcolors.ENDC)
145
+ print(bcolors.RED + "---" * 30 + bcolors.ENDC + "\n")
146
+ break
147
+ sys_response = df_mwoz_data.iloc[:, dialogue_idx].log[sys_response_raw_data_idx]["text"]
148
+ else:
149
+ sys_response = sys_model.response(user_utterance, turn_idx)
150
+ capitalised_sys_response = sys_response[0].upper() + sys_response[1:]
151
+ print(bcolors.GREEN + "System: " + bcolors.ENDC + capitalised_sys_response)
152
+
153
+ except Exception:
154
+ print(bcolors.RED + "*" * 30 + bcolors.ENDC)
155
+ print(bcolors.RED + "Error in dialogue {}".format(dialogue_filename) + bcolors.ENDC)
156
+ print(bcolors.RED + "*" * 30 + bcolors.ENDC)
157
+ traceback.print_exc()
158
+ continue
159
+
160
+ print("Successful dialogues: {}".format(successful_dialogues))
161
+ print("Total dialogues: {}".format(n_dialogues))
162
+ print("% Successful Dialopues: {}".format(successful_dialogues / n_dialogues))
163
+
164
+
165
+ def test():
166
+ return "SUCCESS"
167
+
168
+
169
+ iface = gr.Interface(fn=test, outputs="text")
170
+ iface.launch()
scripts/UBAR_code/__init__.py ADDED
File without changes
scripts/UBAR_code/data_analysis.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import os
4
+ import re
5
+ import zipfile
6
+ from collections import OrderedDict
7
+
8
+ from crazyneuraluser.UBAR_code.ontology import all_domains
9
+
10
+ # 2.0
11
+ data_path = "data/preprocessed/UBAR/gen_usr_utt_experiment_data.json"
12
+ save_path = "data/interim/gen_usr_utts/multi-woz-analysis/"
13
+ save_path_exp = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/"
14
+ # 2.1
15
+ # data_path = 'data/raw/UBAR/MultiWOZ_2.1/'
16
+ # save_path = 'data/interim/multi-woz-2.1-analysis/'
17
+ # save_path_exp = 'data/preprocessed/multi-woz-2.1-processed/'
18
+ data_file = "data.json"
19
+ domains = all_domains
20
+ # all_domains = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'police', 'hospital']
21
+
22
+
23
+ def analysis():
24
+ compressed_raw_data = {}
25
+ goal_of_dials = {}
26
+ req_slots = {}
27
+ info_slots = {}
28
+ dom_count = {}
29
+ dom_fnlist = {}
30
+ all_domain_specific_slots = set()
31
+ for domain in domains:
32
+ req_slots[domain] = []
33
+ info_slots[domain] = []
34
+
35
+ # archive = zipfile.ZipFile(data_path + data_file + ".zip", "r")
36
+ # data = archive.open(data_file, "r").read().decode("utf-8").lower()
37
+ data = open(data_path, "r").read().lower()
38
+ ref_nos = list(set(re.findall(r"\"reference\"\: \"(\w+)\"", data)))
39
+ data = json.loads(data)
40
+
41
+ for fn, dial in data.items():
42
+ goals = dial["goal"]
43
+ logs = dial["log"]
44
+
45
+ # get compressed_raw_data and goal_of_dials
46
+ compressed_raw_data[fn] = {"goal": {}, "log": []}
47
+ goal_of_dials[fn] = {}
48
+ for dom, goal in goals.items(): # get goal of domains that are in demmand
49
+ if dom != "topic" and dom != "message" and goal:
50
+ compressed_raw_data[fn]["goal"][dom] = goal
51
+ goal_of_dials[fn][dom] = goal
52
+
53
+ for turn in logs:
54
+ if not turn["metadata"]: # user's turn
55
+ compressed_raw_data[fn]["log"].append({"text": turn["text"]})
56
+ else: # system's turn
57
+ meta = turn["metadata"]
58
+ turn_dict = {"text": turn["text"], "metadata": {}}
59
+ for (
60
+ dom,
61
+ book_semi,
62
+ ) in meta.items(): # for every domain, sys updates "book" and "semi"
63
+ book, semi = book_semi["book"], book_semi["semi"]
64
+ record = False
65
+ for (
66
+ slot,
67
+ value,
68
+ ) in book.items(): # record indicates non-empty-book domain
69
+ if value not in ["", []]:
70
+ record = True
71
+ if record:
72
+ turn_dict["metadata"][dom] = {}
73
+ turn_dict["metadata"][dom]["book"] = book # add that domain's book
74
+ record = False
75
+ for (
76
+ slot,
77
+ value,
78
+ ) in semi.items(): # here record indicates non-empty-semi domain
79
+ if value not in ["", []]:
80
+ record = True
81
+ break
82
+ if record:
83
+ for s, v in copy.deepcopy(semi).items():
84
+ if v == "not mentioned":
85
+ del semi[s]
86
+ if not turn_dict["metadata"].get(dom):
87
+ turn_dict["metadata"][dom] = {}
88
+ turn_dict["metadata"][dom]["semi"] = semi # add that domain's semi
89
+ compressed_raw_data[fn]["log"].append(turn_dict) # add to log the compressed turn_dict
90
+
91
+ # get domain statistics
92
+ dial_type = (
93
+ "multi" if "mul" in fn or "MUL" in fn else "single"
94
+ ) # determine the dialog's type: sinle or multi
95
+ if fn in ["pmul2756.json", "pmul4958.json", "pmul3599.json"]:
96
+ dial_type = "single"
97
+ dial_domains = [dom for dom in domains if goals[dom]] # domains that are in demmand
98
+ dom_str = ""
99
+ for dom in dial_domains:
100
+ if not dom_count.get(dom + "_" + dial_type): # count each domain type, with single or multi considered
101
+ dom_count[dom + "_" + dial_type] = 1
102
+ else:
103
+ dom_count[dom + "_" + dial_type] += 1
104
+ if not dom_fnlist.get(dom + "_" + dial_type): # keep track the file number of each domain type
105
+ dom_fnlist[dom + "_" + dial_type] = [fn]
106
+ else:
107
+ dom_fnlist[dom + "_" + dial_type].append(fn)
108
+ dom_str += "%s_" % dom
109
+ dom_str = dom_str[:-1] # substract the last char in dom_str
110
+ if dial_type == "multi": # count multi-domains
111
+ if not dom_count.get(dom_str):
112
+ dom_count[dom_str] = 1
113
+ else:
114
+ dom_count[dom_str] += 1
115
+ if not dom_fnlist.get(dom_str):
116
+ dom_fnlist[dom_str] = [fn]
117
+ else:
118
+ dom_fnlist[dom_str].append(fn)
119
+ ######
120
+
121
+ # get informable and requestable slots statistics
122
+ for domain in domains:
123
+ info_ss = goals[domain].get("info", {})
124
+ book_ss = goals[domain].get("book", {})
125
+ req_ss = goals[domain].get("reqt", {})
126
+ for info_s in info_ss:
127
+ all_domain_specific_slots.add(domain + "-" + info_s)
128
+ if info_s not in info_slots[domain]:
129
+ info_slots[domain] += [info_s]
130
+ for book_s in book_ss:
131
+ if "book_" + book_s not in info_slots[domain] and book_s not in [
132
+ "invalid",
133
+ "pre_invalid",
134
+ ]:
135
+ all_domain_specific_slots.add(domain + "-" + book_s)
136
+ info_slots[domain] += ["book_" + book_s]
137
+ for req_s in req_ss:
138
+ if req_s not in req_slots[domain]:
139
+ req_slots[domain] += [req_s]
140
+
141
+ # result statistics
142
+ if not os.path.exists(save_path):
143
+ os.mkdir(save_path)
144
+ if not os.path.exists(save_path_exp):
145
+ os.mkdir(save_path_exp)
146
+ with open(save_path + "req_slots.json", "w") as sf:
147
+ json.dump(req_slots, sf, indent=2)
148
+ with open(save_path + "info_slots.json", "w") as sf:
149
+ json.dump(info_slots, sf, indent=2)
150
+ with open(save_path + "all_domain_specific_info_slots.json", "w") as sf:
151
+ json.dump(list(all_domain_specific_slots), sf, indent=2)
152
+ print("slot num:", len(list(all_domain_specific_slots)))
153
+ with open(save_path + "goal_of_each_dials.json", "w") as sf:
154
+ json.dump(goal_of_dials, sf, indent=2)
155
+ with open(save_path + "compressed_data.json", "w") as sf:
156
+ json.dump(compressed_raw_data, sf, indent=2)
157
+ with open(save_path + "domain_count.json", "w") as sf:
158
+ single_count = [d for d in dom_count.items() if "single" in d[0]]
159
+ multi_count = [d for d in dom_count.items() if "multi" in d[0]]
160
+ other_count = [d for d in dom_count.items() if "multi" not in d[0] and "single" not in d[0]]
161
+ dom_count_od = OrderedDict(single_count + multi_count + other_count)
162
+ json.dump(dom_count_od, sf, indent=2)
163
+ with open(save_path_exp + "reference_no.json", "w") as sf:
164
+ json.dump(ref_nos, sf, indent=2)
165
+ with open(save_path_exp + "domain_files.json", "w") as sf:
166
+ json.dump(dom_fnlist, sf, indent=2)
167
+
168
+
169
+ if __name__ == "__main__":
170
+ analysis()
scripts/UBAR_code/interaction/UBAR_interact.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ import random
4
+ import string
5
+
6
+ # import bcolors
7
+ from omegaconf import OmegaConf
8
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
9
+
10
+ from crazyneuraluser.UBAR_code.config import global_config as cfg
11
+ from crazyneuraluser.UBAR_code.reader import MultiWozReader
12
+ from crazyneuraluser.UBAR_code.db_ops import MultiWozDB
13
+
14
+ from typing import List
15
+
16
+
17
+ class bcolors:
18
+ HEADER = "\033[95m"
19
+ OKBLUE = "\033[94m"
20
+ OKCYAN = "\033[96m"
21
+ GREEN = "\033[92m"
22
+ YELLOW = "\033[93m"
23
+ RED = "\033[91m"
24
+ ENDC = "\033[0m"
25
+ BOLD = "\033[1m"
26
+ UNDERLINE = "\033[4m"
27
+
28
+
29
+ class UbarSystemModel: # may inherit convlab or not, just like andy's
30
+ def __init__(self, name: str, checkpoint_path: str, model_config_path: str):
31
+
32
+ self.tokenizer = GPT2Tokenizer.from_pretrained(checkpoint_path)
33
+ self.model = GPT2LMHeadModel.from_pretrained(checkpoint_path)
34
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
35
+ self.name = name
36
+ self.turn_domain = ["general"] # returns a list of one string that is the domain e.g. 'taxi'
37
+ # (this is because of the way the db_ops.py deals with the domain. It should really be a string.)
38
+
39
+ self.ubar_status = {"dialogue_terminate": False}
40
+
41
+ self.print_intermediary_info = False
42
+
43
+ self.config = OmegaConf.load(model_config_path)
44
+ self.previous_turn = {"user": [], "bspn": [], "aspn": [], "db": []}
45
+
46
+ # NB: best to use corpus goals to guide interactions - baselines/simulate_agent.py allows that.
47
+
48
+ # initialize multiwoz reader and db_ops
49
+ self.reader = MultiWozReader(self.tokenizer)
50
+ self.db = MultiWozDB(self.config.dbs_path)
51
+
52
+ def lexicalize_sys_response(self, sys_response, domain_hits, decoded_belief_state_subseq) -> str:
53
+ lexicalized_sys_response = ""
54
+
55
+ # Track entities already filled e.g. if there are 3 restaurants track which have already been added to a slot
56
+ max_idx_of_added_entities = -1
57
+
58
+ # Fill slots with values from the DB (lexicalization)
59
+ for token in sys_response.split():
60
+ token = token.strip(" .,;:")
61
+ if token.startswith("["): # It is a slot to be filled
62
+
63
+ # Note in hotel there is specific price data too but to simplify things
64
+ # we just use the price range (e.g. moderate)
65
+ # TODO: there are different uses of price in different databases ('price' vs 'pricerange':
66
+ # need to deal with this appropriately below)
67
+ slots_to_db_keys_map = {
68
+ "[value_price]": "price",
69
+ "[value_pricerange]": "pricerange",
70
+ "[value_food]": "food",
71
+ "[value_area]": "area",
72
+ "[value_type]": "type",
73
+ "[value_phone]": "phone",
74
+ "[value_address]": "address",
75
+ "[value_leave]": "leave",
76
+ "[value_postcode]": "postcode",
77
+ "[value_id]": "id",
78
+ "[value_arrive]": "arrive",
79
+ "[value_stars]": "stars",
80
+ "[value_day]": "day",
81
+ "[value_destination]": "destination",
82
+ "[value_car]": "taxi_types",
83
+ "[value_departure]": "departure",
84
+ "[value_people]": "people",
85
+ "[value_stay]": "stay",
86
+ "[value_department]": "department",
87
+ "[value_time]": "time",
88
+ "[value_name]": "name",
89
+ "[value_reference]": "reference",
90
+ }
91
+ # Hospital domain is a strange outlier data structure
92
+ if self.turn_domain == ["hospital"] and token == "[value_address]":
93
+ token = "1 Addenbrooks Street"
94
+ elif self.turn_domain == ["hospital"] and token == "[value_postcode]":
95
+ token = "CB11QD"
96
+
97
+ # So does taxi
98
+ elif self.turn_domain == ["taxi"] and token == "[value_phone]" and domain_hits != []:
99
+ token = domain_hits[0]["taxi_phone"]
100
+
101
+ # Deal with value_name differently because there can be multiple
102
+ elif token == "[value_name]" and domain_hits != []:
103
+ token = domain_hits[max_idx_of_added_entities + 1]["name"]
104
+ max_idx_of_added_entities += 1
105
+
106
+ # This slot tells the user how many db hits there were matching their constraints
107
+ elif token == "[value_choice]" and domain_hits != []:
108
+ token = len(domain_hits)
109
+
110
+ # Randomly generate the reference
111
+ elif token == "[value_reference]" and domain_hits != []:
112
+ token = "".join(random.choices(string.ascii_uppercase, k=10))
113
+
114
+ else:
115
+ # First check can we fill the token from the db results
116
+ db_success = False
117
+ if domain_hits != []:
118
+ for slot, db_key in slots_to_db_keys_map.items():
119
+ if token == slot and db_key in domain_hits[0]:
120
+ token = domain_hits[0][db_key]
121
+ db_success = True
122
+
123
+ # If we cannot, then try to fill it from the belief state by looking for a match
124
+ # in the belief state and then if there is a match adding the next token.
125
+ # This is not perfect as some are more than one word but its probably good enough.
126
+ if not db_success:
127
+ decoded_belief_states = decoded_belief_state_subseq.split()
128
+ for idx, belief_state_slot in enumerate(decoded_belief_states):
129
+ if token in slots_to_db_keys_map.keys():
130
+ if slots_to_db_keys_map[token] == belief_state_slot:
131
+ token == decoded_belief_states[idx + 1]
132
+
133
+ # Otherwise just leave the slot as it is as we have failed to fill it
134
+
135
+ lexicalized_sys_response += str(token)
136
+ lexicalized_sys_response += " "
137
+
138
+ return lexicalized_sys_response
139
+
140
+ def set_turn_domain(self, belief_span_ids_subseq, sys_act_span_ids_subseq=None) -> None:
141
+ """
142
+ IMPORTANT: use_system_act is not None when actually querying the DB to
143
+ lexicalise the system response. When it is None the Belief state NOT the system act is used to determine
144
+ the domain. In self.response() the DB is queried twice. The first time is using the Belief state as the system
145
+ act has not yet been generated, and it is only used to find out if there are matches in the DB for the current
146
+ domain + constraints. Then, after the system act is generated, we call the DB to actually get the results to
147
+ lexicalise the system response. It is much more important that the domain is correct for the second call, and
148
+ the system act is much more accurate at determining the domain.
149
+ """
150
+
151
+ if sys_act_span_ids_subseq is None:
152
+ decoded_belief_state_subseq = self.tokenizer.decode(belief_span_ids_subseq[1:-1])
153
+ decoded_prev_belief_state_subseq = self.tokenizer.decode(self.previous_turn["bspn"][1:-1])
154
+
155
+ # If it is the first turn and the belief state is empty then set the domain to general
156
+ if self.previous_turn["bspn"] == [] and len(belief_span_ids_subseq) == 2:
157
+ self.turn_domain = ["general"]
158
+ return
159
+
160
+ # If the belief state doesn't change then keep the same domain
161
+ if belief_span_ids_subseq == self.previous_turn["bspn"]:
162
+ return
163
+
164
+ # The domain has changed, get the new one (from the right)
165
+ else:
166
+ # remove substring from string
167
+ if decoded_prev_belief_state_subseq in decoded_belief_state_subseq:
168
+ decoded_new_tokens = decoded_belief_state_subseq.replace("decoded_prev_belief_state_subseq", "")
169
+ most_recent_domain_in_belief_state = [
170
+ [token.strip("[]") for token in decoded_new_tokens.split() if token.startswith("[")][-1]
171
+ ]
172
+ self.turn_domain = most_recent_domain_in_belief_state
173
+ else:
174
+ # Sometimes the previous belief state is not in the current belief state as
175
+ # the output changes very slightly (say by one word) - in this case just keep the same domain
176
+ # TODO: Could probably handle this better.
177
+ if self.print_intermediary_info:
178
+ print(
179
+ bcolors.YELLOW
180
+ + "!Previous belief state not in current belief state! Details below:"
181
+ + bcolors.ENDC
182
+ )
183
+ print("Previous Belief State: " + decoded_prev_belief_state_subseq)
184
+ print("Current Belief State: " + decoded_belief_state_subseq)
185
+
186
+ else:
187
+ decoded_sys_act_subseq = self.tokenizer.decode(sys_act_span_ids_subseq[1:-1])
188
+
189
+ most_recent_domain_in_sys_act = [
190
+ [token.strip("[]") for token in decoded_sys_act_subseq.split() if token.startswith("[")][0]
191
+ ]
192
+ self.turn_domain = most_recent_domain_in_sys_act
193
+
194
+ def get_domain_hits(self, decoded_belief_state_subseq) -> dict:
195
+ # Get hits from db based on belief state, unless its a general turn (no hits then)
196
+ constraint_dict = self.reader.bspan_to_constraint_dict(decoded_belief_state_subseq)
197
+ query_turn_domain = self.turn_domain[0] # db.queryJsons needs a string not a list (single domain)
198
+ # If the constraint dict doesn't contain any constraints for the current domain then pass an empty dict
199
+ if query_turn_domain in constraint_dict:
200
+ domain_hits = self.db.queryJsons(query_turn_domain, constraint_dict[query_turn_domain])
201
+ else:
202
+ domain_hits = self.db.queryJsons(query_turn_domain, {})
203
+
204
+ return domain_hits
205
+
206
+ def print_turn_intermediate_info(self, generated_subseq_ids_map) -> None:
207
+ print(bcolors.OKCYAN + "Turn domain: " + bcolors.ENDC + "[" + str(self.turn_domain[0]) + "]")
208
+
209
+ belief_state = self.tokenizer.decode(generated_subseq_ids_map["bspn"])
210
+ print(bcolors.OKCYAN + "Belief state: " + bcolors.ENDC + belief_state)
211
+
212
+ db_output = self.tokenizer.decode(generated_subseq_ids_map["db"])
213
+ print(bcolors.OKCYAN + "DB Output: " + bcolors.ENDC + db_output)
214
+
215
+ sys_act = self.tokenizer.decode(generated_subseq_ids_map["aspn"])
216
+ print(bcolors.OKCYAN + "System Act: " + bcolors.ENDC + sys_act)
217
+
218
+ def _init_ubar_status(self) -> dict:
219
+ return {"dialogue_terminate": False}
220
+
221
+ def init_session(self):
222
+ self.ubar_status = self._init_ubar_status()
223
+ self.previous_turn = {"user": [], "bspn": [], "aspn": [], "db": []}
224
+ self.turn_domain = ["general"]
225
+
226
+ def is_terminated(self) -> bool:
227
+ """This should tell an external client whether the user model considers they have completed the task."""
228
+ # return False
229
+ return self.ubar_status["dialogue_terminate"]
230
+
231
+ def _activate_dialogue_terminate(self) -> None:
232
+ """Turn on the ubar status about dialogue termination"""
233
+ self.ubar_status["dialogue_terminate"] = True
234
+
235
+ def add_torch_input_eval(self, inputs):
236
+ # inputs: context
237
+ inputs["context_tensor"] = torch.tensor([inputs["context"]]).to(self.device)
238
+ return inputs
239
+
240
+ def prepare_input_for_model(self, user_utterance: str, turn_id: int) -> torch.Tensor:
241
+ # TODO: CONVERT DIALOGUE HISTORY TO TOKEN IDS
242
+
243
+ tokenised_user_utterance = self.tokenizer.encode("<sos_u> " + user_utterance + " <eos_u>")
244
+ # In this application turn always only contains ["user"], not ["bspn", "aspn", "db"] etc.
245
+ turn = {"user": tokenised_user_utterance}
246
+
247
+ first_turn = turn_id == 0
248
+ inputs = self.reader.convert_turn_eval(turn, self.previous_turn, first_turn)
249
+ inputs = self.add_torch_input_eval(inputs)
250
+
251
+ return inputs
252
+
253
+ def decode_generated_bspn(self, generated) -> List[int]:
254
+ eos_b_id = self.tokenizer.encode(["<eos_b>"])[0]
255
+ if eos_b_id in generated:
256
+ eos_b_idx = generated.index(eos_b_id)
257
+ else:
258
+ eos_b_idx = len(generated) - 1
259
+ return generated[: eos_b_idx + 1]
260
+
261
+ def decode_grenerated_act_resp(self, generated) -> dict:
262
+ """
263
+ decode generated
264
+ return decoded['resp'] ('bspn', 'aspn')
265
+ """
266
+ decoded = {}
267
+ eos_a_id = self.tokenizer.encode(["<eos_a>"])[0]
268
+ eos_r_id = self.tokenizer.encode(["<eos_r>"])[0]
269
+ # eos_b_id = self.tokenizer.encode(["<eos_b>"])[0]
270
+
271
+ # eos_r may not exists if gpt2 generated repetitive words.
272
+ if eos_r_id in generated:
273
+ eos_r_idx = generated.index(eos_r_id)
274
+ else:
275
+ eos_r_idx = len(generated) - 1
276
+
277
+ if cfg.use_true_curr_aspn: # only predict resp
278
+ decoded["resp"] = generated[: eos_r_idx + 1]
279
+ else: # predicted aspn, resp
280
+ eos_a_idx = generated.index(eos_a_id)
281
+ decoded["aspn"] = generated[: eos_a_idx + 1]
282
+ decoded["resp"] = generated[eos_a_idx + 1 : eos_r_idx + 1]
283
+ return decoded
284
+
285
+ def generate_ids_subseq_map(self, inputs):
286
+
287
+ context_input_subseq = inputs["context"]
288
+ # decoded_context_input_subseq = self.tokenizer.decode(context_input_subseq)
289
+ # Check if model has put duplicate tags in the context and if so remove one of the duplicates
290
+ # Yes this is kind of hacky, but UBAR seems to learn to duplicate certain tags - I don't know why
291
+ # Also instead of decoding and encoding here tags could be checked with their ids - but time is short...
292
+ # cleaned_decoded_list = []
293
+ # prev_token = ""
294
+ # for token in decoded_context_input_subseq.split():
295
+ # if token.startswith("<") and token.endswith(">"): # It is a tag
296
+ # if token == prev_token: # It is a duplicate tag
297
+ # continue
298
+ # cleaned_decoded_list.append(token)
299
+ # prev_token = token
300
+ # decoded_context_input_subseq = " ".join(cleaned_decoded_list)
301
+ # context_input_subseq = self.tokenizer.encode(decoded_context_input_subseq)
302
+
303
+ context_input_subeq_tensor = inputs["context_tensor"]
304
+
305
+ # TODO: FIND OUT BY COMPARING WITH MODEL.VALIDATE() how to calculate context_length
306
+ context_length = len(context_input_subseq)
307
+
308
+ belief_state_ids = self.model.generate(
309
+ input_ids=context_input_subeq_tensor,
310
+ max_length=context_length + 60,
311
+ temperature=0.7,
312
+ top_p=1,
313
+ num_beams=1,
314
+ pad_token_id=self.tokenizer.eos_token_id,
315
+ eos_token_id=self.tokenizer.encode(["<eos_b>"])[0],
316
+ )
317
+ gen_belief_state_token_ids = belief_state_ids[0].cpu().numpy().tolist() # type: list[int]
318
+ belief_span_ids_subseq = self.decode_generated_bspn(
319
+ gen_belief_state_token_ids[context_length - 1 :]
320
+ ) # type: list[int]
321
+
322
+ self.set_turn_domain(belief_span_ids_subseq)
323
+
324
+ db_result = self.reader.bspan_to_DBpointer(
325
+ self.tokenizer.decode(belief_span_ids_subseq), self.turn_domain
326
+ ) # type: str
327
+ db_ids_subseq = self.tokenizer.convert_tokens_to_ids(
328
+ self.tokenizer.tokenize("<sos_db> " + db_result + " <eos_db>")
329
+ ) + self.tokenizer.encode(["<sos_a>"])
330
+
331
+ # TODO: context_input_subseq is already a tensor but the other two subseqs aren't - why?
332
+ act_response_gen_input_subseq = context_input_subseq + belief_span_ids_subseq + db_ids_subseq
333
+ act_response_gen_input_subseq_tensor = torch.tensor([act_response_gen_input_subseq]).to(self.device)
334
+ context_length = len(act_response_gen_input_subseq)
335
+
336
+ outputs_db = self.model.generate(
337
+ input_ids=act_response_gen_input_subseq_tensor,
338
+ max_length=context_length + 80,
339
+ temperature=0.7,
340
+ top_p=1,
341
+ num_beams=1,
342
+ pad_token_id=self.tokenizer.eos_token_id,
343
+ eos_token_id=self.tokenizer.encode(["<eos_r>"])[0],
344
+ )
345
+ generated_act_resp_token_ids = outputs_db[0].cpu().numpy().tolist() # type: list[int]
346
+ generated_act_resp_token_ids = generated_act_resp_token_ids[context_length - 1 :]
347
+
348
+ try:
349
+ generated_subseq_ids_map = self.decode_grenerated_act_resp(generated_act_resp_token_ids)
350
+ # TODO: IF YOU WANT Option b) then you just read the ['resp'] key and convert to string using huggingface;
351
+ # that would be sys_response; Obviously, this applies to Option a as well
352
+ generated_subseq_ids_map["bspn"] = belief_span_ids_subseq
353
+ # TODO: Option a) STORE THESE MAPPINGS IN SELF.CONTEXT IF YOU WANT TO HAVE
354
+ # {U_1, BS_1, DB_1, A_1, R_1, U_2, BS_2... history}
355
+
356
+ generated_subseq_ids_map["db"] = db_ids_subseq
357
+ generated_subseq_ids_map["labels"] = context_input_subseq
358
+
359
+ except ValueError:
360
+ generated_subseq_ids_map = {"resp": [], "bspn": [], "aspn": [], "db": [], "labels": []}
361
+
362
+ # IMPORTANT: this is how all of the previous state is updated (appended) after each turn
363
+ # Update self.previous_turn to track state to be fed into GPT2
364
+ for k, v in generated_subseq_ids_map.items():
365
+ self.previous_turn[k] = v
366
+
367
+ if self.print_intermediary_info:
368
+ self.print_turn_intermediate_info(generated_subseq_ids_map)
369
+
370
+ return generated_subseq_ids_map
371
+
372
+ def response(self, usr_utterance: str, turn_id: int) -> str:
373
+
374
+ if usr_utterance == "Goodbye":
375
+ self._activate_dialogue_terminate()
376
+ return "Session Terminated by User"
377
+
378
+ inputs = self.prepare_input_for_model(usr_utterance, turn_id)
379
+
380
+ generated_subseq_ids_map = self.generate_ids_subseq_map(inputs)
381
+ belief_span_ids_subseq = generated_subseq_ids_map["bspn"]
382
+
383
+ sys_response = self.tokenizer.decode(generated_subseq_ids_map["resp"][1:-1])
384
+
385
+ prev_turn_domain = self.turn_domain
386
+ sys_act_span_ids_subseq = generated_subseq_ids_map["aspn"]
387
+ self.set_turn_domain(belief_span_ids_subseq, sys_act_span_ids_subseq)
388
+
389
+ if self.turn_domain != ["general"]:
390
+ # If the domain changes when reading the system response, then we need to re-do the generation process
391
+ # for both the belief state and the system action and response. We do this because self.get_domain_hits()
392
+ # will break if the domain is different when querying the DB for the second time here than when it was
393
+ # originally queried above, due to the constraint dict it uses that is generated from the belief state
394
+ # How can the belief state domain and the system act domain be different? Bunch of things, for example:
395
+ # When asking for the police the belief state may be empty (so 'general' domain)
396
+ # but then the system action will have [police].
397
+ if prev_turn_domain != self.turn_domain:
398
+ if self.print_intermediary_info:
399
+ print(
400
+ bcolors.RED
401
+ + "Domain changed from {} to {}".format(prev_turn_domain, self.turn_domain)
402
+ + bcolors.RED
403
+ )
404
+ generated_subseq_ids_map = self.generate_ids_subseq_map(inputs)
405
+ sys_response = self.tokenizer.decode(generated_subseq_ids_map["resp"][1:-1])
406
+
407
+ decoded_belief_state_subseq = self.tokenizer.decode(belief_span_ids_subseq)
408
+ domain_hits = self.get_domain_hits(decoded_belief_state_subseq)
409
+ # print(bcolors.UNDERLINE + "Domain hits: \n" + bcolors.ENDC, domain_hits) # for debugging
410
+
411
+ sys_response = self.lexicalize_sys_response(sys_response, domain_hits, decoded_belief_state_subseq)
412
+
413
+ return sys_response
414
+
415
+
416
+ def interact(checkpoint_path):
417
+ sys_model = UbarSystemModel("UBAR_sys_model", checkpoint_path, "scripts/UBAR_code/interaction/config.yaml")
418
+ # TODO: Fix this hardcoded variable (should be in config)
419
+ sys_model.print_intermediary_info = True
420
+
421
+ for dial_id in range(1, 11):
422
+ print(f"In dialogue {dial_id}")
423
+
424
+ # Reset state after each dialog
425
+ sys_model.init_session()
426
+
427
+ user_utt = input(bcolors.GREEN + "Enter user response here: " + bcolors.ENDC)
428
+
429
+ for turn_id in range(100):
430
+ try:
431
+ sys_response = sys_model.response(user_utt, turn_id)
432
+ # There are a lot of edge case bugs that are possible that could break the current turn. If so, continue
433
+ # to ensure a large run across the dataset isn't ruined by a single bad turn.
434
+ except Exception() as e:
435
+ print(bcolors.RED + "Exception: {}".format(e) + bcolors.ENDC)
436
+ continue
437
+
438
+ if sys_model.is_terminated():
439
+ print(bcolors.RED + sys_response + bcolors.ENDC)
440
+ print(bcolors.RED + "---" * 30 + bcolors.ENDC)
441
+ break
442
+
443
+ print(bcolors.YELLOW + "System: " + bcolors.ENDC + sys_response)
444
+ print("---" * 30)
445
+
446
+ # next turn materials
447
+ user_utt = input(bcolors.GREEN + "Enter user response here: " + bcolors.ENDC)
448
+
449
+
450
+ if __name__ == "__main__":
451
+ if len(sys.argv) == 1:
452
+ print("Wrong argument!")
453
+ print("Usage: python UBAR_interact.py checkpoint_path")
454
+ sys.exit(1)
455
+
456
+ checkpoint_path = sys.argv[1]
457
+ interact(checkpoint_path)
scripts/UBAR_code/interaction/__init__.py ADDED
File without changes
scripts/UBAR_code/interaction/config.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ path: "./models/UBAR/experiments/distilgpt-2_sd11_lr0.0001_bs16_ga2/epoch50_trloss0.59_gpt2"
3
+ goal_update:
4
+ finish_inform: "loose" # loose or strict
5
+
6
+ schema_path: "scripts/user_model_code/interaction/schema.json"
7
+
8
+ decode:
9
+ dec_max_len: 1024
10
+ num_beams: 1
11
+ temperature: 1.0
12
+ do_sample: False
13
+
14
+ use_all_previous_context: False
15
+
16
+ dbs_path:
17
+ "attraction": "data/preprocessed/UBAR/db_processed/attraction_db_processed.json"
18
+ "hospital": "data/preprocessed/UBAR/db_processed/hospital_db_processed.json"
19
+ "hotel": "data/preprocessed/UBAR/db_processed/hotel_db_processed.json"
20
+ "police": "data/preprocessed/UBAR/db_processed/police_db_processed.json"
21
+ "restaurant": "data/preprocessed/UBAR/db_processed/restaurant_db_processed.json"
22
+ "taxi": "data/preprocessed/UBAR/db_processed/taxi_db_processed.json"
23
+ "train": "data/preprocessed/UBAR/db_processed/train_db_processed.json"
scripts/UBAR_code/preprocess.py ADDED
@@ -0,0 +1,576 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import os
4
+ import re
5
+ import zipfile
6
+ from collections import OrderedDict
7
+
8
+ import spacy
9
+ from tqdm import tqdm
10
+
11
+ from crazyneuraluser.UBAR_code import ontology, utils
12
+ from crazyneuraluser.UBAR_code.clean_dataset import clean_slot_values, clean_text
13
+ from crazyneuraluser.UBAR_code.config import global_config as cfg
14
+ from crazyneuraluser.UBAR_code.db_ops import MultiWozDB
15
+
16
+
17
+ # value_set.json, all the domain[slot] values in datasets
18
+ def get_db_values(value_set_path):
19
+ processed = {}
20
+ bspn_word = []
21
+ nlp = spacy.load("en_core_web_sm")
22
+
23
+ with open(value_set_path, "r") as f: # read value set file in lower
24
+ value_set = json.loads(f.read().lower())
25
+
26
+ with open("data/raw/UBAR/db/ontology.json", "r") as f: # read ontology in lower, all the domain-slot values
27
+ otlg = json.loads(f.read().lower())
28
+
29
+ for (
30
+ domain,
31
+ slots,
32
+ ) in value_set.items(): # add all informable slots to bspn_word, create lists holder for values
33
+ processed[domain] = {}
34
+ bspn_word.append("[" + domain + "]")
35
+ for slot, values in slots.items():
36
+ s_p = ontology.normlize_slot_names.get(slot, slot)
37
+ if s_p in ontology.informable_slots[domain]:
38
+ bspn_word.append(s_p)
39
+ processed[domain][s_p] = []
40
+
41
+ for (
42
+ domain,
43
+ slots,
44
+ ) in value_set.items(): # add all words of values of informable slots to bspn_word
45
+ for slot, values in slots.items():
46
+ s_p = ontology.normlize_slot_names.get(slot, slot)
47
+ if s_p in ontology.informable_slots[domain]:
48
+ for v in values:
49
+ _, v_p = clean_slot_values(domain, slot, v)
50
+ v_p = " ".join([token.text for token in nlp(v_p)]).strip()
51
+ processed[domain][s_p].append(v_p)
52
+ for x in v_p.split():
53
+ if x not in bspn_word:
54
+ bspn_word.append(x)
55
+
56
+ for domain_slot, values in otlg.items(): # split domain-slots to domains and slots
57
+ domain, slot = domain_slot.split("-")
58
+ if domain == "bus":
59
+ domain = "taxi"
60
+ if slot == "price range":
61
+ slot = "pricerange"
62
+ if slot == "book stay":
63
+ slot = "stay"
64
+ if slot == "book day":
65
+ slot = "day"
66
+ if slot == "book people":
67
+ slot = "people"
68
+ if slot == "book time":
69
+ slot = "time"
70
+ if slot == "arrive by":
71
+ slot = "arrive"
72
+ if slot == "leave at":
73
+ slot = "leave"
74
+ if slot == "leaveat":
75
+ slot = "leave"
76
+ # add all slots and words of values if not already in processed and bspn_word
77
+ if slot not in processed[domain]:
78
+ processed[domain][slot] = []
79
+ bspn_word.append(slot)
80
+ for v in values:
81
+ _, v_p = clean_slot_values(domain, slot, v)
82
+ v_p = " ".join([token.text for token in nlp(v_p)]).strip()
83
+ if v_p not in processed[domain][slot]:
84
+ processed[domain][slot].append(v_p)
85
+ for x in v_p.split():
86
+ if x not in bspn_word:
87
+ bspn_word.append(x)
88
+
89
+ with open(value_set_path.replace(".json", "_processed.json"), "w") as f:
90
+ json.dump(processed, f, indent=2) # save processed.json
91
+ with open("data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/bspn_word_collection.json", "w") as f:
92
+ json.dump(bspn_word, f, indent=2) # save bspn_word
93
+
94
+ print("DB value set processed! ")
95
+
96
+
97
+ def preprocess_db(db_paths): # apply clean_slot_values to all dbs
98
+ dbs = {}
99
+ nlp = spacy.load("en_core_web_sm")
100
+ for domain in ontology.all_domains:
101
+ with open(db_paths[domain], "r") as f: # for every db_domain, read json file
102
+ dbs[domain] = json.loads(f.read().lower())
103
+ # entry has information about slots of said domain
104
+ for idx, entry in enumerate(dbs[domain]):
105
+ new_entry = copy.deepcopy(entry)
106
+ for key, value in entry.items(): # key = slot
107
+ if type(value) is not str:
108
+ continue
109
+ del new_entry[key]
110
+ key, value = clean_slot_values(domain, key, value)
111
+ tokenize_and_back = " ".join([token.text for token in nlp(value)]).strip()
112
+ new_entry[key] = tokenize_and_back
113
+ dbs[domain][idx] = new_entry
114
+ with open(db_paths[domain].replace(".json", "_processed.json"), "w") as f:
115
+ json.dump(dbs[domain], f, indent=2)
116
+ print("[%s] DB processed! " % domain)
117
+
118
+
119
+ class DataPreprocessor(object):
120
+ def __init__(self):
121
+ self.nlp = spacy.load("en_core_web_sm")
122
+ self.db = MultiWozDB(cfg.dbs) # load all processed dbs
123
+ data_path = "data/preprocessed/UBAR/gen_usr_utt_experiment_data_with_span_full.json"
124
+ # archive = zipfile.ZipFile(data_path + ".zip", "r")
125
+ # self.convlab_data = json.loads(archive.open(data_path.split("/")[-1], "r").read().lower())
126
+ self.convlab_data = json.loads(open(data_path, "r").read().lower())
127
+ self.delex_sg_valdict_path = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/delex_single_valdict.json"
128
+ self.delex_mt_valdict_path = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/delex_multi_valdict.json"
129
+ self.ambiguous_val_path = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/ambiguous_values.json"
130
+ self.delex_refs_path = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/reference_no.json"
131
+ self.delex_refs = json.loads(open(self.delex_refs_path, "r").read())
132
+ if not os.path.exists(self.delex_sg_valdict_path):
133
+ (
134
+ self.delex_sg_valdict,
135
+ self.delex_mt_valdict,
136
+ self.ambiguous_vals,
137
+ ) = self.get_delex_valdict()
138
+ else:
139
+ self.delex_sg_valdict = json.loads(open(self.delex_sg_valdict_path, "r").read())
140
+ self.delex_mt_valdict = json.loads(open(self.delex_mt_valdict_path, "r").read())
141
+ self.ambiguous_vals = json.loads(open(self.ambiguous_val_path, "r").read())
142
+
143
+ self.vocab = utils.Vocab(cfg.vocab_size)
144
+
145
+ def delex_by_annotation(self, dial_turn):
146
+ u = dial_turn["text"].split()
147
+ span = dial_turn["span_info"]
148
+ for s in span:
149
+ slot = s[1]
150
+ if slot == "open":
151
+ continue
152
+ if ontology.da_abbr_to_slot_name.get(slot):
153
+ slot = ontology.da_abbr_to_slot_name[slot]
154
+ for idx in range(s[3], s[4] + 1):
155
+ u[idx] = ""
156
+ try:
157
+ u[s[3]] = "[value_" + slot + "]"
158
+ except Exception:
159
+ u[5] = "[value_" + slot + "]"
160
+ u_delex = " ".join([t for t in u if t != ""])
161
+ u_delex = u_delex.replace("[value_address] , [value_address] , [value_address]", "[value_address]")
162
+ u_delex = u_delex.replace("[value_address] , [value_address]", "[value_address]")
163
+ u_delex = u_delex.replace("[value_name] [value_name]", "[value_name]")
164
+ u_delex = u_delex.replace("[value_name]([value_phone] )", "[value_name] ( [value_phone] )")
165
+ return u_delex
166
+
167
+ def delex_by_valdict(self, text):
168
+ text = clean_text(text)
169
+
170
+ text = re.sub(r"\d{5}\s?\d{5,7}", "[value_phone]", text)
171
+ text = re.sub(r"\d[\s-]stars?", "[value_stars]", text)
172
+ text = re.sub(r"\$\d+|\$?\d+.?(\d+)?\s(pounds?|gbps?)", "[value_price]", text)
173
+ text = re.sub(r"tr[\d]{4}", "[value_id]", text)
174
+ text = re.sub(
175
+ r"([a-z]{1}[\. ]?[a-z]{1}[\. ]?\d{1,2}[, ]+\d{1}[\. ]?[a-z]{1}[\. ]?[a-z]{1}|[a-z]{2}\d{2}[a-z]{2})",
176
+ "[value_postcode]",
177
+ text,
178
+ )
179
+
180
+ for value, slot in self.delex_mt_valdict.items():
181
+ text = text.replace(value, "[value_%s]" % slot)
182
+
183
+ for value, slot in self.delex_sg_valdict.items():
184
+ tokens = text.split()
185
+ for idx, tk in enumerate(tokens):
186
+ if tk == value:
187
+ tokens[idx] = "[value_%s]" % slot
188
+ text = " ".join(tokens)
189
+
190
+ for ambg_ent in self.ambiguous_vals:
191
+ # ely is a place, but appears in words like moderately
192
+ start_idx = text.find(" " + ambg_ent)
193
+ if start_idx == -1:
194
+ continue
195
+ front_words = text[:start_idx].split()
196
+ ent_type = "time" if ":" in ambg_ent else "place"
197
+
198
+ for fw in front_words[::-1]:
199
+ if fw in [
200
+ "arrive",
201
+ "arrives",
202
+ "arrived",
203
+ "arriving",
204
+ "arrival",
205
+ "destination",
206
+ "there",
207
+ "reach",
208
+ "to",
209
+ "by",
210
+ "before",
211
+ ]:
212
+ slot = "[value_arrive]" if ent_type == "time" else "[value_destination]"
213
+ text = re.sub(" " + ambg_ent, " " + slot, text)
214
+ elif fw in [
215
+ "leave",
216
+ "leaves",
217
+ "leaving",
218
+ "depart",
219
+ "departs",
220
+ "departing",
221
+ "departure",
222
+ "from",
223
+ "after",
224
+ "pulls",
225
+ ]:
226
+ slot = "[value_leave]" if ent_type == "time" else "[value_departure]"
227
+ text = re.sub(" " + ambg_ent, " " + slot, text)
228
+
229
+ text = text.replace("[value_car] [value_car]", "[value_car]")
230
+ return text
231
+
232
+ def get_delex_valdict(
233
+ self,
234
+ ):
235
+ skip_entry_type = {
236
+ "taxi": ["taxi_phone"],
237
+ "police": ["id"],
238
+ "hospital": ["id"],
239
+ "hotel": [
240
+ "id",
241
+ "location",
242
+ "internet",
243
+ "parking",
244
+ "takesbookings",
245
+ "stars",
246
+ "price",
247
+ "n",
248
+ "postcode",
249
+ "phone",
250
+ ],
251
+ "attraction": [
252
+ "id",
253
+ "location",
254
+ "pricerange",
255
+ "price",
256
+ "openhours",
257
+ "postcode",
258
+ "phone",
259
+ ],
260
+ "train": ["price", "id"],
261
+ "restaurant": [
262
+ "id",
263
+ "location",
264
+ "introduction",
265
+ "signature",
266
+ "type",
267
+ "postcode",
268
+ "phone",
269
+ ],
270
+ }
271
+ entity_value_to_slot = {}
272
+ ambiguous_entities = []
273
+ for domain, db_data in self.db.dbs.items():
274
+ print("Processing entity values in [%s]" % domain)
275
+ if domain != "taxi":
276
+ for db_entry in db_data:
277
+ for slot, value in db_entry.items():
278
+ if slot not in skip_entry_type[domain]:
279
+ if type(value) is not str:
280
+ raise TypeError("value '%s' in domain '%s' should be rechecked" % (slot, domain))
281
+ else:
282
+ slot, value = clean_slot_values(domain, slot, value)
283
+ value = " ".join([token.text for token in self.nlp(value)]).strip()
284
+ if value in entity_value_to_slot and entity_value_to_slot[value] != slot:
285
+ # print(value, ": ",entity_value_to_slot[value], slot)
286
+ ambiguous_entities.append(value)
287
+ entity_value_to_slot[value] = slot
288
+ else: # taxi db specific
289
+ db_entry = db_data[0]
290
+ for slot, ent_list in db_entry.items():
291
+ if slot not in skip_entry_type[domain]:
292
+ for ent in ent_list:
293
+ entity_value_to_slot[ent] = "car"
294
+ ambiguous_entities = set(ambiguous_entities)
295
+ ambiguous_entities.remove("cambridge")
296
+ ambiguous_entities = list(ambiguous_entities)
297
+ for amb_ent in ambiguous_entities: # departure or destination? arrive time or leave time?
298
+ entity_value_to_slot.pop(amb_ent)
299
+ entity_value_to_slot["parkside"] = "address"
300
+ entity_value_to_slot["parkside, cambridge"] = "address"
301
+ entity_value_to_slot["cambridge belfry"] = "name"
302
+ entity_value_to_slot["hills road"] = "address"
303
+ entity_value_to_slot["hills rd"] = "address"
304
+ entity_value_to_slot["Parkside Police Station"] = "name"
305
+
306
+ single_token_values = {}
307
+ multi_token_values = {}
308
+ for val, slt in entity_value_to_slot.items():
309
+ if val in ["cambridge"]:
310
+ continue
311
+ if len(val.split()) > 1:
312
+ multi_token_values[val] = slt
313
+ else:
314
+ single_token_values[val] = slt
315
+
316
+ with open(self.delex_sg_valdict_path, "w") as f:
317
+ single_token_values = OrderedDict(
318
+ sorted(single_token_values.items(), key=lambda kv: len(kv[0]), reverse=True)
319
+ )
320
+ json.dump(single_token_values, f, indent=2)
321
+ print("single delex value dict saved!")
322
+ with open(self.delex_mt_valdict_path, "w") as f:
323
+ multi_token_values = OrderedDict(
324
+ sorted(multi_token_values.items(), key=lambda kv: len(kv[0]), reverse=True)
325
+ )
326
+ json.dump(multi_token_values, f, indent=2)
327
+ print("multi delex value dict saved!")
328
+ with open(self.ambiguous_val_path, "w") as f:
329
+ json.dump(ambiguous_entities, f, indent=2)
330
+ print("ambiguous value dict saved!")
331
+
332
+ return single_token_values, multi_token_values, ambiguous_entities
333
+
334
+ def preprocess_main(self, save_path=None, is_test=False):
335
+ """ """
336
+ data = {}
337
+ count = 0
338
+ self.unique_da = {}
339
+ ordered_sysact_dict = {}
340
+ for fn, raw_dial in tqdm(list(self.convlab_data.items())):
341
+ count += 1
342
+ # if count == 100:
343
+ # break
344
+
345
+ compressed_goal = {} # for every dialog, keep track the goal, domains, requests
346
+ dial_domains, dial_reqs = [], []
347
+ for dom, g in raw_dial["goal"].items():
348
+ if dom != "topic" and dom != "message" and g:
349
+ if g.get("reqt"): # request info. eg. postcode/address/phone
350
+ # normalize request slots
351
+ for i, req_slot in enumerate(g["reqt"]):
352
+ if ontology.normlize_slot_names.get(req_slot):
353
+ g["reqt"][i] = ontology.normlize_slot_names[req_slot]
354
+ dial_reqs.append(g["reqt"][i])
355
+ compressed_goal[dom] = g
356
+ if dom in ontology.all_domains:
357
+ dial_domains.append(dom)
358
+
359
+ dial_reqs = list(set(dial_reqs))
360
+
361
+ dial = {"goal": compressed_goal, "log": []}
362
+ single_turn = {}
363
+ constraint_dict = OrderedDict()
364
+ prev_constraint_dict = {}
365
+ prev_turn_domain = ["general"]
366
+ ordered_sysact_dict[fn] = {}
367
+
368
+ for turn_num, dial_turn in enumerate(raw_dial["log"]):
369
+ # for user turn, have text
370
+ # sys turn: text, belief states(metadata), dialog_act, span_info
371
+ dial_state = dial_turn["metadata"]
372
+ if not dial_state: # user
373
+ # delexicalize user utterance, either by annotation or by val_dict
374
+ u = " ".join(clean_text(dial_turn["text"]).split())
375
+
376
+ # NOTE: Commenting out delexicalisation because it is not used and
377
+ # breaks when I use generated user dialogues for some reason
378
+
379
+ # if dial_turn["span_info"]:
380
+ # u_delex = clean_text(self.delex_by_annotation(dial_turn))
381
+ # else:
382
+ # u_delex = self.delex_by_valdict(dial_turn["text"])
383
+
384
+ single_turn["user"] = u
385
+ # single_turn["user_delex"] = u_delex
386
+
387
+ else: # system
388
+ # delexicalize system response, either by annotation or by val_dict
389
+ if dial_turn["span_info"]:
390
+ s_delex = clean_text(self.delex_by_annotation(dial_turn))
391
+ else:
392
+ if not dial_turn["text"]:
393
+ print(fn)
394
+ s_delex = self.delex_by_valdict(dial_turn["text"])
395
+ single_turn["resp"] = s_delex
396
+
397
+ # get belief state, semi=informable/book=requestable, put into constraint_dict
398
+ for domain in dial_domains:
399
+ if not constraint_dict.get(domain):
400
+ constraint_dict[domain] = OrderedDict()
401
+ info_sv = dial_state[domain]["semi"]
402
+ for s, v in info_sv.items():
403
+ s, v = clean_slot_values(domain, s, v)
404
+ if len(v.split()) > 1:
405
+ v = " ".join([token.text for token in self.nlp(v)]).strip()
406
+ if v != "":
407
+ constraint_dict[domain][s] = v
408
+ book_sv = dial_state[domain]["book"]
409
+ for s, v in book_sv.items():
410
+ if s == "booked":
411
+ continue
412
+ s, v = clean_slot_values(domain, s, v)
413
+ if len(v.split()) > 1:
414
+ v = " ".join([token.text for token in self.nlp(v)]).strip()
415
+ if v != "":
416
+ constraint_dict[domain][s] = v
417
+
418
+ constraints = [] # list in format of [domain] slot value
419
+ cons_delex = []
420
+ turn_dom_bs = []
421
+ for domain, info_slots in constraint_dict.items():
422
+ if info_slots:
423
+ constraints.append("[" + domain + "]")
424
+ cons_delex.append("[" + domain + "]")
425
+ for slot, value in info_slots.items():
426
+ constraints.append(slot)
427
+ constraints.extend(value.split())
428
+ cons_delex.append(slot)
429
+ if domain not in prev_constraint_dict:
430
+ turn_dom_bs.append(domain)
431
+ elif prev_constraint_dict[domain] != constraint_dict[domain]:
432
+ turn_dom_bs.append(domain)
433
+
434
+ sys_act_dict = {}
435
+ turn_dom_da = set()
436
+ for act in dial_turn["dialog_act"]:
437
+ d, a = act.split("-") # split domain-act
438
+ turn_dom_da.add(d)
439
+ turn_dom_da = list(turn_dom_da)
440
+ if len(turn_dom_da) != 1 and "general" in turn_dom_da:
441
+ turn_dom_da.remove("general")
442
+ if len(turn_dom_da) != 1 and "booking" in turn_dom_da:
443
+ turn_dom_da.remove("booking")
444
+
445
+ # get turn domain
446
+ turn_domain = turn_dom_bs
447
+ for dom in turn_dom_da:
448
+ if dom != "booking" and dom not in turn_domain:
449
+ turn_domain.append(dom)
450
+ if not turn_domain:
451
+ turn_domain = prev_turn_domain
452
+ if len(turn_domain) == 2 and "general" in turn_domain:
453
+ turn_domain.remove("general")
454
+ if len(turn_domain) == 2:
455
+ if len(prev_turn_domain) == 1 and prev_turn_domain[0] == turn_domain[1]:
456
+ turn_domain = turn_domain[::-1]
457
+
458
+ # get system action
459
+ for dom in turn_domain:
460
+ sys_act_dict[dom] = {}
461
+ add_to_last_collect = []
462
+ booking_act_map = {"inform": "offerbook", "book": "offerbooked"}
463
+ for act, params in dial_turn["dialog_act"].items():
464
+ if act == "general-greet":
465
+ continue
466
+ d, a = act.split("-")
467
+ if d == "general" and d not in sys_act_dict:
468
+ sys_act_dict[d] = {}
469
+ if d == "booking":
470
+ d = turn_domain[0]
471
+ a = booking_act_map.get(a, a)
472
+ add_p = []
473
+ for param in params:
474
+ p = param[0]
475
+ if p == "none":
476
+ continue
477
+ elif ontology.da_abbr_to_slot_name.get(p):
478
+ p = ontology.da_abbr_to_slot_name[p]
479
+ if p not in add_p:
480
+ add_p.append(p)
481
+ add_to_last = True if a in ["request", "reqmore", "bye", "offerbook"] else False
482
+ if add_to_last:
483
+ add_to_last_collect.append((d, a, add_p))
484
+ else:
485
+ sys_act_dict[d][a] = add_p
486
+ for d, a, add_p in add_to_last_collect:
487
+ sys_act_dict[d][a] = add_p
488
+
489
+ for d in copy.copy(sys_act_dict):
490
+ acts = sys_act_dict[d]
491
+ if not acts:
492
+ del sys_act_dict[d]
493
+ if "inform" in acts and "offerbooked" in acts:
494
+ for s in sys_act_dict[d]["inform"]:
495
+ sys_act_dict[d]["offerbooked"].append(s)
496
+ del sys_act_dict[d]["inform"]
497
+
498
+ ordered_sysact_dict[fn][len(dial["log"])] = sys_act_dict
499
+
500
+ sys_act = []
501
+ if "general-greet" in dial_turn["dialog_act"]:
502
+ sys_act.extend(["[general]", "[greet]"])
503
+ for d, acts in sys_act_dict.items():
504
+ sys_act += ["[" + d + "]"]
505
+ for a, slots in acts.items():
506
+ self.unique_da[d + "-" + a] = 1
507
+ sys_act += ["[" + a + "]"]
508
+ sys_act += slots
509
+
510
+ # get db pointers
511
+ matnums = self.db.get_match_num(constraint_dict)
512
+ match_dom = turn_domain[0] if len(turn_domain) == 1 else turn_domain[1]
513
+ match = matnums[match_dom]
514
+ dbvec = self.db.addDBPointer(match_dom, match)
515
+ bkvec = self.db.addBookingPointer(dial_turn["dialog_act"])
516
+
517
+ # 4 database pointer for domains, 2 for booking
518
+ single_turn["pointer"] = ",".join([str(d) for d in dbvec + bkvec])
519
+ single_turn["match"] = str(match)
520
+ single_turn["constraint"] = " ".join(constraints)
521
+ single_turn["cons_delex"] = " ".join(cons_delex)
522
+ single_turn["sys_act"] = " ".join(sys_act)
523
+ single_turn["turn_num"] = len(dial["log"])
524
+ single_turn["turn_domain"] = " ".join(["[" + d + "]" for d in turn_domain])
525
+
526
+ prev_turn_domain = copy.deepcopy(turn_domain)
527
+ prev_constraint_dict = copy.deepcopy(constraint_dict)
528
+
529
+ if "user" in single_turn:
530
+ dial["log"].append(single_turn)
531
+ for t in single_turn["user"].split() + single_turn["resp"].split() + constraints + sys_act:
532
+ self.vocab.add_word(t)
533
+
534
+ # NOTE: Commenting out delexicalisation because it is not used and
535
+ # breaks when I use generated user dialogues for some reason
536
+
537
+ # for t in single_turn["user_delex"].split():
538
+ # if "[" in t and "]" in t and not t.startswith("[") and not t.endswith("]"):
539
+ # single_turn["user_delex"].replace(t, t[t.index("[") : t.index("]") + 1])
540
+ # elif not self.vocab.has_word(t):
541
+ # self.vocab.add_word(t)
542
+
543
+ single_turn = {}
544
+
545
+ data[fn] = dial
546
+ # pprint(dial)
547
+ # if count == 20:
548
+ # break
549
+ self.vocab.construct()
550
+ self.vocab.save_vocab("data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/vocab")
551
+ with open("data/interim/gen_usr_utts/multi-woz-analysis/dialog_acts.json", "w") as f:
552
+ json.dump(ordered_sysact_dict, f, indent=2)
553
+ with open("data/interim/gen_usr_utts/multi-woz-analysis/dialog_act_type.json", "w") as f:
554
+ json.dump(self.unique_da, f, indent=2)
555
+ return data
556
+
557
+
558
+ if __name__ == "__main__":
559
+ db_paths = {
560
+ "attraction": "data/raw/UBAR/db/attraction_db.json",
561
+ "hospital": "data/raw/UBAR/db/hospital_db.json",
562
+ "hotel": "data/raw/UBAR/db/hotel_db.json",
563
+ "police": "data/raw/UBAR/db/police_db.json",
564
+ "restaurant": "data/raw/UBAR/db/restaurant_db.json",
565
+ "taxi": "data/raw/UBAR/db/taxi_db.json",
566
+ "train": "data/raw/UBAR/db/train_db.json",
567
+ }
568
+ get_db_values("data/raw/UBAR/db/value_set.json")
569
+ preprocess_db(db_paths)
570
+ dh = DataPreprocessor()
571
+ data = dh.preprocess_main()
572
+ if not os.path.exists("data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed"):
573
+ os.mkdir("data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed")
574
+
575
+ with open("data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/data_for_ubar.json", "w") as f:
576
+ json.dump(data, f, indent=2)
scripts/UBAR_code/preprocess2.1.py ADDED
@@ -0,0 +1,585 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import json
3
+ import os
4
+ import re
5
+ import zipfile
6
+ from collections import OrderedDict
7
+
8
+ import spacy
9
+ from tqdm import tqdm
10
+
11
+ from crazyneuraluser.UBAR_code import ontology, utils
12
+ from crazyneuraluser.UBAR_code.clean_dataset import clean_slot_values, clean_text
13
+ from crazyneuraluser.UBAR_code.config import global_config as cfg
14
+ from crazyneuraluser.UBAR_code.db_ops import MultiWozDB
15
+
16
+
17
+ def get_db_values(
18
+ value_set_path,
19
+ ): # value_set.json, all the domain[slot] values in datasets
20
+ processed = {}
21
+ bspn_word = []
22
+ nlp = spacy.load("en_core_web_sm")
23
+
24
+ with open(value_set_path, "r") as f: # read value set file in lower
25
+ value_set = json.loads(f.read().lower())
26
+
27
+ with open("db/ontology.json", "r") as f: # read ontology in lower, all the domain-slot values
28
+ otlg = json.loads(f.read().lower())
29
+
30
+ for (
31
+ domain,
32
+ slots,
33
+ ) in value_set.items(): # add all informable slots to bspn_word, create lists holder for values
34
+ processed[domain] = {}
35
+ bspn_word.append("[" + domain + "]")
36
+ for slot, values in slots.items():
37
+ s_p = ontology.normlize_slot_names.get(slot, slot)
38
+ if s_p in ontology.informable_slots[domain]:
39
+ bspn_word.append(s_p)
40
+ processed[domain][s_p] = []
41
+
42
+ for (
43
+ domain,
44
+ slots,
45
+ ) in value_set.items(): # add all words of values of informable slots to bspn_word
46
+ for slot, values in slots.items():
47
+ s_p = ontology.normlize_slot_names.get(slot, slot)
48
+ if s_p in ontology.informable_slots[domain]:
49
+ for v in values:
50
+ _, v_p = clean_slot_values(domain, slot, v)
51
+ v_p = " ".join([token.text for token in nlp(v_p)]).strip()
52
+ processed[domain][s_p].append(v_p)
53
+ for x in v_p.split():
54
+ if x not in bspn_word:
55
+ bspn_word.append(x)
56
+
57
+ for domain_slot, values in otlg.items(): # split domain-slots to domains and slots
58
+ domain, slot = domain_slot.split("-")
59
+ if domain == "bus":
60
+ domain = "taxi"
61
+ if slot == "price range":
62
+ slot = "pricerange"
63
+ if slot == "book stay":
64
+ slot = "stay"
65
+ if slot == "book day":
66
+ slot = "day"
67
+ if slot == "book people":
68
+ slot = "people"
69
+ if slot == "book time":
70
+ slot = "time"
71
+ if slot == "arrive by":
72
+ slot = "arrive"
73
+ if slot == "leave at":
74
+ slot = "leave"
75
+ if slot == "leaveat":
76
+ slot = "leave"
77
+ if slot not in processed[domain]: # add all slots and words of values if not already in processed and bspn_word
78
+ processed[domain][slot] = []
79
+ bspn_word.append(slot)
80
+ for v in values:
81
+ _, v_p = clean_slot_values(domain, slot, v)
82
+ v_p = " ".join([token.text for token in nlp(v_p)]).strip()
83
+ if v_p not in processed[domain][slot]:
84
+ processed[domain][slot].append(v_p)
85
+ for x in v_p.split():
86
+ if x not in bspn_word:
87
+ bspn_word.append(x)
88
+
89
+ with open(value_set_path.replace(".json", "_processed.json"), "w") as f:
90
+ json.dump(processed, f, indent=2) # save processed.json
91
+ with open("data/preprocessed/UBAR/multi-woz-processed/bspn_word_collection.json", "w") as f:
92
+ json.dump(bspn_word, f, indent=2) # save bspn_word
93
+
94
+ print("DB value set processed! ")
95
+
96
+
97
+ def preprocess_db(db_paths): # apply clean_slot_values to all dbs
98
+ dbs = {}
99
+ nlp = spacy.load("en_core_web_sm")
100
+ for domain in ontology.all_domains:
101
+ with open(db_paths[domain], "r") as f: # for every db_domain, read json file
102
+ dbs[domain] = json.loads(f.read().lower())
103
+ for idx, entry in enumerate(dbs[domain]): # entry has information about slots of said domain
104
+ new_entry = copy.deepcopy(entry)
105
+ for key, value in entry.items(): # key = slot
106
+ if type(value) is not str:
107
+ continue
108
+ del new_entry[key]
109
+ key, value = clean_slot_values(domain, key, value)
110
+ tokenize_and_back = " ".join([token.text for token in nlp(value)]).strip()
111
+ new_entry[key] = tokenize_and_back
112
+ dbs[domain][idx] = new_entry
113
+ with open(db_paths[domain].replace(".json", "_processed.json"), "w") as f:
114
+ json.dump(dbs[domain], f, indent=2)
115
+ print("[%s] DB processed! " % domain)
116
+
117
+
118
+ # 2.1
119
+ class DataPreprocessor(object):
120
+ def __init__(self):
121
+ self.nlp = spacy.load("en_core_web_sm")
122
+ self.db = MultiWozDB(cfg.dbs) # load all processed dbs
123
+ # data_path = 'data/multi-woz/annotated_user_da_with_span_full.json'
124
+ data_path = "data/raw/UBAR/MultiWOZ_2.1/data.json"
125
+ archive = zipfile.ZipFile(data_path + ".zip", "r")
126
+ self.convlab_data = json.loads(archive.open(data_path.split("/")[-1], "r").read().lower())
127
+ # self.delex_sg_valdict_path = 'data/multi-woz-processed/delex_single_valdict.json'
128
+ # self.delex_mt_valdict_path = 'data/multi-woz-processed/delex_multi_valdict.json'
129
+ # self.ambiguous_val_path = 'data/multi-woz-processed/ambiguous_values.json'
130
+ # self.delex_refs_path = 'data/multi-woz-processed/reference_no.json'
131
+ self.delex_sg_valdict_path = "data/preprocessed/UBAR/multi-woz-2.1-processed/delex_single_valdict.json"
132
+ self.delex_mt_valdict_path = "data/preprocessed/UBAR/multi-woz-2.1-processed/delex_multi_valdict.json"
133
+ self.ambiguous_val_path = "data/preprocessed/UBAR/multi-woz-2.1-processed/ambiguous_values.json"
134
+ self.delex_refs_path = "data/preprocessed/UBAR/multi-woz-2.1-processed/reference_no.json"
135
+ self.delex_refs = json.loads(open(self.delex_refs_path, "r").read())
136
+ if not os.path.exists(self.delex_sg_valdict_path):
137
+ (
138
+ self.delex_sg_valdict,
139
+ self.delex_mt_valdict,
140
+ self.ambiguous_vals,
141
+ ) = self.get_delex_valdict()
142
+ else:
143
+ self.delex_sg_valdict = json.loads(open(self.delex_sg_valdict_path, "r").read())
144
+ self.delex_mt_valdict = json.loads(open(self.delex_mt_valdict_path, "r").read())
145
+ self.ambiguous_vals = json.loads(open(self.ambiguous_val_path, "r").read())
146
+
147
+ self.vocab = utils.Vocab(cfg.vocab_size)
148
+
149
+ def delex_by_annotation(self, dial_turn):
150
+ # add by yyy in 13:48 0803
151
+ u = dial_turn["text"].split()
152
+ # u = my_clean_text(dial_turn['text']).split()
153
+ ##
154
+ span = dial_turn["span_info"]
155
+ for s in span:
156
+ slot = s[1]
157
+ if slot == "open":
158
+ continue
159
+ if ontology.da_abbr_to_slot_name.get(slot):
160
+ slot = ontology.da_abbr_to_slot_name[slot]
161
+ for idx in range(s[3], s[4] + 1):
162
+ u[idx] = ""
163
+ try:
164
+ u[s[3]] = "[value_" + slot + "]"
165
+ except Exception:
166
+ u[5] = "[value_" + slot + "]"
167
+ u_delex = " ".join([t for t in u if t != ""])
168
+ u_delex = u_delex.replace("[value_address] , [value_address] , [value_address]", "[value_address]")
169
+ u_delex = u_delex.replace("[value_address] , [value_address]", "[value_address]")
170
+ u_delex = u_delex.replace("[value_name] [value_name]", "[value_name]")
171
+ u_delex = u_delex.replace("[value_name]([value_phone] )", "[value_name] ( [value_phone] )")
172
+ return u_delex
173
+
174
+ def delex_by_valdict(self, text):
175
+ text = clean_text(text)
176
+
177
+ text = re.sub(r"\d{5}\s?\d{5,7}", "[value_phone]", text)
178
+ text = re.sub(r"\d[\s-]stars?", "[value_stars]", text)
179
+ text = re.sub(r"\$\d+|\$?\d+.?(\d+)?\s(pounds?|gbps?)", "[value_price]", text)
180
+ text = re.sub(r"tr[\d]{4}", "[value_id]", text)
181
+ text = re.sub(
182
+ r"([a-z]{1}[\. ]?[a-z]{1}[\. ]?\d{1,2}[, ]+\d{1}[\. ]?[a-z]{1}[\. ]?[a-z]{1}|[a-z]{2}\d{2}[a-z]{2})",
183
+ "[value_postcode]",
184
+ text,
185
+ )
186
+
187
+ for value, slot in self.delex_mt_valdict.items():
188
+ text = text.replace(value, "[value_%s]" % slot)
189
+
190
+ for value, slot in self.delex_sg_valdict.items():
191
+ tokens = text.split()
192
+ for idx, tk in enumerate(tokens):
193
+ if tk == value:
194
+ tokens[idx] = "[value_%s]" % slot
195
+ text = " ".join(tokens)
196
+
197
+ for ambg_ent in self.ambiguous_vals:
198
+ start_idx = text.find(" " + ambg_ent) # ely is a place, but appears in words like moderately
199
+ if start_idx == -1:
200
+ continue
201
+ front_words = text[:start_idx].split()
202
+ ent_type = "time" if ":" in ambg_ent else "place"
203
+
204
+ for fw in front_words[::-1]:
205
+ if fw in [
206
+ "arrive",
207
+ "arrives",
208
+ "arrived",
209
+ "arriving",
210
+ "arrival",
211
+ "destination",
212
+ "there",
213
+ "reach",
214
+ "to",
215
+ "by",
216
+ "before",
217
+ ]:
218
+ slot = "[value_arrive]" if ent_type == "time" else "[value_destination]"
219
+ text = re.sub(" " + ambg_ent, " " + slot, text)
220
+ elif fw in [
221
+ "leave",
222
+ "leaves",
223
+ "leaving",
224
+ "depart",
225
+ "departs",
226
+ "departing",
227
+ "departure",
228
+ "from",
229
+ "after",
230
+ "pulls",
231
+ ]:
232
+ slot = "[value_leave]" if ent_type == "time" else "[value_departure]"
233
+ text = re.sub(" " + ambg_ent, " " + slot, text)
234
+
235
+ text = text.replace("[value_car] [value_car]", "[value_car]")
236
+ return text
237
+
238
+ def get_delex_valdict(
239
+ self,
240
+ ):
241
+ skip_entry_type = {
242
+ "taxi": ["taxi_phone"],
243
+ "police": ["id"],
244
+ "hospital": ["id"],
245
+ "hotel": [
246
+ "id",
247
+ "location",
248
+ "internet",
249
+ "parking",
250
+ "takesbookings",
251
+ "stars",
252
+ "price",
253
+ "n",
254
+ "postcode",
255
+ "phone",
256
+ ],
257
+ "attraction": [
258
+ "id",
259
+ "location",
260
+ "pricerange",
261
+ "price",
262
+ "openhours",
263
+ "postcode",
264
+ "phone",
265
+ ],
266
+ "train": ["price", "id"],
267
+ "restaurant": [
268
+ "id",
269
+ "location",
270
+ "introduction",
271
+ "signature",
272
+ "type",
273
+ "postcode",
274
+ "phone",
275
+ ],
276
+ }
277
+ entity_value_to_slot = {}
278
+ ambiguous_entities = []
279
+ for domain, db_data in self.db.dbs.items():
280
+ print("Processing entity values in [%s]" % domain)
281
+ if domain != "taxi":
282
+ for db_entry in db_data:
283
+ for slot, value in db_entry.items():
284
+ if slot not in skip_entry_type[domain]:
285
+ if type(value) is not str:
286
+ raise TypeError("value '%s' in domain '%s' should be rechecked" % (slot, domain))
287
+ else:
288
+ slot, value = clean_slot_values(domain, slot, value)
289
+ value = " ".join([token.text for token in self.nlp(value)]).strip()
290
+ if value in entity_value_to_slot and entity_value_to_slot[value] != slot:
291
+ # print(value, ": ",entity_value_to_slot[value], slot)
292
+ ambiguous_entities.append(value)
293
+ entity_value_to_slot[value] = slot
294
+ else: # taxi db specific
295
+ db_entry = db_data[0]
296
+ for slot, ent_list in db_entry.items():
297
+ if slot not in skip_entry_type[domain]:
298
+ for ent in ent_list:
299
+ entity_value_to_slot[ent] = "car"
300
+ ambiguous_entities = set(ambiguous_entities)
301
+ ambiguous_entities.remove("cambridge")
302
+ ambiguous_entities = list(ambiguous_entities)
303
+ for amb_ent in ambiguous_entities: # departure or destination? arrive time or leave time?
304
+ entity_value_to_slot.pop(amb_ent)
305
+ entity_value_to_slot["parkside"] = "address"
306
+ entity_value_to_slot["parkside, cambridge"] = "address"
307
+ entity_value_to_slot["cambridge belfry"] = "name"
308
+ entity_value_to_slot["hills road"] = "address"
309
+ entity_value_to_slot["hills rd"] = "address"
310
+ entity_value_to_slot["Parkside Police Station"] = "name"
311
+
312
+ single_token_values = {}
313
+ multi_token_values = {}
314
+ for val, slt in entity_value_to_slot.items():
315
+ if val in ["cambridge"]:
316
+ continue
317
+ if len(val.split()) > 1:
318
+ multi_token_values[val] = slt
319
+ else:
320
+ single_token_values[val] = slt
321
+
322
+ with open(self.delex_sg_valdict_path, "w") as f:
323
+ single_token_values = OrderedDict(
324
+ sorted(single_token_values.items(), key=lambda kv: len(kv[0]), reverse=True)
325
+ )
326
+ json.dump(single_token_values, f, indent=2)
327
+ print("single delex value dict saved!")
328
+ with open(self.delex_mt_valdict_path, "w") as f:
329
+ multi_token_values = OrderedDict(
330
+ sorted(multi_token_values.items(), key=lambda kv: len(kv[0]), reverse=True)
331
+ )
332
+ json.dump(multi_token_values, f, indent=2)
333
+ print("multi delex value dict saved!")
334
+ with open(self.ambiguous_val_path, "w") as f:
335
+ json.dump(ambiguous_entities, f, indent=2)
336
+ print("ambiguous value dict saved!")
337
+
338
+ return single_token_values, multi_token_values, ambiguous_entities
339
+
340
+ def preprocess_main(self, save_path=None, is_test=False):
341
+ """ """
342
+ data = {}
343
+ count = 0
344
+ self.unique_da = {}
345
+ ordered_sysact_dict = {}
346
+ # yyy
347
+ for fn, raw_dial in tqdm(list(self.convlab_data.items())):
348
+ if fn in [
349
+ "pmul4707.json",
350
+ "pmul2245.json",
351
+ "pmul4776.json",
352
+ "pmul3872.json",
353
+ "pmul4859.json",
354
+ ]:
355
+ continue
356
+ count += 1
357
+ # if count == 100:
358
+ # break
359
+
360
+ compressed_goal = {} # for every dialog, keep track the goal, domains, requests
361
+ dial_domains, dial_reqs = [], []
362
+ for dom, g in raw_dial["goal"].items():
363
+ if dom != "topic" and dom != "message" and g:
364
+ if g.get("reqt"): # request info. eg. postcode/address/phone
365
+ for i, req_slot in enumerate(g["reqt"]): # normalize request slots
366
+ if ontology.normlize_slot_names.get(req_slot):
367
+ g["reqt"][i] = ontology.normlize_slot_names[req_slot]
368
+ dial_reqs.append(g["reqt"][i])
369
+ compressed_goal[dom] = g
370
+ if dom in ontology.all_domains:
371
+ dial_domains.append(dom)
372
+
373
+ dial_reqs = list(set(dial_reqs))
374
+
375
+ dial = {"goal": compressed_goal, "log": []}
376
+ single_turn = {}
377
+ constraint_dict = OrderedDict()
378
+ prev_constraint_dict = {}
379
+ prev_turn_domain = ["general"]
380
+ ordered_sysact_dict[fn] = {}
381
+
382
+ for turn_num, dial_turn in enumerate(raw_dial["log"]):
383
+ # for user turn, have text
384
+ # sys turn: text, belief states(metadata), dialog_act, span_info
385
+ dial_state = dial_turn["metadata"]
386
+ dial_turn["text"] = " ".join([t.text for t in self.nlp(dial_turn["text"])])
387
+ if not dial_state: # user
388
+ # delexicalize user utterance, either by annotation or by val_dict
389
+ u = " ".join(clean_text(dial_turn["text"]).split())
390
+ if "span_info" in dial_turn and dial_turn["span_info"]:
391
+ u_delex = clean_text(self.delex_by_annotation(dial_turn))
392
+ else:
393
+ u_delex = self.delex_by_valdict(dial_turn["text"])
394
+
395
+ single_turn["user"] = u
396
+ single_turn["user_delex"] = u_delex
397
+
398
+ else: # system
399
+ # delexicalize system response, either by annotation or by val_dict
400
+ if "span_info" in dial_turn and dial_turn["span_info"]:
401
+ s_delex = clean_text(self.delex_by_annotation(dial_turn))
402
+ else:
403
+ if not dial_turn["text"]:
404
+ print(fn)
405
+ s_delex = self.delex_by_valdict(dial_turn["text"])
406
+ single_turn["resp"] = s_delex
407
+ single_turn["nodelx_resp"] = " ".join(clean_text(dial_turn["text"]).split())
408
+
409
+ # get belief state, semi=informable/book=requestable, put into constraint_dict
410
+ for domain in dial_domains:
411
+ if not constraint_dict.get(domain):
412
+ constraint_dict[domain] = OrderedDict()
413
+ info_sv = dial_state[domain]["semi"]
414
+ for s, v in info_sv.items():
415
+ s, v = clean_slot_values(domain, s, v)
416
+ if len(v.split()) > 1:
417
+ v = " ".join([token.text for token in self.nlp(v)]).strip()
418
+ if v != "":
419
+ constraint_dict[domain][s] = v
420
+ book_sv = dial_state[domain]["book"]
421
+ for s, v in book_sv.items():
422
+ if s == "booked":
423
+ continue
424
+ s, v = clean_slot_values(domain, s, v)
425
+ if len(v.split()) > 1:
426
+ v = " ".join([token.text for token in self.nlp(v)]).strip()
427
+ if v != "":
428
+ constraint_dict[domain][s] = v
429
+
430
+ constraints = [] # list in format of [domain] slot value
431
+ cons_delex = []
432
+ turn_dom_bs = []
433
+ for domain, info_slots in constraint_dict.items():
434
+ if info_slots:
435
+ constraints.append("[" + domain + "]")
436
+ cons_delex.append("[" + domain + "]")
437
+ for slot, value in info_slots.items():
438
+ constraints.append(slot)
439
+ constraints.extend(value.split())
440
+ cons_delex.append(slot)
441
+ if domain not in prev_constraint_dict:
442
+ turn_dom_bs.append(domain)
443
+ elif prev_constraint_dict[domain] != constraint_dict[domain]:
444
+ turn_dom_bs.append(domain)
445
+
446
+ sys_act_dict = {}
447
+ turn_dom_da = set()
448
+ for act in dial_turn["dialog_act"]:
449
+ d, a = act.split("-") # split domain-act
450
+ turn_dom_da.add(d)
451
+ turn_dom_da = list(turn_dom_da)
452
+ if len(turn_dom_da) != 1 and "general" in turn_dom_da:
453
+ turn_dom_da.remove("general")
454
+ if len(turn_dom_da) != 1 and "booking" in turn_dom_da:
455
+ turn_dom_da.remove("booking")
456
+
457
+ # get turn domain
458
+ turn_domain = turn_dom_bs
459
+ for dom in turn_dom_da:
460
+ if dom != "booking" and dom not in turn_domain:
461
+ turn_domain.append(dom)
462
+ if not turn_domain:
463
+ turn_domain = prev_turn_domain
464
+ if len(turn_domain) == 2 and "general" in turn_domain:
465
+ turn_domain.remove("general")
466
+ if len(turn_domain) == 2:
467
+ if len(prev_turn_domain) == 1 and prev_turn_domain[0] == turn_domain[1]:
468
+ turn_domain = turn_domain[::-1]
469
+
470
+ # get system action
471
+ for dom in turn_domain:
472
+ sys_act_dict[dom] = {}
473
+ add_to_last_collect = []
474
+ booking_act_map = {"inform": "offerbook", "book": "offerbooked"}
475
+ for act, params in dial_turn["dialog_act"].items():
476
+ if act == "general-greet":
477
+ continue
478
+ d, a = act.split("-")
479
+ if d == "general" and d not in sys_act_dict:
480
+ sys_act_dict[d] = {}
481
+ if d == "booking":
482
+ d = turn_domain[0]
483
+ a = booking_act_map.get(a, a)
484
+ add_p = []
485
+ for param in params:
486
+ p = param[0]
487
+ if p == "none":
488
+ continue
489
+ elif ontology.da_abbr_to_slot_name.get(p):
490
+ p = ontology.da_abbr_to_slot_name[p]
491
+ if p not in add_p:
492
+ add_p.append(p)
493
+ add_to_last = True if a in ["request", "reqmore", "bye", "offerbook"] else False
494
+ if add_to_last:
495
+ add_to_last_collect.append((d, a, add_p))
496
+ else:
497
+ sys_act_dict[d][a] = add_p
498
+ for d, a, add_p in add_to_last_collect:
499
+ sys_act_dict[d][a] = add_p
500
+
501
+ for d in copy.copy(sys_act_dict):
502
+ acts = sys_act_dict[d]
503
+ if not acts:
504
+ del sys_act_dict[d]
505
+ if "inform" in acts and "offerbooked" in acts:
506
+ for s in sys_act_dict[d]["inform"]:
507
+ sys_act_dict[d]["offerbooked"].append(s)
508
+ del sys_act_dict[d]["inform"]
509
+
510
+ ordered_sysact_dict[fn][len(dial["log"])] = sys_act_dict
511
+
512
+ sys_act = []
513
+ if "general-greet" in dial_turn["dialog_act"]:
514
+ sys_act.extend(["[general]", "[greet]"])
515
+ for d, acts in sys_act_dict.items():
516
+ sys_act += ["[" + d + "]"]
517
+ for a, slots in acts.items():
518
+ self.unique_da[d + "-" + a] = 1
519
+ sys_act += ["[" + a + "]"]
520
+ sys_act += slots
521
+
522
+ # get db pointers
523
+ matnums = self.db.get_match_num(constraint_dict)
524
+ match_dom = turn_domain[0] if len(turn_domain) == 1 else turn_domain[1]
525
+ match = matnums[match_dom]
526
+ dbvec = self.db.addDBPointer(match_dom, match)
527
+ bkvec = self.db.addBookingPointer(dial_turn["dialog_act"])
528
+
529
+ single_turn["pointer"] = ",".join(
530
+ [str(d) for d in dbvec + bkvec]
531
+ ) # 4 database pointer for domains, 2 for booking
532
+ single_turn["match"] = str(match)
533
+ single_turn["constraint"] = " ".join(constraints)
534
+ single_turn["cons_delex"] = " ".join(cons_delex)
535
+ single_turn["sys_act"] = " ".join(sys_act)
536
+ single_turn["turn_num"] = len(dial["log"])
537
+ single_turn["turn_domain"] = " ".join(["[" + d + "]" for d in turn_domain])
538
+
539
+ prev_turn_domain = copy.deepcopy(turn_domain)
540
+ prev_constraint_dict = copy.deepcopy(constraint_dict)
541
+
542
+ if "user" in single_turn:
543
+ dial["log"].append(single_turn)
544
+ for t in single_turn["user"].split() + single_turn["resp"].split() + constraints + sys_act:
545
+ self.vocab.add_word(t)
546
+ for t in single_turn["user_delex"].split():
547
+ if "[" in t and "]" in t and not t.startswith("[") and not t.endswith("]"):
548
+ single_turn["user_delex"].replace(t, t[t.index("[") : t.index("]") + 1])
549
+ elif not self.vocab.has_word(t):
550
+ self.vocab.add_word(t)
551
+
552
+ single_turn = {}
553
+
554
+ data[fn] = dial
555
+ # pprint(dial)
556
+ # if count == 20:
557
+ # break
558
+ self.vocab.construct()
559
+ self.vocab.save_vocab("data/preprocessed/UBAR/multi-woz-2.1-processed/vocab")
560
+ with open("data/interim/multi-woz-2.1-analysis/dialog_acts.json", "w") as f:
561
+ json.dump(ordered_sysact_dict, f, indent=2)
562
+ with open("data/interim/multi-woz-2.1-analysis/dialog_act_type.json", "w") as f:
563
+ json.dump(self.unique_da, f, indent=2)
564
+ return data
565
+
566
+
567
+ if __name__ == "__main__":
568
+ db_paths = {
569
+ "attraction": "db/raw/attraction_db.json",
570
+ "hospital": "db/raw/hospital_db.json",
571
+ "hotel": "db/raw/hotel_db.json",
572
+ "police": "db/raw/police_db.json",
573
+ "restaurant": "db/raw/restaurant_db.json",
574
+ "taxi": "db/raw/taxi_db.json",
575
+ "train": "db/raw/train_db.json",
576
+ }
577
+ # get_db_values('db/value_set.json') #
578
+ # preprocess_db(db_paths)
579
+ if not os.path.exists("data/preprocessed/UBAR/multi-woz-2.1-processed"):
580
+ os.mkdir("data/preprocessed/UBAR/multi-woz-2.1-processed")
581
+ dh = DataPreprocessor()
582
+ data = dh.preprocess_main()
583
+
584
+ with open("data/preprocessed/UBAR/multi-woz-2.1-processed/data_for_ubar.json", "w") as f:
585
+ json.dump(data, f, indent=2)
scripts/UBAR_code/train_ubar.py ADDED
@@ -0,0 +1,697 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import logging
4
+ import os
5
+ import random
6
+ import time
7
+ import warnings
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn as nn
12
+ from torch.utils.tensorboard import SummaryWriter
13
+ from tqdm import tqdm
14
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
15
+ from transformers.optimization import AdamW, get_linear_schedule_with_warmup
16
+
17
+ import wandb
18
+ from crazyneuraluser.UBAR_code.config import global_config as cfg
19
+ from crazyneuraluser.UBAR_code.eval import MultiWozEvaluator
20
+ from crazyneuraluser.UBAR_code.reader import MultiWozReader
21
+
22
+ # from config21 import global_config as cfg # global, already initialized
23
+
24
+
25
+ warnings.filterwarnings("ignore")
26
+
27
+
28
+ class Model(object):
29
+ def __init__(self, device):
30
+ self.device = device
31
+ # initialize tokenizer
32
+ self.tokenizer = GPT2Tokenizer.from_pretrained(cfg.gpt_path)
33
+ # cfg.tokenizer = tokenizer
34
+
35
+ # initialize multiwoz reader
36
+ self.reader = MultiWozReader(self.tokenizer)
37
+
38
+ # create model: gpt2
39
+ self.model = GPT2LMHeadModel.from_pretrained(cfg.gpt_path)
40
+ if cfg.mode == "train":
41
+ self.model.resize_token_embeddings(len(self.tokenizer))
42
+ self.model.to(self.device) # single gpu
43
+
44
+ #
45
+ self.evaluator = MultiWozEvaluator(self.reader)
46
+ if cfg.save_log and cfg.mode == "train":
47
+ self.tb_writer = SummaryWriter(log_dir="./log")
48
+ else:
49
+ self.tb_writer = None
50
+
51
+ def get_optimizers(self):
52
+ """
53
+ Setup the optimizer and the learning rate scheduler.
54
+
55
+ from transformers.Trainer
56
+
57
+ parameters from cfg: lr (1e-3); warmup_steps
58
+ """
59
+ # Prepare optimizer and schedule (linear warmup and decay)
60
+ no_decay = ["bias", "LayerNorm.weight"]
61
+ optimizer_grouped_parameters = [
62
+ {
63
+ "params": [p for n, p in self.model.named_parameters() if not any(nd in n for nd in no_decay)],
64
+ "weight_decay": cfg.weight_decay,
65
+ },
66
+ {
67
+ "params": [p for n, p in self.model.named_parameters() if any(nd in n for nd in no_decay)],
68
+ "weight_decay": 0.0,
69
+ },
70
+ ]
71
+ optimizer = AdamW(optimizer_grouped_parameters, lr=cfg.lr)
72
+ num_training_steps = (
73
+ self.reader.set_stats["train"]["num_dials"]
74
+ * cfg.epoch_num
75
+ // (cfg.gradient_accumulation_steps * cfg.batch_size)
76
+ )
77
+ num_warmup_steps = cfg.warmup_steps if cfg.warmup_steps >= 0 else int(num_training_steps * 0.2)
78
+ scheduler = get_linear_schedule_with_warmup(
79
+ optimizer,
80
+ num_warmup_steps=num_warmup_steps,
81
+ num_training_steps=num_training_steps,
82
+ )
83
+ return optimizer, scheduler
84
+
85
+ def log_first_inputs(self, inputs):
86
+ tokenizer = self.tokenizer
87
+ logging.info("**** Input Examples: ****")
88
+ for context in inputs["contexts"][:4]:
89
+ # ubar = tokenizer.convert_ids_to_tokens(context)
90
+ # ubar = tokenizer.convert_tokens_to_string(context)
91
+ # ubar = " ".join(ubar)
92
+ ubar = tokenizer.decode(context)
93
+ logging.info(ubar)
94
+
95
+ def add_torch_input(self, inputs):
96
+ # to tensor and to device
97
+ contexts_tensor = torch.from_numpy(inputs["contexts_np"]).long()
98
+ contexts_tensor = contexts_tensor.to(self.device)
99
+ inputs["contexts_tensor"] = contexts_tensor
100
+ return inputs
101
+
102
+ def add_torch_input_eval(self, inputs):
103
+ # inputs: context
104
+ inputs["context_tensor"] = torch.tensor([inputs["context"]]).to(self.device)
105
+ return inputs
106
+
107
+ def calculate_loss_and_accuracy(self, outputs, labels):
108
+ # GPT2-chicahat/train.py
109
+ lm_logits = outputs[0]
110
+
111
+ shift_logits = lm_logits[..., :-1, :].contiguous()
112
+ shift_labels = labels[..., 1:].contiguous()
113
+
114
+ pad_id = cfg.pad_id
115
+ loss_fct = nn.CrossEntropyLoss(ignore_index=pad_id, reduction="sum")
116
+ loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
117
+
118
+ # avg loss
119
+ not_ignore = shift_labels.ne(pad_id)
120
+ num_targets = not_ignore.long().sum().item()
121
+
122
+ loss /= num_targets
123
+ return loss
124
+
125
+ def train(self):
126
+ """
127
+ UBARU
128
+ """
129
+
130
+ wandb.init(
131
+ # Set the project where this run will be logged
132
+ project="E2E User Simulator (Alistair)",
133
+ entity="byrne-lab",
134
+ # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
135
+ name=cfg.wandb_train_run_name,
136
+ # Track hyperparameters and run metadata
137
+ config={
138
+ "dataset": cfg.data_path,
139
+ "gpt_path": cfg.gpt_path,
140
+ "learning_rate": cfg.lr,
141
+ "warmup_steps": cfg.warmup_steps,
142
+ "gradient_accumulation_steps": cfg.gradient_accumulation_steps,
143
+ "batch_size": cfg.batch_size,
144
+ "epochs": cfg.epoch_num,
145
+ },
146
+ )
147
+
148
+ all_batches = self.reader.get_batches("train")
149
+ # compute num_training_steps in get_batches()
150
+ optimizer, scheduler = self.get_optimizers()
151
+
152
+ # log info
153
+ set_stats = self.reader.set_stats["train"]
154
+ logging.info("***** Running training *****")
155
+ logging.info(
156
+ " Num Training steps(one turn in a batch of dialogs) per epoch = %d",
157
+ set_stats["num_training_steps_per_epoch"],
158
+ )
159
+ logging.info(" Num Turns = %d", set_stats["num_turns"])
160
+ logging.info(" Num Dialogs = %d", set_stats["num_dials"])
161
+ logging.info(" Num Epochs = %d", cfg.epoch_num)
162
+ logging.info(" Batch size = %d", cfg.batch_size)
163
+ logging.info(" Gradient Accumulation steps = %d", cfg.gradient_accumulation_steps)
164
+ logging.info(
165
+ " Total optimization steps = %d",
166
+ set_stats["num_dials"] * cfg.epoch_num // (cfg.gradient_accumulation_steps * cfg.batch_size),
167
+ )
168
+
169
+ # tb writer
170
+ if self.tb_writer is not None:
171
+ self.tb_writer.add_text("cfg", json.dumps(cfg.__dict__, indent=2))
172
+ # self.tb_writer.add_hparams(self.args.to_sanitized_dict(), metric_dict={})
173
+
174
+ log_inputs = 2
175
+ global_step = 0
176
+ # sw = time.time()
177
+
178
+ for epoch in range(cfg.epoch_num):
179
+ epoch_step = 0
180
+ tr_loss = 0.0
181
+ logging_loss = 0.0
182
+ btm = time.time()
183
+ oom_time = 0
184
+ self.model.zero_grad()
185
+
186
+ data_iterator = self.reader.get_nontranspose_data_iterator(all_batches)
187
+
188
+ for batch_idx, dial_batch in enumerate(data_iterator):
189
+ inputs = self.reader.convert_batch_session(dial_batch)
190
+ try: # avoid OOM
191
+ self.model.train()
192
+ if log_inputs > 0: # log inputs for the very first two turns
193
+ self.log_first_inputs(inputs)
194
+ log_inputs -= 1
195
+
196
+ # to tensor
197
+ inputs = self.add_torch_input(inputs)
198
+ # loss
199
+ outputs = self.model(inputs["contexts_tensor"])
200
+ # outputs = self.model(inputs['contexts_tensor']) # debugging with GPT2Model
201
+ loss = self.calculate_loss_and_accuracy(outputs, labels=inputs["contexts_tensor"])
202
+ loss.backward()
203
+ tr_loss += loss.item()
204
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 5.0)
205
+ epoch_step += 1
206
+
207
+ # step, wrt gradient_accumulation_steps, clip grad norm
208
+ if (epoch_step + 1) % cfg.gradient_accumulation_steps == 0 or (
209
+ # end of an epoch
210
+ (epoch_step + 1)
211
+ == set_stats["num_training_steps_per_epoch"]
212
+ ):
213
+ optimizer.step()
214
+ scheduler.step()
215
+ optimizer.zero_grad()
216
+ # global_step: actual step the optimizer took
217
+ global_step += 1
218
+
219
+ logs = {} # for tb writer
220
+ # logging: loss, lr... after certain amount of steps
221
+ if cfg.report_interval > 0 and global_step % cfg.report_interval == 0:
222
+ loss_scalar = (tr_loss - logging_loss) / cfg.report_interval
223
+ logging_loss = tr_loss
224
+ logs["loss"] = loss_scalar
225
+ logging.info(
226
+ "Global step: {}, epoch step: {}, interval loss: {:.4f}".format(
227
+ global_step, epoch_step, loss_scalar
228
+ )
229
+ )
230
+
231
+ # validate
232
+ # add to tensorboard...
233
+ if cfg.evaluate_during_training and loss_scalar < 10:
234
+ results = self.validate(epoch)
235
+ for k, v in results.items():
236
+ eval_key = "eval_{}".format(k)
237
+ logs[eval_key] = v
238
+
239
+ if self.tb_writer:
240
+ for k, v in logs.items():
241
+ self.tb_writer.add_scalar(k, v, global_step)
242
+ # save model...
243
+
244
+ except RuntimeError as exception:
245
+ if "out of memory" in str(exception):
246
+ max_length = max(inputs["lengths"])
247
+ oom_time += 1
248
+ logging.info(
249
+ "WARNING: ran out of memory,times: {}, batch size: {}, max_len: {}".format(
250
+ oom_time, cfg.batch_size, max_length
251
+ )
252
+ )
253
+ if hasattr(torch.cuda, "empty_cache"):
254
+ torch.cuda.empty_cache()
255
+ else:
256
+ logging.info(str(exception))
257
+ raise exception
258
+ logging.info("Train epoch time: {:.2f} min, epoch loss: {:.4f}".format((time.time() - btm) / 60, tr_loss))
259
+ # save model after every epoch
260
+ # if epoch > 10 or tr_loss/epoch_step < 1:
261
+ self.save_model(epoch, tr_loss / epoch_step)
262
+
263
+ wandb.log({"epoch loss": tr_loss})
264
+
265
+ # Mark the run as finished on wandb
266
+ wandb.finish()
267
+
268
+ def save_model(self, epoch, loss):
269
+ save_path = os.path.join(cfg.exp_path, "epoch{}_trloss{:.2f}_gpt2".format(epoch + 1, loss))
270
+ if not os.path.exists(save_path):
271
+ os.mkdir(save_path)
272
+ logging.info("Saving model checkpoint to %s", save_path)
273
+ # save gpt2
274
+ self.model.save_pretrained(save_path)
275
+ # save tokenizer
276
+ self.tokenizer.save_pretrained(save_path)
277
+ # save cfg
278
+
279
+ def validate(self, data="dev", do_test=False, epoch=0):
280
+
281
+ if cfg.mode != "train":
282
+ wandb.init(
283
+ # Set the project where this run will be logged
284
+ project="E2E User Simulator (Alistair)",
285
+ entity="byrne-lab",
286
+ # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
287
+ name=cfg.wandb_eval_run_name,
288
+ # Track hyperparameters and run metadata
289
+ config={
290
+ "eval_load_path": cfg.eval_load_path,
291
+ "dataset": cfg.data_path,
292
+ "gpt_path": cfg.gpt_path,
293
+ "learning_rate": cfg.lr,
294
+ "warmup_steps": cfg.warmup_steps,
295
+ "gradient_accumulation_steps": cfg.gradient_accumulation_steps,
296
+ "batch_size": cfg.batch_size,
297
+ "epochs": cfg.epoch_num,
298
+ "data": data,
299
+ },
300
+ )
301
+
302
+ test_data_at = wandb.Artifact(str(wandb.run.id + str(epoch)), type="predictions")
303
+
304
+ # Create your W&B Table
305
+ column_names = [
306
+ "dialog",
307
+ "turn_num",
308
+ "turn_domain",
309
+ "pointer",
310
+ "user",
311
+ "usdx",
312
+ "resp",
313
+ "bspn",
314
+ "bsdx",
315
+ "aspn",
316
+ "dspn",
317
+ "db",
318
+ "resp_gen",
319
+ "bspn_gen",
320
+ "aspn_gen",
321
+ "dspn_gen",
322
+ ]
323
+ val_table = wandb.Table(columns=column_names)
324
+
325
+ # predict one dialog/ one turn at a time
326
+ self.model.eval()
327
+
328
+ # all_batches = self.reader.get_batches('dev')
329
+ # data_iterator = self.reader.get_data_iterator(all_batches)
330
+ eval_data = self.reader.get_eval_data(data)
331
+
332
+ set_stats = self.reader.set_stats[data]
333
+ logging.info("***** Running Evaluation *****")
334
+ logging.info(" Num Turns = %d", set_stats["num_turns"])
335
+ # logging.info(" Num Dialogs = %d", set_stats['num_dials'])
336
+
337
+ # valid_losses = []
338
+ btm = time.time()
339
+ result_collection = {}
340
+ with torch.no_grad():
341
+ # Adding this index to allow for quick testing of evaluation
342
+ dialogues_to_run = 1
343
+ for dial_idx, dialog in tqdm(enumerate(eval_data)):
344
+ if dialogues_to_run == 0:
345
+ break
346
+ dialogues_to_run -= 1
347
+
348
+ pv_turn = {}
349
+ for turn_idx, turn in enumerate(dialog):
350
+ first_turn = turn_idx == 0
351
+ inputs = self.reader.convert_turn_eval(turn, pv_turn, first_turn)
352
+ inputs = self.add_torch_input_eval(inputs)
353
+
354
+ # fail to generate new tokens, if max_length not set
355
+ context_length = len(inputs["context"])
356
+ if cfg.use_true_curr_bspn: # generate act, response
357
+ max_len = 60
358
+ if not cfg.use_true_curr_aspn:
359
+ max_len = 80
360
+
361
+ outputs = self.model.generate(
362
+ input_ids=inputs["context_tensor"],
363
+ max_length=context_length + max_len,
364
+ temperature=0.7, # top_p=0.9, num_beams=4,
365
+ pad_token_id=self.tokenizer.eos_token_id,
366
+ eos_token_id=self.tokenizer.encode(["<eos_r>"])[0],
367
+ )
368
+ # no_repeat_ngram_size=4
369
+ # turn['generated'] = self.tokenizer.decode(outputs[0])
370
+
371
+ # resp_gen, need to trim previous context
372
+ generated = outputs[0].cpu().numpy().tolist()
373
+ generated = generated[context_length - 1 :]
374
+
375
+ try:
376
+ decoded = self.decode_generated_act_resp(generated)
377
+ except ValueError as exception:
378
+ logging.info(str(exception))
379
+ logging.info(self.tokenizer.decode(generated))
380
+ decoded = {"resp": [], "bspn": [], "aspn": []}
381
+
382
+ else: # predict bspn, access db, then generate act and resp
383
+ outputs = self.model.generate(
384
+ input_ids=inputs["context_tensor"],
385
+ max_length=context_length + 60,
386
+ temperature=0.7, # top_p=0.9, num_beams=4,
387
+ pad_token_id=self.tokenizer.eos_token_id,
388
+ eos_token_id=self.tokenizer.encode(["<eos_b>"])[0],
389
+ )
390
+ generated_bs = outputs[0].cpu().numpy().tolist()
391
+ # generated_bs = generated_bs[context_length-1:]
392
+ bspn_gen = self.decode_generated_bspn(generated_bs[context_length - 1 :])
393
+ # check DB result
394
+ if cfg.use_true_db_pointer:
395
+ # db_result = self.reader.bspan_to_DBpointer(
396
+ # self.tokenizer.decode(turn['bspn']), turn['turn_domain'])
397
+ db = turn["db"]
398
+ else:
399
+ db_result = self.reader.bspan_to_DBpointer(
400
+ self.tokenizer.decode(bspn_gen), turn["turn_domain"]
401
+ )
402
+ db = self.tokenizer.convert_tokens_to_ids(
403
+ self.tokenizer.tokenize("<sos_db> " + db_result + " <eos_db>")
404
+ ) + self.tokenizer.encode(["<sos_a>"])
405
+ inputs["context_tensor_db"] = torch.tensor([inputs["context"][:-1] + bspn_gen + db]).to(
406
+ self.device
407
+ )
408
+ context_length = len(inputs["context_tensor_db"][0])
409
+ outputs_db = self.model.generate(
410
+ input_ids=inputs["context_tensor_db"],
411
+ max_length=context_length + 80,
412
+ temperature=0.7, # top_p=0.9, num_beams=4,
413
+ pad_token_id=self.tokenizer.eos_token_id,
414
+ eos_token_id=self.tokenizer.encode(["<eos_r>"])[0],
415
+ )
416
+ generated_ar = outputs_db[0].cpu().numpy().tolist()
417
+ generated_ar = generated_ar[context_length - 1 :]
418
+ try:
419
+ decoded = self.decode_generated_act_resp(generated_ar)
420
+ decoded["bspn"] = bspn_gen
421
+ except ValueError:
422
+ # NOTE: the below logging is commented out because when running evaluation
423
+ # on early checkpoints of gpt2, the generated response is almost always
424
+ # missing <eos_b> and it kills the GPU due to constant decoding (plus it swamps the logs)
425
+
426
+ # logging.info(str(exception))
427
+ # logging.info(self.tokenizer.decode(generated_ar))
428
+ decoded = {"resp": [], "bspn": [], "aspn": []}
429
+
430
+ turn["resp_gen"] = decoded["resp"]
431
+ turn["bspn_gen"] = turn["bspn"] if cfg.use_true_curr_bspn else decoded["bspn"]
432
+ turn["aspn_gen"] = turn["aspn"] if cfg.use_true_curr_aspn else decoded["aspn"]
433
+ turn["dspn_gen"] = turn["dspn"]
434
+
435
+ # check DB results
436
+ # db_result = self.reader.bspan_to_DBpointer(self.tokenizer.decode(turn['bspn']),
437
+ # turn['turn_domain'])
438
+ # if db_result[0] == 1: # no match
439
+ # print('gt:', self.tokenizer.decode(turn['aspn']), '
440
+ # |gen:', self.tokenizer.decode(decoded['aspn']))
441
+ # print('gen_resp: ', self.tokenizer.decode(decoded['resp']))
442
+ # print('gt_resp: ', self.tokenizer.decode(turn['resp']), '\n')
443
+
444
+ # all true previous context
445
+ pv_turn["labels"] = inputs["labels"]
446
+ pv_turn["resp"] = turn["resp"] if cfg.use_true_prev_resp else decoded["resp"]
447
+ pv_turn["bspn"] = turn["bspn"] if cfg.use_true_prev_bspn else decoded["bspn"]
448
+ pv_turn["db"] = turn["db"] if cfg.use_true_curr_bspn else db
449
+ pv_turn["aspn"] = turn["aspn"] if cfg.use_true_prev_aspn else decoded["aspn"]
450
+
451
+ turn_result = self.reader.inverse_transpose_turn(dialog)
452
+ result_collection.update(turn_result)
453
+
454
+ for dialog, turns in turn_result.items():
455
+ for turn in turns:
456
+ curr_turn_plain = [
457
+ dialog,
458
+ turn["turn_num"],
459
+ turn["turn_domain"],
460
+ turn["pointer"],
461
+ ]
462
+ curr_turn_tokenised = [
463
+ self.tokenizer.decode(turn[key])
464
+ for key in turn.keys()
465
+ if key != "pointer" and key != "turn_domain" and key != "turn_num"
466
+ ]
467
+ curr_turn_data = curr_turn_plain + curr_turn_tokenised
468
+ val_table.add_data(*curr_turn_data)
469
+
470
+ logging.info("inference time: {:.2f} min".format((time.time() - btm) / 60))
471
+ # score
472
+ btm = time.time()
473
+ results, _ = self.reader.wrap_result_lm(result_collection)
474
+ bleu, success, match = self.evaluator.validation_metric(results)
475
+ logging.info("Scoring time: {:.2f} min".format((time.time() - btm) / 60))
476
+ score = 0.5 * (success + match) + bleu
477
+ # valid_loss = 130 - score
478
+ logging.info(
479
+ "validation [CTR] match: %2.2f success: %2.2f bleu: %2.2f score: %.2f" % (match, success, bleu, score)
480
+ )
481
+ eval_results = {}
482
+ eval_results["bleu"] = bleu
483
+ eval_results["success"] = success
484
+ eval_results["match"] = match
485
+ eval_results["score"] = score
486
+ eval_results["result"] = "validation [CTR] match: %2.2f success: %2.2f bleu: %2.2f score: %.2f" % (
487
+ match,
488
+ success,
489
+ bleu,
490
+ score,
491
+ )
492
+
493
+ wandb.log(
494
+ {
495
+ "bleu": eval_results["bleu"],
496
+ "success": eval_results["success"],
497
+ "match": eval_results["match"],
498
+ "score": eval_results["score"],
499
+ }
500
+ )
501
+
502
+ model_setting, epoch_setting = (
503
+ cfg.eval_load_path.split("/")[1],
504
+ cfg.eval_load_path.split("/")[2],
505
+ )
506
+ eval_on = "-".join(cfg.exp_domains)
507
+ if data == "test":
508
+ eval_on += "_test"
509
+ if not os.path.exists(cfg.log_path):
510
+ os.mkdir(cfg.log_path)
511
+ log_file_name = os.path.join(cfg.log_path, model_setting + "-" + eval_on + ".json")
512
+ if os.path.exists(log_file_name):
513
+ eval_to_json = json.load(open(log_file_name, "r"))
514
+ eval_to_json[epoch_setting] = eval_results
515
+ json.dump(eval_to_json, open(log_file_name, "w"), indent=2)
516
+ else:
517
+ eval_to_json = {}
518
+ eval_to_json[epoch_setting] = eval_results
519
+ json.dump(eval_to_json, open(log_file_name, "w"), indent=2)
520
+ logging.info("update eval results to {}".format(log_file_name))
521
+
522
+ # log predictions table to wandb, giving it a name
523
+ test_data_at.add(val_table, "predictions")
524
+ wandb.run.log_artifact(test_data_at)
525
+
526
+ if cfg.mode != "train":
527
+ # Mark the run as finished on wandb
528
+ wandb.finish()
529
+
530
+ return eval_results
531
+
532
+ def decode_generated_act_resp(self, generated):
533
+ """
534
+ decode generated
535
+ return decoded['resp'] ('bspn', 'aspn')
536
+ """
537
+ decoded = {}
538
+ eos_a_id = self.tokenizer.encode(["<eos_a>"])[0]
539
+ eos_r_id = self.tokenizer.encode(["<eos_r>"])[0]
540
+ # eos_b_id = self.tokenizer.encode(["<eos_b>"])[0]
541
+
542
+ # eos_r may not exists if gpt2 generated repetitive words.
543
+ if eos_r_id in generated:
544
+ eos_r_idx = generated.index(eos_r_id)
545
+ else:
546
+ eos_r_idx = len(generated) - 1
547
+ # NOTE: the below logging is commented out because when running evaluation
548
+ # on early checkpoints of gpt2, the generated response is almost always missing
549
+ # <eos_r> and it kills the GPU due to constant decoding (plus it swamps the logs)
550
+
551
+ # logging.info('eos_r not in generated: ' +
552
+ # self.tokenizer.decode(generated))
553
+
554
+ if cfg.use_true_curr_aspn: # only predict resp
555
+ decoded["resp"] = generated[: eos_r_idx + 1]
556
+ else: # predicted aspn, resp
557
+ eos_a_idx = generated.index(eos_a_id)
558
+ decoded["aspn"] = generated[: eos_a_idx + 1]
559
+ decoded["resp"] = generated[eos_a_idx + 1 : eos_r_idx + 1]
560
+ # if cfg.use_true_curr_bspn:
561
+
562
+ # else: # predict bspn aspn resp
563
+ # eos_b_idx = generated.index(eos_b_id)
564
+ # eos_a_idx = generated.index(eos_a_id)
565
+ # decoded['bspn'] = generated[: eos_b_idx+1]
566
+ # decoded['aspn'] = generated[eos_b_idx+1: eos_a_idx+1]
567
+ # decoded['resp'] = generated[eos_a_idx+1: eos_r_idx+1]
568
+ return decoded
569
+
570
+ def decode_generated_bspn(self, generated):
571
+ eos_b_id = self.tokenizer.encode(["<eos_b>"])[0]
572
+ if eos_b_id in generated:
573
+ eos_b_idx = generated.index(eos_b_id)
574
+ else:
575
+ eos_b_idx = len(generated) - 1
576
+ return generated[: eos_b_idx + 1]
577
+
578
+
579
+ def parse_arg_cfg(args):
580
+ # add args to cfg
581
+ if args.cfg:
582
+ for pair in args.cfg:
583
+ k, v = tuple(pair.split("="))
584
+ dtype = type(getattr(cfg, k))
585
+ if dtype == type(None):
586
+ raise ValueError()
587
+ if dtype is bool:
588
+ v = False if v == "False" else True
589
+ elif dtype is list:
590
+ v = v.split(",")
591
+ if k == "cuda_device":
592
+ v = [int(no) for no in v]
593
+ else:
594
+ v = dtype(v)
595
+ setattr(cfg, k, v)
596
+ return
597
+
598
+
599
+ def main():
600
+ if not os.path.exists("./models/UBAR/experiments"):
601
+ os.mkdir("./models/UBAR/experiments")
602
+
603
+ if not os.path.exists("./models/UBAR/experiments_21"):
604
+ os.mkdir("./models/UBAR/experiments_21")
605
+
606
+ parser = argparse.ArgumentParser()
607
+ parser.add_argument("-mode")
608
+ parser.add_argument("-cfg", nargs="*")
609
+ args = parser.parse_args()
610
+
611
+ cfg.mode = args.mode
612
+ if args.mode == "test" or args.mode == "adjust":
613
+ parse_arg_cfg(args)
614
+ # cfg.model_path = cfg.eval_load_path
615
+ cfg.gpt_path = cfg.eval_load_path
616
+ else: # train
617
+
618
+ parse_arg_cfg(args)
619
+ if cfg.exp_path in ["", "to be generated"]:
620
+ # log file path, control the factors: seed, learning_rate, batch_size,
621
+ # early_stop_count, weight decay... cfg.exp_path = 'experiments/
622
+ # {}_{}_sd{}_lr{}_bs{}_sp{}_dc{}/'.format('-'.join(cfg.exp_domains),
623
+ # cfg.exp_no, cfg.seed, cfg.lr, cfg.batch_size,
624
+ # cfg.early_stop_count, cfg.weight_decay_count)
625
+
626
+ experiments_path = (
627
+ "./models/UBAR/experiments" if "all" in cfg.exp_domains else "./models/experiments_Xdomain"
628
+ )
629
+ cfg.exp_path = os.path.join(
630
+ experiments_path,
631
+ "{}_{}_sd{}_lr{}_bs{}_ga{}".format(
632
+ "-".join(cfg.exp_domains),
633
+ cfg.exp_no,
634
+ cfg.seed,
635
+ cfg.lr,
636
+ cfg.batch_size,
637
+ cfg.gradient_accumulation_steps,
638
+ ),
639
+ )
640
+ logging.info("save path:", cfg.exp_path)
641
+ if cfg.save_log:
642
+ if not os.path.exists(cfg.exp_path):
643
+ os.mkdir(cfg.exp_path)
644
+
645
+ # to gpt later
646
+ cfg.model_path = os.path.join(cfg.exp_path, "model.pkl")
647
+ cfg.result_path = os.path.join(cfg.exp_path, "result.csv")
648
+ cfg.vocab_path_eval = os.path.join(cfg.exp_path, "vocab")
649
+ cfg.eval_load_path = cfg.exp_path
650
+
651
+ cfg._init_logging_handler(args.mode)
652
+ if cfg.cuda:
653
+ if len(cfg.cuda_device) == 1:
654
+ cfg.multi_gpu = False
655
+ # torch.cuda.set_device(cfg.cuda_device[0])
656
+ device = torch.device("cuda:{}".format(cfg.cuda_device[0]))
657
+ else:
658
+ pass # multi-gpu
659
+ else:
660
+ device = torch.device("cpu")
661
+ # logging.info('Device: {}'.format(torch.cuda.current_device()))
662
+
663
+ # fix random seed
664
+ torch.manual_seed(cfg.seed)
665
+ torch.cuda.manual_seed(cfg.seed)
666
+ random.seed(cfg.seed)
667
+ np.random.seed(cfg.seed)
668
+
669
+ # initialize model
670
+ m = Model(device)
671
+
672
+ if args.mode == "train": # train
673
+ if cfg.save_log: # save cfg details.
674
+ pass
675
+ m.train()
676
+ else: # test
677
+ logging.info(
678
+ "Generate setting: \n\t use true_prev_bspn={} \n\t use true_prev_aspn={} \n\t use true_db_pointer={} \
679
+ \n\t use true_prev_resp={} \n\t use true_curr_bspn={} \n\t use true_curr_aspn={} \
680
+ \n\t use_all_previous_context={}".format(
681
+ cfg.use_true_prev_bspn,
682
+ cfg.use_true_prev_aspn,
683
+ cfg.use_true_db_pointer,
684
+ cfg.use_true_prev_resp,
685
+ cfg.use_true_curr_bspn,
686
+ cfg.use_true_curr_aspn,
687
+ cfg.use_all_previous_context,
688
+ )
689
+ )
690
+
691
+ logging.info("Running eval on test")
692
+ m.validate(cfg.eval_set)
693
+ logging.info("Evaluation finished")
694
+
695
+
696
+ if __name__ == "__main__":
697
+ main()
scripts/agent_agent.yaml ADDED
File without changes
scripts/crazyneuraluser.egg-info/PKG-INFO ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: crazyneuraluser
3
+ Version: 0.0.post1.dev55+g3c295fb.d20220606
4
+ Summary: Add a short description here!
5
+ Home-page: https://github.com/pyscaffold/pyscaffold/
6
+ Author: Extended by Alistair McLeay, original code by Alexandru Coca
7
+ Author-email: am@alistairmcleay.com and alexcoca23@yahoo.co.uk
8
+ License: MIT
9
+ Project-URL: Documentation, https://pyscaffold.org/
10
+ Platform: any
11
+ Classifier: Development Status :: 4 - Beta
12
+ Classifier: Programming Language :: Python
13
+ Description-Content-Type: text/markdown; charset=UTF-8; variant=GFM
14
+ Provides-Extra: testing
15
+ License-File: LICENSE.txt
16
+ License-File: AUTHORS.md
17
+
18
+ # Cambridge Masters Project
19
+ Joint Learning of Practical Dialogue Systems and User Simulators
20
+
21
+ ## Environment setup
22
+
23
+ 1. Create an environment `crazyneuraluser` with the help of [conda]
24
+ ```
25
+ conda env create -f environment.yml
26
+ ```
27
+ 2. Activate the new environment with:
28
+ ```
29
+ conda activate crazyneuraluser
30
+ ```
31
+ 3. Install a version of `pytorch` compatible with your hardware (see the [pytorch website](https://pytorch.org/get-started/previous-versions/)). E.g.:
32
+ ```
33
+ pip install torch --extra-index-url https://download.pytorch.org/whl/cu113
34
+ ```
35
+
36
+ 4. Install `spacy` and download the tokenization tool in spacy:
37
+ ```
38
+ pip install spacy'
39
+ python -m spacy download en_core_web_sm
40
+ ```
41
+
42
+ ### Generating dialogues through agent-agent interaction
43
+
44
+ To generate dialogues, first change working directory to the `baselines` directory. Run the command
45
+ ```
46
+ python baselines_setup.py
47
+ ```
48
+ to prepare `convlab2` for running the baselines.
49
+
50
+ #### Generating dialogues conditioned on randomly sampled goals
51
+
52
+ Select one of the available configurations in the `configs` directory and run the command
53
+ ```
54
+ python simulate_agent_interaction.py --config /rel/path/to/chosen/config
55
+ ```
56
+ to generate dialogues conditioned on randomly sampled goals according to the `convlab2` goal model. The dialogues will be be saved automatically in the `models` directory, under a directory whose name depends on the configuration run. The `models` directory is located in the parent directory of the `baselines` directory. The `metadata.json` file saved with the dialogues contains information about the data generation process.
57
+
58
+ #### Generating dialogues conditioned on `MultiWOZ2.1` goals
59
+
60
+ To generate the entire corpus, simply pass the `--goals-path /path/to/multiwoz2.1/data.json/file` flag to `simulate_agent_interaction.py`. To generate the `test/val` split additionally pass the `--filter-path /path/to/multiwoz2.1/test-or-valListFile` argument to `simulate_agent_interaction.py`. You can use the `generate_multiwoz21_train_id_file` function in `baselines/utils.py` to generate `trainListFile` which can then be passed via the `--filter-path` argument to the dialogue generation script in order to generate dialogues conditioned on the `MultiWOZ2.1` training goals.
61
+
62
+ ### Converting the generated dialogues to SGD-like format
63
+
64
+ The `create_data_from_multiwoz.py` script can be used to convert the generated dialogues to SGD format, necessary for evaluation. It is based on the script provided by Google for DSTC8, but with additional functionality such as:
65
+
66
+ - conversion of slot names as annotated in the MultiWOZ 2.1 dialogue acts to different slot names, specified through the `--slots_convention` argument. Options are `multiwoz22` to convert the slots to the same slots as defined in the MultiWOZ 2.2 dataset whreas the `multiwoz_goals` converts the slot names to the names used in the dialogue goal and state tracking annotations.
67
+
68
+ - addition of system and user `nlu` fields for every turn
69
+
70
+ - option to perform cleaning operations on the goals to ensure a standard format is received by the evaluator.
71
+
72
+ The conversion is done according to the `schema.json` file in the `baselines` directory, which is the same as used by `DSTC8` conversion except for the addition of the `police` domain. Type ``python create_data_from_multiwoz.py --helpfull`` to see a full list of flags and usage.
73
+
74
+ ## Installation
75
+
76
+ The recommended way to use this repository is to develop the core code under `src/crazyneuraluser`. The experiments/exporatory analysis making use of the core package code should be placed outside the library and imported. See more guidance under the [Project Organisation](#project-organization) section below.
77
+
78
+ To create an environment for the package, make sure you have deactivated all `conda` environments. Then:
79
+
80
+ 1. Create an environment `crazyneuraluser` with the help of [conda]:
81
+ ```
82
+ conda env create -f environment.yml
83
+ ```
84
+ 2. Add the developer dependencies to this environment with the help of [conda]:
85
+ ```
86
+ conda env update -f dev_environment.yml
87
+ ```
88
+
89
+ Optional and needed only once after `git clone`:
90
+
91
+ 3. install several [pre-commit] git hooks with:
92
+ ```bash
93
+ pre-commit install
94
+ # You _are encouraged_ to run `pre-commit autoupdate`
95
+ ```
96
+ and checkout the configuration under `.pre-commit-config.yaml`.
97
+ The `-n, --no-verify` flag of `git commit` can be used to deactivate pre-commit hooks temporarily.
98
+
99
+ 4. install [nbstripout] git hooks to remove the output cells of committed notebooks with:
100
+ ```bash
101
+ nbstripout --install --attributes notebooks/.gitattributes
102
+ ```
103
+ This is useful to avoid large diffs due to plots in your notebooks.
104
+ A simple `nbstripout --uninstall` will revert these changes.
105
+
106
+ Then take a look into the `scripts` and `notebooks` folders.
107
+
108
+ ## Dependency Management & Reproducibility
109
+
110
+ 1. Always keep your abstract (unpinned) dependencies updated in `environment.yml` and eventually
111
+ in `setup.cfg` if you want to ship and install your package via `pip` later on.
112
+ 2. Create concrete dependencies as `environment.lock.yml` for the exact reproduction of your
113
+ environment with:
114
+ ```bash
115
+ conda env export -n crazyneuraluser -f environment.lock.yml
116
+ ```
117
+ For multi-OS development, consider using `--no-builds` during the export.
118
+ 3. Update your current environment with respect to a new `environment.lock.yml` using:
119
+ ```bash
120
+ conda env update -f environment.lock.yml --prune
121
+ ```
122
+ ## Project Organization
123
+
124
+ ```
125
+ ├── AUTHORS.md <- List of developers and maintainers.
126
+ ├── CHANGELOG.md <- Changelog to keep track of new features and fixes.
127
+ ├── LICENSE.txt <- License as chosen on the command-line.
128
+ ├── README.md <- The top-level README for developers.
129
+ ├── configs <- Directory for configurations of model & application.
130
+ ├── data
131
+ │ ├── external <- Data from third party sources.
132
+ │ ├── interim <- Intermediate data that has been transformed.
133
+ │ ├── processed <- The final, canonical data sets for modeling.
134
+ │ └── raw <- The original, immutable data dump.
135
+ ├── docs <- Directory for Sphinx documentation in rst or md.
136
+ ├── environment.yml <- The conda environment file for reproducibility.
137
+ ├── models <- Trained and serialized models, model predictions,
138
+ │ or model summaries.
139
+ ├── notebooks <- Jupyter notebooks. Naming convention is a number (for
140
+ │ ordering), the creator's initials and a description,
141
+ │ e.g. `1.0-fw-initial-data-exploration`.
142
+ ├── pyproject.toml <- Build system configuration. Do not change!
143
+ ├── references <- Data dictionaries, manuals, and all other materials.
144
+ ├── reports <- Generated analysis as HTML, PDF, LaTeX, etc.
145
+ │ └── figures <- Generated plots and figures for reports.
146
+ ├── scripts <- Analysis and production scripts which import the
147
+ │ actual Python package, e.g. train_model.py.
148
+ ├── setup.cfg <- Declarative configuration of your project.
149
+ ├── setup.py <- Use `pip install -e .` to install for development or
150
+ | or create a distribution with `tox -e build`.
151
+ ├── src
152
+ │ └── crazyneuraluser <- Actual Python package where the main functionality goes.
153
+ ├── tests <- Unit tests which can be run with `py.test`.
154
+ ├── .coveragerc <- Configuration for coverage reports of unit tests.
155
+ ├── .isort.cfg <- Configuration for git hook that sorts imports.
156
+ └── .pre-commit-config.yaml <- Configuration of pre-commit git hooks.
157
+ ```
158
+
159
+ <!-- pyscaffold-notes -->
160
+
161
+ ## Note
162
+
163
+ This project has been set up using [PyScaffold] 4.0.1 and the [dsproject extension] 0.6.1.
164
+
165
+ [conda]: https://docs.conda.io/
166
+ [pre-commit]: https://pre-commit.com/
167
+ [Jupyter]: https://jupyter.org/
168
+ [nbstripout]: https://github.com/kynan/nbstripout
169
+ [Google style]: http://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings
170
+ [PyScaffold]: https://pyscaffold.org/
171
+ [dsproject extension]: https://github.com/pyscaffold/pyscaffoldext-dsproject
scripts/crazyneuraluser.egg-info/SOURCES.txt ADDED
File without changes
scripts/crazyneuraluser.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
scripts/crazyneuraluser.egg-info/not-zip-safe ADDED
@@ -0,0 +1 @@
 
 
1
+
scripts/crazyneuraluser.egg-info/requires.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers==4.18.0
2
+ tqdm==4.64.0
3
+ wandb==0.12.16
4
+ nltk==3.7
5
+ sklearn==0.0
6
+ tensorboard==2.9.0
7
+ spacy==3.3.0
8
+
9
+ [:python_version < "3.8"]
10
+ importlib-metadata
11
+
12
+ [testing]
13
+ setuptools
14
+ pytest
15
+ pytest-cov
scripts/crazyneuraluser.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ crazyneuraluser
scripts/simulate_interaction.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import traceback
3
+ import pandas as pd
4
+
5
+ # from tqdm import tqdm
6
+ from UBAR_code.interaction import UBAR_interact
7
+ from user_model_code.interaction import multiwoz_interact
8
+ from UBAR_code.interaction.UBAR_interact import bcolors
9
+
10
+
11
+ def instantiate_agents():
12
+
13
+ UBAR_checkpoint_path = "models/UBAR/experiments/distilgpt-2_sd11_lr0.0001_bs16_ga2/epoch50_trloss0.59_gpt2"
14
+ user_model_checkpoint_path = "models/user_model/MultiWOZ-full_checkpoint_step340k"
15
+
16
+ sys_model = UBAR_interact.UbarSystemModel(
17
+ "UBAR_sys_model", UBAR_checkpoint_path, "scripts/UBAR_code/interaction/config.yaml"
18
+ )
19
+
20
+ user_model = multiwoz_interact.NeuralAgent(
21
+ "user", user_model_checkpoint_path, "scripts/user_model_code/interaction/config.yaml"
22
+ )
23
+
24
+ return sys_model, user_model
25
+
26
+
27
+ def read_multiwoz_data():
28
+ """
29
+ Read the multiwoz 2.0 raw data from the .json file
30
+ """
31
+ raw_mwoz_20_path = "data/raw/UBAR/multi-woz/data.json"
32
+ df_raw_mwoz = pd.read_json(raw_mwoz_20_path)
33
+ return df_raw_mwoz
34
+
35
+
36
+ def load_test_val_lists():
37
+ val_list_file = "data/raw/UBAR/multi-woz/valListFile.json"
38
+ test_list_file = "data/raw/UBAR/multi-woz/testListFile.json"
39
+
40
+ with open(val_list_file, "r") as f:
41
+ val_list = f.readlines()
42
+ val_list = [x.strip() for x in val_list]
43
+
44
+ with open(test_list_file, "r") as f:
45
+ test_list = f.readlines()
46
+ test_list = [x.strip() for x in test_list]
47
+
48
+ return val_list, test_list
49
+
50
+
51
+ def main(
52
+ write_to_file=False, ground_truth_system_responses=False, train_only=True, n_dialogues="all", log_successes=False
53
+ ):
54
+ sys_model, user_model = instantiate_agents()
55
+
56
+ # TODO: move hardcoded vars into config file
57
+ raw_mwoz_20_path = "data/raw/UBAR/multi-woz/data.json"
58
+ user_utterances_out_path = "data/preprocessed/UBAR/user_utterances_from_simulator.txt"
59
+ logging_successes_path = "data/preprocessed/UBAR/logging_successes"
60
+ sys_model.print_intermediary_info = False
61
+ user_model.print_intermediary_info = False
62
+
63
+ df_raw_mwoz = pd.read_json(raw_mwoz_20_path)
64
+ if n_dialogues == "all":
65
+ n_dialogues = len(df_raw_mwoz.columns)
66
+
67
+ curr_dialogue_user_utterances_formatted = []
68
+
69
+ print("Loading goals...")
70
+ goals = multiwoz_interact.read_multiWOZ_20_goals(raw_mwoz_20_path, n_dialogues)
71
+
72
+ # Write column headers
73
+ if write_to_file:
74
+ with open(user_utterances_out_path, "w") as f:
75
+ f.write("Dialogue #\tDialogue ID\tTurn #\tSystem Response\n")
76
+
77
+ print("Loading data...")
78
+ df_mwoz_data = read_multiwoz_data()
79
+ val_list, test_list = load_test_val_lists()
80
+
81
+ successful_dialogues = 0
82
+ total_dialogues_generated = 0 # train dialogues only
83
+ for dialogue_idx, (goal, dialogue_filename) in enumerate(zip(goals, df_mwoz_data.columns)):
84
+ if log_successes:
85
+ # log successful_dialogues to logging_successes_path every 100 dialogues
86
+ if dialogue_idx % 100 == 0:
87
+ with open(logging_successes_path, "w") as f:
88
+ f.write(str(successful_dialogues) + " / " + str(total_dialogues_generated))
89
+
90
+ curr_dialogue_user_utterances_formatted = []
91
+ if train_only:
92
+ if dialogue_filename in val_list or dialogue_filename in test_list:
93
+ continue
94
+
95
+ total_dialogues_generated += 1
96
+ print("Dialogue: {}".format(dialogue_filename))
97
+
98
+ # There are occasionally exceptions thrown from one of the agents, usually the user
99
+ # In this case we simply continue to the next dialogue
100
+ try:
101
+ # Reset state after each dialogue
102
+ sys_model.init_session()
103
+ user_model.init_session(ini_goal=goal)
104
+ sys_response = ""
105
+
106
+ for turn_idx in range(50):
107
+ # Turn idx in this case represents the turn as one user utterance AND one system response
108
+ usr_response_raw_data_idx = turn_idx * 2
109
+ sys_response_raw_data_idx = turn_idx * 2 + 1
110
+
111
+ user_utterance = user_model.response(sys_response)
112
+ print(bcolors.OKBLUE + "User: " + bcolors.ENDC + user_utterance)
113
+
114
+ if write_to_file:
115
+ user_utterance = user_utterance.replace("\n", " ")
116
+ curr_dialogue_user_utterances_formatted.append(
117
+ str(dialogue_idx)
118
+ + "\t"
119
+ + dialogue_filename
120
+ + "\t"
121
+ + str(usr_response_raw_data_idx)
122
+ + "\t"
123
+ + user_utterance
124
+ + "\n"
125
+ )
126
+
127
+ if user_model.is_terminated():
128
+ successful_dialogues += 1
129
+ print(bcolors.OKCYAN + "Dialogue terminated successfully!" + bcolors.ENDC)
130
+ print(bcolors.OKCYAN + "---" * 30 + bcolors.ENDC + "\n")
131
+ if write_to_file:
132
+ # Write whole dialogue to file
133
+ with open(user_utterances_out_path, "a") as f:
134
+ for line in curr_dialogue_user_utterances_formatted:
135
+ f.write(line)
136
+ break
137
+
138
+ # Next turn materials
139
+ if ground_truth_system_responses:
140
+ # If we are at the end of the ground truth dialogues
141
+ if len(df_mwoz_data.iloc[:, dialogue_idx].log) <= sys_response_raw_data_idx:
142
+ print(bcolors.RED + "Dialogue terminated unsuccessfully!" + bcolors.ENDC)
143
+ print(bcolors.RED + "---" * 30 + bcolors.ENDC + "\n")
144
+ break
145
+ sys_response = df_mwoz_data.iloc[:, dialogue_idx].log[sys_response_raw_data_idx]["text"]
146
+ else:
147
+ sys_response = sys_model.response(user_utterance, turn_idx)
148
+ capitalised_sys_response = sys_response[0].upper() + sys_response[1:]
149
+ print(bcolors.GREEN + "System: " + bcolors.ENDC + capitalised_sys_response)
150
+
151
+ except Exception:
152
+ print(bcolors.RED + "*" * 30 + bcolors.ENDC)
153
+ print(bcolors.RED + "Error in dialogue {}".format(dialogue_filename) + bcolors.ENDC)
154
+ print(bcolors.RED + "*" * 30 + bcolors.ENDC)
155
+ traceback.print_exc()
156
+ continue
157
+
158
+ print("Successful dialogues: {}".format(successful_dialogues))
159
+ print("Total dialogues: {}".format(n_dialogues))
160
+ print("% Successful Dialopues: {}".format(successful_dialogues / n_dialogues))
161
+
162
+
163
+ if __name__ == "__main__":
164
+ # TODO: move parameters to config file
165
+ # Fix the hacky mess below
166
+ ground_truth_system_responses = sys.argv[1]
167
+ if ground_truth_system_responses == "False":
168
+ ground_truth_system_responses = False
169
+ else:
170
+ ground_truth_system_responses = True
171
+ main(write_to_file=False, ground_truth_system_responses=ground_truth_system_responses)
scripts/template_train_model.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ import logging
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ import click
7
+ from IPython.core import ultratb
8
+
9
+ import crazyneuraluser
10
+
11
+ # fallback to debugger on error
12
+ sys.excepthook = ultratb.FormattedTB(mode="Verbose", color_scheme="Linux", call_pdb=1)
13
+ # turn UserWarning messages to errors to find the actual cause
14
+ # import warnings
15
+ # warnings.simplefilter("error")
16
+
17
+ _logger = logging.getLogger(__name__)
18
+
19
+
20
+ @click.command()
21
+ @click.option(
22
+ "-c",
23
+ "--config",
24
+ "cfg_path",
25
+ required=True,
26
+ type=click.Path(exists=True),
27
+ help="path to config file",
28
+ )
29
+ @click.option("--quiet", "log_level", flag_value=logging.WARNING, default=True)
30
+ @click.option("-v", "--verbose", "log_level", flag_value=logging.INFO)
31
+ @click.option("-vv", "--very-verbose", "log_level", flag_value=logging.DEBUG)
32
+ @click.version_option(crazyneuraluser.__version__)
33
+ def main(cfg_path: Path, log_level: int):
34
+ logging.basicConfig(
35
+ stream=sys.stdout,
36
+ level=log_level,
37
+ datefmt="%Y-%m-%d %H:%M",
38
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
39
+ )
40
+ # YOUR CODE GOES HERE! Keep the main functionality in src/crazyneuraluser
41
+ # est = crazyneuraluser.models.Estimator()
42
+
43
+
44
+ if __name__ == "__main__":
45
+ main()
scripts/user_model_code/__init__.py ADDED
File without changes
scripts/user_model_code/decode.sh ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment=$1
2
+ checkpoint=$2
3
+
4
+ if [[ "$experiment" == "SGD" ]]; then
5
+ echo "Conduct experiment with SGD dataset"
6
+ job_name='SGD-full'
7
+ data_list="sgd" # 165k training examples
8
+ eval_interval=50000 # evaluation interval
9
+
10
+ elif [[ "$experiment" == "MultiWOZ" ]]; then
11
+ echo "Conduct experiment with MulwiWOZ dataset"
12
+ job_name='MultiWOZ-full'
13
+ data_list="multiwoz" # 56k training examples
14
+ eval_interval=20000
15
+
16
+ elif [[ "$experiment" == "Joint" ]]; then
17
+ echo "Conduct experiment with SGD + MulwiWOZ dataset"
18
+ job_name='Joint-full'
19
+ data_list="sgd multiwoz" # 221k training examples
20
+ eval_interval=70000
21
+
22
+ else
23
+ echo "Unrecognised argument"
24
+ exit
25
+ fi
26
+
27
+ mkdir -p log decode
28
+ decode_file='decode/'$job_name'.json'
29
+ eye_browse_output=true # set to false for storing generation results in file
30
+
31
+ python main.py --mode='testing' \
32
+ --model_name=$job_name \
33
+ --checkpoint=$checkpoint \
34
+ --decode_file=$decode_file \
35
+ --data_dir="processed_data" \
36
+ --data_list=$data_list \
37
+ --eye_browse_output=$eye_browse_output
scripts/user_model_code/interaction/__init__.py ADDED
File without changes
scripts/user_model_code/interaction/config.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ path: "./models/user_model/MultiWOZ-full_checkpoint_step340k"
3
+ goal_update:
4
+ finish_inform: "loose" # loose or strict
5
+
6
+ schema_path: "scripts/user_model_code/interaction/schema.json"
7
+
8
+ decode:
9
+ dec_max_len: 1024
10
+ num_beams: 1
11
+ temperature: 0.7
12
+ do_sample: False
scripts/user_model_code/interaction/multiwoz_interact.py ADDED
@@ -0,0 +1,1034 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import random
3
+ import re
4
+ import sys
5
+ import traceback
6
+ from typing import List
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ import transformers
12
+ from omegaconf import OmegaConf
13
+ from transformers import GPT2LMHeadModel, GPT2Tokenizer
14
+ from .utils import add_str, bcolors, find_segment, load_schema, segment_gen, wrap_element
15
+
16
+
17
+ class DummyPolicy:
18
+ def init_session(self, ini_goal): # noqa
19
+ self.goal = ini_goal # noqa
20
+
21
+ def get_goal(self) -> dict:
22
+ """Returns current user goal.
23
+
24
+ Notes
25
+ -----
26
+ ``hasattr`` user works around the fact that ``convlab2`` initialises the dialogue session
27
+ before we can explicitly pass the goal to the user model.
28
+ """
29
+ if hasattr(self.goal, "domain_goals"):
30
+ return self.goal.domain_goals
31
+ # return {}
32
+ return self.goal # for consistency
33
+
34
+
35
+ def generation_func(model, input_ids, eos_id, dec_max_len):
36
+ """Generation method using greedy search for Transformer v2.x"""
37
+
38
+ def _extend_mask(mask):
39
+ mask = torch.cat([mask, mask.new_ones((mask.shape[0], 1))], dim=-1)
40
+ return mask
41
+
42
+ # input_ids, attention_mask, token_type_ids = batch['input_ids'], batch['attention_mask'], batch['token_type_ids']
43
+ batch_size = input_ids.size(0)
44
+ attention_mask = torch.ones_like(input_ids)
45
+ past = None
46
+ finish_sent = [False for _ in range(batch_size)]
47
+ for i in range(dec_max_len):
48
+ logits, past = model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=None).values()
49
+
50
+ # logits: (B, T, V), T=1 when past is passed
51
+ next_token_logits = logits[:, -1, :]
52
+ next_token = torch.argmax(next_token_logits, dim=-1)
53
+ input_ids = torch.cat([input_ids, next_token.unsqueeze(-1)], dim=-1)
54
+ attention_mask = _extend_mask(attention_mask)
55
+
56
+ for bs_idx, token_id in enumerate(next_token):
57
+ if finish_sent[bs_idx] is False and token_id.item() == eos_id: # first produce <eos>
58
+ finish_sent[bs_idx] = True
59
+ if sum(finish_sent) == batch_size:
60
+ break
61
+ return input_ids
62
+
63
+
64
+ class NeuralAgent: # crazyusermodel
65
+ def __init__(self, name: str, model_path: str, model_config_path: str):
66
+ """User Simulator
67
+ Description
68
+ ---------
69
+ A user model that is able to chat with the task-oriented dialogue system in an end-to-end manner
70
+
71
+ Parameters
72
+ ----------
73
+ name
74
+ Should indicate the role played by the agent. It should be always user
75
+ """
76
+
77
+ if name != "user":
78
+ raise ValueError(f"Expected name 'user' but got {name} instead.")
79
+
80
+ # load necessities
81
+ self.set_device()
82
+ self.config = OmegaConf.load(model_config_path)
83
+
84
+ self.print_intermediary_info = False
85
+
86
+ # get schema, which is dependent to dataset, only for providing task description here
87
+ self.service2meta, self.schema_intents, self.schema_slots = load_schema(self.config["schema_path"])
88
+ # self.load_checkpoint_and_tokenizer(self.config["model"]["path"])
89
+ self.load_checkpoint_and_tokenizer(model_path)
90
+ self.load_materials()
91
+
92
+ self.context = []
93
+ self.current_goal = {}
94
+ self.behaviour_params = {}
95
+ self.input_action = [] # type: list[list[str]]
96
+ self.output_action = [] # type: list[list[str]]
97
+
98
+ # for compatibility with convlab2 evaluator
99
+ self.policy = DummyPolicy()
100
+
101
+ """ for reproduction """
102
+ seed = 1130
103
+ random.seed(seed)
104
+ np.random.seed(seed)
105
+ torch.manual_seed(seed)
106
+ torch.cuda.manual_seed(seed)
107
+ torch.cuda.manual_seed_all(seed)
108
+ torch.backends.cudnn.deterministic = True
109
+ torch.backends.cudnn.enabled = False
110
+ torch.backends.cudnn.benchmark = False
111
+
112
+ def load_checkpoint_and_tokenizer(self, checkpoint_path: str) -> None:
113
+ """Load model checkpoint with the model tokenizer, only for GPT2 for now"""
114
+ print("Load model, tokenizer from {}".format(checkpoint_path))
115
+ self.tokenizer = GPT2Tokenizer.from_pretrained(checkpoint_path)
116
+ self.model = GPT2LMHeadModel.from_pretrained(checkpoint_path)
117
+ self.model.to(self.device)
118
+
119
+ def set_device(self) -> None:
120
+ """Set device to GPU/CPU"""
121
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
122
+
123
+ def load_materials(self):
124
+ """Load useful materials used in generation"""
125
+ # model attributes
126
+ """
127
+ finish_inform
128
+ how strict to finish an informable slot in goal: "strict" or "loose"
129
+ if "strict": the attribute are finished (removed from goal) only if both slot and value are produced in act
130
+ if "loose": the attribute are finished if the slot is produced
131
+ """
132
+ self.finish_inform = (
133
+ self.config.model.goal_update.finish_inform
134
+ ) # controls how strict to eliminate informed slots in goal
135
+ assert self.finish_inform in ["strict", "loose"]
136
+
137
+ # constants
138
+ self.bos_id, _, self.pad_id, self.sep_id = self.tokenizer.convert_tokens_to_ids(
139
+ ["<BOS>", "<EOS>", "<PAD>", "<SEP>"]
140
+ )
141
+ self.bin_flags = {"true": "_True_", "false": "_False_"}
142
+
143
+ self.supported_services = ["train", "attraction", "hotel", "restaurant", "taxi", "police", "hospital"]
144
+
145
+ self.slot_types = {"search": "search", "book": "book"}
146
+ # important to change the corresponding act str name when using different tokenization methods,
147
+ # as they are used to control the user behaviours
148
+ self.const_act_str = {
149
+ "inform": "inform",
150
+ "recommend": "recommend",
151
+ "request": "request",
152
+ "fail_search": "no offer",
153
+ "fail_book": "no book",
154
+ }
155
+
156
+ def prepare_input_ids(self, data: dict, start_token: str) -> str:
157
+ assert start_token in ["<SYS_ACT/>", "<USR_ACT/>"]
158
+ input_seq = ""
159
+ for key in [
160
+ "CTX",
161
+ "SYS_UTT",
162
+ "SYS_ACT",
163
+ "SNT",
164
+ "RA",
165
+ "GC",
166
+ "GOAL",
167
+ ]: # fixed order, consistent between training and inference
168
+ if key not in data:
169
+ continue
170
+ wrap = wrap_element(key, data[key])
171
+ input_seq = add_str(input_seq, wrap)
172
+
173
+ input_seq = add_str(input_seq, start_token)
174
+ if transformers.__version__.startswith("2."): # compatible with transformers v2.x used in convlab2
175
+ input_ids = self.tokenizer.encode(input_seq)
176
+ else:
177
+ input_ids = self.tokenizer(input_seq)["input_ids"] # convert to ids
178
+ input_ids = torch.tensor([input_ids]).long().to(self.device)
179
+ return input_ids
180
+
181
+ def update_internal_data(self, data: dict) -> None:
182
+ """Maintain context and user act in the format of generation string for the next turn generation"""
183
+ # update context
184
+ sys_utt_wrap = wrap_element("SYS", data["SYS_UTT"]) # e.g., <SYS/> Which area would you prefer? </SYS>
185
+ usr_utt_wrap = wrap_element("USR", data["USR_UTT"]) # e.g., <USR/> I want to be in the centre. </USR>
186
+ self._context_str = add_str(self._context_str, sys_utt_wrap)
187
+ self._context_str = add_str(self._context_str, usr_utt_wrap)
188
+
189
+ # update prev usr act
190
+ self._prev_usr_act = data["USR_ACT"] # e.g., <ACT/> inform </ACT> <SLOT/> area </SLOT> <VALUE/> centre </VALUE>
191
+
192
+ def run_inference_once(self, input_ids: torch.tensor, eos_id: int) -> List:
193
+ if transformers.__version__.startswith("2."): # compatible with transformers v2.x used in convlab2
194
+ output = generation_func(self.model, input_ids, eos_id, self.config.decode.dec_max_len)
195
+ else:
196
+ output = self.model.generate(
197
+ input_ids,
198
+ max_length=self.config.decode.dec_max_len,
199
+ do_sample=self.config.decode.do_sample,
200
+ early_stopping=True,
201
+ temperature=self.config.decode.temperature,
202
+ use_cache=True,
203
+ num_beams=self.config.decode.num_beams,
204
+ bos_token_id=self.bos_id,
205
+ eos_token_id=eos_id,
206
+ pad_token_id=self.pad_id,
207
+ )
208
+ return output
209
+
210
+ def generate_whole_sequence(self, sys_utt: str) -> tuple:
211
+ # first forward pass: generate NLU output and three special flags #####
212
+ data = {"CTX": self._context_str, "SYS_UTT": sys_utt}
213
+ start_token, end_token = "<SYS_ACT/>", "</GC>"
214
+ input_ids = self.prepare_input_ids(data, start_token)
215
+ eos_id = self.tokenizer.convert_tokens_to_ids(end_token)
216
+ output = self.run_inference_once(input_ids, eos_id)
217
+ generation = self.tokenizer.decode(output[0]) # decode back to str, including the fed context
218
+
219
+ # parse first pass prediction
220
+ for key in ["SYS_ACT", "SNT", "GC", "RA"]:
221
+ value = find_segment(generation, key)
222
+ data[key] = value
223
+
224
+ # update dynamic goal
225
+ if self.print_intermediary_info:
226
+ print("SYS ACT ->", data["SYS_ACT"])
227
+ goal = self.prepare_turn_goal(self._prev_usr_act, data["SYS_ACT"], data["SNT"], data["GC"], data["RA"])
228
+ data["GOAL"] = goal
229
+
230
+ # second forward pass: generate dialogue act and NLG output #####
231
+ start_token, end_token = "<USR_ACT/>", "<EOS>"
232
+ input_ids = self.prepare_input_ids(data, start_token)
233
+ eos_id = self.tokenizer.convert_tokens_to_ids(end_token)
234
+ output = self.run_inference_once(input_ids, eos_id)
235
+ generation = self.tokenizer.decode(output[0]) # decode back to str, including the fed context
236
+
237
+ # parse second pass prediction
238
+ for key in ["USR_ACT", "USR_UTT"]:
239
+ value = find_segment(generation, key)
240
+ data[key] = value
241
+ return data, generation
242
+
243
+ def _format_complete_goal(self, input_goal: dict) -> dict:
244
+ """Format the internal goal representation given a goal
245
+
246
+ :param input_goal: a goal that the user has in mind
247
+ either from the corpus or sampled randomly in a valid way (e.g., correct slot names)
248
+ :returns: complete_goal: an internal representation of the given goal, a dict with the keys "intents",
249
+ "constraints"
250
+ intents: list[str], list of intents in the dialogue, aka scenario
251
+ constraints: dict, intent as key, in the following format
252
+ dict(intent: intent_constraints)
253
+ intent_constraints: {"informable": dict(slot: value_list), "requestable": slot_set}
254
+ each slot has a value list in case of failure of searching
255
+ """
256
+ # TODO: make the order of services more flexible (how does convlab2 decide the service order?)
257
+ constraints = dict()
258
+ intents = []
259
+ self.n_max_value = {
260
+ self.slot_types["book"]: 0,
261
+ self.slot_types["search"]: 0,
262
+ } # record the max length of value list of a slot
263
+
264
+ for service in input_goal["ordered_services"]:
265
+ if service not in self.supported_services:
266
+ continue
267
+
268
+ # record intent list (scenario), order matters
269
+ intent = self._map_service_to_intent(service)
270
+ assert intent not in intents and intent not in constraints
271
+ intents.append(intent)
272
+ constraints[intent] = {"informable": dict(), "requestable": set()}
273
+
274
+ # collect informable slots
275
+ assert "info" in input_goal[service] # info has to exist
276
+ for key in ["fail_info", "info", "fail_book", "book"]: # order matters
277
+ # assert key in input_goal[service]
278
+ if key not in input_goal[service]:
279
+ continue
280
+ for slot, value in input_goal[service][key].items():
281
+ self._add_info(constraints[intent]["informable"], slot, value)
282
+
283
+ # collect requestable slots
284
+ key = "reqt"
285
+ # assert key in input_goal[service]
286
+ # for slot in input_goal[service][key]:
287
+ if key in input_goal[service]:
288
+ for slot in input_goal[service][key].keys():
289
+ self._add_reqt(constraints[intent]["requestable"], slot)
290
+
291
+ # order intents by the order they are dealt with in the data so
292
+ # if using ground truth system responses the right order of the intents
293
+ # is preserved
294
+
295
+ complete_goal = {"intents": intents, "constraints": constraints}
296
+ return complete_goal
297
+
298
+ def _init_user_status(self) -> dict:
299
+ """Initialise user status with intent and constraint
300
+ intent_idx: int, the index of current intent
301
+ constraint_idx: dict, intent as key, value is the constraint index used to record which value is used
302
+ in the slot value list
303
+ :return:
304
+ """
305
+ intent_idx = 0 # -1
306
+ # constraint_idx = {intent: 0 for intent in self.complete_goal["intents"]}
307
+ constraint_idx = {
308
+ intent: {self.slot_types["search"]: 0, self.slot_types["book"]: 0}
309
+ for intent in self.complete_goal["intents"]
310
+ }
311
+ # TODO: entity provide records, one of the criteria to move to the next intents
312
+ entity_provided = {intent: False for intent in self.complete_goal["intents"]}
313
+ return {
314
+ "intent_idx": intent_idx,
315
+ "constraint_idx": constraint_idx,
316
+ "dialogue_terminate": False,
317
+ "entity_provided": entity_provided,
318
+ }
319
+
320
+ def _get_scenario_str(self) -> None:
321
+ """Get a scenario str from a intent list
322
+
323
+ Description
324
+ convert a list of intents, aka scenario, into string with special marks
325
+ the scenario is determined at the start of dialogue and static during interaction
326
+ """
327
+ intents = self.complete_goal["intents"]
328
+ _str = [wrap_element("INTENT", intent) for intent in intents]
329
+ _str = " ".join(_str)
330
+ self.scenario_str = wrap_element("SCENARIO", _str)
331
+
332
+ def _prepare_current_constraints(
333
+ self,
334
+ involved_intents: List[str],
335
+ involved_slot_types: List[str],
336
+ if_reset_reqt: bool,
337
+ ) -> None:
338
+ """Prepare the current constraints, copied the specified content from the complete goal
339
+
340
+ the current constraints is used as condition in the model generation
341
+ its content comes from the "constraints" in "complete goal",
342
+ but the current constraints only allows one value for a slot at a time
343
+ the value is chosen from the value list by the "constraint_idx" in user status
344
+
345
+ :param involved_intents: list[str], intent list
346
+ :return:
347
+ current_constraints: dict, similar format as constraints in the complete goal,
348
+ but a slot has only one value, e.g.,
349
+ dict(intent: intent_constraints)
350
+ intent_constraints: {"informable": dict(slot: value), "requestable": slot_set}
351
+ """
352
+ # iterate the involved intents
353
+ for intent in involved_intents:
354
+ constraints = {"informable": dict(), "requestable": set()}
355
+ # informable slots value pairs
356
+ for slot, value_list in self.complete_goal["constraints"][intent]["informable"].items():
357
+ slot_type = self._get_slot_type(slot)
358
+ if slot_type not in involved_slot_types:
359
+ continue
360
+ value_idx = self.user_status["constraint_idx"][intent][slot_type]
361
+ if value_idx < len(value_list):
362
+ value = value_list[value_idx]
363
+ constraints["informable"][slot] = value
364
+
365
+ # requestable
366
+ if if_reset_reqt:
367
+ constraints["requestable"] = copy.deepcopy(self.complete_goal["constraints"][intent]["requestable"])
368
+ else:
369
+ constraints["requestable"] = copy.deepcopy(self.current_constraints[intent]["requestable"])
370
+
371
+ # overwrite intent constraints
372
+ self.current_constraints[intent] = constraints
373
+
374
+ @staticmethod
375
+ def _map_intent_to_service(intent: str) -> str:
376
+ # TODO: make it not dataset dependent?
377
+ """map an intent into a service, multiwoz only"""
378
+ return intent.split()[1]
379
+
380
+ @staticmethod
381
+ def _map_service_to_intent(service: str) -> str:
382
+ # TODO: make it not dataset dependent?
383
+ """map a service into an intent, multiwoz only"""
384
+ return f"find {service}"
385
+
386
+ def _get_slot_type(self, slot: str) -> str:
387
+ """return search or book type of a slot"""
388
+ slot_type = "book" if "book" in slot else "search"
389
+ assert slot_type in self.slot_types.keys()
390
+ return slot_type
391
+
392
+ def _get_goal_str(self, intent: str) -> str:
393
+ """prepare the proper goal sequence, same as used in training"""
394
+ goal_str = ""
395
+ # dialogue scenario
396
+ goal_str = add_str(goal_str, self.scenario_str)
397
+
398
+ # current task
399
+ goal_str = add_str(goal_str, wrap_element("TASK", intent))
400
+
401
+ # task description
402
+ service = self._map_intent_to_service(intent)
403
+ description = self.service2meta[service]["intents"][intent]["description"]
404
+ goal_str = add_str(goal_str, wrap_element("DESC", description))
405
+
406
+ # intent_constraints = self.dynamic_constraints[intent]
407
+ intent_constraints = self.current_constraints[intent]
408
+ # informable slots
409
+ info_str = ""
410
+ # for slot, value in intent_constraints["informable"].items():
411
+ for slot in sorted(intent_constraints["informable"].keys()): # sort by slot
412
+ value = intent_constraints["informable"][slot]
413
+ info_str = add_str(info_str, wrap_element("SLOT", slot))
414
+ info_str = add_str(info_str, wrap_element("VALUE", value))
415
+ goal_str = add_str(goal_str, wrap_element("INFORM", info_str))
416
+
417
+ # requestable slots
418
+ req_str = ""
419
+ for slot in sorted(list(intent_constraints["requestable"])):
420
+ req_str = add_str(req_str, wrap_element("SLOT", slot))
421
+ goal_str = add_str(goal_str, wrap_element("REQUEST", req_str))
422
+ return goal_str.strip()
423
+
424
+ def _start_new_intent(self, SNT_flag: str) -> bool:
425
+ """decide whether to start a new intent"""
426
+ # SNT (start new task) is predicted as on
427
+ assert SNT_flag in list(self.bin_flags.values())
428
+ # intent = self.intents[self.intent_idx]
429
+ intent = self.complete_goal["intents"][self.user_status["intent_idx"]]
430
+
431
+ # TODO: need at least an entity provided (not really sure...
432
+ # if not self.intent_entity_provided[intent]: # no entities provided in the intent yet
433
+ # return False
434
+
435
+ # TODO: think about the priority of SNT prediction. It's should be less prioritised than
436
+ # the number of left constraints.
437
+ # if SNT_flag == self.bin_flags["true"]: # model prediction in first turn is true
438
+ # return True
439
+
440
+ # current intent has empty constraints
441
+ if (
442
+ len(self.current_constraints[intent]["informable"]) == 0
443
+ and len(self.current_constraints[intent]["requestable"]) == 0
444
+ ):
445
+ return True
446
+ return False
447
+
448
+ def _check_entity_provided(self, sys_act, intent):
449
+ # TODO:
450
+ """Check if an entity provided in system response (act)"""
451
+ assert intent in [
452
+ "find restaurant",
453
+ "find hotel",
454
+ "find attraction",
455
+ "find train",
456
+ "find taxi",
457
+ "find police",
458
+ "find hospital",
459
+ ]
460
+ if intent in ["find restaurant", "find hotel", "find attraction"]:
461
+ if "<SLOT/> name </SLOT>" in sys_act:
462
+ self.intent_entity_provided[intent] = True
463
+ elif intent == "find train":
464
+ if "<SLOT/> train id </SLOT>" in sys_act:
465
+ self.intent_entity_provided[intent] = True
466
+ else: # taxi
467
+ if "<SLOT/> type </SLOT>" in sys_act:
468
+ self.intent_entity_provided[intent] = True
469
+
470
+ def _activate_dialogue_terminate(self) -> None:
471
+ """Turn on the user status about dialogue termination"""
472
+ self.user_status["dialogue_terminate"] = True
473
+
474
+ def prepare_turn_goal(self, prev_usr_act: str, sys_act: str, SNT_flag: str, GC_flag: str, RA_flag: str) -> str:
475
+ """prepare the goal sequence for the current turn"""
476
+ # TODO: more detailed instruction here
477
+ # TODO: Deal with empty intents (and figure out why they happen)
478
+ intent = self.complete_goal["intents"][self.user_status["intent_idx"]]
479
+
480
+ # TODO: check if at least one entity is provided in system act
481
+ # First thing to do, check if the system provides an entity
482
+ # self._check_entity_provided(sys_act, intent)
483
+
484
+ # update goal first then check if moves to next intent (task)
485
+ self._update_current_constraints(intent, "usr", prev_usr_act, sys_act)
486
+ self._update_current_constraints(
487
+ intent, "sys", prev_usr_act, sys_act
488
+ ) # impact of sys_act overwrites that of usr_act
489
+
490
+ # check if new intent starts
491
+ if self._start_new_intent(SNT_flag):
492
+ self.user_status["intent_idx"] += 1
493
+ if self.user_status["intent_idx"] < len(self.complete_goal["intents"]):
494
+ intent = self.complete_goal["intents"][self.user_status["intent_idx"]]
495
+ else:
496
+ self._activate_dialogue_terminate()
497
+ # TODO: request alternative by setting <RA> for sgd
498
+ # TODO: sample new goal if goal change <GC> for sgd
499
+ goal_str = self._get_goal_str(intent)
500
+
501
+ # print("***** user status *****\n->", self.user_status, "\n")
502
+ # print("***** current intent *****\n->", intent, "\n") # BACK
503
+ # print("***** current intent constraint *****\n->", self.current_constraints[intent], "\n")
504
+ # print("***** corresponding goal str *****\n->", goal_str, "\n")
505
+ # print("***** current entities provided (empty) *****\n->", self.intent_entity_provided, "\n")
506
+ return goal_str
507
+
508
+ def _use_next_constraints(self, intent: str, slot_type: str) -> None:
509
+ """move the constraint pointer to the next"""
510
+ # Another problem is that how to decide which slot type (search or book) to add when failure?
511
+ # one solution is that dont use act mapping to keep NoOffer and NoBook separate, if so, try use nl on act
512
+ self.user_status["constraint_idx"][intent][slot_type] += 1
513
+ if self.user_status["constraint_idx"][intent][slot_type] >= self.n_max_value[slot_type]:
514
+ # TODO: ask Alex, usually how to deal with this warning case? And make it as warning rather than just print
515
+ print(
516
+ f"Failure times on {slot_type} is more than the given value candidates, \
517
+ no new value to choose as alternative"
518
+ )
519
+ print("A valid goal should not enter here!")
520
+ self.user_status["constraint_idx"][intent][slot_type] = (
521
+ self.n_max_value[slot_type] - 1
522
+ ) # let user use last values as they are supposed to be fine
523
+
524
+ def _update_current_constraints(self, intent: str, spk: str, usr_act: str, sys_act: str) -> None:
525
+ # TODO: complete instruction here
526
+ """Update current constraints used for generation based on either previous usr or sys act
527
+
528
+ :param act:
529
+ :param spk:
530
+ :param intent:
531
+ :return:
532
+ """
533
+ assert spk in ["usr", "sys"]
534
+ # act_dict = parse_act(act)
535
+ intent_constraints = self.current_constraints[intent]
536
+
537
+ if spk == "sys":
538
+ act_dict = self.parse_act(sys_act, self.print_intermediary_info)
539
+ # When the system provides information (in the act_dict) then remove it from the user's
540
+ # requestable constraints as the user has been provided with the info!
541
+ # NB: This was added by Alistair after the original code was shared by Andy,
542
+ # as it seems the original implementation missed this critical step.
543
+ if self.const_act_str["inform"] in act_dict:
544
+ for slot in act_dict[self.const_act_str["inform"]]:
545
+ if slot in intent_constraints["requestable"]:
546
+ intent_constraints["requestable"].remove(slot)
547
+ elif self.const_act_str["recommend"] in act_dict:
548
+ for slot in act_dict[self.const_act_str["recommend"]]:
549
+ if slot in intent_constraints["requestable"]:
550
+ intent_constraints["requestable"].remove(slot)
551
+
552
+ # when the system informs failure (search or book), use next set of constraints given in goal #####
553
+ # if "_NOTIFY_FAILURE_" in act_dict:
554
+ if self.const_act_str["fail_search"] in act_dict:
555
+ slot_type = self.slot_types["search"]
556
+ self._use_next_constraints(intent, slot_type)
557
+ keep_slot_types = [
558
+ self.slot_types["search"],
559
+ self.slot_types["book"],
560
+ ] # still in search phase, book slots should be kept
561
+ self._prepare_current_constraints(
562
+ [intent], keep_slot_types, if_reset_reqt=False
563
+ ) # only change constraints for this intent
564
+
565
+ elif self.const_act_str["fail_book"] in act_dict:
566
+ slot_type = self.slot_types["book"]
567
+ self._use_next_constraints(intent, slot_type)
568
+ keep_slot_types = [self.slot_types["book"]] # already found entities, no need to keep search slots
569
+ self._prepare_current_constraints([intent], keep_slot_types, if_reset_reqt=False)
570
+
571
+ # when the system request #
572
+ elif self.const_act_str["request"] in act_dict:
573
+ requested_slots = act_dict[self.const_act_str["request"]]
574
+ for slot in requested_slots.keys():
575
+ # requested slot in current constraint, do nothing
576
+ if slot in intent_constraints["informable"].keys():
577
+ continue
578
+
579
+ # slots that are beyond the current goal enter the following section
580
+ # case 1: requested slot in the complete goal,
581
+ # this should be entered if the system requests the informed slots
582
+ # if slot in self.complete_constraints[intent]["informable"].keys():
583
+ # value = self.complete_constraints[intent]["informable"][slot]
584
+ if (
585
+ slot in self.complete_goal["constraints"][intent]["informable"].keys()
586
+ ): # dict of slot to value_list
587
+ slot_type = self._get_slot_type(slot)
588
+ value_idx = self.user_status["constraint_idx"][intent][slot_type]
589
+ value = self.complete_goal["constraints"][intent]["informable"][slot][value_idx]
590
+
591
+ # case 2: requested slot not in the complete goal, set the value to "dontcare"
592
+ # can sample a new value here for more interesting interactions
593
+ else:
594
+ value = "dontcare" # "no preference" # TODO: play around to see nlg output
595
+ intent_constraints["informable"][slot] = value
596
+
597
+ else: # usr
598
+ act_dict = self.parse_act(usr_act, self.print_intermediary_info)
599
+ # remove informed slot/value pair, if informed #
600
+ if self.const_act_str["inform"] in act_dict:
601
+ for slot, value_list in act_dict[self.const_act_str["inform"]].items():
602
+ # value = value_list[0]
603
+ for value in value_list: # possible to have multi-value slots in user act in corpus
604
+ if self.finish_inform == "loose" and slot in intent_constraints["informable"]:
605
+ del intent_constraints["informable"][slot]
606
+ if (
607
+ self.finish_inform == "strict"
608
+ and slot in intent_constraints["informable"]
609
+ and value == intent_constraints["informable"][slot]
610
+ ):
611
+ del intent_constraints["informable"][slot]
612
+
613
+ # remove requested slot, if requested
614
+ if self.const_act_str["request"] in act_dict:
615
+ sys_act_dict = self.parse_act(sys_act, self.print_intermediary_info) # auxiliary check
616
+ for slot in act_dict[self.const_act_str["request"]].keys():
617
+ # if slot in intent_constraints["requestable"]: # one choice
618
+ if self.const_act_str["inform"] in sys_act_dict:
619
+ if (
620
+ slot in intent_constraints["requestable"]
621
+ and slot in sys_act_dict[self.const_act_str["inform"]].keys()
622
+ ): # another choice, more strict
623
+ intent_constraints["requestable"].remove(slot)
624
+
625
+ def _add_info(self, slot_to_value_list, slot, value) -> None:
626
+ # print(slot)
627
+ # assert slot in self.schema_slots # SLOT_FORMAT
628
+ # constraints[intent]["informable"][slot] = value
629
+ if slot not in slot_to_value_list:
630
+ slot_to_value_list[slot] = []
631
+ # assert value not in slot_to_value_list[slot]
632
+ if value not in slot_to_value_list[slot]:
633
+ slot_to_value_list[slot].append(value)
634
+ slot_type = self._get_slot_type(slot)
635
+ if len(slot_to_value_list) > self.n_max_value[slot_type]:
636
+ self.n_max_value[slot_type] = len(slot_to_value_list)
637
+
638
+ def _add_reqt(self, slot_set, slot) -> None:
639
+ # assert slot in self.schema_slots # SLOT_FORMAT
640
+ # constraints[intent]["requestable"].add(slot)
641
+ slot_set.add(slot)
642
+
643
+ def _validate_input_goal(self):
644
+ """validate the input goal"""
645
+ # TODO: finish the method
646
+ # assert all([intent in self.schema_intents for intent in intents]) # ensure intents are in schema
647
+ pass
648
+
649
+ @staticmethod
650
+ def parse_act(act_seq: str, print_intermediary_info: bool) -> dict:
651
+ """parse usr/sys act string into dict(act: {slot=value_list}) (slots in act_request have '_Empty_' value)"""
652
+ act_dict = {}
653
+ assert isinstance(act_seq, str)
654
+ act_seq = act_seq.split("<ACT/>")
655
+ for act_seg in act_seq:
656
+ if act_seg == "":
657
+ continue
658
+
659
+ act_seg = act_seg.strip() # remove space at the start/end
660
+ act_seg = act_seg.split()
661
+ # get act in special token format #
662
+ # act = act_seg[0] # e.g., _INFORM_, _REQUEST_
663
+ # assert act[0] == "_" and act[-1] == "_"
664
+ # act_seg = " ".join(act_seg[2:]) # discard first two tokens, "_ACT_ </ACT>"
665
+
666
+ # get act in natural language format #
667
+ end_idx = act_seg.index("</ACT>")
668
+ act = " ".join(act_seg[:end_idx])
669
+ act_seg = " ".join(act_seg[end_idx + 1 :]) # act arguments (slot/value pairs)
670
+ # print(f"act: {act}\n", act_seg, "\n")
671
+
672
+ assert act not in act_dict
673
+ act_dict[act] = {}
674
+
675
+ # Sometimes the model bugs out and puts <ACT/> or </ACT> where there should be </VALUE> or <VALUE/>
676
+ if "ACT" in act_seg:
677
+ continue
678
+
679
+ for sv_seg in act_seg.split("</VALUE>"):
680
+ if sv_seg == "":
681
+ continue
682
+
683
+ try:
684
+ sv_seg = sv_seg.replace("<SLOT/>", "")
685
+ sv_seg = sv_seg.strip() # remove spaces at begin and end
686
+ # print("|{}|".format(sv_seg))
687
+ slot, value = sv_seg.split("</SLOT> <VALUE/>")
688
+ slot, value = slot.strip(), value.strip()
689
+ # print("act: |{}|, slot: |{}|, value: |{}|".format(act, slot, value))
690
+ # one slot one value
691
+ # act_dict[act][slot] = value
692
+ # one slot, multi-value is possible by system
693
+ if slot not in act_dict[act]:
694
+ act_dict[act][slot] = []
695
+ if value not in act_dict[act][slot]:
696
+ act_dict[act][slot].append(value)
697
+
698
+ except Exception:
699
+ if print_intermediary_info:
700
+ print(
701
+ bcolors.YELLOW
702
+ + "!The User Agent got messed up the intermediate syntax! Exception:"
703
+ + bcolors.ENDC
704
+ )
705
+ traceback.print_exc()
706
+ continue
707
+
708
+ # print(act_dict)
709
+ return act_dict
710
+
711
+ def convert_into_system_act_format(self):
712
+ # TODO
713
+ pass
714
+
715
+ # below methods need be implemented for convlab-2 to work #
716
+ def init_session(self, **kwargs):
717
+ """Use this method to reset the agent state after each dialogue, if necessary.
718
+ This gets called before each dialogue.
719
+
720
+ Examples
721
+ --------
722
+ In `simulate_corpus_interaction.py` you will see that this is used, for example, to pass
723
+ the dialogue to the corpus agent so it knows what to talk about.
724
+
725
+ An example here would be to reset the dialogue context.
726
+ """
727
+ # dialogue goal in MultiWOZ2.1-like format
728
+ self.current_goal = kwargs.get("ini_goal", {})
729
+ self.policy.init_session(ini_goal=self.current_goal)
730
+ self.current_goal = self.policy.get_goal()
731
+ # TODO: ANYTHING ELSE THAT NEEDS TO HAPPEN BEFORE EACH DIALOGUE?
732
+ self.context = []
733
+ self.input_action = []
734
+ self.output_action = []
735
+
736
+ # init internal data
737
+ self._context_str = "" # context string with special tags used in generation
738
+ self._prev_usr_act = "" # user act string used in generation
739
+
740
+ # goal process
741
+ self.complete_goal = self._format_complete_goal(self.current_goal)
742
+ self.user_status = self._init_user_status()
743
+ self._get_scenario_str()
744
+ self.current_constraints = {} # init
745
+ self._prepare_current_constraints(
746
+ self.complete_goal["intents"],
747
+ list(self.slot_types.keys()),
748
+ if_reset_reqt=True,
749
+ )
750
+
751
+ # print("input goal:\n", self.current_goal, "\n")
752
+ # print("complete goal:\n", self.complete_goal, "\n")
753
+ # print("current constraints:\n", self.current_constraints, "\n")
754
+ # sys.exit(1)
755
+
756
+ def response(self, sys_utterance: str) -> str:
757
+ """Generate natural language response given the system response.
758
+
759
+ Parameters
760
+ ---------
761
+ sys_utterance
762
+ Last system utterance. For first turn, sys_utterance is the empty string.
763
+
764
+ Returns
765
+ -------
766
+ response
767
+ A natural language response.
768
+
769
+ """
770
+
771
+ # TODO: MAKE APPROPRIATE USE OF THE HISTORY, BEHAVIOUR_PARAMS, CURRENT_GOAL, UPDATE_GOAL TO GENERATE A RESPONSE
772
+ # TODO: DON'T FORGET TO UPDATE INPUT AND OUTPUT ACTIONS STATES
773
+ # response = "I want Italian."
774
+ gen_parse, gen_str = self.generate_whole_sequence(sys_utterance)
775
+ self.update_internal_data(gen_parse) # prepare for next turn
776
+ if self.print_intermediary_info:
777
+ segment_gen(gen_str, "example dialogue") # crazyusermodel
778
+ # TODO: update lists of context, da_in, da_out here
779
+ return gen_parse["USR_UTT"]
780
+
781
+ def get_in_da(self) -> List[List[str]]:
782
+ """Used by clients to retrieve the user model NLU.
783
+
784
+ Returns
785
+ -------
786
+ NLU output, assumed to be a list of lists, each formatted as::
787
+
788
+ [[intention, domain, slot, value], ...]
789
+
790
+ Here ``intention`` refers to a dialogue act and the ``intention``, ``domain`` and ``slot`` strings should
791
+ follow the same convention as the corpus dialogue act annotations (i.e., capitalised, and using the correct
792
+ set of slot names).
793
+ """
794
+ return self.input_action
795
+
796
+ def get_out_da(self) -> List[List[str]]:
797
+ """Used by clients to retrieve the user model policy output.
798
+
799
+ Returns
800
+ -------
801
+ Policy output, following the same convention as the NLU output.
802
+ """
803
+ return self.output_action
804
+
805
+ def get_reward(self) -> float:
806
+ """Dummy method, used for API consistency."""
807
+ return -1
808
+
809
+ def is_terminated(self) -> bool:
810
+ """This should tell an external client whether the user model considers they have completed the task."""
811
+ # return False
812
+ return self.user_status["dialogue_terminate"]
813
+
814
+
815
+ def parse_complete_gen(gen):
816
+ """parse the complete generation output, return predictions of system act, user act and user utterance"""
817
+ output = {}
818
+ for key in ["SYS_ACT", "SNT", "GC", "RA", "USR_ACT", "USR_UTT"]:
819
+ value = find_segment(gen, key)
820
+ output[key] = value
821
+ # print("***** complete generation output *****\n->", gen, "\n") # BACK
822
+ # print("***** parse output *****\n->", output, "\n")
823
+ return output
824
+
825
+
826
+ def generate_example_goal() -> dict:
827
+ """create an example goal for testing"""
828
+ # {service: service_meta},
829
+ # service_mate: {"info": {slot: value}, "fail_info": {slot: value},
830
+ # "book": {slot}: value, "fail_book": {slot: value}, "reqt": set(slot)}
831
+ goal = {}
832
+ services = ["restaurant", "hotel"]
833
+ # services = ["train", "attraction"]
834
+ # services = ["restaurant"]
835
+
836
+ # # restaurant
837
+ service = services[0]
838
+ goal[service] = {}
839
+ goal[service]["fail_info"] = {"food": "eastern european", "area": "south", "price range": "expensive"}
840
+ goal[service]["info"] = {"food": "chinese", "area": "south", "price range": "cheap"}
841
+ goal[service]["fail_book"] = {}
842
+ goal[service]["book"] = {"book day": "monday", "book people": "8", "book time": "13:15"}
843
+ goal[service]["reqt"] = {"address": "?"}
844
+
845
+ # hotel
846
+ service = services[1]
847
+ goal[service] = {}
848
+ goal[service]["fail_info"] = {"stars": "3", "price range": "cheap", "area": "centre", "internet": "_True_"}
849
+ goal[service]["info"] = {"stars": "5", "price range": "expensive", "area": "centre", "internet": "_True_"}
850
+ goal[service]["fail_book"] = {"book day": "sunday", "book stay": 3, "book people": 2}
851
+ goal[service]["book"] = {"book day": "monday", "book stay": 1, "book people": 2}
852
+ goal[service]["reqt"] = {"phone": "?", "postcode": "?"}
853
+
854
+ # # train
855
+ # service = services[1]
856
+ # goal[service] = {}
857
+ # goal[service]["info"] = {
858
+ # "destination": "ely",
859
+ # "day": "monday",
860
+ # "arrive by": "19:00",
861
+ # "departure": "cambridge",
862
+ # "book people": "8"
863
+ # }
864
+ # goal[service]["reqt"] = {"duration": "?", "leave at": "?", "train id": "?"}
865
+
866
+ # # attraction
867
+ # service = services[1]
868
+ # goal[service] = {}
869
+ # goal[service]["info"] = {
870
+ # "type": "college",
871
+ # "area": "west"
872
+ # }
873
+ # goal[service]["reqt"] = {"phone": "?", "postcode": "?"}
874
+
875
+ # taxi
876
+ # service = services[0]
877
+ # goal[service] = {}
878
+ # goal[service]["info"] = {
879
+ # "arrive by": "17:30",
880
+ # "departure": "city stop restaurant",
881
+ # "destination": "the cambridge punter"
882
+ # }
883
+ # goal[service]["reqt"] = {"phone": "?", "type": "?"}
884
+ # more services...
885
+ return goal
886
+
887
+
888
+ def set_sorted_services_for_current_goal(goal, goal_idx, df_raw_mwoz):
889
+ # Get the list of services in the goal as they appear in the data so they can be processed correctly
890
+
891
+ current_dialogue_services = []
892
+ for service_name in goal:
893
+ current_dialogue_services.append(service_name)
894
+
895
+ message = df_raw_mwoz.iloc[:, goal_idx].goal["message"]
896
+
897
+ ordered_current_dialogue_services = []
898
+
899
+ for instruction in message:
900
+ instruction_split = re.split(" |<|>", instruction)
901
+ for word in instruction_split:
902
+ if word in current_dialogue_services:
903
+ ordered_current_dialogue_services.append(word)
904
+ current_dialogue_services.remove(word)
905
+
906
+ # Make sure any words not mentioned in the message (e.g. it happens for 'police' in the second goal) are n#ot missed
907
+ for word in current_dialogue_services:
908
+ if word not in ordered_current_dialogue_services:
909
+ ordered_current_dialogue_services.append(word)
910
+
911
+ return ordered_current_dialogue_services
912
+
913
+
914
+ def read_multiWOZ_20_goals(file_path, n_goals):
915
+ df_raw_mwoz = pd.read_json(file_path)
916
+
917
+ goals = []
918
+ for i in range(n_goals):
919
+ parsed_goal = {}
920
+ goal = df_raw_mwoz.iloc[:, i].goal
921
+
922
+ # Determine relevant keys
923
+ for _ in goal.keys():
924
+ relevant_goals = {k: v for k, v in goal.items() if v != {} and k != "topic" and k != "message"}
925
+ services = [key for key in relevant_goals.keys()]
926
+ for service in services:
927
+ parsed_goal[service] = relevant_goals[service]
928
+
929
+ ordered_services = set_sorted_services_for_current_goal(parsed_goal, i, df_raw_mwoz)
930
+ parsed_goal["ordered_services"] = ordered_services
931
+
932
+ # Update the format of those relevant keys to match the format of this code
933
+ for service in parsed_goal.keys():
934
+ if service == "ordered_services":
935
+ continue
936
+
937
+ for service_key, service_value in parsed_goal[service].items():
938
+
939
+ # Handle 'reqt' key which is a list. Convert it to a dict. (and do the same for similar keys).
940
+ if type(parsed_goal[service][service_key]) is list and parsed_goal[service][service_key] != []:
941
+ replacement_dict = {}
942
+ for item in parsed_goal[service][service_key]:
943
+ replacement_dict[item] = "?"
944
+ parsed_goal[service][service_key] = replacement_dict
945
+
946
+ # Handle 'hotel' key which has a string value
947
+ # with the name of hotel - or other similar situations
948
+ elif type(parsed_goal[service][service_key]) is str:
949
+ continue
950
+
951
+ # Make sure the dictionary we are adding is not empty
952
+ if not parsed_goal[service][service_key]:
953
+ continue
954
+
955
+ # Remove any attributes that are "invalid" or "preinvalid"
956
+ # Also check if 'arriveBy' or 'leaveAt' or 'pricerange' is inside the attributes of the service_key
957
+ # If so reformat it according to the code in this file
958
+
959
+ list_of_attribute_keys = [k for k in parsed_goal[service][service_key].keys()]
960
+ for k in list_of_attribute_keys:
961
+ if k == "invalid" or k == "pre_invalid":
962
+ parsed_goal[service][service_key].pop(k)
963
+ if k == "arriveBy":
964
+ parsed_goal[service][service_key]["arrive by"] = parsed_goal[service][service_key].pop(k)
965
+ elif k == "leaveAt":
966
+ parsed_goal[service][service_key]["leave at"] = parsed_goal[service][service_key].pop(k)
967
+ elif k == "pricerange":
968
+ parsed_goal[service][service_key]["price range"] = parsed_goal[service][service_key].pop(k)
969
+ elif k == "car type":
970
+ parsed_goal[service][service_key]["type"] = parsed_goal[service][service_key].pop(k)
971
+ elif k == "trainID":
972
+ parsed_goal[service][service_key]["train id"] = parsed_goal[service][service_key].pop(k)
973
+
974
+ # Check if "book" is in the service info dict ("book" or "fail_book") then prepend
975
+ # 'book' to the keys inside the service (the attributes of the service)
976
+ if "book" in service_key:
977
+ list_of_attribute_keys = [k for k in parsed_goal[service][service_key].keys()]
978
+ for k in list_of_attribute_keys:
979
+ parsed_goal[service][service_key]["book {}".format(k)] = parsed_goal[service][service_key].pop(
980
+ k
981
+ )
982
+
983
+ # If True or False is in service.values, convert to "_True_" or "_False_"
984
+ for k, v in parsed_goal[service][service_key].items():
985
+ if v is True:
986
+ parsed_goal[service][service_key][k] = "_True_"
987
+ elif v is False:
988
+ parsed_goal[service][service_key][k] = "_False_"
989
+
990
+ goals.append(parsed_goal)
991
+
992
+ return goals
993
+
994
+
995
+ def interact(checkpoint_path):
996
+ user_model = NeuralAgent("user", checkpoint_path, "scripts/user_model_code/interaction/config.yaml")
997
+
998
+ # TODO: fix the hardcoded variables here
999
+ file_path = "data/raw/UBAR/multi-woz/data.json"
1000
+ user_model.print_intermediary_info = True
1001
+ n_goals = 50
1002
+
1003
+ for dialogue_number, goal in enumerate(read_multiWOZ_20_goals(file_path, n_goals)):
1004
+ try:
1005
+ # goal = generate_example_goal()
1006
+ user_model.init_session(ini_goal=goal)
1007
+ sys_utt = ""
1008
+
1009
+ for turn_id in range(100):
1010
+ user_model.response(sys_utt)
1011
+
1012
+ if user_model.is_terminated():
1013
+ print("Dialogue terminates!")
1014
+ break
1015
+
1016
+ # next turn materials
1017
+ sys_utt = input("Enter system response here: ")
1018
+ if sys_utt == "Goodbye":
1019
+ break
1020
+
1021
+ except Exception:
1022
+ print("Error in dialogue {}".format(dialogue_number))
1023
+ traceback.print_exc()
1024
+ continue
1025
+
1026
+
1027
+ if __name__ == "__main__":
1028
+ if len(sys.argv) == 1:
1029
+ print("Wrong argument!")
1030
+ print("Usage: python multiwoz_interact.py checkpoint_path")
1031
+ sys.exit(1)
1032
+
1033
+ checkpoint_path = sys.argv[1]
1034
+ interact(checkpoint_path)
scripts/user_model_code/interaction/schema.json ADDED
@@ -0,0 +1,712 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "service_name": "hotel",
4
+ "slots": [
5
+ {
6
+ "name": "hotel-pricerange",
7
+ "description": "price budget of the hotel",
8
+ "possible_values": [
9
+ "expensive",
10
+ "cheap",
11
+ "moderate"
12
+ ],
13
+ "is_categorical": true
14
+ },
15
+ {
16
+ "name": "hotel-type",
17
+ "description": "what is the type of the hotel",
18
+ "possible_values": [
19
+ "guesthouse",
20
+ "hotel"
21
+ ],
22
+ "is_categorical": true
23
+ },
24
+ {
25
+ "name": "hotel-parking",
26
+ "description": "whether the hotel has parking",
27
+ "possible_values": [
28
+ "free",
29
+ "no",
30
+ "yes"
31
+ ],
32
+ "is_categorical": true
33
+ },
34
+ {
35
+ "name": "hotel-bookday",
36
+ "description": "day of the hotel booking",
37
+ "possible_values": [
38
+ "monday",
39
+ "tuesday",
40
+ "wednesday",
41
+ "thursday",
42
+ "friday",
43
+ "saturday",
44
+ "sunday"
45
+ ],
46
+ "is_categorical": true
47
+ },
48
+ {
49
+ "name": "hotel-bookpeople",
50
+ "description": "number of people for the hotel booking",
51
+ "possible_values": [
52
+ "1",
53
+ "2",
54
+ "3",
55
+ "4",
56
+ "5",
57
+ "6",
58
+ "7",
59
+ "8"
60
+ ],
61
+ "is_categorical": true
62
+ },
63
+ {
64
+ "name": "hotel-bookstay",
65
+ "description": "length of stay at the hotel",
66
+ "possible_values": [
67
+ "1",
68
+ "2",
69
+ "3",
70
+ "4",
71
+ "5",
72
+ "6",
73
+ "7",
74
+ "8"
75
+ ],
76
+ "is_categorical": true
77
+ },
78
+ {
79
+ "name": "hotel-stars",
80
+ "description": "star rating of the hotel",
81
+ "possible_values": [
82
+ "0",
83
+ "1",
84
+ "2",
85
+ "3",
86
+ "4",
87
+ "5"
88
+ ],
89
+ "is_categorical": true
90
+ },
91
+ {
92
+ "name": "hotel-internet",
93
+ "description": "whether the hotel has internet",
94
+ "possible_values": [
95
+ "free",
96
+ "no",
97
+ "yes"
98
+ ],
99
+ "is_categorical": true
100
+ },
101
+ {
102
+ "name": "hotel-name",
103
+ "description": "name of the hotel",
104
+ "possible_values": [],
105
+ "is_categorical": false
106
+ },
107
+ {
108
+ "name": "hotel-area",
109
+ "description": "area or place of the hotel",
110
+ "possible_values": [
111
+ "centre",
112
+ "east",
113
+ "north",
114
+ "south",
115
+ "west"
116
+ ],
117
+ "is_categorical": true
118
+ },
119
+ {
120
+ "name": "hotel-address",
121
+ "description": "address of the hotel",
122
+ "is_categorical": false
123
+ },
124
+ {
125
+ "name": "hotel-phone",
126
+ "description": "phone number of the hotel",
127
+ "is_categorical": false
128
+ },
129
+ {
130
+ "name": "hotel-postcode",
131
+ "description": "postal code of the hotel",
132
+ "is_categorical": false
133
+ },
134
+ {
135
+ "name": "hotel-ref",
136
+ "description": "reference number of the hotel booking",
137
+ "is_categorical": false
138
+ }
139
+ ],
140
+ "description": "hotel reservations and vacation stays",
141
+ "intents": [
142
+ {
143
+ "name": "find_hotel",
144
+ "description": "search for a hotel to stay in",
145
+ "is_transactional": false,
146
+ "required_slots": [],
147
+ "optional_slots": {
148
+ "hotel-pricerange": "dontcare",
149
+ "hotel-type": "dontcare",
150
+ "hotel-parking": "dontcare",
151
+ "hotel-bookday": "dontcare",
152
+ "hotel-bookpeople": "dontcare",
153
+ "hotel-bookstay": "dontcare",
154
+ "hotel-stars": "dontcare",
155
+ "hotel-internet": "dontcare",
156
+ "hotel-name": "dontcare",
157
+ "hotel-area": "dontcare"
158
+ }
159
+ },
160
+ {
161
+ "name": "book_hotel",
162
+ "description": "book a hotel to stay in",
163
+ "is_transactional": true,
164
+ "required_slots": [],
165
+ "optional_slots": {
166
+ "hotel-pricerange": "dontcare",
167
+ "hotel-type": "dontcare",
168
+ "hotel-parking": "dontcare",
169
+ "hotel-bookday": "dontcare",
170
+ "hotel-bookpeople": "dontcare",
171
+ "hotel-bookstay": "dontcare",
172
+ "hotel-stars": "dontcare",
173
+ "hotel-internet": "dontcare",
174
+ "hotel-name": "dontcare",
175
+ "hotel-area": "dontcare"
176
+ }
177
+ }
178
+ ]
179
+ },
180
+ {
181
+ "service_name": "train",
182
+ "slots": [
183
+ {
184
+ "name": "train-arriveby",
185
+ "description": "arrival time of the train",
186
+ "possible_values": [],
187
+ "is_categorical": false
188
+ },
189
+ {
190
+ "name": "train-departure",
191
+ "description": "departure location of the train",
192
+ "possible_values": [
193
+ "birmingham new street",
194
+ "bishops stortford",
195
+ "broxbourne",
196
+ "cambridge",
197
+ "ely",
198
+ "kings lynn",
199
+ "leicester",
200
+ "london kings cross",
201
+ "london liverpool street",
202
+ "norwich",
203
+ "peterborough",
204
+ "stansted airport",
205
+ "stevenage"
206
+ ],
207
+ "is_categorical": true
208
+ },
209
+ {
210
+ "name": "train-day",
211
+ "description": "day of the train",
212
+ "possible_values": [
213
+ "monday",
214
+ "tuesday",
215
+ "wednesday",
216
+ "thursday",
217
+ "friday",
218
+ "saturday",
219
+ "sunday"
220
+ ],
221
+ "is_categorical": true
222
+ },
223
+ {
224
+ "name": "train-bookpeople",
225
+ "description": "how many train tickets you need",
226
+ "possible_values": [
227
+ "0",
228
+ "1",
229
+ "2",
230
+ "3",
231
+ "4",
232
+ "5",
233
+ "6",
234
+ "7",
235
+ "8",
236
+ "9",
237
+ "10",
238
+ "15"
239
+ ],
240
+ "is_categorical": true
241
+ },
242
+ {
243
+ "name": "train-leaveat",
244
+ "description": "leaving time for the train",
245
+ "possible_values": [],
246
+ "is_categorical": false
247
+ },
248
+ {
249
+ "name": "train-destination",
250
+ "description": "destination of the train",
251
+ "possible_values": [
252
+ "birmingham new street",
253
+ "bishops stortford",
254
+ "broxbourne",
255
+ "cambridge",
256
+ "ely",
257
+ "kings lynn",
258
+ "leicester",
259
+ "london kings cross",
260
+ "london liverpool street",
261
+ "norwich",
262
+ "peterborough",
263
+ "stansted airport",
264
+ "stevenage"
265
+ ],
266
+ "is_categorical": true
267
+ },
268
+ {
269
+ "name": "train-trainid",
270
+ "description": "id of the train",
271
+ "is_categorical": false
272
+ },
273
+ {
274
+ "name": "train-ref",
275
+ "description": "reference number of the train booking",
276
+ "is_categorical": false
277
+ },
278
+ {
279
+ "name": "train-price",
280
+ "description": "price of the train",
281
+ "is_categorical": false
282
+ },
283
+ {
284
+ "name": "train-duration",
285
+ "description": "duration of the travel",
286
+ "is_categorical": false
287
+ }
288
+ ],
289
+ "description": "find trains that take you to places",
290
+ "intents": [
291
+ {
292
+ "name": "find_train",
293
+ "description": "search for trains that take you places",
294
+ "is_transactional": false,
295
+ "required_slots": [],
296
+ "optional_slots": {
297
+ "train-destination": "dontcare",
298
+ "train-arriveby": "dontcare",
299
+ "train-departure": "dontcare",
300
+ "train-day": "dontcare",
301
+ "train-bookpeople": "dontcare",
302
+ "train-leaveat": "dontcare"
303
+ }
304
+ },
305
+ {
306
+ "name": "book_train",
307
+ "description": "book train tickets",
308
+ "is_transactional": true,
309
+ "required_slots": [],
310
+ "optional_slots": {
311
+ "train-destination": "dontcare",
312
+ "train-arriveby": "dontcare",
313
+ "train-departure": "dontcare",
314
+ "train-day": "dontcare",
315
+ "train-bookpeople": "dontcare",
316
+ "train-leaveat": "dontcare"
317
+ }
318
+ }
319
+ ]
320
+ },
321
+ {
322
+ "service_name": "attraction",
323
+ "slots": [
324
+ {
325
+ "name": "attraction-area",
326
+ "description": "area to search for attractions",
327
+ "possible_values": [
328
+ "centre",
329
+ "east",
330
+ "north",
331
+ "south",
332
+ "west"
333
+ ],
334
+ "is_categorical": true
335
+ },
336
+ {
337
+ "name": "attraction-name",
338
+ "description": "name of the attraction",
339
+ "possible_values": [],
340
+ "is_categorical": false
341
+ },
342
+ {
343
+ "name": "attraction-type",
344
+ "description": "type of the attraction",
345
+ "possible_values": [
346
+ "architecture",
347
+ "boat",
348
+ "cinema",
349
+ "college",
350
+ "concerthall",
351
+ "entertainment",
352
+ "museum",
353
+ "multiple sports",
354
+ "nightclub",
355
+ "park",
356
+ "swimmingpool",
357
+ "theatre"
358
+ ],
359
+ "is_categorical": true
360
+ },
361
+ {
362
+ "name": "attraction-entrancefee",
363
+ "description": "how much is the entrance fee",
364
+ "is_categorical": false
365
+ },
366
+ {
367
+ "name": "attraction-openhours",
368
+ "description": "open hours of the attraction",
369
+ "is_categorical": false
370
+ },
371
+ {
372
+ "name": "attraction-address",
373
+ "description": "address of the attraction",
374
+ "is_categorical": false
375
+ },
376
+ {
377
+ "name": "attraction-phone",
378
+ "description": "phone number of the attraction",
379
+ "is_categorical": false
380
+ },
381
+ {
382
+ "name": "attraction-postcode",
383
+ "description": "postal code of the attraction",
384
+ "is_categorical": false
385
+ }
386
+ ],
387
+ "description": "find touristy stuff to do around you",
388
+ "intents": [
389
+ {
390
+ "name": "find_attraction",
391
+ "description": "search for places to see for leisure",
392
+ "is_transactional": false,
393
+ "required_slots": [],
394
+ "optional_slots": {
395
+ "attraction-area": "dontcare",
396
+ "attraction-name": "dontcare",
397
+ "attraction-type": "dontcare"
398
+ }
399
+ }
400
+ ]
401
+ },
402
+ {
403
+ "service_name": "restaurant",
404
+ "slots": [
405
+ {
406
+ "name": "restaurant-pricerange",
407
+ "description": "price budget for the restaurant",
408
+ "possible_values": [
409
+ "cheap",
410
+ "expensive",
411
+ "moderate"
412
+ ],
413
+ "is_categorical": true
414
+ },
415
+ {
416
+ "name": "restaurant-area",
417
+ "description": "area or place of the restaurant",
418
+ "possible_values": [
419
+ "centre",
420
+ "east",
421
+ "north",
422
+ "south",
423
+ "west"
424
+ ],
425
+ "is_categorical": true
426
+ },
427
+ {
428
+ "name": "restaurant-food",
429
+ "description": "the cuisine of the restaurant you are looking for",
430
+ "is_categorical": false
431
+ },
432
+ {
433
+ "name": "restaurant-name",
434
+ "description": "name of the restaurant",
435
+ "possible_values": [],
436
+ "is_categorical": false
437
+ },
438
+ {
439
+ "name": "restaurant-bookday",
440
+ "description": "day of the restaurant booking",
441
+ "possible_values": [
442
+ "monday",
443
+ "tuesday",
444
+ "wednesday",
445
+ "thursday",
446
+ "friday",
447
+ "saturday",
448
+ "sunday"
449
+ ],
450
+ "is_categorical": true
451
+ },
452
+ {
453
+ "name": "restaurant-bookpeople",
454
+ "description": "how many people for the restaurant reservation",
455
+ "possible_values": [
456
+ "1",
457
+ "2",
458
+ "3",
459
+ "4",
460
+ "5",
461
+ "6",
462
+ "7",
463
+ "8"
464
+ ],
465
+ "is_categorical": true
466
+ },
467
+ {
468
+ "name": "restaurant-booktime",
469
+ "description": "time of the restaurant booking",
470
+ "possible_values": [],
471
+ "is_categorical": false
472
+ },
473
+ {
474
+ "name": "restaurant-address",
475
+ "description": "address of the restaurant",
476
+ "is_categorical": false
477
+ },
478
+ {
479
+ "name": "restaurant-phone",
480
+ "description": "phone number of the restaurant",
481
+ "is_categorical": false
482
+ },
483
+ {
484
+ "name": "restaurant-postcode",
485
+ "description": "postal code of the restaurant",
486
+ "is_categorical": false
487
+ },
488
+ {
489
+ "name": "restaurant-ref",
490
+ "description": "reference number of the restaurant booking",
491
+ "is_categorical": false
492
+ }
493
+ ],
494
+ "description": "find places to dine and whet your appetite",
495
+ "intents": [
496
+ {
497
+ "name": "find_restaurant",
498
+ "description": "search for places to wine and dine",
499
+ "is_transactional": false,
500
+ "required_slots": [],
501
+ "optional_slots": {
502
+ "restaurant-pricerange": "dontcare",
503
+ "restaurant-area": "dontcare",
504
+ "restaurant-food": "dontcare",
505
+ "restaurant-name": "dontcare",
506
+ "restaurant-bookday": "dontcare",
507
+ "restaurant-bookpeople": "dontcare",
508
+ "restaurant-booktime": "dontcare"
509
+ }
510
+ },
511
+ {
512
+ "name": "book_restaurant",
513
+ "description": "book a table at a restaurant",
514
+ "is_transactional": true,
515
+ "required_slots": [],
516
+ "optional_slots": {
517
+ "restaurant-pricerange": "dontcare",
518
+ "restaurant-area": "dontcare",
519
+ "restaurant-food": "dontcare",
520
+ "restaurant-name": "dontcare",
521
+ "restaurant-bookday": "dontcare",
522
+ "restaurant-bookpeople": "dontcare",
523
+ "restaurant-booktime": "dontcare"
524
+ }
525
+ }
526
+ ]
527
+ },
528
+ {
529
+ "service_name": "hospital",
530
+ "slots": [
531
+ {
532
+ "name": "hospital-department",
533
+ "description": "type of medical care",
534
+ "possible_values": [],
535
+ "is_categorical": false
536
+ },
537
+ {
538
+ "name": "hospital-address",
539
+ "description": "address of the hospital",
540
+ "is_categorical": false
541
+ },
542
+ {
543
+ "name": "hospital-phone",
544
+ "description": "phone number of the hospital",
545
+ "is_categorical": false
546
+ },
547
+ {
548
+ "name": "hospital-postcode",
549
+ "description": "postal code of the hospital",
550
+ "is_categorical": false
551
+ }
552
+ ],
553
+ "description": "making you feel better when you are ill",
554
+ "intents": [
555
+ {
556
+ "name": "find_hospital",
557
+ "description": "search for a medical facility or a doctor",
558
+ "is_transactional": false,
559
+ "required_slots": [],
560
+ "optional_slots": {
561
+ "hospital-department": "dontcare"
562
+ }
563
+ }
564
+ ]
565
+ },
566
+ {
567
+ "service_name": "taxi",
568
+ "slots": [
569
+ {
570
+ "name": "taxi-leaveat",
571
+ "description": "leaving time of taxi",
572
+ "possible_values": [],
573
+ "is_categorical": false
574
+ },
575
+ {
576
+ "name": "taxi-destination",
577
+ "description": "destination of taxi",
578
+ "possible_values": [],
579
+ "is_categorical": false
580
+ },
581
+ {
582
+ "name": "taxi-departure",
583
+ "description": "departure location of taxi",
584
+ "possible_values": [],
585
+ "is_categorical": false
586
+ },
587
+ {
588
+ "name": "taxi-arriveby",
589
+ "description": "arrival time of taxi",
590
+ "possible_values": [],
591
+ "is_categorical": false
592
+ },
593
+ {
594
+ "name": "taxi-type",
595
+ "description": "car type of the taxi",
596
+ "is_categorical": false
597
+ },
598
+ {
599
+ "name": "taxi-phone",
600
+ "description": "phone number of the taxi",
601
+ "is_categorical": false
602
+ }
603
+ ],
604
+ "description": "rent cheap cabs to avoid traffic",
605
+ "intents": [
606
+ {
607
+ "name": "book_taxi",
608
+ "description": "book taxis to travel between places",
609
+ "is_transactional": true,
610
+ "required_slots": [],
611
+ "optional_slots": {
612
+ "taxi-leaveat": "dontcare",
613
+ "taxi-destination": "dontcare",
614
+ "taxi-departure": "dontcare",
615
+ "taxi-arriveby": "dontcare"
616
+ }
617
+ }
618
+ ]
619
+ },
620
+ {
621
+ "service_name": "bus",
622
+ "slots": [
623
+ {
624
+ "name": "bus-departure",
625
+ "description": "departure location of bus",
626
+ "possible_values": [
627
+ "cambridge"
628
+ ],
629
+ "is_categorical": false
630
+ },
631
+ {
632
+ "name": "bus-destination",
633
+ "description": "destination of bus",
634
+ "possible_values": [
635
+ "london kings cross",
636
+ "bishops stortford",
637
+ "cambridge",
638
+ "kohinoor"
639
+ ],
640
+ "is_categorical": false
641
+ },
642
+ {
643
+ "name": "bus-leaveat",
644
+ "description": "leaving time of bus",
645
+ "is_categorical": false
646
+ },
647
+ {
648
+ "name": "bus-day",
649
+ "description": "day to use the bus tickets",
650
+ "possible_values": [
651
+ "wednesday"
652
+ ],
653
+ "is_categorical": true
654
+ }
655
+ ],
656
+ "description": "bus service for traveling",
657
+ "intents": [
658
+ {
659
+ "name": "find_bus",
660
+ "description": "search for a bus",
661
+ "is_transactional": false,
662
+ "required_slots": [],
663
+ "optional_slots": {
664
+ "bus-departure": "dontcare",
665
+ "bus-destination": "dontcare",
666
+ "bus-day": "dontcare",
667
+ "bus-leaveat": "dontcare"
668
+ }
669
+ }
670
+ ]
671
+ },
672
+ {
673
+ "service_name": "police",
674
+ "slots": [
675
+ {
676
+ "name": "police-address",
677
+ "description": "address of the police station",
678
+ "is_categorical": false
679
+ },
680
+ {
681
+ "name": "police-phone",
682
+ "description": "phone number of the police station",
683
+ "is_categorical": false
684
+ },
685
+ {
686
+ "name": "police-postcode",
687
+ "description": "postal code of the police station",
688
+ "is_categorical": false
689
+ },
690
+ {
691
+ "name": "police-name",
692
+ "description": "name of the police station",
693
+ "possible_values": [
694
+ "parkside police station"
695
+ ],
696
+ "is_categorical": true
697
+ }
698
+ ],
699
+ "description": "police station",
700
+ "intents": [
701
+ {
702
+ "name": "police",
703
+ "description": "search for police station",
704
+ "is_transactional": false,
705
+ "required_slots": [],
706
+ "optional_slots": {
707
+ "police-name": "dontcare"
708
+ }
709
+ }
710
+ ]
711
+ }
712
+ ]
scripts/user_model_code/interaction/utils.py ADDED
@@ -0,0 +1,308 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+
4
+
5
+ def segment_gen(gen, dial_id):
6
+ def _color(_segment):
7
+ if tag == "CTX":
8
+ _segment = _segment.replace(" </USR>", f"{bcolors.ENDC}")
9
+ _segment = _segment.replace(" </SYS>", f"{bcolors.ENDC}")
10
+ _segment = _segment.replace("<USR/> ", f"USR: {bcolors.OKCYAN}")
11
+ _segment = _segment.replace("<SYS/> ", f"SYS: {bcolors.OKBLUE}")
12
+ if tag == "SYS_UTT":
13
+ _segment = f"{bcolors.OKBLUE}" + _segment + f"{bcolors.ENDC}"
14
+ if tag == "USR_UTT":
15
+ _segment = f"{bcolors.OKCYAN}" + _segment + f"{bcolors.ENDC}"
16
+ if tag in ["SYS_ACT", "USR_ACT", "GOAL"]:
17
+ _segment = _segment.replace("<ACT/> ", f"{bcolors.RED}")
18
+ _segment = _segment.replace(" </ACT>", f"{bcolors.ENDC}")
19
+ _segment = _segment.replace("<SLOT/> ", f"{bcolors.YELLOW}")
20
+ _segment = _segment.replace(" </SLOT>", f"{bcolors.ENDC}")
21
+ _segment = _segment.replace("<VALUE/> ", f"{bcolors.GREEN}")
22
+ _segment = _segment.replace(" </VALUE>", f"{bcolors.ENDC}")
23
+ if tag == "GOAL":
24
+ _segment = _segment.replace(
25
+ "<SCENARIO/>", f"<SCENARIO/>{bcolors.UNDERLINE}"
26
+ )
27
+ _segment = _segment.replace("</SCENARIO>", f"{bcolors.ENDC}</SCENARIO>")
28
+ _segment = _segment.replace("<TASK/>", f"<TASK/>{bcolors.UNDERLINE}")
29
+ _segment = _segment.replace("</TASK>", f"{bcolors.ENDC}</TASK>")
30
+ # if tag in ["SNT", "GC"]:
31
+ # segment = segment.replace("<{}/> ".format(tag), "<{}/> *".format(tag))
32
+ # segment = segment.replace(" </{}>".format(tag), "* <{}/>".format(tag))
33
+ return _segment
34
+
35
+ assert isinstance(gen, str)
36
+ # gen = gen.split()
37
+ # print(gen)
38
+ print("*** Dial_id: {} ***".format(dial_id))
39
+ for tag in [
40
+ "CTX",
41
+ "SYS_UTT",
42
+ "SYS_ACT",
43
+ "GOAL",
44
+ "SNT",
45
+ "RA",
46
+ "GC",
47
+ "USR_ACT",
48
+ "USR_UTT",
49
+ ]:
50
+ segment = find_segment(gen, tag)
51
+ if segment is not None:
52
+ print('{} -> "{}"'.format(tag, _color(segment)))
53
+ else:
54
+ print("Fail to find the segment...")
55
+ print("GEN:", gen)
56
+ print("---" * 30)
57
+
58
+
59
+ # input("press...")
60
+
61
+
62
+ def get_original_act_set():
63
+ # full act vocab:
64
+ # https://github.com/ConvLab/ConvLab/blob/master/data/multiwoz/annotation/Multiwoz%20data%20analysis.md#dialog-act
65
+ acts = set()
66
+ acts.add("Inform")
67
+ acts.add("Request")
68
+ acts.add(
69
+ "NoOffer"
70
+ ) # equivalent to the concept of `no matching`, `cannot find` in database
71
+ acts.add("Recommend")
72
+ acts.add("Select")
73
+ acts.add(
74
+ "OfferBook"
75
+ ) # only for `train` domain, ask if book is needed, equivalent to `Booking-Inform` with [[none, none]]
76
+ # args in restaurant/hotel domain
77
+ acts.add(
78
+ "OfferBooked"
79
+ ) # only for `train` domain, inform booking is complete, with corresponding info (such as ref number)
80
+ acts.add("Book") # inform booking is successful, equivalent to `OfferBooked` above
81
+ acts.add(
82
+ "NoBook"
83
+ ) # inform booking fails, might because of no availability, usually come together act `request`
84
+ acts.add("bye")
85
+ acts.add("greet")
86
+ acts.add("reqmore")
87
+ acts.add("welcome")
88
+ acts.add("thank")
89
+ return acts
90
+
91
+
92
+ def get_act_natural_language(act):
93
+ if act in ["bye", "greet", "reqmore", "welcome", "thank"]:
94
+ return act
95
+
96
+ assert act[0].isupper()
97
+ tokens = re.findall("[A-Z][^A-Z]*", act) # e.g., `FindEvents` -> `Find Events`
98
+ tokens = list(map(str.lower, tokens)) # lower case, -> `find events`
99
+ act_nl = " ".join(tokens)
100
+ return act_nl
101
+
102
+
103
+ def convert_act_into_sgd(act, SPECIAL_TOKENS):
104
+ # TODO: check inference result to see if mapping on NoOffer, OfferBook and NoBook are fine
105
+ """
106
+ convert multiwoz acts (w/o domain info) into sgd acts ensure that acts with same concept use one name
107
+ e.g., Book (OfferBooked) -> NOTIFY_SUCCESS, NoBook -> NOTIFY_FAILURE
108
+ """
109
+ if act == "NoOffer":
110
+ act = "NOTIFY_FAILURE"
111
+
112
+ elif act == "Recommend":
113
+ act = "OFFER"
114
+
115
+ # technically, `OfferBook` is equivalent to (`act=OFFER_INTENT, slot=intent, value=ReserveRestaurant`)
116
+ # on system side in sgd since (1) the conversion is not trivial (completely different representations)
117
+ # and (2) multiwoz has no slot called `intent` one cannot simply convert `OfferBook` to `OFFER_INTENT`
118
+ # we thus keep the act as is
119
+ # note that there is no slot `intent` and value conveying intents in multiwoz
120
+ elif act == "OfferBook":
121
+ act = "Offer_Book"
122
+
123
+ elif act == "OfferBooked":
124
+ act = "NOTIFY_SUCCESS"
125
+
126
+ elif act == "Book": # same as `OfferBooked`
127
+ act = "NOTIFY_SUCCESS"
128
+
129
+ elif act == "NoBook":
130
+ act = "NOTIFY_FAILURE"
131
+
132
+ elif act == "bye":
133
+ act = "GOODBYE"
134
+
135
+ elif act == "reqmore":
136
+ act = "REQ_MORE"
137
+
138
+ elif act == "thank":
139
+ act = "THANK_YOU"
140
+ # elif act == "greet":
141
+ # elif act == "welcome":
142
+ act = act.upper() # align with sgd acts, e.g., `Inform` -> `INFORM`
143
+
144
+ # check if valid
145
+ assert "_{}_".format(act) in SPECIAL_TOKENS["additional_special_tokens"]
146
+ return act
147
+
148
+
149
+ def load_schema(schema_file):
150
+ def _update(key, value, mapping):
151
+ if key in mapping:
152
+ assert (
153
+ value == mapping[key]
154
+ ) # ensure service meta is the same between data splits
155
+ else:
156
+ mapping[key] = value
157
+
158
+ def _restructure_service_meta(service_meta, attribute):
159
+ """ "convert slot/intent metadata list into dict(slot/intent=metadata)"""
160
+ assert attribute in ["slots", "intents"]
161
+ mapping = {}
162
+ for value in service_meta[attribute]:
163
+ key = value["name"]
164
+ if attribute == "slots": # domain-slot in multiwoz
165
+ assert "-" in key
166
+ _, key = key.split("-") # domain, slot
167
+ key = normalise_slot(key)
168
+ else: # intent
169
+ key = normalise_intent(key)
170
+ mapping[key] = value
171
+ service_meta[attribute] = mapping
172
+
173
+ with open(schema_file) as f:
174
+ data = json.load(f)
175
+
176
+ SERVICE2META = {}
177
+ SLOTS, INTENTS = set(), set()
178
+ for service_meta in data:
179
+ service = service_meta["service_name"]
180
+ _restructure_service_meta(service_meta, "slots")
181
+ _restructure_service_meta(service_meta, "intents")
182
+ _update(service, service_meta, SERVICE2META)
183
+
184
+ # collect domain-independent slots
185
+ # for domain_slot in service_meta["slots"]:
186
+ # assert "-" in domain_slot
187
+ # domain, slot = domain_slot.split("-")
188
+ # slot = normalise_slot(slot)
189
+ # SLOTS.add(slot)
190
+ for slot in service_meta["slots"]:
191
+ SLOTS.add(slot)
192
+
193
+ for intent in service_meta["intents"]:
194
+ # intent = normalise_intent(intent)
195
+ INTENTS.add(intent)
196
+
197
+ print("Load schema, intents: {}, slots: {}".format(len(INTENTS), len(SLOTS)))
198
+ return SERVICE2META, INTENTS, SLOTS
199
+
200
+
201
+ def normalise_intent(intent):
202
+ """convert intent into natural language, e.g., find_hotel -> find hotel"""
203
+ if intent == "police":
204
+ intent = "find_police"
205
+ if intent == "book_taxi":
206
+ intent = "find_taxi"
207
+ assert "_" in intent
208
+ return " ".join(intent.split("_"))
209
+
210
+
211
+ def normalise_slot(slot):
212
+ if slot == "pricerange":
213
+ return "price range"
214
+
215
+ elif slot == "bookday":
216
+ return "book day"
217
+
218
+ elif slot == "bookpeople":
219
+ return "book people"
220
+
221
+ elif slot == "booktime":
222
+ return "book time"
223
+
224
+ elif slot == "bookstay":
225
+ return "book stay"
226
+
227
+ elif slot == "ref":
228
+ return "reference"
229
+
230
+ elif slot == "arriveby":
231
+ return "arrive by"
232
+
233
+ elif slot == "leaveat":
234
+ return "leave at"
235
+
236
+ elif slot == "trainid":
237
+ return "train id"
238
+
239
+ elif slot == "openhours":
240
+ return "open hours"
241
+
242
+ elif slot == "entrancefee":
243
+ return "entrance fee"
244
+
245
+ elif slot in ["none", "?"]:
246
+ # return "_Empty_" # special token mark will be added during sequence linearlisation
247
+ return "Empty"
248
+
249
+ else:
250
+ return slot
251
+
252
+
253
+ def normalise_value(value):
254
+ # deal with binary and empty values
255
+ if value == "yes":
256
+ # return "_True_"
257
+ return "True"
258
+
259
+ elif value == "no":
260
+ # return "_False_"
261
+ return "False"
262
+
263
+ elif value in ["none", "?"]:
264
+ # return "_Empty_"
265
+ return "Empty"
266
+
267
+ # if value == "swimmingpool": # for simplicity, dont split
268
+ # return "swimming pool"
269
+
270
+ else:
271
+ return value
272
+
273
+
274
+ def wrap_element(content_type, content):
275
+ """
276
+ wrap elements such as slot, value, e.g., <SLOT/> slot </SLOT>
277
+ """
278
+ assert "/" not in content_type
279
+ return "<{}/> {} </{}>".format(content_type, content, content_type)
280
+
281
+
282
+ def add_str(str1, str2):
283
+ return str1 + " " + str2
284
+
285
+
286
+ def find_segment(gen, tag):
287
+ assert isinstance(gen, str)
288
+ gen = gen.split()
289
+ try:
290
+ start = gen.index("<{}/>".format(tag)) + 1
291
+ end = gen.index("</{}>".format(tag))
292
+ segment = " ".join(gen[start:end])
293
+ except Exception:
294
+ print("Missing {} tag in generated sequence".format(tag))
295
+ segment = None
296
+ return segment
297
+
298
+
299
+ class bcolors:
300
+ HEADER = "\033[95m"
301
+ OKBLUE = "\033[94m"
302
+ OKCYAN = "\033[96m"
303
+ GREEN = "\033[92m"
304
+ YELLOW = "\033[93m"
305
+ RED = "\033[91m"
306
+ ENDC = "\033[0m"
307
+ BOLD = "\033[1m"
308
+ UNDERLINE = "\033[4m"
scripts/user_model_code/main_user_model.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import sys
4
+ import time
5
+
6
+ import numpy as np
7
+ import torch
8
+ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
9
+ from tqdm import tqdm
10
+ from transformers import (
11
+ AdamW,
12
+ GPT2Config,
13
+ GPT2LMHeadModel,
14
+ GPT2Tokenizer,
15
+ get_linear_schedule_with_warmup,
16
+ )
17
+
18
+ import wandb
19
+ from crazyneuraluser.user_model_code.argument import get_args
20
+
21
+ # from interact import interact
22
+ from crazyneuraluser.user_model_code.dataset import SGD_Dataset
23
+ from crazyneuraluser.user_model_code.utils_generation import decode_e2e
24
+ from crazyneuraluser.user_model_code.utils_sgd import get_special_tokens
25
+
26
+
27
+ def print_loss(epoch, data_type, LOSS, t0):
28
+ print(
29
+ "Epoch: {} | {} loss: {:.3f} | time: {:.1f}".format(
30
+ epoch, data_type, LOSS, time.time() - t0
31
+ )
32
+ )
33
+
34
+
35
+ def print_score(epoch, data_type, res, t0):
36
+ print(
37
+ "Epoch: {} | {}: joint_acc: {:.2f}%, slot_acc: {:.2f}% | time: {:.1f}".format(
38
+ epoch,
39
+ data_type,
40
+ res["avg_joint_acc"],
41
+ res["avg_slot_acc"],
42
+ time.time() - t0,
43
+ )
44
+ )
45
+
46
+
47
+ def run_one_epoch(data_type, dataloader, trainer, epoch, run_type, collector=None):
48
+ t0 = time.time()
49
+ assert data_type in ["dev", "test"]
50
+ assert run_type in ["teacher_force", "generation"]
51
+ model, optimizer, scheduler, tokenizer = trainer
52
+
53
+ LOSS = 0
54
+ # result = {"slot_acc": [], "joint_acc": []}
55
+ # mention_match = 0
56
+ # coref_lines = []
57
+ iterator = enumerate(
58
+ tqdm(
59
+ dataloader,
60
+ desc="Epoch {} {}".format(epoch, run_type),
61
+ disable=args.disable_display,
62
+ )
63
+ )
64
+ for step, batch in iterator:
65
+ if run_type == "teacher_force":
66
+ loss, logits, _ = model(
67
+ input_ids=batch["input_ids"],
68
+ attention_mask=batch["attention_mask"],
69
+ token_type_ids=batch["token_type_ids"],
70
+ labels=batch["label_ids"],
71
+ ).values()
72
+ LOSS += loss
73
+ else:
74
+ decode_e2e(args, batch, model, tokenizer, collector=collector)
75
+
76
+ # print log
77
+ if run_type == "teacher_force":
78
+ LOSS /= step + 1
79
+ print_loss(epoch, data_type, LOSS, t0)
80
+ return LOSS
81
+ else: # generation
82
+ # TODO: add evaluation code here
83
+ return None
84
+
85
+
86
+ def set_dataloader(args, tokenizer, data_type, run_type, data_size=-1):
87
+ dataset = SGD_Dataset(
88
+ args, tokenizer, data_type, run_type == "generation", data_size
89
+ )
90
+ # sys.exit(1)
91
+ if data_type == "train":
92
+ sampler = RandomSampler(
93
+ dataset
94
+ ) # if args.local_rank == -1 else DistributedSampler(train_dataset)
95
+ else:
96
+ sampler = SequentialSampler(dataset)
97
+
98
+ dataloader = DataLoader(
99
+ dataset,
100
+ sampler=sampler,
101
+ batch_size=args.train_batch_size
102
+ if data_type == "train"
103
+ else args.eval_batch_size,
104
+ collate_fn=dataset.collate_fn,
105
+ )
106
+ return dataloader
107
+
108
+
109
+ def train(args, tokenizer, model):
110
+
111
+ wandb.init(
112
+ # Set the project where this run will be logged
113
+ project="E2E User Simulator (Alistair)",
114
+ entity="byrne-lab",
115
+ # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
116
+ name=args.wandb_train_run_name,
117
+ # Track hyperparameters and run metadata
118
+ config={
119
+ "data_dir": args.data_dir,
120
+ "model_name": args.model_name,
121
+ "learning_rate": args.learning_rate,
122
+ "gradient_accumulation_steps": args.gradient_accumulation_steps,
123
+ "train_batch_size": args.train_batch_size,
124
+ "eval_batch_size": args.eval_batch_size,
125
+ },
126
+ )
127
+
128
+ # load data
129
+ train_dataloader = set_dataloader(
130
+ args, tokenizer, "train", "teacher_force", data_size=args.train_size
131
+ )
132
+ dev_dataloader = set_dataloader(
133
+ args, tokenizer, "dev", "teacher_force", data_size=args.eval_size
134
+ )
135
+
136
+ optimizer = AdamW(model.parameters(), lr=args.learning_rate, eps=args.adam_epsilon)
137
+ if args.use_scheduler:
138
+ t_total = (
139
+ len(train_dataloader) // args.gradient_accumulation_steps * args.max_epoch
140
+ )
141
+ scheduler = get_linear_schedule_with_warmup(
142
+ optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total
143
+ )
144
+ else:
145
+ scheduler = None
146
+ trainer = (model, optimizer, scheduler, tokenizer)
147
+
148
+ print("Do evaluation before training!")
149
+ model.eval()
150
+ with torch.no_grad():
151
+ _ = run_one_epoch("dev", dev_dataloader, trainer, -1, "teacher_force")
152
+
153
+ print("Start training!\n{}".format("***" * 30))
154
+ eval_step = args.eval_interval // args.train_batch_size
155
+ best_score = -100
156
+ global_step = 0
157
+ no_improve_count = 0
158
+ for epoch in range(args.max_epoch):
159
+ # initialize for each epoch training
160
+ t0 = time.time()
161
+ model.train()
162
+ model.zero_grad()
163
+ LOSS = 0
164
+ iterator = enumerate(
165
+ tqdm(
166
+ train_dataloader,
167
+ desc="Epoch {}".format(epoch),
168
+ disable=args.disable_display,
169
+ )
170
+ )
171
+ for local_step, batch in iterator:
172
+ loss, logits, _ = model(
173
+ input_ids=batch["input_ids"],
174
+ attention_mask=batch["attention_mask"],
175
+ token_type_ids=batch["token_type_ids"],
176
+ labels=batch["label_ids"],
177
+ ).values()
178
+ LOSS += loss
179
+ global_step += 1
180
+
181
+ wandb.log({"loss": loss})
182
+
183
+ # update model
184
+ if loss != 0:
185
+ loss = loss / args.gradient_accumulation_steps
186
+ loss.backward()
187
+
188
+ if global_step % args.gradient_accumulation_steps == 0:
189
+ # norm = torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)
190
+ optimizer.step()
191
+ if args.use_scheduler:
192
+ scheduler.step()
193
+ optimizer.zero_grad()
194
+
195
+ # evaluate model
196
+ if global_step % eval_step == 0:
197
+ model.eval()
198
+ with torch.no_grad():
199
+ loss = run_one_epoch(
200
+ "dev", dev_dataloader, trainer, epoch, "teacher_force"
201
+ )
202
+ score = -loss # dev loss as criterion for early training
203
+ wandb.log({"dev_loss": loss})
204
+ model.train()
205
+
206
+ save_checkpoint(
207
+ args, tokenizer, model, global_step * args.train_batch_size
208
+ )
209
+ if score > best_score:
210
+ best_score = score
211
+ print("Best score: {:.2f}".format(best_score))
212
+ no_improve_count = 0
213
+ else:
214
+ no_improve_count += 1
215
+
216
+ # early stop
217
+ if no_improve_count == args.no_improve_max:
218
+ print("Early stop!")
219
+ return
220
+
221
+ LOSS /= local_step + 1
222
+ print_loss(epoch, "train", LOSS, t0)
223
+ print("***" * 30)
224
+
225
+ wandb.log({"epoch": epoch, "epoch_loss": LOSS})
226
+
227
+ # Mark the run as finished on wandb
228
+ wandb.finish()
229
+
230
+
231
+ def test(args, tokenizer, model):
232
+ # load data
233
+ test_gen_dataloader = set_dataloader(args, tokenizer, "test", "generation")
234
+
235
+ trainer = (model, None, None, tokenizer)
236
+ model.eval()
237
+ collector = {"decode-dev": {}, "decode-test": {}}
238
+ with torch.no_grad():
239
+ # # evaluate on dev
240
+ # _ = run_one_epoch('dev', dev_dataloader, trainer, 'Eval', 'teacher_force')
241
+
242
+ # # generate on dev
243
+ # res_dev = run_one_epoch('dev', dev_gen_dataloader, trainer, 'Dev', 'generation',
244
+ # collector=collector['decode-dev'])
245
+ # collector['result-dev'] = res_dev
246
+ # print_qr_result(res_dev['qr'], 'dev')
247
+
248
+ # generate on test
249
+ res_test = run_one_epoch(
250
+ "test",
251
+ test_gen_dataloader,
252
+ trainer,
253
+ "Test",
254
+ "generation",
255
+ collector=collector["decode-test"],
256
+ )
257
+ collector["result-test"] = res_test
258
+
259
+ out_file = args.decode_file
260
+ with open(out_file, "w") as f:
261
+ json.dump(collector, f, indent=4, sort_keys=True)
262
+ print("Decode file is saved at {}".format(out_file))
263
+ print("Done decoding!")
264
+
265
+
266
+ def save_checkpoint(args, tokenizer, model, step):
267
+ save_path = args.checkpoint + "_step" + str(step)
268
+ print("Save model in {}!".format(save_path))
269
+ tokenizer.save_pretrained(save_path)
270
+ model.save_pretrained(save_path)
271
+
272
+
273
+ def load_checkpoint(args):
274
+ save_path = args.checkpoint # + '_step' + str(args.step)
275
+ print("Load model, tokenizer from {}".format(save_path))
276
+ tokenizer = GPT2Tokenizer.from_pretrained(save_path)
277
+ model = GPT2LMHeadModel.from_pretrained(save_path)
278
+ model.to(args.device)
279
+ return tokenizer, model
280
+
281
+
282
+ def load_pretrained_model(args):
283
+ save_path = args.pre_checkpoint
284
+ print("Load model, tokenizer from {}".format(save_path))
285
+ tokenizer = GPT2Tokenizer.from_pretrained(save_path)
286
+ model = GPT2LMHeadModel.from_pretrained(save_path)
287
+ model.to(args.device)
288
+ return tokenizer, model
289
+
290
+
291
+ def set_model(args, SPECIAL_TOKENS):
292
+ """initiate config, tokenizer and model"""
293
+ # add special tokens into tokenizer
294
+ config = GPT2Config.from_pretrained(args.model_name_or_path)
295
+ tokenizer = GPT2Tokenizer.from_pretrained(args.model_name_or_path)
296
+ tokenizer.add_special_tokens(SPECIAL_TOKENS)
297
+ model = GPT2LMHeadModel.from_pretrained(
298
+ args.model_name_or_path, config=config
299
+ ) # GPT2LMHeadModel
300
+ model.resize_token_embeddings(len(tokenizer))
301
+ model.to(args.device)
302
+ print("Done setting model")
303
+ return config, tokenizer, model
304
+
305
+
306
+ def set_seed(args):
307
+ """for reproduction"""
308
+ random.seed(args.seed)
309
+ np.random.seed(args.seed)
310
+ torch.manual_seed(args.seed)
311
+ torch.cuda.manual_seed(args.seed)
312
+ torch.cuda.manual_seed_all(args.seed)
313
+ torch.backends.cudnn.deterministic = True
314
+ torch.backends.cudnn.enabled = False
315
+ torch.backends.cudnn.benchmark = False
316
+
317
+
318
+ if __name__ == "__main__":
319
+ # Load arguments
320
+ args = get_args()
321
+
322
+ # Set seed, device
323
+ set_seed(args)
324
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
325
+ args.device = device
326
+
327
+ # Load special tokens
328
+ SPECIAL_TOKENS = get_special_tokens()
329
+
330
+ if args.mode == "training":
331
+ config, tokenizer, model = set_model(args, SPECIAL_TOKENS)
332
+ train(args, tokenizer, model)
333
+
334
+ elif args.mode == "finetune":
335
+ tokenizer, model = load_pretrained_model(args)
336
+ train(args, tokenizer, model)
337
+
338
+ elif args.mode == "testing":
339
+ tokenizer, model = load_checkpoint(args)
340
+ test(args, tokenizer, model)
341
+
342
+ # elif args.mode == 'interact':
343
+ # tokenizer, model = load_checkpoint(args)
344
+ # interact(args, tokenizer, model)
345
+
346
+ else:
347
+ sys.exit(1)
scripts/user_model_code/preprocess_multiwoz.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+
5
+ from tqdm import tqdm
6
+
7
+ from crazyneuraluser.user_model_code.analysis_multiwoz import DATA_SPLIT, collect_data
8
+ from crazyneuraluser.user_model_code.utils_multiwoz import (
9
+ get_act_natural_language,
10
+ get_original_act_set,
11
+ load_schema,
12
+ normalise_intent,
13
+ normalise_slot,
14
+ normalise_value,
15
+ )
16
+ from crazyneuraluser.user_model_code.utils_sgd import (
17
+ add_str,
18
+ compare_slot_values_in_state,
19
+ conv_special_token,
20
+ dict2list,
21
+ get_special_tokens,
22
+ wrap_element,
23
+ )
24
+
25
+ """ pre-process script for MultiWOZ v2.2 """
26
+
27
+
28
+ class DialMetaData:
29
+ def __init__(self, dial_id, dial_meta, dial_act, unify_act):
30
+ self.dial_id = dial_id
31
+ self.unify_act = unify_act
32
+ self.turn_meta_list, self.scenario = self.parse(
33
+ dial_meta, dial_act
34
+ ) # None for system turn
35
+ self.linearise_turns()
36
+
37
+ def parse(self, dial_meta, dial_act):
38
+ global n, act_intent, non_intent
39
+ assert len(dial_meta["turns"]) == len(dial_act)
40
+
41
+ turn_meta_list = []
42
+ scenario = []
43
+ sys_turn = None # dummy sys turn for first usr turn
44
+ prev_intent = ""
45
+ prev_usr_turn, prev_usr_turn_meta = (
46
+ None,
47
+ None,
48
+ ) # dummpy for tracing goal change at first turn
49
+ for turn_id, turn in enumerate(dial_meta["turns"]):
50
+ assert turn_id == int(turn["turn_id"])
51
+
52
+ if turn["speaker"] == "SYSTEM":
53
+ sys_turn = turn
54
+ turn_meta_list.append(None)
55
+ continue
56
+
57
+ # init turn meta
58
+ turn_meta = TurnMetaData(
59
+ prev_intent, sys_turn, turn, self.dial_id, self.unify_act
60
+ )
61
+
62
+ # get goal change label
63
+ turn_meta.get_goal_change_label(prev_usr_turn, prev_usr_turn_meta)
64
+
65
+ # update previous goal
66
+ for prev_turn_meta in reversed(turn_meta_list):
67
+ if prev_turn_meta is None:
68
+ continue
69
+ prev_turn_meta.accumulate_constraints(turn_meta) # TODO: check goal
70
+
71
+ # record task (intent) in scenario
72
+ if turn_meta.usr_intent not in scenario:
73
+ scenario.append(turn_meta.usr_intent)
74
+
75
+ turn_meta_list.append(turn_meta)
76
+ prev_intent = turn_meta.usr_intent
77
+ prev_usr_turn, prev_usr_turn_meta = turn, turn_meta
78
+ assert len(turn_meta_list) == len(dial_meta["turns"])
79
+ return turn_meta_list, scenario
80
+
81
+ def linearise_turns(self):
82
+ # linearise necessary meterials
83
+ for turn_meta in self.turn_meta_list:
84
+ if turn_meta is None:
85
+ continue
86
+ turn_meta._linearise(self.scenario, SERVICE2META)
87
+
88
+
89
+ class TurnMetaData:
90
+ def __init__(self, prev_intent, sys_turn, usr_turn, dial_id, unify_act):
91
+ self.dial_id = dial_id
92
+ self.unify_act = unify_act
93
+ self.original_act_set = get_original_act_set() # act set w/o domain information
94
+ self.sys_turn, self.usr_turn = sys_turn, usr_turn
95
+
96
+ # turn id
97
+ self.sys_turn_id, self.usr_turn_id = self._get_turn_id(sys_turn, usr_turn)
98
+
99
+ # intent
100
+ self.usr_intent = normalise_intent(self._get_intent(usr_turn, prev_intent))
101
+ if remove_book_intent:
102
+ self.usr_intent = self.usr_intent.replace("book", "find")
103
+ assert self.usr_intent in INTENTS # or self.usr_intent == "temp temp"
104
+ self.service = self.usr_intent.split()[1]
105
+
106
+ # utterances
107
+ self.utt = {}
108
+ self.utt["sys"], self.utt["usr"] = self._get_utt(sys_turn), self._get_utt(
109
+ usr_turn
110
+ )
111
+
112
+ # act
113
+ self.act2sv = {}
114
+ self.act2sv["sys"], _ = self._parse_action(self.sys_turn_id, self.sys_turn)
115
+ self.act2sv["usr"], self.usr_constraints = self._parse_action(
116
+ self.usr_turn_id, self.usr_turn
117
+ )
118
+
119
+ # task boundary
120
+ self._get_new_task_label(prev_intent)
121
+
122
+ # req_alts
123
+ self._get_req_alts_label()
124
+
125
+ def _get_turn_id(self, sys_turn, usr_turn):
126
+ usr_turn_id = int(usr_turn["turn_id"]) # 0, 2, 4 ...
127
+ sys_turn_id = int(sys_turn["turn_id"]) if sys_turn is not None else -1
128
+ assert sys_turn_id == (usr_turn_id - 1)
129
+ return sys_turn_id, usr_turn_id
130
+
131
+ def _get_utt(self, turn):
132
+ if turn is None:
133
+ return ""
134
+ return turn["utterance"]
135
+
136
+ def accumulate_constraints(self, new_turn_meta):
137
+ """
138
+ Add slot, slot-value pairs from a given following turn
139
+ This function forms the user goal by accumulating constraints backward
140
+ """
141
+ # only accumulate constraints with the same task/intent
142
+ if new_turn_meta.usr_intent != self.usr_intent:
143
+ return
144
+
145
+ if (
146
+ new_turn_meta.goal_change
147
+ ): # if goal changes at a new turn, these constraints should not be put in previous turns
148
+ return
149
+
150
+ # only accumulate constraints without goal change
151
+ # if the value of a slot is changed (goal change) in a new turn,
152
+ # this slot-value pair is not part of initial goal and should not be added into the goal of previous turns
153
+ new_constraints = new_turn_meta.usr_constraints
154
+ self.usr_constraints["requestable"] = self.usr_constraints["requestable"].union(
155
+ new_constraints["requestable"]
156
+ )
157
+ for slot, value_list in new_constraints["informable"].items():
158
+ if slot not in self.usr_constraints["informable"]:
159
+ self.usr_constraints["informable"][slot] = value_list
160
+
161
+ def get_goal_change_label(self, prev_usr_turn, prev_turn_meta):
162
+ """check if goal changed (value of slot changes) between two turn states"""
163
+ # first usr turn
164
+ if prev_usr_turn is None:
165
+ assert self.usr_turn_id == 0
166
+ self.goal_change = False
167
+ return
168
+
169
+ # last usr turn
170
+ if "GOODBYE" in self.act2sv["usr"] or "THANK_YOU" in self.act2sv["usr"]:
171
+ self.goal_change = False
172
+ return
173
+
174
+ assert self.usr_turn_id != 0
175
+ assert prev_usr_turn["speaker"] == "USER"
176
+
177
+ # new task
178
+ if self.usr_intent != prev_turn_meta.usr_intent:
179
+ self.goal_change = False
180
+ return
181
+
182
+ # compare two states to obtain goal change flag
183
+ curr_state, prev_state = None, None
184
+ for frame in self.usr_turn["frames"]:
185
+ if frame["service"] == self.service:
186
+ curr_state = frame["state"]["slot_values"]
187
+
188
+ for frame in prev_usr_turn["frames"]:
189
+ if frame["service"] == prev_turn_meta.service:
190
+ prev_state = frame["state"]["slot_values"]
191
+
192
+ # check if slot value has changed at current turn (new slot is not counted)
193
+ assert curr_state is not None and prev_state is not None
194
+ self.goal_change = compare_slot_values_in_state(curr_state, prev_state)
195
+
196
+ def _get_domain_from_act(self, dialogue_act):
197
+ """
198
+ parse the raw dialouge act annotation to get domain info
199
+ number of doamin can be more than 1 for multi-domain turns
200
+ """
201
+ domains = set()
202
+ book_flag = False
203
+ for dact, sv_pairs in dialogue_act.items():
204
+ assert "-" in dact
205
+ domain, _ = dact.split("-")
206
+ if domain not in ["Booking", "general"]:
207
+ domains.add(domain)
208
+ for slot, value in sv_pairs:
209
+ if "book" in slot: # e.g., bookday
210
+ book_flag = True
211
+ return domains, book_flag
212
+
213
+ def _get_intent(self, usr_turn, prev_intent):
214
+ intents = []
215
+ for frame in usr_turn["frames"]:
216
+ # service = frame["service"]
217
+ intent = frame["state"]["active_intent"]
218
+ if intent != "NONE":
219
+ intents.append(intent)
220
+
221
+ if len(intents) == 1:
222
+ intent = intents[0]
223
+ if intent == "find_taxi":
224
+ intent = "book_taxi"
225
+ return intent # tackle 51.5k out of 71.5k user turns
226
+
227
+ # if above fails (e.g., due to wrong label), leverage usr act to help determine main intent/service
228
+ # possible domains in da: {'Hospital', 'Taxi', 'Train', 'Police', 'Restaurant', 'Booking', 'general',
229
+ # 'Attraction', 'Hotel'}
230
+ usr_act = data_act[self.dial_id][str(self.usr_turn_id)]["dialog_act"]
231
+ domains, book_flag = self._get_domain_from_act(usr_act)
232
+ if len(domains) == 1:
233
+ domain = list(domains)[0].lower()
234
+ if book_flag and domain in ["restaurant", "hotel", "train"]:
235
+ intent = "book_{}".format(domain)
236
+ elif domain == "taxi":
237
+ intent = "book_{}".format(domain)
238
+ else:
239
+ intent = "find_{}".format(domain)
240
+ return intent # tackle 58.1k out of 71.5k user turns
241
+
242
+ if "Taxi" in domains:
243
+ return "book_taxi" # tackle 58.8k out of 71.5k user turns
244
+
245
+ if (
246
+ self.usr_turn_id == 0
247
+ ): # wrong label at first turn, no previous intent to use, only 136 user turns here
248
+ utt = usr_turn["utterance"]
249
+ if (
250
+ "restaurant" in utt
251
+ or "Restaurant" in utt
252
+ or "eat" in utt
253
+ or "din" in utt
254
+ ):
255
+ return "find_restaurant"
256
+ elif (
257
+ "hotel" in utt
258
+ or "room" in utt
259
+ or "house" in utt
260
+ or "stay" in utt
261
+ or "live" in utt
262
+ ):
263
+ return "find_hotel"
264
+ else:
265
+ return "find_attraction" # tackle 58.9k out of 71.5k user turns
266
+
267
+ else: # not first turn, leverage sys act to help decide intent
268
+ sys_act = data_act[self.dial_id][str(self.sys_turn_id)]["dialog_act"]
269
+ sys_domains, _ = self._get_domain_from_act(sys_act)
270
+ if len(sys_domains) == 1:
271
+ domain = list(sys_domains)[0].lower()
272
+ if book_flag and domain in ["restaurant", "hotel", "train"]:
273
+ intent = "book_{}".format(domain)
274
+ elif domain == "taxi":
275
+ intent = "book_{}".format(domain)
276
+ else:
277
+ intent = "find_{}".format(domain)
278
+ return intent # tackle 67.3k out of 71.5k user turns
279
+
280
+ # two cases left enter here
281
+ # 1. turns with only general act, e.g., bye
282
+ # 2. turns have multiple intents (very few)
283
+ # both will be handled using previous intent
284
+ assert prev_intent != ""
285
+ intent = "_".join(
286
+ prev_intent.split()
287
+ ) # as prev_intent has been normalised already
288
+ return intent
289
+
290
+ def _parse_action(self, turn_id, turn):
291
+ """parse the `dialog_act` field in `dialog_acts.json`
292
+
293
+ Returns:
294
+ act2sv: act to slot value pairs, {act=sv}; sv: slot to value list, {slot=[v1, v2]}
295
+ """
296
+ act2sv = dict()
297
+ constraints = {"informable": dict(), "requestable": set()}
298
+ if turn is None:
299
+ return None, constraints
300
+
301
+ # get da from data_act
302
+ dialogue_act = data_act[self.dial_id][str(turn_id)]["dialog_act"]
303
+ # domains = set()
304
+ for dact, svs in dialogue_act.items():
305
+ assert "-" in dact
306
+ if self.unify_act: # will use only act part without domain info
307
+ domain, act = dact.split(
308
+ "-"
309
+ ) # split `domain-act`, e.g., `hotel-inform` -> hotel, inform
310
+ else: # keep original mwoz act
311
+ act = dact # use act with domain info
312
+
313
+ if self.unify_act:
314
+ # unify act: `Booking-Inform` with no args is equivalent to `OfferBook` in train domain
315
+ if dact == "Booking-Inform" and svs == [["none", "none"]]:
316
+ act = "OfferBook"
317
+
318
+ # deal with act
319
+ if self.unify_act:
320
+ assert act in self.original_act_set
321
+ if turn["speaker"] == "USER":
322
+ assert act in ["Inform", "Request", "bye", "thank", "greet"]
323
+ act = get_act_natural_language(act)
324
+
325
+ if act not in act2sv:
326
+ act2sv[act] = dict()
327
+
328
+ # iterate slot value pairs
329
+ for slot, value in svs:
330
+ slot = normalise_slot(slot)
331
+ value = normalise_value(value)
332
+
333
+ # act to slot value pairs
334
+ # NOTE: same slot might appear more than once per turn, e.g., when the system informs two hotels with
335
+ # their addresses so a value list is stored for each slot
336
+ if slot not in act2sv[act]:
337
+ act2sv[act][slot] = []
338
+ act2sv[act][slot].append(value)
339
+
340
+ # collect constraints
341
+ if act in ["REQUEST", "Request", "request"]:
342
+ constraints["requestable"].add(slot)
343
+ else:
344
+ if slot != "Empty":
345
+ if (
346
+ slot not in constraints["informable"]
347
+ ): # NOTE: same reason as act, value list per slot
348
+ constraints["informable"][slot] = []
349
+ constraints["informable"][slot].append(value)
350
+ return act2sv, constraints
351
+
352
+ def _linearise(self, scenario, service2meta):
353
+ self.linear_act = {}
354
+ self.linear_act["sys"] = self._linearise_act(self.act2sv["sys"])
355
+ self.linear_act["usr"] = self._linearise_act(self.act2sv["usr"])
356
+ self.linear_goal = self._linearise_goal(
357
+ self.usr_constraints, scenario, service2meta
358
+ )
359
+
360
+ def _linearise_goal(self, constraints, scenario, service2meta):
361
+ """
362
+ linearise goal representation which consists of several parts:
363
+ scenario, task (intent), task description, constraints with informable and requestable
364
+ e.g., <SCENARIO/> task1 task2 .. </SCENARIO>
365
+ <TASK/> current task </TASK> <DESC/> task description </DESC>
366
+ <INFORM/> <SLOT/> slot1 </SLOT> <VALUE> value1 </VALUE> .. </INFORM>
367
+ <REQUEST/> <SLOT> slot1 </SLOT> <SLOT> slot2 </SLOT> .. </REQUEST>
368
+ """
369
+ res = ""
370
+ # scenario
371
+ assert isinstance(scenario, list) and len(scenario) > 0
372
+ scenario = " ".join(
373
+ [wrap_element("INTENT", intent) for intent in scenario]
374
+ ) # treat intent as nl
375
+ scenario_wrap = wrap_element("SCENARIO", scenario)
376
+ res = add_str(res, scenario_wrap)
377
+
378
+ # task name
379
+ intent = self.usr_intent
380
+ assert intent in scenario
381
+ intent_wrap = wrap_element("TASK", intent)
382
+ res = add_str(res, intent_wrap)
383
+
384
+ # task description
385
+ description = service2meta[self.service]["intents"][intent]["description"]
386
+ description_warp = wrap_element("DESC", description)
387
+ res = add_str(res, description_warp)
388
+
389
+ # informable
390
+ informable = dict2list(
391
+ constraints["informable"]
392
+ ) # sorted sv pair list [slot=value]
393
+ res = add_str(res, "<INFORM/>")
394
+ for sv_pair in informable:
395
+ slot, value = sv_pair.split("=")
396
+ if value in ["True", "False", "Empty"]:
397
+ value = conv_special_token(value, SPECIAL_TOKENS)
398
+ if slot in ["Empty"]:
399
+ slot = conv_special_token(slot, SPECIAL_TOKENS)
400
+ # slot
401
+ slot_wrap = wrap_element("SLOT", slot)
402
+ res = add_str(res, slot_wrap)
403
+ # value
404
+ value_wrap = wrap_element("VALUE", value)
405
+ res = add_str(res, value_wrap)
406
+ res = add_str(res, "</INFORM>")
407
+
408
+ # requestable
409
+ requestable = sorted(
410
+ list(constraints["requestable"])
411
+ ) # sorted slot list [slot]
412
+ res = add_str(res, "<REQUEST/>")
413
+ for slot in requestable:
414
+ slot_wrap = wrap_element("SLOT", slot)
415
+ res = add_str(res, slot_wrap)
416
+ res = add_str(res, "</REQUEST>")
417
+ return res[1:] # remove first space
418
+
419
+ def _linearise_act(self, act2sv):
420
+ """
421
+ NOTE: 1) split slot/value if "_"; 2) special tokens of acts; 3) empty slot or empty value
422
+ NOTE: filer too many values (e.g., 10 movie names) but make sure the one the user chose is present
423
+
424
+ Return: ordered (slots sorted within act, acts sorted) linearised act sequence,
425
+ e.g., <ACT/> <INFORM> </ACT> <SLOT/> area </SLOT> <VALUE/> Cambridge </VALUE> ...
426
+ e.g., <ACT/> <REQUEST> </ACT> <SLOT/> _Empty_ </SLOT> <VALUE/> _Empty_ </VALUE>
427
+ """
428
+ res = ""
429
+ if act2sv is None:
430
+ return res
431
+
432
+ for act in sorted(act2sv.keys()): # sort act
433
+ sv = act2sv[act] # dict{slot: value_list}
434
+ act_wrap = wrap_element("ACT", act)
435
+ res = add_str(res, act_wrap)
436
+
437
+ sorted_sv = dict2list(
438
+ sv
439
+ ) # sorted sv list, [s1=v1, s2=v2], note slot can repeat
440
+ for sv_pair in sorted_sv:
441
+ slot, value = sv_pair.split("=")
442
+ if value in ["True", "False", "Empty"]:
443
+ value = conv_special_token(value, SPECIAL_TOKENS)
444
+ if slot in ["Empty"]:
445
+ slot = conv_special_token(slot, SPECIAL_TOKENS)
446
+
447
+ # slot
448
+ slot_wrap = wrap_element("SLOT", slot)
449
+ res = add_str(res, slot_wrap)
450
+
451
+ # value
452
+ value_wrap = wrap_element("VALUE", value)
453
+ res = add_str(res, value_wrap)
454
+
455
+ return res[1:] # remove first space
456
+
457
+ def _get_new_task_label(self, prev_intent):
458
+ """
459
+ get a binary label indicating if a turn starts a new task (intent) in dialogue
460
+ """
461
+ assert prev_intent != "NONE" and self.usr_intent != "NONE"
462
+ if self.usr_intent != prev_intent:
463
+ self.start_new_task = True
464
+ else:
465
+ self.start_new_task = False
466
+
467
+ def _get_req_alts_label(self):
468
+ self.req_alts = False # no request alternative in mwoz
469
+
470
+
471
+ def collect_examples(dial_id, dial_meta, examples):
472
+ num = 0
473
+ examples[dial_id] = {}
474
+ for turn_meta in dial_meta.turn_meta_list:
475
+ if turn_meta is None: # sys turn
476
+ continue
477
+
478
+ example_id = "{}-{}".format(dial_id, num)
479
+ example = {
480
+ "utterances": turn_meta.utt,
481
+ "actions": turn_meta.linear_act,
482
+ "goal": turn_meta.linear_goal,
483
+ "service": turn_meta.service,
484
+ "intent": turn_meta.usr_intent,
485
+ "goal_change": turn_meta.goal_change,
486
+ "start_new_task": turn_meta.start_new_task,
487
+ "req_alts": turn_meta.req_alts,
488
+ }
489
+ examples[dial_id][example_id] = example
490
+ num += 1
491
+
492
+
493
+ def prepare_data_seq(unify_act, out_data_path):
494
+ for split in DATA_SPLIT:
495
+ examples = {}
496
+ for dial_num, dial_id in enumerate(tqdm(sorted(data[split].keys()))):
497
+ dial = data[split][dial_id]
498
+ dial_act = data_act[dial_id]
499
+
500
+ dial_meta = DialMetaData(dial_id, dial, dial_act, unify_act)
501
+ collect_examples(dial_id, dial_meta, examples)
502
+
503
+ with open("{}/{}.json".format(out_data_path, split), "w") as f:
504
+ json.dump(examples, f, sort_keys=True, indent=4)
505
+ print("Done process {} {} dialogues".format(split, len(examples)))
506
+
507
+
508
+ if __name__ == "__main__":
509
+ if len(sys.argv) == 1:
510
+ print("Wrong argument!")
511
+ print("usage: python utils/preprocess_multiwoz.py multiwoz2.2-data-path")
512
+ sys.exit(1)
513
+
514
+ # Set data path
515
+ data_path = sys.argv[1]
516
+ out_data_path = "./data/preprocessed/user_model"
517
+ os.makedirs(out_data_path, exist_ok=True)
518
+
519
+ # Control flags
520
+ unify_act = True
521
+ remove_book_intent = True
522
+
523
+ # Load data and material as global var
524
+ SERVICE2META, INTENTS, SLOTS = load_schema(os.path.join(data_path, "schema.json"))
525
+ SPECIAL_TOKENS = get_special_tokens()
526
+ data, data_act = collect_data(data_path, remove_dial_switch=False)
527
+
528
+ prepare_data_seq(unify_act, out_data_path)
scripts/user_model_code/preprocess_sgd.py ADDED
@@ -0,0 +1,431 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import sys
4
+
5
+ from tqdm import tqdm
6
+
7
+ from crazyneuraluser.user_model_code.analysis_sgd import DATA_SPLIT, collect_data
8
+ from crazyneuraluser.user_model_code.utils_sgd import (
9
+ add_str,
10
+ compare_slot_values_in_state,
11
+ dict2list,
12
+ get_special_tokens,
13
+ get_turn_intent,
14
+ load_schema,
15
+ split_intent,
16
+ wrap_element,
17
+ )
18
+
19
+ """pre-processing script for SGD
20
+
21
+ The annotations for a turn are grouped into frames, where each frame corresponds to a single service
22
+ The values of "slot_values" in user "state" is a list, where spoken variations are considered, e.g., tomorrow, 8/2
23
+ """
24
+
25
+
26
+ class DialMetaData:
27
+ def __init__(self, dial_id, dial):
28
+ self.dial_id = dial_id
29
+ self.turn_meta_list, self.scenario = self.parse(dial) # None for system turn
30
+ self.linearise_turns()
31
+
32
+ def parse(self, dial):
33
+ turn_meta_list = []
34
+ scenario = []
35
+ sys_turn = None # dummy sys turn for first usr turn
36
+ prev_intent = ""
37
+ prev_usr_turn, prev_usr_turn_meta = (
38
+ None,
39
+ None,
40
+ ) # dummpy for tracing goal change at first turn
41
+ for turn_id, turn in enumerate(dial["turns"]):
42
+ if turn["speaker"] == "SYSTEM":
43
+ sys_turn = turn
44
+ turn_meta_list.append(None)
45
+ continue
46
+
47
+ # init turn meta
48
+ turn_meta = TurnMetaData(prev_intent, sys_turn, turn, self.dial_id)
49
+
50
+ # get goal change label
51
+ turn_meta.get_goal_change_label(prev_usr_turn, prev_usr_turn_meta)
52
+
53
+ # update previous goal
54
+ for prev_turn_meta in reversed(turn_meta_list):
55
+ if prev_turn_meta is None:
56
+ continue
57
+ prev_turn_meta.accumulate_constraints(turn_meta)
58
+
59
+ # record task (intent) in scenario
60
+ prev_intent = turn_meta.usr_intent
61
+ if turn_meta.usr_intent not in scenario:
62
+ scenario.append(turn_meta.usr_intent)
63
+
64
+ turn_meta_list.append(turn_meta)
65
+ prev_usr_turn, prev_usr_turn_meta = turn, turn_meta
66
+
67
+ assert len(turn_meta_list) == len(dial["turns"])
68
+ return turn_meta_list, scenario
69
+
70
+ def linearise_turns(self):
71
+ # linearise necessary meterials
72
+ for turn_meta in self.turn_meta_list:
73
+ if turn_meta is None:
74
+ continue
75
+ turn_meta._linearise(self.scenario)
76
+
77
+
78
+ class TurnMetaData:
79
+ def __init__(self, prev_intent, sys_turn, usr_turn, dial_id):
80
+ self.dial_id = dial_id
81
+ self.sys_turn, self.usr_turn = sys_turn, usr_turn
82
+ self.empty_token = "_Empty_"
83
+ assert self.empty_token in SPECIAL_TOKENS["additional_special_tokens"]
84
+
85
+ # intent
86
+ self.usr_intent, self.service = self._get_intent(usr_turn, prev_intent)
87
+
88
+ # utterances
89
+ self.utt = {}
90
+ self.utt["sys"], self.utt["usr"] = self._get_utt(sys_turn), self._get_utt(
91
+ usr_turn
92
+ )
93
+
94
+ # action
95
+ self.act2sv = {}
96
+ self.act2sv["sys"], _ = self._parse_action(sys_turn)
97
+ self.act2sv["usr"], self.usr_constraints = self._parse_action(usr_turn)
98
+
99
+ # task boundary
100
+ self._get_new_task_label(prev_intent)
101
+
102
+ # req_alts
103
+ self._get_req_alts_label(self.act2sv["usr"])
104
+
105
+ def _get_intent(self, turn, prev_intent):
106
+ """manually set the `NONE` intent to the intent of previous turn"""
107
+ active_intent, service = get_turn_intent(
108
+ turn
109
+ ) # intent annotation (migt be `NONE`)
110
+ if active_intent == "NONE":
111
+ active_intent = prev_intent
112
+ return active_intent, service
113
+
114
+ def _get_utt(self, turn):
115
+ if turn is None:
116
+ return ""
117
+ return turn["utterance"]
118
+
119
+ def _parse_action(self, turn):
120
+ """
121
+ parse action annotation to collect turn level information
122
+ 1) act to slot-value pairs, dict{act: {slot: value}}
123
+ 2) turn level constraints, dict{'informable': dict{slot: value}, 'requestable': set(slot)}
124
+ """
125
+ # get mapping from act to slot-value pairs
126
+ act2sv = {}
127
+ info_req = {"informable": dict(), "requestable": set()} # constraints
128
+
129
+ if turn is None:
130
+ return None, info_req
131
+
132
+ for frame in turn["frames"]:
133
+ for action in frame["actions"]:
134
+ act, slot, values = action["act"], action["slot"], action["values"]
135
+
136
+ # deal with empty slot or value
137
+ if turn["speaker"] == "USER":
138
+ assert len(values) in [0, 1]
139
+ if slot == "":
140
+ slot = self.empty_token
141
+ value = values[0] if len(values) > 0 else self.empty_token
142
+
143
+ # act to slot-value pairs
144
+ if act not in act2sv:
145
+ act2sv[act] = {}
146
+ assert slot not in act2sv[act]
147
+ act2sv[act][slot] = value
148
+
149
+ # collect constraints
150
+ if slot in [
151
+ "",
152
+ self.empty_token,
153
+ ]: # only act but no constraints, e.g., AFFIRM, NEGATE
154
+ continue
155
+
156
+ # turn level informalable and requestable info
157
+ if act == "REQUEST":
158
+ assert slot != ""
159
+ info_req["requestable"].add(slot)
160
+ else:
161
+ if turn["speaker"] == "USER":
162
+ assert act in [
163
+ "INFORM_INTENT",
164
+ "INFORM",
165
+ "SELECT",
166
+ ] # not apply to system side
167
+ if (
168
+ act != "SELECT"
169
+ ): # result offered by system is part of initial user goal
170
+ assert slot not in info_req["informable"]
171
+ info_req["informable"][slot] = value
172
+ return act2sv, info_req
173
+
174
+ def accumulate_constraints(self, new_turn_meta):
175
+ """
176
+ Add slot, slot-value pairs from a given following turn
177
+ This function is used to form user goal by accumulating constraints backward
178
+ """
179
+ # only accumulate constraints with the same task/intent
180
+ if new_turn_meta.usr_intent != self.usr_intent:
181
+ return
182
+
183
+ if (
184
+ new_turn_meta.goal_change
185
+ ): # if goal changes at a new turn, these constraints should not be put in previous turns
186
+ return
187
+
188
+ # only accumulate constraints without goal change
189
+ # if the value of a slot is changed (goal change) in a new turn,
190
+ # this slot-value pair is not part of initial goal and should not be added into the goal of previous turns
191
+ new_constraints = new_turn_meta.usr_constraints
192
+ self.usr_constraints["requestable"] = self.usr_constraints["requestable"].union(
193
+ new_constraints["requestable"]
194
+ )
195
+ for slot, value in new_constraints["informable"].items():
196
+ if slot not in self.usr_constraints["informable"]:
197
+ self.usr_constraints["informable"][slot] = value
198
+
199
+ def _get_new_task_label(self, prev_intent):
200
+ """get a binary label indicating if a turn starts a new task (intent) in dialogue"""
201
+ assert prev_intent != "NONE" and self.usr_intent != "NONE"
202
+ if self.usr_intent != prev_intent:
203
+ self.start_new_task = True
204
+ else:
205
+ self.start_new_task = False
206
+
207
+ def _get_req_alts_label(self, act2sv):
208
+ """get a binary label indicating if usr requests alternatives"""
209
+ if "REQUEST_ALTS" in act2sv:
210
+ self.req_alts = True
211
+ else:
212
+ self.req_alts = False
213
+
214
+ def get_goal_change_label(self, prev_usr_turn, prev_turn_meta):
215
+ """check if goal changed (value of slot changes) between two turn states"""
216
+ if prev_usr_turn is None: # first usr turn
217
+ self.goal_change = False
218
+ return
219
+
220
+ if (
221
+ len(self.usr_turn["frames"]) == 1
222
+ and self.usr_turn["frames"][0]["state"]["active_intent"] == "NONE"
223
+ ): # `NONE` intent
224
+ self.goal_change = False
225
+ return
226
+
227
+ if self.usr_intent != prev_turn_meta.usr_intent: # new task
228
+ self.goal_change = False
229
+ return
230
+
231
+ assert prev_usr_turn["speaker"] == "USER"
232
+ prev_state_sv, curr_state_sv = None, None
233
+ for frame in prev_usr_turn["frames"]:
234
+ if frame["state"]["active_intent"] == self.usr_intent:
235
+ prev_state_sv = frame["state"]["slot_values"]
236
+
237
+ # fix some weird cases (count very few, around 30 turns)
238
+ if prev_state_sv is None:
239
+ assert (
240
+ len(prev_usr_turn["frames"]) == 1
241
+ and prev_usr_turn["frames"][0]["state"]["active_intent"] == "NONE"
242
+ )
243
+ prev_state_sv = prev_usr_turn["frames"][0]["state"]["slot_values"]
244
+
245
+ for frame in self.usr_turn["frames"]:
246
+ if frame["state"]["active_intent"] == self.usr_intent:
247
+ curr_state_sv = frame["state"]["slot_values"]
248
+
249
+ assert prev_state_sv is not None and curr_state_sv is not None
250
+ self.goal_change = compare_slot_values_in_state(
251
+ prev_state_sv, curr_state_sv
252
+ ) # True if goal changes
253
+
254
+ def _linearise(self, scenario):
255
+ self.linear_act = {}
256
+ self.linear_act["sys"] = self._linearise_act(self.act2sv["sys"])
257
+ self.linear_act["usr"] = self._linearise_act(self.act2sv["usr"])
258
+ self.linear_goal = self._linearise_goal(self.usr_constraints, scenario)
259
+
260
+ def _linearise_act(self, act2sv):
261
+ """
262
+ NOTE: 1) split slot/value if "_"; 2) special tokens of acts; 3) empty slot or empty value
263
+ NOTE: filer too many values (e.g., 10 movie names) but make sure the one the user chose is present
264
+
265
+ Return: ordered (slots sorted within act, acts sorted) linearised act sequence,
266
+ e.g., <ACT/> <INFORM> </ACT> <SLOT/> area </SLOT> <VALUE/> Cambridge </VALUE> ...
267
+ e.g., <ACT/> <REQUEST> </ACT> <SLOT/> _Empty_ </SLOT> <VALUE/> _Empty_ </VALUE>
268
+ """
269
+ res = ""
270
+ if act2sv is None:
271
+ return res
272
+
273
+ for act in sorted(act2sv.keys()): # sort act
274
+ sv = act2sv[act] # dict{slot: value}
275
+
276
+ act = "_{}_".format(act) # act is special token
277
+ assert act in SPECIAL_TOKENS["additional_special_tokens"]
278
+ act_wrap = wrap_element("ACT", act)
279
+ res = add_str(res, act_wrap)
280
+
281
+ sorted_sv = dict2list(sv) # sorted sv list, [slot=value]
282
+ for sv_pair in sorted_sv:
283
+ slot, value = sv_pair.split("=")
284
+ slot, value = self._basic_normalise_slot(
285
+ slot
286
+ ), self._basic_normalise_value(value, slot)
287
+
288
+ # slot
289
+ slot_wrap = wrap_element("SLOT", slot)
290
+ res = add_str(res, slot_wrap)
291
+
292
+ # value
293
+ value_wrap = wrap_element("VALUE", value)
294
+ res = add_str(res, value_wrap)
295
+ return res[1:] # remove first space
296
+
297
+ def _basic_normalise_value(self, value, slot):
298
+ # intent value
299
+ if slot == "intent":
300
+ value = split_intent(value)
301
+ return value
302
+
303
+ # special token value
304
+ if value in ["True", "False"]: # Empty is already in the form of "_Empty_"
305
+ value = "_{}_".format(value)
306
+ assert value in SPECIAL_TOKENS["additional_special_tokens"]
307
+ return value
308
+ return value
309
+
310
+ def _basic_normalise_slot(self, slot):
311
+ if slot not in SPECIAL_TOKENS["additional_special_tokens"]:
312
+ slot = slot.replace(
313
+ "_", " "
314
+ ) # e.g., `date_of_journey` -> `date of journey`
315
+ return slot
316
+
317
+ def _linearise_goal(self, constraints, scenario):
318
+ """
319
+ linearise goal representation which consists of several parts:
320
+ scenario, task (intent), task description, constraints with informable and requestable
321
+ e.g., <SCENARIO/> task1 task2 .. </SCENARIO>
322
+ <TASK/> current task </TASK> <DESC/> task description </DESC>
323
+ <INFORM/> <SLOT/> slot1 </SLOT> <VALUE> value1 </VALUE> .. </INFORM>
324
+ <REQUEST/> <SLOT> slot1 </SLOT> <SLOT> slot2 </SLOT> .. </REQUEST>
325
+ """
326
+ res = ""
327
+ # scenario
328
+ assert isinstance(scenario, list) and len(scenario) > 0
329
+ scenario = " ".join(
330
+ [wrap_element("INTENT", split_intent(intent)) for intent in scenario]
331
+ )
332
+ scenario_wrap = wrap_element("SCENARIO", scenario)
333
+ res = add_str(res, scenario_wrap)
334
+
335
+ # task name
336
+ intent = split_intent(self.usr_intent)
337
+ assert intent in scenario
338
+ intent_wrap = wrap_element("TASK", intent)
339
+ res = add_str(res, intent_wrap)
340
+
341
+ # task description
342
+ description = SERVICE2META[self.service]["intents"][self.usr_intent][
343
+ "description"
344
+ ]
345
+ description_warp = wrap_element("DESC", description)
346
+ res = add_str(res, description_warp)
347
+
348
+ # informable
349
+ informable = dict2list(
350
+ constraints["informable"]
351
+ ) # sorted sv pair list [slot=value]
352
+ res = add_str(res, "<INFORM/>")
353
+ for sv_pair in informable:
354
+ slot, value = sv_pair.split("=")
355
+ slot, value = self._basic_normalise_slot(slot), self._basic_normalise_value(
356
+ value, slot
357
+ )
358
+ # slot
359
+ slot_wrap = wrap_element("SLOT", slot)
360
+ res = add_str(res, slot_wrap)
361
+ # value
362
+ value_wrap = wrap_element("VALUE", value)
363
+ res = add_str(res, value_wrap)
364
+ res = add_str(res, "</INFORM>")
365
+
366
+ # requestable
367
+ requestable = sorted(
368
+ list(constraints["requestable"])
369
+ ) # sorted slot list [slot]
370
+ res = add_str(res, "<REQUEST/>")
371
+ for slot in requestable:
372
+ slot = self._basic_normalise_slot(slot)
373
+ slot_wrap = wrap_element("SLOT", slot)
374
+ res = add_str(res, slot_wrap)
375
+ res = add_str(res, "</REQUEST>")
376
+ return res[1:] # remove first space
377
+
378
+
379
+ def collect_examples(dial_id, dial_meta, examples):
380
+ num = 0
381
+ examples[dial_id] = {}
382
+ for turn_meta in dial_meta.turn_meta_list:
383
+ if turn_meta is None: # sys turn
384
+ continue
385
+
386
+ example_id = "{}-{}".format(dial_id, num)
387
+ example = {
388
+ "utterances": turn_meta.utt,
389
+ "actions": turn_meta.linear_act,
390
+ "goal": turn_meta.linear_goal,
391
+ "service": turn_meta.service,
392
+ "intent": turn_meta.usr_intent,
393
+ "goal_change": turn_meta.goal_change,
394
+ "start_new_task": turn_meta.start_new_task,
395
+ "req_alts": turn_meta.req_alts,
396
+ }
397
+ examples[dial_id][example_id] = example
398
+ num += 1
399
+
400
+
401
+ def prepare_data_seq(data, out_data_path):
402
+ for split in DATA_SPLIT:
403
+ examples = {}
404
+ for dial_num, dial_id in enumerate(tqdm(sorted(data[split].keys()))):
405
+ dial = data[split][dial_id]
406
+ dial_meta = DialMetaData(dial_id, dial)
407
+ collect_examples(dial_id, dial_meta, examples)
408
+
409
+ with open("{}/{}.json".format(out_data_path, split), "w") as f:
410
+ json.dump(examples, f, sort_keys=True, indent=4)
411
+ print("Done process {} {} dialogues".format(split, len(examples)))
412
+
413
+
414
+ if __name__ == "__main__":
415
+ if len(sys.argv) == 1:
416
+ print("wrong arguments!")
417
+ print("usage: python utils/preprocess_sgd.py sgd-data-path")
418
+ sys.exit(1)
419
+
420
+ # Set data path
421
+ data_path = sys.argv[1]
422
+ out_data_path = "./processed_data/sgd/"
423
+ os.makedirs(out_data_path, exist_ok=True)
424
+
425
+ # Load data and material as global var
426
+ SERVICE2META, INTENTS, SLOTS = load_schema(data_path)
427
+ SPECIAL_TOKENS = get_special_tokens()
428
+ data = collect_data(data_path, remove_dial_switch=True)
429
+
430
+ # Process data
431
+ prepare_data_seq(data, out_data_path)
scripts/user_model_code/train.sh ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment=$1
2
+
3
+ # common setup
4
+ wandb_train_run_name="Full-user-model-training"
5
+ bs=16 # batch size for training
6
+ grad_step=2 # accumulated gradient steps
7
+ max_epoch=8 # max epoch for training
8
+ data_dir="./data/preprocessed/user_model"
9
+ train_size=-1 # number of examples used for training, -1 means all
10
+ eval_size=-1 # number of examples ued for evaluation, -1 means all
11
+
12
+
13
+
14
+ if [[ "$experiment" == "SGD" ]]; then
15
+ echo "Conduct experiment with SGD dataset"
16
+ job_name='SGD-full'
17
+ data_list="sgd" # 165k training examples
18
+ eval_interval=50000 # evaluation interval
19
+
20
+ elif [[ "$experiment" == "MultiWOZ" ]]; then
21
+ echo "Conduct experiment with MulwiWOZ dataset"
22
+ job_name='MultiWOZ-full'
23
+ data_list="multiwoz" # 56k training examples
24
+ eval_interval=20000
25
+
26
+ elif [[ "$experiment" == "Joint" ]]; then
27
+ echo "Conduct experiment with SGD + MulwiWOZ dataset"
28
+ job_name='Joint-full'
29
+ data_list="sgd multiwoz" # 221k training examples
30
+ eval_interval=70000
31
+
32
+ else
33
+ echo "Unrecognised argument"
34
+ exit
35
+ fi
36
+
37
+ mkdir -p checkpoint log
38
+ checkpoint='checkpoint/'$job_name
39
+ log='log/'$job_name'.log'
40
+ python ./scripts/user_model_code/main_user_model.py --mode='training' \
41
+ --wandb_train_run_name=$wandb_train_run_name \
42
+ --model_name=$job_name \
43
+ --checkpoint=$checkpoint \
44
+ --data_dir=$data_dir \
45
+ --data_list $data_list \
46
+ --train_size=$train_size \
47
+ --eval_size=$eval_size \
48
+ --eval_interval=$eval_interval \
49
+ --gradient_accumulation_steps=$grad_step \
50
+ --train_batch_size=$bs \
51
+ --max_epoch=$max_epoch
src/crazyneuraluser.egg-info/PKG-INFO ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Metadata-Version: 2.1
2
+ Name: crazyneuraluser
3
+ Version: 0.0.post1.dev47+g049b138.d20220509
4
+ Summary: Add a short description here!
5
+ Home-page: https://github.com/pyscaffold/pyscaffold/
6
+ Author: Extended by Alistair McLeay, original code by Alexandru Coca
7
+ Author-email: am@alistairmcleay.com and alexcoca23@yahoo.co.uk
8
+ License: MIT
9
+ Project-URL: Documentation, https://pyscaffold.org/
10
+ Platform: any
11
+ Classifier: Development Status :: 4 - Beta
12
+ Classifier: Programming Language :: Python
13
+ Description-Content-Type: text/markdown; charset=UTF-8; variant=GFM
14
+ Provides-Extra: testing
15
+ License-File: LICENSE.txt
16
+ License-File: AUTHORS.md
17
+
18
+ # Cambridge Masters Project
19
+ Joint Learning of Practical Dialogue Systems and User Simulators
20
+
21
+ ## Environment setup
22
+
23
+ 1. Create an environment `crazyneuraluser` with the help of [conda]
24
+ ```
25
+ conda env create -f environment.yml
26
+ ```
27
+ 2. Activate the new environment with:
28
+ ```
29
+ conda activate crazyneuraluser
30
+ ```
31
+ 3. Install a version of `pytorch` compatible with your hardware (see the [pytorch website](https://pytorch.org/get-started/previous-versions/)). E.g.:
32
+ ```
33
+ pip install torch --extra-index-url https://download.pytorch.org/whl/cu113
34
+ ```
35
+
36
+ 4. Install `spacy` and download the tokenization tool in spacy:
37
+ ```
38
+ pip install spacy'
39
+ python -m spacy download en_core_web_sm
40
+ ```
41
+
42
+ ### Generating dialogues through agent-agent interaction
43
+
44
+ To generate dialogues, first change working directory to the `baselines` directory. Run the command
45
+ ```
46
+ python baselines_setup.py
47
+ ```
48
+ to prepare `convlab2` for running the baselines.
49
+
50
+ #### Generating dialogues conditioned on randomly sampled goals
51
+
52
+ Select one of the available configurations in the `configs` directory and run the command
53
+ ```
54
+ python simulate_agent_interaction.py --config /rel/path/to/chosen/config
55
+ ```
56
+ to generate dialogues conditioned on randomly sampled goals according to the `convlab2` goal model. The dialogues will be be saved automatically in the `models` directory, under a directory whose name depends on the configuration run. The `models` directory is located in the parent directory of the `baselines` directory. The `metadata.json` file saved with the dialogues contains information about the data generation process.
57
+
58
+ #### Generating dialogues conditioned on `MultiWOZ2.1` goals
59
+
60
+ To generate the entire corpus, simply pass the `--goals-path /path/to/multiwoz2.1/data.json/file` flag to `simulate_agent_interaction.py`. To generate the `test/val` split additionally pass the `--filter-path /path/to/multiwoz2.1/test-or-valListFile` argument to `simulate_agent_interaction.py`. You can use the `generate_multiwoz21_train_id_file` function in `baselines/utils.py` to generate `trainListFile` which can then be passed via the `--filter-path` argument to the dialogue generation script in order to generate dialogues conditioned on the `MultiWOZ2.1` training goals.
61
+
62
+ ### Converting the generated dialogues to SGD-like format
63
+
64
+ The `create_data_from_multiwoz.py` script can be used to convert the generated dialogues to SGD format, necessary for evaluation. It is based on the script provided by Google for DSTC8, but with additional functionality such as:
65
+
66
+ - conversion of slot names as annotated in the MultiWOZ 2.1 dialogue acts to different slot names, specified through the `--slots_convention` argument. Options are `multiwoz22` to convert the slots to the same slots as defined in the MultiWOZ 2.2 dataset whreas the `multiwoz_goals` converts the slot names to the names used in the dialogue goal and state tracking annotations.
67
+
68
+ - addition of system and user `nlu` fields for every turn
69
+
70
+ - option to perform cleaning operations on the goals to ensure a standard format is received by the evaluator.
71
+
72
+ The conversion is done according to the `schema.json` file in the `baselines` directory, which is the same as used by `DSTC8` conversion except for the addition of the `police` domain. Type ``python create_data_from_multiwoz.py --helpfull`` to see a full list of flags and usage.
73
+
74
+ ## Installation
75
+
76
+ The recommended way to use this repository is to develop the core code under `src/crazyneuraluser`. The experiments/exporatory analysis making use of the core package code should be placed outside the library and imported. See more guidance under the [Project Organisation](#project-organization) section below.
77
+
78
+ To create an environment for the package, make sure you have deactivated all `conda` environments. Then:
79
+
80
+ 1. Create an environment `crazyneuraluser` with the help of [conda]:
81
+ ```
82
+ conda env create -f environment.yml
83
+ ```
84
+ 2. Add the developer dependencies to this environment with the help of [conda]:
85
+ ```
86
+ conda env update -f dev_environment.yml
87
+ ```
88
+
89
+ Optional and needed only once after `git clone`:
90
+
91
+ 3. install several [pre-commit] git hooks with:
92
+ ```bash
93
+ pre-commit install
94
+ # You _are encouraged_ to run `pre-commit autoupdate`
95
+ ```
96
+ and checkout the configuration under `.pre-commit-config.yaml`.
97
+ The `-n, --no-verify` flag of `git commit` can be used to deactivate pre-commit hooks temporarily.
98
+
99
+ 4. install [nbstripout] git hooks to remove the output cells of committed notebooks with:
100
+ ```bash
101
+ nbstripout --install --attributes notebooks/.gitattributes
102
+ ```
103
+ This is useful to avoid large diffs due to plots in your notebooks.
104
+ A simple `nbstripout --uninstall` will revert these changes.
105
+
106
+ Then take a look into the `scripts` and `notebooks` folders.
107
+
108
+ ## Dependency Management & Reproducibility
109
+
110
+ 1. Always keep your abstract (unpinned) dependencies updated in `environment.yml` and eventually
111
+ in `setup.cfg` if you want to ship and install your package via `pip` later on.
112
+ 2. Create concrete dependencies as `environment.lock.yml` for the exact reproduction of your
113
+ environment with:
114
+ ```bash
115
+ conda env export -n crazyneuraluser -f environment.lock.yml
116
+ ```
117
+ For multi-OS development, consider using `--no-builds` during the export.
118
+ 3. Update your current environment with respect to a new `environment.lock.yml` using:
119
+ ```bash
120
+ conda env update -f environment.lock.yml --prune
121
+ ```
122
+ ## Project Organization
123
+
124
+ ```
125
+ ├── AUTHORS.md <- List of developers and maintainers.
126
+ ├── CHANGELOG.md <- Changelog to keep track of new features and fixes.
127
+ ├── LICENSE.txt <- License as chosen on the command-line.
128
+ ├── README.md <- The top-level README for developers.
129
+ ├── configs <- Directory for configurations of model & application.
130
+ ├── data
131
+ │ ├── external <- Data from third party sources.
132
+ │ ├── interim <- Intermediate data that has been transformed.
133
+ │ ├── processed <- The final, canonical data sets for modeling.
134
+ │ └── raw <- The original, immutable data dump.
135
+ ├── docs <- Directory for Sphinx documentation in rst or md.
136
+ ├── environment.yml <- The conda environment file for reproducibility.
137
+ ├── models <- Trained and serialized models, model predictions,
138
+ │ or model summaries.
139
+ ├── notebooks <- Jupyter notebooks. Naming convention is a number (for
140
+ │ ordering), the creator's initials and a description,
141
+ │ e.g. `1.0-fw-initial-data-exploration`.
142
+ ├── pyproject.toml <- Build system configuration. Do not change!
143
+ ├── references <- Data dictionaries, manuals, and all other materials.
144
+ ├── reports <- Generated analysis as HTML, PDF, LaTeX, etc.
145
+ │ └── figures <- Generated plots and figures for reports.
146
+ ├── scripts <- Analysis and production scripts which import the
147
+ │ actual Python package, e.g. train_model.py.
148
+ ├── setup.cfg <- Declarative configuration of your project.
149
+ ├── setup.py <- Use `pip install -e .` to install for development or
150
+ | or create a distribution with `tox -e build`.
151
+ ├── src
152
+ │ └── crazyneuraluser <- Actual Python package where the main functionality goes.
153
+ ├── tests <- Unit tests which can be run with `py.test`.
154
+ ├── .coveragerc <- Configuration for coverage reports of unit tests.
155
+ ├── .isort.cfg <- Configuration for git hook that sorts imports.
156
+ └── .pre-commit-config.yaml <- Configuration of pre-commit git hooks.
157
+ ```
158
+
159
+ <!-- pyscaffold-notes -->
160
+
161
+ ## Note
162
+
163
+ This project has been set up using [PyScaffold] 4.0.1 and the [dsproject extension] 0.6.1.
164
+
165
+ [conda]: https://docs.conda.io/
166
+ [pre-commit]: https://pre-commit.com/
167
+ [Jupyter]: https://jupyter.org/
168
+ [nbstripout]: https://github.com/kynan/nbstripout
169
+ [Google style]: http://google.github.io/styleguide/pyguide.html#38-comments-and-docstrings
170
+ [PyScaffold]: https://pyscaffold.org/
171
+ [dsproject extension]: https://github.com/pyscaffold/pyscaffoldext-dsproject
172
+
173
+
src/crazyneuraluser.egg-info/SOURCES.txt ADDED
@@ -0,0 +1,76 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .coveragerc
2
+ .gitignore
3
+ .isort.cfg
4
+ .pre-commit-config.yaml
5
+ .readthedocs.yml
6
+ AUTHORS.md
7
+ CHANGELOG.md
8
+ CONVLAB_README.md
9
+ LICENSE.txt
10
+ README.md
11
+ baselines_environment.lock.yml
12
+ baselines_environment.yml
13
+ dev_environment.yml
14
+ environment.yml
15
+ pyproject.toml
16
+ setup.cfg
17
+ setup.py
18
+ tox.ini
19
+ baselines/__init__.py
20
+ baselines/_preprocess_raw_canonical_map.py
21
+ baselines/baseline_setup.py
22
+ baselines/canonical_map.json
23
+ baselines/canonical_map.py
24
+ baselines/correct_categorical_state_values.tsv
25
+ baselines/create_data_from_multiwoz.py
26
+ baselines/create_dbleu_reference_map.py
27
+ baselines/goal_new_values.json
28
+ baselines/sanity_checks.py
29
+ baselines/schema.json
30
+ baselines/simulate_agent_interaction.py
31
+ baselines/simulate_corpus_interaction.py
32
+ baselines/system_models.py
33
+ baselines/user_models.py
34
+ baselines/utils.py
35
+ baselines/configs/agent_agent.yaml
36
+ configs/.gitignore
37
+ data/.gitignore
38
+ data/external/.gitignore
39
+ data/interim/.gitignore
40
+ data/preprocessed/.gitignore
41
+ data/raw/.gitignore
42
+ docs/Makefile
43
+ docs/authors.md
44
+ docs/changelog.md
45
+ docs/conf.py
46
+ docs/index.md
47
+ docs/license.rst
48
+ docs/readme.md
49
+ docs/requirements.txt
50
+ docs/_static/.gitignore
51
+ models/.gitignore
52
+ notebooks/1.0-ac-goals_consistency_check.ipynb
53
+ notebooks/template.ipynb
54
+ references/.gitignore
55
+ reports/figures/.gitignore
56
+ scripts/data_analysis.py
57
+ scripts/preprocess.py
58
+ scripts/preprocess2.1.py
59
+ scripts/template_train_model.py
60
+ scripts/train_ubar.py
61
+ src/crazyneuraluser/__init__.py
62
+ src/crazyneuraluser/clean_dataset.py
63
+ src/crazyneuraluser/config.py
64
+ src/crazyneuraluser/config21.py
65
+ src/crazyneuraluser/db_ops.py
66
+ src/crazyneuraluser/eval.py
67
+ src/crazyneuraluser/ontology.py
68
+ src/crazyneuraluser/reader.py
69
+ src/crazyneuraluser/utils.py
70
+ src/crazyneuraluser.egg-info/PKG-INFO
71
+ src/crazyneuraluser.egg-info/SOURCES.txt
72
+ src/crazyneuraluser.egg-info/dependency_links.txt
73
+ src/crazyneuraluser.egg-info/not-zip-safe
74
+ src/crazyneuraluser.egg-info/requires.txt
75
+ src/crazyneuraluser.egg-info/top_level.txt
76
+ tests/conftest.py
src/crazyneuraluser.egg-info/dependency_links.txt ADDED
@@ -0,0 +1 @@
 
 
1
+
src/crazyneuraluser.egg-info/not-zip-safe ADDED
@@ -0,0 +1 @@
 
 
1
+
src/crazyneuraluser.egg-info/requires.txt ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ transformers==4.18.0
2
+ tqdm==4.64.0
3
+ wandb==0.12.16
4
+ nltk==3.7
5
+ sklearn==0.0
6
+ tensorboard==2.9.0
7
+ spacy==3.3.0
8
+
9
+ [:python_version < "3.8"]
10
+ importlib-metadata
11
+
12
+ [testing]
13
+ setuptools
14
+ pytest
15
+ pytest-cov
src/crazyneuraluser.egg-info/top_level.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ crazyneuraluser
src/crazyneuraluser/UBAR_code/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+
3
+ if sys.version_info[:2] >= (3, 8):
4
+ # TODO: Import directly (no need for conditional) when `python_requires = >= 3.8`
5
+ from importlib.metadata import PackageNotFoundError, version # pragma: no cover
6
+ else:
7
+ from importlib_metadata import PackageNotFoundError, version # pragma: no cover
8
+
9
+ try:
10
+ # Change here if project is renamed and does not equal the package name
11
+ dist_name = __name__
12
+ __version__ = version(dist_name)
13
+ except PackageNotFoundError: # pragma: no cover
14
+ __version__ = "unknown"
15
+ finally:
16
+ del version, PackageNotFoundError
src/crazyneuraluser/UBAR_code/clean_dataset.py ADDED
@@ -0,0 +1,334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import re
3
+
4
+ from crazyneuraluser.UBAR_code import ontology
5
+
6
+
7
+ def my_clean_text(text):
8
+ text = re.sub(r"([a-zT]+)\.([a-z])", r"\1 . \2", text) # 'abc.xyz' -> 'abc . xyz'
9
+ text = re.sub(r"(\w+)\.\.? ", r"\1 . ", text) # if 'abc. ' -> 'abc . '
10
+ return text
11
+
12
+
13
+ def clean_text(text):
14
+ text = text.strip()
15
+ text = text.lower()
16
+ text = text.replace("’", "'")
17
+ text = text.replace("‘", "'")
18
+ text = text.replace(";", ",")
19
+ text = text.replace('"', " ")
20
+ text = text.replace("/", " and ")
21
+ text = text.replace("don't", "do n't")
22
+ text = clean_time(text)
23
+ baddata = {
24
+ r"c\.b (\d), (\d) ([a-z])\.([a-z])": r"cb\1\2\3\4",
25
+ "c.b. 1 7 d.y": "cb17dy",
26
+ "c.b.1 7 d.y": "cb17dy",
27
+ "c.b 25, 9 a.q": "cb259aq",
28
+ "isc.b 25, 9 a.q": "is cb259aq",
29
+ "c.b2, 1 u.f": "cb21uf",
30
+ "c.b 1,2 q.a": "cb12qa",
31
+ "0-122-336-5664": "01223365664",
32
+ "postcodecb21rs": "postcode cb21rs",
33
+ r"i\.d": "id",
34
+ " i d ": "id",
35
+ "Telephone:01223358966": "Telephone: 01223358966",
36
+ "depature": "departure",
37
+ "depearting": "departing",
38
+ "-type": " type",
39
+ r"b[\s]?&[\s]?b": "bed and breakfast",
40
+ "b and b": "bed and breakfast",
41
+ r"guesthouse[s]?": "guest house",
42
+ r"swimmingpool[s]?": "swimming pool",
43
+ "wo n't": "will not",
44
+ " 'd ": " would ",
45
+ " 'm ": " am ",
46
+ " 're' ": " are ",
47
+ " 'll' ": " will ",
48
+ " 've ": " have ",
49
+ r"^\'": "",
50
+ r"\'$": "",
51
+ }
52
+ for tmpl, good in baddata.items():
53
+ text = re.sub(tmpl, good, text)
54
+
55
+ text = re.sub(r"([a-zT]+)\.([a-z])", r"\1 . \2", text) # 'abc.xyz' -> 'abc . xyz'
56
+ text = re.sub(r"(\w+)\.\.? ", r"\1 . ", text) # if 'abc. ' -> 'abc . '
57
+
58
+ with open("data/raw/UBAR/multi-woz/mapping.pair", "r") as fin:
59
+ for line in fin.readlines():
60
+ fromx, tox = line.replace("\n", "").split("\t")
61
+ text = " " + text + " "
62
+ text = text.replace(" " + fromx + " ", " " + tox + " ")[1:-1]
63
+
64
+ return text
65
+
66
+
67
+ def clean_time(utter):
68
+ utter = re.sub(
69
+ r"(\d+) ([ap]\.?m)", lambda x: x.group(1) + x.group(2), utter
70
+ ) # 9 am -> 9am
71
+ utter = re.sub(r"((?<!\d)\d:\d+)(am)?", r"0\1", utter)
72
+ utter = re.sub(r"((?<!\d)\d)am", r"0\1:00", utter)
73
+ utter = re.sub(r"((?<!\d)\d)pm", lambda x: str(int(x.group(1)) + 12) + ":00", utter)
74
+ utter = re.sub(
75
+ r"(\d+)(:\d+)pm", lambda x: str(int(x.group(1)) + 12) + x.group(2), utter
76
+ )
77
+ utter = re.sub(r"(\d+)a\.?m", r"\1", utter)
78
+ return utter
79
+
80
+
81
+ def clean_slot_values(domain, slot, value):
82
+ value = clean_text(value)
83
+ if not value:
84
+ value = ""
85
+ elif value == "not mentioned":
86
+ value = ""
87
+ # value = 'not mentioned' # if in DST setting
88
+ elif domain == "attraction":
89
+ if slot == "name":
90
+ if value == "t":
91
+ value = ""
92
+ if value == "trinity":
93
+ value = "trinity college"
94
+ elif slot == "area":
95
+ if value in ["town centre", "cent", "center", "ce"]:
96
+ value = "centre"
97
+ elif value in ["ely", "in town", "museum", "norwich", "same area as hotel"]:
98
+ value = ""
99
+ elif value in ["we"]:
100
+ value = "west"
101
+ elif slot == "type":
102
+ if value in ["m", "mus", "musuem"]:
103
+ value = "museum"
104
+ elif value in ["art", "architectural"]:
105
+ value = "architecture"
106
+ elif value in ["churches"]:
107
+ value = "church"
108
+ elif value in ["coll"]:
109
+ value = "college"
110
+ elif value in ["concert", "concerthall"]:
111
+ value = "concert hall"
112
+ elif value in ["night club"]:
113
+ value = "nightclub"
114
+ elif value in ["mutiple sports", "mutliple sports", "sports", "galleria"]:
115
+ value = "multiple sports"
116
+ elif value in ["ol", "science", "gastropub", "la raza"]:
117
+ value = ""
118
+ elif value in ["swimmingpool", "pool"]:
119
+ value = "swimming pool"
120
+ elif value in ["fun"]:
121
+ value = "entertainment"
122
+
123
+ elif domain == "hotel":
124
+ if slot == "area":
125
+ if value in ["cen", "centre of town", "near city center", "center"]:
126
+ value = "centre"
127
+ elif value in ["east area", "east side"]:
128
+ value = "east"
129
+ elif value in ["in the north", "north part of town"]:
130
+ value = "north"
131
+ elif value in ["we"]:
132
+ value = "west"
133
+ elif slot == "day":
134
+ if value == "monda":
135
+ value = "monday"
136
+ elif value == "t":
137
+ value = "tuesday"
138
+ elif slot == "name":
139
+ if value == "uni":
140
+ value = "university arms hotel"
141
+ elif value == "university arms":
142
+ value = "university arms hotel"
143
+ elif value == "acron":
144
+ value = "acorn guest house"
145
+ elif value == "ashley":
146
+ value = "ashley hotel"
147
+ elif value == "arbury lodge guesthouse":
148
+ value = "arbury lodge guest house"
149
+ elif value == "la":
150
+ value = "la margherit"
151
+ elif value == "no":
152
+ value = ""
153
+ elif slot == "internet":
154
+ if value == "does not":
155
+ value = "no"
156
+ elif value in ["y", "free", "free internet"]:
157
+ value = "yes"
158
+ elif value in ["4"]:
159
+ value = ""
160
+ elif slot == "parking":
161
+ if value == "n":
162
+ value = "no"
163
+ elif value in ["free parking"]:
164
+ value = "yes"
165
+ elif value in ["y"]:
166
+ value = "yes"
167
+ elif slot in ["pricerange", "price range"]:
168
+ slot = "pricerange"
169
+ if value == "moderately":
170
+ value = "moderate"
171
+ elif value in ["any"]:
172
+ value = "do n't care"
173
+ elif value in ["any"]:
174
+ value = "do n't care"
175
+ elif value in ["inexpensive"]:
176
+ value = "cheap"
177
+ elif value in ["2", "4"]:
178
+ value = ""
179
+ elif slot == "stars":
180
+ if value == "two":
181
+ value = "2"
182
+ elif value == "three":
183
+ value = "3"
184
+ elif value in ["4-star", "4 stars", "4 star", "four star", "four stars"]:
185
+ value = "4"
186
+ elif slot == "type":
187
+ if value == "0 star rarting":
188
+ value = ""
189
+ elif value == "guesthouse":
190
+ value = "guest house"
191
+ elif value not in ["hotel", "guest house", "do n't care"]:
192
+ value = ""
193
+ elif domain == "restaurant":
194
+ if slot == "area":
195
+ if value in [
196
+ "center",
197
+ "scentre",
198
+ "center of town",
199
+ "city center",
200
+ "cb30aq",
201
+ "town center",
202
+ "centre of cambridge",
203
+ "city centre",
204
+ ]:
205
+ value = "centre"
206
+ elif value == "west part of town":
207
+ value = "west"
208
+ elif value == "n":
209
+ value = "north"
210
+ elif value in ["the south"]:
211
+ value = "south"
212
+ elif value not in [
213
+ "centre",
214
+ "south",
215
+ "do n't care",
216
+ "west",
217
+ "east",
218
+ "north",
219
+ ]:
220
+ value = ""
221
+ elif slot == "day":
222
+ if value == "monda":
223
+ value = "monday"
224
+ elif value == "t":
225
+ value = "tuesday"
226
+ elif slot in ["pricerange", "price range"]:
227
+ slot = "pricerange"
228
+ if value in ["moderately", "mode", "mo"]:
229
+ value = "moderate"
230
+ elif value in ["not"]:
231
+ value = ""
232
+ elif value in ["inexpensive", "ch"]:
233
+ value = "cheap"
234
+ elif slot == "food":
235
+ if value == "barbecue":
236
+ value = "barbeque"
237
+ elif slot == "pricerange":
238
+ if value == "moderately":
239
+ value = "moderate"
240
+ elif slot == "time":
241
+ if value == "9:00":
242
+ value = "09:00"
243
+ elif value == "9:45":
244
+ value = "09:45"
245
+ elif value == "1330":
246
+ value = "13:30"
247
+ elif value == "1430":
248
+ value = "14:30"
249
+ elif value == "9:15":
250
+ value = "09:15"
251
+ elif value == "9:30":
252
+ value = "09:30"
253
+ elif value == "1830":
254
+ value = "18:30"
255
+ elif value == "9":
256
+ value = "09:00"
257
+ elif value == "2:00":
258
+ value = "14:00"
259
+ elif value == "1:00":
260
+ value = "13:00"
261
+ elif value == "3:00":
262
+ value = "15:00"
263
+ elif domain == "taxi":
264
+ if slot in ["arriveBy", "arrive by"]:
265
+ slot = "arriveby"
266
+ if value == "1530":
267
+ value = "15:30"
268
+ elif value == "15 minutes":
269
+ value = ""
270
+ elif slot in ["leaveAt", "leave at"]:
271
+ slot = "leaveat"
272
+ if value == "1:00":
273
+ value = "01:00"
274
+ elif value == "21:4":
275
+ value = "21:04"
276
+ elif value == "4:15":
277
+ value = "04:15"
278
+ elif value == "5:45":
279
+ value = "05:45"
280
+ elif value == "0700":
281
+ value = "07:00"
282
+ elif value == "4:45":
283
+ value = "04:45"
284
+ elif value == "8:30":
285
+ value = "08:30"
286
+ elif value == "9:30":
287
+ value = "09:30"
288
+ value = value.replace(".", ":")
289
+
290
+ elif domain == "train":
291
+ if slot in ["arriveBy", "arrive by"]:
292
+ slot = "arriveby"
293
+ if value == "1":
294
+ value = "01:00"
295
+ elif value in ["does not care", "doesnt care", "doesn't care"]:
296
+ value = "do n't care"
297
+ elif value == "8:30":
298
+ value = "08:30"
299
+ elif value == "not 15:45":
300
+ value = ""
301
+ value = value.replace(".", ":")
302
+ elif slot == "day":
303
+ if value == "doesnt care" or value == "doesn't care":
304
+ value = "do n't care"
305
+ elif slot in ["leaveAt", "leave at"]:
306
+ slot = "leaveat"
307
+ if value == "2:30":
308
+ value = "02:30"
309
+ elif value == "7:54":
310
+ value = "07:54"
311
+ elif value == "after 5:45 pm":
312
+ value = "17:45"
313
+ elif value in ["early evening", "friday", "sunday", "tuesday", "afternoon"]:
314
+ value = ""
315
+ elif value == "12":
316
+ value = "12:00"
317
+ elif value == "1030":
318
+ value = "10:30"
319
+ elif value == "1700":
320
+ value = "17:00"
321
+ elif value in [
322
+ "does not care",
323
+ "doesnt care",
324
+ "do nt care",
325
+ "doesn't care",
326
+ ]:
327
+ value = "do n't care"
328
+
329
+ value = value.replace(".", ":")
330
+ if value in ["dont care", "don't care", "do nt care", "doesn't care"]:
331
+ value = "do n't care"
332
+ if ontology.normlize_slot_names.get(slot):
333
+ slot = ontology.normlize_slot_names[slot]
334
+ return slot, value
src/crazyneuraluser/UBAR_code/config.py ADDED
@@ -0,0 +1,164 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+
5
+
6
+ class _Config:
7
+ def __init__(self):
8
+ self._multiwoz_ubar_init()
9
+
10
+ def _multiwoz_ubar_init(self):
11
+ self.gpt_path = "distilgpt2"
12
+
13
+ self.vocab_path_train = "./data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/vocab"
14
+ self.vocab_path_eval = None
15
+ self.data_path = "./data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/"
16
+ self.data_file = "data_for_ubar.json"
17
+ self.dev_list = "data/raw/UBAR/multi-woz/valListFile.json"
18
+ self.test_list = "data/raw/UBAR/multi-woz/testListFile.json"
19
+ self.dbs = {
20
+ "attraction": "data/preprocessed_gen_usr_utts/UBAR/db_processed/attraction_db_processed.json",
21
+ "hospital": "data/preprocessed_gen_usr_utts/UBAR/db_processed/hospital_db_processed.json",
22
+ "hotel": "data/preprocessed_gen_usr_utts/UBAR/db_processed/hotel_db_processed.json",
23
+ "police": "data/preprocessed_gen_usr_utts/UBAR/db_processed/police_db_processed.json",
24
+ "restaurant": "data/preprocessed_gen_usr_utts/UBAR/db_processed/restaurant_db_processed.json",
25
+ "taxi": "data/preprocessed_gen_usr_utts/UBAR/db_processed/taxi_db_processed.json",
26
+ "train": "data/preprocessed_gen_usr_utts/UBAR/db_processed/train_db_processed.json",
27
+ }
28
+ self.glove_path = "./data/glove/glove.6B.50d.txt"
29
+ self.domain_file_path = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/domain_files.json"
30
+ self.slot_value_set_path = "data/preprocessed_gen_usr_utts/UBAR/db_processed/value_set_processed.json"
31
+ self.multi_acts_path = "data/preprocessed_gen_usr_utts/UBAR/multi-woz-processed/multi_act_mapping_train.json"
32
+ self.exp_path = "to be generated"
33
+ self.log_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
34
+
35
+ # experiment settings
36
+ self.mode = "unknown"
37
+ self.cuda = False
38
+ self.cuda_device = [0]
39
+ self.exp_no = ""
40
+ self.seed = 11
41
+ self.save_log = True # tensorboard
42
+ self.evaluate_during_training = False # evaluate during training
43
+ self.report_interval = 200 # 485 for bs 128
44
+ self.max_nl_length = 60
45
+ self.max_span_length = 30
46
+ self.truncated = False
47
+
48
+ # training settings
49
+ self.lr = 1e-4
50
+ self.warmup_steps = -1
51
+ self.weight_decay = 0.0
52
+ self.gradient_accumulation_steps = 16
53
+ self.batch_size = 2
54
+
55
+ self.label_smoothing = 0.0
56
+ self.lr_decay = 0.5
57
+ self.epoch_num = 40
58
+ self.early_stop_count = 5
59
+ self.weight_decay_count = 3
60
+ self.teacher_force = 100
61
+ self.multi_acts_training = False
62
+ self.multi_act_sampling_num = 1
63
+ self.valid_loss = "score"
64
+
65
+ self.wandb_train_run_name = "Train with generated usr utterances"
66
+
67
+ # evaluation settings
68
+ self.eval_load_path = "models/UBAR/experiments/distilgpt-2_sd11_lr0.0001_bs16_ga2/epoch53_trloss0.59_gpt2"
69
+ self.model_output = "model_output_e2e_FFFT_fix_bs.json"
70
+ self.eval_per_domain = False
71
+ self.eval_set = "test" # test, dev
72
+
73
+ self.wandb_eval_run_name = "US Generated usr utterances evaluation"
74
+
75
+ # my setting
76
+ self.use_true_prev_bspn = False
77
+ self.use_true_prev_aspn = False
78
+ self.use_true_db_pointer = False
79
+ self.use_true_prev_resp = False
80
+
81
+ self.use_true_curr_bspn = False
82
+ self.use_true_curr_aspn = False
83
+ self.use_all_previous_context = True
84
+
85
+ self.exp_domains = ["all"] # hotel,train, attraction, restaurant, taxi
86
+ self.log_path = "logs_test"
87
+ self.low_resource = False
88
+ ###
89
+
90
+ # dst setting
91
+ self.fix_bs = True
92
+ self.use_nodelex_resp = True
93
+ self.max_context_length = 900
94
+ ##
95
+
96
+ # model settings
97
+ self.vocab_size = 3000
98
+ self.embed_size = 50
99
+ self.hidden_size = 100
100
+ self.pointer_dim = 6 # fixed
101
+ self.enc_layer_num = 1
102
+ self.dec_layer_num = 1
103
+ self.dropout = 0
104
+ self.layer_norm = False
105
+ self.skip_connect = False
106
+ self.encoder_share = False
107
+ self.attn_param_share = False
108
+ self.copy_param_share = False
109
+ self.enable_aspn = True
110
+ self.use_pvaspn = False
111
+ self.enable_bspn = True
112
+ self.bspn_mode = "bspn" # 'bspn' or 'bsdx'
113
+ self.enable_dspn = False # removed
114
+ self.enable_dst = False
115
+
116
+ self.use_true_bspn_for_ctr_eval = True
117
+ self.use_true_domain_for_ctr_eval = True
118
+ self.limit_bspn_vocab = False
119
+ self.limit_aspn_vocab = False
120
+ self.same_eval_as_cambridge = True
121
+ self.same_eval_act_f1_as_hdsa = False
122
+ self.aspn_decode_mode = "greedy" # beam, greedy, nucleur_sampling, topk_sampling
123
+ self.beam_width = 5
124
+ self.nbest = 5
125
+ self.beam_diverse_param = 0.2
126
+ self.act_selection_scheme = "high_test_act_f1"
127
+ self.topk_num = 1
128
+ self.nucleur_p = 0.0
129
+ self.record_mode = False
130
+
131
+ def __str__(self):
132
+ s = ""
133
+ for k, v in self.__dict__.items():
134
+ s += "{} : {}\n".format(k, v)
135
+ return s
136
+
137
+ def _init_logging_handler(self, mode):
138
+ stderr_handler = logging.StreamHandler()
139
+ if not os.path.exists("./log"):
140
+ os.mkdir("./log")
141
+ if self.save_log and self.mode == "train":
142
+ file_handler = logging.FileHandler(
143
+ "./log/log_{}_{}_{}_{}_sd{}.txt".format(
144
+ self.log_time,
145
+ mode,
146
+ "-".join(self.exp_domains),
147
+ self.exp_no,
148
+ self.seed,
149
+ )
150
+ )
151
+ logging.basicConfig(handlers=[stderr_handler, file_handler])
152
+ elif self.mode == "test":
153
+ eval_log_path = os.path.join(self.eval_load_path, "eval_log.json")
154
+ # if os.path.exists(eval_log_path):
155
+ # os.remove(eval_log_path)
156
+ file_handler = logging.FileHandler(eval_log_path)
157
+ logging.basicConfig(handlers=[stderr_handler, file_handler])
158
+ else:
159
+ logging.basicConfig(handlers=[stderr_handler])
160
+ logger = logging.getLogger()
161
+ logger.setLevel(logging.INFO)
162
+
163
+
164
+ global_config = _Config()
src/crazyneuraluser/UBAR_code/config21.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import time
4
+
5
+
6
+ class _Config:
7
+ def __init__(self):
8
+ self._multiwoz_ubar_init()
9
+
10
+ def _multiwoz_ubar_init(self):
11
+ self.gpt_path = "/data/yangyy/BERT-models/huggingface/distilgpt2/"
12
+
13
+ self.vocab_path_train = "./data/multi-woz-2.1-processed/vocab"
14
+ self.vocab_path_eval = None
15
+ self.data_path = "./data/multi-woz-2.1-processed/"
16
+ self.data_file = "data_for_ubar.json"
17
+ self.dev_list = "data/multi-woz/valListFile.json"
18
+ self.test_list = "data/multi-woz/testListFile.json"
19
+ self.dbs = {
20
+ "attraction": "data/preprocessed/UBAR/db_processed/attraction_db_processed.json",
21
+ "hospital": "data/preprocessed/UBAR/db_processed/hospital_db_processed.json",
22
+ "hotel": "data/preprocessed/UBAR/db_processed/hotel_db_processed.json",
23
+ "police": "data/preprocessed/UBAR/db_processed/police_db_processed.json",
24
+ "restaurant": "data/preprocessed/UBAR/db_processed/restaurant_db_processed.json",
25
+ "taxi": "data/preprocessed/UBAR/db_processed/taxi_db_processed.json",
26
+ "train": "data/preprocessed/UBAR/db_processed/train_db_processed.json",
27
+ }
28
+ self.glove_path = "./data/glove/glove.6B.50d.txt"
29
+ self.domain_file_path = (
30
+ "data/preprocessed/UBAR/multi-woz-2.1-processed/domain_files.json"
31
+ )
32
+ self.slot_value_set_path = (
33
+ "data/preprocessed/UBAR/db_processed/value_set_processed.json"
34
+ )
35
+ self.multi_acts_path = "data/preprocessed/UBAR/multi-woz-2.1-processed/multi_act_mapping_train.json"
36
+ self.exp_path = "to be generated"
37
+ self.log_time = time.strftime("%Y-%m-%d-%H-%M-%S", time.localtime())
38
+
39
+ # experiment settings
40
+ self.mode = "unknown"
41
+ self.cuda = True
42
+ self.cuda_device = [1]
43
+ self.exp_no = ""
44
+ self.seed = 11
45
+ self.exp_domains = ["all"]
46
+ self.save_log = True # tensorboard
47
+ self.evaluate_during_training = False # evaluate during training
48
+ self.report_interval = 200 # 485 for bs 128
49
+ self.max_nl_length = 60
50
+ self.max_span_length = 30
51
+ self.truncated = False
52
+
53
+ # model settings
54
+ self.vocab_size = 3000
55
+ self.embed_size = 50
56
+ self.hidden_size = 100
57
+ self.pointer_dim = 6 # fixed
58
+ self.enc_layer_num = 1
59
+ self.dec_layer_num = 1
60
+ self.dropout = 0
61
+ self.layer_norm = False
62
+ self.skip_connect = False
63
+ self.encoder_share = False
64
+ self.attn_param_share = False
65
+ self.copy_param_share = False
66
+ self.enable_aspn = True
67
+ self.use_pvaspn = False
68
+ self.enable_bspn = True
69
+ self.bspn_mode = "bsdx" # 'bspn' or 'bsdx'
70
+ self.enable_dspn = False # removed
71
+ self.enable_dst = False
72
+
73
+ # training settings
74
+ self.lr = 5e-4
75
+ self.warmup_steps = 2000 # gpt tbd
76
+ self.weight_decay = 0.0 # gpt tbd
77
+ self.gradient_accumulation_steps = 16
78
+ self.batch_size = 2
79
+
80
+ self.label_smoothing = 0.0
81
+ self.lr_decay = 0.5
82
+ self.epoch_num = 60
83
+ self.early_stop_count = 5
84
+ self.weight_decay_count = 3
85
+ self.teacher_force = 100
86
+ self.multi_acts_training = False
87
+ self.multi_act_sampling_num = 1
88
+ self.valid_loss = "score"
89
+
90
+ self.wandb_train_run_name = "Name to be added"
91
+
92
+ # evaluation settings
93
+ self.eval_load_path = "models/UBAR/experiments/all_0729_sd11_lr0.0001_bs2_ga16/epoch43_trloss0.56_gpt2"
94
+ self.model_output = "model_output_e2e_FFFT_fix_bs.json"
95
+ self.eval_per_domain = False
96
+ self.eval_set = "test" # test, dev
97
+
98
+ self.wandb_eval_run_name = "Name to be added"
99
+
100
+ # generation setting
101
+ self.use_true_prev_bspn = True
102
+ self.use_true_prev_aspn = True
103
+ self.use_true_db_pointer = False
104
+ self.use_true_prev_resp = True
105
+
106
+ self.use_true_curr_bspn = True
107
+ self.use_true_curr_aspn = False
108
+ self.use_all_previous_context = True
109
+
110
+ self.exp_domains = ["all"] # hotel,train, attraction, restaurant, taxi
111
+ self.log_path = "logs2.1"
112
+ self.low_resource = False
113
+
114
+ # dst setting
115
+ self.fix_bs = True
116
+ self.use_nodelex_resp = True
117
+ self.max_context_length = 900
118
+
119
+ self.use_true_bspn_for_ctr_eval = True
120
+ self.use_true_domain_for_ctr_eval = True
121
+ self.limit_bspn_vocab = False
122
+ self.limit_aspn_vocab = False
123
+ self.same_eval_as_cambridge = True
124
+ self.same_eval_act_f1_as_hdsa = False
125
+ self.aspn_decode_mode = (
126
+ "greedy" # beam, greedy, nucleur_sampling, topk_sampling
127
+ )
128
+ self.beam_width = 5
129
+ self.nbest = 5
130
+ self.beam_diverse_param = 0.2
131
+ self.act_selection_scheme = "high_test_act_f1"
132
+ self.topk_num = 1
133
+ self.nucleur_p = 0.0
134
+ self.record_mode = False
135
+
136
+ def __str__(self):
137
+ s = ""
138
+ for k, v in self.__dict__.items():
139
+ s += "{} : {}\n".format(k, v)
140
+ return s
141
+
142
+ def _init_logging_handler(self, mode):
143
+ stderr_handler = logging.StreamHandler()
144
+ if not os.path.exists("./log"):
145
+ os.mkdir("./log")
146
+ if self.save_log and self.mode == "train":
147
+ file_handler = logging.FileHandler(
148
+ "./log/log_{}_{}_{}_{}_sd{}.txt".format(
149
+ self.log_time,
150
+ mode,
151
+ "-".join(self.exp_domains),
152
+ self.exp_no,
153
+ self.seed,
154
+ )
155
+ )
156
+ logging.basicConfig(handlers=[stderr_handler, file_handler])
157
+ elif self.mode == "test":
158
+ eval_log_path = os.path.join(self.eval_load_path, "eval_log.json")
159
+ # if os.path.exists(eval_log_path):
160
+ # os.remove(eval_log_path)
161
+ file_handler = logging.FileHandler(eval_log_path)
162
+ logging.basicConfig(handlers=[stderr_handler, file_handler])
163
+ else:
164
+ logging.basicConfig(handlers=[stderr_handler])
165
+ logger = logging.getLogger()
166
+ logger.setLevel(logging.INFO)
167
+
168
+
169
+ global_config = _Config()
src/crazyneuraluser/UBAR_code/db_ops.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import random
3
+ import sqlite3
4
+ import string
5
+
6
+ from crazyneuraluser.UBAR_code.ontology import all_domains, db_domains
7
+
8
+
9
+ class MultiWozDB(object):
10
+ def __init__(self, db_paths):
11
+ self.dbs = {}
12
+ self.sql_dbs = {}
13
+ for domain in all_domains:
14
+ with open(db_paths[domain], "r") as f:
15
+ self.dbs[domain] = json.loads(f.read().lower())
16
+
17
+ def oneHotVector(self, domain, num):
18
+ """Return number of available entities for particular domain."""
19
+ vector = [0, 0, 0, 0]
20
+ if num == "":
21
+ return vector
22
+ if domain != "train":
23
+ if num == 0:
24
+ vector = [1, 0, 0, 0]
25
+ elif num == 1:
26
+ vector = [0, 1, 0, 0]
27
+ elif num <= 3:
28
+ vector = [0, 0, 1, 0]
29
+ else:
30
+ vector = [0, 0, 0, 1]
31
+ else:
32
+ if num == 0:
33
+ vector = [1, 0, 0, 0]
34
+ elif num <= 5:
35
+ vector = [0, 1, 0, 0]
36
+ elif num <= 10:
37
+ vector = [0, 0, 1, 0]
38
+ else:
39
+ vector = [0, 0, 0, 1]
40
+ return vector
41
+
42
+ def addBookingPointer(self, turn_da):
43
+ """Add information about availability of the booking option."""
44
+ # Booking pointer
45
+ # Do not consider booking two things in a single turn.
46
+ vector = [0, 0]
47
+ if turn_da.get("booking-nobook"):
48
+ vector = [1, 0]
49
+ if turn_da.get("booking-book") or turn_da.get("train-offerbooked"):
50
+ vector = [0, 1]
51
+ return vector
52
+
53
+ def addDBPointer(self, domain, match_num, return_num=False):
54
+ """Create database pointer for all related domains."""
55
+ # if turn_domains is None:
56
+ # turn_domains = db_domains
57
+ if domain in db_domains:
58
+ vector = self.oneHotVector(domain, match_num)
59
+ else:
60
+ vector = [0, 0, 0, 0]
61
+ return vector
62
+
63
+ def addDBIndicator(self, domain, match_num, return_num=False):
64
+ """Create database indicator for all related domains."""
65
+ # if turn_domains is None:
66
+ # turn_domains = db_domains
67
+ if domain in db_domains:
68
+ vector = self.oneHotVector(domain, match_num)
69
+ else:
70
+ vector = [0, 0, 0, 0]
71
+
72
+ # '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]'
73
+ if vector == [0, 0, 0, 0]:
74
+ indicator = "[db_nores]"
75
+ else:
76
+ indicator = "[db_%s]" % vector.index(1)
77
+ return indicator
78
+
79
+ def get_match_num(self, constraints, return_entry=False):
80
+ """Create database pointer for all related domains."""
81
+ match = {"general": ""}
82
+ entry = {}
83
+ # if turn_domains is None:
84
+ # turn_domains = db_domains
85
+ for domain in all_domains:
86
+ match[domain] = ""
87
+ if domain in db_domains and constraints.get(domain):
88
+ matched_ents = self.queryJsons(domain, constraints[domain])
89
+ match[domain] = len(matched_ents)
90
+ if return_entry:
91
+ entry[domain] = matched_ents
92
+ if return_entry:
93
+ return entry
94
+ return match
95
+
96
+ def pointerBack(self, vector, domain):
97
+ # multi domain implementation
98
+ # domnum = cfg.domain_num
99
+ if domain.endswith("]"):
100
+ domain = domain[1:-1]
101
+ if domain != "train":
102
+ nummap = {0: "0", 1: "1", 2: "2-3", 3: ">3"}
103
+ else:
104
+ nummap = {0: "0", 1: "1-5", 2: "6-10", 3: ">10"}
105
+ if vector[:4] == [0, 0, 0, 0]:
106
+ report = ""
107
+ else:
108
+ num = vector.index(1)
109
+ report = domain + ": " + nummap[num] + "; "
110
+
111
+ if vector[-2] == 0 and vector[-1] == 1:
112
+ report += "booking: ok"
113
+ if vector[-2] == 1 and vector[-1] == 0:
114
+ report += "booking: unable"
115
+
116
+ return report
117
+
118
+ def queryJsons(self, domain, constraints, exactly_match=True, return_name=False):
119
+ """Returns the list of entities for a given domain
120
+ based on the annotation of the belief state
121
+ constraints: dict e.g. {'pricerange': 'cheap', 'area': 'west'}
122
+ """
123
+ # query the db
124
+ if domain == "taxi":
125
+ return [
126
+ {
127
+ "taxi_colors": random.choice(self.dbs[domain][0]["taxi_colors"]),
128
+ "taxi_types": random.choice(self.dbs[domain][0]["taxi_types"]),
129
+ "taxi_phone": "".join(random.choices(string.digits, k=10)),
130
+ }
131
+ ]
132
+ if domain == "police":
133
+ return self.dbs["police"]
134
+ if domain == "hospital":
135
+ if constraints.get("department"):
136
+ for entry in self.dbs["hospital"]:
137
+ if entry.get("department") == constraints.get("department"):
138
+ return [entry]
139
+ else:
140
+ # Instead of returning an empty list which breaks lexicalisation, when is no department constraint,
141
+ # return the first entry from the hospital db so the user still gets hospital information.
142
+ return [self.dbs["hospital"][0]]
143
+
144
+ valid_cons = False
145
+ for v in constraints.values():
146
+ if v not in ["not mentioned", ""]:
147
+ valid_cons = True
148
+ if not valid_cons:
149
+ return []
150
+
151
+ match_result = []
152
+
153
+ if "name" in constraints:
154
+ for db_ent in self.dbs[domain]:
155
+ if "name" in db_ent:
156
+ cons = constraints["name"]
157
+ dbn = db_ent["name"]
158
+ if cons == dbn:
159
+ db_ent = db_ent if not return_name else db_ent["name"]
160
+ match_result.append(db_ent)
161
+ return match_result
162
+
163
+ for db_ent in self.dbs[domain]:
164
+ match = True
165
+ for s, v in constraints.items():
166
+ if s == "name":
167
+ continue
168
+ if (
169
+ s in ["people", "stay"]
170
+ or (domain == "hotel" and s == "day")
171
+ or (domain == "restaurant" and s in ["day", "time"])
172
+ ):
173
+ continue
174
+
175
+ skip_case = {
176
+ "don't care": 1,
177
+ "do n't care": 1,
178
+ "dont care": 1,
179
+ "not mentioned": 1,
180
+ "dontcare": 1,
181
+ "": 1,
182
+ }
183
+ if skip_case.get(v):
184
+ continue
185
+
186
+ if s not in db_ent:
187
+ # logging.warning('Searching warning: slot %s not in %s db'%(s, domain))
188
+ match = False
189
+ break
190
+
191
+ # v = 'guesthouse' if v == 'guest house' else v
192
+ # v = 'swimmingpool' if v == 'swimming pool' else v
193
+ v = "yes" if v == "free" else v
194
+
195
+ if s in ["arrive", "leave"]:
196
+ try:
197
+ h, m = v.split(":") # raise error if time value is not xx:xx format
198
+ v = int(h) * 60 + int(m)
199
+ except Exception:
200
+ match = False
201
+ break
202
+ time = int(db_ent[s].split(":")[0]) * 60 + int(db_ent[s].split(":")[1])
203
+ if s == "arrive" and v > time:
204
+ match = False
205
+ if s == "leave" and v < time:
206
+ match = False
207
+ else:
208
+ if exactly_match and v != db_ent[s]:
209
+ match = False
210
+ break
211
+ elif v not in db_ent[s]:
212
+ match = False
213
+ break
214
+
215
+ if match:
216
+ match_result.append(db_ent)
217
+
218
+ if not return_name:
219
+ return match_result
220
+ else:
221
+ if domain == "train":
222
+ match_result = [e["id"] for e in match_result]
223
+ else:
224
+ match_result = [e["name"] for e in match_result]
225
+ return match_result
226
+
227
+ def querySQL(self, domain, constraints):
228
+ if not self.sql_dbs:
229
+ for dom in db_domains:
230
+ db = "db/{}-dbase.db".format(dom)
231
+ conn = sqlite3.connect(db)
232
+ c = conn.cursor()
233
+ self.sql_dbs[dom] = c
234
+
235
+ sql_query = "select * from {}".format(domain)
236
+
237
+ flag = True
238
+ for key, val in constraints.items():
239
+ if (
240
+ val == ""
241
+ or val == "dontcare"
242
+ or val == "not mentioned"
243
+ or val == "don't care"
244
+ or val == "dont care"
245
+ or val == "do n't care"
246
+ ):
247
+ pass
248
+ else:
249
+ if flag:
250
+ sql_query += " where "
251
+ val2 = val.replace("'", "''")
252
+ # val2 = normalize(val2)
253
+ if key == "leaveAt":
254
+ sql_query += r" " + key + " > " + r"'" + val2 + r"'"
255
+ elif key == "arriveBy":
256
+ sql_query += r" " + key + " < " + r"'" + val2 + r"'"
257
+ else:
258
+ sql_query += r" " + key + "=" + r"'" + val2 + r"'"
259
+ flag = False
260
+ else:
261
+ val2 = val.replace("'", "''")
262
+ # val2 = normalize(val2)
263
+ if key == "leaveAt":
264
+ sql_query += r" and " + key + " > " + r"'" + val2 + r"'"
265
+ elif key == "arriveBy":
266
+ sql_query += r" and " + key + " < " + r"'" + val2 + r"'"
267
+ else:
268
+ sql_query += r" and " + key + "=" + r"'" + val2 + r"'"
269
+
270
+ try: # "select * from attraction where name = 'queens college'"
271
+ print(sql_query)
272
+ return self.sql_dbs[domain].execute(sql_query).fetchall()
273
+ except Exception:
274
+ return [] # TODO test it
275
+
276
+
277
+ if __name__ == "__main__":
278
+ dbPATHs = {
279
+ "attraction": "db/attraction_db_processed.json",
280
+ "hospital": "db/hospital_db_processed.json",
281
+ "hotel": "db/hotel_db_processed.json",
282
+ "police": "db/police_db_processed.json",
283
+ "restaurant": "db/restaurant_db_processed.json",
284
+ "taxi": "db/taxi_db_processed.json",
285
+ "train": "db/train_db_processed.json",
286
+ }
287
+ db = MultiWozDB(dbPATHs)
288
+ while True:
289
+ constraints = {}
290
+ inp = input("input belief state in fomat: domain-slot1=value1;slot2=value2...\n")
291
+ domain, cons = inp.split("-")
292
+ for sv in cons.split(";"):
293
+ s, v = sv.split("=")
294
+ constraints[s] = v
295
+ # res = db.querySQL(domain, constraints)
296
+ res = db.queryJsons(domain, constraints, return_name=True)
297
+ report = []
298
+ reidx = {
299
+ "hotel": 8,
300
+ "restaurant": 6,
301
+ "attraction": 5,
302
+ "train": 1,
303
+ }
304
+ # for ent in res:
305
+ # if reidx.get(domain):
306
+ # report.append(ent[reidx[domain]])
307
+ # for ent in res:
308
+ # if 'name' in ent:
309
+ # report.append(ent['name'])
310
+ # if 'trainid' in ent:
311
+ # report.append(ent['trainid'])
312
+ print(constraints)
313
+ print(res)
314
+ print("count:", len(res), "\nnames:", report)
src/crazyneuraluser/UBAR_code/eval.py ADDED
@@ -0,0 +1,932 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import logging
3
+ import math
4
+ from collections import Counter, OrderedDict
5
+
6
+ from nltk.util import ngrams
7
+
8
+ from crazyneuraluser.UBAR_code import ontology
9
+ from crazyneuraluser.UBAR_code.clean_dataset import clean_slot_values
10
+ from crazyneuraluser.UBAR_code.config import global_config as cfg
11
+
12
+
13
+ class BLEUScorer(object):
14
+ # BLEU score calculator via GentScorer interface
15
+ # it calculates the BLEU-4 by taking the entire corpus in
16
+ # Calulate based multiple candidates against multiple references
17
+ def __init__(self):
18
+ pass
19
+
20
+ def score(self, parallel_corpus):
21
+
22
+ # containers
23
+ count = [0, 0, 0, 0]
24
+ clip_count = [0, 0, 0, 0]
25
+ r = 0
26
+ c = 0
27
+ weights = [0.25, 0.25, 0.25, 0.25]
28
+
29
+ # accumulate ngram statistics
30
+ for hyps, refs in parallel_corpus:
31
+ hyps = [hyp.split() for hyp in hyps]
32
+ refs = [ref.split() for ref in refs]
33
+ for hyp in hyps:
34
+
35
+ for i in range(4):
36
+ # accumulate ngram counts
37
+ hypcnts = Counter(ngrams(hyp, i + 1))
38
+ cnt = sum(hypcnts.values())
39
+ count[i] += cnt
40
+
41
+ # compute clipped counts
42
+ max_counts = {}
43
+ for ref in refs:
44
+ refcnts = Counter(ngrams(ref, i + 1))
45
+ for ng in hypcnts:
46
+ max_counts[ng] = max(max_counts.get(ng, 0), refcnts[ng])
47
+ clipcnt = dict(
48
+ (ng, min(count, max_counts[ng]))
49
+ for ng, count in hypcnts.items()
50
+ )
51
+ clip_count[i] += sum(clipcnt.values())
52
+
53
+ # accumulate r & c
54
+ bestmatch = [1000, 1000]
55
+ for ref in refs:
56
+ if bestmatch[0] == 0:
57
+ break
58
+ diff = abs(len(ref) - len(hyp))
59
+ if diff < bestmatch[0]:
60
+ bestmatch[0] = diff
61
+ bestmatch[1] = len(ref)
62
+ r += bestmatch[1]
63
+ c += len(hyp)
64
+
65
+ # computing bleu score
66
+ p0 = 1e-7
67
+ bp = 1 if c > r else math.exp(1 - float(r) / float(c))
68
+ p_ns = [float(clip_count[i]) / float(count[i] + p0) + p0 for i in range(4)]
69
+ s = math.fsum(w * math.log(p_n) for w, p_n in zip(weights, p_ns) if p_n)
70
+ bleu = bp * math.exp(s)
71
+ return bleu * 100
72
+
73
+
74
+ class MultiWozEvaluator(object):
75
+ def __init__(self, reader):
76
+ self.reader = reader
77
+ self.domains = ontology.all_domains
78
+ self.domain_files = self.reader.domain_files
79
+ self.all_data = self.reader.data
80
+ self.test_data = self.reader.test
81
+
82
+ self.bleu_scorer = BLEUScorer()
83
+
84
+ self.all_info_slot = []
85
+ for d, s_list in ontology.informable_slots.items():
86
+ for s in s_list:
87
+ self.all_info_slot.append(d + "-" + s)
88
+
89
+ # only evaluate these slots for dialog success
90
+ self.requestables = ["phone", "address", "postcode", "reference", "id"]
91
+
92
+ def pack_dial(self, data):
93
+ dials = {}
94
+ for turn in data:
95
+ dial_id = turn["dial_id"]
96
+ if dial_id not in dials:
97
+ dials[dial_id] = []
98
+ dials[dial_id].append(turn)
99
+ return dials
100
+
101
+ def run_metrics(self, data):
102
+ if "all" in cfg.exp_domains:
103
+ metric_results = []
104
+ metric_result = self._get_metric_results(data)
105
+ metric_results.append(metric_result)
106
+
107
+ if cfg.eval_per_domain:
108
+ # all domain experiments, sub domain evaluation
109
+ domains = [d + "_single" for d in ontology.all_domains]
110
+ domains = domains + [
111
+ "restaurant_train",
112
+ "restaurant_hotel",
113
+ "restaurant_attraction",
114
+ "hotel_train",
115
+ "hotel_attraction",
116
+ "attraction_train",
117
+ "restaurant_hotel_taxi",
118
+ "restaurant_attraction_taxi",
119
+ "hotel_attraction_taxi",
120
+ ]
121
+ for domain in domains:
122
+ file_list = self.domain_files.get(domain, [])
123
+ if not file_list:
124
+ print("No sub domain [%s]" % domain)
125
+ metric_result = self._get_metric_results(data, domain, file_list)
126
+ if metric_result:
127
+ metric_results.append(metric_result)
128
+
129
+ else:
130
+ # sub domain experiments
131
+ metric_results = []
132
+ for domain, file_list in self.domain_files.items():
133
+ if domain not in cfg.exp_domains:
134
+ continue
135
+ metric_result = self._get_metric_results(data, domain, file_list)
136
+ if metric_result:
137
+ metric_results.append(metric_result)
138
+
139
+ return metric_results
140
+
141
+ def validation_metric(self, data):
142
+ bleu = self.bleu_metric(data)
143
+ # accu_single_dom, accu_multi_dom, multi_dom_num = self.domain_eval(data)
144
+ success, match, req_offer_counts, dial_num = self.context_to_response_eval(
145
+ data, same_eval_as_cambridge=cfg.same_eval_as_cambridge
146
+ )
147
+ return bleu, success, match
148
+
149
+ def _get_metric_results(self, data, domain="all", file_list=None):
150
+ metric_result = {"domain": domain}
151
+ bleu = self.bleu_metric(data, file_list)
152
+ if cfg.bspn_mode == "bspn" or cfg.enable_dst:
153
+ (
154
+ jg,
155
+ slot_f1,
156
+ slot_acc,
157
+ slot_cnt,
158
+ slot_corr,
159
+ ) = self.dialog_state_tracking_eval(data, file_list)
160
+ jg_nn, sf1_nn, sac_nn, _, _ = self.dialog_state_tracking_eval(
161
+ data, file_list, no_name=True, no_book=False
162
+ )
163
+ jg_nb, sf1_nb, sac_nb, _, _ = self.dialog_state_tracking_eval(
164
+ data, file_list, no_name=False, no_book=True
165
+ )
166
+ jg_nnnb, sf1_nnnb, sac_nnnb, _, _ = self.dialog_state_tracking_eval(
167
+ data, file_list, no_name=True, no_book=True
168
+ )
169
+ metric_result.update(
170
+ {"joint_goal": jg, "slot_acc": slot_acc, "slot_f1": slot_f1}
171
+ )
172
+ if cfg.bspn_mode == "bsdx":
173
+ (
174
+ jg_,
175
+ slot_f1_,
176
+ slot_acc_,
177
+ slot_cnt,
178
+ slot_corr,
179
+ ) = self.dialog_state_tracking_eval(data, file_list, bspn_mode="bsdx")
180
+ jg_nn_, sf1_nn_, sac_nn_, _, _ = self.dialog_state_tracking_eval(
181
+ data, file_list, bspn_mode="bsdx", no_name=True, no_book=False
182
+ )
183
+ metric_result.update(
184
+ {
185
+ "joint_goal_delex": jg_,
186
+ "slot_acc_delex": slot_acc_,
187
+ "slot_f1_delex": slot_f1_,
188
+ }
189
+ )
190
+
191
+ info_slots_acc = {}
192
+ for slot in slot_cnt:
193
+ correct = slot_corr.get(slot, 0)
194
+ info_slots_acc[slot] = correct / slot_cnt[slot] * 100
195
+ info_slots_acc = OrderedDict(sorted(info_slots_acc.items(), key=lambda x: x[1]))
196
+
197
+ act_f1 = self.aspn_eval(data, file_list)
198
+ avg_act_num, avg_diverse_score = self.multi_act_eval(data, file_list)
199
+ accu_single_dom, accu_multi_dom, multi_dom_num = self.domain_eval(
200
+ data, file_list
201
+ )
202
+
203
+ success, match, req_offer_counts, dial_num = self.context_to_response_eval(
204
+ data, file_list, same_eval_as_cambridge=cfg.same_eval_as_cambridge
205
+ )
206
+ req_slots_acc = {}
207
+ for req in self.requestables:
208
+ acc = req_offer_counts[req + "_offer"] / (
209
+ req_offer_counts[req + "_total"] + 1e-10
210
+ )
211
+ req_slots_acc[req] = acc * 100
212
+ req_slots_acc = OrderedDict(sorted(req_slots_acc.items(), key=lambda x: x[1]))
213
+
214
+ if dial_num:
215
+ metric_result.update(
216
+ {
217
+ "act_f1": act_f1,
218
+ "success": success,
219
+ "match": match,
220
+ "bleu": bleu,
221
+ "req_slots_acc": req_slots_acc,
222
+ "info_slots_acc": info_slots_acc,
223
+ "dial_num": dial_num,
224
+ "accu_single_dom": accu_single_dom,
225
+ "accu_multi_dom": accu_multi_dom,
226
+ "avg_act_num": avg_act_num,
227
+ "avg_diverse_score": avg_diverse_score,
228
+ }
229
+ )
230
+ if domain == "all":
231
+ logging.info(
232
+ "-------------------------- All DOMAINS --------------------------"
233
+ )
234
+ else:
235
+ logging.info(
236
+ "-------------------------- %s (# %d) -------------------------- "
237
+ % (domain.upper(), dial_num)
238
+ )
239
+ if cfg.bspn_mode == "bspn" or cfg.enable_dst:
240
+ logging.info(
241
+ "[DST] joint goal:%2.1f slot acc: %2.1f slot f1: %2.1f act f1: %2.1f"
242
+ % (jg, slot_acc, slot_f1, act_f1)
243
+ )
244
+ logging.info(
245
+ "[DST] [not eval name slots] joint goal:%2.1f slot acc: %2.1f slot f1: %2.1f"
246
+ % (jg_nn, sac_nn, sf1_nn)
247
+ )
248
+ logging.info(
249
+ "[DST] [not eval book slots] joint goal:%2.1f slot acc: %2.1f slot f1: %2.1f"
250
+ % (jg_nb, sac_nb, sf1_nb)
251
+ )
252
+ logging.info(
253
+ "[DST] [not eval name & book slots] joint goal:%2.1f slot acc: %2.1f slot f1: %2.1f"
254
+ % (jg_nnnb, sac_nnnb, sf1_nnnb)
255
+ )
256
+ if cfg.bspn_mode == "bsdx":
257
+ logging.info(
258
+ "[BDX] joint goal:%2.1f slot acc: %2.1f slot f1: %2.1f act f1: %2.1f"
259
+ % (jg_, slot_acc_, slot_f1_, act_f1)
260
+ )
261
+ logging.info(
262
+ "[BDX] [not eval name slots] joint goal:%2.1f slot acc: %2.1f slot f1: %2.1f"
263
+ % (jg_nn_, sac_nn_, sf1_nn_)
264
+ )
265
+ logging.info(
266
+ "[CTR] match: %2.1f success: %2.1f bleu: %2.1f"
267
+ % (match, success, bleu)
268
+ )
269
+ logging.info(
270
+ "[CTR] "
271
+ + "; ".join(
272
+ ["%s: %2.1f" % (req, acc) for req, acc in req_slots_acc.items()]
273
+ )
274
+ )
275
+ logging.info(
276
+ "[DOM] accuracy: single %2.1f / multi: %2.1f (%d)"
277
+ % (accu_single_dom, accu_multi_dom, multi_dom_num)
278
+ )
279
+ if self.reader.multi_acts_record is not None:
280
+ logging.info(
281
+ "[MA] avg acts num %2.1f avg slots num: %2.1f "
282
+ % (avg_act_num, avg_diverse_score)
283
+ )
284
+ return metric_result
285
+ else:
286
+ return None
287
+
288
+ def bleu_metric(self, data, eval_dial_list=None):
289
+ gen, truth = [], []
290
+ for row in data:
291
+ if eval_dial_list and row["dial_id"] + ".json" not in eval_dial_list:
292
+ continue
293
+ gen.append(row["resp_gen"])
294
+ truth.append(row["resp"])
295
+ wrap_generated = [[_] for _ in gen]
296
+ wrap_truth = [[_] for _ in truth]
297
+ if gen and truth:
298
+ sc = self.bleu_scorer.score(zip(wrap_generated, wrap_truth))
299
+ else:
300
+ sc = 0.0
301
+ return sc
302
+
303
+ def value_similar(self, a, b):
304
+ return True if a == b else False
305
+
306
+ # the value equal condition used in "Sequicity" is too loose
307
+ if (
308
+ a in b
309
+ or b in a
310
+ or a.split()[0] == b.split()[0]
311
+ or a.split()[-1] == b.split()[-1]
312
+ ):
313
+ return True
314
+ return False
315
+
316
+ def _bspn_to_dict(self, bspn, no_name=False, no_book=False, bspn_mode="bspn"):
317
+ constraint_dict = self.reader.bspan_to_constraint_dict(
318
+ bspn, bspn_mode=bspn_mode
319
+ )
320
+ constraint_dict_flat = {}
321
+ for domain, cons in constraint_dict.items():
322
+ for s, v in cons.items():
323
+ key = domain + "-" + s
324
+ if no_name and s == "name":
325
+ continue
326
+ if no_book:
327
+ if s in ["people", "stay"] or key in [
328
+ "hotel-day",
329
+ "restaurant-day",
330
+ "restaurant-time",
331
+ ]:
332
+ continue
333
+ constraint_dict_flat[key] = v
334
+ return constraint_dict_flat
335
+
336
+ def _constraint_compare(
337
+ self, truth_cons, gen_cons, slot_appear_num=None, slot_correct_num=None
338
+ ):
339
+ tp, fp, fn = 0, 0, 0
340
+ false_slot = []
341
+ for slot in gen_cons:
342
+ v_gen = gen_cons[slot]
343
+ if slot in truth_cons and self.value_similar(
344
+ v_gen, truth_cons[slot]
345
+ ): # v_truth = truth_cons[slot]
346
+ tp += 1
347
+ if slot_correct_num is not None:
348
+ slot_correct_num[slot] = (
349
+ 1
350
+ if not slot_correct_num.get(slot)
351
+ else slot_correct_num.get(slot) + 1
352
+ )
353
+ else:
354
+ fp += 1
355
+ false_slot.append(slot)
356
+ for slot in truth_cons:
357
+ v_truth = truth_cons[slot]
358
+ if slot_appear_num is not None:
359
+ slot_appear_num[slot] = (
360
+ 1
361
+ if not slot_appear_num.get(slot)
362
+ else slot_appear_num.get(slot) + 1
363
+ )
364
+ if slot not in gen_cons or not self.value_similar(v_truth, gen_cons[slot]):
365
+ fn += 1
366
+ false_slot.append(slot)
367
+ acc = len(self.all_info_slot) - fp - fn
368
+ return tp, fp, fn, acc, list(set(false_slot))
369
+
370
+ def domain_eval(self, data, eval_dial_list=None):
371
+ dials = self.pack_dial(data)
372
+ corr_single, total_single, corr_multi, total_multi = 0, 0, 0, 0
373
+
374
+ dial_num = 0
375
+ for dial_id in dials:
376
+ if eval_dial_list and dial_id + ".json" not in eval_dial_list:
377
+ continue
378
+ dial_num += 1
379
+ dial = dials[dial_id]
380
+ wrong_pred = []
381
+
382
+ prev_constraint_dict = {}
383
+ prev_turn_domain = ["general"]
384
+
385
+ for turn_num, turn in enumerate(dial):
386
+ if turn_num == 0:
387
+ continue
388
+ true_domains = self.reader.dspan_to_domain(turn["dspn"])
389
+ if cfg.enable_dspn:
390
+ pred_domains = self.reader.dspan_to_domain(turn["dspn_gen"])
391
+ else:
392
+ turn_dom_bs = []
393
+ if (
394
+ cfg.enable_bspn
395
+ and not cfg.use_true_bspn_for_ctr_eval
396
+ and (cfg.bspn_mode == "bspn" or cfg.enable_dst)
397
+ ):
398
+ constraint_dict = self.reader.bspan_to_constraint_dict(
399
+ turn["bspn_gen"]
400
+ )
401
+ else:
402
+ constraint_dict = self.reader.bspan_to_constraint_dict(
403
+ turn["bspn"]
404
+ )
405
+ for domain in constraint_dict:
406
+ if domain not in prev_constraint_dict:
407
+ turn_dom_bs.append(domain)
408
+ elif prev_constraint_dict[domain] != constraint_dict[domain]:
409
+ turn_dom_bs.append(domain)
410
+ aspn = "aspn" if not cfg.enable_aspn else "aspn_gen"
411
+ turn_dom_da = []
412
+ for a in turn[aspn].split():
413
+ if a[1:-1] in ontology.all_domains + ["general"]:
414
+ turn_dom_da.append(a[1:-1])
415
+
416
+ # get turn domain
417
+ turn_domain = turn_dom_bs
418
+ for dom in turn_dom_da:
419
+ if dom != "booking" and dom not in turn_domain:
420
+ turn_domain.append(dom)
421
+ if not turn_domain:
422
+ turn_domain = prev_turn_domain
423
+ if len(turn_domain) == 2 and "general" in turn_domain:
424
+ turn_domain.remove("general")
425
+ if len(turn_domain) == 2:
426
+ if (
427
+ len(prev_turn_domain) == 1
428
+ and prev_turn_domain[0] == turn_domain[1]
429
+ ):
430
+ turn_domain = turn_domain[::-1]
431
+ prev_turn_domain = copy.deepcopy(turn_domain)
432
+ prev_constraint_dict = copy.deepcopy(constraint_dict)
433
+
434
+ turn["dspn_gen"] = " ".join(["[" + d + "]" for d in turn_domain])
435
+ pred_domains = {}
436
+ for d in turn_domain:
437
+ pred_domains["[" + d + "]"] = 1
438
+
439
+ if len(true_domains) == 1:
440
+ total_single += 1
441
+ if pred_domains == true_domains:
442
+ corr_single += 1
443
+ else:
444
+ wrong_pred.append(str(turn["turn_num"]))
445
+ turn["wrong_domain"] = "x"
446
+ else:
447
+ total_multi += 1
448
+ if pred_domains == true_domains:
449
+ corr_multi += 1
450
+ else:
451
+ wrong_pred.append(str(turn["turn_num"]))
452
+ turn["wrong_domain"] = "x"
453
+
454
+ # dialog inform metric record
455
+ dial[0]["wrong_domain"] = " ".join(wrong_pred)
456
+ accu_single = corr_single / (total_single + 1e-10)
457
+ accu_multi = corr_multi / (total_multi + 1e-10)
458
+ return accu_single * 100, accu_multi * 100, total_multi
459
+
460
+ def dialog_state_tracking_eval(
461
+ self, data, eval_dial_list=None, bspn_mode="bspn", no_name=False, no_book=False
462
+ ):
463
+ dials = self.pack_dial(data)
464
+ total_turn, joint_match, total_tp, total_fp, total_fn, total_acc = (
465
+ 0,
466
+ 0,
467
+ 0,
468
+ 0,
469
+ 0,
470
+ 0,
471
+ )
472
+ slot_appear_num, slot_correct_num = {}, {}
473
+ dial_num = 0
474
+ for dial_id in dials:
475
+ if eval_dial_list and dial_id + ".json" not in eval_dial_list:
476
+ continue
477
+ dial_num += 1
478
+ dial = dials[dial_id]
479
+ missed_jg_turn_id = []
480
+ for turn_num, turn in enumerate(dial):
481
+ if turn_num == 0:
482
+ continue
483
+ gen_cons = self._bspn_to_dict(
484
+ turn[bspn_mode + "_gen"],
485
+ no_name=no_name,
486
+ no_book=no_book,
487
+ bspn_mode=bspn_mode,
488
+ )
489
+ truth_cons = self._bspn_to_dict(
490
+ turn[bspn_mode],
491
+ no_name=no_name,
492
+ no_book=no_book,
493
+ bspn_mode=bspn_mode,
494
+ )
495
+
496
+ if truth_cons == gen_cons:
497
+ joint_match += 1
498
+ else:
499
+ missed_jg_turn_id.append(str(turn["turn_num"]))
500
+
501
+ if eval_dial_list is None:
502
+ tp, fp, fn, acc, false_slots = self._constraint_compare(
503
+ truth_cons, gen_cons, slot_appear_num, slot_correct_num
504
+ )
505
+ else:
506
+ tp, fp, fn, acc, false_slots = self._constraint_compare(
507
+ truth_cons,
508
+ gen_cons,
509
+ )
510
+
511
+ total_tp += tp
512
+ total_fp += fp
513
+ total_fn += fn
514
+ total_acc += acc
515
+ total_turn += 1
516
+ if not no_name and not no_book:
517
+ turn["wrong_inform"] = "; ".join(
518
+ false_slots
519
+ ) # turn inform metric record
520
+
521
+ # dialog inform metric record
522
+ if not no_name and not no_book:
523
+ dial[0]["wrong_inform"] = " ".join(missed_jg_turn_id)
524
+
525
+ precision = total_tp / (total_tp + total_fp + 1e-10)
526
+ recall = total_tp / (total_tp + total_fn + 1e-10)
527
+ f1 = 2 * precision * recall / (precision + recall + 1e-10) * 100
528
+ accuracy = total_acc / (total_turn * len(self.all_info_slot) + 1e-10) * 100
529
+ joint_goal = joint_match / (total_turn + 1e-10) * 100
530
+
531
+ return joint_goal, f1, accuracy, slot_appear_num, slot_correct_num
532
+
533
+ def aspn_eval(self, data, eval_dial_list=None):
534
+ def _get_tp_fp_fn(label_list, pred_list):
535
+ tp = len([t for t in pred_list if t in label_list])
536
+ fp = max(0, len(pred_list) - tp)
537
+ fn = max(0, len(label_list) - tp)
538
+ return tp, fp, fn
539
+
540
+ dials = self.pack_dial(data)
541
+ total_tp, total_fp, total_fn = 0, 0, 0
542
+
543
+ dial_num = 0
544
+ for dial_id in dials:
545
+ if eval_dial_list and dial_id + ".json" not in eval_dial_list:
546
+ continue
547
+ dial_num += 1
548
+ dial = dials[dial_id]
549
+ wrong_act = []
550
+ for turn_num, turn in enumerate(dial):
551
+ if turn_num == 0:
552
+ continue
553
+ if cfg.same_eval_act_f1_as_hdsa:
554
+ pred_acts, true_acts = {}, {}
555
+ for t in turn["aspn_gen"]:
556
+ pred_acts[t] = 1
557
+ for t in turn["aspn"]:
558
+ true_acts[t] = 1
559
+ tp, fp, fn = _get_tp_fp_fn(true_acts, pred_acts)
560
+ else:
561
+ pred_acts = self.reader.aspan_to_act_list(turn["aspn_gen"])
562
+ true_acts = self.reader.aspan_to_act_list(turn["aspn"])
563
+ tp, fp, fn = _get_tp_fp_fn(true_acts, pred_acts)
564
+ if fp + fn != 0:
565
+ wrong_act.append(str(turn["turn_num"]))
566
+ turn["wrong_act"] = "x"
567
+
568
+ total_tp += tp
569
+ total_fp += fp
570
+ total_fn += fn
571
+
572
+ dial[0]["wrong_act"] = " ".join(wrong_act)
573
+ precision = total_tp / (total_tp + total_fp + 1e-10)
574
+ recall = total_tp / (total_tp + total_fn + 1e-10)
575
+ f1 = 2 * precision * recall / (precision + recall + 1e-10)
576
+
577
+ return f1 * 100
578
+
579
+ def multi_act_eval(self, data, eval_dial_list=None):
580
+
581
+ dials = self.pack_dial(data)
582
+ total_act_num, total_slot_num = 0, 0
583
+
584
+ dial_num = 0
585
+ turn_count = 0
586
+ for dial_id in dials:
587
+ if eval_dial_list and dial_id + ".json" not in eval_dial_list:
588
+ continue
589
+ dial_num += 1
590
+ dial = dials[dial_id]
591
+ for turn_num, turn in enumerate(dial):
592
+ if turn_num == 0:
593
+ continue
594
+ target = (
595
+ turn["multi_act_gen"]
596
+ if self.reader.multi_acts_record is not None
597
+ else turn["aspn_gen"]
598
+ )
599
+
600
+ # diversity
601
+ act_collect, slot_collect = {}, {}
602
+ act_type_collect = {}
603
+ slot_score = 0
604
+ for act_str in target.split(" | "):
605
+ pred_acts = self.reader.aspan_to_act_list(act_str)
606
+ act_type = ""
607
+ for act in pred_acts:
608
+ d, a, s = act.split("-")
609
+ if d + "-" + a not in act_collect:
610
+ act_collect[d + "-" + a] = {s: 1}
611
+ slot_score += 1
612
+ act_type += d + "-" + a + ";"
613
+ elif s not in act_collect:
614
+ act_collect[d + "-" + a][s] = 1
615
+ slot_score += 1
616
+ slot_collect[s] = 1
617
+ act_type_collect[act_type] = 1
618
+ total_act_num += len(act_collect)
619
+ total_slot_num += len(slot_collect)
620
+ turn_count += 1
621
+
622
+ total_act_num = total_act_num / (float(turn_count) + 1e-10)
623
+ total_slot_num = total_slot_num / (float(turn_count) + 1e-10)
624
+ return total_act_num, total_slot_num
625
+
626
+ def context_to_response_eval(
627
+ self, data, eval_dial_list=None, same_eval_as_cambridge=False
628
+ ):
629
+ dials = self.pack_dial(data)
630
+ counts = {}
631
+ for req in self.requestables:
632
+ counts[req + "_total"] = 0
633
+ counts[req + "_offer"] = 0
634
+
635
+ dial_num, successes, matches = 0, 0, 0
636
+
637
+ for dial_id in dials:
638
+ if eval_dial_list and dial_id + ".json" not in eval_dial_list:
639
+ continue
640
+ dial = dials[dial_id]
641
+ reqs = {}
642
+ goal = {}
643
+ if ".json" not in dial_id and ".json" in list(self.all_data.keys())[0]:
644
+ dial_id = dial_id + ".json"
645
+ for domain in ontology.all_domains:
646
+ if self.all_data[dial_id]["goal"].get(domain):
647
+ true_goal = self.all_data[dial_id]["goal"]
648
+ goal = self._parseGoal(goal, true_goal, domain)
649
+ # print(goal)
650
+ for domain in goal.keys():
651
+ reqs[domain] = goal[domain]["requestable"]
652
+
653
+ # print('\n',dial_id)
654
+ success, match, stats, counts = self._evaluateGeneratedDialogue(
655
+ dial, goal, reqs, counts, same_eval_as_cambridge=same_eval_as_cambridge
656
+ )
657
+
658
+ successes += success
659
+ matches += match
660
+ dial_num += 1
661
+
662
+ # for domain in gen_stats.keys():
663
+ # gen_stats[domain][0] += stats[domain][0]
664
+ # gen_stats[domain][1] += stats[domain][1]
665
+ # gen_stats[domain][2] += stats[domain][2]
666
+
667
+ # if 'SNG' in filename:
668
+ # for domain in gen_stats.keys():
669
+ # sng_gen_stats[domain][0] += stats[domain][0]
670
+ # sng_gen_stats[domain][1] += stats[domain][1]
671
+ # sng_gen_stats[domain][2] += stats[domain][2]
672
+
673
+ # self.logger.info(report)
674
+ succ_rate = successes / (float(dial_num) + 1e-10) * 100
675
+ match_rate = matches / (float(dial_num) + 1e-10) * 100
676
+ return succ_rate, match_rate, counts, dial_num
677
+
678
+ def _evaluateGeneratedDialogue(
679
+ self,
680
+ dialog,
681
+ goal,
682
+ real_requestables,
683
+ counts,
684
+ soft_acc=False,
685
+ same_eval_as_cambridge=False,
686
+ ):
687
+ """Evaluates the dialogue created by the model.
688
+ First we load the user goal of the dialogue, then for each turn
689
+ generated by the system we look for key-words.
690
+ For the Inform rate we look whether the entity was proposed.
691
+ For the Success rate we look for requestables slots"""
692
+ # for computing corpus success 'id'
693
+ requestables = self.requestables
694
+
695
+ # CHECK IF MATCH HAPPENED
696
+ provided_requestables = {}
697
+ venue_offered = {}
698
+ domains_in_goal = []
699
+ bspans = {}
700
+
701
+ for domain in goal.keys():
702
+ venue_offered[domain] = []
703
+ provided_requestables[domain] = []
704
+ domains_in_goal.append(domain)
705
+
706
+ for t, turn in enumerate(dialog):
707
+ if t == 0:
708
+ continue
709
+ sent_t = turn["resp_gen"]
710
+ # sent_t = turn['resp']
711
+ for domain in goal.keys():
712
+ # for computing success
713
+ if same_eval_as_cambridge:
714
+ # [restaurant_name], [hotel_name] instead of [value_name]
715
+ if cfg.use_true_domain_for_ctr_eval:
716
+ dom_pred = [d[1:-1] for d in turn["dspn"].split()]
717
+ else:
718
+ dom_pred = [d[1:-1] for d in turn["dspn_gen"].split()]
719
+ # else:
720
+ # raise NotImplementedError('Just use true domain label')
721
+ if domain not in dom_pred: # fail
722
+ continue
723
+ if "[value_name]" in sent_t or "[value_id]" in sent_t:
724
+ if domain in ["restaurant", "hotel", "attraction", "train"]:
725
+ # HERE YOU CAN PUT YOUR BELIEF STATE ESTIMATION
726
+ if (
727
+ not cfg.use_true_curr_bspn
728
+ and not cfg.use_true_bspn_for_ctr_eval
729
+ ):
730
+ bspn = turn["bspn_gen"]
731
+ else:
732
+ bspn = turn["bspn"]
733
+ # bspn = turn['bspn']
734
+
735
+ constraint_dict = self.reader.bspan_to_constraint_dict(bspn)
736
+ if constraint_dict.get(domain):
737
+ venues = self.reader.db.queryJsons(
738
+ domain, constraint_dict[domain], return_name=True
739
+ )
740
+ else:
741
+ venues = []
742
+
743
+ # if venue has changed
744
+ if len(venue_offered[domain]) == 0 and venues:
745
+ # venue_offered[domain] = random.sample(venues, 1)
746
+ venue_offered[domain] = venues
747
+ bspans[domain] = constraint_dict[domain]
748
+ else:
749
+ # flag = False
750
+ # for ven in venues:
751
+ # if venue_offered[domain][0] == ven:
752
+ # flag = True
753
+ # break
754
+ # if not flag and venues:
755
+ flag = False
756
+ for ven in venues:
757
+ if ven not in venue_offered[domain]:
758
+ # if ven not in venue_offered[domain]:
759
+ flag = True
760
+ break
761
+ # if flag and venues:
762
+ if (
763
+ flag and venues
764
+ ): # sometimes there are no results so sample won't work
765
+ # print venues
766
+ # venue_offered[domain] = random.sample(venues, 1)
767
+ venue_offered[domain] = venues
768
+ bspans[domain] = constraint_dict[domain]
769
+ else: # not limited so we can provide one
770
+ venue_offered[domain] = "[value_name]"
771
+
772
+ # ATTENTION: assumption here - we didn't provide phone or address twice! etc
773
+ for requestable in requestables:
774
+ if requestable == "reference":
775
+ if "[value_reference]" in sent_t:
776
+ if (
777
+ "booked" in turn["pointer"] or "ok" in turn["pointer"]
778
+ ): # if pointer was allowing for that?
779
+ provided_requestables[domain].append("reference")
780
+ # provided_requestables[domain].append('reference')
781
+ else:
782
+ if "[value_" + requestable + "]" in sent_t:
783
+ provided_requestables[domain].append(requestable)
784
+
785
+ # if name was given in the task
786
+ for domain in goal.keys():
787
+ # if name was provided for the user, the match is being done automatically
788
+ if "name" in goal[domain]["informable"]:
789
+ venue_offered[domain] = "[value_name]"
790
+
791
+ # special domains - entity does not need to be provided
792
+ if domain in ["taxi", "police", "hospital"]:
793
+ venue_offered[domain] = "[value_name]"
794
+
795
+ if domain == "train":
796
+ if (
797
+ not venue_offered[domain]
798
+ and "id" not in goal[domain]["requestable"]
799
+ ):
800
+ venue_offered[domain] = "[value_name]"
801
+
802
+ """
803
+ Given all inform and requestable slots
804
+ we go through each domain from the user goal
805
+ and check whether right entity was provided and
806
+ all requestable slots were given to the user.
807
+ The dialogue is successful if that's the case for all domains.
808
+ """
809
+ # HARD EVAL
810
+ stats = {
811
+ "restaurant": [0, 0, 0],
812
+ "hotel": [0, 0, 0],
813
+ "attraction": [0, 0, 0],
814
+ "train": [0, 0, 0],
815
+ "taxi": [0, 0, 0],
816
+ "hospital": [0, 0, 0],
817
+ "police": [0, 0, 0],
818
+ }
819
+
820
+ match = 0
821
+ success = 0
822
+ # MATCH
823
+ for domain in goal.keys():
824
+ match_stat = 0
825
+ if domain in ["restaurant", "hotel", "attraction", "train"]:
826
+ goal_venues = self.reader.db.queryJsons(
827
+ domain, goal[domain]["informable"], return_name=True
828
+ )
829
+ if (
830
+ type(venue_offered[domain]) is str
831
+ and "_name" in venue_offered[domain]
832
+ ):
833
+ match += 1
834
+ match_stat = 1
835
+ elif (
836
+ len(venue_offered[domain]) > 0
837
+ and len(set(venue_offered[domain]) & set(goal_venues)) > 0
838
+ ):
839
+ match += 1
840
+ match_stat = 1
841
+ else:
842
+ if "_name]" in venue_offered[domain]:
843
+ match += 1
844
+ match_stat = 1
845
+
846
+ stats[domain][0] = match_stat
847
+ stats[domain][2] = 1
848
+
849
+ if soft_acc:
850
+ match = float(match) / len(goal.keys())
851
+ else:
852
+ if match == len(goal.keys()):
853
+ match = 1.0
854
+ else:
855
+ match = 0.0
856
+
857
+ for domain in domains_in_goal:
858
+ for request in real_requestables[domain]:
859
+ counts[request + "_total"] += 1
860
+ if request in provided_requestables[domain]:
861
+ counts[request + "_offer"] += 1
862
+
863
+ # SUCCESS
864
+ if match == 1.0:
865
+ for domain in domains_in_goal:
866
+ success_stat = 0
867
+ domain_success = 0
868
+ if len(real_requestables[domain]) == 0:
869
+ success += 1
870
+ success_stat = 1
871
+ stats[domain][1] = success_stat
872
+ continue
873
+ # if values in sentences are super set of requestables
874
+ # for request in set(provided_requestables[domain]):
875
+ # if request in real_requestables[domain]:
876
+ # domain_success += 1
877
+ for request in real_requestables[domain]:
878
+ if request in provided_requestables[domain]:
879
+ domain_success += 1
880
+
881
+ # if domain_success >= len(real_requestables[domain]):
882
+ if domain_success == len(real_requestables[domain]):
883
+ success += 1
884
+ success_stat = 1
885
+
886
+ stats[domain][1] = success_stat
887
+
888
+ # final eval
889
+ if soft_acc:
890
+ success = float(success) / len(real_requestables)
891
+ else:
892
+ if success >= len(real_requestables):
893
+ success = 1
894
+ else:
895
+ success = 0
896
+
897
+ return success, match, stats, counts
898
+
899
+ def _parseGoal(self, goal, true_goal, domain):
900
+ """Parses user goal into dictionary format."""
901
+ goal[domain] = {}
902
+ goal[domain] = {"informable": {}, "requestable": [], "booking": []}
903
+ if "info" in true_goal[domain]:
904
+ if domain == "train":
905
+ # we consider dialogues only where train had to be booked!
906
+ if "book" in true_goal[domain]:
907
+ goal[domain]["requestable"].append("reference")
908
+ if "reqt" in true_goal[domain]:
909
+ if "id" in true_goal[domain]["reqt"]:
910
+ goal[domain]["requestable"].append("id")
911
+ else:
912
+ if "reqt" in true_goal[domain]:
913
+ for s in true_goal[domain]["reqt"]: # addtional requests:
914
+ if s in ["phone", "address", "postcode", "reference", "id"]:
915
+ # ones that can be easily delexicalized
916
+ goal[domain]["requestable"].append(s)
917
+ if "book" in true_goal[domain]:
918
+ goal[domain]["requestable"].append("reference")
919
+
920
+ for s, v in true_goal[domain]["info"].items():
921
+ s_, v_ = clean_slot_values(domain, s, v)
922
+ if len(v_.split()) > 1:
923
+ v_ = " ".join([token.text for token in self.reader.nlp(v_)]).strip()
924
+ goal[domain]["informable"][s_] = v_
925
+
926
+ if "book" in true_goal[domain]:
927
+ goal[domain]["booking"] = true_goal[domain]["book"]
928
+ return goal
929
+
930
+
931
+ if __name__ == "__main__":
932
+ pass
src/crazyneuraluser/UBAR_code/ontology.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ all_domains = [
2
+ "restaurant",
3
+ "hotel",
4
+ "attraction",
5
+ "train",
6
+ "taxi",
7
+ "police",
8
+ "hospital",
9
+ ]
10
+ db_domains = ["restaurant", "hotel", "attraction", "train"]
11
+
12
+ # original slot names in goals (including booking slots)
13
+ # requestable_slots_in_goals = {
14
+ # "taxi": ["car type", "phone"],
15
+ # "police": ["postcode", "address", "phone"],
16
+ # "hospital": ["address", "phone", "postcode"],
17
+ # "hotel": ["address", "postcode", "internet", "phone", "parking",
18
+ # "type", "pricerange", "stars", "area", "reference"],
19
+ # "attraction": ["entrance fee", "type", "address", "postcode", "phone", "area", "reference"],
20
+ # "train": ["duration", "leaveat", "price", "arriveby", "id", "reference"],
21
+ # "restaurant": ["phone", "postcode", "address", "pricerange", "food", "area", "reference"]
22
+ # }
23
+
24
+ # informable_slots_in_goals = {
25
+ # "taxi": ["leaveat", "destination", "departure", "arriveby"],
26
+ # "police": [],
27
+ # "hospital": ["department"],
28
+ # "hotel": ["type", "parking", "pricerange", "internet", "stay", "day", "people", "area", "stars", "name"],
29
+ # "attraction": ["area", "type", "name"],
30
+ # "train": ["destination", "day", "arriveby", "departure", "people", "leaveat"],
31
+ # "restaurant": ["food", "pricerange", "area", "name", "time", "day", "people"]
32
+ # }
33
+
34
+ normlize_slot_names = {
35
+ "car type": "car",
36
+ "entrance fee": "price",
37
+ "duration": "time",
38
+ "leaveat": "leave",
39
+ "arriveby": "arrive",
40
+ "trainid": "id",
41
+ }
42
+
43
+ requestable_slots = {
44
+ "taxi": ["car", "phone"],
45
+ "police": ["postcode", "address", "phone"],
46
+ "hospital": ["address", "phone", "postcode"],
47
+ "hotel": [
48
+ "address",
49
+ "postcode",
50
+ "internet",
51
+ "phone",
52
+ "parking",
53
+ "type",
54
+ "pricerange",
55
+ "stars",
56
+ "area",
57
+ "reference",
58
+ ],
59
+ "attraction": [
60
+ "price",
61
+ "type",
62
+ "address",
63
+ "postcode",
64
+ "phone",
65
+ "area",
66
+ "reference",
67
+ ],
68
+ "train": ["time", "leave", "price", "arrive", "id", "reference"],
69
+ "restaurant": [
70
+ "phone",
71
+ "postcode",
72
+ "address",
73
+ "pricerange",
74
+ "food",
75
+ "area",
76
+ "reference",
77
+ ],
78
+ }
79
+ all_reqslot = [
80
+ "car",
81
+ "address",
82
+ "postcode",
83
+ "phone",
84
+ "internet",
85
+ "parking",
86
+ "type",
87
+ "pricerange",
88
+ "food",
89
+ "stars",
90
+ "area",
91
+ "reference",
92
+ "time",
93
+ "leave",
94
+ "price",
95
+ "arrive",
96
+ "id",
97
+ ]
98
+ # count: 17
99
+
100
+ informable_slots = {
101
+ "taxi": ["leave", "destination", "departure", "arrive"],
102
+ "police": [],
103
+ "hospital": ["department"],
104
+ "hotel": [
105
+ "type",
106
+ "parking",
107
+ "pricerange",
108
+ "internet",
109
+ "stay",
110
+ "day",
111
+ "people",
112
+ "area",
113
+ "stars",
114
+ "name",
115
+ ],
116
+ "attraction": ["area", "type", "name"],
117
+ "train": ["destination", "day", "arrive", "departure", "people", "leave"],
118
+ "restaurant": ["food", "pricerange", "area", "name", "time", "day", "people"],
119
+ }
120
+ all_infslot = [
121
+ "type",
122
+ "parking",
123
+ "pricerange",
124
+ "internet",
125
+ "stay",
126
+ "day",
127
+ "people",
128
+ "area",
129
+ "stars",
130
+ "name",
131
+ "leave",
132
+ "destination",
133
+ "departure",
134
+ "arrive",
135
+ "department",
136
+ "food",
137
+ "time",
138
+ ]
139
+ # count: 17
140
+
141
+ all_slots = all_reqslot + [
142
+ "stay",
143
+ "day",
144
+ "people",
145
+ "name",
146
+ "destination",
147
+ "departure",
148
+ "department",
149
+ ]
150
+ get_slot = {}
151
+ for s in all_slots:
152
+ get_slot[s] = 1
153
+ # count: 24
154
+
155
+
156
+ # mapping slots in dialogue act to original goal slot names
157
+ da_abbr_to_slot_name = {
158
+ "addr": "address",
159
+ "fee": "price",
160
+ "post": "postcode",
161
+ "ref": "reference",
162
+ "ticket": "price",
163
+ "depart": "departure",
164
+ "dest": "destination",
165
+ }
166
+
167
+ # slot merging: not used currently
168
+ # slot_name_to_value_token = {
169
+ # 'entrance fee': 'price',
170
+ # 'pricerange': 'price',
171
+ # 'arrive': 'time',
172
+ # 'leave': 'time',
173
+ # 'departure': 'name',
174
+ # 'destination': 'name',
175
+ # 'stay': 'count',
176
+ # 'people': 'count',
177
+ # 'stars': 'count',
178
+ # }
179
+ # dialog_act_dom = ['restaurant', 'hotel', 'attraction', 'train', 'taxi', 'police', 'hospital', 'general', 'booking']
180
+ dialog_acts = {
181
+ "restaurant": [
182
+ "inform",
183
+ "request",
184
+ "nooffer",
185
+ "recommend",
186
+ "select",
187
+ "offerbook",
188
+ "offerbooked",
189
+ "nobook",
190
+ ],
191
+ "hotel": [
192
+ "inform",
193
+ "request",
194
+ "nooffer",
195
+ "recommend",
196
+ "select",
197
+ "offerbook",
198
+ "offerbooked",
199
+ "nobook",
200
+ ],
201
+ "attraction": ["inform", "request", "nooffer", "recommend", "select"],
202
+ "train": ["inform", "request", "nooffer", "offerbook", "offerbooked", "select"],
203
+ "taxi": ["inform", "request"],
204
+ "police": ["inform", "request"],
205
+ "hospital": ["inform", "request"],
206
+ # 'booking': ['book', 'inform', 'nobook', 'request'],
207
+ "general": ["bye", "greet", "reqmore", "welcome"],
208
+ }
209
+ all_acts = []
210
+ for acts in dialog_acts.values():
211
+ for act in acts:
212
+ if act not in all_acts:
213
+ all_acts.append(act)
214
+ # print(all_acts)
215
+
216
+ dialog_act_params = {
217
+ "inform": all_slots + ["choice", "open"],
218
+ "request": all_infslot + ["choice", "price"],
219
+ "nooffer": all_slots + ["choice"],
220
+ "recommend": all_reqslot + ["choice", "open"],
221
+ "select": all_slots + ["choice"],
222
+ # 'book': ['time', 'people', 'stay', 'reference', 'day', 'name', 'choice'],
223
+ "nobook": ["time", "people", "stay", "reference", "day", "name", "choice"],
224
+ "offerbook": all_slots + ["choice"],
225
+ "offerbooked": all_slots + ["choice"],
226
+ "reqmore": [],
227
+ "welcome": [],
228
+ "bye": [],
229
+ "greet": [],
230
+ }
231
+
232
+ # dialog_acts = ['inform', 'request', 'nooffer', 'recommend', 'select', 'book', 'nobook', 'offerbook', 'offerbooked',
233
+ # 'reqmore', 'welcome', 'bye', 'greet'] # thank
234
+ dialog_act_all_slots = all_slots + ["choice", "open"]
235
+ # act_span_vocab = ['['+i+']' for i in dialog_act_dom] + ['['+i+']' for i in dialog_acts] + all_slots
236
+
237
+ # value_token_in_resp = ['address', 'name', 'phone', 'postcode', 'area', 'food', 'pricerange', 'id',
238
+ # 'department', 'place', 'day', 'count', 'car']
239
+ # count: 12
240
+
241
+
242
+ # special slot tokens in belief span
243
+ # no need of this, just covert slot to [slot] e.g. pricerange -> [pricerange]
244
+ slot_name_to_slot_token = {}
245
+
246
+
247
+ # special slot tokens in responses
248
+ # not use at the momoent
249
+ slot_name_to_value_token = {
250
+ # 'entrance fee': '[value_price]',
251
+ # 'pricerange': '[value_price]',
252
+ # 'arriveby': '[value_time]',
253
+ # 'leaveat': '[value_time]',
254
+ # 'departure': '[value_place]',
255
+ # 'destination': '[value_place]',
256
+ # 'stay': 'count',
257
+ # 'people': 'count'
258
+ }
259
+
260
+
261
+ db_tokens = [
262
+ "<sos_db>",
263
+ "<eos_db>",
264
+ "[db_nores]",
265
+ "[db_0]",
266
+ "[db_1]",
267
+ "[db_2]",
268
+ "[db_3]",
269
+ ]
270
+
271
+ special_tokens = [
272
+ "<pad>",
273
+ "<go_r>",
274
+ "<unk>",
275
+ "<go_b>",
276
+ "<go_a>",
277
+ "<eos_u>",
278
+ "<eos_r>",
279
+ "<eos_b>",
280
+ "<eos_a>",
281
+ "<go_d>",
282
+ "<eos_d>",
283
+ "<sos_u>",
284
+ "<sos_r>",
285
+ "<sos_b>",
286
+ "<sos_a>",
287
+ "<sos_d>",
288
+ ] + db_tokens
289
+
290
+ eos_tokens = {
291
+ "user": "<eos_u>",
292
+ "user_delex": "<eos_u>",
293
+ "resp": "<eos_r>",
294
+ "resp_gen": "<eos_r>",
295
+ "pv_resp": "<eos_r>",
296
+ "bspn": "<eos_b>",
297
+ "bspn_gen": "<eos_b>",
298
+ "pv_bspn": "<eos_b>",
299
+ "bsdx": "<eos_b>",
300
+ "bsdx_gen": "<eos_b>",
301
+ "pv_bsdx": "<eos_b>",
302
+ "aspn": "<eos_a>",
303
+ "aspn_gen": "<eos_a>",
304
+ "pv_aspn": "<eos_a>",
305
+ "dspn": "<eos_d>",
306
+ "dspn_gen": "<eos_d>",
307
+ "pv_dspn": "<eos_d>",
308
+ }
309
+
310
+ sos_tokens = {
311
+ "user": "<sos_u>",
312
+ "user_delex": "<sos_u>",
313
+ "resp": "<sos_r>",
314
+ "resp_gen": "<sos_r>",
315
+ "pv_resp": "<sos_r>",
316
+ "bspn": "<sos_b>",
317
+ "bspn_gen": "<sos_b>",
318
+ "pv_bspn": "<sos_b>",
319
+ "bsdx": "<sos_b>",
320
+ "bsdx_gen": "<sos_b>",
321
+ "pv_bsdx": "<sos_b>",
322
+ "aspn": "<sos_a>",
323
+ "aspn_gen": "<sos_a>",
324
+ "pv_aspn": "<sos_a>",
325
+ "dspn": "<sos_d>",
326
+ "dspn_gen": "<sos_d>",
327
+ "pv_dspn": "<sos_d>",
328
+ }
src/crazyneuraluser/UBAR_code/reader.py ADDED
@@ -0,0 +1,1262 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import json
3
+ import logging
4
+ import os
5
+ import random
6
+ from collections import OrderedDict
7
+ from copy import deepcopy
8
+
9
+ import numpy as np
10
+ import spacy
11
+ from transformers import GPT2Tokenizer
12
+
13
+ from crazyneuraluser.UBAR_code import ontology, utils
14
+ from crazyneuraluser.UBAR_code.config import global_config as cfg
15
+ from crazyneuraluser.UBAR_code.db_ops import MultiWozDB
16
+
17
+ # from config21 import global_config as cfg
18
+
19
+
20
+ class _ReaderBase(object):
21
+ def __init__(self):
22
+ self.train, self.dev, self.test = [], [], []
23
+ self.vocab = None
24
+ self.db = None
25
+ self.set_stats = {}
26
+
27
+ def _bucket_by_turn(self, encoded_data):
28
+ turn_bucket = {}
29
+ for dial in encoded_data:
30
+ turn_len = len(dial)
31
+ if turn_len not in turn_bucket:
32
+ turn_bucket[turn_len] = []
33
+ turn_bucket[turn_len].append(dial)
34
+ del_l = []
35
+ for k in turn_bucket:
36
+ if k >= 5:
37
+ del_l.append(k)
38
+ logging.debug("bucket %d instance %d" % (k, len(turn_bucket[k])))
39
+ # for k in del_l:
40
+ # turn_bucket.pop(k)
41
+ return OrderedDict(sorted(turn_bucket.items(), key=lambda i: i[0]))
42
+
43
+ def _construct_mini_batch(self, data):
44
+ all_batches = []
45
+ batch = []
46
+ for dial in data:
47
+ batch.append(dial)
48
+ if len(batch) == cfg.batch_size:
49
+ # print('batch size: %d, batch num +1'%(len(batch)))
50
+ all_batches.append(batch)
51
+ batch = []
52
+ # if remainder > 1/2 batch_size, just put them in the previous batch, otherwise form a new batch
53
+ # print('last batch size: %d, batch num +1'%(len(batch)))
54
+ if (len(batch) % len(cfg.cuda_device)) != 0:
55
+ batch = batch[: -(len(batch) % len(cfg.cuda_device))]
56
+ if len(batch) > 0.5 * cfg.batch_size:
57
+ all_batches.append(batch)
58
+ elif len(all_batches):
59
+ all_batches[-1].extend(batch)
60
+ else:
61
+ all_batches.append(batch)
62
+ return all_batches
63
+
64
+ def transpose_batch(self, batch):
65
+ dial_batch = []
66
+ turn_num = len(batch[0])
67
+ for turn in range(turn_num):
68
+ turn_l = {}
69
+ for dial in batch:
70
+ this_turn = dial[turn]
71
+ for k in this_turn:
72
+ if k not in turn_l:
73
+ turn_l[k] = []
74
+ turn_l[k].append(this_turn[k])
75
+ dial_batch.append(turn_l)
76
+ return dial_batch
77
+
78
+ def inverse_transpose_turn(self, turn_list):
79
+ """
80
+ eval, one dialog at a time
81
+ """
82
+ dialogs = {}
83
+ turn_num = len(turn_list)
84
+ dial_id = turn_list[0]["dial_id"]
85
+ dialogs[dial_id] = []
86
+ for turn_idx in range(turn_num):
87
+ dial_turn = {}
88
+ turn = turn_list[turn_idx]
89
+ for key, value in turn.items():
90
+ if key == "dial_id":
91
+ continue
92
+ if key == "pointer" and self.db is not None:
93
+ turn_domain = turn["turn_domain"][-1]
94
+ value = self.db.pointerBack(value, turn_domain)
95
+ dial_turn[key] = value
96
+ dialogs[dial_id].append(dial_turn)
97
+ return dialogs
98
+
99
+ def inverse_transpose_batch(self, turn_batch_list):
100
+ """
101
+ :param turn_batch_list: list of transpose dial batch
102
+ """
103
+ dialogs = {}
104
+ total_turn_num = len(turn_batch_list)
105
+ # initialize
106
+ for idx_in_batch, dial_id in enumerate(turn_batch_list[0]["dial_id"]):
107
+ dialogs[dial_id] = []
108
+ for turn_n in range(total_turn_num):
109
+ dial_turn = {}
110
+ turn_batch = turn_batch_list[turn_n]
111
+ for key, v_list in turn_batch.items():
112
+ if key == "dial_id":
113
+ continue
114
+ value = v_list[idx_in_batch]
115
+ if key == "pointer" and self.db is not None:
116
+ turn_domain = turn_batch["turn_domain"][idx_in_batch][-1]
117
+ value = self.db.pointerBack(value, turn_domain)
118
+ dial_turn[key] = value
119
+ dialogs[dial_id].append(dial_turn)
120
+ return dialogs
121
+
122
+ def get_eval_data(self, set_name="dev"):
123
+ name_to_set = {"train": self.train, "test": self.test, "dev": self.dev}
124
+ dial = name_to_set[set_name]
125
+
126
+ if set_name not in self.set_stats:
127
+ self.set_stats[set_name] = {}
128
+ num_turns = 0
129
+ num_dials = len(dial)
130
+ for d in dial:
131
+ num_turns += len(d)
132
+
133
+ self.set_stats[set_name]["num_turns"] = num_turns
134
+ self.set_stats[set_name]["num_dials"] = num_dials
135
+
136
+ return dial
137
+
138
+ def get_batches(self, set_name):
139
+ """
140
+ compute dataset stats.
141
+ """
142
+ global dia_count
143
+ log_str = ""
144
+ name_to_set = {"train": self.train, "test": self.test, "dev": self.dev}
145
+ dial = name_to_set[set_name]
146
+ if cfg.low_resource and set_name == "train":
147
+ # dial = random.sample(dial, int(len(dial)*0.01))
148
+ dial = random.sample(dial, 100)
149
+ logging.info("Low Resource setting, finetuning size: {}".format(len(dial)))
150
+ turn_bucket = self._bucket_by_turn(dial)
151
+ # self._shuffle_turn_bucket(turn_bucket)
152
+ all_batches = []
153
+
154
+ if set_name not in self.set_stats:
155
+ self.set_stats[set_name] = {}
156
+ num_training_steps = 0
157
+ num_turns = 0
158
+ num_dials = 0
159
+
160
+ for k in turn_bucket:
161
+ if set_name != "test" and k == 1 or k >= 17:
162
+ continue
163
+ batches = self._construct_mini_batch(turn_bucket[k])
164
+ log_str += "turn num:%d, dial num: %d, batch num: %d last batch len: %d\n" % (
165
+ k,
166
+ len(turn_bucket[k]),
167
+ len(batches),
168
+ len(batches[-1]),
169
+ )
170
+ # print("turn num:%d, dial num:v%d, batch num: %d, "%(k, len(turn_bucket[k]), len(batches)))
171
+ num_training_steps += k * len(batches)
172
+ num_turns += k * len(turn_bucket[k])
173
+ num_dials += len(turn_bucket[k])
174
+ all_batches += batches
175
+ log_str += "total batch num: %d\n" % len(all_batches)
176
+ # print('total batch num: %d'%len(all_batches))
177
+ # print('dialog count: %d'%dia_count)
178
+ # return all_batches
179
+
180
+ # log stats
181
+ # logging.info(log_str)
182
+ # cfg.num_training_steps = num_training_steps * cfg.epoch_num
183
+ self.set_stats[set_name]["num_training_steps_per_epoch"] = num_training_steps
184
+ self.set_stats[set_name]["num_turns"] = num_turns
185
+ self.set_stats[set_name]["num_dials"] = num_dials
186
+
187
+ if set_name == "train":
188
+ random.shuffle(all_batches)
189
+ return all_batches
190
+
191
+ def get_nontranspose_data_iterator(self, all_batches):
192
+ for i, batch in enumerate(all_batches):
193
+ yield batch
194
+
195
+ def get_data_iterator(self, all_batches):
196
+ for i, batch in enumerate(all_batches):
197
+ yield self.transpose_batch(batch)
198
+
199
+ def save_result(self, write_mode, results, field, write_title=False):
200
+ with open(cfg.result_path, write_mode) as rf:
201
+ if write_title:
202
+ rf.write(write_title + "\n")
203
+ writer = csv.DictWriter(rf, fieldnames=field)
204
+ writer.writeheader()
205
+ writer.writerows(results)
206
+ return None
207
+
208
+ def save_result_report(self, results):
209
+ # if 'joint_goal' in results[0]:
210
+ # with open(cfg.result_path[:-4] + '_report_dst.txt', 'w') as rf:
211
+ # rf.write('joint goal\tslot_acc\tslot_f1\tact_f1\n')
212
+ # for res in results:
213
+ # a,b,c,d = res['joint_goal'], res['slot_acc'], res['slot_f1'], res['act_f1']
214
+ # rf.write('%2.1f\t%2.1f\t%2.1f\t%2.1f\n'%(a,b,c,d))
215
+ # elif 'joint_goal_delex' in results[0]:
216
+ # with open(cfg.result_path[:-4] + '_report_bsdx.txt', 'w') as rf:
217
+ # rf.write('joint goal\tslot_acc\tslot_f1\tact_f1\n')
218
+ # for res in results:
219
+ # a,b,c,d = res['joint_goal_delex'], res['slot_acc_delex'], res['slot_f1_delex'], res['act_f1']
220
+ # rf.write('%2.1f\t%2.1f\t%2.1f\t%2.1f\n'%(a,b,c,d))
221
+ ctr_save_path = cfg.result_path[:-4] + "_report_ctr%s.csv" % cfg.seed
222
+ write_title = False if os.path.exists(ctr_save_path) else True
223
+ if cfg.aspn_decode_mode == "greedy":
224
+ setting = ""
225
+ elif cfg.aspn_decode_mode == "beam":
226
+ setting = "width=%s" % str(cfg.beam_width)
227
+ if cfg.beam_diverse_param > 0:
228
+ setting += ", penalty=%s" % str(cfg.beam_diverse_param)
229
+ elif cfg.aspn_decode_mode == "topk_sampling":
230
+ setting = "topk=%s" % str(cfg.topk_num)
231
+ elif cfg.aspn_decode_mode == "nucleur_sampling":
232
+ setting = "p=%s" % str(cfg.nucleur_p)
233
+ res = {
234
+ "exp": cfg.eval_load_path,
235
+ "true_bspn": cfg.use_true_curr_bspn,
236
+ "true_aspn": cfg.use_true_curr_aspn,
237
+ "decode": cfg.aspn_decode_mode,
238
+ "param": setting,
239
+ "nbest": cfg.nbest,
240
+ "selection_sheme": cfg.act_selection_scheme,
241
+ "match": results[0]["match"],
242
+ "success": results[0]["success"],
243
+ "bleu": results[0]["bleu"],
244
+ "act_f1": results[0]["act_f1"],
245
+ "avg_act_num": results[0]["avg_act_num"],
246
+ "avg_diverse": results[0]["avg_diverse_score"],
247
+ }
248
+ with open(ctr_save_path, "a") as rf:
249
+ writer = csv.DictWriter(rf, fieldnames=list(res.keys()))
250
+ if write_title:
251
+ writer.writeheader()
252
+ writer.writerows([res])
253
+
254
+
255
+ class MultiWozReader(_ReaderBase):
256
+ def __init__(self, tokenizer):
257
+ super().__init__()
258
+ self.nlp = spacy.load("en_core_web_sm")
259
+
260
+ self.db = MultiWozDB(cfg.dbs)
261
+ self.vocab_size = self._build_vocab()
262
+
263
+ # self.tokenizer = GPT2Tokenizer.from_pretrained(cfg.gpt_path) # add special tokens later
264
+ self.tokenizer = tokenizer
265
+ if cfg.mode == "train":
266
+ self.add_sepcial_tokens()
267
+
268
+ self.domain_files = json.loads(open(cfg.domain_file_path, "r").read())
269
+ self.slot_value_set = json.loads(open(cfg.slot_value_set_path, "r").read())
270
+ if cfg.multi_acts_training:
271
+ self.multi_acts = json.loads(open(cfg.multi_acts_path, "r").read())
272
+
273
+ test_list = [test_list.strip().lower() for test_list in open(cfg.test_list, "r").readlines()]
274
+ dev_list = [dev_list.strip().lower() for dev_list in open(cfg.dev_list, "r").readlines()]
275
+ self.dev_files, self.test_files = {}, {}
276
+ for fn in test_list:
277
+ self.test_files[fn.replace(".json", "")] = 1
278
+ for fn in dev_list:
279
+ self.dev_files[fn.replace(".json", "")] = 1
280
+
281
+ # for domain expanse aka. Cross domain
282
+ self.exp_files = {}
283
+ # if 'all' not in cfg.exp_domains:
284
+ # for domain in cfg.exp_domains:
285
+ # fn_list = self.domain_files.get(domain)
286
+ # if not fn_list:
287
+ # raise ValueError(
288
+ # '[%s] is an invalid experiment setting' % domain)
289
+ # for fn in fn_list:
290
+ # self.exp_files[fn.replace('.json', '')] = 1
291
+ all_domains_list = list(self.domain_files.keys())
292
+ if "all" not in cfg.exp_domains:
293
+ domains = self.get_exp_domains(cfg.exp_domains, all_domains_list)
294
+ logging.info(domains)
295
+ for domain in domains:
296
+ fn_list = self.domain_files.get(domain)
297
+ if not fn_list:
298
+ raise ValueError("[%s] is an invalid experiment setting" % domain)
299
+ for fn in fn_list:
300
+ self.exp_files[fn.replace(".json", "")] = 1
301
+ #
302
+
303
+ self._load_data()
304
+
305
+ if cfg.limit_bspn_vocab:
306
+ self.bspn_masks = self._construct_bspn_constraint()
307
+ if cfg.limit_aspn_vocab:
308
+ self.aspn_masks = self._construct_aspn_constraint()
309
+
310
+ self.multi_acts_record = None
311
+
312
+ def get_exp_domains(self, exp_domains, all_domains_list):
313
+ if "hotel" in exp_domains:
314
+ if "except" in exp_domains:
315
+ # ['except', 'hotel']
316
+ domains = [d for d in all_domains_list if "hotel" not in d and "multi" not in d]
317
+ else:
318
+ # ['hotel']
319
+ domains = ["hotel_single", "hotel_multi"]
320
+ if "train" in exp_domains:
321
+ if "except" in exp_domains:
322
+ # ['except', 'train']
323
+ domains = [d for d in all_domains_list if "train" not in d and "multi" not in d]
324
+ else:
325
+ # ['train']
326
+ domains = ["train_single", "train_multi"]
327
+ if "attraction" in exp_domains:
328
+ if "except" in exp_domains:
329
+ # ['except', 'attraction']
330
+ domains = [d for d in all_domains_list if "attraction" not in d and "multi" not in d]
331
+ else:
332
+ # ['attraction']
333
+ domains = ["attraction_single", "attraction_multi"]
334
+ if "restaurant" in exp_domains:
335
+ if "except" in exp_domains:
336
+ # ['except', 'restaurant']
337
+ domains = [d for d in all_domains_list if "restaurant" not in d and "multi" not in d]
338
+ else:
339
+ # ['restaurant']
340
+ domains = ["restaurant_single", "restaurant_multi"]
341
+ if "taxi" in exp_domains:
342
+ if "except" in exp_domains:
343
+ # ['except', 'taxi']
344
+ domains = [d for d in all_domains_list if "taxi" not in d and "multi" not in d]
345
+ else:
346
+ # ['taxi']
347
+ domains = ["taxi_single", "taxi_multi"]
348
+ return domains
349
+
350
+ def add_sepcial_tokens(self):
351
+ """
352
+ add special tokens to gpt tokenizer
353
+ serves a similar role of Vocab.construt()
354
+ make a dict of special tokens
355
+ """
356
+ special_tokens = []
357
+ for word in ontology.all_domains + ["general"]:
358
+ word = "[" + word + "]"
359
+ special_tokens.append(word)
360
+ for word in ontology.all_acts:
361
+ word = "[" + word + "]"
362
+ special_tokens.append(word)
363
+ # for word in ontology.all_slots:
364
+ # to be determine whether slot should be [slot]
365
+ # if slot, tokenizer having trouble decoding.
366
+ # special_tokens.append(word)
367
+ for word in self.vocab._word2idx.keys():
368
+ if word.startswith("[value_") and word.endswith("]"):
369
+ special_tokens.append(word)
370
+ special_tokens.extend(ontology.special_tokens)
371
+
372
+ special_tokens_dict = {"additional_special_tokens": special_tokens}
373
+ self.tokenizer.add_special_tokens(special_tokens_dict)
374
+ logging.info("Added special tokens to gpt tokenizer.")
375
+
376
+ cfg.pad_id = self.tokenizer.encode("<pad>")[0]
377
+
378
+ def _build_vocab(self):
379
+ self.vocab = utils.Vocab(cfg.vocab_size)
380
+ vp = cfg.vocab_path_train if cfg.mode == "train" or cfg.vocab_path_eval is None else cfg.vocab_path_eval
381
+ # vp = cfg.vocab_path+'.json.freq.json'
382
+ self.vocab.load_vocab(vp)
383
+ return self.vocab.vocab_size
384
+
385
+ def _construct_bspn_constraint(self):
386
+ bspn_masks = {}
387
+ valid_domains = [
388
+ "restaurant",
389
+ "hotel",
390
+ "attraction",
391
+ "train",
392
+ "taxi",
393
+ "hospital",
394
+ ]
395
+ all_dom_codes = [self.vocab.encode("[" + d + "]") for d in valid_domains]
396
+ all_slot_codes = [self.vocab.encode(s) for s in ontology.all_slots]
397
+ bspn_masks[self.vocab.encode("<go_b>")] = all_dom_codes + [
398
+ self.vocab.encode("<eos_b>"),
399
+ 0,
400
+ ]
401
+ bspn_masks[self.vocab.encode("<eos_b>")] = [self.vocab.encode("<pad>")]
402
+ bspn_masks[self.vocab.encode("<pad>")] = [self.vocab.encode("<pad>")]
403
+ for domain, slot_values in self.slot_value_set.items():
404
+ if domain == "police":
405
+ continue
406
+ dom_code = self.vocab.encode("[" + domain + "]")
407
+ bspn_masks[dom_code] = []
408
+ for slot, values in slot_values.items():
409
+ slot_code = self.vocab.encode(slot)
410
+ if slot_code not in bspn_masks:
411
+ bspn_masks[slot_code] = []
412
+ if slot_code not in bspn_masks[dom_code]:
413
+ bspn_masks[dom_code].append(slot_code)
414
+ for value in values:
415
+ for idx, v in enumerate(value.split()):
416
+ if not self.vocab.has_word(v):
417
+ continue
418
+ v_code = self.vocab.encode(v)
419
+ if v_code not in bspn_masks:
420
+ # print(self.vocab._word2idx)
421
+ bspn_masks[v_code] = []
422
+ if idx == 0 and v_code not in bspn_masks[slot_code]:
423
+ bspn_masks[slot_code].append(v_code)
424
+ if idx == (len(value.split()) - 1):
425
+ for w in all_dom_codes + all_slot_codes:
426
+ if self.vocab.encode("<eos_b>") not in bspn_masks[v_code]:
427
+ bspn_masks[v_code].append(self.vocab.encode("<eos_b>"))
428
+ if w not in bspn_masks[v_code]:
429
+ bspn_masks[v_code].append(w)
430
+ break
431
+ if not self.vocab.has_word(value.split()[idx + 1]):
432
+ continue
433
+ next_v_code = self.vocab.encode(value.split()[idx + 1])
434
+ if next_v_code not in bspn_masks[v_code]:
435
+ bspn_masks[v_code].append(next_v_code)
436
+ bspn_masks[self.vocab.encode("<unk>")] = list(bspn_masks.keys())
437
+
438
+ with open("data/processed/multi-woz-processed/bspn_masks.txt", "w") as f:
439
+ for i, j in bspn_masks.items():
440
+ f.write(self.vocab.decode(i) + ": " + " ".join([self.vocab.decode(int(m)) for m in j]) + "\n")
441
+ return bspn_masks
442
+
443
+ def _construct_aspn_constraint(self):
444
+ aspn_masks = {}
445
+ aspn_masks = {}
446
+ all_dom_codes = [self.vocab.encode("[" + d + "]") for d in ontology.dialog_acts.keys()]
447
+ all_act_codes = [self.vocab.encode("[" + a + "]") for a in ontology.dialog_act_params]
448
+ all_slot_codes = [self.vocab.encode(s) for s in ontology.dialog_act_all_slots]
449
+ aspn_masks[self.vocab.encode("<go_a>")] = all_dom_codes + [
450
+ self.vocab.encode("<eos_a>"),
451
+ 0,
452
+ ]
453
+ aspn_masks[self.vocab.encode("<eos_a>")] = [self.vocab.encode("<pad>")]
454
+ aspn_masks[self.vocab.encode("<pad>")] = [self.vocab.encode("<pad>")]
455
+ # for d in all_dom_codes:
456
+ # aspn_masks[d] = all_act_codes
457
+ for a in all_act_codes:
458
+ aspn_masks[a] = all_dom_codes + all_slot_codes + [self.vocab.encode("<eos_a>")]
459
+ for domain, acts in ontology.dialog_acts.items():
460
+ dom_code = self.vocab.encode("[" + domain + "]")
461
+ aspn_masks[dom_code] = []
462
+ for a in acts:
463
+ act_code = self.vocab.encode("[" + a + "]")
464
+ if act_code not in aspn_masks[dom_code]:
465
+ aspn_masks[dom_code].append(act_code)
466
+ # for a, slots in ontology.dialog_act_params.items():
467
+ # act_code = self.vocab.encode('['+a+']')
468
+ # slot_codes = [self.vocab.encode(s) for s in slots]
469
+ # aspn_masks[act_code] = all_dom_codes + slot_codes + [self.vocab.encode('<eos_a>')]
470
+ for s in all_slot_codes:
471
+ aspn_masks[s] = all_dom_codes + all_slot_codes + [self.vocab.encode("<eos_a>")]
472
+ aspn_masks[self.vocab.encode("<unk>")] = list(aspn_masks.keys())
473
+
474
+ with open("processed/multi-woz-processed/aspn_masks.txt", "w") as f:
475
+ for i, j in aspn_masks.items():
476
+ f.write(self.vocab.decode(i) + ": " + " ".join([self.vocab.decode(int(m)) for m in j]) + "\n")
477
+ return aspn_masks
478
+
479
+ def _load_data(self, save_temp=True):
480
+ """
481
+ load processed data and encode, or load already encoded data
482
+ """
483
+ if save_temp: # save encoded data
484
+ if "all" in cfg.exp_domains:
485
+ encoded_file = os.path.join(cfg.data_path, "new_db_se_blank_encoded.data.json")
486
+ # encoded: no sos, se_encoded: sos and eos
487
+ # db: add db results every turn
488
+ else:
489
+ xdomain_dir = "./models/UBAR/experiments_Xdomain/data"
490
+ if not os.path.exists(xdomain_dir):
491
+ os.makedirs(xdomain_dir)
492
+ encoded_file = os.path.join(
493
+ xdomain_dir,
494
+ "{}-encoded.data.json".format("-".join(cfg.exp_domains)),
495
+ )
496
+
497
+ if os.path.exists(encoded_file):
498
+ logging.info("Reading encoded data from {}".format(encoded_file))
499
+ self.data = json.loads(open(cfg.data_path + cfg.data_file, "r", encoding="utf-8").read().lower())
500
+ encoded_data = json.loads(open(encoded_file, "r", encoding="utf-8").read())
501
+ self.train = encoded_data["train"]
502
+ self.dev = encoded_data["dev"]
503
+ self.test = encoded_data["test"]
504
+ else:
505
+ logging.info("Encoding data now and save the encoded data in {}".format(encoded_file))
506
+ # not exists, encode data and save
507
+ self.data = json.loads(open(cfg.data_path + cfg.data_file, "r", encoding="utf-8").read().lower())
508
+ self.train, self.dev, self.test = [], [], []
509
+ for fn, dial in self.data.items():
510
+ if ".json" in fn:
511
+ fn = fn.replace(".json", "")
512
+ if "all" in cfg.exp_domains or self.exp_files.get(fn):
513
+ if self.dev_files.get(fn):
514
+ self.dev.append(self._get_encoded_data(fn, dial))
515
+ elif self.test_files.get(fn):
516
+ self.test.append(self._get_encoded_data(fn, dial))
517
+ else:
518
+ self.train.append(self._get_encoded_data(fn, dial))
519
+
520
+ # save encoded data
521
+ encoded_data = {"train": self.train, "dev": self.dev, "test": self.test}
522
+ json.dump(encoded_data, open(encoded_file, "w"), indent=2)
523
+
524
+ else: # directly read processed data and encode
525
+ self.data = json.loads(open(cfg.data_path + cfg.data_file, "r", encoding="utf-8").read().lower())
526
+ self.train, self.dev, self.test = [], [], []
527
+ for fn, dial in self.data.items():
528
+ if ".json" in fn:
529
+ fn = fn.replace(".json", "")
530
+ if "all" in cfg.exp_domains or self.exp_files.get(fn):
531
+ if self.dev_files.get(fn):
532
+ self.dev.append(self._get_encoded_data(fn, dial))
533
+ elif self.test_files.get(fn):
534
+ self.test.append(self._get_encoded_data(fn, dial))
535
+ else:
536
+ self.train.append(self._get_encoded_data(fn, dial))
537
+ # if save_temp:
538
+ # json.dump(self.test, open(
539
+ # 'data/multi-woz-analysis/test.encoded.json', 'w'), indent=2)
540
+ # self.vocab.save_vocab('data/multi-woz-analysis/vocab_temp')
541
+
542
+ random.shuffle(self.train)
543
+ # random.shuffle(self.dev)
544
+ # random.shuffle(self.test)
545
+ logging.info("train size:{}, dev size:{}, test size:{}".format(len(self.train), len(self.dev), len(self.test)))
546
+
547
+ def _get_encoded_data(self, fn, dial):
548
+ encoded_dial = []
549
+ for idx, t in enumerate(dial["log"]): # tokenize to list of ids
550
+ enc = {}
551
+ enc["dial_id"] = fn
552
+
553
+ # enc['user'] = self.vocab.sentence_encode(t['user'].split() + ['<eos_u>'])
554
+ # enc['usdx'] = self.vocab.sentence_encode(t['user_delex'].split() + ['<eos_u>'])
555
+ # enc['resp'] = self.vocab.sentence_encode(t['resp'].split() + ['<eos_r>'])
556
+ # enc['bspn'] = self.vocab.sentence_encode(t['constraint'].split() + ['<eos_b>'])
557
+ # enc['bsdx'] = self.vocab.sentence_encode(t['cons_delex'].split() + ['<eos_b>'])
558
+ # enc['aspn'] = self.vocab.sentence_encode(t['sys_act'].split() + ['<eos_a>'])
559
+ # enc['dspn'] = self.vocab.sentence_encode(t['turn_domain'].split() + ['<eos_d>'])
560
+
561
+ # use gpt tokenizer directly tokenize word list, prone to encode unknown words to |endoftext|
562
+ # enc['user'] = self.tokenizer.encode(
563
+ # t['user'].split() + ['<eos_u>'])
564
+ # enc['usdx'] = self.tokenizer.encode(
565
+ # t['user_delex'].split() + ['<eos_u>'])
566
+ # enc['resp'] = self.tokenizer.encode(
567
+ # t['resp'].split() + ['<eos_r>'])
568
+ # enc['bspn'] = self.tokenizer.encode(
569
+ # t['constraint'].split() + ['<eos_b>'])
570
+ # enc['bsdx'] = self.tokenizer.encode(
571
+ # t['cons_delex'].split() + ['<eos_b>'])
572
+ # enc['aspn'] = self.tokenizer.encode(
573
+ # t['sys_act'].split() + ['<eos_a>'])
574
+ # enc['dspn'] = self.tokenizer.encode(
575
+ # t['turn_domain'].split() + ['<eos_d>'])
576
+
577
+ # gpt use bpe to encode strings, very very slow. ~9min
578
+ # in tokenization_utils.encode I find encode can pad_to_max_length, and reutrn tensor
579
+ enc["user"] = self.tokenizer.convert_tokens_to_ids(
580
+ self.tokenizer.tokenize("<sos_u> " + t["user"] + " <eos_u>")
581
+ )
582
+ enc["usdx"] = self.tokenizer.convert_tokens_to_ids(
583
+ self.tokenizer.tokenize("<sos_u> " + t["user"] + " <eos_u>")
584
+ )
585
+ enc["resp"] = self.tokenizer.convert_tokens_to_ids(
586
+ self.tokenizer.tokenize("<sos_r> " + t["resp"] + " <eos_r>")
587
+ )
588
+ enc["bspn"] = self.tokenizer.convert_tokens_to_ids(
589
+ self.tokenizer.tokenize("<sos_b> " + t["constraint"] + " <eos_b>")
590
+ )
591
+ enc["bsdx"] = self.tokenizer.convert_tokens_to_ids(
592
+ self.tokenizer.tokenize("<sos_b> " + t["cons_delex"] + " <eos_b>")
593
+ )
594
+ enc["aspn"] = self.tokenizer.convert_tokens_to_ids(
595
+ self.tokenizer.tokenize("<sos_a> " + t["sys_act"] + " <eos_a>")
596
+ )
597
+ enc["dspn"] = self.tokenizer.convert_tokens_to_ids(
598
+ self.tokenizer.tokenize("<sos_d> " + t["turn_domain"] + " <eos_d>")
599
+ )
600
+
601
+ enc["pointer"] = [int(i) for i in t["pointer"].split(",")]
602
+ enc["turn_domain"] = t["turn_domain"].split()
603
+ enc["turn_num"] = t["turn_num"]
604
+ if cfg.multi_acts_training:
605
+ enc["aspn_aug"] = []
606
+ if fn in self.multi_acts:
607
+ turn_ma = self.multi_acts[fn].get(str(idx), {})
608
+ for act_type, act_spans in turn_ma.items():
609
+ enc["aspn_aug"].append([self.tokenizer.encode(a.split() + ["<eos_a>"]) for a in act_spans])
610
+
611
+ # add db results to enc, at every turn
612
+ db_pointer = self.bspan_to_DBpointer(t["constraint"], t["turn_domain"].split())
613
+ # db_tokens = ['<sos_db>', '<eos_db>', '[db_nores]', '[db_0]', '[db_1]', '[db_2]', '[db_3]']
614
+ enc["db"] = self.tokenizer.convert_tokens_to_ids(
615
+ self.tokenizer.tokenize("<sos_db> " + db_pointer + " <eos_db>")
616
+ )
617
+
618
+ encoded_dial.append(enc)
619
+ return encoded_dial
620
+
621
+ def bspan_to_constraint_dict(self, bspan, bspn_mode="bspn"):
622
+ bspan = bspan.split() if isinstance(bspan, str) else bspan
623
+ constraint_dict = {}
624
+ domain = None
625
+ conslen = len(bspan)
626
+ for idx, cons in enumerate(bspan):
627
+ cons = self.vocab.decode(cons) if type(cons) is not str else cons
628
+ if cons == "<eos_b>":
629
+ break
630
+ if "[" in cons:
631
+ if cons[1:-1] not in ontology.all_domains:
632
+ continue
633
+ domain = cons[1:-1]
634
+ elif cons in ontology.get_slot:
635
+ if domain is None:
636
+ continue
637
+ if cons == "people":
638
+ # handle confusion of value name "people's portraits..." and slot people
639
+ try:
640
+ ns = bspan[idx + 1]
641
+ ns = self.vocab.decode(ns) if type(ns) is not str else ns
642
+ if ns == "'s":
643
+ continue
644
+ except Exception:
645
+ continue
646
+ if not constraint_dict.get(domain):
647
+ constraint_dict[domain] = {}
648
+ if bspn_mode == "bsdx":
649
+ constraint_dict[domain][cons] = 1
650
+ continue
651
+ vidx = idx + 1
652
+ if vidx == conslen:
653
+ break
654
+ vt_collect = []
655
+ vt = bspan[vidx]
656
+ vt = self.vocab.decode(vt) if type(vt) is not str else vt
657
+ while vidx < conslen and vt != "<eos_b>" and "[" not in vt and vt not in ontology.get_slot:
658
+ vt_collect.append(vt)
659
+ vidx += 1
660
+ if vidx == conslen:
661
+ break
662
+ vt = bspan[vidx]
663
+ vt = self.vocab.decode(vt) if type(vt) is not str else vt
664
+ if vt_collect:
665
+ constraint_dict[domain][cons] = " ".join(vt_collect)
666
+
667
+ return constraint_dict
668
+
669
+ def bspan_to_DBpointer(self, bspan, turn_domain):
670
+ constraint_dict = self.bspan_to_constraint_dict(bspan)
671
+ # print(constraint_dict)
672
+ matnums = self.db.get_match_num(constraint_dict)
673
+ match_dom = turn_domain[0] if len(turn_domain) == 1 else turn_domain[1]
674
+ match_dom = match_dom[1:-1] if match_dom.startswith("[") else match_dom
675
+ match = matnums[match_dom]
676
+ # vector = self.db.addDBPointer(match_dom, match)
677
+ vector = self.db.addDBIndicator(match_dom, match)
678
+ return vector
679
+
680
+ def aspan_to_act_list(self, aspan):
681
+ aspan = aspan.split() if isinstance(aspan, str) else aspan
682
+ acts = []
683
+ domain = None
684
+ conslen = len(aspan)
685
+ for idx, cons in enumerate(aspan):
686
+ cons = self.vocab.decode(cons) if type(cons) is not str else cons
687
+ if cons == "<eos_a>":
688
+ break
689
+ if "[" in cons and cons[1:-1] in ontology.dialog_acts:
690
+ domain = cons[1:-1]
691
+
692
+ elif "[" in cons and cons[1:-1] in ontology.dialog_act_params:
693
+ if domain is None:
694
+ continue
695
+ vidx = idx + 1
696
+ if vidx == conslen:
697
+ acts.append(domain + "-" + cons[1:-1] + "-none")
698
+ break
699
+ vt = aspan[vidx]
700
+ vt = self.vocab.decode(vt) if type(vt) is not str else vt
701
+ no_param_act = True
702
+ while vidx < conslen and vt != "<eos_a>" and "[" not in vt:
703
+ no_param_act = False
704
+ acts.append(domain + "-" + cons[1:-1] + "-" + vt)
705
+ vidx += 1
706
+ if vidx == conslen:
707
+ break
708
+ vt = aspan[vidx]
709
+ vt = self.vocab.decode(vt) if type(vt) is not str else vt
710
+ if no_param_act:
711
+ acts.append(domain + "-" + cons[1:-1] + "-none")
712
+
713
+ return acts
714
+
715
+ def dspan_to_domain(self, dspan):
716
+ domains = {}
717
+ dspan = dspan.split() if isinstance(dspan, str) else dspan
718
+ for d in dspan:
719
+ dom = self.vocab.decode(d) if type(d) is not str else d
720
+ if dom != "<eos_d>":
721
+ domains[dom] = 1
722
+ else:
723
+ break
724
+ return domains
725
+
726
+ def convert_turn_eval(self, turn, pv_turn, first_turn=False):
727
+ """
728
+ input: [all previous ubar, U_t, B_t, A_t] predict R_t
729
+ firts turn: [U_t, B_t, A_t] predict R_t
730
+
731
+ regarding the context, all previous ubar is too slow, try the previous ubar
732
+ """
733
+ inputs = {}
734
+
735
+ context_list = []
736
+ # predict_list = []
737
+ prompt = ""
738
+ if cfg.use_true_curr_bspn:
739
+ if cfg.use_true_curr_aspn: # only predict resp
740
+ context_list = ["user", "bspn", "db", "aspn"]
741
+ # context_list = ['user','aspn'] # predict resp based on current aspn and bspn
742
+ # predict_list = ['resp']
743
+ prompt = "<sos_r>"
744
+ else: # predicted aspn
745
+ context_list = ["user", "bspn", "db"]
746
+ # predict_list = ['aspn', 'resp']
747
+ prompt = "<sos_a>"
748
+ else: # predict bspn aspn resp. db are not predicted. this part tbd.
749
+ context_list = ["user"]
750
+ # predict_list = ['bspn', 'db','aspn', 'resp']
751
+ prompt = "<sos_b>"
752
+
753
+ if first_turn:
754
+ context = []
755
+ for c in context_list:
756
+ context += turn[c]
757
+
758
+ inputs["context"] = context + self.tokenizer.encode([prompt])
759
+ inputs["labels"] = context
760
+ # e43 with BABAU
761
+ # inputs['labels'] = []
762
+
763
+ else:
764
+ context = []
765
+ for c in context_list:
766
+ context += turn[c]
767
+
768
+ pv_context = pv_turn["labels"] + pv_turn["bspn"] + pv_turn["db"] + pv_turn["aspn"] + pv_turn["resp"]
769
+ # e43 with BABAU
770
+ # pv_context = pv_turn['labels'] + pv_turn['bspn'] + pv_turn['db'] + pv_turn['aspn']
771
+
772
+ # prompt response, add sos_r
773
+ inputs["context"] = pv_context + context + self.tokenizer.encode([prompt])
774
+ # context just the current turn
775
+ # inputs['context'] = context + self.tokenizer.encode([prompt])
776
+ # context just the current action
777
+
778
+ if cfg.use_all_previous_context:
779
+ inputs["labels"] = pv_context + context # use all previous ubar history
780
+ else:
781
+ inputs["labels"] = context # use previous trun
782
+
783
+ if len(inputs["context"]) > 900:
784
+ print("len exceeds 900")
785
+ inputs["context"] = inputs["context"][-900:]
786
+
787
+ return inputs
788
+
789
+ def convert_batch_session(self, dial_batch):
790
+ """
791
+ convert the whole session for training
792
+ concat [U_0, B_0, A_0, R_0, ... , U_n, B_n, A_n, R_n]
793
+
794
+ try: [user, bspn, aspn, resp]
795
+ or
796
+ try: [user, bspn, db, aspn, resp]
797
+ """
798
+ inputs = {}
799
+ contexts = []
800
+ cell_list = ["user", "bspn", "db", "aspn", "resp"]
801
+ for idx, dial in enumerate(dial_batch):
802
+ context = []
803
+ for turn_num, turn in enumerate(dial):
804
+ for cell in cell_list:
805
+ context.extend(turn[cell])
806
+ contexts.append(context)
807
+
808
+ inputs["contexts"] = contexts
809
+ inputs["contexts_np"], inputs["lengths"] = utils.padSeqs_gpt(inputs["contexts"], cfg.pad_id)
810
+ return inputs
811
+
812
+ def convert_batch_gpt(self, turn_batch, pv_batch, first_turn=False):
813
+ """
814
+ convert the current and the last turn
815
+ concat [U_{t-1}, B_{t-1}, A_{t-1}, R_{t-1}, U_t, B_t, A_t, R_t]
816
+ firts turn: [U_t, B_t, A_t, R_t]
817
+ try: [usdx, bspn, aspn, resp]
818
+
819
+ """
820
+ inputs = {}
821
+ if first_turn:
822
+ contexts = []
823
+ batch_zipped = zip(
824
+ turn_batch["usdx"],
825
+ turn_batch["bspn"],
826
+ turn_batch["aspn"],
827
+ turn_batch["resp"],
828
+ )
829
+ for u, b, a, r in batch_zipped:
830
+ context = u + b + a + r
831
+ contexts.append(context)
832
+ inputs["contexts"] = contexts
833
+ # padSeqs to make [UBAR] the same length
834
+ inputs["contexts_np"], inputs["lengths"] = utils.padSeqs_gpt(inputs["contexts"], cfg.pad_id)
835
+ else:
836
+ contexts = []
837
+ batch_zipped = zip(
838
+ pv_batch["pv_usdx"],
839
+ pv_batch["pv_bspn"],
840
+ pv_batch["pv_aspn"],
841
+ pv_batch["pv_resp"],
842
+ turn_batch["usdx"],
843
+ turn_batch["bspn"],
844
+ turn_batch["aspn"],
845
+ turn_batch["resp"],
846
+ )
847
+ for pu, pb, pa, pr, u, b, a, r in batch_zipped:
848
+ context = pu + pb + pa + pr + u + b + a + r
849
+ contexts.append(context)
850
+ inputs["contexts"] = contexts
851
+ contexts_np, lengths = utils.padSeqs_gpt(inputs["contexts"], cfg.pad_id)
852
+ inputs["contexts_np"] = contexts_np
853
+ inputs["lengths"] = lengths
854
+ return inputs
855
+
856
+ def convert_batch(self, py_batch, py_prev, first_turn=False):
857
+ inputs = {}
858
+ if first_turn:
859
+ for item, py_list in py_prev.items():
860
+ batch_size = len(py_batch["user"])
861
+ inputs[item + "_np"] = np.array([[1]] * batch_size)
862
+ inputs[item + "_unk_np"] = np.array([[1]] * batch_size)
863
+ else:
864
+ for item, py_list in py_prev.items():
865
+ if py_list is None:
866
+ continue
867
+ if not cfg.enable_aspn and "aspn" in item:
868
+ continue
869
+ if not cfg.enable_bspn and "bspn" in item:
870
+ continue
871
+ if not cfg.enable_dspn and "dspn" in item:
872
+ continue
873
+ prev_np = utils.padSeqs(py_list, truncated=cfg.truncated, trunc_method="pre")
874
+ inputs[item + "_np"] = prev_np
875
+ if item in ["pv_resp", "pv_bspn"]:
876
+ inputs[item + "_unk_np"] = deepcopy(inputs[item + "_np"])
877
+ # <unk>, restrict vocab size to 3k, map ids>3k to <unk>
878
+ inputs[item + "_unk_np"][inputs[item + "_unk_np"] >= self.vocab_size] = 2
879
+ else:
880
+ inputs[item + "_unk_np"] = inputs[item + "_np"]
881
+
882
+ for item in ["user", "usdx", "resp", "bspn", "aspn", "bsdx", "dspn"]:
883
+ if not cfg.enable_aspn and item == "aspn":
884
+ continue
885
+ if not cfg.enable_bspn and item == "bspn":
886
+ continue
887
+
888
+ if not cfg.enable_dspn and item == "dspn":
889
+ continue
890
+ py_list = py_batch[item]
891
+ trunc_method = "post" if item == "resp" else "pre"
892
+ # max_length = cfg.max_nl_length if item in ['user', 'usdx', 'resp'] else cfg.max_span_length
893
+ inputs[item + "_np"] = utils.padSeqs(py_list, truncated=cfg.truncated, trunc_method=trunc_method)
894
+ if item in ["user", "usdx", "resp", "bspn"]:
895
+ inputs[item + "_unk_np"] = deepcopy(inputs[item + "_np"])
896
+ inputs[item + "_unk_np"][inputs[item + "_unk_np"] >= self.vocab_size] = 2 # <unk>
897
+ else:
898
+ inputs[item + "_unk_np"] = inputs[item + "_np"]
899
+
900
+ if cfg.multi_acts_training and cfg.mode == "train":
901
+ inputs["aspn_bidx"], multi_aspn = [], []
902
+ for bidx, aspn_type_list in enumerate(py_batch["aspn_aug"]):
903
+ if aspn_type_list:
904
+ for aspn_list in aspn_type_list:
905
+ random.shuffle(aspn_list)
906
+ # choose one random act span in each act type
907
+ aspn = aspn_list[0]
908
+ multi_aspn.append(aspn)
909
+ inputs["aspn_bidx"].append(bidx)
910
+ if cfg.multi_act_sampling_num > 1:
911
+ for i in range(cfg.multi_act_sampling_num):
912
+ if len(aspn_list) >= i + 2:
913
+ # choose one random act span in each act type
914
+ aspn = aspn_list[i + 1]
915
+ multi_aspn.append(aspn)
916
+ inputs["aspn_bidx"].append(bidx)
917
+
918
+ if multi_aspn:
919
+ inputs["aspn_aug_np"] = utils.padSeqs(multi_aspn, truncated=cfg.truncated, trunc_method="pre")
920
+ # [all available aspn num in the batch, T]
921
+ inputs["aspn_aug_unk_np"] = inputs["aspn_aug_np"]
922
+
923
+ inputs["db_np"] = np.array(py_batch["pointer"])
924
+ inputs["turn_domain"] = py_batch["turn_domain"]
925
+
926
+ return inputs
927
+
928
+ def wrap_result_lm(self, result_dict, eos_syntax=None):
929
+ results = []
930
+ eos_syntax = ontology.eos_tokens if not eos_syntax else eos_syntax
931
+ sos_syntax = ontology.sos_tokens
932
+ # ground truth bs, as, ds.. generate response
933
+ field = [
934
+ "dial_id",
935
+ "turn_num",
936
+ "user",
937
+ "bspn_gen",
938
+ "bsdx",
939
+ "resp_gen",
940
+ "resp",
941
+ "aspn_gen",
942
+ "aspn",
943
+ "dspn_gen",
944
+ "dspn",
945
+ "bspn",
946
+ "pointer",
947
+ ]
948
+
949
+ for dial_id, turns in result_dict.items():
950
+ entry = {"dial_id": dial_id, "trun_num": len(turns)}
951
+ for f in field[2:]:
952
+ entry[f] = "" # ???
953
+ results.append(entry)
954
+ for turn_idx, turn in enumerate(turns):
955
+ entry = {"dial_id": dial_id}
956
+ for key in field:
957
+ if key in ["dial_id"]:
958
+ continue
959
+ v = turn.get(key, "")
960
+ if key == "turn_domain":
961
+ v = " ".join(v)
962
+
963
+ if key in eos_syntax and v != "":
964
+ # remove eos tokens
965
+ v = self.tokenizer.decode(v)
966
+ v = v.split()
967
+ # remove eos/sos in span
968
+ if eos_syntax[key] in v:
969
+ v.remove(eos_syntax[key])
970
+ if sos_syntax[key] in v:
971
+ v.remove(sos_syntax[key])
972
+ # if key != 'resp_gen':
973
+ # # remove eos/sos in span
974
+ # if eos_syntax[key] in v:
975
+ # v.remove(eos_syntax[key])
976
+ # if sos_syntax[key] in v:
977
+ # v.remove(sos_syntax[key])
978
+ # else: # 'resp_gen'
979
+ # sos_index = 0
980
+ # eos_index = -1
981
+ # if sos_syntax[key] in v:
982
+ # sos_index = v.index(sos_syntax[key])
983
+ # if eos_syntax[key] in v:
984
+ # eos_index = v.index(eos_syntax[key])
985
+ # else:
986
+ # pass # take too long
987
+ # # no <eos_r> found, stop at any eos_tokens
988
+ # # for i in range(sos_index+1, len(v)):
989
+ # # if v[i] in sos_syntax.values() or v[i] in eos_syntax.values():
990
+ # # eos_index = i
991
+ # v = v[sos_index+1: eos_index]
992
+
993
+ # v = self.tokenizer.convert_tokens_to_string(v)
994
+ v = " ".join(v)
995
+ else:
996
+ pass # v = v
997
+ entry[key] = v
998
+
999
+ results.append(entry)
1000
+
1001
+ return results, field
1002
+
1003
+ def wrap_result(self, result_dict, eos_syntax=None):
1004
+ decode_fn = self.vocab.sentence_decode
1005
+ results = []
1006
+ eos_syntax = ontology.eos_tokens if not eos_syntax else eos_syntax
1007
+
1008
+ if cfg.bspn_mode == "bspn":
1009
+ field = [
1010
+ "dial_id",
1011
+ "turn_num",
1012
+ "user",
1013
+ "bspn_gen",
1014
+ "bspn",
1015
+ "resp_gen",
1016
+ "resp",
1017
+ "aspn_gen",
1018
+ "aspn",
1019
+ "dspn_gen",
1020
+ "dspn",
1021
+ "pointer",
1022
+ ]
1023
+ elif not cfg.enable_dst: # this
1024
+ field = [
1025
+ "dial_id",
1026
+ "turn_num",
1027
+ "user",
1028
+ "bsdx_gen",
1029
+ "bsdx",
1030
+ "resp_gen",
1031
+ "resp",
1032
+ "aspn_gen",
1033
+ "aspn",
1034
+ "dspn_gen",
1035
+ "dspn",
1036
+ "bspn",
1037
+ "pointer",
1038
+ ]
1039
+ else:
1040
+ field = [
1041
+ "dial_id",
1042
+ "turn_num",
1043
+ "user",
1044
+ "bsdx_gen",
1045
+ "bsdx",
1046
+ "resp_gen",
1047
+ "resp",
1048
+ "aspn_gen",
1049
+ "aspn",
1050
+ "dspn_gen",
1051
+ "dspn",
1052
+ "bspn_gen",
1053
+ "bspn",
1054
+ "pointer",
1055
+ ]
1056
+ if self.multi_acts_record is not None:
1057
+ field.insert(7, "multi_act_gen")
1058
+
1059
+ for dial_id, turns in result_dict.items():
1060
+ entry = {"dial_id": dial_id, "turn_num": len(turns)}
1061
+ for prop in field[2:]:
1062
+ entry[prop] = ""
1063
+ results.append(entry)
1064
+ for turn_no, turn in enumerate(turns):
1065
+ entry = {"dial_id": dial_id}
1066
+ for key in field:
1067
+ if key in ["dial_id"]:
1068
+ continue
1069
+ v = turn.get(key, "")
1070
+ if key == "turn_domain":
1071
+ v = " ".join(v)
1072
+ entry[key] = decode_fn(v, eos=eos_syntax[key]) if key in eos_syntax and v != "" else v
1073
+ results.append(entry)
1074
+ return results, field
1075
+
1076
+ def restore(self, resp, domain, constraint_dict, mat_ents):
1077
+ restored = resp
1078
+
1079
+ restored = restored.replace("[value_reference]", "53022")
1080
+ restored = restored.replace("[value_car]", "BMW")
1081
+
1082
+ # restored.replace('[value_phone]', '830-430-6666')
1083
+ for d in domain:
1084
+ constraint = constraint_dict.get(d, None)
1085
+ if constraint:
1086
+ if "stay" in constraint:
1087
+ restored = restored.replace("[value_stay]", constraint["stay"])
1088
+ if "day" in constraint:
1089
+ restored = restored.replace("[value_day]", constraint["day"])
1090
+ if "people" in constraint:
1091
+ restored = restored.replace("[value_people]", constraint["people"])
1092
+ if "time" in constraint:
1093
+ restored = restored.replace("[value_time]", constraint["time"])
1094
+ if "type" in constraint:
1095
+ restored = restored.replace("[value_type]", constraint["type"])
1096
+ if d in mat_ents and len(mat_ents[d]) == 0:
1097
+ for s in constraint:
1098
+ if s == "pricerange" and d in ["hotel", "restaurant"] and "price]" in restored:
1099
+ restored = restored.replace("[value_price]", constraint["pricerange"])
1100
+ if s + "]" in restored:
1101
+ restored = restored.replace("[value_%s]" % s, constraint[s])
1102
+
1103
+ if "[value_choice" in restored and mat_ents.get(d):
1104
+ restored = restored.replace("[value_choice]", str(len(mat_ents[d])))
1105
+ if "[value_choice" in restored:
1106
+ restored = restored.replace("[value_choice]", "3")
1107
+
1108
+ # restored.replace('[value_car]', 'BMW')
1109
+
1110
+ try:
1111
+ ent = mat_ents.get(domain[-1], [])
1112
+ if ent:
1113
+ ent = ent[0]
1114
+
1115
+ for t in restored.split():
1116
+ if "[value" in t:
1117
+ slot = t[7:-1]
1118
+ if ent.get(slot):
1119
+ if domain[-1] == "hotel" and slot == "price":
1120
+ slot = "pricerange"
1121
+ restored = restored.replace(t, ent[slot])
1122
+ elif slot == "price":
1123
+ if ent.get("pricerange"):
1124
+ restored = restored.replace(t, ent["pricerange"])
1125
+ else:
1126
+ print(restored, domain)
1127
+ except Exception:
1128
+ print(resp)
1129
+ print(restored)
1130
+ quit()
1131
+
1132
+ restored = restored.replace("[value_phone]", "62781111")
1133
+ restored = restored.replace("[value_postcode]", "CG9566")
1134
+ restored = restored.replace("[value_address]", "Parkside, Cambridge")
1135
+
1136
+ # if '[value_' in restored:
1137
+
1138
+ # print(domain)
1139
+ # # print(mat_ents)
1140
+ # print(resp)
1141
+ # print(restored)
1142
+ return restored
1143
+
1144
+ def record_utterance(self, result_dict):
1145
+ decode_fn = self.vocab.sentence_decode
1146
+
1147
+ ordered_dial = {}
1148
+ for dial_id, turns in result_dict.items():
1149
+ diverse = 0
1150
+ turn_count = 0
1151
+ for turn_no, turn in enumerate(turns):
1152
+ act_collect = {}
1153
+ act_type_collect = {}
1154
+ slot_score = 0
1155
+ for i in range(cfg.nbest):
1156
+ aspn = decode_fn(turn["multi_act"][i], eos=ontology.eos_tokens["aspn"])
1157
+ pred_acts = self.aspan_to_act_list(" ".join(aspn))
1158
+ act_type = ""
1159
+ for act in pred_acts:
1160
+ d, a, s = act.split("-")
1161
+ if d + "-" + a not in act_collect:
1162
+ act_collect[d + "-" + a] = {s: 1}
1163
+ slot_score += 1
1164
+ act_type += d + "-" + a + ";"
1165
+ elif s not in act_collect:
1166
+ act_collect[d + "-" + a][s] = 1
1167
+ slot_score += 1
1168
+ act_type_collect[act_type] = 1
1169
+ turn_count += 1
1170
+ diverse += len(act_collect) * 3 + slot_score
1171
+ ordered_dial[dial_id] = diverse / turn_count
1172
+
1173
+ ordered_dial = sorted(ordered_dial.keys(), key=lambda x: -ordered_dial[x])
1174
+
1175
+ dialog_record = {}
1176
+
1177
+ with open(cfg.eval_load_path + "/dialogue_record.csv", "w") as rf:
1178
+ writer = csv.writer(rf)
1179
+
1180
+ for dial_id in ordered_dial:
1181
+ dialog_record[dial_id] = []
1182
+ turns = result_dict[dial_id]
1183
+ writer.writerow([dial_id])
1184
+ for turn_no, turn in enumerate(turns):
1185
+ user = decode_fn(turn["user"], eos=ontology.eos_tokens["user"])
1186
+ bspn = decode_fn(turn["bspn"], eos=ontology.eos_tokens["bspn"])
1187
+ aspn = decode_fn(turn["aspn"], eos=ontology.eos_tokens["aspn"])
1188
+ resp = decode_fn(turn["resp"], eos=ontology.eos_tokens["resp"])
1189
+ constraint_dict = self.bspan_to_constraint_dict(bspn)
1190
+ # print(constraint_dict)
1191
+ mat_ents = self.db.get_match_num(constraint_dict, True)
1192
+ domain = [i[1:-1] for i in self.dspan_to_domain(turn["dspn"]).keys()]
1193
+ restored = self.restore(resp, domain, constraint_dict, mat_ents)
1194
+ writer.writerow([turn_no, user, turn["pointer"], domain, restored, resp])
1195
+ turn_record = {
1196
+ "user": user,
1197
+ "bspn": bspn,
1198
+ "aspn": aspn,
1199
+ "dom": domain,
1200
+ "resp": resp,
1201
+ "resp_res": restored,
1202
+ }
1203
+
1204
+ resp_col = []
1205
+ aspn_col = []
1206
+ resp_restore_col = []
1207
+ for i in range(cfg.nbest):
1208
+ aspn = decode_fn(turn["multi_act"][i], eos=ontology.eos_tokens["aspn"])
1209
+ resp = decode_fn(turn["multi_resp"][i], eos=ontology.eos_tokens["resp"])
1210
+
1211
+ restored = self.restore(resp, domain, constraint_dict, mat_ents)
1212
+ resp_col.append(resp)
1213
+ resp_restore_col.append(restored)
1214
+ aspn_col.append(aspn)
1215
+
1216
+ zipped = list(zip(resp_restore_col, resp_col, aspn_col))
1217
+ zipped.sort(key=lambda s: len(s[0]))
1218
+ resp_restore_col = list(list(zip(*zipped))[0])
1219
+ aspn_col = list(list(zip(*zipped))[2])
1220
+ resp_col = list(list(zip(*zipped))[1])
1221
+ turn_record["aspn_col"] = aspn_col
1222
+ turn_record["resp_col"] = resp_col
1223
+ turn_record["resp_res_col"] = resp_restore_col
1224
+ for i in range(cfg.nbest):
1225
+ # aspn = decode_fn(turn['multi_act'][i], eos=ontology.eos_tokens['aspn'])
1226
+ resp = resp_col[i]
1227
+ aspn = aspn_col[i]
1228
+ resp_restore = resp_restore_col[i]
1229
+
1230
+ writer.writerow(["", resp_restore, resp, aspn])
1231
+
1232
+ dialog_record[dial_id].append(turn_record)
1233
+
1234
+ # json.dump(dialog_record, open(cfg.eval_load_path + '/resultdict.json','w'))
1235
+
1236
+
1237
+ if __name__ == "__main__":
1238
+ reader = MultiWozReader(GPT2Tokenizer)
1239
+ # for aspan in ["[general] [bye] [welcome] <eos_a>","[train] [inform] trainid destination \
1240
+ # arrive leave [offerbook] [general] [reqmore] <eos_a>",]:
1241
+ # act = reader.aspan_to_constraint_dict(aspan.split())
1242
+ # print('!!!')
1243
+ # print(act)
1244
+
1245
+ for bspan in [
1246
+ "[taxi] destination golden house departure broughton house gallery arrive 19:30 [attraction]"
1247
+ + " type museum name whipple museum of the history of science people 5 day monday",
1248
+ "[taxi] destination golden house departure broughton house gallery arrive 19:30 [attraction]"
1249
+ + " type museum name whipple museum of the history of science people 5 day monday <eos_b>",
1250
+ ]:
1251
+ encoded = reader.vocab.sentence_encode(bspan.split())
1252
+ print(encoded)
1253
+ cons = reader.bspan_to_constraint_dict(encoded, bspn_mode="bspn")
1254
+ print(cons)
1255
+ for bspan in [
1256
+ "[taxi] destination departure leave [hotel] name [attraction] name people day",
1257
+ "[taxi] destination departure leave [hotel] name [attraction] name people day <eos_b>",
1258
+ ]:
1259
+ encoded = reader.vocab.sentence_encode(bspan.split())
1260
+ print(encoded)
1261
+ cons = reader.bspan_to_constraint_dict(encoded, bspn_mode="bsdx")
1262
+ print(cons)
src/crazyneuraluser/UBAR_code/utils.py ADDED
@@ -0,0 +1,292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import logging
3
+ from collections import OrderedDict
4
+
5
+ import numpy as np
6
+
7
+ from crazyneuraluser.UBAR_code import ontology
8
+
9
+
10
+ def py2np(list):
11
+ return np.array(list)
12
+
13
+
14
+ def write_dict(fn, dic):
15
+ with open(fn, "w") as f:
16
+ json.dump(dic, f, indent=2)
17
+
18
+
19
+ def f1_score(label_list, pred_list):
20
+ tp = len([t for t in pred_list if t in label_list])
21
+ fp = max(0, len(pred_list) - tp)
22
+ fn = max(0, len(label_list) - tp)
23
+ precision = tp / (tp + fp + 1e-10)
24
+ recall = tp / (tp + fn + 1e-10)
25
+ f1 = 2 * precision * recall / (precision + recall + 1e-10)
26
+ return f1
27
+
28
+
29
+ class Vocab(object):
30
+ def __init__(self, vocab_size=0):
31
+ self.vocab_size = vocab_size
32
+ self.vocab_size_oov = 0 # get after construction
33
+ self._idx2word = {} # word + oov
34
+ self._word2idx = {} # word
35
+ self._freq_dict = {} # word + oov
36
+ for w in [
37
+ "<pad>",
38
+ "<go_r>",
39
+ "<unk>",
40
+ "<go_b>",
41
+ "<go_a>",
42
+ "<eos_u>",
43
+ "<eos_r>",
44
+ "<eos_b>",
45
+ "<eos_a>",
46
+ "<go_d>",
47
+ "<eos_d>",
48
+ ]:
49
+ self._absolute_add_word(w)
50
+
51
+ def _absolute_add_word(self, w):
52
+ idx = len(self._idx2word)
53
+ self._idx2word[idx] = w
54
+ self._word2idx[w] = idx
55
+
56
+ def add_word(self, word):
57
+ if word not in self._freq_dict:
58
+ self._freq_dict[word] = 0
59
+ self._freq_dict[word] += 1
60
+
61
+ def has_word(self, word):
62
+ return self._freq_dict.get(word)
63
+
64
+ def _add_to_vocab(self, word):
65
+ if word not in self._word2idx:
66
+ idx = len(self._idx2word)
67
+ self._idx2word[idx] = word
68
+ self._word2idx[word] = idx
69
+
70
+ def construct(self):
71
+ decoded = sorted(self._freq_dict.keys(), key=lambda x: -self._freq_dict[x])
72
+ print("Vocabulary size including oov: %d" % (len(decoded) + len(self._idx2word)))
73
+ if len(decoded) + len(self._idx2word) < self.vocab_size:
74
+ logging.warning(
75
+ "actual label set smaller than that configured: {}/{}".format(
76
+ len(decoded) + len(self._idx2word), self.vocab_size
77
+ )
78
+ )
79
+ for word in ontology.all_domains + ["general"]:
80
+ word = "[" + word + "]"
81
+ self._add_to_vocab(word)
82
+ for word in ontology.all_acts:
83
+ word = "[" + word + "]"
84
+ self._add_to_vocab(word)
85
+ for word in ontology.all_slots:
86
+ self._add_to_vocab(word)
87
+ for word in decoded:
88
+ if word.startswith("[value_") and word.endswith("]"):
89
+ self._add_to_vocab(word)
90
+ for word in decoded:
91
+ self._add_to_vocab(word)
92
+ self.vocab_size_oov = len(self._idx2word)
93
+
94
+ def load_vocab(self, vocab_path):
95
+ self._freq_dict = json.loads(open(vocab_path + ".freq.json", "r").read())
96
+ self._word2idx = json.loads(open(vocab_path + ".word2idx.json", "r").read())
97
+ self._idx2word = {}
98
+ for w, idx in self._word2idx.items():
99
+ self._idx2word[idx] = w
100
+ self.vocab_size_oov = len(self._idx2word)
101
+ print('vocab file loaded from "' + vocab_path + '"')
102
+ print("Vocabulary size including oov: %d" % (self.vocab_size_oov))
103
+
104
+ def save_vocab(self, vocab_path):
105
+ _freq_dict = OrderedDict(sorted(self._freq_dict.items(), key=lambda kv: kv[1], reverse=True))
106
+
107
+ write_dict(vocab_path + ".word2idx.json", self._word2idx)
108
+ write_dict(vocab_path + ".freq.json", _freq_dict)
109
+
110
+ def encode(self, word, include_oov=True):
111
+ if include_oov:
112
+ if self._word2idx.get(word, None) is None:
113
+ raise ValueError("Unknown word: %s. Vocabulary should include oovs here." % word)
114
+ return self._word2idx[word]
115
+ else:
116
+ word = "<unk>" if word not in self._word2idx else word
117
+ return self._word2idx[word]
118
+
119
+ def sentence_encode(self, word_list):
120
+ return [self.encode(_) for _ in word_list]
121
+
122
+ def oov_idx_map(self, idx):
123
+ return 2 if idx > self.vocab_size else idx
124
+
125
+ def sentence_oov_map(self, index_list):
126
+ return [self.oov_idx_map(_) for _ in index_list]
127
+
128
+ def decode(self, idx, indicate_oov=False):
129
+ if not self._idx2word.get(idx):
130
+ raise ValueError("Error idx: %d. Vocabulary should include oovs here." % idx)
131
+ if not indicate_oov or idx < self.vocab_size:
132
+ return self._idx2word[idx]
133
+ else:
134
+ return self._idx2word[idx] + "(o)"
135
+
136
+ def sentence_decode(self, index_list, eos=None, indicate_oov=False):
137
+ decoded = [self.decode(_, indicate_oov) for _ in index_list]
138
+ if not eos or eos not in decoded:
139
+ return " ".join(decoded)
140
+ else:
141
+ idx = decoded.index(eos)
142
+ return " ".join(decoded[:idx])
143
+
144
+ def nl_decode(self, decoded, eos=None):
145
+ return [self.sentence_decode(_, eos) + "\n" for _ in decoded]
146
+
147
+
148
+ def padSeqs_gpt(sequences, pad_id, maxlen=None):
149
+ lengths = []
150
+ for x in sequences:
151
+ lengths.append(len(x))
152
+
153
+ num_samples = len(sequences)
154
+ seq_mexlen = np.max(lengths)
155
+
156
+ # maxlen = 1024
157
+ if seq_mexlen > 1024: # gpt2.n_ctx
158
+ # print('maxlen exceeds 1024')
159
+ maxlen = 1024
160
+ else:
161
+ maxlen = seq_mexlen
162
+
163
+ # tokenizer.encode('<|endoftext|>') = ['50256']
164
+ # All labels set to ``-100`` are ignored (masked), the loss is only
165
+ # computed for labels in ``[0, ..., config.vocab_size]`` (from modeling_gpt2.GPT2LMHeadModel)
166
+
167
+ x = np.ones((num_samples, maxlen)) * pad_id
168
+ for idx, s in enumerate(sequences):
169
+ if not len(s):
170
+ print("empty list was found in padSeqs")
171
+ # trunc method = 'pre'
172
+ trunc = s[-maxlen:]
173
+ trunc = np.asarray(trunc)
174
+
175
+ # pad method = 'post'
176
+ x[idx, : len(trunc)] = trunc
177
+
178
+ return x, lengths
179
+
180
+
181
+ def padSeqs(
182
+ sequences,
183
+ maxlen=None,
184
+ truncated=False,
185
+ pad_method="post",
186
+ trunc_method="pre",
187
+ dtype="int32",
188
+ value=0.0,
189
+ ):
190
+ if not hasattr(sequences, "__len__"):
191
+ raise ValueError("`sequences` must be iterable.")
192
+ lengths = []
193
+ for x in sequences:
194
+ if not hasattr(x, "__len__"):
195
+ raise ValueError("`sequences` must be a list of iterables. " "Found non-iterable: " + str(x))
196
+ lengths.append(len(x))
197
+
198
+ num_samples = len(sequences)
199
+ seq_maxlen = np.max(lengths)
200
+
201
+ if maxlen is not None and truncated:
202
+ maxlen = min(seq_maxlen, maxlen)
203
+ else:
204
+ maxlen = seq_maxlen
205
+ # take the sample shape from the first non empty sequence
206
+ # checking for consistency in the main loop below.
207
+ sample_shape = tuple()
208
+ for s in sequences:
209
+ if len(s) > 0:
210
+ sample_shape = np.asarray(s).shape[1:]
211
+ break
212
+
213
+ x = (np.ones((num_samples, maxlen) + sample_shape) * value).astype(dtype)
214
+ for idx, s in enumerate(sequences):
215
+ if not len(s):
216
+ print("empty list/array was found")
217
+ continue # empty list/array was found
218
+ if trunc_method == "pre":
219
+ trunc = s[-maxlen:]
220
+ elif trunc_method == "post":
221
+ trunc = s[:maxlen]
222
+ else:
223
+ raise ValueError('Truncating type "%s" not understood' % trunc_method)
224
+
225
+ # check `trunc` has expected shape
226
+ trunc = np.asarray(trunc, dtype=dtype)
227
+ if trunc.shape[1:] != sample_shape:
228
+ raise ValueError(
229
+ "Shape of sample %s of sequence at position %s is different from expected shape %s"
230
+ % (trunc.shape[1:], idx, sample_shape)
231
+ )
232
+
233
+ if pad_method == "post":
234
+ x[idx, : len(trunc)] = trunc
235
+ elif pad_method == "pre":
236
+ x[idx, -len(trunc) :] = trunc
237
+ else:
238
+ raise ValueError('Padding type "%s" not understood' % pad_method)
239
+ return x
240
+
241
+
242
+ def get_glove_matrix(glove_path, vocab, initial_embedding_np):
243
+ """
244
+ return a glove embedding matrix
245
+ :param self:
246
+ :param glove_file:
247
+ :param initial_embedding_np:
248
+ :return: np array of [V,E]
249
+ """
250
+ ef = open(glove_path, "r", encoding="UTF-8")
251
+ cnt = 0
252
+ vec_array = initial_embedding_np
253
+ old_avg = np.average(vec_array)
254
+ old_std = np.std(vec_array)
255
+ vec_array = vec_array.astype(np.float32)
256
+ new_avg, new_std = 0, 0
257
+
258
+ for line in ef.readlines():
259
+ line = line.strip().split(" ")
260
+ word, vec = line[0], line[1:]
261
+ vec = np.array(vec, np.float32)
262
+ if not vocab.has_word(word):
263
+ continue
264
+ word_idx = vocab.encode(word)
265
+ if word_idx < vocab.vocab_size:
266
+ cnt += 1
267
+ vec_array[word_idx] = vec
268
+ new_avg += np.average(vec)
269
+ new_std += np.std(vec)
270
+ new_avg /= cnt
271
+ new_std /= cnt
272
+ ef.close()
273
+ logging.info(
274
+ "%d known embedding. old mean: %f new mean %f, old std %f new std %f"
275
+ % (cnt, old_avg, new_avg, old_std, new_std)
276
+ )
277
+ return vec_array
278
+
279
+
280
+ def position_encoding_init(self, n_position, d_pos_vec):
281
+ position_enc = np.array(
282
+ [
283
+ [pos / np.power(10000, 2 * (j // 2) / d_pos_vec) for j in range(d_pos_vec)]
284
+ if pos != 0
285
+ else np.zeros(d_pos_vec)
286
+ for pos in range(n_position)
287
+ ]
288
+ )
289
+
290
+ position_enc[1:, 0::2] = np.sin(position_enc[1:, 0::2]) # dim 2i
291
+ position_enc[1:, 1::2] = np.cos(position_enc[1:, 1::2]) # dim 2i+1
292
+ return position_enc
src/crazyneuraluser/user_model_code/analysis_multiwoz.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ DATA_SPLIT = ["train", "dev", "test"]
5
+
6
+
7
+ def _check_n_turns(data, data_act):
8
+ for split in DATA_SPLIT:
9
+ for dial_id, meta in data[split].items():
10
+ n_in_meta = len(meta["turns"])
11
+
12
+ assert dial_id in data_act
13
+ n_in_act = len(data_act[dial_id])
14
+ assert n_in_meta == n_in_act
15
+
16
+
17
+ def collect_data(data_path, remove_dial_switch=False):
18
+ # load act
19
+ act_file = os.path.join(data_path, "dialog_acts.json")
20
+ with open(act_file) as f:
21
+ data_act = json.load(f)
22
+ print("Load {} dialogues in act file".format(len(data_act)))
23
+
24
+ # load data
25
+ data = {}
26
+ for split in DATA_SPLIT:
27
+ data[split] = iter_data_folder(data_path, split, remove_dial_switch, data_act)
28
+
29
+ _check_n_turns(data, data_act)
30
+ return data, data_act
31
+
32
+
33
+ def remove_dial(dial_id, dial, dial_act):
34
+ # check services
35
+ services = dial["services"]
36
+ if "police" in services or "bus" in services or "hospital" in services:
37
+ return True
38
+
39
+ # check act
40
+ domains = set()
41
+ for turn_id, turn_act in dial_act.items():
42
+ dialogue_act = turn_act["dialog_act"]
43
+ for dact in dialogue_act:
44
+ assert "-" in dact
45
+ domain, act = dact.split("-")
46
+ domains.add(domain)
47
+ if "Police" in domains or "Bus" in domains or "Hospital" in domains:
48
+ return True
49
+ return False
50
+
51
+
52
+ def iter_data_folder(data_path, split, remove_dial_switch, data_act):
53
+ """Iterate data folder"""
54
+ split_dir = os.path.join(data_path, split)
55
+ data_split = {}
56
+ remove_dial_ids = []
57
+ total_dial_ids = []
58
+ for f in os.listdir(split_dir):
59
+ if not f.startswith("dialogues"): # skip schema.json
60
+ continue
61
+ file_path = os.path.join(data_path, split, f)
62
+ iter_file(
63
+ file_path,
64
+ data_split,
65
+ remove_dial_ids,
66
+ total_dial_ids,
67
+ remove_dial_switch,
68
+ data_act,
69
+ )
70
+ print(
71
+ "Done collecting {} | total {} dialogues | load {} dialogues | remove {} dialogues".format(
72
+ split, len(total_dial_ids), len(data_split), len(remove_dial_ids)
73
+ )
74
+ )
75
+ return data_split
76
+
77
+
78
+ def iter_file(
79
+ file_path, data_split, remove_dial_ids, total_dial_ids, remove_dial_switch, data_act
80
+ ):
81
+ with open(file_path) as f:
82
+ data_in = json.load(f) # list of dialouges in a json file
83
+
84
+ for dial in data_in:
85
+ dial_id = dial["dialogue_id"]
86
+ total_dial_ids.append(dial_id)
87
+ dial_act = data_act[dial_id]
88
+
89
+ if remove_dial_switch and remove_dial(dial_id, dial, dial_act):
90
+ remove_dial_ids.append(dial_id)
91
+ else:
92
+ data_split[dial_id] = dial
93
+
94
+
95
+ def show_dial(dial_id, data, data_act):
96
+ def simple_linearise_act(dialouge_act):
97
+ linear_act = ""
98
+ for domain_act, slot_value_list in dialouge_act.items():
99
+ linear_act += domain_act + " "
100
+ for slot_value in slot_value_list:
101
+ slot, value = slot_value[0], slot_value[1]
102
+ linear_act += slot + " "
103
+ linear_act += value + " "
104
+ return linear_act
105
+
106
+ split = None
107
+ for data_split in DATA_SPLIT:
108
+ if dial_id in data[data_split]:
109
+ split = data_split
110
+ break
111
+
112
+ print("dial_id: {}".format(dial_id))
113
+ for turn_id, turn in enumerate(data[split][dial_id]["turns"]):
114
+ dialouge_act = data_act[dial_id][str(turn_id)]["dialog_act"]
115
+ linear_act = simple_linearise_act(dialouge_act)
116
+ print("-----" * 15)
117
+ print("turn_id: {}, spk: {}".format(turn_id, turn["speaker"]))
118
+ print("act: |{}|".format(linear_act))
119
+ print("utt: |{}|".format(turn["utterance"]))
src/crazyneuraluser/user_model_code/analysis_sgd.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ from utils_sgd import (
5
+ bcolors,
6
+ compare_slot_values_in_state,
7
+ dict2str,
8
+ get_turn_act,
9
+ list2str,
10
+ )
11
+
12
+ """ This file contains some utilities for analysis and parsing SGD """
13
+
14
+ DATA_SPLIT = ["train", "dev", "test"]
15
+
16
+
17
+ def collect_data(data_path, remove_dial_switch=False):
18
+ data = {}
19
+ for split in DATA_SPLIT:
20
+ data[split] = iter_data_folder(data_path, split, remove_dial_switch)
21
+ return data
22
+
23
+
24
+ def _remove_dial(dial_id, dial):
25
+ # remove_flag = False
26
+ # removes service `Homes_2` in test set as the slot `intent` is the same name as the user intent,
27
+ # which causes problem in goal preparation
28
+ if "Homes_2" in dial["services"]:
29
+ return True
30
+ return False
31
+
32
+
33
+ def iter_data_folder(data_path, split, remove_dial_switch):
34
+ """Iterate data split folder"""
35
+ split_dir = os.path.join(data_path, split)
36
+ data_split = {}
37
+ remove_dial_ids = []
38
+ total_dial_ids = []
39
+ for f in os.listdir(split_dir):
40
+ if not f.startswith("dialogues"): # skip schema.json
41
+ continue
42
+ file_path = os.path.join(data_path, split, f)
43
+ iter_file(
44
+ file_path, data_split, remove_dial_ids, total_dial_ids, remove_dial_switch
45
+ )
46
+ print(
47
+ "Done collecting {} | total {} dialogues | load {} dialogues | remove {} dialogues".format(
48
+ split, len(total_dial_ids), len(data_split), len(remove_dial_ids)
49
+ )
50
+ )
51
+ return data_split
52
+
53
+
54
+ def iter_file(
55
+ file_path, data_split, remove_dial_ids, total_dial_ids, remove_dial_switch
56
+ ):
57
+ """Iterate data file"""
58
+ with open(file_path) as f:
59
+ data_in = json.load(f) # list of dialouges in a json file
60
+
61
+ for dial in data_in:
62
+ dial_id = dial["dialogue_id"]
63
+ total_dial_ids.append(dial_id)
64
+
65
+ if remove_dial_switch and _remove_dial(dial_id, dial):
66
+ remove_dial_ids.append(dial_id)
67
+ else:
68
+ data_split[dial_id] = dial
69
+
70
+
71
+ def check_multiple_services_per_turn(data):
72
+ for split in DATA_SPLIT:
73
+ for dial_id in sorted(data[split].keys()):
74
+ dial = data[split][dial_id]
75
+ for turn_id, turn in enumerate(dial["turns"]):
76
+ frames = turn["frames"]
77
+ if len(frames) > 1:
78
+ print(split, dial_id, turn_id, turn["utterance"])
79
+
80
+
81
+ def show_actions(actions):
82
+ for action_id, action in enumerate(actions):
83
+ act, slot, values = action["act"], action["slot"], action["values"]
84
+ print(
85
+ f"====> ACTION | Act {action_id}: {bcolors.RED}{act}{bcolors.ENDC}, \
86
+ slot: {bcolors.YELLOW}{slot}{bcolors.ENDC}, values: {bcolors.GREEN}{values}{bcolors.ENDC}"
87
+ )
88
+
89
+
90
+ def show_user_state(frame):
91
+ state = frame["state"]
92
+ active_intent = state["active_intent"]
93
+ req_slots = list2str(state["requested_slots"])
94
+ slot2value = dict2str(state["slot_values"], colored=True)
95
+ print(
96
+ "====> STATE | intent: {}, req_slots: {}, slot2value: {}".format(
97
+ active_intent, req_slots, slot2value
98
+ )
99
+ )
100
+
101
+
102
+ def show_service_call(frame):
103
+ if "service_call" not in frame:
104
+ return
105
+ # system calls api
106
+ service_call, service_results = frame["service_call"], frame["service_results"]
107
+ print(
108
+ "====> API call | method: {}, args: {}, results: {}".format(
109
+ service_call["method"],
110
+ dict2str(service_call["parameters"]),
111
+ len(service_results),
112
+ )
113
+ )
114
+
115
+
116
+ def show_frame(spk, frame_id, frame):
117
+ service = frame["service"]
118
+ print("==> Frame_id: {}, service: {}".format(frame_id, service))
119
+
120
+ # actions (include all slots)
121
+ show_actions(frame["actions"])
122
+
123
+ # slots (only provide non-categorical slots with word span boundaries)
124
+ if spk == "USER":
125
+ show_user_state(frame)
126
+ else: # system
127
+ show_service_call(frame)
128
+
129
+
130
+ def show_turn(turn_id, turn):
131
+ if turn is None:
132
+ return
133
+
134
+ frames = turn["frames"]
135
+ spk = turn["speaker"]
136
+ utt = turn["utterance"]
137
+ assert spk in ["USER", "SYSTEM"]
138
+ print(f"{spk}: {bcolors.UNDERLINE}{utt}{bcolors.ENDC}")
139
+ for frame_id, frame in enumerate(frames):
140
+ show_frame(spk, frame_id, frame)
141
+ print("------" * 15)
142
+
143
+
144
+ def show_dial_info(dial_id, dial):
145
+ print("\n")
146
+ print("******" * 15)
147
+ print("Dialogue={} | Service={}".format(dial_id, list2str(dial["services"])))
148
+ print("******" * 15)
149
+
150
+
151
+ def show_dial(dial_id, dial):
152
+ show_dial_info(dial_id, dial)
153
+ for turn_id, turn in enumerate(dial["turns"]):
154
+ show_turn(turn_id, turn)
155
+
156
+
157
+ def show_data(data):
158
+ for split in DATA_SPLIT:
159
+ for dial_id in sorted(data[split].keys()):
160
+ dial = data[split][dial_id]
161
+ show_dial(dial_id, dial)
162
+ input("press...")
163
+
164
+
165
+ def identify_scenarios(data):
166
+ """
167
+ According to dataset paper, a scenario is a sequence of intents, seeded at the start of a conversation
168
+ to the user agent
169
+ """
170
+ # TODO: deal with NONE intent, check the # of intent seq conbinations
171
+ for split in DATA_SPLIT:
172
+ scenario2dialogues = {}
173
+ n_scenario_max, n_scenario_min = 0, 100
174
+ for dial_id in sorted(data[split].keys()):
175
+ dial = data[split][dial_id]
176
+ scenario = []
177
+ for turn in dial["turns"]:
178
+ if turn["speaker"] == "SYSTEM":
179
+ continue
180
+ # USER turn
181
+ # it's fine to consider only first frame (service) if the turn is at the bounrary between two services
182
+ frame = turn["frames"][0]
183
+ intent = frame["state"]["active_intent"]
184
+ if intent == "NONE":
185
+ continue
186
+ if len(scenario) == 0 or intent != scenario[-1]:
187
+ scenario.append(intent)
188
+
189
+ # update count
190
+ if len(scenario) > n_scenario_max:
191
+ n_scenario_max = len(scenario)
192
+ if len(scenario) < n_scenario_min:
193
+ n_scenario_min = len(scenario)
194
+
195
+ scenario = list2str(scenario)
196
+ if scenario not in scenario2dialogues:
197
+ scenario2dialogues[scenario] = []
198
+ scenario2dialogues[scenario].append(dial_id)
199
+
200
+ # done iter over split
201
+ print(
202
+ "Summary: split={}, unique_scenario={}, max_intent={}, min_intent={}".format(
203
+ split, len(scenario2dialogues), n_scenario_max, n_scenario_min
204
+ )
205
+ )
206
+
207
+
208
+ def _check_request_alts_type(prev_turn, sys_turn, curr_turn, curr_acts):
209
+ """
210
+ check which of the following happens when request_alts
211
+ 1. randomly change goal (state changes)
212
+ 2. request_alts as system provides venue with missing slot-value (usr provides new info)
213
+ 3. simply dislike the provided venue, change venue without new slot-value (same info)
214
+
215
+ Input:
216
+ prev_turn: previous user turn
217
+ curr_turn: current user turn
218
+ """
219
+
220
+ def _get_intent2state(turn):
221
+ intent2state = {}
222
+ for frame in turn["frames"]:
223
+ state = frame["state"]
224
+ intent = state["active_intent"]
225
+ intent2state[intent] = state
226
+ return intent2state
227
+
228
+ assert "REQUEST_ALTS" in curr_acts
229
+ if len(curr_acts) == 1: # case 3
230
+ # return "_dislike_"
231
+ if "OFFER" in get_turn_act(sys_turn):
232
+ return "_dislike_offer_"
233
+ else:
234
+ return "_dislike_info_"
235
+ elif (
236
+ "INFORM" in curr_acts and len(set(curr_acts)) == 2
237
+ ): # only inform and request_alts
238
+ assert len(curr_turn["frames"]) == 1
239
+ curr_slot_values = curr_turn["frames"][0]["state"]["slot_values"]
240
+ curr_intent = curr_turn["frames"][0]["state"]["active_intent"]
241
+
242
+ if len(prev_turn["frames"]) == 1:
243
+ prev_slot_values = prev_turn["frames"][0]["state"]["slot_values"]
244
+ else: # need to get the state with the same intent
245
+ intent2state = _get_intent2state(prev_turn)
246
+ prev_slot_values = intent2state[curr_intent]["slot_values"]
247
+
248
+ state_diff = compare_slot_values_in_state(prev_slot_values, curr_slot_values)
249
+ if state_diff: # case 1
250
+ return "_random_"
251
+ else: # case 2
252
+ return "_miss_"
253
+ else:
254
+ return "_unknown_"
255
+
256
+
257
+ def stats_request_alts_type(data):
258
+ for split in DATA_SPLIT:
259
+ stats = {
260
+ "_random_": 0,
261
+ "_miss_": 0,
262
+ "_dislike_offer_": 0,
263
+ "_dislike_info_": 0,
264
+ "_unknown_": 0,
265
+ }
266
+ n_all_usr_turn, n_request_alts = 0, 0
267
+
268
+ for dial_id in sorted(data[split].keys()):
269
+ dial = data[split][dial_id]
270
+ for turn_id, turn in enumerate(dial["turns"]):
271
+ prev_turn = turn
272
+ if turn["speaker"] == "SYSTEM":
273
+ sys_turn = turn
274
+ continue
275
+ acts = get_turn_act(turn)
276
+ if "REQUEST_ALTS" in acts:
277
+ n_request_alts += 1
278
+ type_result = _check_request_alts_type(
279
+ prev_turn, sys_turn, turn, acts
280
+ )
281
+ stats[type_result] += 1
282
+ if type_result == "_random_":
283
+ print("CASE {}".format(type_result))
284
+ show_turn(0, prev_turn)
285
+ show_turn(0, sys_turn)
286
+ show_turn(0, turn)
287
+ input("press...")
288
+ n_all_usr_turn += 1
289
+ prev_turn = turn
290
+
291
+ print("REQUEST_ALTS type statistics")
292
+ for k, v in stats.items():
293
+ print("{} => {}".format(k, v))
294
+ print(
295
+ "request_alts turns: {}, all usr turns: {}, dialogues: {}".format(
296
+ n_request_alts, n_all_usr_turn, len(data[split])
297
+ )
298
+ )
299
+
300
+
301
+ def show_utt_by_act(data):
302
+ target_act = "OFFER"
303
+ for split in DATA_SPLIT:
304
+ for dial_id in sorted(data[split].keys()):
305
+ dial = data[split][dial_id]
306
+ match_flag = False
307
+ for turn_id, turn in enumerate(dial["turns"]):
308
+ acts = get_turn_act(turn)
309
+ if target_act in acts:
310
+ match_flag = True
311
+ if match_flag:
312
+ show_dial(dial_id, dial)
313
+ input("press...")
314
+
315
+
316
+ def show_state_with_value_change(data):
317
+ for split in DATA_SPLIT:
318
+ for dial_id in sorted(data[split].keys()):
319
+ dial = data[split][dial_id]
320
+ intent2slot_values = {}
321
+ for turn_id, turn in enumerate(dial["turns"]):
322
+ utt, spk = turn["utterance"], turn["speaker"]
323
+ if spk != "USER":
324
+ prev_system_turn = turn
325
+ continue
326
+ for frame in turn["frames"]:
327
+ state = frame["state"]
328
+ active_intent = state["active_intent"]
329
+ slot_values = state["slot_values"]
330
+ if active_intent in intent2slot_values:
331
+ state_diff = compare_slot_values_in_state(
332
+ intent2slot_values[active_intent], slot_values
333
+ )
334
+ if state_diff:
335
+ print(
336
+ "Dial: {}, state change: {}".format(dial_id, state_diff)
337
+ )
338
+ print(
339
+ "==> Prev SYS: {}".format(prev_system_turn["utterance"])
340
+ )
341
+ for sys_frame in prev_system_turn["frames"]:
342
+ show_actions(sys_frame["actions"])
343
+ print("==> Curr USR: {}".format(utt))
344
+ show_actions(frame["actions"])
345
+ print(
346
+ "recorded state => intent: {}, slot2value: {}".format(
347
+ active_intent,
348
+ dict2str(intent2slot_values[active_intent]),
349
+ )
350
+ )
351
+ print(
352
+ "current state => intent: {}, slot2value: {}".format(
353
+ active_intent, dict2str(slot_values)
354
+ )
355
+ )
356
+ input("press...")
357
+ intent2slot_values[
358
+ active_intent
359
+ ] = slot_values # overlap with new state, no matter values changed or not
360
+
361
+
362
+ def check_state_with_value_change(data, display=False):
363
+ for split in DATA_SPLIT:
364
+ n_diff = {"NOTIFY_FAILURE": 0, "NEGATE": 0, "REQUEST_ALTS": 0, "RANDOM": 0}
365
+ for dial_id in sorted(data[split].keys()):
366
+ dial = data[split][dial_id]
367
+ intent2slot_values = {}
368
+ diff_flag = False
369
+ for turn_id, turn in enumerate(dial["turns"]):
370
+ if diff_flag:
371
+ break
372
+ utt, spk = turn["utterance"], turn["speaker"]
373
+ if spk != "USER":
374
+ prev_system_turn = turn
375
+ continue
376
+ for frame in turn["frames"]:
377
+ state = frame["state"]
378
+ active_intent = state["active_intent"]
379
+ slot_values = state["slot_values"]
380
+ if active_intent in intent2slot_values:
381
+ state_diff = compare_slot_values_in_state(
382
+ intent2slot_values[active_intent], slot_values
383
+ )
384
+ if state_diff:
385
+ usr_acts = get_turn_act(turn)
386
+ if "NOTIFY_FAILURE" in get_turn_act(prev_system_turn):
387
+ if display:
388
+ print("FAILURE", dial_id, utt)
389
+ n_diff["NOTIFY_FAILURE"] += 1
390
+ elif "NEGATE" in usr_acts:
391
+ if display:
392
+ print("NEGATE", dial_id, utt)
393
+ n_diff["NEGATE"] += 1
394
+ elif "REQUEST_ALTS" in usr_acts:
395
+ if display:
396
+ print("REQUEST_ALTS", dial_id, utt)
397
+ n_diff["REQUEST_ALTS"] += 1
398
+ else:
399
+ if display:
400
+ print("RANDOM", dial_id, utt)
401
+ n_diff["RANDOM"] += 1
402
+ if display:
403
+ input("press...")
404
+ # n_diff += 1
405
+ diff_flag = True
406
+ intent2slot_values[
407
+ active_intent
408
+ ] = slot_values # overlap with new state, no matter values changed or not
409
+ n = (
410
+ n_diff["NOTIFY_FAILURE"]
411
+ + n_diff["NEGATE"]
412
+ + n_diff["REQUEST_ALTS"]
413
+ + n_diff["RANDOM"]
414
+ )
415
+ print(
416
+ "{} => total dials: {}, change goal dials: {} (total: {})".format(
417
+ split, len(data[split]), dict2str(n_diff), n
418
+ )
419
+ )
420
+
421
+
422
+ def stats_after_system(data):
423
+ """
424
+ check the possible user behavior right after system offers/notify_failure
425
+ """
426
+ n = 0
427
+ stats = {
428
+ "SELECT": 0,
429
+ "REQUEST_ALTS": 0,
430
+ "REQUEST": 0,
431
+ "AFFIRM": 0,
432
+ "unknown": 0,
433
+ } # if system offers
434
+ # stats = {"INFORM": 0, "AFFIRM": 0, "NEGATE": 0, "unknown": 0} # if system notify_failure
435
+ for split in DATA_SPLIT:
436
+ for dial_id in sorted(data[split].keys()):
437
+ dial = data[split][dial_id]
438
+ for turn_id, turn in enumerate(dial["turns"]):
439
+ if turn_id == 0:
440
+ prev_turn = turn
441
+ continue
442
+ if turn["speaker"] == "SYSTEM":
443
+ sys_turn = turn
444
+ continue
445
+
446
+ if "OFFER" in get_turn_act(sys_turn):
447
+ # if "OFFER" in get_turn_act(sys_turn) and "NOTIFY_FAILURE" in get_turn_act(sys_turn):
448
+ # if "NOTIFY_FAILURE" in get_turn_act(sys_turn):
449
+ n += 1
450
+ acts = get_turn_act(turn)
451
+ # OFFER
452
+ if "SELECT" in acts:
453
+ stats["SELECT"] += 1
454
+ elif "REQUEST_ALTS" in acts:
455
+ stats["REQUEST_ALTS"] += 1
456
+ elif "REQUEST" in acts:
457
+ stats["REQUEST"] += 1
458
+ elif (
459
+ "AFFIRM" in acts
460
+ ): # cases fall into here are SYS_ACT: ["OFFER", "NOTIFY_FAILURE"], and USR_ACT: ["AFFIRM"],
461
+ # e.g., accept new proposal
462
+ show_turn(0, prev_turn)
463
+ show_turn(0, sys_turn)
464
+ show_turn(0, turn)
465
+ input("press...")
466
+ stats["AFFIRM"] += 1
467
+ else:
468
+ stats["unknown"] += 1
469
+
470
+ # NOTIFY_FAILURE
471
+ # if "INFORM" in acts:
472
+ # stats["INFORM"] += 1
473
+ # elif "AFFIRM" in acts:
474
+ # stats["AFFIRM"] += 1
475
+ # elif "NEGATE" in acts:
476
+ # stats["NEGATE"] += 1
477
+ # else:
478
+ # stats["unknown"] += 1
479
+
480
+ prev_turn = turn
481
+ for k, v in stats.items():
482
+ print("{} -> {}".format(k, v))
483
+ print("Total offer turns: {}".format(n))
src/crazyneuraluser/user_model_code/argument.py ADDED
@@ -0,0 +1,153 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def str2bool(v):
5
+ if v.lower() in ("yes", "true", "t", "y", "1"):
6
+ return True
7
+ elif v.lower() in ("no", "false", "f", "n", "0"):
8
+ return False
9
+ else:
10
+ raise argparse.ArgumentTypeError("Boolean value expected.")
11
+
12
+
13
+ def verify_args(args):
14
+ # datasets
15
+ assert isinstance(args.data_list, list) and len(args.data_list) > 0
16
+ for data_name in args.data_list:
17
+ assert data_name in ["sgd", "multiwoz"]
18
+
19
+ # mode
20
+ assert args.mode in ["training", "finetune", "testing", "interact"]
21
+ if args.mode == "finetune":
22
+ assert args.pre_checkpoint != ""
23
+
24
+
25
+ def get_args():
26
+ parser = argparse.ArgumentParser(description="")
27
+ # logging
28
+ parser.add_argument("--wandb_train_run_name", type=str, default="Default name")
29
+ # data
30
+ parser.add_argument(
31
+ "--data_dir",
32
+ type=str,
33
+ default="proc_data",
34
+ help="Directory of processed datasets",
35
+ )
36
+ parser.add_argument(
37
+ "--data_list",
38
+ type=str,
39
+ nargs="+",
40
+ default="",
41
+ help="Datasets involved, split by space, e.g., `sgd multiwoz`",
42
+ )
43
+
44
+ # design control
45
+ parser.add_argument(
46
+ "--use_ra_flag",
47
+ type=str2bool,
48
+ default=True,
49
+ help="Whether to use `request_alternatives` flag",
50
+ )
51
+
52
+ # training
53
+ parser.add_argument("--mode", type=str, required=True, help="")
54
+ parser.add_argument("--seed", type=int, default=1122)
55
+ parser.add_argument(
56
+ "--model_name", type=str, required=True, help="Unique name, e.g., job id"
57
+ )
58
+ parser.add_argument("--model_name_or_path", type=str, default="gpt2")
59
+ parser.add_argument(
60
+ "--train_batch_size", type=int, default=4, help="Batch size of training per gpu"
61
+ )
62
+ parser.add_argument(
63
+ "--eval_batch_size",
64
+ type=int,
65
+ default=1,
66
+ help="Batch size of evaluation per gpu",
67
+ ) # TODO: make decoding parallel
68
+ parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
69
+ parser.add_argument("--learning_rate", type=float, default=6.25e-5) # tune
70
+ parser.add_argument("--adam_epsilon", type=float, default=1e-12)
71
+ parser.add_argument("--max_grad_norm", type=float, default=1.0)
72
+ parser.add_argument("--max_epoch", type=int, default=20)
73
+ parser.add_argument(
74
+ "--fp16", type=str2bool, default=False, help="Whether to use float16"
75
+ )
76
+ parser.add_argument(
77
+ "--use_scheduler",
78
+ type=str2bool,
79
+ default=True,
80
+ help="Whether to use lr scheduler",
81
+ )
82
+ parser.add_argument("--warmup_steps", type=int, default=0)
83
+ parser.add_argument(
84
+ "--checkpoint",
85
+ type=str,
86
+ default="",
87
+ required=True,
88
+ help="Path of your trained model",
89
+ )
90
+ parser.add_argument(
91
+ "--pre_checkpoint",
92
+ type=str,
93
+ default="",
94
+ help="Path of the pretrained model used for finetuning",
95
+ )
96
+ parser.add_argument(
97
+ "--train_size",
98
+ type=int,
99
+ default=-1,
100
+ help="How many examples used for training. -1 means all data",
101
+ )
102
+ parser.add_argument(
103
+ "--eval_size",
104
+ type=int,
105
+ default=-1,
106
+ help="How many examples used for evaluation. -1 means all data",
107
+ )
108
+ parser.add_argument(
109
+ "--eval_interval",
110
+ type=int,
111
+ default=1000,
112
+ help="During training, how frequent to evaluate the model in terms of training examples",
113
+ )
114
+ parser.add_argument(
115
+ "--no_improve_max",
116
+ type=int,
117
+ default=100,
118
+ help="The max tolerance for model not improving",
119
+ )
120
+ parser.add_argument("--eps", type=float, default=1e-12)
121
+ parser.add_argument(
122
+ "--disable_display", type=str2bool, default=False, help="display progress bar"
123
+ )
124
+
125
+ # decoding
126
+ # parser.add_argument('--step', type=int, default=-1) # load model trained at which specific step
127
+ parser.add_argument(
128
+ "--dec_max_len", type=int, default=2000
129
+ ) # we use early stop to stop generation when hits <EOS>
130
+ parser.add_argument("--num_beams", type=int, default=1)
131
+ parser.add_argument("--temperature", type=float, default=1.0)
132
+ # parser.add_argument('--top_k', type=int, default=0)
133
+ # parser.add_argument('--top_p', type=int, default=0)
134
+ parser.add_argument("--decode_file", type=str, default="")
135
+ parser.add_argument(
136
+ "--eye_browse_output",
137
+ type=str2bool,
138
+ default=False,
139
+ help="Whether to eye browse decoded results",
140
+ )
141
+
142
+ # ddp
143
+ parser.add_argument(
144
+ "--local_rank",
145
+ type=int,
146
+ default=-1,
147
+ help="Local rank for distributed training (-1: not distributed)",
148
+ )
149
+
150
+ args = parser.parse_args()
151
+ verify_args(args)
152
+ print(args)
153
+ return args
src/crazyneuraluser/user_model_code/dataset.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import torch
5
+ from tqdm import tqdm
6
+
7
+ from crazyneuraluser.user_model_code.utils_sgd import (
8
+ add_str,
9
+ get_special_tokens,
10
+ wrap_element,
11
+ )
12
+
13
+
14
+ class SGD_Dataset(torch.utils.data.Dataset):
15
+ def __init__(self, args, tokenizer, data_split, generation, data_size):
16
+ assert data_split in ["train", "dev", "test", "demo"]
17
+ self.args = args
18
+ self.data_size = data_size
19
+ self.tokenizer = tokenizer
20
+ self.data_split = data_split
21
+ self.generation = generation
22
+ self.n_trimmed = 0
23
+
24
+ self.SPECIAL_TOKENS = get_special_tokens()
25
+ self._get_special_token_ids()
26
+
27
+ # create examples
28
+ self.examples = []
29
+ for data_name in args.data_list:
30
+ examples = self._create_examples(data_name, data_split)
31
+ self.examples += examples
32
+ print("Total ({}) -> {} examples".format(data_split, len(self.examples)))
33
+
34
+ def _get_special_token_ids(self):
35
+ self.bos_id = self.tokenizer.convert_tokens_to_ids(
36
+ self.SPECIAL_TOKENS["bos_token"]
37
+ )
38
+ self.eos_id = self.tokenizer.convert_tokens_to_ids(
39
+ self.SPECIAL_TOKENS["eos_token"]
40
+ )
41
+ self.pad_id = self.tokenizer.convert_tokens_to_ids(
42
+ self.SPECIAL_TOKENS["pad_token"]
43
+ )
44
+ self.sep_id = self.tokenizer.convert_tokens_to_ids(
45
+ self.SPECIAL_TOKENS["sep_token"]
46
+ )
47
+ # print('SPECIAL TOKEN MAPPING:')
48
+ # print('bos:{} | eos:{} | pad:{} | sep:{}'.format(self.bos_id, self.eos_id, self.pad_id, self.sep_id))
49
+
50
+ self.add_special_token_ids = {}
51
+ for token in self.SPECIAL_TOKENS["additional_special_tokens"]:
52
+ self.add_special_token_ids[token] = self.tokenizer.convert_tokens_to_ids(
53
+ token
54
+ )
55
+
56
+ self.true_token, self.false_token = "_True_", "_False_"
57
+ assert self.true_token in self.SPECIAL_TOKENS["additional_special_tokens"]
58
+ assert self.false_token in self.SPECIAL_TOKENS["additional_special_tokens"]
59
+ """
60
+ if using BPE (default method, simply call tokenizer(natural sentence)), no need unk_token
61
+ if using convert_tokens_to_ids, check which is correct way to handle oov:
62
+ a) simply use <endoftext> as unk_token (default setup) or
63
+ b) add unk_token into special tokens
64
+ """
65
+
66
+ def _create_examples(self, data_name, data_split):
67
+ data_file = os.path.join(
68
+ self.args.data_dir, data_name, "{}.json".format(data_split)
69
+ )
70
+ with open(data_file) as f:
71
+ data = json.load(f)
72
+
73
+ examples = []
74
+ for dial_id in tqdm(sorted(data.keys())):
75
+ if self.data_size != -1 and len(examples) >= self.data_size:
76
+ break
77
+ dial_meta = data[dial_id]
78
+ context = ""
79
+ for i in range(100):
80
+ example_id = "{}-{}".format(dial_id, i)
81
+ self.example_id = example_id
82
+ if example_id not in dial_meta:
83
+ break
84
+
85
+ # testing #
86
+ # # SGD
87
+ # if data_split == "test" and dial_id not in ["10_00056", "10_00075"]: # seen, movie domain
88
+ # if data_split == "test" and dial_id not in ["16_00040"]: # seen
89
+ # if data_split == "test" and dial_id not in ["8_00066", "16_00095", "8_00065"]: # unseen
90
+ # if data_split == "test" and dial_id not in ["9_00121", "9_00122"]:
91
+ # # req_alts cases w/i, w/o inform
92
+ # continue
93
+ # # mwoz
94
+ # if data_split == "test" and dial_id not in ["MUL0071.json"]:
95
+ # # test predictions in no offer & no book
96
+ # continue
97
+
98
+ # turn info
99
+ goal = dial_meta[example_id]["goal"]
100
+ # service = dial_meta[example_id]["service"]
101
+ # intent = dial_meta[example_id]["intent"]
102
+
103
+ # utterances
104
+ usr_utt = dial_meta[example_id]["utterances"]["usr"]
105
+ sys_utt = dial_meta[example_id]["utterances"]["sys"]
106
+
107
+ # actions
108
+ usr_act = dial_meta[example_id]["actions"]["usr"]
109
+ sys_act = dial_meta[example_id]["actions"]["sys"]
110
+
111
+ # binary flags
112
+ snt = dial_meta[example_id]["start_new_task"]
113
+ gc = dial_meta[example_id]["goal_change"]
114
+ ra = dial_meta[example_id]["req_alts"]
115
+
116
+ # get input ids
117
+ (
118
+ input_seq,
119
+ input_ids,
120
+ label_ids,
121
+ valid_example,
122
+ ) = self._prepare_input_ids(
123
+ goal, context, usr_utt, usr_act, sys_utt, sys_act, snt, gc, ra
124
+ )
125
+
126
+ if valid_example:
127
+ assert len(input_ids) < 1024
128
+ dial_meta[example_id]["context"] = context
129
+ examples.append(
130
+ {
131
+ "input_ids": input_ids, # list of ids
132
+ "label_ids": label_ids, # list of ids
133
+ "metadata": dial_meta[example_id],
134
+ "example_id": self.example_id,
135
+ "data_name": data_name,
136
+ }
137
+ )
138
+
139
+ # collect context
140
+ sys_utt_wrap = wrap_element("SYS", sys_utt)
141
+ usr_utt_wrap = wrap_element("USR", usr_utt)
142
+ context = add_str(context, sys_utt_wrap)
143
+ context = add_str(context, usr_utt_wrap)
144
+
145
+ print(
146
+ "Data Stat: {} ({}) -> {} examples ({} examples are trimmed)".format(
147
+ data_name, self.data_split, len(examples), self.n_trimmed
148
+ )
149
+ )
150
+ return examples
151
+
152
+ def _prepare_input_ids(
153
+ self, goal, context, usr_utt, usr_act, sys_utt, sys_act, snt, gc, ra
154
+ ):
155
+ """
156
+ prepare input sequence ids to GPT2
157
+ template: <CTX> <SYS_UTT> <SYS_ACT> <SNT> <RA> <GC> <GOAL> <USR_ACT> <USR_UTT>
158
+ """
159
+ goal_wrap = wrap_element("GOAL", goal)
160
+ context_wrap = wrap_element("CTX", context)
161
+ usr_utt_wrap = wrap_element("USR_UTT", usr_utt)
162
+ usr_act_wrap = wrap_element("USR_ACT", usr_act)
163
+ sys_utt_wrap = wrap_element("SYS_UTT", sys_utt)
164
+ sys_act_wrap = wrap_element("SYS_ACT", sys_act)
165
+
166
+ snt = self.true_token if snt else self.false_token # `Start New Task` flag
167
+ snt_wrap = wrap_element("SNT", snt)
168
+ gc = self.true_token if gc else self.false_token # `Goal Change` flag
169
+ gc_wrap = wrap_element("GC", gc)
170
+ ra = self.true_token if ra else self.false_token # `Request Alternatives` flag
171
+ ra_wrap = wrap_element("RA", ra)
172
+ if self.args.use_ra_flag:
173
+ flags_wrap = snt_wrap + " " + ra_wrap + " " + gc_wrap
174
+ else:
175
+ flags_wrap = snt_wrap + " " + gc_wrap
176
+
177
+ if not self.generation: # supervised
178
+ input_seq = (
179
+ context_wrap
180
+ + " "
181
+ + sys_utt_wrap
182
+ + " "
183
+ + sys_act_wrap
184
+ + " "
185
+ + flags_wrap
186
+ + " "
187
+ + goal_wrap
188
+ + " "
189
+ + usr_act_wrap
190
+ + " "
191
+ + usr_utt_wrap
192
+ + " "
193
+ + self.SPECIAL_TOKENS["eos_token"]
194
+ )
195
+ input_ids = self.tokenizer(input_seq)["input_ids"] # convert to ids
196
+ label_ids = self._get_labels(input_ids)
197
+ else: # generation
198
+ input_seq = (
199
+ context_wrap
200
+ + " "
201
+ + sys_utt_wrap
202
+ + " "
203
+ + sys_act_wrap
204
+ + " "
205
+ + flags_wrap
206
+ + " "
207
+ + goal_wrap
208
+ + " "
209
+ + "<USR_ACT/>"
210
+ ) # + " " + usr_act_wrap + " " + usr_utt_wrap
211
+ input_ids = self.tokenizer(input_seq)["input_ids"] # convert to ids
212
+ label_ids = None
213
+
214
+ valid_example = True
215
+ if len(input_ids) > 1023:
216
+ print("{}: {}".format(self.n_trimmed, self.example_id))
217
+ self.n_trimmed += 1
218
+ valid_example = False
219
+
220
+ return input_seq, input_ids, label_ids, valid_example
221
+
222
+ def _get_labels(self, input_ids):
223
+ for special_token in ["<SYS_ACT/>", "</GC>", "<USR_ACT/>"]:
224
+ special_token_id = self.add_special_token_ids[special_token]
225
+ assert input_ids.count(special_token_id) == 1
226
+
227
+ label_ids = [-100] * len(input_ids)
228
+
229
+ # sys act signal interval
230
+ start_position = input_ids.index(self.add_special_token_ids["<SYS_ACT/>"])
231
+ end_position = input_ids.index(self.add_special_token_ids["</GC>"]) + 1
232
+ label_ids[start_position:end_position] = input_ids[start_position:end_position]
233
+
234
+ # usr act and utt singal interval
235
+ start_position = input_ids.index(self.add_special_token_ids["<USR_ACT/>"])
236
+ assert self.eos_id == input_ids[-1]
237
+ label_ids[start_position:] = input_ids[start_position:]
238
+ assert len(label_ids) == len(input_ids)
239
+ return label_ids
240
+
241
+ def _pad(self, sentences, pad_id):
242
+ max_len = max((map(len, sentences)))
243
+ attention_mask = []
244
+ sentences_pad = []
245
+ for sent in sentences:
246
+ pad_len = max_len - len(sent)
247
+ sentences_pad.append(sent + [pad_id] * pad_len)
248
+ attention_mask.append([1] * len(sent) + [0] * pad_len)
249
+ return sentences_pad, attention_mask
250
+
251
+ def __len__(self): # required
252
+ return len(self.examples)
253
+
254
+ def __getitem__(self, index): # required
255
+ """
256
+ index will be ramdomly sampled by the fed sampler, we dont need to worry about index
257
+ """
258
+ return self.examples[index]
259
+
260
+ def collate_fn(self, batch): # optional but useful
261
+ """
262
+ when collate_fn is given to the torch dataloader, we can do further actions to the batch, e.g.,
263
+ tensor can be formed here a batch is formed as a list where each element is a defined data returned
264
+ by __getitem__, andy
265
+ """
266
+ input_ids = [example["input_ids"] for example in batch]
267
+ input_ids, attention_mask = self._pad(input_ids, self.pad_id)
268
+ input_ids, attention_mask = torch.tensor(input_ids).long().to(
269
+ self.args.device
270
+ ), torch.tensor(attention_mask).long().to(self.args.device)
271
+
272
+ if not self.generation:
273
+ label_ids = [example["label_ids"] for example in batch]
274
+ label_ids, _ = self._pad(label_ids, -100)
275
+ label_ids = torch.tensor(label_ids).long().to(self.args.device)
276
+ else:
277
+ label_ids = None
278
+ token_type_ids = None
279
+
280
+ # store info for scoring
281
+ metadata = [ex["metadata"] for ex in batch]
282
+ example_id = [ex["example_id"] for ex in batch]
283
+ data_name = [ex["data_name"] for ex in batch]
284
+
285
+ return {
286
+ "input_ids": input_ids,
287
+ "attention_mask": attention_mask,
288
+ "token_type_ids": token_type_ids,
289
+ "label_ids": label_ids,
290
+ "metadata": metadata,
291
+ "example_id": example_id,
292
+ "data_name": data_name,
293
+ }
294
+
295
+
296
+ if __name__ == "__main__":
297
+ pass
src/crazyneuraluser/user_model_code/utils_generation.py ADDED
@@ -0,0 +1,210 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from crazyneuraluser.user_model_code.utils_sgd import add_str, bcolors, wrap_element
4
+
5
+
6
+ def find_segment(gen, tag):
7
+ assert isinstance(gen, str)
8
+ gen = gen.split()
9
+ try:
10
+ start = gen.index("<{}/>".format(tag)) + 1
11
+ end = gen.index("</{}>".format(tag))
12
+ segment = " ".join(gen[start:end])
13
+ except Exception:
14
+ print("Missing {} tag in generated sequence".format(tag))
15
+ segment = None
16
+ return segment
17
+
18
+
19
+ def segment_gen(gen, dial_id):
20
+ def _color(_segment):
21
+ if tag == "CTX":
22
+ _segment = _segment.replace(" </USR>", f"{bcolors.ENDC}")
23
+ _segment = _segment.replace(" </SYS>", f"{bcolors.ENDC}")
24
+ _segment = _segment.replace("<USR/> ", f"USR: {bcolors.OKCYAN}")
25
+ _segment = _segment.replace("<SYS/> ", f"SYS: {bcolors.OKBLUE}")
26
+ if tag == "SYS_UTT":
27
+ _segment = f"{bcolors.OKBLUE}" + _segment + f"{bcolors.ENDC}"
28
+ if tag == "USR_UTT":
29
+ _segment = f"{bcolors.OKCYAN}" + _segment + f"{bcolors.ENDC}"
30
+ if tag in ["SYS_ACT", "USR_ACT", "GOAL"]:
31
+ _segment = _segment.replace("<ACT/> ", f"{bcolors.RED}")
32
+ _segment = _segment.replace(" </ACT>", f"{bcolors.ENDC}")
33
+ _segment = _segment.replace("<SLOT/> ", f"{bcolors.YELLOW}")
34
+ _segment = _segment.replace(" </SLOT>", f"{bcolors.ENDC}")
35
+ _segment = _segment.replace("<VALUE/> ", f"{bcolors.GREEN}")
36
+ _segment = _segment.replace(" </VALUE>", f"{bcolors.ENDC}")
37
+ if tag == "GOAL":
38
+ _segment = _segment.replace(
39
+ "<SCENARIO/>", f"<SCENARIO/>{bcolors.UNDERLINE}"
40
+ )
41
+ _segment = _segment.replace("</SCENARIO>", f"{bcolors.ENDC}</SCENARIO>")
42
+ _segment = _segment.replace("<TASK/>", f"<TASK/>{bcolors.UNDERLINE}")
43
+ _segment = _segment.replace("</TASK>", f"{bcolors.ENDC}</TASK>")
44
+ # if tag in ["SNT", "GC"]:
45
+ # segment = segment.replace("<{}/> ".format(tag), "<{}/> *".format(tag))
46
+ # segment = segment.replace(" </{}>".format(tag), "* <{}/>".format(tag))
47
+ return _segment
48
+
49
+ assert isinstance(gen, str)
50
+ print("*** Dial_id: {} ***".format(dial_id))
51
+ for tag in [
52
+ "CTX",
53
+ "SYS_UTT",
54
+ "SYS_ACT",
55
+ "GOAL",
56
+ "SNT",
57
+ "RA",
58
+ "GC",
59
+ "USR_ACT",
60
+ "USR_UTT",
61
+ ]:
62
+ segment = find_segment(gen, tag)
63
+ if segment is not None:
64
+ print('{} -> "{}"'.format(tag, _color(segment)))
65
+ else:
66
+ print("Fail to find the segment...")
67
+ print("GEN:", gen)
68
+ print("---" * 30)
69
+ input("press any key to continue...")
70
+
71
+
72
+ def save_gen(gen, dial_id, container):
73
+ output = {"raw_generation": gen}
74
+ parsed_generation = {}
75
+
76
+ assert isinstance(gen, str)
77
+ for tag in [
78
+ "CTX",
79
+ "SYS_UTT",
80
+ "SYS_ACT",
81
+ "GOAL",
82
+ "SNT",
83
+ "RA",
84
+ "GC",
85
+ "USR_ACT",
86
+ "USR_UTT",
87
+ ]:
88
+ segment = find_segment(gen, tag)
89
+ if segment is not None:
90
+ parsed_generation[tag] = segment
91
+ else:
92
+ print("Fail to parse generation on example {}".format(dial_id))
93
+ parsed_generation[tag] = None
94
+
95
+ output["parsed_generation"] = parsed_generation
96
+ container[dial_id] = output
97
+
98
+
99
+ # def decode(args, batch, model, tokenizer):
100
+ # input_ids = batch['input_ids']
101
+ # batch_size, ctx_len = input_ids.size()
102
+ # assert batch_size == 1
103
+ # bos_id, eos_id, pad_id, sep_id = tokenizer.convert_tokens_to_ids(['<BOS>', '<EOS>', '<PAD>', '<SEP>'])
104
+ #
105
+ # # output size: (B, T)
106
+ # output = model.generate(input_ids, max_length=(ctx_len+args.dec_max_len), do_sample=False,
107
+ # temperature=args.temperature, use_cache=True, num_beams=args.num_beams, bos_token_id=bos_id,
108
+ # eos_token_id=eos_id, pad_token_id=pad_id, early_stopping=True)
109
+ #
110
+ # gen = tokenizer.decode(output[0]) # include context fed into model
111
+ # segment_gen(gen, batch["example_id"][0])
112
+ # return [gen]
113
+
114
+
115
+ def prepare_input_ids(
116
+ args: object, tokenizer: object, data: object, start_token: object
117
+ ) -> object:
118
+ assert start_token in ["<SYS_ACT/>", "<USR_ACT/>"]
119
+ input_seq = ""
120
+ for key in [
121
+ "CTX",
122
+ "SYS_UTT",
123
+ "SYS_ACT",
124
+ "SNT",
125
+ "RA",
126
+ "GC",
127
+ "GOAL",
128
+ ]: # fixed order, consistent between training and inference
129
+ if key not in data:
130
+ continue
131
+ wrap = wrap_element(key, data[key])
132
+ input_seq = add_str(input_seq, wrap)
133
+
134
+ input_seq = add_str(input_seq, start_token)
135
+
136
+ input_ids = tokenizer(input_seq)["input_ids"] # convert to ids
137
+ input_ids = torch.tensor([input_ids]).long().to(args.device)
138
+ return input_ids
139
+
140
+
141
+ def decode_e2e(
142
+ args, batch, model, tokenizer, user_goal=None, prev_usr_act=None, collector=None
143
+ ):
144
+ """decode with predicted sys act, goal can be random or from the corpus"""
145
+ assert len(batch["metadata"]) == 1
146
+ context = batch["metadata"][0]["context"]
147
+ sys_utt = batch["metadata"][0]["utterances"]["sys"]
148
+ bos_id, _, pad_id, sep_id = tokenizer.convert_tokens_to_ids(
149
+ ["<BOS>", "<EOS>", "<PAD>", "<SEP>"]
150
+ )
151
+
152
+ # first forward pass
153
+ data = {"CTX": context, "SYS_UTT": sys_utt}
154
+ start_token, end_token = "<SYS_ACT/>", "</GC>"
155
+ input_ids = prepare_input_ids(args, tokenizer, data, start_token)
156
+ eos_id = tokenizer.convert_tokens_to_ids(end_token)
157
+ output = model.generate(
158
+ input_ids,
159
+ max_length=args.dec_max_len,
160
+ do_sample=False,
161
+ temperature=args.temperature,
162
+ use_cache=True,
163
+ num_beams=args.num_beams,
164
+ bos_token_id=bos_id,
165
+ eos_token_id=eos_id,
166
+ pad_token_id=pad_id,
167
+ early_stopping=True,
168
+ )
169
+ gen = tokenizer.decode(output[0]) # include context fed into model
170
+
171
+ # parse the first pass prediction
172
+ for key in ["SYS_ACT", "SNT", "GC", "RA"]:
173
+ value = find_segment(gen, key)
174
+ data[key] = value
175
+ # print("***** First run generation *****")
176
+ # print("SYS_ACT -> {}".format(data["SYS_ACT"]))
177
+ # print("FLAGS -> SNT: {}, GC: {}, RA: {} *****".format(data["SNT"], data["GC"], data["RA"]))
178
+ # print("********************************")
179
+
180
+ # prepare goal
181
+ if user_goal is None: # use ground truth goal from corpus
182
+ data["GOAL"] = batch["metadata"][0]["goal"]
183
+ else:
184
+ goal = user_goal.prepare_turn_goal(
185
+ prev_usr_act, data["SYS_ACT"], data["SNT"], data["GC"], data["RA"]
186
+ )
187
+ data["GOAL"] = goal
188
+
189
+ # second forward pass
190
+ start_token, end_token = "<USR_ACT/>", "<EOS>"
191
+ input_ids = prepare_input_ids(args, tokenizer, data, start_token)
192
+ eos_id = tokenizer.convert_tokens_to_ids(end_token)
193
+ output = model.generate(
194
+ input_ids,
195
+ max_length=args.dec_max_len,
196
+ do_sample=False,
197
+ temperature=args.temperature,
198
+ use_cache=True,
199
+ num_beams=args.num_beams,
200
+ bos_token_id=bos_id,
201
+ eos_token_id=eos_id,
202
+ pad_token_id=pad_id,
203
+ early_stopping=True,
204
+ )
205
+ gen = tokenizer.decode(output[0]) # include context fed into model
206
+ if args.eye_browse_output:
207
+ segment_gen(gen, batch["example_id"][0])
208
+ else:
209
+ save_gen(gen, batch["example_id"][0], collector)
210
+ return [gen]
src/crazyneuraluser/user_model_code/utils_multiwoz.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import re
3
+
4
+
5
+ def get_original_act_set():
6
+ # NOTE:
7
+ # act `Book` and `NoBook` belong to `Booking` domain by ontology,
8
+ # they contain information about either `restaurant` or `hotel` domain
9
+ # full act vocab: https://github.com/ConvLab/ConvLab/blob/master/data/multiwoz/ \
10
+ # annotation/Multiwoz%20data%20analysis.md#dialog-act
11
+ acts = set()
12
+ acts.add("Inform")
13
+ acts.add("Request")
14
+ acts.add(
15
+ "NoOffer"
16
+ ) # equivalent to the concept of `no matching`, `cannot find` in database
17
+ acts.add("Recommend")
18
+ acts.add("Select")
19
+ acts.add(
20
+ "OfferBook"
21
+ ) # only for `train` domain, ask if book is needed, equivalent to `Booking-Inform`
22
+ # with [[none, none]] args in restaurant/hotel domain
23
+ acts.add(
24
+ "OfferBooked"
25
+ ) # only for `train` domain, inform booking is complete, with corresponding info (such as ref number)
26
+ acts.add("Book") # inform booking is successful, equivalent to `OfferBooked` above
27
+ acts.add(
28
+ "NoBook"
29
+ ) # inform booking fails, might because of no availability, usually come together act `request`
30
+ acts.add("bye")
31
+ acts.add("greet")
32
+ acts.add("reqmore")
33
+ acts.add("welcome")
34
+ acts.add("thank")
35
+ return acts
36
+
37
+
38
+ def get_act_natural_language(act):
39
+ if act in ["bye", "greet", "reqmore", "welcome", "thank"]:
40
+ return act
41
+
42
+ assert act[0].isupper()
43
+ tokens = re.findall("[A-Z][^A-Z]*", act) # e.g., `FindEvents` -> `Find Events`
44
+ tokens = list(map(str.lower, tokens)) # lower case, -> `find events`
45
+ act_nl = " ".join(tokens)
46
+ return act_nl
47
+
48
+
49
+ def convert_act_into_sgd(act, SPECIAL_TOKENS):
50
+ """
51
+ convert multiwoz acts (w/o domain info) into sgd acts ensure that acts with same concept use one name
52
+ e.g., Book (OfferBooked) -> NOTIFY_SUCCESS, NoBook -> NOTIFY_FAILURE
53
+ """
54
+ if act == "NoOffer":
55
+ act = "NOTIFY_FAILURE"
56
+
57
+ elif act == "Recommend":
58
+ act = "OFFER"
59
+
60
+ # technically, `OfferBook` is equivalent to (`act=OFFER_INTENT, slot=intent, value=ReserveRestaurant`)
61
+ # on system side in sgd since (1) the conversion is not trivial (completely different representations)
62
+ # and (2) multiwoz has no slot called `intent`
63
+ # one cannot simply convert `OfferBook` to `OFFER_INTENT`
64
+ # we thus keep the act as is
65
+ # note that there is no slot `intent` and value conveying intents in multiwoz
66
+ elif act == "OfferBook":
67
+ act = "Offer_Book"
68
+
69
+ elif act == "OfferBooked":
70
+ act = "NOTIFY_SUCCESS"
71
+
72
+ elif act == "Book": # same as `OfferBooked`
73
+ act = "NOTIFY_SUCCESS"
74
+
75
+ elif act == "NoBook":
76
+ act = "NOTIFY_FAILURE"
77
+
78
+ elif act == "bye":
79
+ act = "GOODBYE"
80
+
81
+ elif act == "reqmore":
82
+ act = "REQ_MORE"
83
+
84
+ elif act == "thank":
85
+ act = "THANK_YOU"
86
+ # elif act == "greet":
87
+ # elif act == "welcome":
88
+ act = act.upper() # align with sgd acts, e.g., `Inform` -> `INFORM`
89
+
90
+ # check if valid
91
+ assert "_{}_".format(act) in SPECIAL_TOKENS["additional_special_tokens"]
92
+ return act
93
+
94
+
95
+ def load_schema(schema_file):
96
+ def _update(key, value, mapping):
97
+ if key in mapping:
98
+ assert (
99
+ value == mapping[key]
100
+ ) # ensure service meta is the same between data splits
101
+ else:
102
+ mapping[key] = value
103
+
104
+ def _restructure_service_meta(service_meta, attribute):
105
+ """convert slot/intent metadata list into dict(slot/intent=metadata)"""
106
+ assert attribute in ["slots", "intents"]
107
+ mapping = {}
108
+ for value in service_meta[attribute]:
109
+ key = value["name"]
110
+ if attribute == "slots": # domain-slot in multiwoz
111
+ assert "-" in key
112
+ _, key = key.split("-") # domain, slot
113
+ key = normalise_slot(key)
114
+ else: # intent
115
+ key = normalise_intent(key)
116
+ mapping[key] = value
117
+ service_meta[attribute] = mapping
118
+
119
+ with open(schema_file) as f:
120
+ data = json.load(f)
121
+
122
+ SERVICE2META = {}
123
+ SLOTS, INTENTS = set(), set()
124
+ for service_meta in data:
125
+ service = service_meta["service_name"]
126
+ _restructure_service_meta(service_meta, "slots")
127
+ _restructure_service_meta(service_meta, "intents")
128
+ _update(service, service_meta, SERVICE2META)
129
+
130
+ # collect domain-independent slots
131
+ for slot in service_meta["slots"]:
132
+ SLOTS.add(slot)
133
+
134
+ for intent in service_meta["intents"]:
135
+ INTENTS.add(intent)
136
+
137
+ print("Load schema, intents: {}, slots: {}".format(len(INTENTS), len(SLOTS)))
138
+ return SERVICE2META, INTENTS, SLOTS
139
+
140
+
141
+ def normalise_intent(intent):
142
+ """convert intent into natural language, e.g., find_hotel -> find hotel"""
143
+ if intent == "police":
144
+ intent = "find_police"
145
+ if intent == "book_taxi":
146
+ intent = "find_taxi"
147
+ assert "_" in intent
148
+ return " ".join(intent.split("_"))
149
+
150
+
151
+ def normalise_slot(slot):
152
+ if slot == "pricerange":
153
+ return "price range"
154
+
155
+ elif slot == "bookday":
156
+ return "book day"
157
+
158
+ elif slot == "bookpeople":
159
+ return "book people"
160
+
161
+ elif slot == "booktime":
162
+ return "book time"
163
+
164
+ elif slot == "bookstay":
165
+ return "book stay"
166
+
167
+ elif slot == "ref":
168
+ return "reference"
169
+
170
+ elif slot == "arriveby":
171
+ return "arrive by"
172
+
173
+ elif slot == "leaveat":
174
+ return "leave at"
175
+
176
+ elif slot == "trainid":
177
+ return "train id"
178
+
179
+ elif slot == "openhours":
180
+ return "open hours"
181
+
182
+ elif slot == "entrancefee":
183
+ return "entrance fee"
184
+
185
+ elif slot in ["none", "?"]:
186
+ return "Empty"
187
+
188
+ else:
189
+ return slot
190
+
191
+
192
+ def normalise_value(value):
193
+ # deal with binary and empty values
194
+ if value == "yes":
195
+ return "True"
196
+
197
+ elif value == "no":
198
+ return "False"
199
+
200
+ elif value in ["none", "?"]:
201
+ return "Empty"
202
+
203
+ else:
204
+ return value