Source code for cotk.dataloader.sentence_classification

"""Dataloader for language generation"""
import warnings
from collections import OrderedDict

from .field import Sentence, SentenceGPT2, SentenceBERT
from .dataloader import LanguageProcessing
from .context import FieldContext, VocabContext
from .vocab import GeneralVocab, PretrainedVocab
from .tokenizer import PretrainedTokenizer
from .field import Sentence

if False: # for type check # pylint: disable=using-constant-test
	from ..metric import MetricChain #pylint: disable=unused-import

# pylint: disable=W0223
[docs]class SentenceClassification(LanguageProcessing): r"""Base class for sentence classification datasets. This is an abstract class. Arguments: Notes: A :class:`Sentence` field must be set as default field. When invoking :meth:`__init__` of :class:`SentenceClassification`, the default field, which may be reset in subclass, is set as self.fields['train']['sent']. """ _version = 2 def __init__(self, file_id: str, tokenizer=None, max_sent_length=None, convert_to_lower_letter=None, min_frequent_vocab_times=None, min_rare_vocab_times=None, fields=None, pretrained=None): self._pretrained = pretrained if pretrained is None: if fields is None: fields = OrderedDict([('sent', 'SentenceDefault'), ('label', 'DenseLabel')]) with FieldContext.set_parameters(tokenizer=tokenizer, max_sent_length=max_sent_length, convert_to_lower_letter=convert_to_lower_letter): with VocabContext.set_parameters(min_rare_vocab_times=min_rare_vocab_times, min_frequent_vocab_times=min_frequent_vocab_times): super().__init__(file_id, fields) elif pretrained == 'gpt2' or pretrained == 'bert': if fields is None: fields = OrderedDict([('sent', Sentence.get_pretrained_class(pretrained).__name__), ('label', 'DenseLabel')]) if not isinstance(tokenizer, PretrainedTokenizer): raise ValueError("tokenize should be loaded first if you want a %s dataloader" % (pretrained)) vocab = PretrainedVocab(tokenizer.tokenizer) with FieldContext.set_parameters(tokenizer=tokenizer, vocab=vocab, max_sent_length=max_sent_length, convert_to_lower_letter=convert_to_lower_letter): super().__init__(file_id, fields) else: raise ValueError("No pretrained name %s" % pretrained) self.set_default_field('train', 'sent') if pretrained == 'gpt2' or pretrained == 'bert': # check whether SentenceGPT2 or SentenceBERT is used. for set_name, set_fields in self.fields.items(): for field_name, field in set_fields.items(): if isinstance(field, Sentence) and not isinstance(field, Sentence.get_pretrained_class(pretrained)): warnings.warn("If you want to use a %s sentence_classification, you'd better use %s instead of %s." % (pretrained, Sentence.get_pretrained_class(pretrained).__name__, type(field).__name__))
[docs] def get_batch(self, set_name, indexes): '''Get a batch of specified `indexes`. Arguments: set_name (str): must be contained in `key_name` indexes (list): a list of specified indexes Returns: (dict): A dict at least contains: * sent_length(:class:`numpy.array`): A 1-d array, the length of sentence in each batch. Size: `[batch_size]` * sent(:class:`numpy.array`): A 2-d padding array containing id of words. Only provide valid words. `unk_id` will be used if a word is not valid. Size: `[batch_size, max(sent_length)]` * label(:class:`numpy.array`): A 1-d array, the label of sentence in each batch. * sent_allvocabs(:class:`numpy.array`): A 2-d padding array containing id of words. Provide both valid and invalid words. Size: `[batch_size, max(sent_length)]` Examples: >>> # all_vocab_list = ["<pad>", "<unk>", "<go>", "<eos>", "how", "are", "you", >>> # "hello", "i", "am", "fine"] >>> # vocab_size = 9 >>> # vocab_list = ["<pad>", "<unk>", "<go>", "<eos>", "how", "are", "you", "hello", "i"] >>> dataloader.get_batch('train', [0, 1, 2]) { "sent": numpy.array([ [2, 4, 5, 6, 3, 0], # first sentence: <go> how are you <eos> <pad> [2, 7, 3, 0, 0, 0], # second sentence: <go> hello <eos> <pad> <pad> <pad> [2, 7, 8, 1, 1, 3] # third sentence: <go> hello i <unk> <unk> <eos> ]), "label": numpy.array([1, 2, 0]) # label of sentences "sent_length": numpy.array([5, 3, 6]), # length of sentences "sent_allvocabs": numpy.array([ [2, 4, 5, 6, 3, 0], # first sentence: <go> how are you <eos> <pad> [2, 7, 3, 0, 0, 0], # second sentence: <go> hello <eos> <pad> <pad> <pad> [2, 7, 8, 9, 10, 3] # third sentence: <go> hello i am fine <eos> ]), } ''' return super().get_batch(set_name, indexes)
[docs] def get_metric(self, prediction_key="prediction"): '''Get metrics for accuracy. In other words, this function provides metrics for sentence classification task. It contains: * :class:`.metric.AccuracyMetric` Arguments: prediction_key (str): The key of prediction over sentences. Refer to :class:`.metric.AccuracyMetric`. Default: ``prediction``. Returns: A :class:`.metric.MetricChain` object. ''' from ..metric import MetricChain, AccuracyMetric metric = MetricChain() metric.add_metric(AccuracyMetric(self, \ label_key='label', \ prediction_key=prediction_key)) return metric
[docs]class SST(SentenceClassification): '''A dataloader for preprocessed SST dataset. Arguments: file_id (str): a str indicates the source of SST dataset. min_frequent_vocab_times (int): A cut-off threshold of valid tokens. All tokens appear not less than `min_frequent_vocab_times` in **training set** will be marked as frequent words. Default: 10. max_sent_length (int): All sentences longer than `max_sent_length` will be shortened to first `max_sent_length` tokens. Default: 50. min_rare_vocab_times (int): A cut-off threshold of invalid tokens. All tokens appear not less than `min_rare_vocab_times` in the **whole dataset** (except valid words) will be marked as rare words. Otherwise, they are unknown words, both in training or testing stages. Default: 0 (No unknown words). Refer to :class:`.SentenceClassification` for attributes and methods. References: [1] [2] Lin T Y, Maire M, Belongie S, et al. Microsoft COCO: Common Objects in Context. ECCV 2014. ''' def __init__(self, file_id, min_frequent_vocab_times=10, \ max_sent_length=50, min_rare_vocab_times=0, tokenizer='space', pretrained=None): super().__init__(file_id, tokenizer=tokenizer, max_sent_length=max_sent_length, convert_to_lower_letter=False, min_frequent_vocab_times=min_frequent_vocab_times, min_rare_vocab_times=min_rare_vocab_times, pretrained=pretrained)