llm_test / src /brain.py
fschwartzer's picture
Update src/brain.py
d016c3b verified
raw
history blame
947 Bytes
from transformers import BertTokenizer, BertForSequenceClassification
tokenizer = BertTokenizer.from_pretrained('juridics/bertimbaulaw-base-portuguese-sts-scale')
model = BertForSequenceClassification.from_pretrained('juridics/bertimbaulaw-base-portuguese-sts-scale')
def generate_answers(query):
inputs = tokenizer(query, return_tensors='pt', padding='max_length', truncation=True, max_length=512)
attention_mask = inputs['attention_mask']
input_ids = inputs['input_ids']
generated_ids = model.generate(
input_ids,
attention_mask=attention_mask,
max_length=len(input_ids[0]) + 100, # Aumentar o limite de geração
temperature=0.7, # Ajustar a criatividade
top_p=0.9, # Usar nucleus sampling
no_repeat_ngram_size=2 # Evitar repetições desnecessárias
)
generated_text = tokenizer.decode(generated_ids[0], skip_special_tokens=True)
return generated_text