edugp commited on
Commit
0def03f
·
1 Parent(s): ab846df

Replicate default cc_net preprocessing at inference time on KenlmModel.get_perplexity

Browse files
Files changed (1) hide show
  1. perplexity_lenses/perplexity.py +88 -1
perplexity_lenses/perplexity.py CHANGED
@@ -1,10 +1,53 @@
1
  import os
 
 
2
  import urllib.request
 
3
 
4
  import kenlm
5
 
6
 
7
  class KenlmModel:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  def __init__(self, language):
9
  download_kenlm_model(language)
10
  try:
@@ -19,7 +62,9 @@ class KenlmModel:
19
  def from_pretrained(cls, language: str):
20
  return cls(language)
21
 
22
- def get_perplexity(self, doc: str):
 
 
23
  doc_log_score, doc_length = 0, 0
24
  for line in doc.split("\n"):
25
  log_score = self.model.score(line)
@@ -28,6 +73,48 @@ class KenlmModel:
28
  doc_length += length
29
  return 10.0 ** (-doc_log_score / doc_length)
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  def download_kenlm_model(language: str):
33
  root_url = "http://dl.fbaipublicfiles.com/cc_net/lm"
 
1
  import os
2
+ import re
3
+ import unicodedata
4
  import urllib.request
5
+ from typing import Dict
6
 
7
  import kenlm
8
 
9
 
10
  class KenlmModel:
11
+ digit_re: re.Pattern = re.compile(r"\d")
12
+ unicode_punct: Dict[str, str] = {
13
+ ",": ",",
14
+ "。": ".",
15
+ "、": ",",
16
+ "„": '"',
17
+ "”": '"',
18
+ "“": '"',
19
+ "«": '"',
20
+ "»": '"',
21
+ "1": '"',
22
+ "」": '"',
23
+ "「": '"',
24
+ "《": '"',
25
+ "》": '"',
26
+ "´": "'",
27
+ "∶": ":",
28
+ ":": ":",
29
+ "?": "?",
30
+ "!": "!",
31
+ "(": "(",
32
+ ")": ")",
33
+ ";": ";",
34
+ "–": "-",
35
+ "—": " - ",
36
+ ".": ". ",
37
+ "~": "~",
38
+ "’": "'",
39
+ "…": "...",
40
+ "━": "-",
41
+ "〈": "<",
42
+ "〉": ">",
43
+ "【": "[",
44
+ "】": "]",
45
+ "%": "%",
46
+ "►": "-",
47
+ }
48
+ unicode_punct_re = re.compile(f"[{''.join(unicode_punct.keys())}]")
49
+ non_printing_chars_re = re.compile(f"[{''.join(map(chr, list(range(0,32)) + list(range(127,160))))}]")
50
+
51
  def __init__(self, language):
52
  download_kenlm_model(language)
53
  try:
 
62
  def from_pretrained(cls, language: str):
63
  return cls(language)
64
 
65
+ def get_perplexity(self, doc: str, normalize_cc_net: bool = True):
66
+ if normalize_cc_net:
67
+ doc = self.normalize(doc)
68
  doc_log_score, doc_length = 0, 0
69
  for line in doc.split("\n"):
70
  log_score = self.model.score(line)
 
73
  doc_length += length
74
  return 10.0 ** (-doc_log_score / doc_length)
75
 
76
+ def normalize(
77
+ self,
78
+ line: str,
79
+ accent: bool = True,
80
+ case: bool = True,
81
+ numbers: bool = True,
82
+ punct: int = 1,
83
+ ) -> str:
84
+ line = line.strip()
85
+ if not line:
86
+ return line
87
+ if case:
88
+ line = line.lower()
89
+ if accent:
90
+ line = self.strip_accents(line)
91
+ if numbers:
92
+ line = self.digit_re.sub("0", line)
93
+ if punct == 1:
94
+ line = self.replace_unicode_punct(line)
95
+ elif punct == 2:
96
+ line = self.remove_unicode_punct(line)
97
+ line = self.remove_non_printing_char(line)
98
+ return line
99
+
100
+ def strip_accents(self, line: str) -> str:
101
+ """Strips accents from a piece of text."""
102
+ nfd = unicodedata.normalize("NFD", line)
103
+ output = [c for c in nfd if unicodedata.category(c) != "Mn"]
104
+ if len(output) == line:
105
+ return line
106
+ return "".join(output)
107
+
108
+ def replace_unicode_punct(self, text: str) -> str:
109
+ return "".join((self.unicode_punct.get(c, c) for c in text))
110
+
111
+ def remove_unicode_punct(self, text: str) -> str:
112
+ """More aggressive version of replace_unicode_punct but also faster."""
113
+ return self.unicode_punct_re.sub("", text)
114
+
115
+ def remove_non_printing_char(self, text: str) -> str:
116
+ return self.non_printing_chars_re.sub("", text)
117
+
118
 
119
  def download_kenlm_model(language: str):
120
  root_url = "http://dl.fbaipublicfiles.com/cc_net/lm"