import argparse |
import collections |
import contextlib |
import copy |
import importlib |
import logging |
import os |
import sys |
import warnings |
from itertools import accumulate |
from typing import TYPE_CHECKING, Callable, Dict, List, Optional |
import torch |
import torch.nn.functional as F |
from torch import Tensor |
from fairseq.modules.multihead_attention import MultiheadAttention |
try: |
from amp_C import multi_tensor_l2norm |
multi_tensor_l2norm_available = True |
except ImportError: |
multi_tensor_l2norm_available = False |
try: |
import torch_xla.core.xla_model as xm |
except ImportError: |
xm = None |
logger = logging.getLogger(__name__) |
class FileContentsAction(argparse.Action): |
def __init__(self, option_strings, dest, nargs=None, **kwargs): |
if nargs is not None: |
raise ValueError("nargs not allowed") |
super(FileContentsAction, self).__init__(option_strings, dest, **kwargs) |
def __call__(self, parser, namespace, values, option_string=None): |
from fairseq.file_io import PathManager |
if PathManager.isfile(values): |
with PathManager.open(values) as f: |
argument = f.read().strip() |
else: |
argument = values |
setattr(namespace, self.dest, argument) |
def split_paths(paths: str, separator=os.pathsep) -> List[str]: |
return ( |
paths.split(separator) if "://" not in paths else paths.split(MANIFOLD_PATH_SEP) |
) |
def load_ensemble_for_inference(filenames, task, model_arg_overrides=None): |
from fairseq import checkpoint_utils |
deprecation_warning( |
"utils.load_ensemble_for_inference is deprecated. " |
"Please use checkpoint_utils.load_model_ensemble instead." |
) |
return checkpoint_utils.load_model_ensemble( |
filenames, arg_overrides=model_arg_overrides, task=task |
) |
def apply_to_sample(f, sample): |
if hasattr(sample, "__len__") and len(sample) == 0: |
return {} |
def _apply(x): |
if torch.is_tensor(x): |
return f(x) |
elif isinstance(x, collections.OrderedDict): |
od = collections.OrderedDict( |
(key, _apply(value)) for key, value in x.items() |
) |
od.__dict__ = x.__dict__ |
return od |
elif isinstance(x, dict): |
return {key: _apply(value) for key, value in x.items()} |
elif isinstance(x, list): |
return [_apply(x) for x in x] |
elif isinstance(x, tuple): |
return tuple(_apply(x) for x in x) |
elif isinstance(x, set): |
return {_apply(x) for x in x} |
else: |
return x |
return _apply(sample) |
def move_to_cuda(sample, device=None): |
device = device or torch.cuda.current_device() |
def _move_to_cuda(tensor): |
return tensor.to(device=device, non_blocking=True) |
return apply_to_sample(_move_to_cuda, sample) |
def move_to_cpu(sample): |
def _move_to_cpu(tensor): |
if tensor.dtype in {torch.bfloat16, torch.float16}: |
tensor = tensor.to(dtype=torch.float32) |
return tensor.cpu() |
return apply_to_sample(_move_to_cpu, sample) |
def move_to_tpu(sample): |
import torch_xla.core.xla_model as xm |
device = xm.xla_device() |
def _move_to_tpu(tensor): |
return tensor.to(device) |
return apply_to_sample(_move_to_tpu, sample) |
def get_incremental_state( |
module: "MultiheadAttention", |
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], |
key: str, |
) -> Optional[Dict[str, Optional[Tensor]]]: |
"""Helper for getting incremental state for an nn.Module.""" |
return module.get_incremental_state(incremental_state, key) |
def set_incremental_state( |
module: "MultiheadAttention", |
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]], |
key: str, |
value: Dict[str, Optional[Tensor]], |
) -> Optional[Dict[str, Dict[str, Optional[Tensor]]]]: |
"""Helper for setting incremental state for an nn.Module.""" |
if incremental_state is not None: |
result = module.set_incremental_state(incremental_state, key, value) |
if result is not None: |
incremental_state = result |
return incremental_state |
def load_align_dict(replace_unk): |
if replace_unk is None: |
align_dict = None |
elif isinstance(replace_unk, str) and len(replace_unk) > 0: |
align_dict = {} |
with open(replace_unk, "r") as f: |
for line in f: |
cols = line.split() |
align_dict[cols[0]] = cols[1] |
else: |
align_dict = {} |
return align_dict |
def print_embed_overlap(embed_dict, vocab_dict): |
embed_keys = set(embed_dict.keys()) |
vocab_keys = set(vocab_dict.symbols) |
overlap = len(embed_keys & vocab_keys) |
logger.info("found {}/{} types in embedding file".format(overlap, len(vocab_dict))) |
def parse_embedding(embed_path): |
"""Parse embedding text file into a dictionary of word and embedding tensors. |
The first line can have vocabulary size and dimension. The following lines |
should contain word and embedding separated by spaces. |
Example: |
2 5 |
the -0.0230 -0.0264 0.0287 0.0171 0.1403 |
at -0.0395 -0.1286 0.0275 0.0254 -0.0932 |
""" |
embed_dict = {} |
with open(embed_path) as f_embed: |
next(f_embed) |
for line in f_embed: |
pieces = line.rstrip().split(" ") |
embed_dict[pieces[0]] = torch.Tensor( |
[float(weight) for weight in pieces[1:]] |
) |
return embed_dict |
def load_embedding(embed_dict, vocab, embedding): |
for idx in range(len(vocab)): |
token = vocab[idx] |
if token in embed_dict: |
embedding.weight.data[idx] = embed_dict[token] |
return embedding |
def replace_unk(hypo_str, src_str, alignment, align_dict, unk): |
from fairseq import tokenizer |
hypo_tokens = tokenizer.tokenize_line(hypo_str) |
src_tokens = tokenizer.tokenize_line(src_str) + ["<eos>"] |
for i, ht in enumerate(hypo_tokens): |
if ht == unk: |
src_token = src_tokens[alignment[i]] |
hypo_tokens[i] = align_dict.get(src_token, src_token) |
return " ".join(hypo_tokens) |
def post_process_prediction( |
hypo_tokens, |
src_str, |
alignment, |
align_dict, |
tgt_dict, |
remove_bpe=None, |
extra_symbols_to_ignore=None, |
): |
hypo_str = tgt_dict.string( |
hypo_tokens, remove_bpe, extra_symbols_to_ignore=extra_symbols_to_ignore |
) |
if align_dict is not None: |
hypo_str = replace_unk( |
hypo_str, src_str, alignment, align_dict, tgt_dict.unk_string() |
) |
if align_dict is not None or remove_bpe is not None: |
hypo_tokens = tgt_dict.encode_line(hypo_str, add_if_not_exist=True) |
return hypo_tokens, hypo_str, alignment |
def make_positions(tensor, padding_idx: int, onnx_trace: bool = False): |
"""Replace non-padding symbols with their position numbers. |
Position numbers begin at padding_idx+1. Padding symbols are ignored. |
""" |
mask = tensor.ne(padding_idx).int() |
return (torch.cumsum(mask, dim=1).type_as(mask) * mask).long() + padding_idx |
def strip_pad(tensor, pad): |
return tensor[tensor.ne(pad)] |
def buffered_arange(max): |
if not hasattr(buffered_arange, "buf"): |
buffered_arange.buf = torch.LongTensor() |
if max > buffered_arange.buf.numel(): |
buffered_arange.buf.resize_(max) |
torch.arange(max, out=buffered_arange.buf) |
return buffered_arange.buf[:max] |
def convert_padding_direction( |
src_tokens, padding_idx, right_to_left: bool = False, left_to_right: bool = False |
): |
assert right_to_left ^ left_to_right |
pad_mask = src_tokens.eq(padding_idx) |
if not pad_mask.any(): |
return src_tokens |
if left_to_right and not pad_mask[:, 0].any(): |
return src_tokens |
if right_to_left and not pad_mask[:, -1].any(): |
return src_tokens |
max_len = src_tokens.size(1) |
buffered = torch.empty(0).long() |
if max_len > 0: |
torch.arange(max_len, out=buffered) |
range = buffered.type_as(src_tokens).expand_as(src_tokens) |
num_pads = pad_mask.long().sum(dim=1, keepdim=True) |
if right_to_left: |
index = torch.remainder(range - num_pads, max_len) |
else: |
index = torch.remainder(range + num_pads, max_len) |
return src_tokens.gather(1, index) |
def item(tensor): |
if torch.is_tensor(tensor) and tensor.device.type == "xla": |
return tensor.detach() |
if hasattr(tensor, "item"): |
return tensor.item() |
if hasattr(tensor, "__getitem__"): |
return tensor[0] |
return tensor |
def multi_tensor_total_norm(grads, chunk_size=2048 * 32) -> torch.Tensor: |
per_device_grads = {} |
norms = [] |
for grad in grads: |
device = grad.device |
cur_device_grads = per_device_grads.get(device) |
if cur_device_grads is None: |
cur_device_grads = [] |
per_device_grads[device] = cur_device_grads |
cur_device_grads.append(grad) |
for device in per_device_grads.keys(): |
cur_device_grads = per_device_grads[device] |
if device.type == "cuda": |
has_inf = torch.zeros((1, 1), dtype=torch.int, device=device) |
with torch.cuda.device(device): |
norm = multi_tensor_l2norm( |
chunk_size, has_inf, [cur_device_grads], False |
) |
norms.append(norm[0].to(torch.cuda.current_device())) |
else: |
norms += [torch.norm(g, p=2, dtype=torch.float32) for g in cur_device_grads] |
total_norm = torch.norm(torch.stack(norms)) |
return total_norm |
@torch.no_grad() |
def clip_grad_norm_(params, max_norm, aggregate_norm_fn=None) -> torch.Tensor: |
def grad_exists(p): |
return p is not None and getattr(p, "grad", None) is not None |
if isinstance(params, torch.Tensor): |
params = [params] |
params = list(params) |
grads = [ |
p.grad.detach() for p in params if grad_exists(p) and not hasattr(p, "expert") |
] |
expert_grads = [ |
p.grad.detach() for p in params if grad_exists(p) and hasattr(p, "expert") |
] |
if len(grads) == 0: |
if len(params) > 0: |
return params[0].new_tensor(0.0) |
else: |
return torch.tensor(0.0) |
if len(grads) == 1: |
total_norm = torch.norm(grads[0], p=2, dtype=torch.float32) |
else: |
if multi_tensor_l2norm_available: |
total_norm = multi_tensor_total_norm(grads) |
else: |
if torch.cuda.is_available(): |
warnings.warn( |
"amp_C fused kernels unavailable, disabling multi_tensor_l2norm; " |
"you may get better performance by installing NVIDIA's apex library" |
) |
device = torch.cuda.current_device() |
elif grads[0].device.type == "xla": |
device = grads[0].device |
else: |
device = torch.device("cpu") |
total_norm = torch.norm( |
torch.stack( |
[torch.norm(g, p=2, dtype=torch.float32).to(device) for g in grads] |
) |
) |
if aggregate_norm_fn is not None: |
total_norm = aggregate_norm_fn(total_norm) |
if max_norm > 0: |
max_norm = float(max_norm) |
clip_coef = (max_norm / (total_norm + 1e-6)).clamp_(max=1) |
for g in grads + expert_grads: |
g.mul_(clip_coef) |
return total_norm |
def fill_with_neg_inf(t): |
"""FP16-compatible function that fills a tensor with -inf.""" |
return t.float().fill_(float("-inf")).type_as(t) |
def _match_types(arg1, arg2): |
"""Convert the numerical argument to the same type as the other argument""" |
def upgrade(arg_number, arg_structure): |
if isinstance(arg_structure, tuple): |
return tuple([arg_number] * len(arg_structure)) |
elif isinstance(arg_structure, dict): |
arg = copy.deepcopy(arg_structure) |
for k in arg: |
arg[k] = upgrade(arg_number, arg_structure[k]) |
return arg |
else: |
return arg_number |
if isinstance(arg1, float) or isinstance(arg1, int): |
return upgrade(arg1, arg2), arg2 |
elif isinstance(arg2, float) or isinstance(arg2, int): |
return arg1, upgrade(arg2, arg1) |
return arg1, arg2 |
def resolve_max_positions(*args): |
"""Resolve max position constraints from multiple sources.""" |
def map_value_update(d1, d2): |
updated_value = copy.deepcopy(d1) |
for key in d2: |
if key not in updated_value: |
updated_value[key] = d2[key] |
else: |
updated_value[key] = min(d1[key], d2[key]) |
return updated_value |
def nullsafe_min(l): |
minim = None |
for item in l: |
if minim is None: |
minim = item |
elif item is not None and item < minim: |
minim = item |
return minim |
max_positions = None |
for arg in args: |
if max_positions is None: |
max_positions = arg |
elif arg is not None: |
max_positions, arg = _match_types(max_positions, arg) |
if isinstance(arg, float) or isinstance(arg, int): |
max_positions = min(max_positions, arg) |
elif isinstance(arg, dict): |
max_positions = map_value_update(max_positions, arg) |
else: |
max_positions = tuple(map(nullsafe_min, zip(max_positions, arg))) |
return max_positions |
def import_user_module(args): |
module_path = getattr(args, "user_dir", None) |
if module_path is not None: |
module_path = os.path.abspath(args.user_dir) |
if not os.path.exists(module_path) and not os.path.isfile( |
os.path.dirname(module_path) |
): |
fairseq_rel_path = os.path.join(os.path.dirname(__file__), args.user_dir) |
if os.path.exists(fairseq_rel_path): |
module_path = fairseq_rel_path |
else: |
fairseq_rel_path = os.path.join( |
os.path.dirname(__file__), "..", args.user_dir |
) |
if os.path.exists(fairseq_rel_path): |
module_path = fairseq_rel_path |
else: |
raise FileNotFoundError(module_path) |
import_user_module.memo = getattr(import_user_module, "memo", set()) |
if module_path not in import_user_module.memo: |
import_user_module.memo.add(module_path) |
module_parent, module_name = os.path.split(module_path) |
if module_name not in sys.modules: |
sys.path.insert(0, module_parent) |
importlib.import_module(module_name) |
tasks_path = os.path.join(module_path, "tasks") |
if os.path.exists(tasks_path): |
from fairseq.tasks import import_tasks |
import_tasks(tasks_path, f"{module_name}.tasks") |
models_path = os.path.join(module_path, "models") |
if os.path.exists(models_path): |
from fairseq.models import import_models |
import_models(models_path, f"{module_name}.models") |
elif module_path in sys.modules[module_name].__path__: |
logger.info(f"--user-dir={module_path} has already been imported.") |
else: |
raise ImportError( |
"Failed to import --user-dir={} because the corresponding module name " |
"({}) is not globally unique. Please rename the directory to " |
"something unique and try again.".format(module_path, module_name) |
) |
def softmax(x, dim: int, onnx_trace: bool = False): |
if onnx_trace: |
return F.softmax(x.float(), dim=dim) |
else: |
return F.softmax(x, dim=dim, dtype=torch.float32) |
def log_softmax(x, dim: int, onnx_trace: bool = False): |
if onnx_trace: |
return F.log_softmax(x.float(), dim=dim) |
else: |
return F.log_softmax(x, dim=dim, dtype=torch.float32) |
def get_perplexity(loss, round=2, base=2): |
from fairseq.logging.meters import safe_round |
if loss is None: |
return 0.0 |
try: |
return safe_round(base**loss, round) |
except OverflowError: |
return float("inf") |
def deprecation_warning(message, stacklevel=3): |
warnings.warn(message, stacklevel=stacklevel) |
def relu_squared(x: torch.Tensor): |
return F.relu(x).pow(2) |
def get_activation_fn(activation: str) -> Callable: |
"""Returns the activation function corresponding to `activation`""" |
from fairseq.modules import gelu, gelu_accurate |
if activation == "relu": |
return F.relu |
elif activation == "relu_squared": |
return relu_squared |
elif activation == "gelu": |
return gelu |
elif activation == "gelu_fast": |
deprecation_warning( |
"--activation-fn=gelu_fast has been renamed to gelu_accurate" |
) |
return gelu_accurate |
elif activation == "gelu_accurate": |
return gelu_accurate |
elif activation == "tanh": |
return torch.tanh |
elif activation == "linear": |
return lambda x: x |
elif activation == "swish": |
return torch.nn.SiLU |
else: |
raise RuntimeError("--activation-fn {} not supported".format(activation)) |
def get_available_activation_fns() -> List: |
return [ |
"relu", |
"gelu", |
"gelu_fast", |
"gelu_accurate", |
"tanh", |
"linear", |
] |
@contextlib.contextmanager |
def model_eval(model): |
is_training = model.training |
model.eval() |
yield |
model.train(is_training) |
def has_parameters(module): |
try: |
next(module.parameters()) |
return True |
except StopIteration: |
return False |
def get_rng_state(): |
state = {"torch_rng_state": torch.get_rng_state()} |
if xm is not None: |
state["xla_rng_state"] = xm.get_rng_state() |
if torch.cuda.is_available(): |
state["cuda_rng_state"] = torch.cuda.get_rng_state() |
return state |
def set_rng_state(state): |
torch.set_rng_state(state["torch_rng_state"]) |
if xm is not None: |
xm.set_rng_state(state["xla_rng_state"]) |
if torch.cuda.is_available(): |
torch.cuda.set_rng_state(state["cuda_rng_state"]) |
class set_torch_seed(object): |
def __init__(self, seed): |
assert isinstance(seed, int) |
self.rng_state = get_rng_state() |
torch.manual_seed(seed) |
if xm is not None: |
xm.set_rng_state(seed) |
if torch.cuda.is_available(): |
torch.cuda.manual_seed(seed) |
def __enter__(self): |
return self |
def __exit__(self, *exc): |
set_rng_state(self.rng_state) |
def parse_alignment(line): |
""" |
Parses a single line from the alingment file. |
Args: |
line (str): String containing the alignment of the format: |
<src_idx_1>-<tgt_idx_1> <src_idx_2>-<tgt_idx_2> .. |
<src_idx_m>-<tgt_idx_m>. All indices are 0 indexed. |
Returns: |
torch.IntTensor: packed alignments of shape (2 * m). |
""" |
alignments = line.strip().split() |
parsed_alignment = torch.IntTensor(2 * len(alignments)) |
for idx, alignment in enumerate(alignments): |
src_idx, tgt_idx = alignment.split("-") |
parsed_alignment[2 * idx] = int(src_idx) |
parsed_alignment[2 * idx + 1] = int(tgt_idx) |
return parsed_alignment |
def get_token_to_word_mapping(tokens, exclude_list): |
n = len(tokens) |
word_start = [int(token not in exclude_list) for token in tokens] |
word_idx = list(accumulate(word_start)) |
token_to_word = {i: word_idx[i] for i in range(n)} |
return token_to_word |
def extract_hard_alignment(attn, src_sent, tgt_sent, pad, eos): |
tgt_valid = ( |
((tgt_sent != pad) & (tgt_sent != eos)).nonzero(as_tuple=False).squeeze(dim=-1) |
) |
src_invalid = ( |
((src_sent == pad) | (src_sent == eos)).nonzero(as_tuple=False).squeeze(dim=-1) |
) |
src_token_to_word = get_token_to_word_mapping(src_sent, [eos, pad]) |
tgt_token_to_word = get_token_to_word_mapping(tgt_sent, [eos, pad]) |
alignment = [] |
if len(tgt_valid) != 0 and len(src_invalid) < len(src_sent): |
attn_valid = attn[tgt_valid] |
attn_valid[:, src_invalid] = float("-inf") |
_, src_indices = attn_valid.max(dim=1) |
for tgt_idx, src_idx in zip(tgt_valid, src_indices): |
alignment.append( |
( |
src_token_to_word[src_idx.item()] - 1, |
tgt_token_to_word[tgt_idx.item()] - 1, |
) |
) |
return alignment |
def extract_soft_alignment(attn, src_sent, tgt_sent, pad, eos): |
tgt_valid = ((tgt_sent != pad)).nonzero(as_tuple=False) |
src_valid = ((src_sent != pad)).nonzero(as_tuple=False).squeeze(dim=-1) |
alignment = [] |
if len(tgt_valid) != 0 and len(src_valid) != 0: |
attn_valid = attn[tgt_valid, src_valid] |
alignment = [ |
["{:.6f}".format(p) for p in src_probs.tolist()] for src_probs in attn_valid |
] |
return alignment |
def new_arange(x, *size): |
""" |
Return a Tensor of `size` filled with a range function on the device of x. |
If size is empty, using the size of the variable x. |
""" |
if len(size) == 0: |
size = x.size() |
return torch.arange(size[-1], device=x.device).expand(*size).contiguous() |
def get_tpu_device(): |
return xm.xla_device() |
def tpu_data_loader(itr): |
import torch_xla.core.xla_model as xm |
import torch_xla.distributed.parallel_loader as pl |
from fairseq.data import iterators |
xm.rendezvous("tpu_data_loader") |
xm.mark_step() |
device = xm.xla_device() |
return iterators.CountingIterator( |
pl.ParallelLoader(itr, [device]).per_device_loader(device), |
start=getattr(itr, "n", 0), |
total=len(itr), |
) |
def is_xla_tensor(tensor): |
return torch.is_tensor(tensor) and tensor.device.type == "xla" |
def index_put(tensor, indices, value): |
if is_xla_tensor(tensor): |
for _ in range(indices.dim(), tensor.dim()): |
indices = indices.unsqueeze(-1) |
if indices.size(-1) < tensor.size(-1): |
indices = indices.expand_as(tensor) |
tensor = torch.mul(tensor, ~indices) + torch.mul(value, indices) |
else: |
tensor[indices] = value |
return tensor |
def xla_device_to_cpu(dat): |
import torch_xla.core.xla_model as xm |
return xm._maybe_convert_to_cpu(dat) |
class CudaEnvironment(object): |
def __init__(self): |
cur_device = torch.cuda.current_device() |
prop = torch.cuda.get_device_properties("cuda:{}".format(cur_device)) |
self.name = prop.name |
self.major = prop.major |
self.minor = prop.minor |
self.total_memory_in_GB = prop.total_memory / 1024 / 1024 / 1024 |
@staticmethod |
def pretty_print_cuda_env_list(cuda_env_list): |
""" |
Given a list of CudaEnviorments, pretty print them |
""" |
num_workers = len(cuda_env_list) |
center = "CUDA enviroments for all {} workers".format(num_workers) |
banner_len = 40 - len(center) // 2 |
first_line = "*" * banner_len + center + "*" * banner_len |
logger.info(first_line) |
for r, env in enumerate(cuda_env_list): |
logger.info( |
"rank {:3d}: ".format(r) |
+ "capabilities = {:2d}.{:<2d} ; ".format(env.major, env.minor) |
+ "total memory = {:.3f} GB ; ".format(env.total_memory_in_GB) |
+ "name = {:40s}".format(env.name) |
) |
logger.info(first_line) |
def csv_str_list(x): |
return x.split(",") |
def eval_str_list(x, type=float): |
if x is None: |
return None |
if isinstance(x, str): |
x = eval(x) |
try: |
return list(map(type, x)) |
except TypeError: |
return [type(x)] |
def eval_str_dict(x, type=dict): |
if x is None: |
return None |
if isinstance(x, str): |
x = eval(x) |
return x |
def eval_bool(x, default=False): |
if x is None: |
return default |
try: |
return bool(eval(x)) |
except TypeError: |
return default |
def reset_logging(): |
root = logging.getLogger() |
for handler in root.handlers: |
root.removeHandler(handler) |
root.setLevel(os.environ.get("LOGLEVEL", "INFO").upper()) |
handler = logging.StreamHandler(sys.stdout) |
handler.setFormatter( |
logging.Formatter( |
fmt="%(asctime)s | %(levelname)s | %(name)s | %(message)s", |
datefmt="%Y-%m-%d %H:%M:%S", |
) |
) |
root.addHandler(handler) |
def safe_getattr(obj, k, default=None): |
"""Returns obj[k] if it exists and is not None, otherwise returns default.""" |
from omegaconf import OmegaConf |
if OmegaConf.is_config(obj): |
return obj[k] if k in obj and obj[k] is not None else default |
return getattr(obj, k, default) |
def safe_hasattr(obj, k): |
"""Returns True if the given key exists and is not None.""" |
return getattr(obj, k, None) is not None |
def hotreload_function(name=None): |
""" |
Decorator to function to enable hot-reload for debugging. |
It allows you to debug a function without having reloading all heavy models, dataset loading and |
preprocessing, allow faster debugging. |
If you want to change model or dataset loading, consider relaunching your code |
----------------------------------- |
This will run the decorated function func: |
if func run successful: |
It will pause, allow user to edit code, and prompt user to: |
Press enter to re-run the function with updated code |
Type "done" to finish the function, return output |
Type "disable" to stop pausing this function and let code continue without pause |
Ctril + C to terminal |
if func raise error: |
it will prompt user to |
1. Edit code, and press enter to retry |
2. Ctrl + C to terminate |
3. Type "raise" to raise that exception |
* Requirements: |
0. Fairseq was installed with `pip install --editable .` |
1. pip install jurigged[develoop] |
3. Run on only 1 GPU (no distributed) |
* How to use: |
1. in python, import and decorate the top-level function to be re-run after code edits: |
```python |
from fairseq.utils import hotreload_function |
.... |
@hotreload_function("train_step") |
def train_step(self, sample ....): |
.... |
.... |
``` |
2. in bash run scripts: |
```bash |
watch_dir=<home>/fairseq-py/fairseq/tasks # directory to watch for file changes |
export CUDA_VISIBLE_DEVICES=0 # single-gpu |
HOTRELOAD_PAUSE=1 CUDA_LAUNCH_BLOCKING=1 python -m jurigged -w ${watch_dir} --poll 2 -v train.py ...... |
``` |
* NOTE: |
1. -w ${watch_dir} specify all the files to be watched for changes |
once functions, class, ... code are changed, all instances in the process will get updated (hot-reload) |
* Limitation: |
* Currently distributed debugging not working |
* Need to launch train.py locally (cannot submit jobs) |
""" |
try: |
import jurigged |
except ImportError as e: |
logger.warning("Please install jurigged: pip install jurigged[develoop]") |
raise e |
from fairseq.distributed import utils as distributed_utils |
import traceback |
def hotreload_decorator(func): |
assert callable(func), f"not callable: {func}" |
jname = name or func.__name__ |
logger.info(f"jurigged-hotreload:Apply jurigged on {jname}:{func.__name__}") |
HOTRELOAD_PAUSE = bool(os.environ.get("HOTRELOAD_PAUSE", 0)) |
cublk = bool(os.environ.get("CUDA_LAUNCH_BLOCKING", 0)) |
prefix = f"HOTRELOAD:{jname}:[cublk={cublk}]" |
hot_reload_state = {"disable": False} |
def func_wrapper(*args, **kwargs): |
if not HOTRELOAD_PAUSE or hot_reload_state["disable"]: |
return func(*args, **kwargs) |
world_size = distributed_utils.get_global_world_size() |
assert ( |
world_size <= 1 |
), f"HOTRELOAD_PAUSE:{jname} currently cannot do distributed training" |
success = False |
while not success: |
try: |
output = func(*args, **kwargs) |
end_action = input( |
f"{prefix}: PAUSE, you may edit code now. Enter to re-run, ctrl+C to terminate, " |
f'type "done" to continue (function still being watched), or type "disable" to stop pausing this function :' |
) |
if end_action.strip().lower() in ["disable", "done"]: |
success = True |
else: |
logger.warning( |
f"{prefix}: action={end_action} function will re-run now." |
) |
except Exception as e: |
action = input( |
f"{prefix}:ERROR: \n{traceback.format_exc()}\n" |
f'Edit code to try again: enter to continue, ctrl+C to terminate, or type "raise" to raise the exception: ' |
) |
if action.strip().lower() == "raise": |
raise e |
if end_action.strip().lower() == "disable": |
logger.warning( |
f"{prefix}: Stop pausing {jname}. The function is still being watched and newly editted code will take effect " |
f"if the {jname} is called again later." |
f' "unset HOTRELOAD_PAUSE" before relaunch to disable hotreload and' |
f" remove @hotreload_function decorator in the code." |
) |
hot_reload_state["disable"] = True |
return output |
return func_wrapper |
return hotreload_decorator |