kiankaydee commited on
Commit
13c7b4e
·
1 Parent(s): e5e2b12

some baseline code for the repo

Browse files
Files changed (3) hide show
  1. .vscode/settings.json +5 -0
  2. requirements.txt +212 -0
  3. vit_encoder.py +60 -0
.vscode/settings.json ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ {
2
+ "flake8.args": [
3
+ "--max-line-length=120"
4
+ ]
5
+ }
requirements.txt ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.24.1
2
+ aiohttp==3.8.6
3
+ aiohttp-retry==2.8.3
4
+ aiosignal==1.3.1
5
+ albumentations==1.3.0
6
+ amqp==5.1.1
7
+ analytics-python==1.4.post1
8
+ antlr4-python3-runtime==4.9.3
9
+ appdirs==1.4.4
10
+ argcomplete==3.1.2
11
+ asciitree==0.3.3
12
+ async-generator==1.10
13
+ async-timeout==4.0.3
14
+ asyncssh==2.14.0
15
+ atpublic==4.0
16
+ attrs==23.1.0
17
+ azure-core==1.29.4
18
+ azure-storage-blob==12.18.3
19
+ backoff==1.10.0
20
+ bcrypt==4.0.1
21
+ billiard==4.1.0
22
+ boto3==1.28.64
23
+ botocore==1.31.64
24
+ build==1.0.3
25
+ cachetools==5.3.1
26
+ celery==5.3.4
27
+ certifi==2023.7.22
28
+ cffi==1.16.0
29
+ charset-normalizer==3.3.0
30
+ click==8.1.7
31
+ click-didyoumean==0.3.0
32
+ click-plugins==1.1.1
33
+ click-repl==0.3.0
34
+ cloudpickle==2.2.1
35
+ colorama==0.4.6
36
+ configobj==5.0.8
37
+ contourpy==1.1.1
38
+ crc32c==2.3.post0
39
+ cryptography==41.0.4
40
+ cycler==0.12.1
41
+ Cython==0.29.36
42
+ databricks-cli==0.18.0
43
+ db-dtypes==1.1.1
44
+ decorator==5.1.1
45
+ determined==0.23.3
46
+ dictdiffer==0.9.0
47
+ diskcache==5.6.3
48
+ distro==1.8.0
49
+ docker==6.1.3
50
+ dpath==2.1.6
51
+ dulwich==0.21.6
52
+ dvc==3.26.2
53
+ dvc-data==2.18.1
54
+ dvc-gs==2.22.1
55
+ dvc-http==2.30.2
56
+ dvc-objects==1.0.1
57
+ dvc-render==0.6.0
58
+ dvc-studio-client==0.15.0
59
+ dvc-task==0.3.0
60
+ entrypoints==0.4
61
+ fasteners==0.19
62
+ filelock==3.12.4
63
+ flatten-dict==0.4.2
64
+ flufl.lock==7.1.1
65
+ fonttools==4.43.1
66
+ frozenlist==1.4.0
67
+ fsspec==2023.9.2
68
+ funcy==2.0
69
+ gcsfs==2023.9.2
70
+ gitdb==4.0.10
71
+ GitPython==3.1.38
72
+ google-api-core==2.12.0
73
+ google-api-python-client==2.103.0
74
+ google-auth==2.23.3
75
+ google-auth-httplib2==0.1.1
76
+ google-auth-oauthlib==1.0.0
77
+ google-cloud==0.34.0
78
+ google-cloud-bigquery==3.12.0
79
+ google-cloud-core==2.3.3
80
+ google-cloud-storage==2.12.0
81
+ google-crc32c==1.5.0
82
+ google-resumable-media==2.6.0
83
+ googleapis-common-protos==1.61.0
84
+ grandalf==0.8
85
+ grpcio==1.59.0
86
+ grpcio-status==1.48.2
87
+ gto==1.4.0
88
+ httplib2==0.22.0
89
+ huggingface-hub==0.18.0
90
+ hydra-core==1.3.2
91
+ idna==3.4
92
+ importlib-metadata==6.8.0
93
+ isodate==0.6.1
94
+ iterative-telemetry==0.0.8
95
+ Jinja2==3.1.2
96
+ jmespath==1.0.1
97
+ joblib==1.3.2
98
+ kiwisolver==1.4.5
99
+ kombu==5.3.2
100
+ lightning-fabric==2.1.0
101
+ lightning-utilities==0.9.0
102
+ lmdb==1.4.1
103
+ lomond==0.3.3
104
+ mahotas==1.4.13
105
+ markdown-it-py==3.0.0
106
+ MarkupSafe==2.1.3
107
+ matplotlib==3.8.0
108
+ mdurl==0.1.2
109
+ mlflow-skinny==2.6.0
110
+ monotonic==1.6
111
+ mpmath==0.19
112
+ multidict==6.0.4
113
+ networkx==3.1
114
+ numcodecs==0.12.0
115
+ numpy==1.26.1
116
+ nvidia-cublas-cu12==12.1.3.1
117
+ nvidia-cuda-cupti-cu12==12.1.105
118
+ nvidia-cuda-nvrtc-cu12==12.1.105
119
+ nvidia-cuda-runtime-cu12==12.1.105
120
+ nvidia-cudnn-cu12==8.9.2.26
121
+ nvidia-cufft-cu12==11.0.2.54
122
+ nvidia-curand-cu12==10.3.2.106
123
+ nvidia-cusolver-cu12==11.4.5.107
124
+ nvidia-cusparse-cu12==12.1.0.106
125
+ nvidia-nccl-cu12==2.18.1
126
+ nvidia-nvjitlink-cu12==12.2.140
127
+ nvidia-nvtx-cu12==12.1.105
128
+ oauthlib==3.2.2
129
+ omegaconf==2.3.0
130
+ orjson==3.9.9
131
+ packaging==21.3
132
+ pandas==2.1.1
133
+ paramiko==3.3.1
134
+ pathspec==0.11.2
135
+ peft==0.5.0
136
+ Pillow==10.1.0
137
+ pip-tools==7.3.0
138
+ platformdirs==3.11.0
139
+ prompt-toolkit==3.0.39
140
+ proto-plus==1.22.3
141
+ protobuf==3.20.3
142
+ psutil==5.9.6
143
+ pyarrow==13.0.0
144
+ pyasn1==0.5.0
145
+ pyasn1-modules==0.3.0
146
+ pycparser==2.21
147
+ pydantic==1.10.13
148
+ pydot==1.4.2
149
+ pygit2==1.13.1
150
+ Pygments==2.16.1
151
+ pygtrie==2.5.0
152
+ PyJWT==2.8.0
153
+ PyNaCl==1.5.0
154
+ pyOpenSSL==23.2.0
155
+ pyparsing==3.0.9
156
+ pyproject_hooks==1.0.0
157
+ python-dateutil==2.8.2
158
+ pytorch-lightning==2.1.0
159
+ pytz==2023.3.post1
160
+ PyYAML==6.0.1
161
+ pyzmq==25.1.1
162
+ regex==2023.10.3
163
+ requests==2.31.0
164
+ requests-oauthlib==1.3.1
165
+ rich==13.6.0
166
+ rsa==4.9
167
+ ruamel.yaml==0.17.35
168
+ ruamel.yaml.clib==0.2.8
169
+ s3transfer==0.7.0
170
+ safetensors==0.4.0
171
+ scikit-learn==1.3.1
172
+ scipy==1.11.3
173
+ scmrepo==1.4.0
174
+ semver==3.0.2
175
+ shortuuid==1.0.11
176
+ shtab==1.6.4
177
+ six==1.16.0
178
+ smmap==5.0.1
179
+ sqlparse==0.4.4
180
+ sqltrie==0.8.0
181
+ sympy==1.12
182
+ tabulate==0.9.0
183
+ tensorboardX==2.6.2.2
184
+ termcolor==2.3.0
185
+ threadpoolctl==3.2.0
186
+ timm==0.9.7
187
+ tokenizers==0.15.0
188
+ tomli==2.0.1
189
+ tomlkit==0.12.1
190
+ toolz==0.12.0
191
+ torch==2.1.0+cu121
192
+ torchmetrics==1.2.0
193
+ torchvision==0.16.0+cu121
194
+ tqdm==4.66.1
195
+ transformers==4.35.2
196
+ triton==2.1.0
197
+ typer==0.9.0
198
+ typing_extensions==4.8.0
199
+ tzdata==2023.3
200
+ uritemplate==4.1.1
201
+ urllib3==1.26.17
202
+ vine==5.0.0
203
+ voluptuous==0.13.1
204
+ wcwidth==0.2.8
205
+ websocket-client==1.6.4
206
+ websockets==11.0.3
207
+ xformers==0.0.22.post7
208
+ yarl==1.9.2
209
+ yogadl==0.1.4
210
+ zarr==2.16.1
211
+ zc.lockfile==3.0.post1
212
+ zipp==3.17.0
vit_encoder.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict
2
+
3
+ import timm.models.vision_transformer as vit
4
+ import torch
5
+
6
+
7
+ def build_imagenet_baselines() -> Dict[str, torch.jit.ScriptModule]:
8
+ """This returns the prepped imagenet encoders from timm, not bad for microscopy data."""
9
+ vit_backbones = [
10
+ _make_vit(vit.vit_small_patch16_384),
11
+ _make_vit(vit.vit_base_patch16_384),
12
+ _make_vit(vit.vit_base_patch8_224),
13
+ _make_vit(vit.vit_large_patch16_384),
14
+ ]
15
+ model_names = [
16
+ "vit_small_patch16_384",
17
+ "vit_base_patch16_384",
18
+ "vit_base_patch8_224",
19
+ "vit_large_patch16_384",
20
+ ]
21
+ imagenet_encoders = list(map(_make_torchscripted_encoder, vit_backbones))
22
+ return {name: model for name, model in zip(model_names, imagenet_encoders)}
23
+
24
+
25
+ def _make_torchscripted_encoder(vit_backbone) -> torch.jit.ScriptModule:
26
+ dummy_input = torch.testing.make_tensor(
27
+ (2, 6, 256, 256),
28
+ low=0,
29
+ high=255,
30
+ dtype=torch.uint8,
31
+ device=torch.device("cpu"),
32
+ )
33
+ encoder = torch.nn.Sequential(
34
+ Normalizer(),
35
+ torch.nn.LazyInstanceNorm2d(
36
+ affine=False, track_running_stats=False
37
+ ), # this module performs self-standardization, very important
38
+ vit_backbone,
39
+ ).to(device="cpu")
40
+ _ = encoder(dummy_input) # get those lazy modules built
41
+ return torch.jit.freeze(torch.jit.script(encoder.eval()))
42
+
43
+
44
+ def _make_vit(constructor):
45
+ return constructor(
46
+ pretrained=True, # download imagenet weights
47
+ img_size=256, # 256x256 crops
48
+ in_chans=6, # we expect 6-channel microscopy images
49
+ num_classes=0,
50
+ fc_norm=None,
51
+ class_token=True,
52
+ global_pool="avg", # minimal perf diff btwn "cls" and "avg"
53
+ )
54
+
55
+
56
+ class Normalizer(torch.nn.Module):
57
+ def forward(self, pixels: torch.Tensor) -> torch.Tensor:
58
+ pixels = pixels.float()
59
+ pixels /= 255.0
60
+ return pixels