Skip to content

Commit

Permalink
added more flexible type-branching metaprogram for assign() promotion…
Browse files Browse the repository at this point in the history
… in agrad matrix
  • Loading branch information
Bob Carpenter committed Sep 11, 2012
1 parent 2842b0d commit 2d5ac18
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 22 deletions.
50 changes: 28 additions & 22 deletions src/stan/agrad/matrix.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2367,33 +2367,39 @@ namespace stan {
assign_to_var(x(m,n),y(m,n));
}

inline void assign_to_nonvar(double& var, int val) {
var = val;
}
template <typename T1, typename T2>
inline void assign_to_nonvar(T1& var, const T2& val) {
throw std::domain_error("illegal assignment with mismatched LHS and RHS types");
}
template <typename T>
inline void assign_to_nonvar(T& var, const T& val) {
var = val;
}


template <typename LHS, typename RHS>
inline void assign(LHS& var,
const RHS& val) {
using stan::is_constant_struct;
if (is_constant_struct<RHS>::value
&& !is_constant_struct<LHS>::value)
assign_to_var(var,val);
else
assign_to_nonvar(var,val);
}

struct needs_promotion {
enum { value = ( is_constant_struct<RHS>::value
&& !is_constant_struct<LHS>::value) };
};

template <bool PromoteRHS, typename LHS, typename RHS>
struct assigner {
static inline void assign(LHS& var, const RHS& val) {
throw std::domain_error("should not call base class of assigner");
}
};

template <typename LHS, typename RHS>
struct assigner<false,LHS,RHS> {
static inline void assign(LHS& var, const RHS& val) {
var = val;
}
};

template <typename LHS, typename RHS>
struct assigner<true,LHS,RHS> {
static inline void assign(LHS& var, const RHS& val) {
assign_to_var(var,val);
}
};


template <typename LHS, typename RHS>
inline void assign(LHS& var, const RHS& val) {
assigner<needs_promotion<LHS,RHS>::value, LHS, RHS>::assign(var,val);
}



Expand Down
19 changes: 19 additions & 0 deletions src/test/agrad/matrix_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2958,6 +2958,7 @@ TEST(AgradMatrix, initializeVariable) {
}

TEST(AgradMatrix, assign) {
using stan::agrad::assign;
using stan::agrad::var;
using std::vector;
using Eigen::Matrix;
Expand All @@ -2967,10 +2968,28 @@ TEST(AgradMatrix, assign) {
assign(x,2.0);
EXPECT_FLOAT_EQ(2.0,x.val());

assign(x,2);
EXPECT_FLOAT_EQ(2.0,x.val());

var y(3.0);
assign(x,y);
EXPECT_FLOAT_EQ(3.0,x.val());

double xd;
assign(xd,2.0);
EXPECT_FLOAT_EQ(2.0,xd);

assign(xd,2);
EXPECT_FLOAT_EQ(2.0,xd);

int iii;
assign(iii,2);
EXPECT_EQ(2,iii);

unsigned int j = 12;
assign(iii,j);
EXPECT_EQ(12,j);

vector<double> y_dbl(2);
y_dbl[0] = 2.0;
y_dbl[1] = 3.0;
Expand Down

0 comments on commit 2d5ac18

Please sign in to comment.