Skip to content

Commit

Permalink
fix: handle watching files tracked by user defined glob; current solu…
Browse files Browse the repository at this point in the history
…tion is bruteforce
  • Loading branch information
sgerloff committed Jan 26, 2024
1 parent 64aa7db commit a32ed23
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 14 deletions.
22 changes: 16 additions & 6 deletions rocket/file_watcher.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,37 @@
import glob
import os
import time

from typing import List
from watchdog.events import FileSystemEventHandler
from watchdog.observers import Observer

from rocket.utils import gather_glob_paths


class FileWatcher:
class _Handler(FileSystemEventHandler):
def __init__(self, watcher_instance):
self.watcher_instance = watcher_instance

def on_modified(self, event):
if event.is_directory:
_current_glob_files = gather_glob_paths(self.watcher_instance.glob_paths)
if event.src_path in _current_glob_files:
self.watcher_instance.modified_files.add(event.src_path)
elif event.is_directory:
return
if os.path.splitext(event.src_path)[1] == ".py":
self.watcher_instance.modified_files.append(event.src_path)
elif os.path.splitext(event.src_path)[1] == ".py":
self.watcher_instance.modified_files.add(event.src_path)

def __init__(self, path_to_watch, callback, recursive=True):
def __init__(self, path_to_watch, callback, recursive=True, glob_paths: List[str] = None):
self.path_to_watch = path_to_watch
self.callback = callback
self.recursive = recursive
self.observer = Observer()
self.modified_files = []
self.modified_files = set()
self.glob_paths = glob_paths
if self.glob_paths is None:
self.glob_paths = []
self.handler = self._Handler(self)

def start(self):
Expand All @@ -33,7 +43,7 @@ def start(self):
while True:
time.sleep(1)
if self.modified_files:
self.callback(self.modified_files)
self.callback(list(self.modified_files))
self.modified_files.clear()
except KeyboardInterrupt:
self.observer.stop()
Expand Down
21 changes: 13 additions & 8 deletions rocket/rocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
extract_python_package_dirs,
extract_python_files_from_folder,
execute_for_each_multithreaded,
gather_glob_paths,
)


Expand Down Expand Up @@ -82,7 +83,13 @@ def launch(
project_name = os.path.abspath(project_location).split("/")[-1]
dbfs_path = f"{dbfs_path}/{project_name}"

self._build_and_deploy(watch=watch, project_location=project_location, dbfs_path=dbfs_path, glob_path=glob_path)
glob_paths = []
if isinstance(glob_path, str):
glob_paths = [os.path.join(project_location, glob_path)]
elif isinstance(glob_path, list):
glob_paths = [os.path.join(project_location, path) for path in glob_path]

self._build_and_deploy(watch=watch, project_location=project_location, dbfs_path=dbfs_path, glob_paths=glob_paths)
if watch:
watcher = FileWatcher(
project_location,
Expand All @@ -91,8 +98,9 @@ def launch(
modified_files=watcher.modified_files,
dbfs_path=dbfs_path,
project_location=project_location,
glob_path=glob_path
glob_paths=glob_path
),
glob_paths=glob_paths,
)
watcher.start()

Expand All @@ -102,7 +110,7 @@ def _build_and_deploy(
project_location: str,
dbfs_path: str,
modified_files: Optional[List[str]] = None,
glob_path: Optional[Union[str, List[str]]] = None
glob_paths: Optional[List[str]] = None
) -> None:
if modified_files:
logger.info(f"Found changes in {modified_files}. Overwriting them.")
Expand Down Expand Up @@ -147,11 +155,8 @@ def _build_and_deploy(
for package_dir in package_dirs:
files.update(extract_python_files_from_folder(package_dir))

if isinstance(glob_path, str):
files.update(glob.glob(os.path.join(project_location, glob_path)))
elif isinstance(glob_path, list):
for path in glob_path:
files.update(glob.glob(os.path.join(project_location, path)))
if glob_paths is not None:
files.update(gather_glob_paths(glob_paths))

project_files = ["setup.py", "pyproject.toml"]
for project_file in project_files:
Expand Down
9 changes: 9 additions & 0 deletions rocket/utils.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import concurrent.futures
import glob
import os
import subprocess

from typing import List, Set
from rocket.logger import logger


Expand Down Expand Up @@ -53,3 +55,10 @@ def extract_python_files_from_folder(path):
py_files.append(os.path.join(root, file))

return py_files


def gather_glob_paths(glob_paths: List[str]) -> Set[str]:
_unique_paths = set()
for glob_path in glob_paths:
_unique_paths.update(glob.glob(glob_path))
return _unique_paths

0 comments on commit a32ed23

Please sign in to comment.