TomatoCocotree
上传
6a62ffb
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import logging
from typing import List
logger = logging.getLogger(__name__)
def uniform(dataset_sizes: List[int]):
return [1.0] * len(dataset_sizes)
def temperature_sampling(dataset_sizes, temp):
total_size = sum(dataset_sizes)
return [(size / total_size) ** (1.0 / temp) for size in dataset_sizes]
def make_temperature_sampling(temp=1.0):
def sampling_func(dataset_sizes):
return temperature_sampling(dataset_sizes, temp)
return sampling_func
def make_ratio_sampling(ratios):
def sampling_func(dataset_sizes):
return ratios
return sampling_func
class SamplingMethod:
@staticmethod
def add_arguments(parser):
parser.add_argument(
"--sampling-method",
choices=[
"uniform",
"temperature",
"concat",
"RoundRobin",
],
type=str,
default="concat",
help="The method to sample data per language pairs",
)
parser.add_argument(
"--sampling-temperature",
default=1.5,
type=float,
help="only work with --sampling-method temperature",
)
@staticmethod
def build_sampler(args, task):
return SamplingMethod(args, task)
def __init__(self, args, task):
self.args = args
self.task = task
def is_adaptive(self):
return False
def sampling_method_selector(self):
args = self.args
logger.info(f"selected sampler: {args.sampling_method}")
if args.sampling_method == "uniform":
return uniform
elif args.sampling_method == "temperature" or self.is_adaptive():
return make_temperature_sampling(float(args.sampling_temperature))
else:
# default to concating all data set together
return None