Plonk / models /schedulers.py
nicolas-dufour's picture
squash: merge all unpushed commits
c4c7cee
import torch
class SigmoidScheduler:
def __init__(self, start=-3, end=3, tau=1, clip_min=1e-9):
self.start = start
self.end = end
self.tau = tau
self.clip_min = clip_min
self.v_start = torch.sigmoid(torch.tensor(self.start / self.tau))
self.v_end = torch.sigmoid(torch.tensor(self.end / self.tau))
def __call__(self, t):
output = (
-torch.sigmoid((t * (self.end - self.start) + self.start) / self.tau)
+ self.v_end
) / (self.v_end - self.v_start)
return torch.clamp(output, min=self.clip_min, max=1.0)
def derivative(self, t):
x = (t * (self.end - self.start) + self.start) / self.tau
sigmoid_x = torch.sigmoid(x)
# Chain rule: d/dt of original function
return (
-(self.end - self.start)
* sigmoid_x
* (1 - sigmoid_x)
/ (self.tau * (self.v_end - self.v_start))
)
def alpha(self, t):
return -self.derivative(t) / (1e-6 + self.__call__(t))
class LinearScheduler:
def __init__(self, start=1, end=0, clip_min=1e-9):
self.start = start
self.end = end
self.clip_min = clip_min
def __call__(self, t):
output = (self.end - self.start) * t + self.start
return torch.clamp(output, min=self.clip_min, max=1.0)
def derivative(self, t):
return torch.tensor(self.end - self.start).to(t.device)
def alpha(self, t):
return -self.derivative(t) / (1e-6 + self.__call__(t))
class CosineScheduler:
def __init__(
self,
start: float = 1,
end: float = 0,
tau: float = 1.0,
clip_min: float = 1e-9,
):
self.start = start
self.end = end
self.tau = tau
self.clip_min = clip_min
self.v_start = torch.cos(torch.tensor(self.start) * torch.pi / 2) ** (
2 * self.tau
)
self.v_end = torch.cos(torch.tensor(self.end) * torch.pi / 2) ** (2 * self.tau)
def __call__(self, t: float) -> float:
output = (
torch.cos((t * (self.end - self.start) + self.start) * torch.pi / 2)
** (2 * self.tau)
- self.v_end
) / (self.v_start - self.v_end)
return torch.clamp(output, min=self.clip_min, max=1.0)
def derivative(self, t: float) -> float:
x = (t * (self.end - self.start) + self.start) * torch.pi / 2
cos_x = torch.cos(x)
# Chain rule: d/dt of original function
return (
-2
* self.tau
* (self.end - self.start)
* torch.pi
/ 2
* cos_x
* (cos_x ** (2 * self.tau - 1))
* torch.sin(x)
/ (self.v_start - self.v_end)
)
class CosineSchedulerSimple:
def __init__(self, ns: float = 0.0002, ds: float = 0.00025):
self.ns = ns
self.ds = ds
def __call__(self, t: float) -> float:
return torch.cos(((t + self.ns) / (1 + self.ds)) * torch.pi / 2) ** 2
def derivative(self, t: float) -> float:
x = ((t + self.ns) / (1 + self.ds)) * torch.pi / 2
return -torch.pi * torch.cos(x) * torch.sin(x) / (1 + self.ds)