from transformers import AutoTokenizer, AutoModelForCausalLM import torch from tqdm import tqdm import os os.environ["TOKENIZERS_PARALLELISM"] = "false" # To prevent long warnings :) torch.set_float32_matmul_precision('high') # Other configuration options DEVICE = "cuda:1" NUM_RUNS = 10 MAX_NEW_TOKENS = 256 TEXT_INPUT = "def sieve_of_eratosthenes():" # Load the model and prepare generate args repo_id = "gg-hf/gemma-2-2b-it" model = AutoModelForCausalLM.from_pretrained(repo_id).to(DEVICE) model.generation_config.cache_implementation = "static" model.forward = torch.compile(model.forward, mode="reduce-overhead", fullgraph=True) assistant_model = None tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True) model_inputs = tokenizer(TEXT_INPUT, return_tensors="pt").to(DEVICE) generate_kwargs = { "max_new_tokens": MAX_NEW_TOKENS, "do_sample": True, "temperature": 0.2, "eos_token_id": -1 # forces the generation of `max_new_tokens` } # Warmup print("Warming up...") for _ in range(2): gen_out = model.generate(**model_inputs, **generate_kwargs) print("Done!") # Measure OR Stream def measure_generate(model, model_inputs, generate_kwargs): start_event = torch.cuda.Event(enable_timing=True) end_event = torch.cuda.Event(enable_timing=True) torch.cuda.reset_peak_memory_stats(DEVICE) torch.cuda.empty_cache() torch.cuda.synchronize() start_event.record() for _ in tqdm(range(NUM_RUNS)): gen_out = model.generate(**model_inputs, **generate_kwargs) end_event.record() torch.cuda.synchronize() max_memory = torch.cuda.max_memory_allocated(DEVICE) print("Max memory (MB): ", max_memory * 1e-6) print("Throughput (tokens/sec): ", (NUM_RUNS * MAX_NEW_TOKENS) / (start_event.elapsed_time(end_event) * 1.0e-3)) measure_generate(model, model_inputs, generate_kwargs)