Spaces:
Sleeping
Sleeping
madhavkotecha
commited on
Update app.py
Browse files
app.py
CHANGED
@@ -15,18 +15,20 @@ nltk.download('universal_tagset')
|
|
15 |
|
16 |
class CRF_POS_Tagger:
|
17 |
def __init__(self, train=False):
|
|
|
18 |
self.corpus = nltk.corpus.brown.tagged_sents(tagset='universal')
|
|
|
19 |
self.corpus = [[(word.lower(), tag) for word, tag in sentence] for sentence in self.corpus]
|
20 |
self.actual_tag = []
|
21 |
self.predicted_tag = []
|
22 |
self.prefixes = [
|
23 |
-
"a", "anti", "auto", "bi", "co", "dis", "en", "em", "ex", "in", "im",
|
24 |
"inter", "mis", "non", "over", "pre", "re", "sub", "trans", "un", "under"
|
25 |
]
|
26 |
|
27 |
self.suffixes = [
|
28 |
-
"able", "ible", "al", "ance", "ence", "dom", "er", "or", "ful", "hood",
|
29 |
-
"ic", "ing", "ion", "tion", "ity", "ty", "ive", "less", "ly", "ment",
|
30 |
"ness", "ous", "ship", "y", "es", "s"
|
31 |
]
|
32 |
|
@@ -35,16 +37,17 @@ class CRF_POS_Tagger:
|
|
35 |
|
36 |
self.X = [[self.word_features(sentence, i) for i in range(len(sentence))] for sentence in self.corpus]
|
37 |
self.y = [[postag for _, postag in sentence] for sentence in self.corpus]
|
38 |
-
|
39 |
self.split = int(0.8 * len(self.X))
|
40 |
self.X_train = self.X[:self.split]
|
41 |
self.y_train = self.y[:self.split]
|
42 |
self.X_test = self.X[self.split:]
|
43 |
self.y_test = self.y[self.split:]
|
|
|
44 |
self.crf_model = sklearn_crfsuite.CRF(algorithm='lbfgs', c1=0.1, c2=0.1, max_iterations=100, all_possible_transitions=True)
|
|
|
45 |
if train:
|
46 |
-
|
47 |
-
self.train(data)
|
48 |
|
49 |
def word_splitter(self, word):
|
50 |
prefix = ""
|
@@ -62,7 +65,7 @@ class CRF_POS_Tagger:
|
|
62 |
stem = stem[: -len(suffix)]
|
63 |
|
64 |
return prefix, stem, suffix
|
65 |
-
|
66 |
# Define a function to extract features for each word in a sentence
|
67 |
def word_features(self, sentence, i):
|
68 |
word = sentence[i][0]
|
@@ -79,7 +82,7 @@ class CRF_POS_Tagger:
|
|
79 |
# 'is_capitalized': word[0].upper() == word[0],
|
80 |
'is_all_caps': word.isupper(), #word is in uppercase
|
81 |
'is_all_lower': word.islower(), #word is in lowercase
|
82 |
-
|
83 |
'prefix-1': word[0],
|
84 |
'prefix-2': word[:2],
|
85 |
'prefix-3': word[:3],
|
@@ -97,31 +100,31 @@ class CRF_POS_Tagger:
|
|
97 |
'prefix-de': word[:3] == 'de', #if word starts with de
|
98 |
'prefix-in': word[:3] == 'in', #if word starts with in
|
99 |
'prefix-en': word[:3] == 'en', #if word starts with en
|
100 |
-
|
101 |
'suffix-ed': word[-2:] == 'ed', #if word ends with ed
|
102 |
'suffix-ing': word[-3:] == 'ing', #if word ends with ing
|
103 |
'suffix-es': word[-2:] == 'es', #if word ends with es
|
104 |
'suffix-ly': word[-2:] == 'ly', #if word ends with ly
|
105 |
'suffix-ment': word[-4:] == 'ment', #if word ends with ment
|
106 |
-
'suffix-er': word[-2:] == 'er', #if word ends with er
|
107 |
'suffix-ive': word[-3:] == 'ive',
|
108 |
'suffix-ous': word[-3:] == 'ous',
|
109 |
'suffix-ness': word[-4:] == 'ness',
|
110 |
-
'ends_with_s': word[-1] == 's',
|
111 |
'ends_with_es': word[-2:] == 'es',
|
112 |
|
113 |
'has_hyphen': '-' in word, #if word has hypen
|
114 |
'is_numeric': word.isdigit(), #if word is in numeric
|
115 |
'capitals_inside': word[1:].lower() != word[1:],
|
116 |
'is_title_case': word.istitle(), #if first letter is in uppercase
|
117 |
-
|
118 |
}
|
119 |
-
|
120 |
if i > 0:
|
121 |
# prev_word, prev_postag = sentence[i-1]
|
122 |
prev_word = sentence[i-1][0]
|
123 |
prev_prefix, prev_stem, prev_suffix = self.word_splitter(prev_word)
|
124 |
-
|
125 |
features.update({
|
126 |
'prev_word': prev_word,
|
127 |
# 'prev_postag': prev_postag,
|
@@ -131,7 +134,7 @@ class CRF_POS_Tagger:
|
|
131 |
'prev:is_all_caps': prev_word.isupper(),
|
132 |
'prev:is_all_lower': prev_word.islower(),
|
133 |
'prev:is_numeric': prev_word.isdigit(),
|
134 |
-
'prev:is_title_case': prev_word.istitle(),
|
135 |
})
|
136 |
|
137 |
if i < len(sentence)-1:
|
@@ -145,28 +148,38 @@ class CRF_POS_Tagger:
|
|
145 |
'next:is_all_caps': next_word.isupper(),
|
146 |
'next:is_all_lower': next_word.islower(),
|
147 |
'next:is_numeric': next_word.isdigit(),
|
148 |
-
'next:is_title_case': next_word.istitle(),
|
149 |
})
|
150 |
-
|
151 |
return features
|
152 |
|
153 |
-
def train(self, data):
|
154 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
155 |
self.crf_model.fit(X_train, y_train)
|
156 |
-
|
157 |
def predict(self, X_test):
|
158 |
return self.crf_model.predict(X_test)
|
159 |
-
|
160 |
def accuracy(self, test_data):
|
161 |
X_test, y_test = zip(*test_data)
|
162 |
y_pred = self.predict(X_test)
|
163 |
self.actual_tag.extend([item for sublist in y_test for item in sublist])
|
164 |
self.predicted_tag.extend([item for sublist in y_pred for item in sublist])
|
|
|
165 |
return metrics.flat_accuracy_score(y_test, y_pred)
|
166 |
|
167 |
def cross_validation(self):
|
168 |
validator = CRF_POS_Tagger()
|
169 |
-
data = list(zip(self.X, self.y))
|
|
|
170 |
accuracies = []
|
171 |
for i in range(5):
|
172 |
n1 = int(i / 5.0 * len(data))
|
@@ -176,12 +189,15 @@ class CRF_POS_Tagger:
|
|
176 |
validator.train(train_data)
|
177 |
acc = validator.accuracy(test_data)
|
178 |
accuracies.append(acc)
|
|
|
|
|
179 |
return accuracies, sum(accuracies) / 5.0
|
180 |
|
181 |
def con_matrix(self):
|
182 |
-
self.labels = np.unique(self.actual_tag)
|
|
|
183 |
conf_matrix = confusion_matrix(self.actual_tag, self.predicted_tag, labels=self.labels)
|
184 |
-
|
185 |
plt.figure(figsize=(10, 7))
|
186 |
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=self.labels, yticklabels=self.labels)
|
187 |
plt.xlabel('Predicted Tags')
|
@@ -189,9 +205,9 @@ class CRF_POS_Tagger:
|
|
189 |
plt.title('Confusion Matrix Heatmap')
|
190 |
plt.savefig("Confusion_matrix.png")
|
191 |
plt.show()
|
192 |
-
|
193 |
return conf_matrix
|
194 |
-
|
195 |
def per_pos_accuracy(self, conf_matrix):
|
196 |
print("Per Tag Precision, Recall, and F-Score:")
|
197 |
per_tag_metrics = {}
|
@@ -220,7 +236,7 @@ class CRF_POS_Tagger:
|
|
220 |
|
221 |
print(f"{tag}: Precision = {precision:.2f}, Recall = {recall:.2f}, f1-Score = {f1_score:.2f}, "
|
222 |
f"f05-Score = {f0_5_score:.2f}, f2-Score = {f2_score:.2f}")
|
223 |
-
|
224 |
def tagging(self, input):
|
225 |
sentence = (re.sub(r'(\S)([.,;:!?])', r'\1 \2', input.strip())).split()
|
226 |
sentence_list = [[word.lower()] for word in sentence]
|
@@ -231,23 +247,24 @@ class CRF_POS_Tagger:
|
|
231 |
return output
|
232 |
|
233 |
|
234 |
-
|
235 |
-
accuracies, avg_accuracy =
|
236 |
print(f"Cross-Validation Accuracies: {accuracies}")
|
237 |
print(f"Average Accuracy: {avg_accuracy}")
|
238 |
|
239 |
-
conf_matrix =
|
240 |
-
print(
|
241 |
|
242 |
-
|
243 |
-
|
244 |
-
|
245 |
-
|
246 |
-
|
|
|
247 |
outputs = gr.Textbox(
|
248 |
-
|
249 |
-
|
250 |
-
|
251 |
title = "Conditional Random Field POS Tagger",
|
252 |
description = "CS626 Assignment 1B (Autumn 2024)",
|
253 |
theme=gr.themes.Soft())
|
|
|
15 |
|
16 |
class CRF_POS_Tagger:
|
17 |
def __init__(self, train=False):
|
18 |
+
print("Loading Data...")
|
19 |
self.corpus = nltk.corpus.brown.tagged_sents(tagset='universal')
|
20 |
+
print("Data Loaded...")
|
21 |
self.corpus = [[(word.lower(), tag) for word, tag in sentence] for sentence in self.corpus]
|
22 |
self.actual_tag = []
|
23 |
self.predicted_tag = []
|
24 |
self.prefixes = [
|
25 |
+
"a", "anti", "auto", "bi", "co", "dis", "en", "em", "ex", "in", "im",
|
26 |
"inter", "mis", "non", "over", "pre", "re", "sub", "trans", "un", "under"
|
27 |
]
|
28 |
|
29 |
self.suffixes = [
|
30 |
+
"able", "ible", "al", "ance", "ence", "dom", "er", "or", "ful", "hood",
|
31 |
+
"ic", "ing", "ion", "tion", "ity", "ty", "ive", "less", "ly", "ment",
|
32 |
"ness", "ous", "ship", "y", "es", "s"
|
33 |
]
|
34 |
|
|
|
37 |
|
38 |
self.X = [[self.word_features(sentence, i) for i in range(len(sentence))] for sentence in self.corpus]
|
39 |
self.y = [[postag for _, postag in sentence] for sentence in self.corpus]
|
40 |
+
|
41 |
self.split = int(0.8 * len(self.X))
|
42 |
self.X_train = self.X[:self.split]
|
43 |
self.y_train = self.y[:self.split]
|
44 |
self.X_test = self.X[self.split:]
|
45 |
self.y_test = self.y[self.split:]
|
46 |
+
print("Data Loaded...")
|
47 |
self.crf_model = sklearn_crfsuite.CRF(algorithm='lbfgs', c1=0.1, c2=0.1, max_iterations=100, all_possible_transitions=True)
|
48 |
+
print("Model Created...")
|
49 |
if train:
|
50 |
+
self.train()
|
|
|
51 |
|
52 |
def word_splitter(self, word):
|
53 |
prefix = ""
|
|
|
65 |
stem = stem[: -len(suffix)]
|
66 |
|
67 |
return prefix, stem, suffix
|
68 |
+
|
69 |
# Define a function to extract features for each word in a sentence
|
70 |
def word_features(self, sentence, i):
|
71 |
word = sentence[i][0]
|
|
|
82 |
# 'is_capitalized': word[0].upper() == word[0],
|
83 |
'is_all_caps': word.isupper(), #word is in uppercase
|
84 |
'is_all_lower': word.islower(), #word is in lowercase
|
85 |
+
|
86 |
'prefix-1': word[0],
|
87 |
'prefix-2': word[:2],
|
88 |
'prefix-3': word[:3],
|
|
|
100 |
'prefix-de': word[:3] == 'de', #if word starts with de
|
101 |
'prefix-in': word[:3] == 'in', #if word starts with in
|
102 |
'prefix-en': word[:3] == 'en', #if word starts with en
|
103 |
+
|
104 |
'suffix-ed': word[-2:] == 'ed', #if word ends with ed
|
105 |
'suffix-ing': word[-3:] == 'ing', #if word ends with ing
|
106 |
'suffix-es': word[-2:] == 'es', #if word ends with es
|
107 |
'suffix-ly': word[-2:] == 'ly', #if word ends with ly
|
108 |
'suffix-ment': word[-4:] == 'ment', #if word ends with ment
|
109 |
+
'suffix-er': word[-2:] == 'er', #if word ends with er
|
110 |
'suffix-ive': word[-3:] == 'ive',
|
111 |
'suffix-ous': word[-3:] == 'ous',
|
112 |
'suffix-ness': word[-4:] == 'ness',
|
113 |
+
'ends_with_s': word[-1] == 's',
|
114 |
'ends_with_es': word[-2:] == 'es',
|
115 |
|
116 |
'has_hyphen': '-' in word, #if word has hypen
|
117 |
'is_numeric': word.isdigit(), #if word is in numeric
|
118 |
'capitals_inside': word[1:].lower() != word[1:],
|
119 |
'is_title_case': word.istitle(), #if first letter is in uppercase
|
120 |
+
|
121 |
}
|
122 |
+
|
123 |
if i > 0:
|
124 |
# prev_word, prev_postag = sentence[i-1]
|
125 |
prev_word = sentence[i-1][0]
|
126 |
prev_prefix, prev_stem, prev_suffix = self.word_splitter(prev_word)
|
127 |
+
|
128 |
features.update({
|
129 |
'prev_word': prev_word,
|
130 |
# 'prev_postag': prev_postag,
|
|
|
134 |
'prev:is_all_caps': prev_word.isupper(),
|
135 |
'prev:is_all_lower': prev_word.islower(),
|
136 |
'prev:is_numeric': prev_word.isdigit(),
|
137 |
+
'prev:is_title_case': prev_word.istitle(),
|
138 |
})
|
139 |
|
140 |
if i < len(sentence)-1:
|
|
|
148 |
'next:is_all_caps': next_word.isupper(),
|
149 |
'next:is_all_lower': next_word.islower(),
|
150 |
'next:is_numeric': next_word.isdigit(),
|
151 |
+
'next:is_title_case': next_word.istitle(),
|
152 |
})
|
153 |
+
|
154 |
return features
|
155 |
|
156 |
+
def train(self, data=None):
|
157 |
+
if data:
|
158 |
+
X_train, y_train = zip(*data)
|
159 |
+
else:
|
160 |
+
X_train, y_train = self.X_train, self.y_train
|
161 |
+
|
162 |
+
print("Training CRF Model...", len(self.X_train), len(self.y_train))
|
163 |
+
|
164 |
+
# Ensure X_train is a list of lists of dictionaries
|
165 |
+
X_train = [list(map(dict, x)) for x in X_train]
|
166 |
self.crf_model.fit(X_train, y_train)
|
167 |
+
|
168 |
def predict(self, X_test):
|
169 |
return self.crf_model.predict(X_test)
|
170 |
+
|
171 |
def accuracy(self, test_data):
|
172 |
X_test, y_test = zip(*test_data)
|
173 |
y_pred = self.predict(X_test)
|
174 |
self.actual_tag.extend([item for sublist in y_test for item in sublist])
|
175 |
self.predicted_tag.extend([item for sublist in y_pred for item in sublist])
|
176 |
+
print(len(self.actual_tag), len(self.predicted_tag))
|
177 |
return metrics.flat_accuracy_score(y_test, y_pred)
|
178 |
|
179 |
def cross_validation(self):
|
180 |
validator = CRF_POS_Tagger()
|
181 |
+
data = list(zip(self.X, self.y))
|
182 |
+
print("Cross-Validation...")
|
183 |
accuracies = []
|
184 |
for i in range(5):
|
185 |
n1 = int(i / 5.0 * len(data))
|
|
|
189 |
validator.train(train_data)
|
190 |
acc = validator.accuracy(test_data)
|
191 |
accuracies.append(acc)
|
192 |
+
self.actual_tag = validator.actual_tag
|
193 |
+
self.predicted_tag = validator.predicted_tag
|
194 |
return accuracies, sum(accuracies) / 5.0
|
195 |
|
196 |
def con_matrix(self):
|
197 |
+
self.labels = np.unique(self.actual_tag)
|
198 |
+
print(self.labels, self.actual_tag, self.predicted_tag)
|
199 |
conf_matrix = confusion_matrix(self.actual_tag, self.predicted_tag, labels=self.labels)
|
200 |
+
|
201 |
plt.figure(figsize=(10, 7))
|
202 |
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', xticklabels=self.labels, yticklabels=self.labels)
|
203 |
plt.xlabel('Predicted Tags')
|
|
|
205 |
plt.title('Confusion Matrix Heatmap')
|
206 |
plt.savefig("Confusion_matrix.png")
|
207 |
plt.show()
|
208 |
+
|
209 |
return conf_matrix
|
210 |
+
|
211 |
def per_pos_accuracy(self, conf_matrix):
|
212 |
print("Per Tag Precision, Recall, and F-Score:")
|
213 |
per_tag_metrics = {}
|
|
|
236 |
|
237 |
print(f"{tag}: Precision = {precision:.2f}, Recall = {recall:.2f}, f1-Score = {f1_score:.2f}, "
|
238 |
f"f05-Score = {f0_5_score:.2f}, f2-Score = {f2_score:.2f}")
|
239 |
+
|
240 |
def tagging(self, input):
|
241 |
sentence = (re.sub(r'(\S)([.,;:!?])', r'\1 \2', input.strip())).split()
|
242 |
sentence_list = [[word.lower()] for word in sentence]
|
|
|
247 |
return output
|
248 |
|
249 |
|
250 |
+
validate = CRF_POS_Tagger()
|
251 |
+
accuracies, avg_accuracy = validate.cross_validation()
|
252 |
print(f"Cross-Validation Accuracies: {accuracies}")
|
253 |
print(f"Average Accuracy: {avg_accuracy}")
|
254 |
|
255 |
+
conf_matrix = validate.con_matrix()
|
256 |
+
print(validate.per_pos_accuracy(conf_matrix))
|
257 |
|
258 |
+
tagger = CRF_POS_Tagger(True)
|
259 |
+
interface = gr.Interface(fn = tagger.tagging,
|
260 |
+
inputs = gr.Textbox(
|
261 |
+
label="Input Sentence",
|
262 |
+
placeholder="Enter your sentence here...",
|
263 |
+
),
|
264 |
outputs = gr.Textbox(
|
265 |
+
label="Tagged Output",
|
266 |
+
placeholder="Tagged sentence appears here...",
|
267 |
+
),
|
268 |
title = "Conditional Random Field POS Tagger",
|
269 |
description = "CS626 Assignment 1B (Autumn 2024)",
|
270 |
theme=gr.themes.Soft())
|