Anonymous
format and clean code
d27fe32
import csv
import json
import logging
import multiprocessing as mp
import os
import re
import string
import sys
import unicodedata
from typing import Any, Dict, List, NewType, Optional, Union
import numpy as np
import yaml
from datasets import Dataset, load_dataset
from easygoogletranslate import EasyGoogleTranslate
from evaluate import load
from langchain.prompts import FewShotPromptTemplate, PromptTemplate
from tqdm import tqdm
from yaml.loader import SafeLoader
XQUAD_LANG2CODES = {
"bengali": "bn",
"korean": "ko",
"swahili": "sw",
"english": "en",
"indonesian": "id",
"arabic": "ar",
"finnish": "fi",
"telugu": "te",
"russian": "ru",
"german": "de",
"greek": "el",
"hindi": "hi",
"vietnamese": "vi",
"romanian": "ro",
}
INDICQA_LANG2CODES = {
"indicqa": "as",
"bengali": "bn",
"gujarati": "gu",
"hindi": "hi",
"kannada": "kn",
"malayalam": "ml",
"marathi": "mr",
"odia": "or",
"punjabi": "pa",
"tamil": "ta",
"telugu": "te",
"assamese": "as",
}
PUNCT = {
chr(i)
for i in range(sys.maxunicode)
if unicodedata.category(chr(i)).startswith("P")
}.union(string.punctuation)
WHITESPACE_LANGS = ["en", "es", "hi", "vi", "de", "ar"]
MIXED_SEGMENTATION_LANGS = ["zh"]
TYDIQA_LANG2CODES = {
"bengali": "bn",
"korean": "ko",
"swahili": "sw",
"english": "en",
"indonesian": "id",
"arabic": "ar",
"finnish": "fi",
"telugu": "te",
"russian": "ru",
"assamese": "as",
"persian": "fa",
}
logger = logging.Logger("Xlsum_task")
LANGUAGE_TO_SUFFIX = {
"chinese_simplified": "zh-CN",
"french": "fr",
"portuguese": "pt",
"english": "en",
"arabic": "ar",
"hindi": "hi",
"indonesian": "id",
"amharic": "am",
"bengali": "bn",
"burmese": "my",
"uzbek": "uz",
"nepali": "ne",
"japanese": "ja",
"spanish": "es",
"turkish": "tr",
"persian": "fa",
"azerbaijani": "az",
"korean": "ko",
"hebrew": "he",
"telugu": "te",
"german": "de",
"greek": "el",
"tamil": "ta",
"assamese": "as",
"vietnamese": "vi",
"russian": "ru",
"romanian": "ro",
"malayalam": "ml",
"swahili": "sw",
"bulgarian": "bg",
"thai": "th",
"urdu": "ur",
"italian": "it",
"polish": "pl",
"dutch": "nl",
"swedish": "sv",
"danish": "da",
"norwegian": "no",
"finnish": "fi",
"hungarian": "hu",
"czech": "cs",
"slovak": "sk",
"ukrainian": "uk",
}
PARAMS = NewType("PARAMS", Dict[str, Any])
def read_parameters(args_path) -> PARAMS:
with open(args_path) as f:
args = yaml.load(f, Loader=SafeLoader)
return args
def load_qa_dataset(dataset_name, lang, split, translate_test=False, limit=5):
if dataset_name == "indicqa":
if split != "train":
dataset = load_dataset(
"ai4bharat/IndicQA", f"indicqa.{INDICQA_LANG2CODES[lang]}"
)[split]
else:
dataset = load_dataset("squad_v2")[split]
elif dataset_name == "xquad":
if split != "train":
dataset = load_dataset("xquad", f"xquad.{XQUAD_LANG2CODES[lang]}")[
"validation"
]
else:
dataset = load_dataset("squad")[split]
elif dataset_name == "tydiqa":
dataset = load_dataset("tydiqa", "secondary_task")[split]
dataset = dataset.map(
lambda example: {"lang": TYDIQA_LANG2CODES[example["id"].split("-")[0]]}
)
dataset = dataset.filter(lambda example: example["lang"] == lang)
elif dataset_name == "mlqa":
if split == "train":
print("No Training Data for MLQA, switching to validation!")
split = "validation"
if translate_test:
dataset_name = f"mlqa-translate-test.{lang}"
else:
dataset_name = f"mlqa.{lang}.{lang}"
dataset = load_dataset("mlqa", dataset_name)[split]
else:
raise NotImplementedError()
return dataset.select(np.arange(limit))
def construct_prompt(
instruction: str,
test_example: dict,
ic_examples: List[dict],
zero_shot: bool,
lang: str,
config: Any,
):
example_prompt = PromptTemplate(
input_variables=["context", "question", "answers"],
template="Context: {context} \n Question: {question} \n "
"Answers: {answers}",
)
zero_shot_template = (
f"""{instruction}""" + " \n <Context>: {context} \n <Question>: {question} " ""
)
prompt = (
FewShotPromptTemplate(
examples=ic_examples,
prefix=instruction,
example_prompt=example_prompt,
suffix="<Context>: {context} \n <Question>: {question} \n Answers: ?",
input_variables=["question", "context"],
)
if not zero_shot
else PromptTemplate(
input_variables=["question", "context"], template=zero_shot_template
)
)
label = test_example["answers"]
if config["input"] != lang:
test_example = _translate_example(
example=test_example, src_language=lang, target_language=config["input"]
)
return (
prompt.format(
question=test_example["question"], context=test_example["context"]
),
label,
)
def dump_metrics(
lang: str, config: Dict[str, str], f1: float, em: float, metric_logger_path: str
):
# Check if the metric logger file exists
file_exists = os.path.exists(metric_logger_path)
# Open the CSV file in append mode
with open(metric_logger_path, "a", newline="") as f:
csvwriter = csv.writer(f, delimiter=",")
# Write header row if the file is newly created
if not file_exists:
header = ["Language", "Prefix", "Input", "Context", "Output", "F1", "Em"]
csvwriter.writerow(header)
csvwriter.writerow(
[
lang,
config["prefix"],
config["input"],
config["context"][0],
config["output"],
f1,
em,
]
)
def dump_predictions(idx, response, label, response_logger_file):
obj = {"q_idx": idx, "prediction": response, "label": label}
with open(response_logger_file, "a") as f:
f.write(json.dumps(obj, ensure_ascii=False) + " \n ")
def _translate_instruction(basic_instruction: str, target_language: str) -> str:
translator = EasyGoogleTranslate(
source_language="en",
target_language=LANGUAGE_TO_SUFFIX[target_language],
timeout=50,
)
return translator.translate(basic_instruction)
def _translate_prediction_to_output_language(
prediction: str, prediction_language: str, output_language: str
) -> str:
translator = EasyGoogleTranslate(
source_language=LANGUAGE_TO_SUFFIX[prediction_language],
target_language=LANGUAGE_TO_SUFFIX[output_language],
timeout=10,
)
return translator.translate(prediction)
def create_instruction(lang: str, instruction_language: str, expected_output):
basic_instruction = (
"Answer to the <Question> below, based only to the given <Context>, Follow these instructions: \n "
"1. The answer should include only words from the given context \n "
"2. The answer must include up to 5 words \n "
"3. The answer Should be the shortest as possible \n "
f"4. The answer must be in {expected_output} only!, not another language!!!"
)
return (
basic_instruction
if instruction_language == "english"
else _translate_instruction(basic_instruction, target_language=lang)
)
def _translate_example(
example: Dict[str, str], src_language: str, target_language: str
):
translator = EasyGoogleTranslate(
source_language=LANGUAGE_TO_SUFFIX[src_language],
target_language=LANGUAGE_TO_SUFFIX[target_language],
timeout=30,
)
try:
return {
"question": translator.translate(example["question"]),
"context": translator.translate(example["context"][:2000])
+ translator.translate(example["context"][2000:4000])
+ translator.translate(example["context"][4000:6000]),
"answers": "",
}
except Exception as e:
pass
def choose_few_shot_examples(
train_dataset: Dataset,
few_shot_size: int,
context: List[str],
selection_criteria: str,
lang: str,
) -> List[Dict[str, Union[str, int]]]:
"""Selects few-shot examples from training datasets
Args:
train_dataset (Dataset): Training Dataset
few_shot_size (int): Number of few-shot examples
selection_criteria (few_shot_selection): How to select few-shot examples. Choices: [random, first_k]
Returns:
List[Dict[str, Union[str, int]]]: Selected examples
"""
selected_examples = []
example_idxs = []
if selection_criteria == "first_k":
example_idxs = list(range(few_shot_size))
elif selection_criteria == "random":
example_idxs = (
np.random.choice(len(train_dataset), size=few_shot_size, replace=True)
.astype(int)
.tolist()
)
ic_examples = [
{
"question": train_dataset[idx]["question"],
"context": train_dataset[idx]["context"],
"answers": train_dataset[idx]["answers"]["text"],
}
for idx in example_idxs
]
for idx, ic_language in enumerate(context):
(
selected_examples.append(ic_examples[idx])
if ic_language == lang
else (
selected_examples.append(
_translate_example(
example=ic_examples[idx],
src_language=lang,
target_language=ic_language,
)
)
)
)
return selected_examples
def normalize_answer(s):
"""Lower text and remove punctuation, articles and extra whitespace."""
def remove_articles(text):
return re.sub(r"\b(a|an|the)\b", " ", text)
def white_space_fix(text):
return " ".join(text.split())
def remove_punc(text):
exclude = set(PUNCT) # set(string.punctuation)
return "".join(ch for ch in text if ch not in exclude)
def lower(text):
return text.lower()
return white_space_fix(remove_articles(remove_punc(lower(s))))
def process_test_example(
test_data, config_header, idx, test_example, config, zero_shot, lang, params
):
try:
# Your existing code for processing each test example
instruction = create_instruction(
lang=config["prefix"], expected_output=config["output"]
)
text_example = {
"question": test_example["question"],
"context": test_example["context"],
"answers": test_example["answers"]["text"],
}
ic_examples = []
if not zero_shot:
ic_examples = choose_few_shot_examples(
train_dataset=test_data,
few_shot_size=len(config["context"]),
context=config["context"],
selection_criteria="random",
lang=params["selected_language"],
)
prompt, label = construct_prompt(
instruction=instruction,
test_example=text_example,
ic_examples=ic_examples,
zero_shot=zero_shot,
lang=lang,
config=config,
)
print(len(prompt))
pred = get_prediction(
prompt=prompt, endpoint_id=7327255438662041600, project_id=16514800572
)
# pred = mixtral_completion(prompt)
print(pred)
logger.info("Saving prediction to persistent volume")
os.makedirs(
f"{params['response_logger_root']}/{params['model']}/{lang}", exist_ok=True
)
dump_predictions(
idx=idx,
response=pred,
label=label,
response_logger_file=f"{params['response_logger_root']}/{params['model']}/{lang}/{config_header}.csv",
)
except Exception as e:
# Handle exceptions here
print(f"Error processing example {idx}: {e}")
def run_one_configuration(params: Optional[PARAMS] = None):
if not params:
params = read_parameters("../../parameters.yaml")
lang = params["selected_language"]
config = params["config"]
zero_shot = len(config["context"]) == 0
rouge1, rouge2, rougeL, normalized_ic_examples, batched_predictions = (
[],
[],
[],
[],
[],
)
config_header = f"{config['input']}_{config['prefix']}_{config['context'][0]}_{config['output']}"
dataset_name = params["dataset_name"]
squad_metric = load("squad")
metric = params["metric"]
f1_sum = 0
em_sum = 0
avg_em = 0
avg_f1 = 0
preds = []
labels = []
f1s, ems = [], []
test_data = load_qa_dataset(
dataset_name=params["dataset_name"],
lang=lang,
split="validation" if params["dataset_name"] == "xquad" else "test",
limit=params["limit"],
)
for idx, test_example in (pbar := tqdm(enumerate(test_data))):
try:
instruction = create_instruction(
lang=config["prefix"], expected_output=config["output"]
)
text_example = {
"question": test_example["question"],
"context": test_example["context"],
"answers": test_example["answers"]["text"],
}
ic_examples = []
if not zero_shot:
ic_examples = choose_few_shot_examples(
train_dataset=test_data,
few_shot_size=len(config["context"]),
context=config["context"],
selection_criteria="random",
lang=params["selected_language"],
)
prompt, label = construct_prompt(
instruction=instruction,
test_example=text_example,
ic_examples=ic_examples,
zero_shot=zero_shot,
lang=lang,
config=config,
)
pred = mt0_completion(prompt=prompt)
print(pred)
logger.info("Saving prediction to persistent volume")
os.makedirs(
f"{params['response_logger_root']}" + f"{params['model']}" + f"/{lang}",
exist_ok=True,
)
dump_predictions(
idx=idx,
response=pred,
label=label,
response_logger_file=f"{params['response_logger_root']}"
+ f"/{params['model']}"
+ f"/{lang}/"
+ config_header
+ ".csv",
)
except Exception as e:
print(f"Found an exception {e}, continue to the next example")
continue
os.makedirs(f"{params['metrics_root']}" + f"/{params['model']}", exist_ok=True)
dump_metrics(
lang,
config,
avg_f1,
avg_em,
f"{params['metrics_root']}" + f"/{params['model']}" + f"/{lang}.csv",
)
def run_one_configuration_paralle(params: Optional[PARAMS] = None, zero: bool = False):
if not params:
params = read_parameters("../../parameters.yaml")
lang = params["selected_language"]
config = params["config"]
zero_shot = len(config["context"]) == 0
rouge1, rouge2, rougeL, normalized_ic_examples, batched_predictions = (
[],
[],
[],
[],
[],
)
if not zero:
config_header = f"{config['input']}_{config['prefix']}_{config['context'][0]}_{config['output']}"
else:
config_header = f"{config['input']}_{config['prefix']}_zero_{config['output']}"
test_data = load_qa_dataset(
dataset_name=params["dataset_name"],
lang=lang,
split="validation" if params["dataset_name"] == "xquad" else "test",
limit=params["limit"],
)
# Initialize multiprocessing poosl
num_processes = mp.cpu_count() # Use number of available CPU cores
pool = mp.Pool(processes=10)
# Iterate over test_data using tqdm for progress tracking
for idx, test_example in tqdm(enumerate(test_data), total=len(test_data)):
# Apply asynchronous processing of each test example
pool.apply_async(
process_test_example,
args=(
test_data,
config_header,
idx,
test_example,
config,
zero_shot,
lang,
params,
),
)
# Close the pool and wait for all processes to finish
pool.close()
pool.join()
def construct_prompt(
instruction: str,
test_example: dict,
zero_shot: bool,
num_examples: int,
lang: str,
config: Dict[str, str],
dataset_name: str = "xquad",
):
if not instruction:
instruction = create_instruction(lang, config["prefix"], config["output"])
example_prompt = PromptTemplate(
input_variables=["context", "question", "answers"],
template="Context: {context} \n Question: {question} \n " "Answers: {answers}",
)
zero_shot_template = (
f"""{instruction}""" + " \n <Context>: {context} \n <Question>: {question} " ""
)
if not zero_shot:
try:
test_data = load_qa_dataset(
dataset_name=dataset_name, lang=lang, split="test", limit=100
)
except Exception as e:
raise KeyError(f"{lang} is not supported in {dataset_name}")
ic_examples = []
if not zero_shot:
ic_examples = choose_few_shot_examples(
train_dataset=test_data,
few_shot_size=num_examples,
context=[config["context"]] * num_examples,
selection_criteria="random",
lang=lang,
)
prompt = (
FewShotPromptTemplate(
examples=ic_examples,
prefix=instruction,
example_prompt=example_prompt,
suffix="<Context>: {context} \n <Question>: {question} \n Answers: ?",
input_variables=["question", "context"],
)
if not zero_shot
else PromptTemplate(
input_variables=["question", "context"], template=zero_shot_template
)
)
print("lang", lang)
print(config["input"], lang)
if config["input"] != lang:
test_example = _translate_example(
example=test_example, src_language=lang, target_language=config["input"]
)
return prompt.format(
question=test_example["question"], context=test_example["context"]
)