Corey Morris commited on
Commit
9549fcc
1 Parent(s): 85667d0

Added test for removal of undesired columns. fixed code error in column removal

Browse files
result_data_processor.py CHANGED
@@ -83,8 +83,7 @@ class ResultDataProcessor:
83
  data = data[cols]
84
 
85
  # Drop specific columns
86
- data.drop(columns=['all', 'truthfulqa:mc|0'])
87
-
88
 
89
  # Add parameter count column using extract_parameters function
90
  data['Parameters'] = data.index.to_series().apply(self._extract_parameters)
 
83
  data = data[cols]
84
 
85
  # Drop specific columns
86
+ data = data.drop(columns=['all', 'truthfulqa:mc|0'])
 
87
 
88
  # Add parameter count column using extract_parameters function
89
  data['Parameters'] = data.index.to_series().apply(self._extract_parameters)
test_data_processing.py CHANGED
@@ -18,12 +18,22 @@ class TestResultDataProcessor(unittest.TestCase):
18
  self.assertIn('Parameters', data.columns)
19
  self.assertIn('MMLU_average', data.columns)
20
  # check number of columns
21
- self.assertEqual(len(data.columns), 63)
22
 
23
  # check that the number of rows is correct
24
  def test_rows(self):
25
  data = self.processor.data
26
  self.assertEqual(len(data), 992)
 
 
 
 
 
 
 
 
 
 
27
 
28
  if __name__ == '__main__':
29
  unittest.main()
 
18
  self.assertIn('Parameters', data.columns)
19
  self.assertIn('MMLU_average', data.columns)
20
  # check number of columns
21
+ self.assertEqual(len(data.columns), 61)
22
 
23
  # check that the number of rows is correct
24
  def test_rows(self):
25
  data = self.processor.data
26
  self.assertEqual(len(data), 992)
27
+
28
+ # # check that mc1 column exists
29
+ # def test_mc1(self):
30
+ # data = self.processor.data
31
+ # self.assertIn('mc1', data.columns)
32
+
33
+ # test that a column that contains truthfulqa:mc does not exist
34
+ def test_truthfulqa_mc(self):
35
+ data = self.processor.data
36
+ self.assertNotIn('truthfulqa:mc', data.columns)
37
 
38
  if __name__ == '__main__':
39
  unittest.main()