|
|
|
|
|
|
|
|
|
|
|
import logging |
|
from typing import List |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def uniform(dataset_sizes: List[int]): |
|
return [1.0] * len(dataset_sizes) |
|
|
|
|
|
def temperature_sampling(dataset_sizes, temp): |
|
total_size = sum(dataset_sizes) |
|
return [(size / total_size) ** (1.0 / temp) for size in dataset_sizes] |
|
|
|
|
|
def make_temperature_sampling(temp=1.0): |
|
def sampling_func(dataset_sizes): |
|
return temperature_sampling(dataset_sizes, temp) |
|
|
|
return sampling_func |
|
|
|
|
|
def make_ratio_sampling(ratios): |
|
def sampling_func(dataset_sizes): |
|
return ratios |
|
|
|
return sampling_func |
|
|
|
|
|
class SamplingMethod: |
|
@staticmethod |
|
def add_arguments(parser): |
|
parser.add_argument( |
|
"--sampling-method", |
|
choices=[ |
|
"uniform", |
|
"temperature", |
|
"concat", |
|
"RoundRobin", |
|
], |
|
type=str, |
|
default="concat", |
|
help="The method to sample data per language pairs", |
|
) |
|
parser.add_argument( |
|
"--sampling-temperature", |
|
default=1.5, |
|
type=float, |
|
help="only work with --sampling-method temperature", |
|
) |
|
|
|
@staticmethod |
|
def build_sampler(args, task): |
|
return SamplingMethod(args, task) |
|
|
|
def __init__(self, args, task): |
|
self.args = args |
|
self.task = task |
|
|
|
def is_adaptive(self): |
|
return False |
|
|
|
def sampling_method_selector(self): |
|
args = self.args |
|
logger.info(f"selected sampler: {args.sampling_method}") |
|
if args.sampling_method == "uniform": |
|
return uniform |
|
elif args.sampling_method == "temperature" or self.is_adaptive(): |
|
return make_temperature_sampling(float(args.sampling_temperature)) |
|
else: |
|
|
|
return None |
|
|