diff --git a/tzrec/datasets/data_parser.py b/tzrec/datasets/data_parser.py index 760b357..76afeac 100644 --- a/tzrec/datasets/data_parser.py +++ b/tzrec/datasets/data_parser.py @@ -453,14 +453,9 @@ def _to_sparse_features( seq_length = length key_length = input_data[f"{key}.key_lengths"] # TODO: remove to_float when segment_reduce support int values - try: - length = torch.segment_reduce( - key_length.float(), "sum", lengths=seq_length - ).to(length.dtype) - except Exception: - import pdb - - pdb.set_trace() + length = torch.segment_reduce( + key_length.float(), "sum", lengths=seq_length + ).to(length.dtype) mulval_keys.append(key) mulval_seq_lengths.append(seq_length) mulval_key_lengths.append(key_length)