File size: 1,692 Bytes
98b0855
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
from tqdm import tqdm

# Other configuration options
DEVICE = "cuda:1"
NUM_RUNS = 10
MAX_NEW_TOKENS = 1000
TEXT_INPUT = "def sieve_of_eratosthenes():"

# Load the model and prepare generate args
repo_id = "gg-hf/gemma-2-2b-it"
model = AutoModelForCausalLM.from_pretrained(repo_id).to(DEVICE)
# model = AutoModelForCausalLM.from_pretrained(repo_id, device_map="auto", torch_dtype=torch.bfloat16)

assistant_model = None
tokenizer = AutoTokenizer.from_pretrained(repo_id, use_fast=True)
model_inputs = tokenizer(TEXT_INPUT, return_tensors="pt").to(DEVICE)

generate_kwargs = {
    "max_new_tokens": MAX_NEW_TOKENS,
    "do_sample": True,
    "temperature": 0.2,
    "eos_token_id": -1  # forces the generation of `max_new_tokens`
}

# Warmup
print("Warming up...")
for _ in range(2):
    gen_out = model.generate(**model_inputs, **generate_kwargs)
print("Done!")


# Measure OR Stream
def measure_generate(model, model_inputs, generate_kwargs):
    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)
    torch.cuda.reset_peak_memory_stats(DEVICE)
    torch.cuda.empty_cache()
    torch.cuda.synchronize()

    start_event.record()
    for _ in tqdm(range(NUM_RUNS)):
        gen_out = model.generate(**model_inputs, **generate_kwargs)
    end_event.record()

    torch.cuda.synchronize()
    max_memory = torch.cuda.max_memory_allocated(DEVICE)
    print("Max memory (MB): ", max_memory * 1e-6)
    print("Throughput (tokens/sec): ", (NUM_RUNS * MAX_NEW_TOKENS) / (start_event.elapsed_time(end_event) * 1.0e-3))

measure_generate(model, model_inputs, generate_kwargs)