Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
cli99 committed Mar 6, 2024
1 parent 6bacd4f commit 8135689
Show file tree
Hide file tree
Showing 7 changed files with 75 additions and 57 deletions.
13 changes: 12 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,18 @@ repos:
rev: v0.32.0
hooks:
- id: yapf
additional_dependencies: [toml]
name: yapf
description: "A formatter for Python files."
entry: yapf
args: [-i, -vv, -p] # inplace
language: python
types: [python]
additional_dependencies:
- "toml"
- repo: https://github.com/pycqa/isort
hooks:
- id: isort
rev: 5.12.0
- repo: https://github.com/codespell-project/codespell
rev: v2.1.0
hooks:
Expand Down
12 changes: 5 additions & 7 deletions examples/llama2/run_infer_cursor.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
from llm_analysis.config import (
ParallelismConfig,
get_dtype_config_by_name,
get_gpu_config_by_name,
get_model_config_by_name,
)
from llm_analysis.analysis import LLMAnalysis
import csv

from llm_analysis.analysis import LLMAnalysis
from llm_analysis.config import (ParallelismConfig, get_dtype_config_by_name,
get_gpu_config_by_name,
get_model_config_by_name)

gpu_name = "a100-sxm-80gb"
dtype_name = "w16a16e16"
model_name = "upstage/Llama-2-70b-instruct-v2"
Expand Down
68 changes: 45 additions & 23 deletions llm_analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,10 @@

import fire

from llm_analysis.config import (
DtypeConfig,
GPUConfig,
ModelConfig,
ParallelismConfig,
get_dtype_config_by_name,
get_gpu_config_by_name,
get_model_config_by_name,
)
from llm_analysis.config import (DtypeConfig, GPUConfig, ModelConfig,
ParallelismConfig, get_dtype_config_by_name,
get_gpu_config_by_name,
get_model_config_by_name)
from llm_analysis.constant import *
from llm_analysis.logger import logger
from llm_analysis.utils import _latency_to_string, _num_to_string, within_range
Expand Down Expand Up @@ -361,9 +356,11 @@ def get_num_active_params_total(self) -> int:
self.get_num_params_last_layernorm())

