Update README.md
Browse files
README.md
CHANGED
@@ -46,3 +46,58 @@ python finetune.py \
|
|
46 |
--eval-file code_eval.jsonl --wandb-project jerboa --wandb-log-model \
|
47 |
--wandb-watch gradients --num-epochs 2
|
48 |
```
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
--eval-file code_eval.jsonl --wandb-project jerboa --wandb-log-model \
|
47 |
--wandb-watch gradients --num-epochs 2
|
48 |
```
|
49 |
+
|
50 |
+
|
51 |
+
```Python
|
52 |
+
import torch
|
53 |
+
from transformers import AutoTokenizer, AutoModelForCausalLM
|
54 |
+
|
55 |
+
|
56 |
+
TOKENIZER_SOURCE = 'tiiuae/falcon-40b'
|
57 |
+
BASE_MODEL = 'jinaai/falcon-40b-code-alpaca'
|
58 |
+
DEVICE = "cuda"
|
59 |
+
|
60 |
+
PROMPT = """
|
61 |
+
Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
62 |
+
|
63 |
+
### Instruction:
|
64 |
+
Write a for loop in python
|
65 |
+
|
66 |
+
### Input:
|
67 |
+
|
68 |
+
### Response:
|
69 |
+
"""
|
70 |
+
model = AutoModelForCausalLM.from_pretrained(
|
71 |
+
pretrained_model_name_or_path=BASE_MODEL,
|
72 |
+
torch_dtype=torch.float16,
|
73 |
+
trust_remote_code=True,
|
74 |
+
device_map='auto',
|
75 |
+
)
|
76 |
+
|
77 |
+
model.eval()
|
78 |
+
|
79 |
+
tokenizer = AutoTokenizer.from_pretrained(
|
80 |
+
TOKENIZER_SOURCE,
|
81 |
+
trust_remote_code=True,
|
82 |
+
padding_side='left',
|
83 |
+
)
|
84 |
+
tokenizer.pad_token = tokenizer.eos_token
|
85 |
+
|
86 |
+
inputs = tokenizer(PROMPT, return_tensors="pt")
|
87 |
+
input_ids = inputs["input_ids"].to(DEVICE)
|
88 |
+
input_attention_mask = inputs["attention_mask"].to(DEVICE)
|
89 |
+
|
90 |
+
with torch.no_grad():
|
91 |
+
generation_output = model.generate(
|
92 |
+
input_ids=input_ids,
|
93 |
+
attention_mask=input_attention_mask,
|
94 |
+
return_dict_in_generate=True,
|
95 |
+
max_new_tokens=32,
|
96 |
+
eos_token_id=tokenizer.eos_token_id,
|
97 |
+
)
|
98 |
+
generation_output = generation_output.sequences[0]
|
99 |
+
output = tokenizer.decode(generation_output, skip_special_tokens=True)
|
100 |
+
|
101 |
+
print(output)
|
102 |
+
|
103 |
+
```
|