from typing import Dict, List, Union import numpy as np from datasets import Dataset, load_dataset from easygoogletranslate import EasyGoogleTranslate from langchain.prompts import FewShotPromptTemplate, PromptTemplate LANGUAGE_TO_SUFFIX = { "chinese_simplified": "zh-CN", "french": "fr", "portuguese": "pt", "english": "en", "arabic": "ar", "hindi": "hi", "indonesian": "id", "amharic": "am", "bengali": "bn", "burmese": "my", "uzbek": "uz", "nepali": "ne", "japanese": "ja", "spanish": "es", "turkish": "tr", "persian": "fa", "azerbaijani": "az", "korean": "ko", "hebrew": "he", "telugu": "te", "german": "de", "greek": "el", "tamil": "ta", "assamese": "as", "vietnamese": "vi", "russian": "ru", "romanian": "ro", "malayalam": "ml", "swahili": "sw", "bulgarian": "bg", "thai": "th", "urdu": "ur", "italian": "it", "polish": "pl", "dutch": "nl", "swedish": "sv", "danish": "da", "norwegian": "no", "finnish": "fi", "hungarian": "hu", "czech": "cs", "slovak": "sk", "ukrainian": "uk", } def choose_few_shot_examples( train_dataset: Dataset, few_shot_size: int, context: List[str], selection_criteria: str, lang: str, ) -> List[Dict[str, Union[str, int]]]: selected_examples = [] example_idxs = [] if selection_criteria == "first_k": example_idxs = list(range(few_shot_size)) elif selection_criteria == "random": example_idxs = ( np.random.choice(len(train_dataset), size=few_shot_size, replace=True) .astype(int) .tolist() ) ic_examples = [ {"text": train_dataset[idx]["text"], "summary": train_dataset[idx]["summary"]} for idx in example_idxs ] for idx, ic_language in enumerate(context): ( selected_examples.append(ic_examples[idx]) if ic_language == lang else ( selected_examples.append( _translate_example( example=ic_examples[idx], src_language=lang, target_language=ic_language, ) ) ) ) return selected_examples def _translate_instruction(basic_instruction: str, target_language: str) -> str: translator = EasyGoogleTranslate( source_language="en", target_language=LANGUAGE_TO_SUFFIX[target_language], timeout=50, ) return translator.translate(basic_instruction) def _translate_example( example: Dict[str, str], src_language: str, target_language: str ): translator = EasyGoogleTranslate( source_language=LANGUAGE_TO_SUFFIX[src_language], target_language=LANGUAGE_TO_SUFFIX[target_language], timeout=30, ) try: return {"text": translator.translate(example["text"]), "summary": ""} except Exception as e: print(e) def create_instruction(lang: str, instruction_language: str, expected_output: str): basic_instruction = ( f"Write a summary of the given \n The output should be in {expected_output} " f"\n The output must be up to 2 sentences maximum!!!" ) print(lang) return ( basic_instruction if instruction_language == "english" else _translate_instruction(basic_instruction, target_language=lang) ) def load_xlsum_data(lang, split, limit=5): """Loads the xlsum dataset""" dataset = load_dataset("csebuetnlp/xlsum", lang)[split] return dataset.select(range(limit)) def construct_prompt( instruction: str, test_example: dict, zero_shot: bool, dataset: str, num_examples: int, lang: str, config: Dict[str, str], ): if not instruction: print(lang) instruction = create_instruction(lang, config["prefix"], config["output"]) example_prompt = PromptTemplate( input_variables=["summary", "text"], template="Text: {text}\nSummary: {summary}" ) zero_shot_template = f"""{instruction}""" + "\n Input: {text} " "" if not zero_shot: try: test_data = load_xlsum_data(lang=lang, split="test", limit=100) except Exception as e: raise KeyError( f"{lang} is not supported in XlSum dataset, choose supported language in few-shot" ) ic_examples = [] if not zero_shot: ic_examples = choose_few_shot_examples( train_dataset=test_data, few_shot_size=num_examples, context=[config["context"]] * num_examples, selection_criteria="random", lang=lang, ) prompt = ( FewShotPromptTemplate( examples=ic_examples, prefix=instruction, example_prompt=example_prompt, suffix=": {text}", input_variables=["text"], ) if not zero_shot else PromptTemplate(input_variables=["text"], template=zero_shot_template) ) print("lang", lang) print(config["input"], lang) if config["input"] != lang: test_example = _translate_example( example=test_example, src_language=lang, target_language=config["input"] ) print("test_example", prompt) return prompt.format(text=test_example["text"])