class Gramformer: def __init__(self, models=1, use_gpu=False): from transformers import AutoTokenizer from transformers import AutoModelForSeq2SeqLM from lm_scorer.models.auto import AutoLMScorer as LMScorer import errant import en_core_web_sm nlp = en_core_web_sm.load() self.annotator = errant.load('en', nlp) if use_gpu: device= "cuda:0" else: device = "cpu" batch_size = 1 self.scorer = LMScorer.from_pretrained("gpt2", device=device, batch_size=batch_size) self.device = device correction_model_tag = "prithivida/grammar_error_correcter_v1" self.model_loaded = False if models == 1: self.correction_tokenizer = AutoTokenizer.from_pretrained(correction_model_tag) self.correction_model = AutoModelForSeq2SeqLM.from_pretrained(correction_model_tag) self.correction_model = self.correction_model.to(device) self.model_loaded = True print("[Gramformer] Grammar error correct/highlight model loaded..") elif models == 2: # TODO print("TO BE IMPLEMENTED!!!") def correct(self, input_sentence, max_candidates=1): if self.model_loaded: correction_prefix = "gec: " input_sentence = correction_prefix + input_sentence input_ids = self.correction_tokenizer.encode(input_sentence, return_tensors='pt') input_ids = input_ids.to(self.device) preds = self.correction_model.generate( input_ids, do_sample=True, max_length=128, top_k=50, top_p=0.95, early_stopping=True, num_return_sequences=max_candidates) corrected = set() for pred in preds: corrected.add(self.correction_tokenizer.decode(pred, skip_special_tokens=True).strip()) corrected = list(corrected) scores = self.scorer.sentence_score(corrected, log=True) ranked_corrected = [(c,s) for c, s in zip(corrected, scores)] ranked_corrected.sort(key = lambda x:x[1], reverse=True) return ranked_corrected else: print("Model is not loaded") return None def highlight(self, orig, cor): edits = self._get_edits(orig, cor) orig_tokens = orig.split() ignore_indexes = [] for edit in edits: edit_type = edit[0] edit_str_start = edit[1] edit_spos = edit[2] edit_epos = edit[3] edit_str_end = edit[4] # if no_of_tokens(edit_str_start) > 1 ==> excluding the first token, mark all other tokens for deletion for i in range(edit_spos+1, edit_epos): ignore_indexes.append(i) if edit_str_start == "": if edit_spos - 1 >= 0: new_edit_str = orig_tokens[edit_spos - 1] edit_spos -= 1 else: new_edit_str = orig_tokens[edit_spos + 1] edit_spos += 1 if edit_type == "PUNCT": st = "" + new_edit_str + "" else: st = "" + new_edit_str + "" orig_tokens[edit_spos] = st elif edit_str_end == "": st = "" + edit_str_start + "" orig_tokens[edit_spos] = st else: st = "" + edit_str_start + "" orig_tokens[edit_spos] = st for i in sorted(ignore_indexes, reverse=True): del(orig_tokens[i]) return(" ".join(orig_tokens)) def detect(self, input_sentence): # TO BE IMPLEMENTED pass def _get_edits(self, orig, cor): orig = self.annotator.parse(orig) cor = self.annotator.parse(cor) alignment = self.annotator.align(orig, cor) edits = self.annotator.merge(alignment) if len(edits) == 0: return [] edit_annotations = [] for e in edits: e = self.annotator.classify(e) edit_annotations.append((e.type[2:], e.o_str, e.o_start, e.o_end, e.c_str, e.c_start, e.c_end)) if len(edit_annotations) > 0: return edit_annotations else: return [] def get_edits(self, orig, cor): return self._get_edits(orig, cor)