Spaces:
Sleeping
Sleeping
Refactor and add more tests
Browse files- app.py +5 -4
- src/baseline.py +25 -16
- tests/test_baseline.py +9 -2
- tests/test_integration.py +2 -2
app.py
CHANGED
@@ -16,14 +16,15 @@ def root():
|
|
16 |
|
17 |
@app.route('/baseline/fix-commas/', methods=['POST'])
|
18 |
def fix_commas_with_baseline():
|
|
|
19 |
data = request.get_json()
|
20 |
-
if
|
21 |
-
return make_response(jsonify({
|
22 |
else:
|
23 |
-
return make_response("Parameter '
|
24 |
|
25 |
|
26 |
if __name__ == '__main__':
|
27 |
logger.info("Loading the baseline model.")
|
28 |
app.baseline_pipeline = create_baseline_pipeline()
|
29 |
-
app.run(debug=True)
|
|
|
16 |
|
17 |
@app.route('/baseline/fix-commas/', methods=['POST'])
|
18 |
def fix_commas_with_baseline():
|
19 |
+
json_field_name = 's'
|
20 |
data = request.get_json()
|
21 |
+
if json_field_name in data:
|
22 |
+
return make_response(jsonify({json_field_name: fix_commas(app.baseline_pipeline, data['s'])}), 200)
|
23 |
else:
|
24 |
+
return make_response(f"Parameter '{json_field_name}' missing", 400)
|
25 |
|
26 |
|
27 |
if __name__ == '__main__':
|
28 |
logger.info("Loading the baseline model.")
|
29 |
app.baseline_pipeline = create_baseline_pipeline()
|
30 |
+
app.run(debug=True) # TODO get this from config or env variable
|
src/baseline.py
CHANGED
@@ -1,12 +1,19 @@
|
|
1 |
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, NerPipeline
|
2 |
|
3 |
|
4 |
-
def create_baseline_pipeline() -> NerPipeline:
|
5 |
-
tokenizer = AutoTokenizer.from_pretrained(
|
6 |
-
model = AutoModelForTokenClassification.from_pretrained(
|
7 |
return pipeline('ner', model=model, tokenizer=tokenizer)
|
8 |
|
9 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
def _remove_punctuation(s: str) -> str:
|
11 |
to_remove = ".,?-:"
|
12 |
for char in to_remove:
|
@@ -14,23 +21,25 @@ def _remove_punctuation(s: str) -> str:
|
|
14 |
return s
|
15 |
|
16 |
|
17 |
-
def
|
18 |
-
|
19 |
-
# TODO don't accept tokens with commas inside words
|
20 |
-
result = original_s.replace(',', '') # We will fix the commas, but keep everything else intact
|
21 |
current_offset = 0
|
|
|
22 |
for i in range(1, len(pipeline_json)):
|
23 |
-
|
24 |
-
|
25 |
-
# Only insert commas for the final token of a word
|
26 |
-
if pipeline_json[i - 1]['entity'] == ',' and pipeline_json[i]['word'].startswith('β'):
|
27 |
result = result[:current_offset] + ',' + result[current_offset:]
|
28 |
current_offset += 1
|
29 |
return result
|
30 |
|
31 |
|
32 |
-
def
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, NerPipeline
|
2 |
|
3 |
|
4 |
+
def create_baseline_pipeline(model_name="oliverguhr/fullstop-punctuation-multilang-large") -> NerPipeline:
|
5 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
6 |
+
model = AutoModelForTokenClassification.from_pretrained(model_name)
|
7 |
return pipeline('ner', model=model, tokenizer=tokenizer)
|
8 |
|
9 |
|
10 |
+
def fix_commas(ner_pipeline: NerPipeline, s: str) -> str:
|
11 |
+
return _fix_commas_based_on_pipeline_output(
|
12 |
+
ner_pipeline(_remove_punctuation(s)),
|
13 |
+
s
|
14 |
+
)
|
15 |
+
|
16 |
+
|
17 |
def _remove_punctuation(s: str) -> str:
|
18 |
to_remove = ".,?-:"
|
19 |
for char in to_remove:
|
|
|
21 |
return s
|
22 |
|
23 |
|
24 |
+
def _fix_commas_based_on_pipeline_output(pipeline_json: list[dict], original_s: str) -> str:
|
25 |
+
result = original_s.replace(',', '') # We will fix the commas, but keep everything else intact
|
|
|
|
|
26 |
current_offset = 0
|
27 |
+
|
28 |
for i in range(1, len(pipeline_json)):
|
29 |
+
current_offset = _find_current_token(current_offset, i, pipeline_json, result)
|
30 |
+
if _should_insert_comma(i, pipeline_json):
|
|
|
|
|
31 |
result = result[:current_offset] + ',' + result[current_offset:]
|
32 |
current_offset += 1
|
33 |
return result
|
34 |
|
35 |
|
36 |
+
def _should_insert_comma(i, pipeline_json, new_word_indicator='β') -> bool:
|
37 |
+
# Only insert commas for the final token of a word
|
38 |
+
return pipeline_json[i - 1]['entity'] == ',' and pipeline_json[i]['word'].startswith(new_word_indicator)
|
39 |
+
|
40 |
+
|
41 |
+
def _find_current_token(current_offset, i, pipeline_json, result, new_word_indicator='β') -> int:
|
42 |
+
current_word = pipeline_json[i - 1]['word'].replace(new_word_indicator, '')
|
43 |
+
# Find the current word in the result string, starting looking at current offset
|
44 |
+
current_offset = result.find(current_word, current_offset) + len(current_word)
|
45 |
+
return current_offset
|
tests/test_baseline.py
CHANGED
@@ -11,7 +11,8 @@ def baseline_pipeline():
|
|
11 |
"test_input",
|
12 |
['',
|
13 |
'Hello world.',
|
14 |
-
'This test string should not have any commas inside it.'
|
|
|
15 |
)
|
16 |
def test_fix_commas_leaves_correct_strings_unchanged(baseline_pipeline, test_input):
|
17 |
result = fix_commas(baseline_pipeline, s=test_input)
|
@@ -23,7 +24,13 @@ def test_fix_commas_leaves_correct_strings_unchanged(baseline_pipeline, test_inp
|
|
23 |
[
|
24 |
['I, am.', 'I am.'],
|
25 |
['A complex clause however it misses a comma something else and a dot...?',
|
26 |
-
'A complex clause, however, it misses a comma, something else and a dot...?']
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
)
|
28 |
def test_fix_commas_fixes_incorrect_commas(baseline_pipeline, test_input, expected):
|
29 |
result = fix_commas(baseline_pipeline, s=test_input)
|
|
|
11 |
"test_input",
|
12 |
['',
|
13 |
'Hello world.',
|
14 |
+
'This test string should not have any commas inside it.',
|
15 |
+
'aAaalLL the.. weird?~! punctuation.should also . be kept-as is! Only fixing-commas.']
|
16 |
)
|
17 |
def test_fix_commas_leaves_correct_strings_unchanged(baseline_pipeline, test_input):
|
18 |
result = fix_commas(baseline_pipeline, s=test_input)
|
|
|
24 |
[
|
25 |
['I, am.', 'I am.'],
|
26 |
['A complex clause however it misses a comma something else and a dot...?',
|
27 |
+
'A complex clause, however, it misses a comma, something else and a dot...?'],
|
28 |
+
['a pen an apple, \tand a pineapple!',
|
29 |
+
'a pen, an apple \tand a pineapple!'],
|
30 |
+
['Even newlines\ntabs\tand others get preserved.',
|
31 |
+
'Even newlines,\ntabs\tand others get preserved.'],
|
32 |
+
['I had no Creativity left, therefore, I come here, and write useless examples, for this test.',
|
33 |
+
'I had no Creativity left therefore, I come here and write useless examples for this test.']]
|
34 |
)
|
35 |
def test_fix_commas_fixes_incorrect_commas(baseline_pipeline, test_input, expected):
|
36 |
result = fix_commas(baseline_pipeline, s=test_input)
|
tests/test_integration.py
CHANGED
@@ -29,7 +29,7 @@ def test_fix_commas_fails_on_wrong_parameters(client):
|
|
29 |
'Hello world.',
|
30 |
'This test string should not have any commas inside it.']
|
31 |
)
|
32 |
-
def
|
33 |
response = client.post('/baseline/fix-commas/', json={'s': test_input})
|
34 |
|
35 |
assert response.status_code == 200
|
@@ -40,7 +40,7 @@ def test_fix_commas_plain_string_unchanged(client, test_input: str):
|
|
40 |
"test_input, expected",
|
41 |
[['I am, here.', 'I am here.'],
|
42 |
['books pens and pencils',
|
43 |
-
'books, pens and pencils
|
44 |
)
|
45 |
def test_fix_commas_fixes_wrong_commas(client, test_input: str, expected: str):
|
46 |
response = client.post('/baseline/fix-commas/', json={'s': test_input})
|
|
|
29 |
'Hello world.',
|
30 |
'This test string should not have any commas inside it.']
|
31 |
)
|
32 |
+
def test_fix_commas_correct_string_unchanged(client, test_input: str):
|
33 |
response = client.post('/baseline/fix-commas/', json={'s': test_input})
|
34 |
|
35 |
assert response.status_code == 200
|
|
|
40 |
"test_input, expected",
|
41 |
[['I am, here.', 'I am here.'],
|
42 |
['books pens and pencils',
|
43 |
+
'books, pens and pencils']]
|
44 |
)
|
45 |
def test_fix_commas_fixes_wrong_commas(client, test_input: str, expected: str):
|
46 |
response = client.post('/baseline/fix-commas/', json={'s': test_input})
|