"""
``cotk.metrics`` provides classes and functions evaluating results of models.
It provides a fair metric for every model.
"""
from typing import Any, List, Dict
import hashlib
from .._utils.unordered_hash import UnorderedSha256, dumps
from .._utils.metaclass import LoadClassInterface, DocStringInheritor
[docs]class MetricBase(LoadClassInterface, metaclass=DocStringInheritor):
'''Base class for metrics.
'''
DATALOADER_ARGUMENTS = \
"""dataloader (:class:`.dataloader.LanguageProcessing`, :class:`.dataloader.Sentence`, :class:`.dataloader.Session`): \
A language generation dataloader."""
MULTI_TURN_DATALOADER_ARGUMENTS = \
"""dataloader (:class:`.dataloader.LanguageProcessing`, :class:`.dataloader.Session`): \
A language generation dataloader."""
NGRAM_ARGUMENTS = \
"""ngram (int, optional): The order of ngram to calculate metrics like BLEU and Perplexity. Default: ``4``."""
TOKENIZER_ARGUMENTS = \
"""tokenizer (None, :class:`.dataloader.Tokenizer`, str, optional): Specifies the tokenizer used in \
the metric. Default: ``None``."""
IGNORE_SMOOTHING_ERROR_ARGUMENTS = \
"""ignore_smoothing_error (bool, optional): Specifies whether to ignore the smoothing error when calculating \
BLEU. Default: ``False``."""
SAMPLE_ARGUMENTS_IN_BLEU = \
"""sample (int, optional): Number of examples sampled from the generated sentences. Default: ``1000``."""
SAMPLE_ARGUMENTS_IN_NGRAM_PERPLEXITY = \
SAMPLE_ARGUMENTS_IN_BLEU.replace("Default: ``1000``.", "Default: ``10000``.")
SEED_ARGUMENTS = \
"""seed (int, optional): Random seed for sampling. Default: ``1229``."""
REFERENCE_TEST_LIST_ARGUMENTS = \
"""reference_test_list (list): Reference sentences with :ref:`all vocabs <vocabulary_ref>` in test data."""
REFERENCE_ALLVOCABS_KEY_ARGUMENTS = \
"""reference_allvocabs_key (str, optional): \
The key of reference sentences. Default: ``ref_allvocabs``."""
FORWARD_REFERENCE_ALLVOCABS_ARGUMENTS = \
"""* **data[reference_allvocabs_key]** (list, :class:`numpy.ndarray`): \
A 2-d jagged or padded array of int. Reference sentences with \
:ref:`allvocabs <vocabulary_ref>` in index form. \
Contains start token (eg: ``<go>``) and end token (eg: ``<eos>``). \
Size: ``[batch_size, ~ref_sentence_length]``, \
where "~" means different sizes in this dimension is allowed."""
FORWARD_REFERENCE_ALLVOCABS_ARGUMENTS_WITH_TORCH = \
FORWARD_REFERENCE_ALLVOCABS_ARGUMENTS.replace("list, :class:`numpy.ndarray`", \
"list, :class:`numpy.ndarray`, :class:`torch.Tensor`")
FORWARD_POST_ALLVOCABS_ARGUMENTS = \
FORWARD_REFERENCE_ALLVOCABS_ARGUMENTS.replace("reference_allvocabs_key", \
"post_allvocabs_key")
FORWARD_RESP_ALLVOCABS_ARGUMENTS = \
FORWARD_REFERENCE_ALLVOCABS_ARGUMENTS.replace("reference_allvocabs_key", \
"resp_allvocabs_key")
LABEL_KEY_ARGUMENTS = \
"""label_key (str): \
The key of reference sentence labels. Default: ``label``."""
LABEL_ARGUMENTS = """* **data[label_key]** (list or :class:`numpy.ndarray`): \
A 1-d array of int. \
Size: ``[batch_size]``, \
each element refers to label of one sample"""
PREDICTION_KEY_ARGUMENTS = \
"""prediction_key (str): \
The key of reference sentence predictions. Default: ``prediction``."""
PREDICTION_ARGUMENTS = """* **data[prediction_key]** (list or :class:`numpy.ndarray`): \
A 1-d array of int. \
Size: ``[batch_size]``, \
each element refers to prediction for one sample"""
MULTI_TURN_REFERENCE_ALLVOCABS_KEY_ARGUMENTS = \
"""multi_turn_reference_allvocabs_key (str, optional): \
The key of reference sentences. Default: ``multi_turn_ref_allvocabs``."""
FORWARD_MULTI_TURN_REFERENCE_ALLVOCABS_ARGUMENTS = \
"""* **data[multi_turn_reference_allvocabs_key]** (list, :class:`numpy.ndarray`): \
A 3-d jagged or padded array of int. Multi-turn reference sentences with \
:ref:`all vocabs <vocabulary_ref>`. Contains start token (eg: ``<go>``) and \
end token (eg: ``<eos>``). Size: ``[batch_size, ~turn_length, ~sentence_length]``, \
where "~" means different sizes in this dimension is allowed."""
FORWARD_MULTI_TURN_REFERENCE_ALLVOCABS_ARGUMENTS_WITH_TORCH = \
FORWARD_MULTI_TURN_REFERENCE_ALLVOCABS_ARGUMENTS.replace("list, :class:`numpy.ndarray`", \
"list, :class:`numpy.ndarray`, :class:`torch.Tensor`")
REFERENCE_LEN_KEY_ARGUMENTS = \
"""reference_len_key (str, optional): \
The key of lengths of reference sentences. \
Default: ``ref_length``."""
FORWARD_REFERENCE_LEN_ARGUMENTS = \
"""* **data[reference_len_key]** (list, :class:`numpy.ndarray`): \
Length of reference sentences. Contains start token (eg:``<go>``) \
and end token (eg:``<eos>``). Size: ``[batch_size]``."""
MULTI_TURN_REFERENCE_LEN_KEY_ARGUMENTS = \
"""multi_turn_reference_len_key (str, optional): \
The key of lengths of reference sentences. \
Default: ``multi_turn_ref_length``."""
FORWARD_MULTI_TURN_REFERENCE_LEN_ARGUMENTS = \
"""* **data[multi_turn_reference_len_key]** (list, :class:`numpy.ndarray`): \
A 2-d jagged or padded array of int. **If padded, redundant position must be set to** ``0``. \
Length of multi-turn reference sentences. Contains start token (eg:``<go>``) \
and end token (eg:``<eos>``). Size: ``[batch_size, ~turn_length]``, \
where "~" means different sizes in this dimension is allowed."""
GEN_KEY_ARGUMENTS = \
"""gen_key (str, optional): \
The key of generated sentences. Default: ``gen``."""
GEN_LOG_PROB_KEY_ARGUMENTS = \
"""gen_log_prob_key (str, optional): The key of predicted **log** probability over words. \
Default: ``gen_log_prob``."""
GENERATE_RARE_VOCAB_ARGUMENTS = \
"""generate_rare_vocab (bool, optional): Whether ``gen_log_prob`` contains :ref:`invalid vocab <vocabulary_ref>`. \
Default: ``False``."""
FULL_CHECK_ARGUMENTS = \
"""full_check (bool, optional): Whether to perform a full check on ``gen_log_prob`` to make sure the sum
of probability is 1. Otherwise, a random check will be performed for efficiency.
If PyTorch is used, a full check is always performed and this argument will be ignored.
Default: ``False``."""
FORWARD_GEN_ARGUMENTS = \
"""* **data[gen_key]** (list, :class:`numpy.ndarray`): \
A 2-d jagged or padded array of int. \
Sentences generated by model. Contains end token (eg: ``<eos>``), \
but without start token (eg: ``<go>``). \
Size: ``[batch_size, ~gen_sentence_length]``, \
where "~" means different sizes in this dimension is allowed."""
MULTI_TURN_GEN_KEY_ARGUMENTS = \
"""multi_turn_gen_key (str, optional): \
The key of generated sentences. Default: ``multi_turn_gen``."""
FORWARD_MULTI_TURN_GEN_ARGUMENTS = \
"""* **data[gen_key]** (list, :class:`numpy.ndarray`): \
A 3-d jagged or padded array of int. Sentences generated by model. \
Contains end token (eg: ``<eos>``), but without start token (eg: ``<go>``). \
Size: ``[batch_size, ~max_turn_length, ~gen_sentence_length]``, \
where "~" means different sizes in this dimension is allowed."""
MULTI_TURN_LENGTH_KEY_ARGUMENTS = \
"""turn_length (str, optional): \
The key of length of turns. Default: ``turn_length``."""
FORWARD_MULTI_TURN_LENGTH_ARGUMENTS = \
"""* **data[turn_len_key]** (list, :class:`numpy.ndarray`): \
Length of turns in each sample. \
Size: ``[batch_size]``."""
CPU_COUNT_ARGUMENTS = \
"""cpu_count (int, optional): Number of used cpu for multiprocessing. Multiprocessing will **NOT** be used \
when ``cpu_count`` is set to ``1`` or the dataset is small. Default: If ``None``, \
the environment variable ``CPU_COUNT`` will be used when available, \
or all available cpu will be used otherwise."""
def __init__(self, name: str, version: int):
self.unordered_hash = UnorderedSha256()
self.ordered_hash = hashlib.sha256()
self.name = name
self.version = version
self._hash_ordered_data((name, version))
self.closed = False
[docs] def _hash_unordered_list(self, data_list: List[Any]):
'''Invoked by :meth:`.forward` or :meth:`.close` to hash relevant data when computing a metric.
Arguments:
data_list (list): relevant data organized as list.
'''
for item in data_list:
self.unordered_hash.update_data(dumps(item))
[docs] def _hash_ordered_data(self, data: Any):
self.ordered_hash.update(dumps(data))
[docs] def _hashvalue(self):
'''Invoked by :meth:`.close` to return the recorded hash value.
'''
return hashlib.sha256(dumps((self.ordered_hash.hexdigest(), self.unordered_hash.hexdigest()))).hexdigest()
[docs] def forward(self, data: Dict[Any, Any]):
'''Processing a batch of data.
Arguments:
data (dict): A dict contains the data that metrics need.
'''
if self.closed:
raise ValueError("The metric has been closed.")
if not isinstance(data, dict):
raise TypeError("Data must be a dict.")
[docs] def close(self) -> Dict[Any, Any]:
'''
Close the metric and return a dict containing results. Once the metric is closed,
any operation on the metric (e.g. forward or another close) will raise a ValueError.
'''
if not self.closed:
self.closed = True
return {}
else:
raise RuntimeError("The metric has been closed.")
[docs]class MetricChain(MetricBase):
'''A metric-like class for stacked metric. You can use this class
making multiples metric combination like one.
Examples:
>>> metric = MetricChain()
>>> metric.add_metric(BleuCorpusMetric())
>>> metric.add_metric(SingleDialogRecorder(dataloader))
Todo: Give more examples to combining forward and close
'''
_name = 'MetricChain'
_version = 2
def __init__(self):
super().__init__(self._name, self._version)
self.metric_list = []
[docs] def add_metric(self, metric: "MetricBase"):
'''Add metric for processing.
Arguments:
metric (:class:`.metric.MetricBase`): a metric class.
'''
if not isinstance(metric, MetricBase):
raise TypeError("Metric must be a subclass of MetricBase")
self.metric_list.append(metric)
[docs] def forward(self, data: Dict[Any, Any]):
'''Processing a batch of data.
Arguments:
data (dict): A dict at least contains keys which all the
metric components need.
'''
super().forward(data)
for metric in self.metric_list:
metric.forward(data)
[docs] def close(self) -> Dict[Any, Any]:
r'''Return a dict containing the items which all the metric components return.
'''
res = super().close()
for metric in self.metric_list:
res.update(metric.close())
return res