diff --git a/src/brevitas/graph/base.py b/src/brevitas/graph/base.py index 13cd19d2e..d1631f34e 100644 --- a/src/brevitas/graph/base.py +++ b/src/brevitas/graph/base.py @@ -18,7 +18,6 @@ from brevitas.fx import immutable_dict from brevitas.fx import Node from brevitas.graph.utils import * -from brevitas.nn import ScaledDotProductAttention from brevitas.utils.python_utils import islambda from brevitas.utils.rotation_utils import RotationWeightParametrization @@ -122,7 +121,6 @@ def _map_origin_vars(self, vars: dict): def _module_attributes(self, module): attrs = vars(module) - # workaround since bias doesn't show up on vars of Linear if hasattr(module, 'bias'): attrs['bias'] = module.bias