Spaces:
Build error
Build error
Merge branch 'master' into main
Browse files- .gitignore +3 -0
- app.py +35 -0
- pulsar_clip.py +222 -0
- requirements.txt +6 -0
- utils.py +71 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.idea/
|
2 |
+
**/__pycache__/
|
3 |
+
flagged/
|
app.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pulsar_clip import PulsarCLIP, CONFIG_SPEC
|
2 |
+
from datetime import datetime
|
3 |
+
import gradio as gr
|
4 |
+
|
5 |
+
|
6 |
+
def generate(*args):
|
7 |
+
pc = PulsarCLIP(dict([(k, t(v) if not isinstance(t, (tuple, list)) else v)
|
8 |
+
for v, (k, v0, t) in zip(args, CONFIG_SPEC)]))
|
9 |
+
frames = []
|
10 |
+
for image in pc.generate():
|
11 |
+
frames.append(image)
|
12 |
+
from tqdm.auto import tqdm
|
13 |
+
from subprocess import Popen, PIPE
|
14 |
+
fps = 30
|
15 |
+
video_path = f"{datetime.strftime(datetime.now())}.mp4"
|
16 |
+
if frames:
|
17 |
+
p = Popen((f"ffmpeg -y -f image2pipe -vcodec png -r {fps} -i - -vcodec libx264 -r {fps} "
|
18 |
+
f"-pix_fmt yuv420p -crf 17 -preset fast ").split() + [str(video_path)], stdin=PIPE)
|
19 |
+
for im in tqdm(frames):
|
20 |
+
im.save(p.stdin, "PNG")
|
21 |
+
p.stdin.close()
|
22 |
+
p.wait()
|
23 |
+
return video_path
|
24 |
+
|
25 |
+
|
26 |
+
def main():
|
27 |
+
gr.Interface(inputs=[
|
28 |
+
(gr.inputs.Number(label=k, default=v0) if t in (float, int) else
|
29 |
+
gr.inputs.Checkbox(label=k, default=v0) if t == bool else gr.inputs.Textbox(label=k, default=v0) if t == str
|
30 |
+
else gr.inputs.Dropdown(label=k, default=v0, choices=t) if isinstance(t, (tuple, list)) else 1/0)
|
31 |
+
for k, v0, t in CONFIG_SPEC], outputs=gr.outputs.Video(), fn=generate).launch()
|
32 |
+
|
33 |
+
|
34 |
+
if __name__ == '__main__':
|
35 |
+
main()
|
pulsar_clip.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import set_seed
|
2 |
+
from tqdm.auto import trange
|
3 |
+
from PIL import Image
|
4 |
+
import numpy as np
|
5 |
+
import random
|
6 |
+
import utils
|
7 |
+
import torch
|
8 |
+
|
9 |
+
|
10 |
+
CONFIG_SPEC = [
|
11 |
+
("text", "A cloud at dawn", str),
|
12 |
+
("iterations", 5000, int),
|
13 |
+
("turns", 4, int),
|
14 |
+
("showoff", 5000, int),
|
15 |
+
("seed", 12, int),
|
16 |
+
("focal_length", 0.1, float),
|
17 |
+
("plane_width", 0.1, float),
|
18 |
+
("shade_strength", 0.25, float),
|
19 |
+
("gamma", 0.5, float),
|
20 |
+
("max_depth", 7, float),
|
21 |
+
("lr", 0.5, float),
|
22 |
+
("offset", 5, float),
|
23 |
+
("offset_random", 0.75, float),
|
24 |
+
("xyz_random", 0.25, float),
|
25 |
+
("altitude_range", 0.3, float),
|
26 |
+
("augments", 4, int),
|
27 |
+
("show_every", 50, int),
|
28 |
+
("epochs", 1, int),
|
29 |
+
("w", 224, int),
|
30 |
+
("h", 224, int),
|
31 |
+
("num_objects", 256, int),
|
32 |
+
#@markdown CLIP loss type, might improve the results
|
33 |
+
("loss_type", "spherical", ("spherical", "cosine")),
|
34 |
+
#@markdown CLIP loss weight
|
35 |
+
("clip_weight", 1.0, float), #@param {type: "number"}
|
36 |
+
#@markdown Number of dimensions. 0 is for point clouds (default), 1 will make
|
37 |
+
#@markdown strokes, 2 will make planes, 3 produces little cubes
|
38 |
+
("ndim", 0, (0, 1, 2, 3)), #@param {type: "integer"}
|
39 |
+
|
40 |
+
#@markdown Opacity scale:
|
41 |
+
("min_opacity", 1e-4, float), #@param {type: "number"}
|
42 |
+
("max_opacity", 1.0, float), #@param {type: "number"}
|
43 |
+
("log_opacity", False, bool), #@param {type: "boolean"}
|
44 |
+
|
45 |
+
("min_radius", 0.030, float),
|
46 |
+
("max_radius", 0.070, float),
|
47 |
+
("log_radius", False, bool),
|
48 |
+
|
49 |
+
# TODO dynamically decide bezier_res
|
50 |
+
#@markdown Bezier resolution: how many points a line/plane/cube will have. Not applicable to points
|
51 |
+
("bezier_res", 8, int), #@param {type: "integer"}
|
52 |
+
#@markdown Maximum scale of parameters: position, velocity, acceleration
|
53 |
+
("pos_scale", 0.4, float), #@param {type: "number"}
|
54 |
+
("vel_scale", 0.15, float), #@param {type: "number"}
|
55 |
+
("acc_scale", 0.15, float), #@param {type: "number"}
|
56 |
+
|
57 |
+
#@markdown Scale of each individual 3D object. Master control for velocity and acceleration scale.
|
58 |
+
("scale", 1, float), #@param {type: "number"}
|
59 |
+
]
|
60 |
+
|
61 |
+
|
62 |
+
# TODO: one day separate the config into multiple parts and split this megaobject into multiple objects
|
63 |
+
class PulsarCLIP(object):
|
64 |
+
def __init__(self, args):
|
65 |
+
args = DotDict(**args)
|
66 |
+
set_seed(args.seed)
|
67 |
+
self.args = args
|
68 |
+
self.device = args.get("device", "cuda" if torch.cuda.is_available() else "cpu")
|
69 |
+
# Defer the import so that we can import `pulsar_clip` and then install `pytorch3d`
|
70 |
+
import pytorch3d.renderer.points.pulsar as ps
|
71 |
+
self.ndim = int(self.args.ndim)
|
72 |
+
self.renderer = ps.Renderer(self.args.w, self.args.h,
|
73 |
+
self.args.num_objects * (self.args.bezier_res ** self.ndim)).to(self.device)
|
74 |
+
self.bezier_pos = torch.nn.Parameter(torch.randn((args.num_objects, 4)).to(self.device))
|
75 |
+
self.bezier_vel = torch.nn.Parameter(torch.randn((args.num_objects, 3 * self.ndim)).to(self.device))
|
76 |
+
self.bezier_acc = torch.nn.Parameter(torch.randn((args.num_objects, 3 * self.ndim)).to(self.device))
|
77 |
+
self.bezier_col = torch.nn.Parameter(torch.randn((args.num_objects, 4 * (1 + self.ndim))).to(self.device))
|
78 |
+
self.optimizer = torch.optim.Adam([dict(params=[self.bezier_col], lr=5e-1 * args.lr),
|
79 |
+
dict(params=[self.bezier_pos], lr=1e-1 * args.lr),
|
80 |
+
dict(params=[self.bezier_vel, self.bezier_acc], lr=5e-2 * args.lr),
|
81 |
+
])
|
82 |
+
self.model_clip, self.preprocess_clip = utils.load_clip()
|
83 |
+
self.model_clip.visual.requires_grad_(False)
|
84 |
+
self.scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(self.optimizer,
|
85 |
+
int(self.args.iterations
|
86 |
+
/ self.args.augments
|
87 |
+
/ self.args.epochs))
|
88 |
+
import clip
|
89 |
+
self.txt_emb = self.model_clip.encode_text(clip.tokenize([self.args.text]).to(self.device))[0].detach()
|
90 |
+
self.txt_emb = torch.nn.functional.normalize(self.txt_emb, dim=-1)
|
91 |
+
|
92 |
+
def get_points(self):
|
93 |
+
if self.ndim > 0:
|
94 |
+
bezier_ts = torch.stack(torch.meshgrid(
|
95 |
+
(torch.linspace(0, 1, self.args.bezier_res, device=self.device),) * self.ndim), dim=0
|
96 |
+
).unsqueeze(1).repeat((1, self.args.num_objects) + (1,) * self.ndim).unsqueeze(-1)
|
97 |
+
|
98 |
+
def interpolate_3D(pos, vel=0.0, acc=0.0, pos_scale=None, vel_scale=None, acc_scale=None, scale=None):
|
99 |
+
pos_scale = self.args.pos_scale if pos_scale is None else pos_scale
|
100 |
+
vel_scale = self.args.vel_scale if vel_scale is None else vel_scale
|
101 |
+
acc_scale = self.args.acc_scale if acc_scale is None else acc_scale
|
102 |
+
scale = self.args.scale if scale is None else scale
|
103 |
+
if self.ndim == 0:
|
104 |
+
return pos * pos_scale
|
105 |
+
result = 0.0
|
106 |
+
s = pos.shape[-1]
|
107 |
+
assert s * self.ndim == vel.shape[-1] == acc.shape[-1]
|
108 |
+
# O(dim) sequential lol
|
109 |
+
for d, bezier_t in zip(range(self.ndim), bezier_ts): # TODO replace with fused dimension operation
|
110 |
+
result = (result
|
111 |
+
+ torch.tanh(vel[..., d * s:(d + 1) * s]).view(
|
112 |
+
(-1,) + (1,) * self.ndim + (s,)) * vel_scale * bezier_t
|
113 |
+
+ torch.tanh(acc[..., d * s:(d + 1) * s]).view(
|
114 |
+
(-1,) + (1,) * self.ndim + (s,)) * acc_scale * bezier_t.pow(2))
|
115 |
+
result = (result * scale
|
116 |
+
+ torch.tanh(pos[..., :s]).view((-1,) + (1,) * self.ndim + (s,)) * pos_scale).view(-1, s)
|
117 |
+
return result
|
118 |
+
|
119 |
+
vert_pos = interpolate_3D(self.bezier_pos[..., :3], self.bezier_vel, self.bezier_acc)
|
120 |
+
vert_col = interpolate_3D(self.bezier_col[..., :4],
|
121 |
+
self.bezier_col[..., 4:4 + 4 * self.ndim],
|
122 |
+
self.bezier_col[..., -4 * self.ndim:])
|
123 |
+
|
124 |
+
to_bezier = lambda x: x.view((-1,) + (1,) * self.ndim + (x.shape[-1],)).repeat(
|
125 |
+
(1,) + (self.args.bezier_res,) * self.ndim + (1,)).reshape(-1, x.shape[-1])
|
126 |
+
rescale = lambda x, a, b, is_log=False: (torch.exp(x
|
127 |
+
* np.log(b / a)
|
128 |
+
+ np.log(a))) if is_log else x * (b - a) + a
|
129 |
+
return (
|
130 |
+
vert_pos,
|
131 |
+
torch.sigmoid(vert_col[..., :3]),
|
132 |
+
rescale(
|
133 |
+
torch.sigmoid(to_bezier(self.bezier_pos[..., -1:])[..., 0]),
|
134 |
+
self.args.min_radius, self.args.max_radius, is_log=self.args.log_radius
|
135 |
+
),
|
136 |
+
rescale(torch.sigmoid(vert_col[..., -1]),
|
137 |
+
self.args.min_opacity, self.args.max_opacity, is_log=self.args.log_opacity))
|
138 |
+
|
139 |
+
def camera(self, angle, altitude=0.0, offset=None, use_random=True, offset_random=None,
|
140 |
+
xyz_random=None, focal_length=None, plane_width=None):
|
141 |
+
if offset is None:
|
142 |
+
offset = self.args.offset
|
143 |
+
if xyz_random is None:
|
144 |
+
xyz_random = self.args.xyz_random
|
145 |
+
if focal_length is None:
|
146 |
+
focal_length = self.args.focal_length
|
147 |
+
if plane_width is None:
|
148 |
+
plane_width = self.args.plane_width
|
149 |
+
if offset_random is None:
|
150 |
+
offset_random = self.args.offset_random
|
151 |
+
device = self.device
|
152 |
+
offset = offset + np.random.normal() * offset_random * int(use_random)
|
153 |
+
position = torch.tensor([0, 0, -offset], dtype=torch.float)
|
154 |
+
position = utils.rotate_axis(position, altitude, 0)
|
155 |
+
position = utils.rotate_axis(position, angle, 1)
|
156 |
+
position = position + torch.randn(3) * xyz_random * int(use_random)
|
157 |
+
return torch.tensor([position[0], position[1], position[2],
|
158 |
+
altitude, angle, 0,
|
159 |
+
focal_length, plane_width], dtype=torch.float, device=device)
|
160 |
+
|
161 |
+
|
162 |
+
def render(self, cam_params=None):
|
163 |
+
if cam_params is None:
|
164 |
+
cam_params = self.camera(0, 0)
|
165 |
+
vert_pos, vert_col, radius, opacity = self.get_points()
|
166 |
+
|
167 |
+
rgb = self.renderer(vert_pos, vert_col, radius, cam_params,
|
168 |
+
self.args.gamma, self.args.max_depth, opacity=opacity)
|
169 |
+
opacity = self.renderer(vert_pos, vert_col * 0, radius, cam_params,
|
170 |
+
self.args.gamma, self.args.max_depth, opacity=opacity)
|
171 |
+
return rgb, opacity
|
172 |
+
|
173 |
+
def random_view_render(self):
|
174 |
+
angle = random.uniform(0, np.pi * 2)
|
175 |
+
altitude = random.uniform(-self.args.altitude_range / 2, self.args.altitude_range / 2)
|
176 |
+
cam_params = self.camera(angle, altitude)
|
177 |
+
result, alpha = self.render(cam_params)
|
178 |
+
back = torch.zeros_like(result)
|
179 |
+
s = back.shape
|
180 |
+
for j in range(s[-1]):
|
181 |
+
n = random.choice([7, 14, 28])
|
182 |
+
back[..., j] = utils.rand_perlin_2d_octaves(s[:-1], (n, n)).clip(-0.5, 0.5) + 0.5
|
183 |
+
result = result * (1 - alpha) + back * alpha
|
184 |
+
return result
|
185 |
+
|
186 |
+
|
187 |
+
def generate(self):
|
188 |
+
self.optimizer.zero_grad()
|
189 |
+
try:
|
190 |
+
for i in trange(self.args.iterations + self.args.showoff):
|
191 |
+
if i < self.args.iterations:
|
192 |
+
result = self.random_view_render()
|
193 |
+
img_emb = self.model_clip.encode_image(
|
194 |
+
self.preprocess_clip(result.permute(2, 0, 1)).unsqueeze(0).clamp(0., 1.))
|
195 |
+
img_emb = torch.nn.functional.normalize(img_emb, dim=-1)
|
196 |
+
if self.args.loss_type == "spherical":
|
197 |
+
clip_loss = (img_emb - self.txt_emb).norm(dim=-1).div(2).arcsin().pow(2).mul(2).mean()
|
198 |
+
elif self.args.loss_type == "cosine":
|
199 |
+
clip_loss = (1 - img_emb @ self.txt_emb.T).mean()
|
200 |
+
else:
|
201 |
+
raise NotImplementedError(f"CLIP loss type not supported: {self.args.loss_type}")
|
202 |
+
loss = clip_loss * self.args.clip_weight + (0 and ...) # TODO add more loss types
|
203 |
+
loss.backward()
|
204 |
+
if i % self.args.augments == self.args.augments - 1:
|
205 |
+
self.optimizer.step()
|
206 |
+
self.optimizer.zero_grad()
|
207 |
+
try:
|
208 |
+
self.scheduler.step()
|
209 |
+
except AttributeError:
|
210 |
+
pass
|
211 |
+
if i % self.args.show_every == 0:
|
212 |
+
cam_params = self.camera(i / self.args.iterations * np.pi * 2 * self.args.turns, use_random=False)
|
213 |
+
img_show, _ = self.render(cam_params)
|
214 |
+
img = Image.fromarray((img_show.cpu().detach().numpy() * 255).astype(np.uint8))
|
215 |
+
yield img
|
216 |
+
except KeyboardInterrupt:
|
217 |
+
pass
|
218 |
+
|
219 |
+
|
220 |
+
class DotDict(dict):
|
221 |
+
def __getattr__(self, item):
|
222 |
+
return self.__getitem__(item)
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
pytorch3d==0.6.2
|
2 |
+
transformers==4.10.3
|
3 |
+
torch==1.11.0+cu113
|
4 |
+
torchvision==0.12.0+cu113
|
5 |
+
clip
|
6 |
+
gradio
|
utils.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import torch
|
3 |
+
import math
|
4 |
+
|
5 |
+
|
6 |
+
def rotate_axis(x, add_angle=0, axis=1): # TODO Replace with a rotation matrix # But this is more fun
|
7 |
+
axes = list(range(3))
|
8 |
+
axes.remove(axis)
|
9 |
+
ax1, ax2 = axes
|
10 |
+
angle = torch.atan2(x[..., ax1], x[..., ax2])
|
11 |
+
if isinstance(add_angle, torch.Tensor):
|
12 |
+
while add_angle.ndim < angle.ndim:
|
13 |
+
add_angle = add_angle.unsqueeze(-1)
|
14 |
+
angle = angle + add_angle
|
15 |
+
dist = x.norm(dim=-1)
|
16 |
+
t = []
|
17 |
+
_, t = zip(*sorted([
|
18 |
+
(axis, x[..., axis]),
|
19 |
+
(ax1, torch.sin(angle) * dist),
|
20 |
+
(ax2, torch.cos(angle) * dist),
|
21 |
+
]))
|
22 |
+
return torch.stack(t, dim=-1)
|
23 |
+
|
24 |
+
|
25 |
+
noise_level = 0.5
|
26 |
+
|
27 |
+
|
28 |
+
# stolen from https://gist.github.com/ac1b097753f217c5c11bc2ff396e0a57
|
29 |
+
# ported from https://github.com/pvigier/perlin-numpy/blob/master/perlin2d.py
|
30 |
+
def rand_perlin_2d(shape, res, fade=lambda t: 6 * t ** 5 - 15 * t ** 4 + 10 * t ** 3):
|
31 |
+
delta = (res[0] / shape[0], res[1] / shape[1])
|
32 |
+
d = (shape[0] // res[0], shape[1] // res[1])
|
33 |
+
|
34 |
+
grid = torch.stack(torch.meshgrid(torch.arange(0, res[0], delta[0]), torch.arange(0, res[1], delta[1])), dim=-1) % 1
|
35 |
+
angles = 2 * math.pi * torch.rand(res[0] + 1, res[1] + 1)
|
36 |
+
gradients = torch.stack((torch.cos(angles), torch.sin(angles)), dim=-1)
|
37 |
+
|
38 |
+
tile_grads = lambda slice1, slice2: gradients[slice1[0]:slice1[1], slice2[0]:slice2[1]].repeat_interleave(d[0],
|
39 |
+
0).repeat_interleave(
|
40 |
+
d[1], 1)
|
41 |
+
dot = lambda grad, shift: (
|
42 |
+
torch.stack((grid[:shape[0], :shape[1], 0] + shift[0], grid[:shape[0], :shape[1], 1] + shift[1]),
|
43 |
+
dim=-1) * grad[:shape[0], :shape[1]]).sum(dim=-1)
|
44 |
+
|
45 |
+
n00 = dot(tile_grads([0, -1], [0, -1]), [0, 0])
|
46 |
+
n10 = dot(tile_grads([1, None], [0, -1]), [-1, 0])
|
47 |
+
n01 = dot(tile_grads([0, -1], [1, None]), [0, -1])
|
48 |
+
n11 = dot(tile_grads([1, None], [1, None]), [-1, -1])
|
49 |
+
t = fade(grid[:shape[0], :shape[1]])
|
50 |
+
return math.sqrt(2) * torch.lerp(torch.lerp(n00, n10, t[..., 0]), torch.lerp(n01, n11, t[..., 0]), t[..., 1])
|
51 |
+
|
52 |
+
|
53 |
+
def rand_perlin_2d_octaves(shape, res, octaves=1, persistence=0.5):
|
54 |
+
noise = torch.zeros(shape)
|
55 |
+
frequency = 1
|
56 |
+
amplitude = 1
|
57 |
+
for _ in range(octaves):
|
58 |
+
noise += amplitude * rand_perlin_2d(shape, (frequency * res[0], frequency * res[1]))
|
59 |
+
frequency *= 2
|
60 |
+
amplitude *= persistence
|
61 |
+
noise *= random.random() - noise_level # haha
|
62 |
+
noise += random.random() - noise_level # haha x2
|
63 |
+
return noise
|
64 |
+
|
65 |
+
|
66 |
+
def load_clip(model_name="ViT-B/16", device="cuda:0" if torch.cuda.is_available() else "cpu"):
|
67 |
+
import clip
|
68 |
+
model, preprocess = clip.load(model_name, device=device, jit=False)
|
69 |
+
if len(preprocess.transforms) > 4:
|
70 |
+
preprocess.transforms = preprocess.transforms[-1:]
|
71 |
+
return model, preprocess
|