OssamaLafhel commited on
Commit
eb4b3b5
1 Parent(s): 8764899

Update handler.py

Browse files
Files changed (1) hide show
  1. handler.py +5 -151
handler.py CHANGED
@@ -12,163 +12,17 @@ from loguru import logger
12
  from typing import Dict, List, Any
13
 
14
 
15
- # ---------------------> Converting the model to 8 bits <------------------- #
16
-
17
- class FrozenBNBLinear(nn.Module):
18
- def __init__(self, weight, absmax, code, bias=None):
19
- assert isinstance(bias, nn.Parameter) or bias is None
20
- super().__init__()
21
- self.out_features, self.in_features = weight.shape
22
- self.register_buffer("weight", weight.requires_grad_(False))
23
- self.register_buffer("absmax", absmax.requires_grad_(False))
24
- self.register_buffer("code", code.requires_grad_(False))
25
- self.adapter = None
26
- self.bias = bias
27
-
28
- def forward(self, input):
29
- output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)
30
- if self.adapter:
31
- output += self.adapter(input)
32
- return output
33
-
34
- @classmethod
35
- def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
36
- weights_int8, state = quantize_blockise_lowmemory(linear.weight)
37
- return cls(weights_int8, *state, linear.bias)
38
-
39
- def __repr__(self):
40
- return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
41
-
42
-
43
- class DequantizeAndLinear(torch.autograd.Function):
44
- @staticmethod
45
- @custom_fwd
46
- def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
47
- absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
48
- weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
49
- ctx.save_for_backward(input, weights_quantized, absmax, code)
50
- ctx._has_bias = bias is not None
51
- return F.linear(input, weights_deq, bias)
52
-
53
- @staticmethod
54
- @custom_bwd
55
- def backward(ctx, grad_output: torch.Tensor):
56
- assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
57
- input, weights_quantized, absmax, code = ctx.saved_tensors
58
- # grad_output: [*batch, out_features]
59
- weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
60
- grad_input = grad_output @ weights_deq
61
- grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
62
- return grad_input, None, None, None, grad_bias
63
-
64
-
65
- class FrozenBNBEmbedding(nn.Module):
66
- def __init__(self, weight, absmax, code):
67
- super().__init__()
68
- self.num_embeddings, self.embedding_dim = weight.shape
69
- self.register_buffer("weight", weight.requires_grad_(False))
70
- self.register_buffer("absmax", absmax.requires_grad_(False))
71
- self.register_buffer("code", code.requires_grad_(False))
72
- self.adapter = None
73
-
74
- def forward(self, input, **kwargs):
75
- with torch.no_grad():
76
- # note: both quantuized weights and input indices are *not* differentiable
77
- weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
78
- output = F.embedding(input, weight_deq, **kwargs)
79
- if self.adapter:
80
- output += self.adapter(input)
81
- return output
82
-
83
- @classmethod
84
- def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
85
- weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
86
- return cls(weights_int8, *state)
87
-
88
- def __repr__(self):
89
- return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
90
-
91
-
92
- def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
93
- assert chunk_size % 4096 == 0
94
- code = None
95
- chunks = []
96
- absmaxes = []
97
- flat_tensor = matrix.view(-1)
98
- for i in range((matrix.numel() - 1) // chunk_size + 1):
99
- input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
100
- quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
101
- chunks.append(quantized_chunk)
102
- absmaxes.append(absmax_chunk)
103
-
104
- matrix_i8 = torch.cat(chunks).reshape_as(matrix)
105
- absmax = torch.cat(absmaxes)
106
- return matrix_i8, (absmax, code)
107
-
108
-
109
- def convert_to_int8(model):
110
- """Convert linear and embedding modules to 8-bit with optional adapters"""
111
- for module in list(model.modules()):
112
- for name, child in module.named_children():
113
- if isinstance(child, nn.Linear):
114
- print(name, child)
115
- setattr(
116
- module,
117
- name,
118
- FrozenBNBLinear(
119
- weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
120
- absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
121
- code=torch.zeros(256),
122
- bias=child.bias,
123
- ),
124
- )
125
- elif isinstance(child, nn.Embedding):
126
- setattr(
127
- module,
128
- name,
129
- FrozenBNBEmbedding(
130
- weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
131
- absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
132
- code=torch.zeros(256),
133
- )
134
- )
135
-
136
-
137
- class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
138
- def __init__(self, config):
139
- super().__init__(config)
140
-
141
- convert_to_int8(self.attn)
142
- convert_to_int8(self.mlp)
143
-
144
-
145
- class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
146
- def __init__(self, config):
147
- super().__init__(config)
148
- convert_to_int8(self)
149
-
150
-
151
- class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
152
- def __init__(self, config):
153
- super().__init__(config)
154
- convert_to_int8(self)
155
-
156
-
157
- transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock # monkey-patch GPT-J
158
-
159
-
160
  # -----------------------------------------> API <---------------------------------------
161
- tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
162
- model = GPTJForCausalLM.from_pretrained("Kanpredict/gptj-6b-8bits", low_cpu_mem_usage=True)
 
163
  device = 0 if torch.cuda.is_available() else -1
164
 
165
 
166
  class EndpointHandler:
167
  def __init__(self, path=""):
168
- # load the model
169
- model.to(device)
170
  # create inference pipeline
171
- self.pipeline = pipeline(model=model, tokenizer=tokenizer, device=device)
172
 
173
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
174
  inputs = data.pop("inputs", data)
@@ -184,7 +38,7 @@ class EndpointHandler:
184
  start = time.time()
185
  prompt = tokenizer(prompt, return_tensors='pt')
186
  prompt = {key: value.to(device) for key, value in prompt.items()}
187
- out = model.generate(**prompt, min_length=length, max_length=length, temperature=temperature, do_sample=True)
188
  generated_text = tokenizer.decode(out[0])
189
  logger.info("generated text: ", generated_text)
190
  logger.info("time taken: %s", time.time() - start)
 
12
  from typing import Dict, List, Any
13
 
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  # -----------------------------------------> API <---------------------------------------
16
+ name="Kanpredict/gptj-6b-8bits"
17
+ model = AutoModelForCausalLM.from_pretrained(name, device_map="auto", load_in_8bit=True)
18
+ tokenizer = AutoTokenizer.from_pretrained(name)
19
  device = 0 if torch.cuda.is_available() else -1
20
 
21
 
22
  class EndpointHandler:
23
  def __init__(self, path=""):
 
 
24
  # create inference pipeline
25
+ self.pipeline = pipeline(model=name, model_kwargs= {"device_map": "auto", "load_in_8bit": True}, max_new_tokens=max_new_tokens)
26
 
27
  def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
28
  inputs = data.pop("inputs", data)
 
38
  start = time.time()
39
  prompt = tokenizer(prompt, return_tensors='pt')
40
  prompt = {key: value.to(device) for key, value in prompt.items()}
41
+ out = self.pipeline(**prompt, min_length=length, max_length=length, temperature=temperature, do_sample=True)
42
  generated_text = tokenizer.decode(out[0])
43
  logger.info("generated text: ", generated_text)
44
  logger.info("time taken: %s", time.time() - start)