Skip to content

Commit

Permalink
changing the way means are set, removing excess code
Browse files Browse the repository at this point in the history
  • Loading branch information
cmaceves committed Mar 29, 2024
1 parent c8f464b commit cf33fb4
Show file tree
Hide file tree
Showing 5 changed files with 71 additions and 193 deletions.
127 changes: 69 additions & 58 deletions src/gmm.cpp
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#include "./include/armadillo"
#include "gmm.h"
#include "call_consensus_clustering.h"
#include "population_estimate.h"
#include <fstream>
#include <cmath>
#include <algorithm>
Expand All @@ -11,7 +10,6 @@ double calculate_distance(double point, double mean) {
// Euclidean distance for a single dimension
return std::abs(point - mean);
}

int find_closest_mean_index(double data_point, const std::vector<double>& means) {
// Find the index of the closest mean to the data point
int closest_index = 0;
Expand All @@ -24,7 +22,6 @@ int find_closest_mean_index(double data_point, const std::vector<double>& means)
closest_index = i;
}
}

return closest_index;
}

Expand All @@ -43,25 +40,28 @@ std::vector<std::vector<double>> cluster_data(const std::vector<double>& data, c

std::vector<double> determine_clusters(std::vector<variant> variants, uint32_t n){
std::vector<std::vector<double>> clusters(n);
std::vector<double> means;

for(uint32_t i=0; i < variants.size(); i++){
if(!variants[i].amplicon_flux && !variants[i].depth_flag && !variants[i].outside_freq_range && !variants[i].qual_flag && !variants[i].del_flag && !variants[i].amplicon_masked && !variants[i].primer_masked){
if(variants[i].vague_assignment){
//means.push_back(variants[i].freq);
//std::cerr << "cluster vague assignment " << variants[i].freq << " " << variants[i].position << std::endl;
continue;
}
if(variants[i].low_prob_flag){
//means.push_back(variants[i].freq);
//std::cerr << "cluster low prob " << variants[i].freq << " " << variants[i].position << std::endl;
continue;
}
if(variants[i].cluster_outlier){
//means.push_back(variants[i].freq);
//std::cerr << "cluster outlier " << variants[i].freq << " " << variants[i].position << std::endl;
continue;
}
clusters[variants[i].cluster_assigned].push_back(variants[i].freq);
}
}
std::vector<double> means;
for(uint32_t i=0; i < clusters.size(); i++){
double sum = std::accumulate(clusters[i].begin(), clusters[i].end(), 0.0);
double mean = sum / clusters[i].size();
Expand Down Expand Up @@ -300,18 +300,6 @@ double calculate_cluster_bounds(std::vector<variant> variants, uint32_t n){
return(min_freq);
}

uint32_t count_useful_variants(std::vector<variant> variants){
uint32_t count = 0;
//determine the number of variants useful for modeling
for(uint32_t i=0; i< variants.size(); i++){
if(!variants[i].amplicon_flux && !variants[i].depth_flag && !variants[i].outside_freq_range && !variants[i].qual_flag && !variants[i].del_flag && !variants[i].amplicon_masked && !variants[i].primer_masked){
count += 1;
}
}
return(count);
}


void generate_permutations(const std::vector<uint32_t>& elements, int n, int target, std::vector<std::vector<uint32_t>> &other_tmp){
std::vector<uint32_t> subset(elements);
n = std::min(n, static_cast<int>(elements.size()));
Expand Down Expand Up @@ -450,7 +438,7 @@ void assign_variants_simple(std::vector<variant> &variants, std::vector<std::vec
//all locations in the prob matrix for this position
for(uint32_t k = 0; k < variants.size(); k++){
if(variants[k].amplicon_flux || variants[k].depth_flag || variants[k].outside_freq_range || variants[k].qual_flag || variants[k].del_flag || variants[k].amplicon_masked || variants[k].primer_masked) continue;

if(variants[k].position == unique_pos[i]){
//std::cerr << variants[k].position << " " << variants[k].freq << std::endl;
pos_idxs.push_back(j);
Expand All @@ -472,7 +460,7 @@ void assign_variants_simple(std::vector<variant> &variants, std::vector<std::vec
uint32_t k = 0;
for(uint32_t z =0; z < variants.size(); z++){
if(variants[z].amplicon_flux || variants[z].depth_flag || variants[z].outside_freq_range || variants[z].qual_flag || variants[z].del_flag || variants[z].amplicon_masked || variants[z].primer_masked) continue;
//this pos was flagged as poorly assigned
//this pos was flagged as poorly assigned
if(tmp != assignment_flagged.end() && k == pos_idxs[j]){
//technically this could use work as it's repetitive
//std::cerr << variants[z].position << " " << variants[z].freq << " " << assigned[j] << std::endl;
Expand Down Expand Up @@ -513,13 +501,6 @@ void solve_solution_sets(std::vector<double> means, uint32_t n){
std::vector<std::vector<double>> combos;
double error = 0.05;
go(0, n, means, combination, combos, error);
for(uint32_t i=0; i < combos.size(); i++){
for(uint32_t j=0; j < combos[i].size(); j++){
//std::cerr << combos[i][j] << " ";
}
//std::cerr << std::endl;
}

}

void determine_low_prob_positions(std::vector<variant> &variants){
Expand Down Expand Up @@ -655,7 +636,6 @@ std::vector<uint32_t> find_deletion_positions(std::string filename, uint32_t dep
if(1/freq * depth < depth_cutoff || freq < lower_bound || freq > upper_bound) continue;
if (is_substring(nuc, "+") || is_substring(nuc, "-")) {
deletion_positions.push_back(pos);
//std::cerr << "del pos depth " << depth << " freq " << freq << std::endl;
}
}
return(deletion_positions);
Expand Down Expand Up @@ -759,13 +739,12 @@ int gmm_model(std::string prefix, std::vector<uint32_t> populations_iterate, std
int retval = 0;
float lower_bound = 0.01;
float upper_bound = 0.99;
uint32_t depth_cutoff = 50;
uint32_t depth_cutoff = 10;
float quality_threshold = 20;
uint32_t round_val = 4;
std::vector<variant> variants;
std::vector<uint32_t> deletion_positions = find_deletion_positions(prefix, depth_cutoff, lower_bound, upper_bound, round_val);
std::vector<uint32_t> low_quality_positions = find_low_quality_positions(prefix, depth_cutoff, lower_bound, upper_bound, quality_threshold, round_val);

parse_internal_variants(prefix, variants, depth_cutoff, lower_bound, upper_bound, deletion_positions, low_quality_positions, round_val);
std::string filename = prefix + ".txt";
//this whole things needs to be reconfigured
Expand All @@ -779,9 +758,11 @@ int gmm_model(std::string prefix, std::vector<uint32_t> populations_iterate, std
}
//initialize armadillo dataset and populate with frequency data
arma::mat data(1, useful_var, arma::fill::zeros);

//(rows, cols) where each columns is a sample
uint32_t count=0;
for(uint32_t i = 0; i < variants.size(); i++){

//check if variant should be filtered for first pass model
if(!variants[i].amplicon_flux && !variants[i].depth_flag && !variants[i].outside_freq_range && !variants[i].qual_flag && !variants[i].del_flag && !variants[i].amplicon_masked && !variants[i].primer_masked){
double tmp = static_cast<double>(variants[i].freq);
Expand All @@ -792,42 +773,60 @@ int gmm_model(std::string prefix, std::vector<uint32_t> populations_iterate, std
std::vector<std::vector<double>> solutions; //straight from the model
std::vector<double> all_aic; //aic for each population
std::vector<std::vector<variant>> all_variants;

//try various clusters
std::cerr << "num useful vars " << useful_var << std::endl;
for(auto n : populations_iterate){
variants.clear();
parse_internal_variants(prefix, variants, depth_cutoff, lower_bound, upper_bound, deletion_positions, low_quality_positions, round_val);
if(n > (useful_var/10)) continue; //this is because it's recommended to have 10 points per gaussian
if(((float)n > (float)(useful_var/5)) && (n > 2)) continue; //this is because it's recommended to have 10 points per gaussian
arma::gmm_diag model;
arma::mat cov (1, n, arma::fill::zeros);
bool status = model.learn(data, n, arma::eucl_dist, arma::random_spread, 15, 15, 1e-12, false);
std::cerr << model.dcovs << std::endl;
if(status == false){
std::cerr << "gmm model failed" << std::endl;
continue;
}
//get the means of the gaussians
std::vector<double> means;

for(auto x : model.means){
std::cerr << x << std::endl;
means.push_back((double) x);
}

auto min_iterator = std::min_element(means.begin(), means.end());
uint32_t min_index = std::distance(means.begin(), min_iterator);
auto max_iterator = std::max_element(means.begin(), means.end());
uint32_t max_index = std::distance(means.begin(), max_iterator);

arma::mat mean_fill (1, n, arma::fill::zeros);
for(uint32_t l=0; l < n; l++){
if(l == min_index){
mean_fill.col(l) = 0.02;
} else if(l == max_index){
mean_fill.col(l) = 0.98;
} else{
mean_fill.col(l) = means[l];
}
}
model.set_means(mean_fill);
means.clear();
for(auto x : model.means){
std::cerr << x << std::endl;
means.push_back((double) x);
}
for(uint32_t l=0; l<n;l++){

for(uint32_t l=0; l < n;l++){
if(means[l] >= 0.95 || means[l] <= 0.05){
cov.col(l) = 0.01;
cov.col(l) = 0.005;
} else {
cov.col(l) = 0.001;
cov.col(l) = 0.005;
}
}
double heft = (double)(1/(double)n);
arma::mat hefts (1, n, arma::fill::zeros);
for(uint32_t l=0; l<n;l++){
hefts.col(l) = heft;
}

//model.set_hefts(hefts);
}
std::cerr << model.means << std::endl;
model.set_dcovs(cov);
std::cerr << "hefts " << model.hefts << std::endl;
std::cerr << model.dcovs << std::endl;

//get the probability of each frequency being assigned to each gaussian
std::vector<std::vector<double>> prob_matrix;
for(uint32_t i=0; i < n; i++){
Expand Down Expand Up @@ -855,33 +854,25 @@ int gmm_model(std::string prefix, std::vector<uint32_t> populations_iterate, std
double prob_sum = 0;
determine_low_prob_positions(variants);
all_variants.push_back(variants);
std::cerr << "useful variants " << useful_var << std::endl;

for(uint32_t i=0; i < variants.size(); i++){
if(!variants[i].amplicon_flux && !variants[i].depth_flag && !variants[i].outside_freq_range && !variants[i].qual_flag && !variants[i].del_flag && !variants[i].amplicon_masked && !variants[i].primer_masked){
//log likelihood for each point
double prob = variants[i].probabilities[variants[i].cluster_assigned];
//std::cerr << variants[i].freq << " " << variants[i].position << " " << variants[i].nuc << std::endl;
/*if(variants[i].freq < 0.35 && variants[i].freq > 0.15){
std::cerr << variants[i].freq << " " << variants[i].position << " " << prob << std::endl;
for(auto x : variants[i].probabilities){
std::cerr << x << " ";
}
std::cerr << "\n";
/*if(variants[i].freq > 0.10 && variants[i].freq < 0.90){
std::cerr << variants[i].cluster_assigned << " " << variants[i].freq << " " << variants[i].position << " " << prob << std::endl;
}*/
prob_sum += prob;
}
}
solutions.push_back(means);
means.clear();
double aic = (2 * (double)n) - (2 * prob_sum / useful_var);
double aic = (2 * (double)n) - (2 * prob_sum / (double)useful_var);
std::cerr << "aic " << aic << "\n" << std::endl;
all_aic.push_back(aic);
std::cerr << "avg aic " << aic << std::endl;
std::cerr << "\n";
//draw the actual threshold
//double threshold = calculate_cluster_bounds(variants, n);
//model.save("my_model.gmm");
}

double smallest_value = std::numeric_limits<double>::max();
size_t index = 0;
for (size_t i = 0; i < all_aic.size(); ++i) {
Expand All @@ -892,6 +883,13 @@ int gmm_model(std::string prefix, std::vector<uint32_t> populations_iterate, std
}
std::vector<variant> used_variants = all_variants[index];
std::vector<double> means = solutions[index];
double counter = 0;
for(uint32_t i=0; i < used_variants.size(); i++){
if(!used_variants[i].amplicon_flux && !used_variants[i].depth_flag && !used_variants[i].outside_freq_range && !used_variants[i].qual_flag && !used_variants[i].del_flag && !used_variants[i].amplicon_masked && !used_variants[i].primer_masked){
counter += 1;
}
}

//std::vector<double> means = determine_clusters(used_variants, populations_iterate[index]);
for(auto x : means){
std::cerr << x << std::endl;
Expand All @@ -906,9 +904,22 @@ int gmm_model(std::string prefix, std::vector<uint32_t> populations_iterate, std
means_string += std::to_string(means[j]);
}
means_string += "]";
file << "means\n";
file << means_string;

std::vector<double> new_means = determine_clusters(used_variants, means.size());
std::string means_recalc = "[";
for(uint32_t j=0; j < new_means.size(); j++){
if(j != 0) means_recalc += ",";
means_recalc += std::to_string(new_means[j]);
}
means_recalc += "]";


file << "means\trecalculated_means\n";
file << means_string << "\t";
file << means_recalc << "\n";
file.close();


exit(1);

//here we now define the two criteria in which we eliminate things as being contaminated
Expand Down
23 changes: 2 additions & 21 deletions src/ivar.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#include "trim_primer_quality.h"
#include "gmm.h"
#include "saga.h"
#include "population_estimate.h"

const std::string VERSION = "1.4.2";

Expand Down Expand Up @@ -49,9 +48,6 @@ struct args_t {
bool keep_for_reanalysis; // -k
//contam params
std::string variants; // -s
float evol_rate; // -e
std::string sample_date; // -d
std::string ref_date; // -r
} g_args;

void print_usage() {
Expand All @@ -77,8 +73,7 @@ void print_usage() {

void print_contam_usage() {
std::cout
<< "Usage: ivar contam -e [<evol_rate>] -d <sample_date> [-s <variants>] [-r "
"<ref_date>]\n\n"
<< "Usage: ivar contam \n\n"
"Input Options Description\n"
" -i BAM file, with aligned reads, to "
"trim primers and quality. If not specified will use standard in\n"
Expand Down Expand Up @@ -337,7 +332,6 @@ int main(int argc, char *argv[]) {
g_args.bed = "";
g_args.primer_pair_file = "";
g_args.primer_offset = 0;
g_args.evol_rate = 0.0;
opt = getopt(argc, argv, contam_opt_str);
while (opt != -1) {
switch (opt) {
Expand All @@ -359,15 +353,6 @@ int main(int argc, char *argv[]) {
case 's':
g_args.variants = optarg;
break;
case 'e':
g_args.evol_rate = std::stof(optarg);
break;
case 'd':
g_args.sample_date = optarg;
break;
case 'r':
g_args.ref_date = optarg;
break;
case 'h':
case '?':
print_trim_usage();
Expand All @@ -376,11 +361,7 @@ int main(int argc, char *argv[]) {
}
opt = getopt(argc, argv, contam_opt_str);
}
if (!(g_args.evol_rate == 0.0) && !g_args.variants.empty() && !g_args.sample_date.empty() && !g_args.ref_date.empty()) {
res = estimate_populations(g_args.variants, g_args.evol_rate, g_args.sample_date, g_args.ref_date);
//print_contam_usage();
//return -1;
} else if (!g_args.variants.empty() && !g_args.prefix.empty()) {
if (!g_args.variants.empty() && !g_args.prefix.empty()) {
std::vector<uint32_t> populations_iterate;
for(uint32_t i= 2; i <= 6; i++){
populations_iterate.push_back(i);
Expand Down
Loading

0 comments on commit cf33fb4

Please sign in to comment.