File size: 3,074 Bytes
e3d46c8
 
8158997
 
e3d46c8
0788ae6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56b99e8
0788ae6
e3d46c8
56b99e8
0788ae6
a85495d
 
 
 
b28d4fd
093cd61
68d6aa3
 
 
 
 
8158997
68d6aa3
 
 
 
 
8158997
68d6aa3
 
 
 
b28d4fd
68d6aa3
 
 
b28d4fd
bceabb4
 
 
68d6aa3
bceabb4
 
 
68d6aa3
bceabb4
0788ae6
ed6dd13
f7ac541
23bc4d7
 
bceabb4
68d6aa3
 
 
0788ae6
8158997
1af96b9
5e15a3b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import streamlit as st
from transformers import pipeline
import torch
import matplotlib.pyplot as plt

#pipe = pipeline(model="RuudVelo/dutch_news_classifier_bert_finetuned")
#text = st.text_area('Please type/copy/paste the Dutch article')

#labels = ['Binnenland' 'Buitenland' 'Cultuur & Media' 'Economie' 'Koningshuis'
# 'Opmerkelijk' 'Politiek' 'Regionaal nieuws' 'Tech']

#if text:
#   out = pipe(text)
#   st.json(out)
   
   
   # load tokenizer and model, create trainer
  #model_name = "RuudVelo/dutch_news_classifier_bert_finetuned"
  #tokenizer = AutoTokenizer.from_pretrained(model_name)
  #model = AutoModelForSequenceClassification.from_pretrained(model_name)
  #trainer = Trainer(model=model)  
  #print(filename, type(filename))
  #print(filename.name)
  
from transformers import BertForSequenceClassification, BertTokenizer

model = BertForSequenceClassification.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned")
#from transformers import BertTokenizer

tokenizer = BertTokenizer.from_pretrained("RuudVelo/dutch_news_clf_bert_finetuned")


# Title
st.title("Dutch news article classification")

text = st.text_area('Please type/copy/paste text of the Dutch article')

#if text:
#   encoding = tokenizer(text, return_tensors="pt")
#   outputs = model(**encoding)
#   predictions = outputs.logits.argmax(-1)
#   probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
   
##   fig = plt.figure()
#   ax = fig.add_axes([0,0,1,1])
#   labels_plot = ['Binnenland', 'Buitenland' ,'Cultuur & Media' ,'Economie' ,'Koningshuis',
# 'Opmerkelijk' ,'Politiek', 'Regionaal nieuws', 'Tech']
#   probs_plot = probabilities[0].cpu().detach().numpy()
   
#   ax.barh(labels_plot,probs_plot )
#   st.pyplot(fig)


#input = st.text_input('Context')

if st.button('Submit'):
    with st.spinner('Generating a response...'):
        encoding = tokenizer(text, return_tensors="pt")
        outputs = model(**encoding)
        predictions = outputs.logits.argmax(-1)
        probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
   
        fig = plt.figure()
        ax = fig.add_axes([0,0,1,1])
        labels_plot = ['Binnenland', 'Buitenland' ,'Cultuur & Media' ,'Economie' ,'Koningshuis',
 'Opmerkelijk' ,'Politiek', 'Regionaal nieuws', 'Tech']
        probs_plot = probabilities[0].cpu().detach().numpy()
   
        ax.barh(labels_plot,probs_plot)
        ax.set_title("Predicted article category probability")
        ax.set_xlabel("Probability")
        ax.set_ylabel("Predicted category")
        st.pyplot(fig)
   #     output = genQuestion(option, input)
   #     print(output)
   # st.write(output)   
#encoding = tokenizer(text, return_tensors="pt")
#import numpy as np
st.write("The model for this app has been trained using data from Dutch news articles published by NOS. For more information regarding the dataset can be found at https://www.kaggle.com/maxscheijen/dutch-news-articles")
st.write('\n')
st.write('The model performance details can be found at https://huggingface.co/RuudVelo/dutch_news_classifier_bert_finetuned')