qwen-2-coreml-ane / coreml_example.py
seba's picture
python sample generation script
edfe3bc verified
raw
history blame
3.04 kB
import time
import numpy as np
from argparse import ArgumentParser
from transformers import AutoTokenizer
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B")
parser = ArgumentParser()
parser.add_argument("--model_path", "--model-path", required=True)
parser.add_argument("--prompt", "-p", required=True)
parser.add_argument("--max-tokens", "--max_tokens", type=int, default=100)
parser.add_argument("--min_p", "--min-p", type=float, default=0.3)
parser.add_argument("--temp", type=float, default=1.0)
args = parser.parse_args()
import coremltools as ct
print("Loading model...")
if args.model_path.rstrip("/").endswith(".mlpackage"):
mf_model_1 = ct.models.MLModel(
args.model_path,
compute_units=ct.ComputeUnit.CPU_AND_NE,
function_name="length_1",
)
mf_model_64 = ct.models.MLModel(
args.model_path,
compute_units=ct.ComputeUnit.CPU_AND_NE,
function_name="length_64",
)
else:
mf_model_1 = ct.models.CompiledMLModel(
args.model_path,
compute_units=ct.ComputeUnit.CPU_AND_NE,
function_name="length_1",
)
mf_model_64 = ct.models.CompiledMLModel(
args.model_path,
compute_units=ct.ComputeUnit.CPU_AND_NE,
function_name="length_64",
)
def min_p_sample(logits, min_p, temp):
# logits = logits.astype(np.float16)
max_ = np.max(logits * (1 / temp), axis=1, keepdims=True)
logits = logits - max_
logits = np.exp(logits)
logits[logits < min_p] = 0
# logits = logits.astype(np.float32)
logits = np.cumsum(logits, axis=1)
sample = np.random.uniform(high=logits[:, -1:])
sample = np.argmax(logits > sample, axis=1).astype(np.int32)
return sample
length = len(tokenizer(args.prompt)["input_ids"])
input_ids = tokenizer(
args.prompt, return_tensors="np", padding="max_length", max_length=64
)["input_ids"].astype(np.int32)
print("Prompt:", args.prompt)
state = mf_model_64.make_state()
start = time.time()
pred = mf_model_64.predict(
{"input_ids": input_ids, "query_pos1": np.array([0], dtype=np.int32)}, state
)
prompt_time = time.time() - start
# input_ids = pred["logits"][..., length - 1].argmax(1, keepdims=True).astype(np.int32)
logits = pred["logits"][..., [length - 1]]
input_ids = min_p_sample(logits, args.min_p, args.temp)
print("Generated:")
print(tokenizer.decode(input_ids[0]), end="", flush=True)
start = time.time()
for i in range(args.max_tokens):
pred = mf_model_1.predict(
{"input_ids": input_ids, "query_pos1": np.array([i + length], dtype=np.int32)},
state,
)
input_ids = min_p_sample(pred["logits"], args.min_p, args.temp)
# input_ids = pred["logits"].argmax(1).astype(np.int32)
print(tokenizer.decode(input_ids[0]), end="", flush=True)
print("", "=" * 10)
generation_time = time.time() - start
print(
"Prompt:",
length / prompt_time,
"tokens-per-sec",
f"({64 / prompt_time} considering the processed padding)",
)
print("Generation:", args.max_tokens / generation_time, "tokens-per-sec")