Skip to content

Commit

Permalink
Merge branch 'main' into bugfix/fix-broken-training
Browse files Browse the repository at this point in the history
  • Loading branch information
bhancockio authored Jan 29, 2025
2 parents 4ff2252 + a3ad2c1 commit d89cdb2
Show file tree
Hide file tree
Showing 13 changed files with 626 additions and 92 deletions.
4 changes: 2 additions & 2 deletions docs/concepts/agents.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ Think of an agent as a specialized team member with specific skills, expertise,
| **Max Retry Limit** _(optional)_ | `max_retry_limit` | `int` | Maximum number of retries when an error occurs. Default is 2. |
| **Respect Context Window** _(optional)_ | `respect_context_window` | `bool` | Keep messages under context window size by summarizing. Default is True. |
| **Code Execution Mode** _(optional)_ | `code_execution_mode` | `Literal["safe", "unsafe"]` | Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct). Default is 'safe'. |
| **Embedder Config** _(optional)_ | `embedder_config` | `Optional[Dict[str, Any]]` | Configuration for the embedder used by the agent. |
| **Embedder** _(optional)_ | `embedder` | `Optional[Dict[str, Any]]` | Configuration for the embedder used by the agent. |
| **Knowledge Sources** _(optional)_ | `knowledge_sources` | `Optional[List[BaseKnowledgeSource]]` | Knowledge sources available to the agent. |
| **Use System Prompt** _(optional)_ | `use_system_prompt` | `Optional[bool]` | Whether to use system prompt (for o1 model support). Default is True. |

