Steven C commited on
Commit
37a77f2
1 Parent(s): f24f2e7

Format tokenizer_base

Browse files
Files changed (1) hide show
  1. tokenizer_base.py +41 -21
tokenizer_base.py CHANGED
@@ -1,4 +1,3 @@
1
- import re
2
  from abc import ABC, abstractmethod
3
  from itertools import groupby
4
  from typing import List, Optional, Tuple
@@ -13,10 +12,9 @@ class CharsetAdapter:
13
 
14
  def __init__(self, target_charset) -> None:
15
  super().__init__()
16
- self.charset = target_charset ###
17
  self.lowercase_only = target_charset == target_charset.lower()
18
  self.uppercase_only = target_charset == target_charset.upper()
19
- # self.unsupported = f'[^{re.escape(target_charset)}]'
20
 
21
  def __call__(self, label):
22
  if self.lowercase_only:
@@ -28,8 +26,10 @@ class CharsetAdapter:
28
 
29
  class BaseTokenizer(ABC):
30
 
31
- def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None:
32
- self._itos = specials_first + tuple(charset+'[UNK]') + specials_last
 
 
33
  self._stoi = {s: i for i, s in enumerate(self._itos)}
34
 
35
  def __len__(self):
@@ -40,10 +40,12 @@ class BaseTokenizer(ABC):
40
 
41
  def _ids2tok(self, token_ids: List[int], join: bool = True) -> str:
42
  tokens = [self._itos[i] for i in token_ids]
43
- return ''.join(tokens) if join else tokens
44
 
45
  @abstractmethod
46
- def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
 
 
47
  """Encode a batch of labels to a representation suitable for the model.
48
 
49
  Args:
@@ -60,7 +62,9 @@ class BaseTokenizer(ABC):
60
  """Internal method which performs the necessary filtering prior to decoding."""
61
  raise NotImplementedError
62
 
63
- def decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]:
 
 
64
  """Decode a batch of token distributions.
65
 
66
  Args:
@@ -84,19 +88,29 @@ class BaseTokenizer(ABC):
84
 
85
 
86
  class Tokenizer(BaseTokenizer):
87
- BOS = '[B]'
88
- EOS = '[E]'
89
- PAD = '[P]'
90
 
91
  def __init__(self, charset: str) -> None:
92
  specials_first = (self.EOS,)
93
  specials_last = (self.BOS, self.PAD)
94
  super().__init__(charset, specials_first, specials_last)
95
- self.eos_id, self.bos_id, self.pad_id = [self._stoi[s] for s in specials_first + specials_last]
96
-
97
- def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
98
- batch = [torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device)
99
- for y in labels]
 
 
 
 
 
 
 
 
 
 
100
  return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
101
 
102
  def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
@@ -107,21 +121,27 @@ class Tokenizer(BaseTokenizer):
107
  eos_idx = len(ids) # Nothing to truncate.
108
  # Truncate after EOS
109
  ids = ids[:eos_idx]
110
- probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists)
 
111
  return probs, ids
112
 
113
 
114
  class CTCTokenizer(BaseTokenizer):
115
- BLANK = '[B]'
116
 
117
  def __init__(self, charset: str) -> None:
118
  # BLANK uses index == 0 by default
119
  super().__init__(charset, specials_first=(self.BLANK,))
120
  self.blank_id = self._stoi[self.BLANK]
121
 
122
- def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor:
 
 
123
  # We use a padded representation since we don't want to use CUDNN's CTC implementation
124
- batch = [torch.as_tensor(self._tok2ids(y), dtype=torch.long, device=device) for y in labels]
 
 
 
125
  return pad_sequence(batch, batch_first=True, padding_value=self.blank_id)
126
 
127
  def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
@@ -129,4 +149,4 @@ class CTCTokenizer(BaseTokenizer):
129
  ids = list(zip(*groupby(ids.tolist())))[0] # Remove duplicate tokens
130
  ids = [x for x in ids if x != self.blank_id] # Remove BLANKs
131
  # `probs` is just pass-through since all positions are considered part of the path
132
- return probs, ids
 
 
1
  from abc import ABC, abstractmethod
