jacklangerman
commited on
Commit
•
e99b13a
1
Parent(s):
f5a979c
Create hoho.py
Browse files
hoho.py
ADDED
@@ -0,0 +1,261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import shutil
|
4 |
+
from pathlib import Path
|
5 |
+
from typing import Dict
|
6 |
+
|
7 |
+
from PIL import ImageFile
|
8 |
+
ImageFile.LOAD_TRUNCATED_IMAGES = True
|
9 |
+
|
10 |
+
LOCAL_DATADIR = None
|
11 |
+
|
12 |
+
def setup(local_dir='./data/usm-training-data/data'):
|
13 |
+
|
14 |
+
# If we are in the test environment, we need to link the data directory to the correct location
|
15 |
+
tmp_datadir = Path('/tmp/data/data')
|
16 |
+
local_test_datadir = Path('./data/usm-test-data-x/data')
|
17 |
+
local_val_datadir = Path(local_dir)
|
18 |
+
|
19 |
+
os.system('pwd')
|
20 |
+
os.system('ls -lahtr .')
|
21 |
+
|
22 |
+
if tmp_datadir.exists() and not local_test_datadir.exists():
|
23 |
+
global LOCAL_DATADIR
|
24 |
+
LOCAL_DATADIR = local_test_datadir
|
25 |
+
# shutil.move(datadir, './usm-test-data-x/data')
|
26 |
+
print(f"Linking {tmp_datadir} to {LOCAL_DATADIR} (we are in the test environment)")
|
27 |
+
LOCAL_DATADIR.parent.mkdir(parents=True, exist_ok=True)
|
28 |
+
LOCAL_DATADIR.symlink_to(tmp_datadir)
|
29 |
+
else:
|
30 |
+
LOCAL_DATADIR = local_val_datadir
|
31 |
+
print(f"Using {LOCAL_DATADIR} as the data directory (we are running locally)")
|
32 |
+
|
33 |
+
# os.system("ls -lahtr")
|
34 |
+
|
35 |
+
assert LOCAL_DATADIR.exists(), f"Data directory {LOCAL_DATADIR} does not exist"
|
36 |
+
return LOCAL_DATADIR
|
37 |
+
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
import importlib
|
42 |
+
from pathlib import Path
|
43 |
+
import subprocess
|
44 |
+
|
45 |
+
def download_package(package_name, path_to_save='packages'):
|
46 |
+
"""
|
47 |
+
Downloads a package using pip and saves it to a specified directory.
|
48 |
+
|
49 |
+
Parameters:
|
50 |
+
package_name (str): The name of the package to download.
|
51 |
+
path_to_save (str): The path to the directory where the package will be saved.
|
52 |
+
"""
|
53 |
+
try:
|
54 |
+
# pip download webdataset -d packages/webdataset --platform manylinux1_x86_64 --python-version 38 --only-binary=:all:
|
55 |
+
subprocess.check_call([subprocess.sys.executable, "-m", "pip", "download", package_name,
|
56 |
+
"-d", str(Path(path_to_save)/package_name), # Download the package to the specified directory
|
57 |
+
"--platform", "manylinux1_x86_64", # Specify the platform
|
58 |
+
"--python-version", "38", # Specify the Python version
|
59 |
+
"--only-binary=:all:"]) # Download only binary packages
|
60 |
+
print(f'Package "{package_name}" downloaded successfully')
|
61 |
+
except subprocess.CalledProcessError as e:
|
62 |
+
print(f'Failed to downloaded package "{package_name}". Error: {e}')
|
63 |
+
|
64 |
+
|
65 |
+
def install_package_from_local_file(package_name, folder='packages'):
|
66 |
+
"""
|
67 |
+
Installs a package from a local .whl file or a directory containing .whl files using pip.
|
68 |
+
|
69 |
+
Parameters:
|
70 |
+
path_to_file_or_directory (str): The path to the .whl file or the directory containing .whl files.
|
71 |
+
"""
|
72 |
+
try:
|
73 |
+
pth = str(Path(folder) / package_name)
|
74 |
+
subprocess.check_call([subprocess.sys.executable, "-m", "pip", "install",
|
75 |
+
"--no-index", # Do not use package index
|
76 |
+
"--find-links", pth, # Look for packages in the specified directory or at the file
|
77 |
+
package_name]) # Specify the package to install
|
78 |
+
print(f"Package installed successfully from {pth}")
|
79 |
+
except subprocess.CalledProcessError as e:
|
80 |
+
print(f"Failed to install package from {pth}. Error: {e}")
|
81 |
+
|
82 |
+
|
83 |
+
def importt(module_name, as_name=None):
|
84 |
+
"""
|
85 |
+
Imports a module and returns it.
|
86 |
+
|
87 |
+
Parameters:
|
88 |
+
module_name (str): The name of the module to import.
|
89 |
+
as_name (str): The name to use for the imported module. If None, the original module name will be used.
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
The imported module.
|
93 |
+
"""
|
94 |
+
for _ in range(2):
|
95 |
+
try:
|
96 |
+
if as_name is None:
|
97 |
+
print(f'imported {module_name}')
|
98 |
+
return importlib.import_module(module_name)
|
99 |
+
else:
|
100 |
+
print(f'imported {module_name} as {as_name}')
|
101 |
+
return importlib.import_module(module_name, as_name)
|
102 |
+
except ModuleNotFoundError as e:
|
103 |
+
install_package_from_local_file(module_name)
|
104 |
+
print(f"Failed to import module {module_name}. Error: {e}")
|
105 |
+
|
106 |
+
|
107 |
+
def prepare_submission():
|
108 |
+
# Download packages from requirements.txt
|
109 |
+
if Path('requirements.txt').exists():
|
110 |
+
print('downloading packages from requirements.txt')
|
111 |
+
Path('packages').mkdir(exist_ok=True)
|
112 |
+
with open('requirements.txt') as f:
|
113 |
+
packages = f.readlines()
|
114 |
+
for p in packages:
|
115 |
+
download_package(p.strip())
|
116 |
+
|
117 |
+
|
118 |
+
print('all packages downloaded. Don\'t foget to include the packages in the submission by adding them with git lfs.')
|
119 |
+
|
120 |
+
|
121 |
+
def Rt_to_eye_target(im, K, R, t):
|
122 |
+
height = im.height
|
123 |
+
focal_length = K[0,0]
|
124 |
+
fov = 2.0 * np.arctan2((0.5 * height), focal_length) / (np.pi / 180.0)
|
125 |
+
|
126 |
+
x_axis, y_axis, z_axis = R
|
127 |
+
|
128 |
+
eye = -(R.T @ t).squeeze()
|
129 |
+
z_axis = z_axis.squeeze()
|
130 |
+
target = eye + z_axis
|
131 |
+
up = -y_axis
|
132 |
+
|
133 |
+
return eye, target, up, fov
|
134 |
+
|
135 |
+
|
136 |
+
########## general utilities ##########
|
137 |
+
import contextlib
|
138 |
+
import tempfile
|
139 |
+
from pathlib import Path
|
140 |
+
|
141 |
+
@contextlib.contextmanager
|
142 |
+
def working_directory(path):
|
143 |
+
"""Changes working directory and returns to previous on exit."""
|
144 |
+
prev_cwd = Path.cwd()
|
145 |
+
os.chdir(path)
|
146 |
+
try:
|
147 |
+
yield
|
148 |
+
finally:
|
149 |
+
os.chdir(prev_cwd)
|
150 |
+
|
151 |
+
@contextlib.contextmanager
|
152 |
+
def temp_working_directory():
|
153 |
+
with tempfile.TemporaryDirectory(dir='.') as D:
|
154 |
+
with working_directory(D):
|
155 |
+
yield
|
156 |
+
|
157 |
+
|
158 |
+
############# Dataset #############
|
159 |
+
def proc(row, split='train'):
|
160 |
+
# column_names_train = ['ade20k', 'depthcm', 'gestalt', 'colmap', 'KRt', 'mesh', 'wireframe']
|
161 |
+
# column_names_test = ['ade20k', 'depthcm', 'gestalt', 'colmap', 'KRt', 'wireframe']
|
162 |
+
# cols = column_names_train if split == 'train' else column_names_test
|
163 |
+
out = {}
|
164 |
+
for k, v in row.items():
|
165 |
+
colname = k.split('.')[0]
|
166 |
+
if colname in {'ade20k', 'depthcm', 'gestalt'}:
|
167 |
+
if colname in out:
|
168 |
+
out[colname].append(v)
|
169 |
+
else:
|
170 |
+
out[colname] = [v]
|
171 |
+
elif colname in {'wireframe', 'mesh'}:
|
172 |
+
# out.update({a: b.tolist() for a,b in v.items()})
|
173 |
+
out.update({a: b for a,b in v.items()})
|
174 |
+
elif colname in 'kr':
|
175 |
+
out[colname.upper()] = v
|
176 |
+
else:
|
177 |
+
out[colname] = v
|
178 |
+
|
179 |
+
return Sample(out)
|
180 |
+
|
181 |
+
|
182 |
+
class Sample(Dict):
|
183 |
+
def __repr__(self):
|
184 |
+
return str({k: v.shape if hasattr(v, 'shape') else [type(v[0])] if isinstance(v, list) else type(v) for k,v in self.items()})
|
185 |
+
|
186 |
+
|
187 |
+
|
188 |
+
def get_params():
|
189 |
+
exmaple_param_dict = {
|
190 |
+
"competition_id": "usm3d/S23DR",
|
191 |
+
"competition_type": "script",
|
192 |
+
"metric": "custom",
|
193 |
+
"token": "hf_**********************************",
|
194 |
+
"team_id": "local-test-team_id",
|
195 |
+
"submission_id": "local-test-submission_id",
|
196 |
+
"submission_id_col": "__key__",
|
197 |
+
"submission_cols": [
|
198 |
+
"__key__",
|
199 |
+
"wf_edges",
|
200 |
+
"wf_vertices",
|
201 |
+
"edge_semantics"
|
202 |
+
],
|
203 |
+
"submission_rows": 180,
|
204 |
+
"output_path": ".",
|
205 |
+
"submission_repo": "<THE HF MODEL ID of THIS REPO",
|
206 |
+
"time_limit": 7200,
|
207 |
+
"dataset": "usm3d/usm-test-data-x",
|
208 |
+
"submission_filenames": [
|
209 |
+
"submission.parquet"
|
210 |
+
]
|
211 |
+
}
|
212 |
+
|
213 |
+
param_path = Path('params.json')
|
214 |
+
|
215 |
+
if not param_path.exists():
|
216 |
+
print('params.json not found (this means we probably aren\'t in the test env). Using example params.')
|
217 |
+
params = exmaple_param_dict
|
218 |
+
else:
|
219 |
+
print('found params.json (this means we are probably in the test env). Using params from file.')
|
220 |
+
with param_path.open() as f:
|
221 |
+
params = json.load(f)
|
222 |
+
print(params)
|
223 |
+
return params
|
224 |
+
|
225 |
+
|
226 |
+
|
227 |
+
import webdataset as wds
|
228 |
+
import numpy as np
|
229 |
+
|
230 |
+
def get_dataset(decode='pil', proc=proc, split='train', dataset_type='webdataset'):
|
231 |
+
if LOCAL_DATADIR is None:
|
232 |
+
raise ValueError('LOCAL_DATADIR is not set. Please run setup() first.')
|
233 |
+
|
234 |
+
local_dir = Path(LOCAL_DATADIR)
|
235 |
+
if split != 'all':
|
236 |
+
local_dir = local_dir / split
|
237 |
+
|
238 |
+
paths = [str(p) for p in local_dir.rglob('*.tar.gz')]
|
239 |
+
|
240 |
+
dataset = wds.WebDataset(paths)
|
241 |
+
if decode is not None:
|
242 |
+
dataset = dataset.decode(decode)
|
243 |
+
else:
|
244 |
+
dataset = dataset.decode()
|
245 |
+
|
246 |
+
dataset = dataset.map(proc)
|
247 |
+
|
248 |
+
if dataset_type == 'webdataset':
|
249 |
+
return dataset
|
250 |
+
|
251 |
+
if dataset_type == 'hf':
|
252 |
+
import datasets
|
253 |
+
from datasets import Features, Value, Sequence, Image, Array2D
|
254 |
+
|
255 |
+
if split == 'train':
|
256 |
+
return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
|
257 |
+
elif split == 'val':
|
258 |
+
return datasets.IterableDataset.from_generator(lambda: dataset.iterator())
|
259 |
+
|
260 |
+
|
261 |
+
|