diff --git a/scripts/tornado_detection/train_tornado_tf.py b/scripts/tornado_detection/train_tornado_tf.py index 57ea33c..7382a20 100644 --- a/scripts/tornado_detection/train_tornado_tf.py +++ b/scripts/tornado_detection/train_tornado_tf.py @@ -122,6 +122,7 @@ def main(config): for x,y,w in ds_train: shp=get_shape(x) c_shp=x['coordinates'].shape + break in_shapes = (None,None,shp[-1]) c_shapes = (None,None,c_shp[-1])