File size: 1,365 Bytes
5231633 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 |
class EpsilonTarget():
def __call__(self, x0, epsilon, logSNR, a, b):
return epsilon
def x0(self, noised, pred, logSNR, a, b):
return (noised - pred * b) / a
def epsilon(self, noised, pred, logSNR, a, b):
return pred
def noise_givenx0_noised(self, x0, noised , logSNR, a, b):
return (noised - a * x0) / b
def xt(self, x0, noise, logSNR, a, b):
return x0 * a + noise*b
class X0Target():
def __call__(self, x0, epsilon, logSNR, a, b):
return x0
def x0(self, noised, pred, logSNR, a, b):
return pred
def epsilon(self, noised, pred, logSNR, a, b):
return (noised - pred * a) / b
class VTarget():
def __call__(self, x0, epsilon, logSNR, a, b):
return a * epsilon - b * x0
def x0(self, noised, pred, logSNR, a, b):
squared_sum = a**2 + b**2
return a/squared_sum * noised - b/squared_sum * pred
def epsilon(self, noised, pred, logSNR, a, b):
squared_sum = a**2 + b**2
return b/squared_sum * noised + a/squared_sum * pred
class RectifiedFlowsTarget():
def __call__(self, x0, epsilon, logSNR, a, b):
return epsilon - x0
def x0(self, noised, pred, logSNR, a, b):
return noised - pred * b
def epsilon(self, noised, pred, logSNR, a, b):
return noised + pred * a
|