Skip to content

Commit

Permalink
ContractorNetworks in Codac2 version
Browse files Browse the repository at this point in the history
  • Loading branch information
SimonRohou committed Dec 2, 2023
1 parent e77c0ac commit 47bbc52
Show file tree
Hide file tree
Showing 15 changed files with 313 additions and 41 deletions.
3 changes: 2 additions & 1 deletion python/src/core/2/cn/codac2_py_ContractorNetwork.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ void export_ContractorNetwork_codac2(py::module& m)
"n"_a)

.def("contract", &ContractorNetwork::contract,
"todo")
"todo",
"verbose"_a=true)
;
}
15 changes: 15 additions & 0 deletions src/core/2/cn/codac2_Contractor.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,28 @@ namespace codac2
return std::make_shared<ContractorNode<std::remove_reference<decltype(*this)>::type,T...>>(*this, a...); \
} \

#define make_available_in_cn__templated(template_args) \
template<typename... T> \
std::shared_ptr<ContractorNodeBase> operator()(T&... a) \
{ \
return std::make_shared<ContractorNode<template_args,T...>>(*this, a...); \
} \

class Contractor
{
public:

virtual ~Contractor() = default;
};

template<int N>
class ContractorOnBox : public Contractor
{
public:

virtual void contract(IntervalVector_<N>& x) = 0;
};

class Contractor1 : public Contractor
{
public:
Expand Down
39 changes: 31 additions & 8 deletions src/core/2/cn/codac2_ContractorNetwork.cpp
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
#include <cassert>
#include <time.h>
#include "codac2_ContractorNetwork.h"

using namespace std;
Expand All @@ -19,19 +20,27 @@ namespace codac2

void ContractorNetwork::add_ctc_to_stack(const shared_ptr<ContractorNodeBase>& ctc)
{
assert(!ctc->to_be_called());
_stack.push_back(ctc);
ctc->set_call_flag(true);
if(find(_stack.begin(),_stack.end(),ctc) == _stack.end())
_stack.push_back(ctc);
}

void ContractorNetwork::disable_auto_fixpoint(bool disable)
{
_auto_fixpoint = !disable;
}

void ContractorNetwork::contract()
void ContractorNetwork::contract(bool verbose)
{
do
if(verbose)
{
std::cout << "Contractor network (" << _v_ctc.size()
<< " contractors, " << _v_domains.size() << " domains)" << std::endl;
std::cout << "Computing, " << _stack.size() << " contractors currently in stack" << std::endl;
}

clock_t t_start = clock();

while(!_stack.empty())
{
shared_ptr<ContractorNodeBase> current_ctc = _stack.front();
_stack.pop_front();
Expand All @@ -43,10 +52,24 @@ namespace codac2
{
auto p_c = ci.lock();
if(!_auto_fixpoint && current_ctc.get() == p_c.get()) continue;
if(!p_c->to_be_called())
add_ctc_to_stack(p_c);
add_ctc_to_stack(p_c);
}
}

if(verbose)
std::cout << " Constraint propagation time: " << (double)(clock() - t_start)/CLOCKS_PER_SEC << "s" << std::endl;
}

void ContractorNetwork::reset_all_vars()
{
for(auto& d : _v_domains)
if(d->is_var())
d->reset();
}

} while(!_stack.empty());
void ContractorNetwork::trigger_all_contractors()
{
for(auto& c : _v_ctc)
add_ctc_to_stack(c);
}
}
21 changes: 19 additions & 2 deletions src/core/2/cn/codac2_ContractorNetwork.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
#include <memory>
#include "codac2_DomainNode.h"
#include "codac2_ContractorNode.h"
#include "codac2_Var.h"

