Bounding box detection
PyTorch
Ontocord.AI commited on
Commit
e0103c6
·
1 Parent(s): cae5ca5

Create utils.py

Browse files
Files changed (1) hide show
  1. utils.py +761 -0
utils.py ADDED
@@ -0,0 +1,761 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ coding=utf-8
3
+ Copyright 2022, Ontocord, LLC
4
+ Copyright 2018, Antonio Mendoza Hao Tan, Mohit Bansal, Huggingface team :)
5
+ Adapted From Facebook Inc, Detectron2
6
+ Licensed under the Apache License, Version 2.0 (the "License");
7
+ you may not use this file except in compliance with the License.
8
+ You may obtain a copy of the License at
9
+ http://www.apache.org/licenses/LICENSE-2.0
10
+ Unless required by applicable law or agreed to in writing, software
11
+ distributed under the License is distributed on an "AS IS" BASIS,
12
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ See the License for the specific language governing permissions and
14
+ limitations under the License.import copy
15
+ """
16
+
17
+ import copy
18
+ import fnmatch
19
+ import json
20
+ import os
21
+ import pickle as pkl
22
+ import shutil
23
+ import sys
24
+ import tarfile
25
+ import tempfile
26
+ from collections import OrderedDict
27
+ from contextlib import contextmanager
28
+ from functools import partial
29
+ from hashlib import sha256
30
+ from io import BytesIO
31
+ from pathlib import Path
32
+ from urllib.parse import urlparse
33
+ from zipfile import ZipFile, is_zipfile
34
+
35
+ import numpy as np
36
+ from PIL import Image
37
+ from tqdm.auto import tqdm
38
+
39
+ import cv2
40
+ import requests
41
+ from filelock import FileLock
42
+ from yaml import Loader, dump, load
43
+ from torch.nn.functional import cosine_similarity
44
+ from numpy import asarray
45
+
46
+ try:
47
+ import torch
48
+
49
+ _torch_available = True
50
+ except ImportError:
51
+ _torch_available = False
52
+
53
+
54
+ try:
55
+ from torch.hub import _get_torch_home
56
+
57
+ torch_cache_home = _get_torch_home()
58
+ except ImportError:
59
+ torch_cache_home = os.path.expanduser(
60
+ os.getenv("TORCH_HOME", os.path.join(os.getenv("XDG_CACHE_HOME", "~/.cache"), "torch"))
61
+ )
62
+
63
+ import re
64
+ import numpy as np
65
+ import torch
66
+ import torch.distributed as dist
67
+ import collections
68
+ import logging
69
+ import sys, os
70
+ try:
71
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__),
72
+ os.path.pardir)))
73
+ except:
74
+ sys.path.append(os.path.abspath(os.path.join("./",
75
+ os.path.pardir)))
76
+
77
+ in_notebook = 'google.colab' in sys.modules
78
+ if not in_notebook:
79
+ try:
80
+ get_ipython()
81
+ except:
82
+ in_notebook = False
83
+ if in_notebook:
84
+ from IPython.display import clear_output, Image, display
85
+
86
+ import PIL.Image
87
+
88
+ import random
89
+ from PIL import Image
90
+ import requests
91
+ from transformers import CLIPProcessor, CLIPModel
92
+ import torch
93
+ from torch.nn.functional import cosine_similarity
94
+ import json
95
+ import tqdm
96
+ import numpy
97
+
98
+ # for visualizing output
99
+ def showarray(a, fmt='jpeg'):
100
+ a = np.uint8(np.clip(a, 0, 255))
101
+ f = io.BytesIO()
102
+ PIL.Image.fromarray(a).save(f, fmt)
103
+ display(Image(data=f.getvalue()))
104
+
105
+
106
+ def decode_image(img, frcnn, image_preprocessor, max_detections=36, annotated_image=False):
107
+ from .visualizing_image import SingleImageViz
108
+ from .frcnn_ids import objids, attrids
109
+ if annotated_image:
110
+ frcnn_visualizer = SingleImageViz(img, id2obj=objids, id2attr=attrids)
111
+
112
+ images, sizes, scales_yx = image_preprocessor(img)
113
+
114
+
115
+ output_dict = frcnn(
116
+ images,
117
+ sizes,
118
+ scales_yx = scales_yx,
119
+ padding = 'max_detections',
120
+ max_detections = max_detections,
121
+ return_tensors = 'pt'
122
+ )
123
+
124
+ if annotated_image:
125
+ # add boxes and labels to the image
126
+ frcnn_visualizer.draw_boxes(
127
+ output_dict.get("boxes"),
128
+ output_dict.get("obj_ids"),
129
+ output_dict.get("obj_probs"),
130
+ output_dict.get("attr_ids"),
131
+ output_dict.get("attr_probs"),
132
+ )
133
+
134
+
135
+ a = frcnn_visualizer._get_buffer()
136
+ a = np.uint8(np.clip(a, 0, 255))
137
+ output_dict['annotated_image'] = PIL.Image.fromarray(a)
138
+
139
+
140
+ return output_dict
141
+
142
+ def get_area(pos):
143
+ """
144
+ Args
145
+ pos: [B, N, 4]
146
+ (x1, x2, y1, y2)
147
+ Return
148
+ area : [B, N]
149
+ """
150
+ # [B, N]
151
+ height = pos[:, :, 3] - pos[:, :, 2]
152
+ width = pos[:, :, 1] - pos[:, :, 0]
153
+ area = height * width
154
+ return area
155
+
156
+ def get_relative_distance(pos):
157
+ """
158
+ Args
159
+ pos: [B, N, 4]
160
+ (x1, x2, y1, y2)
161
+ Return
162
+ out : [B, N, N, 4]
163
+ """
164
+ # B, N = pos.size()[:-1]
165
+
166
+ # [B, N, N, 4]
167
+ relative_distance = pos.unsqueeze(1) - pos.unsqueeze(2)
168
+
169
+ return relative_distance
170
+
171
+
172
+ class LossMeter(object):
173
+ def __init__(self, maxlen=100):
174
+ """Computes and stores the running average"""
175
+ self.vals = collections.deque([], maxlen=maxlen)
176
+
177
+ def __len__(self):
178
+ return len(self.vals)
179
+
180
+ def update(self, new_val):
181
+ self.vals.append(new_val)
182
+
183
+ @property
184
+ def val(self):
185
+ return sum(self.vals) / len(self.vals)
186
+
187
+ def __repr__(self):
188
+ return str(self.val)
189
+
190
+
191
+ def count_parameters(model):
192
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
193
+
194
+
195
+ def load_state_dict(state_dict_path, loc='cpu'):
196
+ state_dict = torch.load(state_dict_path, map_location=loc)
197
+ # Change Multi GPU to single GPU
198
+ original_keys = list(state_dict.keys())
199
+ for key in original_keys:
200
+ if key.startswith("module."):
201
+ new_key = key[len("module."):]
202
+ state_dict[new_key] = state_dict.pop(key)
203
+ return state_dict
204
+
205
+
206
+ def set_global_logging_level(level=logging.ERROR, prefices=[""]):
207
+ """
208
+ Override logging levels of different modules based on their name as a prefix.
209
+ It needs to be invoked after the modules have been loaded so that their loggers have been initialized.
210
+ Args:
211
+ - level: desired level. e.g. logging.INFO. Optional. Default is logging.ERROR
212
+ - prefices: list of one or more str prefices to match (e.g. ["transformers", "torch"]). Optional.
213
+ Default is `[""]` to match all active loggers.
214
+ The match is a case-sensitive `module_name.startswith(prefix)`
215
+ """
216
+ prefix_re = re.compile(fr'^(?:{ "|".join(prefices) })')
217
+ for name in logging.root.manager.loggerDict:
218
+ if re.match(prefix_re, name):
219
+ logging.getLogger(name).setLevel(level)
220
+
221
+
222
+ def get_iou(anchors, gt_boxes):
223
+ """
224
+ anchors: (N, 4) torch floattensor
225
+ gt_boxes: (K, 4) torch floattensor
226
+ overlaps: (N, K) ndarray of overlap between boxes and query_boxes
227
+ """
228
+ N = anchors.size(0)
229
+
230
+ if gt_boxes.size() == (4,):
231
+ gt_boxes = gt_boxes.view(1, 4)
232
+ K = gt_boxes.size(0)
233
+
234
+ gt_boxes_area = (
235
+ (gt_boxes[:, 2] - gt_boxes[:, 0] + 1) *
236
+ (gt_boxes[:, 3] - gt_boxes[:, 1] + 1)
237
+ ).view(1, K)
238
+
239
+ anchors_area = (
240
+ (anchors[:, 2] - anchors[:, 0] + 1) *
241
+ (anchors[:, 3] - anchors[:, 1] + 1)
242
+ ).view(N, 1)
243
+
244
+ boxes = anchors.view(N, 1, 4).expand(N, K, 4)
245
+ query_boxes = gt_boxes.view(1, K, 4).expand(N, K, 4)
246
+
247
+ iw = (
248
+ torch.min(boxes[:, :, 2], query_boxes[:, :, 2])
249
+ - torch.max(boxes[:, :, 0], query_boxes[:, :, 0])
250
+ + 1
251
+ )
252
+ iw[iw < 0] = 0
253
+
254
+ ih = (
255
+ torch.min(boxes[:, :, 3], query_boxes[:, :, 3])
256
+ - torch.max(boxes[:, :, 1], query_boxes[:, :, 1])
257
+ + 1
258
+ )
259
+ ih[ih < 0] = 0
260
+
261
+ ua = anchors_area + gt_boxes_area - (iw * ih)
262
+ overlaps = iw * ih / ua
263
+
264
+ return overlaps
265
+
266
+
267
+ def xywh_to_xyxy(boxes):
268
+ """Convert [x y w h] box format to [x1 y1 x2 y2] format."""
269
+ return np.hstack((boxes[:, 0:2], boxes[:, 0:2] + boxes[:, 2:4] - 1))
270
+
271
+ default_cache_path = os.path.join(torch_cache_home, "transformers")
272
+
273
+ CLOUDFRONT_DISTRIB_PREFIX = "https://cdn.huggingface.co"
274
+ S3_BUCKET_PREFIX = "https://s3.amazonaws.com/models.huggingface.co/bert"
275
+ PATH = "/".join(str(Path(__file__).resolve()).split("/")[:-1])
276
+ CONFIG = os.path.join(PATH, "config.yaml")
277
+ ATTRIBUTES = os.path.join(PATH, "attributes.txt")
278
+ OBJECTS = os.path.join(PATH, "objects.txt")
279
+ PYTORCH_PRETRAINED_BERT_CACHE = os.getenv("PYTORCH_PRETRAINED_BERT_CACHE", default_cache_path)
280
+ PYTORCH_TRANSFORMERS_CACHE = os.getenv("PYTORCH_TRANSFORMERS_CACHE", PYTORCH_PRETRAINED_BERT_CACHE)
281
+ TRANSFORMERS_CACHE = os.getenv("TRANSFORMERS_CACHE", PYTORCH_TRANSFORMERS_CACHE)
282
+ WEIGHTS_NAME = "pytorch_model.bin"
283
+ CONFIG_NAME = "config.yaml"
284
+
285
+
286
+ def load_labels(objs=OBJECTS, attrs=ATTRIBUTES):
287
+ vg_classes = []
288
+ with open(objs) as f:
289
+ for object in f.readlines():
290
+ vg_classes.append(object.split(",")[0].lower().strip())
291
+
292
+ vg_attrs = []
293
+ with open(attrs) as f:
294
+ for object in f.readlines():
295
+ vg_attrs.append(object.split(",")[0].lower().strip())
296
+ return vg_classes, vg_attrs
297
+
298
+
299
+ def load_checkpoint(ckp):
300
+ r = OrderedDict()
301
+ with open(ckp, "rb") as f:
302
+ ckp = pkl.load(f)["model"]
303
+ for k in copy.deepcopy(list(ckp.keys())):
304
+ v = ckp.pop(k)
305
+ if isinstance(v, np.ndarray):
306
+ v = torch.tensor(v)
307
+ else:
308
+ assert isinstance(v, torch.tensor), type(v)
309
+ r[k] = v
310
+ return r
311
+
312
+
313
+ class Config:
314
+ _pointer = {}
315
+
316
+ def __init__(self, dictionary: dict, name: str = "root", level=0):
317
+ self._name = name
318
+ self._level = level
319
+ d = {}
320
+ for k, v in dictionary.items():
321
+ if v is None:
322
+ raise ValueError()
323
+ k = copy.deepcopy(k)
324
+ v = copy.deepcopy(v)
325
+ if isinstance(v, dict):
326
+ v = Config(v, name=k, level=level + 1)
327
+ d[k] = v
328
+ setattr(self, k, v)
329
+
330
+ self._pointer = d
331
+
332
+ def __repr__(self):
333
+ return str(list((self._pointer.keys())))
334
+
335
+ def __setattr__(self, key, val):
336
+ self.__dict__[key] = val
337
+ self.__dict__[key.upper()] = val
338
+ levels = key.split(".")
339
+ last_level = len(levels) - 1
340
+ pointer = self._pointer
341
+ if len(levels) > 1:
342
+ for i, l in enumerate(levels):
343
+ if hasattr(self, l) and isinstance(getattr(self, l), Config):
344
+ setattr(getattr(self, l), ".".join(levels[i:]), val)
345
+ if l == last_level:
346
+ pointer[l] = val
347
+ else:
348
+ pointer = pointer[l]
349
+
350
+ def to_dict(self):
351
+ return self._pointer
352
+
353
+ def dump_yaml(self, data, file_name):
354
+ with open(f"{file_name}", "w") as stream:
355
+ dump(data, stream)
356
+
357
+ def dump_json(self, data, file_name):
358
+ with open(f"{file_name}", "w") as stream:
359
+ json.dump(data, stream)
360
+
361
+ @staticmethod
362
+ def load_yaml(config):
363
+ with open(config) as stream:
364
+ data = load(stream, Loader=Loader)
365
+ return data
366
+
367
+ def __str__(self):
368
+ t = " "
369
+ if self._name != "root":
370
+ r = f"{t * (self._level-1)}{self._name}:\n"
371
+ else:
372
+ r = ""
373
+ level = self._level
374
+ for i, (k, v) in enumerate(self._pointer.items()):
375
+ if isinstance(v, Config):
376
+ r += f"{t * (self._level)}{v}\n"
377
+ self._level += 1
378
+ else:
379
+ r += f"{t * (self._level)}{k}: {v} ({type(v).__name__})\n"
380
+ self._level = level
381
+ return r[:-1]
382
+
383
+ @classmethod
384
+ def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs):
385
+ config_dict, kwargs = cls.get_config_dict(pretrained_model_name_or_path, **kwargs)
386
+ return cls(config_dict)
387
+
388
+ @classmethod
389
+ def get_config_dict(cls, pretrained_model_name_or_path: str, **kwargs):
390
+
391
+ cache_dir = kwargs.pop("cache_dir", None)
392
+ force_download = kwargs.pop("force_download", False)
393
+ resume_download = kwargs.pop("resume_download", False)
394
+ proxies = kwargs.pop("proxies", None)
395
+ local_files_only = kwargs.pop("local_files_only", False)
396
+
397
+ if os.path.isdir(pretrained_model_name_or_path):
398
+ config_file = os.path.join(pretrained_model_name_or_path, CONFIG_NAME)
399
+ elif os.path.isfile(pretrained_model_name_or_path) or is_remote_url(pretrained_model_name_or_path):
400
+ config_file = pretrained_model_name_or_path
401
+ else:
402
+ config_file = hf_bucket_url(pretrained_model_name_or_path, filename=CONFIG_NAME, use_cdn=False)
403
+
404
+ try:
405
+ # Load from URL or cache if already cached
406
+ resolved_config_file = cached_path(
407
+ config_file,
408
+ cache_dir=cache_dir,
409
+ force_download=force_download,
410
+ proxies=proxies,
411
+ resume_download=resume_download,
412
+ local_files_only=local_files_only,
413
+ )
414
+ # Load config dict
415
+ if resolved_config_file is None:
416
+ raise EnvironmentError
417
+
418
+ config_file = Config.load_yaml(resolved_config_file)
419
+
420
+ except EnvironmentError:
421
+ msg = "Can't load config for"
422
+ raise EnvironmentError(msg)
423
+
424
+ if resolved_config_file == config_file:
425
+ print("loading configuration file from path")
426
+ else:
427
+ print("loading configuration file cache")
428
+
429
+ return Config.load_yaml(resolved_config_file), kwargs
430
+
431
+
432
+ # quick compare tensors
433
+ def compare(in_tensor):
434
+
435
+ out_tensor = torch.load("dump.pt", map_location=in_tensor.device)
436
+ n1 = in_tensor.numpy()
437
+ n2 = out_tensor.numpy()[0]
438
+ print(n1.shape, n1[0, 0, :5])
439
+ print(n2.shape, n2[0, 0, :5])
440
+ assert np.allclose(
441
+ n1, n2, rtol=0.01, atol=0.1
442
+ ), f"{sum([1 for x in np.isclose(n1, n2, rtol=0.01, atol=0.1).flatten() if x == False])/len(n1.flatten())*100:.4f} % element-wise mismatch"
443
+ raise Exception("tensors are all good")
444
+
445
+ # Hugging face functions below
446
+
447
+
448
+ def is_remote_url(url_or_filename):
449
+ parsed = urlparse(url_or_filename)
450
+ return parsed.scheme in ("http", "https")
451
+
452
+
453
+ def hf_bucket_url(model_id: str, filename: str, use_cdn=True) -> str:
454
+ endpoint = CLOUDFRONT_DISTRIB_PREFIX if use_cdn else S3_BUCKET_PREFIX
455
+ legacy_format = "/" not in model_id
456
+ if legacy_format:
457
+ return f"{endpoint}/{model_id}-{filename}"
458
+ else:
459
+ return f"{endpoint}/{model_id}/{filename}"
460
+
461
+
462
+ def http_get(
463
+ url,
464
+ temp_file,
465
+ proxies=None,
466
+ resume_size=0,
467
+ user_agent=None,
468
+ ):
469
+ ua = "python/{}".format(sys.version.split()[0])
470
+ if _torch_available:
471
+ ua += "; torch/{}".format(torch.__version__)
472
+ if isinstance(user_agent, dict):
473
+ ua += "; " + "; ".join("{}/{}".format(k, v) for k, v in user_agent.items())
474
+ elif isinstance(user_agent, str):
475
+ ua += "; " + user_agent
476
+ headers = {"user-agent": ua}
477
+ if resume_size > 0:
478
+ headers["Range"] = "bytes=%d-" % (resume_size,)
479
+ response = requests.get(url, stream=True, proxies=proxies, headers=headers)
480
+ if response.status_code == 416: # Range not satisfiable
481
+ return
482
+ content_length = response.headers.get("Content-Length")
483
+ total = resume_size + int(content_length) if content_length is not None else None
484
+ progress = tqdm(
485
+ unit="B",
486
+ unit_scale=True,
487
+ total=total,
488
+ initial=resume_size,
489
+ desc="Downloading",
490
+ )
491
+ for chunk in response.iter_content(chunk_size=1024):
492
+ if chunk: # filter out keep-alive new chunks
493
+ progress.update(len(chunk))
494
+ temp_file.write(chunk)
495
+ progress.close()
496
+
497
+
498
+ def get_from_cache(
499
+ url,
500
+ cache_dir=None,
501
+ force_download=False,
502
+ proxies=None,
503
+ etag_timeout=10,
504
+ resume_download=False,
505
+ user_agent=None,
506
+ local_files_only=False,
507
+ ):
508
+
509
+ if cache_dir is None:
510
+ cache_dir = TRANSFORMERS_CACHE
511
+ if isinstance(cache_dir, Path):
512
+ cache_dir = str(cache_dir)
513
+
514
+ os.makedirs(cache_dir, exist_ok=True)
515
+
516
+ etag = None
517
+ if not local_files_only:
518
+ try:
519
+ response = requests.head(url, allow_redirects=True, proxies=proxies, timeout=etag_timeout)
520
+ if response.status_code == 200:
521
+ etag = response.headers.get("ETag")
522
+ except (EnvironmentError, requests.exceptions.Timeout):
523
+ # etag is already None
524
+ pass
525
+
526
+ filename = url_to_filename(url, etag)
527
+
528
+ # get cache path to put the file
529
+ cache_path = os.path.join(cache_dir, filename)
530
+
531
+ # etag is None = we don't have a connection, or url doesn't exist, or is otherwise inaccessible.
532
+ # try to get the last downloaded one
533
+ if etag is None:
534
+ if os.path.exists(cache_path):
535
+ return cache_path
536
+ else:
537
+ matching_files = [
538
+ file
539
+ for file in fnmatch.filter(os.listdir(cache_dir), filename + ".*")
540
+ if not file.endswith(".json") and not file.endswith(".lock")
541
+ ]
542
+ if len(matching_files) > 0:
543
+ return os.path.join(cache_dir, matching_files[-1])
544
+ else:
545
+ # If files cannot be found and local_files_only=True,
546
+ # the models might've been found if local_files_only=False
547
+ # Notify the user about that
548
+ if local_files_only:
549
+ raise ValueError(
550
+ "Cannot find the requested files in the cached path and outgoing traffic has been"
551
+ " disabled. To enable model look-ups and downloads online, set 'local_files_only'"
552
+ " to False."
553
+ )
554
+ return None
555
+
556
+ # From now on, etag is not None.
557
+ if os.path.exists(cache_path) and not force_download:
558
+ return cache_path
559
+
560
+ # Prevent parallel downloads of the same file with a lock.
561
+ lock_path = cache_path + ".lock"
562
+ with FileLock(lock_path):
563
+
564
+ # If the download just completed while the lock was activated.
565
+ if os.path.exists(cache_path) and not force_download:
566
+ # Even if returning early like here, the lock will be released.
567
+ return cache_path
568
+
569
+ if resume_download:
570
+ incomplete_path = cache_path + ".incomplete"
571
+
572
+ @contextmanager
573
+ def _resumable_file_manager():
574
+ with open(incomplete_path, "a+b") as f:
575
+ yield f
576
+
577
+ temp_file_manager = _resumable_file_manager
578
+ if os.path.exists(incomplete_path):
579
+ resume_size = os.stat(incomplete_path).st_size
580
+ else:
581
+ resume_size = 0
582
+ else:
583
+ temp_file_manager = partial(tempfile.NamedTemporaryFile, dir=cache_dir, delete=False)
584
+ resume_size = 0
585
+
586
+ # Download to temporary file, then copy to cache dir once finished.
587
+ # Otherwise you get corrupt cache entries if the download gets interrupted.
588
+ with temp_file_manager() as temp_file:
589
+ print(
590
+ "%s not found in cache or force_download set to True, downloading to %s",
591
+ url,
592
+ temp_file.name,
593
+ )
594
+
595
+ http_get(
596
+ url,
597
+ temp_file,
598
+ proxies=proxies,
599
+ resume_size=resume_size,
600
+ user_agent=user_agent,
601
+ )
602
+
603
+ os.replace(temp_file.name, cache_path)
604
+
605
+ meta = {"url": url, "etag": etag}
606
+ meta_path = cache_path + ".json"
607
+ with open(meta_path, "w") as meta_file:
608
+ json.dump(meta, meta_file)
609
+
610
+ return cache_path
611
+
612
+
613
+ def url_to_filename(url, etag=None):
614
+
615
+ url_bytes = url.encode("utf-8")
616
+ url_hash = sha256(url_bytes)
617
+ filename = url_hash.hexdigest()
618
+
619
+ if etag:
620
+ etag_bytes = etag.encode("utf-8")
621
+ etag_hash = sha256(etag_bytes)
622
+ filename += "." + etag_hash.hexdigest()
623
+
624
+ if url.endswith(".h5"):
625
+ filename += ".h5"
626
+
627
+ return filename
628
+
629
+
630
+ def cached_path(
631
+ url_or_filename,
632
+ cache_dir=None,
633
+ force_download=False,
634
+ proxies=None,
635
+ resume_download=False,
636
+ user_agent=None,
637
+ extract_compressed_file=False,
638
+ force_extract=False,
639
+ local_files_only=False,
640
+ ):
641
+ if cache_dir is None:
642
+ cache_dir = TRANSFORMERS_CACHE
643
+ if isinstance(url_or_filename, Path):
644
+ url_or_filename = str(url_or_filename)
645
+ if isinstance(cache_dir, Path):
646
+ cache_dir = str(cache_dir)
647
+
648
+ if is_remote_url(url_or_filename):
649
+ # URL, so get it from the cache (downloading if necessary)
650
+ output_path = get_from_cache(
651
+ url_or_filename,
652
+ cache_dir=cache_dir,
653
+ force_download=force_download,
654
+ proxies=proxies,
655
+ resume_download=resume_download,
656
+ user_agent=user_agent,
657
+ local_files_only=local_files_only,
658
+ )
659
+ elif os.path.exists(url_or_filename):
660
+ # File, and it exists.
661
+ output_path = url_or_filename
662
+ elif urlparse(url_or_filename).scheme == "":
663
+ # File, but it doesn't exist.
664
+ raise EnvironmentError("file {} not found".format(url_or_filename))
665
+ else:
666
+ # Something unknown
667
+ raise ValueError("unable to parse {} as a URL or as a local path".format(url_or_filename))
668
+
669
+ if extract_compressed_file:
670
+ if not is_zipfile(output_path) and not tarfile.is_tarfile(output_path):
671
+ return output_path
672
+
673
+ # Path where we extract compressed archives
674
+ # We avoid '.' in dir name and add "-extracted" at the end: "./model.zip" => "./model-zip-extracted/"
675
+ output_dir, output_file = os.path.split(output_path)
676
+ output_extract_dir_name = output_file.replace(".", "-") + "-extracted"
677
+ output_path_extracted = os.path.join(output_dir, output_extract_dir_name)
678
+
679
+ if os.path.isdir(output_path_extracted) and os.listdir(output_path_extracted) and not force_extract:
680
+ return output_path_extracted
681
+
682
+ # Prevent parallel extractions
683
+ lock_path = output_path + ".lock"
684
+ with FileLock(lock_path):
685
+ shutil.rmtree(output_path_extracted, ignore_errors=True)
686
+ os.makedirs(output_path_extracted)
687
+ if is_zipfile(output_path):
688
+ with ZipFile(output_path, "r") as zip_file:
689
+ zip_file.extractall(output_path_extracted)
690
+ zip_file.close()
691
+ elif tarfile.is_tarfile(output_path):
692
+ tar_file = tarfile.open(output_path)
693
+ tar_file.extractall(output_path_extracted)
694
+ tar_file.close()
695
+ else:
696
+ raise EnvironmentError("Archive format of {} could not be identified".format(output_path))
697
+
698
+ return output_path_extracted
699
+
700
+ return output_path
701
+
702
+
703
+ def get_data(query, delim=","):
704
+ assert isinstance(query, str)
705
+ if os.path.isfile(query):
706
+ with open(query) as f:
707
+ data = eval(f.read())
708
+ else:
709
+ req = requests.get(query)
710
+ try:
711
+ data = requests.json()
712
+ except Exception:
713
+ data = req.content.decode()
714
+ assert data is not None, "could not connect"
715
+ try:
716
+ data = eval(data)
717
+ except Exception:
718
+ data = data.split("\n")
719
+ req.close()
720
+ return data
721
+
722
+
723
+ def get_image_from_url(url):
724
+ response = requests.get(url)
725
+ img = np.array(Image.open(BytesIO(response.content)))
726
+ return img
727
+
728
+
729
+ # to load legacy frcnn checkpoint from detectron
730
+ def load_frcnn_pkl_from_url(url):
731
+ import wget
732
+ fn = url.split("/")[-1]
733
+ if fn not in os.listdir(os.getcwd()):
734
+ wget.download(url)
735
+ with open(fn, "rb") as stream:
736
+ weights = pkl.load(stream)
737
+ model = weights.pop("model")
738
+ new = {}
739
+ for k, v in model.items():
740
+ new[k] = torch.from_numpy(v)
741
+ if "running_var" in k:
742
+ zero = torch.tensor([0])
743
+ k2 = k.replace("running_var", "num_batches_tracked")
744
+ new[k2] = zero
745
+ return new
746
+
747
+ def img_tensorize(im, input_format="RGB"):
748
+ assert isinstance(im, str)
749
+ if os.path.isfile(im):
750
+ img = cv2.imread(im)
751
+ else:
752
+ img = get_image_from_url(im)
753
+ assert img is not None, f"could not connect to: {im}"
754
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
755
+ if input_format == "RGB":
756
+ img = img[:, :, ::-1]
757
+ return img
758
+
759
+
760
+ def chunk(images, batch=1):
761
+ return (images[i : i + batch] for i in range(0, len(images), batch))