def get_weight_memory_per_layer(
self,
ds_zero: DSZeRO = DSZeRO.NONE,
return_breakdown: bool = False) -> Union[float, tuple]:
self,
is_sharded: bool = False,
ds_zero: DSZeRO = DSZeRO.NONE,
return_breakdown: bool = False,
) -> Union[float, tuple]:
"""Get the memory (in bytes) required to store the weights of a transformer
layer, given the number of parameters in a transformer layer, the data type used
for the weights, the tensor parallelism size, and the DeepSpeed ZeRO stage. WIth
Expand All @@ -375,7 +372,7 @@ def get_weight_memory_per_layer(
Returns:
Union[float, tuple]: the memory (in bytes) required to store the weights of a transformer layer, or a tuple of its breakdown
"""
if ds_zero == DSZeRO.STAGE_3:
if is_sharded and ds_zero == DSZeRO.STAGE_3:
sharded_dp_size = self.parallelism_config.dp_size
mlp_sharded_dp_size = self.parallelism_config.dp_size / self.parallelism_config.ep_size
else:
Expand Down Expand Up @@ -530,6 +527,7 @@ def get_memory_optimizer_state_and_gradient_last_layernorm(
def get_memory_embedding(
self,
ds_zero: DSZeRO = DSZeRO.NONE,
is_sharded: bool = True,
) -> float:
"""Get the memory (in bytes) required to store the embedding layer, given the
number of parameters in the embedding layer, the data type (defaults to FP32)
Expand All @@ -545,6 +543,8 @@ def get_memory_embedding(
dtype_bytes = self.dtype_config.embedding_bits / BITS_PER_BYTE
memory_embedding = (self.get_num_params_embedding() /
self.parallelism_config.tp_size) * dtype_bytes
if not is_sharded:
return memory_embedding
if ds_zero == DSZeRO.STAGE_3:
memory_embedding /= self.parallelism_config.dp_size

Expand Down Expand Up @@ -1521,7 +1521,6 @@ def output_summary_dict(
log_str = self.get_readable_summary_dict(summary_dict)
file_name = self.get_configs_desc(
) + output_file_suffix + "-summary-readable.txt"
file_name = output_file_suffix + "-summary-readable.txt"
with open(os.path.join(output_dir, file_name), "w") as f:
f.write(log_str)
logger.info(
Expand Down Expand Up @@ -2018,17 +2017,20 @@ def training(
"num_layers not be divisible by pp_size, taking the floor")

weight_memory_embedding_per_gpu = self.get_memory_embedding(ds_zero)
unsharded_weight_memory_embedding = self.get_memory_embedding(
ds_zero, is_sharded=False)

weight_memory_layers_per_gpu, weight_memory_attn_per_gpu, weight_memory_mlp_per_gpu, weight_memory_layernorm_per_gpu = [
x * num_layers_per_gpu
for x in self.get_weight_memory_per_layer(ds_zero,
return_breakdown=True)
x * num_layers_per_gpu for x in self.get_weight_memory_per_layer(
is_sharded=True, ds_zero=ds_zero, return_breakdown=True)
]
weight_memory_last_layernorm = self.get_weight_memory_last_layernorm(
ds_zero)
weight_memory_per_gpu = (weight_memory_embedding_per_gpu +
weight_memory_layers_per_gpu +
weight_memory_last_layernorm)
unsharded_weight_memory_per_layer, unsharded_weight_memory_attn_per_layer, unsharded_weight_memory_mlp_per_layer, unshared_weight_memory_layernorm = self.get_weight_memory_per_layer(
is_sharded=False, ds_zero=ds_zero, return_breakdown=True)

optimizer_state_memory_per_layer, gradient_memory_per_layer = self.get_memory_optimizer_state_and_gradient_per_layer(
master_weights_dtype_bytes, other_op_bytes, ds_zero)
Expand All @@ -2042,8 +2044,9 @@ def training(
optimizer_state_memory_per_gpu = optimizer_state_memory_per_layer * num_layers_per_gpu + optimizer_state_memory_embedding + optimizer_state_memory_last_layernorm
gradient_memory_per_gpu = gradient_memory_per_layer * num_layers_per_gpu + gradient_memory_embedding + gradient_memory_last_layernorm

self.weight_grad_op_state_memory_per_gpu = weight_memory_per_gpu + gradient_memory_per_gpu + optimizer_state_memory_per_gpu

self.weight_grad_op_state_memory_per_gpu = (
weight_memory_per_gpu + optimizer_state_memory_per_gpu +
gradient_memory_per_gpu)
memory_left = (self.gpu_config.mem_per_GPU_in_GB * 1024**3 -
self.weight_grad_op_state_memory_per_gpu)

Expand Down Expand Up @@ -2349,12 +2352,22 @@ def training(
weight_memory_mlp_per_gpu,
"weight_memory_layernorm_per_gpu":
weight_memory_layernorm_per_gpu,
"unsharded_weight_memory_embedding":
unsharded_weight_memory_embedding,
"unsharded_weight_memory_per_layer":
unsharded_weight_memory_per_layer,
"unsharded_weight_memory_attn_per_layer":
unsharded_weight_memory_attn_per_layer,
"unsharded_weight_memory_mlp_per_layer":
unsharded_weight_memory_mlp_per_layer,
"unshared_weight_memory_layernorm":
unshared_weight_memory_layernorm,
"gradient_memory_per_gpu":
gradient_memory_per_gpu,
"optimizer_state_memory_per_gpu":
optimizer_state_memory_per_gpu,
"(weight+op_state+grad)_memory_per_gpu":
self.weight_grad_op_state_memory_per_gpu,
"(weight+op_state)_memory_per_gpu":
optimizer_state_memory_per_gpu + weight_memory_per_gpu,
"activation_memory_batch_size_1":
activation_memory_batch_size_1,
"activation_memory_per_gpu":
Expand All @@ -2367,9 +2380,18 @@ def training(
activation_memory_layernorm_per_gpu,
"activation_memory_embedding_output_per_gpu":
activation_memory_embedding_output_per_gpu,
"(weight+op_state+grad+act)_memory_per_gpu":
self.weight_grad_op_state_memory_per_gpu +
"(weight+op_state+act)_memory_per_gpu":
optimizer_state_memory_per_gpu + weight_memory_per_gpu +
activation_memory_per_gpu,
"end_fwd_memory_per_gpu":
optimizer_state_memory_per_gpu + weight_memory_per_gpu +
activation_memory_per_gpu + unsharded_weight_memory_embedding +
unsharded_weight_memory_per_layer,
"end_last3_bwd_memory_per_gpu":
optimizer_state_memory_per_gpu + weight_memory_per_gpu +
activation_memory_per_gpu + 3 *
(unsharded_weight_memory_per_layer +
unsharded_weight_memory_mlp_per_layer),
"memory_left_per_gpu":
memory_left,
"latency_per_micro_batch":
Expand Down
7 changes: 2 additions & 5 deletions llm_analysis/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,8 @@

import fire

from llm_analysis.constant import (
DTYPE_CONFIG_DIR_NAME,
GPU_CONFIG_DIR_NAME,
MODEL_CONFIG_DIR_NAME,
)
from llm_analysis.constant import (DTYPE_CONFIG_DIR_NAME, GPU_CONFIG_DIR_NAME,
MODEL_CONFIG_DIR_NAME)
from llm_analysis.logger import logger

try:
Expand Down
12 changes: 4 additions & 8 deletions tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from llm_analysis.config import (
ModelConfig,
GPUConfig,
DtypeConfig,
get_dtype_config_by_name,
get_gpu_config_by_name,
get_model_config_by_name,
)
from llm_analysis.config import (DtypeConfig, GPUConfig, ModelConfig,
get_dtype_config_by_name,
get_gpu_config_by_name,
get_model_config_by_name)


def test_get_model_config_by_name():
Expand Down
11 changes: 4 additions & 7 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from llm_analysis.utils import within_range
from llm_analysis.analysis import LLMAnalysis
from llm_analysis.config import (
ParallelismConfig,
get_dtype_config_by_name,
get_gpu_config_by_name,
get_model_config_by_name,
)
from llm_analysis.config import (ParallelismConfig, get_dtype_config_by_name,
get_gpu_config_by_name,
get_model_config_by_name)
from llm_analysis.utils import within_range

TOLERANCE = 0.1

Expand Down
9 changes: 3 additions & 6 deletions tests/test_training.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,9 @@
# limitations under the License.

from llm_analysis.analysis import ActivationRecomputation, DSZeRO, LLMAnalysis
from llm_analysis.config import (
ParallelismConfig,
get_dtype_config_by_name,
get_gpu_config_by_name,
get_model_config_by_name,
)
from llm_analysis.config import (ParallelismConfig, get_dtype_config_by_name,
get_gpu_config_by_name,
get_model_config_by_name)
from llm_analysis.utils import _latency_to_string, _num_to_string, within_range

TOLERANCE = 0.05
Expand Down

0 comments on commit 8135689

Please sign in to comment.