Skip to content

Commit

Permalink
Make ctmc_on_tree take a CharacterData object.
Browse files Browse the repository at this point in the history
  • Loading branch information
bredelings committed Oct 17, 2023
1 parent 281a701 commit f91f6c3
Show file tree
Hide file tree
Showing 9 changed files with 121 additions and 64 deletions.
4 changes: 1 addition & 3 deletions haskell/Bio/Sequence.hs
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,7 @@ reorder_sequences names sequences | length names /= length sequences = error "S
| otherwise = [ sequences_map Map.! name | name <- names ]
where sequences_map = Map.fromList [ (fst sequence, sequence) | sequence <- sequences ]

sequence_length a sequence = vector_size $ sequence_to_indices a sequence

get_sequence_lengths a sequences = Map.fromList [ (fst sequence, sequence_length a sequence) | sequence <- sequences]
get_sequence_lengths a sequenceData = Map.fromList [ (label, vector_size isequence) | (label, isequence) <- getSequences $ sequenceData]

foreign import bpcall "Likelihood:" bitmask_from_sequence :: EVector Int -> CBitVector
foreign import bpcall "Likelihood:" strip_gaps :: EVector Int -> EVector Int
Expand Down
35 changes: 11 additions & 24 deletions haskell/Probability/Distribution/OnTree.hs
Original file line number Diff line number Diff line change
Expand Up @@ -104,13 +104,12 @@ For sampling from phyloCTMCFixedA, we might want two versions:
transition_ps_map smodel_on_tree = IntMap.fromSet (list_to_vector . branch_transition_p smodel_on_tree) edges where
edges = getEdgesSet $ get_tree' smodel_on_tree

annotated_subst_like_on_tree tree alignment smodel sequences' = do
annotated_subst_like_on_tree tree alignment smodel sequenceData = do
let subst_root = modifiable (head $ internal_nodes tree ++ leaf_nodes tree)

let n_nodes = numNodes tree
as = pairwise_alignments alignment
node_sequences = fromMaybe (error "No Label") <$> labelToNodeMap tree sequences
Unaligned (CharacterData _ sequences) = mkUnalignedCharacterData alphabet sequences'
node_sequences = fromMaybe (error "No Label") <$> labelToNodeMap tree (getSequences sequenceData)
alphabet = getAlphabet smodel
smap = stateLetters smodel
smodel_on_tree = SingleBranchLengthModel tree smodel
Expand Down Expand Up @@ -157,7 +156,7 @@ data CTMCOnTree t s = CTMCOnTree t (AlignmentOnTree t) s


instance Dist (CTMCOnTree t s) where
type Result (CTMCOnTree t s) = [Sequence]
type Result (CTMCOnTree t s) = UnalignedCharacterData
dist_name _ = "ctmc_on_tree"

-- TODO: make this work on forests! -
Expand Down Expand Up @@ -201,9 +200,9 @@ instance (IsTree t, HasRoot (Rooted t), HasLabels t, HasBranchLengths (Rooted t)

stateSequences <- sampleComponentStates (makeRooted tree) alignment smodel

let sequenceForNode label stateSequence = (label, sequenceToText alphabet . statesToLetters smap $ extractStates stateSequence)
let sequenceForNode label stateSequence = (label, statesToLetters smap $ extractStates stateSequence)

return $ getLabelled tree sequenceForNode stateSequences
return $ Unaligned $ CharacterData alphabet $ getLabelled tree sequenceForNode stateSequences


----------------------------------------
Expand All @@ -213,29 +212,17 @@ ok, so how do we pass IntMaps to C++ functions?
well, we could turn each IntMap into an EIntMap
for alignments, we could also use an ordering of the sequences to ensure that the leaves are written first.
-}
annotated_subst_likelihood_fixed_A tree smodel sequences = do
annotated_subst_likelihood_fixed_A tree smodel sequenceData = do
let subst_root = modifiable (head $ internal_nodes tree ++ leaf_nodes tree)

let sequence_data = mkAlignedCharacterData alphabet sequences
(isequences, column_counts, mapping) = compress_alignment $ getSequences sequence_data
-- stop going through Alignment
let (isequences, column_counts, mapping) = compress_alignment $ getSequences sequenceData

node_isequences = fromMaybe (error "No label") <$> labelToNodeMap tree isequences
node_seqs_bits = (\seq -> (strip_gaps seq, bitmask_from_sequence seq)) <$> node_isequences
node_sequences = fst <$> node_seqs_bits

