Spaces:
Running
on
Zero
Running
on
Zero
from .transport import Transport, ModelType, WeightType, PathType, SNRType, Sampler | |
def create_transport( | |
path_type='Linear', | |
prediction="velocity", | |
loss_weight=None, | |
train_eps=None, | |
sample_eps=None, | |
snr_type="uniform" | |
): | |
"""function for creating Transport object | |
**Note**: model prediction defaults to velocity | |
Args: | |
- path_type: type of path to use; default to linear | |
- learn_score: set model prediction to score | |
- learn_noise: set model prediction to noise | |
- velocity_weighted: weight loss by velocity weight | |
- likelihood_weighted: weight loss by likelihood weight | |
- train_eps: small epsilon for avoiding instability during training | |
- sample_eps: small epsilon for avoiding instability during sampling | |
""" | |
if prediction == "noise": | |
model_type = ModelType.NOISE | |
elif prediction == "score": | |
model_type = ModelType.SCORE | |
else: | |
model_type = ModelType.VELOCITY | |
if loss_weight == "velocity": | |
loss_type = WeightType.VELOCITY | |
elif loss_weight == "likelihood": | |
loss_type = WeightType.LIKELIHOOD | |
else: | |
loss_type = WeightType.NONE | |
if snr_type == "lognorm": | |
snr_type = SNRType.LOGNORM | |
elif snr_type == "uniform": | |
snr_type = SNRType.UNIFORM | |
else: | |
raise ValueError(f"Invalid snr type {snr_type}") | |
path_choice = { | |
"Linear": PathType.LINEAR, | |
"GVP": PathType.GVP, | |
"VP": PathType.VP, | |
} | |
path_type = path_choice[path_type] | |
if (path_type in [PathType.VP]): | |
train_eps = 1e-5 if train_eps is None else train_eps | |
sample_eps = 1e-3 if train_eps is None else sample_eps | |
elif (path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY): | |
train_eps = 1e-3 if train_eps is None else train_eps | |
sample_eps = 1e-3 if train_eps is None else sample_eps | |
else: # velocity & [GVP, LINEAR] is stable everywhere | |
train_eps = 0 | |
sample_eps = 0 | |
# create flow state | |
state = Transport( | |
model_type=model_type, | |
path_type=path_type, | |
loss_type=loss_type, | |
train_eps=train_eps, | |
sample_eps=sample_eps, | |
snr_type=snr_type | |
) | |
return state | |