TinyStories-3M-val-Hebrew / TinyStories-3M-val-Hebrew-inference.py
Norod78's picture
Upload TinyStories-3M-val-Hebrew-inference.py
a5f152a
raw
history blame
2.91 kB
import argparse
import logging
import numpy as np
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
)
logger = logging.getLogger(__name__)
#model_id = "./TinyStories-3M-val-Hebrew"
model_id = "Norod78/TinyStories-3M-val-Hebrew"
tokenizer = AutoTokenizer.from_pretrained(model_id)
#model = AutoModelForCausalLM.from_pretrained("./Hebrew_GPT3_XL", from_tf=True)
model = AutoModelForCausalLM.from_pretrained(model_id)
#prompt_text = "אתמול, בדרך הביתה, גיליתי ש"
#prompt_text = "פעם, לפני ש"
#prompt_text = "הסוד השמור ביותר של תעשיית היופי"
#prompt_text = "<|startoftext|>"
prompt_text = "\n"
stop_token = "<|endoftext|>"
new_lines = "\n\n\n"
seed = 1000
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_gpu = 0 if torch.cuda.is_available()==False else torch.cuda.device_count()
logger.info(f"device: {device}, n_gpu: {n_gpu}")
np.random.seed(seed)
torch.manual_seed(seed)
if n_gpu > 0:
torch.cuda.manual_seed_all(seed)
model.to(device)
#model.half()
def process_output_sequences(output_sequences):
# Remove the batch dimension when returning multiple sequences
if len(output_sequences.shape) > 2:
output_sequences.squeeze_()
#generated_sequences = []
for generated_sequence_idx, generated_sequence in enumerate(output_sequences):
print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===")
generated_sequence = generated_sequence.tolist()
# Decode text
text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True)
text = text.replace("<|startoftext|>","").replace(" ; ", "\n")
# Remove all text after the stop token
text = text[: text.find(stop_token) if stop_token else None]
# Remove all text after 3 newlines
text = text[: text.find(new_lines) if new_lines else None]
print(text)
#generated_sequences.append(text)
#print(generated_sequences)
print("------")
def encode_prompt(text):
encoded_prompt = tokenizer.encode(
text, add_special_tokens=True, return_tensors="pt")
encoded_prompt = encoded_prompt.to(device)
if encoded_prompt.size()[-1] == 0:
input_ids = None
else:
input_ids = encoded_prompt
return input_ids
input_ids = encode_prompt(prompt_text)
input_ids_len = input_ids.size()[-1]
max_length = input_ids_len + 192
if max_length > 1023:
max_length = 1023
output_sequences = model.generate(
input_ids=input_ids,
max_length=max_length,
temperature=0.98,
top_k=40,
top_p=0.92,
repetition_penalty=2.0,
do_sample=True,
num_return_sequences=5
)
process_output_sequences(output_sequences)