toandev commited on
Commit
5d724fc
·
1 Parent(s): d1ab771
Files changed (1) hide show
  1. utils/tokenizer_base.py +44 -30
utils/tokenizer_base.py CHANGED
@@ -1,4 +1,3 @@
1
- import re
2
  from abc import ABC, abstractmethod
3
  from itertools import groupby
4
  from typing import List, Optional, Tuple
@@ -13,10 +12,9 @@ class CharsetAdapter:
13
 
14
  def __init__(self, target_charset) -> None:
15
  super().__init__()
16
- self.charset = target_charset ###
17
  self.lowercase_only = target_charset == target_charset.lower()
18
  self.uppercase_only = target_charset == target_charset.upper()
19
- # self.unsupported = f'[^{re.escape(target_charset)}]'
20
 
21
  def __call__(self, label):
22
  if self.lowercase_only:
@@ -28,8 +26,10 @@ class CharsetAdapter:
28
 
29
  class BaseTokenizer(ABC):
30
 
31
- def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None:
32
- self._itos = specials_first + tuple(charset+'[UNK]') + specials_last
 
 
33
  self._stoi = {s: i for i, s in enumerate(self._itos)}
34
 
35
  def __len__(self):
@@ -40,10 +40,12 @@ class BaseTokenizer(ABC):
40
 
41
  def _ids2tok(self, token_ids: List[int], join: bool = True) -> str:
42
  tokens = [self._itos[i] for i in token_ids]
43
- return ''.join(tokens) if join else tokens
44
 
45
  @abstractmethod
46
- def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
 
 
47
  """Encode a batch of labels to a representation suitable for the model.
48
 
49
  Args:
@@ -60,7 +62,9 @@ class BaseTokenizer(ABC):
60
  """Internal method which performs the necessary filtering prior to decoding."""
61
  raise NotImplementedError
62
 
63
- def decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]:
 
 
64
  """Decode a batch of token distributions.
65
 
66
  Args:
@@ -74,7 +78,7 @@ class BaseTokenizer(ABC):
74
  batch_tokens = []
75
  batch_probs = []
76
  for dist in token_dists:
77
- probs, ids = dist.max(-1) # greedy selection
78
  if not raw:
79
  probs, ids = self._filter(probs, ids)
80
  tokens = self._ids2tok(ids, not raw)
@@ -84,19 +88,29 @@ class BaseTokenizer(ABC):
84
 
85
 
86
  class Tokenizer(BaseTokenizer):
87
- BOS = '[B]'
88
- EOS = '[E]'
89
- PAD = '[P]'
90
 
91
  def __init__(self, charset: str) -> None:
92
  specials_first = (self.EOS,)
93
  specials_last = (self.BOS, self.PAD)
94
  super().__init__(charset, specials_first, specials_last)
95
- self.eos_id, self.bos_id, self.pad_id = [self._stoi[s] for s in specials_first + specials_last]
96
-
97
- def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
98
- batch = [torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device)
99
- for y in labels]
 
 
 
 
 
 
 
 
 
 
100
  return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
101
 
102
  def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
@@ -104,29 +118,29 @@ class Tokenizer(BaseTokenizer):
104
  try:
105
  eos_idx = ids.index(self.eos_id)
106
  except ValueError:
107
- eos_idx = len(ids) # Nothing to truncate.
108
- # Truncate after EOS
109
  ids = ids[:eos_idx]
110
- probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists)
111
  return probs, ids
112
 
113
 
114
  class CTCTokenizer(BaseTokenizer):
115
- BLANK = '[B]'
116
 
117
  def __init__(self, charset: str) -> None:
118
- # BLANK uses index == 0 by default
119
  super().__init__(charset, specials_first=(self.BLANK,))
120
  self.blank_id = self._stoi[self.BLANK]
121
 
122
- def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
123
- # We use a padded representation since we don't want to use CUDNN's CTC implementation
124
- batch = [torch.as_tensor(self._tok2ids(y), dtype=torch.long, device=device) for y in labels]
 
 
 
 
125
  return pad_sequence(batch, batch_first=True, padding_value=self.blank_id)
126
 
127
  def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
