File size: 1,941 Bytes
1298438 2a34e64 1298438 2a34e64 2549f5d 2a34e64 2549f5d 0252b79 2549f5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
---
license: apache-2.0
datasets:
- lambada
language:
- en
library_name: transformers
pipeline_tag: text-generation
tags:
- text-generation-inference
- causal-lm
- int8
- PyTorch
- PostTrainingStatic
- Intel® Neural Compressor
- neural-compressor
---
# INT8 GPT-J 6B
## Model Description
GPT-J 6B is a transformer model trained using Ben Wang's [Mesh Transformer JAX](https://github.com/kingoflolz/mesh-transformer-jax/). "GPT-J" refers to the class of model, while "6B" represents the number of trainable parameters.
This int8 PyTorch model is generated by [intel-extension-for-transformers](https://github.com/intel/intel-extension-for-transformers).
| Package | Version |
|----------------------|------------|
| intel-extension-for-transformers| a4aba8ddb07c9b744b6ac106502ec059e0c47960 |
| neural-compressor | 2.4.1 |
| torch | 2.1.0+cpu |
| intel-extension-for-pytorch | 2.1.0 |
| transformers | 4.32.0 |
## Usage
Currently, we only support the method of downloading the model and then loading it. In this approach, the model files are downloaded from the server and stored locally on the user's machine.
- Clone this model repository
```bash
# Make sure you have git-lfs installed (https://git-lfs.com)
git lfs install
git clone https://huggingface.co/Intel/gpt-j-6B-pytorch-int8-static
```
- Load int8 model
```python
from intel_extension_for_transformers.llm.evaluation.models import TSModelCausalLMForITREX
user_model = TSModelCausalLMForITREX.from_pretrained(
args.output_dir, # Your saved path
file_name="best_model.pt",
trust_remote_code=args.trust_remote_code, # Default is False
)
```
## Evaluation results
Evaluating the accuracy of the optimized model of gpt-j-6b using the lambada_openai dataset in lm_eval.
| Dtype | Dataset | Precision |
|------ |--------|--------|
| FP32 |Lambada_openai | 0.6831 |
| INT8 |Lambada_openai | 0.6835 | |