Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -14,7 +14,7 @@ nltk.download('brown')
|
|
| 14 |
nltk.download('universal_tagset')
|
| 15 |
|
| 16 |
class CRF_POS_Tagger:
|
| 17 |
-
def __init__(self):
|
| 18 |
self.corpus = nltk.corpus.brown.tagged_sents(tagset='universal')
|
| 19 |
self.corpus = [[(word.lower(), tag) for word, tag in sentence] for sentence in self.corpus]
|
| 20 |
self.actual_tag = []
|
|
@@ -42,7 +42,7 @@ class CRF_POS_Tagger:
|
|
| 42 |
self.X_test = self.X[self.split:]
|
| 43 |
self.y_test = self.y[self.split:]
|
| 44 |
self.crf_model = sklearn_crfsuite.CRF(algorithm='lbfgs', c1=0.1, c2=0.1, max_iterations=100, all_possible_transitions=True)
|
| 45 |
-
|
| 46 |
|
| 47 |
def word_splitter(self, word):
|
| 48 |
prefix = ""
|
|
@@ -163,6 +163,7 @@ class CRF_POS_Tagger:
|
|
| 163 |
return metrics.flat_accuracy_score(y_test, y_pred)
|
| 164 |
|
| 165 |
def cross_validation(self):
|
|
|
|
| 166 |
data = list(zip(self.X, self.y))
|
| 167 |
accuracies = []
|
| 168 |
for i in range(5):
|
|
@@ -170,8 +171,8 @@ class CRF_POS_Tagger:
|
|
| 170 |
n2 = int((i + 1) / 5.0 * len(data))
|
| 171 |
test_data = data[n1:n2]
|
| 172 |
train_data = data[:n1] + data[n2:]
|
| 173 |
-
|
| 174 |
-
acc =
|
| 175 |
accuracies.append(acc)
|
| 176 |
return accuracies, sum(accuracies) / 5.0
|
| 177 |
|
|
@@ -227,7 +228,7 @@ class CRF_POS_Tagger:
|
|
| 227 |
output = "".join(f"{sentence[i]}[{predicted_tags[0][i]}] " for i in range(len(sentence)))
|
| 228 |
return output
|
| 229 |
|
| 230 |
-
tagger = CRF_POS_Tagger()
|
| 231 |
|
| 232 |
accuracies, avg_accuracy = tagger.cross_validation()
|
| 233 |
print(f"Cross-Validation Accuracies: {accuracies}")
|
|
|
|
| 14 |
nltk.download('universal_tagset')
|
| 15 |
|
| 16 |
class CRF_POS_Tagger:
|
| 17 |
+
def __init__(self, train=False):
|
| 18 |
self.corpus = nltk.corpus.brown.tagged_sents(tagset='universal')
|
| 19 |
self.corpus = [[(word.lower(), tag) for word, tag in sentence] for sentence in self.corpus]
|
| 20 |
self.actual_tag = []
|
|
|
|
| 42 |
self.X_test = self.X[self.split:]
|
| 43 |
self.y_test = self.y[self.split:]
|
| 44 |
self.crf_model = sklearn_crfsuite.CRF(algorithm='lbfgs', c1=0.1, c2=0.1, max_iterations=100, all_possible_transitions=True)
|
| 45 |
+
self.train() if train
|
| 46 |
|
| 47 |
def word_splitter(self, word):
|
| 48 |
prefix = ""
|
|
|
|
| 163 |
return metrics.flat_accuracy_score(y_test, y_pred)
|
| 164 |
|
| 165 |
def cross_validation(self):
|
| 166 |
+
validator = CRF_POS_Tagger()
|
| 167 |
data = list(zip(self.X, self.y))
|
| 168 |
accuracies = []
|
| 169 |
for i in range(5):
|
|
|
|
| 171 |
n2 = int((i + 1) / 5.0 * len(data))
|
| 172 |
test_data = data[n1:n2]
|
| 173 |
train_data = data[:n1] + data[n2:]
|
| 174 |
+
validator.train(train_data)
|
| 175 |
+
acc = validator.accuracy(test_data)
|
| 176 |
accuracies.append(acc)
|
| 177 |
return accuracies, sum(accuracies) / 5.0
|
| 178 |
|
|
|
|
| 228 |
output = "".join(f"{sentence[i]}[{predicted_tags[0][i]}] " for i in range(len(sentence)))
|
| 229 |
return output
|
| 230 |
|
| 231 |
+
tagger = CRF_POS_Tagger(True)
|
| 232 |
|
| 233 |
accuracies, avg_accuracy = tagger.cross_validation()
|
| 234 |
print(f"Cross-Validation Accuracies: {accuracies}")
|