node_sequences0 :: IntMap (Maybe (EVector Int))
node_sequences0 = labelToNodeMap tree $ getSequences sequence_data

-- (compressed_node_sequences, column_counts', mapping') = compress_sequences node_sequence0
-- OK, so how are we going to do this?
-- * turn the sequences into an EIntMap (EVector Int)
-- * run the pattern compression on the EIntMap (EVector Int)
-- * return the compressed EIntMap (EVector Int), an EVector Int for the orig->compress mapping, and an EVector Int for the column counts).
-- * convert the EIntMap (EVector Int) back to an IntMap EVector Int for the compressed sequences.

-- OK, so now we need to get the ancestral sequences as an IntMap (EVector Int)
-- So... actually it looks like we already HAVE a minimally_connect_leaf_characters function!!!
node_sequences0 = labelToNodeMap tree $ getSequences sequenceData

n_nodes = numNodes tree
alphabet = getAlphabet smodel
Expand All @@ -260,8 +247,8 @@ annotated_subst_likelihood_fixed_A tree smodel sequences = do

-- This also needs the map from columns to compressed columns:
ancestral_sequences = case n_nodes of
1 -> Text.concat [fastaSeq s | s <- sequences]
2 -> Text.concat [fastaSeq s | s <- sequences]
1 -> Text.concat [fastaSeq $ (label, sequenceToText alphabet s) | (label,s) <- getSequences sequenceData ]
2 -> Text.concat [fastaSeq $ (label, sequenceToText alphabet s) | (label,s) <- getSequences sequenceData ]
_ -> let ancestralComponentStateSequences :: IntMap VectorPairIntInt
ancestralComponentStateSequences = sample_ancestral_sequences_SEV
tree
Expand Down Expand Up @@ -295,7 +282,7 @@ annotated_subst_likelihood_fixed_A tree smodel sequences = do
data CTMCOnTreeFixedA t s = CTMCOnTreeFixedA t s

instance Dist (CTMCOnTreeFixedA t s) where
type Result (CTMCOnTreeFixedA t s) = [Sequence]
type Result (CTMCOnTreeFixedA t s) = AlignedCharacterData
dist_name _ = "ctmc_on_tree_fixed_A"

