Spaces:
Sleeping
Sleeping
Sync with main repo
Browse files- app.py +6 -6
- models/tap_vit_l_548184.pkl +3 -0
- tokenize_anything/__init__.py +1 -1
- tokenize_anything/engine/__init__.py +29 -0
- tokenize_anything/engine/build.py +25 -0
- tokenize_anything/engine/lr_scheduler.py +76 -0
- tokenize_anything/{test_engine.py β engine/test_engine.py} +1 -1
- tokenize_anything/engine/utils.py +153 -0
- tokenize_anything/layers/__init__.py +26 -0
- tokenize_anything/layers/drop.py +39 -0
- tokenize_anything/layers/loss.py +82 -0
- tokenize_anything/layers/utils.py +64 -0
- tokenize_anything/modeling/concept_projector.py +2 -2
- tokenize_anything/modeling/image_decoder.py +4 -7
- tokenize_anything/modeling/image_encoder.py +6 -3
- tokenize_anything/modeling/image_tokenizer.py +8 -9
- tokenize_anything/modeling/text_decoder.py +5 -2
- tokenize_anything/models/__init__.py +18 -0
- tokenize_anything/{build_model.py β models/easy_build.py} +3 -3
- tokenize_anything/prompters/__init__.py +18 -0
- tokenize_anything/prompters/visual_prompter.py +106 -0
- tokenize_anything/utils/logging.py +129 -0
- tokenize_anything/utils/profiler/__init__.py +20 -0
- tokenize_anything/utils/profiler/stats.py +42 -0
- tokenize_anything/utils/{timer.py β profiler/timer.py} +12 -1
- tokenize_anything/utils/registry.py +54 -0
- tokenize_anything/utils/tensorboard.py +68 -0
app.py
CHANGED
@@ -23,18 +23,18 @@ import time
|
|
23 |
import numpy as np
|
24 |
import torch
|
25 |
|
26 |
-
from tokenize_anything import
|
27 |
from tokenize_anything.utils.image import im_rescale
|
28 |
from tokenize_anything.utils.image import im_vstack
|
29 |
|
30 |
|
31 |
def parse_args():
|
32 |
"""Parse arguments."""
|
33 |
-
parser = argparse.ArgumentParser(description="Launch gradio
|
34 |
parser.add_argument("--model-type", type=str, default="tap_vit_l")
|
35 |
-
parser.add_argument("--checkpoint", type=str, default="models/
|
36 |
parser.add_argument("--concept", type=str, default="concepts/merged_2560.pkl")
|
37 |
-
parser.add_argument("--device", nargs="+", type=int, default=[0], help="Index of devices
|
38 |
return parser.parse_args()
|
39 |
|
40 |
|
@@ -94,7 +94,7 @@ class Predictor(object):
|
|
94 |
# Generate captions.
|
95 |
sem_tokens = outputs["sem_tokens"][mask_index].unsqueeze_(1)
|
96 |
captions = self.model.generate_text(sem_tokens).reshape(batch_shape)
|
97 |
-
#
|
98 |
results = []
|
99 |
for i in range(batch_shape[0]):
|
100 |
pred_h, pred_w = im_info[i, :2].astype("int")
|
@@ -227,7 +227,7 @@ if __name__ == "__main__":
|
|
227 |
args = parse_args()
|
228 |
queues = [mp.Queue(1024) for _ in range(len(args.device) + 1)]
|
229 |
commands = [
|
230 |
-
|
231 |
queues[i],
|
232 |
queues[-1],
|
233 |
kwargs={
|
|
|
23 |
import numpy as np
|
24 |
import torch
|
25 |
|
26 |
+
from tokenize_anything import engine
|
27 |
from tokenize_anything.utils.image import im_rescale
|
28 |
from tokenize_anything.utils.image import im_vstack
|
29 |
|
30 |
|
31 |
def parse_args():
|
32 |
"""Parse arguments."""
|
33 |
+
parser = argparse.ArgumentParser(description="Launch gradio application")
|
34 |
parser.add_argument("--model-type", type=str, default="tap_vit_l")
|
35 |
+
parser.add_argument("--checkpoint", type=str, default="models/tap_vit_l_548184.pkl")
|
36 |
parser.add_argument("--concept", type=str, default="concepts/merged_2560.pkl")
|
37 |
+
parser.add_argument("--device", nargs="+", type=int, default=[0], help="Index of devices")
|
38 |
return parser.parse_args()
|
39 |
|
40 |
|
|
|
94 |
# Generate captions.
|
95 |
sem_tokens = outputs["sem_tokens"][mask_index].unsqueeze_(1)
|
96 |
captions = self.model.generate_text(sem_tokens).reshape(batch_shape)
|
97 |
+
# Postprocess results.
|
98 |
results = []
|
99 |
for i in range(batch_shape[0]):
|
100 |
pred_h, pred_w = im_info[i, :2].astype("int")
|
|
|
227 |
args = parse_args()
|
228 |
queues = [mp.Queue(1024) for _ in range(len(args.device) + 1)]
|
229 |
commands = [
|
230 |
+
engine.InferenceCommand(
|
231 |
queues[i],
|
232 |
queues[-1],
|
233 |
kwargs={
|
models/tap_vit_l_548184.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e1d3a11c572af8cb6bce8016d3a6c6948bba4959ea43811f0e984b9eafeee413
|
3 |
+
size 811637521
|
tokenize_anything/__init__.py
CHANGED
@@ -15,5 +15,5 @@
|
|
15 |
# ------------------------------------------------------------------------
|
16 |
"""Tokenize Anything via Prompting."""
|
17 |
|
18 |
-
from tokenize_anything.
|
19 |
from tokenize_anything.version import __version__
|
|
|
15 |
# ------------------------------------------------------------------------
|
16 |
"""Tokenize Anything via Prompting."""
|
17 |
|
18 |
+
from tokenize_anything.models import model_registry
|
19 |
from tokenize_anything.version import __version__
|
tokenize_anything/engine/__init__.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Engine components."""
|
17 |
+
|
18 |
+
from tokenize_anything.engine.build import build_tensorboard
|
19 |
+
from tokenize_anything.engine.test_engine import InferenceCommand
|
20 |
+
from tokenize_anything.engine.utils import apply_ddp_group
|
21 |
+
from tokenize_anything.engine.utils import count_params
|
22 |
+
from tokenize_anything.engine.utils import create_ddp_group
|
23 |
+
from tokenize_anything.engine.utils import freeze_module
|
24 |
+
from tokenize_anything.engine.utils import get_ddp_group
|
25 |
+
from tokenize_anything.engine.utils import get_ddp_rank
|
26 |
+
from tokenize_anything.engine.utils import get_device
|
27 |
+
from tokenize_anything.engine.utils import get_param_groups
|
28 |
+
from tokenize_anything.engine.utils import load_weights
|
29 |
+
from tokenize_anything.engine.utils import manual_seed
|
tokenize_anything/engine/build.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Build for engine."""
|
17 |
+
|
18 |
+
|
19 |
+
def build_tensorboard(log_dir):
|
20 |
+
"""Build the tensorboard."""
|
21 |
+
from tokenize_anything.utils.tensorboard import TensorBoard
|
22 |
+
|
23 |
+
if TensorBoard.is_available():
|
24 |
+
return TensorBoard(log_dir)
|
25 |
+
return None
|
tokenize_anything/engine/lr_scheduler.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Learning rate schedulers."""
|
17 |
+
|
18 |
+
import math
|
19 |
+
|
20 |
+
|
21 |
+
class ConstantLR(object):
|
22 |
+
"""Constant LR scheduler."""
|
23 |
+
|
24 |
+
def __init__(self, **kwargs):
|
25 |
+
self._lr_max = kwargs.pop("lr_max")
|
26 |
+
self._lr_min = kwargs.pop("lr_min", 0)
|
27 |
+
self._warmup_steps = kwargs.pop("warmup_steps", 0)
|
28 |
+
self._warmup_factor = kwargs.pop("warmup_factor", 0)
|
29 |
+
if kwargs:
|
30 |
+
raise ValueError("Unexpected arguments: " + ",".join(v for v in kwargs))
|
31 |
+
self._step_count = 0
|
32 |
+
self._last_decay = 1.0
|
33 |
+
|
34 |
+
def step(self):
|
35 |
+
self._step_count += 1
|
36 |
+
|
37 |
+
def get_lr(self):
|
38 |
+
if self._step_count < self._warmup_steps:
|
39 |
+
alpha = (self._step_count + 1.0) / self._warmup_steps
|
40 |
+
return self._lr_max * (alpha + (1.0 - alpha) * self._warmup_factor)
|
41 |
+
return self._lr_min + (self._lr_max - self._lr_min) * self.get_decay()
|
42 |
+
|
43 |
+
def get_decay(self):
|
44 |
+
return self._last_decay
|
45 |
+
|
46 |
+
|
47 |
+
class CosineLR(ConstantLR):
|
48 |
+
"""LR scheduler with cosine decay."""
|
49 |
+
|
50 |
+
def __init__(self, lr_max, max_steps, lr_min=0, decay_step=1, **kwargs):
|
51 |
+
super(CosineLR, self).__init__(lr_max=lr_max, lr_min=lr_min, **kwargs)
|
52 |
+
self._decay_step = decay_step
|
53 |
+
self._max_steps = max_steps
|
54 |
+
|
55 |
+
def get_decay(self):
|
56 |
+
t = self._step_count - self._warmup_steps
|
57 |
+
t_max = self._max_steps - self._warmup_steps
|
58 |
+
if t > 0 and t % self._decay_step == 0:
|
59 |
+
self._last_decay = 0.5 * (1.0 + math.cos(math.pi * t / t_max))
|
60 |
+
return self._last_decay
|
61 |
+
|
62 |
+
|
63 |
+
class LinearLR(ConstantLR):
|
64 |
+
"""LR scheduler with linear decay."""
|
65 |
+
|
66 |
+
def __init__(self, lr_max, max_steps, lr_min=0, decay_step=1, **kwargs):
|
67 |
+
super(LinearLR, self).__init__(lr_max=lr_max, lr_min=lr_min, **kwargs)
|
68 |
+
self._decay_step = decay_step
|
69 |
+
self._max_steps = max_steps
|
70 |
+
|
71 |
+
def get_decay(self):
|
72 |
+
t = self._step_count - self._warmup_steps
|
73 |
+
t_max = self._max_steps - self._warmup_steps
|
74 |
+
if t > 0 and t % self._decay_step == 0:
|
75 |
+
self._last_decay = 1.0 - float(t) / t_max
|
76 |
+
return self._last_decay
|
tokenize_anything/{test_engine.py β engine/test_engine.py}
RENAMED
@@ -17,7 +17,7 @@
|
|
17 |
|
18 |
import time
|
19 |
|
20 |
-
from tokenize_anything.
|
21 |
|
22 |
|
23 |
class InferenceCommand(object):
|
|
|
17 |
|
18 |
import time
|
19 |
|
20 |
+
from tokenize_anything.models.easy_build import model_registry
|
21 |
|
22 |
|
23 |
class InferenceCommand(object):
|
tokenize_anything/engine/utils.py
ADDED
@@ -0,0 +1,153 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Engine utilities."""
|
17 |
+
|
18 |
+
import collections
|
19 |
+
import functools
|
20 |
+
import pickle
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import numpy as np
|
24 |
+
|
25 |
+
from tokenize_anything.utils import logging
|
26 |
+
|
27 |
+
GLOBAL_DDP_GROUP = None
|
28 |
+
|
29 |
+
|
30 |
+
def count_params(module, trainable=True, unit="M"):
|
31 |
+
"""Return the number of parameters."""
|
32 |
+
counts = [v.size().numel() for v in module.parameters() if v.requires_grad or (not trainable)]
|
33 |
+
return sum(counts) / {"M": 1e6, "B": 1e9}[unit]
|
34 |
+
|
35 |
+
|
36 |
+
def freeze_module(module):
|
37 |
+
"""Freeze parameters of given module."""
|
38 |
+
module.eval()
|
39 |
+
for param in module.parameters():
|
40 |
+
param.requires_grad = False
|
41 |
+
|
42 |
+
|
43 |
+
def get_device(index):
|
44 |
+
"""Create the available device object."""
|
45 |
+
if torch.cuda.is_available():
|
46 |
+
return torch.device("cuda", index)
|
47 |
+
for device_type in ("mps",):
|
48 |
+
try:
|
49 |
+
if getattr(torch.backends, device_type).is_available():
|
50 |
+
return torch.device(device_type, index)
|
51 |
+
except AttributeError:
|
52 |
+
pass
|
53 |
+
return torch.device("cpu")
|
54 |
+
|
55 |
+
|
56 |
+
def get_param_groups(module, layer_lr_decay=1.0):
|
57 |
+
"""Separate parameters into groups."""
|
58 |
+
memo, groups, inner = {}, collections.OrderedDict(), module
|
59 |
+
if isinstance(module, torch.nn.parallel.DistributedDataParallel):
|
60 |
+
inner = module.module
|
61 |
+
lr_scale_getter = None
|
62 |
+
if layer_lr_decay < 1.0 and hasattr(inner.image_encoder, "get_lr_scale"):
|
63 |
+
lr_scale_getter = functools.partial(inner.image_encoder.get_lr_scale, decay=layer_lr_decay)
|
64 |
+
for name, param in module.named_parameters():
|
65 |
+
if not param.requires_grad:
|
66 |
+
continue
|
67 |
+
attrs = collections.OrderedDict()
|
68 |
+
if lr_scale_getter:
|
69 |
+
attrs["lr_scale"] = lr_scale_getter(name)
|
70 |
+
memo[name] = param.shape
|
71 |
+
no_weight_decay = not (name.endswith("weight") and param.dim() > 1)
|
72 |
+
no_weight_decay = getattr(param, "no_weight_decay", no_weight_decay)
|
73 |
+
if no_weight_decay:
|
74 |
+
attrs["weight_decay"] = 0
|
75 |
+
group_name = "/".join(["%s:%s" % (v[0], v[1]) for v in list(attrs.items())])
|
76 |
+
if group_name not in groups:
|
77 |
+
groups[group_name] = {"params": []}
|
78 |
+
groups[group_name].update(attrs)
|
79 |
+
groups[group_name]["params"].append(param)
|
80 |
+
return list(groups.values())
|
81 |
+
|
82 |
+
|
83 |
+
def load_weights(module, weights_file, prefix_removed="", strict=True):
|
84 |
+
"""Load a weights file."""
|
85 |
+
if not weights_file:
|
86 |
+
return
|
87 |
+
if weights_file.endswith(".pkl"):
|
88 |
+
with open(weights_file, "rb") as f:
|
89 |
+
state_dict = pickle.load(f)
|
90 |
+
for k, v in state_dict.items():
|
91 |
+
state_dict[k] = torch.from_numpy(v) if isinstance(v, np.ndarray) else v
|
92 |
+
else:
|
93 |
+
state_dict = torch.load(weights_file)
|
94 |
+
if prefix_removed:
|
95 |
+
new_state_dict = type(state_dict)()
|
96 |
+
for k in list(state_dict.keys()):
|
97 |
+
new_state_dict[k.replace(prefix_removed, "")] = state_dict.pop(k)
|
98 |
+
state_dict = new_state_dict
|
99 |
+
module.load_state_dict(state_dict, strict=strict)
|
100 |
+
|
101 |
+
|
102 |
+
def manual_seed(seed, device_and_seed=None):
|
103 |
+
"""Set the cpu and device random seed."""
|
104 |
+
torch.manual_seed(seed)
|
105 |
+
if device_and_seed is not None:
|
106 |
+
device_index, device_seed = device_and_seed
|
107 |
+
device_type = get_device(device_index).type
|
108 |
+
np.random.seed(device_seed)
|
109 |
+
if device_type in ("cuda", "mps"):
|
110 |
+
getattr(torch, device_type).manual_seed(device_seed)
|
111 |
+
|
112 |
+
|
113 |
+
def synchronize_device(device):
|
114 |
+
"""Synchronize the computation of device."""
|
115 |
+
if device.type in ("cuda", "mps"):
|
116 |
+
getattr(torch, device.type).synchronize(device)
|
117 |
+
|
118 |
+
|
119 |
+
def create_ddp_group(cfg, ranks=None, devices=None, num_nodes=1):
|
120 |
+
"""Create group for data parallelism."""
|
121 |
+
if not torch.distributed.is_initialized():
|
122 |
+
torch.distributed.init_process_group(backend="nccl")
|
123 |
+
world_rank = torch.distributed.get_rank()
|
124 |
+
ranks = ranks if ranks else [i for i in range(cfg.NUM_GPUS)]
|
125 |
+
logging.set_root(world_rank == ranks[0])
|
126 |
+
devices_per_node = len(ranks) // num_nodes
|
127 |
+
devices = devices if devices else [i % devices_per_node for i in range(len(ranks))]
|
128 |
+
cfg.GPU_ID = devices[world_rank]
|
129 |
+
torch.cuda.set_device(cfg.GPU_ID)
|
130 |
+
global GLOBAL_DDP_GROUP
|
131 |
+
GLOBAL_DDP_GROUP = torch.distributed.new_group(ranks)
|
132 |
+
return GLOBAL_DDP_GROUP
|
133 |
+
|
134 |
+
|
135 |
+
def get_ddp_group():
|
136 |
+
"""Return the process group for data parallelism."""
|
137 |
+
return GLOBAL_DDP_GROUP
|
138 |
+
|
139 |
+
|
140 |
+
def get_ddp_rank():
|
141 |
+
"""Return the rank in the data parallelism group."""
|
142 |
+
ddp_group = get_ddp_group()
|
143 |
+
if ddp_group is None:
|
144 |
+
return 0
|
145 |
+
return torch.distributed.get_rank(ddp_group)
|
146 |
+
|
147 |
+
|
148 |
+
def apply_ddp_group(module):
|
149 |
+
"""Apply data parallelism group for given module."""
|
150 |
+
ddp_group = get_ddp_group()
|
151 |
+
if ddp_group is None:
|
152 |
+
return module
|
153 |
+
return torch.nn.parallel.DistributedDataParallel(module, process_group=ddp_group)
|
tokenize_anything/layers/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Layers."""
|
17 |
+
|
18 |
+
from tokenize_anything.layers.drop import DropPath
|
19 |
+
from tokenize_anything.layers.loss import BinaryDiceLoss
|
20 |
+
from tokenize_anything.layers.loss import BinaryFocalLoss
|
21 |
+
from tokenize_anything.layers.loss import CrossEntropyLoss
|
22 |
+
from tokenize_anything.layers.utils import init_cross_conv
|
23 |
+
from tokenize_anything.layers.utils import resize_pos_embed
|
24 |
+
from tokenize_anything.layers.utils import set_dropout
|
25 |
+
from tokenize_anything.layers.utils import set_drop_path
|
26 |
+
from tokenize_anything.layers.utils import set_sync_batch_norm
|
tokenize_anything/layers/drop.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Drop regularization layers."""
|
17 |
+
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
|
21 |
+
class DropPath(nn.Module):
|
22 |
+
"""Set examples to zero randomly."""
|
23 |
+
|
24 |
+
def __init__(self, p=0.1, inplace=False):
|
25 |
+
super(DropPath, self).__init__()
|
26 |
+
self.p = p
|
27 |
+
self.inplace = inplace
|
28 |
+
|
29 |
+
def forward(self, input):
|
30 |
+
if not self.training or self.p <= 0:
|
31 |
+
return input
|
32 |
+
keep_p = 1 - self.p
|
33 |
+
shape = (input.shape[0],) + (1,) * (input.dim() - 1)
|
34 |
+
scale = input.new_empty(shape).bernoulli_(keep_p).div_(keep_p)
|
35 |
+
return input.mul_(scale) if self.inplace else input.mul(scale)
|
36 |
+
|
37 |
+
def extra_repr(self):
|
38 |
+
inplace_str = ", inplace" if self.inplace else ""
|
39 |
+
return "p={}{}".format(self.p, inplace_str)
|
tokenize_anything/layers/loss.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Loss layers."""
|
17 |
+
|
18 |
+
from torch import nn
|
19 |
+
|
20 |
+
|
21 |
+
def reduce_loss(loss, reduction="mean"):
|
22 |
+
"""Reduce the loss."""
|
23 |
+
if reduction == "mean" or reduction == "sum":
|
24 |
+
return getattr(loss, reduction)()
|
25 |
+
if reduction == "batch_mean":
|
26 |
+
return loss.sum().mul_(1.0 / loss.size(0))
|
27 |
+
return loss
|
28 |
+
|
29 |
+
|
30 |
+
class BinaryFocalLoss(nn.Module):
|
31 |
+
"""Binary focal loss."""
|
32 |
+
|
33 |
+
def __init__(self, alpha=0.25, reduction="none"):
|
34 |
+
super(BinaryFocalLoss, self).__init__()
|
35 |
+
self.alpha = alpha
|
36 |
+
self.reduction = reduction
|
37 |
+
|
38 |
+
def forward(self, input, target):
|
39 |
+
alpha, p = self.alpha, input.sigmoid()
|
40 |
+
neg_alpha, neg_target = 1.0 - alpha, 1.0 - target
|
41 |
+
alpha_weight = target.mul(alpha).add_(neg_target.mul(neg_alpha))
|
42 |
+
focal_weight = (1.0 - p).mul_(target).add_(p.mul(neg_target)).square()
|
43 |
+
loss = nn.functional.binary_cross_entropy_with_logits(input, target, reduction="none")
|
44 |
+
return reduce_loss(loss * focal_weight.mul_(alpha_weight), self.reduction)
|
45 |
+
|
46 |
+
|
47 |
+
class BinaryDiceLoss(nn.Module):
|
48 |
+
"""Binary dice loss."""
|
49 |
+
|
50 |
+
def __init__(self, eps=1.0, reduction="none"):
|
51 |
+
super(BinaryDiceLoss, self).__init__()
|
52 |
+
self.eps = eps
|
53 |
+
self.reduction = reduction
|
54 |
+
|
55 |
+
def forward(self, input, target):
|
56 |
+
input = input.sigmoid()
|
57 |
+
num = input.mul(target).sum(-1).mul_(2).add_(self.eps)
|
58 |
+
den = input.add(target).sum(-1).add_(self.eps)
|
59 |
+
return reduce_loss(1.0 - num / den, self.reduction)
|
60 |
+
|
61 |
+
|
62 |
+
class CrossEntropyLoss(nn.Module):
|
63 |
+
"""Cross entropy loss with label smoothing."""
|
64 |
+
|
65 |
+
def __init__(self, epsilon=0, reduction="none"):
|
66 |
+
super(CrossEntropyLoss, self).__init__()
|
67 |
+
self.epsilon = epsilon
|
68 |
+
self.reduction = reduction
|
69 |
+
|
70 |
+
def forward_dense(self, input, target):
|
71 |
+
dim, target = input.shape[-1], target.squeeze_()
|
72 |
+
x = nn.functional.log_softmax(input, dim=-1)
|
73 |
+
y = nn.functional.one_hot(target, dim).float()
|
74 |
+
x = x.permute([0, x.dim() - 1] + list(range(x.dim()))[1:-1]) if x.dim() > 2 else x
|
75 |
+
y = y.permute([0, y.dim() - 1] + list(range(y.dim()))[1:-1]) if y.dim() > 2 else y
|
76 |
+
loss = nn.functional.cross_entropy(x, y, reduction="none", label_smoothing=self.epsilon)
|
77 |
+
return reduce_loss(loss, self.reduction)
|
78 |
+
|
79 |
+
def forward(self, input, target):
|
80 |
+
if self.epsilon > 0:
|
81 |
+
return self.forward_dense(input, target)
|
82 |
+
return nn.functional.cross_entropy(input, target, reduction=self.reduction)
|
tokenize_anything/layers/utils.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Layer utilities."""
|
17 |
+
|
18 |
+
import cv2
|
19 |
+
import numpy as np
|
20 |
+
import torch
|
21 |
+
|
22 |
+
|
23 |
+
def init_cross_conv(blocks):
|
24 |
+
"""Initialize convolutional cross attention."""
|
25 |
+
for m in blocks.modules():
|
26 |
+
if isinstance(m, torch.nn.Conv2d):
|
27 |
+
torch.nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
|
28 |
+
for blk in blocks:
|
29 |
+
torch.nn.init.constant_(blk.norm3.weight, 0)
|
30 |
+
|
31 |
+
|
32 |
+
def set_dropout(module, dropout):
|
33 |
+
"""Initialize dropout."""
|
34 |
+
for m in [m for m in module.modules() if isinstance(m, torch.nn.Dropout)]:
|
35 |
+
m.p = dropout
|
36 |
+
|
37 |
+
|
38 |
+
def set_drop_path(blocks, drop_path):
|
39 |
+
"""Initialize drop path."""
|
40 |
+
if not isinstance(blocks, torch.nn.ModuleList):
|
41 |
+
blocks = getattr(blocks, "blocks", getattr(blocks, "layers", None))
|
42 |
+
for i, blk in enumerate(blocks):
|
43 |
+
for m in [m for m in blk.modules() if type(m).__name__ == "DropPath"]:
|
44 |
+
m.p = i * drop_path / (len(blocks) - 1)
|
45 |
+
|
46 |
+
|
47 |
+
def set_sync_batch_norm(module, ddp_group):
|
48 |
+
"""Set data parallelism group for sync batch norm."""
|
49 |
+
for m in module.modules():
|
50 |
+
if isinstance(m, torch.nn.SyncBatchNorm):
|
51 |
+
m.process_group = ddp_group
|
52 |
+
|
53 |
+
|
54 |
+
def resize_pos_embed(weight, out_len):
|
55 |
+
"""Resize position embedding weights."""
|
56 |
+
out_h = out_w = int(out_len**0.5)
|
57 |
+
h = w = int(weight.shape[0] ** 0.5)
|
58 |
+
weight = weight.reshape((h, w, weight.shape[1]))
|
59 |
+
out_weight = [
|
60 |
+
cv2.resize(x, (out_w, out_h), interpolation=cv2.INTER_CUBIC)
|
61 |
+
for x in np.split(weight.astype("float32", copy=False), 4, axis=-1)
|
62 |
+
]
|
63 |
+
out_weight = np.concatenate(out_weight, axis=-1)
|
64 |
+
return out_weight.reshape((-1, weight.shape[-1])).astype(weight.dtype, copy=False)
|
tokenize_anything/modeling/concept_projector.py
CHANGED
@@ -51,11 +51,11 @@ class ConceptProjector(nn.Module):
|
|
51 |
proj = proj.to(device=embeds.device)
|
52 |
return embeds, proj
|
53 |
|
54 |
-
def encode_src(self, src_embeds):
|
55 |
"""Encode source visual embedding via concept projection."""
|
56 |
src_embeds, self.src_weights = self.maybe_convert(src_embeds, self.src_weights)
|
57 |
logits = nn.functional.normalize(src_embeds, dim=-1) @ self.src_weights
|
58 |
-
return nn.functional.log_softmax(logits, dim=-1)
|
59 |
|
60 |
def encode_tgt(self, tgt_embeds):
|
61 |
"""Encode target visual embedding via concept projection."""
|
|
|
51 |
proj = proj.to(device=embeds.device)
|
52 |
return embeds, proj
|
53 |
|
54 |
+
def encode_src(self, src_embeds, logpi=True):
|
55 |
"""Encode source visual embedding via concept projection."""
|
56 |
src_embeds, self.src_weights = self.maybe_convert(src_embeds, self.src_weights)
|
57 |
logits = nn.functional.normalize(src_embeds, dim=-1) @ self.src_weights
|
58 |
+
return nn.functional.log_softmax(logits, dim=-1) if logpi else logits
|
59 |
|
60 |
def encode_tgt(self, tgt_embeds):
|
61 |
"""Encode target visual embedding via concept projection."""
|
tokenize_anything/modeling/image_decoder.py
CHANGED
@@ -76,7 +76,6 @@ class Block(nn.Module):
|
|
76 |
num_heads=8,
|
77 |
attn_ratio=0.5,
|
78 |
mlp_dim=2048,
|
79 |
-
dropout=0.1,
|
80 |
activation_type="ReLU",
|
81 |
skip_first_query_pos=False,
|
82 |
):
|
@@ -89,7 +88,7 @@ class Block(nn.Module):
|
|
89 |
self.norm3 = nn.LayerNorm(dim)
|
90 |
self.cross_attn_image_to_token = Attention(dim, num_heads, attn_ratio)
|
91 |
self.norm4 = nn.LayerNorm(dim)
|
92 |
-
self.dropout = nn.Dropout(
|
93 |
self.skip_first_query_pos = skip_first_query_pos
|
94 |
|
95 |
def forward(self, query, key, query_pos, key_pos):
|
@@ -115,7 +114,6 @@ class Transformer(nn.Module):
|
|
115 |
num_heads=8,
|
116 |
attn_ratio=0.5,
|
117 |
mlp_dim=2048,
|
118 |
-
dropout=0.1,
|
119 |
activation_type="ReLU",
|
120 |
depth=2,
|
121 |
):
|
@@ -126,7 +124,6 @@ class Transformer(nn.Module):
|
|
126 |
num_heads,
|
127 |
attn_ratio=attn_ratio,
|
128 |
mlp_dim=mlp_dim,
|
129 |
-
dropout=dropout,
|
130 |
activation_type=activation_type,
|
131 |
skip_first_query_pos=i == 0,
|
132 |
)
|
@@ -134,7 +131,7 @@ class Transformer(nn.Module):
|
|
134 |
)
|
135 |
self.final_attn_token_to_image = Attention(embed_dim, num_heads, attn_ratio)
|
136 |
self.norm = nn.LayerNorm(embed_dim)
|
137 |
-
self.dropout = nn.Dropout(
|
138 |
|
139 |
def forward(self, query, key, query_pos, key_pos):
|
140 |
for blk in self.blocks:
|
@@ -202,7 +199,7 @@ class ImageDecoder(nn.Module):
|
|
202 |
query, key = self.transformer(query, key, query, inputs["img_pos"])
|
203 |
# Upscale key.
|
204 |
key = key.transpose(1, 2).view((-1, self.embed_dim) + img_embed_size)
|
205 |
-
|
206 |
# Unpack query.
|
207 |
tokens = query[:, :num_tokens].unbind(dim=1)
|
208 |
iou_tokens = tokens[num_tokens - self.num_mask_tokens - 1]
|
@@ -210,7 +207,7 @@ class ImageDecoder(nn.Module):
|
|
210 |
sem_tokens = tokens[: self.num_mask_tokens]
|
211 |
# Predict.
|
212 |
mask_pred = [f(x) for f, x in zip(self.mask_pred, mask_tokens)]
|
213 |
-
mask_pred = torch.stack(mask_pred, dim=1) @
|
214 |
mask_pred_size = list(4 * embed_size for embed_size in img_embed_size)
|
215 |
mask_pred = mask_pred.view([-1, self.num_mask_tokens] + mask_pred_size)
|
216 |
outputs = {"iou_pred": self.iou_pred(iou_tokens), "mask_pred": mask_pred}
|
|
|
76 |
num_heads=8,
|
77 |
attn_ratio=0.5,
|
78 |
mlp_dim=2048,
|
|
|
79 |
activation_type="ReLU",
|
80 |
skip_first_query_pos=False,
|
81 |
):
|
|
|
88 |
self.norm3 = nn.LayerNorm(dim)
|
89 |
self.cross_attn_image_to_token = Attention(dim, num_heads, attn_ratio)
|
90 |
self.norm4 = nn.LayerNorm(dim)
|
91 |
+
self.dropout = nn.Dropout(0.1, inplace=True)
|
92 |
self.skip_first_query_pos = skip_first_query_pos
|
93 |
|
94 |
def forward(self, query, key, query_pos, key_pos):
|
|
|
114 |
num_heads=8,
|
115 |
attn_ratio=0.5,
|
116 |
mlp_dim=2048,
|
|
|
117 |
activation_type="ReLU",
|
118 |
depth=2,
|
119 |
):
|
|
|
124 |
num_heads,
|
125 |
attn_ratio=attn_ratio,
|
126 |
mlp_dim=mlp_dim,
|
|
|
127 |
activation_type=activation_type,
|
128 |
skip_first_query_pos=i == 0,
|
129 |
)
|
|
|
131 |
)
|
132 |
self.final_attn_token_to_image = Attention(embed_dim, num_heads, attn_ratio)
|
133 |
self.norm = nn.LayerNorm(embed_dim)
|
134 |
+
self.dropout = nn.Dropout(0.1, inplace=True)
|
135 |
|
136 |
def forward(self, query, key, query_pos, key_pos):
|
137 |
for blk in self.blocks:
|
|
|
199 |
query, key = self.transformer(query, key, query, inputs["img_pos"])
|
200 |
# Upscale key.
|
201 |
key = key.transpose(1, 2).view((-1, self.embed_dim) + img_embed_size)
|
202 |
+
mask_embeds = self.output_conv(key).flatten(2)
|
203 |
# Unpack query.
|
204 |
tokens = query[:, :num_tokens].unbind(dim=1)
|
205 |
iou_tokens = tokens[num_tokens - self.num_mask_tokens - 1]
|
|
|
207 |
sem_tokens = tokens[: self.num_mask_tokens]
|
208 |
# Predict.
|
209 |
mask_pred = [f(x) for f, x in zip(self.mask_pred, mask_tokens)]
|
210 |
+
mask_pred = torch.stack(mask_pred, dim=1) @ mask_embeds
|
211 |
mask_pred_size = list(4 * embed_size for embed_size in img_embed_size)
|
212 |
mask_pred = mask_pred.view([-1, self.num_mask_tokens] + mask_pred_size)
|
213 |
outputs = {"iou_pred": self.iou_pred(iou_tokens), "mask_pred": mask_pred}
|
tokenize_anything/modeling/image_encoder.py
CHANGED
@@ -17,6 +17,8 @@
|
|
17 |
import torch
|
18 |
from torch import nn
|
19 |
|
|
|
|
|
20 |
|
21 |
def space_to_depth(input, block_size):
|
22 |
"""Rearrange blocks of spatial data into depth."""
|
@@ -84,10 +86,11 @@ class Block(nn.Module):
|
|
84 |
self.attn = Attention(dim, num_heads, qkv_bias=qkv_bias)
|
85 |
self.norm2 = nn.LayerNorm(dim)
|
86 |
self.mlp = MLP(dim, mlp_ratio=mlp_ratio)
|
|
|
87 |
|
88 |
def forward(self, x):
|
89 |
-
x = self.attn(self.norm1(x)).add_(x)
|
90 |
-
return self.mlp(self.norm2(x)).add_(x)
|
91 |
|
92 |
|
93 |
class Bottleneck(nn.Module):
|
@@ -245,7 +248,7 @@ class ImageEncoderViT(nn.Module):
|
|
245 |
if i in self.cross_indices or i == len(self.blocks) - 1:
|
246 |
x = self.norm(x) if i == len(self.blocks) - 1 else x
|
247 |
x = depth_to_space(x.reshape(wmsa_shape), self.window_size)
|
248 |
-
x = x.permute(0, 3, 1, 2)
|
249 |
if i in self.cross_indices:
|
250 |
x = self.cross_conv[self.cross_indices.index(i)](x)
|
251 |
if i in self.cross_indices and i < len(self.blocks) - 1:
|
|
|
17 |
import torch
|
18 |
from torch import nn
|
19 |
|
20 |
+
from tokenize_anything import layers
|
21 |
+
|
22 |
|
23 |
def space_to_depth(input, block_size):
|
24 |
"""Rearrange blocks of spatial data into depth."""
|
|
|
86 |
self.attn = Attention(dim, num_heads, qkv_bias=qkv_bias)
|
87 |
self.norm2 = nn.LayerNorm(dim)
|
88 |
self.mlp = MLP(dim, mlp_ratio=mlp_ratio)
|
89 |
+
self.drop_path = layers.DropPath(0.1, inplace=True)
|
90 |
|
91 |
def forward(self, x):
|
92 |
+
x = self.drop_path(self.attn(self.norm1(x))).add_(x)
|
93 |
+
return self.drop_path(self.mlp(self.norm2(x))).add_(x)
|
94 |
|
95 |
|
96 |
class Bottleneck(nn.Module):
|
|
|
248 |
if i in self.cross_indices or i == len(self.blocks) - 1:
|
249 |
x = self.norm(x) if i == len(self.blocks) - 1 else x
|
250 |
x = depth_to_space(x.reshape(wmsa_shape), self.window_size)
|
251 |
+
x = x.permute(0, 3, 1, 2).contiguous()
|
252 |
if i in self.cross_indices:
|
253 |
x = self.cross_conv[self.cross_indices.index(i)](x)
|
254 |
if i in self.cross_indices and i < len(self.blocks) - 1:
|
tokenize_anything/modeling/image_tokenizer.py
CHANGED
@@ -45,13 +45,15 @@ class ImageTokenizer(nn.Module):
|
|
45 |
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean))
|
46 |
self.register_buffer("pixel_rsig", torch.Tensor(pixel_std).reciprocal_())
|
47 |
|
48 |
-
def get_inputs(self, inputs):
|
49 |
"""Return the model inputs.
|
50 |
|
51 |
Parameters
|
52 |
----------
|
53 |
inputs : dict
|
54 |
The initial inputs.
|
|
|
|
|
55 |
|
56 |
Returns
|
57 |
-------
|
@@ -59,13 +61,10 @@ class ImageTokenizer(nn.Module):
|
|
59 |
The model inputs.
|
60 |
|
61 |
"""
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
inputs["img"] = inputs["img"].to(dtype=self.pixel_mean.dtype)
|
67 |
-
inputs["img"] = inputs["img"].sub(self.pixel_mean).mul_(self.pixel_rsig)
|
68 |
-
inputs["img"] = inputs["img"].permute(0, 3, 1, 2)
|
69 |
return inputs
|
70 |
|
71 |
def get_features(self, inputs):
|
@@ -177,7 +176,7 @@ class ImageTokenizer(nn.Module):
|
|
177 |
An array of generated texts.
|
178 |
|
179 |
"""
|
180 |
-
max_gen_len = max_gen_len or self.text_decoder.
|
181 |
prompts = self.text_decoder.get_prompts(visual_tokens)
|
182 |
out_shape = (prompts.size(0), self.text_decoder.max_text_len)
|
183 |
tokens = np.full(out_shape, self.text_tokenizer.pad_id, "int64")
|
|
|
45 |
self.register_buffer("pixel_mean", torch.Tensor(pixel_mean))
|
46 |
self.register_buffer("pixel_rsig", torch.Tensor(pixel_std).reciprocal_())
|
47 |
|
48 |
+
def get_inputs(self, inputs, dtype=None):
|
49 |
"""Return the model inputs.
|
50 |
|
51 |
Parameters
|
52 |
----------
|
53 |
inputs : dict
|
54 |
The initial inputs.
|
55 |
+
dtype : torch.dtype, optional
|
56 |
+
The optional input dtype.
|
57 |
|
58 |
Returns
|
59 |
-------
|
|
|
61 |
The model inputs.
|
62 |
|
63 |
"""
|
64 |
+
img_dtype, img_device = self.pixel_mean.dtype, self.pixel_mean.device
|
65 |
+
inputs["img"] = torch.as_tensor(inputs["img"], dtype=img_dtype, device=img_device)
|
66 |
+
inputs["img"] = inputs["img"].sub(self.pixel_mean).mul_(self.pixel_rsig).permute(0, 3, 1, 2)
|
67 |
+
inputs["img"] = inputs["img"].to(dtype=dtype) if dtype else inputs["img"]
|
|
|
|
|
|
|
68 |
return inputs
|
69 |
|
70 |
def get_features(self, inputs):
|
|
|
176 |
An array of generated texts.
|
177 |
|
178 |
"""
|
179 |
+
max_gen_len = max_gen_len or self.text_decoder.max_text_len
|
180 |
prompts = self.text_decoder.get_prompts(visual_tokens)
|
181 |
out_shape = (prompts.size(0), self.text_decoder.max_text_len)
|
182 |
tokens = np.full(out_shape, self.text_tokenizer.pad_id, "int64")
|
tokenize_anything/modeling/text_decoder.py
CHANGED
@@ -79,6 +79,7 @@ class TransformerCache(nn.Module):
|
|
79 |
cache_v = self.cache_dict.get(f"{id(mixer)}_v", None)
|
80 |
flash_args = {"softmax_scale": mixer.scale, "causal": True}
|
81 |
if cache_k is None or cache_v is None:
|
|
|
82 |
return flash_attn_func(q, k, v, **flash_args)
|
83 |
flash_args["cache_seqlens"] = self.cache_dict["seq_lens"][: q.shape[0]]
|
84 |
return flash_attn_with_kvcache(q, cache_k, cache_v, k, v, **flash_args)
|
@@ -94,6 +95,7 @@ class Attention(nn.Module):
|
|
94 |
self.head_dim = dim // num_heads
|
95 |
self.num_heads = num_heads
|
96 |
self.scale = self.head_dim**-0.5
|
|
|
97 |
self.cache = nn.Module()
|
98 |
|
99 |
def forward(self, x):
|
@@ -126,10 +128,11 @@ class Block(nn.Module):
|
|
126 |
self.mlp = MLP(dim, mlp_dim, bias=bias)
|
127 |
self.norm1 = nn.LayerNorm(dim)
|
128 |
self.norm2 = nn.LayerNorm(dim)
|
|
|
129 |
|
130 |
def forward(self, x):
|
131 |
-
x = self.attn(self.norm1(x)).add_(x)
|
132 |
-
return self.mlp(self.norm2(x)).add_(x)
|
133 |
|
134 |
|
135 |
class Transformer(nn.Module):
|
|
|
79 |
cache_v = self.cache_dict.get(f"{id(mixer)}_v", None)
|
80 |
flash_args = {"softmax_scale": mixer.scale, "causal": True}
|
81 |
if cache_k is None or cache_v is None:
|
82 |
+
flash_args["dropout_p"] = mixer.dropout.p if mixer.training else 0
|
83 |
return flash_attn_func(q, k, v, **flash_args)
|
84 |
flash_args["cache_seqlens"] = self.cache_dict["seq_lens"][: q.shape[0]]
|
85 |
return flash_attn_with_kvcache(q, cache_k, cache_v, k, v, **flash_args)
|
|
|
95 |
self.head_dim = dim // num_heads
|
96 |
self.num_heads = num_heads
|
97 |
self.scale = self.head_dim**-0.5
|
98 |
+
self.dropout = nn.Dropout(0.1, inplace=False)
|
99 |
self.cache = nn.Module()
|
100 |
|
101 |
def forward(self, x):
|
|
|
128 |
self.mlp = MLP(dim, mlp_dim, bias=bias)
|
129 |
self.norm1 = nn.LayerNorm(dim)
|
130 |
self.norm2 = nn.LayerNorm(dim)
|
131 |
+
self.dropout = nn.Dropout(0.1, inplace=True)
|
132 |
|
133 |
def forward(self, x):
|
134 |
+
x = self.dropout(self.attn(self.norm1(x))).add_(x)
|
135 |
+
return self.dropout(self.mlp(self.norm2(x))).add_(x)
|
136 |
|
137 |
|
138 |
class Transformer(nn.Module):
|
tokenize_anything/models/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Models."""
|
17 |
+
|
18 |
+
from tokenize_anything.models.easy_build import model_registry
|
tokenize_anything/{build_model.py β models/easy_build.py}
RENAMED
@@ -13,7 +13,7 @@
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
# ------------------------------------------------------------------------
|
16 |
-
"""
|
17 |
|
18 |
from functools import partial
|
19 |
import pickle
|
@@ -40,7 +40,7 @@ def get_device(device_index):
|
|
40 |
def load_weights(module, weights_file, strict=True):
|
41 |
"""Load a weights file."""
|
42 |
if not weights_file:
|
43 |
-
return
|
44 |
if weights_file.endswith(".pkl"):
|
45 |
with open(weights_file, "rb") as f:
|
46 |
state_dict = pickle.load(f)
|
@@ -48,7 +48,7 @@ def load_weights(module, weights_file, strict=True):
|
|
48 |
state_dict[k] = torch.from_numpy(v) if isinstance(v, np.ndarray) else v
|
49 |
else:
|
50 |
state_dict = torch.load(weights_file)
|
51 |
-
|
52 |
|
53 |
|
54 |
def vit_encoder(depth, embed_dim, num_heads, out_dim, image_size):
|
|
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
# ------------------------------------------------------------------------
|
16 |
+
"""Easy model builder."""
|
17 |
|
18 |
from functools import partial
|
19 |
import pickle
|
|
|
40 |
def load_weights(module, weights_file, strict=True):
|
41 |
"""Load a weights file."""
|
42 |
if not weights_file:
|
43 |
+
return
|
44 |
if weights_file.endswith(".pkl"):
|
45 |
with open(weights_file, "rb") as f:
|
46 |
state_dict = pickle.load(f)
|
|
|
48 |
state_dict[k] = torch.from_numpy(v) if isinstance(v, np.ndarray) else v
|
49 |
else:
|
50 |
state_dict = torch.load(weights_file)
|
51 |
+
module.load_state_dict(state_dict, strict=strict)
|
52 |
|
53 |
|
54 |
def vit_encoder(depth, embed_dim, num_heads, out_dim, image_size):
|
tokenize_anything/prompters/__init__.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Prompters."""
|
17 |
+
|
18 |
+
from tokenize_anything.prompters.visual_prompter import VisualPrompter
|
tokenize_anything/prompters/visual_prompter.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Generate visual prompts."""
|
17 |
+
|
18 |
+
import collections
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
import numpy.random as npr
|
22 |
+
|
23 |
+
|
24 |
+
class VisualPrompter(object):
|
25 |
+
"""Generate visual prompts."""
|
26 |
+
|
27 |
+
def __init__(self, image_size=1024, max_points=9, num_experts=4, padding_index=4):
|
28 |
+
super(VisualPrompter, self).__init__()
|
29 |
+
self.num_stages = 2
|
30 |
+
self.max_points = max_points
|
31 |
+
self.point_weight = [1000] + [0] * (num_experts - 1)
|
32 |
+
self.image_size = image_size if isinstance(image_size, (tuple, list)) else [image_size] * 2
|
33 |
+
self.padding_index = padding_index
|
34 |
+
self.coord_count = collections.defaultdict(int)
|
35 |
+
self.coords = self.labels = self.boxes_turn = None
|
36 |
+
self.stage_count = 0
|
37 |
+
self.box_prob = 0.5
|
38 |
+
|
39 |
+
@property
|
40 |
+
def is_last_stage(self):
|
41 |
+
return self.stage_count == self.num_stages - 1
|
42 |
+
|
43 |
+
def add_point(self, index, gt_masks, error_masks=None, num=1):
|
44 |
+
def sample(mask):
|
45 |
+
ys, xs = np.nonzero(mask)
|
46 |
+
if ys.shape[0] > 0:
|
47 |
+
idx = npr.choice(ys.shape[0], size=(num,), replace=num > ys.shape[0])
|
48 |
+
return xs[idx], ys[idx]
|
49 |
+
return [-0.5] * num, [-0.5] * num
|
50 |
+
|
51 |
+
labels = [self.padding_index] * num
|
52 |
+
if error_masks is not None: # FP or FN point.
|
53 |
+
xs, ys = sample(error_masks[index])
|
54 |
+
labels = gt_masks[index, ys, xs] if ys[0] >= 0 else labels
|
55 |
+
if labels[0] == self.padding_index: # GT point.
|
56 |
+
xs, ys = sample(gt_masks[index])
|
57 |
+
labels = [1] * num if ys[0] >= 0 else labels
|
58 |
+
xs = (np.array(xs, "float32") + 0.5) * (self.image_size[1] / gt_masks.shape[2]) - 0.5
|
59 |
+
ys = (np.array(ys, "float32") + 0.5) * (self.image_size[0] / gt_masks.shape[1]) - 0.5
|
60 |
+
slice_index = slice(self.coord_count[index], self.coord_count[index] + num)
|
61 |
+
self.coords[index, slice_index] = np.vstack([xs, ys]).T
|
62 |
+
self.labels[index, slice_index] = labels
|
63 |
+
self.coord_count[index] += num
|
64 |
+
|
65 |
+
def add_box(self, index, gt_boxes):
|
66 |
+
x1, y1, x2, y2 = gt_boxes[index, :4]
|
67 |
+
dx1, dx2 = np.clip(npr.normal(0.0, 0.1 * (x2 - x1), (2,)), -20, 20)
|
68 |
+
dy1, dy2 = np.clip(npr.normal(0.0, 0.1 * (y2 - y1), (2,)), -20, 20)
|
69 |
+
x1, y1 = x1 + np.minimum(dx1, 0), y1 + np.minimum(dy1, 0)
|
70 |
+
x2, y2 = x2 + np.maximum(dx2, 0), y2 + np.maximum(dy2, 0)
|
71 |
+
self.coords[index, self.coord_count[index]] = (x1, y1)
|
72 |
+
self.coords[index, self.coord_count[index] + 1] = (x2, y2)
|
73 |
+
self.labels[index, self.coord_count[index]] = 2
|
74 |
+
self.labels[index, self.coord_count[index] + 1] = 3
|
75 |
+
self.coord_count[index] += 2
|
76 |
+
|
77 |
+
def reset(self, num):
|
78 |
+
self.stage_count = 0
|
79 |
+
self.coord_count.clear()
|
80 |
+
self.coords = np.full((num, self.max_points + 1, 2), -0.5, "float32")
|
81 |
+
self.labels = np.full((num, self.max_points + 1), self.padding_index, "int64")
|
82 |
+
self.boxes_turn = npr.rand(num) < self.box_prob
|
83 |
+
|
84 |
+
def get_prompts(self, gt_boxes, gt_masks=None, masks=None):
|
85 |
+
num = gt_boxes.shape[0]
|
86 |
+
if self.stage_count == 0:
|
87 |
+
self.reset(num)
|
88 |
+
coords = labels = error_masks = None
|
89 |
+
if masks is not None:
|
90 |
+
masks = masks.reshape(gt_masks.shape)
|
91 |
+
error_masks = (masks | gt_masks) ^ (masks & gt_masks)
|
92 |
+
num_points = 1
|
93 |
+
if self.stage_count > 0:
|
94 |
+
num_points = npr.randint(1, self.max_points + 1 - self.stage_count)
|
95 |
+
if self.stage_count == 0 and self.box_prob == 0:
|
96 |
+
num_points = npr.randint(2, self.max_points + 1)
|
97 |
+
for index in range(num):
|
98 |
+
is_box = self.stage_count == 0 and self.boxes_turn[index]
|
99 |
+
if gt_masks is None or is_box:
|
100 |
+
self.add_box(index, gt_boxes)
|
101 |
+
else:
|
102 |
+
self.add_point(index, gt_masks, error_masks, num_points)
|
103 |
+
coords = self.coords[:, : 1 + self.stage_count + num_points]
|
104 |
+
labels = self.labels[:, : 1 + self.stage_count + num_points]
|
105 |
+
scores = (self.boxes_turn[:, None] - 0.5) * self.point_weight
|
106 |
+
return {"points": (coords, labels), "point_score": scores}
|
tokenize_anything/utils/logging.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Logging utilities."""
|
17 |
+
|
18 |
+
import inspect
|
19 |
+
import logging as _logging
|
20 |
+
import os
|
21 |
+
import sys as _sys
|
22 |
+
import threading
|
23 |
+
|
24 |
+
|
25 |
+
_logger = None
|
26 |
+
_logger_lock = threading.Lock()
|
27 |
+
|
28 |
+
|
29 |
+
def get_logger():
|
30 |
+
global _logger
|
31 |
+
# Use double-checked locking to avoid taking lock unnecessarily.
|
32 |
+
if _logger:
|
33 |
+
return _logger
|
34 |
+
_logger_lock.acquire()
|
35 |
+
try:
|
36 |
+
if _logger:
|
37 |
+
return _logger
|
38 |
+
logger = _logging.getLogger("tokenize-anything")
|
39 |
+
logger.setLevel("INFO")
|
40 |
+
logger.propagate = False
|
41 |
+
logger._is_root = True
|
42 |
+
if True:
|
43 |
+
# Determine whether we are in an interactive environment.
|
44 |
+
_interactive = False
|
45 |
+
try:
|
46 |
+
# This is only defined in interactive shells.
|
47 |
+
if _sys.ps1:
|
48 |
+
_interactive = True
|
49 |
+
except AttributeError:
|
50 |
+
# Even now, we may be in an interactive shell with `python -i`.
|
51 |
+
_interactive = _sys.flags.interactive
|
52 |
+
# If we are in an interactive environment (like Jupyter), set loglevel
|
53 |
+
# to INFO and pipe the output to stdout.
|
54 |
+
if _interactive:
|
55 |
+
logger.setLevel("INFO")
|
56 |
+
_logging_target = _sys.stdout
|
57 |
+
else:
|
58 |
+
_logging_target = _sys.stderr
|
59 |
+
# Add the output handler.
|
60 |
+
_handler = _logging.StreamHandler(_logging_target)
|
61 |
+
_handler.setFormatter(_logging.Formatter("%(levelname)s %(message)s"))
|
62 |
+
logger.addHandler(_handler)
|
63 |
+
_logger = logger
|
64 |
+
return _logger
|
65 |
+
finally:
|
66 |
+
_logger_lock.release()
|
67 |
+
|
68 |
+
|
69 |
+
def _detailed_msg(msg):
|
70 |
+
file, lineno = inspect.stack()[:3][2][1:3]
|
71 |
+
return "{}:{}] {}".format(os.path.split(file)[-1], lineno, msg)
|
72 |
+
|
73 |
+
|
74 |
+
def log(level, msg, *args, **kwargs):
|
75 |
+
get_logger().log(level, _detailed_msg(msg), *args, **kwargs)
|
76 |
+
|
77 |
+
|
78 |
+
def debug(msg, *args, **kwargs):
|
79 |
+
if is_root():
|
80 |
+
get_logger().debug(_detailed_msg(msg), *args, **kwargs)
|
81 |
+
|
82 |
+
|
83 |
+
def error(msg, *args, **kwargs):
|
84 |
+
get_logger().error(_detailed_msg(msg), *args, **kwargs)
|
85 |
+
assert 0
|
86 |
+
|
87 |
+
|
88 |
+
def fatal(msg, *args, **kwargs):
|
89 |
+
get_logger().fatal(_detailed_msg(msg), *args, **kwargs)
|
90 |
+
assert 0
|
91 |
+
|
92 |
+
|
93 |
+
def info(msg, *args, **kwargs):
|
94 |
+
if is_root():
|
95 |
+
get_logger().info(_detailed_msg(msg), *args, **kwargs)
|
96 |
+
|
97 |
+
|
98 |
+
def warning(msg, *args, **kwargs):
|
99 |
+
if is_root():
|
100 |
+
get_logger().warning(_detailed_msg(msg), *args, **kwargs)
|
101 |
+
|
102 |
+
|
103 |
+
def get_verbosity():
|
104 |
+
"""Return how much logging output will be produced."""
|
105 |
+
return get_logger().getEffectiveLevel()
|
106 |
+
|
107 |
+
|
108 |
+
def set_verbosity(v):
|
109 |
+
"""Set the threshold for what messages will be logged."""
|
110 |
+
get_logger().setLevel(v)
|
111 |
+
|
112 |
+
|
113 |
+
def set_formatter(fmt=None, datefmt=None):
|
114 |
+
"""Set the formatter."""
|
115 |
+
handler = _logging.StreamHandler(_sys.stderr)
|
116 |
+
handler.setFormatter(_logging.Formatter(fmt, datefmt))
|
117 |
+
logger = get_logger()
|
118 |
+
logger.removeHandler(logger.handlers[0])
|
119 |
+
logger.addHandler(handler)
|
120 |
+
|
121 |
+
|
122 |
+
def set_root(is_root=True):
|
123 |
+
"""Set logger to the root."""
|
124 |
+
get_logger()._is_root = is_root
|
125 |
+
|
126 |
+
|
127 |
+
def is_root():
|
128 |
+
"""Return logger is the root."""
|
129 |
+
return get_logger()._is_root
|
tokenize_anything/utils/profiler/__init__.py
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Profiler utilities."""
|
17 |
+
|
18 |
+
from tokenize_anything.utils.profiler.stats import SmoothedValue
|
19 |
+
from tokenize_anything.utils.profiler.timer import Timer
|
20 |
+
from tokenize_anything.utils.profiler.timer import get_progress
|
tokenize_anything/utils/profiler/stats.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Trackable statistics."""
|
17 |
+
|
18 |
+
import collections
|
19 |
+
import numpy as np
|
20 |
+
|
21 |
+
|
22 |
+
class SmoothedValue(object):
|
23 |
+
"""Track values and provide smoothed report."""
|
24 |
+
|
25 |
+
def __init__(self, window_size=None):
|
26 |
+
self.deque = collections.deque(maxlen=window_size)
|
27 |
+
self.total = 0.0
|
28 |
+
self.count = 0
|
29 |
+
|
30 |
+
def update(self, value):
|
31 |
+
self.deque.append(value)
|
32 |
+
self.count += 1
|
33 |
+
self.total += value
|
34 |
+
|
35 |
+
def mean(self):
|
36 |
+
return np.mean(self.deque)
|
37 |
+
|
38 |
+
def median(self):
|
39 |
+
return np.median(self.deque)
|
40 |
+
|
41 |
+
def average(self):
|
42 |
+
return self.total / self.count
|
tokenize_anything/utils/{timer.py β profiler/timer.py}
RENAMED
@@ -9,13 +9,14 @@
|
|
9 |
#
|
10 |
# Unless required by applicable law or agreed to in writing, software
|
11 |
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND,
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
# ------------------------------------------------------------------------
|
16 |
"""Timing functions."""
|
17 |
|
18 |
import contextlib
|
|
|
19 |
import time
|
20 |
|
21 |
|
@@ -49,3 +50,13 @@ class Timer(object):
|
|
49 |
def toc(self, n=1, average=True):
|
50 |
self.diff = time.time() - self.start_time
|
51 |
return self.add_diff(self.diff, n, average)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
#
|
10 |
# Unless required by applicable law or agreed to in writing, software
|
11 |
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, esither express or implied.
|
13 |
# See the License for the specific language governing permissions and
|
14 |
# limitations under the License.
|
15 |
# ------------------------------------------------------------------------
|
16 |
"""Timing functions."""
|
17 |
|
18 |
import contextlib
|
19 |
+
import datetime
|
20 |
import time
|
21 |
|
22 |
|
|
|
50 |
def toc(self, n=1, average=True):
|
51 |
self.diff = time.time() - self.start_time
|
52 |
return self.add_diff(self.diff, n, average)
|
53 |
+
|
54 |
+
|
55 |
+
def get_progress(timer, step, max_steps):
|
56 |
+
"""Return the progress information."""
|
57 |
+
eta_seconds = timer.average_time * (max_steps - step)
|
58 |
+
eta = str(datetime.timedelta(seconds=int(eta_seconds)))
|
59 |
+
progress = (step + 1.0) / max_steps
|
60 |
+
return "< PROGRESS: {:.2%} | SPEED: {:.3f}s / iter | ETA: {} >".format(
|
61 |
+
progress, timer.average_time, eta
|
62 |
+
)
|
tokenize_anything/utils/registry.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Registry utilities."""
|
17 |
+
|
18 |
+
import collections
|
19 |
+
import functools
|
20 |
+
|
21 |
+
|
22 |
+
class Registry(object):
|
23 |
+
"""Registry class."""
|
24 |
+
|
25 |
+
def __init__(self, name):
|
26 |
+
self.name = name
|
27 |
+
self.registry = collections.OrderedDict()
|
28 |
+
|
29 |
+
def has(self, key):
|
30 |
+
return key in self.registry
|
31 |
+
|
32 |
+
def register(self, name, func=None, **kwargs):
|
33 |
+
def decorated(inner_function):
|
34 |
+
for key in name if isinstance(name, (tuple, list)) else [name]:
|
35 |
+
self.registry[key] = functools.partial(inner_function, **kwargs)
|
36 |
+
return inner_function
|
37 |
+
|
38 |
+
if func is not None:
|
39 |
+
return decorated(func)
|
40 |
+
return decorated
|
41 |
+
|
42 |
+
def get(self, name, default=None):
|
43 |
+
if name is None:
|
44 |
+
return None
|
45 |
+
if not self.has(name):
|
46 |
+
if default is not None:
|
47 |
+
return default
|
48 |
+
raise KeyError("`%s` is not registered in <%s>." % (name, self.name))
|
49 |
+
return self.registry[name]
|
50 |
+
|
51 |
+
def try_get(self, name):
|
52 |
+
if self.has(name):
|
53 |
+
return self.get(name)
|
54 |
+
return None
|
tokenize_anything/utils/tensorboard.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------
|
2 |
+
# Copyright (c) 2023-present, BAAI. All Rights Reserved.
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
# ------------------------------------------------------------------------
|
16 |
+
"""Tensorboard application."""
|
17 |
+
|
18 |
+
import time
|
19 |
+
|
20 |
+
import numpy as np
|
21 |
+
|
22 |
+
try:
|
23 |
+
import tensorflow as tf
|
24 |
+
except ImportError:
|
25 |
+
tf = None
|
26 |
+
|
27 |
+
|
28 |
+
class TensorBoard(object):
|
29 |
+
"""TensorBoard application."""
|
30 |
+
|
31 |
+
def __init__(self, log_dir=None):
|
32 |
+
"""Create a summary writer logging to log_dir."""
|
33 |
+
if tf is None:
|
34 |
+
raise ImportError("Failed to import ``tensorflow`` package.")
|
35 |
+
tf.config.set_visible_devices([], "GPU")
|
36 |
+
if log_dir is None:
|
37 |
+
log_dir = "./logs/" + time.strftime("%Y%m%d_%H%M%S", time.localtime(time.time()))
|
38 |
+
self.writer = tf.summary.create_file_writer(log_dir)
|
39 |
+
|
40 |
+
@staticmethod
|
41 |
+
def is_available():
|
42 |
+
"""Return if tensor board is available."""
|
43 |
+
return tf is not None
|
44 |
+
|
45 |
+
def close(self):
|
46 |
+
"""Close board and apply all cached summaries."""
|
47 |
+
self.writer.close()
|
48 |
+
|
49 |
+
def histogram_summary(self, tag, values, step, buckets=10):
|
50 |
+
"""Write a histogram of values."""
|
51 |
+
with self.writer.as_default():
|
52 |
+
tf.summary.histogram(tag, values, step, buckets=buckets)
|
53 |
+
|
54 |
+
def image_summary(self, tag, images, step, order="BGR"):
|
55 |
+
"""Write a list of images."""
|
56 |
+
if isinstance(images, (tuple, list)):
|
57 |
+
images = np.stack(images)
|
58 |
+
if len(images.shape) != 4:
|
59 |
+
raise ValueError("Images can not be packed to (N, H, W, C).")
|
60 |
+
if order == "BGR":
|
61 |
+
images = images[:, :, :, ::-1]
|
62 |
+
with self.writer.as_default():
|
63 |
+
tf.summary.image(tag, images, step, max_outputs=images.shape[0])
|
64 |
+
|
65 |
+
def scalar_summary(self, tag, value, step):
|
66 |
+
"""Write a scalar."""
|
67 |
+
with self.writer.as_default():
|
68 |
+
tf.summary.scalar(tag, value, step)
|