diff --git a/src/brevitas/core/function_wrapper/learned_round.py b/src/brevitas/core/function_wrapper/learned_round.py index ffc69b7da..8387c839c 100644 --- a/src/brevitas/core/function_wrapper/learned_round.py +++ b/src/brevitas/core/function_wrapper/learned_round.py @@ -65,8 +65,7 @@ def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor: return floor_ste(x) + p -# TODO: Restore JIT compatibility -class LearnedRoundIdentity(torch.nn.Module): +class LearnedRoundIdentity(brevitas.jit.ScriptModule): """ Implementation for LearnedRound learned parameter Adapted from https://arxiv.org/abs/2309.05516 @@ -75,15 +74,14 @@ class LearnedRoundIdentity(torch.nn.Module): def __init__(self) -> None: super(LearnedRoundIdentity, self).__init__() self.tensor_clamp = TensorClampSte() + self.upper_lower_bound = brevitas.jit.Attribute(0.5, float) - @brevitas.jit.ignore def forward(self, p: torch.Tensor) -> torch.Tensor: return self.tensor_clamp( p, - min_val=torch.tensor(-0.5, device=p.device), - max_val=torch.tensor(+0.5, device=p.device)) + min_val=torch.tensor(-self.upper_lower_bound).type_as(p), + max_val=torch.tensor(self.upper_lower_bound).type_as(p)) - @brevitas.jit.ignore def round_forward(self, x: torch.Tensor, p: torch.Tensor) -> torch.Tensor: return round_ste(x + p) diff --git a/tests/brevitas/optim/test_sign_sgd.py b/tests/brevitas/optim/test_sign_sgd.py index 5970ac262..2aca86c82 100644 --- a/tests/brevitas/optim/test_sign_sgd.py +++ b/tests/brevitas/optim/test_sign_sgd.py @@ -79,9 +79,10 @@ class TestOptimSignSGD: @device_dtype_parametrize @pytest_cases.parametrize("lr", [0.1]) - @requires_pt_ge('2.1') # TODO: revisit this + @requires_pt_ge('2.1') # TODO: revisit this def test_sign_sgd_single_update(self, device, dtype, lr): from brevitas.optim.sign_sgd import SignSGD + # Initialize weights and grads weights = Parameter(REFERENCE_WEIGHTS.to(device=device, dtype=dtype)) # Initialize tensors to compute expected result @@ -104,6 +105,7 @@ def test_sign_sgd_single_update(self, device, dtype, lr): @requires_pt_ge('2.1') def test_forloop_goes_right_direction(self, device, dtype, optimizer_kwargs, lr_scheduler_args): from brevitas.optim.sign_sgd import SignSGD + # PyTorch version previous to 2.3.1. might no have mv (addmv_impl_cpu) implemented for Half if dtype == torch.float16 and device == "cpu" and torch_version < parse('2.3.1'): pytest.xfail(