diff --git a/dosagelib/loader.py b/dosagelib/loader.py index 2fcc67b96..f04f1615c 100644 --- a/dosagelib/loader.py +++ b/dosagelib/loader.py @@ -5,11 +5,12 @@ Functions to load plugin modules. Example usage: - modules = loader.get_plugin_modules() - plugins = loader.get_plugins(modules, PluginClass) + for module in loader.get_plugin_modules(): + plugins.extend(loader.get_module_plugins(module, PluginClass)) """ import importlib import pkgutil +import sys from .plugins import (__name__ as plugin_package, __path__ as plugin_path) from .output import out @@ -44,16 +45,21 @@ def _get_all_modules_pyinstaller(): return toc -def get_plugins(modules, classobj): - """Find all class objects in all modules. - @param modules: the modules to search - @ptype modules: iterator of modules - @return: found classes - @rytpe: iterator of class objects +def get_plugin_modules_from_dir(path, prefix='user_'): + """Load and import a directory of python files as if they were part of the + "plugins" package. (Mostly "stolen" from + https://docs.python.org/3/library/importlib.html#importing-a-source-file-directly) """ - for module in modules: - for plugin in get_module_plugins(module, classobj): - yield plugin + modules = [] + for f in path.glob('*.py'): + name = plugin_package + "." + prefix + f.stem + # FIXME: Drop str() when this is Python 3.6+ + spec = importlib.util.spec_from_file_location(name, str(f)) + module = importlib.util.module_from_spec(spec) + sys.modules[name] = module + spec.loader.exec_module(module) + modules.append(module) + return modules def get_module_plugins(module, classobj): diff --git a/dosagelib/scraper.py b/dosagelib/scraper.py index b881c9f23..79610046a 100644 --- a/dosagelib/scraper.py +++ b/dosagelib/scraper.py @@ -542,7 +542,7 @@ class Cache: slow. """ def __init__(self): - self.data = None + self.data = [] def find(self, comic, multiple_allowed=False): """Get a list comic scraper objects. @@ -574,12 +574,41 @@ class Cache: def load(self): out.debug("Loading comic modules...") - modules = loader.get_plugin_modules() - plugins = list(loader.get_plugins(modules, Scraper)) - self.data = list([m for x in plugins for m in x.getmodules()]) + modules = 0 + classes = 0 + for module in loader.get_plugin_modules(): + modules += 1 + classes += self.addmodule(module) self.validate() - out.debug("... %d modules loaded from %d classes." % ( - len(self.data), len(plugins))) + out.debug("... %d scrapers loaded from %d classes in %d modules." % ( + len(self.data), classes, modules)) + + def adddir(self, path): + """Add an additional directory with python modules to the scraper list. + These are handled as if the were part of the plugins package. + """ + if not self.data: + self.load() + modules = 0 + classes = 0 + out.debug("Loading user scrapers from '{}'...".format(path)) + for module in loader.get_plugin_modules_from_dir(path): + modules += 1 + classes += self.addmodule(module) + self.validate() + if classes > 0: + out.debug("Added %d user classes from %d modules." % ( + classes, modules)) + + def addmodule(self, module): + """Adds all valid plugin classes from the specified module to the cache. + @return: number of classes added + """ + classes = 0 + for plugin in loader.get_module_plugins(module, Scraper): + classes += 1 + self.data.extend(plugin.getmodules()) + return classes def get(self, include_removed=False): """Find all comic scraper classes in the plugins directory. diff --git a/tests/mocks/plugins/dummy.py b/tests/mocks/plugins/dummy.py new file mode 100644 index 000000000..591344ae9 --- /dev/null +++ b/tests/mocks/plugins/dummy.py @@ -0,0 +1,7 @@ +# SPDX-License-Identifier: MIT +# Copyright (C) 2020 Tobias Gruetzmacher +from ..scraper import _ParserScraper + + +class ADummyTestScraper(_ParserScraper): + url = 'https://dummy.example/' diff --git a/tests/test_scraper.py b/tests/test_scraper.py index 9a479b35b..597d5ceaa 100644 --- a/tests/test_scraper.py +++ b/tests/test_scraper.py @@ -1,7 +1,10 @@ # SPDX-License-Identifier: MIT # Copyright (C) 2013-2014 Bastian Kleineidam # Copyright (C) 2015-2020 Tobias Gruetzmacher +from pathlib import Path + import pytest + from dosagelib.scraper import scrapers @@ -24,3 +27,9 @@ class TestScraper(object): def test_find_scrapers_error(self): with pytest.raises(ValueError, match='empty comic name'): scrapers.find('') + + def test_user_dir(self): + oldlen = len(scrapers.get()) + scrapers.adddir(Path(__file__).parent / 'mocks' / 'plugins') + assert len(scrapers.get()) == oldlen + 1 + assert len(scrapers.find('ADummyTestScraper')) == 1