Skip to content

Commit

Permalink
Hex-LLM supports disaggregated serving as an experimental feature
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 707243638
  • Loading branch information
vertex-mg-bot authored and copybara-github committed Dec 18, 2024
1 parent dae9e79 commit b479ba3
Show file tree
Hide file tree
Showing 8 changed files with 36 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,7 @@
" tensor_parallel_size: int = 1,\n",
" machine_type: str = \"ct5lp-hightpu-1t\",\n",
" tpu_topology: str = \"1x1\",\n",
" disagg_topology: str = None,\n",
" hbm_utilization_factor: float = 0.6,\n",
" max_running_seqs: int = 256,\n",
" max_model_len: int = 4096,\n",
Expand Down Expand Up @@ -364,6 +365,8 @@
" f\"--max_running_seqs={max_running_seqs}\",\n",
" f\"--max_model_len={max_model_len}\",\n",
" ]\n",
" if disagg_topology:\n",
" hexllm_args.append(f\"--disagg_topo={disagg_topology}\")\n",
"\n",
" env_vars = {\n",
" \"MODEL_ID\": base_model_id,\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,7 @@
" tensor_parallel_size: int = 1,\n",
" machine_type: str = \"ct5lp-hightpu-1t\",\n",
" tpu_topology: str = \"1x1\",\n",
" disagg_topology: str = None,\n",
" hbm_utilization_factor: float = 0.6,\n",
" max_running_seqs: int = 256,\n",
" max_model_len: int = 4096,\n",
Expand Down Expand Up @@ -331,6 +332,8 @@
" f\"--max_running_seqs={max_running_seqs}\",\n",
" f\"--max_model_len={max_model_len}\",\n",
" ]\n",
" if disagg_topology:\n",
" hexllm_args.append(f\"--disagg_topo={disagg_topology}\")\n",
"\n",
" env_vars = {\n",
" \"MODEL_ID\": base_model_id,\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@
" tensor_parallel_size: int = 1,\n",
" machine_type: str = \"ct5lp-hightpu-1t\",\n",
" tpu_topology: str = \"1x1\",\n",
" disagg_topology: str = None,\n",
" hbm_utilization_factor: float = 0.6,\n",
" max_running_seqs: int = 256,\n",
" max_model_len: int = 4096,\n",
Expand Down Expand Up @@ -353,6 +354,8 @@
" f\"--max_running_seqs={max_running_seqs}\",\n",
" f\"--max_model_len={max_model_len}\",\n",
" ]\n",
" if disagg_topology:\n",
" hexllm_args.append(f\"--disagg_topo={disagg_topology}\")\n",
"\n",
" env_vars = {\n",
" \"MODEL_ID\": base_model_id,\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1197,6 +1197,7 @@
" tensor_parallel_size: int = 1,\n",
" machine_type: str = \"ct5lp-hightpu-1t\",\n",
" tpu_topology: str = \"1x1\",\n",
" disagg_topology: str = None,\n",
" hbm_utilization_factor: float = 0.6,\n",
" max_running_seqs: int = 256,\n",
" max_model_len: int = 4096,\n",
Expand Down Expand Up @@ -1238,6 +1239,8 @@
" f\"--max_running_seqs={max_running_seqs}\",\n",
" f\"--max_model_len={max_model_len}\",\n",
" ]\n",
" if disagg_topology:\n",
" hexllm_args.append(f\"--disagg_topo={disagg_topology}\")\n",
"\n",
" env_vars = {\n",
" \"MODEL_ID\": base_model_id,\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -637,6 +637,7 @@
" tensor_parallel_size: int = 1,\n",
" machine_type: str = \"ct5lp-hightpu-1t\",\n",
" tpu_topology: str = \"1x1\",\n",
" disagg_topology: str = None,\n",
" hbm_utilization_factor: float = 0.6,\n",
" max_running_seqs: int = 256,\n",
" max_model_len: int = 4096,\n",
Expand Down Expand Up @@ -678,6 +679,8 @@
" f\"--max_running_seqs={max_running_seqs}\",\n",
" f\"--max_model_len={max_model_len}\",\n",
" ]\n",
" if disagg_topology:\n",
" hexllm_args.append(f\"--disagg_topo={disagg_topology}\")\n",
"\n",
" env_vars = {\n",
" \"MODEL_ID\": base_model_id,\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -259,12 +259,20 @@
"\n",
"# @markdown Find Vertex AI prediction TPUv5e machine types in\n",
"# @markdown https://cloud.google.com/vertex-ai/docs/predictions/use-tpu#deploy_a_model.\n",
"# @markdown The 8B model variant requires 4 TPU v5e cores single host, and the 70B model variant requires 16 TPU v5e cores 4x4 multi host.\n",
"# @markdown Choose `ct5lp-hightpu-4t` for both 8B and 70B model variants. The multi-host topology will be automatically set based on the model size.\n",
"# @markdown Choose `ct5lp-hightpu-8t` for the 8B variant when you want to use the experimental disaggregated serving topology.\n",
"\n",
"# Sets ct5lp-hightpu-4t (4 TPU chips) to deploy models.\n",
"machine_type = \"ct5lp-hightpu-4t\"\n",
"machine_type = \"ct5lp-hightpu-4t\" # @param [\"ct5lp-hightpu-4t\", \"ct5lp-hightpu-8t\"]\n",
"# Note: 1 TPU V5 chip has only one core.\n",
"tpu_type = \"TPU_V5e\"\n",
"\n",
"# @markdown Set the disaggregated topology to balance the TTFT and TPOT.\n",
"# @markdown This is an **experimental** feature and is only supported for single host deployments.\n",
"# @markdown If want to enable the feature, set this parameter to a string of the form `\"num_prefill_workers,num_decode_workers\"`, like `\"3,1\"`.\n",
"disagg_topo = None # @param\n",
"\n",
"if \"8B\" in MODEL_ID:\n",
" tpu_count = 4\n",
" tpu_topo = \"1x4\"\n",
Expand All @@ -286,6 +294,8 @@
"# Server parameters.\n",
"tensor_parallel_size = tpu_count\n",
"\n",
"# @markdown Set the server parameters.\n",
"\n",
"# Fraction of HBM memory allocated for KV cache after model loading. A larger value improves throughput but gives higher risk of TPU out-of-memory errors with long prompts.\n",
"hbm_utilization_factor = 0.8 # @param\n",
"# Maximum number of running sequences in a continuous batch.\n",
Expand All @@ -307,6 +317,7 @@
" tensor_parallel_size: int = 1,\n",
" machine_type: str = \"ct5lp-hightpu-1t\",\n",
" tpu_topology: str = \"1x1\",\n",
" disagg_topology: str = None,\n",
" hbm_utilization_factor: float = 0.6,\n",
" max_running_seqs: int = 256,\n",
" max_model_len: int = 4096,\n",
Expand Down Expand Up @@ -348,6 +359,8 @@
" f\"--max_running_seqs={max_running_seqs}\",\n",
" f\"--max_model_len={max_model_len}\",\n",
" ]\n",
" if disagg_topology:\n",
" hexllm_args.append(f\"--disagg_topo={disagg_topology}\")\n",
"\n",
" env_vars = {\n",
" \"MODEL_ID\": base_model_id,\n",
Expand Down Expand Up @@ -396,6 +409,7 @@
" tensor_parallel_size=tensor_parallel_size,\n",
" machine_type=machine_type,\n",
" tpu_topology=tpu_topo,\n",
" disagg_topology=disagg_topo,\n",
" hbm_utilization_factor=hbm_utilization_factor,\n",
" max_running_seqs=max_running_seqs,\n",
" max_model_len=max_model_len,\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,6 +307,7 @@
" tensor_parallel_size: int = 1,\n",
" machine_type: str = \"ct5lp-hightpu-1t\",\n",
" tpu_topology: str = \"1x1\",\n",
" disagg_topology: str = None,\n",
" hbm_utilization_factor: float = 0.6,\n",
" max_running_seqs: int = 256,\n",
" max_model_len: int = 4096,\n",
Expand Down Expand Up @@ -348,6 +349,8 @@
" f\"--max_running_seqs={max_running_seqs}\",\n",
" f\"--max_model_len={max_model_len}\",\n",
" ]\n",
" if disagg_topology:\n",
" hexllm_args.append(f\"--disagg_topo={disagg_topology}\")\n",
"\n",
" env_vars = {\n",
" \"MODEL_ID\": base_model_id,\n",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@
" tensor_parallel_size: int = 1,\n",
" machine_type: str = \"ct5lp-hightpu-1t\",\n",
" tpu_topology: str = \"1x1\",\n",
" disagg_topology: str = None,\n",
" hbm_utilization_factor: float = 0.6,\n",
" max_running_seqs: int = 256,\n",
" max_model_len: int = 4096,\n",
Expand Down Expand Up @@ -606,6 +607,8 @@
" f\"--max_running_seqs={max_running_seqs}\",\n",
" f\"--max_model_len={max_model_len}\",\n",
" ]\n",
" if disagg_topology:\n",
" hexllm_args.append(f\"--disagg_topo={disagg_topology}\")\n",
"\n",
" env_vars = {\n",
" \"MODEL_ID\": base_model_id,\n",
Expand Down

0 comments on commit b479ba3

Please sign in to comment.