'''Processor for resource after download and before read
'''
import os
import zipfile
import shutil
import json
import hashlib
from itertools import chain
from .._utils.metaclass import LoadClassInterface
def unzip_file(src_path, dst_dir):
'''unzip the zip file in src_path to dst_dir
'''
if zipfile.is_zipfile(src_path):
with zipfile.ZipFile(src_path, 'r') as zip_obj:
zip_obj.extractall(dst_dir)
else:
raise ValueError('{} is not zip'.format(src_path))
[docs]class ResourceProcessor(LoadClassInterface):
'''Base class for processor.
'''
def __init__(self, cache_dir, config_dir):
self.cache_dir = cache_dir
self.config_dir = config_dir
[docs]class DefaultResourceProcessor(ResourceProcessor):
'''Processor for default resource: do nothing.
'''
[docs] def preprocess(self, local_path):
'''Preprocess after download and before save.
'''
return local_path
[docs] def postprocess(self, local_path):
'''Postprocess before read.
'''
return local_path
class ZipResourceProcessor(ResourceProcessor):
'''Processor for default resource: extract zip.
'''
def preprocess(self, local_path):
'''Preprocess after download and before save.
'''
if os.path.isdir(local_path):
return local_path
dst_dir = local_path + '_unzip'
unzip_file(local_path, dst_dir)
return dst_dir
def postprocess(self, local_path):
'''Postprocess before read.
'''
return local_path
[docs]class BaseResourceProcessor(ResourceProcessor):
"""Basic processor for MSCOCO, OpenSubtitles, Ubuntu..."""
[docs] def basepreprocess(self, local_path, name):
'''Preprocess after download and before save.
'''
if os.path.isdir(local_path):
return local_path
dst_dir = local_path + '_unzip'
unzip_file(local_path, dst_dir)
return os.path.join(dst_dir, name)
[docs] def postprocess(self, local_path):
'''Postprocess before read.
'''
return local_path
[docs] def get_temp_dir(self, filepath: str):
"""Get a temp directory, in which some temporary files may be saved.
The temp directory is a subdirectory of `self.cache_dr` and is named after the hash value of argument `filepath`,
so that the same `filepath` has the same corresponding temp directory.
"""
abs_path = os.path.abspath(filepath)
hash_value = hashlib.sha256(abs_path.encode('utf-8')).hexdigest()
return os.path.join(self.cache_dir, hash_value + '_temp')
#TODO: merge the following Processor because of duplicate codes
[docs]class MSCOCOResourceProcessor(BaseResourceProcessor):
'''Processor for MSCOCO dataset
'''
[docs] def preprocess(self, local_path):
'''Preprocess after download and before save.
'''
return self.basepreprocess(local_path, 'mscoco')
[docs] def postprocess(self, local_path):
local_path = super().postprocess(local_path)
new_local_path = self.get_temp_dir(local_path)
os.makedirs(new_local_path, exist_ok=True)
for key in ['train', 'dev', 'test']:
local_file = os.path.join(local_path, 'mscoco_%s.txt' % key)
new_local_file = os.path.join(new_local_path, '%s.txt' % key)
if os.path.isfile(local_file):
shutil.copy(local_file, new_local_file)
return new_local_path
[docs]class OpenSubtitlesResourceProcessor(BaseResourceProcessor):
'''Processor for OpenSubtitles Dataset
'''
[docs] def preprocess(self, local_path):
'''Preprocess after download and before save.
'''
return self.basepreprocess(local_path, 'opensubtitles')
[docs] def postprocess(self, local_path):
local_path = super().postprocess(local_path)
new_local_path = self.get_temp_dir(local_path)
os.makedirs(new_local_path, exist_ok=True)
for key in ['train', 'test', 'dev']:
post_path = os.path.join(local_path, 'opensub_pair_%s.post' % key)
response_path = os.path.join(local_path, 'opensub_pair_%s.response' % key)
if not os.path.isfile(post_path) or not os.path.isfile(response_path):
continue
with open(post_path, 'r', encoding='utf-8') as posts:
with open(response_path, 'r', encoding='utf-8') as responses:
with open(os.path.join(new_local_path, '%s.txt' % key), 'w', encoding='utf-8') as out:
for post, resp in zip(posts, responses):
out.write(post if post[-1] == '\n' else (post + '\n'))
out.write(resp if resp[-1] == '\n' else (resp + '\n'))
return new_local_path
[docs]class UbuntuResourceProcessor(BaseResourceProcessor):
'''Processor for UbuntuCorpus dataset
'''
[docs] def preprocess(self, local_path):
'''Preprocess after download and before save.
'''
return self.basepreprocess(local_path, 'ubuntu_dataset')
[docs] def postprocess(self, local_path):
import csv
local_path = super().postprocess(local_path)
new_local_path = self.get_temp_dir(local_path)
os.makedirs(new_local_path, exist_ok=True)
for key in ['train', 'dev', 'test']:
local_file = os.path.join(local_path, 'ubuntu_corpus_%s.csv' % key)
if not os.path.isfile(local_file):
continue
new_local_file = os.path.join(new_local_path, '%s.txt' % key)
with open(local_file, 'r', encoding='utf-8') as f:
reader = csv.reader(f)
head = next(reader)
if head[2] == 'Label':
raw_data = [d[0] + d[1] for d in reader if d[2] == '1.0']
else:
raw_data = [d[0] + d[1] for d in reader]
with open(new_local_file, 'w', encoding='utf-8') as f:
for session in raw_data:
for sent in session.strip().replace('__eou__', '').split('__eot__'):
f.write(sent)
f.write('\n')
f.write('\n')
return new_local_path
[docs]class SwitchboardCorpusResourceProcessor(BaseResourceProcessor):
'''Processor for SwitchboardCorpus dataset
'''
[docs] def preprocess(self, local_path):
'''Preprocess after download and before save.
'''
return self.basepreprocess(local_path, 'switchboard_corpus')
[docs] def postprocess(self, local_path):
local_path = super().postprocess(local_path)
new_local_path = self.get_temp_dir(local_path)
os.makedirs(new_local_path, exist_ok=True)
for key in ['train', 'test', 'dev', 'multi_ref']:
filepath = os.path.join(local_path, 'switchboard_corpus_%s.jsonl' % key)
new_filepath = os.path.join(new_local_path, '%s.txt' % key)
res = self._read_file(filepath, key == 'multi_ref')
with open(new_filepath, 'w', encoding='utf-8') as fout:
if key != 'multi_ref':
dataset = res # sessions
else:
sessions, responses = res
dataset = chain.from_iterable(zip(sessions, responses)) # [session1, response1, session2, response2, ...]
# response is like a session. Both contain several sentences.
for sess in dataset:
assert sess
for line in sess:
if line[-1] != '\n':
line += '\n'
fout.write(line)
fout.write('\n')
return new_local_path
[docs] def _read_file(self, filepath, read_multi_ref=False):
"""
Arguments:
filepath (str): Name of the file to read from
read_multi_ref (bool):
If False, add turn ``<d>`` ahead of each session
If True, add turn ``<d>`` at the end of each session and read candidate ``responses``
"""
sessions = []
if read_multi_ref:
responses = []
with open(filepath, "r", encoding='utf-8') as data_file:
for line in data_file:
line = json.loads(line)
prefix_utts = [['X', '<d>']] + line['utts']
# pylint: disable=cell-var-from-loop
suffix_utts = list(map(lambda utt: utt[1][1].strip() + ' ' \
if prefix_utts[utt[0]][0] == utt[1][0] \
else '<eos> ' + utt[1][1].strip() + ' ', enumerate(line['utts'])))
utts = ('<d> ' + "".join(suffix_utts).strip()).split("<eos>")
sess = utts[1:] + ['<d>'] if read_multi_ref else utts
sessions.append(sess)
if read_multi_ref:
responses.append([resp for _, resp in line['responses']])
if read_multi_ref:
return sessions, responses
else:
return sessions
[docs]class SSTResourceProcessor(BaseResourceProcessor):
'''Processor for SST dataset
'''
[docs] def preprocess(self, local_path):
'''Preprocess after download and before save.
'''
return self.basepreprocess(local_path, 'trees')
def _parseline(self, line):
label = int(line[1])
line = line.split(')')
sent = [x.split(' ')[-1].lower() for x in line if x != '']
return label, ' '.join(sent)
def _postprocess(self, src, dest, key):
with open(os.path.join(src, key + '.txt'), 'r', encoding='utf-8') as fin, \
open(os.path.join(dest, key + '.txt'), 'w', encoding='utf-8') as fout:
for label, sent in map(self._parseline, fin):
fout.write(sent)
fout.write(str(label) + '\n')
[docs] def postprocess(self, local_path):
local_path = super().postprocess(local_path)
new_local_path = self.get_temp_dir(local_path)
os.makedirs(new_local_path, exist_ok=True)
for key in ['train', 'test', 'dev']:
if not os.path.isfile(os.path.join(local_path, key + '.txt')):
raise FileNotFoundError("there isn\'t %s in %s" % (key + '.txt', local_path))
else:
self._postprocess(local_path, new_local_path, key)
return new_local_path
[docs]class GloveResourceProcessor(ResourceProcessor):
'''Base Class for all dimension version of glove wordvector.
'''
def __init__(self, cache_dir=None, config_dir=None):
super(GloveResourceProcessor, self).__init__(cache_dir, config_dir)
self.other_gloves = []
[docs] def basepreprocess(self, local_path, name):
'''Preprocess after download and before save.
'''
dst_dir = local_path + '_unzip'
unzip_file(local_path, dst_dir)
filenames = os.listdir(dst_dir)
for filename in filenames:
if os.path.isdir(os.path.join(dst_dir, filename)):
continue
dim = filename.split('.')[-2]
if dim != name and self.cache_dir is not None and self.config_dir is not None:
self.other_gloves.append(["resources://Glove%s" % (dim), \
local_path])
continue
sub_dir = os.path.join(dst_dir, dim)
os.makedirs(sub_dir, exist_ok=True)
with open(os.path.join(dst_dir, filename)) as f:
with open(os.path.join(sub_dir, 'wordvec.txt'), 'w') as g:
for line in f:
word, wordvec = line.strip().split(" ", 1)
g.write(word + "\n")
g.write(wordvec + "\n")
return dst_dir
[docs] def basepostprocess(self, local_path, name):
'''Postprocess before read.
'''
from .file_utils import import_local_resources
for glove in self.other_gloves:
import_local_resources(glove[0], glove[1], self.cache_dir, self.config_dir, ignore_exist_error=True)
self.other_gloves = []
return os.path.join(local_path, name)
[docs]class Glove50dResourceProcessor(GloveResourceProcessor):
'''Processor for glove50d wordvector
'''
[docs] def preprocess(self, local_path):
'''Preprocess after download and before save.
'''
return self.basepreprocess(local_path, '50d')
[docs] def postprocess(self, local_path):
'''Postprocess before read.
'''
return self.basepostprocess(local_path, '50d')
[docs]class Glove100dResourceProcessor(GloveResourceProcessor):
'''Processor for glove100d wordvector
'''
[docs] def preprocess(self, local_path):
'''Preprocess after download and before save.
'''
return self.basepreprocess(local_path, '100d')
[docs] def postprocess(self, local_path):
'''Postprocess before read.
'''
return self.basepostprocess(local_path, '100d')
[docs]class Glove200dResourceProcessor(GloveResourceProcessor):
'''Processor for glove200d wordvector
'''
[docs] def preprocess(self, local_path):
'''Preprocess after download and before save.
'''
return self.basepreprocess(local_path, '200d')
[docs] def postprocess(self, local_path):
'''Postprocess before read.
'''
return self.basepostprocess(local_path, '200d')
[docs]class Glove300dResourceProcessor(GloveResourceProcessor):
'''Processor for glove300d wordvector
'''
[docs] def preprocess(self, local_path):
'''Preprocess after download and before save.
'''
return self.basepreprocess(local_path, '300d')
[docs] def postprocess(self, local_path):
'''Postprocess before read.
'''
return self.basepostprocess(local_path, '300d')