Source code for cotk.file_utils.resource_processor

'''Processor for resource after download and before read
'''
import os
import zipfile
import shutil
import json
import hashlib
from itertools import chain
from .._utils.metaclass import LoadClassInterface

def unzip_file(src_path, dst_dir):
	'''unzip the zip file in src_path to dst_dir
	'''
	if zipfile.is_zipfile(src_path):
		with zipfile.ZipFile(src_path, 'r') as zip_obj:
			zip_obj.extractall(dst_dir)
	else:
		raise ValueError('{} is not zip'.format(src_path))

[docs]class ResourceProcessor(LoadClassInterface): '''Base class for processor. ''' def __init__(self, cache_dir, config_dir): self.cache_dir = cache_dir self.config_dir = config_dir
[docs]class DefaultResourceProcessor(ResourceProcessor): '''Processor for default resource: do nothing. '''
[docs] def preprocess(self, local_path): '''Preprocess after download and before save. ''' return local_path
[docs] def postprocess(self, local_path): '''Postprocess before read. ''' return local_path
class ZipResourceProcessor(ResourceProcessor): '''Processor for default resource: extract zip. ''' def preprocess(self, local_path): '''Preprocess after download and before save. ''' if os.path.isdir(local_path): return local_path dst_dir = local_path + '_unzip' unzip_file(local_path, dst_dir) return dst_dir def postprocess(self, local_path): '''Postprocess before read. ''' return local_path
[docs]class BaseResourceProcessor(ResourceProcessor): """Basic processor for MSCOCO, OpenSubtitles, Ubuntu..."""
[docs] def basepreprocess(self, local_path, name): '''Preprocess after download and before save. ''' if os.path.isdir(local_path): return local_path dst_dir = local_path + '_unzip' unzip_file(local_path, dst_dir) return os.path.join(dst_dir, name)
[docs] def postprocess(self, local_path): '''Postprocess before read. ''' return local_path
[docs] def get_temp_dir(self, filepath: str): """Get a temp directory, in which some temporary files may be saved. The temp directory is a subdirectory of `self.cache_dr` and is named after the hash value of argument `filepath`, so that the same `filepath` has the same corresponding temp directory. """ abs_path = os.path.abspath(filepath) hash_value = hashlib.sha256(abs_path.encode('utf-8')).hexdigest() return os.path.join(self.cache_dir, hash_value + '_temp')
#TODO: merge the following Processor because of duplicate codes
[docs]class MSCOCOResourceProcessor(BaseResourceProcessor): '''Processor for MSCOCO dataset '''
[docs] def preprocess(self, local_path): '''Preprocess after download and before save. ''' return self.basepreprocess(local_path, 'mscoco')
[docs] def postprocess(self, local_path): local_path = super().postprocess(local_path) new_local_path = self.get_temp_dir(local_path) os.makedirs(new_local_path, exist_ok=True) for key in ['train', 'dev', 'test']: local_file = os.path.join(local_path, 'mscoco_%s.txt' % key) new_local_file = os.path.join(new_local_path, '%s.txt' % key) if os.path.isfile(local_file): shutil.copy(local_file, new_local_file) return new_local_path
[docs]class OpenSubtitlesResourceProcessor(BaseResourceProcessor): '''Processor for OpenSubtitles Dataset '''
[docs] def preprocess(self, local_path): '''Preprocess after download and before save. ''' return self.basepreprocess(local_path, 'opensubtitles')
[docs] def postprocess(self, local_path): local_path = super().postprocess(local_path) new_local_path = self.get_temp_dir(local_path) os.makedirs(new_local_path, exist_ok=True) for key in ['train', 'test', 'dev']: post_path = os.path.join(local_path, 'opensub_pair_%s.post' % key) response_path = os.path.join(local_path, 'opensub_pair_%s.response' % key) if not os.path.isfile(post_path) or not os.path.isfile(response_path): continue with open(post_path, 'r', encoding='utf-8') as posts: with open(response_path, 'r', encoding='utf-8') as responses: with open(os.path.join(new_local_path, '%s.txt' % key), 'w', encoding='utf-8') as out: for post, resp in zip(posts, responses): out.write(post if post[-1] == '\n' else (post + '\n')) out.write(resp if resp[-1] == '\n' else (resp + '\n')) return new_local_path
[docs]class UbuntuResourceProcessor(BaseResourceProcessor): '''Processor for UbuntuCorpus dataset '''
[docs] def preprocess(self, local_path): '''Preprocess after download and before save. ''' return self.basepreprocess(local_path, 'ubuntu_dataset')
[docs] def postprocess(self, local_path): import csv local_path = super().postprocess(local_path) new_local_path = self.get_temp_dir(local_path) os.makedirs(new_local_path, exist_ok=True) for key in ['train', 'dev', 'test']: local_file = os.path.join(local_path, 'ubuntu_corpus_%s.csv' % key) if not os.path.isfile(local_file): continue new_local_file = os.path.join(new_local_path, '%s.txt' % key) with open(local_file, 'r', encoding='utf-8') as f: reader = csv.reader(f) head = next(reader) if head[2] == 'Label': raw_data = [d[0] + d[1] for d in reader if d[2] == '1.0'] else: raw_data = [d[0] + d[1] for d in reader] with open(new_local_file, 'w', encoding='utf-8') as f: for session in raw_data: for sent in session.strip().replace('__eou__', '').split('__eot__'): f.write(sent) f.write('\n') f.write('\n') return new_local_path
[docs]class SwitchboardCorpusResourceProcessor(BaseResourceProcessor): '''Processor for SwitchboardCorpus dataset '''
[docs] def preprocess(self, local_path): '''Preprocess after download and before save. ''' return self.basepreprocess(local_path, 'switchboard_corpus')
[docs] def postprocess(self, local_path): local_path = super().postprocess(local_path) new_local_path = self.get_temp_dir(local_path) os.makedirs(new_local_path, exist_ok=True) for key in ['train', 'test', 'dev', 'multi_ref']: filepath = os.path.join(local_path, 'switchboard_corpus_%s.jsonl' % key) new_filepath = os.path.join(new_local_path, '%s.txt' % key) res = self._read_file(filepath, key == 'multi_ref') with open(new_filepath, 'w', encoding='utf-8') as fout: if key != 'multi_ref': dataset = res # sessions else: sessions, responses = res dataset = chain.from_iterable(zip(sessions, responses)) # [session1, response1, session2, response2, ...] # response is like a session. Both contain several sentences. for sess in dataset: assert sess for line in sess: if line[-1] != '\n': line += '\n' fout.write(line) fout.write('\n') return new_local_path
[docs] def _read_file(self, filepath, read_multi_ref=False): """ Arguments: filepath (str): Name of the file to read from read_multi_ref (bool): If False, add turn ``<d>`` ahead of each session If True, add turn ``<d>`` at the end of each session and read candidate ``responses`` """ sessions = [] if read_multi_ref: responses = [] with open(filepath, "r", encoding='utf-8') as data_file: for line in data_file: line = json.loads(line) prefix_utts = [['X', '<d>']] + line['utts'] # pylint: disable=cell-var-from-loop suffix_utts = list(map(lambda utt: utt[1][1].strip() + ' ' \ if prefix_utts[utt[0]][0] == utt[1][0] \ else '<eos> ' + utt[1][1].strip() + ' ', enumerate(line['utts']))) utts = ('<d> ' + "".join(suffix_utts).strip()).split("<eos>") sess = utts[1:] + ['<d>'] if read_multi_ref else utts sessions.append(sess) if read_multi_ref: responses.append([resp for _, resp in line['responses']]) if read_multi_ref: return sessions, responses else: return sessions
[docs]class SSTResourceProcessor(BaseResourceProcessor): '''Processor for SST dataset '''
[docs] def preprocess(self, local_path): '''Preprocess after download and before save. ''' return self.basepreprocess(local_path, 'trees')
def _parseline(self, line): label = int(line[1]) line = line.split(')') sent = [x.split(' ')[-1].lower() for x in line if x != ''] return label, ' '.join(sent) def _postprocess(self, src, dest, key): with open(os.path.join(src, key + '.txt'), 'r', encoding='utf-8') as fin, \ open(os.path.join(dest, key + '.txt'), 'w', encoding='utf-8') as fout: for label, sent in map(self._parseline, fin): fout.write(sent) fout.write(str(label) + '\n')
[docs] def postprocess(self, local_path): local_path = super().postprocess(local_path) new_local_path = self.get_temp_dir(local_path) os.makedirs(new_local_path, exist_ok=True) for key in ['train', 'test', 'dev']: if not os.path.isfile(os.path.join(local_path, key + '.txt')): raise FileNotFoundError("there isn\'t %s in %s" % (key + '.txt', local_path)) else: self._postprocess(local_path, new_local_path, key) return new_local_path
[docs]class GloveResourceProcessor(ResourceProcessor): '''Base Class for all dimension version of glove wordvector. ''' def __init__(self, cache_dir=None, config_dir=None): super(GloveResourceProcessor, self).__init__(cache_dir, config_dir) self.other_gloves = []
[docs] def basepreprocess(self, local_path, name): '''Preprocess after download and before save. ''' dst_dir = local_path + '_unzip' unzip_file(local_path, dst_dir) filenames = os.listdir(dst_dir) for filename in filenames: if os.path.isdir(os.path.join(dst_dir, filename)): continue dim = filename.split('.')[-2] if dim != name and self.cache_dir is not None and self.config_dir is not None: self.other_gloves.append(["resources://Glove%s" % (dim), \ local_path]) continue sub_dir = os.path.join(dst_dir, dim) os.makedirs(sub_dir, exist_ok=True) with open(os.path.join(dst_dir, filename)) as f: with open(os.path.join(sub_dir, 'wordvec.txt'), 'w') as g: for line in f: word, wordvec = line.strip().split(" ", 1) g.write(word + "\n") g.write(wordvec + "\n") return dst_dir
[docs] def basepostprocess(self, local_path, name): '''Postprocess before read. ''' from .file_utils import import_local_resources for glove in self.other_gloves: import_local_resources(glove[0], glove[1], self.cache_dir, self.config_dir, ignore_exist_error=True) self.other_gloves = [] return os.path.join(local_path, name)
[docs]class Glove50dResourceProcessor(GloveResourceProcessor): '''Processor for glove50d wordvector '''
[docs] def preprocess(self, local_path): '''Preprocess after download and before save. ''' return self.basepreprocess(local_path, '50d')
[docs] def postprocess(self, local_path): '''Postprocess before read. ''' return self.basepostprocess(local_path, '50d')
[docs]class Glove100dResourceProcessor(GloveResourceProcessor): '''Processor for glove100d wordvector '''
[docs] def preprocess(self, local_path): '''Preprocess after download and before save. ''' return self.basepreprocess(local_path, '100d')
[docs] def postprocess(self, local_path): '''Postprocess before read. ''' return self.basepostprocess(local_path, '100d')
[docs]class Glove200dResourceProcessor(GloveResourceProcessor): '''Processor for glove200d wordvector '''
[docs] def preprocess(self, local_path): '''Preprocess after download and before save. ''' return self.basepreprocess(local_path, '200d')
[docs] def postprocess(self, local_path): '''Postprocess before read. ''' return self.basepostprocess(local_path, '200d')
[docs]class Glove300dResourceProcessor(GloveResourceProcessor): '''Processor for glove300d wordvector '''
[docs] def preprocess(self, local_path): '''Preprocess after download and before save. ''' return self.basepreprocess(local_path, '300d')
[docs] def postprocess(self, local_path): '''Postprocess before read. ''' return self.basepostprocess(local_path, '300d')