vives commited on
Commit
2349e64
·
1 Parent(s): d56e301

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +105 -1
app.py CHANGED
@@ -2,6 +2,11 @@ from transformers import AutoModelForMaskedLM
2
  from transformers import AutoTokenizer
3
  import spacy
4
  import pytextrank
 
 
 
 
 
5
 
6
  model_checkpoint = "vives/distilbert-base-uncased-finetuned-cvent-2019_2022"
7
  model = AutoModelForMaskedLM.from_pretrained(model_checkpoint, output_hidden_states=True)
@@ -13,4 +18,103 @@ POS = ["NOUN", "PROPN", "VERB"]
13
 
14
  nlp = spacy.load("en_core_web_sm")
15
  nlp.add_pipe("textrank", last=True, config={"pos_kept": POS, "token_lookback": 3})
16
- all_stopwords = nlp.Defaults.stop_words
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
  from transformers import AutoTokenizer
3
  import spacy
4
  import pytextrank
5
+ from nlp_entities import *
6
+ import torch
7
+ import streamlit as st
8
+ from sklearn.metrics.pairwise import cosine_similarity
9
+ from collections import defaultdict
10
 
11
  model_checkpoint = "vives/distilbert-base-uncased-finetuned-cvent-2019_2022"
12
  model = AutoModelForMaskedLM.from_pretrained(model_checkpoint, output_hidden_states=True)
 
18
 
19
  nlp = spacy.load("en_core_web_sm")
20
  nlp.add_pipe("textrank", last=True, config={"pos_kept": POS, "token_lookback": 3})
21
+ all_stopwords = nlp.Defaults.stop_words
22
+
23
+ #streamlit stuff
24
+ tags = st.text_input("Input tags separated by commas")
25
+ text = st.text_input("Input text to classify")
26
+ #Methods for tag processing
27
+ def pool_embeddings(out, tok):
28
+ embeddings = out["hidden_states"][-1]
29
+ attention_mask = tok['attention_mask']
30
+ mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
31
+ masked_embeddings = embeddings * mask
32
+ summed = torch.sum(masked_embeddings, 1)
33
+ summed_mask = torch.clamp(mask.sum(1), min=1e-9)
34
+ mean_pooled = summed / summed_mask
35
+ return mean_pooled
36
+ import pandas as pd
37
+
38
+ def get_transcript(file):
39
+ data = pd.io.json.read_json(file)
40
+ transcript = data['results'].values[1][0]['transcript']
41
+ transcript = transcript.lower()
42
+ return transcript
43
+ #
44
+ """preprocess tags"""
45
+ if tags:
46
+ tags = [x.lower().strip() for x in tags.split(",")]
47
+ tags_tokens = concat_tokens(tags)
48
+ tags_tokens.pop("KPS")
49
+ with torch.no_grad():
50
+ outputs_tags = model(**tags_tokens)
51
+ pools_tags = pool_embeddings(outputs_tags, tags_tokens).detach().numpy()
52
+ token_dict = {}
53
+ for tag,embedding in zip(tags,pools_tags):
54
+ token_dict[tag] = embedding
55
+
56
+ """Code related with processing text, extracting KPs, and doing distance to tag"""
57
+ def concat_tokens(sentences):
58
+ tokens = {'input_ids': [], 'attention_mask': [], 'KPS': {}}
59
+ for sentence, values in sentences.items():
60
+ weight = values['weight']
61
+ # encode each sentence and append to dictionary
62
+ new_tokens = tokenizer.encode_plus(sentence, max_length=64,
63
+ truncation=True, padding='max_length',
64
+ return_tensors='pt')
65
+ tokens['input_ids'].append(new_tokens['input_ids'][0])
66
+ tokens['attention_mask'].append(new_tokens['attention_mask'][0])
67
+ tokens['KPS'][sentence] = weight
68
+ # reformat list of tensors into single tensor
69
+ tokens['input_ids'] = torch.stack(tokens['input_ids'])
70
+ tokens['attention_mask'] = torch.stack(tokens['attention_mask'])
71
+ return tokens
72
+
73
+ def calculate_weighted_embed_dist(out, tokens, weight, text,kp_dict, idx, exclude_text=False,exclude_words=False):
74
+ sim_dict = {}
75
+ pools = pool_embeddings_count(out, tokens, idx).detach().numpy()
76
+ for key in kp_dict.keys():
77
+ if exclude_text and text in key:
78
+ continue
79
+ if exclude_words and True in [x in key for x in text.split(" ")]:
80
+ continue
81
+
82
+ sim_dict[key] = cosine_similarity(
83
+ pools,
84
+ [kp_dict[key]]
85
+ )[0][0] * weight
86
+ return sim_dict
87
+ def pool_embeddings_count(out, tok, idx):
88
+ embeddings = out["hidden_states"][-1][idx:idx+1,:,:]
89
+ attention_mask = tok['attention_mask'][idx]
90
+ mask = attention_mask.unsqueeze(-1).expand(embeddings.size()).float()
91
+ masked_embeddings = embeddings * mask
92
+ summed = torch.sum(masked_embeddings, 1)
93
+ summed_mask = torch.clamp(mask.sum(1), min=1e-9)
94
+ mean_pooled = summed / summed_mask
95
+ return mean_pooled
96
+ import pandas as pd
97
+ def extract_tokens(text,top_kp=30):
98
+ kps = return_ners_and_kp([text], ret_ne=True)['KP']
99
+ #only process the top_kp tokens
100
+ kps = sorted(kps.items(), key= lambda x: x[1]['weight'], reverse = True)[:top_kp]
101
+ kps = {x:y for x,y in kps}
102
+ return concat_tokens(kps)
103
+
104
+ """Process text and classify it"""
105
+ if text and tags:
106
+ text = text.lower()
107
+ t1_tokens = extract_tokens(text)
108
+ t1_kps = t1_tokens.pop("KPS")
109
+ with torch.no_grad():
110
+ outputs = model(**t1_tokens)
111
+ tag_distance = None
112
+ for i,kp in enumerate(t1_kps):
113
+ if tag_distance is None:
114
+ tag_distance = calculate_weighted_embed_dist(outputs, t1_tokens,t1_kps[kp], kp, token_dict,i,exclude_text=False,exclude_words=False)
115
+ else:
116
+ curr = calculate_weighted_embed_dist(outputs, t1_tokens,t1_kps[kp], kp, token_dict,i,exclude_text=False,exclude_words=False)
117
+ tag_distance = {x:tag_distance[x] + curr[x] for x in tag_distance.keys()}
118
+ tag_distance = sorted(tag_distance.items(), key= lambda x: x[1], reverse = True)
119
+ tag_distance = {x:y for x,y in tag_distance}
120
+ st.json(tag_distance)