klasocki commited on
Commit
b1106e6
β€’
1 Parent(s): 35c0239

Refactor and add more tests

Browse files
Files changed (4) hide show
  1. app.py +5 -4
  2. src/baseline.py +25 -16
  3. tests/test_baseline.py +9 -2
  4. tests/test_integration.py +2 -2
app.py CHANGED
@@ -16,14 +16,15 @@ def root():
16
 
17
  @app.route('/baseline/fix-commas/', methods=['POST'])
18
  def fix_commas_with_baseline():
 
19
  data = request.get_json()
20
- if 's' in data:
21
- return make_response(jsonify({'s': fix_commas(app.baseline_pipeline, data['s'])}), 200)
22
  else:
23
- return make_response("Parameter 's' missing", 400)
24
 
25
 
26
  if __name__ == '__main__':
27
  logger.info("Loading the baseline model.")
28
  app.baseline_pipeline = create_baseline_pipeline()
29
- app.run(debug=True)
 
16
 
17
  @app.route('/baseline/fix-commas/', methods=['POST'])
18
  def fix_commas_with_baseline():
19
+ json_field_name = 's'
20
  data = request.get_json()
21
+ if json_field_name in data:
22
+ return make_response(jsonify({json_field_name: fix_commas(app.baseline_pipeline, data['s'])}), 200)
23
  else:
24
+ return make_response(f"Parameter '{json_field_name}' missing", 400)
25
 
26
 
27
  if __name__ == '__main__':
28
  logger.info("Loading the baseline model.")
29
  app.baseline_pipeline = create_baseline_pipeline()
30
+ app.run(debug=True) # TODO get this from config or env variable
src/baseline.py CHANGED
@@ -1,12 +1,19 @@
1
  from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, NerPipeline
2
 
3
 
4
- def create_baseline_pipeline() -> NerPipeline:
5
- tokenizer = AutoTokenizer.from_pretrained("oliverguhr/fullstop-punctuation-multilang-large")
6
- model = AutoModelForTokenClassification.from_pretrained("oliverguhr/fullstop-punctuation-multilang-large")
7
  return pipeline('ner', model=model, tokenizer=tokenizer)
8
 
9
 
 
 
 
 
 
 
 
10
  def _remove_punctuation(s: str) -> str:
11
  to_remove = ".,?-:"
12
  for char in to_remove:
@@ -14,23 +21,25 @@ def _remove_punctuation(s: str) -> str:
14
  return s
15
 
16
 
17
- def _convert_pipeline_json_to_string(pipeline_json: list[dict], original_s: str) -> str:
18
- # TODO is it ok to remove redundant spaces, or should we keep input data as is and only touch commas?
19
- # TODO don't accept tokens with commas inside words
20
- result = original_s.replace(',', '') # We will fix the commas, but keep everything else intact
21
  current_offset = 0
 
22
  for i in range(1, len(pipeline_json)):
23
- current_word = pipeline_json[i - 1]['word'].replace('▁', '')
24
- current_offset = result.find(current_word, current_offset) + len(current_word)
25
- # Only insert commas for the final token of a word
26
- if pipeline_json[i - 1]['entity'] == ',' and pipeline_json[i]['word'].startswith('▁'):
27
  result = result[:current_offset] + ',' + result[current_offset:]
28
  current_offset += 1
29
  return result
30
 
31
 
32
- def fix_commas(ner_pipeline: NerPipeline, s: str) -> str:
33
- return _convert_pipeline_json_to_string(
34
- ner_pipeline(_remove_punctuation(s)),
35
- s
36
- )
 
 
 
 
 
 
1
  from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, NerPipeline
2
 
3
 
4
+ def create_baseline_pipeline(model_name="oliverguhr/fullstop-punctuation-multilang-large") -> NerPipeline:
5
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
6
+ model = AutoModelForTokenClassification.from_pretrained(model_name)
7
  return pipeline('ner', model=model, tokenizer=tokenizer)
8
 
9
 
10
+ def fix_commas(ner_pipeline: NerPipeline, s: str) -> str:
11
+ return _fix_commas_based_on_pipeline_output(
12
+ ner_pipeline(_remove_punctuation(s)),
13
+ s
14
+ )
15
+
16
+
17
  def _remove_punctuation(s: str) -> str:
18
  to_remove = ".,?-:"
19
  for char in to_remove:
 
21
  return s
22
 
23
 
24
+ def _fix_commas_based_on_pipeline_output(pipeline_json: list[dict], original_s: str) -> str:
25
+ result = original_s.replace(',', '') # We will fix the commas, but keep everything else intact
 
 
26
  current_offset = 0
27
+
28
  for i in range(1, len(pipeline_json)):
