madhavkotecha commited on
Commit
13e1dc8
·
verified ·
1 Parent(s): 751b357

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -13
app.py CHANGED
@@ -43,7 +43,8 @@ class CRF_POS_Tagger:
43
  self.y_test = self.y[self.split:]
44
  self.crf_model = sklearn_crfsuite.CRF(algorithm='lbfgs', c1=0.1, c2=0.1, max_iterations=100, all_possible_transitions=True)
45
  if train:
46
- self.train()
 
47
 
48
  def word_splitter(self, word):
49
  prefix = ""
@@ -149,8 +150,8 @@ class CRF_POS_Tagger:
149
 
150
  return features
151
 
152
- def train(self, data=None):
153
- X_train, y_train = zip(*data) if data else self.X_train, self.y_train
154
  self.crf_model.fit(X_train, y_train)
155
 
156
  def predict(self, X_test):
@@ -164,7 +165,7 @@ class CRF_POS_Tagger:
164
  return metrics.flat_accuracy_score(y_test, y_pred)
165
 
166
  def cross_validation(self):
167
-
168
  data = list(zip(self.X, self.y))
169
  accuracies = []
170
  for i in range(5):
@@ -172,8 +173,8 @@ class CRF_POS_Tagger:
172
  n2 = int((i + 1) / 5.0 * len(data))
173
  test_data = data[n1:n2]
174
  train_data = data[:n1] + data[n2:]
175
- self.train(train_data)
176
- acc = self.accuracy(test_data)
177
  accuracies.append(acc)
178
  return accuracies, sum(accuracies) / 5.0
179
 
@@ -230,15 +231,14 @@ class CRF_POS_Tagger:
230
  return output
231
 
232
 
233
- # validator = CRF_POS_Tagger()
234
- # accuracies, avg_accuracy = validator.cross_validation()
235
- # print(f"Cross-Validation Accuracies: {accuracies}")
236
- # print(f"Average Accuracy: {avg_accuracy}")
237
 
238
- # conf_matrix = tagger.con_matrix()
239
- # print(tagger.per_pos_accuracy(conf_matrix))
240
 
241
- tagger = CRF_POS_Tagger(True)
242
  interface = gr.Interface(fn = tagger.tagging,
243
  inputs = gr.Textbox(
244
  label="Input Sentence",
 
43
  self.y_test = self.y[self.split:]
44
  self.crf_model = sklearn_crfsuite.CRF(algorithm='lbfgs', c1=0.1, c2=0.1, max_iterations=100, all_possible_transitions=True)
45
  if train:
46
+ data = list(zip(self.X_train, self.y_train))
47
+ self.train(data)
48
 
49
  def word_splitter(self, word):
50
  prefix = ""
 
150
 
151
  return features
152
 
153
+ def train(self, data):
154
+ X_train, y_train = zip(*data)
155
  self.crf_model.fit(X_train, y_train)
156
 
157
  def predict(self, X_test):
 
165
  return metrics.flat_accuracy_score(y_test, y_pred)
166
 
167
  def cross_validation(self):
168
+ validator = CRF_POS_Tagger()
169
  data = list(zip(self.X, self.y))
170
  accuracies = []
171
  for i in range(5):
 
173
  n2 = int((i + 1) / 5.0 * len(data))
174
  test_data = data[n1:n2]
175
  train_data = data[:n1] + data[n2:]
176
+ validator.train(train_data)
177
+ acc = validator.accuracy(test_data)
178
  accuracies.append(acc)
179
  return accuracies, sum(accuracies) / 5.0
180
 
 
231
  return output
232
 
233
 
234
+ tagger = CRF_POS_Tagger(True)
235
+ accuracies, avg_accuracy = tagger.cross_validation()
236
+ print(f"Cross-Validation Accuracies: {accuracies}")
237
+ print(f"Average Accuracy: {avg_accuracy}")
238
 
239
+ conf_matrix = tagger.con_matrix()
240
+ print(tagger.per_pos_accuracy(conf_matrix))
241
 
 
242
  interface = gr.Interface(fn = tagger.tagging,
243
  inputs = gr.Textbox(
244
  label="Input Sentence",