diff --git a/.github/workflows/linter.yml b/.github/workflows/linter.yml index 6e49801..234469d 100644 --- a/.github/workflows/linter.yml +++ b/.github/workflows/linter.yml @@ -6,7 +6,7 @@ on: jobs: lint-python: - name: ruff + name: Python Linting runs-on: ubuntu-latest if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name != github.event.pull_request.base.repo.full_name steps: @@ -15,7 +15,11 @@ jobs: - uses: actions/setup-python@v5 with: python-version: 3.11 - - name: Install Ruff - run: pip install ruff==0.7.1 + - name: Install Dependencies + run: | + pip install ruff==0.7.1 + pip install black==23.9.1 - name: Run Ruff - run: ruff check --output-format=github aide/ \ No newline at end of file + run: ruff check --output-format=github aide/ + - name: Run Black + run: black --check aide/ diff --git a/.gitignore b/.gitignore index 6f1dee3..74062f1 100644 --- a/.gitignore +++ b/.gitignore @@ -164,4 +164,7 @@ workspaces logs .DS_STORE -.trunk \ No newline at end of file +.trunk + +.gradio/ +.ruff_cache/ \ No newline at end of file diff --git a/aide/__init__.py b/aide/__init__.py index e669170..3b1a82b 100644 --- a/aide/__init__.py +++ b/aide/__init__.py @@ -5,13 +5,21 @@ from .journal import Journal from omegaconf import OmegaConf from rich.status import Status -from .utils.config import load_task_desc, prep_agent_workspace, save_run, _load_cfg, prep_cfg +from .utils.config import ( + load_task_desc, + prep_agent_workspace, + save_run, + _load_cfg, + prep_cfg, +) + @dataclass class Solution: code: str valid_metric: float + class Experiment: def __init__(self, data_dir: str, goal: str, eval: str | None = None): @@ -22,7 +30,7 @@ def __init__(self, data_dir: str, goal: str, eval: str | None = None): goal (str): Description of the goal of the task. eval (str | None, optional): Optional description of the preferred way for the agent to evaluate its solutions. """ - + _cfg = _load_cfg(use_cli_args=False) _cfg.data_dir = data_dir _cfg.goal = goal @@ -52,6 +60,3 @@ def run(self, steps: int) -> Solution: best_node = self.journal.get_best_node(only_good=False) return Solution(code=best_node.code, valid_metric=best_node.metric.value) - - - diff --git a/aide/backend/backend_anthropic.py b/aide/backend/backend_anthropic.py index a28cc93..c411300 100644 --- a/aide/backend/backend_anthropic.py +++ b/aide/backend/backend_anthropic.py @@ -21,6 +21,7 @@ def _setup_anthropic_client(): global _client _client = anthropic.Anthropic(max_retries=0) + def query( system_message: str | None, user_message: str | None, diff --git a/aide/backend/backend_openai.py b/aide/backend/backend_openai.py index a69aade..e751e5e 100644 --- a/aide/backend/backend_openai.py +++ b/aide/backend/backend_openai.py @@ -19,11 +19,13 @@ openai.InternalServerError, ) + @once def _setup_openai_client(): global _client _client = openai.OpenAI(max_retries=0) + def query( system_message: str | None, user_message: str | None, diff --git a/aide/backend/utils.py b/aide/backend/utils.py index 914ec90..01499d8 100644 --- a/aide/backend/utils.py +++ b/aide/backend/utils.py @@ -11,7 +11,6 @@ OutputType = str | FunctionCallType - logger = logging.getLogger("aide") @@ -29,6 +28,7 @@ def backoff_create( logger.info(f"Backoff exception: {e}") return False + def opt_messages_to_list( system_message: str | None, user_message: str | None ) -> list[dict[str, str]]: diff --git a/aide/interpreter.py b/aide/interpreter.py index f6a947f..69171f5 100644 --- a/aide/interpreter.py +++ b/aide/interpreter.py @@ -49,7 +49,11 @@ def exception_summary(e, working_dir, exec_file_name, format_tb_ipython): tb_lines = traceback.format_exception(e) # skip parts of stack trace in weflow code tb_str = "".join( - [line for line in tb_lines if "aide/" not in line and "importlib" not in line] + [ + line + for line in tb_lines + if "aide/" not in line and "importlib" not in line + ] ) # tb_str = "".join([l for l in tb_lines]) diff --git a/aide/journal2report.py b/aide/journal2report.py index 14340b1..282cc6a 100644 --- a/aide/journal2report.py +++ b/aide/journal2report.py @@ -2,6 +2,7 @@ from .journal import Journal from .utils.config import StageConfig + def journal2report(journal: Journal, task_desc: dict, rcfg: StageConfig): """ Generate a report from a journal, the report will be in markdown format. diff --git a/aide/run.py b/aide/run.py index 035383e..856ab28 100644 --- a/aide/run.py +++ b/aide/run.py @@ -28,6 +28,7 @@ logger = logging.getLogger("aide") + def journal_to_rich_tree(journal: Journal): best_node = journal.get_best_node() @@ -36,7 +37,7 @@ def append_rec(node: Node, tree): s = "[red]◍ bug" else: style = "bold " if node is best_node else "" - + if node is best_node: s = f"[{style}green]● {node.metric.value:.3f} (best)" else: @@ -51,6 +52,7 @@ def append_rec(node: Node, tree): append_rec(n, tree) return tree + def run(): cfg = load_cfg() logger.info(f'Starting run "{cfg.exp_name}"') @@ -64,6 +66,7 @@ def run(): def cleanup(): if global_step == 0: shutil.rmtree(cfg.workspace_dir) + atexit.register(cleanup) journal = Journal() @@ -101,7 +104,9 @@ def generate_live(): f"Agent workspace directory:\n[yellow]▶ {str(cfg.workspace_dir)}", f"Experiment log directory:\n[yellow]▶ {str(cfg.log_dir)}", ] - left = Group(Panel(Text(task_desc_str.strip()), title="Task description"), prog, status) + left = Group( + Panel(Text(task_desc_str.strip()), title="Task description"), prog, status + ) right = tree wide = Group(*file_paths) @@ -133,10 +138,10 @@ def generate_live(): print("Generating final report from journal...") report = journal2report(journal, task_desc, cfg.report) print(report) - report_file_path = cfg.log_dir / 'report.md' + report_file_path = cfg.log_dir / "report.md" with open(report_file_path, "w") as f: f.write(report) - print('Report written to file:', report_file_path) + print("Report written to file:", report_file_path) if __name__ == "__main__": diff --git a/aide/utils/config.py b/aide/utils/config.py index f10b6d4..06a8b95 100644 --- a/aide/utils/config.py +++ b/aide/utils/config.py @@ -58,6 +58,7 @@ class ExecConfig: agent_file_name: str format_tb_ipython: bool + @dataclass class Config(Hashable): data_dir: Path @@ -91,7 +92,10 @@ def _get_next_logindex(dir: Path) -> int: pass return max_index + 1 -def _load_cfg(path: Path = Path(__file__).parent / "config.yaml", use_cli_args=True) -> Config: + +def _load_cfg( + path: Path = Path(__file__).parent / "config.yaml", use_cli_args=True +) -> Config: cfg = OmegaConf.load(path) if use_cli_args: cfg = OmegaConf.merge(cfg, OmegaConf.from_cli()) @@ -106,9 +110,11 @@ def load_cfg(path: Path = Path(__file__).parent / "config.yaml") -> Config: def prep_cfg(cfg: Config): if cfg.data_dir is None: raise ValueError("`data_dir` must be provided.") - + if cfg.desc_file is None and cfg.goal is None: - raise ValueError("You must provide either a description of the task goal (`goal=...`) or a path to a plaintext file containing the description (`desc_file=...`).") + raise ValueError( + "You must provide either a description of the task goal (`goal=...`) or a path to a plaintext file containing the description (`desc_file=...`)." + ) if cfg.data_dir.startswith("example_tasks/"): cfg.data_dir = Path(__file__).parent.parent / cfg.data_dir @@ -151,13 +157,15 @@ def load_task_desc(cfg: Config): logger.warning( "Ignoring goal and eval args because task description file is provided." ) - + with open(cfg.desc_file) as f: return f.read() - + # or generate it from the goal and eval args if cfg.goal is None: - raise ValueError("`goal` (and optionally `eval`) must be provided if a task description file is not provided.") + raise ValueError( + "`goal` (and optionally `eval`) must be provided if a task description file is not provided." + ) task_desc = {"Task goal": cfg.goal} if cfg.eval is not None: @@ -165,6 +173,7 @@ def load_task_desc(cfg: Config): return task_desc + def prep_agent_workspace(cfg: Config): """Setup the agent's workspace and preprocess data if necessary.""" (cfg.workspace_dir / "input").mkdir(parents=True, exist_ok=True) @@ -177,7 +186,7 @@ def prep_agent_workspace(cfg: Config): def save_run(cfg: Config, journal): cfg.log_dir.mkdir(parents=True, exist_ok=True) - + # save journal serialize.dump_json(journal, cfg.log_dir / "journal.json") # save config diff --git a/aide/utils/tree_export.py b/aide/utils/tree_export.py index e6c8431..918c889 100644 --- a/aide/utils/tree_export.py +++ b/aide/utils/tree_export.py @@ -62,7 +62,7 @@ def cfg_to_tree_struct(cfg, jou: Journal): def generate_html(tree_graph_str: str): template_dir = Path(__file__).parent / "viz_templates" - + with open(template_dir / "template.js") as f: js = f.read() js = js.replace("", tree_graph_str)