import csv import json import multiprocessing as mp import os from typing import Any, Dict, List, NewType, Optional, Union import numpy as np import yaml from datasets import Dataset, DatasetDict, load_dataset from easygoogletranslate import EasyGoogleTranslate from langchain.prompts import FewShotPromptTemplate, PromptTemplate from tqdm import tqdm from yaml.loader import SafeLoader LANGUAGE_TO_SUFFIX = { "chinese_simplified": "zh-CN", "french": "fr", "portuguese": "pt", "english": "en", "arabic": "ar", "hindi": "hi", "indonesian": "id", "amharic": "am", "bengali": "bn", "burmese": "my", "chinese": "zh-CN", "swahili": "sw", "bulgarian": "bg", "thai": "th", "urdu": "ur", "turkish": "tr", "spanish": "es", "chinese": "zh", "greek": "el", "german": "de", } NUMBER_TO_TAG = {0: "entailment", 1: "neutral", 2: "contradiction"} PARAMS = NewType("PARAMS", Dict[str, Any]) def read_parameters(args_path) -> PARAMS: with open(args_path) as f: args = yaml.load(f, Loader=SafeLoader) return args def get_key(key_path): with open(key_path) as f: key = f.read().split("\n")[0] return key def _translate_example( example: Dict[str, str], src_language: str, target_language: str ): translator = EasyGoogleTranslate( source_language=LANGUAGE_TO_SUFFIX[src_language], target_language=LANGUAGE_TO_SUFFIX[target_language], timeout=30, ) try: return { "premise": translator.translate(example["premise"]), "hypothesis": translator.translate(example["hypothesis"]), "label": "", } except Exception as e: print(e) def choose_few_shot_examples( train_dataset: Dataset, few_shot_size: int, context: List[str], selection_criteria: str, lang: str, ) -> List[Dict[str, Union[str, int]]]: """Selects few-shot examples from training datasets Args: train_dataset (Dataset): Training Dataset few_shot_size (int): Number of few-shot examples selection_criteria (few_shot_selection): How to select few-shot examples. Choices: [random, first_k] Returns: List[Dict[str, Union[str, int]]]: Selected examples """ selected_examples = [] example_idxs = [] if selection_criteria == "first_k": example_idxs = list(range(few_shot_size)) elif selection_criteria == "random": example_idxs = ( np.random.choice(len(train_dataset), size=few_shot_size, replace=True) .astype(int) .tolist() ) ic_examples = [train_dataset[idx] for idx in example_idxs] ic_examples = [ { "premise": example["premise"], "hypothesis": example["hypothesis"], "label": NUMBER_TO_TAG[example["label"]], } for example in ic_examples ] for idx, ic_language in enumerate(context): ( selected_examples.append(ic_examples[idx]) if ic_language == lang else ( selected_examples.append( _translate_example( example=ic_examples[idx], src_language=lang, target_language=ic_language, ) ) ) ) return selected_examples def load_xnli_dataset( dataset_name: str, lang: str, split: str, limit: int = 200, ) -> Union[Dataset, DatasetDict]: """ Args: lang (str): Language for which xnli dataset is to be loaded split (str): Train test of validation split of the model to load dataset_frac (float): Fraction of examples to load. Defaults to 1.0 Returns: Union[Dataset, DatasetDict]: huggingface dataset object """ if dataset_name == "indicxnli": ##PJ:To add except hindi dataset = load_dataset("Divyanshu/indicxnli", LANGUAGE_TO_SUFFIX[lang])[split] else: dataset = load_dataset("xnli", LANGUAGE_TO_SUFFIX[lang])[split] return dataset.select(np.arange(limit)) def construct_prompt( instruction: str, test_example: dict, ic_examples: List[dict], zero_shot: bool ): example_prompt = PromptTemplate( input_variables=["premise", "hypothesis", "label"], template="Premise: {premise}\n Hypothesis: {hypothesis} \n Label{label}", ) zero_shot_template = ( f"""{instruction}""" + "\n hypothesis: {hypothesis} + \n Premise: {premise}" "" ) prompt = ( FewShotPromptTemplate( examples=ic_examples, prefix=instruction, example_prompt=example_prompt, suffix="Premise: {premise} \n Hypothesis: {hypothesis}", input_variables=["hypothesis", "premise"], ) if not zero_shot else PromptTemplate( input_variables=["hypothesis", "premise"], template=zero_shot_template ) ) return ( prompt.format( hypothesis=test_example["hypothesis"], premise=test_example["premise"] ), test_example["label"], ) def dump_metrics( lang: str, config: Dict[str, str], r1: float, r2: float, rL: float, metric_logger_path: str, ): # Check if the metric logger file exists file_exists = os.path.exists(metric_logger_path) # Open the CSV file in append mode with open(metric_logger_path, "a", newline="") as f: csvwriter = csv.writer(f, delimiter=",") # Write header row if the file is newly created if not file_exists: header = [ "Language", "Prefix", "Input", "Context", "Output", "R1", "R2", "RL", ] csvwriter.writerow(header) csvwriter.writerow( [ lang, config["prefix"], config["input"], config["context"][0], config["output"], r1, r2, rL, ] ) def dump_predictions(idx, response, label, response_logger_file): obj = {"q_idx": idx, "prediction": response, "label": label} with open(response_logger_file, "a") as f: f.write(json.dumps(obj, ensure_ascii=False) + "\n") def compute_rouge(scorer, pred, label): score = scorer.score(pred, label) return score["rouge1"], score["rouge2"], score["rougeL"] def _translate_instruction(basic_instruction: str, target_language: str) -> str: translator = EasyGoogleTranslate( source_language="en", target_language=LANGUAGE_TO_SUFFIX[target_language], timeout=10, ) return translator.translate(basic_instruction) def _translate_prediction_to_output_language( prediction: str, prediction_language: str, output_language: str ) -> str: translator = EasyGoogleTranslate( source_language=LANGUAGE_TO_SUFFIX[prediction_language], target_language=LANGUAGE_TO_SUFFIX[output_language], timeout=10, ) return translator.translate(prediction) def create_instruction(lang: str): basic_instruction = f""" You are an NLP assistant whose purpose is to solve Natural Language Inference (NLI) problems. NLI is the task of determining the inference relation between two texts: entailment, contradiction, or neutral. Your answer should be one word of the following - entailment, contradiction, or neutral. Pay attention: The output should be only one word!!!! """ return ( basic_instruction if lang == "english" else _translate_instruction(basic_instruction, target_language=lang) ) def run_one_configuration(params: Optional[PARAMS] = None, zero: bool = False): if not params: params = read_parameters("../../parameters.yaml") lang = params["selected_language"] config = params["config"] zero_shot = len(config["context"]) == 0 if not zero: config_header = f"{config['input']}_{config['prefix']}_{config['context'][0]}" else: config_header = f"{config['input']}_{config['prefix']}_zero" test_data = load_xnli_dataset( dataset_name=params["dataset_name"], lang=lang, split="test", limit=params["limit"], ) pool = mp.Pool(processes=3) # Iterate over test_data using tqdm for progress tracking for idx, test_example in tqdm(enumerate(test_data), total=len(test_data)): # Apply asynchronous processing of each test example pool.apply_async( process_test_example, args=( test_data, config_header, idx, test_example, config, zero_shot, lang, params, ), ) # Close the pool and wait for all processes to finish pool.close() pool.join() def process_test_example( test_data, config_header, idx, test_example, config, zero_shot, lang, params ): try: instruction = create_instruction(lang=config["prefix"]) text_example = { "premise": test_example["premise"], "hypothesis": test_example["hypothesis"], "label": test_example["label"], } ic_examples = [] if not zero_shot: ic_examples = choose_few_shot_examples( train_dataset=test_data, few_shot_size=len(config["context"]), context=config["context"], selection_criteria="random", lang=params["selected_language"], ) prompt, label = construct_prompt( instruction=instruction, test_example=text_example, ic_examples=ic_examples, zero_shot=zero_shot, ) pred = get_prediction( prompt=prompt, endpoint_id=7327255438662041600, project_id=16514800572 ) print(pred) os.makedirs( f"{params['response_logger_root']}/{params['model']}/{lang}", exist_ok=True ) dump_predictions( idx=idx, response=pred, label=label, response_logger_file=f"{params['response_logger_root']}/{params['model']}/{lang}/{config_header}.csv", ) except Exception as e: # Handle exceptions here print(f"Error processing example {idx}: {e}") def construct_prompt( instruction: str, test_example: dict, zero_shot: bool, num_examples: int, lang: str, config: Dict[str, str], dataset_name: str = "xnli", ): if not instruction: print(lang) instruction = create_instruction(lang) example_prompt = PromptTemplate( input_variables=["premise", "hypothesis", "label"], template="Premise {premise}\n Hypothesis {hypothesis} \n{label}", ) zero_shot_template = ( f"""{instruction}""" + "\n Hypothesis: {hypothesis} + \n Premise: {premise}" "" ) if not zero_shot: try: test_data = load_xnli_dataset(dataset_name, lang, split="test", limit=100) except KeyError as e: raise KeyError( f"{lang} is not supported in {dataset_name} dataset, choose supported language in few-shot" ) ic_examples = [] if not zero_shot: ic_examples = choose_few_shot_examples( train_dataset=test_data, few_shot_size=num_examples, context=[config["context"]] * num_examples, selection_criteria="random", lang=lang, ) prompt = ( FewShotPromptTemplate( examples=ic_examples, prefix=instruction, example_prompt=example_prompt, suffix="{premise} \n{hypothesis}", input_variables=["hypothesis", "premise"], ) if not zero_shot else PromptTemplate( input_variables=["hypothesis", "premise"], template=zero_shot_template ) ) print("lang", lang) print(config["input"], lang) if config["input"] != lang: test_example = _translate_example( example=test_example, src_language=lang, target_language=config["input"] ) return prompt.format( hypothesis=test_example["hypothesis"], premise=test_example["premise"] )