File size: 3,866 Bytes
b8a6dde
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from typing import List, Dict, Any

from pathlib import Path

from utils import get_full_file_path

# SENTENCE_STOPPERS = {'!', '"', '#', '$', '%', '&', "'", '(', ')', '*', '+', ',', '-', '.', '/', ':', ';', '<', '=', '>', '?', '@', '[', '\\', ']', '^', '_', '`', '{', '|', '}', '~'}
# VIETNAMESE_SPECIAL_CHARACTERS = {'à', 'á', 'ả', 'ã', 'ạ', 'â', 'ầ', 'ấ', 'ẩ', 'ẫ', 'ậ', 'ă', 'ằ', 'ắ', 'ẳ', 'ẵ', 'ặ', 'è', 'é', 'ẻ', 'ẽ', 'ẹ', 'ê', 'ề', 'ế', 'ể', 'ễ', 'ệ', 'ì', 'í', 'ỉ', 'ĩ', 'ị', 'ò', 'ó', 'ỏ', 'õ', 'ọ', 'ô', 'ồ', 'ố', 'ổ', 'ỗ', 'ộ', 'ơ', 'ờ', 'ớ', 'ở', 'ỡ', 'ợ', 'ù', 'ú', 'ủ', 'ũ', 'ụ', 'ư', 'ừ', 'ứ', 'ử', 'ữ', 'ự', 'ỳ', 'ý', 'ỷ', 'ỹ', 'ỵ'}

# def is_Vietnamese_character(char):
#     return char.isalpha() or char in VIETNAMESE_SPECIAL_CHARACTERS

# def categorize_word(word: str) -> str:
#     """
#     Categoize word into 3 types:
#     - "vi": likely Vietnamese.
#     - "lo": likely Laos.
#     - "num": a number
#     """
#     if any(char.isdigit() for char in word):
#         return "num"
    
#     for stopper in SENTENCE_STOPPERS:
#         if word.endswith(stopper):
#             word = word[:-1]
#         if len(word) == 0:
#             break
    
#     if len(word) > 0 and any(not is_Vietnamese_character(char) for char in word):
#         return "lo"
#     else:
#         return "vi"
# 
# def open_dataset(
#     dataset_filename: str, 
#     src_lang: str = "lo", 
#     tgt_lang: str = "vi"
# ) -> List[Dict[str, Dict[str,str]]]:
#     ds = []
#     file_path = get_full_file_path(dataset_filename)
#     with open(file_path, 'r', encoding='utf-8') as file:
#         lines = file.readlines()

#     for index, line in enumerate(lines):
#         line = line.split(sep=None)

#         lo_positions = [i for i, word in enumerate(line) if categorize_word(word) == "lo"]
#         if len(lo_positions) == 0:
#             # print(line)
#             continue

#         split_index = max(lo_positions)
#         assert split_index is not None, f"Dataset error on line {index+1}."

#         src_text = ' '.join(line[:split_index+1])
#         tgt_text = line[split_index+1:]
        
#         if index <= 5:
#             print(src_text, tgt_text, sep="\n", end="\n-------")

#         # TODO: post process the tgt_text to split all numbers in to single digits.
#         ds.append({'translation':{src_lang:src_text, tgt_lang:tgt_text}})
#     return ds

# open_dataset('datasets/dev_clean.dat')

def load_local_dataset(
    dataset_filename: str, 
    src_lang: str = "lo", 
    tgt_lang: str = "vi"
) -> List[Dict[str, Dict[str,str]]]:
    ds = []
    file_path = get_full_file_path(dataset_filename)
    with open(file_path, 'r', encoding='utf-8') as file:
        lines = file.readlines()

    for index, line in enumerate(lines):
        src_text, tgt_text = line.split(sep="\t", maxsplit=1)
        ds.append({'translation':{src_lang:src_text, tgt_lang:tgt_text}})
    return ds

def load_local_bleu_dataset(
    src_dataset_filename: str, 
    tgt_dataset_filename: str,
    src_lang: str = "lo", 
    tgt_lang: str = "vi"
) -> List[Dict[str, Dict[str,str]]]:
    def load_local_monolanguage_dataset(dataset_filename: str):
        mono_ds = []
        file_path = get_full_file_path(dataset_filename)
        with open(file_path, 'r', encoding='utf-8') as file:
            lines = file.readlines()
        for line in lines:
            mono_ds.append(line)
        return mono_ds
    
    src_texts = load_local_monolanguage_dataset(src_dataset_filename)
    tgt_texts = load_local_monolanguage_dataset(tgt_dataset_filename)

    assert len(src_texts) == len(tgt_texts)
    ds = []
    for i in range(len(src_texts)):
        ds.append({'translation':{src_lang:src_texts[i], tgt_lang:tgt_texts[i]}})
    return ds