diff --git a/avalanche/training/plugins/ewc.py b/avalanche/training/plugins/ewc.py index 4a469ed8b..30c5cf841 100644 --- a/avalanche/training/plugins/ewc.py +++ b/avalanche/training/plugins/ewc.py @@ -89,7 +89,7 @@ def before_backward(self, strategy, **kwargs): # dynamic models may add new units # new units are ignored by the regularization n_units = saved_param.shape[0] - cur_param = saved_param[:n_units] + cur_param = cur_param[:n_units] penalty += (imp * (cur_param - saved_param).pow(2)).sum() elif self.mode == "online": prev_exp = exp_counter - 1 @@ -101,7 +101,7 @@ def before_backward(self, strategy, **kwargs): # dynamic models may add new units # new units are ignored by the regularization n_units = saved_param.shape[0] - cur_param = saved_param[:n_units] + cur_param = cur_param[:n_units] penalty += (imp * (cur_param - saved_param).pow(2)).sum() else: raise ValueError("Wrong EWC mode.")