test / app.py
Nuthanon's picture
Update app.py
a57d36b verified
raw
history blame
1.78 kB
import streamlit as st
from transformers import BertTokenizer, BertForSequenceClassification
import torch
import torch.nn.functional as F
# Load the tokenizer and model
model_name = "TurkuNLP/bert-base-finnish-cased-v1"
tokenizer = BertTokenizer.from_pretrained(model_name)
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=6) # Assuming 6 categories
# Define categories
categories = ["Urakka sisältää", "Urakka ei sisältää", "Tilaajan velvoitteet", "Käytäntöjen tarkennukset", "Hintojen tarkennukset", "Muu"]
# Function to classify lines
def classify_lines(text):
lines = text.split("\n")
categorized_lines = {category: [] for category in categories}
for line in lines:
if line.strip(): # Skip empty lines
inputs = tokenizer(line, return_tensors="pt", padding=True, truncation=True, max_length=512)
outputs = model(**inputs)
probs = F.softmax(outputs.logits, dim=1)
predicted_category = torch.argmax(probs, dim=1).item()
categorized_lines[categories[predicted_category]].append(line)
return categorized_lines
st.title("Finnish Contract Specifications Categorizer with TurkuNLP BERT")
st.write("Enter the contract specifications in Finnish:")
# Text area for large text input
contract_text = st.text_area("Contract Specifications (Finnish):", height=300)
if st.button("Classify"):
if contract_text:
categories = classify_lines(contract_text)
st.write("Classified Contract Specifications:")
for category, lines in categories.items():
st.write(f"### {category}")
for line in lines:
st.write(f"- {line}")
else:
st.write("Please enter the contract specifications.")