|
--- |
|
library_name: transformers |
|
license: gpl |
|
datasets: |
|
- DenyTranDFW/SEC_10K_FSNoNDS_Zip |
|
language: |
|
- en |
|
base_model: |
|
- openai-community/gpt2 |
|
--- |
|
|
|
# Model Card for gpt2-next-tag-prediction |
|
|
|
import torch |
|
from transformers import AutoModelForCausalLM, GPT2Tokenizer |
|
|
|
#LOAD THE MODEL |
|
model = AutoModelForCausalLM.from_pretrained('DenyTranDFW/gpt2-next-tag-prediction') |
|
|
|
#AFTER REVIEWING THE [CODE](https://www.kaggle.com/code/denytran/gpt2-next-tag-prediction-train/settings), IT LOOKS LIKE I FORGOT TO UPLOAD THE TOKENIZER, |
|
#PLEASE USE GPT2'S TOKENIZER |
|
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') |
|
|
|
prompt = "AssetsCurrent" |
|
input_ids = tokenizer.encode(prompt, return_tensors='pt') |
|
|
|
with torch.no_grad(): |
|
outputs = model(input_ids) |
|
predictions = outputs.logits[:, -1, :] |
|
|
|
|
|
predicted_index = predictions.argmax(-1).item() |
|
predicted_word = tokenizer.decode(predicted_index) |
|
|
|
print(f"Prompt: {prompt}") |
|
print(f"Predicted next word: {predicted_word}") |
|
|
|
|
|
![predict.png](https://cdn-uploads.huggingface.co/production/uploads/664fd752780b951aa03abade/q2_BnitOyncxHvs7-fERT.png) |