Spaces:
Runtime error
Runtime error
from typing import Dict, List, Union | |
import numpy as np | |
from datasets import Dataset, load_dataset | |
from easygoogletranslate import EasyGoogleTranslate | |
from langchain.prompts import FewShotPromptTemplate, PromptTemplate | |
LANGUAGE_TO_SUFFIX = { | |
"chinese_simplified": "zh-CN", | |
"french": "fr", | |
"portuguese": "pt", | |
"english": "en", | |
"arabic": "ar", | |
"hindi": "hi", | |
"indonesian": "id", | |
"amharic": "am", | |
"bengali": "bn", | |
"burmese": "my", | |
"uzbek": "uz", | |
"nepali": "ne", | |
"japanese": "ja", | |
"spanish": "es", | |
"turkish": "tr", | |
"persian": "fa", | |
"azerbaijani": "az", | |
"korean": "ko", | |
"hebrew": "he", | |
"telugu": "te", | |
"german": "de", | |
"greek": "el", | |
"tamil": "ta", | |
"assamese": "as", | |
"vietnamese": "vi", | |
"russian": "ru", | |
"romanian": "ro", | |
"malayalam": "ml", | |
"swahili": "sw", | |
"bulgarian": "bg", | |
"thai": "th", | |
"urdu": "ur", | |
"italian": "it", | |
"polish": "pl", | |
"dutch": "nl", | |
"swedish": "sv", | |
"danish": "da", | |
"norwegian": "no", | |
"finnish": "fi", | |
"hungarian": "hu", | |
"czech": "cs", | |
"slovak": "sk", | |
"ukrainian": "uk", | |
} | |
def choose_few_shot_examples( | |
train_dataset: Dataset, | |
few_shot_size: int, | |
context: List[str], | |
selection_criteria: str, | |
lang: str, | |
) -> List[Dict[str, Union[str, int]]]: | |
selected_examples = [] | |
example_idxs = [] | |
if selection_criteria == "first_k": | |
example_idxs = list(range(few_shot_size)) | |
elif selection_criteria == "random": | |
example_idxs = ( | |
np.random.choice(len(train_dataset), size=few_shot_size, replace=True) | |
.astype(int) | |
.tolist() | |
) | |
ic_examples = [ | |
{"text": train_dataset[idx]["text"], "summary": train_dataset[idx]["summary"]} | |
for idx in example_idxs | |
] | |
for idx, ic_language in enumerate(context): | |
( | |
selected_examples.append(ic_examples[idx]) | |
if ic_language == lang | |
else ( | |
selected_examples.append( | |
_translate_example( | |
example=ic_examples[idx], | |
src_language=lang, | |
target_language=ic_language, | |
) | |
) | |
) | |
) | |
return selected_examples | |
def _translate_instruction(basic_instruction: str, target_language: str) -> str: | |
translator = EasyGoogleTranslate( | |
source_language="en", | |
target_language=LANGUAGE_TO_SUFFIX[target_language], | |
timeout=50, | |
) | |
return translator.translate(basic_instruction) | |
def _translate_example( | |
example: Dict[str, str], src_language: str, target_language: str | |
): | |
translator = EasyGoogleTranslate( | |
source_language=LANGUAGE_TO_SUFFIX[src_language], | |
target_language=LANGUAGE_TO_SUFFIX[target_language], | |
timeout=30, | |
) | |
try: | |
return {"text": translator.translate(example["text"]), "summary": ""} | |
except Exception as e: | |
print(e) | |
def create_instruction(lang: str, instruction_language: str, expected_output: str): | |
basic_instruction = ( | |
f"Write a summary of the given <Text> \n The output should be in {expected_output} " | |
f"\n The output must be up to 2 sentences maximum!!!" | |
) | |
print(lang) | |
return ( | |
basic_instruction | |
if instruction_language == "english" | |
else _translate_instruction(basic_instruction, target_language=lang) | |
) | |
def load_xlsum_data(lang, split, limit=5): | |
"""Loads the xlsum dataset""" | |
dataset = load_dataset("csebuetnlp/xlsum", lang)[split] | |
return dataset.select(range(limit)) | |
def construct_prompt( | |
instruction: str, | |
test_example: dict, | |
zero_shot: bool, | |
dataset: str, | |
num_examples: int, | |
lang: str, | |
config: Dict[str, str], | |
): | |
if not instruction: | |
print(lang) | |
instruction = create_instruction(lang, config["prefix"], config["output"]) | |
example_prompt = PromptTemplate( | |
input_variables=["summary", "text"], template="Text: {text}\nSummary: {summary}" | |
) | |
zero_shot_template = f"""{instruction}""" + "\n Input: {text} " "" | |
if not zero_shot: | |
try: | |
test_data = load_xlsum_data(lang=lang, split="test", limit=100) | |
except Exception as e: | |
raise KeyError( | |
f"{lang} is not supported in XlSum dataset, choose supported language in few-shot" | |
) | |
ic_examples = [] | |
if not zero_shot: | |
ic_examples = choose_few_shot_examples( | |
train_dataset=test_data, | |
few_shot_size=num_examples, | |
context=[config["context"]] * num_examples, | |
selection_criteria="random", | |
lang=lang, | |
) | |
prompt = ( | |
FewShotPromptTemplate( | |
examples=ic_examples, | |
prefix=instruction, | |
example_prompt=example_prompt, | |
suffix="<Text>: {text}", | |
input_variables=["text"], | |
) | |
if not zero_shot | |
else PromptTemplate(input_variables=["text"], template=zero_shot_template) | |
) | |
print("lang", lang) | |
print(config["input"], lang) | |
if config["input"] != lang: | |
test_example = _translate_example( | |
example=test_example, src_language=lang, target_language=config["input"] | |
) | |
print("test_example", prompt) | |
return prompt.format(text=test_example["text"]) | |