STEM-AI-mtl commited on
Commit
019dbde
1 Parent(s): 095ce9f

Upload chat-GPTQ.py

Browse files
Files changed (1) hide show
  1. chat-GPTQ.py +53 -0
chat-GPTQ.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from peft import PeftModel, PeftConfig
3
+ from transformers import AutoModelForCausalLM, AutoTokenizer
4
+ import warnings
5
+ import os
6
+
7
+ # Suppress INFO and WARNING messages from TensorFlow
8
+ os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
9
+ warnings.filterwarnings("ignore", category=UserWarning, module='transformers.generation.utils')
10
+
11
+ def load_model_and_tokenizer():
12
+ base_model = "TheBloke/phi-2-GPTQ"
13
+ peft_model_id = "STEM-AI-mtl/phi-2-electrical-engineering"
14
+ config = PeftConfig.from_pretrained(peft_model_id, trust_remote_code=True)
15
+ model = AutoModelForCausalLM.from_pretrained(base_model, device_map="cuda:0",return_dict=True, trust_remote_code=True)
16
+
17
+ model = model.to('cuda')
18
+
19
+ tokenizer = AutoTokenizer.from_pretrained(base_model)
20
+ model = PeftModel.from_pretrained(model, peft_model_id, trust_remote_code=True)
21
+
22
+ model = model.to('cuda')
23
+
24
+ return model, tokenizer
25
+
26
+ def generate(instruction, model, tokenizer):
27
+ inputs = tokenizer(instruction, return_tensors="pt", return_attention_mask=False)
28
+ inputs = inputs.to('cuda')
29
+ outputs = model.generate(
30
+ **inputs,
31
+ max_length=350,
32
+ do_sample=True,
33
+ temperature=0.7,
34
+ top_k=50,
35
+ top_p=0.9,
36
+ repetition_penalty=1,
37
+ )
38
+ text = tokenizer.batch_decode(outputs)[0]
39
+ return text
40
+
41
+
42
+ if __name__ == '__main__':
43
+ model, tokenizer = load_model_and_tokenizer()
44
+ while True:
45
+ instruction = input("Enter your instruction: ")
46
+ if not instruction:
47
+ continue
48
+ if instruction.lower() in ["exit", "quit", "exit()", "quit()"]:
49
+ print("Exiting...")
50
+ break
51
+
52
+ answer = generate(instruction, model, tokenizer)
53
+ print(f'Answer: {answer}')