Source code for cotk.file_utils.file_utils

"""
Utilities for working with the local dataset cache.
This file is adapted from the AllenNLP library at https://github.com/allenai/allennlp
Copyright by the AllenNLP authors.
"""

import os
import json
import tempfile
import shutil
import hashlib
import logging
import sys
from pathlib import Path
from urllib.parse import urlparse
import requests

from tqdm import tqdm
from checksumdir import dirhash

from .resource_processor import ResourceProcessor

LOGGER = logging.getLogger(__name__)
LOGGER.setLevel(level=logging.INFO)
FORMAT = logging.Formatter("%(levelname)s: %(message)s")
SH = logging.StreamHandler(stream=sys.stdout)
SH.setFormatter(FORMAT)
LOGGER.addHandler(SH)
CACHE_DIR = os.path.join(str(Path.home()), '.cotk_cache')
CONFIG_DIR = os.path.join(os.path.dirname(__file__), '../resource_config')

def _url_to_filename(url):
	r'''Convert the url to sha256 as filename
	'''
	url_bytes = url.encode('utf-8')
	url_hash = hashlib.sha256(url_bytes)
	filename = url_hash.hexdigest()

	return filename


def _get_config(res_name, config_dir=None):
	'''Get config(dict) by the name of resource'''
	config_dir = config_dir or CONFIG_DIR
	config_path = os.path.join(config_dir, res_name + '.json')
	if not os.path.exists(config_path):
		raise FileNotFoundError("file {} not found".format(config_path))
	with open(config_path, 'r', encoding='utf-8') as config_file:
		config = json.load(config_file)

	return config


def _http_get(url, temp_file):
	'''Pull a file directly from http'''
	req = requests.get(url, stream=True)
	content_length = req.headers.get('Content-Length')
	total = int(content_length) if content_length is not None else None
	progress = tqdm(unit="B", total=total)
	for chunk in req.iter_content(chunk_size=65536):
		if chunk:
			progress.update(len(chunk))
			temp_file.write(chunk)
	progress.close()


def _get_file_sha256(file_path):
	'''Get sha256 of given file'''
	hash_sha256 = hashlib.sha256()
	with open(file_path, "rb") as fin:
		for chunk in iter(lambda: fin.read(4096), b""):
			hash_sha256.update(chunk)
	return hash_sha256.hexdigest()


def _get_hashtag(file_path):
	'''Get sha256 of given directory or file'''
	if os.path.isdir(file_path):
		return dirhash(file_path, 'sha256')
	else:
		return _get_file_sha256(file_path)

def _parse_file_id(file_id):
	'''
	file_id contains one essential part and two optional parts

	file_id: name[@source][#processor]

	examples:

	.. code-block:: text

		file_id=https://XXX/            name=https://XXX/   source=None      processor=Default
		file_id=MSCOCO@tsinghua#Glove   name=MSCOCO         source=tsinghua  processor=GloveProcessor
	'''
	# There may be # in name, so we process file_id reversely.
	# TODO: what if there is # in name?
	source, processor = None, None
	name = file_id[::-1]
	if '#' in file_id:
		processor, name = name.split('#', 1)
		processor = processor[::-1]
	if '@' in file_id:
		source, name = name.split('@', 1)
		source = source[::-1]
	name = name[::-1]
	return name, source, processor

