Source code for cotk.dataloader.vocab

'''A module for vocab'''
from typing import Optional, List, Dict, Any
from collections import Counter, OrderedDict
from itertools import chain
import logging
import hashlib

from .._utils.typehint import OrderedDictType
from .._utils.metaclass import DocStringInheritor, LoadClassInterface
from .._utils.unordered_hash import dumps
from .context import VocabContext
from .tokenizer import PretrainedTokenizer


[docs]class Vocab(LoadClassInterface, metaclass=DocStringInheritor): '''A class for storing vocabulary. This is an abstract base class. It often works as a part of :class:`Field` or is shared between :class:`Field`. See :ref:`introduction of vocabulary<vocabulary_ref>` for more information. Arguments: This class do not contains arguments for initialization. ''' NOT_SPECIFIED_DOCS = r''' If any argument is not specified, the value will be first retrieved from :class:`VocabContext`. If still ``None``, default value will be used. ''' def __init__(self): if self.__class__.__name__ == "Vocab": raise NotImplementedError("This class is an abstract class, use GeneralVocab instead.") self._setting_hash: Optional[str] = None
[docs] def add_tokens(self, tokens: List[str], vocab_from: str) -> None: '''Add tokens for this vocabulary instance, the tokens will be used for building vocabulary list. Must be called before :meth:`.build_vocab`. Arguments: tokens (List[str]): A list of tokens to add to the vocabulary. vocab_from (str): One of ``train``, ``test``, ``extra``. * ``train``: The tokens are from the training data. Frequent vocabs are selected from tokens of this type. * ``test``: The tokens are from the validation data or test data. Rare vocabs are selected from tokens of this type. * ``extra``: The tokens are from extra data. The tokens of this type will not selected as frequent or rare vocabs. ''' raise NotImplementedError
[docs] def build_vocab(self): '''Building the vocabulary list according to the tokens from :meth:`.add_tokens`. ''' raise NotImplementedError
_VOCAB_MORE_DOCSTRING = "" CONVERT_TOKENS_TO_IDS_ARG = """ tokens (List[str]): List of tokens. only_frequent_word (bool, optional): Use ``unk`` for rare tokens. Defaults: False. """
[docs] def convert_tokens_to_ids(self, tokens: List[str], only_frequent_word=False) -> List[int]: '''Convert list of tokens to list of ids. {_VOCAB_MORE_DOCSTRING} Arguments:{CONVERT_TOKENS_TO_IDS_ARG} ''' raise NotImplementedError
[docs] def convert_ids_to_tokens(self, ids: List[int]) -> List[str]: '''Convert list of ids to list of tokens. {_VOCAB_MORE_DOCSTRING} Arguments: ids (List[int]): List of ids. ''' raise NotImplementedError
@property def frequent_vocab_size(self): '''int: The number of **frequent** words. {_VOCAB_MORE_DOCSTRING} ''' raise NotImplementedError @property def all_vocab_size(self): '''int: The number of frequent words and rare words. {_VOCAB_MORE_DOCSTRING} ''' raise NotImplementedError @property def frequent_vocab_list(self): '''list: The list of frequent words. {_VOCAB_MORE_DOCSTRING} ''' raise NotImplementedError @property def all_vocab_list(self): '''list: The list of frequent words and rare words. Frequent words are always in the front of the list. {_VOCAB_MORE_DOCSTRING} ''' raise NotImplementedError SPECIAL_TOKEN_DOCS = '''Special tokens mapping is an ordered dict \ mapping the general name of special tokens to its string. \ The key must be one of the following: \ ``pad``, ``unk``, ``go``, ``eos``, ``sep``, ``cls``, ``mask``. \ The value can be arbitrary string, e.g., ``"<pad>"``, ``"<unk>"``.'''
[docs] def get_special_tokens_mapping(self) -> OrderedDictType[str, str]: '''Get special tokens mapping. {SPECIAL_TOKEN_DOCS} {_VOCAB_MORE_DOCSTRING} ''' raise NotImplementedError
[docs] def get_special_tokens_id(self, name: str) -> int: '''Get id of special token specifying the general name. Raise ``KeyError`` if no such token in this instance. {_VOCAB_MORE_DOCSTRING} Arguments: name (str): the general name, must be one of the following, ``pad``, ``unk``, ``go``, ``eos``, ``sep``, ``cls``, ``mask``. ''' raise NotImplementedError
@property def pad_id(self) -> int: '''int: The id of pad token. Raise ``KeyError`` if no pad token in this instance. {_VOCAB_MORE_DOCSTRING} ''' return self.get_special_tokens_id("pad") @property def unk_id(self) -> int: '''int: The id of unk token. Raise ``KeyError`` if no unk token in this instance. {_VOCAB_MORE_DOCSTRING} ''' return self.get_special_tokens_id("unk") @property def go_id(self) -> int: '''int: The id of go token. Raise ``KeyError`` if no go token in this instance. {_VOCAB_MORE_DOCSTRING} ''' return self.get_special_tokens_id("go") @property def eos_id(self) -> int: '''int: The id of eos token. Raise ``KeyError`` if no eos token in this instance. {_VOCAB_MORE_DOCSTRING} ''' return self.get_special_tokens_id("eos")
[docs] def get_setting_hash(self) -> str: '''Get setting hash for the Vocabulary instance. See :ref:`here <dataloader_hash_ref>` for the explaination of ``setting hash``. ''' assert self._setting_hash is not None return self._setting_hash
[docs] def get_vocab_hash(self) -> str: '''Get vocab hash for the Vocabulary instance. See :ref:`here <dataloader_hash_ref>` for the explaination of ``vocab hash``. ''' raise NotImplementedError
[docs]class GeneralVocab(Vocab): '''Bases: :class:`.dataloader.Vocab` A vocabulary class for general use. This class always have the following 4 speical tokens: ``pad``, ``unk``, ``go``, ``eos``. {NOT_SPECIFIED_DOCS} Arguments: {MIN_FREQUENT_VOCAB_TIMES_DOCS} {MIN_FREQUENT_VOCAB_TIMES_DEFAULT} {MIN_RARE_VOCAB_TIMES_DOCS} {MIN_RARE_VOCAB_TIMES_DEFAULT} {SPECIAL_TOKEN_DOCS} {SPECIAL_TOKEN_DEFAULT} special_appeared_in_data (bool, optional): If the string of special tokens will appear in the data. Default: If not specified, it will be ``False``. ''' MIN_FREQUENT_VOCAB_TIMES_DOCS = r""" min_frequent_vocab_times (int, optional): Tokens from training data appeared no less than ``min_frequent_vocab_times`` will be regarded as frequent words.""" MIN_FREQUENT_VOCAB_TIMES_DEFAULT = r"""Default: ``0``""" MIN_RARE_VOCAB_TIMES_DOCS = r""" min_rare_vocab_times (int, optional): Tokens from training data or test data appeared more than ``min_rare_vocab_times`` will be regarded as rare words (frequent word excluded). """ MIN_RARE_VOCAB_TIMES_DEFAULT = r"""Default: ``0``""" SPECIAL_TOKEN_DOCS = r""" special_tokens_mapping (OrderedDict, optional): {Vocab.SPECIAL_TOKEN_DOCS} It must at least contains ``pad``, ``unk``, ``go``, ``eos``. All the value of special tokens cannot be the same.""" SPECIAL_TOKEN_DEFAULT = r"""Default: If ``None``, it will be ``OrderedDict([("pad", "<pad>"), ("unk", "<unk>"), ("go", "<go>"), ("eos", "<eos>")]``.""" def __init__(self, min_frequent_vocab_times: Optional[int] = None, \ min_rare_vocab_times: Optional[int] = None, \ special_tokens_mapping: Optional[OrderedDictType[str, str]] = None, \ special_appeared_in_data: Optional[bool] = None): super().__init__() with VocabContext.set_parameters(\ min_frequent_vocab_times=min_frequent_vocab_times,\ min_rare_vocab_times=min_rare_vocab_times,\ special_tokens_mapping=special_tokens_mapping,\ special_appeared_in_data=special_appeared_in_data): self.min_frequent_vocab_times: int = VocabContext.get("min_frequent_vocab_times", 0) self.min_rare_vocab_times: int = VocabContext.get("min_rare_vocab_times", 0) filled_special_tokens: Optional[OrderedDictType[str, str]] = VocabContext.get("special_tokens_mapping", None) self.special_appeared_in_data: bool = VocabContext.get("special_appeared_in_data", False) self.special_tokens_mapping = filled_special_tokens or OrderedDict( [("pad", "<pad>"), ("unk", "<unk>"), ("go", "<go>"), ("eos", "<eos>")] ) if {"pad", "unk", "go", "eos"}.difference(set(self.special_tokens_mapping.keys())): raise ValueError("Special tokens should at least contains 4 tokens: pad, unk, go, eos.") if set(self.special_tokens_mapping.keys()).difference({"pad", "unk", "go", "eos", "sep", "cls", "mask"}): raise ValueError("Special tokens should not contains keys other than pad, unk, go, eos, sep, cls, mask.") if len(set(self.special_tokens_mapping.values())) != len(set(self.special_tokens_mapping.keys())): raise ValueError("All the value of special tokens cannot be the same.") self.mode = "init" self.train_tokens: Optional[List[str]] = [] self.test_tokens: Optional[List[str]] = [] self._all_vocab_list: Optional[List[str]] = None self.word2id: Optional[Dict[str, int]] = None self._frequent_vocab_size: int = 0 self._setting_hash = hashlib.sha256(dumps([ \ "Vocab", \ "configs", \ self.min_frequent_vocab_times, \ self.min_rare_vocab_times, \ self.special_tokens_mapping, \ self.special_appeared_in_data \ ])).hexdigest()
[docs] @staticmethod def from_predefined(vocab_list: List[str], \ frequent_vocab_size: int, \ special_tokens_mapping: Optional[OrderedDictType[str, str]] = None) -> "GeneralVocab": '''Return a :class:`GeneralVocab` instance, whose vocabulary comes from a predefined list. See :meth:`.from_predefined_vocab` if you want to use the vocabulary from an existing :class:`GeneralVocab` instance. Arguments: vocab_list (List[str]): A list of all vocabulary. frequent_vocab_size (int): the length of the frequent words. The frequent word must be in the front of the ``vocab_list``. {SPECIAL_TOKEN_DOCS} Special tokens MUST be in the front of the ``frequent_vocab_list`` (ordered sensitive). {SPECIAL_TOKEN_DEFAULT} ''' vocab = GeneralVocab(special_tokens_mapping=special_tokens_mapping) special_values = list(vocab.get_special_tokens_mapping().values()) if vocab_list[:len(special_values)] != special_values: raise ValueError("special tokens should be in the front of the vocab_list, where special tokens are %s, but \ the first tokens of vocab_list are %s." % (repr(special_values), repr(vocab_list[:len(special_values)])) ) if len(set(vocab_list)) != len(vocab_list): raise ValueError("vocab_list should not contain a single token multiple times") #pylint: disable=protected-access vocab.mode = "finish" vocab._all_vocab_list = vocab_list vocab._frequent_vocab_size = frequent_vocab_size vocab.word2id = {w: i for i, w in enumerate(vocab.all_vocab_list)} vocab.train_tokens = None vocab.test_tokens = None vocab._setting_hash = hashlib.sha256(dumps([ \ "Vocab", \ "predefined", \ vocab.all_vocab_list, \ vocab._frequent_vocab_size, \ len(vocab.special_tokens_mapping) \ ])).hexdigest() return vocab
[docs] @staticmethod def from_predefined_vocab(vocab: "GeneralVocab") -> "GeneralVocab": '''Return a new :class:`GeneralVocab` instance from ``vocab``. The new instance have the same vocabulary list as the old one. Arguments: vocab(:class:`GeneralVocab`): The old instance. ''' if not isinstance(vocab, GeneralVocab): raise TypeError("vocab must be an instance of GeneralVocab class.") vocab_list = vocab.all_vocab_list frequent_vocab_size = vocab._frequent_vocab_size special_token_mappings = vocab.get_special_tokens_mapping() return GeneralVocab.from_predefined(vocab_list, frequent_vocab_size, special_token_mappings)
[docs] @staticmethod def from_frequent_word(frequent_vocab_list: List[str], \ special_tokens_mapping: Optional[OrderedDictType[str, str]] = None) -> "GeneralVocab": '''Return a :class:`GeneralVocab` instance, whose vocabulary comes from a predefined frequent list. And its rare word list can be built later. See :meth:`.from_frequent_word_of_vocab` if you want to use the frequent vocabulary from an existing :class:`GeneralVocab` instance. Arguments: frequent_vocab_list (List[str]): A list of frequent vocabulary. {SPECIAL_TOKEN_DOCS} Special tokens MUST be in the front of the ``frequent_vocab_list`` (ordered sensitive). {SPECIAL_TOKEN_DEFAULT} ''' vocab = GeneralVocab(special_tokens_mapping=special_tokens_mapping) special_values = list(vocab.get_special_tokens_mapping().values()) if frequent_vocab_list[:len(special_values)] != special_values: raise ValueError("special tokens should be in the front of the vocab_list, where special tokens are %s, but \ the first tokens of vocab_list are %s." % (repr(special_values), repr(frequent_vocab_list[:len(special_values)])) ) #pylint: disable=protected-access vocab.mode = "frequent_specified" vocab._all_vocab_list = frequent_vocab_list vocab._setting_hash = hashlib.sha256(dumps([ \ "Vocab", \ "frequent", \ frequent_vocab_list, \ special_tokens_mapping \ ])).hexdigest() return vocab
[docs] @staticmethod def from_frequent_word_of_vocab(vocab: "GeneralVocab") -> "GeneralVocab": '''Return a :class:`GeneralVocab` instance, which has the same frequent vocabulary list as the old one. The rare word list can be built later. Arguments: vocab(:class:`GeneralVocab`): The old instance to provide frequent words. ''' if not isinstance(vocab, GeneralVocab): raise TypeError("vocab must be an instance of GeneralVocab class.") vocab_list = vocab.all_vocab_list frequent_vocab_size = vocab.frequent_vocab_size special_token_mappings = vocab.get_special_tokens_mapping() return GeneralVocab.from_predefined(vocab_list[:frequent_vocab_size], special_token_mappings)
def add_tokens(self, tokens: List[str], vocab_from: str) -> None: if self.train_tokens is None or self.test_tokens is None: return #raise RuntimeError("Vocabulary has been built, cannot add more tokens.") if vocab_from == "train": self.train_tokens.extend(tokens) elif vocab_from == "test": self.test_tokens.extend(tokens) elif vocab_from == "extra": pass else: raise ValueError("Unknown vocab_from: %s, only supports frequent, rare, extra or default" % vocab_from) def build_vocab(self) -> None: if self.mode == "finish": return #raise RuntimeError("Vocabulary has been built, cannot build again.") if self.train_tokens is None or self.test_tokens is None: raise RuntimeError("Train tokens or test tokens should not be None") if not self.special_appeared_in_data: all_token_set = set(chain(self.train_tokens, self.test_tokens)) for special_token in self.special_tokens_mapping.values(): if special_token in all_token_set: raise RuntimeError("Dataset file contains special tokens %s. If it is desired, try to set \ 'special_appeared_in_data' to True in Vocab or Dataloader." % special_token) exclude_set = set(self.special_tokens_mapping.values()) if self.mode != "frequent_specified": assert self._all_vocab_list is None vocab = sorted(Counter(self.train_tokens).most_common(), \ key=lambda pair: (-pair[1], pair[0])) frequent_vocab = [x[0] for x in vocab if x[1] >= self.min_frequent_vocab_times and x[0] not in exclude_set] else: assert self._all_vocab_list is not None frequent_vocab = self._all_vocab_list exclude_set.update(frequent_vocab) vocab = sorted(Counter(chain(self.train_tokens, self.test_tokens)).most_common(), \ key=lambda pair: (-pair[1], pair[0])) rare_vocab = [x[0] for x in vocab if x[1] >= self.min_rare_vocab_times \ and x[0] not in exclude_set] self._all_vocab_list = list(self.special_tokens_mapping.values()) + frequent_vocab + rare_vocab self._frequent_vocab_size = len(self.special_tokens_mapping) + len(frequent_vocab) logging.info("frequent vocab list length = %d", self._frequent_vocab_size) logging.info("frequent + rare vocab list length = %d", len(self._all_vocab_list)) self.word2id = {w: i for i, w in enumerate(self._all_vocab_list)} self.train_tokens = None self.test_tokens = None self.mode = "finish" def get_special_tokens_id(self, name) -> int: try: return self.word2id[self.special_tokens_mapping[name]] # type: ignore except KeyError: raise KeyError("No such special token in this class") def convert_tokens_to_ids(self, tokens: List[str], only_frequent_word=False) -> List[int]: if self.word2id is None: raise RuntimeError("You have to run build_vocab first") ids = [self.word2id.get(token, self.unk_id) for token in tokens] if only_frequent_word: ids = [self.unk_id if i >= self._frequent_vocab_size else i for i in ids] return ids def convert_ids_to_tokens(self, ids: List[int]) -> List[str]: if self._all_vocab_list is None: raise RuntimeError("You have to run build_vocab first") return [self._all_vocab_list[word] for word in ids] def get_vocab_hash(self) -> str: return hashlib.sha256(dumps([ \ self._all_vocab_list, \ self._frequent_vocab_size, \ len(self.special_tokens_mapping) \ ])).hexdigest() @property def frequent_vocab_size(self): return self._frequent_vocab_size @property def all_vocab_size(self): return len(self._all_vocab_list) # type: ignore @property def frequent_vocab_list(self): return self._all_vocab_list[:self._frequent_vocab_size] # type: ignore @property def all_vocab_list(self): return self._all_vocab_list[:] # type: ignore def get_special_tokens_mapping(self): return self.special_tokens_mapping
[docs]class PretrainedVocab(Vocab): '''Bases: :class:`.dataloader.Vocab` Use the vocabulary from a pretrained tokenizer in ``transformers`` package. This class is usually used for pretrained models, and it **do NOT** have rare words. Unlike :class:`GeneralVocab`, this class do not always have ``pad``, ``unk``, ``go``, ``eos``. Some special tokens may refer to the same token. Arguments: tokenizer (``transformers.PretrainedTokenizer``): A pretrained tokenizer from transformers package. ''' def __init__(self, tokenizer: Any): super().__init__() self.tokenizer = PretrainedTokenizer(tokenizer) self._inner_tokenizer = tokenizer self._setting_hash = hashlib.sha256(dumps(["pretrained", self.tokenizer.get_setting_hash()])).hexdigest() def add_tokens(self, tokens: List[str], vocab_from: str) -> None: pass def build_vocab(self) -> None: pass def get_vocab_hash(self): return self._setting_hash #vocab hash is represented by tokenizer def convert_tokens_to_ids(self, tokens: List[str], only_frequent_word=False) -> List[int]: return self._inner_tokenizer.convert_tokens_to_ids(tokens) def convert_ids_to_tokens(self, ids: List[int]) -> List[str]: return self._inner_tokenizer.convert_ids_to_tokens(ids) @property def frequent_vocab_size(self): return self._inner_tokenizer.vocab_size @property def all_vocab_size(self): return self._inner_tokenizer.vocab_size @property def frequent_vocab_list(self): return self.convert_ids_to_tokens(list(range(self.frequent_vocab_size))) @property def all_vocab_list(self): return self.frequent_vocab_list def get_special_tokens_mapping(self): old_key = ["pad_token", "unk_token", "bos_token", "eos_token", "sep_token", "cls_token", "mask_token"] new_key = ["pad", "unk", "go", "eos", "sep", "cls", "mask"] res = OrderedDict() for key, value in self._inner_tokenizer.special_tokens_map.items(): if key in old_key: idx = old_key.index(key) res[new_key[idx]] = value return res def get_special_tokens_id(self, name): try: return self.convert_tokens_to_ids([self.get_special_tokens_mapping()[name]])[0] except KeyError: raise KeyError("No such special token in this class")
[docs]class SimpleVocab(Vocab): """Bases: :class:`.dataloader.Vocab` A very simple vocabulary class. No rare vocabs or special tokens. Used by :class:`SparseLabel`. Arguments: This class do not contains arguments for initialization. """ def __init__(self): super().__init__() self._setting_hash = hashlib.sha256( dumps([self.__class__.__name__, "configs"]) ).hexdigest() self._token_counter = Counter() self._all_vocab_list: List[str] = None self.word2id: Dict[str, int] = None self.mode = "init" def add_tokens(self, tokens: List[str], vocab_from: str) -> None: if self.mode == "init": for token, num in Counter(tokens).items(): self._token_counter[token] += num add_tokens.__doc__ = Vocab.add_tokens.__doc__ + r""" Notes: Since frequency is not important in this class, argument `vocab_from` has no effect. """ def build_vocab(self): if self.mode == "finish": return vocabs = sorted( self._token_counter.items(), key=lambda item:(-item[1], item[0]) ) self._all_vocab_list = [item[0] for item in vocabs] self.word2id = {w: i for i, w in enumerate(self._all_vocab_list)} self.mode = "finish" self._token_counter = None def convert_tokens_to_ids(self, tokens: List[str], only_frequent_word=False) -> List[int]: if self.word2id is None: raise RuntimeError("You have to run build_vocab first") return [self.word2id[token] for token in tokens] def convert_ids_to_tokens(self, ids: List[int]) -> List[str]: if self._all_vocab_list is None: raise RuntimeError("You have to run build_vocab first") return [self._all_vocab_list[i] for i in ids] @property def frequent_vocab_size(self): return len(self._all_vocab_list) @property def all_vocab_size(self): return len(self._all_vocab_list) @property def frequent_vocab_list(self): return self._all_vocab_list @property def all_vocab_list(self): return self._all_vocab_list def get_special_tokens_mapping(self) -> OrderedDictType[str, str]: return {} def get_special_tokens_id(self, name: str) -> int: raise NotImplementedError("SimpleVocab don\'t use any special tokens.") def get_vocab_hash(self) -> str: return hashlib.sha256( dumps([self._all_vocab_list]) ).hexdigest()