namespace codac2
{

class ContractorNetwork
{
public:
Expand All @@ -18,7 +18,24 @@ namespace codac2
void add(const std::shared_ptr<ContractorNodeBase>& ctc);
void add_ctc_to_stack(const std::shared_ptr<ContractorNodeBase>& ctc);
void disable_auto_fixpoint(bool disable = true);
void contract();
void contract(bool verbose = true);

void reset_all_vars();

template<typename T>
void reset_var(const Var<T> *ref, const T& x)
{
for(auto& d : _v_domains)
if(d->raw_ptr() == ref)
{
static_cast<DomainNode<T>&>(*d).get() = x;
return;
}

assert(false && "unable to find variable");
}

void trigger_all_contractors();

//protected:

Expand Down
10 changes: 0 additions & 10 deletions src/core/2/cn/codac2_ContractorNode.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,6 @@ using namespace std;

namespace codac2
{
void ContractorNodeBase::set_call_flag(bool flag)
{
_to_be_called = flag;
}

bool ContractorNodeBase::to_be_called() const
{
return _to_be_called;
}

ostream& operator<<(ostream& os, const ContractorNodeBase& d)
{
os << "Contractor: " << d.contractor_class_name() << ", dom=" << d.nb_args() << endl;
Expand Down
17 changes: 4 additions & 13 deletions src/core/2/cn/codac2_ContractorNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,13 @@ namespace codac2
public:

virtual ~ContractorNodeBase() = default;
virtual constexpr size_t nb_args() const = 0;
virtual size_t nb_args() const = 0;
virtual std::list<std::shared_ptr<DomainNodeBase>> call_contract() = 0;
virtual Contractor* raw_ptr() const = 0;
virtual void associate_domains(std::vector<std::shared_ptr<DomainNodeBase>>& cn_domains) = 0;
virtual std::string contractor_class_name() const = 0;
void set_call_flag(bool flag = true);
bool to_be_called() const;

friend std::ostream& operator<<(std::ostream& os, const ContractorNodeBase& d);

protected:

bool _to_be_called = false; // this is redundant with the presence of the node
// in the CN stack, but avoids to search the node in the stack
};

template<typename C,typename... T>
Expand All @@ -51,7 +44,7 @@ namespace codac2

}

constexpr size_t nb_args() const
size_t nb_args() const
{
return std::tuple_size<decltype(_x)>::value;
}
Expand All @@ -61,15 +54,14 @@ namespace codac2
return &(_ctc.get());
}

void add_domain_if_contracted(const std::shared_ptr<DomainNodeBase>& d, std::list<std::shared_ptr<DomainNodeBase>>& l)
void process_domain_after_ctc(const std::shared_ptr<DomainNodeBase>& d, std::list<std::shared_ptr<DomainNodeBase>>& l)
{
if(d->has_been_contrated())
l.push_back(d);
}

std::list<std::shared_ptr<DomainNodeBase>> call_contract()
{
_to_be_called = false;
std::list<std::shared_ptr<DomainNodeBase>> contracted_doms;

std::apply(
Expand All @@ -81,7 +73,7 @@ namespace codac2
std::apply(
[this,&contracted_doms](auto &&... args)
{
(add_domain_if_contracted(args,contracted_doms),...);
(process_domain_after_ctc(args,contracted_doms),...);
}, _x);

return contracted_doms;
Expand All @@ -96,7 +88,6 @@ namespace codac2
d = std::dynamic_pointer_cast<D>(cn_xi);
return;
}

cn_domains.push_back(d);
}

Expand Down
18 changes: 17 additions & 1 deletion src/core/2/cn/codac2_DomainNode.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <memory>
#include <functional>
#include "codac2_Domain.h"
#include "codac2_Var.h"

namespace codac2
{
Expand All @@ -23,6 +24,8 @@ namespace codac2
virtual const std::vector<std::weak_ptr<ContractorNodeBase>>& contractors() const = 0;
virtual void add_ctc(const std::shared_ptr<ContractorNodeBase>& ctc) = 0;
virtual std::string domain_class_name() const = 0;
virtual void reset() = 0;
virtual bool is_var() const = 0;

friend std::ostream& operator<<(std::ostream& os, const DomainNodeBase& d);
};
Expand All @@ -34,6 +37,8 @@ namespace codac2

public:

constexpr static bool _is_var = std::is_base_of<VarBase,T>::value;

std::reference_wrapper<T> make_ref(T_& x)
{
if constexpr (std::is_const<T_>::value)
Expand All @@ -45,7 +50,7 @@ namespace codac2
return std::ref(x);
}

DomainNode(T_& x)
explicit DomainNode(T_& x)
: _x(make_ref(x)), _volume(x.dom_volume())
{
if constexpr (std::is_const<T_>::value)
Expand Down Expand Up @@ -92,6 +97,17 @@ namespace codac2
return typeid(T).name();
}

void reset()
{
if constexpr(_is_var)
_x.get().reset();
}

constexpr bool is_var() const
{
return _is_var;
}

protected:

std::shared_ptr<T> _local_dom = nullptr;
Expand Down
21 changes: 21 additions & 0 deletions src/core/2/cn/codac2_Var.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ namespace codac2
public:

virtual ~VarBase() = default;
VarBase& operator=(const VarBase& x) = delete;
};

template<typename T>
Expand All @@ -25,6 +26,26 @@ namespace codac2
{

}

Var(const T& initial_value)
: _initial_value(initial_value)
{

}

void reset()
{
this->T::operator=(_initial_value);
}

void reset(const T& x)
{
this->T::operator=(x);
}

protected:

T _initial_value;
};
}

Expand Down
49 changes: 49 additions & 0 deletions src/core/2/contractors/codac2_CtcCN.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/**
* \file
*
* ----------------------------------------------------------------------------
* \date 2023
* \author Simon Rohou
* \copyright Copyright 2023 Codac Team
* \license This program is distributed under the terms of
* the GNU Lesser General Public License (LGPL).
*/

#ifndef __CODAC2_CTCCN_H__
#define __CODAC2_CTCCN_H__

#include "codac2_Contractor.h"
#include "codac2_ContractorNetwork.h"

namespace codac2
{
template<int N>
class CtcCN : public ContractorOnBox<N>
{
public:

CtcCN(ContractorNetwork& cn, Var<IntervalVector_<N>>& var)
: _cn(cn), _ref_var(&var)
{

}

void contract(IntervalVector_<N>& x)
{
_cn.reset_all_vars();
_cn.reset_var(_ref_var,x);
_cn.trigger_all_contractors();
_cn.contract(false);
x = *_ref_var;
}

make_available_in_cn__templated(CtcCN<N>)

protected:

ContractorNetwork& _cn;
const Var<IntervalVector_<N>>* _ref_var;
};
}

#endif
Loading

0 comments on commit 47bbc52

Please sign in to comment.