File size: 1,847 Bytes
422dfc7
176fefc
422dfc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8067faa
35494e1
422dfc7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
import streamlit as st
from annotated_text import annotated_text
import transformers

ENTITY_TO_COLOR = {
    'DepositProduct': '#edff87',
    'Product': '#d586ff',
    'ProductProblemInfo': '#9886ff',
    'ServiceInformation': '#ff9886',
    'ServiceClosest': '#ff86b0',
    'Location': '#d461be',
    'ServiceNumber': '#f9cde4',
    'Brand': '#ffd4a4',
    'Campaign': '#bcffd8',
    'ProductSelector': '#fb5d4e',
    'SpecialCampaign': '#f56286',
}

@st.cache(allow_output_mutation=True, show_spinner=False)
def get_pipe():
    model_name = "pnr-svc/distilbert-turkish-ner"
    model = transformers.AutoModelForTokenClassification.from_pretrained(model_name)
    tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
    pipe = transformers.pipeline("token-classification", model=model, tokenizer=tokenizer, aggregation_strategy="simple")
    return pipe

def parse_text(text, prediction):
    start = 0
    parsed_text = []
    for p in prediction:
        parsed_text.append(text[start:p["start"]])
        parsed_text.append((p["word"], p["entity_group"], ENTITY_TO_COLOR[p["entity_group"]]))
        start = p["end"]
    parsed_text.append(text[start:])
    return parsed_text

st.set_page_config(page_title="NER ARÇELİK")
st.title("NER ARÇELİK")
st.write("Type text into the text box and then press 'Predict' to get the named entities.")

default_text = "tekirdağ çerkezköy arçelik yetkili servis no paylaş"

text = st.text_area('Enter text here:', value=default_text)
submit = st.button('Predict')

with st.spinner("Loading model..."):
    pipe = get_pipe()

if (submit and len(text.strip()) > 0) or len(text.strip()) > 0:

    prediction = pipe(text)

    parsed_text = parse_text(text, prediction)

    st.header("Prediction:")
    annotated_text(*parsed_text)

    st.header('Raw values:')
    st.json(prediction)