madhavkotecha commited on
Commit
66e18d4
·
verified ·
1 Parent(s): b4825e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -14,7 +14,7 @@ nltk.download('brown')
14
  nltk.download('universal_tagset')
15
 
16
  class CRF_POS_Tagger:
17
- def __init__(self):
18
  self.corpus = nltk.corpus.brown.tagged_sents(tagset='universal')
19
  self.corpus = [[(word.lower(), tag) for word, tag in sentence] for sentence in self.corpus]
20
  self.actual_tag = []
@@ -42,7 +42,7 @@ class CRF_POS_Tagger:
42
  self.X_test = self.X[self.split:]
43
  self.y_test = self.y[self.split:]
44
  self.crf_model = sklearn_crfsuite.CRF(algorithm='lbfgs', c1=0.1, c2=0.1, max_iterations=100, all_possible_transitions=True)
45
- # self.train()
46
 
47
  def word_splitter(self, word):
48
  prefix = ""
@@ -163,6 +163,7 @@ class CRF_POS_Tagger:
163
  return metrics.flat_accuracy_score(y_test, y_pred)
164
 
165
  def cross_validation(self):
 
166
  data = list(zip(self.X, self.y))
167
  accuracies = []
168
  for i in range(5):
@@ -170,8 +171,8 @@ class CRF_POS_Tagger:
170
  n2 = int((i + 1) / 5.0 * len(data))
171
  test_data = data[n1:n2]
172
  train_data = data[:n1] + data[n2:]
173
- self.train(train_data)
174
- acc = self.accuracy(test_data)
175
  accuracies.append(acc)
176
  return accuracies, sum(accuracies) / 5.0
177
 
@@ -227,7 +228,7 @@ class CRF_POS_Tagger:
227
  output = "".join(f"{sentence[i]}[{predicted_tags[0][i]}] " for i in range(len(sentence)))
228
  return output
229
 
230
- tagger = CRF_POS_Tagger()
231
 
232
  accuracies, avg_accuracy = tagger.cross_validation()
233
  print(f"Cross-Validation Accuracies: {accuracies}")
 
14
  nltk.download('universal_tagset')
15
 
16
  class CRF_POS_Tagger:
17
+ def __init__(self, train=False):
18
  self.corpus = nltk.corpus.brown.tagged_sents(tagset='universal')
19
  self.corpus = [[(word.lower(), tag) for word, tag in sentence] for sentence in self.corpus]
20
  self.actual_tag = []
 
42
  self.X_test = self.X[self.split:]
43
  self.y_test = self.y[self.split:]
44
  self.crf_model = sklearn_crfsuite.CRF(algorithm='lbfgs', c1=0.1, c2=0.1, max_iterations=100, all_possible_transitions=True)
45
+ self.train() if train
46
 
47
  def word_splitter(self, word):
48
  prefix = ""
 
163
  return metrics.flat_accuracy_score(y_test, y_pred)
164
 
165
  def cross_validation(self):
166
+ validator = CRF_POS_Tagger()
167
  data = list(zip(self.X, self.y))
168
  accuracies = []
169
  for i in range(5):
 
171
  n2 = int((i + 1) / 5.0 * len(data))
172
  test_data = data[n1:n2]
173
  train_data = data[:n1] + data[n2:]
174
+ validator.train(train_data)
175
+ acc = validator.accuracy(test_data)
176
  accuracies.append(acc)
177
  return accuracies, sum(accuracies) / 5.0
178
 
 
228
  output = "".join(f"{sentence[i]}[{predicted_tags[0][i]}] " for i in range(len(sentence)))
229
  return output
230
 
231
+ tagger = CRF_POS_Tagger(True)
232
 
233
  accuracies, avg_accuracy = tagger.cross_validation()
234
  print(f"Cross-Validation Accuracies: {accuracies}")