File size: 1,941 Bytes
1298438
 
2a34e64
 
 
 
 
 
 
 
 
 
 
 
 
 
1298438
2a34e64
2549f5d
 
2a34e64
 
2549f5d
 
 
 
 
 
 
 
 
0252b79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2549f5d
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
---
license: apache-2.0
datasets:
- lambada
language:
- en
library_name: transformers
pipeline_tag: text-generation
tags:
- text-generation-inference
- causal-lm
- int8
- PyTorch
- PostTrainingStatic
- Intel® Neural Compressor
- neural-compressor
---
# INT8 GPT-J 6B

## Model Description
GPT-J 6B is a transformer model trained using Ben Wang's [Mesh Transformer JAX](https://github.com/kingoflolz/mesh-transformer-jax/). "GPT-J" refers to the class of model, while "6B" represents the number of trainable parameters.

This int8 PyTorch model is generated by [intel-extension-for-transformers](https://github.com/intel/intel-extension-for-transformers).
| Package       | Version      |
|----------------------|------------|
| intel-extension-for-transformers| a4aba8ddb07c9b744b6ac106502ec059e0c47960  |
| neural-compressor     | 2.4.1    |
| torch      | 2.1.0+cpu       |
| intel-extension-for-pytorch       | 2.1.0      |
| transformers    | 4.32.0         |

## Usage
Currently, we only support the method of downloading the model and then loading it. In this approach, the model files are downloaded from the server and stored locally on the user's machine. 
- Clone this model repository
```bash
# Make sure you have git-lfs installed (https://git-lfs.com)
git lfs install
git clone https://huggingface.co/Intel/gpt-j-6B-pytorch-int8-static
```
- Load int8 model
```python
from intel_extension_for_transformers.llm.evaluation.models import TSModelCausalLMForITREX

user_model = TSModelCausalLMForITREX.from_pretrained(
    args.output_dir, # Your saved path
    file_name="best_model.pt",
    trust_remote_code=args.trust_remote_code, # Default is False
)
```

## Evaluation results
Evaluating the accuracy of the optimized model of gpt-j-6b using the lambada_openai dataset in lm_eval.

| Dtype | Dataset | Precision |
|------  |--------|--------|
| FP32    |Lambada_openai     | 0.6831 |
| INT8    |Lambada_openai     | 0.6835 |