Skip to content

Commit

Permalink
Merge pull request #159 from SimonRohou/codac2_dev
Browse files Browse the repository at this point in the history
[ctc] improved MulOp operators (for r>c linear systems)
  • Loading branch information
SimonRohou authored Dec 12, 2024
2 parents b99a975 + 8a1121a commit 25e664b
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 2 deletions.
84 changes: 82 additions & 2 deletions src/core/contractors/codac2_directed_ctc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,48 @@ using namespace codac2;
MulOp::bwd(y, x2, x1);
}

Interval MulOp::fwd(const IntervalRow& x1, const IntervalVector& x2)
{
assert(x1.size() == x2.size());
Interval s(0.);
for(Index i = 0 ; i < x1.size() ; i++)
s += x1[i]*x2[i];
return s;
}

//ScalarOpValue MulOp::fwd(const RowOpValue& x1, const VectorOpValue& x2)
//{
// // RowOpValue not yet defined
//}

void MulOp::bwd(const Interval& y, IntervalRow& x1, IntervalVector& x2)
{
assert(x1.size() == x2.size());

const Index n = x1.size();
vector<Interval> sums(n), prods(n);

// Forward propagation

for(Index i = 0 ; i < n ; i++)
{
prods[i] = x1[i]*x2[i];
sums[i] = prods[i];
if(i > 0) sums[i] += sums[i-1];
}

// Backward propagation

sums[n-1] &= y;

for(Index i = n-1 ; i >= 0 ; i--)
{
if(i > 0) AddOp::bwd(sums[i],sums[i-1],prods[i]);
else prods[0] &= sums[0];
MulOp::bwd(prods[i],x1[i],x2[i]);
}
}

IntervalVector MulOp::fwd(const IntervalMatrix& x1, const IntervalVector& x2)
{
assert(x1.cols() == x2.size());
Expand All @@ -381,20 +423,58 @@ using namespace codac2;
#include "codac2_linear_ctc.h"
#include "codac2_GaussJordan.h"

//#include "codac2_ibex.h"

void MulOp::bwd(const IntervalVector& y, IntervalMatrix& x1, IntervalVector& x2)
{
assert(x1.rows() == y.size());
assert(x1.cols() == x2.size());

/*if(x1.is_squared()) // not working for any x1
/*if(x1.is_squared()) // not working for any squared x1
{
CtcGaussElim ctc_ge;
CtcLinearPrecond ctc_gep(ctc_ge);
IntervalVector y_(y);
ctc_gep.contract(x1,x2,y_);
}*/

if(x1.rows() > x1.cols())
{
#if 0 // IBEX version
ibex::IntervalVector ibex_y(to_ibex(y)), ibex_x2(to_ibex(x2));
ibex::IntervalMatrix ibex_x1(to_ibex(x1));
ibex::bwd_mul(ibex_y, ibex_x1, ibex_x2, 0.05);
x1 &= to_codac(ibex_x1);
x2 &= to_codac(ibex_x2);
#else

Index last_row = 0;
Index i = 0;

do
{
double vol_x2 = x2.volume();
IntervalRow row_i = x1.row(i);
MulOp::bwd(y[i],row_i,x2);

if(row_i.is_empty())
{
x1.set_empty();
return;
}

else
x1.row(i) = row_i;

if(x2.volume()/vol_x2 < 0.98)
last_row = i;
i = (i+1)%y.size();
} while(i != last_row);

#endif
}

else*/
else
{
IntervalMatrix Q = gauss_jordan(x1.mid());
IntervalVector b_tilde = Q*y;
Expand Down
5 changes: 5 additions & 0 deletions src/core/contractors/codac2_directed_ctc.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <string>
#include "codac2_analytic_values.h"
#include "codac2_template_tools.h"
#include "codac2_IntervalRow.h"

namespace codac2
{
Expand Down Expand Up @@ -87,6 +88,10 @@ namespace codac2
static VectorOpValue fwd(const VectorOpValue& x1, const ScalarOpValue& x2);
static void bwd(const IntervalVector& y, IntervalVector& x1, Interval& x2);

static Interval fwd(const IntervalRow& x1, const IntervalVector& x2);
//static ScalarOpValue fwd(const RowOpValue& x1, const VectorOpValue& x2); // RowOpValue not yet defined
static void bwd(const Interval& y, IntervalRow& x1, IntervalVector& x2);

static IntervalVector fwd(const IntervalMatrix& x1, const IntervalVector& x2);
static VectorOpValue fwd(const MatrixOpValue& x1, const VectorOpValue& x2);
static void bwd(const IntervalVector& y, IntervalMatrix& x1, IntervalVector& x2);
Expand Down

0 comments on commit 25e664b

Please sign in to comment.