Spaces:
Runtime error
Runtime error
kevinwang676
commited on
Commit
•
9cabf4f
1
Parent(s):
38e25d3
Delete load_model.py
Browse files- load_model.py +0 -936
load_model.py
DELETED
@@ -1,936 +0,0 @@
|
|
1 |
-
# Copyright (c) Facebook, Inc. and its affiliates.
|
2 |
-
#
|
3 |
-
# This source code is licensed under the MIT license found in the
|
4 |
-
# LICENSE file in the root directory of this source tree.
|
5 |
-
|
6 |
-
import ast
|
7 |
-
import collections
|
8 |
-
import contextlib
|
9 |
-
import inspect
|
10 |
-
import logging
|
11 |
-
import os
|
12 |
-
import re
|
13 |
-
import time
|
14 |
-
import traceback
|
15 |
-
from collections import OrderedDict
|
16 |
-
from pathlib import Path
|
17 |
-
from typing import Any, Dict, Optional, Union
|
18 |
-
|
19 |
-
import numpy as np
|
20 |
-
import torch
|
21 |
-
from fairseq.data import data_utils
|
22 |
-
from fairseq.dataclass.configs import CheckpointConfig
|
23 |
-
from fairseq.dataclass.utils import (
|
24 |
-
convert_namespace_to_omegaconf,
|
25 |
-
overwrite_args_by_name,
|
26 |
-
)
|
27 |
-
from fairseq.distributed.fully_sharded_data_parallel import FSDP, has_FSDP
|
28 |
-
from fairseq.file_io import PathManager
|
29 |
-
from fairseq.models import FairseqDecoder, FairseqEncoder
|
30 |
-
from omegaconf import DictConfig, OmegaConf, open_dict
|
31 |
-
|
32 |
-
logger = logging.getLogger(__name__)
|
33 |
-
|
34 |
-
|
35 |
-
def save_checkpoint(cfg: CheckpointConfig, trainer, epoch_itr, val_loss):
|
36 |
-
from fairseq import meters
|
37 |
-
|
38 |
-
# only one worker should attempt to create the required dir
|
39 |
-
if trainer.data_parallel_rank == 0:
|
40 |
-
os.makedirs(cfg.save_dir, exist_ok=True)
|
41 |
-
|
42 |
-
prev_best = getattr(save_checkpoint, "best", val_loss)
|
43 |
-
if val_loss is not None:
|
44 |
-
best_function = max if cfg.maximize_best_checkpoint_metric else min
|
45 |
-
save_checkpoint.best = best_function(val_loss, prev_best)
|
46 |
-
|
47 |
-
if cfg.no_save:
|
48 |
-
return None
|
49 |
-
|
50 |
-
trainer.consolidate_optimizer() # TODO(SS): do we need this if no_save_optimizer_state
|
51 |
-
|
52 |
-
if not trainer.should_save_checkpoint_on_current_rank:
|
53 |
-
if trainer.always_call_state_dict_during_save_checkpoint:
|
54 |
-
trainer.state_dict()
|
55 |
-
return None
|
56 |
-
|
57 |
-
write_timer = meters.StopwatchMeter()
|
58 |
-
write_timer.start()
|
59 |
-
|
60 |
-
epoch = epoch_itr.epoch
|
61 |
-
end_of_epoch = epoch_itr.end_of_epoch()
|
62 |
-
updates = trainer.get_num_updates()
|
63 |
-
|
64 |
-
logger.info(f"Preparing to save checkpoint for epoch {epoch} @ {updates} updates")
|
65 |
-
|
66 |
-
def is_better(a, b):
|
67 |
-
return a >= b if cfg.maximize_best_checkpoint_metric else a <= b
|
68 |
-
|
69 |
-
suffix = trainer.checkpoint_suffix
|
70 |
-
checkpoint_conds = collections.OrderedDict()
|
71 |
-
checkpoint_conds["checkpoint{}{}.pt".format(epoch, suffix)] = (
|
72 |
-
end_of_epoch and not cfg.no_epoch_checkpoints and epoch % cfg.save_interval == 0
|
73 |
-
)
|
74 |
-
checkpoint_conds["checkpoint_{}_{}{}.pt".format(epoch, updates, suffix)] = (
|
75 |
-
not end_of_epoch
|
76 |
-
and cfg.save_interval_updates > 0
|
77 |
-
and updates % cfg.save_interval_updates == 0
|
78 |
-
)
|
79 |
-
checkpoint_conds["checkpoint_best{}.pt".format(suffix)] = val_loss is not None and (
|
80 |
-
not hasattr(save_checkpoint, "best")
|
81 |
-
or is_better(val_loss, save_checkpoint.best)
|
82 |
-
)
|
83 |
-
if val_loss is not None and cfg.keep_best_checkpoints > 0:
|
84 |
-
worst_best = getattr(save_checkpoint, "best", None)
|
85 |
-
chkpts = checkpoint_paths(
|
86 |
-
cfg.save_dir,
|
87 |
-
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
|
88 |
-
cfg.best_checkpoint_metric, suffix
|
89 |
-
),
|
90 |
-
)
|
91 |
-
if len(chkpts) > 0:
|
92 |
-
p = chkpts[-1] if cfg.maximize_best_checkpoint_metric else chkpts[0]
|
93 |
-
worst_best = float(p.rsplit("_")[-1].replace("{}.pt".format(suffix), ""))
|
94 |
-
# add random digits to resolve ties
|
95 |
-
with data_utils.numpy_seed(epoch, updates, val_loss):
|
96 |
-
rand_sfx = np.random.randint(0, cfg.keep_best_checkpoints)
|
97 |
-
|
98 |
-
checkpoint_conds[
|
99 |
-
"checkpoint.best_{}_{:.3f}{}{}.pt".format(
|
100 |
-
cfg.best_checkpoint_metric, val_loss, rand_sfx, suffix
|
101 |
-
)
|
102 |
-
] = worst_best is None or is_better(val_loss, worst_best)
|
103 |
-
checkpoint_conds[
|
104 |
-
"checkpoint_last{}.pt".format(suffix)
|
105 |
-
] = not cfg.no_last_checkpoints
|
106 |
-
|
107 |
-
extra_state = {
|
108 |
-
"train_iterator": epoch_itr.state_dict(),
|
109 |
-
"val_loss": val_loss,
|
110 |
-
}
|
111 |
-
|
112 |
-
# Going forward, different tasks could expose an API like this to dump all
|
113 |
-
# the checkpoint worthy attributes in a dictionary which then will be
|
114 |
-
# merged with the parent dictionary to create the "extra_state". This
|
115 |
-
# allows for an extensible yet simple design to checkpoint task level
|
116 |
-
# attributes
|
117 |
-
if hasattr(trainer.task, "get_checkpoint_dict"):
|
118 |
-
extra_state = {**extra_state, **trainer.task.get_checkpoint_dict()}
|
119 |
-
logger.info(f"State of {trainer.task.__class__.__name__} is ready to be persisted with the checkpoint")
|
120 |
-
|
121 |
-
if hasattr(save_checkpoint, "best"):
|
122 |
-
extra_state.update({"best": save_checkpoint.best})
|
123 |
-
|
124 |
-
checkpoints = [
|
125 |
-
os.path.join(cfg.save_dir, fn) for fn, cond in checkpoint_conds.items() if cond
|
126 |
-
]
|
127 |
-
saved_cp = None
|
128 |
-
if len(checkpoints) > 0 and trainer.should_save_checkpoint_on_current_rank:
|
129 |
-
saved_cp = trainer.save_checkpoint(checkpoints[0], extra_state)
|
130 |
-
for cp in checkpoints[1:]:
|
131 |
-
if cfg.write_checkpoints_asynchronously:
|
132 |
-
# TODO[ioPath]: Need to implement a delayed asynchronous
|
133 |
-
# file copying/moving feature.
|
134 |
-
logger.warning(
|
135 |
-
f"ioPath is not copying {checkpoints[0]} to {cp} "
|
136 |
-
"since async write mode is on."
|
137 |
-
)
|
138 |
-
else:
|
139 |
-
assert PathManager.copy(
|
140 |
-
checkpoints[0], cp, overwrite=True
|
141 |
-
), f"Failed to copy {checkpoints[0]} to {cp}"
|
142 |
-
|
143 |
-
write_timer.stop()
|
144 |
-
logger.info(
|
145 |
-
"Saved checkpoint {} (epoch {} @ {} updates, score {}) (writing took {} seconds)".format(
|
146 |
-
checkpoints[0], epoch, updates, val_loss, write_timer.sum
|
147 |
-
)
|
148 |
-
)
|
149 |
-
|
150 |
-
if (
|
151 |
-
not end_of_epoch
|
152 |
-
and cfg.keep_interval_updates > 0
|
153 |
-
and trainer.should_save_checkpoint_on_current_rank
|
154 |
-
):
|
155 |
-
# remove old checkpoints; checkpoints are sorted in descending order
|
156 |
-
if cfg.keep_interval_updates_pattern == -1:
|
157 |
-
checkpoints = checkpoint_paths(
|
158 |
-
cfg.save_dir, pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix)
|
159 |
-
)
|
160 |
-
else:
|
161 |
-
checkpoints = checkpoint_paths(
|
162 |
-
cfg.save_dir,
|
163 |
-
pattern=r"checkpoint_\d+_(\d+){}\.pt".format(suffix),
|
164 |
-
keep_match=True,
|
165 |
-
)
|
166 |
-
checkpoints = [
|
167 |
-
x[0]
|
168 |
-
for x in checkpoints
|
169 |
-
if x[1] % cfg.keep_interval_updates_pattern != 0
|
170 |
-
]
|
171 |
-
|
172 |
-
for old_chk in checkpoints[cfg.keep_interval_updates :]:
|
173 |
-
if os.path.lexists(old_chk):
|
174 |
-
os.remove(old_chk)
|
175 |
-
elif PathManager.exists(old_chk):
|
176 |
-
PathManager.rm(old_chk)
|
177 |
-
|
178 |
-
if cfg.keep_last_epochs > 0 and trainer.should_save_checkpoint_on_current_rank:
|
179 |
-
# remove old epoch checkpoints; checkpoints are sorted in descending order
|
180 |
-
checkpoints = checkpoint_paths(
|
181 |
-
cfg.save_dir, pattern=r"checkpoint(\d+){}\.pt".format(suffix)
|
182 |
-
)
|
183 |
-
for old_chk in checkpoints[cfg.keep_last_epochs :]:
|
184 |
-
if os.path.lexists(old_chk):
|
185 |
-
os.remove(old_chk)
|
186 |
-
elif PathManager.exists(old_chk):
|
187 |
-
PathManager.rm(old_chk)
|
188 |
-
|
189 |
-
if cfg.keep_best_checkpoints > 0 and trainer.should_save_checkpoint_on_current_rank:
|
190 |
-
# only keep the best N checkpoints according to validation metric
|
191 |
-
checkpoints = checkpoint_paths(
|
192 |
-
cfg.save_dir,
|
193 |
-
pattern=r"checkpoint\.best_{}_(\d+\.?\d*){}\.pt".format(
|
194 |
-
cfg.best_checkpoint_metric, suffix
|
195 |
-
),
|
196 |
-
)
|
197 |
-
if not cfg.maximize_best_checkpoint_metric:
|
198 |
-
checkpoints = checkpoints[::-1]
|
199 |
-
for old_chk in checkpoints[cfg.keep_best_checkpoints :]:
|
200 |
-
if os.path.lexists(old_chk):
|
201 |
-
os.remove(old_chk)
|
202 |
-
elif PathManager.exists(old_chk):
|
203 |
-
PathManager.rm(old_chk)
|
204 |
-
|
205 |
-
return saved_cp
|
206 |
-
|
207 |
-
|
208 |
-
def load_checkpoint(cfg: CheckpointConfig, trainer, **passthrough_args):
|
209 |
-
"""
|
210 |
-
Load a checkpoint and restore the training iterator.
|
211 |
-
|
212 |
-
*passthrough_args* will be passed through to
|
213 |
-
``trainer.get_train_iterator``.
|
214 |
-
"""
|
215 |
-
|
216 |
-
reset_optimizer = cfg.reset_optimizer
|
217 |
-
reset_lr_scheduler = cfg.reset_lr_scheduler
|
218 |
-
optimizer_overrides = ast.literal_eval(cfg.optimizer_overrides)
|
219 |
-
reset_meters = cfg.reset_meters
|
220 |
-
reset_dataloader = cfg.reset_dataloader
|
221 |
-
|
222 |
-
if cfg.finetune_from_model is not None and (
|
223 |
-
reset_optimizer or reset_lr_scheduler or reset_meters or reset_dataloader
|
224 |
-
):
|
225 |
-
raise ValueError(
|
226 |
-
"--finetune-from-model can not be set together with either --reset-optimizer"
|
227 |
-
" or reset_lr_scheduler or reset_meters or reset_dataloader"
|
228 |
-
)
|
229 |
-
|
230 |
-
suffix = trainer.checkpoint_suffix
|
231 |
-
if (
|
232 |
-
cfg.restore_file == "checkpoint_last.pt"
|
233 |
-
): # default value of restore_file is 'checkpoint_last.pt'
|
234 |
-
checkpoint_path = os.path.join(
|
235 |
-
cfg.save_dir, "checkpoint_last{}.pt".format(suffix)
|
236 |
-
)
|
237 |
-
first_launch = not PathManager.exists(checkpoint_path)
|
238 |
-
if first_launch and getattr(cfg, "continue_once", None) is not None:
|
239 |
-
checkpoint_path = cfg.continue_once
|
240 |
-
elif cfg.finetune_from_model is not None and first_launch:
|
241 |
-
# if there is no last checkpoint to restore, start the finetune from pretrained model
|
242 |
-
# else just use usual logic to load checkpoint, e.g. restart from last checkpoint and etc.
|
243 |
-
if PathManager.exists(cfg.finetune_from_model):
|
244 |
-
checkpoint_path = cfg.finetune_from_model
|
245 |
-
reset_optimizer = True
|
246 |
-
reset_lr_scheduler = True
|
247 |
-
reset_meters = True
|
248 |
-
reset_dataloader = True
|
249 |
-
logger.info(
|
250 |
-
f"loading pretrained model from {checkpoint_path}: "
|
251 |
-
"optimizer, lr scheduler, meters, dataloader will be reset"
|
252 |
-
)
|
253 |
-
else:
|
254 |
-
raise ValueError(
|
255 |
-
f"--finetune-from-model {cfg.finetune_from_model} does not exist"
|
256 |
-
)
|
257 |
-
elif suffix is not None:
|
258 |
-
checkpoint_path = cfg.restore_file.replace(".pt", suffix + ".pt")
|
259 |
-
else:
|
260 |
-
checkpoint_path = cfg.restore_file
|
261 |
-
|
262 |
-
if cfg.restore_file != "checkpoint_last.pt" and cfg.finetune_from_model:
|
263 |
-
raise ValueError(
|
264 |
-
"--finetune-from-model and --restore-file (non-default value) "
|
265 |
-
"can not be specified together: " + str(cfg)
|
266 |
-
)
|
267 |
-
|
268 |
-
extra_state = trainer.load_checkpoint(
|
269 |
-
checkpoint_path,
|
270 |
-
reset_optimizer,
|
271 |
-
reset_lr_scheduler,
|
272 |
-
optimizer_overrides,
|
273 |
-
reset_meters=reset_meters,
|
274 |
-
)
|
275 |
-
|
276 |
-
if (
|
277 |
-
extra_state is not None
|
278 |
-
and "best" in extra_state
|
279 |
-
and not reset_optimizer
|
280 |
-
and not reset_meters
|
281 |
-
):
|
282 |
-
save_checkpoint.best = extra_state["best"]
|
283 |
-
|
284 |
-
if extra_state is not None and not reset_dataloader:
|
285 |
-
# restore iterator from checkpoint
|
286 |
-
itr_state = extra_state["train_iterator"]
|
287 |
-
epoch_itr = trainer.get_train_iterator(
|
288 |
-
epoch=itr_state["epoch"], load_dataset=True, **passthrough_args
|
289 |
-
)
|
290 |
-
epoch_itr.load_state_dict(itr_state)
|
291 |
-
|
292 |
-
# Preload the checkpoint for the task
|
293 |
-
task_cp_dict = extra_state.get(trainer.task.__class__.__name__, {})
|
294 |
-
if task_cp_dict and hasattr(trainer.task, "set_checkpoint_dict"):
|
295 |
-
trainer.task.set_checkpoint_dict(task_cp_dict)
|
296 |
-
else:
|
297 |
-
epoch_itr = trainer.get_train_iterator(
|
298 |
-
epoch=1, load_dataset=True, **passthrough_args
|
299 |
-
)
|
300 |
-
|
301 |
-
trainer.lr_step(epoch_itr.epoch)
|
302 |
-
|
303 |
-
return extra_state, epoch_itr
|
304 |
-
|
305 |
-
|
306 |
-
def load_checkpoint_to_cpu(path, arg_overrides=None, load_on_all_ranks=False):
|
307 |
-
"""Loads a checkpoint to CPU (with upgrading for backward compatibility).
|
308 |
-
|
309 |
-
If doing single-GPU training or if the checkpoint is only being loaded by at
|
310 |
-
most one process on each node (current default behavior is for only rank 0
|
311 |
-
to read the checkpoint from disk), load_on_all_ranks should be False to
|
312 |
-
avoid errors from torch.distributed not having been initialized or
|
313 |
-
torch.distributed.barrier() hanging.
|
314 |
-
|
315 |
-
If all processes on each node may be loading the checkpoint
|
316 |
-
simultaneously, load_on_all_ranks should be set to True to avoid I/O
|
317 |
-
conflicts.
|
318 |
-
|
319 |
-
There's currently no support for > 1 but < all processes loading the
|
320 |
-
checkpoint on each node.
|
321 |
-
"""
|
322 |
-
local_path = PathManager.get_local_path(path)
|
323 |
-
# The locally cached file returned by get_local_path() may be stale for
|
324 |
-
# remote files that are periodically updated/overwritten (ex:
|
325 |
-
# checkpoint_last.pt) - so we remove the local copy, sync across processes
|
326 |
-
# (if needed), and then download a fresh copy.
|
327 |
-
if local_path != path and PathManager.path_requires_pathmanager(path):
|
328 |
-
try:
|
329 |
-
os.remove(local_path)
|
330 |
-
except FileNotFoundError:
|
331 |
-
# With potentially multiple processes removing the same file, the
|
332 |
-
# file being missing is benign (missing_ok isn't available until
|
333 |
-
# Python 3.8).
|
334 |
-
pass
|
335 |
-
if load_on_all_ranks:
|
336 |
-
torch.distributed.barrier()
|
337 |
-
local_path = PathManager.get_local_path(path)
|
338 |
-
|
339 |
-
with open(local_path, "rb") as f:
|
340 |
-
state = torch.load(f, map_location=torch.device("cpu"))
|
341 |
-
|
342 |
-
if "args" in state and state["args"] is not None and arg_overrides is not None:
|
343 |
-
args = state["args"]
|
344 |
-
for arg_name, arg_val in arg_overrides.items():
|
345 |
-
setattr(args, arg_name, arg_val)
|
346 |
-
|
347 |
-
if "cfg" in state and state["cfg"] is not None:
|
348 |
-
|
349 |
-
# hack to be able to set Namespace in dict config. this should be removed when we update to newer
|
350 |
-
# omegaconf version that supports object flags, or when we migrate all existing models
|
351 |
-
from omegaconf import __version__ as oc_version
|
352 |
-
from omegaconf import _utils
|
353 |
-
|
354 |
-
if oc_version < "2.2":
|
355 |
-
old_primitive = _utils.is_primitive_type
|
356 |
-
_utils.is_primitive_type = lambda _: True
|
357 |
-
|
358 |
-
state["cfg"] = OmegaConf.create(state["cfg"])
|
359 |
-
|
360 |
-
_utils.is_primitive_type = old_primitive
|
361 |
-
OmegaConf.set_struct(state["cfg"], True)
|
362 |
-
else:
|
363 |
-
state["cfg"] = OmegaConf.create(state["cfg"], flags={"allow_objects": True})
|
364 |
-
|
365 |
-
if arg_overrides is not None:
|
366 |
-
overwrite_args_by_name(state["cfg"], arg_overrides)
|
367 |
-
|
368 |
-
state = _upgrade_state_dict(state)
|
369 |
-
return state
|
370 |
-
|
371 |
-
|
372 |
-
def load_model_ensemble(
|
373 |
-
filenames,
|
374 |
-
arg_overrides: Optional[Dict[str, Any]] = None,
|
375 |
-
task=None,
|
376 |
-
strict=True,
|
377 |
-
suffix="",
|
378 |
-
num_shards=1,
|
379 |
-
state=None,
|
380 |
-
):
|
381 |
-
"""Loads an ensemble of models.
|
382 |
-
|
383 |
-
Args:
|
384 |
-
filenames (List[str]): checkpoint files to load
|
385 |
-
arg_overrides (Dict[str,Any], optional): override model args that
|
386 |
-
were used during model training
|
387 |
-
task (fairseq.tasks.FairseqTask, optional): task to use for loading
|
388 |
-
"""
|
389 |
-
assert not (
|
390 |
-
strict and num_shards > 1
|
391 |
-
), "Cannot load state dict with strict=True and checkpoint shards > 1"
|
392 |
-
ensemble, args, _task = load_model_ensemble_and_task(
|
393 |
-
filenames,
|
394 |
-
arg_overrides,
|
395 |
-
task,
|
396 |
-
strict,
|
397 |
-
suffix,
|
398 |
-
num_shards,
|
399 |
-
state,
|
400 |
-
)
|
401 |
-
return ensemble, args
|
402 |
-
|
403 |
-
|
404 |
-
def get_maybe_sharded_checkpoint_filename(
|
405 |
-
filename: str, suffix: str, shard_idx: int, num_shards: int
|
406 |
-
) -> str:
|
407 |
-
orig_filename = filename
|
408 |
-
filename = filename.replace(".pt", suffix + ".pt")
|
409 |
-
fsdp_filename = filename[:-3] + f"-shard{shard_idx}.pt"
|
410 |
-
model_parallel_filename = orig_filename[:-3] + f"_part{shard_idx}.pt"
|
411 |
-
if PathManager.exists(fsdp_filename):
|
412 |
-
return fsdp_filename
|
413 |
-
elif num_shards > 1:
|
414 |
-
return model_parallel_filename
|
415 |
-
else:
|
416 |
-
return filename
|
417 |
-
|
418 |
-
|
419 |
-
def load_model_ensemble_and_task(
|
420 |
-
filenames,
|
421 |
-
arg_overrides: Optional[Dict[str, Any]] = None,
|
422 |
-
task=None,
|
423 |
-
strict=True,
|
424 |
-
suffix="",
|
425 |
-
num_shards=1,
|
426 |
-
state=None,
|
427 |
-
):
|
428 |
-
assert state is None or len(filenames) == 1
|
429 |
-
|
430 |
-
from fairseq import tasks
|
431 |
-
|
432 |
-
assert not (
|
433 |
-
strict and num_shards > 1
|
434 |
-
), "Cannot load state dict with strict=True and checkpoint shards > 1"
|
435 |
-
ensemble = []
|
436 |
-
cfg = None
|
437 |
-
for filename in filenames:
|
438 |
-
orig_filename = filename
|
439 |
-
model_shard_state = {"shard_weights": [], "shard_metadata": []}
|
440 |
-
assert num_shards > 0
|
441 |
-
st = time.time()
|
442 |
-
for shard_idx in range(num_shards):
|
443 |
-
filename = get_maybe_sharded_checkpoint_filename(
|
444 |
-
orig_filename, suffix, shard_idx, num_shards
|
445 |
-
)
|
446 |
-
|
447 |
-
if not PathManager.exists(filename):
|
448 |
-
raise IOError("Model file not found: {}".format(filename))
|
449 |
-
if state is None:
|
450 |
-
state = load_checkpoint_to_cpu(filename, arg_overrides)
|
451 |
-
if "args" in state and state["args"] is not None:
|
452 |
-
cfg = convert_namespace_to_omegaconf(state["args"])
|
453 |
-
elif "cfg" in state and state["cfg"] is not None:
|
454 |
-
cfg = state["cfg"]
|
455 |
-
else:
|
456 |
-
raise RuntimeError(
|
457 |
-
f"Neither args nor cfg exist in state keys = {state.keys()}"
|
458 |
-
)
|
459 |
-
|
460 |
-
if task is None:
|
461 |
-
task = tasks.setup_task(cfg.task, from_checkpoint=True)
|
462 |
-
|
463 |
-
if "task_state" in state:
|
464 |
-
task.load_state_dict(state["task_state"])
|
465 |
-
|
466 |
-
argspec = inspect.getfullargspec(task.build_model)
|
467 |
-
|
468 |
-
if "fsdp_metadata" in state and num_shards > 1:
|
469 |
-
model_shard_state["shard_weights"].append(state["model"])
|
470 |
-
model_shard_state["shard_metadata"].append(state["fsdp_metadata"])
|
471 |
-
# check FSDP import before the code goes too far
|
472 |
-
if not has_FSDP:
|
473 |
-
raise ImportError(
|
474 |
-
"Cannot find FullyShardedDataParallel. "
|
475 |
-
"Please install fairscale with: pip install fairscale"
|
476 |
-
)
|
477 |
-
if shard_idx == num_shards - 1:
|
478 |
-
consolidated_model_state = FSDP.consolidate_shard_weights(
|
479 |
-
shard_weights=model_shard_state["shard_weights"],
|
480 |
-
shard_metadata=model_shard_state["shard_metadata"],
|
481 |
-
)
|
482 |
-
if "from_checkpoint" in argspec.args:
|
483 |
-
model = task.build_model(cfg.model, from_checkpoint=True)
|
484 |
-
else:
|
485 |
-
model = task.build_model(cfg.model)
|
486 |
-
if (
|
487 |
-
"optimizer_history" in state
|
488 |
-
and len(state["optimizer_history"]) > 0
|
489 |
-
and "num_updates" in state["optimizer_history"][-1]
|
490 |
-
):
|
491 |
-
model.set_num_updates(
|
492 |
-
state["optimizer_history"][-1]["num_updates"]
|
493 |
-
)
|
494 |
-
model.load_state_dict(
|
495 |
-
consolidated_model_state, strict=strict, model_cfg=cfg.model
|
496 |
-
)
|
497 |
-
else:
|
498 |
-
# model parallel checkpoint or unsharded checkpoint
|
499 |
-
# support old external tasks
|
500 |
-
|
501 |
-
if "from_checkpoint" in argspec.args:
|
502 |
-
model = task.build_model(cfg.model, from_checkpoint=True)
|
503 |
-
else:
|
504 |
-
model = task.build_model(cfg.model)
|
505 |
-
if (
|
506 |
-
"optimizer_history" in state
|
507 |
-
and len(state["optimizer_history"]) > 0
|
508 |
-
and "num_updates" in state["optimizer_history"][-1]
|
509 |
-
):
|
510 |
-
model.set_num_updates(state["optimizer_history"][-1]["num_updates"])
|
511 |
-
model.load_state_dict(
|
512 |
-
state["model"], strict=strict, model_cfg=cfg.model
|
513 |
-
)
|
514 |
-
|
515 |
-
# reset state so it gets loaded for the next model in ensemble
|
516 |
-
state = None
|
517 |
-
if shard_idx % 10 == 0 and shard_idx > 0:
|
518 |
-
elapsed = time.time() - st
|
519 |
-
logger.info(
|
520 |
-
f"Loaded {shard_idx} shards in {elapsed:.2f}s, {elapsed / (shard_idx+1):.2f}s/shard"
|
521 |
-
)
|
522 |
-
|
523 |
-
# build model for ensemble
|
524 |
-
ensemble.append(model)
|
525 |
-
return ensemble, cfg, task
|
526 |
-
|
527 |
-
|
528 |
-
def load_model_ensemble_and_task_from_hf_hub(
|
529 |
-
model_id,
|
530 |
-
cache_dir: Optional[str] = None,
|
531 |
-
arg_overrides: Optional[Dict[str, Any]] = None,
|
532 |
-
**kwargs: Any,
|
533 |
-
):
|
534 |
-
try:
|
535 |
-
from huggingface_hub import snapshot_download
|
536 |
-
except ImportError:
|
537 |
-
raise ImportError(
|
538 |
-
"You need to install huggingface_hub to use `load_from_hf_hub`. "
|
539 |
-
"See https://pypi.org/project/huggingface-hub/ for installation."
|
540 |
-
)
|
541 |
-
|
542 |
-
library_name = "fairseq"
|
543 |
-
cache_dir = cache_dir or (Path.home() / ".cache" / library_name).as_posix()
|
544 |
-
cache_dir = snapshot_download(
|
545 |
-
model_id, cache_dir=cache_dir, library_name=library_name, **kwargs
|
546 |
-
)
|
547 |
-
|
548 |
-
_arg_overrides = arg_overrides or {}
|
549 |
-
_arg_overrides["data"] = cache_dir
|
550 |
-
return load_model_ensemble_and_task(
|
551 |
-
[p.as_posix() for p in Path(cache_dir).glob("*.pt")],
|
552 |
-
arg_overrides=_arg_overrides,
|
553 |
-
)
|
554 |
-
|
555 |
-
|
556 |
-
def checkpoint_paths(path, pattern=r"checkpoint(\d+)\.pt", keep_match=False):
|
557 |
-
"""Retrieves all checkpoints found in `path` directory.
|
558 |
-
|
559 |
-
Checkpoints are identified by matching filename to the specified pattern. If
|
560 |
-
the pattern contains groups, the result will be sorted by the first group in
|
561 |
-
descending order.
|
562 |
-
"""
|
563 |
-
pt_regexp = re.compile(pattern)
|
564 |
-
files = PathManager.ls(path)
|
565 |
-
|
566 |
-
entries = []
|
567 |
-
for i, f in enumerate(files):
|
568 |
-
m = pt_regexp.fullmatch(f)
|
569 |
-
if m is not None:
|
570 |
-
idx = float(m.group(1)) if len(m.groups()) > 0 else i
|
571 |
-
entries.append((idx, m.group(0)))
|
572 |
-
if keep_match:
|
573 |
-
return [(os.path.join(path, x[1]), x[0]) for x in sorted(entries, reverse=True)]
|
574 |
-
else:
|
575 |
-
return [os.path.join(path, x[1]) for x in sorted(entries, reverse=True)]
|
576 |
-
|
577 |
-
|
578 |
-
def torch_persistent_save(obj, filename, async_write: bool = False):
|
579 |
-
if async_write:
|
580 |
-
with PathManager.opena(filename, "wb") as f:
|
581 |
-
_torch_persistent_save(obj, f)
|
582 |
-
else:
|
583 |
-
if PathManager.supports_rename(filename):
|
584 |
-
# do atomic save
|
585 |
-
with PathManager.open(filename + ".tmp", "wb") as f:
|
586 |
-
_torch_persistent_save(obj, f)
|
587 |
-
PathManager.rename(filename + ".tmp", filename)
|
588 |
-
else:
|
589 |
-
# fallback to non-atomic save
|
590 |
-
with PathManager.open(filename, "wb") as f:
|
591 |
-
_torch_persistent_save(obj, f)
|
592 |
-
|
593 |
-
|
594 |
-
def _torch_persistent_save(obj, f):
|
595 |
-
if isinstance(f, str):
|
596 |
-
with PathManager.open(f, "wb") as h:
|
597 |
-
torch_persistent_save(obj, h)
|
598 |
-
return
|
599 |
-
for i in range(3):
|
600 |
-
try:
|
601 |
-
return torch.save(obj, f)
|
602 |
-
except Exception:
|
603 |
-
if i == 2:
|
604 |
-
logger.error(traceback.format_exc())
|
605 |
-
raise
|
606 |
-
else:
|
607 |
-
time.sleep(2.5)
|
608 |
-
|
609 |
-
|
610 |
-
def _upgrade_state_dict(state):
|
611 |
-
"""Helper for upgrading old model checkpoints."""
|
612 |
-
|
613 |
-
# add optimizer_history
|
614 |
-
if "optimizer_history" not in state:
|
615 |
-
state["optimizer_history"] = [
|
616 |
-
{"criterion_name": "CrossEntropyCriterion", "best_loss": state["best_loss"]}
|
617 |
-
]
|
618 |
-
state["last_optimizer_state"] = state["optimizer"]
|
619 |
-
del state["optimizer"]
|
620 |
-
del state["best_loss"]
|
621 |
-
# move extra_state into sub-dictionary
|
622 |
-
if "epoch" in state and "extra_state" not in state:
|
623 |
-
state["extra_state"] = {
|
624 |
-
"epoch": state["epoch"],
|
625 |
-
"batch_offset": state["batch_offset"],
|
626 |
-
"val_loss": state["val_loss"],
|
627 |
-
}
|
628 |
-
del state["epoch"]
|
629 |
-
del state["batch_offset"]
|
630 |
-
del state["val_loss"]
|
631 |
-
# reduce optimizer history's memory usage (only keep the last state)
|
632 |
-
if "optimizer" in state["optimizer_history"][-1]:
|
633 |
-
state["last_optimizer_state"] = state["optimizer_history"][-1]["optimizer"]
|
634 |
-
for optim_hist in state["optimizer_history"]:
|
635 |
-
del optim_hist["optimizer"]
|
636 |
-
# record the optimizer class name
|
637 |
-
if "optimizer_name" not in state["optimizer_history"][-1]:
|
638 |
-
state["optimizer_history"][-1]["optimizer_name"] = "FairseqNAG"
|
639 |
-
# move best_loss into lr_scheduler_state
|
640 |
-
if "lr_scheduler_state" not in state["optimizer_history"][-1]:
|
641 |
-
state["optimizer_history"][-1]["lr_scheduler_state"] = {
|
642 |
-
"best": state["optimizer_history"][-1]["best_loss"]
|
643 |
-
}
|
644 |
-
del state["optimizer_history"][-1]["best_loss"]
|
645 |
-
# keep track of number of updates
|
646 |
-
if "num_updates" not in state["optimizer_history"][-1]:
|
647 |
-
state["optimizer_history"][-1]["num_updates"] = 0
|
648 |
-
# use stateful training data iterator
|
649 |
-
if "train_iterator" not in state["extra_state"]:
|
650 |
-
state["extra_state"]["train_iterator"] = {
|
651 |
-
"epoch": state["extra_state"].get("epoch", 0),
|
652 |
-
"iterations_in_epoch": state["extra_state"].get("batch_offset", 0),
|
653 |
-
}
|
654 |
-
|
655 |
-
# backward compatibility, cfg updates
|
656 |
-
if "args" in state and state["args"] is not None:
|
657 |
-
# old model checkpoints may not have separate source/target positions
|
658 |
-
if hasattr(state["args"], "max_positions") and not hasattr(
|
659 |
-
state["args"], "max_source_positions"
|
660 |
-
):
|
661 |
-
state["args"].max_source_positions = state["args"].max_positions
|
662 |
-
state["args"].max_target_positions = state["args"].max_positions
|
663 |
-
# default to translation task
|
664 |
-
if not hasattr(state["args"], "task"):
|
665 |
-
state["args"].task = "translation"
|
666 |
-
# --raw-text and --lazy-load are deprecated
|
667 |
-
if getattr(state["args"], "raw_text", False):
|
668 |
-
state["args"].dataset_impl = "raw"
|
669 |
-
elif getattr(state["args"], "lazy_load", False):
|
670 |
-
state["args"].dataset_impl = "lazy"
|
671 |
-
# epochs start at 1
|
672 |
-
if state["extra_state"]["train_iterator"] is not None:
|
673 |
-
state["extra_state"]["train_iterator"]["epoch"] = max(
|
674 |
-
state["extra_state"]["train_iterator"].get("epoch", 1), 1
|
675 |
-
)
|
676 |
-
# --remove-bpe ==> --postprocess
|
677 |
-
if hasattr(state["args"], "remove_bpe"):
|
678 |
-
state["args"].post_process = state["args"].remove_bpe
|
679 |
-
# --min-lr ==> --stop-min-lr
|
680 |
-
if hasattr(state["args"], "min_lr"):
|
681 |
-
state["args"].stop_min_lr = state["args"].min_lr
|
682 |
-
del state["args"].min_lr
|
683 |
-
# binary_cross_entropy / kd_binary_cross_entropy => wav2vec criterion
|
684 |
-
if hasattr(state["args"], "criterion") and state["args"].criterion in [
|
685 |
-
"binary_cross_entropy",
|
686 |
-
"kd_binary_cross_entropy",
|
687 |
-
]:
|
688 |
-
state["args"].criterion = "wav2vec"
|
689 |
-
# remove log_keys if it's None (criteria will supply a default value of [])
|
690 |
-
if hasattr(state["args"], "log_keys") and state["args"].log_keys is None:
|
691 |
-
delattr(state["args"], "log_keys")
|
692 |
-
# speech_pretraining => audio pretraining
|
693 |
-
if (
|
694 |
-
hasattr(state["args"], "task")
|
695 |
-
and state["args"].task == "speech_pretraining"
|
696 |
-
):
|
697 |
-
state["args"].task = "audio_pretraining"
|
698 |
-
# audio_cpc => wav2vec
|
699 |
-
if hasattr(state["args"], "arch") and state["args"].arch == "audio_cpc":
|
700 |
-
state["args"].arch = "wav2vec"
|
701 |
-
# convert legacy float learning rate to List[float]
|
702 |
-
if hasattr(state["args"], "lr") and isinstance(state["args"].lr, float):
|
703 |
-
state["args"].lr = [state["args"].lr]
|
704 |
-
# convert task data arg to a string instead of List[string]
|
705 |
-
if (
|
706 |
-
hasattr(state["args"], "data")
|
707 |
-
and isinstance(state["args"].data, list)
|
708 |
-
and len(state["args"].data) > 0
|
709 |
-
):
|
710 |
-
state["args"].data = state["args"].data[0]
|
711 |
-
|
712 |
-
state["cfg"] = convert_namespace_to_omegaconf(state["args"])
|
713 |
-
|
714 |
-
if "cfg" in state and state["cfg"] is not None:
|
715 |
-
cfg = state["cfg"]
|
716 |
-
with open_dict(cfg):
|
717 |
-
# any upgrades for Hydra-based configs
|
718 |
-
if (
|
719 |
-
"task" in cfg
|
720 |
-
and "eval_wer_config" in cfg.task
|
721 |
-
and isinstance(cfg.task.eval_wer_config.print_alignment, bool)
|
722 |
-
):
|
723 |
-
cfg.task.eval_wer_config.print_alignment = "hard"
|
724 |
-
if "generation" in cfg and isinstance(cfg.generation.print_alignment, bool):
|
725 |
-
cfg.generation.print_alignment = (
|
726 |
-
"hard" if cfg.generation.print_alignment else None
|
727 |
-
)
|
728 |
-
if (
|
729 |
-
"model" in cfg
|
730 |
-
and "w2v_args" in cfg.model
|
731 |
-
and cfg.model.w2v_args is not None
|
732 |
-
and (
|
733 |
-
hasattr(cfg.model.w2v_args, "task") or "task" in cfg.model.w2v_args
|
734 |
-
)
|
735 |
-
and hasattr(cfg.model.w2v_args.task, "eval_wer_config")
|
736 |
-
and cfg.model.w2v_args.task.eval_wer_config is not None
|
737 |
-
and isinstance(
|
738 |
-
cfg.model.w2v_args.task.eval_wer_config.print_alignment, bool
|
739 |
-
)
|
740 |
-
):
|
741 |
-
cfg.model.w2v_args.task.eval_wer_config.print_alignment = "hard"
|
742 |
-
|
743 |
-
return state
|
744 |
-
|
745 |
-
|
746 |
-
def prune_state_dict(state_dict, model_cfg: Optional[DictConfig]):
|
747 |
-
"""Prune the given state_dict if desired for LayerDrop
|
748 |
-
(https://arxiv.org/abs/1909.11556).
|
749 |
-
|
750 |
-
Training with LayerDrop allows models to be robust to pruning at inference
|
751 |
-
time. This function prunes state_dict to allow smaller models to be loaded
|
752 |
-
from a larger model and re-maps the existing state_dict for this to occur.
|
753 |
-
|
754 |
-
It's called by functions that load models from checkpoints and does not
|
755 |
-
need to be called directly.
|
756 |
-
"""
|
757 |
-
arch = None
|
758 |
-
if model_cfg is not None:
|
759 |
-
arch = (
|
760 |
-
model_cfg._name
|
761 |
-
if isinstance(model_cfg, DictConfig)
|
762 |
-
else getattr(model_cfg, "arch", None)
|
763 |
-
)
|
764 |
-
|
765 |
-
if not model_cfg or arch is None or arch == "ptt_transformer":
|
766 |
-
# args should not be none, but don't crash if it is.
|
767 |
-
return state_dict
|
768 |
-
|
769 |
-
encoder_layers_to_keep = getattr(model_cfg, "encoder_layers_to_keep", None)
|
770 |
-
decoder_layers_to_keep = getattr(model_cfg, "decoder_layers_to_keep", None)
|
771 |
-
|
772 |
-
if not encoder_layers_to_keep and not decoder_layers_to_keep:
|
773 |
-
return state_dict
|
774 |
-
|
775 |
-
# apply pruning
|
776 |
-
logger.info(
|
777 |
-
"Pruning model to specified layer configuration - this works best if the model was trained with LayerDrop"
|
778 |
-
)
|
779 |
-
|
780 |
-
def create_pruning_pass(layers_to_keep, layer_name):
|
781 |
-
keep_layers = sorted(
|
782 |
-
int(layer_string) for layer_string in layers_to_keep.split(",")
|
783 |
-
)
|
784 |
-
mapping_dict = {}
|
785 |
-
for i in range(len(keep_layers)):
|
786 |
-
mapping_dict[str(keep_layers[i])] = str(i)
|
787 |
-
|
788 |
-
regex = re.compile(r"^{layer}.*\.layers\.(\d+)".format(layer=layer_name))
|
789 |
-
return {"substitution_regex": regex, "mapping_dict": mapping_dict}
|
790 |
-
|
791 |
-
pruning_passes = []
|
792 |
-
if encoder_layers_to_keep:
|
793 |
-
pruning_passes.append(create_pruning_pass(encoder_layers_to_keep, "encoder"))
|
794 |
-
if decoder_layers_to_keep:
|
795 |
-
pruning_passes.append(create_pruning_pass(decoder_layers_to_keep, "decoder"))
|
796 |
-
|
797 |
-
new_state_dict = {}
|
798 |
-
for layer_name in state_dict.keys():
|
799 |
-
match = re.search(r"\.layers\.(\d+)\.", layer_name)
|
800 |
-
# if layer has no number in it, it is a supporting layer, such as an
|
801 |
-
# embedding
|
802 |
-
if not match:
|
803 |
-
new_state_dict[layer_name] = state_dict[layer_name]
|
804 |
-
continue
|
805 |
-
|
806 |
-
# otherwise, layer should be pruned.
|
807 |
-
original_layer_number = match.group(1)
|
808 |
-
# figure out which mapping dict to replace from
|
809 |
-
for pruning_pass in pruning_passes:
|
810 |
-
if original_layer_number in pruning_pass["mapping_dict"] and pruning_pass[
|
811 |
-
"substitution_regex"
|
812 |
-
].search(layer_name):
|
813 |
-
new_layer_number = pruning_pass["mapping_dict"][original_layer_number]
|
814 |
-
substitution_match = pruning_pass["substitution_regex"].search(
|
815 |
-
layer_name
|
816 |
-
)
|
817 |
-
new_state_key = (
|
818 |
-
layer_name[: substitution_match.start(1)]
|
819 |
-
+ new_layer_number
|
820 |
-
+ layer_name[substitution_match.end(1) :]
|
821 |
-
)
|
822 |
-
new_state_dict[new_state_key] = state_dict[layer_name]
|
823 |
-
|
824 |
-
# Since layers are now pruned, *_layers_to_keep are no longer needed.
|
825 |
-
# This is more of "It would make it work fix" rather than a proper fix.
|
826 |
-
if isinstance(model_cfg, DictConfig):
|
827 |
-
context = open_dict(model_cfg)
|
828 |
-
else:
|
829 |
-
context = contextlib.ExitStack()
|
830 |
-
with context:
|
831 |
-
if hasattr(model_cfg, "encoder_layers_to_keep"):
|
832 |
-
model_cfg.encoder_layers_to_keep = None
|
833 |
-
if hasattr(model_cfg, "decoder_layers_to_keep"):
|
834 |
-
model_cfg.decoder_layers_to_keep = None
|
835 |
-
|
836 |
-
return new_state_dict
|
837 |
-
|
838 |
-
|
839 |
-
def load_pretrained_component_from_model(
|
840 |
-
component: Union[FairseqEncoder, FairseqDecoder],
|
841 |
-
checkpoint: str,
|
842 |
-
strict: bool = True,
|
843 |
-
):
|
844 |
-
"""
|
845 |
-
Load a pretrained FairseqEncoder or FairseqDecoder from checkpoint into the
|
846 |
-
provided `component` object. If state_dict fails to load, there may be a
|
847 |
-
mismatch in the architecture of the corresponding `component` found in the
|
848 |
-
`checkpoint` file.
|
849 |
-
"""
|
850 |
-
if not PathManager.exists(checkpoint):
|
851 |
-
raise IOError("Model file not found: {}".format(checkpoint))
|
852 |
-
state = load_checkpoint_to_cpu(checkpoint)
|
853 |
-
if isinstance(component, FairseqEncoder):
|
854 |
-
component_type = "encoder"
|
855 |
-
elif isinstance(component, FairseqDecoder):
|
856 |
-
component_type = "decoder"
|
857 |
-
else:
|
858 |
-
raise ValueError(
|
859 |
-
"component to load must be either a FairseqEncoder or "
|
860 |
-
"FairseqDecoder. Loading other component types are not supported."
|
861 |
-
)
|
862 |
-
component_state_dict = OrderedDict()
|
863 |
-
for key in state["model"].keys():
|
864 |
-
if key.startswith(component_type):
|
865 |
-
# encoder.input_layers.0.0.weight --> input_layers.0.0.weight
|
866 |
-
component_subkey = key[len(component_type) + 1 :]
|
867 |
-
component_state_dict[component_subkey] = state["model"][key]
|
868 |
-
component.load_state_dict(component_state_dict, strict=strict)
|
869 |
-
return component
|
870 |
-
|
871 |
-
|
872 |
-
def verify_checkpoint_directory(save_dir: str) -> None:
|
873 |
-
if not os.path.exists(save_dir):
|
874 |
-
os.makedirs(save_dir, exist_ok=True)
|
875 |
-
temp_file_path = os.path.join(save_dir, "dummy")
|
876 |
-
try:
|
877 |
-
with open(temp_file_path, "w"):
|
878 |
-
pass
|
879 |
-
except OSError as e:
|
880 |
-
logger.warning(
|
881 |
-
"Unable to access checkpoint save directory: {}".format(save_dir)
|
882 |
-
)
|
883 |
-
raise e
|
884 |
-
else:
|
885 |
-
os.remove(temp_file_path)
|
886 |
-
|
887 |
-
|
888 |
-
def save_ema_as_checkpoint(src_path, dst_path):
|
889 |
-
state = load_ema_from_checkpoint(src_path)
|
890 |
-
torch_persistent_save(state, dst_path)
|
891 |
-
|
892 |
-
|
893 |
-
def load_ema_from_checkpoint(fpath):
|
894 |
-
"""Loads exponential moving averaged (EMA) checkpoint from input and
|
895 |
-
returns a model with ema weights.
|
896 |
-
|
897 |
-
Args:
|
898 |
-
fpath: A string path of checkpoint to load from.
|
899 |
-
|
900 |
-
Returns:
|
901 |
-
A dict of string keys mapping to various values. The 'model' key
|
902 |
-
from the returned dict should correspond to an OrderedDict mapping
|
903 |
-
string parameter names to torch Tensors.
|
904 |
-
"""
|
905 |
-
params_dict = collections.OrderedDict()
|
906 |
-
new_state = None
|
907 |
-
|
908 |
-
with PathManager.open(fpath, "rb") as f:
|
909 |
-
new_state = torch.load(
|
910 |
-
f,
|
911 |
-
map_location=(
|
912 |
-
lambda s, _: torch.serialization.default_restore_location(s, "cpu")
|
913 |
-
),
|
914 |
-
)
|
915 |
-
|
916 |
-
# EMA model is stored in a separate "extra state"
|
917 |
-
model_params = new_state["extra_state"]["ema"]
|
918 |
-
|
919 |
-
for key in list(model_params.keys()):
|
920 |
-
p = model_params[key]
|
921 |
-
if isinstance(p, torch.HalfTensor):
|
922 |
-
p = p.float()
|
923 |
-
if key not in params_dict:
|
924 |
-
params_dict[key] = p.clone()
|
925 |
-
# NOTE: clone() is needed in case of p is a shared parameter
|
926 |
-
else:
|
927 |
-
raise ValueError("Key {} is repeated in EMA model params.".format(key))
|
928 |
-
|
929 |
-
if len(params_dict) == 0:
|
930 |
-
raise ValueError(
|
931 |
-
f"Input checkpoint path '{fpath}' does not contain "
|
932 |
-
"ema model weights, is this model trained with EMA?"
|
933 |
-
)
|
934 |
-
|
935 |
-
new_state["model"] = params_dict
|
936 |
-
return new_state
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|