29
+ current_offset = _find_current_token(current_offset, i, pipeline_json, result)
30
+ if _should_insert_comma(i, pipeline_json):
 
 
31
  result = result[:current_offset] + ',' + result[current_offset:]
32
  current_offset += 1
33
  return result
34
 
35
 
36
+ def _should_insert_comma(i, pipeline_json, new_word_indicator='▁') -> bool:
37
+ # Only insert commas for the final token of a word
38
+ return pipeline_json[i - 1]['entity'] == ',' and pipeline_json[i]['word'].startswith(new_word_indicator)
39
+
40
+
41
+ def _find_current_token(current_offset, i, pipeline_json, result, new_word_indicator='▁') -> int:
42
+ current_word = pipeline_json[i - 1]['word'].replace(new_word_indicator, '')
43
+ # Find the current word in the result string, starting looking at current offset
44
+ current_offset = result.find(current_word, current_offset) + len(current_word)
45
+ return current_offset
tests/test_baseline.py CHANGED
@@ -11,7 +11,8 @@ def baseline_pipeline():
11
  "test_input",
12
  ['',
13
  'Hello world.',
14
- 'This test string should not have any commas inside it.']
 
15
  )
16
  def test_fix_commas_leaves_correct_strings_unchanged(baseline_pipeline, test_input):
17
  result = fix_commas(baseline_pipeline, s=test_input)
@@ -23,7 +24,13 @@ def test_fix_commas_leaves_correct_strings_unchanged(baseline_pipeline, test_inp
23
  [
24
  ['I, am.', 'I am.'],
25
  ['A complex clause however it misses a comma something else and a dot...?',
26
- 'A complex clause, however, it misses a comma, something else and a dot...?']]
 
 
 
 
 
 
27
  )
28
  def test_fix_commas_fixes_incorrect_commas(baseline_pipeline, test_input, expected):
29
  result = fix_commas(baseline_pipeline, s=test_input)
 
11
  "test_input",
12
  ['',
13
  'Hello world.',
14
+ 'This test string should not have any commas inside it.',
15
+ 'aAaalLL the.. weird?~! punctuation.should also . be kept-as is! Only fixing-commas.']
16
  )
17
  def test_fix_commas_leaves_correct_strings_unchanged(baseline_pipeline, test_input):
18
  result = fix_commas(baseline_pipeline, s=test_input)
 
24
  [
25
  ['I, am.', 'I am.'],
26
  ['A complex clause however it misses a comma something else and a dot...?',
27
+ 'A complex clause, however, it misses a comma, something else and a dot...?'],
28
+ ['a pen an apple, \tand a pineapple!',
29
+ 'a pen, an apple \tand a pineapple!'],
30
+ ['Even newlines\ntabs\tand others get preserved.',
31
+ 'Even newlines,\ntabs\tand others get preserved.'],
32
+ ['I had no Creativity left, therefore, I come here, and write useless examples, for this test.',
33
+ 'I had no Creativity left therefore, I come here and write useless examples for this test.']]
34
  )
35
  def test_fix_commas_fixes_incorrect_commas(baseline_pipeline, test_input, expected):
36
  result = fix_commas(baseline_pipeline, s=test_input)
tests/test_integration.py CHANGED
@@ -29,7 +29,7 @@ def test_fix_commas_fails_on_wrong_parameters(client):
29
  'Hello world.',
30
  'This test string should not have any commas inside it.']
31
  )
32
- def test_fix_commas_plain_string_unchanged(client, test_input: str):
33
  response = client.post('/baseline/fix-commas/', json={'s': test_input})
34
 
35
  assert response.status_code == 200
@@ -40,7 +40,7 @@ def test_fix_commas_plain_string_unchanged(client, test_input: str):
40
  "test_input, expected",
41
  [['I am, here.', 'I am here.'],
42
  ['books pens and pencils',
43
- 'books, pens and pencils.']]
44
  )
45
  def test_fix_commas_fixes_wrong_commas(client, test_input: str, expected: str):
46
  response = client.post('/baseline/fix-commas/', json={'s': test_input})
 
29
  'Hello world.',
30
  'This test string should not have any commas inside it.']
31
  )
32
+ def test_fix_commas_correct_string_unchanged(client, test_input: str):
33
  response = client.post('/baseline/fix-commas/', json={'s': test_input})
34
 
35
  assert response.status_code == 200
 
40
  "test_input, expected",
41
  [['I am, here.', 'I am here.'],
42
  ['books pens and pencils',
43
+ 'books, pens and pencils']]
44
  )
45
  def test_fix_commas_fixes_wrong_commas(client, test_input: str, expected: str):
46
  response = client.post('/baseline/fix-commas/', json={'s': test_input})