random-files / test-tok-sec.py
reach-vb's picture
reach-vb HF staff
Create test-tok-sec.py (#5)
98b0855 verified
raw
history blame contribute delete
No virus
1.69 kB
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from tqdm import tqdm
# Other configuration options
DEVICE = "cuda:1"
NUM_RUNS = 10
MAX_NEW_TOKENS = 1000
TEXT_INPUT = "def sieve_of_eratosthenes():"
# Load the model and prepare generate args
repo_id = "gg-hf/gemma-2-2b-it"
model = AutoModelForCausalLM.from_pretrained(repo_id).to(DEVICE)
# model = AutoModelForCausalLM.from_pretrained(repo_id, device_map="auto", torch_dtype=torch.bfloat16)
assistant_model = None
tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True)
model_inputs = tokenizer(TEXT_INPUT, return_tensors="pt").to(DEVICE)
generate_kwargs = {
"max_new_tokens": MAX_NEW_TOKENS,
"do_sample": True,
"temperature": 0.2,
"eos_token_id": -1 # forces the generation of `max_new_tokens`
}
# Warmup
print("Warming up...")
for _ in range(2):
gen_out = model.generate(**model_inputs, **generate_kwargs)
print("Done!")
# Measure OR Stream
def measure_generate(model, model_inputs, generate_kwargs):
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
torch.cuda.reset_peak_memory_stats(DEVICE)
torch.cuda.empty_cache()
torch.cuda.synchronize()
start_event.record()
for _ in tqdm(range(NUM_RUNS)):
gen_out = model.generate(**model_inputs, **generate_kwargs)
end_event.record()
torch.cuda.synchronize()
max_memory = torch.cuda.max_memory_allocated(DEVICE)
print("Max memory (MB): ", max_memory * 1e-6)
print("Throughput (tokens/sec): ", (NUM_RUNS * MAX_NEW_TOKENS) / (start_event.elapsed_time(end_event) * 1.0e-3))
measure_generate(model, model_inputs, generate_kwargs)