jester6136 commited on
Commit
7ced1c2
·
verified ·
1 Parent(s): 7464b54

Upload llama3-awq-onnx-qa.py

Browse files
Files changed (1) hide show
  1. llama3-awq-onnx-qa.py +93 -0
llama3-awq-onnx-qa.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import onnxruntime_genai as og
2
+ import argparse
3
+ import time
4
+
5
+ def main(args):
6
+ if args.verbose: print("Loading model...")
7
+ if args.timings:
8
+ started_timestamp = 0
9
+ first_token_timestamp = 0
10
+
11
+ model = og.Model(f'{args.model}')
12
+ if args.verbose: print("Model loaded")
13
+ tokenizer = og.Tokenizer(model)
14
+ tokenizer_stream = tokenizer.create_stream()
15
+ if args.verbose: print("Tokenizer created")
16
+ if args.verbose: print()
17
+ search_options = {name:getattr(args, name) for name in ['do_sample', 'max_length', 'min_length', 'top_p', 'top_k', 'temperature', 'repetition_penalty'] if name in args}
18
+
19
+ # Set the max length to something sensible by default, unless it is specified by the user,
20
+ # since otherwise it will be set to the entire context length
21
+ if 'max_length' not in search_options:
22
+ search_options['max_length'] = 2048
23
+
24
+ chat_template = '<|start_header_id|>user<|end_header_id|>\n{input}<|eot_id|><|start_header_id|>assistant<|end_header_id|>'
25
+
26
+ # Keep asking for input prompts in a loop
27
+ while True:
28
+ text = input("Input: ")
29
+ if not text:
30
+ print("Error, input cannot be empty")
31
+ continue
32
+
33
+ if args.timings: started_timestamp = time.time()
34
+
35
+ # If there is a chat template, use it
36
+ prompt = f'{chat_template.format(input=text)}'
37
+
38
+ input_tokens = tokenizer.encode(prompt)
39
+
40
+ params = og.GeneratorParams(model)
41
+ params.set_search_options(**search_options)
42
+ params.input_ids = input_tokens
43
+ generator = og.Generator(model, params)
44
+ if args.verbose: print("Generator created")
45
+
46
+ if args.verbose: print("Running generation loop ...")
47
+ if args.timings:
48
+ first = True
49
+ new_tokens = []
50
+
51
+ print()
52
+ print("Output: ", end='', flush=True)
53
+
54
+ try:
55
+ while not generator.is_done():
56
+ generator.compute_logits()
57
+ generator.generate_next_token()
58
+ if args.timings:
59
+ if first:
60
+ first_token_timestamp = time.time()
61
+ first = False
62
+
63
+ new_token = generator.get_next_tokens()[0]
64
+ print(tokenizer_stream.decode(new_token), end='', flush=True)
65
+ if args.timings: new_tokens.append(new_token)
66
+ except KeyboardInterrupt:
67
+ print(" --control+c pressed, aborting generation--")
68
+ print()
69
+ print()
70
+
71
+ # Delete the generator to free the captured graph for the next generator, if graph capture is enabled
72
+ del generator
73
+
74
+ if args.timings:
75
+ prompt_time = first_token_timestamp - started_timestamp
76
+ run_time = time.time() - first_token_timestamp
77
+ print(f"Prompt length: {len(input_tokens)}, New tokens: {len(new_tokens)}, Time to first: {(prompt_time):.2f}s, Prompt tokens per second: {len(input_tokens)/prompt_time:.2f} tps, New tokens per second: {len(new_tokens)/run_time:.2f} tps")
78
+
79
+
80
+ if __name__ == "__main__":
81
+ parser = argparse.ArgumentParser(argument_default=argparse.SUPPRESS, description="End-to-end AI Question/Answer example for gen-ai")
82
+ parser.add_argument('-m', '--model', type=str, required=True, help='Onnx model folder path (must contain config.json and model.onnx)')
83
+ parser.add_argument('-i', '--min_length', type=int, help='Min number of tokens to generate including the prompt')
84
+ parser.add_argument('-l', '--max_length', type=int, help='Max number of tokens to generate including the prompt')
85
+ parser.add_argument('-ds', '--do_sample', action='store_true', default=False, help='Do random sampling. When false, greedy or beam search are used to generate the output. Defaults to false')
86
+ parser.add_argument('-p', '--top_p', type=float, help='Top p probability to sample with')
87
+ parser.add_argument('-k', '--top_k', type=int, help='Top k tokens to sample from')
88
+ parser.add_argument('-t', '--temperature', type=float, help='Temperature to sample with')
89
+ parser.add_argument('-r', '--repetition_penalty', type=float, help='Repetition penalty to sample with')
90
+ parser.add_argument('-v', '--verbose', action='store_true', default=False, help='Print verbose output and timing information. Defaults to false')
91
+ parser.add_argument('-g', '--timings', action='store_true', default=False, help='Print timing information for each generation step. Defaults to false')
92
+ args = parser.parse_args()
93
+ main(args)