Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
cli99 committed May 22, 2024
1 parent 65cc34e commit 2b92db3
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 21 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ repos:
- id: fix-encoding-pragma
args: [--remove]
- repo: https://github.com/pycqa/flake8
rev: 4.0.1
rev: 7.0.0
hooks:
- id: flake8
args: ["--config=.flake8"]
Expand Down
2 changes: 1 addition & 1 deletion llm_analysis/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
54 changes: 35 additions & 19 deletions llm_analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,7 +756,7 @@ def get_activation_memory_per_layernorm(
self.parallelism_config.sp_size) * dtype_bytes

def get_activation_memory_input_embedding(self, batch_size: int,
seq_len: int) -> float:
seq_len: int) -> float:
"""Get the memory (in bytes) required to store the activations of output embedding (logits)"""
return self.model_config.hidden_dim * batch_size * seq_len * self.dtype_config.activation_bits / BITS_PER_BYTE / self.parallelism_config.tp_size

Expand Down Expand Up @@ -1182,7 +1182,8 @@ def get_latency_fwd_per_layernorm(
float: the latency in seconds for the forward pass of a single layernorm in a transformer layer
"""
input_numel = seq_len * batch_size * self.model_config.hidden_dim
compute_latency = input_numel * 5 / (self.get_TFLOPS_per_gpu() * 10**12)
compute_latency = input_numel * 5 / (self.get_TFLOPS_per_gpu() *
10**12)
activation_memory = self.get_activation_memory_per_layernorm(
batch_size,
seq_len,
Expand Down Expand Up @@ -2069,8 +2070,9 @@ def training(
3 + int(fwd_prefetch) +
int(bwd_prefetch)) * (unsharded_weight_memory_per_layer)

estimated_prefetch_memory_per_gpu = max(estimated_fwd_prefetch_memory_per_gpu,
estimated_bwd_prefetch_memory_per_gpu)
estimated_prefetch_memory_per_gpu = max(
estimated_fwd_prefetch_memory_per_gpu,
estimated_bwd_prefetch_memory_per_gpu)

memory_left = (self.gpu_config.mem_per_GPU_in_GB * 1024**3 -
weight_memory_per_gpu - optimizer_state_memory_per_gpu)
Expand All @@ -2079,22 +2081,20 @@ def training(
f"weight_memory_per_gpu: {_num_to_string(weight_memory_per_gpu)}B (embedding_memory: {_num_to_string(weight_memory_embedding_per_gpu)}B), optimizer_state_memory_per_gpu: {_num_to_string(optimizer_state_memory_per_gpu)}B, gradient_memory_per_gpu: {_num_to_string(gradient_memory_per_gpu)}B, estimated_fwd_prefetch_memory_per_gpu: {_num_to_string(estimated_fwd_prefetch_memory_per_gpu)}B, estimated_bwd_prefetch_memory_per_gpu: {_num_to_string(estimated_bwd_prefetch_memory_per_gpu)}B"
)


if memory_left < 0:
logger.warning(
"model weight/optimizer state memory usage is too large to fit in GPU memory"
)

if memory_left - max(estimated_prefetch_memory_per_gpu, gradient_memory_per_gpu) < 0:
if memory_left - max(estimated_prefetch_memory_per_gpu,
gradient_memory_per_gpu) < 0:
logger.warning(
"model gradient or bwd prefetch memory usage is too large to fit in GPU memory"
)

loss_bwd_memory_batch_size_1 = self.get_loss_bwd_memory(1, seq_len)
if memory_left - loss_bwd_memory_batch_size_1 < 0:
logger.warning(
"loss_bwd_memory is too large to fit in GPU memory"
)
logger.warning("loss_bwd_memory is too large to fit in GPU memory")

# With pipeline parallelism, each stage contains L/p layers so the first stage must store p ×L/p = L layers worth of activations regardless of the pipeline parallel size p; activation memory required for the input embeddings, the last layer-norm, and the output layer are ignored here. Refer to https://arxiv.org/abs/2205.05198 for more details.

Expand All @@ -2121,7 +2121,8 @@ def training(
layernorm_activation_memory_per_layer_batch_size_1)
]

activation_memory_input_embedding_batch_size_1 = self.get_activation_memory_input_embedding(1, seq_len)
activation_memory_input_embedding_batch_size_1 = self.get_activation_memory_input_embedding(
1, seq_len)
activation_memory_batch_size_1 += activation_memory_input_embedding_batch_size_1
activation_memory_output_embedding_batch_size_1 = self.get_activation_memory_output_embedding(
1, seq_len)
Expand All @@ -2132,7 +2133,9 @@ def training(
layernorm_dtype_bytes,
)

if memory_left - max(estimated_prefetch_memory_per_gpu, loss_bwd_memory_batch_size_1) < activation_memory_batch_size_1:
if memory_left - max(
estimated_prefetch_memory_per_gpu,
loss_bwd_memory_batch_size_1) < activation_memory_batch_size_1:
logger.warning(
f"memory_left {_num_to_string(memory_left)} < activation_memory_batch_size_1 {_num_to_string(activation_memory_batch_size_1)}"
)
Expand All @@ -2143,7 +2146,10 @@ def training(

max_batch_size_per_gpu = int(memory_left //
activation_memory_batch_size_1)
while memory_left < max(estimated_prefetch_memory_per_gpu, self.get_loss_bwd_memory(max_batch_size_per_gpu, seq_len)) + activation_memory_batch_size_1 * max_batch_size_per_gpu:
while memory_left < max(
estimated_prefetch_memory_per_gpu,
self.get_loss_bwd_memory(max_batch_size_per_gpu, seq_len)
) + activation_memory_batch_size_1 * max_batch_size_per_gpu:
max_batch_size_per_gpu -= 1

logger.info(
Expand Down Expand Up @@ -2201,7 +2207,8 @@ def training(

loss_bwd_memory = self.get_loss_bwd_memory(batch_size_per_gpu, seq_len)

if memory_left < activation_memory_per_gpu + max(estimated_prefetch_memory_per_gpu, loss_bwd_memory):
if memory_left < activation_memory_per_gpu + max(
estimated_prefetch_memory_per_gpu, loss_bwd_memory):
logger.warning(
"activation_memory_per_gpu memory or loss_bwd_memory is too large with batch_size_per_gpu ="
f" {batch_size_per_gpu} to fit in GPU memory (requiring"
Expand All @@ -2210,7 +2217,8 @@ def training(
f" {_num_to_string(memory_left)}B, max_batch_size_per_gpu ="
f" {max_batch_size_per_gpu})")

memory_left = memory_left - activation_memory_per_gpu - max(estimated_prefetch_memory_per_gpu, loss_bwd_memory)
memory_left = memory_left - activation_memory_per_gpu - max(
estimated_prefetch_memory_per_gpu, loss_bwd_memory)

num_flops_fwd_total = self.get_num_flops_fwd_total(
batch_size_per_gpu, seq_len)
Expand Down Expand Up @@ -2271,11 +2279,16 @@ def training(
num_layers_per_gpu = int(self.model_config.num_layers /
self.parallelism_config.pp_size)
if activation_recomputation == ActivationRecomputation.FULL:
latency_recompute = num_layers_per_gpu * (latency_fwd_per_layer_attn_compute + latency_fwd_per_layer_mlp_compute + 2 * latency_fwd_per_layernorm_compute)
latency_recompute = num_layers_per_gpu * (
latency_fwd_per_layer_attn_compute +
latency_fwd_per_layer_mlp_compute +
2 * latency_fwd_per_layernorm_compute)
elif activation_recomputation == ActivationRecomputation.NORM_ATTN_NORM:
latency_recompute = num_layers_per_gpu * (latency_fwd_per_layer_attn_compute + 2 * latency_fwd_per_layernorm_compute)
latency_recompute = num_layers_per_gpu * (
latency_fwd_per_layer_attn_compute +
2 * latency_fwd_per_layernorm_compute)
elif activation_recomputation == ActivationRecomputation.ATTN:
latency_recompute = num_layers_per_gpu * latency_fwd_per_layer_attn_compute
latency_recompute = num_layers_per_gpu * latency_fwd_per_layer_attn_compute
elif activation_recomputation == ActivationRecomputation.ATTN_COMPUTE:
latency_recompute = num_layers_per_gpu * self.get_num_flops_total_attn_compute(
batch_size_per_gpu, seq_len) / (
Expand Down Expand Up @@ -2448,14 +2461,17 @@ def training(
"(weight+op_state+grad)_memory_per_gpu":
self.weight_grad_op_state_memory_per_gpu,
"estimated_peak_memory_per_gpu":
optimizer_state_memory_per_gpu + weight_memory_per_gpu + max(activation_memory_per_gpu, gradient_memory_per_gpu) + max(estimated_bwd_prefetch_memory_per_gpu, loss_bwd_memory),
optimizer_state_memory_per_gpu + weight_memory_per_gpu +
max(activation_memory_per_gpu, gradient_memory_per_gpu) +
max(estimated_bwd_prefetch_memory_per_gpu, loss_bwd_memory),
"latency_per_micro_batch":
latency_per_micro_batch,
"latency_fwd":
latency_fwd,
}
summary_dict.update(latency_fwd_breakdown)
device_tokens_per_sec = round(seq_len * batch_size_per_gpu / latency_per_iter, 2)
device_tokens_per_sec = round(
seq_len * batch_size_per_gpu / latency_per_iter, 2)
summary_dict.update({
"latency_per_iter": latency_per_iter,
"device_tokens_per_sec": device_tokens_per_sec,
Expand Down

0 comments on commit 2b92db3

Please sign in to comment.