diff --git a/src/gluonnlp/vocab/vocab.py b/src/gluonnlp/vocab/vocab.py index 5483b3b678..01fee4a5ab 100644 --- a/src/gluonnlp/vocab/vocab.py +++ b/src/gluonnlp/vocab/vocab.py @@ -30,7 +30,7 @@ from .. import _constants as C from .. import embedding as emb -from ..data.utils import Counter, DefaultLookupDict, count_tokens +from ..data.utils import Counter, count_tokens UNK_IDX = 0 _DEPR_PAD = object() @@ -219,10 +219,7 @@ def __init__(self, counter: Optional[Counter] = None, max_size: Optional[int] = # Set up idx_to_token and token_to_idx based on presence of unknown token self._unknown_token = unknown_token self._idx_to_token = [unknown_token] if unknown_token else [] - if unknown_token: - self._token_to_idx = DefaultLookupDict(UNK_IDX) - else: - self._token_to_idx = {} + self._token_to_idx = dict() # Handle special tokens special_tokens = [] @@ -267,10 +264,6 @@ def __init__(self, counter: Optional[Counter] = None, max_size: Optional[int] = if token_to_idx: self._sort_index_according_to_user_specification(token_to_idx) - if unknown_token: - self._token_to_idx._default = \ - self._token_to_idx[unknown_token] # pytype: disable=not-writable - def _index_counter_keys(self, counter, unknown_token, special_tokens, max_size, min_freq): @@ -395,9 +388,17 @@ def __getitem__(self, tokens): """ if not isinstance(tokens, (list, tuple)): - return self._token_to_idx[tokens] + if self._unknown_token: + unknown_token_idx = self._token_to_idx[self._unknown_token] + return self._token_to_idx.get(tokens, unknown_token_idx) + else: + return self._token_to_idx[tokens] else: - return [self._token_to_idx[token] for token in tokens] + if self._unknown_token: + unknown_token_idx = self._token_to_idx[self._unknown_token] + return [self._token_to_idx.get(token, unknown_token_idx) for token in tokens] + else: + return [self._token_to_idx[token] for token in tokens] def __len__(self): return len(self._idx_to_token)