def _get_resource(file_id, cache_dir=None, config_dir=None):
	'''Get the resource with the given name.
	If not cached, download it using the URL stored in config file.
	If cached, check the hashtag.
	'''
	cache_dir = cache_dir or CACHE_DIR
	config_dir = config_dir or CONFIG_DIR
	os.makedirs(cache_dir, exist_ok=True)

	res_name, src_name, res_type = _parse_file_id(file_id)
	config = _get_config(res_name, config_dir)

	src_name = src_name or 'default'
	res_type = res_type or config.get('type', 'Default')

	resource_processor_class = ResourceProcessor.load_class(res_type + 'ResourceProcessor')
	if resource_processor_class is None:
		raise RuntimeError("No resources type named %sResourcePreprocessor" % res_type)
	resource_processor = resource_processor_class(cache_dir, config_dir)

	if src_name not in config['link']:
		raise ValueError("source {} wrong".format(src_name))
	url = config['link'][src_name]
	cache_path = os.path.join(cache_dir, _url_to_filename(res_name))
	meta_path = os.path.join(cache_dir, _url_to_filename(res_name) + '.json')

	LOGGER.info('downloading resources')
	LOGGER.info('name: %s', res_name)
	LOGGER.info('source: %s', src_name)
	LOGGER.info('url: %s', url)
	LOGGER.info('processor: %s', res_type)

	if not os.path.exists(meta_path):
		with tempfile.NamedTemporaryFile()  as temp_file:
			_http_get(url, temp_file)
			temp_file.flush() # flush to avoid truncation
			temp_file.seek(0) # shutil.copyfileobj() starts at the current position
			with open(cache_path, 'wb') as cache_file:
				shutil.copyfileobj(temp_file, cache_file)

		# filename hash for search, content hash for validation
		content_hash = _get_file_sha256(cache_path)

		cache_path = resource_processor.preprocess(cache_path)

		if content_hash == config['hashtag']:
			meta = {'hashtag': content_hash, 'local_path': cache_path}
			with open(meta_path, 'w', encoding='utf-8') as meta_file:
				json.dump(meta, meta_file)
		else:
			print("bad hashtag {}, correct is {}".format(content_hash, config['hashtag']))
			raise ValueError("bad hashtag of {}".format(res_name))

	else:
		with open(meta_path, 'r', encoding='utf-8') as meta_file:
			meta = json.load(meta_file)
			cache_path = meta['local_path']
			content_hash = meta['hashtag']

		LOGGER.info('{} exists in cache'.format(res_name))
		if content_hash != config['hashtag']:
			raise ValueError("bad hashtag of {}, name conflication or mismatched content. \
							meta path {}. cache path {}".format(res_name, meta_path, cache_path))

	cache_path = resource_processor.postprocess(cache_path)
	LOGGER.info('resource cached at %s', cache_path)
	return cache_path

def _download_data(url, cache_dir=None, config_dir=None):
	r'''If not cached, download the resource using url.
	'''
	cache_dir = cache_dir or CACHE_DIR
	config_dir = config_dir or CONFIG_DIR
	os.makedirs(cache_dir, exist_ok=True)

	url, _, res_type = _parse_file_id(url)
	res_type = res_type or 'Default'
	LOGGER.info('url: %s', url)
	LOGGER.info('processor type: %s', res_type)

	resource_processor = ResourceProcessor.load_class(res_type + 'ResourceProcessor')(cache_dir=cache_dir,config_dir=config_dir)
	cache_path = os.path.join(cache_dir, _url_to_filename(url))
	meta_path = os.path.join(cache_dir, _url_to_filename(url) + '.json')

	if not os.path.exists(meta_path):
		with tempfile.NamedTemporaryFile() as temp_file:
			_http_get(url, temp_file)
			temp_file.flush() # flush to avoid truncation
			temp_file.seek(0) # shutil.copyfileobj() starts at the current position

			with open(cache_path, 'wb') as cache_file:
				shutil.copyfileobj(temp_file, cache_file)
			# filename hash for search, content hash for validation
			content_hash = _get_file_sha256(cache_path)

			cache_path = resource_processor.preprocess(cache_path)

			meta = {'local_path': cache_path, 'hashtag': content_hash}
			with open(meta_path, 'w', encoding='utf-8') as meta_file:
				json.dump(meta, meta_file)
	else:
		with open(meta_path, 'r', encoding='utf-8') as meta_file:
			meta = json.load(meta_file)
			cache_path = meta['local_path']
	cache_path = resource_processor.postprocess(cache_path)
	LOGGER.info('resource cached at %s', cache_path)
	return cache_path

