homemade_lo_vi / load_dataset.py
moiduy04's picture
Upload 12 files
b8a6dde
raw
history blame
3.87 kB
from typing import List, Dict, Any
from pathlib import Path
from utils import get_full_file_path
# SENTENCE_STOPPERS = {'!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~'}
# VIETNAMESE_SPECIAL_CHARACTERS = {'à', 'á', 'ả', 'ã', 'ạ', 'â', 'ầ', 'ấ', 'ẩ', 'ẫ', 'ậ', 'ă', 'ằ', 'ắ', 'ẳ', 'ẵ', 'ặ', 'è', 'é', 'ẻ', 'ẽ', 'ẹ', 'ê', 'ề', 'ế', 'ể', 'ễ', 'ệ', 'ì', 'í', 'ỉ', 'ĩ', 'ị', 'ò', 'ó', 'ỏ', 'õ', 'ọ', 'ô', 'ồ', 'ố', 'ổ', 'ỗ', 'ộ', 'ơ', 'ờ', 'ớ', 'ở', 'ỡ', 'ợ', 'ù', 'ú', 'ủ', 'ũ', 'ụ', 'ư', 'ừ', 'ứ', 'ử', 'ữ', 'ự', 'ỳ', 'ý', 'ỷ', 'ỹ', 'ỵ'}
# def is_Vietnamese_character(char):
# return char.isalpha() or char in VIETNAMESE_SPECIAL_CHARACTERS
# def categorize_word(word: str) -> str:
# """
# Categoize word into 3 types:
# - "vi": likely Vietnamese.
# - "lo": likely Laos.
# - "num": a number
# """
# if any(char.isdigit() for char in word):
# return "num"
# for stopper in SENTENCE_STOPPERS:
# if word.endswith(stopper):
# word = word[:-1]
# if len(word) == 0:
# break
# if len(word) > 0 and any(not is_Vietnamese_character(char) for char in word):
# return "lo"
# else:
# return "vi"
#
# def open_dataset(
# dataset_filename: str,
# src_lang: str = "lo",
# tgt_lang: str = "vi"
# ) -> List[Dict[str, Dict[str,str]]]:
# ds = []
# file_path = get_full_file_path(dataset_filename)
# with open(file_path, 'r', encoding='utf-8') as file:
# lines = file.readlines()
# for index, line in enumerate(lines):
# line = line.split(sep=None)
# lo_positions = [i for i, word in enumerate(line) if categorize_word(word) == "lo"]
# if len(lo_positions) == 0:
# # print(line)
# continue
# split_index = max(lo_positions)
# assert split_index is not None, f"Dataset error on line {index+1}."
# src_text = ' '.join(line[:split_index+1])
# tgt_text = line[split_index+1:]
# if index <= 5:
# print(src_text, tgt_text, sep="\n", end="\n-------")
# # TODO: post process the tgt_text to split all numbers in to single digits.
# ds.append({'translation':{src_lang:src_text, tgt_lang:tgt_text}})
# return ds
# open_dataset('datasets/dev_clean.dat')
def load_local_dataset(
dataset_filename: str,
src_lang: str = "lo",
tgt_lang: str = "vi"
) -> List[Dict[str, Dict[str,str]]]:
ds = []
file_path = get_full_file_path(dataset_filename)
with open(file_path, 'r', encoding='utf-8') as file:
lines = file.readlines()
for index, line in enumerate(lines):
src_text, tgt_text = line.split(sep="\t", maxsplit=1)
ds.append({'translation':{src_lang:src_text, tgt_lang:tgt_text}})
return ds
def load_local_bleu_dataset(
src_dataset_filename: str,
tgt_dataset_filename: str,
src_lang: str = "lo",
tgt_lang: str = "vi"
) -> List[Dict[str, Dict[str,str]]]:
def load_local_monolanguage_dataset(dataset_filename: str):
mono_ds = []
file_path = get_full_file_path(dataset_filename)
with open(file_path, 'r', encoding='utf-8') as file:
lines = file.readlines()
for line in lines:
mono_ds.append(line)
return mono_ds
src_texts = load_local_monolanguage_dataset(src_dataset_filename)
tgt_texts = load_local_monolanguage_dataset(tgt_dataset_filename)
assert len(src_texts) == len(tgt_texts)
ds = []
for i in range(len(src_texts)):
ds.append({'translation':{src_lang:src_texts[i], tgt_lang:tgt_texts[i]}})
return ds