File size: 581 Bytes
d6682b6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import numpy as np
import torch
from typing import Union, List

class linear:
    def __init__(self):
        pass
    def execute(
        self,
        t: Union[float, List[float]],
        v0: Union[List[torch.Tensor], torch.Tensor],
        v1: Union[List[torch.Tensor], torch.Tensor],
        DOT_THRESHOLD: float = 0.9995,
        eps: float = 1e-8,
        densities = None,
    ):
        if type(v0) is list:
            v0 = v0[0]
        if type(t) is list:
            t = t[0]
        if type(v1) is list:
            v1 = v1[0]

        return t * v1 + (1.0 - t) * v0