|
import argparse |
|
import logging |
|
|
|
import numpy as np |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
level=logging.INFO, |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
model_id = "Norod78/TinyStories-3M-val-Hebrew" |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_id) |
|
|
|
model = AutoModelForCausalLM.from_pretrained(model_id) |
|
|
|
|
|
|
|
|
|
|
|
prompt_text = "\n" |
|
stop_token = "<|endoftext|>" |
|
new_lines = "\n\n\n" |
|
seed = 1000 |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
n_gpu = 0 if torch.cuda.is_available()==False else torch.cuda.device_count() |
|
|
|
logger.info(f"device: {device}, n_gpu: {n_gpu}") |
|
|
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
if n_gpu > 0: |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
model.to(device) |
|
|
|
|
|
def process_output_sequences(output_sequences): |
|
|
|
if len(output_sequences.shape) > 2: |
|
output_sequences.squeeze_() |
|
|
|
|
|
|
|
for generated_sequence_idx, generated_sequence in enumerate(output_sequences): |
|
print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===") |
|
generated_sequence = generated_sequence.tolist() |
|
|
|
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) |
|
text = text.replace("<|startoftext|>","").replace(" ; ", "\n") |
|
|
|
text = text[: text.find(stop_token) if stop_token else None] |
|
|
|
text = text[: text.find(new_lines) if new_lines else None] |
|
print(text) |
|
|
|
|
|
print("------") |
|
|
|
|
|
def encode_prompt(text): |
|
encoded_prompt = tokenizer.encode( |
|
text, add_special_tokens=True, return_tensors="pt") |
|
encoded_prompt = encoded_prompt.to(device) |
|
if encoded_prompt.size()[-1] == 0: |
|
input_ids = None |
|
else: |
|
input_ids = encoded_prompt |
|
return input_ids |
|
|
|
input_ids = encode_prompt(prompt_text) |
|
input_ids_len = input_ids.size()[-1] |
|
max_length = input_ids_len + 192 |
|
if max_length > 1023: |
|
max_length = 1023 |
|
|
|
output_sequences = model.generate( |
|
input_ids=input_ids, |
|
max_length=max_length, |
|
temperature=0.98, |
|
top_k=40, |
|
top_p=0.92, |
|
repetition_penalty=2.0, |
|
do_sample=True, |
|
num_return_sequences=5 |
|
) |
|
|
|
process_output_sequences(output_sequences) |
|
|
|
|
|
|