-- TODO: make this work on forests! -
Expand Down
121 changes: 96 additions & 25 deletions src/models/A-T-prog.cc
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ std::string generate_atmodel_program(const variables_map& args,
set<string> imports;
imports.insert("Bio.Alignment"); // for Alignment.load_alignment
imports.insert("Bio.Alphabet"); // for Bio.Alphabet.dna, etc.
imports.insert("Bio.Sequence"); // for mkAlignedCharacterData, mkUnalignedCharacterData
imports.insert("Effect"); // for getProperties
imports.insert("MCMC"); // for scale_means_only_slice
imports.insert("Probability.Distribution.OnTree"); // for ctmc_on_tree{,fixed_A}
Expand Down Expand Up @@ -228,13 +229,46 @@ std::string generate_atmodel_program(const variables_map& args,
// Therefore, we are constructing a list with values [(prefix1,(Just value1, loggers1)), (prefix1, (Just value1, loggers2))

// M1. Taxa
// Partitions are classified into n groups.
// Currently n = 2, and groups are {unaligned, aligned}.
vector<int> partition_index(n_partitions);
vector<int> partition_group(n_partitions);
vector<int> partition_group_size(2);

for(int i=0;i<n_partitions;i++)
{
int g = (i_mapping[i])?0:1;
partition_group[i] = g;
partition_index[i] = partition_group_size[g]++;
}

auto unaligned_sequence_data = var("sequenceData");
auto aligned_sequence_data = var("sequenceData");

if (partition_group_size[0] > 0 and partition_group_size[1] > 0)
{
unaligned_sequence_data = var("unalignedSequenceData");
aligned_sequence_data = var("alignedSequenceData");
}

auto getSequenceData = [&](int p) -> expression_ref
{
assert(n_partitions > 0);

int group = partition_group[p];

auto sequenceData = (group == 0) ? unaligned_sequence_data : aligned_sequence_data;

// Only subscript if the group contains more than one element.
if (partition_group_size[group] == 1)
return sequenceData;
else
return {var("!!"), sequenceData, partition_index[p]};
};

if (n_partitions > 0)
{
expression_ref sequence_data1 = var("sequenceData");
if (n_partitions > 1)
sequence_data1 = {var("!!"),sequence_data1,0};
program.let(taxon_names_var, {var("map"),var("fst"),sequence_data1});
program.let(taxon_names_var, {var("getTaxa"), getSequenceData(0)});
program.empty_stmt();
}

Expand Down Expand Up @@ -395,9 +429,7 @@ std::string generate_atmodel_program(const variables_map& args,
int smodel_index = *s_mapping[i];
auto imodel_index = i_mapping[i];
expression_ref smodel = smodels[smodel_index];
expression_ref sequence_data_var = var("sequenceData");
if (n_partitions > 1)
sequence_data_var = {var("!!"),sequence_data_var,i};
expression_ref sequence_data_var = getSequenceData(i);

// Model.Partition.1. tree_part<i> = scale_branch_lengths scale tree
var branch_dist_tree("tree" + part_suffix);
Expand Down Expand Up @@ -528,19 +560,29 @@ std::string generate_atmodel_program(const variables_map& args,
for(auto& [i,a,l]: alignments)
alignment_loggers.push_back(l);

var sequences("sequences");
auto model = var("model");
auto sequence_data = var("sequenceData");
auto topology = var("topology");

expression_ref model_fn = model;

// Pass in the sequence data for the two groups.
if (partition_group_size[0] > 0)
model_fn = {model_fn, unaligned_sequence_data};
if (partition_group_size[1] > 0)
model_fn = {model_fn, aligned_sequence_data};

// Pass in the fixed tree or topology
auto tree = var("tree");
var jsonLogger("logParamsJSON");
var tsvLogger("logParamsTSV");
auto treeLogger = var("logTree");
expression_ref model_fn = {model,sequence_data};
var loggers_var("loggers");
auto topology = var("topology");
if (fixed.count("tree"))
model_fn = {model_fn,tree};
model_fn = {model_fn, tree};
else if (fixed.count("topology"))
model_fn = {model_fn, topology};

// Pass in the loggers
var jsonLogger("logParamsJSON");
var tsvLogger("logParamsTSV");
auto treeLogger = var("logTree");
if (not args.count("test"))
{
if (log_formats.count("tsv"))
Expand All @@ -553,6 +595,7 @@ std::string generate_atmodel_program(const variables_map& args,
model_fn = {model_fn, get_list(alignment_loggers)};
}

var loggers_var("loggers");
program.let(loggers_var, get_list(program_loggers));
program.empty_stmt();

Expand Down Expand Up @@ -597,15 +640,26 @@ std::string generate_atmodel_program(const variables_map& args,
if (not args.count("test"))
main.perform(get_list(prog_args), var("getArgs"));

auto unaligned_partitions = unaligned_sequence_data;
auto aligned_partitions = aligned_sequence_data;
if (n_partitions == 1)
{
auto [filename, range] = filename_ranges[0];
expression_ref E = {var("load_sequences"),String(filename.string())};
if (not range.empty())

// Load the sequences
expression_ref E = {var("load_sequences"),String(filename.string())};

// Select range
if (not range.empty())
E = {var("<$>"), {var("select_range"),String(range)}, E};
main.empty_stmt();

main.perform(sequence_data, E);
// Convert to CharacterData
if (i_mapping[0])
E = {var("<$>"),{var("mkUnalignedCharacterData"),alphabet_exps[0]}, E};
else
E = {var("<$>"),{var("mkAlignedCharacterData"),alphabet_exps[0]}, E};

main.perform(var("sequenceData"), E);
}
else
{
Expand All @@ -628,9 +682,6 @@ std::string generate_atmodel_program(const variables_map& args,
main.let(filenames_var,get_list(filenames_));
}

if (index_for_filename.size() == n_partitions and not any_ranges)
main.perform(sequence_data,{var("mapM"), var("load_sequences"), filenames_var});
else
{
// Main.2: Emit let filenames_to_seqs = ...
var filename_to_seqs("seqs");
Expand All @@ -640,22 +691,42 @@ std::string generate_atmodel_program(const variables_map& args,
main.empty_stmt();

// Main.3. Emit let sequence_data<n> =
vector<expression_ref> partition_sequence_data;
vector<var> unaligned_sequence_partitions;
vector<var> aligned_sequence_partitions;
for(int i=0;i<n_partitions;i++)
{
int group = partition_group[i];
string part = std::to_string(i+1);

var partition_sequence_data_var("sequenceData"+part);
if (partition_group_size[group] == 1)
partition_sequence_data_var = (group==0) ? unaligned_sequence_data : aligned_sequence_data;

int index = index_for_filename.at( filename_ranges[i].first );
expression_ref loaded_sequences = {var("!!"),filename_to_seqs,index};
if (not filename_ranges[i].second.empty())
loaded_sequences = {var("select_range"), String(filename_ranges[i].second), loaded_sequences};
if (i_mapping[i])
{
loaded_sequences = {var("mkUnalignedCharacterData"),alphabet_exps[i],loaded_sequences};
unaligned_sequence_partitions.push_back(partition_sequence_data_var);
}
else
{
loaded_sequences = {var("mkAlignedCharacterData"),alphabet_exps[i],loaded_sequences};
aligned_sequence_partitions.push_back(partition_sequence_data_var);
}
main.let(partition_sequence_data_var, loaded_sequences);
partition_sequence_data.push_back(partition_sequence_data_var);
main.empty_stmt();
}

// Main.4. Emit let sequence_data = ...
main.let(sequence_data, get_list(partition_sequence_data));
if (unaligned_sequence_partitions.size() > 1)
main.let(unaligned_partitions, get_list(unaligned_sequence_partitions));

if (aligned_sequence_partitions.size() > 1)
main.let(aligned_partitions, get_list(aligned_sequence_partitions));

main.empty_stmt();
}
}
Expand Down
5 changes: 3 additions & 2 deletions tests/prob_prog/infer_tree/1/Model.hs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ module Model where
import Probability
import Bio.Alignment
import Bio.Alphabet
import Bio.Sequence
import Tree
import Tree.Newick
import SModel
Expand All @@ -13,7 +14,7 @@ import System.Environment -- for getArgs
branch_length_dist topology branch = gamma (1/2) (2/fromIntegral n) where n = numBranches topology

model seq_data = do
let taxa = map fst seq_data
let taxa = getTaxa seq_data
tip_seq_lengths = get_sequence_lengths dna seq_data

-- Tree
Expand Down Expand Up @@ -54,6 +55,6 @@ model seq_data = do
main = do
[filename] <- getArgs

seq_data <- load_sequences filename
seq_data <- mkUnalignedCharacterData dna <$> load_sequences filename

return $ model seq_data
4 changes: 2 additions & 2 deletions tests/prob_prog/infer_tree/2/Model.hs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ branch_length_dist topology b = gamma 0.5 (2 / fromIntegral n) where n = numBran

model seq_data = do

let taxa = map fst seq_data
let taxa = getTaxa seq_data

scale1 <- prior $ gamma 0.5 2

Expand All @@ -39,6 +39,6 @@ model seq_data = do
main = do
[filename] <- getArgs

seq_data <- load_sequences filename
seq_data <- mkAlignedCharacterData dna <$> load_sequences filename

return $ model seq_data
4 changes: 2 additions & 2 deletions tests/prob_prog/infer_tree/3/Model.hs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ branch_length_dist topology b = gamma 0.5 (2.0 / fromIntegral n) where n = numBr

model seq_data = do

let taxa = map fst seq_data
let taxa = getTaxa seq_data

scale <- prior $ gamma 0.5 2.0

Expand All @@ -50,6 +50,6 @@ model seq_data = do
main = do
[filename] <- getArgs

seq_data <- load_sequences filename
seq_data <- mkAlignedCharacterData dna <$> load_sequences filename

return $ model seq_data
4 changes: 2 additions & 2 deletions tests/prob_prog/infer_tree/4/Model.hs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import System.Environment -- for getArgs

model seq_data = do

let taxa = zip [0..] $ map fst seq_data
let taxa = zip [0..] $ getTaxa seq_data

age <- sample $ gamma 0.5 2
tree <- add_labels taxa <$> sample (uniform_time_tree age (length taxa))
Expand All @@ -33,6 +33,6 @@ model seq_data = do
main = do
[filename] <- getArgs

seq_data <- load_sequences filename
seq_data <- mkAlignedCharacterData dna <$> load_sequences filename

return $ model seq_data
4 changes: 2 additions & 2 deletions tests/prob_prog/infer_tree/5/Model.hs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ tree_prior taxa = do


model seq_data = do
let taxa = zip [0..] $ map fst seq_data
let taxa = zip [0..] $ getTaxa seq_data

(tree , tree_loggers) <- tree_prior taxa

Expand All @@ -46,6 +46,6 @@ model seq_data = do
main = do
[filename] <- getArgs

seq_data <- load_sequences filename
seq_data <- mkAlignedCharacterData dna <$> load_sequences filename

return $ model seq_data
Loading

0 comments on commit f91f6c3

Please sign in to comment.