Spaces:
Running
on
Zero
Running
on
Zero
from abc import ABC, abstractmethod | |
import numpy as np | |
import torch as th | |
import torch.distributed as dist | |
def create_named_schedule_sampler(name, diffusion): | |
""" | |
Create a ScheduleSampler from a library of pre-defined samplers. | |
:param name: the name of the sampler. | |
:param diffusion: the diffusion object to sample for. | |
""" | |
if name == "uniform": | |
return UniformSampler(diffusion) | |
else: | |
raise NotImplementedError(f"unknown schedule sampler: {name}") | |
class ScheduleSampler(ABC): | |
""" | |
A distribution over timesteps in the diffusion process, intended to reduce | |
variance of the objective. | |
By default, samplers perform unbiased importance sampling, in which the | |
objective's mean is unchanged. | |
However, subclasses may override sample() to change how the resampled | |
terms are reweighted, allowing for actual changes in the objective. | |
""" | |
def weights(self): | |
""" | |
Get a numpy array of weights, one per diffusion step. | |
The weights needn't be normalized, but must be positive. | |
""" | |
def sample(self, batch_size, device): | |
""" | |
Importance-sample timesteps for a batch. | |
:param batch_size: the number of timesteps. | |
:param device: the torch device to save to. | |
:return: a tuple (timesteps, weights): | |
- timesteps: a tensor of timestep indices. | |
- weights: a tensor of weights to scale the resulting losses. | |
""" | |
w = self.weights() | |
p = w / np.sum(w) | |
indices_np = np.random.choice(len(p), size=(batch_size, ), p=p) | |
indices = th.from_numpy(indices_np).long().to(device) | |
weights_np = 1 / (len(p) * p[indices_np]) | |
weights = th.from_numpy(weights_np).float().to(device) | |
return indices, weights | |
class UniformSampler(ScheduleSampler): | |
def __init__(self, num_timesteps): | |
self._weights = np.ones([num_timesteps]) | |
def weights(self): | |
return self._weights | |