Spaces:
Runtime error
Runtime error
import gradio as gr | |
import torch | |
from transformers import BertForSequenceClassification, BertTokenizer | |
# Load the tokenizer from Hugging Face | |
token_model = "indolem/indobertweet-base-uncased" | |
tokenizer = BertTokenizer.from_pretrained(token_model) | |
# Define the model directory where your config.json and pytorch_model.bin are located | |
model_directory = "pretrained_arief.model" # Make sure this directory has config.json and pytorch_model.bin | |
# Load the model | |
# If your weights are named differently, ensure the file is named pytorch_model.bin or modify the loading method | |
model = BertForSequenceClassification.from_pretrained(model_directory) | |
model.eval() # Set the model to evaluation mode | |
# Check if CUDA is available and set the device accordingly | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device) | |
def classify_transaction(notes): | |
# Tokenize the input text | |
inputs = tokenizer.encode_plus( | |
notes, | |
None, | |
add_special_tokens=True, | |
max_length=256, | |
padding='max_length', | |
return_token_type_ids=False, | |
return_attention_mask=True, | |
truncation=True, | |
return_tensors='pt' | |
) | |
# Move tensors to the same device as the model | |
input_ids = inputs['input_ids'].to(device) | |
attention_mask = inputs['attention_mask'].to(device) | |
# Model in evaluation mode | |
model.eval() | |
# Make prediction | |
with torch.no_grad(): | |
outputs = model(input_ids, attention_mask=attention_mask) | |
# Extract logits and convert to probabilities | |
logits = outputs[0] | |
probabilities = torch.softmax(logits, dim=1) | |
# Get the predicted class | |
predicted_class = torch.argmax(probabilities, dim=1).cpu().numpy() | |
# Return the predicted class | |
return f"Predicted Category: {predicted_class}" | |
# Creating the Gradio interface | |
iface = gr.Interface( | |
fn=classify_transaction, | |
inputs=gr.Textbox(lines=3, placeholder="Enter Transaction Notes Here", label="Transaction Notes"), | |
outputs=gr.Text(label="Classification Result"), | |
title="Transaction Category Classifier", | |
description="Enter transaction notes to get the predicted category.", | |
live=True # Update the output as soon as the input changes | |
) | |
if __name__ == "__main__": | |
iface.launch() | |