Text Generation
Transformers
PyTorch
mpt
Composer
MosaicML
llm-foundry
custom_code
text-generation-inference
abhi-mosaic commited on
Commit
0fc7cfd
1 Parent(s): ea98a59
Files changed (1) hide show
  1. README.md +22 -19
README.md CHANGED
@@ -49,34 +49,37 @@ model = transformers.AutoModelForCausalLM.from_pretrained(
49
  )
50
  ```
51
 
52
- To use the optimized [triton implementation](https://github.com/openai/triton) of FlashAttention, you can load the model with `attn_impl='triton'` and move the model to `bfloat16`:
53
  ```python
54
- config = transformers.AutoConfig.from_pretrained(
55
- 'mosaicml/mpt-7b-storywriter',
56
- trust_remote_code=True
57
- )
 
 
58
  config.attn_config['attn_impl'] = 'triton'
 
59
 
60
  model = transformers.AutoModelForCausalLM.from_pretrained(
61
- 'mosaicml/mpt-7b-storywriter',
62
  config=config,
63
- torch_dtype=torch.bfloat16,
64
  trust_remote_code=True
65
  )
66
- model.to(device='cuda:0')
67
  ```
68
 
69
- Although the model was trained with a sequence length of 2048 and finetuned with a sequence length of 65536,
70
  ALiBi enables users to increase the maximum sequence length during finetuning and/or inference. For example:
71
-
72
  ```python
73
- config = transformers.AutoConfig.from_pretrained(
74
- 'mosaicml/mpt-7b-storywriter',
75
- trust_remote_code=True
76
- )
77
- config.update({"max_seq_len": 83968})
 
 
78
  model = transformers.AutoModelForCausalLM.from_pretrained(
79
- 'mosaicml/mpt-7b-storywriter',
80
  config=config,
81
  trust_remote_code=True
82
  )
@@ -155,8 +158,8 @@ The data was tokenized using the [EleutherAI/gpt-neox-20b](https://huggingface.c
155
 
156
  ### Training Configuration
157
 
158
- This model was trained on 8 A100-80GBs for about 2 days using the [MosaicML Platform](https://www.mosaicml.com/platform).
159
- The model was trained with sharded data parallelism using [FSDP](https://pytorch.org/docs/stable/fsdp.html) and used the [LION](https://arxiv.org/abs/2302.06675) optimizer.
160
 
161
  ## Limitations and Biases
162
 
@@ -193,4 +196,4 @@ Please cite this model using the following format:
193
  note = {Accessed: 2023-03-28}, % change this date
194
  urldate = {2023-03-28} % change this date
195
  }
196
- ```
 
49
  )
50
  ```
51
 
52
+ To use the optimized [triton implementation](https://github.com/openai/triton) of FlashAttention, you can load the model on GPU (`cuda:0`) with `attn_impl='triton'` and with `bfloat16` precision:
53
  ```python
54
+ import torch
55
+ import transformers
56
+
57
+ name = 'mosaicml/mpt-7b-storywriter'
58
+
59
+ config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True)
60
  config.attn_config['attn_impl'] = 'triton'
61
+ config.init_device = 'cuda:0' # For fast initialization directly on GPU!
62
 
63
  model = transformers.AutoModelForCausalLM.from_pretrained(
64
+ name,
65
  config=config,
66
+ torch_dtype=torch.bfloat16, # Load model weights in bfloat16
67
  trust_remote_code=True
68
  )
 
69
  ```
70
 
71
+ Although the model was trained with a sequence length of 2048 and finetuned with a sequence length of 65536,
72
  ALiBi enables users to increase the maximum sequence length during finetuning and/or inference. For example:
 
73
  ```python
74
+ import transformers
75
+
76
+ name = 'mosaicml/mpt-7b'
77
+
78
+ config = transformers.AutoConfig.from_pretrained(name, trust_remote_code=True)
79
+ config.max_seq_len = 83968 # (input + output) tokens can now be up to 83968
80
+
81
  model = transformers.AutoModelForCausalLM.from_pretrained(
82
+ name,
83
  config=config,
84
  trust_remote_code=True
85
  )
 
158
 
159
  ### Training Configuration
160
 
161
+ This model was trained on 8 A100-80GBs for about 2 days using the [MosaicML Platform](https://www.mosaicml.com/platform).
162
+ The model was trained with sharded data parallelism using [FSDP](https://pytorch.org/docs/stable/fsdp.html) and used the [LION](https://arxiv.org/abs/2302.06675) optimizer.
163
 
164
  ## Limitations and Biases
165
 
 
196
  note = {Accessed: 2023-03-28}, % change this date
197
  urldate = {2023-03-28} % change this date
198
  }
199
+ ```