diff --git a/src/brevitas/graph/equalize.py b/src/brevitas/graph/equalize.py index f3a87d258..8713d8d25 100644 --- a/src/brevitas/graph/equalize.py +++ b/src/brevitas/graph/equalize.py @@ -647,7 +647,7 @@ def _no_equalize(): module=module, tensor_name="bias", transform_module=ScaleWeightParametrization( - scaling_factor=partial_scale.view_as(module.bias),))) + scaling_factor=partial_scale.view_as(module.bias),is_sink=False))) src_broadcast_size = [1] * module.weight.ndim src_broadcast_size[axis] = module.weight.size(axis) if fuse_scaling: