Source code for cotk.metric.bleu

r"""
Containing some classes and functions about bleu evaluating results of models.
"""
from typing import Union, List, Any, Optional, Iterable, Dict
import random
import os
import multiprocessing
from multiprocessing import Pool
import numpy as np
import tqdm
from nltk.translate.bleu_score import corpus_bleu, sentence_bleu, SmoothingFunction

from .metric import MetricBase
from ..dataloader.tokenizer import Tokenizer, SimpleTokenizer
from .._utils import replace_unk

if False: # for type check # disable: using-constant-test
	from ..dataloader.dataloader import LanguageProcessing

def _sentence_bleu(ele):
	'''Auxiliary function for computing sentence bleu:

	Arguments:
		ele (tuple): A tuple (`reference sentences`, `a hypothesis sentence`).

	Returns:

		* int: **sentence-bleu** value.
	'''

	return sentence_bleu(ele[0], ele[1], weights=ele[2], smoothing_function=SmoothingFunction().method1)

[docs]class BleuCorpusMetric(MetricBase): '''Metric for calculating BLEU. Arguments: {MetricBase.DATALOADER_ARGUMENTS} {MetricBase.NGRAM_ARGUMENTS} {MetricBase.TOKENIZER_ARGUMENTS} reference_num (int, None, optional): The number of references used to calculate BLEU. If ``None``, the number of references \ is uncertain and will be determined by the argument of :meth:`.forward`. Default: ``1``. {MetricBase.IGNORE_SMOOTHING_ERROR_ARGUMENTS} {MetricBase.REFERENCE_ALLVOCABS_KEY_ARGUMENTS} {MetricBase.GEN_KEY_ARGUMENTS} reference_str_key (str, optional): The key of reference sentences in the string form. Default: ``ref_str``. Here is an example: >>> dl = cotk.dataloader.UbuntuCorpus('resources://Ubuntu_small') >>> reference_allvocabs_key = "ref_allvocabs" >>> gen_key = "gen" >>> metric = cotk.metric.BleuCorpusMetric(dl, ... reference_allvocabs_key=reference_allvocabs_key, ... gen_key=gen_key) >>> data = { ... reference_allvocabs_key: [[2, 10, 64, 851, 3], [2, 10, 48, 851, 3]], ... # reference_allvocabs_key: [["<go>", "I", "like", "python", "<eos>"], ["<go>", "I", "use", "python", "<eos>"]], ... ... gen_key: [[10, 1028, 479, 285, 220, 3], [851, 17, 2451, 3]] ... # gen_key: [["I", "love", "java", "very", "much", "<eos>"], ["python", "is", "excellent", "<eos>"]], ... } >>> metric.forward(data) >>> metric.close() {'bleu': 0.08582363099612991, 'bleu hashvalue': '70e019630fef24d9477034a3d941a5349fcbff5a3dc6978a13ea3d85290114fb'} ''' _name = 'BleuCorpusMetric' _version = 2 def __init__(self, dataloader: Union["LanguageProcessing", "Sentence", "Session"], ngram: int =4, *, \ tokenizer: Union[None, Tokenizer, str] = None, reference_num: Optional[int] = 1, \ ignore_smoothing_error: bool = False, reference_allvocabs_key: str = "ref_allvocabs", \ gen_key: str = "gen", reference_str_key: str = "ref_str"): super().__init__(self._name, self._version) #self._hash_ordered_data(self.ngram) self.dataloader = dataloader self.ngram = ngram self.tokenizer = tokenizer self.reference_num = reference_num self.ignore_smoothing_error = ignore_smoothing_error self.reference_allvocabs_key = reference_allvocabs_key self.reference_str_key = reference_str_key self.gen_key = gen_key self.hyps: List[Any] = [] self.refs: List[List[Any]] = []
[docs] def forward(self, data: Dict[str, Any]): '''Processing a batch of data. Arguments: data (dict): A dict at least contains the following keys: {MetricBase.FORWARD_REFERENCE_ALLVOCABS_ARGUMENTS} {MetricBase.FORWARD_GEN_ARGUMENTS} Here is an example for data: >>> # all_vocab_list = ["<pad>", "<unk>", "<go>", "<eos>", "I", "have", >>> # "been", "to", "China"] >>> data = { ... reference_allvocabs_key: [[2,4,3], [2,5,6,3]], ... gen_key: [[4,5,3], [6,7,8,3]] ... } ''' super().forward(data) if self.tokenizer is None: self._direct_forward(data) else: self._re_tokenize_forward(data)
def _direct_forward(self, data: Dict[str, Any]): gen = data[self.gen_key] resp = data[self.reference_allvocabs_key] if not isinstance(gen, (np.ndarray, list)): raise TypeError("Unknown type for gen.") if not isinstance(resp, (np.ndarray, list)): raise TypeError("Unknown type for resp") if len(resp) != len(gen): raise ValueError("Batch num is not matched.") relevant_data = [] for gen_sen, resp_sen in zip(gen, resp): hyp = self.dataloader.convert_ids_to_tokens(gen_sen, remove_special=True, trim=True) if self.reference_num == 1: refs = [self.dataloader.convert_ids_to_tokens(resp_sen, remove_special=True, trim=True)] else: if self.reference_num is not None and len(resp_sen) != self.reference_num: raise RuntimeError("Require %d references but get %d" % (self.reference_num, len(resp_sen))) refs = [self.dataloader.convert_ids_to_tokens(resp_single_sen, remove_special=True, trim=True) for resp_single_sen in resp_sen] self.hyps.append(hyp) self.refs.append(refs) relevant_data.append(refs) self._hash_unordered_list(relevant_data) def _re_tokenize_forward(self, data: Dict[str, Any]): gen = data[self.gen_key] resp = data.get(self.reference_allvocabs_key, None) resp_str = data.get(self.reference_str_key, None) if not isinstance(gen, (np.ndarray, list)): raise TypeError("Unknown type for gen.") #fill more typeerror hints relevant_data = [] for i, gen_sen in enumerate(gen): hyp = self.dataloader.convert_ids_to_sentence(gen_sen, remove_special=True, trim=True) if resp_str: if self.reference_num == 1: refs = [resp_str[i]] else: if self.reference_num is not None and len(resp_str[i]) != self.reference_num: raise RuntimeError("Require %d references but get %d" % (self.reference_num, len(resp_str[i]))) refs = resp_str[i] else: if self.reference_num == 1: refs = [self.dataloader.convert_ids_to_sentence(resp[i], remove_special=None, trim=True)] else: if self.reference_num is not None and len(resp[i]) != self.reference_num: raise RuntimeError("Require %d references but get %d" % (self.reference_num, len(resp[i]))) refs = [self.dataloader.convert_ids_to_sentence(resp_single_sen, remove_special=True, trim=True) for resp_single_sen in resp[i]] self.hyps.append(hyp) self.refs.append(refs) relevant_data.append(refs) self._hash_unordered_list(relevant_data)
[docs] def close(self) -> Dict[str, Any]: '''Return a dict which contains * **bleu**: bleu value. * **bleu hashvalue**: hash value for bleu metric, same hash value stands for same evaluation settings. ''' result = super().close() if (not self.hyps) or (not self.refs): raise RuntimeError("The metric has not been forwarded data correctly.") if self.tokenizer: self._do_tokenize() if "unk" in self.dataloader.get_special_tokens_mapping(): self.hyps = replace_unk(self.hyps, self.dataloader.get_special_tokens_mapping()["unk"]) try: weights = np.ones(self.ngram) / self.ngram result.update({"bleu": \ corpus_bleu(self.refs, self.hyps, weights=weights, smoothing_function=SmoothingFunction().method3), \ "bleu hashvalue": self._hashvalue()}) except ZeroDivisionError as _: if not self.ignore_smoothing_error: raise ZeroDivisionError("Bleu smoothing divided by zero. This is a known bug of corpus_bleu, \ usually caused when there is only one sample and the sample length is 1.") from None result.update({"bleu": \ 0, \ "bleu hashvalue": self._hashvalue()}) return result
def _do_tokenize(self): tokenizer: Tokenizer if isinstance(self.tokenizer, str): tokenizer = SimpleTokenizer(self.tokenizer) elif isinstance(self.tokenizer, Tokenizer): tokenizer = self.tokenizer else: raise TypeError("Unknown type of tokenizer") self.refs = tokenizer.tokenize_sessions(self.refs) self.hyps = tokenizer.tokenize_sentences(self.hyps)
[docs]class SelfBleuCorpusMetric(MetricBase): r'''Metric for calculating Self-BLEU. Arguments: {MetricBase.DATALOADER_ARGUMENTS} {MetricBase.NGRAM_ARGUMENTS} {MetricBase.TOKENIZER_ARGUMENTS} {MetricBase.GEN_KEY_ARGUMENTS} {MetricBase.SAMPLE_ARGUMENTS_IN_BLEU} {MetricBase.SEED_ARGUMENTS} {MetricBase.CPU_COUNT_ARGUMENTS} Warning: the calculation of ``hashvalue`` considers the actual sample size of hypotheses which will be less than ``sample`` if the size of hypotheses is smaller than ``sample``. Here is an example: >>> dl = cotk.dataloader.UbuntuCorpus('resources://Ubuntu_small') >>> gen_key = 'gen' >>> metric = cotk.metric.SelfBleuCorpusMetric(dl, gen_key=gen_key) >>> data = { ... gen_key: [[10, 64, 851, 3], [10, 48, 851, 3]], ... # gen_key: [["I", "like", "python", "<eos>"], ["I", "use", "python", "<eos>"]], ... } >>> metric.forward(data) >>> metric.close() {'self-bleu': 0.13512001548070346, 'self-bleu hashvalue': '53cf55829c1b080c86c392c846a5d39a54340c70d838ec953f952aa6731118fb'} ''' _name = 'SelfBleuCorpusMetric' _version = 2 def __init__(self, dataloader: Union["LanguageProcessing", "Sentence", "Session"], ngram: int = 4, *, \ tokenizer: Union[None, Tokenizer, str] = None, \ gen_key: str = "gen", \ sample: int = 1000, \ seed: int = 1229, \ cpu_count: Optional[int] = None): super().__init__(self._name, self._version) self.dataloader = dataloader self.ngram = ngram self.tokenizer = tokenizer self.gen_key = gen_key self.sample = sample self.hyps: List[Any] = [] self.seed = seed if cpu_count is not None: self.cpu_count = cpu_count elif "CPU_COUNT" in os.environ and os.environ["CPU_COUNT"] is not None: self.cpu_count = int(os.environ["CPU_COUNT"]) else: self.cpu_count = multiprocessing.cpu_count()
[docs] def forward(self, data: Dict[str, Any]): '''Processing a batch of data. Arguments: data (dict): A dict at least contains the following keys: {MetricBase.FORWARD_GEN_ARGUMENTS} Here is an example for data: >>> # all_vocab_list = ["<pad>", "<unk>", "<go>", "<eos>", "I", "have", >>> # "been", "to", "China"] >>> data = { ... gen_key: [[4,5,3], [6,7,8,3]] ... } ''' super().forward(data) gen = data[self.gen_key] if not isinstance(gen, (np.ndarray, list)): raise TypeError("Unknown type for gen.") self.hyps.extend(gen)
[docs] def close(self) -> Dict[str, Any]: '''Return a dict which contains * **self-bleu**: self-bleu value. * **self-bleu hashvalue**: hash value for self-bleu metric, same hash value stands for same evaluation settings. ''' res = super().close() if not self.hyps: raise RuntimeError("The metric has not been forwarded data correctly.") if len(self.hyps) == 1: raise RuntimeError("Self-Bleu can't be computed because there is only 1 generated sentence.") if self.sample <= 1: raise RuntimeError('`self.sample` should be more than 1, \ whose value is `{}`'.format(self.sample)) if self.sample > len(self.hyps): self.sample = len(self.hyps) rng_state = random.getstate() random.seed(self.seed) random.shuffle(self.hyps) random.setstate(rng_state) ref = self.hyps[:self.sample] if self.tokenizer: tokenizer: Tokenizer if isinstance(self.tokenizer, str): tokenizer = SimpleTokenizer(self.tokenizer) else: tokenizer = tokenizer ref = [self.dataloader.convert_ids_to_sentence(ids, remove_special=True, trim=True) for ids in ref] ref = tokenizer.tokenize_sentences(ref) else: ref = [self.dataloader.convert_ids_to_tokens(ids, remove_special=True, trim=True) for ids in ref] if "unk" in self.dataloader.get_special_tokens_mapping(): _ref = replace_unk(ref, self.dataloader.get_special_tokens_mapping()["unk"]) else: _ref = ref bleu_irl = [] weights = np.ones(self.ngram) / self.ngram tasks = ((ref[:i]+ref[i+1:self.sample], _ref[i], weights) for i in range(self.sample)) pool: Optional[Any] values: Iterable[Any] if self.sample >= 1000 and self.cpu_count > 1: # use multiprocessing pool = Pool(self.cpu_count) values = pool.imap_unordered(_sentence_bleu, tasks, chunksize=20) else: pool = None values = map(_sentence_bleu, tasks) if self.sample >= 1000: # use tqdm values = tqdm.tqdm(values, total=self.sample) for ans in values: bleu_irl.append(ans) if pool is not None: pool.close() pool.join() self._hash_ordered_data((self.seed, self.sample)) res.update({"self-bleu" : 1.0 * sum(bleu_irl) / len(bleu_irl),\ "self-bleu hashvalue": self._hashvalue()}) return res
[docs]class FwBwBleuCorpusMetric(MetricBase): r'''Metric for calculating FwBw-BLEU. Arguments: {MetricBase.DATALOADER_ARGUMENTS} {MetricBase.REFERENCE_TEST_LIST_ARGUMENTS} {MetricBase.NGRAM_ARGUMENTS} {MetricBase.TOKENIZER_ARGUMENTS} {MetricBase.GEN_KEY_ARGUMENTS} {MetricBase.SAMPLE_ARGUMENTS_IN_BLEU} {MetricBase.SEED_ARGUMENTS} {MetricBase.CPU_COUNT_ARGUMENTS} Warning: The calculation of ``hashvalue`` considers the actual sample size of hypotheses and references. Therefore ``hashvalue`` may vary with the size of hypothesis or references if the size of them is smaller than ``sample``. Here is an example: >>> dl = cotk.dataloader.UbuntuCorpus('resources://Ubuntu_small') >>> gen_key = 'gen' >>> metric = cotk.metric.FwBwBleuCorpusMetric(dl, ... reference_test_list=dl.get_all_batch('test')['session'][0].tolist(), ... gen_key=gen_key) >>> data = { ... gen_key: [[10, 64, 851, 3], [10, 48, 851, 3]], ... # gen_key: [["I", "like", "python", "<eos>"], ["I", "use", "python", "<eos>"]], ... } >>> metric.forward(data) >>> metric.close() {'fw-bleu': 0.007688528488990184, 'bw-bleu': 0.0012482612634667945, 'fw-bw-bleu': 0.002147816509441494, 'fw-bw-bleu hashvalue': '0e3f58a90225af615ff780f04c91613759e04a3c7b4329670b1d03b679adf8cd'} ''' _name = 'FwBwBleuCorpusMetric' _version = 2 def __init__(self, dataloader: Union["LanguageProcessing", "Sentence", "Session"], \ reference_test_list: List[Any], ngram: int = 4, *, \ tokenizer: Union[None, Tokenizer, str] = None, \ gen_key: str = "gen", \ sample: int = 1000, \ seed: int = 1229, \ cpu_count: Optional[int] = None): super().__init__(self._name, self._version) self.dataloader = dataloader self.tokenizer = tokenizer self.reference_test_list = reference_test_list self.gen_key = gen_key self.sample = sample self.seed = seed self.ngram=ngram if cpu_count is not None: self.cpu_count = cpu_count elif "CPU_COUNT" in os.environ and os.environ["CPU_COUNT"] is not None: self.cpu_count = int(os.environ["CPU_COUNT"]) else: self.cpu_count = multiprocessing.cpu_count() self.hyps: List[Any] = []
[docs] def forward(self, data: Dict[str, Any]): '''Processing a batch of data. Arguments: data (dict): A dict at least contains the following keys: {MetricBase.FORWARD_GEN_ARGUMENTS} Here is an example for data: >>> # all_vocab_list = ["<pad>", "<unk>", "<go>", "<eos>", "I", "have", >>> # "been", "to", "China"] >>> data = { ... gen_key: [[4,5,3], [6,7,8,3]] ... } ''' gen = data[self.gen_key] if not isinstance(gen, (np.ndarray, list)): raise TypeError("Unknown type for gen.") for gen_sen in gen: self.hyps.append(list(self.dataloader.trim_in_ids(gen_sen)))
[docs] def close(self) -> Dict[str, Any]: '''Return a dict which contains * **fw-bleu**: fw bleu value. * **bw-bleu**: bw bleu value. * **fw-bw-bleu**: harmony mean of fw/bw bleu value. * **fw-bw-bleu hashvalue**: hash value for fwbwbleu metric, same hash value stands for same evaluation settings. ''' res = super().close() if not self.hyps: raise RuntimeError("The metric has not been forwarded data correctly.") if not self.reference_test_list: raise RuntimeError("Reference cannot be empty") sample_hyps_num = self.sample if self.sample < len(self.hyps) else len(self.hyps) sample_refs_num = self.sample if self.sample < len(self.reference_test_list) else len(self.reference_test_list) if sample_hyps_num <= 1: raise RuntimeError('`sample_hyps` should be more than 1, \ whose value is `{}`'.format(sample_hyps_num)) if sample_refs_num <= 1: raise RuntimeError('`sample_refs` should be more than 1, \ whose value is `{}`'.format(sample_refs_num)) sample_hyps = self.hyps[:sample_hyps_num] sample_refs = self.reference_test_list[:sample_refs_num] refs: List[Any] hyps: List[Any] if self.tokenizer: tokenizer: Tokenizer if isinstance(self.tokenizer, str): tokenizer = SimpleTokenizer(self.tokenizer) else: tokenizer = tokenizer if isinstance(sample_refs[0], List): ref_sents = [self.dataloader.convert_ids_to_sentence(ids, remove_special=True, trim=True) for ids in sample_refs] else: ref_sents = sample_refs refs = tokenizer.tokenize_sentences(ref_sents) hyp_sents = [self.dataloader.convert_ids_to_sentence(ids, remove_special=True, trim=True) for ids in sample_hyps] hyps = tokenizer.tokenize_sentences(hyp_sents) else: refs = [self.dataloader.convert_ids_to_tokens(ids, remove_special=True, trim=True) for ids in sample_refs] hyps = [self.dataloader.convert_ids_to_tokens(ids, remove_special=True, trim=True) for ids in sample_hyps] rng_state = random.getstate() random.seed(self.seed) random.shuffle(hyps) random.shuffle(refs) random.setstate(rng_state) if "unk" in self.dataloader.get_special_tokens_mapping(): refs = replace_unk(refs, self.dataloader.get_special_tokens_mapping()["unk"]) bleu_irl_fw, bleu_irl_bw = [], [] weights = np.ones(self.ngram) / self.ngram tasks = ((refs, hyps[i], weights) for i in range(sample_hyps_num)) pool: Optional[Any] values: Iterable[Any] if sample_hyps_num >= 1000 and self.cpu_count > 1: pool = Pool(self.cpu_count) values = pool.imap_unordered(_sentence_bleu, tasks, chunksize=20) else: pool = None values = map(_sentence_bleu, tasks) if sample_hyps_num >= 1000: values = tqdm.tqdm(values, total=sample_hyps_num) for ans in values: bleu_irl_fw.append(ans) if pool is not None: pool.close() pool.join() tasks = ((hyps, refs[i], weights) for i in range(sample_refs_num)) if sample_refs_num >= 1000 and self.cpu_count > 1: pool = Pool(self.cpu_count) values = pool.imap_unordered(_sentence_bleu, tasks, chunksize=20) else: pool = None values = map(_sentence_bleu, tasks) if sample_refs_num >= 1000: values = tqdm.tqdm(values, total=sample_refs_num) for ans in values: bleu_irl_bw.append(ans) if pool is not None: pool.close() pool.join() fw_bleu = (1.0 * sum(bleu_irl_fw) / len(bleu_irl_fw)) bw_bleu = (1.0 * sum(bleu_irl_bw) / len(bleu_irl_bw)) if fw_bleu + bw_bleu > 0: fw_bw_bleu = 2.0 * bw_bleu * fw_bleu / (fw_bleu + bw_bleu) else: fw_bw_bleu = 0 res.update({"fw-bleu" : fw_bleu, \ "bw-bleu" : bw_bleu, \ "fw-bw-bleu" : fw_bw_bleu \ }) self._hash_unordered_list(refs) self._hash_ordered_data((self.ngram, self.seed, sample_hyps_num, sample_refs_num)) res.update({"fw-bw-bleu hashvalue" : self._hashvalue()}) return res
[docs]class MultiTurnBleuCorpusMetric(MetricBase): '''Metric for calculating multi-turn BLEU. Arguments: {MetricBase.DATALOADER_ARGUMENTS} {MetricBase.IGNORE_SMOOTHING_ERROR_ARGUMENTS} {MetricBase.MULTI_TURN_REFERENCE_ALLVOCABS_KEY_ARGUMENTS} {MetricBase.MULTI_TURN_GEN_KEY_ARGUMENTS} {MetricBase.MULTI_TURN_LENGTH_KEY_ARGUMENTS} Here is an example: >>> dl = cotk.dataloader.UbuntuCorpus('resources://Ubuntu_small') >>> multi_turn_reference_allvocabs_key = "reference_allvocabs" >>> multi_turn_gen_key = "multi_turn_gen" >>> turn_len_key = "turn_length" >>> metric = cotk.metric.MultiTurnBleuCorpusMetric(dl, >>> multi_turn_reference_allvocabs_key=multi_turn_reference_allvocabs_key, >>> multi_turn_gen_key=multi_turn_gen_key, >>> turn_len_key=turn_len_key) >>> data = { ... multi_turn_reference_allvocabs_key: [[[2, 10, 64, 851, 3], [2, 10, 64, 479, 3]], [[2, 10, 64, 279, 1460, 3]]], ... # multi_turn_reference_allvocabs_key = [[["<go>", "I", "like", "python", "<eos>"], ["<go>", "I", "like", "java", "<eos>"]], ... # [["<go>", "I", "like", "machine", "learning", "<eos>"]]] ... ... turn_len_key: [2, 1], ... # turn_len_key: [len(multi_turn_reference_allvocabs_key[0]), len(multi_turn_reference_allvocabs_key[1])] ... ... multi_turn_gen_key: [[[851, 17, 2451, 3], [2019, 17, 393, 3]], [[10, 64, 34058, 805, 2601, 3]]] ... # multi_turn_gen_key = [[["python", "is", "excellent", "<eos>"], ["PHP", "is", "best", "<eos>"]], ... # [["I", "like", "natural", "language", "processing", "<eos>"]]] ... } >>> metric.forward(data) >>> metric.close() {'bleu': 0.12081744577265555, 'bleu hashvalue': 'c65b44c454dee5a8d393901644c7f1acfdb847bae3ab03823cb5b9f643958960'} ''' _name = 'MultiTurnBleuCorpusMetric' _version = 2 def __init__(self, dataloader: Union["LanguageProcessing", "Sentence", "Session"], \ ignore_smoothing_error: bool = False,\ multi_turn_reference_allvocabs_key: str = "reference_allvocabs", \ multi_turn_gen_key: str = "multi_turn_gen", \ turn_len_key: str = "turn_length" \ ): super().__init__(self._name, self._version) self.dataloader = dataloader self.ignore_smoothing_error = ignore_smoothing_error self.multi_turn_reference_allvocabs_key = multi_turn_reference_allvocabs_key self.turn_len_key = turn_len_key self.multi_turn_gen_key = multi_turn_gen_key self.refs = [] self.hyps = []
[docs] def forward(self, data: Dict[str, Any]): '''Processing a batch of data. Arguments: data (dict): A dict at least contains the following keys: {MetricBase.FORWARD_MULTI_TURN_REFERENCE_ALLVOCABS_ARGUMENTS} {MetricBase.FORWARD_MULTI_TURN_GEN_ARGUMENTS} {MetricBase.FORWARD_MULTI_TURN_LENGTH_ARGUMENTS} Here is an example for data: >>> # all_vocab_list = ["<pad>", "<unk>", "<go>", "<eos>", "I", "have", >>> # "been", "to", "China"] >>> data = { ... multi_turn_reference_allvocabs_key: [[[2,4,3], [2,5,6,3]], [[2,7,6,8,3]]], ... turn_len_key: [2, 1], ... gen_key: [[[6,7,8,3], [4,5,3]], [[7,3]]] ... } ''' super().forward(data) reference_allvocabs = data[self.multi_turn_reference_allvocabs_key] length = data[self.turn_len_key] gen = data[self.multi_turn_gen_key] if not isinstance(reference_allvocabs, (np.ndarray, list)): raise TypeError("Unknown type for reference_allvocabs.") if not isinstance(length, (np.ndarray, list)): raise TypeError("Unknown type for length") if not isinstance(gen, (np.ndarray, list)): raise TypeError("Unknown type for gen") if len(length) != len(reference_allvocabs) or len(length) != len(gen): raise ValueError("Batch num is not matched.") for i, turn_length in enumerate(length): gen_session = gen[i] ref_session = reference_allvocabs[i] for j in range(turn_length): self.hyps.append(list(self.dataloader.trim_in_ids(gen_session[j]))) self.refs.append([list(self.dataloader.trim_in_ids(ref_session[j])[1:])])
[docs] def close(self) -> Dict[str, Any]: '''Return a dict which contains * **bleu**: bleu value. * **bleu hashvalue**: hash value for bleu metric, same hash value stands for same evaluation settings. ''' result = super().close() if (not self.hyps) or (not self.refs): raise RuntimeError("The metric has not been forwarded data correctly.") self.hyps = replace_unk(self.hyps, self.dataloader.unk_id) self._hash_unordered_list(self.refs) try: result.update({"bleu": \ corpus_bleu(self.refs, self.hyps, smoothing_function=SmoothingFunction().method3), \ "bleu hashvalue": self._hashvalue()}) except ZeroDivisionError as _: if not self.ignore_smoothing_error: raise ZeroDivisionError("Bleu smoothing divided by zero. This is a known bug of corpus_bleu, \ usually caused when there is only one sample and the sample length is 1.") result.update({"bleu": \ 0, \ "bleu hashvalue": self._hashvalue()}) return result