File size: 1,156 Bytes
81e28a1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
import os
from transformers import AutoModel
from accelerate import Accelerator, init_empty_weights
from accelerate.utils import BnbQuantizationConfig, load_and_quantize_model

# Make sure transformers works offline
os.environ["TRANSFORMERS_OFFLINE"] = "1"

# 1. Initialize the empty model
model_fp32 = AutoModel.from_pretrained("./models/all-MiniLM-L6-v2")
with init_empty_weights():
    empty_model = model_fp32

# 2. Get the path to the weights of your model. For now, we'll assume it's in the same folder.
weights_location = "./models/all-MiniLM-L6-v2-unquantized/pytorch_model.bin"

# 3. Set quantization configuration (8-bit for this example)
bnb_quantization_config = BnbQuantizationConfig(load_in_8bit=True, llm_int8_threshold=6)

# 4. Quantize the empty model
quantized_model = load_and_quantize_model(empty_model, weights_location=weights_location,
                                          bnb_quantization_config=bnb_quantization_config, device_map="auto")

# 5. Save the quantized model
accelerator = Accelerator()
new_weights_location = "./models/all-MiniLM-L6-v2-unquantized-q8"
accelerator.save_model(quantized_model, new_weights_location)