128
- # Best path decoding:
129
- ids = list(zip(*groupby(ids.tolist())))[0] # Remove duplicate tokens
130
- ids = [x for x in ids if x != self.blank_id] # Remove BLANKs
131
- # `probs` is just pass-through since all positions are considered part of the path
132
- return probs, ids
 
 
1
  from abc import ABC, abstractmethod
2
  from itertools import groupby
3
  from typing import List, Optional, Tuple
 
12
 
13
  def __init__(self, target_charset) -> None:
14
  super().__init__()
15
+ self.charset = target_charset
16
  self.lowercase_only = target_charset == target_charset.lower()
17
  self.uppercase_only = target_charset == target_charset.upper()
 
18
 
19
  def __call__(self, label):
20
  if self.lowercase_only:
 
26
 
27
  class BaseTokenizer(ABC):
28
 
29
+ def __init__(
30
+ self, charset: str, specials_first: tuple = (), specials_last: tuple = ()
31
+ ) -> None:
32
+ self._itos = specials_first + tuple(charset + "[UNK]") + specials_last
33
  self._stoi = {s: i for i, s in enumerate(self._itos)}
34
 
35
  def __len__(self):
 
40
 
41
  def _ids2tok(self, token_ids: List[int], join: bool = True) -> str:
42
  tokens = [self._itos[i] for i in token_ids]
43
+ return "".join(tokens) if join else tokens
44
 
45
  @abstractmethod
46
+ def encode(
47
+ self, labels: List[str], device: Optional[torch.device] = None
48
+ ) -> Tensor:
49
  """Encode a batch of labels to a representation suitable for the model.
50
 
51
  Args:
 
62
  """Internal method which performs the necessary filtering prior to decoding."""
63
  raise NotImplementedError
64
 
65
+ def decode(
66
+ self, token_dists: Tensor, raw: bool = False
67
+ ) -> Tuple[List[str], List[Tensor]]:
68
  """Decode a batch of token distributions.
69
 
70
  Args:
 
78
  batch_tokens = []
79
  batch_probs = []
80
  for dist in token_dists:
81
+ probs, ids = dist.max(-1)
82
  if not raw:
83
  probs, ids = self._filter(probs, ids)
84
  tokens = self._ids2tok(ids, not raw)
 
88
 
89
 
90
  class Tokenizer(BaseTokenizer):
91
+ BOS = "[B]"
92
+ EOS = "[E]"
93
+ PAD = "[P]"
94
 
95
  def __init__(self, charset: str) -> None:
96
  specials_first = (self.EOS,)
97
  specials_last = (self.BOS, self.PAD)
98
  super().__init__(charset, specials_first, specials_last)
99
+ self.eos_id, self.bos_id, self.pad_id = [
100
+ self._stoi[s] for s in specials_first + specials_last
101
+ ]
102
+
103
+ def encode(
104
+ self, labels: List[str], device: Optional[torch.device] = None
105
+ ) -> Tensor:
106
+ batch = [
107
+ torch.as_tensor(
108
+ [self.bos_id] + self._tok2ids(y) + [self.eos_id],
109
+ dtype=torch.long,
110
+ device=device,
111
+ )
112
+ for y in labels
113
+ ]
114
  return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
115
 
116
  def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
 
118
  try:
119
  eos_idx = ids.index(self.eos_id)
120
  except ValueError:
121
+ eos_idx = len(ids)
 
122
  ids = ids[:eos_idx]
123
+ probs = probs[: eos_idx + 1]
124
  return probs, ids
125
 
126
 
127
  class CTCTokenizer(BaseTokenizer):
128
+ BLANK = "[B]"
129
 
130
  def __init__(self, charset: str) -> None:
 
131
  super().__init__(charset, specials_first=(self.BLANK,))
132
  self.blank_id = self._stoi[self.BLANK]
133
 
134
+ def encode(
135
+ self, labels: List[str], device: Optional[torch.device] = None
136
+ ) -> Tensor:
137
+ batch = [
138
+ torch.as_tensor(self._tok2ids(y), dtype=torch.long, device=device)
139
+ for y in labels
140
+ ]
141
  return pad_sequence(batch, batch_first=True, padding_value=self.blank_id)
142
 
143
  def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
144
+ ids = list(zip(*groupby(ids.tolist())))[0]
145
+ ids = [x for x in ids if x != self.blank_id]
146
+ return probs, ids