Spaces:
Running
on
Zero
Running
on
Zero
File size: 2,267 Bytes
11e6f7b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 |
from .transport import Transport, ModelType, WeightType, PathType, Sampler, SNRType
def create_transport(
path_type='Linear',
prediction="velocity",
loss_weight=None,
train_eps=None,
sample_eps=None,
snr_type='uniform',
):
"""function for creating Transport object
**Note**: model prediction defaults to velocity
Args:
- path_type: type of path to use; default to linear
- learn_score: set model prediction to score
- learn_noise: set model prediction to noise
- velocity_weighted: weight loss by velocity weight
- likelihood_weighted: weight loss by likelihood weight
- train_eps: small epsilon for avoiding instability during training
- sample_eps: small epsilon for avoiding instability during sampling
"""
if prediction == "noise":
model_type = ModelType.NOISE
elif prediction == "score":
model_type = ModelType.SCORE
else:
model_type = ModelType.VELOCITY
if loss_weight == "velocity":
loss_type = WeightType.VELOCITY
elif loss_weight == "likelihood":
loss_type = WeightType.LIKELIHOOD
else:
loss_type = WeightType.NONE
if snr_type == "lognorm":
snr_type = SNRType.LOGNORM
elif snr_type == "uniform":
snr_type = SNRType.UNIFORM
else:
raise ValueError(f"Invalid snr type {snr_type}")
path_choice = {
"Linear": PathType.LINEAR,
"GVP": PathType.GVP,
"VP": PathType.VP,
}
path_type = path_choice[path_type]
if (path_type in [PathType.VP]):
train_eps = 1e-5 if train_eps is None else train_eps
sample_eps = 1e-3 if train_eps is None else sample_eps
elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY):
train_eps = 1e-3 if train_eps is None else train_eps
sample_eps = 1e-3 if train_eps is None else sample_eps
else: # velocity & [GVP, LINEAR] is stable everywhere
train_eps = 0
sample_eps = 0
# create flow state
state = Transport(
model_type=model_type,
path_type=path_type,
loss_type=loss_type,
train_eps=train_eps,
sample_eps=sample_eps,
snr_type=snr_type,
)
return state |