Skip to content

Commit

Permalink
delete unnecessary tests
Browse files Browse the repository at this point in the history
  • Loading branch information
garrett361 committed Jun 10, 2024
1 parent 5aef982 commit 7fddafa
Showing 1 changed file with 35 additions and 77 deletions.
112 changes: 35 additions & 77 deletions blog/act-mem-2/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,13 @@
torch.empty(1, 1, device="cuda") @ torch.empty(1, 1, device="cuda")


ZERO_MEM_ACT_FNS = [nn.ReLU(), nn.Sigmoid(), nn.Tanh(), nn.LeakyReLU(inplace=True), nn.Sigmoid()]
ZERO_MEM_ACT_FNS = [
nn.ReLU(),
nn.Sigmoid(),
nn.Tanh(),
nn.LeakyReLU(inplace=True),
nn.Sigmoid(),
]
ALL_ACT_FNS = ZERO_MEM_ACT_FNS + [
nn.ELU(),
nn.GELU(),
Expand Down Expand Up @@ -93,9 +99,13 @@ def test_mlp(
times as large as the MLP's inputs). The MLP activation memory can be nearly halved by a
choice of activation function.
"""
inputs = torch.randn(batch_size, seq_len, d_model, requires_grad=True, device=device)
inputs = torch.randn(
batch_size, seq_len, d_model, requires_grad=True, device=device
)
expansion_factor = 4
mlp = layers.MLP(d_model=d_model, act_fn=act_fn, dropout_prob=dropout_prob, device=device)
mlp = layers.MLP(
d_model=d_model, act_fn=act_fn, dropout_prob=dropout_prob, device=device
)
with act_mem.SavedTensorContext(ignored_tensors=mlp.parameters()) as saved:
_ = mlp(inputs)

Expand All @@ -104,10 +114,15 @@ def test_mlp(
second_lin_input_mem = expansion_factor * first_lin_input_mem
# Only some activations require additional activation memory
activation_input_mem = 0 if act_fn in ZERO_MEM_ACT_FNS else second_lin_input_mem
dropout_act_mem = 0 if not dropout_prob else inputs.numel() * (4 if device == "cpu" else 1)
dropout_act_mem = (
0 if not dropout_prob else inputs.numel() * (4 if device == "cpu" else 1)
)

expected_mem = (
first_lin_input_mem + second_lin_input_mem + activation_input_mem + dropout_act_mem
first_lin_input_mem
+ second_lin_input_mem
+ activation_input_mem
+ dropout_act_mem
)
assert saved.saved_tensor_mem == expected_mem

Expand All @@ -130,9 +145,13 @@ def test_mlp_amp(
Similar story with AMP. The only changes come from the modified dtypes and needing to also
save references to the low-precision weights in the Linear layers.
"""
inputs = torch.randn(batch_size, seq_len, d_model, requires_grad=True, device=device)
inputs = torch.randn(
batch_size, seq_len, d_model, requires_grad=True, device=device
)
expansion_factor = 4
mlp = layers.MLP(d_model=d_model, act_fn=act_fn, dropout_prob=dropout_prob, device=device)
mlp = layers.MLP(
d_model=d_model, act_fn=act_fn, dropout_prob=dropout_prob, device=device
)
dtype = torch.bfloat16
with torch.autocast(device_type=device, dtype=dtype):
with act_mem.SavedTensorContext(ignored_tensors=mlp.parameters()) as saved:
Expand All @@ -145,7 +164,9 @@ def test_mlp_amp(
# Only some activations require additional activation memory
activation_input_mem = 0 if act_fn in ZERO_MEM_ACT_FNS else second_lin_input_mem
dropout_act_mem = (
0 if not dropout_prob else inputs.numel() * (dtype.itemsize if device == "cpu" else 1)
0
if not dropout_prob
else inputs.numel() * (dtype.itemsize if device == "cpu" else 1)
)

expected_mem = (
Expand All @@ -155,7 +176,9 @@ def test_mlp_amp(
+ activation_input_mem
+ dropout_act_mem
)
assert saved.saved_tensor_mem == expected_mem, f"Failed on {act_fn=}, {dropout_prob=}"
assert (
saved.saved_tensor_mem == expected_mem
), f"Failed on {act_fn=}, {dropout_prob=}"


@pytest.mark.skipif(not torch.cuda.is_available(), reason="cuda not available")
Expand All @@ -164,7 +187,9 @@ class TestCUDAMemReadings:
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("act_fn", ALL_ACT_FNS)
def test_mlp(self, d_model: int, batch_size: int, seq_len: int, act_fn: nn.Module) -> None:
def test_mlp(
self, d_model: int, batch_size: int, seq_len: int, act_fn: nn.Module
) -> None:
"""
Track saved tensors and allocated memory and verify they agree.
"""
Expand All @@ -181,70 +206,3 @@ def test_mlp(self, d_model: int, batch_size: int, seq_len: int, act_fn: nn.Modul
# captures inputs and not outputs. Nevertheless, the readings agree because inputs and
# outputs are tensors of the same size and `dtype`.
assert mem.delta["current"] == saved.saved_tensor_mem


# TODO: @garrett.goon - Delete these
class TestLayers:
@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("d_model", D_MODELS)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("dropout_prob", (None, 0.5))
def test_mlp(
self,
device: str,
d_model: int,
batch_size: int,
seq_len: int,
dropout_prob: Optional[float],
) -> None:
mlp = layers.MLP(
d_model=d_model, act_fn=nn.ReLU(), device=device, dropout_prob=dropout_prob
)
inputs = torch.randn(batch_size, seq_len, d_model, device=device)
outputs = mlp(inputs)
assert outputs.shape == inputs.shape

@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("d_model", D_MODELS)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("n_heads", N_HEADS)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
def test_attention(
self,
device: str,
d_model: int,
batch_size: int,
seq_len: int,
n_heads: int,
) -> None:
attn = layers.Attention(d_model=d_model, n_heads=n_heads, device=device)
inputs = torch.randn(batch_size, seq_len, d_model, device=device)
outputs = attn(inputs)
assert outputs.shape == inputs.shape

@pytest.mark.parametrize("device", DEVICES)
@pytest.mark.parametrize("d_model", D_MODELS)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("seq_len", SEQ_LENS)
@pytest.mark.parametrize("n_heads", N_HEADS)
@pytest.mark.parametrize("dropout_prob", (None, 0.5))
def test_block(
self,
device: str,
d_model: int,
batch_size: int,
seq_len: int,
n_heads: int,
dropout_prob: Optional[float],
) -> None:
block = layers.Block(
d_model=d_model,
n_heads=n_heads,
act_fn=nn.ReLU(),
device=device,
dropout_prob=dropout_prob,
)
inputs = torch.randn(batch_size, seq_len, d_model, device=device)
outputs = block(inputs)
assert outputs.shape == inputs.shape

0 comments on commit 7fddafa

Please sign in to comment.