Source code for cotk._utils.imports

r"""
``cotk._utils.imports`` provides classes that fake a uninstalled module.
"""
import importlib
import sys

[docs]class LazyModule(object): r'''Lazy loading modules. Arguments: module_name (str): The path of import. For example: ``cotk``. global_dict (str): Override the global dictionary when the module is loaded. Examples: >>> torch = LazyModule("torch", globals()) >>> print(torch) <cotk._utils.imports.LazyModule object at 0x000001BE147682E8> >>> torch.Tensor = LazyObject("torch.Tensor") >>> print(torch.Tensor) <cotk._utils.imports.LazyObject object at 0x000001BE1339CE80> >>> print(torch.LongTensor) <class 'torch.LongTensor'> >>> print(torch.Tensor) <class 'torch.Tensor'> >>> print(torch) <module 'torch'> ''' def __init__(self, module_name, global_dict): super().__init__() self._module_name = module_name self._global_dict = global_dict def _try_load(self): if super().__getattribute__("_module_name") in sys.modules: return sys.modules[super().__getattribute__("_module_name")] else: return None def _load(self): module = importlib.import_module(super().__getattribute__("_module_name")) global_dict = super().__getattribute__("_global_dict") global_dict[super().__getattribute__("_module_name")] = module return module def __getattribute__(self, key): loaded = super().__getattribute__("_try_load")() if loaded is not None: return getattr(loaded, key) try: return super().__getattribute__(key) except AttributeError: pass if key == "__bases__": return tuple() else: return getattr(super().__getattribute__("_load")(), key)
# def __call__(self, *args, **kwargs): # return self.__load()(*args, **kwargs)
[docs]class LazyObject(object): r'''Lazy loading objects. Arguments: object_name (str): The path of import. For example: ``cotk.dataloader.MSCOCO``. Examples: >>> dataloader = LazyObject("cotk.dataloader") >>> print(dataloader) <cotk._utils.imports.LazyObject object at 0x000001BE1339CE48> >>> print(dataloader.MSCOCO) <class 'cotk.dataloader.language_generation.MSCOCO'> ''' def __init__(self, object_name): super().__init__() self._object_name = object_name self._module_name = object_name.split('.')[0] def _try_load_module(self): if super().__getattribute__("_module_name") in sys.modules: return sys.modules[super().__getattribute__("_module_name")] else: return None def _load_object(self): mod = importlib.import_module(super().__getattribute__("_module_name")) arr = super().__getattribute__("_object_name").split('.') obj = getattr(mod, arr[1]) for i in range(2, len(arr)): try: obj = getattr(obj, arr[i]) except AttributeError: raise AttributeError("No attribute %s in %s." % (arr[i], ".".join(arr[:i]))) return obj def _try_getattribute(self, key): loaded = super().__getattribute__("_try_load_module")() if loaded is not None: return getattr(super().__getattribute__("_load_object")(), key) try: return super().__getattribute__(key) except AttributeError: pass return None def __getattribute__(self, key): loaded = super().__getattribute__("_try_load_module")() if loaded is not None: return getattr(super().__getattribute__("_load_object")(), key) try: return super().__getattribute__(key) except AttributeError: pass if key == "__bases__": return tuple() else: return getattr(super().__getattribute__("_load_object")(), key) def __call__(self, *args, **kwargs): return self._load_object()(*args, **kwargs) @staticmethod def peek(obj, key): if isinstance(obj, LazyObject): loaded = obj._try_load_module() if loaded is not None: return getattr(obj, key) else: return obj._try_getattribute(key) else: return None
[docs]class DummyObject(dict): r'''Dummy Object for uninstalled modules. [deprecated] Examples: >>> try: ... import torch ... except ImportError as err: ... torch = DummyObject(err) ... >>> torch.Tensor ModuleNotFoundError: No module named 'torch' >>> torch.Tensor = DummyObject(torch.err) >>> torch.Tensor >>> torch.Tensor() ModuleNotFoundError: No module named 'torch' ''' def __init__(self, err): self.err = err super().__init__() def __getattr__(self, key): if key in self: return self[key] elif key == "__bases__": return tuple() else: raise self["err"] def __setattr__(self, key, value): self[key] = value def __delattr__(self, key): del self[key] def __call__(self, *args, **kwargs): raise self['err']