Expand Down Expand Up @@ -152,7 +152,7 @@ agent = Agent(
use_system_prompt=True, # Default: True
tools=[SerperDevTool()], # Optional: List of tools
knowledge_sources=None, # Optional: List of knowledge sources
embedder_config=None, # Optional: Custom embedder configuration
embedder=None, # Optional: Custom embedder configuration
system_template=None, # Optional: Custom system prompt template
prompt_template=None, # Optional: Custom prompt template
response_template=None, # Optional: Custom response template
Expand Down
7 changes: 7 additions & 0 deletions docs/concepts/knowledge.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,13 @@ agent = Agent(
verbose=True,
allow_delegation=False,
llm=gemini_llm,
embedder={
"provider": "google",
"config": {
"model": "models/text-embedding-004",
"api_key": GEMINI_API_KEY,
}
}
)

task = Task(
Expand Down
19 changes: 7 additions & 12 deletions src/crewai/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ class Agent(BaseAgent):
tools: Tools at agents disposal
step_callback: Callback to be executed after each step of the agent execution.
knowledge_sources: Knowledge sources for the agent.
embedder: Embedder configuration for the agent.
"""

_times_executed: int = PrivateAttr(default=0)
Expand Down Expand Up @@ -122,17 +123,10 @@ class Agent(BaseAgent):
default="safe",
description="Mode for code execution: 'safe' (using Docker) or 'unsafe' (direct execution).",
)
embedder_config: Optional[Dict[str, Any]] = Field(
embedder: Optional[Dict[str, Any]] = Field(
default=None,
description="Embedder configuration for the agent.",
)
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
default=None,
description="Knowledge sources for the agent.",
)
_knowledge: Optional[Knowledge] = PrivateAttr(
default=None,
)

@model_validator(mode="after")
def post_init_setup(self):
Expand Down Expand Up @@ -163,10 +157,11 @@ def _set_knowledge(self):
if isinstance(self.knowledge_sources, list) and all(
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
):
self._knowledge = Knowledge(
self.knowledge = Knowledge(
sources=self.knowledge_sources,
embedder_config=self.embedder_config,
embedder=self.embedder,
collection_name=knowledge_agent_name,
storage=self.knowledge_storage or None,
)
except (TypeError, ValueError) as e:
raise ValueError(f"Invalid Knowledge Configuration: {str(e)}")
Expand Down Expand Up @@ -225,8 +220,8 @@ def execute_task(
if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory)

if self._knowledge:
agent_knowledge_snippets = self._knowledge.query([task.prompt()])
if self.knowledge:
agent_knowledge_snippets = self.knowledge.query([task.prompt()])
if agent_knowledge_snippets:
agent_knowledge_context = extract_knowledge_context(
agent_knowledge_snippets
Expand Down
50 changes: 48 additions & 2 deletions src/crewai/agents/agent_builder/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.agents.tools_handler import ToolsHandler
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.tools import BaseTool
from crewai.tools.base_tool import Tool
from crewai.utilities import I18N, Logger, RPMController
Expand Down Expand Up @@ -48,6 +50,8 @@ class BaseAgent(ABC, BaseModel):
cache_handler (InstanceOf[CacheHandler]): An instance of the CacheHandler class.
tools_handler (InstanceOf[ToolsHandler]): An instance of the ToolsHandler class.
max_tokens: Maximum number of tokens for the agent to generate in a response.
knowledge_sources: Knowledge sources for the agent.
knowledge_storage: Custom knowledge storage for the agent.
Methods:
Expand Down Expand Up @@ -130,6 +134,17 @@ class BaseAgent(ABC, BaseModel):
max_tokens: Optional[int] = Field(
default=None, description="Maximum number of tokens for the agent's execution."
)
knowledge: Optional[Knowledge] = Field(
default=None, description="Knowledge for the agent."
)
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
default=None,
description="Knowledge sources for the agent.",
)
knowledge_storage: Optional[Any] = Field(
default=None,
description="Custom knowledge storage for the agent.",
)

@model_validator(mode="before")
@classmethod
Expand Down Expand Up @@ -256,13 +271,44 @@ def copy(self: T) -> T: # type: ignore # Signature of "copy" incompatible with
"tools_handler",
"cache_handler",
"llm",
"knowledge_sources",
"knowledge_storage",
"knowledge",
}

# Copy llm and clear callbacks
# Copy llm
existing_llm = shallow_copy(self.llm)
copied_knowledge = shallow_copy(self.knowledge)
copied_knowledge_storage = shallow_copy(self.knowledge_storage)
# Properly copy knowledge sources if they exist
existing_knowledge_sources = None
if self.knowledge_sources:
# Create a shared storage instance for all knowledge sources
shared_storage = (
self.knowledge_sources[0].storage if self.knowledge_sources else None
)

existing_knowledge_sources = []
for source in self.knowledge_sources:
copied_source = (
source.model_copy()
if hasattr(source, "model_copy")
else shallow_copy(source)
)
# Ensure all copied sources use the same storage instance
copied_source.storage = shared_storage
existing_knowledge_sources.append(copied_source)

copied_data = self.model_dump(exclude=exclude)
copied_data = {k: v for k, v in copied_data.items() if v is not None}
copied_agent = type(self)(**copied_data, llm=existing_llm, tools=self.tools)
copied_agent = type(self)(
**copied_data,
llm=existing_llm,
tools=self.tools,
knowledge_sources=existing_knowledge_sources,
knowledge=copied_knowledge,
knowledge_storage=copied_knowledge_storage,
)

return copied_agent

Expand Down
23 changes: 18 additions & 5 deletions src/crewai/crew.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import uuid
import warnings
from concurrent.futures import Future
from copy import copy as shallow_copy
from hashlib import md5
from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union

Expand Down Expand Up @@ -210,8 +211,9 @@ class Crew(BaseModel):
default=None,
description="LLM used to handle chatting with the crew.",
)
_knowledge: Optional[Knowledge] = PrivateAttr(
knowledge: Optional[Knowledge] = Field(
default=None,
description="Knowledge for the crew.",
)

@field_validator("id", mode="before")
Expand Down Expand Up @@ -289,7 +291,7 @@ def create_crew_knowledge(self) -> "Crew":
if isinstance(self.knowledge_sources, list) and all(
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
):
self._knowledge = Knowledge(
self.knowledge = Knowledge(
sources=self.knowledge_sources,
embedder_config=self.embedder,
collection_name="crew",
Expand Down Expand Up @@ -996,8 +998,8 @@ def replay(
return result

def query_knowledge(self, query: List[str]) -> Union[List[Dict[str, Any]], None]:
if self._knowledge:
return self._knowledge.query(query)
if self.knowledge:
return self.knowledge.query(query)
return None

def fetch_inputs(self) -> Set[str]:
Expand Down Expand Up @@ -1041,13 +1043,18 @@ def copy(self):
"_telemetry",
"agents",
"tasks",
"knowledge_sources",
"knowledge",
}

cloned_agents = [agent.copy() for agent in self.agents]

task_mapping = {}

cloned_tasks = []
existing_knowledge_sources = shallow_copy(self.knowledge_sources)
existing_knowledge = shallow_copy(self.knowledge)

for task in self.tasks:
cloned_task = task.copy(cloned_agents, task_mapping)
cloned_tasks.append(cloned_task)
Expand All @@ -1067,7 +1074,13 @@ def copy(self):
copied_data.pop("agents", None)
copied_data.pop("tasks", None)

copied_crew = Crew(**copied_data, agents=cloned_agents, tasks=cloned_tasks)
copied_crew = Crew(
**copied_data,
agents=cloned_agents,
tasks=cloned_tasks,
knowledge_sources=existing_knowledge_sources,
knowledge=existing_knowledge,
)

return copied_crew

Expand Down
25 changes: 13 additions & 12 deletions src/crewai/knowledge/knowledge.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,20 @@ class Knowledge(BaseModel):
Args:
sources: List[BaseKnowledgeSource] = Field(default_factory=list)
storage: Optional[KnowledgeStorage] = Field(default=None)
embedder_config: Optional[Dict[str, Any]] = None
embedder: Optional[Dict[str, Any]] = None
"""

sources: List[BaseKnowledgeSource] = Field(default_factory=list)
model_config = ConfigDict(arbitrary_types_allowed=True)
storage: Optional[KnowledgeStorage] = Field(default=None)
embedder_config: Optional[Dict[str, Any]] = None
embedder: Optional[Dict[str, Any]] = None
collection_name: Optional[str] = None

def __init__(
self,
collection_name: str,
sources: List[BaseKnowledgeSource],
embedder_config: Optional[Dict[str, Any]] = None,
embedder: Optional[Dict[str, Any]] = None,
storage: Optional[KnowledgeStorage] = None,
**data,
):
Expand All @@ -37,32 +37,33 @@ def __init__(
self.storage = storage
else:
self.storage = KnowledgeStorage(
embedder_config=embedder_config, collection_name=collection_name
embedder=embedder, collection_name=collection_name
)
self.sources = sources
self.storage.initialize_knowledge_storage()
for source in sources:
source.storage = self.storage
source.add()
self._add_sources()

def query(self, query: List[str], limit: int = 3) -> List[Dict[str, Any]]:
"""
Query across all knowledge sources to find the most relevant information.
Returns the top_k most relevant chunks.
Raises:
ValueError: If storage is not initialized.
"""
if self.storage is None:
raise ValueError("Storage is not initialized.")

results = self.storage.search(
query,
limit,
)
return results

def _add_sources(self):
for source in self.sources:
source.storage = self.storage
source.add()
try:
for source in self.sources:
source.storage = self.storage
source.add()
except Exception as e:
raise e
8 changes: 7 additions & 1 deletion src/crewai/knowledge/source/base_file_knowledge_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,13 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
def validate_file_path(cls, v, info):
"""Validate that at least one of file_path or file_paths is provided."""
# Single check if both are None, O(1) instead of nested conditions
if v is None and info.data.get("file_path" if info.field_name == "file_paths" else "file_paths") is None:
if (
v is None
and info.data.get(
"file_path" if info.field_name == "file_paths" else "file_paths"
)
is None
):
raise ValueError("Either file_path or file_paths must be provided")
return v

Expand Down
16 changes: 7 additions & 9 deletions src/crewai/knowledge/storage/knowledge_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,11 @@ class KnowledgeStorage(BaseKnowledgeStorage):

def __init__(
self,
embedder_config: Optional[Dict[str, Any]] = None,
embedder: Optional[Dict[str, Any]] = None,
collection_name: Optional[str] = None,
):
self.collection_name = collection_name
self._set_embedder_config(embedder_config)
self._set_embedder_config(embedder)

def search(
self,
Expand Down Expand Up @@ -99,7 +99,7 @@ def initialize_knowledge_storage(self):
)
if self.app:
self.collection = self.app.get_or_create_collection(
name=collection_name, embedding_function=self.embedder_config
name=collection_name, embedding_function=self.embedder
)
else:
raise Exception("Vector Database Client not initialized")
Expand Down Expand Up @@ -187,17 +187,15 @@ def _create_default_embedding_function(self):
api_key=os.getenv("OPENAI_API_KEY"), model_name="text-embedding-3-small"
)

def _set_embedder_config(
self, embedder_config: Optional[Dict[str, Any]] = None
) -> None:
def _set_embedder_config(self, embedder: Optional[Dict[str, Any]] = None) -> None:
"""Set the embedding configuration for the knowledge storage.
Args:
embedder_config (Optional[Dict[str, Any]]): Configuration dictionary for the embedder.
If None or empty, defaults to the default embedding function.
"""
self.embedder_config = (
EmbeddingConfigurator().configure_embedder(embedder_config)
if embedder_config
self.embedder = (
EmbeddingConfigurator().configure_embedder(embedder)
if embedder
else self._create_default_embedding_function()
)
1 change: 0 additions & 1 deletion src/crewai/utilities/embedding_configurator.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def configure_embedder(
raise Exception(
f"Unsupported embedding provider: {provider}, supported providers: {list(self.embedding_functions.keys())}"
)

return self.embedding_functions[provider](config, model_name)

@staticmethod
Expand Down
Loading

0 comments on commit d89cdb2

Please sign in to comment.