|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import chainer |
|
import h5py |
|
import logging |
|
import numpy as np |
|
import os |
|
import random |
|
import six |
|
from tqdm import tqdm |
|
|
|
from chainer.training import extension |
|
|
|
|
|
def load_dataset(path, label_dict, outdir=None): |
|
"""Load and save HDF5 that contains a dataset and stats for LM |
|
|
|
Args: |
|
path (str): The path of an input text dataset file |
|
label_dict (dict[str, int]): |
|
dictionary that maps token label string to its ID number |
|
outdir (str): The path of an output dir |
|
|
|
Returns: |
|
tuple[list[np.ndarray], int, int]: Tuple of |
|
token IDs in np.int32 converted by `read_tokens` |
|
the number of tokens by `count_tokens`, |
|
and the number of OOVs by `count_tokens` |
|
""" |
|
if outdir is not None: |
|
os.makedirs(outdir, exist_ok=True) |
|
filename = outdir + "/" + os.path.basename(path) + ".h5" |
|
if os.path.exists(filename): |
|
logging.info(f"loading binary dataset: {filename}") |
|
f = h5py.File(filename, "r") |
|
return f["data"][:], f["n_tokens"][()], f["n_oovs"][()] |
|
else: |
|
logging.info("skip dump/load HDF5 because the output dir is not specified") |
|
logging.info(f"reading text dataset: {path}") |
|
ret = read_tokens(path, label_dict) |
|
n_tokens, n_oovs = count_tokens(ret, label_dict["<unk>"]) |
|
if outdir is not None: |
|
logging.info(f"saving binary dataset: {filename}") |
|
with h5py.File(filename, "w") as f: |
|
|
|
data = f.create_dataset( |
|
"data", (len(ret),), dtype=h5py.special_dtype(vlen=np.int32) |
|
) |
|
data[:] = ret |
|
f["n_tokens"] = n_tokens |
|
f["n_oovs"] = n_oovs |
|
return ret, n_tokens, n_oovs |
|
|
|
|
|
def read_tokens(filename, label_dict): |
|
"""Read tokens as a sequence of sentences |
|
|
|
:param str filename : The name of the input file |
|
:param dict label_dict : dictionary that maps token label string to its ID number |
|
:return list of ID sequences |
|
:rtype list |
|
""" |
|
|
|
data = [] |
|
unk = label_dict["<unk>"] |
|
for ln in tqdm(open(filename, "r", encoding="utf-8")): |
|
data.append( |
|
np.array( |
|
[label_dict.get(label, unk) for label in ln.split()], dtype=np.int32 |
|
) |
|
) |
|
return data |
|
|
|
|
|
def count_tokens(data, unk_id=None): |
|
"""Count tokens and oovs in token ID sequences. |
|
|
|
Args: |
|
data (list[np.ndarray]): list of token ID sequences |
|
unk_id (int): ID of unknown token |
|
|
|
Returns: |
|
tuple: tuple of number of token occurrences and number of oov tokens |
|
|
|
""" |
|
|
|
n_tokens = 0 |
|
n_oovs = 0 |
|
for sentence in data: |
|
n_tokens += len(sentence) |
|
if unk_id is not None: |
|
n_oovs += np.count_nonzero(sentence == unk_id) |
|
return n_tokens, n_oovs |
|
|
|
|
|
def compute_perplexity(result): |
|
"""Computes and add the perplexity to the LogReport |
|
|
|
:param dict result: The current observations |
|
""" |
|
|
|
result["perplexity"] = np.exp(result["main/loss"] / result["main/count"]) |
|
if "validation/main/loss" in result: |
|
result["val_perplexity"] = np.exp(result["validation/main/loss"]) |
|
|
|
|
|
class ParallelSentenceIterator(chainer.dataset.Iterator): |
|
"""Dataset iterator to create a batch of sentences. |
|
|
|
This iterator returns a pair of sentences, where one token is shifted |
|
between the sentences like '<sos> w1 w2 w3' and 'w1 w2 w3 <eos>' |
|
Sentence batches are made in order of longer sentences, and then |
|
randomly shuffled. |
|
""" |
|
|
|
def __init__( |
|
self, dataset, batch_size, max_length=0, sos=0, eos=0, repeat=True, shuffle=True |
|
): |
|
self.dataset = dataset |
|
self.batch_size = batch_size |
|
|
|
|
|
|
|
self.epoch = 0 |
|
|
|
self.is_new_epoch = False |
|
self.repeat = repeat |
|
length = len(dataset) |
|
self.batch_indices = [] |
|
|
|
if batch_size > 1: |
|
indices = sorted(range(len(dataset)), key=lambda i: -len(dataset[i])) |
|
bs = 0 |
|
while bs < length: |
|
be = min(bs + batch_size, length) |
|
|
|
|
|
if max_length > 0: |
|
sent_length = len(dataset[indices[bs]]) |
|
be = min( |
|
be, bs + max(batch_size // (sent_length // max_length + 1), 1) |
|
) |
|
self.batch_indices.append(np.array(indices[bs:be])) |
|
bs = be |
|
if shuffle: |
|
|
|
random.shuffle(self.batch_indices) |
|
else: |
|
self.batch_indices = [np.array([i]) for i in six.moves.range(length)] |
|
|
|
|
|
|
|
self.iteration = 0 |
|
self.sos = sos |
|
self.eos = eos |
|
|
|
self._previous_epoch_detail = -1.0 |
|
|
|
def __next__(self): |
|
|
|
|
|
|
|
n_batches = len(self.batch_indices) |
|
if not self.repeat and self.iteration >= n_batches: |
|
|
|
|
|
raise StopIteration |
|
|
|
batch = [] |
|
for idx in self.batch_indices[self.iteration % n_batches]: |
|
batch.append( |
|
( |
|
np.append([self.sos], self.dataset[idx]), |
|
np.append(self.dataset[idx], [self.eos]), |
|
) |
|
) |
|
|
|
self._previous_epoch_detail = self.epoch_detail |
|
self.iteration += 1 |
|
|
|
epoch = self.iteration // n_batches |
|
self.is_new_epoch = self.epoch < epoch |
|
if self.is_new_epoch: |
|
self.epoch = epoch |
|
|
|
return batch |
|
|
|
def start_shuffle(self): |
|
random.shuffle(self.batch_indices) |
|
|
|
@property |
|
def epoch_detail(self): |
|
|
|
return self.iteration / len(self.batch_indices) |
|
|
|
@property |
|
def previous_epoch_detail(self): |
|
if self._previous_epoch_detail < 0: |
|
return None |
|
return self._previous_epoch_detail |
|
|
|
def serialize(self, serializer): |
|
|
|
self.iteration = serializer("iteration", self.iteration) |
|
self.epoch = serializer("epoch", self.epoch) |
|
try: |
|
self._previous_epoch_detail = serializer( |
|
"previous_epoch_detail", self._previous_epoch_detail |
|
) |
|
except KeyError: |
|
|
|
self._previous_epoch_detail = self.epoch + ( |
|
self.current_position - 1 |
|
) / len(self.batch_indices) |
|
if self.epoch_detail > 0: |
|
self._previous_epoch_detail = max(self._previous_epoch_detail, 0.0) |
|
else: |
|
self._previous_epoch_detail = -1.0 |
|
|
|
|
|
class MakeSymlinkToBestModel(extension.Extension): |
|
"""Extension that makes a symbolic link to the best model |
|
|
|
:param str key: Key of value |
|
:param str prefix: Prefix of model files and link target |
|
:param str suffix: Suffix of link target |
|
""" |
|
|
|
def __init__(self, key, prefix="model", suffix="best"): |
|
super(MakeSymlinkToBestModel, self).__init__() |
|
self.best_model = -1 |
|
self.min_loss = 0.0 |
|
self.key = key |
|
self.prefix = prefix |
|
self.suffix = suffix |
|
|
|
def __call__(self, trainer): |
|
observation = trainer.observation |
|
if self.key in observation: |
|
loss = observation[self.key] |
|
if self.best_model == -1 or loss < self.min_loss: |
|
self.min_loss = loss |
|
self.best_model = trainer.updater.epoch |
|
src = "%s.%d" % (self.prefix, self.best_model) |
|
dest = os.path.join(trainer.out, "%s.%s" % (self.prefix, self.suffix)) |
|
if os.path.lexists(dest): |
|
os.remove(dest) |
|
os.symlink(src, dest) |
|
logging.info("best model is " + src) |
|
|
|
def serialize(self, serializer): |
|
if isinstance(serializer, chainer.serializer.Serializer): |
|
serializer("_best_model", self.best_model) |
|
serializer("_min_loss", self.min_loss) |
|
serializer("_key", self.key) |
|
serializer("_prefix", self.prefix) |
|
serializer("_suffix", self.suffix) |
|
else: |
|
self.best_model = serializer("_best_model", -1) |
|
self.min_loss = serializer("_min_loss", 0.0) |
|
self.key = serializer("_key", "") |
|
self.prefix = serializer("_prefix", "model") |
|
self.suffix = serializer("_suffix", "best") |
|
|
|
|
|
|
|
|
|
def make_lexical_tree(word_dict, subword_dict, word_unk): |
|
"""Make a lexical tree to compute word-level probabilities""" |
|
|
|
root = [{}, -1, None] |
|
for w, wid in word_dict.items(): |
|
if wid > 0 and wid != word_unk: |
|
if True in [c not in subword_dict for c in w]: |
|
continue |
|
succ = root[0] |
|
for i, c in enumerate(w): |
|
cid = subword_dict[c] |
|
if cid not in succ: |
|
succ[cid] = [{}, -1, (wid - 1, wid)] |
|
else: |
|
prev = succ[cid][2] |
|
succ[cid][2] = (min(prev[0], wid - 1), max(prev[1], wid)) |
|
if i == len(w) - 1: |
|
succ[cid][1] = wid |
|
succ = succ[cid][0] |
|
return root |
|
|