diff --git a/src/gflownet/models/seq_transformer.py b/src/gflownet/models/seq_transformer.py index 84e604d1..54557922 100644 --- a/src/gflownet/models/seq_transformer.py +++ b/src/gflownet/models/seq_transformer.py @@ -65,7 +65,6 @@ def logZ(self, cond_info: Optional[torch.Tensor]): return self._logZ(torch.ones((1, 1), device=self._logZ.weight.device)) return self._logZ(cond_info) - def forward(self, xs: SeqBatch, cond, batched=False): """Returns a GraphActionCategorical and a tensor of state predictions.