Skip to content

Commit

Permalink
Review
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 20, 2023
1 parent 0260988 commit f822e45
Showing 1 changed file with 7 additions and 2 deletions.
9 changes: 7 additions & 2 deletions src/brevitas/quant_tensor/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,14 @@ def contiguous(self):
def int(self, float_datatype=False):
if self.is_valid:
# After rounding, cast to the original dtype of the scale factor
int_value = round_ste(self._pre_round_int_value).type(self.scale.dtype)
int_value = round_ste(self._pre_round_int_value)
if float_datatype:
return int_value
# values in 8bit and lower can be represented exactly with float16 and bfloat16
# otherwise (e.g. Int16 bias), we upscale to float32
if self.bit_width <= 8.:
return int_value.type(self.scale.dtype)
else:
return int_value.type(torch.float32)
else:
if self.bit_width <= 8. and self.signed_t.item():
return int_value.to(torch.int8)
Expand Down

0 comments on commit f822e45

Please sign in to comment.