Skip to content

Commit

Permalink
Add eps to histogram()
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Dec 29, 2024
1 parent 1a1e224 commit 6d88f3b
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 4 deletions.
6 changes: 4 additions & 2 deletions phiml/math/_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3036,7 +3036,7 @@ def ravel_index(index: Tensor, resolution: Shape, dim=channel, mode='undefined')



def histogram(values: Tensor, bins: Shape or Tensor = spatial(bins=30), weights=1, same_bins: DimFilter = None):
def histogram(values: Tensor, bins: Shape or Tensor = spatial(bins=30), weights=1, same_bins: DimFilter = None, eps=1e-5):
"""
Compute a histogram of a distribution of values.
Expand All @@ -3062,7 +3062,9 @@ def histogram(values: Tensor, bins: Shape or Tensor = spatial(bins=30), weights=
weights = wrap(weights)
if isinstance(bins, SHAPE_TYPES):
def equal_bins(v):
return linspace(finite_min(v, shape), finite_max(v, shape), bins.with_size(bins.size + 1))
lo, up = finite_min(v, shape), finite_max(v, shape)
margin = eps * (up - lo)
return linspace(lo, up+margin, bins.with_size(bins.size + 1))
bins = broadcast_op(equal_bins, [values], iter_dims=(batch(values) & batch(weights)).without(same_bins))
assert isinstance(bins, Tensor), f"bins must be a Tensor but got {type(bins)}"
assert non_batch(bins).rank == 1, f"bins must contain exactly one spatial or instance dimension listing the bin edges but got shape {bins.shape}"
Expand Down
4 changes: 2 additions & 2 deletions tests/commit/math/test__ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,9 +628,9 @@ def test_scatter_any_all_1d(self):
def test_histogram_1d(self):
for backend in BACKENDS:
with backend:
data = vec(instance('losses'), 0, .1, .1, .2, .1, .2, .3, .5)
data = vec(instance('losses'), 0, .11, .11, .21, .11, .21, .31, .51)
hist, bin_edges, bin_center = math.histogram(data, instance(loss=10))
assert_close(hist, [1, 0, 3, 0, 2, 0, 1, 0, 0, 1])
assert_close(hist, [1, 0, 3, 0, 2, 0, 1, 0, 0, 1], msg=backend.name)

def test_sin(self):
for backend in BACKENDS:
Expand Down

0 comments on commit 6d88f3b

Please sign in to comment.