Spaces:
Runtime error
Runtime error
import csv | |
import json | |
import logging | |
import multiprocessing as mp | |
import os | |
import re | |
import string | |
import sys | |
import unicodedata | |
from typing import Any, Dict, List, NewType, Optional, Union | |
import numpy as np | |
import yaml | |
from datasets import Dataset, load_dataset | |
from easygoogletranslate import EasyGoogleTranslate | |
from evaluate import load | |
from langchain.prompts import FewShotPromptTemplate, PromptTemplate | |
from tqdm import tqdm | |
from yaml.loader import SafeLoader | |
XQUAD_LANG2CODES = { | |
"bengali": "bn", | |
"korean": "ko", | |
"swahili": "sw", | |
"english": "en", | |
"indonesian": "id", | |
"arabic": "ar", | |
"finnish": "fi", | |
"telugu": "te", | |
"russian": "ru", | |
"german": "de", | |
"greek": "el", | |
"hindi": "hi", | |
"vietnamese": "vi", | |
"romanian": "ro", | |
} | |
INDICQA_LANG2CODES = { | |
"indicqa": "as", | |
"bengali": "bn", | |
"gujarati": "gu", | |
"hindi": "hi", | |
"kannada": "kn", | |
"malayalam": "ml", | |
"marathi": "mr", | |
"odia": "or", | |
"punjabi": "pa", | |
"tamil": "ta", | |
"telugu": "te", | |
"assamese": "as", | |
} | |
PUNCT = { | |
chr(i) | |
for i in range(sys.maxunicode) | |
if unicodedata.category(chr(i)).startswith("P") | |
}.union(string.punctuation) | |
WHITESPACE_LANGS = ["en", "es", "hi", "vi", "de", "ar"] | |
MIXED_SEGMENTATION_LANGS = ["zh"] | |
TYDIQA_LANG2CODES = { | |
"bengali": "bn", | |
"korean": "ko", | |
"swahili": "sw", | |
"english": "en", | |
"indonesian": "id", | |
"arabic": "ar", | |
"finnish": "fi", | |
"telugu": "te", | |
"russian": "ru", | |
"assamese": "as", | |
"persian": "fa", | |
} | |
logger = logging.Logger("Xlsum_task") | |
LANGUAGE_TO_SUFFIX = { | |
"chinese_simplified": "zh-CN", | |
"french": "fr", | |
"portuguese": "pt", | |
"english": "en", | |
"arabic": "ar", | |
"hindi": "hi", | |
"indonesian": "id", | |
"amharic": "am", | |
"bengali": "bn", | |
"burmese": "my", | |
"uzbek": "uz", | |
"nepali": "ne", | |
"japanese": "ja", | |
"spanish": "es", | |
"turkish": "tr", | |
"persian": "fa", | |
"azerbaijani": "az", | |
"korean": "ko", | |
"hebrew": "he", | |
"telugu": "te", | |
"german": "de", | |
"greek": "el", | |
"tamil": "ta", | |
"assamese": "as", | |
"vietnamese": "vi", | |
"russian": "ru", | |
"romanian": "ro", | |
"malayalam": "ml", | |
"swahili": "sw", | |
"bulgarian": "bg", | |
"thai": "th", | |
"urdu": "ur", | |
"italian": "it", | |
"polish": "pl", | |
"dutch": "nl", | |
"swedish": "sv", | |
"danish": "da", | |
"norwegian": "no", | |
"finnish": "fi", | |
"hungarian": "hu", | |
"czech": "cs", | |
"slovak": "sk", | |
"ukrainian": "uk", | |
} | |
PARAMS = NewType("PARAMS", Dict[str, Any]) | |
def read_parameters(args_path) -> PARAMS: | |
with open(args_path) as f: | |
args = yaml.load(f, Loader=SafeLoader) | |
return args | |
def load_qa_dataset(dataset_name, lang, split, translate_test=False, limit=5): | |
if dataset_name == "indicqa": | |
if split != "train": | |
dataset = load_dataset( | |
"ai4bharat/IndicQA", f"indicqa.{INDICQA_LANG2CODES[lang]}" | |
)[split] | |
else: | |
dataset = load_dataset("squad_v2")[split] | |
elif dataset_name == "xquad": | |
if split != "train": | |
dataset = load_dataset("xquad", f"xquad.{XQUAD_LANG2CODES[lang]}")[ | |
"validation" | |
] | |
else: | |
dataset = load_dataset("squad")[split] | |
elif dataset_name == "tydiqa": | |
dataset = load_dataset("tydiqa", "secondary_task")[split] | |
dataset = dataset.map( | |
lambda example: {"lang": TYDIQA_LANG2CODES[example["id"].split("-")[0]]} | |
) | |
dataset = dataset.filter(lambda example: example["lang"] == lang) | |
elif dataset_name == "mlqa": | |
if split == "train": | |
print("No Training Data for MLQA, switching to validation!") | |
split = "validation" | |
if translate_test: | |
dataset_name = f"mlqa-translate-test.{lang}" | |
else: | |
dataset_name = f"mlqa.{lang}.{lang}" | |
dataset = load_dataset("mlqa", dataset_name)[split] | |
else: | |
raise NotImplementedError() | |
return dataset.select(np.arange(limit)) | |
def construct_prompt( | |
instruction: str, | |
test_example: dict, | |
ic_examples: List[dict], | |
zero_shot: bool, | |
lang: str, | |
config: Any, | |
): | |
example_prompt = PromptTemplate( | |
input_variables=["context", "question", "answers"], | |
template="Context: {context} \n Question: {question} \n " | |
"Answers: {answers}", | |
) | |
zero_shot_template = ( | |
f"""{instruction}""" + " \n <Context>: {context} \n <Question>: {question} " "" | |
) | |
prompt = ( | |
FewShotPromptTemplate( | |
examples=ic_examples, | |
prefix=instruction, | |
example_prompt=example_prompt, | |
suffix="<Context>: {context} \n <Question>: {question} \n Answers: ?", | |
input_variables=["question", "context"], | |
) | |
if not zero_shot | |
else PromptTemplate( | |
input_variables=["question", "context"], template=zero_shot_template | |
) | |
) | |
label = test_example["answers"] | |
if config["input"] != lang: | |
test_example = _translate_example( | |
example=test_example, src_language=lang, target_language=config["input"] | |
) | |
return ( | |
prompt.format( | |
question=test_example["question"], context=test_example["context"] | |
), | |
label, | |
) | |
def dump_metrics( | |
lang: str, config: Dict[str, str], f1: float, em: float, metric_logger_path: str | |
): | |
# Check if the metric logger file exists | |
file_exists = os.path.exists(metric_logger_path) | |
# Open the CSV file in append mode | |
with open(metric_logger_path, "a", newline="") as f: | |
csvwriter = csv.writer(f, delimiter=",") | |
# Write header row if the file is newly created | |
if not file_exists: | |
header = ["Language", "Prefix", "Input", "Context", "Output", "F1", "Em"] | |
csvwriter.writerow(header) | |
csvwriter.writerow( | |
[ | |
lang, | |
config["prefix"], | |
config["input"], | |
config["context"][0], | |
config["output"], | |
f1, | |
em, | |
] | |
) | |
def dump_predictions(idx, response, label, response_logger_file): | |
obj = {"q_idx": idx, "prediction": response, "label": label} | |
with open(response_logger_file, "a") as f: | |
f.write(json.dumps(obj, ensure_ascii=False) + " \n ") | |
def _translate_instruction(basic_instruction: str, target_language: str) -> str: | |
translator = EasyGoogleTranslate( | |
source_language="en", | |
target_language=LANGUAGE_TO_SUFFIX[target_language], | |
timeout=50, | |
) | |
return translator.translate(basic_instruction) | |
def _translate_prediction_to_output_language( | |
prediction: str, prediction_language: str, output_language: str | |
) -> str: | |
translator = EasyGoogleTranslate( | |
source_language=LANGUAGE_TO_SUFFIX[prediction_language], | |
target_language=LANGUAGE_TO_SUFFIX[output_language], | |
timeout=10, | |
) | |
return translator.translate(prediction) | |
def create_instruction(lang: str, instruction_language: str, expected_output): | |
basic_instruction = ( | |
"Answer to the <Question> below, based only to the given <Context>, Follow these instructions: \n " | |
"1. The answer should include only words from the given context \n " | |
"2. The answer must include up to 5 words \n " | |
"3. The answer Should be the shortest as possible \n " | |
f"4. The answer must be in {expected_output} only!, not another language!!!" | |
) | |
return ( | |
basic_instruction | |
if instruction_language == "english" | |
else _translate_instruction(basic_instruction, target_language=lang) | |
) | |
def _translate_example( | |
example: Dict[str, str], src_language: str, target_language: str | |
): | |
translator = EasyGoogleTranslate( | |
source_language=LANGUAGE_TO_SUFFIX[src_language], | |
target_language=LANGUAGE_TO_SUFFIX[target_language], | |
timeout=30, | |
) | |
try: | |
return { | |
"question": translator.translate(example["question"]), | |
"context": translator.translate(example["context"][:2000]) | |
+ translator.translate(example["context"][2000:4000]) | |
+ translator.translate(example["context"][4000:6000]), | |
"answers": "", | |
} | |
except Exception as e: | |
pass | |
def choose_few_shot_examples( | |
train_dataset: Dataset, | |
few_shot_size: int, | |
context: List[str], | |
selection_criteria: str, | |
lang: str, | |
) -> List[Dict[str, Union[str, int]]]: | |
"""Selects few-shot examples from training datasets | |
Args: | |
train_dataset (Dataset): Training Dataset | |
few_shot_size (int): Number of few-shot examples | |
selection_criteria (few_shot_selection): How to select few-shot examples. Choices: [random, first_k] | |
Returns: | |
List[Dict[str, Union[str, int]]]: Selected examples | |
""" | |
selected_examples = [] | |
example_idxs = [] | |
if selection_criteria == "first_k": | |
example_idxs = list(range(few_shot_size)) | |
elif selection_criteria == "random": | |
example_idxs = ( | |
np.random.choice(len(train_dataset), size=few_shot_size, replace=True) | |
.astype(int) | |
.tolist() | |
) | |
ic_examples = [ | |
{ | |
"question": train_dataset[idx]["question"], | |
"context": train_dataset[idx]["context"], | |
"answers": train_dataset[idx]["answers"]["text"], | |
} | |
for idx in example_idxs | |
] | |
for idx, ic_language in enumerate(context): | |
( | |
selected_examples.append(ic_examples[idx]) | |
if ic_language == lang | |
else ( | |
selected_examples.append( | |
_translate_example( | |
example=ic_examples[idx], | |
src_language=lang, | |
target_language=ic_language, | |
) | |
) | |
) | |
) | |
return selected_examples | |
def normalize_answer(s): | |
"""Lower text and remove punctuation, articles and extra whitespace.""" | |
def remove_articles(text): | |
return re.sub(r"\b(a|an|the)\b", " ", text) | |
def white_space_fix(text): | |
return " ".join(text.split()) | |
def remove_punc(text): | |
exclude = set(PUNCT) # set(string.punctuation) | |
return "".join(ch for ch in text if ch not in exclude) | |
def lower(text): | |
return text.lower() | |
return white_space_fix(remove_articles(remove_punc(lower(s)))) | |
def process_test_example( | |
test_data, config_header, idx, test_example, config, zero_shot, lang, params | |
): | |
try: | |
# Your existing code for processing each test example | |
instruction = create_instruction( | |
lang=config["prefix"], expected_output=config["output"] | |
) | |
text_example = { | |
"question": test_example["question"], | |
"context": test_example["context"], | |
"answers": test_example["answers"]["text"], | |
} | |
ic_examples = [] | |
if not zero_shot: | |
ic_examples = choose_few_shot_examples( | |
train_dataset=test_data, | |
few_shot_size=len(config["context"]), | |
context=config["context"], | |
selection_criteria="random", | |
lang=params["selected_language"], | |
) | |
prompt, label = construct_prompt( | |
instruction=instruction, | |
test_example=text_example, | |
ic_examples=ic_examples, | |
zero_shot=zero_shot, | |
lang=lang, | |
config=config, | |
) | |
print(len(prompt)) | |
pred = get_prediction( | |
prompt=prompt, endpoint_id=7327255438662041600, project_id=16514800572 | |
) | |
# pred = mixtral_completion(prompt) | |
print(pred) | |
logger.info("Saving prediction to persistent volume") | |
os.makedirs( | |
f"{params['response_logger_root']}/{params['model']}/{lang}", exist_ok=True | |
) | |
dump_predictions( | |
idx=idx, | |
response=pred, | |
label=label, | |
response_logger_file=f"{params['response_logger_root']}/{params['model']}/{lang}/{config_header}.csv", | |
) | |
except Exception as e: | |
# Handle exceptions here | |
print(f"Error processing example {idx}: {e}") | |
def run_one_configuration(params: Optional[PARAMS] = None): | |
if not params: | |
params = read_parameters("../../parameters.yaml") | |
lang = params["selected_language"] | |
config = params["config"] | |
zero_shot = len(config["context"]) == 0 | |
rouge1, rouge2, rougeL, normalized_ic_examples, batched_predictions = ( | |
[], | |
[], | |
[], | |
[], | |
[], | |
) | |
config_header = f"{config['input']}_{config['prefix']}_{config['context'][0]}_{config['output']}" | |
dataset_name = params["dataset_name"] | |
squad_metric = load("squad") | |
metric = params["metric"] | |
f1_sum = 0 | |
em_sum = 0 | |
avg_em = 0 | |
avg_f1 = 0 | |
preds = [] | |
labels = [] | |
f1s, ems = [], [] | |
test_data = load_qa_dataset( | |
dataset_name=params["dataset_name"], | |
lang=lang, | |
split="validation" if params["dataset_name"] == "xquad" else "test", | |
limit=params["limit"], | |
) | |
for idx, test_example in (pbar := tqdm(enumerate(test_data))): | |
try: | |
instruction = create_instruction( | |
lang=config["prefix"], expected_output=config["output"] | |
) | |
text_example = { | |
"question": test_example["question"], | |
"context": test_example["context"], | |
"answers": test_example["answers"]["text"], | |
} | |
ic_examples = [] | |
if not zero_shot: | |
ic_examples = choose_few_shot_examples( | |
train_dataset=test_data, | |
few_shot_size=len(config["context"]), | |
context=config["context"], | |
selection_criteria="random", | |
lang=params["selected_language"], | |
) | |
prompt, label = construct_prompt( | |
instruction=instruction, | |
test_example=text_example, | |
ic_examples=ic_examples, | |
zero_shot=zero_shot, | |
lang=lang, | |
config=config, | |
) | |
pred = mt0_completion(prompt=prompt) | |
print(pred) | |
logger.info("Saving prediction to persistent volume") | |
os.makedirs( | |
f"{params['response_logger_root']}" + f"{params['model']}" + f"/{lang}", | |
exist_ok=True, | |
) | |
dump_predictions( | |
idx=idx, | |
response=pred, | |
label=label, | |
response_logger_file=f"{params['response_logger_root']}" | |
+ f"/{params['model']}" | |
+ f"/{lang}/" | |
+ config_header | |
+ ".csv", | |
) | |
except Exception as e: | |
print(f"Found an exception {e}, continue to the next example") | |
continue | |
os.makedirs(f"{params['metrics_root']}" + f"/{params['model']}", exist_ok=True) | |
dump_metrics( | |
lang, | |
config, | |
avg_f1, | |
avg_em, | |
f"{params['metrics_root']}" + f"/{params['model']}" + f"/{lang}.csv", | |
) | |
def run_one_configuration_paralle(params: Optional[PARAMS] = None, zero: bool = False): | |
if not params: | |
params = read_parameters("../../parameters.yaml") | |
lang = params["selected_language"] | |
config = params["config"] | |
zero_shot = len(config["context"]) == 0 | |
rouge1, rouge2, rougeL, normalized_ic_examples, batched_predictions = ( | |
[], | |
[], | |
[], | |
[], | |
[], | |
) | |
if not zero: | |
config_header = f"{config['input']}_{config['prefix']}_{config['context'][0]}_{config['output']}" | |
else: | |
config_header = f"{config['input']}_{config['prefix']}_zero_{config['output']}" | |
test_data = load_qa_dataset( | |
dataset_name=params["dataset_name"], | |
lang=lang, | |
split="validation" if params["dataset_name"] == "xquad" else "test", | |
limit=params["limit"], | |
) | |
# Initialize multiprocessing poosl | |
num_processes = mp.cpu_count() # Use number of available CPU cores | |
pool = mp.Pool(processes=10) | |
# Iterate over test_data using tqdm for progress tracking | |
for idx, test_example in tqdm(enumerate(test_data), total=len(test_data)): | |
# Apply asynchronous processing of each test example | |
pool.apply_async( | |
process_test_example, | |
args=( | |
test_data, | |
config_header, | |
idx, | |
test_example, | |
config, | |
zero_shot, | |
lang, | |
params, | |
), | |
) | |
# Close the pool and wait for all processes to finish | |
pool.close() | |
pool.join() | |
def construct_prompt( | |
instruction: str, | |
test_example: dict, | |
zero_shot: bool, | |
num_examples: int, | |
lang: str, | |
config: Dict[str, str], | |
dataset_name: str = "xquad", | |
): | |
if not instruction: | |
instruction = create_instruction(lang, config["prefix"], config["output"]) | |
example_prompt = PromptTemplate( | |
input_variables=["context", "question", "answers"], | |
template="Context: {context} \n Question: {question} \n " "Answers: {answers}", | |
) | |
zero_shot_template = ( | |
f"""{instruction}""" + " \n <Context>: {context} \n <Question>: {question} " "" | |
) | |
if not zero_shot: | |
try: | |
test_data = load_qa_dataset( | |
dataset_name=dataset_name, lang=lang, split="test", limit=100 | |
) | |
except Exception as e: | |
raise KeyError(f"{lang} is not supported in {dataset_name}") | |
ic_examples = [] | |
if not zero_shot: | |
ic_examples = choose_few_shot_examples( | |
train_dataset=test_data, | |
few_shot_size=num_examples, | |
context=[config["context"]] * num_examples, | |
selection_criteria="random", | |
lang=lang, | |
) | |
prompt = ( | |
FewShotPromptTemplate( | |
examples=ic_examples, | |
prefix=instruction, | |
example_prompt=example_prompt, | |
suffix="<Context>: {context} \n <Question>: {question} \n Answers: ?", | |
input_variables=["question", "context"], | |
) | |
if not zero_shot | |
else PromptTemplate( | |
input_variables=["question", "context"], template=zero_shot_template | |
) | |
) | |
print("lang", lang) | |
print(config["input"], lang) | |
if config["input"] != lang: | |
test_example = _translate_example( | |
example=test_example, src_language=lang, target_language=config["input"] | |
) | |
return prompt.format( | |
question=test_example["question"], context=test_example["context"] | |
) | |