Skip to content

Commit

Permalink
Add t-SNE embeddings and clusters
Browse files Browse the repository at this point in the history
Adds rules, scripts, and config to produce joint t-SNE embeddings per
build from all gene segments and find clusters from the resulting
embeddings. When the user defines the `embedding` key in their build
config, the workflow produces pairwise distances per gene segment, runs
t-SNE on those distances, finds clusters with HDBSCAN, and exports the
embedding coordinates and clusters in the Auspice JSON.
  • Loading branch information
huddlej committed Aug 28, 2024
1 parent 0984ad8 commit 4f31eb8
Show file tree
Hide file tree
Showing 8 changed files with 231 additions and 6 deletions.
103 changes: 103 additions & 0 deletions Snakefile
Original file line number Diff line number Diff line change
Expand Up @@ -490,6 +490,105 @@ rule cleavage_site:
--cleavage_site_sequence {output.cleavage_site_sequences}
"""

rule get_strains_in_alignment:
input:
alignment = "results/{subtype}/{segment}/{time}/aligned.fasta",
output:
alignment_strains = "results/{subtype}/{segment}/{time}/aligned_strains.txt",
shell:
"""
seqkit fx2tab -n -i {input.alignment} | sort -k 1,1 > {output.alignment_strains}
"""

rule get_shared_strains_in_alignments:
input:
alignment_strains = expand("results/{{subtype}}/{segment}/{{time}}/aligned_strains.txt", segment=config["segments"]),
output:
shared_strains = "results/{subtype}/all/{time}/shared_strains_in_alignment.txt",
shell:
"""
python3 scripts/intersect_items.py \
--items {input.alignment_strains:q} \
--output {output.shared_strains}
"""

rule select_shared_strains_from_alignment_and_sort:
input:
shared_strains = "results/{subtype}/all/{time}/shared_strains_in_alignment.txt",
alignment = "results/{subtype}/{segment}/{time}/aligned.fasta",
output:
alignment = "results/{subtype}/{segment}/{time}/aligned.sorted.fasta",
shell:
"""
seqkit grep -f {input.shared_strains} {input.alignment} \
| seqkit sort -n > {output.alignment}
"""

rule calculate_pairwise_distances:
input:
alignment = "results/{subtype}/{segment}/{time}/aligned.sorted.fasta",
output:
distances = "results/{subtype}/{segment}/{time}/distances.csv",
benchmark:
"benchmarks/calculate_pairwise_distances_{subtype}_{segment}_{time}.txt"
shell:
"""
pathogen-distance \
--alignment {input.alignment} \
--output {output.distances}
"""

rule embed_with_tsne:
input:
alignments = expand("results/{{subtype}}/{segment}/{{time}}/aligned.sorted.fasta", segment=config["segments"]),
distances = expand("results/{{subtype}}/{segment}/{{time}}/distances.csv", segment=config["segments"]),
output:
embedding = "results/{subtype}/all/{time}/embed_tsne.csv",
params:
perplexity=config.get("embedding", {}).get("perplexity", 200),
benchmark:
"benchmarks/embed_with_tsne_{subtype}_{time}.txt"
shell:
"""
pathogen-embed \
--alignment {input.alignments} \
--distance-matrix {input.distances} \
--output-dataframe {output.embedding} \
t-sne \
--perplexity {params.perplexity}
"""

rule cluster_tsne_embedding:
input:
embedding = "results/{subtype}/all/{time}/embed_tsne.csv",
output:
clusters = "results/{subtype}/all/{time}/cluster_embed_tsne.csv",
params:
label_attribute="tsne_cluster",
distance_threshold=1.0,
benchmark:
"benchmarks/cluster_tsne_embedding_{subtype}_{time}.txt"
shell:
"""
pathogen-cluster \
--embedding {input.embedding} \
--label-attribute {params.label_attribute:q} \
--distance-threshold {params.distance_threshold} \
--output-dataframe {output.clusters}
"""

rule convert_embedding_clusters_to_node_data:
input:
clusters = "results/{subtype}/all/{time}/cluster_embed_tsne.csv",
output:
node_data = "results/{subtype}/all/{time}/cluster_embed_tsne.json",
shell:
"""
python3 scripts/table_to_node_data.py \
--table {input.clusters} \
--output {output.node_data}
"""

def export_node_data_files(wildcards):
nd = [
rules.refine.output.node_data,
Expand All @@ -502,6 +601,10 @@ def export_node_data_files(wildcards):

if wildcards.subtype=="h5n1-cattle-outbreak" and wildcards.segment!='genome':
nd.append(rules.prune_tree.output.node_data)

if config.get("embedding"):
nd.append(rules.convert_embedding_clusters_to_node_data.output.node_data)

return nd


Expand Down
5 changes: 4 additions & 1 deletion config/gisaid.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ local_ingest: false

#### Parameters which control large overarching aspects of the build
target_sequences_per_tree: 3000
same_strains_per_segment: false
same_strains_per_segment: true


#### Config files ####
Expand Down Expand Up @@ -159,6 +159,9 @@ traits:
confidence:
FALLBACK: true

embedding:
perplexity: 200

export:
title:
FALLBACK: false # use the title in the auspice JSON
Expand Down
20 changes: 18 additions & 2 deletions config/h5n1/auspice_config_h5n1.json
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
"key": "division",
"title": "Admin Division",
"type": "categorical"
},
},
{
"key": "host",
"title": "Host",
Expand Down Expand Up @@ -89,6 +89,21 @@
"key": "submitting_lab",
"title": "Submitting Lab",
"type": "categorical"
},
{
"key": "tsne_x",
"title": "t-SNE 1",
"type": "continuous"
},
{
"key": "tsne_y",
"title": "t-SNE 2",
"type": "continuous"
},
{
"key": "tsne_cluster",
"title": "t-SNE cluster",
"type": "categorical"
}
],
"geo_resolutions": [
Expand All @@ -111,6 +126,7 @@
"gisaid_clade",
"authors",
"originating_lab",
"submitting_lab"
"submitting_lab",
"tsne_cluster"
]
}
18 changes: 17 additions & 1 deletion config/h5nx/auspice_config_h5nx.json
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,21 @@
"key": "submitting_lab",
"title": "Submitting Lab",
"type": "categorical"
},
{
"key": "tsne_x",
"title": "t-SNE 1",
"type": "continuous"
},
{
"key": "tsne_y",
"title": "t-SNE 2",
"type": "continuous"
},
{
"key": "tsne_cluster",
"title": "t-SNE cluster",
"type": "categorical"
}
],
"geo_resolutions": [
Expand All @@ -111,6 +126,7 @@
"gisaid_clade",
"authors",
"originating_lab",
"submitting_lab"
"submitting_lab",
"tsne_cluster"
]
}
18 changes: 17 additions & 1 deletion config/h7n9/auspice_config_h7n9.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,21 @@
"key": "submitting_lab",
"title": "Submitting Lab",
"type": "categorical"
},
{
"key": "tsne_x",
"title": "t-SNE 1",
"type": "continuous"
},
{
"key": "tsne_y",
"title": "t-SNE 2",
"type": "continuous"
},
{
"key": "tsne_cluster",
"title": "t-SNE cluster",
"type": "categorical"
}
],
"geo_resolutions": [
Expand All @@ -70,6 +85,7 @@
"country",
"division",
"originating_lab",
"submitting_lab"
"submitting_lab",
"tsne_cluster"
]
}
18 changes: 17 additions & 1 deletion config/h9n2/auspice_config_h9n2.json
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,21 @@
"key": "submitting_lab",
"title": "Submitting Lab",
"type": "categorical"
},
{
"key": "tsne_x",
"title": "t-SNE 1",
"type": "continuous"
},
{
"key": "tsne_y",
"title": "t-SNE 2",
"type": "continuous"
},
{
"key": "tsne_cluster",
"title": "t-SNE cluster",
"type": "categorical"
}
],
"geo_resolutions": [
Expand All @@ -69,6 +84,7 @@
"region",
"country",
"originating_lab",
"submitting_lab"
"submitting_lab",
"tsne_cluster"
]
}
23 changes: 23 additions & 0 deletions scripts/intersect_items.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
#!/usr/bin/env python3
import argparse


if __name__ == '__main__':
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument("--items", nargs="+", required=True, help="one or more files containing a list of items")
parser.add_argument("--output", required=True, help="list of items shared by all input files (the intersection)")

args = parser.parse_args()

with open(args.items[0], "r", encoding="utf-8") as fh:
shared_items = {line.strip() for line in fh}

for item_file in args.items[1:]:
with open(item_file, "r", encoding="utf-8") as fh:
items = {line.strip() for line in fh}

shared_items = shared_items & items

with open(args.output, "w", encoding="utf-8") as oh:
for item in sorted(shared_items):
print(item, file=oh)
32 changes: 32 additions & 0 deletions scripts/table_to_node_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""Create Augur-compatible node data JSON from a pandas data frame.
"""
import argparse
import pandas as pd
from augur.utils import write_json


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--table", help="table to convert to a node data JSON")
parser.add_argument("--index-column", default="strain", help="name of the column to use as an index")
parser.add_argument("--delimiter", default=",", help="separator between columns in the given table")
parser.add_argument("--node-name", default="nodes", help="name of the node data attribute in the JSON output")
parser.add_argument("--output", help="node data JSON file")

args = parser.parse_args()

if args.output is not None:
table = pd.read_csv(
args.table,
sep=args.delimiter,
index_col=args.index_column,
dtype=str,
)

# # Convert columns that aren't strain names or labels to floats.
# for column in table.columns:
# if column != "strain" and not "label" in column:
# table[column] = table[column].astype(float)

table_dict = table.transpose().to_dict()
write_json({args.node_name: table_dict}, args.output)

0 comments on commit 4f31eb8

Please sign in to comment.