diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index dfe31d8..0ff6199 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -1,6 +1,6 @@ name: Test -on: [push] +on: [push, pull_request] jobs: test: diff --git a/README.md b/README.md index b30e896..6aad20b 100644 --- a/README.md +++ b/README.md @@ -157,6 +157,22 @@ cppimport.settings['force_rebuild'] = True And if this is a common occurence, I would love to hear your use case and why the combination of the checksum, `cfg['dependencies']` and `cfg['sources']` is insufficient! +Note that `force_rebuild` does not work when importing the module concurrently. + +### Can I import my model concurrently? + +It's safe to use `cppimport` to import a module concurrently using multiple threads, processes or even machines! + +Before building a module, `cppimport` obtains a lockfile preventing other processors from building it at the same time - this prevents clashes that can lead to failure. +Other processes will wait maximum 10 mins until the first process has built the module and load it. If your module does not build within 10 mins then it will timeout. +You can increase the timeout time in the settings: + +```python +cppimport.settings['lock_timeout'] = 10*60 # 10 mins +``` + +You should not use `force_rebuild` when importing concurrently. + ### How can I get information about filepaths in the configuration block? The module name is available as the `fullname` variable and the C++ module file is available as `filepath`. For example, diff --git a/cppimport/__init__.py b/cppimport/__init__.py index aef1141..0690538 100644 --- a/cppimport/__init__.py +++ b/cppimport/__init__.py @@ -8,9 +8,11 @@ from cppimport.find import _check_first_line_contains_cppimport settings = dict( - force_rebuild=False, + force_rebuild=False, # `force_rebuild` with multiple processes is not supported file_exts=[".cpp", ".c"], rtld_flags=ctypes.RTLD_LOCAL, + lock_suffix=".lock", + lock_timeout=10 * 60, remove_strict_prototypes=True, release_mode=os.getenv("CPPIMPORT_RELEASE_MODE", "0").lower() in ("true", "yes", "1"), @@ -57,19 +59,26 @@ def imp_from_filepath(filepath, fullname=None): module : the compiled and loaded Python extension module """ from cppimport.importer import ( + build_safely, is_build_needed, load_module, setup_module_data, - template_and_build, try_load, ) + filepath = os.path.abspath(filepath) if fullname is None: fullname = os.path.splitext(os.path.basename(filepath))[0] module_data = setup_module_data(fullname, filepath) + # The call to try_load is necessary here because there are times when the + # only evidence a rebuild is needed comes from attempting to load an + # existing extension module. For example, if the extension was built on a + # different architecture or with different Python headers and will produce + # an error when loaded, then the load will fail. In that situation, we will + # need to rebuild. if is_build_needed(module_data) or not try_load(module_data): - template_and_build(filepath, module_data) - load_module(module_data) + build_safely(filepath, module_data) + load_module(module_data) return module_data["module"] @@ -108,17 +117,19 @@ def build_filepath(filepath, fullname=None): ext_path : the path to the compiled extension. """ from cppimport.importer import ( + build_safely, is_build_needed, + load_module, setup_module_data, - template_and_build, ) + filepath = os.path.abspath(filepath) if fullname is None: fullname = os.path.splitext(os.path.basename(filepath))[0] module_data = setup_module_data(fullname, filepath) if is_build_needed(module_data): - template_and_build(filepath, module_data) - + build_safely(filepath, module_data) + load_module(module_data) # Return the path to the built module return module_data["ext_path"] diff --git a/cppimport/checksum.py b/cppimport/checksum.py index 6e717ac..57b8a07 100644 --- a/cppimport/checksum.py +++ b/cppimport/checksum.py @@ -45,6 +45,9 @@ def _load_checksum_trailer(module_data): except FileNotFoundError: logger.info("Failed to find compiled extension; rebuilding.") return None, None + except OSError: + logger.info("Checksum trailer invalid. Rebuilding.") + return None, None try: deps, old_checksum = json.loads(json_s) @@ -79,7 +82,7 @@ def _save_checksum_trailer(module_data, dep_filepaths, cur_checksum): # legal (see e.g. https://stackoverflow.com/questions/10106447). dump = json.dumps([dep_filepaths, cur_checksum]).encode("ascii") dump += _FMT.pack(len(dump), _TAG) - with open(module_data["ext_path"], "ab") as file: + with open(module_data["ext_path"], "ab", buffering=0) as file: file.write(dump) diff --git a/cppimport/importer.py b/cppimport/importer.py index cf04f1a..bd5dcbd 100644 --- a/cppimport/importer.py +++ b/cppimport/importer.py @@ -3,6 +3,10 @@ import os import sys import sysconfig +from contextlib import suppress +from time import sleep, time + +import filelock import cppimport from cppimport.build_module import build_module @@ -12,6 +16,46 @@ logger = logging.getLogger(__name__) +def build_safely(filepath, module_data): + """Protect against race conditions when multiple processes executing + `template_and_build`""" + binary_path = module_data["ext_path"] + lock_path = binary_path + cppimport.settings["lock_suffix"] + + def build_completed(): + return os.path.exists(binary_path) and is_checksum_valid(module_data) + + t = time() + + # Race to obtain the lock and build. Other processes can wait + while not build_completed() and time() - t < cppimport.settings["lock_timeout"]: + try: + with filelock.FileLock(lock_path, timeout=1): + if build_completed(): + break + template_and_build(filepath, module_data) + except filelock.Timeout: + logging.debug(f"Could not obtain lock (pid {os.getpid()})") + if cppimport.settings["force_rebuild"]: + raise ValueError( + "force_build must be False to build concurrently." + "This process failed to claim a filelock indicating that" + " a concurrent build is in progress" + ) + sleep(1) + + if os.path.exists(lock_path): + with suppress(OSError): + os.remove(lock_path) + + if not build_completed(): + raise Exception( + f"Could not compile binary as lock already taken and timed out." + f" Try increasing the timeout setting if " + f"the build time is longer (pid {os.getpid()})." + ) + + def template_and_build(filepath, module_data): logger.debug(f"Compiling {filepath}.") run_templating(module_data) @@ -79,6 +123,8 @@ def is_build_needed(module_data): def try_load(module_data): + """Try loading the module to test if it's not corrupt and for the correct + architecture""" try: load_module(module_data) return True @@ -86,4 +132,6 @@ def try_load(module_data): logger.info( f"ImportError during import with matching checksum: {e}. Trying to rebuild." ) + with suppress(OSError): + os.remove(module_data["fullname"]) return False diff --git a/environment.yml b/environment.yml index ca71dcb..0a90830 100644 --- a/environment.yml +++ b/environment.yml @@ -10,3 +10,4 @@ dependencies: - pytest - pytest-cov - pre-commit + - filelock diff --git a/setup.py b/setup.py index c19d9c4..8ac7601 100644 --- a/setup.py +++ b/setup.py @@ -6,7 +6,7 @@ use_scm_version={"version_scheme": "post-release"}, setup_requires=["setuptools_scm"], packages=["cppimport"], - install_requires=["mako", "pybind11"], + install_requires=["mako", "pybind11", "filelock"], zip_safe=False, name="cppimport", description="Import C++ files directly from Python!", diff --git a/tests/test_cppimport.py b/tests/test_cppimport.py index 8dd992a..a81c8fd 100644 --- a/tests/test_cppimport.py +++ b/tests/test_cppimport.py @@ -2,8 +2,11 @@ import copy import logging import os +import shutil import subprocess import sys +from multiprocessing import Process +from tempfile import TemporaryDirectory import cppimport import cppimport.build_module @@ -40,11 +43,28 @@ def subprocess_check(test_code, returncode=0): stdout=subprocess.PIPE, stderr=subprocess.PIPE, ) - print(p.stdout.decode("utf-8")) - print(p.stderr.decode("utf-8")) + if len(p.stdout) > 0: + print(p.stdout.decode("utf-8")) + if len(p.stderr) > 0: + print(p.stderr.decode("utf-8")) assert p.returncode == returncode +@contextlib.contextmanager +def tmp_dir(files=None): + """Create a temporary directory and copy `files` into it. `files` can also + include directories.""" + files = files if files else [] + + with TemporaryDirectory() as tmp_path: + for f in files: + if os.path.isdir(f): + shutil.copytree(f, os.path.join(tmp_path, os.path.basename(f))) + else: + shutil.copyfile(f, os.path.join(tmp_path, os.path.basename(f))) + yield tmp_path + + def test_find_module_cpppath(): mymodule_loc = find_module_cpppath("mymodule") mymodule_dir = os.path.dirname(mymodule_loc) @@ -170,3 +190,24 @@ def test_import_hook(): cppimport.force_rebuild(False) hook_test + + +def test_multiple_processes(): + with tmp_dir(["tests/hook_test.cpp"]) as tmp_path: + test_code = f""" +import os; +os.chdir('{tmp_path}'); +import cppimport.import_hook; +import hook_test; + """ + processes = [ + Process(target=subprocess_check, args=(test_code,)) for i in range(100) + ] + + for p in processes: + p.start() + + for p in processes: + p.join() + + assert all(p.exitcode == 0 for p in processes)