File size: 6,055 Bytes
c968fc3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 |
# Copyright (c) 2023 Amphion.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import torch as th
def get_generator(generator, num_samples=0, seed=0):
if generator == "dummy":
return DummyGenerator()
elif generator == "determ":
return DeterministicGenerator(num_samples, seed)
elif generator == "determ-indiv":
return DeterministicIndividualGenerator(num_samples, seed)
else:
raise NotImplementedError
class DummyGenerator:
def randn(self, *args, **kwargs):
return th.randn(*args, **kwargs)
def randint(self, *args, **kwargs):
return th.randint(*args, **kwargs)
def randn_like(self, *args, **kwargs):
return th.randn_like(*args, **kwargs)
class DeterministicGenerator:
"""
RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines
Uses a single rng and samples num_samples sized randomness and subsamples the current indices
"""
def __init__(self, num_samples, seed=0):
print("Warning: Distributed not initialised, using single rank")
self.rank = 0
self.world_size = 1
self.num_samples = num_samples
self.done_samples = 0
self.seed = seed
self.rng_cpu = th.Generator()
if th.cuda.is_available():
self.rng_cuda = th.Generator(dist_util.dev())
self.set_seed(seed)
def get_global_size_and_indices(self, size):
global_size = (self.num_samples, *size[1:])
indices = th.arange(
self.done_samples + self.rank,
self.done_samples + self.world_size * int(size[0]),
self.world_size,
)
indices = th.clamp(indices, 0, self.num_samples - 1)
assert (
len(indices) == size[0]
), f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}"
return global_size, indices
def get_generator(self, device):
return self.rng_cpu if th.device(device).type == "cpu" else self.rng_cuda
def randn(self, *size, dtype=th.float, device="cpu"):
global_size, indices = self.get_global_size_and_indices(size)
generator = self.get_generator(device)
return th.randn(*global_size, generator=generator, dtype=dtype, device=device)[
indices
]
def randint(self, low, high, size, dtype=th.long, device="cpu"):
global_size, indices = self.get_global_size_and_indices(size)
generator = self.get_generator(device)
return th.randint(
low, high, generator=generator, size=global_size, dtype=dtype, device=device
)[indices]
def randn_like(self, tensor):
size, dtype, device = tensor.size(), tensor.dtype, tensor.device
return self.randn(*size, dtype=dtype, device=device)
def set_done_samples(self, done_samples):
self.done_samples = done_samples
self.set_seed(self.seed)
def get_seed(self):
return self.seed
def set_seed(self, seed):
self.rng_cpu.manual_seed(seed)
if th.cuda.is_available():
self.rng_cuda.manual_seed(seed)
class DeterministicIndividualGenerator:
"""
RNG to deterministically sample num_samples samples that does not depend on batch_size or mpi_machines
Uses a separate rng for each sample to reduce memoery usage
"""
def __init__(self, num_samples, seed=0):
print("Warning: Distributed not initialised, using single rank")
self.rank = 0
self.world_size = 1
self.num_samples = num_samples
self.done_samples = 0
self.seed = seed
self.rng_cpu = [th.Generator() for _ in range(num_samples)]
if th.cuda.is_available():
self.rng_cuda = [th.Generator(dist_util.dev()) for _ in range(num_samples)]
self.set_seed(seed)
def get_size_and_indices(self, size):
indices = th.arange(
self.done_samples + self.rank,
self.done_samples + self.world_size * int(size[0]),
self.world_size,
)
indices = th.clamp(indices, 0, self.num_samples - 1)
assert (
len(indices) == size[0]
), f"rank={self.rank}, ws={self.world_size}, l={len(indices)}, bs={size[0]}"
return (1, *size[1:]), indices
def get_generator(self, device):
return self.rng_cpu if th.device(device).type == "cpu" else self.rng_cuda
def randn(self, *size, dtype=th.float, device="cpu"):
size, indices = self.get_size_and_indices(size)
generator = self.get_generator(device)
return th.cat(
[
th.randn(*size, generator=generator[i], dtype=dtype, device=device)
for i in indices
],
dim=0,
)
def randint(self, low, high, size, dtype=th.long, device="cpu"):
size, indices = self.get_size_and_indices(size)
generator = self.get_generator(device)
return th.cat(
[
th.randint(
low,
high,
generator=generator[i],
size=size,
dtype=dtype,
device=device,
)
for i in indices
],
dim=0,
)
def randn_like(self, tensor):
size, dtype, device = tensor.size(), tensor.dtype, tensor.device
return self.randn(*size, dtype=dtype, device=device)
def set_done_samples(self, done_samples):
self.done_samples = done_samples
def get_seed(self):
return self.seed
def set_seed(self, seed):
[
rng_cpu.manual_seed(i + self.num_samples * seed)
for i, rng_cpu in enumerate(self.rng_cpu)
]
if th.cuda.is_available():
[
rng_cuda.manual_seed(i + self.num_samples * seed)
for i, rng_cuda in enumerate(self.rng_cuda)
]
|