diff --git a/src/brevitas/nn/mixin/base.py b/src/brevitas/nn/mixin/base.py index 6e48339ab..bad417939 100644 --- a/src/brevitas/nn/mixin/base.py +++ b/src/brevitas/nn/mixin/base.py @@ -79,7 +79,9 @@ def unpack_input(self, inp: Union[Tensor, QuantTensor]) -> Union[Tensor, QuantTe if not torch._C._get_tracing_state(): if isinstance(inp, QuantTensor): inp = inp.set(value=inp.value.rename(None)) - elif isinstance(inp, Tensor): + else: + if isinstance(inp, Tensor) is False: + inp = torch.tensor(inp) inp = inp.rename(None) return inp