Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable speculative decoding #2777

Merged
merged 18 commits into from
Jan 29, 2025
40 changes: 35 additions & 5 deletions demos/common/export_models/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
def add_common_arguments(parser):
parser.add_argument('--model_repository_path', required=False, default='models', help='Where the model should be exported to', dest='model_repository_path')
parser.add_argument('--source_model', required=True, help='HF model name or path to the local folder with PyTorch or OpenVINO model', dest='source_model')
parser.add_argument('--model_name', required=False, default=None, help='Model name that should be used in the deployment. Equal to source_name if HF model name is used', dest='model_name')
parser.add_argument('--model_name', required=False, default=None, help='Model name that should be used in the deployment. Equal to source_model if HF model name is used', dest='model_name')
parser.add_argument('--weight-format', default='int8', help='precision of the exported model', dest='precision')
parser.add_argument('--config_file_path', default='config.json', help='path to the config file', dest='config_file_path')
parser.add_argument('--overwrite_models', default=False, action='store_true', help='Overwrite the model if it already exists in the models repository', dest='overwrite_models')
Expand All @@ -44,6 +44,11 @@ def add_common_arguments(parser):
parser_text.add_argument('--max_num_batched_tokens', default=None, help='empty or integer. The maximum number of tokens that can be batched together.', dest='max_num_batched_tokens')
parser_text.add_argument('--max_num_seqs', default=None, help='256 by default. The maximum number of sequences that can be processed together.', dest='max_num_seqs')
parser_text.add_argument('--cache_size', default=10, type=int, help='cache size in GB', dest='cache_size')
parser_text.add_argument('--draft_source_model', required=False, default=None, help='HF model name or path to the local folder with PyTorch or OpenVINO draft model. '
'Using this option will create configuration for speculative decoding', dest='draft_source_model')
parser_text.add_argument('--draft_model_name', required=False, default=None, help='Draft model name that should be used in the deployment. '
'Equal to draft_source_model if HF model name is used. Available only in draft_source_model has been specified.', dest='draft_model_name')

parser_embeddings = subparsers.add_parser('embeddings', help='export model for embeddings endpoint')
add_common_arguments(parser_embeddings)
parser_embeddings.add_argument('--skip_normalize', default=True, action='store_false', help='Skip normalize the embeddings.', dest='normalize')
Expand Down Expand Up @@ -148,7 +153,9 @@ def add_common_arguments(parser):
dynamic_split_fuse: false, {% endif %}
max_num_seqs: {{max_num_seqs|default("256", true)}},
device: "{{target_device|default("CPU", true)}}",

