|
class EpsilonTarget(): |
|
def __call__(self, x0, epsilon, logSNR, a, b): |
|
return epsilon |
|
|
|
def x0(self, noised, pred, logSNR, a, b): |
|
return (noised - pred * b) / a |
|
|
|
def epsilon(self, noised, pred, logSNR, a, b): |
|
return pred |
|
def noise_givenx0_noised(self, x0, noised , logSNR, a, b): |
|
return (noised - a * x0) / b |
|
def xt(self, x0, noise, logSNR, a, b): |
|
|
|
return x0 * a + noise*b |
|
class X0Target(): |
|
def __call__(self, x0, epsilon, logSNR, a, b): |
|
return x0 |
|
|
|
def x0(self, noised, pred, logSNR, a, b): |
|
return pred |
|
|
|
def epsilon(self, noised, pred, logSNR, a, b): |
|
return (noised - pred * a) / b |
|
|
|
class VTarget(): |
|
def __call__(self, x0, epsilon, logSNR, a, b): |
|
return a * epsilon - b * x0 |
|
|
|
def x0(self, noised, pred, logSNR, a, b): |
|
squared_sum = a**2 + b**2 |
|
return a/squared_sum * noised - b/squared_sum * pred |
|
|
|
def epsilon(self, noised, pred, logSNR, a, b): |
|
squared_sum = a**2 + b**2 |
|
return b/squared_sum * noised + a/squared_sum * pred |
|
|
|
class RectifiedFlowsTarget(): |
|
def __call__(self, x0, epsilon, logSNR, a, b): |
|
return epsilon - x0 |
|
|
|
def x0(self, noised, pred, logSNR, a, b): |
|
return noised - pred * b |
|
|
|
def epsilon(self, noised, pred, logSNR, a, b): |
|
return noised + pred * a |
|
|