jacklangerman commited on
Commit
e99b13a
1 Parent(s): f5a979c

Create hoho.py

Browse files
Files changed (1) hide show
  1. hoho.py +261 -0
hoho.py ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import shutil
4
+ from pathlib import Path
5
+ from typing import Dict
6
+
7
+ from PIL import ImageFile
8
+ ImageFile.LOAD_TRUNCATED_IMAGES = True
9
+
10
+ LOCAL_DATADIR = None
11
+
12
+ def setup(local_dir='./data/usm-training-data/data'):
13
+
14
+ # If we are in the test environment, we need to link the data directory to the correct location
15
+ tmp_datadir = Path('/tmp/data/data')
16
+ local_test_datadir = Path('./data/usm-test-data-x/data')
17
+ local_val_datadir = Path(local_dir)
18
+
19
+ os.system('pwd')
20
+ os.system('ls -lahtr .')
21
+
22
+ if tmp_datadir.exists() and not local_test_datadir.exists():
23
+ global LOCAL_DATADIR
24
+ LOCAL_DATADIR = local_test_datadir
25
+ # shutil.move(datadir, './usm-test-data-x/data')
26
+ print(f"Linking {tmp_datadir} to {LOCAL_DATADIR} (we are in the test environment)")
27
+ LOCAL_DATADIR.parent.mkdir(parents=True, exist_ok=True)
28
+ LOCAL_DATADIR.symlink_to(tmp_datadir)
29
+ else:
30
+ LOCAL_DATADIR = local_val_datadir
31
+ print(f"Using {LOCAL_DATADIR} as the data directory (we are running locally)")
32
+
33
+ # os.system("ls -lahtr")
34
+
35
+ assert LOCAL_DATADIR.exists(), f"Data directory {LOCAL_DATADIR} does not exist"
36
+ return LOCAL_DATADIR
37
+
38
+
39
+
40
+
41
+ import importlib
42
+ from pathlib import Path
43
+ import subprocess
44
+
45
+ def download_package(package_name, path_to_save='packages'):
46
+ """
47
+ Downloads a package using pip and saves it to a specified directory.
48
+
49
+ Parameters:
50
+ package_name (str): The name of the package to download.
51
+ path_to_save (str): The path to the directory where the package will be saved.
52
+ """
53
+ try:
54
+ # pip download webdataset -d packages/webdataset --platform manylinux1_x86_64 --python-version 38 --only-binary=:all:
55
+ subprocess.check_call([subprocess.sys.executable, "-m", "pip", "download", package_name,
56
+ "-d", str(Path(path_to_save)/package_name), # Download the package to the specified directory
57
+ "--platform", "manylinux1_x86_64", # Specify the platform
58
+ "--python-version", "38", # Specify the Python version
59
+ "--only-binary=:all:"]) # Download only binary packages
60
+ print(f'Package "{package_name}" downloaded successfully')
61
+ except subprocess.CalledProcessError as e:
62
+ print(f'Failed to downloaded package "{package_name}". Error: {e}')
63
+
64
+
65
+ def install_package_from_local_file(package_name, folder='packages'):
66
+ """
67
+ Installs a package from a local .whl file or a directory containing .whl files using pip.
68
+
69
+ Parameters:
70
+ path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files.
71
+ """
72
+ try:
73
+ pth = str(Path(folder) / package_name)
74
+ subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install",
75
+ "--no-index", # Do not use package index
76
+ "--find-links", pth, # Look for packages in the specified directory or at the file
77
+ package_name]) # Specify the package to install
78
+ print(f"Package installed successfully from {pth}")
79
+ except subprocess.CalledProcessError as e:
80
+ print(f"Failed to install package from {pth}. Error: {e}")
81
+
82
+
83
+ def importt(module_name, as_name=None):
84
+ """
85
+ Imports a module and returns it.
86
+
87
+ Parameters:
88
+ module_name (str): The name of the module to import.
89
+ as_name (str): The name to use for the imported module. If None, the original module name will be used.
90
+
91
+ Returns:
92
+ The imported module.
93
+ """
94
+ for _ in range(2):
95
+ try:
96
+ if as_name is None:
97
+ print(f'imported {module_name}')
98
+ return importlib.import_module(module_name)
99
+ else:
100
+ print(f'imported {module_name} as {as_name}')
101
+ return importlib.import_module(module_name, as_name)
102
+ except ModuleNotFoundError as e:
103
+ install_package_from_local_file(module_name)
104
+ print(f"Failed to import module {module_name}. Error: {e}")
105
+
106
+
107
+ def prepare_submission():
108
+ # Download packages from requirements.txt
109
+ if Path('requirements.txt').exists():
110
+ print('downloading packages from requirements.txt')
111
+ Path('packages').mkdir(exist_ok=True)
112
+ with open('requirements.txt') as f:
113
+ packages = f.readlines()
114
+ for p in packages:
115
+ download_package(p.strip())
116
+
117
+
118
+ print('all packages downloaded. Don\'t foget to include the packages in the submission by adding them with git lfs.')
119
+
120
+
121
+ def Rt_to_eye_target(im, K, R, t):
122
+ height = im.height
123
+ focal_length = K[0,0]
124
+ fov = 2.0 * np.arctan2((0.5 * height), focal_length) / (np.pi / 180.0)
125
+
126
+ x_axis, y_axis, z_axis = R
127
+
128
+ eye = -(R.T @ t).squeeze()
129
+ z_axis = z_axis.squeeze()
130
+ target = eye + z_axis
131
+ up = -y_axis
132
+
133
+ return eye, target, up, fov
134
+
135
+
136
+ ########## general utilities ##########
137
+ import contextlib
138
+ import tempfile
139
+ from pathlib import Path
140
+
141
+ @contextlib.contextmanager
142
+ def working_directory(path):
143
+ """Changes working directory and returns to previous on exit."""
144
+ prev_cwd = Path.cwd()
145
+ os.chdir(path)
146
+ try:
147
+ yield
148
+ finally:
149
+ os.chdir(prev_cwd)
150
+
151
+ @contextlib.contextmanager
152
+ def temp_working_directory():
153
+ with tempfile.TemporaryDirectory(dir='.') as D:
154
+ with working_directory(D):
155
+ yield
156
+
157
+
158
+ ############# Dataset #############
159
+ def proc(row, split='train'):
160
+ # column_names_train = ['ade20k', 'depthcm', 'gestalt', 'colmap', 'KRt', 'mesh', 'wireframe']
161
+ # column_names_test = ['ade20k', 'depthcm', 'gestalt', 'colmap', 'KRt', 'wireframe']
162
+ # cols = column_names_train if split == 'train' else column_names_test
163
+ out = {}
164
+ for k, v in row.items():
165
+ colname = k.split('.')[0]
166
+ if colname in {'ade20k', 'depthcm', 'gestalt'}:
167
+ if colname in out:
168
+ out[colname].append(v)
169
+ else:
170
+ out[colname] = [v]
171
+ elif colname in {'wireframe', 'mesh'}:
172
+ # out.update({a: b.tolist() for a,b in v.items()})
173
+ out.update({a: b for a,b in v.items()})
174
+ elif colname in 'kr':
175
+ out[colname.upper()] = v
176
+ else:
177
+ out[colname] = v
178
+
179
+ return Sample(out)
180
+
181
+
182
+ class Sample(Dict):
183
+ def __repr__(self):
184
+ return str({k: v.shape if hasattr(v, 'shape') else [type(v[0])] if isinstance(v, list) else type(v) for k,v in self.items()})
185
+
186
+
187
+
188
+ def get_params():
189
+ exmaple_param_dict = {
190
+ "competition_id": "usm3d/S23DR",
191
+ "competition_type": "script",
192
+ "metric": "custom",
193
+ "token": "hf_**********************************",
194
+ "team_id": "local-test-team_id",
195
+ "submission_id": "local-test-submission_id",
196
+ "submission_id_col": "__key__",
197
+ "submission_cols": [
198
+ "__key__",
199
+ "wf_edges",
200
+ "wf_vertices",
201
+ "edge_semantics"
202
+ ],
203
+ "submission_rows": 180,
204
+ "output_path": ".",
205
+ "submission_repo": "<THE HF MODEL ID of THIS REPO",
206
+ "time_limit": 7200,
207
+ "dataset": "usm3d/usm-test-data-x",
208
+ "submission_filenames": [
209
+ "submission.parquet"
210
+ ]
211
+ }
212
+
213
+ param_path = Path('params.json')
214
+
215
+ if not param_path.exists():
216
+ print('params.json not found (this means we probably aren\'t in the test env). Using example params.')
217
+ params = exmaple_param_dict
218
+ else:
219
+ print('found params.json (this means we are probably in the test env). Using params from file.')
220
+ with param_path.open() as f:
221
+ params = json.load(f)
222
+ print(params)
223
+ return params
224
+
225
+
226
+
227
+ import webdataset as wds
228
+ import numpy as np
229
+
230
+ def get_dataset(decode='pil', proc=proc, split='train', dataset_type='webdataset'):
231
+ if LOCAL_DATADIR is None:
232
+ raise ValueError('LOCAL_DATADIR is not set. Please run setup() first.')
233
+
234
+ local_dir = Path(LOCAL_DATADIR)
235
+ if split != 'all':
236
+ local_dir = local_dir / split
237
+
238
+ paths = [str(p) for p in local_dir.rglob('*.tar.gz')]
239
+
240
+ dataset = wds.WebDataset(paths)
241
+ if decode is not None:
242
+ dataset = dataset.decode(decode)
243
+ else:
244
+ dataset = dataset.decode()
245
+
246
+ dataset = dataset.map(proc)
247
+
248
+ if dataset_type == 'webdataset':
249
+ return dataset
250
+
251
+ if dataset_type == 'hf':
252
+ import datasets
253
+ from datasets import Features, Value, Sequence, Image, Array2D
254
+
255
+ if split == 'train':
256
+ return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
257
+ elif split == 'val':
258
+ return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
259
+
260
+
261
+