{%- if draft_model_dir_name %}
# Speculative decoding configuration
draft_models_path: "./{{draft_model_dir_name}}",{% endif %}
}
}
input_stream_handler {
Expand Down Expand Up @@ -265,10 +272,26 @@ def export_text_generation_model(model_repository_path, source_model, model_name
if not os.path.isdir(llm_model_path) or args['overwrite_models']:
optimum_command = "optimum-cli export openvino --model {} --weight-format {} --trust-remote-code {}".format(source_model, precision, llm_model_path)
if os.system(optimum_command):
raise ValueError("Failed to export llm model", source_model)
raise ValueError("Failed to export llm model", source_model)
### Speculative decoding specific
draft_source_model = task_parameters.get("draft_source_model", None)
draft_model_dir_name = None
if draft_source_model:
draft_model_dir_name = draft_source_model.replace("/", "-") # flatten the name so we don't create nested directory structure
draft_llm_model_path = os.path.join(model_repository_path, model_name, draft_model_dir_name)
if os.path.isfile(os.path.join(draft_llm_model_path, 'openvino_model.xml')):
print("OV model is source folder. Skipping conversion.")
else: # assume HF model name or local pytorch model folder
print("Exporting LLM model to ", draft_llm_model_path)
if not os.path.isdir(draft_llm_model_path) or args['overwrite_models']:
optimum_command = "optimum-cli export openvino --model {} --weight-format {} --trust-remote-code {}".format(draft_source_model, precision, draft_llm_model_path)
if os.system(optimum_command):
raise ValueError("Failed to export llm model", source_model)
###
os.makedirs(os.path.join(model_repository_path, model_name), exist_ok=True)
gtemplate = jinja2.Environment(loader=jinja2.BaseLoader).from_string(text_generation_graph_template)
graph_content = gtemplate.render(tokenizer_model="{}_tokenizer_model".format(model_name), embeddings_model="{}_embeddings_model".format(model_name), model_path=model_path, **task_parameters)
graph_content = gtemplate.render(tokenizer_model="{}_tokenizer_model".format(model_name), embeddings_model="{}_embeddings_model".format(model_name),
model_path=model_path, draft_model_dir_name=draft_model_dir_name, **task_parameters)
with open(os.path.join(model_repository_path, model_name, 'graph.pbtxt'), 'w') as f:
f.write(graph_content)
print("Created graph {}".format(os.path.join(model_repository_path, model_name, 'graph.pbtxt')))
Expand Down Expand Up @@ -375,8 +398,15 @@ def export_rerank_model(model_repository_path, source_model, model_name, precisi
if args['model_name'] is None and args['source_model'] is None:
raise ValueError("Either model_name or source_model should be provided")

### Speculative decoding specific
if args['draft_source_model'] is None:
args['draft_source_model'] = args['draft_model_name']
if args['draft_model_name'] is None:
args['draft_model_name'] = args['draft_source_model']
###

template_parameters = {k: v for k, v in args.items() if k not in ['model_repository_path', 'source_model', 'model_name', 'precision', 'version', 'config_file_path', 'overwrite_models']}
print("template params:",template_parameters)
print("template params:", template_parameters)

if args['task'] == 'text_generation':
export_text_generation_model(args['model_repository_path'], args['source_model'], args['model_name'], args['precision'], template_parameters, args['config_file_path'])
Expand Down
5 changes: 5 additions & 0 deletions demos/continuous_batching/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ hidden:
ovms_demos_continuous_batching_accuracy
ovms_demos_continuous_batching_rag
ovms_demos_continuous_batching_scaling
ovms_demos_continuous_batching_speculative_decoding
```

This demo shows how to deploy LLM models in the OpenVINO Model Server using continuous batching and paged attention algorithms.
Expand Down Expand Up @@ -330,6 +331,10 @@ Check this simple [text generation scaling demo](https://github.com/openvinotool

Check the [guide of using lm-evaluation-harness](https://github.com/openvinotoolkit/model_server/blob/main/demos/continuous_batching/accuracy/README.md)

## Use Speculative Decoding

Check the [guide for speculative decoding](./speculative_decoding/README.md)


## References
- [Chat Completions API](../../docs/model_server_rest_api_chat.md)
Expand Down
183 changes: 183 additions & 0 deletions demos/continuous_batching/speculative_decoding/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,183 @@
# How to serve LLM Models in Speculative Decoding Pipeline{#ovms_demos_continuous_batching_speculative_decoding}

Following [OpenVINO GenAI docs](https://docs.openvino.ai/2024/learn-openvino/llm_inference_guide/genai-guide.html#efficient-text-generation-via-speculative-decoding):
> Speculative decoding (or assisted-generation) enables faster token generation when an additional smaller draft model is used alongside the main model. This reduces the number of infer requests to the main model, increasing performance.
>
> The draft model predicts the next K tokens one by one in an autoregressive manner. The main model validates these predictions and corrects them if necessary - in case of a discrepancy, the main model prediction is used. Then, the draft model acquires this token and runs prediction of the next K tokens, thus repeating the cycle.

The goal of this sampling method is to reduce latency while keeping the main model accuracy. It gives the biggest gain in low concurrency scenario.

This demo shows how to use speculative decoding in the model serving scenario, by deploying main and draft models in a speculative decoding pipeline in a manner similar to regular deployments with continuous batching.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a note that the goal of this algorithm is to reduce the latency while keeping the main model accuracy. It give the biggest gain in low concurrency requests.

## Prerequisites

**Model preparation**: Python 3.9 or higher with pip and HuggingFace account

**Model Server deployment**: Installed Docker Engine or OVMS binary package according to the [baremetal deployment guide](../../docs/deploying_server_baremetal.md)

## Model considerations

From the functional perspective both main and draft models must use the same tokenizer, so the tokens from the draft model are correctly matched in the the main model.

From the performance perspective, benefits from speculative decoding are strictly tied to the pair of models used.
For some models, the performance boost is significant, while for others it's rather negligible. Models sizes and precisions also come into play, so optimal setup shall be found empirically.

In this demo we will use:
- [meta-llama/CodeLlama-7b-hf](https://huggingface.co/meta-llama/CodeLlama-7b-hf) as a main model
- [AMD-Llama-135m](https://huggingface.co/amd/AMD-Llama-135m) as a draft model

both in FP16 precision.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why FP16? Can it be loaded on dGPU?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No particular reason. There were no tests on GPU, but speculative decoding reuses most of regular CB pipeline logic, so there should be not issue. Specifying target_device will propage to draft model also.


## Model preparation
Here, the original Pytorch LLM models and the tokenizers will be converted to IR format and optionally quantized.
That ensures faster initialization time, better performance and lower memory consumption.
LLM engine parameters will be defined inside the `graph.pbtxt` file.

Download export script, install its dependencies and create directory for the models:
```console
curl https://raw.githubusercontent.com/openvinotoolkit/model_server/refs/heads/main/demos/common/export_models/export_model.py -o export_model.py
pip3 install -r https://raw.githubusercontent.com/openvinotoolkit/model_server/refs/heads/main/demos/common/export_models/requirements.txt
mkdir models
```

Run `export_model.py` script to download and quantize the model:

> **Note:** Before downloading the CodeLlama model, access must be requested. Follow the instructions on the [meta-llama/CodeLlama-7b-hf](https://huggingface.co/meta-llama/CodeLlama-7b-hf) to request access. When access is granted, create an authentication token in the HuggingFace account -> Settings -> Access Tokens page. Issue the following command and enter the authentication token. Authenticate via `huggingface-cli login`.

```console
python export_model.py text_generation --source_model meta-llama/CodeLlama-7b-hf --draft_source_model amd/AMD-Llama-135m --weight-format fp16 --kv_cache_precision u8 --config_file_path models/config.json --model_repository_path models
```

Draft model inherits all scheduler properties from the main model.

You should have a model folder like below:
```
models
├── config.json
└── meta-llama
└── CodeLlama-7b-hf
├── amd-AMD-Llama-135m
│   ├── config.json
│   ├── generation_config.json
│   ├── openvino_detokenizer.bin
│   ├── openvino_detokenizer.xml
│   ├── openvino_model.bin
│   ├── openvino_model.xml
│   ├── openvino_tokenizer.bin
│   ├── openvino_tokenizer.xml
│   ├── special_tokens_map.json
│   ├── tokenizer_config.json
│   ├── tokenizer.json
│   └── tokenizer.model
├── config.json
├── generation_config.json
├── graph.pbtxt
├── openvino_detokenizer.bin
├── openvino_detokenizer.xml
├── openvino_model.bin
├── openvino_model.xml
├── openvino_tokenizer.bin
├── openvino_tokenizer.xml
├── special_tokens_map.json
├── tokenizer_config.json
├── tokenizer.json
└── tokenizer.model

```

## Server Deployment

:::{dropdown} **Deploying with Docker**
```bash
docker run -d --rm -p 8000:8000 -v $(pwd)/models:/workspace:ro openvino/model_server:latest --rest_port 8000 --config_path /workspace/config.json
```

Running above command starts the container with no accelerators support.
To deploy on devices other than CPU, change `target_device` parameter in `export_model.py` call and follow [AI accelerators guide](../../../docs/accelerators.md) for additionally required docker parameters.
:::

:::{dropdown} **Deploying on Bare Metal**

Assuming you have unpacked model server package, make sure to:

- **On Windows**: run `setupvars` script
- **On Linux**: set `LD_LIBRARY_PATH` and `PATH` environment variables

as mentioned in [deployment guide](../../docs/deploying_server_baremetal.md), in every new shell that will start OpenVINO Model Server.

Depending on how you prepared models in the first step of this demo, they are deployed to either CPU or GPU (it's defined in `config.json`). If you run on GPU make sure to have appropriate drivers installed, so the device is accessible for the model server.

```bat
ovms --rest_port 8000 --config_path ./models/config.json
```
:::

## Readiness Check

Wait for the model to load. You can check the status with a simple command:
```console
curl http://localhost:8000/v1/config
```
```json
{
"meta-llama/CodeLlama-7b-hf": {
"model_version_status": [
{
"version": "1",
"state": "AVAILABLE",
"status": {
"error_code": "OK",
"error_message": "OK"
}
}
]
}
}
```

## Request Generation

Models used in this demo - `meta-llama/CodeLlama-7b-hf` and `AMD-Llama-135m` are not chat models, so we will use `completions` endpoint to interact with the pipeline.

Below you can see an exemplary unary request (you can switch `stream` parameter to enable streamed response). Compared to calls to regular continuous batching model, this request has additional parameter `num_assistant_tokens` which specifies how many tokens should a draft model generate before main model validates them.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the default values if both params are omitted?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are none. One of those must be provided in the request.
I didn't specify any default because it's hard to recommend a single value that would work well for different combinations of main and draft models.



```console
curl http://localhost:8000/v3/completions \
-H "Content-Type: application/json" \
-d '{
"model": "meta-llama/CodeLlama-7b-hf",
"temperature": 0,
"max_tokens":100,
"stream":false,
"prompt": "<s>def quicksort(numbers):",
"num_assistant_tokens": 5
}'| jq .
```
```json
{
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": "\n if len(numbers) <= 1:\n return numbers\n else:\n pivot = numbers[0]\n lesser = [x for x in numbers[1:] if x <= pivot]\n greater = [x for x in numbers[1:] if x > pivot]\n return quicksort(lesser) + [pivot] + quicksort(greater)\n\n\ndef quicksort_recursive(numbers):\n if"
}
],
"created": 1737547359,
"model": "meta-llama/CodeLlama-7b-hf-sd",
"object": "text_completion",
"usage": {
"prompt_tokens": 9,
"completion_tokens": 100,
"total_tokens": 109
}
}

