From cb2780aa298f8e45b4f156063d0d89901eee259f Mon Sep 17 00:00:00 2001 From: Vertex MG Team Date: Tue, 17 Dec 2024 14:05:25 -0800 Subject: [PATCH] Hex-LLM supports disaggregated serving as an experimental feature PiperOrigin-RevId: 707243638 --- ...l_garden_codegemma_deployment_on_vertex.ipynb | 3 +++ ...odel_garden_gemma2_deployment_on_vertex.ipynb | 3 +++ ...model_garden_gemma_deployment_on_vertex.ipynb | 3 +++ ...model_garden_gemma_finetuning_on_vertex.ipynb | 3 +++ .../model_garden_phi3_deployment.ipynb | 3 +++ ...odel_garden_pytorch_llama3_1_deployment.ipynb | 16 +++++++++++++++- ...odel_garden_pytorch_llama3_2_deployment.ipynb | 3 +++ .../model_garden_pytorch_qwen2_deployment.ipynb | 3 +++ 8 files changed, 36 insertions(+), 1 deletion(-) diff --git a/notebooks/community/model_garden/model_garden_codegemma_deployment_on_vertex.ipynb b/notebooks/community/model_garden/model_garden_codegemma_deployment_on_vertex.ipynb index 7eb3ddb05..6304295cb 100644 --- a/notebooks/community/model_garden/model_garden_codegemma_deployment_on_vertex.ipynb +++ b/notebooks/community/model_garden/model_garden_codegemma_deployment_on_vertex.ipynb @@ -323,6 +323,7 @@ " tensor_parallel_size: int = 1,\n", " machine_type: str = \"ct5lp-hightpu-1t\",\n", " tpu_topology: str = \"1x1\",\n", + " disagg_topology: str = None,\n", " hbm_utilization_factor: float = 0.6,\n", " max_running_seqs: int = 256,\n", " max_model_len: int = 4096,\n", @@ -364,6 +365,8 @@ " f\"--max_running_seqs={max_running_seqs}\",\n", " f\"--max_model_len={max_model_len}\",\n", " ]\n", + " if disagg_topology:\n", + " hexllm_args.append(f\"--disagg_topo={disagg_topology}\")\n", "\n", " env_vars = {\n", " \"MODEL_ID\": base_model_id,\n", diff --git a/notebooks/community/model_garden/model_garden_gemma2_deployment_on_vertex.ipynb b/notebooks/community/model_garden/model_garden_gemma2_deployment_on_vertex.ipynb index 4e6a9cb0b..c0640f7af 100644 --- a/notebooks/community/model_garden/model_garden_gemma2_deployment_on_vertex.ipynb +++ b/notebooks/community/model_garden/model_garden_gemma2_deployment_on_vertex.ipynb @@ -290,6 +290,7 @@ " tensor_parallel_size: int = 1,\n", " machine_type: str = \"ct5lp-hightpu-1t\",\n", " tpu_topology: str = \"1x1\",\n", + " disagg_topology: str = None,\n", " hbm_utilization_factor: float = 0.6,\n", " max_running_seqs: int = 256,\n", " max_model_len: int = 4096,\n", @@ -331,6 +332,8 @@ " f\"--max_running_seqs={max_running_seqs}\",\n", " f\"--max_model_len={max_model_len}\",\n", " ]\n", + " if disagg_topology:\n", + " hexllm_args.append(f\"--disagg_topo={disagg_topology}\")\n", "\n", " env_vars = {\n", " \"MODEL_ID\": base_model_id,\n", diff --git a/notebooks/community/model_garden/model_garden_gemma_deployment_on_vertex.ipynb b/notebooks/community/model_garden/model_garden_gemma_deployment_on_vertex.ipynb index 83cd855fa..256825e6c 100644 --- a/notebooks/community/model_garden/model_garden_gemma_deployment_on_vertex.ipynb +++ b/notebooks/community/model_garden/model_garden_gemma_deployment_on_vertex.ipynb @@ -312,6 +312,7 @@ " tensor_parallel_size: int = 1,\n", " machine_type: str = \"ct5lp-hightpu-1t\",\n", " tpu_topology: str = \"1x1\",\n", + " disagg_topology: str = None,\n", " hbm_utilization_factor: float = 0.6,\n", " max_running_seqs: int = 256,\n", " max_model_len: int = 4096,\n", @@ -353,6 +354,8 @@ " f\"--max_running_seqs={max_running_seqs}\",\n", " f\"--max_model_len={max_model_len}\",\n", " ]\n", + " if disagg_topology:\n", + " hexllm_args.append(f\"--disagg_topo={disagg_topology}\")\n", "\n", " env_vars = {\n", " \"MODEL_ID\": base_model_id,\n", diff --git a/notebooks/community/model_garden/model_garden_gemma_finetuning_on_vertex.ipynb b/notebooks/community/model_garden/model_garden_gemma_finetuning_on_vertex.ipynb index 2f4fda0be..190ae8561 100644 --- a/notebooks/community/model_garden/model_garden_gemma_finetuning_on_vertex.ipynb +++ b/notebooks/community/model_garden/model_garden_gemma_finetuning_on_vertex.ipynb @@ -1197,6 +1197,7 @@ " tensor_parallel_size: int = 1,\n", " machine_type: str = \"ct5lp-hightpu-1t\",\n", " tpu_topology: str = \"1x1\",\n", + " disagg_topology: str = None,\n", " hbm_utilization_factor: float = 0.6,\n", " max_running_seqs: int = 256,\n", " max_model_len: int = 4096,\n", @@ -1238,6 +1239,8 @@ " f\"--max_running_seqs={max_running_seqs}\",\n", " f\"--max_model_len={max_model_len}\",\n", " ]\n", + " if disagg_topology:\n", + " hexllm_args.append(f\"--disagg_topo={disagg_topology}\")\n", "\n", " env_vars = {\n", " \"MODEL_ID\": base_model_id,\n", diff --git a/notebooks/community/model_garden/model_garden_phi3_deployment.ipynb b/notebooks/community/model_garden/model_garden_phi3_deployment.ipynb index ba38dd6f6..32fcbca4c 100644 --- a/notebooks/community/model_garden/model_garden_phi3_deployment.ipynb +++ b/notebooks/community/model_garden/model_garden_phi3_deployment.ipynb @@ -637,6 +637,7 @@ " tensor_parallel_size: int = 1,\n", " machine_type: str = \"ct5lp-hightpu-1t\",\n", " tpu_topology: str = \"1x1\",\n", + " disagg_topology: str = None,\n", " hbm_utilization_factor: float = 0.6,\n", " max_running_seqs: int = 256,\n", " max_model_len: int = 4096,\n", @@ -678,6 +679,8 @@ " f\"--max_running_seqs={max_running_seqs}\",\n", " f\"--max_model_len={max_model_len}\",\n", " ]\n", + " if disagg_topology:\n", + " hexllm_args.append(f\"--disagg_topo={disagg_topology}\")\n", "\n", " env_vars = {\n", " \"MODEL_ID\": base_model_id,\n", diff --git a/notebooks/community/model_garden/model_garden_pytorch_llama3_1_deployment.ipynb b/notebooks/community/model_garden/model_garden_pytorch_llama3_1_deployment.ipynb index 3579f1323..ef616a346 100644 --- a/notebooks/community/model_garden/model_garden_pytorch_llama3_1_deployment.ipynb +++ b/notebooks/community/model_garden/model_garden_pytorch_llama3_1_deployment.ipynb @@ -259,12 +259,20 @@ "\n", "# @markdown Find Vertex AI prediction TPUv5e machine types in\n", "# @markdown https://cloud.google.com/vertex-ai/docs/predictions/use-tpu#deploy_a_model.\n", + "# @markdown The 8B model variant requires 4 TPU v5e cores single host, and the 70B model variant requires 16 TPU v5e cores 4x4 multi host.\n", + "# @markdown Choose `ct5lp-hightpu-4t` for both 8B and 70B model variants. The multi-host topology will be automatically set based on the model size.\n", + "# @markdown Choose `ct5lp-hightpu-8t` for the 8B variant when you want to use the experimental disaggregated serving topology.\n", "\n", "# Sets ct5lp-hightpu-4t (4 TPU chips) to deploy models.\n", - "machine_type = \"ct5lp-hightpu-4t\"\n", + "machine_type = \"ct5lp-hightpu-4t\" # @param [\"ct5lp-hightpu-4t\", \"ct5lp-hightpu-8t\"]\n", "# Note: 1 TPU V5 chip has only one core.\n", "tpu_type = \"TPU_V5e\"\n", "\n", + "# @markdown Set the disaggregated topology to balance the TTFT and TPOT.\n", + "# @markdown This is an **experimental** feature and is only supported for single host deployments.\n", + "# @markdown If want to enable the feature, set this parameter to a string of the form `\"num_prefill_workers,num_decode_workers\"`, like `\"3,1\"`.\n", + "disagg_topo = None # @param\n", + "\n", "if \"8B\" in MODEL_ID:\n", " tpu_count = 4\n", " tpu_topo = \"1x4\"\n", @@ -286,6 +294,8 @@ "# Server parameters.\n", "tensor_parallel_size = tpu_count\n", "\n", + "# @markdown Set the server parameters.\n", + "\n", "# Fraction of HBM memory allocated for KV cache after model loading. A larger value improves throughput but gives higher risk of TPU out-of-memory errors with long prompts.\n", "hbm_utilization_factor = 0.8 # @param\n", "# Maximum number of running sequences in a continuous batch.\n", @@ -307,6 +317,7 @@ " tensor_parallel_size: int = 1,\n", " machine_type: str = \"ct5lp-hightpu-1t\",\n", " tpu_topology: str = \"1x1\",\n", + " disagg_topology: str = None,\n", " hbm_utilization_factor: float = 0.6,\n", " max_running_seqs: int = 256,\n", " max_model_len: int = 4096,\n", @@ -348,6 +359,8 @@ " f\"--max_running_seqs={max_running_seqs}\",\n", " f\"--max_model_len={max_model_len}\",\n", " ]\n", + " if disagg_topology:\n", + " hexllm_args.append(f\"--disagg_topo={disagg_topology}\")\n", "\n", " env_vars = {\n", " \"MODEL_ID\": base_model_id,\n", @@ -396,6 +409,7 @@ " tensor_parallel_size=tensor_parallel_size,\n", " machine_type=machine_type,\n", " tpu_topology=tpu_topo,\n", + " disagg_topology=disagg_topo,\n", " hbm_utilization_factor=hbm_utilization_factor,\n", " max_running_seqs=max_running_seqs,\n", " max_model_len=max_model_len,\n", diff --git a/notebooks/community/model_garden/model_garden_pytorch_llama3_2_deployment.ipynb b/notebooks/community/model_garden/model_garden_pytorch_llama3_2_deployment.ipynb index 68d7de3d5..fffed472d 100644 --- a/notebooks/community/model_garden/model_garden_pytorch_llama3_2_deployment.ipynb +++ b/notebooks/community/model_garden/model_garden_pytorch_llama3_2_deployment.ipynb @@ -314,6 +314,7 @@ " tensor_parallel_size: int = 1,\n", " machine_type: str = \"ct5lp-hightpu-1t\",\n", " tpu_topology: str = \"1x1\",\n", + " disagg_topology: str = None,\n", " hbm_utilization_factor: float = 0.6,\n", " max_running_seqs: int = 256,\n", " max_model_len: int = 4096,\n", @@ -355,6 +356,8 @@ " f\"--max_running_seqs={max_running_seqs}\",\n", " f\"--max_model_len={max_model_len}\",\n", " ]\n", + " if disagg_topology:\n", + " hexllm_args.append(f\"--disagg_topo={disagg_topology}\")\n", "\n", " env_vars = {\n", " \"MODEL_ID\": base_model_id,\n", diff --git a/notebooks/community/model_garden/model_garden_pytorch_qwen2_deployment.ipynb b/notebooks/community/model_garden/model_garden_pytorch_qwen2_deployment.ipynb index bf2c4aabb..3985eb1c0 100644 --- a/notebooks/community/model_garden/model_garden_pytorch_qwen2_deployment.ipynb +++ b/notebooks/community/model_garden/model_garden_pytorch_qwen2_deployment.ipynb @@ -565,6 +565,7 @@ " tensor_parallel_size: int = 1,\n", " machine_type: str = \"ct5lp-hightpu-1t\",\n", " tpu_topology: str = \"1x1\",\n", + " disagg_topology: str = None,\n", " hbm_utilization_factor: float = 0.6,\n", " max_running_seqs: int = 256,\n", " max_model_len: int = 4096,\n", @@ -606,6 +607,8 @@ " f\"--max_running_seqs={max_running_seqs}\",\n", " f\"--max_model_len={max_model_len}\",\n", " ]\n", + " if disagg_topology:\n", + " hexllm_args.append(f\"--disagg_topo={disagg_topology}\")\n", "\n", " env_vars = {\n", " \"MODEL_ID\": base_model_id,\n",