Source code for cotk.metric.metric

"""
``cotk.metrics`` provides classes and functions evaluating results of models.
It provides a fair metric for every model.
"""
from typing import Any, List, Dict
import hashlib

from .._utils.unordered_hash import UnorderedSha256, dumps
from .._utils.metaclass import LoadClassInterface, DocStringInheritor

[docs]class MetricBase(LoadClassInterface, metaclass=DocStringInheritor): '''Base class for metrics. ''' DATALOADER_ARGUMENTS = \ """dataloader (:class:`.dataloader.LanguageProcessing`, :class:`.dataloader.Sentence`, :class:`.dataloader.Session`): \ A language generation dataloader.""" MULTI_TURN_DATALOADER_ARGUMENTS = \ """dataloader (:class:`.dataloader.LanguageProcessing`, :class:`.dataloader.Session`): \ A language generation dataloader.""" NGRAM_ARGUMENTS = \ """ngram (int, optional): The order of ngram to calculate metrics like BLEU and Perplexity. Default: ``4``.""" TOKENIZER_ARGUMENTS = \ """tokenizer (None, :class:`.dataloader.Tokenizer`, str, optional): Specifies the tokenizer used in \ the metric. Default: ``None``.""" IGNORE_SMOOTHING_ERROR_ARGUMENTS = \ """ignore_smoothing_error (bool, optional): Specifies whether to ignore the smoothing error when calculating \ BLEU. Default: ``False``.""" SAMPLE_ARGUMENTS_IN_BLEU = \ """sample (int, optional): Number of examples sampled from the generated sentences. Default: ``1000``.""" SAMPLE_ARGUMENTS_IN_NGRAM_PERPLEXITY = \ SAMPLE_ARGUMENTS_IN_BLEU.replace("Default: ``1000``.", "Default: ``10000``.") SEED_ARGUMENTS = \ """seed (int, optional): Random seed for sampling. Default: ``1229``.""" REFERENCE_TEST_LIST_ARGUMENTS = \ """reference_test_list (list): Reference sentences with :ref:`all vocabs <vocabulary_ref>` in test data.""" REFERENCE_ALLVOCABS_KEY_ARGUMENTS = \ """reference_allvocabs_key (str, optional): \ The key of reference sentences. Default: ``ref_allvocabs``.""" FORWARD_REFERENCE_ALLVOCABS_ARGUMENTS = \ """* **data[reference_allvocabs_key]** (list, :class:`numpy.ndarray`): \ A 2-d jagged or padded array of int. Reference sentences with \ :ref:`allvocabs <vocabulary_ref>` in index form. \ Contains start token (eg: ``<go>``) and end token (eg: ``<eos>``). \ Size: ``[batch_size, ~ref_sentence_length]``, \ where "~" means different sizes in this dimension is allowed.""" FORWARD_REFERENCE_ALLVOCABS_ARGUMENTS_WITH_TORCH = \ FORWARD_REFERENCE_ALLVOCABS_ARGUMENTS.replace("list, :class:`numpy.ndarray`", \ "list, :class:`numpy.ndarray`, :class:`torch.Tensor`") FORWARD_POST_ALLVOCABS_ARGUMENTS = \ FORWARD_REFERENCE_ALLVOCABS_ARGUMENTS.replace("reference_allvocabs_key", \ "post_allvocabs_key") FORWARD_RESP_ALLVOCABS_ARGUMENTS = \ FORWARD_REFERENCE_ALLVOCABS_ARGUMENTS.replace("reference_allvocabs_key", \ "resp_allvocabs_key") LABEL_KEY_ARGUMENTS = \ """label_key (str): \ The key of reference sentence labels. Default: ``label``.""" LABEL_ARGUMENTS = """* **data[label_key]** (list or :class:`numpy.ndarray`): \ A 1-d array of int. \ Size: ``[batch_size]``, \ each element refers to label of one sample""" PREDICTION_KEY_ARGUMENTS = \ """prediction_key (str): \ The key of reference sentence predictions. Default: ``prediction``.""" PREDICTION_ARGUMENTS = """* **data[prediction_key]** (list or :class:`numpy.ndarray`): \ A 1-d array of int. \ Size: ``[batch_size]``, \ each element refers to prediction for one sample""" MULTI_TURN_REFERENCE_ALLVOCABS_KEY_ARGUMENTS = \ """multi_turn_reference_allvocabs_key (str, optional): \ The key of reference sentences. Default: ``multi_turn_ref_allvocabs``.""" FORWARD_MULTI_TURN_REFERENCE_ALLVOCABS_ARGUMENTS = \ """* **data[multi_turn_reference_allvocabs_key]** (list, :class:`numpy.ndarray`): \ A 3-d jagged or padded array of int. Multi-turn reference sentences with \ :ref:`all vocabs <vocabulary_ref>`. Contains start token (eg: ``<go>``) and \ end token (eg: ``<eos>``). Size: ``[batch_size, ~turn_length, ~sentence_length]``, \ where "~" means different sizes in this dimension is allowed.""" FORWARD_MULTI_TURN_REFERENCE_ALLVOCABS_ARGUMENTS_WITH_TORCH = \ FORWARD_MULTI_TURN_REFERENCE_ALLVOCABS_ARGUMENTS.replace("list, :class:`numpy.ndarray`", \ "list, :class:`numpy.ndarray`, :class:`torch.Tensor`") REFERENCE_LEN_KEY_ARGUMENTS = \ """reference_len_key (str, optional): \ The key of lengths of reference sentences. \ Default: ``ref_length``.""" FORWARD_REFERENCE_LEN_ARGUMENTS = \ """* **data[reference_len_key]** (list, :class:`numpy.ndarray`): \ Length of reference sentences. Contains start token (eg:``<go>``) \ and end token (eg:``<eos>``). Size: ``[batch_size]``.""" MULTI_TURN_REFERENCE_LEN_KEY_ARGUMENTS = \ """multi_turn_reference_len_key (str, optional): \ The key of lengths of reference sentences. \ Default: ``multi_turn_ref_length``.""" FORWARD_MULTI_TURN_REFERENCE_LEN_ARGUMENTS = \ """* **data[multi_turn_reference_len_key]** (list, :class:`numpy.ndarray`): \ A 2-d jagged or padded array of int. **If padded, redundant position must be set to** ``0``. \ Length of multi-turn reference sentences. Contains start token (eg:``<go>``) \ and end token (eg:``<eos>``). Size: ``[batch_size, ~turn_length]``, \ where "~" means different sizes in this dimension is allowed.""" GEN_KEY_ARGUMENTS = \ """gen_key (str, optional): \ The key of generated sentences. Default: ``gen``.""" GEN_LOG_PROB_KEY_ARGUMENTS = \ """gen_log_prob_key (str, optional): The key of predicted **log** probability over words. \ Default: ``gen_log_prob``.""" GENERATE_RARE_VOCAB_ARGUMENTS = \ """generate_rare_vocab (bool, optional): Whether ``gen_log_prob`` contains :ref:`invalid vocab <vocabulary_ref>`. \ Default: ``False``.""" FULL_CHECK_ARGUMENTS = \ """full_check (bool, optional): Whether to perform a full check on ``gen_log_prob`` to make sure the sum of probability is 1. Otherwise, a random check will be performed for efficiency. If PyTorch is used, a full check is always performed and this argument will be ignored. Default: ``False``.""" FORWARD_GEN_ARGUMENTS = \ """* **data[gen_key]** (list, :class:`numpy.ndarray`): \ A 2-d jagged or padded array of int. \ Sentences generated by model. Contains end token (eg: ``<eos>``), \ but without start token (eg: ``<go>``). \ Size: ``[batch_size, ~gen_sentence_length]``, \ where "~" means different sizes in this dimension is allowed.""" MULTI_TURN_GEN_KEY_ARGUMENTS = \ """multi_turn_gen_key (str, optional): \ The key of generated sentences. Default: ``multi_turn_gen``.""" FORWARD_MULTI_TURN_GEN_ARGUMENTS = \ """* **data[gen_key]** (list, :class:`numpy.ndarray`): \ A 3-d jagged or padded array of int. Sentences generated by model. \ Contains end token (eg: ``<eos>``), but without start token (eg: ``<go>``). \ Size: ``[batch_size, ~max_turn_length, ~gen_sentence_length]``, \ where "~" means different sizes in this dimension is allowed.""" MULTI_TURN_LENGTH_KEY_ARGUMENTS = \ """turn_length (str, optional): \ The key of length of turns. Default: ``turn_length``.""" FORWARD_MULTI_TURN_LENGTH_ARGUMENTS = \ """* **data[turn_len_key]** (list, :class:`numpy.ndarray`): \ Length of turns in each sample. \ Size: ``[batch_size]``.""" CPU_COUNT_ARGUMENTS = \ """cpu_count (int, optional): Number of used cpu for multiprocessing. Multiprocessing will **NOT** be used \ when ``cpu_count`` is set to ``1`` or the dataset is small. Default: If ``None``, \ the environment variable ``CPU_COUNT`` will be used when available, \ or all available cpu will be used otherwise.""" def __init__(self, name: str, version: int): self.unordered_hash = UnorderedSha256() self.ordered_hash = hashlib.sha256() self.name = name self.version = version self._hash_ordered_data((name, version)) self.closed = False
[docs] def _hash_unordered_list(self, data_list: List[Any]): '''Invoked by :meth:`.forward` or :meth:`.close` to hash relevant data when computing a metric. Arguments: data_list (list): relevant data organized as list. ''' for item in data_list: self.unordered_hash.update_data(dumps(item))
[docs] def _hash_ordered_data(self, data: Any): self.ordered_hash.update(dumps(data))
[docs] def _hashvalue(self): '''Invoked by :meth:`.close` to return the recorded hash value. ''' return hashlib.sha256(dumps((self.ordered_hash.hexdigest(), self.unordered_hash.hexdigest()))).hexdigest()
[docs] def forward(self, data: Dict[Any, Any]): '''Processing a batch of data. Arguments: data (dict): A dict contains the data that metrics need. ''' if self.closed: raise ValueError("The metric has been closed.") if not isinstance(data, dict): raise TypeError("Data must be a dict.")
[docs] def close(self) -> Dict[Any, Any]: ''' Close the metric and return a dict containing results. Once the metric is closed, any operation on the metric (e.g. forward or another close) will raise a ValueError. ''' if not self.closed: self.closed = True return {} else: raise RuntimeError("The metric has been closed.")
[docs]class MetricChain(MetricBase): '''A metric-like class for stacked metric. You can use this class making multiples metric combination like one. Examples: >>> metric = MetricChain() >>> metric.add_metric(BleuCorpusMetric()) >>> metric.add_metric(SingleDialogRecorder(dataloader)) Todo: Give more examples to combining forward and close ''' _name = 'MetricChain' _version = 2 def __init__(self): super().__init__(self._name, self._version) self.metric_list = []
[docs] def add_metric(self, metric: "MetricBase"): '''Add metric for processing. Arguments: metric (:class:`.metric.MetricBase`): a metric class. ''' if not isinstance(metric, MetricBase): raise TypeError("Metric must be a subclass of MetricBase") self.metric_list.append(metric)
[docs] def forward(self, data: Dict[Any, Any]): '''Processing a batch of data. Arguments: data (dict): A dict at least contains keys which all the metric components need. ''' super().forward(data) for metric in self.metric_list: metric.forward(data)
[docs] def close(self) -> Dict[Any, Any]: r'''Return a dict containing the items which all the metric components return. ''' res = super().close() for metric in self.metric_list: res.update(metric.close()) return res