Source code for cotk.dataloader.multi_turn_dialog

"""
A module for multi turn dialog.
"""
import warnings
from collections import OrderedDict

from .._utils.metaclass import copy_func
from .dataloader import LanguageProcessing
from .field import Session, Field
from .tokenizer import PretrainedTokenizer
from .vocab import PretrainedVocab
from .context import FieldContext, VocabContext
from ..wordvector import Glove

if False:  # for type check # pylint: disable=using-constant-test
	from ..metric import MetricChain  # pylint: disable=unused-import


# pylint: disable=W0223
[docs]class MultiTurnDialog(LanguageProcessing): r"""Base class for multi-turn dialog datasets. This is an abstract class. Arguments: Attributes:{ATTRIBUTES} Notes: A :class:`Session` field must be set as default field. When invoking :meth:`__init__` of :class:`MultiTurnDialog`, the default field, which may be reset in subclass, is set as self.fields['train']['session']. """ _version = 2 # TODO: fill ATTRIBUTES ATTRIBUTES = '' # ATTRIBUTES = LanguageProcessing.ATTRIBUTES # ARGUMENTS = LanguageProcessing.ARGUMENTS GET_BATCH_RETURNS_DICT = r''' * turn_length(:class:`numpy.ndarray`): A 1-d list, the number of turns in sessions. Size: ``[batch_size]`` * sent_length(:class:`numpy.ndarray`): A 2-d non-padded list, the length of sentence in turns. The second dimension is various in different session. Length of outer list: ``[batch_size]`` * sent(:class:`numpy.ndarray`): A 3-d padding array containing words of index form. Only provide valid words. `unk_id` will be used if a word is not valid. Size: ``[batch_size, max(turn_length[i]), max(sent_length)]`` * sent_allvocabs(:class:`numpy.ndarray`): A 3-d padding array containing words of index form. Provide both valid and invalid vocabs. Size: ``[batch_size, max(turn_length[i]), max(sent_length)]`` ''' GET_BATCH_EXAMPLES_PART = r''' >>> # all_vocab_list = ["<pad>", "<unk>", "<go>", "<eos>", "how", "are", "you", >>> # "hello", "i", "am", "fine"] >>> # vocab_size = 9 >>> # vocab_list = ["<pad>", "<unk>", "<go>", "<eos>", "how", "are", "you", "hello", "i"] >>> dataloader.get_batch('train', [0, 1]) { "sent_allvocabs": numpy.array([ [[2, 7, 3, 0, 0, 0], # 1st sentence in 1st session: <go> hello <eos> <pad> <pad> <pad> [2, 7, 3, 0, 0, 0], # 2nd sentence in 1st session: <go> hello <eos> <pad> <pad> <pad> [2, 4, 5, 6, 3, 0], # 3rd sentence in 1st session: <go> how are you <eos> <pad> [2, 8, 9, 10, 3, 0]], # 4th sentence in 1st session: <go> i am fine <eos> <pad> [[2, 7, 4, 5, 6, 3], # 1st sentence in 2nd session: <go> hello how are you <eos> [2, 8, 9, 10, 3, 0], # 2nd sentence in 2nd session: <go> i am fine <eos> <pad> [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]] ]), "sent": numpy.array([ [[2, 7, 3, 0, 0, 0], # 1st sentence in 1st session: <go> hello <eos> <pad> <pad> <pad> [2, 7, 3, 0, 0, 0], # 2nd sentence in 1st session: <go> hello <eos> <pad> <pad> <pad> [2, 4, 5, 6, 3, 0], # 3rd sentence in 1st session: <go> how are you <eos> <pad> [2, 8, 1, 1, 3, 0]], # 4th sentence in 1st session: <go> i <unk> <unk> <eos> <pad> [[2, 7, 4, 5, 6, 3], # 1st sentence in 2nd session: <go> hello how are you <eos> [2, 8, 1, 1, 3, 0], # 2nd sentence in 2nd session: <go> i <unk> <unk> <eos> <pad> [0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0]] ]), "turn_length": np.array([4, 2]), # the number of turns in each session "sent_length": np.array([np.array([3, 3, 5, 5]), np.array([6, 5])]), # length of sentences''' def __init__(self, file_id: str, tokenizer=None, max_sent_length=None, max_turn_length=None, convert_to_lower_letter=None, min_frequent_vocab_times=None, min_rare_vocab_times=None, fields=None, pretrained=None): self._pretrained = pretrained if pretrained is None: if fields is None: fields = OrderedDict([('session', 'SessionDefault')]) with FieldContext.set_parameters(tokenizer=tokenizer, max_sent_length=max_sent_length, convert_to_lower_letter=convert_to_lower_letter, max_turn_length=max_turn_length): with VocabContext.set_parameters(min_rare_vocab_times=min_rare_vocab_times, min_frequent_vocab_times=min_frequent_vocab_times): super().__init__(file_id, fields) elif pretrained == 'gpt2' or pretrained == 'bert': if fields is None: fields = OrderedDict([('session', Session.get_pretrained_class(pretrained).__name__)]) if not isinstance(tokenizer, PretrainedTokenizer): raise ValueError("tokenize should be loaded first if you want a %s dataloader" % (pretrained)) vocab = PretrainedVocab(tokenizer.tokenizer) with FieldContext.set_parameters(tokenizer=tokenizer, vocab=vocab, max_sent_length=max_sent_length, max_turn_length=max_turn_length, convert_to_lower_letter=convert_to_lower_letter): super().__init__(file_id, fields) else: raise ValueError("No pretrained name %s" % pretrained) self.set_default_field('train', 'session') if pretrained == 'gpt2' or pretrained == 'bert': # check whether SessionGPT2 or SessionBERT is used. for set_name, set_fields in self.fields.items(): for field_name, field in set_fields.items(): if isinstance(field, Session) and not isinstance(field, Session.get_pretrained_class(pretrained)): warnings.warn("If you want to use a %s multi_turn_dialog, you'd better use %s instead of %s." % (pretrained, Session.get_pretrained_class(pretrained).__name__, type(field).__name__)) _SESSION_MORE_DOCSTRING = '''It calls the identical method of the :class:`Session` instance ``session``,\ from :meth:`.get_default_field()`.''' multi_turn_trim_in_ids = copy_func(LanguageProcessing.get_default_field, Session, 'multi_turn_trim_in_ids') convert_multi_turn_tokens_to_ids = copy_func(LanguageProcessing.get_default_field, Session, 'convert_multi_turn_tokens_to_ids') convert_multi_turn_ids_to_tokens = copy_func(LanguageProcessing.get_default_field, Session, 'convert_multi_turn_ids_to_tokens')
[docs] def get_teacher_forcing_metric(self, multi_turn_gen_log_prob_key="multi_turn_gen_log_prob"): '''Get metric for teacher-forcing. It contains: * :class:`.metric.MultiTurnPerplexityMetric` Arguments: gen_log_prob_key (str): The key of predicted log probability over words. Refer to :class:`.metric.MultiTurnPerplexityMetric`. Default: ``gen_log_prob``. Returns: A :class:`.metric.MetricChain` object. ''' from ..metric import MetricChain, MultiTurnPerplexityMetric metric = MetricChain() metric.add_metric(MultiTurnPerplexityMetric(self, \ multi_turn_gen_log_prob_key=multi_turn_gen_log_prob_key, \ multi_turn_reference_len_key="sent_length", \ multi_turn_reference_allvocabs_key="sent_allvocabs")) return metric
[docs] def get_inference_metric(self, multi_turn_gen_key="multi_turn_gen"): '''Get metric for inference. It contains: * :class:`.metric.BleuCorpusMetric` * :class:`.metric.MultiTurnDialogRecorder` Arguments: gen_key (str): The key of generated sentences in index form. Refer to :class:`.metric.BleuCorpusMetric` or :class:`.metric.MultiTurnDialogRecorder`. Default: ``gen``. Returns: A :class:`.metric.MetricChain` object. ''' from ..metric import MetricChain, MultiTurnBleuCorpusMetric, MultiTurnDialogRecorder metric = MetricChain() metric.add_metric(MultiTurnBleuCorpusMetric(self, multi_turn_gen_key=multi_turn_gen_key, \ multi_turn_reference_allvocabs_key="sent_allvocabs", turn_len_key="turn_length")) metric.add_metric(MultiTurnDialogRecorder(self, multi_turn_gen_key=multi_turn_gen_key, \ multi_turn_reference_allvocabs_key="sent_allvocabs", turn_len_key="turn_length")) return metric
# TODO: doc
[docs]class UbuntuCorpus(MultiTurnDialog): '''A dataloader for Ubuntu dataset. Arguments: file_id (str): a str indicates the source of UbuntuCorpus dataset. Default: ``resources://Ubuntu``. A preset dataset is downloaded and cached.{ARGUMENTS} Refer to :class:`.MultiTurnDialog` for attributes and methods. References: [1] https://github.com/rkadlec/ubuntu-ranking-dataset-creator [2] Lowe R, Pow N, Serban I, et al. The Ubuntu Dialogue Corpus: A Large Dataset for Research in Unstructured Multi-Turn Dialogue Systems. SIGDIAL 2015. ''' ARGUMENTS_FORMATTER = r''' min_frequent_vocab_times (int): A cut-off threshold of valid tokens. All tokens appear not less than `min_vocab_times` in **training set** will be marked as frequent words. Default: ``{default_min_frequent_vocab_times}``. max_sent_length (int): All sentences longer than ``max_sent_length`` will be shortened to first ``max_sent_length`` tokens. Default: ``{default_max_sent_length}``. max_turn_length (int): All sessions longer than ``max_turn_length`` will be shortened to first ``max_turn_length`` sentences. Default: ``{default_max_turn_length}``. min_rare_vocab_times (int): A cut-off threshold of rare tokens. All tokens appear not less than ``invalid_vocab_times`` in the **whole dataset** (except valid words) will be marked as rare words. Otherwise, they are unknown words, both in training or testing stages. Default: ``{default_min_rare_vocab_times}`` (No unknown words).''' ARGUMENTS = ARGUMENTS_FORMATTER.format( default_min_frequent_vocab_times=10, default_max_sent_length=50, default_max_turn_length=20, default_min_rare_vocab_times=0 ) def __init__(self, file_id="resources://Ubuntu", min_frequent_vocab_times=10, max_sent_length=50, max_turn_length=20, min_rare_vocab_times=0, tokenizer='nltk', pretrained=None): super().__init__(file_id, tokenizer=tokenizer, max_sent_length=max_sent_length, max_turn_length=max_turn_length, convert_to_lower_letter=True, min_frequent_vocab_times=min_frequent_vocab_times, min_rare_vocab_times=min_rare_vocab_times, pretrained=pretrained)
[docs]class SwitchboardCorpus(MultiTurnDialog): '''A dataloader for Switchboard dataset. In this dataset, all sessions start with a ``<d>`` representing empty context. Arguments: file_id (str): a string indicating the source of SwitchboardCorpus dataset. Default: ``resources://SwitchboardCorpus``. A preset dataset is downloaded and cached. Refer to :class:`.MultiTurnDialog` for attributes and methods. References: [1] https://catalog.ldc.upenn.edu/LDC97S62 [2] John J G and Edward H. Switchboard-1 release 2. Linguistic Data Consortium, Philadelphia 1997. ''' ARGUMENTS = UbuntuCorpus.ARGUMENTS_FORMATTER.format( default_min_frequent_vocab_times=5, default_max_sent_length=50, default_max_turn_length=1000, default_min_rare_vocab_times=0 ) def __init__(self, file_id="resources://SwitchboardCorpus", min_frequent_vocab_times=5, \ max_sent_length=50, max_turn_length=1000, min_rare_vocab_times=0, tokenizer='nltk', pretrained=None): if pretrained is None: fields = { **{k: OrderedDict([['session', 'SessionDefault']]) for k in ['train', 'dev', 'test']}, 'multi_ref': OrderedDict([['session', 'SessionDefault'], ['candidate', "SentenceCandidateDefault"]]) } elif pretrained == 'gpt2' or pretrained == 'bert': fields = { **{k: OrderedDict([('session', Session.get_pretrained_class(pretrained).__name__)]) for k in ['train', 'dev', 'test']}, 'multi_ref': OrderedDict([['session', Session.get_pretrained_class(pretrained).__name__], ['candidate', Session.get_candidate_pretrained_class(pretrained).__name__]]) } else: raise ValueError("No pretrained name %s" % pretrained) with FieldContext.set_parameters( vocab_from_mappings={**Field.DEFAULT_VOCAB_FROM_MAPPINGS, 'multi_ref': 'test'}): super().__init__(file_id, tokenizer=tokenizer, max_sent_length=max_sent_length, max_turn_length=max_turn_length, convert_to_lower_letter=False, min_frequent_vocab_times=min_frequent_vocab_times, min_rare_vocab_times=min_rare_vocab_times, fields=fields, pretrained=pretrained) def get_batch(self, set_name, indexes): # '''{LanguageProcessing.GET_BATCH_DOC_WITHOUT_RETURNS} ''' Returns: (dict): A dict contains what is in the return of MultiTurnDialog.get_batch. {MultiTurnDialog.GET_BATCH_RETURNS_DICT} It additionally contains: * candidate_allvocabs (list): A 3-d list, multiple responses for a session. Size: ``[batch_size, ~reference_num, ~sent_length]``, where "~" means different sizes in this dimension is allowed. See the example belows. Examples: {MultiTurnDialog.GET_BATCH_EXAMPLES_PART} "candidate_allvocabs":[ [[2, 7, 3], # two responses to 1st session: <go> hello <eos> [2, 6, 5, 10, 3]], # <go> you are fine <eos> [[2, 6, 5, 10, 3]]] # one response to 2nd session: <go> you are fine <eos> } ''' return super().get_batch(set_name, indexes) def get_multi_ref_metric(self, generated_num_per_context=20, word2vec=None, \ multiple_gen_key="multiple_gen_key"): '''Get metrics for multiple references. It contains: * :class:`.metric.BleuPrecisionRecallMetric` * :class:`.metric.EmbSimilarityPrecisionRecallMetric` Arguments: generated_num_per_context (int): The number of sentences generated per context. word2vec (dict): Maps words to word embeddings for embedding similarity. Default: if ``None``, using glove word embedding from ``resources://Glove300d``. Returns: A :class:`.metric.MetricChain` object. ''' from ..metric import MetricChain, BleuPrecisionRecallMetric, EmbSimilarityPrecisionRecallMetric metric = MetricChain() if word2vec is None: glove = Glove("resources://Glove300d") word2vec = glove.load_dict(self.frequent_vocab_list) for ngram in range(1, 5): metric.add_metric(BleuPrecisionRecallMetric(self, ngram, generated_num_per_context, \ multiple_gen_key=multiple_gen_key)) metric.add_metric(EmbSimilarityPrecisionRecallMetric(self, word2vec, \ 'avg', generated_num_per_context, multiple_gen_key=multiple_gen_key)) metric.add_metric(EmbSimilarityPrecisionRecallMetric(self, word2vec, \ 'extrema', generated_num_per_context, multiple_gen_key=multiple_gen_key)) return metric