Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
adrinjalali committed Jan 8, 2025
0 parents commit 4994f78
Show file tree
Hide file tree
Showing 8 changed files with 2,628 additions and 0 deletions.
49 changes: 49 additions & 0 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
name: Publish to (Test)PyPI
on:
workflow_dispatch:
inputs:
version:
description: 'Version to upload to pypi'
required: true
pypi_repo:
description: 'Repo to upload to ("testpypi" or "pypi")'
default: 'testpypi'
required: true

jobs:
publish:

runs-on: ubuntu-latest
# Specifying a GitHub environment is optional, but strongly encouraged
environment: publish-pypi
permissions:
# IMPORTANT: this permission is mandatory for trusted publishing
id-token: write

steps:
- uses: actions/checkout@v4
with:
ref: ${{ github.event.inputs.version }}

- uses: actions/setup-python@v5
with:
python-version: '3.x'

- name: Install dependencies
run: |
python -m pip install -U pip
python -m pip install -U wheel twine build
- name: Generate distribution archives
run: |
python -m build
- name: Publish package to TestPyPI
uses: pypa/[email protected]
with:
repository-url: https://test.pypi.org/legacy/
if: ${{ github.event.inputs.pypi_repo == 'testpypi' }}

- name: Publish package to PyPI
uses: pypa/[email protected]
if: ${{ github.event.inputs.pypi_repo == 'pypi' }}
37 changes: 37 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Python bytecode
__pycache__/
*.py[cod]
*$py.class

# Distribution/packaging
dist/
build/
*.egg-info/

# Virtual environments
venv/
env/
.env/
.venv/

# IDE settings
.idea/
.vscode/
*.swp
*.swo

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
coverage.xml
*.cover

# Jupyter Notebook
.ipynb_checkpoints

# Local development settings
.env
.env.local
8 changes: 8 additions & 0 deletions joblib_modal/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from joblib import register_parallel_backend
from joblib_modal._modal import ModalBackend

register_parallel_backend("modal", ModalBackend)

__version__ = "0.1.0"

__all__ = ["ModalBackend", "__version__"]
83 changes: 83 additions & 0 deletions joblib_modal/_modal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
import uuid
from joblib._parallel_backends import (
FallbackToBackend,
SequentialBackend,
ThreadingBackend,
)
from joblib import register_parallel_backend
from joblib._utils import _TracebackCapturingWrapper
import modal


def executor(func, *args, **kwargs):
return func(*args, **kwargs)


class ModalBackend(ThreadingBackend):
uses_threads = True
supports_sharedmem = False

def __init__(self, *args, name=None, image=None, modal_output=False, **kwargs):
super().__init__(*args, **kwargs)
self.name = name
self.image = image
self.modal_output = modal_output

def configure(self, n_jobs=1, parallel=None, **_):
"""Build a process or thread pool and return the number of workers"""
n_jobs = self.effective_n_jobs(n_jobs)

if n_jobs == 1:
# Avoid unnecessary overhead and use sequential backend instead.
raise FallbackToBackend(SequentialBackend(nesting_level=self.nesting_level))

self.parallel = parallel
self._n_jobs = n_jobs

if self.image is None:
image = modal.Image.debian_slim().pip_install("joblib")
else:
image = self.image

name = self.name or f"modal-joblib-{uuid.uuid4()}"
self.modal_app = modal.App(name, image=image)

self.modal_executor = self.modal_app.function()(executor)

if self.modal_output:
self.output_ctx = modal.enable_output()
self.output_ctx.__enter__()
self.run_ctx = self.modal_app.run()
self.run_ctx.__enter__()

return n_jobs

def effective_n_jobs(self, n_jobs):
"""Determine the number of jobs which are going to run in parallel"""
if n_jobs == 0:
raise ValueError("n_jobs == 0 in Parallel has no meaning")
if n_jobs < 0:
return 1000
return n_jobs

def apply_async(self, func, callback=None):
"""Schedule a func to be run"""
# Here, we need a wrapper to avoid crashes on KeyboardInterruptErrors.
# We also call the callback on error, to make sure the pool does not
# wait on crashed jobs.
return self._get_pool().apply_async(
_TracebackCapturingWrapper(self.modal_executor.remote),
(),
{"func": func},
callback=callback,
error_callback=callback,
)

def terminate(self):
if self.modal_output:
self.output_ctx.__exit__(None, None, None)
self.run_ctx.__exit__(None, None, None)
super().terminate()


register_parallel_backend("modal", ModalBackend)
Loading

0 comments on commit 4994f78

Please sign in to comment.