Skip to content

Commit

Permalink
Cleanup ops/transformer/inference tests (#6830)
Browse files Browse the repository at this point in the history
  • Loading branch information
loadams authored Jan 3, 2025
1 parent 456c9ac commit a8ede3a
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 6 deletions.
2 changes: 0 additions & 2 deletions tests/unit/ops/transformer/inference/test_bias_add.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)

torch_minor_version = None


def run_bias_add_reference(activations, bias):
return activations + bias
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/ops/transformer/inference/test_bias_gelu.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,8 @@
from deepspeed.ops.op_builder import InferenceBuilder
from deepspeed.ops.transformer import DeepSpeedInferenceConfig
from deepspeed.ops.transformer.inference.op_binding.bias_gelu import BiasGeluOp
from deepspeed.utils.torch import required_torch_version
from .inference_test_utils import allclose, get_dtypes
from packaging import version as pkg_version

if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)
Expand All @@ -34,7 +34,7 @@ def run_bias_gelu_ds(activations, bias):
@pytest.mark.parametrize("channels", [512, 1232, 4096])
@pytest.mark.parametrize("dtype", get_dtypes())
def test_bias_gelu(batch, sequence, channels, dtype):
if pkg_version.parse(torch.__version__) < pkg_version.parse("1.12"):
if not required_torch_version(min_version=1.12):
pytest.skip("gelu implementation matches only after torch 1.12")

activations_ds = torch.randn((batch, sequence, channels), dtype=dtype, device=get_accelerator().device_name())
Expand Down
2 changes: 0 additions & 2 deletions tests/unit/ops/transformer/inference/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]:
pytest.skip("Inference ops are not available on this system", allow_module_level=True)

inference_module = None


def allclose(x, y):
assert x.dtype == y.dtype
Expand Down

0 comments on commit a8ede3a

Please sign in to comment.