From fc1fa92dcc612d3473594313c4428208c6c71a33 Mon Sep 17 00:00:00 2001 From: Vaclav Petras Date: Tue, 12 Mar 2024 11:35:13 -0400 Subject: [PATCH] Provide overloads for Raster operators where ints and floats are mixed (#215) Provides a separate set of operator overloads for integral Rasters when the other operand is a floating point number. We don't support pow and sqrt for ints with -Wfloat-conversion. Additionally, it fixes missing use of cost reference for a read-only parameter. --- include/pops/raster.hpp | 77 +++++++++++++++++++++++++++++++++++++---- tests/test_raster.cpp | 17 +++++++-- 2 files changed, 86 insertions(+), 8 deletions(-) diff --git a/include/pops/raster.hpp b/include/pops/raster.hpp index 03ff60d8..54ccdc2c 100644 --- a/include/pops/raster.hpp +++ b/include/pops/raster.hpp @@ -275,7 +275,11 @@ class Raster } template - Raster& operator+=(OtherNumber value) + typename std::enable_if< + !(std::is_floating_point::value + && std::is_integral::value), + Raster&>::type + operator+=(OtherNumber value) { std::for_each( data_, data_ + (cols_ * rows_), [&value](Number& a) { a += value; }); @@ -283,7 +287,11 @@ class Raster } template - Raster& operator-=(OtherNumber value) + typename std::enable_if< + !(std::is_floating_point::value + && std::is_integral::value), + Raster&>::type + operator-=(OtherNumber value) { std::for_each( data_, data_ + (cols_ * rows_), [&value](Number& a) { a -= value; }); @@ -291,7 +299,11 @@ class Raster } template - Raster& operator*=(OtherNumber value) + typename std::enable_if< + !(std::is_floating_point::value + && std::is_integral::value), + Raster&>::type + operator*=(OtherNumber value) { std::for_each( data_, data_ + (cols_ * rows_), [&value](Number& a) { a *= value; }); @@ -299,13 +311,65 @@ class Raster } template - Raster& operator/=(OtherNumber value) + typename std::enable_if< + !(std::is_floating_point::value + && std::is_integral::value), + Raster&>::type + operator/=(OtherNumber value) { std::for_each( data_, data_ + (cols_ * rows_), [&value](Number& a) { a /= value; }); return *this; } + template + typename std::enable_if< + std::is_floating_point::value && std::is_integral::value, + Raster&>::type + operator+=(OtherNumber value) + { + std::for_each(data_, data_ + (cols_ * rows_), [&value](Number& a) { + a += static_cast(std::floor(value)); + }); + return *this; + } + + template + typename std::enable_if< + std::is_floating_point::value && std::is_integral::value, + Raster&>::type + operator-=(OtherNumber value) + { + std::for_each(data_, data_ + (cols_ * rows_), [&value](Number& a) { + a -= static_cast(std::floor(value)); + }); + return *this; + } + + template + typename std::enable_if< + std::is_floating_point::value && std::is_integral::value, + Raster&>::type + operator*=(OtherNumber value) + { + std::for_each(data_, data_ + (cols_ * rows_), [&value](Number& a) { + a *= static_cast(std::floor(value)); + }); + return *this; + } + + template + typename std::enable_if< + std::is_floating_point::value && std::is_integral::value, + Raster&>::type + operator/=(OtherNumber value) + { + std::for_each(data_, data_ + (cols_ * rows_), [&value](Number& a) { + a /= static_cast(std::floor(value)); + }); + return *this; + } + template typename std::enable_if< std::is_floating_point::value @@ -496,12 +560,13 @@ class Raster return out; } - friend inline Raster pow(Raster image, double value) + friend inline Raster pow(const Raster& image, double value) { image.for_each([value](Number& a) { a = std::pow(a, value); }); return image; } - friend inline Raster sqrt(Raster image) + + friend inline Raster sqrt(const Raster& image) { image.for_each([](Number& a) { a = std::sqrt(a); }); return image; diff --git a/tests/test_raster.cpp b/tests/test_raster.cpp index 8d78e3b1..8f1d7da0 100644 --- a/tests/test_raster.cpp +++ b/tests/test_raster.cpp @@ -126,10 +126,22 @@ static int test_multiply_in_place_operator() } } +static void test_pow() +{ + Raster a = {{4, 5}, {2, 3}}; + Raster b = {{16, 25}, {4, 9}}; + auto c = pow(a, 2); + std::cout << "pow function: "; + if (b == c) + std::cout << "OK" << std::endl; + else + std::cout << "\n" << a << "!=\n" << b << std::endl; +} + static void test_sqrt() { - Raster a = {{16, 25}, {4, 9}}; - Raster b = {{4, 5}, {2, 3}}; + Raster a = {{16, 25}, {4, 9}}; + Raster b = {{4, 5}, {2, 3}}; auto c = sqrt(a); std::cout << "sqrt function: "; if (b == c) @@ -469,6 +481,7 @@ int main() test_plus_operator(); test_multiply_in_place_operator(); + test_pow(); test_sqrt(); // all doubles, no problem