Skip to content

Commit

Permalink
fix(tools): review fixes and internal improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
nfrasser committed Jan 7, 2025
1 parent 6127303 commit 168f3a3
Show file tree
Hide file tree
Showing 8 changed files with 269 additions and 220 deletions.
6 changes: 3 additions & 3 deletions cryosparc/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _call(self, _method: str, _path: str, _schema, *args, **kwargs):
with ctx as res:
return self._handle_response(_schema, res)
except httpx.HTTPStatusError as err:
raise APIError("received error response", res=err.response)
raise APIError("received error response", res=err.response) from err


class APIClient(APINamespace):
Expand Down Expand Up @@ -404,8 +404,8 @@ def _decode_json_response(value: Any, schema: dict):

# Recursively decode list or tuple
if "type" in schema and schema["type"] == "array":
typ, items_key = (tuple, "prefixItems") if "prefixItems" in schema else (list, "items")
return typ(_decode_json_response(item, schema[items_key]) for item in value)
collection_type, items_key = (tuple, "prefixItems") if "prefixItems" in schema else (list, "items")
return collection_type(_decode_json_response(item, schema[items_key]) for item in value)

# Recursively decode object
if "type" in schema and schema["type"] == "object":
Expand Down
48 changes: 24 additions & 24 deletions cryosparc/api.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -501,30 +501,6 @@ class JobsNamespace(APINamespace):
...
def set_param(self, project_uid: str, job_uid: str, param: str, /, *, value: Any) -> Job: ...
def clear_param(self, project_uid: str, job_uid: str, param: str, /) -> Job: ...
def connect(
self, project_uid: str, job_uid: str, input_name: str, /, *, source_job_uid: str, source_output_name: str
) -> Job: ...
def disconnect_all(self, project_uid: str, job_uid: str, input_name: str, /) -> Job: ...
def disconnect(self, project_uid: str, job_uid: str, input_name: str, connection_index: int, /) -> Job: ...
def find_output_result(
self, project_uid: str, job_uid: str, output_name: str, result_name: str, /
) -> OutputResult: ...
def connect_result(
self,
project_uid: str,
job_uid: str,
input_name: str,
connection_index: int,
result_name: str,
/,
*,
source_job_uid: str,
source_output_name: str,
source_result_name: str,
) -> Job: ...
def disconnect_result(
self, project_uid: str, job_uid: str, input_name: str, connection_index: int, result_name: str, /
) -> Job: ...
def load_input(
self,
project_uid: str,
Expand Down Expand Up @@ -568,6 +544,30 @@ class JobsNamespace(APINamespace):
Save job output dataset. Job must be running or waiting.
"""
...
def connect(
self, project_uid: str, job_uid: str, input_name: str, /, *, source_job_uid: str, source_output_name: str
) -> Job: ...
def disconnect_all(self, project_uid: str, job_uid: str, input_name: str, /) -> Job: ...
def disconnect(self, project_uid: str, job_uid: str, input_name: str, connection_index: int, /) -> Job: ...
def find_output_result(
self, project_uid: str, job_uid: str, output_name: str, result_name: str, /
) -> OutputResult: ...
def connect_result(
self,
project_uid: str,
job_uid: str,
input_name: str,
connection_index: int,
result_name: str,
/,
*,
source_job_uid: str,
source_output_name: str,
source_result_name: str,
) -> Job: ...
def disconnect_result(
self, project_uid: str, job_uid: str, input_name: str, connection_index: int, result_name: str, /
) -> Job: ...
def enqueue(
self,
project_uid: str,
Expand Down
43 changes: 20 additions & 23 deletions cryosparc/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,12 @@
# CryoSPARC should not depend on anything in this file.
import warnings
from abc import ABC, abstractmethod
from typing import Any, Dict, Generic, List, Literal, Optional, TypeVar, Union
from typing import Any, Dict, Generic, Optional, TypeVar, Union

from pydantic import BaseModel

from .models.job_spec import InputSlot, OutputSlot
from .spec import Datafield
from .spec import SlotSpec

# API model
M = TypeVar("M", bound=BaseModel)
Expand All @@ -37,9 +37,17 @@ def model(self) -> M:
assert self._model, "Could not refresh database document"
return self._model

@model.setter
def model(self, model: M):
self._model = model

@model.deleter
def model(self):
self._model = None

@property
def doc(self) -> Dict[str, Any]:
warnings.warn(".doc attribute is deprecated. Use .model attribute instead.", DeprecationWarning)
warnings.warn(".doc attribute is deprecated. Use .model attribute instead.", DeprecationWarning, stacklevel=2)
return self.model.model_dump(by_alias=True)

@abstractmethod
Expand All @@ -48,34 +56,23 @@ def refresh(self):
return self


InputSlotSpec = Union[str, InputSlot, Datafield]
"""
A result slot specification for the slots=... argument when creating inputs.
"""

OutputSlotSpec = Union[str, OutputSlot, Datafield]
"""
A result slot specification for the slots=... argument when creating outputs.
"""

LoadableSlots = Union[Literal["default", "passthrough", "all"], List[str]]
"""Slots groups load for a job input or output."""


def as_input_slot(spec: InputSlotSpec) -> InputSlot:
def as_input_slot(spec: Union[SlotSpec, InputSlot]) -> InputSlot:
if isinstance(spec, str):
spec, required = (spec[1:], False) if spec[0] == "?" else (spec, True)
return InputSlot(name=spec, dtype=spec, required=required)
elif isinstance(spec, dict) and "dtype" in spec and "prefix" in spec:
name, dtype, required = spec["prefix"], spec["dtype"].split(".").pop(), spec.get("required", True)
elif isinstance(spec, dict) and "dtype" in spec:
dtype = spec["dtype"]
name = spec.get("name") or spec.get("prefix") or dtype
required = spec.get("required", True)
return InputSlot(name=name, dtype=dtype, required=required)
return spec


def as_output_slot(spec: OutputSlotSpec) -> OutputSlot:
def as_output_slot(spec: Union[SlotSpec, OutputSlot]) -> OutputSlot:
if isinstance(spec, str):
return OutputSlot(name=spec, dtype=spec)
elif isinstance(spec, dict) and "dtype" in spec and "prefix" in spec:
name, dtype = spec["prefix"], spec["dtype"].split(".").pop()
elif isinstance(spec, dict) and "dtype" in spec:
dtype = spec["dtype"]
name = spec.get("name") or spec.get("prefix") or dtype
return OutputSlot(name=name, dtype=dtype)
return spec
Loading

0 comments on commit 168f3a3

Please sign in to comment.