diff --git a/phi/field/_field.py b/phi/field/_field.py index fa3822d29..c956324f9 100644 --- a/phi/field/_field.py +++ b/phi/field/_field.py @@ -206,7 +206,7 @@ def shape(self) -> Shape: if self.is_grid and '~vector' in self.values.shape: return batch(self.geometry) & self.resolution & non_dual(self.values).without(self.resolution) & self.geometry.shape['vector'] set_shape = self.geometry.sets[self.sampled_at] - return batch(self.geometry) & (channel(self.geometry) - 'vector') & set_shape & self.values + return batch(self.geometry) & (channel(self.geometry) - 'vector') & set_shape & self.values.shape @property def resolution(self): diff --git a/phi/vis/_vis.py b/phi/vis/_vis.py index e91165be4..641d99baa 100644 --- a/phi/vis/_vis.py +++ b/phi/vis/_vis.py @@ -437,11 +437,12 @@ def layout_color(content: Dict[Tuple[int, int], List[Field]], indices: Dict[Tupl if (color[idx] != None).all: # user-specified color result_pos.append(color[idx]) else: - cmap = requires_color_map(f) + cmap: bool = requires_color_map(f) channels = channel(f).without('vector') channel_colors = counter + math.range_tensor(channels) - result_pos.append(math.where(cmap, wrap('cmap'), channel_colors)) - counter += channels.volume * math.any(~cmap, shape) + result_pos.append(wrap('cmap') if cmap else channel_colors) + if not cmap: + counter += channels.volume return result