Spaces:
Runtime error
Runtime error
| import sys | |
| import traceback | |
| import pandas as pd | |
| # from tqdm import tqdm | |
| from UBAR_code.interaction import UBAR_interact | |
| from user_model_code.interaction import multiwoz_interact | |
| from UBAR_code.interaction.UBAR_interact import bcolors | |
| # from tqdm import tqdm | |
| from scripts.UBAR_code.interaction import UBAR_interact | |
| from scripts.user_model_code.interaction import multiwoz_interact | |
| from scripts.UBAR_code.interaction.UBAR_interact import bcolors | |
| def instantiate_agents(): | |
| UBAR_checkpoint_path = "cambridge-masters-project/epoch50_trloss0.59_gpt2" | |
| user_model_checkpoint_path = "cambridge-masters-project/MultiWOZ-full_checkpoint_step340k" | |
| sys_model = UBAR_interact.UbarSystemModel( | |
| "UBAR_sys_model", UBAR_checkpoint_path, "cambridge-masters-project/scripts/UBAR_code/interaction/config.yaml" | |
| ) | |
| user_model = multiwoz_interact.NeuralAgent( | |
| "user", user_model_checkpoint_path, "cambridge-masters-project/scripts/user_model_code/interaction/config.yaml" | |
| ) | |
| return sys_model, user_model | |
| def read_multiwoz_data(): | |
| """ | |
| Read the multiwoz 2.0 raw data from the .json file | |
| """ | |
| raw_mwoz_20_path = "cambridge-masters-project/data/raw/UBAR/multi-woz/data.json" | |
| df_raw_mwoz = pd.read_json(raw_mwoz_20_path) | |
| return df_raw_mwoz | |
| def load_test_val_lists(): | |
| val_list_file = "cambridge-masters-project/data/raw/UBAR/multi-woz/valListFile.json" | |
| test_list_file = "cambridge-masters-project/data/raw/UBAR/multi-woz/testListFile.json" | |
| def main( | |
| write_to_file=False, ground_truth_system_responses=False, train_only=True, n_dialogues="all", log_successes=False | |
| ): | |
| sys_model, user_model = instantiate_agents() | |
| # TODO: move hardcoded vars into config file | |
| raw_mwoz_20_path = "cambridge-masters-project/data/raw/UBAR/multi-woz/data.json" | |
| user_utterances_out_path = "cambridge-masters-project/data/preprocessed/UBAR/user_utterances_from_simulator.txt" | |
| logging_successes_path = "cambridge-masters-project/data/preprocessed/UBAR/logging_successes" | |
| sys_model.print_intermediary_info = False | |
| user_model.print_intermediary_info = False | |
| df_raw_mwoz = pd.read_json(raw_mwoz_20_path) | |
| if n_dialogues == "all": | |
| n_dialogues = len(df_raw_mwoz.columns) | |
| curr_dialogue_user_utterances_formatted = [] | |
| print("Loading goals...") | |
| goals = multiwoz_interact.read_multiWOZ_20_goals(raw_mwoz_20_path, n_dialogues) | |
| # Write column headers | |
| if write_to_file: | |
| with open(user_utterances_out_path, "w") as f: | |
| f.write("Dialogue #\tDialogue ID\tTurn #\tSystem Response\n") | |
| print("Loading data...") | |
| df_mwoz_data = read_multiwoz_data() | |
| val_list, test_list = load_test_val_lists() | |
| successful_dialogues = 0 | |
| total_dialogues_generated = 0 # train dialogues only | |
| for dialogue_idx, (goal, dialogue_filename) in enumerate(zip(goals, df_mwoz_data.columns)): | |
| if log_successes: | |
| # log successful_dialogues to logging_successes_path every 100 dialogues | |
| if dialogue_idx % 100 == 0: | |
| with open(logging_successes_path, "w") as f: | |
| f.write(str(successful_dialogues) + " / " + str(total_dialogues_generated)) | |
| curr_dialogue_user_utterances_formatted = [] | |
| if train_only: | |
| if dialogue_filename in val_list or dialogue_filename in test_list: | |
| continue | |
| total_dialogues_generated += 1 | |
| print("Dialogue: {}".format(dialogue_filename)) | |
| # There are occasionally exceptions thrown from one of the agents, usually the user | |
| # In this case we simply continue to the next dialogue | |
| try: | |
| # Reset state after each dialogue | |
| sys_model.init_session() | |
| user_model.init_session(ini_goal=goal) | |
| sys_response = "" | |
| for turn_idx in range(50): | |
| # Turn idx in this case represents the turn as one user utterance AND one system response | |
| usr_response_raw_data_idx = turn_idx * 2 | |
| sys_response_raw_data_idx = turn_idx * 2 + 1 | |
| user_utterance = user_model.response(sys_response) | |
| print(bcolors.OKBLUE + "User: " + bcolors.ENDC + user_utterance) | |
| if write_to_file: | |
| user_utterance = user_utterance.replace("\n", " ") | |
| curr_dialogue_user_utterances_formatted.append( | |
| str(dialogue_idx) | |
| + "\t" | |
| + dialogue_filename | |
| + "\t" | |
| + str(usr_response_raw_data_idx) | |
| + "\t" | |
| + user_utterance | |
| + "\n" | |
| ) | |
| if user_model.is_terminated(): | |
| successful_dialogues += 1 | |
| print(bcolors.OKCYAN + "Dialogue terminated successfully!" + bcolors.ENDC) | |
| print(bcolors.OKCYAN + "---" * 30 + bcolors.ENDC + "\n") | |
| if write_to_file: | |
| # Write whole dialogue to file | |
| with open(user_utterances_out_path, "a") as f: | |
| for line in curr_dialogue_user_utterances_formatted: | |
| f.write(line) | |
| break | |
| # Next turn materials | |
| if ground_truth_system_responses: | |
| # If we are at the end of the ground truth dialogues | |
| if len(df_mwoz_data.iloc[:, dialogue_idx].log) <= sys_response_raw_data_idx: | |
| print(bcolors.RED + "Dialogue terminated unsuccessfully!" + bcolors.ENDC) | |
| print(bcolors.RED + "---" * 30 + bcolors.ENDC + "\n") | |
| break | |
| sys_response = df_mwoz_data.iloc[:, dialogue_idx].log[sys_response_raw_data_idx]["text"] | |
| else: | |
| sys_response = sys_model.response(user_utterance, turn_idx) | |
| capitalised_sys_response = sys_response[0].upper() + sys_response[1:] | |
| print(bcolors.GREEN + "System: " + bcolors.ENDC + capitalised_sys_response) | |
| except Exception: | |
| print(bcolors.RED + "*" * 30 + bcolors.ENDC) | |
| print(bcolors.RED + "Error in dialogue {}".format(dialogue_filename) + bcolors.ENDC) | |
| print(bcolors.RED + "*" * 30 + bcolors.ENDC) | |
| traceback.print_exc() | |
| continue | |
| print("Successful dialogues: {}".format(successful_dialogues)) | |
| print("Total dialogues: {}".format(n_dialogues)) | |
| print("% Successful Dialopues: {}".format(successful_dialogues / n_dialogues)) | |
| if __name__ == "__main__": | |
| # TODO: move parameters to config file | |
| # Fix the hacky mess below | |
| ground_truth_system_responses = sys.argv[1] | |
| if ground_truth_system_responses == "False": | |
| ground_truth_system_responses = False | |
| else: | |
| ground_truth_system_responses = True | |
| main(write_to_file=False, ground_truth_system_responses=ground_truth_system_responses) | |