| | |
| | |
| | |
| |
|
| | import os |
| | import time |
| | import urllib.request |
| |
|
| | import torch |
| |
|
| | from model import Llama3Model, generate, text_to_token_ids, token_ids_to_text |
| | from tokenizer import Llama3Tokenizer, ChatFormat, clean_text |
| |
|
| | |
| | |
| |
|
| | MODEL_FILE = "llama3.2-1B-instruct.pth" |
| | |
| | |
| | |
| |
|
| | MODEL_CONTEXT_LENGTH = 8192 |
| |
|
| | |
| | if "instruct" in MODEL_FILE: |
| | PROMPT = "What do llamas eat?" |
| | else: |
| | PROMPT = "Llamas eat" |
| |
|
| | MAX_NEW_TOKENS = 150 |
| | TEMPERATURE = 0. |
| | TOP_K = 1 |
| | |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| | url = f"https://huggingface.co/rasbt/llama-3.2-from-scratch/resolve/main/{MODEL_FILE}" |
| |
|
| | if not os.path.exists(MODEL_FILE): |
| | print(f"Downloading {MODEL_FILE}...") |
| | urllib.request.urlretrieve(url, MODEL_FILE) |
| | print(f"Downloaded to {MODEL_FILE}") |
| |
|
| |
|
| | if "1B" in MODEL_FILE: |
| | from model import LLAMA32_CONFIG_1B as LLAMA32_CONFIG |
| | elif "3B" in MODEL_FILE: |
| | from model import LLAMA32_CONFIG_3B as LLAMA32_CONFIG |
| | else: |
| | raise ValueError("Incorrect model file name") |
| |
|
| | LLAMA32_CONFIG["context_length"] = MODEL_CONTEXT_LENGTH |
| |
|
| | model = Llama3Model(LLAMA32_CONFIG) |
| | model.load_state_dict(torch.load(MODEL_FILE, weights_only=True)) |
| |
|
| | device = ( |
| | torch.device("cuda") if torch.cuda.is_available() else |
| | torch.device("mps") if torch.backends.mps.is_available() else |
| | torch.device("cpu") |
| | ) |
| | model.to(device) |
| |
|
| | |
| | |
| | |
| | TOKENIZER_FILE = "tokenizer.model" |
| |
|
| | url = f"https://huggingface.co/rasbt/llama-3.2-from-scratch/resolve/main/{TOKENIZER_FILE}" |
| |
|
| | if not os.path.exists(TOKENIZER_FILE): |
| | urllib.request.urlretrieve(url, TOKENIZER_FILE) |
| | print(f"Downloaded to {TOKENIZER_FILE}") |
| | tokenizer = Llama3Tokenizer("tokenizer.model") |
| |
|
| | if "instruct" in MODEL_FILE: |
| | tokenizer = ChatFormat(tokenizer) |
| |
|
| | |
| | |
| | |
| |
|
| | torch.manual_seed(123) |
| |
|
| | start = time.time() |
| |
|
| | token_ids = generate( |
| | model=model, |
| | idx=text_to_token_ids(PROMPT, tokenizer).to(device), |
| | max_new_tokens=MAX_NEW_TOKENS, |
| | context_size=LLAMA32_CONFIG["context_length"], |
| | top_k=TOP_K, |
| | temperature=TEMPERATURE |
| | ) |
| |
|
| | print(f"Time: {time.time() - start:.2f} sec") |
| |
|
| | if torch.cuda.is_available(): |
| | max_mem_bytes = torch.cuda.max_memory_allocated() |
| | max_mem_gb = max_mem_bytes / (1024 ** 3) |
| | print(f"Max memory allocated: {max_mem_gb:.2f} GB") |
| |
|
| | output_text = token_ids_to_text(token_ids, tokenizer) |
| |
|
| | if "instruct" in MODEL_FILE: |
| | output_text = clean_text(output_text) |
| |
|
| | print("\n\nOutput text:\n\n", output_text) |
| |
|