File size: 581 Bytes
d6682b6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
import numpy as np
import torch
from typing import Union, List
class linear:
def __init__(self):
pass
def execute(
self,
t: Union[float, List[float]],
v0: Union[List[torch.Tensor], torch.Tensor],
v1: Union[List[torch.Tensor], torch.Tensor],
DOT_THRESHOLD: float = 0.9995,
eps: float = 1e-8,
densities = None,
):
if type(v0) is list:
v0 = v0[0]
if type(t) is list:
t = t[0]
if type(v1) is list:
v1 = v1[0]
return t * v1 + (1.0 - t) * v0 |