selective_pre_translation / tasks /summarization.py
Anonymous
format and clean code
d27fe32
from typing import Dict, List, Union
import numpy as np
from datasets import Dataset, load_dataset
from easygoogletranslate import EasyGoogleTranslate
from langchain.prompts import FewShotPromptTemplate, PromptTemplate
LANGUAGE_TO_SUFFIX = {
"chinese_simplified": "zh-CN",
"french": "fr",
"portuguese": "pt",
"english": "en",
"arabic": "ar",
"hindi": "hi",
"indonesian": "id",
"amharic": "am",
"bengali": "bn",
"burmese": "my",
"uzbek": "uz",
"nepali": "ne",
"japanese": "ja",
"spanish": "es",
"turkish": "tr",
"persian": "fa",
"azerbaijani": "az",
"korean": "ko",
"hebrew": "he",
"telugu": "te",
"german": "de",
"greek": "el",
"tamil": "ta",
"assamese": "as",
"vietnamese": "vi",
"russian": "ru",
"romanian": "ro",
"malayalam": "ml",
"swahili": "sw",
"bulgarian": "bg",
"thai": "th",
"urdu": "ur",
"italian": "it",
"polish": "pl",
"dutch": "nl",
"swedish": "sv",
"danish": "da",
"norwegian": "no",
"finnish": "fi",
"hungarian": "hu",
"czech": "cs",
"slovak": "sk",
"ukrainian": "uk",
}
def choose_few_shot_examples(
train_dataset: Dataset,
few_shot_size: int,
context: List[str],
selection_criteria: str,
lang: str,
) -> List[Dict[str, Union[str, int]]]:
selected_examples = []
example_idxs = []
if selection_criteria == "first_k":
example_idxs = list(range(few_shot_size))
elif selection_criteria == "random":
example_idxs = (
np.random.choice(len(train_dataset), size=few_shot_size, replace=True)
.astype(int)
.tolist()
)
ic_examples = [
{"text": train_dataset[idx]["text"], "summary": train_dataset[idx]["summary"]}
for idx in example_idxs
]
for idx, ic_language in enumerate(context):
(
selected_examples.append(ic_examples[idx])
if ic_language == lang
else (
selected_examples.append(
_translate_example(
example=ic_examples[idx],
src_language=lang,
target_language=ic_language,
)
)
)
)
return selected_examples
def _translate_instruction(basic_instruction: str, target_language: str) -> str:
translator = EasyGoogleTranslate(
source_language="en",
target_language=LANGUAGE_TO_SUFFIX[target_language],
timeout=50,
)
return translator.translate(basic_instruction)
def _translate_example(
example: Dict[str, str], src_language: str, target_language: str
):
translator = EasyGoogleTranslate(
source_language=LANGUAGE_TO_SUFFIX[src_language],
target_language=LANGUAGE_TO_SUFFIX[target_language],
timeout=30,
)
try:
return {"text": translator.translate(example["text"]), "summary": ""}
except Exception as e:
print(e)
def create_instruction(lang: str, instruction_language: str, expected_output: str):
basic_instruction = (
f"Write a summary of the given <Text> \n The output should be in {expected_output} "
f"\n The output must be up to 2 sentences maximum!!!"
)
print(lang)
return (
basic_instruction
if instruction_language == "english"
else _translate_instruction(basic_instruction, target_language=lang)
)
def load_xlsum_data(lang, split, limit=5):
"""Loads the xlsum dataset"""
dataset = load_dataset("csebuetnlp/xlsum", lang)[split]
return dataset.select(range(limit))
def construct_prompt(
instruction: str,
test_example: dict,
zero_shot: bool,
dataset: str,
num_examples: int,
lang: str,
config: Dict[str, str],
):
if not instruction:
print(lang)
instruction = create_instruction(lang, config["prefix"], config["output"])
example_prompt = PromptTemplate(
input_variables=["summary", "text"], template="Text: {text}\nSummary: {summary}"
)
zero_shot_template = f"""{instruction}""" + "\n Input: {text} " ""
if not zero_shot:
try:
test_data = load_xlsum_data(lang=lang, split="test", limit=100)
except Exception as e:
raise KeyError(
f"{lang} is not supported in XlSum dataset, choose supported language in few-shot"
)
ic_examples = []
if not zero_shot:
ic_examples = choose_few_shot_examples(
train_dataset=test_data,
few_shot_size=num_examples,
context=[config["context"]] * num_examples,
selection_criteria="random",
lang=lang,
)
prompt = (
FewShotPromptTemplate(
examples=ic_examples,
prefix=instruction,
example_prompt=example_prompt,
suffix="<Text>: {text}",
input_variables=["text"],
)
if not zero_shot
else PromptTemplate(input_variables=["text"], template=zero_shot_template)
)
print("lang", lang)
print(config["input"], lang)
if config["input"] != lang:
test_example = _translate_example(
example=test_example, src_language=lang, target_language=config["input"]
)
print("test_example", prompt)
return prompt.format(text=test_example["text"])