Source code for cotk.metric.recorder

"""
Containing some recorders.
"""
from typing import List, Dict, Any, Union
import numpy as np
from .metric import MetricBase

[docs]class SingleTurnDialogRecorder(MetricBase): '''A metric-like class for recording generated sentences and references. Arguments: {MetricBase.DATALOADER_ARGUMENTS} post_allvocabs_key (str, optional): The key of dialog posts with :ref:`allvocabs <vocabulary_ref>`. Default: ``post_allvocabs``. resp_allvocabs_key (str, optional): The key of dialog responses with :ref:`allvocabs <vocabulary_ref>`. Default: ``resp_allvocabs``. {MetricBase.GEN_KEY_ARGUMENTS} Here is an example: >>> post_allvocabs_key = "post_allvocabs" >>> resp_allvocabs_key = "resp_allvocabs" >>> gen_key = "gen" >>> dl = cotk.dataloader.UbuntuCorpus('resources://Ubuntu_small') >>> metric = cotk.metric.SingleTurnDialogRecorder(dl, ... post_allvocabs_key=post_allvocabs_key, ... resp_allvocabs_key=resp_allvocabs_key, ... gen_key=gen_key) >>> data = { ... post_allvocabs_key: [[2, 10, 64, 851, 3], [2, 10, 48, 851, 3]], ... # post_allvocabs_key: [["<go>", "I", "like", "python", "<eos>"], ["<go>", "I", "use", "python", "<eos>"]], ... ... resp_allvocabs_key: [[2, 10, 1214, 479, 3], [2, 851, 17, 2451, 3]], ... # resp_allvocabs_key: [["<go>", "I", "prefe", "java", "<eos>"], ["<go>", "python", "is", "excellent", "<eos>"]], ... ... gen_key: [[10, 64, 2019, 3], [851, 17, 4124, 3]], ... # gen_key: [["I", "like", "PHP", "<eos>"], ["python", "is", "powerful", "<eos>"]] ... } >>> metric.forward(data) >>> metric.close() {'post': [['I', 'like', 'python'], ['I', 'use', 'python']], 'resp': [['I', 'prefer', 'java'], ['python', 'is', 'excellent']], 'gen': [['I', 'like', 'PHP'], ['python', 'is', 'powerful']]} ''' _name = 'SingleTurnDialogRecorder' _version = 2 def __init__(self, dataloader: Union["LanguageProcessing", "Sentence", "Session"], \ post_allvocabs_key: str = "post_allvocabs", \ resp_allvocabs_key: str = "resp_allvocabs", gen_key: str = "gen"): super().__init__(self._name, self._version) self.dataloader = dataloader self.post_allvocabs_key = post_allvocabs_key self.resp_allvocabs_key = resp_allvocabs_key self.gen_key = gen_key self.post_list = [] self.resp_list = [] self.gen_list = []
[docs] def forward(self, data: Dict[str, Any]): '''Processing a batch of data. Arguments: data (dict): A dict at least contains the following keys: {MetricBase.FORWARD_POST_ALLVOCABS_ARGUMENTS} {MetricBase.FORWARD_RESP_ALLVOCABS_ARGUMENTS} {MetricBase.FORWARD_GEN_ARGUMENTS} Here is an example for data: >>> # all_vocab_list = ["<pad>", "<unk>", "<go>", "<eos>", "I", "have", >>> # "been", "to", "China"] >>> data = { ... post_allvocabs_key: [[2,4,3], [2,5,6,3]], ... resp_allvocabs_key: [[2,5,4,3], [2,6,3]], ... gen_key: [[6,7,8,3], [4,5,3]] ... } ''' super().forward(data) post_allvocabs = data[self.post_allvocabs_key] resp_allvocabs = data[self.resp_allvocabs_key] gen = data[self.gen_key] if not isinstance(post_allvocabs, (np.ndarray, list)): raise TypeError("Unknown type for post_allvocabs.") if not isinstance(resp_allvocabs, (np.ndarray, list)): raise TypeError("Unknown type for resp_allvocabs") if not isinstance(gen, (np.ndarray, list)): raise TypeError("Unknown type for gen") if len(post_allvocabs) != len(resp_allvocabs) or len(resp_allvocabs) != len(gen): raise ValueError("Batch num is not matched.") for i, post_sen in enumerate(post_allvocabs): self.post_list.append(self.dataloader.convert_ids_to_tokens(post_sen[1:])) self.resp_list.append(self.dataloader.convert_ids_to_tokens(resp_allvocabs[i][1:])) self.gen_list.append(self.dataloader.convert_ids_to_tokens(gen[i]))
[docs] def close(self) -> Dict[str, Any]: '''Return a dict which contains * **post**: a list of post sentences. A jagged 2-d array of int. Size:``[batch_size, ~sent_length]``, where "~" means different sizes in this dimension is allowed. * **resp**: a list of response sentences. A jagged 2-d array of int. Size:``[batch_size, ~sent_length]``, where "~" means different sizes in this dimension is allowed. * **gen**: A list of generated sentences. A jagged 2-d array of int. Size:``[batch_size, ~sent_length]``, where "~" means different sizes in this dimension is allowed. ''' res = super().close() res.update({"post": self.post_list, "resp": self.resp_list, "gen": self.gen_list}) return res
[docs]class MultiTurnDialogRecorder(MetricBase): '''A metric-like class for recording generated sentences and references. Arguments: {MetricBase.MULTI_TURN_DATALOADER_ARGUMENTS} multi_turn_reference_allvocabs_key (str, optional): The key of dialog references with \ :ref:`allvocabs <vocabulary_ref>`. Default: ``multi_turn_ref_allvocabs``. {MetricBase.MULTI_TURN_GEN_KEY_ARGUMENTS} {MetricBase.MULTI_TURN_LENGTH_KEY_ARGUMENTS} Here is an example: >>> multi_turn_reference_allvocabs_key = "multi_turn_ref_allvocabs" >>> multi_turn_gen_key = "multi_turn_gen" >>> turn_len_key = "turn_length" >>> dl = cotk.dataloader.UbuntuCorpus('resources://Ubuntu_small') >>> metric = cotk.metric.MultiTurnDialogRecorder(dl, ... multi_turn_reference_allvocabs_key=multi_turn_reference_allvocabs_key, ... multi_turn_gen_key=multi_turn_gen_key, ... turn_len_key=turn_len_key) >>> data = { ... multi_turn_reference_allvocabs_key: [[[2, 10, 64, 851, 3], [2, 10, 64, 479, 3]], [[2, 10, 64, 279, 1460, 3]]], ... # multi_turn_reference_allvocabs_key = [[["<go>", "I", "like", "python", "<eos>"], ["<go>", "I", "like", "java", "<eos>"]], ... # [["<go>", "I", "like", "machine", "learning", "<eos>"]]] ... ... turn_len_key: [2, 1], ... # turn_len_key: [len(multi_turn_reference_allvocabs_key[0]), len(multi_turn_reference_allvocabs_key[1])] ... ... multi_turn_gen_key: [[[851, 17, 2451, 3], [2019, 17, 393, 3]], [[10, 64, 34058, 805, 2601, 3]]] ... # multi_turn_gen_key = [[["python", "is", "excellent", "<eos>"], ["PHP", "is", "best", "<eos>"]], ... # [["I", "like", "natural", "language", "processing", "<eos>"]]] ... } >>> metric.forward(data) >>> metric.close() {'reference': [[['I', 'like', 'python'], ['I', 'like', 'java']], [['I', 'like', 'machine', 'learning']]], 'gen': [[['python', 'is', 'excellent'], ['PHP', 'is', 'best']], [['I', 'like', 'natural', 'language', 'processing']]]} ''' _name = 'MultiTurnDialogRecorder' _version = 2 def __init__(self, dataloader: Union["LanguageProcessing", "Session"], multi_turn_reference_allvocabs_key: str = "multi_turn_ref_allvocabs", \ multi_turn_gen_key: str = "multi_turn_gen", \ turn_len_key: str = "turn_length"): super().__init__(self._name, self._version) self.dataloader = dataloader self.multi_turn_reference_allvocabs_key = multi_turn_reference_allvocabs_key self.multi_turn_gen_key = multi_turn_gen_key self.turn_len_key = turn_len_key self.context_list = [] self.reference_list = [] self.gen_list = []
[docs] def forward(self, data: Dict[str, Any]): '''Processing a batch of data. Arguments: data (dict): A dict at least contains the following keys: {MetricBase.FORWARD_MULTI_TURN_REFERENCE_ALLVOCABS_ARGUMENTS} {MetricBase.FORWARD_MULTI_TURN_GEN_ARGUMENTS} {MetricBase.FORWARD_MULTI_TURN_LENGTH_ARGUMENTS} Here is an example for data: >>> # all_vocab_list = ["<pad>", "<unk>", "<go>", "<eos>", "I", "have", >>> # "been", "to", "China"] >>> data = { ... multi_turn_context_allvocabs_key: [[[2,4,3], [2,5,6,3]], [[2,7,6,8,3]]], ... multi_turn_reference_allvocabs_key: [[[2,6,7,3], [2,5,3]], [[2,7,6,8,3]]], ... multi_turn_gen_key: [[[6,7,8,3], [4,5,3]], [[7,3]]], ... turn_len_key: [2,1] ... } ''' super().forward(data) reference_allvocabs = data[self.multi_turn_reference_allvocabs_key] gen = data[self.multi_turn_gen_key] turn_length = data[self.turn_len_key] if not isinstance(reference_allvocabs, (np.ndarray, list)): raise TypeError("Unknown type for reference_allvocabs") if not isinstance(gen, (np.ndarray, list)): raise TypeError("Unknown type for gen") if not isinstance(turn_length, (np.ndarray, list)): raise TypeError("Unknown type for turn_length") if len(turn_length) != len(reference_allvocabs) or \ len(turn_length) != len(gen): raise ValueError("Batch num is not matched.") for i, _ in enumerate(reference_allvocabs): self.reference_list.append(self.dataloader.convert_multi_turn_ids_to_tokens( \ reference_allvocabs[i], remove_special=True)) self.gen_list.append(self.dataloader.convert_multi_turn_ids_to_tokens( \ gen[i], remove_special=False)) if len(self.reference_list[-1]) != len(self.gen_list[-1]): raise ValueError("Reference turn num %d != gen turn num %d." % \ (len(self.reference_list[-1]), len(self.gen_list[-1])))
[docs] def close(self) -> Dict[str, Any]: '''Return a dict which contains * **reference**: a list of response sentences. A jagged 3-d array of int. Size:``[batch_size, ~turn_length, ~sent_length]``, where "~" means different sizes in this dimension is allowed. * **gen**: a list of generated sentences. A jagged 3-d array of int. Size:``[batch_size, ~turn_length, ~sent_length]``, where "~" means different sizes in this dimension is allowed. ''' res = super().close() res.update({"reference": self.reference_list, "gen": self.gen_list}) return res
[docs]class LanguageGenerationRecorder(MetricBase): '''A metric-like class for recorder generated sentences. Arguments: {MetricBase.DATALOADER_ARGUMENTS} {MetricBase.GEN_KEY_ARGUMENTS} Here is an example: >>> gen_key = "gen_key" >>> dl = cotk.dataloader.UbuntuCorpus('resources://Ubuntu_small') >>> metric = cotk.metric.LanguageGenerationRecorder(dl, gen_key=gen_key) >>> data = { ... gen_key: [[2, 10, 64, 851, 3], [2, 10, 48, 851, 3]], ... # gen_key: [["<go>", "I", "like", "python", "<eos>"], ["<go>", "I", "use", "python", "<eos>"]], ... } >>> metric.forward(data) >>> metric.close() {'gen': [['<go>', 'I', 'like', 'python'], ['<go>', 'I', 'use', 'python']]} ''' _name = 'LanguageGenerationRecorder' _version = 2 def __init__(self, dataloader: Union["LanguageProcessing", "Sentence", "Session"], gen_key: str = "gen"): super().__init__(self._name, self._version) self.dataloader = dataloader self.gen_key = gen_key self.gen_list = []
[docs] def forward(self, data: Dict[str, Any]): '''Processing a batch of data. Arguments: data (dict): A dict at least contains the following keys: {MetricBase.FORWARD_GEN_ARGUMENTS} Here is an example for data: >>> # all_vocab_list = ["<pad>", "<unk>", "<go>", "<eos>", "I", "have", >>> # "been", "to", "China"] >>> data = { ... gen_key: [[6,7,8,3], [4,5,3]] ... } ''' super().forward(data) gen = data[self.gen_key] if not isinstance(gen, (np.ndarray, list)): raise TypeError("Unknown type for gen") for sen in gen: self.gen_list.append(self.dataloader.convert_ids_to_tokens(sen))
[docs] def close(self) -> Dict[str, Any]: '''Return a dict which contains * **gen**: a list of generated sentences. A jagged 2-d array of int. Size:``[batch_size, ~sent_length]``, where "~" means different sizes in this dimension is allowed. ''' res = super().close() res.update({"gen": self.gen_list}) return res