llm_test / src /brain.py
fschwartzer's picture
Update src/brain.py
745a539 verified
raw
history blame
408 Bytes
from transformers import AutoTokenizer, AutoModelForCausalLM
model_name = "google/gemma-2b"
tokenizer = AutoTokenizer.from_pretrained(model_name, token=access_token)
model = AutoModelForCausalLM.from_pretrained(model_name, token=access_token)
def generate_answers(query):
input_ids = tokenizer(query, return_tensors="pt")
output = model.generate(**input_ids)
return tokenizer.decode(output[0])