test3 / app.py
bistdude's picture
Create app.py
fbd5a07 verified
raw
history blame contribute delete
717 Bytes
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
torch_device = "cuda" if torch.cuda.is_available() else "cpu"
#model_name = "gpt2"
model_name = "mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name, pad_token_id=tokenizer.eos_token_id).to(torch_device)
model_inputs = tokenizer('bad boy you ', return_tensors='pt').to(torch_device)
#output = model.generate(**model_inputs, max_new_tokens=50, do_sample=True, top_p=0.92, top_k=0, temperature=0.6)
output = model(**model_inputs).logits.argmax(axis=1)
print(tokenizer.decode(output[0],skip_special_tokens=True))