Skip to content

Commit

Permalink
chore: proper Connection registration
Browse files Browse the repository at this point in the history
  • Loading branch information
phil65 committed Jan 31, 2025
1 parent 47e3610 commit 68fab96
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 53 deletions.
50 changes: 31 additions & 19 deletions src/llmling_agent/delegation/team.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,37 +92,49 @@ def __prompt__(self) -> str:

async def run_iter(
self,
*prompts: AnyPromptType | PIL.Image.Image | os.PathLike[str],
*prompts: AnyPromptType,
**kwargs: Any,
) -> AsyncIterator[ChatMessage[Any]]:
"""Yield messages as they arrive from parallel execution."""
# Create queue for collecting results
queue: asyncio.Queue[ChatMessage[Any]] = asyncio.Queue()
errors: dict[str, Exception] = {}

final_prompt = list(prompts)
combined_prompt = "\n".join([await to_prompt(p) for p in final_prompt])
all_nodes = list(await self.pick_agents(combined_prompt))
queue: asyncio.Queue[ChatMessage[Any] | None] = asyncio.Queue()
failures: dict[str, Exception] = {}

async def _run(node: MessageNode[TDeps, Any]) -> None:
try:
message = await node.run(*prompts, **kwargs)
await queue.put(message)
except Exception as e: # noqa: BLE001
errors[node.name] = e
except Exception as e:
logger.exception("Error executing node %s", node.name)
failures[node.name] = e
# Put None to maintain queue count
await queue.put(None)

# Get nodes to run
combined_prompt = "\n".join([await to_prompt(p) for p in prompts])
all_nodes = list(await self.pick_agents(combined_prompt))

# Start all agents
tasks = [asyncio.create_task(_run(n), name=f"run_{n.name}") for n in all_nodes]
for _ in all_nodes:
yield await queue.get()

# Wait for all tasks to complete (for error handling)
await asyncio.gather(*tasks, return_exceptions=True)

if errors:
# Maybe raise an exception with all errors?
first_error = next(iter(errors.values()))
raise first_error
try:
# Yield messages as they arrive
for _ in all_nodes:
if msg := await queue.get():
yield msg

# If any failures occurred, raise error with details
if failures:
error_details = "\n".join(
f"- {name}: {error}" for name, error in failures.items()
)
error_msg = f"Some nodes failed to execute:\n{error_details}"
raise RuntimeError(error_msg)

finally:
# Clean up any remaining tasks
for task in tasks:
if not task.done():
task.cancel()

async def _run(
self,
Expand Down
78 changes: 44 additions & 34 deletions src/llmling_agent/messaging/connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,50 @@ def create_connection(
exit_condition: When to exit application
name: Optional name for cross-referencing connections
"""
if isinstance(target, Sequence):
# Multiple targets -> TeamTalk
talks: list[Talk[Any]] = [
Talk(

def register_talk(talk: Talk[Any]) -> None:
"""Helper to register a single talk connection."""
self._connections.append(talk)
self.connection_added.emit(talk)

if source.context and (pool := source.context.pool):
# Always use Talk's name for registration
if name:
pool.connection_registry.register(name, talk)
else:
pool.connection_registry.register_auto(talk)

match target:
case Sequence():
# Create individual talks
talks: list[Talk[Any]] = []
for t in target:
talk = Talk(
source=source,
targets=[t],
connection_type=connection_type,
priority=priority,
delay=delay,
queued=queued,
queue_strategy=queue_strategy,
transform=transform,
filter_condition=filter_condition,
stop_condition=stop_condition,
exit_condition=exit_condition,
)
register_talk(talk)
talks.append(talk)

# Return TeamTalk as convenience wrapper (but don't register it)
return TeamTalk(talks)

case _:
# Single target case
talk = Talk(
source=source,
targets=[t],
name=name,
targets=[target], # type: ignore
connection_type=connection_type,
name=name,
priority=priority,
delay=delay,
queued=queued,
Expand All @@ -145,34 +181,8 @@ def create_connection(
stop_condition=stop_condition,
exit_condition=exit_condition,
)
for t in target
]
for talk in talks:
self._connections.append(talk)
self.connection_added.emit(talk)
return TeamTalk(talks)
# Single target -> Talk
talk = Talk(
source=source,
targets=[target],
name=name,
connection_type=connection_type,
priority=priority,
delay=delay,
queued=queued,
queue_strategy=queue_strategy,
transform=transform,
filter_condition=filter_condition,
stop_condition=stop_condition,
exit_condition=exit_condition,
)
self._connections.append(talk)
self.connection_added.emit(talk)

if name and source.context and (pool := source.context.pool):
pool.connection_registry.register(name, talk)

return talk
register_talk(talk)
return talk

async def trigger_all(self) -> dict[AgentName, list[ChatMessage[Any]]]:
"""Trigger all queued connections."""
Expand Down

0 comments on commit 68fab96

Please sign in to comment.