Skip to content

Commit

Permalink
Compatibility with Haystack v2
Browse files Browse the repository at this point in the history
Co-authored-by: Peter Izsak <[email protected]>
Co-authored-by: Moshe Berchansky <[email protected]>
  • Loading branch information
3 people committed May 22, 2024
1 parent 3959edd commit 8087aea
Show file tree
Hide file tree
Showing 140 changed files with 11,969 additions and 10,135 deletions.
134 changes: 134 additions & 0 deletions Demo.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@

# Chat Demo with Chainlit

We provide the conversation demo with a multi-modal agent, using the [chainlit](https://github.com/Chainlit/chainlit) framework. For more information, please visit their official website [here](https://docs.chainlit.io/get-started/overview).

For a simple chat experience, we load an LLM, such as [meta-llama/Meta-Llama-3-8B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct), by specifying the configuration of like so:

```yaml
task: text-generation
model: "meta-llama/Meta-Llama-3-8B-Instruct"
do_sample: false
max_new_tokens: 300
```
Then, run the following:
```sh
CONFIG=config/regular_chat.yaml chainlit run fastrag/ui/chainlit_no_rag.py
```

For a chat using a RAG pipeline, specify the tools you wish to use in the following format:

```yaml
chat_model:
generator_kwargs:
model: microsoft/Phi-3-mini-128k-instruct
task: "text-generation"
generation_kwargs:
max_new_tokens: 300
do_sample: false
huggingface_pipeline_kwargs:
torch_dtype: torch.bfloat16
max_new_tokens: 300
do_sample: false
trust_remote_code: true
generator_class: haystack.components.generators.hugging_face_local.HuggingFaceLocalGenerator
tools:
- type: doc
query_handler:
type: "haystack_yaml"
params:
pipeline_yaml_path: "config/empty_doc_only_retrieval_pipeline.yaml"
index_handler:
type: "haystack_yaml"
params:
pipeline_yaml_path: "config/empty_index_pipeline.yaml"
params:
name: "docRetriever"
description: 'useful for when you need to retrieve text to answer questions. Use the following format: {{ "input": [your tool input here ] }}.'
output_variable: "documents"
```
Then, run the application using the command:
```sh
CONFIG=config/rag_pipeline_chat.yaml chainlit run fastrag/ui/chainlit_pipeline.py
```

## Screenshot

![alt text](./assets/chainlit_demo_example.png)


# Multi-Modal Conversational Agent with Chainlit

In this demo, we use the [```xtuner/llava-llama-3-8b-v1_1-transformers```]https://huggingface.co/xtuner/llava-llama-3-8b-v1_1-transformers) model as a conversational agent, that can decide which retriever to use to respond to the user's query.
To perform that, we use dynamic reasoning with [ReAct](https://arxiv.org/abs/2210.03629) prompts, resulting in multiple logical turns.
To explore all the steps to build the agent system, you can check out our [Example Notebook](../examples/multi_modal_react_agent.ipynb).
For more information on how to use ReAct, feel free to visit [Haystack's original tutorial](https://haystack.deepset.ai/tutorials/25_customizing_agent), which our demo is based on.

To run the demo, simply run:

```sh
CONFIG=config/visual_chat_agent.yaml chainlit run fastrag/ui/chainlit_multi_modal_agent.py
```

## Screenshot

![alt text](./assets/chainlit_agent_example.png)

# Available Chat Templates

## Default Template

```
The following is a conversation between a human and an AI. Do not generate the user response to your output.
{memory}
Human: {query}
AI:
```

## Llama 2 Template (Llama2)

```
<s>[INST] <<SYS>>
The following is a conversation between a human and an AI. Do not generate the user response to your output.
<</SYS>>
{memory}{query} [/INST]
```

Notice that here we, the user messages will be:

```
<s>[INST] {USER_QUERY} [/INST]
```

And the model messages will be:

```
{ASSISTATN_RESPONSE} </s>
```

## User-Assistant (UserAssistant)

```
### System:
The following is a conversation between a human and an AI. Do not generate the user response to your output.
{memory}
### User: {query}
### Assistant:
```

## User-Assistant for Llava (UserAssistantLlava)

For the v1.5 llava models, we define a specific template, as shown in [this post regardin Llava models](https://huggingface.co/docs/transformers/model_doc/llava).

```
{memory}
USER: {query}
ASSISTANT:
```
21 changes: 10 additions & 11 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
> [!IMPORTANT]
> We're migrating `fastRAG` to Haystack v2.0 API and will release a major update soon. Stay tuned!
<div align="center">
<img src="assets/fastrag_header.png" width="300"/>

Expand All @@ -10,7 +7,7 @@
<p>Build and explore efficient retrieval-augmented generative models and applications</p>
</h4>

:round_pushpin: <a href="#round_pushpin-installation">Installation</a> • :rocket: <a href="components.md">Components</a> • :books: <a href="examples.md">Examples</a> • :red_car: <a href="getting_started.md">Getting Started</a> • :pill: <a href="demo/README.md">Demos</a> • :pencil2: <a href="scripts/README.md">Scripts</a> • :bar_chart: <a href="benchmarks/README.md">Benchmarks</a>
:round_pushpin: <a href="#round_pushpin-installation">Installation</a> • :rocket: <a href="components.md">Components</a> • :books: <a href="examples.md">Examples</a> • :red_car: <a href="getting_started.md">Getting Started</a> • :pill: <a href="Demo.md">Demos</a> • :pencil2: <a href="scripts/README.md">Scripts</a> • :bar_chart: <a href="benchmarks/README.md">Benchmarks</a>


</div>
Expand All @@ -21,11 +18,13 @@ with a comprehensive tool-set for advancing retrieval augmented generation.

Comments, suggestions, issues and pull-requests are welcomed! :heart:

> [!IMPORTANT]
> Now compatible with Haystack v2+. Please report any possible issues you find.
## :mega: Updates

- **2024-04**: [(Extra Demo)](extras/rag_on_client.ipynb) **Chat with your documents on Intel Meteor Lake iGPU**.
- **2023-12**: Gaudi2, ONNX runtime and LlamaCPP support; Optimized Embedding models; Multi-modality and Chat demos; [REPLUG](https://arxiv.org/abs/2301.12652) text generation.
- **2023-06**: ColBERT index modification: adding/removing documents; see [IndexUpdater](https://github.com/stanford-futuredata/ColBERT/blob/main/colbert/index_updater.py).
- **2023-12**: Gaudi2 and ONNX runtime support; Optimized Embedding models; Multi-modality and Chat demos; [REPLUG](https://arxiv.org/abs/2301.12652) text generation.
- **2023-06**: ColBERT index modification: adding/removing documents; see [IndexUpdater](libs/colbert/colbert/index_updater.py).
- **2023-05**: [RAG with LLM and dynamic prompt synthesis example](examples/rag-prompt-hf.ipynb).
- **2023-04**: Qdrant `DocumentStore` support.

Expand Down Expand Up @@ -57,6 +56,10 @@ For a brief overview of the various unique components in fastRAG refer to the [C
<td><a href="components.md#fastrag-running-llms-with-onnx-runtime">ONNX Runtime</a></td>
<td><em>Running LLMs with optimized ONNX-runtime</td>
</tr>
<tr>
<td><a href="components.md#fastrag-running-quantized-llms-using-openvino">OpenVINO</a></td>
<td><em>Running quantized LLMs using OpenVINO</td>
</tr>
<tr>
<td><a href="components.md#fastrag-running-rag-pipelines-with-llms-on-a-llama-cpp-backend">Llama-CPP</a></td>
<td><em>Running RAG Pipelines with LLMs on a Llama CPP backend</td>
Expand Down Expand Up @@ -117,10 +120,6 @@ pip install .[qdrant] # Support for Qdrant store
pip install .[colbert] # Support for ColBERT+PLAID; requires FAISS
pip install .[faiss-cpu] # CPU-based Faiss library
pip install .[faiss-gpu] # GPU-based Faiss library
pip install .[knowledge_graph] # Libraries for working with spacy and KG

# User interface (for demos)
pip install .[ui]

# Benchmarking
pip install .[benchmark]
Expand Down
Binary file added assets/chainlit_agent_example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/chainlit_demo_example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 17 additions & 0 deletions assets/multi_modal_files/intel_image_data.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
[
{
"image_url": "https://upload.wikimedia.org/wikipedia/commons/1/13/2200_Mission_College_Boulevard.jpg",
"title": "Intel Headquarters",
"content": "Headquarters in Santa Clara, California, in 2023"
},
{
"image_url": "https://upload.wikimedia.org/wikipedia/commons/6/64/Intel_8742_153056995.jpg",
"title": "Intel 8742",
"content": "The die from an Intel 8742, an 8-bit microcontroller that includes a CPU running at 12 MHz, 128 bytes of RAM, 2048 bytes of EPROM, and I/O in the same chip"
},
{
"image_url": "https://upload.wikimedia.org/wikipedia/commons/f/fc/Intel_Costa_12_2007_SJO_105b.jpg",
"title": "microprocessor facility",
"content": "Intel microprocessor facility in Costa Rica was responsible in 2006 for 20% of Costa Rican exports and 4.9% of the country's GDP."
}
]
1 change: 1 addition & 0 deletions chainlit.md
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
# Chat with an LLM 📚
66 changes: 52 additions & 14 deletions components.md
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
# fast**RAG** Components Overview



## REPLUG
<image align="right" src="assets/replug.png" width="600">

Expand All @@ -12,9 +10,9 @@ process a larger number of retrieved documents, without limiting ourselves to th
method works with any LLM, no fine-tuning is needed. See ([Shi et al. 2023](#shiREPLUGRetrievalAugmentedBlackBox2023))
for more details.

We provide implementation for the REPLUG ensembling inference, using the invocation layer
`ReplugHFLocalInvocationLayer`; Our implementation supports most Hugging FAce models with `.generate()` capabilities (such that implement the generation mixin); For a complete example, see [REPLUG Parallel
Reader](examples/replug_parallel_reader.ipynb) notebook.
We provide implementation for the REPLUG ensembling inference, using the generator `ReplugGenerator`; Our implementation
supports most Hugging FAce models with `.generate()` capabilities (such that implement the generation mixin); For a
complete example, see [REPLUG Parallel Reader](examples/replug_parallel_reader.ipynb) notebook.

## ColBERT v2 with PLAID Engine

Expand Down Expand Up @@ -82,11 +80,10 @@ We enabled support for running LLMs on Habana Gaudi (DL1) and Habana Gaudi 2 by
See below an example for loading a `PromptModel` with Habana backend:

```python
from fastrag.prompters.invocation_layers.gaudi_hugging_face_inference import GaudiHFLocalInvocationLayer
from fastrag.generators import GaudiGenerator

PrompterModel = PromptModel(
generator = GaudiGenerator(
model_name_or_path= "meta-llama/Llama-2-7b-chat-hf",
invocation_layer_class=GaudiHFLocalInvocationLayer,
model_kwargs= dict(
max_new_tokens=50,
torch_dtype=torch.bfloat16,
Expand Down Expand Up @@ -137,18 +134,60 @@ quantizer.quantize(save_dir=os.path.join(converted_model_path, 'quantized'), qua

### Loading the Quantized Model

Now that our model is quantized, we can load it in our framework, by specifying the ```ORTInvocationLayer``` invocation layer.
Now that our model is quantized, we can load it in our framework, by using the ```ORTGenerator``` generator.

```python
PrompterModel = PromptModel(
model_name_or_path= "my/local/path/quantized",
invocation_layer_class=ORTInvocationLayer,
generator = ORTGenerator(
model="my/local/path/quantized",
task="text-generation",
generation_kwargs={
"max_new_tokens": 100,
}
)
```

## fastRAG running quantized LLMs using OpenVINO

We provide a method for running quantized LLMs with [OpenVINO](https://docs.openvino.ai/2024/home.html) and [optimum-intel](https://github.com/huggingface/optimum-intel).
We recommend checking out our [notebook](examples/rag_with_openvino.ipynb) with all the details, including the quantization and pipeline construction.

### Installation

Run the following command to install our dependencies:

```
pip install -e .[openvino]
```

For more information regarding the installation process, we recommend checking out the guides provided by [OpenVINO](https://docs.openvino.ai/2024/home.html) and [optimum-intel](https://github.com/huggingface/optimum-intel).

### LLM Quantization

We can use the [OpenVINO tutorial notebook](https://github.com/openvinotoolkit/openvino_notebooks/blob/main/notebooks/254-llm-chatbot/254-llm-chatbot.ipynb) to quantize an LLM to our liking.

### Loading the Quantized Model

Now that our model is quantized, we can load it in our framework, by using the ```OpenVINOGenerator``` component.

```python
from fastrag.generators.openvino import OpenVINOGenerator

openvino_compressed_model_path = "path/to/model"

generator = OpenVINOGenerator(
model="microsoft/phi-2",
compressed_model_dir=openvino_compressed_model_path,
device_openvino="CPU",
task="text-generation",
generation_kwargs={
"max_new_tokens": 100,
}
)
```

## fastRAG Running RAG Pipelines with LLMs on a Llama CPP backend

To run LLM effectively on CPUs, especially on client side machines, we offer a method for running LLMs using the [llama-cpp](https://github.com/ggerganov/llama.cpp).
To run LLMs effectively on CPUs, especially on client side machines, we offer a method for running LLMs using the [llama-cpp](https://github.com/ggerganov/llama.cpp).
We recommend checking out our [tutorial notebook](examples/client_inference_with_Llama_cpp.ipynb) with all the details, including processes such as downloading GGUF models.

### Installation
Expand Down Expand Up @@ -176,7 +215,6 @@ PrompterModel = PromptModel(
)
```


## Optimized Embedding Models

Bi-encoder Embedders are key components of Retrieval Augmented Generation pipelines. Mainly used for indexing documents and for online re-ranking. We provide support for quantized `int8` models that have low latency and high throughput, using [`optimum-intel`](https://github.com/huggingface/optimum-intel) framework.
Expand Down
1 change: 1 addition & 0 deletions config/data/wikipedia_hf_6M.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@ dataset_info:
split: "train"
encoding_method: "wikipedia_hf_multisentence"
batch_size: 500
limit:
8 changes: 8 additions & 0 deletions config/data/wikitext_hf.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
type: fastrag.data_loaders.HFDatasetLoader
dataset_info:
path: wikitext
name: wikitext-2-v1
split: train
encoding_method: text
batch_size: 100
limit:
1 change: 0 additions & 1 deletion config/doc_chat_ort.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,4 @@ chat_model:
session.intra_op_thread_affinities: '3,4;5,6;7,8;9,10;11,12'
intra_op_num_threads: 6
model_name_or_path: '/tmp/facebook_opt-iml-max-1.3b/quantized'
invocation_layer_class: fastrag.prompters.invocation_layers.ort.ORTInvocationLayer
doc_pipeline_file: "config/empty_retrieval_pipeline.yaml"
9 changes: 0 additions & 9 deletions config/embedder/dpr.yaml

This file was deleted.

7 changes: 7 additions & 0 deletions config/embedder/sentence-transformer-docs.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
type: haystack.components.embedders.SentenceTransformersDocumentEmbedder
init_parameters:
model: "BAAI/llm-embedder"
meta_fields_to_embed: ['title'] # optional
prefix: "Represent this document for retrieval: "
batch_size: 256
device:
6 changes: 6 additions & 0 deletions config/embedder/sentence-transformer-text.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
type: haystack.components.embedders.SentenceTransformersTextEmbedder
init_parameters:
model: "BAAI/llm-embedder"
prefix: "Represent this query for retrieving relevant documents: "
batch_size: 256
device:
9 changes: 0 additions & 9 deletions config/embedder/sentence-transformer.yaml

This file was deleted.

Loading

0 comments on commit 8087aea

Please sign in to comment.