SunderAli17 commited on
Commit
c77eb3b
1 Parent(s): 826d7cd

Create openai.py

Browse files
Files changed (1) hide show
  1. evaclip/openai.py +141 -0
evaclip/openai.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ OpenAI pretrained model functions
2
+ Adapted from https://github.com/openai/CLIP. Originally MIT License, Copyright (c) 2021 OpenAI.
3
+ """
4
+
5
+ import os
6
+ import warnings
7
+ from typing import List, Optional, Union
8
+
9
+ import torch
10
+
11
+ from .model import build_model_from_openai_state_dict, convert_weights_to_lp, get_cast_dtype
12
+ from .pretrained import get_pretrained_url, list_pretrained_models_by_tag, download_pretrained_from_url
13
+
14
+ __all__ = ["list_openai_models", "load_openai_model"]
15
+
16
+
17
+ def list_openai_models() -> List[str]:
18
+ """Returns the names of available CLIP models"""
19
+ return list_pretrained_models_by_tag('openai')
20
+
21
+
22
+ def load_openai_model(
23
+ name: str,
24
+ precision: Optional[str] = None,
25
+ device: Optional[Union[str, torch.device]] = None,
26
+ jit: bool = True,
27
+ cache_dir: Optional[str] = None,
28
+ ):
29
+ """Load a CLIP model
30
+ Parameters
31
+ ----------
32
+ name : str
33
+ A model name listed by `clip.available_models()`, or the path to a model checkpoint containing the state_dict
34
+ precision: str
35
+ Model precision, if None defaults to 'fp32' if device == 'cpu' else 'fp16'.
36
+ device : Union[str, torch.device]
37
+ The device to put the loaded model
38
+ jit : bool
39
+ Whether to load the optimized JIT model (default) or more hackable non-JIT model.
40
+ cache_dir : Optional[str]
41
+ The directory to cache the downloaded model weights
42
+ Returns
43
+ -------
44
+ model : torch.nn.Module
45
+ The CLIP model
46
+ preprocess : Callable[[PIL.Image], torch.Tensor]
47
+ A torchvision transform that converts a PIL image into a tensor that the returned model can take as its input
48
+ """
49
+ if device is None:
50
+ device = "cuda" if torch.cuda.is_available() else "cpu"
51
+ if precision is None:
52
+ precision = 'fp32' if device == 'cpu' else 'fp16'
53
+
54
+ if get_pretrained_url(name, 'openai'):
55
+ model_path = download_pretrained_from_url(get_pretrained_url(name, 'openai'), cache_dir=cache_dir)
56
+ elif os.path.isfile(name):
57
+ model_path = name
58
+ else:
59
+ raise RuntimeError(f"Model {name} not found; available models = {list_openai_models()}")
60
+
61
+ try:
62
+ # loading JIT archive
63
+ model = torch.jit.load(model_path, map_location=device if jit else "cpu").eval()
64
+ state_dict = None
65
+ except RuntimeError:
66
+ # loading saved state dict
67
+ if jit:
68
+ warnings.warn(f"File {model_path} is not a JIT archive. Loading as a state dict instead")
69
+ jit = False
70
+ state_dict = torch.load(model_path, map_location="cpu")
71
+
72
+ if not jit:
73
+ # Build a non-jit model from the OpenAI jitted model state dict
74
+ cast_dtype = get_cast_dtype(precision)
75
+ try:
76
+ model = build_model_from_openai_state_dict(state_dict or model.state_dict(), cast_dtype=cast_dtype)
77
+ except KeyError:
78
+ sd = {k[7:]: v for k, v in state_dict["state_dict"].items()}
79
+ model = build_model_from_openai_state_dict(sd, cast_dtype=cast_dtype)
80
+
81
+ # model from OpenAI state dict is in manually cast fp16 mode, must be converted for AMP/fp32/bf16 use
82
+ model = model.to(device)
83
+ if precision.startswith('amp') or precision == 'fp32':
84
+ model.float()
85
+ elif precision == 'bf16':
86
+ convert_weights_to_lp(model, dtype=torch.bfloat16)
87
+
88
+ return model
89
+
90
+ # patch the device names
91
+ device_holder = torch.jit.trace(lambda: torch.ones([]).to(torch.device(device)), example_inputs=[])
92
+ device_node = [n for n in device_holder.graph.findAllNodes("prim::Constant") if "Device" in repr(n)][-1]
93
+
94
+ def patch_device(module):
95
+ try:
96
+ graphs = [module.graph] if hasattr(module, "graph") else []
97
+ except RuntimeError:
98
+ graphs = []
99
+
100
+ if hasattr(module, "forward1"):
101
+ graphs.append(module.forward1.graph)
102
+
103
+ for graph in graphs:
104
+ for node in graph.findAllNodes("prim::Constant"):
105
+ if "value" in node.attributeNames() and str(node["value"]).startswith("cuda"):
106
+ node.copyAttributes(device_node)
107
+
108
+ model.apply(patch_device)
109
+ patch_device(model.encode_image)
110
+ patch_device(model.encode_text)
111
+
112
+ # patch dtype to float32 (typically for CPU)
113
+ if precision == 'fp32':
114
+ float_holder = torch.jit.trace(lambda: torch.ones([]).float(), example_inputs=[])
115
+ float_input = list(float_holder.graph.findNode("aten::to").inputs())[1]
116
+ float_node = float_input.node()
117
+
118
+ def patch_float(module):
119
+ try:
120
+ graphs = [module.graph] if hasattr(module, "graph") else []
121
+ except RuntimeError:
122
+ graphs = []
123
+
124
+ if hasattr(module, "forward1"):
125
+ graphs.append(module.forward1.graph)
126
+
127
+ for graph in graphs:
128
+ for node in graph.findAllNodes("aten::to"):
129
+ inputs = list(node.inputs())
130
+ for i in [1, 2]: # dtype can be the second or third argument to aten::to()
131
+ if inputs[i].node()["value"] == 5:
132
+ inputs[i].node().copyAttributes(float_node)
133
+
134
+ model.apply(patch_float)
135
+ patch_float(model.encode_image)
136
+ patch_float(model.encode_text)
137
+ model.float()
138
+
139
+ # ensure image_size attr available at consistent location for both jit and non-jit
140
+ model.visual.image_size = model.input_resolution.item()
141
+ return model