Aryan-401 commited on
Commit
45d2cb3
1 Parent(s): 57f8e94

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +14 -12
README.md CHANGED
@@ -18,28 +18,30 @@ This model was trained using AutoTrain. For more information, please visit [Auto
18
 
19
  # Usage
20
 
21
- ```python
 
 
22
 
 
23
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
24
 
25
- model_path = "Aryan-401/phi-3-mini-4k-instruct-finetune-guanaco"
 
26
 
27
- tokenizer = AutoTokenizer.from_pretrained(model_path)
28
- model = AutoModelForCausalLM.from_pretrained(
29
- model_path,
30
- device_map="auto",
31
- torch_dtype='auto'
32
- ).eval()
33
 
34
- # Prompt content: "hi"
35
  messages = [
36
- {"role": "user", "content": "hi"}
37
  ]
 
 
 
38
 
39
  input_ids = tokenizer.apply_chat_template(conversation=messages, tokenize=True, add_generation_prompt=True, return_tensors='pt')
40
- output_ids = model.generate(input_ids.to('cuda'))
41
  response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
42
 
43
- # Model response: "Hello! How can I assist you today?"
44
  print(response)
45
  ```
 
18
 
19
  # Usage
20
 
21
+ ```bash
22
+ pip install peft
23
+ ```
24
 
25
+ ```python
26
  from transformers import AutoModelForCausalLM, AutoTokenizer
27
+ from peft import AutoPeftModelForCausalLM, PeftConfig
28
 
29
+ model_id = "Aryan-401/phi-3-mini-4k-instruct-finetune-guanaco"
30
+ peft_model=AutoPeftModelForCausalLM.from_pretrained(model_id)
31
 
32
+ model = peft_model.merge_and_unload()
33
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
 
 
 
 
34
 
 
35
  messages = [
36
+ {"role": "user", "content": "What is the Value of Pi?"}
37
  ]
38
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
39
+
40
+ model = model.to(device).eval()
41
 
42
  input_ids = tokenizer.apply_chat_template(conversation=messages, tokenize=True, add_generation_prompt=True, return_tensors='pt')
43
+ output_ids = model.generate(input_ids.to(device), max_length= 1000)
44
  response = tokenizer.decode(output_ids[0][input_ids.shape[1]:], skip_special_tokens=True)
45
 
 
46
  print(response)
47
  ```