Update README.md
Browse files
README.md
CHANGED
@@ -55,11 +55,47 @@ Use the code in [this repository](https://github.com/chtmp223/suri) for training
|
|
55 |
| optim | adamw_torch |
|
56 |
| per_device_train_batch_size | 1 |
|
57 |
|
58 |
-
|
59 |
-
#### 🤗 Software
|
60 |
|
61 |
Training code is adapted from [Alignment Handbook](https://github.com/huggingface/alignment-handbook) and [Trl](https://github.com/huggingface/trl).
|
62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
## 📜 Citation
|
64 |
|
65 |
```
|
|
|
55 |
| optim | adamw_torch |
|
56 |
| per_device_train_batch_size | 1 |
|
57 |
|
58 |
+
#### Software
|
|
|
59 |
|
60 |
Training code is adapted from [Alignment Handbook](https://github.com/huggingface/alignment-handbook) and [Trl](https://github.com/huggingface/trl).
|
61 |
|
62 |
+
## 🤗 Inference
|
63 |
+
|
64 |
+
```
|
65 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
66 |
+
from peft import PeftModel, PeftConfig
|
67 |
+
from datasets import load_dataset
|
68 |
+
import torch
|
69 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "False"
|
70 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
71 |
+
torch.cuda.empty_cache()
|
72 |
+
|
73 |
+
model_name = "chtmp223/suri-i-orpo"
|
74 |
+
base_model_name = "mistralai/Mistral-7B-Instruct-v0.2"
|
75 |
+
config = PeftConfig.from_pretrained(model_name)
|
76 |
+
base_model = AutoModelForCausalLM.from_pretrained(base_model_name).to(device)
|
77 |
+
model = PeftModel.from_pretrained(base_model, model_name).to(device)
|
78 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model_name)
|
79 |
+
prompt = [
|
80 |
+
{
|
81 |
+
"role": "user",
|
82 |
+
"content": user_prompt,
|
83 |
+
}
|
84 |
+
]
|
85 |
+
input_context = tokenizer.apply_chat_template(
|
86 |
+
prompt, add_generation_prompt=True, tokenize=False
|
87 |
+
)
|
88 |
+
input_ids = tokenizer.encode(
|
89 |
+
input_context, return_tensors="pt", add_special_tokens=False
|
90 |
+
).to(model.device)
|
91 |
+
output = model.generate(
|
92 |
+
input_ids, max_length=10000, do_sample=True, use_cache=True
|
93 |
+
).cpu()
|
94 |
+
|
95 |
+
print(tokenizer.decode(output[0]))
|
96 |
+
```
|
97 |
+
|
98 |
+
|
99 |
## 📜 Citation
|
100 |
|
101 |
```
|