Spaces:
Sleeping
Sleeping
File size: 6,237 Bytes
5760b44 6373819 5760b44 b2a5e86 5760b44 8868222 6373819 72fb4d7 f42ec01 6373819 f42ec01 6373819 f42ec01 72fb4d7 6373819 b1106e6 72fb4d7 5760b44 6373819 b1106e6 6373819 b1106e6 35c0239 6373819 b1106e6 35c0239 6373819 35c0239 5760b44 6373819 b1106e6 6373819 b1106e6 6373819 f42ec01 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
from transformers import AutoTokenizer, AutoModelForTokenClassification, pipeline, NerPipeline
import re
from commafixer.src.comma_fixer_interface import CommaFixerInterface
class BaselineCommaFixer(CommaFixerInterface):
"""
A wrapper class for the oliverguhr/fullstop-punctuation-multilang-large baseline punctuation restoration model.
It adapts the model to perform comma fixing instead of full punctuation restoration, that is, removes the
punctuation, runs the model, and then uses its outputs so that only commas are changed.
"""
def __init__(self, device=-1):
self._ner = _create_baseline_pipeline(device=device)
def fix_commas(self, s: str) -> str:
"""
The main method for fixing commas using the baseline model.
In the future we should think about batching the calls to it, for now it processes requests string by string.
:param s: A string with commas to fix, without length restrictions.
Example: comma_fixer.fix_commas("One two thre, and four!")
:return: A string with commas fixed, example: "One, two, thre and four!"
"""
s_no_punctuation, punctuation_indices = _remove_punctuation(s)
return _fix_commas_based_on_pipeline_output(
self._ner(s_no_punctuation),
s,
punctuation_indices
)
def _create_baseline_pipeline(model_name="oliverguhr/fullstop-punctuation-multilang-large", device=-1) -> NerPipeline:
"""
Creates the huggingface pipeline object.
Can also be used for pre-downloading the model and the tokenizer.
:param model_name: Name of the baseline model on the huggingface hub.
:param device: Device to use when running the pipeline, defaults to -1 for CPU, a higher number indicates the id
of GPU to use.
:return: A token classification pipeline.
"""
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForTokenClassification.from_pretrained(model_name)
return pipeline('ner', model=model, tokenizer=tokenizer, device=device)
def _remove_punctuation(s: str) -> tuple[str, list[int]]:
"""
Removes the punctuation (".,?-:") from the input text, since the baseline model has been trained on data without
punctuation. It also keeps track of the indices where we remove it, so that we can restore the original later.
Commas are the exception, since we remove them, but restore with the model.
Hence we do not keep track of removed comma indices.
:param s: For instance, "A short-string: with punctuation, removed.
:return: A tuple of a string, for instance:
"A shortstring with punctuation removed"; and a list of indices where punctuation has been removed, in ascending
order
"""
to_remove_regex = r"[\.\?\-:]"
# We're not counting commas, since we will remove them later anyway. Only counting removals that will be restored
# in the final resulting string.
punctuation_indices = [m.start() for m in re.finditer(to_remove_regex, s)]
s = re.sub(to_remove_regex, '', s)
s = s.replace(',', '')
return s, punctuation_indices
def _fix_commas_based_on_pipeline_output(pipeline_json: list[dict], original_s: str, punctuation_indices: list[int]) -> \
str:
"""
This function takes the comma fixing token classification pipeline output, and converts it to string based on the
original
string and punctuation indices, so that the string contains all the original characters, except commas, intact.
:param pipeline_json: Token classification pipeline output.
Contains five fields.
'entity' is the punctuation that should follow this token.
'word' is the token text together with preceding space if any.
'end' is the end index in the original string (with punctuation removed in our case!!)
Example: [{'entity': ':',
'score': 0.90034866,
'index': 1,
'word': '▁Exam',
'start': 0,
'end': 4},
{'entity': ':',
'score': 0.9157294,
'index': 2,
'word': 'ple',
'start': 4,
'end': 7}]
:param original_s: The original string, before removing punctuation.
:param punctuation_indices: The indices of the removed punctuation except commas, so that we can correctly keep
track of the current offset in the original string.
:return: A string with commas fixed, and other the original punctuation from the input string restored.
"""
result = original_s.replace(',', '') # We will fix the commas, but keep everything else intact
commas_inserted_or_punctuation_removed = 0
removed_punctuation_index = 0
for i in range(1, len(pipeline_json)):
current_offset = pipeline_json[i - 1]['end'] + commas_inserted_or_punctuation_removed
commas_inserted_or_punctuation_removed, current_offset, removed_punctuation_index = (
_update_offset_by_the_removed_punctuation(
commas_inserted_or_punctuation_removed, current_offset, punctuation_indices, removed_punctuation_index
)
)
if _should_insert_comma(i, pipeline_json):
result = result[:current_offset] + ',' + result[current_offset:]
commas_inserted_or_punctuation_removed += 1
return result
def _update_offset_by_the_removed_punctuation(
commas_inserted_and_punctuation_removed, current_offset, punctuation_indices, removed_punctuation_index
):
# increase the counters for every punctuation removed from the original string before the curent offset
while (removed_punctuation_index < len(punctuation_indices) and
punctuation_indices[removed_punctuation_index] < current_offset):
commas_inserted_and_punctuation_removed += 1
removed_punctuation_index += 1
current_offset += 1
return commas_inserted_and_punctuation_removed, current_offset, removed_punctuation_index
def _should_insert_comma(i, pipeline_json, new_word_indicator='▁') -> bool:
# Only insert commas for the final token of a word, that is, if next word starts with a space.
return pipeline_json[i - 1]['entity'] == ',' and pipeline_json[i]['word'].startswith(new_word_indicator)
if __name__ == "__main__":
BaselineCommaFixer() # to pre-download the model and tokenizer
|