diff --git a/camel/retrievers/__init__.py b/camel/retrievers/__init__.py index a8257cfe85..f0fa0f330c 100644 --- a/camel/retrievers/__init__.py +++ b/camel/retrievers/__init__.py @@ -11,11 +11,13 @@ # See the License for the specific language governing permissions and # limitations under the License. # ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# ruff: noqa: I001 from .auto_retriever import AutoRetriever from .base import BaseRetriever from .bm25_retriever import BM25Retriever from .cohere_rerank_retriever import CohereRerankRetriever from .vector_retriever import VectorRetriever +from .hybrid_retrival import HybridRetriever __all__ = [ 'BaseRetriever', @@ -23,4 +25,5 @@ 'AutoRetriever', 'BM25Retriever', 'CohereRerankRetriever', + 'HybridRetriever', ] diff --git a/camel/retrievers/hybrid_retrival.py b/camel/retrievers/hybrid_retrival.py new file mode 100644 index 0000000000..7787f33e64 --- /dev/null +++ b/camel/retrievers/hybrid_retrival.py @@ -0,0 +1,237 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +from typing import Any, Collection, Dict, List, Optional, Sequence, Union + +import numpy as np + +from camel.embeddings import BaseEmbedding +from camel.retrievers import BaseRetriever, BM25Retriever, VectorRetriever +from camel.storages import BaseVectorStorage + + +class HybridRetriever(BaseRetriever): + def __init__( + self, + embedding_model: Optional[BaseEmbedding] = None, + vector_storage: Optional[BaseVectorStorage] = None, + ) -> None: + r"""Initializes the HybridRetriever with optional embedding model and + vector storage. + + Args: + embedding_model (Optional[BaseEmbedding]): An optional embedding + model used by the VectorRetriever. Defaults to None. + vector_storage (Optional[BaseVectorStorage]): An optional vector + storage used by the VectorRetriever. Defaults to None. + """ + self.vr = VectorRetriever(embedding_model, vector_storage) + self.bm25 = BM25Retriever() + + def process(self, content_input_path: str) -> None: + r"""Processes the content input path for both vector and BM25 + retrievers. + + Args: + content_input_path (str): File path or URL of the content to be + processed. + + Raises: + ValueError: If the content_input_path is empty. + """ + if not content_input_path: + raise ValueError("content_input_path cannot be empty.") + + self.content_input_path = content_input_path + self.vr.process(content=self.content_input_path) + self.bm25.process(content_input_path=self.content_input_path) + + def _sort_rrf_scores( + self, + vector_retriever_results: List[Dict[str, Any]], + bm25_retriever_results: List[Dict[str, Any]], + top_k: int, + vector_weight: float, + bm25_weight: float, + rank_smoothing_factor: float, + ) -> List[Dict[str, Union[str, float]]]: + r"""Sorts and combines results from vector and BM25 retrievers using + Reciprocal Rank Fusion (RRF). + + Args: + vector_retriever_results: A list of dictionaries containing the + results from the vector retriever, where each dictionary + contains a 'text' entry. + bm25_retriever_results: A list of dictionaries containing the + results from the BM25 retriever, where each dictionary + contains a 'text' entry. + top_k: The number of top results to return after sorting by RRF + score. + vector_weight: The weight to assign to the vector retriever + results in the RRF calculation. + bm25_weight: The weight to assign to the BM25 retriever results in + the RRF calculation. + rank_smoothing_factor: A hyperparameter for the RRF calculation + that helps smooth the rank positions. + + Returns: + List[Dict[str, Union[str, float]]]: A list of dictionaries + representing the sorted results. Each dictionary contains the + 'text'from the retrieved items and their corresponding 'rrf_score'. + + Raises: + ValueError: If any of the input weights are negative. + + References: + https://medium.com/@devalshah1619/mathematical-intuition-behind-reciprocal-rank-fusion-rrf-explained-in-2-mins-002df0cc5e2a + https://colab.research.google.com/drive/1iwVJrN96fiyycxN1pBqWlEr_4EPiGdGy#scrollTo=0qh83qGV2dY8 + """ + text_to_id = {} + id_to_info = {} + current_id = 1 + + # Iterate over vector_retriever_results + for rank, result in enumerate(vector_retriever_results, start=1): + text = result.get('text', None) # type: ignore[attr-defined] + if text is None: + raise KeyError("Each result must contain a 'text' key") + + if text not in text_to_id: + text_to_id[text] = current_id + id_to_info[current_id] = {'text': text, 'vector_rank': rank} + current_id += 1 + else: + id_to_info[text_to_id[text]]['vector_rank'] = rank + + # Iterate over bm25_retriever_results + for rank, result in enumerate(bm25_retriever_results, start=1): + text = result['text'] + if text not in text_to_id: + text_to_id[text] = current_id + id_to_info[current_id] = {'text': text, 'bm25_rank': rank} + current_id += 1 + else: + id_to_info[text_to_id[text]].setdefault('bm25_rank', rank) + + vector_ranks = np.array( + [ + info.get('vector_rank', float('inf')) + for info in id_to_info.values() + ] + ) + bm25_ranks = np.array( + [ + info.get('bm25_rank', float('inf')) + for info in id_to_info.values() + ] + ) + + # Calculate RRF scores + vector_rrf_scores = vector_weight / ( + rank_smoothing_factor + vector_ranks + ) + bm25_rrf_scores = bm25_weight / (rank_smoothing_factor + bm25_ranks) + rrf_scores = vector_rrf_scores + bm25_rrf_scores + + for idx, (_, info) in enumerate(id_to_info.items()): + info['rrf_score'] = rrf_scores[idx] + sorted_results = sorted( + id_to_info.values(), key=lambda x: x['rrf_score'], reverse=True + ) + return sorted_results[:top_k] + + def query( + self, + query: str, + top_k: int = 20, + vector_weight: float = 0.8, + bm25_weight: float = 0.2, + rank_smoothing_factor: int = 60, + vector_retriever_top_k: int = 50, + vector_retriever_similarity_threshold: float = 0.5, + bm25_retriever_top_k: int = 50, + return_detailed_info: bool = False, + ) -> Union[ + dict[str, Sequence[Collection[str]]], + dict[str, Sequence[Union[str, float]]], + ]: + r"""Executes a hybrid retrieval query using both vector and BM25 + retrievers. + + Args: + query (str): The search query. + top_k (int): Number of top results to return (default 20). + vector_weight (float): Weight for vector retriever results in RRF. + bm25_weight (float): Weight for BM25 retriever results in RRF. + rank_smoothing_factor (int): RRF hyperparameter for rank smoothing. + vector_retriever_top_k (int): Top results from vector retriever. + vector_retriever_similarity_threshold (float): Similarity + threshold for vector retriever. + bm25_retriever_top_k (int): Top results from BM25 retriever. + return_detailed_info (bool): Return detailed info if True. + + Returns: + Union[ + dict[str, Sequence[Collection[str]]], + dict[str, Sequence[Union[str, float]]] + ]: By default, returns only the text information. If + `return_detailed_info` is `True`, return detailed information + including rrf scores. + """ + if top_k > max(vector_retriever_top_k, bm25_retriever_top_k): + raise ValueError( + "top_k needs to be less than or equal to the " + "maximum value among vector_retriever_top_k and " + "bm25_retriever_top_k." + ) + if vector_weight < 0 or bm25_weight < 0: + raise ValueError( + "Neither `vector_weight` nor `bm25_weight` can be negative." + ) + + vr_raw_results: List[Dict[str, Any]] = self.vr.query( + query=query, + top_k=vector_retriever_top_k, + similarity_threshold=vector_retriever_similarity_threshold, + ) + # if the number of results is less than top_k, return all results + with_score = [ + info for info in vr_raw_results if 'similarity score' in info + ] + vector_retriever_results = sorted( + with_score, key=lambda x: x['similarity score'], reverse=True + ) + + bm25_retriever_results = self.bm25.query( + query=query, + top_k=bm25_retriever_top_k, + ) + + all_retrieved_info = self._sort_rrf_scores( + vector_retriever_results, + bm25_retriever_results, + top_k, + vector_weight, + bm25_weight, + rank_smoothing_factor, + ) + + retrieved_info = { + "Original Query": query, + "Retrieved Context": ( + all_retrieved_info + if return_detailed_info + else [item['text'] for item in all_retrieved_info] # type: ignore[misc] + ), + } + return retrieved_info diff --git a/examples/rag/single_agent_with_hybrid_rag.py b/examples/rag/single_agent_with_hybrid_rag.py new file mode 100644 index 0000000000..672fdb0a5d --- /dev/null +++ b/examples/rag/single_agent_with_hybrid_rag.py @@ -0,0 +1,58 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +from camel.agents import ChatAgent +from camel.retrievers import HybridRetriever + + +def single_agent(query: str) -> str: + # Set agent role + assistant_sys_msg = """You are a helpful assistant to answer question, + I will give you the Original Query and Retrieved Context, + answer the Original Query based on the Retrieved Context, + if you can't answer the question just say I don't know.""" + + hybrid_retriever = HybridRetriever() + hybrid_retriever.process( + content_input_path="https://en.wikipedia.org/wiki/King_Abdullah_University_of_Science_and_Technology" + ) + + retrieved_info = hybrid_retriever.query( + query=query, + top_k=5, + vector_retriever_top_k=10, + bm25_retriever_top_k=10, + ) + + # Pass the retrieved infomation to agent + user_msg = str(retrieved_info) + agent = ChatAgent(assistant_sys_msg) + + # Get response + assistant_response = agent.step(user_msg) + return assistant_response.msg.content + + +print(single_agent("What is it like to be a visiting student at KAUST?")) +''' +=============================================================================== +Being a visiting student at KAUST involves participating in the Visiting +Student Program (VS), which is designed for 3rd or 4th year undergraduate +or master's students. This program allows students to work directly with KAUST +faculty members for a duration that can range from a few days to several +months. Accepted students typically receive a monthly stipend, and their +accommodation, health insurance, and travel costs are covered. This support +makes the experience financially manageable and allows students to focus on +their research and learning during their time at KAUST. +=============================================================================== +''' diff --git a/test/retrievers/test_hybrid_retriever.py b/test/retrievers/test_hybrid_retriever.py new file mode 100644 index 0000000000..483c62a24d --- /dev/null +++ b/test/retrievers/test_hybrid_retriever.py @@ -0,0 +1,62 @@ +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. ========= +import os +from unittest.mock import MagicMock + +import pytest + +from camel.loaders.unstructured_io import UnstructuredIO +from camel.retrievers.hybrid_retrival import HybridRetriever + + +@pytest.fixture +def mock_hybrid_retriever(monkeypatch): + mock_element = UnstructuredIO.create_element_from_text(text='mock element') + monkeypatch.setattr(os.path, 'exists', MagicMock(return_value=True)) + monkeypatch.setattr( + UnstructuredIO, + 'parse_file_or_url', + MagicMock(return_value=[mock_element]), + ) + # Create an instance of HybridRetriever + hybrid_retriever = HybridRetriever() + hybrid_retriever.process(content_input_path="https://mock.com") + # Mock bm25_retriever.query + hybrid_retriever.bm25.query = MagicMock( + return_value=[ + {'text': 'Document 2', 'similarity score': 0.7}, + {'text': 'Document 3', 'similarity score': 0.6}, + ] + ) + # Mock auto_retriever.run_vector_retriever + hybrid_retriever.vr.query = MagicMock( + return_value=[ + {'text': 'Document 1', 'similarity score': 0.9}, + {'text': 'Document 2', 'similarity score': 0.8}, + ] + ) + return hybrid_retriever + + +def test_sort_rrf_scores_integration(mock_hybrid_retriever): + # Call the method that integrates _sort_rrf_scores + # This is a hypothetical method that would use _sort_rrf_scores internally + results = mock_hybrid_retriever.query( + query="some query about document 1", top_k=2, return_detailed_info=True + ) + + # Verify the results after _sort_rrf_scores + assert len(results['Retrieved Context']) == 2 + assert results['Retrieved Context'][0]['text'] == 'Document 2' + assert results['Retrieved Context'][1]['text'] == 'Document 1'