import csv import json import logging import multiprocessing as mp import os import re import string import sys import unicodedata from typing import Any, Dict, List, NewType, Optional, Union import numpy as np import yaml from datasets import Dataset, load_dataset from easygoogletranslate import EasyGoogleTranslate from evaluate import load from langchain.prompts import FewShotPromptTemplate, PromptTemplate from tqdm import tqdm from yaml.loader import SafeLoader XQUAD_LANG2CODES = { "bengali": "bn", "korean": "ko", "swahili": "sw", "english": "en", "indonesian": "id", "arabic": "ar", "finnish": "fi", "telugu": "te", "russian": "ru", "german": "de", "greek": "el", "hindi": "hi", "vietnamese": "vi", "romanian": "ro", } INDICQA_LANG2CODES = { "indicqa": "as", "bengali": "bn", "gujarati": "gu", "hindi": "hi", "kannada": "kn", "malayalam": "ml", "marathi": "mr", "odia": "or", "punjabi": "pa", "tamil": "ta", "telugu": "te", "assamese": "as", } PUNCT = { chr(i) for i in range(sys.maxunicode) if unicodedata.category(chr(i)).startswith("P") }.union(string.punctuation) WHITESPACE_LANGS = ["en", "es", "hi", "vi", "de", "ar"] MIXED_SEGMENTATION_LANGS = ["zh"] TYDIQA_LANG2CODES = { "bengali": "bn", "korean": "ko", "swahili": "sw", "english": "en", "indonesian": "id", "arabic": "ar", "finnish": "fi", "telugu": "te", "russian": "ru", "assamese": "as", "persian": "fa", } logger = logging.Logger("Xlsum_task") LANGUAGE_TO_SUFFIX = { "chinese_simplified": "zh-CN", "french": "fr", "portuguese": "pt", "english": "en", "arabic": "ar", "hindi": "hi", "indonesian": "id", "amharic": "am", "bengali": "bn", "burmese": "my", "uzbek": "uz", "nepali": "ne", "japanese": "ja", "spanish": "es", "turkish": "tr", "persian": "fa", "azerbaijani": "az", "korean": "ko", "hebrew": "he", "telugu": "te", "german": "de", "greek": "el", "tamil": "ta", "assamese": "as", "vietnamese": "vi", "russian": "ru", "romanian": "ro", "malayalam": "ml", "swahili": "sw", "bulgarian": "bg", "thai": "th", "urdu": "ur", "italian": "it", "polish": "pl", "dutch": "nl", "swedish": "sv", "danish": "da", "norwegian": "no", "finnish": "fi", "hungarian": "hu", "czech": "cs", "slovak": "sk", "ukrainian": "uk", } PARAMS = NewType("PARAMS", Dict[str, Any]) def read_parameters(args_path) -> PARAMS: with open(args_path) as f: args = yaml.load(f, Loader=SafeLoader) return args def load_qa_dataset(dataset_name, lang, split, translate_test=False, limit=5): if dataset_name == "indicqa": if split != "train": dataset = load_dataset( "ai4bharat/IndicQA", f"indicqa.{INDICQA_LANG2CODES[lang]}" )[split] else: dataset = load_dataset("squad_v2")[split] elif dataset_name == "xquad": if split != "train": dataset = load_dataset("xquad", f"xquad.{XQUAD_LANG2CODES[lang]}")[ "validation" ] else: dataset = load_dataset("squad")[split] elif dataset_name == "tydiqa": dataset = load_dataset("tydiqa", "secondary_task")[split] dataset = dataset.map( lambda example: {"lang": TYDIQA_LANG2CODES[example["id"].split("-")[0]]} ) dataset = dataset.filter(lambda example: example["lang"] == lang) elif dataset_name == "mlqa": if split == "train": print("No Training Data for MLQA, switching to validation!") split = "validation" if translate_test: dataset_name = f"mlqa-translate-test.{lang}" else: dataset_name = f"mlqa.{lang}.{lang}" dataset = load_dataset("mlqa", dataset_name)[split] else: raise NotImplementedError() return dataset.select(np.arange(limit)) def construct_prompt( instruction: str, test_example: dict, ic_examples: List[dict], zero_shot: bool, lang: str, config: Any, ): example_prompt = PromptTemplate( input_variables=["context", "question", "answers"], template="Context: {context} \n Question: {question} \n " "Answers: {answers}", ) zero_shot_template = ( f"""{instruction}""" + " \n : {context} \n : {question} " "" ) prompt = ( FewShotPromptTemplate( examples=ic_examples, prefix=instruction, example_prompt=example_prompt, suffix=": {context} \n : {question} \n Answers: ?", input_variables=["question", "context"], ) if not zero_shot else PromptTemplate( input_variables=["question", "context"], template=zero_shot_template ) ) label = test_example["answers"] if config["input"] != lang: test_example = _translate_example( example=test_example, src_language=lang, target_language=config["input"] ) return ( prompt.format( question=test_example["question"], context=test_example["context"] ), label, ) def dump_metrics( lang: str, config: Dict[str, str], f1: float, em: float, metric_logger_path: str ): # Check if the metric logger file exists file_exists = os.path.exists(metric_logger_path) # Open the CSV file in append mode with open(metric_logger_path, "a", newline="") as f: csvwriter = csv.writer(f, delimiter=",") # Write header row if the file is newly created if not file_exists: header = ["Language", "Prefix", "Input", "Context", "Output", "F1", "Em"] csvwriter.writerow(header) csvwriter.writerow( [ lang, config["prefix"], config["input"], config["context"][0], config["output"], f1, em, ] ) def dump_predictions(idx, response, label, response_logger_file): obj = {"q_idx": idx, "prediction": response, "label": label} with open(response_logger_file, "a") as f: f.write(json.dumps(obj, ensure_ascii=False) + " \n ") def _translate_instruction(basic_instruction: str, target_language: str) -> str: translator = EasyGoogleTranslate( source_language="en", target_language=LANGUAGE_TO_SUFFIX[target_language], timeout=50, ) return translator.translate(basic_instruction) def _translate_prediction_to_output_language( prediction: str, prediction_language: str, output_language: str ) -> str: translator = EasyGoogleTranslate( source_language=LANGUAGE_TO_SUFFIX[prediction_language], target_language=LANGUAGE_TO_SUFFIX[output_language], timeout=10, ) return translator.translate(prediction) def create_instruction(lang: str, instruction_language: str, expected_output): basic_instruction = ( "Answer to the below, based only to the given , Follow these instructions: \n " "1. The answer should include only words from the given context \n " "2. The answer must include up to 5 words \n " "3. The answer Should be the shortest as possible \n " f"4. The answer must be in {expected_output} only!, not another language!!!" ) return ( basic_instruction if instruction_language == "english" else _translate_instruction(basic_instruction, target_language=lang) ) def _translate_example( example: Dict[str, str], src_language: str, target_language: str ): translator = EasyGoogleTranslate( source_language=LANGUAGE_TO_SUFFIX[src_language], target_language=LANGUAGE_TO_SUFFIX[target_language], timeout=30, ) try: return { "question": translator.translate(example["question"]), "context": translator.translate(example["context"][:2000]) + translator.translate(example["context"][2000:4000]) + translator.translate(example["context"][4000:6000]), "answers": "", } except Exception as e: pass def choose_few_shot_examples( train_dataset: Dataset, few_shot_size: int, context: List[str], selection_criteria: str, lang: str, ) -> List[Dict[str, Union[str, int]]]: """Selects few-shot examples from training datasets Args: train_dataset (Dataset): Training Dataset few_shot_size (int): Number of few-shot examples selection_criteria (few_shot_selection): How to select few-shot examples. Choices: [random, first_k] Returns: List[Dict[str, Union[str, int]]]: Selected examples """ selected_examples = [] example_idxs = [] if selection_criteria == "first_k": example_idxs = list(range(few_shot_size)) elif selection_criteria == "random": example_idxs = ( np.random.choice(len(train_dataset), size=few_shot_size, replace=True) .astype(int) .tolist() ) ic_examples = [ { "question": train_dataset[idx]["question"], "context": train_dataset[idx]["context"], "answers": train_dataset[idx]["answers"]["text"], } for idx in example_idxs ] for idx, ic_language in enumerate(context): ( selected_examples.append(ic_examples[idx]) if ic_language == lang else ( selected_examples.append( _translate_example( example=ic_examples[idx], src_language=lang, target_language=ic_language, ) ) ) ) return selected_examples def normalize_answer(s): """Lower text and remove punctuation, articles and extra whitespace.""" def remove_articles(text): return re.sub(r"\b(a|an|the)\b", " ", text) def white_space_fix(text): return " ".join(text.split()) def remove_punc(text): exclude = set(PUNCT) # set(string.punctuation) return "".join(ch for ch in text if ch not in exclude) def lower(text): return text.lower() return white_space_fix(remove_articles(remove_punc(lower(s)))) def process_test_example( test_data, config_header, idx, test_example, config, zero_shot, lang, params ): try: # Your existing code for processing each test example instruction = create_instruction( lang=config["prefix"], expected_output=config["output"] ) text_example = { "question": test_example["question"], "context": test_example["context"], "answers": test_example["answers"]["text"], } ic_examples = [] if not zero_shot: ic_examples = choose_few_shot_examples( train_dataset=test_data, few_shot_size=len(config["context"]), context=config["context"], selection_criteria="random", lang=params["selected_language"], ) prompt, label = construct_prompt( instruction=instruction, test_example=text_example, ic_examples=ic_examples, zero_shot=zero_shot, lang=lang, config=config, ) print(len(prompt)) pred = get_prediction( prompt=prompt, endpoint_id=7327255438662041600, project_id=16514800572 ) # pred = mixtral_completion(prompt) print(pred) logger.info("Saving prediction to persistent volume") os.makedirs( f"{params['response_logger_root']}/{params['model']}/{lang}", exist_ok=True ) dump_predictions( idx=idx, response=pred, label=label, response_logger_file=f"{params['response_logger_root']}/{params['model']}/{lang}/{config_header}.csv", ) except Exception as e: # Handle exceptions here print(f"Error processing example {idx}: {e}") def run_one_configuration(params: Optional[PARAMS] = None): if not params: params = read_parameters("../../parameters.yaml") lang = params["selected_language"] config = params["config"] zero_shot = len(config["context"]) == 0 rouge1, rouge2, rougeL, normalized_ic_examples, batched_predictions = ( [], [], [], [], [], ) config_header = f"{config['input']}_{config['prefix']}_{config['context'][0]}_{config['output']}" dataset_name = params["dataset_name"] squad_metric = load("squad") metric = params["metric"] f1_sum = 0 em_sum = 0 avg_em = 0 avg_f1 = 0 preds = [] labels = [] f1s, ems = [], [] test_data = load_qa_dataset( dataset_name=params["dataset_name"], lang=lang, split="validation" if params["dataset_name"] == "xquad" else "test", limit=params["limit"], ) for idx, test_example in (pbar := tqdm(enumerate(test_data))): try: instruction = create_instruction( lang=config["prefix"], expected_output=config["output"] ) text_example = { "question": test_example["question"], "context": test_example["context"], "answers": test_example["answers"]["text"], } ic_examples = [] if not zero_shot: ic_examples = choose_few_shot_examples( train_dataset=test_data, few_shot_size=len(config["context"]), context=config["context"], selection_criteria="random", lang=params["selected_language"], ) prompt, label = construct_prompt( instruction=instruction, test_example=text_example, ic_examples=ic_examples, zero_shot=zero_shot, lang=lang, config=config, ) pred = mt0_completion(prompt=prompt) print(pred) logger.info("Saving prediction to persistent volume") os.makedirs( f"{params['response_logger_root']}" + f"{params['model']}" + f"/{lang}", exist_ok=True, ) dump_predictions( idx=idx, response=pred, label=label, response_logger_file=f"{params['response_logger_root']}" + f"/{params['model']}" + f"/{lang}/" + config_header + ".csv", ) except Exception as e: print(f"Found an exception {e}, continue to the next example") continue os.makedirs(f"{params['metrics_root']}" + f"/{params['model']}", exist_ok=True) dump_metrics( lang, config, avg_f1, avg_em, f"{params['metrics_root']}" + f"/{params['model']}" + f"/{lang}.csv", ) def run_one_configuration_paralle(params: Optional[PARAMS] = None, zero: bool = False): if not params: params = read_parameters("../../parameters.yaml") lang = params["selected_language"] config = params["config"] zero_shot = len(config["context"]) == 0 rouge1, rouge2, rougeL, normalized_ic_examples, batched_predictions = ( [], [], [], [], [], ) if not zero: config_header = f"{config['input']}_{config['prefix']}_{config['context'][0]}_{config['output']}" else: config_header = f"{config['input']}_{config['prefix']}_zero_{config['output']}" test_data = load_qa_dataset( dataset_name=params["dataset_name"], lang=lang, split="validation" if params["dataset_name"] == "xquad" else "test", limit=params["limit"], ) # Initialize multiprocessing poosl num_processes = mp.cpu_count() # Use number of available CPU cores pool = mp.Pool(processes=10) # Iterate over test_data using tqdm for progress tracking for idx, test_example in tqdm(enumerate(test_data), total=len(test_data)): # Apply asynchronous processing of each test example pool.apply_async( process_test_example, args=( test_data, config_header, idx, test_example, config, zero_shot, lang, params, ), ) # Close the pool and wait for all processes to finish pool.close() pool.join() def construct_prompt( instruction: str, test_example: dict, zero_shot: bool, num_examples: int, lang: str, config: Dict[str, str], dataset_name: str = "xquad", ): if not instruction: instruction = create_instruction(lang, config["prefix"], config["output"]) example_prompt = PromptTemplate( input_variables=["context", "question", "answers"], template="Context: {context} \n Question: {question} \n " "Answers: {answers}", ) zero_shot_template = ( f"""{instruction}""" + " \n : {context} \n : {question} " "" ) if not zero_shot: try: test_data = load_qa_dataset( dataset_name=dataset_name, lang=lang, split="test", limit=100 ) except Exception as e: raise KeyError(f"{lang} is not supported in {dataset_name}") ic_examples = [] if not zero_shot: ic_examples = choose_few_shot_examples( train_dataset=test_data, few_shot_size=num_examples, context=[config["context"]] * num_examples, selection_criteria="random", lang=lang, ) prompt = ( FewShotPromptTemplate( examples=ic_examples, prefix=instruction, example_prompt=example_prompt, suffix=": {context} \n : {question} \n Answers: ?", input_variables=["question", "context"], ) if not zero_shot else PromptTemplate( input_variables=["question", "context"], template=zero_shot_template ) ) print("lang", lang) print(config["input"], lang) if config["input"] != lang: test_example = _translate_example( example=test_example, src_language=lang, target_language=config["input"] ) return prompt.format( question=test_example["question"], context=test_example["context"] )