Skip to content

Commit

Permalink
Provide overloads for Raster operators where ints and floats are mixed (
Browse files Browse the repository at this point in the history
#215)

Provides a separate set of operator overloads for integral Rasters when the other operand is a floating point number.

We don't support pow and sqrt for ints with -Wfloat-conversion.

Additionally, it fixes missing use of cost reference for a read-only parameter.
  • Loading branch information
wenzeslaus authored Mar 12, 2024
1 parent 2056c47 commit fc1fa92
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 8 deletions.
77 changes: 71 additions & 6 deletions include/pops/raster.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -275,37 +275,101 @@ class Raster
}

template<typename OtherNumber>
Raster& operator+=(OtherNumber value)
typename std::enable_if<
!(std::is_floating_point<OtherNumber>::value
&& std::is_integral<Number>::value),
Raster&>::type
operator+=(OtherNumber value)
{
std::for_each(
data_, data_ + (cols_ * rows_), [&value](Number& a) { a += value; });
return *this;
}

template<typename OtherNumber>
Raster& operator-=(OtherNumber value)
typename std::enable_if<
!(std::is_floating_point<OtherNumber>::value
&& std::is_integral<Number>::value),
Raster&>::type
operator-=(OtherNumber value)
{
std::for_each(
data_, data_ + (cols_ * rows_), [&value](Number& a) { a -= value; });
return *this;
}

template<typename OtherNumber>
Raster& operator*=(OtherNumber value)
typename std::enable_if<
!(std::is_floating_point<OtherNumber>::value
&& std::is_integral<Number>::value),
Raster&>::type
operator*=(OtherNumber value)
{
std::for_each(
data_, data_ + (cols_ * rows_), [&value](Number& a) { a *= value; });
return *this;
}

template<typename OtherNumber>
Raster& operator/=(OtherNumber value)
typename std::enable_if<
!(std::is_floating_point<OtherNumber>::value
&& std::is_integral<Number>::value),
Raster&>::type
operator/=(OtherNumber value)
{
std::for_each(
data_, data_ + (cols_ * rows_), [&value](Number& a) { a /= value; });
return *this;
}

template<typename OtherNumber>
typename std::enable_if<
std::is_floating_point<OtherNumber>::value && std::is_integral<Number>::value,
Raster&>::type
operator+=(OtherNumber value)
{
std::for_each(data_, data_ + (cols_ * rows_), [&value](Number& a) {
a += static_cast<int>(std::floor(value));
});
return *this;
}

template<typename OtherNumber>
typename std::enable_if<
std::is_floating_point<OtherNumber>::value && std::is_integral<Number>::value,
Raster&>::type
operator-=(OtherNumber value)
{
std::for_each(data_, data_ + (cols_ * rows_), [&value](Number& a) {
a -= static_cast<int>(std::floor(value));
});
return *this;
}

template<typename OtherNumber>
typename std::enable_if<
std::is_floating_point<OtherNumber>::value && std::is_integral<Number>::value,
Raster&>::type
operator*=(OtherNumber value)
{
std::for_each(data_, data_ + (cols_ * rows_), [&value](Number& a) {
a *= static_cast<int>(std::floor(value));
});
return *this;
}

template<typename OtherNumber>
typename std::enable_if<
std::is_floating_point<OtherNumber>::value && std::is_integral<Number>::value,
Raster&>::type
operator/=(OtherNumber value)
{
std::for_each(data_, data_ + (cols_ * rows_), [&value](Number& a) {
a /= static_cast<int>(std::floor(value));
});
return *this;
}

template<typename OtherNumber>
typename std::enable_if<
std::is_floating_point<Number>::value
Expand Down Expand Up @@ -496,12 +560,13 @@ class Raster
return out;
}

friend inline Raster pow(Raster image, double value)
friend inline Raster pow(const Raster& image, double value)
{
image.for_each([value](Number& a) { a = std::pow(a, value); });
return image;
}
friend inline Raster sqrt(Raster image)

friend inline Raster sqrt(const Raster& image)
{
image.for_each([](Number& a) { a = std::sqrt(a); });
return image;
Expand Down
17 changes: 15 additions & 2 deletions tests/test_raster.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,22 @@ static int test_multiply_in_place_operator()
}
}

static void test_pow()
{
Raster<double> a = {{4, 5}, {2, 3}};
Raster<double> b = {{16, 25}, {4, 9}};
auto c = pow(a, 2);
std::cout << "pow function: ";
if (b == c)
std::cout << "OK" << std::endl;
else
std::cout << "\n" << a << "!=\n" << b << std::endl;
}

static void test_sqrt()
{
Raster<int> a = {{16, 25}, {4, 9}};
Raster<int> b = {{4, 5}, {2, 3}};
Raster<double> a = {{16, 25}, {4, 9}};
Raster<double> b = {{4, 5}, {2, 3}};
auto c = sqrt(a);
std::cout << "sqrt function: ";
if (b == c)
Expand Down Expand Up @@ -469,6 +481,7 @@ int main()
test_plus_operator();
test_multiply_in_place_operator();

test_pow();
test_sqrt();

// all doubles, no problem
Expand Down

0 comments on commit fc1fa92

Please sign in to comment.