sebaweis commited on
Commit
6edf637
1 Parent(s): c8b4d6c

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +55 -0
README.md CHANGED
@@ -46,3 +46,58 @@ python finetune.py \
46
  --eval-file code_eval.jsonl --wandb-project jerboa --wandb-log-model \
47
  --wandb-watch gradients --num-epochs 2
48
  ```
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
46
  --eval-file code_eval.jsonl --wandb-project jerboa --wandb-log-model \
47
  --wandb-watch gradients --num-epochs 2
48
  ```
49
+
50
+
51
+ ```Python
52
+ import torch
53
+ from transformers import AutoTokenizer, AutoModelForCausalLM
54
+
55
+
56
+ TOKENIZER_SOURCE = 'tiiuae/falcon-40b'
57
+ BASE_MODEL = 'jinaai/falcon-40b-code-alpaca'
58
+ DEVICE = "cuda"
59
+
60
+ PROMPT = """
61
+ Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
62
+
63
+ ### Instruction:
64
+ Write a for loop in python
65
+
66
+ ### Input:
67
+
68
+ ### Response:
69
+ """
70
+ model = AutoModelForCausalLM.from_pretrained(
71
+ pretrained_model_name_or_path=BASE_MODEL,
72
+ torch_dtype=torch.float16,
73
+ trust_remote_code=True,
74
+ device_map='auto',
75
+ )
76
+
77
+ model.eval()
78
+
79
+ tokenizer = AutoTokenizer.from_pretrained(
80
+ TOKENIZER_SOURCE,
81
+ trust_remote_code=True,
82
+ padding_side='left',
83
+ )
84
+ tokenizer.pad_token = tokenizer.eos_token
85
+
86
+ inputs = tokenizer(PROMPT, return_tensors="pt")
87
+ input_ids = inputs["input_ids"].to(DEVICE)
88
+ input_attention_mask = inputs["attention_mask"].to(DEVICE)
89
+
90
+ with torch.no_grad():
91
+ generation_output = model.generate(
92
+ input_ids=input_ids,
93
+ attention_mask=input_attention_mask,
94
+ return_dict_in_generate=True,
95
+ max_new_tokens=32,
96
+ eos_token_id=tokenizer.eos_token_id,
97
+ )
98
+ generation_output = generation_output.sequences[0]
99
+ output = tokenizer.decode(generation_output, skip_special_tokens=True)
100
+
101
+ print(output)
102
+
103
+ ```