def _load_local_data(local_path, cache_dir=None, config_dir=None):
	'''Import temporary resources from local'''
	cache_dir = cache_dir or CACHE_DIR
	config_dir = config_dir or CONFIG_DIR
	local_path, _, res_type = _parse_file_id(local_path)
	res_type = res_type or 'Default'
	LOGGER.info('local path: %s', local_path)
	LOGGER.info('processor type: %s', res_type)

	resource_processor_class = ResourceProcessor.load_class(res_type + 'ResourceProcessor')
	if resource_processor_class is None:
		raise RuntimeError("No resources type named %sResourcePreprocessor" % res_type)
	resource_processor = resource_processor_class(cache_dir=cache_dir, config_dir=config_dir)

	if local_path.endswith(".zip"):
		local_path = resource_processor.preprocess(local_path)
	return resource_processor.postprocess(local_path)


[docs]def get_resource_file_path(file_id, cache_dir=None, config_dir=None): '''Get file_path of resource of all types ''' cache_dir = cache_dir or CACHE_DIR config_dir = config_dir or CONFIG_DIR if file_id.startswith('resources://'): res_id = file_id[12:] return _get_resource(res_id, cache_dir, config_dir) elif file_id.startswith('http://') or file_id.startswith('https://'): url = file_id return _download_data(url, cache_dir, config_dir) else: local_path = file_id return _load_local_data(local_path, cache_dir, config_dir)
[docs]def import_local_resources(file_id, local_path, cache_dir=None, \ config_dir=None, ignore_exist_error=False): '''Import benchmark from local, if hashtag checked, save to cache.''' cache_dir = cache_dir or CACHE_DIR config_dir = config_dir or CONFIG_DIR os.makedirs(cache_dir, exist_ok=True) if not file_id.startswith('resources://'): raise ValueError("file_id must startswith \'resources://\'") res_name, _, _ = _parse_file_id(file_id[12:]) config = _get_config(res_name, config_dir) meta_path = os.path.join(cache_dir, _url_to_filename(res_name)) + '.json' if os.path.exists(meta_path): if ignore_exist_error: return raise ValueError("resources existed. If you want to delete the existing resources. \ Use `rm %s`." % meta_path) local_hashtag = _get_hashtag(local_path) if local_hashtag == config['hashtag']: cache_path = os.path.join(cache_dir, _url_to_filename(res_name)) with open(cache_path, 'wb') as cache_file: shutil.copyfileobj(open(local_path, 'rb'), cache_file) res_type = config.get('type', 'Default') resource_processor_class = ResourceProcessor.load_class(res_type + 'ResourceProcessor') if resource_processor_class is None: raise RuntimeError("No resources type named %sResourcePreprocessor" % res_type) resource_processor = resource_processor_class(cache_dir, config_dir) cache_path = resource_processor.preprocess(cache_path) meta = {'local_path': cache_path, 'hashtag': local_hashtag} with open(meta_path, 'w', encoding='utf-8') as meta_file: json.dump(meta, meta_file) cache_path = resource_processor.postprocess(cache_path) LOGGER.info('resource cached at %s', cache_path) return cache_path else: raise ValueError("bad hashtag of {}".format(res_name))
[docs]def load_file_from_url(url, force=False, cache_dir=None): '''See cotk.downloader.load_file_from_url. ''' cache_dir = cache_dir or CACHE_DIR parts = urlparse(url) filename = os.path.basename(parts.path) cache_dir = os.path.join(cache_dir, 'files') if os.path.exists(cache_dir) and force: shutil.rmtree(cache_dir) # raise ValueError("model existed. If you want to delete the existing model. \ # Use `rm %s`." % cache_path) os.makedirs(cache_dir, exist_ok=True) cache_path = os.path.join(cache_dir, filename) if os.path.exists(cache_path): return cache_path with tempfile.NamedTemporaryFile() as temp_file: _http_get(url, temp_file) temp_file.flush() # flush to avoid truncation temp_file.seek(0) # shutil.copyfileobj() starts at the current position with open(cache_path, 'wb') as cache_file: shutil.copyfileobj(temp_file, cache_file) LOGGER.info('model cached at %s', cache_path) return cache_path