Skip to content

Commit

Permalink
new try
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 5, 2024
1 parent 1a2f5f2 commit 9e80ba0
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 5 deletions.
4 changes: 1 addition & 3 deletions src/brevitas/core/stats/stats_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,7 @@ def forward(self, x: Optional[torch.Tensor] = None) -> torch.Tensor:
stats_input = self.first_tracked_param(None)
for extra_tracked_param in self.extra_tracked_params_list:
stats_input = extra_tracked_param(stats_input)
elif x is not None:
stats_input = self.first_tracked_param(x)
else:
raise RuntimeError("An input is needed to compute the statistics")
stats_input = self.first_tracked_param(x)
out = self.stats(stats_input)
return out
7 changes: 5 additions & 2 deletions src/brevitas/core/stats/view_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,11 @@ def __init__(self, view_shape_impl: Module) -> None:
self.view_shape_impl = view_shape_impl

@brevitas.jit.script_method
def forward(self, x: Tensor) -> Tensor:
return self.view_shape_impl(x)
def forward(self, x: Optional[Tensor]) -> Tensor:
if x is not None:
return self.view_shape_impl(x)
else:
raise RuntimeError("Input cannot be None")


class _ViewCatParameterWrapper(brevitas.jit.ScriptModule):
Expand Down

0 comments on commit 9e80ba0

Please sign in to comment.