|
import streamlit as st |
|
from transformers import BertTokenizer, BertForSequenceClassification |
|
import torch |
|
import torch.nn.functional as F |
|
|
|
|
|
model_name = "TurkuNLP/bert-base-finnish-cased-v1" |
|
tokenizer = BertTokenizer.from_pretrained(model_name) |
|
model = BertForSequenceClassification.from_pretrained(model_name, num_labels=6) |
|
|
|
|
|
categories = ["Urakka sisältää", "Urakka ei sisältää", "Tilaajan velvoitteet", "Käytäntöjen tarkennukset", "Hintojen tarkennukset", "Muu"] |
|
|
|
|
|
def classify_lines(text): |
|
lines = text.split("\n") |
|
categorized_lines = {category: [] for category in categories} |
|
|
|
for line in lines: |
|
if line.strip(): |
|
inputs = tokenizer(line, return_tensors="pt", padding=True, truncation=True, max_length=512) |
|
outputs = model(**inputs) |
|
probs = F.softmax(outputs.logits, dim=1) |
|
predicted_category = torch.argmax(probs, dim=1).item() |
|
categorized_lines[categories[predicted_category]].append(line) |
|
|
|
return categorized_lines |
|
|
|
st.title("Finnish Contract Specifications Categorizer with TurkuNLP BERT") |
|
|
|
st.write("Enter the contract specifications in Finnish:") |
|
|
|
|
|
contract_text = st.text_area("Contract Specifications (Finnish):", height=300) |
|
|
|
if st.button("Classify"): |
|
if contract_text: |
|
categories = classify_lines(contract_text) |
|
|
|
st.write("Classified Contract Specifications:") |
|
|
|
for category, lines in categories.items(): |
|
st.write(f"### {category}") |
|
for line in lines: |
|
st.write(f"- {line}") |
|
else: |
|
st.write("Please enter the contract specifications.") |
|
|