-
-
Notifications
You must be signed in to change notification settings - Fork 373
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #2434 from stan-dev/feature/2432-standalone-gq-ser…
…vice Feature/2432 standalone gq service
- Loading branch information
Showing
18 changed files
with
359,195 additions
and
29 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,92 @@ | ||
#ifndef STAN_SERVICES_SAMPLE_STANDALONE_GQS_HPP | ||
#define STAN_SERVICES_SAMPLE_STANDALONE_GQS_HPP | ||
|
||
#include <stan/callbacks/interrupt.hpp> | ||
#include <stan/callbacks/logger.hpp> | ||
#include <stan/callbacks/writer.hpp> | ||
#include <stan/services/error_codes.hpp> | ||
#include <stan/services/util/create_rng.hpp> | ||
#include <stan/services/util/gq_writer.hpp> | ||
#include <string> | ||
#include <vector> | ||
|
||
namespace stan { | ||
namespace services { | ||
|
||
/** | ||
* Return the number of constrained parameters for the specified | ||
* model. | ||
* | ||
* @tparam Model type of model | ||
* @param[in] model model to query | ||
* @return number of constrained parameters for the model | ||
*/ | ||
template <class Model> | ||
int num_constrained_params(const Model& model) { | ||
std::vector<std::string> param_names; | ||
static const bool include_tparams = false; | ||
static const bool include_gqs = false; | ||
model.constrained_param_names(param_names, include_tparams, | ||
include_gqs); | ||
return param_names.size(); | ||
} | ||
|
||
/** | ||
* Given a set of draws from a fitted model, generate corresponding | ||
* quantities of interest. Data written to callback writer. | ||
* Return code indicates success or type of error. | ||
* | ||
* @tparam Model model class | ||
* @param[in] model instantiated model | ||
* @param[in] draws sequence of draws of unconstrained parameters | ||
* @param[in] seed seed to use for randomization | ||
* @param[in, out] interrupt called every iteration | ||
* @param[in, out] logger logger to which to write warning and error messages | ||
* @param[in, out] sample_writer writer to which draws are written | ||
* @return error code | ||
*/ | ||
template <class Model> | ||
int standalone_generate(const Model& model, | ||
const std::vector<std::vector<double> >& draws, | ||
unsigned int seed, | ||
callbacks::interrupt& interrupt, | ||
callbacks::logger& logger, | ||
callbacks::writer& sample_writer) { | ||
if (draws.empty()) { | ||
logger.error("Empty set of draws from fitted model."); | ||
return error_codes::DATAERR; | ||
} | ||
|
||
int num_params = num_constrained_params(model); | ||
std::vector<std::string> gq_names; | ||
static const bool include_tparams = false; | ||
static const bool include_gqs = true; | ||
model.constrained_param_names(gq_names, include_tparams, include_gqs); | ||
if (!(static_cast<size_t>(num_params) < gq_names.size())) { | ||
logger.error("Model doesn't generate any quantities of interest."); | ||
return error_codes::CONFIG; | ||
} | ||
|
||
util::gq_writer writer(sample_writer, logger, num_params); | ||
boost::ecuyer1988 rng = util::create_rng(seed, 1); | ||
writer.write_gq_names(model); | ||
|
||
std::stringstream msg; | ||
for (const std::vector<double>& draw : draws) { | ||
if (static_cast<size_t>(num_params) != draw.size()) { | ||
msg << "Wrong number of params in draws from fitted model. "; | ||
msg << "Expecting " << num_params << " columns, "; | ||
msg << "found " << draws[0].size() << " columns."; | ||
std::string msgstr = msg.str(); | ||
logger.error(msgstr); | ||
return error_codes::DATAERR; | ||
} | ||
interrupt(); // call out to interrupt and fail | ||
writer.write_gq_values(model, rng, draw); | ||
} | ||
return error_codes::OK; | ||
} | ||
|
||
} | ||
} | ||
#endif |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
#ifndef STAN_SERVICES_UTIL_GQ_WRITER_HPP | ||
#define STAN_SERVICES_UTIL_GQ_WRITER_HPP | ||
|
||
#include <stan/callbacks/logger.hpp> | ||
#include <stan/callbacks/writer.hpp> | ||
#include <stan/mcmc/base_mcmc.hpp> | ||
#include <stan/mcmc/sample.hpp> | ||
#include <stan/model/prob_grad.hpp> | ||
#include <sstream> | ||
#include <iomanip> | ||
#include <string> | ||
#include <vector> | ||
|
||
namespace stan { | ||
namespace services { | ||
namespace util { | ||
|
||
/** | ||
* gq_writer writes out | ||
* | ||
* @tparam Model Model class | ||
*/ | ||
class gq_writer { | ||
private: | ||
callbacks::writer& sample_writer_; | ||
callbacks::logger& logger_; | ||
int num_constrained_params_; | ||
|
||
public: | ||
/** | ||
* Constructor. | ||
* | ||
* @param[in,out] sample_writer samples are "written" to this stream | ||
* @param[in,out] logger messages are written through the logger | ||
* @param[in] num_constrained_params offset into write_array gqs | ||
*/ | ||
gq_writer(callbacks::writer& sample_writer, callbacks::logger& logger, | ||
int num_constrained_params): sample_writer_(sample_writer), | ||
logger_(logger), | ||
num_constrained_params_(num_constrained_params) { } | ||
|
||
/** | ||
* Write names of variables declared in the generated quantities block | ||
* to stream `sample_writer_`. | ||
* | ||
* @tparam M model class | ||
*/ | ||
template <class Model> | ||
void write_gq_names(const Model& model) { | ||
static const bool include_tparams = false; | ||
static const bool include_gqs = true; | ||
std::vector<std::string> names; | ||
model.constrained_param_names(names, include_tparams, include_gqs); | ||
std::vector<std::string> gq_names(names.begin() | ||
+ num_constrained_params_, | ||
names.end()); | ||
sample_writer_(gq_names); | ||
} | ||
|
||
/** | ||
* Calls model's `write_array` method and writes values of | ||
* variables defined in the generated quantities block | ||
* to stream `sample_writer_`. | ||
* | ||
* @tparam M model class | ||
* @tparam RNG pseudo random number generator class | ||
* @param[in] model instantiated model | ||
* @param[in] rng instantiated RNG | ||
* @param[in] draw sequence unconstrained parameters values. | ||
*/ | ||
template <class Model, class RNG> | ||
void write_gq_values(const Model& model, | ||
RNG& rng, | ||
const std::vector<double>& draw) { | ||
std::vector<double> values; | ||
std::vector<int> params_i; // unused - no discrete params | ||
std::stringstream ss; | ||
try { | ||
model.write_array(rng, | ||
const_cast<std::vector<double>&>(draw), | ||
params_i, | ||
values, | ||
false, | ||
true, | ||
&ss); | ||
} catch (const std::exception& e) { | ||
if (ss.str().length() > 0) | ||
logger_.info(ss); | ||
logger_.info(e.what()); | ||
return; | ||
} | ||
if (ss.str().length() > 0) | ||
logger_.info(ss); | ||
|
||
std::vector<double> gq_values(values.begin() | ||
+ num_constrained_params_, | ||
values.end()); | ||
sample_writer_(gq_values); | ||
} | ||
}; | ||
|
||
} | ||
} | ||
} | ||
#endif |
Oops, something went wrong.