OssamaLafhel commited on
Commit
3e86372
1 Parent(s): f30c581

Upload 3 files

Browse files
Files changed (3) hide show
  1. gpt-j-6b-8-bit.py +265 -0
  2. handler.py +179 -0
  3. requirements.txt +8 -0
gpt-j-6b-8-bit.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ finetune-gpt-j-6B-8bit.ipynb
4
+ https://colab.research.google.com/drive/1ft6wQU0BhqG5PRlwgaZJv2VukKKjU4Es
5
+ ### Fine-tuning 6-Billion GPT-J in colab with LoRA and 8-bit compression
6
+ (https://huggingface.co/EleutherAI/gpt-j-6B) with limited memory. A
7
+ https://huggingface.co/hivemind/gpt-j-6B-8bit)
8
+ This notebook is a proof of concept for fine-tuning
9
+ [GPT-J-6B](https://huggingface.co/EleutherAI/gpt-j-6B) with limited memory.
10
+ A detailed explanation of how it works can be found in [this model card]
11
+ (https://huggingface.co/hivemind/gpt-j-6B-8bit).
12
+ """
13
+
14
+ from loguru import logger
15
+ import transformers
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from torch import nn
19
+ from torch.cuda.amp import custom_fwd, custom_bwd
20
+ from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
21
+ from tqdm.auto import tqdm
22
+ from datasets import load_dataset
23
+ from bitsandbytes.optim import Adam8bit
24
+ import time, os
25
+
26
+ # ---------------------> Converting the model to 8 bits <------------------- #
27
+ """
28
+ We convert EleutherAI's GPT-J-6B model to 8 bits using facebook's [bitsandbytes](https://github.com/facebookresearch/bitsandbytes) library.
29
+ This reduces the model's size from 20Gb down to just 6Gb.
30
+ Note that we don't convert linear layer biases to 8 bit as they take up less that 1% of the model's weight anyway.
31
+ """
32
+
33
+ class FrozenBNBLinear(nn.Module):
34
+ def __init__(self, weight, absmax, code, bias=None):
35
+ assert isinstance(bias, nn.Parameter) or bias is None
36
+ super().__init__()
37
+ self.out_features, self.in_features = weight.shape
38
+ self.register_buffer("weight", weight.requires_grad_(False))
39
+ self.register_buffer("absmax", absmax.requires_grad_(False))
40
+ self.register_buffer("code", code.requires_grad_(False))
41
+ self.adapter = None
42
+ self.bias = bias
43
+
44
+ def forward(self, input):
45
+ output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)
46
+ if self.adapter:
47
+ output = output + self.adapter(input)
48
+ return output
49
+
50
+ @classmethod
51
+ def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
52
+ weights_int8, state = quantize_blockise_lowmemory(linear.weight)
53
+ return cls(weights_int8, *state, linear.bias)
54
+
55
+ def __repr__(self):
56
+ return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
57
+
58
+
59
+
60
+ class DequantizeAndLinear(torch.autograd.Function):
61
+ @staticmethod
62
+ @custom_fwd
63
+ def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
64
+ absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
65
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
66
+ ctx.save_for_backward(input, weights_quantized, absmax, code)
67
+ ctx._has_bias = bias is not None
68
+ return F.linear(input, weights_deq, bias)
69
+
70
+ @staticmethod
71
+ @custom_bwd
72
+ def backward(ctx, grad_output: torch.Tensor):
73
+ assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
74
+ input, weights_quantized, absmax, code = ctx.saved_tensors
75
+ # grad_output: [*batch, out_features]
76
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
77
+ grad_input = grad_output @ weights_deq
78
+ grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
79
+ return grad_input, None, None, None, grad_bias
80
+
81
+
82
+ class FrozenBNBEmbedding(nn.Module):
83
+ def __init__(self, weight, absmax, code):
84
+ super().__init__()
85
+ self.num_embeddings, self.embedding_dim = weight.shape
86
+ self.register_buffer("weight", weight.requires_grad_(False))
87
+ self.register_buffer("absmax", absmax.requires_grad_(False))
88
+ self.register_buffer("code", code.requires_grad_(False))
89
+ self.adapter = None
90
+
91
+ def forward(self, input, **kwargs):
92
+ with torch.no_grad():
93
+ # note: both quantuized weights and input indices are *not* differentiable
94
+ weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
95
+ output = F.embedding(input, weight_deq, **kwargs)
96
+ if self.adapter:
97
+ output += self.adapter(input)
98
+ return output
99
+
100
+ @classmethod
101
+ def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
102
+ weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
103
+ return cls(weights_int8, *state)
104
+
105
+ def __repr__(self):
106
+ return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
107
+
108
+ def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
109
+ assert chunk_size % 4096 == 0
110
+ code = None
111
+ chunks = []
112
+ absmaxes = []
113
+ flat_tensor = matrix.view(-1)
114
+ for i in range((matrix.numel() - 1) // chunk_size + 1):
115
+ input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
116
+ quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
117
+ chunks.append(quantized_chunk)
118
+ absmaxes.append(absmax_chunk)
119
+
120
+ matrix_i8 = torch.cat(chunks).reshape_as(matrix)
121
+ absmax = torch.cat(absmaxes)
122
+ return matrix_i8, (absmax, code)
123
+
124
+
125
+ def convert_to_int8(model):
126
+ """Convert linear and embedding modules to 8-bit with optional adapters"""
127
+ for module in list(model.modules()):
128
+ for name, child in module.named_children():
129
+ if isinstance(child, nn.Linear):
130
+ print(name, child)
131
+ setattr(
132
+ module,
133
+ name,
134
+ FrozenBNBLinear(
135
+ weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
136
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
137
+ code=torch.zeros(256),
138
+ bias=child.bias,
139
+ ),
140
+ )
141
+ elif isinstance(child, nn.Embedding):
142
+ setattr(
143
+ module,
144
+ name,
145
+ FrozenBNBEmbedding(
146
+ weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
147
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
148
+ code=torch.zeros(256),
149
+ )
150
+ )
151
+
152
+ class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
153
+ def __init__(self, config):
154
+ super().__init__(config)
155
+
156
+ convert_to_int8(self.attn)
157
+ convert_to_int8(self.mlp)
158
+
159
+
160
+ class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
161
+ def __init__(self, config):
162
+ super().__init__(config)
163
+ convert_to_int8(self)
164
+
165
+
166
+ class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
167
+ def __init__(self, config):
168
+ super().__init__(config)
169
+ convert_to_int8(self)
170
+
171
+
172
+ transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock # monkey-patch GPT-J
173
+
174
+ # ---------------------> Loading EleutherAI/gpt-j-6B config and tokenizer <------------------- #
175
+ config = transformers.GPTJConfig.from_pretrained("EleutherAI/gpt-j-6B")
176
+ tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
177
+
178
+ # ---------------------> Downloading gpt-j-6B-8bit model from huggingface <------------------- #
179
+ #gpt = GPTJForCausalLM.from_pretrained("hivemind/gpt-j-6B-8bit")
180
+
181
+ # ----------------> Saving gpt-j-6B-8bit model to server <-----------------#
182
+ #save_dir = "./saved_models_gpt-j-6B-8bit/gpt-j-6B"
183
+ #gpt.save_pretrained(save_dir)
184
+ #logger.info("Saved model to {}".format(save_dir))
185
+
186
+ # ---------------------> Loading saved gpt-j-6B-8bit model <------------------- #
187
+ gpt = GPTJForCausalLM.from_pretrained("./saved_models_gpt-j-6B-8bit/gpt-j-6B")
188
+
189
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
190
+ gpt.to(device)
191
+
192
+ # ---------------------> Text generation example <------------------- #
193
+ prompt = tokenizer("A cat sat on a mat", return_tensors='pt')
194
+ prompt = {key: value.to(device) for key, value in prompt.items()}
195
+ out = gpt.generate(**prompt, min_length=128, max_length=128, do_sample=True)
196
+ logger.info("Generated text: {}".format(tokenizer.decode(out[0])))
197
+
198
+
199
+ # ---------------------> LoRA fine-tuning example <------------------- #
200
+
201
+ def add_adapters(model, adapter_dim=16):
202
+ assert adapter_dim > 0
203
+
204
+ for module in model.modules():
205
+ if isinstance(module, FrozenBNBLinear):
206
+ module.adapter = nn.Sequential(
207
+ nn.Linear(module.in_features, adapter_dim, bias=False),
208
+ nn.Linear(adapter_dim, module.out_features, bias=False),
209
+ )
210
+ nn.init.zeros_(module.adapter[1].weight)
211
+ elif isinstance(module, FrozenBNBEmbedding):
212
+ module.adapter = nn.Sequential(
213
+ nn.Embedding(module.num_embeddings, adapter_dim),
214
+ nn.Linear(adapter_dim, module.embedding_dim, bias=False),
215
+ )
216
+ nn.init.zeros_(module.adapter[1].weight)
217
+
218
+ add_adapters(gpt)
219
+ gpt.to(device)
220
+ gpt.gradient_checkpointing_enable()
221
+
222
+ # example dataset
223
+ data_files = {"train": "data.jsonl"}
224
+ dataset = load_dataset('nomic-ai/gpt4all_prompt_generations_with_p3', data_files=data_files)
225
+ prompt_response_separator = " response: "
226
+
227
+ def concatenate_prompt_response(row):
228
+ row["text"] = "prompt: " + row["prompt"] + prompt_response_separator + row["response"]
229
+ return row
230
+
231
+ dataset = dataset.map(concatenate_prompt_response, remove_columns=["prompt", "response"])
232
+
233
+ # custom dataset
234
+ #dataset = load_dataset('text', data_files={'train': ['article-1.txt', 'article-2.txt'], 'test': ['article-3.txt', 'article-4.txt']})
235
+
236
+ optimizer = Adam8bit(gpt.parameters(), lr=1e-5)
237
+
238
+ # Set the model to training mode
239
+ start = time.time()
240
+
241
+ # Training loop
242
+ with torch.cuda.amp.autocast():
243
+ for row in tqdm(dataset["train"]):
244
+ if len(row["text"]) <= 1:
245
+ continue
246
+ batch = tokenizer(row["text"], truncation=True, max_length=128, return_tensors='pt')
247
+ batch = {k: v.cuda() for k, v in batch.items()}
248
+ out = gpt.forward(**batch,)
249
+ loss = F.cross_entropy(out.logits[:, :-1, :].flatten(0, -2), batch['input_ids'][:, 1:].flatten(),
250
+ reduction='mean')
251
+ print(loss)
252
+ loss.backward()
253
+ optimizer.step()
254
+ optimizer.zero_grad()
255
+
256
+ logger.info("Finished fine-tuning in {}".format(time.time() - start))
257
+
258
+ # --------------> Saving fine-tuned model <-----------------#
259
+ try:
260
+ save_dir = "./finetuned_gpt-j-8_bit"
261
+ os.makedirs(save_dir)
262
+ gpt.save_pretrained(save_dir)
263
+ except Exception as e:
264
+ #print("Error saving model: ", e)
265
+ logger.info("Error saving model: {}".format(e))
handler.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import transformers
2
+ from transformers import pipeline
3
+ import torch
4
+ from torch import nn
5
+ import torch.nn.functional as F
6
+ from torch.cuda.amp import custom_fwd, custom_bwd
7
+ from bitsandbytes.functional import quantize_blockwise, dequantize_blockwise
8
+ from typing import Dict, List, Any
9
+
10
+
11
+ # ---------------------> Converting the model to 8 bits <------------------- #
12
+
13
+ class FrozenBNBLinear(nn.Module):
14
+ def __init__(self, weight, absmax, code, bias=None):
15
+ assert isinstance(bias, nn.Parameter) or bias is None
16
+ super().__init__()
17
+ self.out_features, self.in_features = weight.shape
18
+ self.register_buffer("weight", weight.requires_grad_(False))
19
+ self.register_buffer("absmax", absmax.requires_grad_(False))
20
+ self.register_buffer("code", code.requires_grad_(False))
21
+ self.adapter = None
22
+ self.bias = bias
23
+
24
+ def forward(self, input):
25
+ output = DequantizeAndLinear.apply(input, self.weight, self.absmax, self.code, self.bias)
26
+ if self.adapter:
27
+ output += self.adapter(input)
28
+ return output
29
+
30
+ @classmethod
31
+ def from_linear(cls, linear: nn.Linear) -> "FrozenBNBLinear":
32
+ weights_int8, state = quantize_blockise_lowmemory(linear.weight)
33
+ return cls(weights_int8, *state, linear.bias)
34
+
35
+ def __repr__(self):
36
+ return f"{self.__class__.__name__}({self.in_features}, {self.out_features})"
37
+
38
+
39
+ class DequantizeAndLinear(torch.autograd.Function):
40
+ @staticmethod
41
+ @custom_fwd
42
+ def forward(ctx, input: torch.Tensor, weights_quantized: torch.ByteTensor,
43
+ absmax: torch.FloatTensor, code: torch.FloatTensor, bias: torch.FloatTensor):
44
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
45
+ ctx.save_for_backward(input, weights_quantized, absmax, code)
46
+ ctx._has_bias = bias is not None
47
+ return F.linear(input, weights_deq, bias)
48
+
49
+ @staticmethod
50
+ @custom_bwd
51
+ def backward(ctx, grad_output: torch.Tensor):
52
+ assert not ctx.needs_input_grad[1] and not ctx.needs_input_grad[2] and not ctx.needs_input_grad[3]
53
+ input, weights_quantized, absmax, code = ctx.saved_tensors
54
+ # grad_output: [*batch, out_features]
55
+ weights_deq = dequantize_blockwise(weights_quantized, absmax=absmax, code=code)
56
+ grad_input = grad_output @ weights_deq
57
+ grad_bias = grad_output.flatten(0, -2).sum(dim=0) if ctx._has_bias else None
58
+ return grad_input, None, None, None, grad_bias
59
+
60
+
61
+ class FrozenBNBEmbedding(nn.Module):
62
+ def __init__(self, weight, absmax, code):
63
+ super().__init__()
64
+ self.num_embeddings, self.embedding_dim = weight.shape
65
+ self.register_buffer("weight", weight.requires_grad_(False))
66
+ self.register_buffer("absmax", absmax.requires_grad_(False))
67
+ self.register_buffer("code", code.requires_grad_(False))
68
+ self.adapter = None
69
+
70
+ def forward(self, input, **kwargs):
71
+ with torch.no_grad():
72
+ # note: both quantuized weights and input indices are *not* differentiable
73
+ weight_deq = dequantize_blockwise(self.weight, absmax=self.absmax, code=self.code)
74
+ output = F.embedding(input, weight_deq, **kwargs)
75
+ if self.adapter:
76
+ output += self.adapter(input)
77
+ return output
78
+
79
+ @classmethod
80
+ def from_embedding(cls, embedding: nn.Embedding) -> "FrozenBNBEmbedding":
81
+ weights_int8, state = quantize_blockise_lowmemory(embedding.weight)
82
+ return cls(weights_int8, *state)
83
+
84
+ def __repr__(self):
85
+ return f"{self.__class__.__name__}({self.num_embeddings}, {self.embedding_dim})"
86
+
87
+
88
+ def quantize_blockise_lowmemory(matrix: torch.Tensor, chunk_size: int = 2 ** 20):
89
+ assert chunk_size % 4096 == 0
90
+ code = None
91
+ chunks = []
92
+ absmaxes = []
93
+ flat_tensor = matrix.view(-1)
94
+ for i in range((matrix.numel() - 1) // chunk_size + 1):
95
+ input_chunk = flat_tensor[i * chunk_size: (i + 1) * chunk_size].clone()
96
+ quantized_chunk, (absmax_chunk, code) = quantize_blockwise(input_chunk, code=code)
97
+ chunks.append(quantized_chunk)
98
+ absmaxes.append(absmax_chunk)
99
+
100
+ matrix_i8 = torch.cat(chunks).reshape_as(matrix)
101
+ absmax = torch.cat(absmaxes)
102
+ return matrix_i8, (absmax, code)
103
+
104
+
105
+ def convert_to_int8(model):
106
+ """Convert linear and embedding modules to 8-bit with optional adapters"""
107
+ for module in list(model.modules()):
108
+ for name, child in module.named_children():
109
+ if isinstance(child, nn.Linear):
110
+ print(name, child)
111
+ setattr(
112
+ module,
113
+ name,
114
+ FrozenBNBLinear(
115
+ weight=torch.zeros(child.out_features, child.in_features, dtype=torch.uint8),
116
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
117
+ code=torch.zeros(256),
118
+ bias=child.bias,
119
+ ),
120
+ )
121
+ elif isinstance(child, nn.Embedding):
122
+ setattr(
123
+ module,
124
+ name,
125
+ FrozenBNBEmbedding(
126
+ weight=torch.zeros(child.num_embeddings, child.embedding_dim, dtype=torch.uint8),
127
+ absmax=torch.zeros((child.weight.numel() - 1) // 4096 + 1),
128
+ code=torch.zeros(256),
129
+ )
130
+ )
131
+
132
+
133
+ class GPTJBlock(transformers.models.gptj.modeling_gptj.GPTJBlock):
134
+ def __init__(self, config):
135
+ super().__init__(config)
136
+
137
+ convert_to_int8(self.attn)
138
+ convert_to_int8(self.mlp)
139
+
140
+
141
+ class GPTJModel(transformers.models.gptj.modeling_gptj.GPTJModel):
142
+ def __init__(self, config):
143
+ super().__init__(config)
144
+ convert_to_int8(self)
145
+
146
+
147
+ class GPTJForCausalLM(transformers.models.gptj.modeling_gptj.GPTJForCausalLM):
148
+ def __init__(self, config):
149
+ super().__init__(config)
150
+ convert_to_int8(self)
151
+
152
+
153
+ transformers.models.gptj.modeling_gptj.GPTJBlock = GPTJBlock # monkey-patch GPT-J
154
+
155
+
156
+ # -----------------------------------------> API <---------------------------------------
157
+
158
+
159
+ class EndpointHandler:
160
+ def __init__(self, path=""):
161
+ # load the model
162
+ tokenizer = transformers.AutoTokenizer.from_pretrained("EleutherAI/gpt-j-6B")
163
+ model = GPTJForCausalLM.from_pretrained(path, low_cpu_mem_usage=True)
164
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
165
+ model.to(device)
166
+ # create inference pipeline
167
+ self.pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer, device=device)
168
+
169
+ def __call__(self, data: Any) -> List[List[Dict[str, float]]]:
170
+ inputs = data.pop("inputs", data)
171
+ parameters = data.pop("parameters", None)
172
+
173
+ # pass inputs with all kwargs in data
174
+ if parameters is not None:
175
+ prediction = self.pipeline(inputs, **parameters)
176
+ else:
177
+ prediction = self.pipeline(inputs)
178
+ # postprocess the prediction
179
+ return prediction
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ fastapi==0.95.0
2
+ uvicorn==0.21.1
3
+ transformers==4.27.4
4
+ torch==2.0.0
5
+ requests==2.28.2
6
+ pydantic~=1.10.7
7
+ loguru==0.5.3
8
+ bitsandbytes-cuda111