```

High value for `num_assistant_tokens` brings profit when tokens generated by the draft model mostly match the main model. If they don't, tokens are dropped and both models do additional work. For low values such risk is lower, but the potential performance boost is limited. Usually the value of `5` is a good compromise.

Second speculative decoding specific parameter is `assistant_confidence_threshold ` which determines confidence level for continuing generation. If draft model generates token with confidence below that threshold, it stops generation for the current cycle and main model starts validation. `assistant_confidence_threshold` is a float in range (0, 1).

**Note that `num_assistant_tokens` and `assistant_confidence_threshold` are mutually exclusive.**
9 changes: 9 additions & 0 deletions docs/model_server_rest_api_chat.md
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,15 @@ curl http://localhost/v3/chat/completions \
| presence_penalty | ✅ | ✅ | ✅ | float (default: `0.0`) | Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. |
| seed | ✅ | ✅ | ✅ | integer (default: `0`) | Random seed to use for the generation. |

#### Speculative decoding specific

Note that below parameters are valid only for speculative pipeline. See [speculative decoding demo](../demos/continuous_batching/speculative_decoding/README.md) for details on how to prepare and serve such pipeline.

| Param | OpenVINO Model Server | OpenAI /completions API | vLLM Serving Sampling Params | Type | Description |
|-------|----------|----------|----------|---------|-----|
| num_assistant_tokens | ✅ | ❌ | ⚠️ | int | This value defines how many tokens should a draft model generate before main model validates them. Equivalent of `num_speculative_tokens` in vLLM. Cannot be used with `assistant_confidence_threshold`. |
| assistant_confidence_threshold | ✅ | ❌ | ❌ | float | This parameter determines confidence level for continuing generation. If draft model generates token with confidence below that threshold, it stops generation for the current cycle and main model starts validation. Cannot be used with `num_assistant_tokens`. |

#### Unsupported params from OpenAI service:
- logit_bias
- top_logprobs
Expand Down
7 changes: 7 additions & 0 deletions docs/model_server_rest_api_completions.md
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,13 @@ curl http://localhost/v3/completions \
| presence_penalty | ✅ | ✅ | ✅ | float (default: `0.0`) | Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. |
| seed | ✅ | ✅ | ✅ | integer (default: `0`) | Random seed to use for the generation. |

#### Speculative decoding specific

| Param | OpenVINO Model Server | OpenAI /completions API | vLLM Serving Sampling Params | Type | Description |
|-------|----------|----------|----------|---------|-----|
| num_assistant_tokens | ✅ | ❌ | ⚠️ | int | This value defines how many tokens should a draft model generate before main model validates them. Equivalent of `num_speculative_tokens` in vLLM. Cannot be used with `assistant_confidence_threshold`. |
| assistant_confidence_threshold | ✅ | ❌ | ❌ | float | This parameter determines confidence level for continuing generation. If draft model generates token with confidence below that threshold, it stops generation for the current cycle and main model starts validation. Cannot be used with `num_assistant_tokens`. |

#### Unsupported params from OpenAI service:
- logit_bias
- suffix
Expand Down
Loading