2
  from itertools import groupby
3
  from typing import List, Optional, Tuple
 
12
 
13
  def __init__(self, target_charset) -> None:
14
  super().__init__()
15
+ self.charset = target_charset
16
  self.lowercase_only = target_charset == target_charset.lower()
17
  self.uppercase_only = target_charset == target_charset.upper()
 
18
 
19
  def __call__(self, label):
20
  if self.lowercase_only:
 
26
 
27
  class BaseTokenizer(ABC):
28
 
29
+ def __init__(
30
+ self, charset: str, specials_first: tuple = (), specials_last: tuple = ()
31
+ ) -> None:
32
+ self._itos = specials_first + tuple(charset + "[UNK]") + specials_last
33
  self._stoi = {s: i for i, s in enumerate(self._itos)}
34
 
35
  def __len__(self):
 
40
 
41
  def _ids2tok(self, token_ids: List[int], join: bool = True) -> str:
42
  tokens = [self._itos[i] for i in token_ids]
43
+ return "".join(tokens) if join else tokens
44
 
45
  @abstractmethod
46
+ def encode(
47
+ self, labels: List[str], device: Optional[torch.device] = None
48
+ ) -> Tensor:
49
  """Encode a batch of labels to a representation suitable for the model.
50
 
51
  Args:
 
62
  """Internal method which performs the necessary filtering prior to decoding."""
63
  raise NotImplementedError
64
 
65
+ def decode(
66
+ self, token_dists: Tensor, raw: bool = False
67
+ ) -> Tuple[List[str], List[Tensor]]:
68
  """Decode a batch of token distributions.
69
 
70
  Args:
 
88
 
89
 
90
  class Tokenizer(BaseTokenizer):
91
+ BOS = "[B]"
92
+ EOS = "[E]"
93
+ PAD = "[P]"
94
 
95
  def __init__(self, charset: str) -> None:
96
  specials_first = (self.EOS,)
97
  specials_last = (self.BOS, self.PAD)
98
  super().__init__(charset, specials_first, specials_last)
99
+ self.eos_id, self.bos_id, self.pad_id = [
100
+ self._stoi[s] for s in specials_first + specials_last
101
+ ]
102
+
103
+ def encode(
104
+ self, labels: List[str], device: Optional[torch.device] = None
105
+ ) -> Tensor:
106
+ batch = [
107
+ torch.as_tensor(
108
+ [self.bos_id] + self._tok2ids(y) + [self.eos_id],
109
+ dtype=torch.long,
110
+ device=device,
111
+ )
112
+ for y in labels
113
+ ]
114
  return pad_sequence(batch, batch_first=True, padding_value=self.pad_id)
115
 
116
  def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
 
121
  eos_idx = len(ids) # Nothing to truncate.
122
  # Truncate after EOS
123
  ids = ids[:eos_idx]
124
+ # but include prob. for EOS (if it exists)
125
+ probs = probs[: eos_idx + 1]
126
  return probs, ids
127
 
128
 
129
  class CTCTokenizer(BaseTokenizer):
130
+ BLANK = "[B]"
131
 
132
  def __init__(self, charset: str) -> None:
133
  # BLANK uses index == 0 by default
134
  super().__init__(charset, specials_first=(self.BLANK,))
135
  self.blank_id = self._stoi[self.BLANK]
136
 
137
+ def encode(
138
+ self, labels: List[str], device: Optional[torch.device] = None
139
+ ) -> Tensor:
140
  # We use a padded representation since we don't want to use CUDNN's CTC implementation
141
+ batch = [
142
+ torch.as_tensor(self._tok2ids(y), dtype=torch.long, device=device)
143
+ for y in labels
144
+ ]
145
  return pad_sequence(batch, batch_first=True, padding_value=self.blank_id)
146
 
147
  def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]:
 
149
  ids = list(zip(*groupby(ids.tolist())))[0] # Remove duplicate tokens
150
  ids = [x for x in ids if x != self.blank_id] # Remove BLANKs
151
  # `probs` is just pass-through since all positions are considered part of the path
152
+ return probs, ids