diff --git a/README.md b/README.md index 9c505f6b..43c6dd1f 100644 --- a/README.md +++ b/README.md @@ -116,6 +116,9 @@ optional arguments: --label-fields LABEL_FIELDS Comma separated string of fields to add to tree report labels. + --date-fields DATE_FIELDS + Comma separated string of date information to include. + Default = "sample_date" --node-summary COLUMN Choose which metadata column to summarise collapsed nodes by. --search-field SEARCH_FIELD Option to search COG database for a different id type. diff --git a/civet/command.py b/civet/command.py index 4a8f63fd..7405ad7e 100644 --- a/civet/command.py +++ b/civet/command.py @@ -45,6 +45,7 @@ def main(sysargs = sys.argv[1:]): parser.add_argument('--fields', action="store",help="Comma separated string of fields to display in the trees in the report. Default: country") parser.add_argument('--display', action="store", help="Comma separated string of fields to display as coloured dots rather than text in report trees. Optionally add colour scheme eg adm1=viridis", dest="display") parser.add_argument('--label-fields', action="store", help="Comma separated string of fields to add to tree report labels.", dest="label_fields") + parser.add_argument("--date-fields", action="store", help="Comma separated string of metadata headers containing date information.", dest="date_fields") parser.add_argument("--node-summary", action="store", help="Column to summarise collapsed nodes by. Default = Global lineage", dest="node_summary") parser.add_argument('--search-field', action="store",help="Option to search COG database for a different id type. Default: COG-UK ID", dest="search_field",default="central_sample_id") parser.add_argument('--distance', action="store",help="Extraction from large tree radius. Default: 2", dest="distance",default=2) @@ -68,9 +69,8 @@ def main(sysargs = sys.argv[1:]): parser.add_argument('--date-window',action="store",default=7, type=int, dest="date_window",help="Define the window +- either side of cluster sample collection date-range. Default is 7 days.") parser.add_argument("-v","--version", action='version', version=f"civet {__version__}") - parser.add_argument("--map-sequences", action="store_true", dest="map_sequences", help="Map the coordinate points of sequences, coloured by a triat.") - parser.add_argument("--x-col", required=False, dest="x_col", help="column containing x coordinate for mapping sequences") - parser.add_argument("--y-col", required=False, dest="y_col", help="column containing y coordinate for mapping sequences") + parser.add_argument("--map-sequences", action="store_true", dest="map_sequences", help="Map the coordinate points of sequences, coloured by a trait.") + parser.add_argument("--map-inputs", required=False, dest="map_inputs", help="columns containing EITHER x and y coordinates as a comma separated string OR outer postcodes for mapping sequences") parser.add_argument("--input-crs", required=False, dest="input_crs", help="Coordinate reference system of sequence coordinates") parser.add_argument("--mapping-trait", required=False, dest="mapping_trait", help="Column to colour mapped sequences by") @@ -91,6 +91,9 @@ def main(sysargs = sys.argv[1:]): 'gnuplot', 'gnuplot2', 'CMRmap', 'cubehelix', 'brg', 'gist_rainbow', 'rainbow', 'jet', 'nipy_spectral', 'gist_ncar'] + full_metadata_headers = ["central_sample_id", "biosample_source_id","sequence_name","secondary_identifier","sample_date","epi_week","country","adm1","adm2","outer_postcode","is_surveillance","is_community","is_hcw","is_travel_history","travel_history","lineage","lineage_support","uk_lineage","acc_lineage","del_lineage","phylotype"] + + # Exit with help menu if no args supplied if len(sysargs)<1: parser.print_help() @@ -178,6 +181,7 @@ def main(sysargs = sys.argv[1:]): labels = [] queries = [] graphics_list = [] + date_lst = [] with open(query, newline="") as f: reader = csv.DictReader(f) @@ -192,10 +196,10 @@ def main(sysargs = sys.argv[1:]): else: desired_fields = args.fields.split(",") for field in desired_fields: - if field in reader.fieldnames: + if field in reader.fieldnames or field in full_metadata_headers: fields.append(field) else: - sys.stderr.write(f"Error: {field} field not found in metadata file") + sys.stderr.write(f"Error: {field} field not found in metadata file or full metadata file") sys.exit(-1) @@ -204,10 +208,21 @@ def main(sysargs = sys.argv[1:]): else: label_fields = args.label_fields.split(",") for label_f in label_fields: - if label_f in reader.fieldnames: + if label_f in reader.fieldnames or label_f in full_metadata_headers: labels.append(label_f) else: - sys.stderr.write(f"Error: {label_f} field not found in metadata file") + sys.stderr.write(f"Error: {label_f} field not found in metadata file or full metadata file") + sys.exit(-1) + + if not args.date_fields: + date_lst.append("NONE") + else: + date_fields = args.date_fields.split(",") + for date_f in date_fields: + if date_f in reader.fieldnames or date_f in full_metadata_headers: + date_lst.append(date_f) + else: + sys.stderr.write(f"Error: {date_f} field not found in query metadata file or full metadata file") sys.exit(-1) if args.display: @@ -251,6 +266,7 @@ def main(sysargs = sys.argv[1:]): "query":query, "fields":",".join(fields), "label_fields":",".join(labels), + "date_fields":",".join(date_lst), "outdir":outdir, "tempdir":tempdir, "trim_start":265, # where to pad to using datafunk @@ -269,26 +285,34 @@ def main(sysargs = sys.argv[1:]): config["add_boxplots"]= True else: config["add_boxplots"]= False - + if args.map_sequences: config["map_sequences"] = True - if not args.x_col or not args.y_col: - sys.stderr.write('Error: coordinates not supplied for mapping sequences. Please provide --x-col and --y-col') - sys.exit(-1) - elif not args.input_crs: - sys.stderr.write('Error: input coordinate system not provided for mapping. Please provide --input-crs eg EPSG:3395') + if not args.map_inputs: + sys.stderr.write('Error: coordinates or outer postcode not supplied for mapping sequences. Please provide either x and y columns as a comma separated string, or column header containing outer postcode.') sys.exit(-1) + if len(args.map_inputs.split(",")) == 2: + if not args.input_crs: + sys.stderr.write('Error: input coordinate system not provided for mapping. Please provide --input-crs eg EPSG:3395') + sys.exit(-1) + else: + input_crs = args.input_crs else: - config["x_col"] = args.x_col - config["y_col"] = args.y_col - config["input_crs"] = args.input_crs + input_crs = "EPSG:4326" + + config["map_cols"] = args.map_inputs + config["input_crs"] = input_crs + relevant_cols = [] with open(query, newline="") as f: reader = csv.DictReader(f) column_names = reader.fieldnames - relevant_cols = [args.x_col, args.y_col, args.mapping_trait] + map_cols = args.map_inputs.split(",") + for i in map_cols: + relevant_cols.append(i) + relevant_cols.append(args.mapping_trait) + for map_arg in relevant_cols: - if map_arg and map_arg not in reader.fieldnames: sys.stderr.write(f"Error: {map_arg} field not found in metadata file") sys.exit(-1) @@ -300,8 +324,7 @@ def main(sysargs = sys.argv[1:]): else: config["map_sequences"] = False - config["x_col"] = False - config["y_col"] = False + config["map_cols"] = False config["input_crs"] = False config["mapping_trait"] = False @@ -516,6 +539,7 @@ def main(sysargs = sys.argv[1:]): map_input_3 = pkg_resources.resource_filename('civet', 'data/mapping_files/NI_counties.geojson') map_input_4 = pkg_resources.resource_filename('civet', 'data/mapping_files/Mainland_HBs_gapclosed_mapshaped_d3.json') map_input_5 = pkg_resources.resource_filename('civet', 'data/mapping_files/urban_areas_UK.geojson') + map_input_6 = pkg_resources.resource_filename('civet', 'data/mapping_files/UK_outPC_coords.csv') spatial_translations_1 = pkg_resources.resource_filename('civet', 'data/mapping_files/HB_Translation.pkl') spatial_translations_2 = pkg_resources.resource_filename('civet', 'data/mapping_files/adm2_regions_to_coords.csv') @@ -555,6 +579,7 @@ def main(sysargs = sys.argv[1:]): config["ni_map"] = map_input_3 config["uk_map_d3"] = map_input_4 config["urban_centres"] = map_input_5 + config["pc_file"] = map_input_6 config["HB_translations"] = spatial_translations_1 config["PC_translations"] = spatial_translations_2 diff --git a/civet/scripts/COG_template.pmd b/civet/scripts/COG_template.pmd index e75267b1..bd8237ef 100644 --- a/civet/scripts/COG_template.pmd +++ b/civet/scripts/COG_template.pmd @@ -193,7 +193,7 @@ else: Their location is shown on the following map, with dots sized by the number of sequences from each location. ```python, name="make_map", echo=False, include=False, results='tex' -map = mapping.run_map_functions(query_dict, clean_locs_file, mapping_json_files) +map = mapping.map_adm2(query_dict, clean_locs_file, mapping_json_files) ``` ```python, name="show_map", echo=False, results='raw' diff --git a/civet/scripts/assess_input_file.smk b/civet/scripts/assess_input_file.smk index c99a985d..c9911aa3 100644 --- a/civet/scripts/assess_input_file.smk +++ b/civet/scripts/assess_input_file.smk @@ -285,16 +285,17 @@ rule find_snps: tempdir= config["tempdir"], path = workflow.current_basedir, threshold = config["threshold"], - + fasta = config["fasta"], tree_dir = os.path.join(config["outdir"],"local_trees"), cores = workflow.cores, force = config["force"], quiet_mode = config["quiet_mode"] + output: - genome_graph = os.path.join(config["outdir"],"figures","genome_graph.png"), - report = os.path.join(config["outdir"],"snp_reports","snp_reports.txt") + genome_graphs = os.path.join(config["outdir"],"snp_reports","tree_subtree_1.snps.txt"), #this obviously isn't ideal because it's not flexible to name stem changes + reports = os.path.join(config["outdir"],"figures","genome_graph_tree_subtree_1.png") run: local_trees = [] for r,d,f in os.walk(params.tree_dir): @@ -410,9 +411,10 @@ rule make_report: uk_map = config["uk_map"], channels_map = config["channels_map"], ni_map = config["ni_map"], + pc_file = config["pc_file"], urban_centres = config["urban_centres"], - genome_graph = rules.find_snps.output.genome_graph, - snp_report = rules.find_snps.output.report, + genome_graph = rules.find_snps.output.genome_graphs, #do these two arguments need to be here? + snp_report = rules.find_snps.output.reports, central = os.path.join(config["outdir"], 'figures', "central_map_ukLin.png"), neighboring = os.path.join(config["outdir"], 'figures', "neighboring_map_ukLin.png"), region = os.path.join(config["outdir"], 'figures', "region_map_ukLin.png") @@ -421,6 +423,7 @@ rule make_report: outdir = config["rel_outdir"], fields = config["fields"], label_fields = config["label_fields"], + date_fields = config["date_fields"], node_summary = config["node_summary"], sc_source = config["sequencing_centre"], sc = config["sequencing_centre_file"], @@ -430,8 +433,7 @@ rule make_report: figdir = os.path.join(config["outdir"],"figures"), failure = config["qc_fail_report"], map_sequences = config["map_sequences"], - x_col = config["x_col"], - y_col = config["y_col"], + map_cols = config["map_cols"], input_crs = config["input_crs"], mapping_trait = config["mapping_trait"], add_boxplots = config["add_boxplots"], @@ -473,6 +475,7 @@ rule make_report: "-f {params.fields:q} " "--graphic_dict {params.graphic_dict:q} " "--label-fields {params.label_fields:q} " + "--date-fields {params.date_fields:q} " "--node-summary {params.node_summary} " "--figdir {params.rel_figdir:q} " "{params.sc_flag} " @@ -485,12 +488,12 @@ rule make_report: "--uk-map {input.uk_map:q} " "--channels-map {input.channels_map:q} " "--ni-map {input.ni_map:q} " + "--pc-file {input.pc_file:q} " "--outfile {output.outfile:q} " "--outdir {params.outdir:q} " "--map-sequences {params.map_sequences} " "--snp-report {input.snp_report:q} " - "--x-col {params.x_col} " - "--y-col {params.y_col} " + "--map-cols {params.map_cols} " "--input-crs {params.input_crs} " "--mapping-trait {params.mapping_trait} " "--urban-centres {input.urban_centres} " diff --git a/civet/scripts/civet_template.pmd b/civet/scripts/civet_template.pmd index 15818cbc..496e8e43 100644 --- a/civet/scripts/civet_template.pmd +++ b/civet/scripts/civet_template.pmd @@ -25,6 +25,7 @@ import os import data_parsing as dp import make_tree_figures as tree_viz import mapping as mapping +import time_functions as time_func import matplotlib.font_manager as font_manager import matplotlib as mpl from collections import defaultdict @@ -66,6 +67,7 @@ label_fields_input = "" ##CHANGE graphic_dict_input = "" ##CHANGE summary_dir = "" ##CHANGE node_summary_option = "" ##CHANGE +date_fields_input = ##CHANGE tree_name_stem = "" ##CHANGE @@ -89,7 +91,11 @@ if label_fields_input != "NONE": for i in options: label_fields.append(i) - +date_fields = [] +if date_fields_input != "NONE": + options = date_fields_input.split(",") + for i in options: + date_fields.append(i) snp_report = "" ##CHANGE @@ -97,15 +103,15 @@ local_lineages = "" ##CHANGE local_lin_maps = "" ##CHANGE local_lin_tables = "" ##CHANGE -map_sequences = "" ##CHANGE #going to be true or false -x_col = "" ##CHANGE -y_col = "" ##CHANGE +map_sequences = "" ##CHANGE +map_inputs = "" ##CHANGE mapping_trait = "" ##CHANGE input_crs = "" ##CHANGE uk_map = "" ##CHANGE channels_map = "" ##CHANGE ni_map = "" ##CHANGE urban_centres = "" ##CHANGE +pc_file = "" ##CHANGE mapping_json_files = [uk_map, channels_map, ni_map] @@ -124,7 +130,7 @@ pyplot.rcParams.update({'figure.max_open_warning': 0}) mpl.rcParams['font.weight']=50 mpl.rcParams['axes.labelweight']=50 -inputs = [desired_fields, label_fields, graphic_dict, node_summary_option, map_sequences, x_col, y_col, mapping_trait] +inputs = [desired_fields, label_fields, graphic_dict, node_summary_option, map_sequences, mapping_trait, map_inputs] ``` @@ -132,20 +138,27 @@ inputs = [desired_fields, label_fields, graphic_dict, node_summary_option, map_s ```python, name="parse metadata", echo=False, results='raw' -present_in_tree, tip_to_tree = dp.parse_tree_tips(tree_dir) +present_in_tree, tip_to_tree, tree_list = dp.parse_tree_tips(tree_dir) -query_dict, query_id_dict, present_lins, tree_to_tip = dp.parse_filtered_metadata(filtered_cog_metadata, tip_to_tree) #Just the lines with their queries plus the closest match in COG +query_dict, query_id_dict, present_lins, tree_to_tip = dp.parse_filtered_metadata(filtered_cog_metadata, tip_to_tree, label_fields, desired_fields) #Just the lines with their queries plus the closest match in COG if input_csv != '': adm2_to_adm1 = dp.prepping_adm2_adm1_data(full_metadata_file) - query_dict = dp.parse_input_csv(input_csv, query_id_dict, desired_fields, label_fields, adm2_to_adm1, False) #Any query information they have provided + query_dict = dp.parse_input_csv(input_csv, query_id_dict, desired_fields, label_fields, date_fields, adm2_to_adm1, False) #Any query information they have provided + +full_tax_dict = dp.parse_full_metadata(query_dict, label_fields, desired_fields, full_metadata_file, present_lins, present_in_tree, node_summary_option, date_fields) -full_tax_dict = dp.parse_full_metadata(query_dict, full_metadata_file, present_lins, present_in_tree, node_summary_option) + +time_outputs = time_func.summarise_dates(query_dict) +if type(time_outputs) != bool: + overall_dates, max_overall_date, min_overall_date, max_string, min_string = time_outputs + dates_present = True +else: + dates_present = False ``` ```python, name="QC fails", echo=False, results="raw" count_fails = 0 - if QC_fail_file != "": with open(QC_fail_file) as f: next(f) @@ -175,13 +188,21 @@ for tax in query_dict.values(): try: dp.analyse_inputs(inputs) except: - print("inputs failed to print.") + print("Failed to print inputs.") + print("\n") print(str(number_seqs) + " queries (" + str(cog_number) + " matched to COG-UK database).") print(str(not_in_cog_number) + " additional sequences provided.") +if dates_present: + print("Time fields provided: " + ",".join(date_fields)) + print("Earliest date: " + min_string) + print("Latest date: " + max_string) +else: + print("No time information provided") + ``` @@ -204,10 +225,9 @@ elif not_in_cog_number == 0 and cog_number != 0: print(df_cog.to_markdown()) ``` - ```python, name="make_trees", echo=False, include=False, figure=False -too_tall_trees, overall_tree_number, colour_dict_dict, overall_df_dict = tree_viz.make_all_of_the_trees(tree_dir, tree_name_stem, full_tax_dict, query_dict, desired_fields, label_fields, graphic_dict) +too_tall_trees, overall_tree_number, colour_dict_dict, overall_df_dict, tree_order = tree_viz.make_all_of_the_trees(tree_dir, tree_name_stem, full_tax_dict, query_dict, desired_fields, label_fields, graphic_dict) ``` ```python, name="make_legend", echo=False, include=False, results='tex' for trait, colour_dict in colour_dict_dict.items(): @@ -216,6 +236,18 @@ for trait, colour_dict in colour_dict_dict.items(): if number_of_options > 15: print("WARNING: There are more than 15 options to colour by for " + trait + ", which will make it difficult to see the differences between colours. Consider adding the trait to the taxon labels on the tree by using the flag _--label-fields_ when calling CIVET.") ``` +```python, name="time_plot", echo=False, results='raw', include=False +if dates_present: + count = 0 + tree_to_time_series = {} + for tree in tree_order: + count += 1 + lookup = f"{tree_name_stem}_{tree}" + tips = tree_to_tip[lookup] + if len(tips) > 1: + time_func.plot_time_series(tips, query_dict, max_overall_date, min_overall_date, date_fields, label_fields) + tree_to_time_series[lookup] = count +``` ```python, name="show_trees", echo=False, results='raw' for i in range(1,overall_tree_number+1): @@ -234,6 +266,10 @@ for i in range(1,overall_tree_number+1): print(f'drawing') print("\n") + if dates_present and lookup in tree_to_time_series.keys(): + print("![](" + figdir + "/" + name_stem + "_time_plot_" + str(tree_to_time_series[lookup]) + ".png)") + + print(f"![]({figdir}/genome_graph_{lookup}.png)") ``` @@ -243,9 +279,6 @@ if too_tall_trees != []: print("Tree" + str(tree) + " is too large to be rendered here.") ``` - - - ```python, name="tree_background", echo=False, include=False, results='raw' if add_boxplots != "": print("""### Tree background\n\nThe following plots describe the data in the collapsed nodes in more detail.\nIf more than one country was present, the bar chart describes the number of sequences present in each country. \nWhere there were 10 countries or more, the largest 10 have been taken. \nIf a UK sequence is present in the collapsed node, it is always shown in the plot.\n\n""") @@ -256,16 +289,10 @@ if add_boxplots != "": ``` -### SNPs found in sequences of interest - -```python, name="genome graph", echo=False, results='raw' -print(f"![]({figdir}/genome_graph.png)") -``` - ```python, name="map_sequences", echo=False, results='raw', include=False if map_sequences != "False": print("## Plotting sequences") - adm2_in_map, adm2_percentages = mapping.map_traits(input_csv, input_crs, mapping_trait, x_col, y_col, mapping_json_files, urban_centres) + adm2_in_map, adm2_percentages = mapping.map_sequences_using_coordinates(input_csv, mapping_json_files, urban_centres, pc_file, mapping_trait, map_inputs, input_crs) print("There are sequences from " + str(len(adm2_in_map)) + " admin2 regions") @@ -283,9 +310,6 @@ if map_sequences != "False": print("![](" + figdir + "/" + name_stem + "_map_sequences_1.png)") ``` - - - ```python, name='Regional-scale', echo=False, results='raw' if local_lineages != '': print("## Regional-scale background UK lineage mapping") diff --git a/civet/scripts/data_parsing.py b/civet/scripts/data_parsing.py index c3d8af6d..97b459d4 100644 --- a/civet/scripts/data_parsing.py +++ b/civet/scripts/data_parsing.py @@ -12,11 +12,13 @@ class taxon(): - def __init__(self, name, global_lin, uk_lin, phylotype): + def __init__(self, name, global_lin, uk_lin, phylotype, label_fields, tree_fields): self.name = name self.sample_date = "NA" + + self.date_dict = {} if global_lin == "": self.global_lin = "NA" @@ -36,6 +38,12 @@ def __init__(self, name, global_lin, uk_lin, phylotype): self.in_cog = False self.attribute_dict = {} self.attribute_dict["adm1"] = "NA" + + for i in label_fields: + self.attribute_dict[i] = "NA" + for i in tree_fields: + self.attribute_dict[i] = "NA" + self.tree = "NA" self.closest_distance = "NA" @@ -62,7 +70,7 @@ def __init__(self, name, taxa): def analyse_inputs(inputs): - desired_fields, label_fields, graphic_dict, node_summary_option, map_sequences, x_col, y_col, mapping_trait = inputs + desired_fields, label_fields, graphic_dict, node_summary_option, map_sequences, mapping_trait, map_inputs = inputs print("Showing " + ",".join(desired_fields) + " on the tree.") print(",".join(list(graphic_dict.keys())) + " fields are displayed graphically using " + ",".join(list(graphic_dict.values())) + " colour schemes respectively.") @@ -73,10 +81,17 @@ def analyse_inputs(inputs): print("Summarising nodes by " + node_summary_option) if map_sequences != "False": - if mapping_trait != "False": - print("Mapping sequences using columns " + x_col + " " + y_col + " for x values and y values respectively, and colouring by " + mapping_trait) + map_args = map_inputs.split(",") + if len(map_args) == 2: + if mapping_trait != "False": + print("Mapping sequences using columns " + map_args[0] + " " + map_args[1] + " for x values and y values respectively, and colouring by " + mapping_trait) + else: + print("Mapping sequences using columns " + map_args[0] + " " + map_args[1] + " for x values and y values respectively.") else: - print("Mapping sequences using columns " + x_col + " " + y_col + " for x values and y values respectively.") + if mapping_trait != "False": + print("Mapping sequences using columns " + map_args[0] + " for outer postcodes, and colouring by " + mapping_trait) + else: + print("Mapping sequences using columns " + map_args[0] + " for outer postocdes.") @@ -107,7 +122,7 @@ def prepping_adm2_adm1_data(full_metadata): return adm2_adm1 -def parse_filtered_metadata(metadata_file, tip_to_tree): +def parse_filtered_metadata(metadata_file, tip_to_tree, label_fields, tree_fields): query_dict = {} query_id_dict = {} @@ -137,7 +152,7 @@ def parse_filtered_metadata(metadata_file, tip_to_tree): sample_date = sequence["sample_date"] - new_taxon = taxon(query_name, glob_lin, uk_lineage, phylotype) + new_taxon = taxon(query_name, glob_lin, uk_lineage, phylotype, label_fields, tree_fields) new_taxon.query_id = query_id @@ -170,7 +185,14 @@ def parse_filtered_metadata(metadata_file, tip_to_tree): return query_dict, query_id_dict, present_lins, tree_to_tip -def parse_input_csv(input_csv, query_id_dict, desired_fields, label_fields, adm2_adm1_dict, cog_report): +def convert_date(date_string): + bits = date_string.split("-") + date_dt = dt.date(int(bits[0]),int(bits[1]), int(bits[2])) + + return date_dt + + +def parse_input_csv(input_csv, query_id_dict, desired_fields, label_fields, date_fields, adm2_adm1_dict, cog_report): new_query_dict = {} contract_dict = {"SCT":"Scotland", "WLS": "Wales", "ENG":"England", "NIR": "Northern_Ireland"} cleaning = {"SCOTLAND":"Scotland", "WALES":"Wales", "ENGLAND":"England", "NORTHERN_IRELAND": "Northern_Ireland", "NORTHERN IRELAND": "Northern_Ireland"} @@ -190,34 +212,45 @@ def parse_input_csv(input_csv, query_id_dict, desired_fields, label_fields, adm2 if name in query_id_dict.keys(): taxon = query_id_dict[name] + for field in date_fields: + if field in reader.fieldnames: + if sequence[field] != "": + date_dt = convert_date(sequence[field]) + taxon.date_dict[field] = date_dt + + #keep this separate to above, because sample date is specifically needed if "sample_date" in col_names: #if it's not in COG but date is provided (if it's in COG, it will already have been assigned a sample date.) if sequence["sample_date"] != "": taxon.sample_date = sequence["sample_date"] for col in col_names: #Add other metadata fields provided - if col != "name" and (col in desired_fields or col in label_fields) and col != "adm1": + if col in label_fields: if sequence[col] == "": - taxon.attribute_dict[col] = "NA" + taxon.attribute_dict[col] = "NA" else: taxon.attribute_dict[col] = sequence[col] - - if col == "adm1": - if "UK" in sequence[col]: - adm1_prep = sequence[col].split("-")[1] - adm1 = contract_dict[adm1_prep] - else: - if sequence[col].upper() in cleaning.keys(): - adm1 = cleaning[sequence[col].upper()] + else: + if col != "name" and col in desired_fields and col != "adm1": + if sequence[col] != "": + taxon.attribute_dict[col] = sequence[col] + + if col == "adm1": + if "UK" in sequence[col]: + adm1_prep = sequence[col].split("-")[1] + adm1 = contract_dict[adm1_prep] else: - adm1 = sequence[col] + if sequence[col].upper() in cleaning.keys(): + adm1 = cleaning[sequence[col].upper()] + else: + adm1 = sequence[col] - taxon.attribute_dict["adm1"] = adm1 - - if col == "adm2" and "adm1" not in col_names: #or sequence["adm1"] == ""): - if sequence[col] in adm2_adm1_dict.keys(): - adm1 = adm2_adm1_dict[sequence[col]] taxon.attribute_dict["adm1"] = adm1 + if col == "adm2" and "adm1" not in col_names: #or sequence["adm1"] == ""): + if sequence[col] in adm2_adm1_dict.keys(): + adm1 = adm2_adm1_dict[sequence[col]] + taxon.attribute_dict["adm1"] = adm1 + if cog_report: taxon.attribute_dict["adm2"]= sequence["adm2"] if "collection_date" in reader.fieldnames: @@ -232,9 +265,6 @@ def parse_input_csv(input_csv, query_id_dict, desired_fields, label_fields, adm2 new_query_dict[taxon.name] = taxon - - # else: - # print(name + " is in the input file but not the processed file. This suggests that it is not in COG and a sequence has also not been provided.") return new_query_dict @@ -243,11 +273,13 @@ def parse_tree_tips(tree_dir): tips = [] tip_to_tree = {} + tree_list = [] for fn in os.listdir(tree_dir): if fn.endswith("tree"): tree_name = fn.split(".")[0] tree = bt.loadNewick(tree_dir + "/" + fn, absoluteTime=False) + tree_list.append(tree_name) for k in tree.Objects: if k.branchType == 'leaf' and "inserted" not in k.name: tips.append(k.name) @@ -260,12 +292,17 @@ def parse_tree_tips(tree_dir): tip_list = tip_string.split(",") tips.extend(tip_list) - return tips, tip_to_tree + return tips, tip_to_tree, tree_list -def parse_full_metadata(query_dict, full_metadata, present_lins, present_in_tree, node_summary_option): +def parse_full_metadata(query_dict, label_fields, tree_fields,full_metadata, present_lins, present_in_tree, node_summary_option, date_fields): full_tax_dict = query_dict.copy() + with open(full_metadata, 'r') as f: + reader = csv.DictReader(f) + col_name_prep = next(reader) + col_names = list(col_name_prep.keys()) + with open(full_metadata, 'r') as f: reader = csv.DictReader(f) in_data = [r for r in reader] @@ -284,7 +321,7 @@ def parse_full_metadata(query_dict, full_metadata, present_lins, present_in_tree if (uk_lin in present_lins or seq_name in present_in_tree) and seq_name not in query_dict.keys(): - new_taxon = taxon(seq_name, glob_lin, uk_lin, phylotype) + new_taxon = taxon(seq_name, glob_lin, uk_lin, phylotype, label_fields, tree_fields) if date == "": date = "NA" @@ -304,9 +341,26 @@ def parse_full_metadata(query_dict, full_metadata, present_lins, present_in_tree tax_object = query_dict[seq_name] if tax_object.sample_date == "NA" and date != "": tax_object.sample_date = date + tax_object.all_dates.append(convert_date(date)) if "adm2" not in tax_object.attribute_dict.keys() and adm2 != "": tax_object.attribute_dict["adm2"] = adm2 - + + for field in date_fields: + if field in reader.fieldnames: + if sequence[field] != "" and field not in tax_object.date_dict.keys(): + date_dt = convert_date(sequence[field]) + tax_object.date_dict[field] = date_dt + + for field in label_fields: + if field in col_names: + if tax_object.attribute_dict[field] == "NA" and sequence[field] != "": + tax_object.attribute_dict[field] = sequence[field] + + for field in tree_fields: + if field in col_names: + if tax_object.attribute_dict[field] == "NA" and sequence[field] != "": + tax_object.attribute_dict[field] = sequence[field] + full_tax_dict[seq_name] = tax_object return full_tax_dict @@ -364,7 +418,7 @@ def make_initial_table(query_dict, desired_fields, label_fields, cog_report): if label_fields != []: for i in label_fields: - if i not in desired_fields: + if i not in desired_fields and i != "sample_date" and i != "name": df_dict[i].append(query.attribute_dict[i]) if cog_report: diff --git a/civet/scripts/find_ambiguities.py b/civet/scripts/find_ambiguities.py index 4d53679f..f54f5551 100644 --- a/civet/scripts/find_ambiguities.py +++ b/civet/scripts/find_ambiguities.py @@ -16,23 +16,21 @@ def parse_args(): return parser.parse_args() -def find_snps(): - - args = parse_args() +def find_snps(input_file, output, report): input_seqs = collections.defaultdict(list) outgroup_seq = "" - for record in SeqIO.parse(args.input, "fasta"): + for record in SeqIO.parse(input_file, "fasta"): if record.id == "outgroup": outgroup_seq = record.seq.upper() else: input_seqs[str(record.seq).upper()].append(record.id) - with open(args.output, "w") as fw: + with open(output, "w") as fw: fw.write("name\tnum_snps\tambiguous_snps\n") snp_dict = collections.defaultdict(list) - with open(args.report, newline="") as f: + with open(report, newline="") as f: reader = csv.DictReader(f, delimiter="\t") for row in reader: snps = row["snps"].split(";") @@ -59,4 +57,6 @@ def find_snps(): if __name__ == '__main__': - find_snps() \ No newline at end of file + args = parse_args() + + find_snps(args.input, args.output, args.report) \ No newline at end of file diff --git a/civet/scripts/find_snps.py b/civet/scripts/find_snps.py index faef69e7..9a5a5c2c 100644 --- a/civet/scripts/find_snps.py +++ b/civet/scripts/find_snps.py @@ -30,8 +30,9 @@ def find_snps(): non_amb = ["A","T","G","C"] snp_dict = collections.defaultdict(list) + with open(args.output, "w") as fw: - + fw.write("name\ttree\tnum_snps\tsnps\n") for query_seq in input_seqs: snps =[] diff --git a/civet/scripts/find_snps.smk b/civet/scripts/find_snps.smk index 80e5ab0d..48a33fb6 100644 --- a/civet/scripts/find_snps.smk +++ b/civet/scripts/find_snps.smk @@ -3,13 +3,13 @@ from Bio import SeqIO import csv import collections + config["tree_stems"] = config["catchment_str"].split(",") rule all: input: - os.path.join(config["outdir"],"snp_reports","snp_reports.txt"), - os.path.join(config["outdir"],"figures","genome_graph.png") - + expand(os.path.join(config["outdir"],"snp_reports","{tree}.snps.txt"),tree=config["tree_stems"]), + expand(os.path.join(config["outdir"],"figures","genome_graph_{tree}.png"), tree=config["tree_stems"]) rule extract_taxa: input: @@ -53,44 +53,46 @@ rule assess_snps: params: tree = "{tree}" output: - snp_report = os.path.join(config["tempdir"], "snp_reports","{tree}.snps.txt") + snp_report = os.path.join(config["outdir"], "snp_reports","{tree}.snps.txt") shell: """ find_snps.py --input {input.aln:q} --output {output.snp_report:q} --tree {params.tree} """ -rule gather_snp_reports: - input: - snps = expand(os.path.join(config["tempdir"],"snp_reports","{tree}.snps.txt"), tree=config["tree_stems"]) - output: - report = os.path.join(config["outdir"],"snp_reports","snp_reports.txt") - run: - with open(output.report, "w") as fw: - fw.write("name\ttree\tnum_snps\tsnps\n") - for report in input.snps: - fn = os.path.basename(report) - with open(report, "r") as f: - for l in f: - l = l.rstrip("\n") - fw.write(l + '\n') +# rule gather_snp_reports: +# input: +# snps = expand(os.path.join(config["tempdir"],"snp_reports","{tree}.snps.txt"), tree=config["tree_stems"]) +# output: +# report = os.path.join(config["outdir"],"snp_reports","snp_reports.txt") +# run: +# with open(output.report, "w") as fw: +# #fw.write("name\ttree\tnum_snps\tsnps\n") +# for report in input.snps: +# fn = os.path.basename(report) +# with open(report, "r") as f: +# for l in f: +# l = l.rstrip("\n") +# fw.write(l + '\n') -rule gather_snp_seqs: - input: - expand(os.path.join(config["tempdir"], "seqs_for_snps","{tree}.fasta"), tree=config["tree_stems"]) - output: - seqs = os.path.join(config["tempdir"], "seqs_for_snps","for_ambiguities.fasta") - run: - with open(output.seqs,"w") as fw: - for seq_file in input: - for record in SeqIO.parse(seq_file,"fasta"): - fw.write(f">{record.description}\n{record.seq}\n") +# rule gather_snp_seqs: +# input: +# expand(os.path.join(config["tempdir"], "seqs_for_snps","{tree}.fasta"), tree=config["tree_stems"]) +# output: +# seqs = os.path.join(config["tempdir"], "seqs_for_snps","for_ambiguities.fasta") +# run: +# with open(output.seqs,"w") as fw: +# for seq_file in input: +# for record in SeqIO.parse(seq_file,"fasta"): +# fw.write(f">{record.description}\n{record.seq}\n") rule ambiguities_at_snp_sites: input: - seqs = rules.gather_snp_seqs.output.seqs, - report = rules.gather_snp_reports.output.report + #seqs = rules.gather_snp_seqs.output.seqs, + #report = rules.gather_snp_reports.output.report + seqs = os.path.join(config["tempdir"], "seqs_for_snps", "{tree}.fasta"), + report = os.path.join(config["outdir"],"snp_reports", "{tree}.snps.txt") output: - snp_report = os.path.join(config["outdir"], "snp_reports","ambiguities.snps.txt") + snp_report = os.path.join(config["outdir"], "snp_reports", "ambiguities_{tree}.snps.txt") shell: """ find_ambiguities.py --input {input.seqs:q} --output {output.snp_report:q} --report {input.report:q} @@ -99,11 +101,13 @@ rule ambiguities_at_snp_sites: rule make_snp_figure: input: - rules.gather_snp_reports.output.report, - rules.ambiguities_at_snp_sites.output.snp_report + #rules.gather_snp_reports.output.report, #change this one to snps input in gather_snp_reports + #snps = expand(os.path.join(config["tempdir"],"snp_reports","{tree}.snps.txt"), tree=config["tree_stems"]), + ambs = os.path.join(config["outdir"], "snp_reports", "ambiguities_{tree}.snps.txt"), + snps = os.path.join(config["outdir"],"snp_reports","{tree}.snps.txt") output: - os.path.join(config["outdir"],"figures","genome_graph.png") + os.path.join(config["outdir"],"figures","genome_graph_{tree}.png") shell: """ - make_genome_graph.py --input {input[0]} --ambiguities {input[1]} --output {output[0]} + make_genome_graph.py --input {input.snps:q} --ambiguities {input.ambs:q} --output {output[0]} """ \ No newline at end of file diff --git a/civet/scripts/make_genome_graph.py b/civet/scripts/make_genome_graph.py index 3c961e75..841a8043 100644 --- a/civet/scripts/make_genome_graph.py +++ b/civet/scripts/make_genome_graph.py @@ -95,8 +95,6 @@ def make_graph(): # snp position labels - - for sequence in snp_dict[snp]: # sequence variant text name,ref,var,y_pos = sequence diff --git a/civet/scripts/make_report.py b/civet/scripts/make_report.py index 4accc91d..381f9c29 100644 --- a/civet/scripts/make_report.py +++ b/civet/scripts/make_report.py @@ -5,9 +5,10 @@ import shutil import sys + thisdir = os.path.abspath(os.path.dirname(__file__)) -def make_report(cog_metadata, input_csv, filtered_cog_metadata, outfile, outdir, treedir, figdir, snp_report, colour_fields, label_fields, node_summary, report_template, failed_seqs, seq_centre, clean_locs, uk_map, channels_map, ni_map, local_lineages, local_lin_maps, local_lin_tables,map_sequences,x_col,y_col, input_crs,mapping_trait,urban_centres,add_boxplots, graphic_dict): +def make_report(cog_metadata, input_csv, filtered_cog_metadata, outfile, outdir, treedir, figdir, snp_report, colour_fields, label_fields, node_summary, report_template, failed_seqs, seq_centre, clean_locs, uk_map, channels_map, ni_map, pc_file, local_lineages, local_lin_maps, local_lin_tables,map_sequences,map_inputs, input_crs,mapping_trait,urban_centres,add_boxplots, graphic_dict, date_fields): name_stem = ".".join(outfile.split(".")[:-1]) @@ -43,6 +44,7 @@ def make_report(cog_metadata, input_csv, filtered_cog_metadata, outfile, outdir, "input_directory": f'input_directory = "{treedir}"\n', "desired_fields_input": f'desired_fields_input = "{colour_fields}"\n', "label_fields_input": f'label_fields_input = "{label_fields}"\n', + "date_fields_input":f'date_fields_input = "{date_fields}"\n', "node_summary_option": f'node_summary_option = "{node_summary}"\n', "figdir": f'figdir = "{figdir}"\n', "tree_dir": f'tree_dir = "{treedir}"\n', @@ -55,12 +57,12 @@ def make_report(cog_metadata, input_csv, filtered_cog_metadata, outfile, outdir, "uk_map": f'uk_map = "{uk_map}"\n', "channels_map": f'channels_map = "{channels_map}"\n', "ni_map": f'ni_map = "{ni_map}"\n', + "pc_file": f'pc_file = "{pc_file}"\n', "local_lineages":f'local_lineages = "{local_lineages}"\n', "local_lin_maps" : f'local_lin_maps = "{local_lin_maps}"\n', "local_lin_tables" : f'local_lin_tables = "{local_lin_tables}"\n', "map_sequences":f'map_sequences = "{map_sequences}"\n', - "x_col":f'x_col = "{x_col}"\n', - "y_col":f'y_col = "{y_col}"\n', + "map_inputs":f'map_inputs = "{map_inputs}"\n', "mapping_trait":f'mapping_trait = "{mapping_trait}"\n', "input_crs":f'input_crs = "{input_crs}"\n', "urban_centres":f'urban_centres = "{urban_centres}"\n', @@ -89,6 +91,7 @@ def main(): parser.add_argument("-f", "--fields",default="", help="desired fields to colour trees by in report. Default=UK country",dest="colour_fields") parser.add_argument("-l", "--label-fields", default="", help="fields to add into labels in report trees. Default is adm2 and date", dest='label_fields') parser.add_argument("-gd", "--graphic_dict", default="", help="fields to colour by rather than display text. Add colour scheme optionally", dest="graphic_dict") + parser.add_argument("--date-fields", help="column headers containing date information as a a comma separated string.", dest="date_fields") parser.add_argument("--node-summary", action="store", help="field to summarise collapsed nodes by. Default=lineage", dest="node_summary") parser.add_argument("-sc", "--sequencing-centre",default="", help="Sequencing centre", dest="sc") @@ -108,12 +111,12 @@ def main(): parser.add_argument("--uk-map", required=True, help="shape file for uk counties", dest="uk_map") parser.add_argument("--channels-map", required=True, help="shape file for channel islands", dest="channels_map") parser.add_argument("--ni-map", required=True, help="shape file for northern irish counties", dest="ni_map") + parser.add_argument("--pc-file", required=True, help="file containing outer postcode to centroid mapping", dest="pc_file") parser.add_argument("--snp-report", required=True, help="snp report", dest="snp_report") parser.add_argument("--add-boxplots", action="store_true",dest="add_boxplots",default=False) parser.add_argument("--map-sequences", required=True, help="Bool for whether mapping of sequences by trait is required", dest="map_sequences") - parser.add_argument("--x-col", default="", help="column name in input csv which contains x coords for mapping", dest="x_col") - parser.add_argument("--y-col", default="", help="column name in input csv which contains y coords for mapping", dest="y_col") + parser.add_argument("--map-cols", default="", help="either column names in input csv which contains x coords and y coords for mapping as a comma separated string OR column name containing outer postcode", dest="map_inputs") parser.add_argument("--input-crs", default="", help="coordinate reference system that x and y inputs are in", dest="input_crs") parser.add_argument("--mapping-trait", default="", help="trait to map sequences by", dest="mapping_trait") parser.add_argument("--urban-centres", default="", help="geojson for plotting urban centres", dest="urban_centres") @@ -124,7 +127,7 @@ def main(): args = parser.parse_args() - make_report(args.cog_metadata, args.input_csv, args.filtered_cog_metadata, args.outfile, args.outdir, args.treedir, args.figdir, args.snp_report, args.colour_fields, args.label_fields, args.node_summary, args.report_template, args.failed_seqs, args.sc, args.clean_locs, args.uk_map, args.channels_map, args.ni_map, args.local_lineages, args.local_lin_maps, args.local_lin_tables,args.map_sequences, args.x_col, args.y_col, args.input_crs, args.mapping_trait, args.urban_centres,args.add_boxplots, args.graphic_dict) + make_report(args.cog_metadata, args.input_csv, args.filtered_cog_metadata, args.outfile, args.outdir, args.treedir, args.figdir, args.snp_report, args.colour_fields, args.label_fields, args.node_summary, args.report_template, args.failed_seqs, args.sc, args.clean_locs, args.uk_map, args.channels_map, args.ni_map, args.pc_file, args.local_lineages, args.local_lin_maps, args.local_lin_tables,args.map_sequences, args.map_inputs, args.input_crs, args.mapping_trait, args.urban_centres,args.add_boxplots, args.graphic_dict, args.date_fields) if __name__ == "__main__": diff --git a/civet/scripts/make_tree_figures.py b/civet/scripts/make_tree_figures.py index 7abc52bb..7b39be31 100644 --- a/civet/scripts/make_tree_figures.py +++ b/civet/scripts/make_tree_figures.py @@ -69,10 +69,11 @@ def display_name(tree, tree_name, tree_dir, full_taxon_dict, query_dict, custom_ for k in tree.Objects: if k.branchType == 'leaf': name = k.name - + display_name = "" if "inserted" in name: - collapsed_node_info = summarise_collapsed_node_for_label(tree_dir, name, tree_name, full_taxon_dict) + collapsed_node_info, number_nodes = summarise_collapsed_node_for_label(tree_dir, name, tree_name, full_taxon_dict) k.traits["display"] = collapsed_node_info + k.node_number = number_nodes else: if name in full_taxon_dict: taxon_obj = full_taxon_dict[name] @@ -84,19 +85,30 @@ def display_name(tree, tree_name, tree_dir, full_taxon_dict, query_dict, custom_ adm2 = taxon_obj.attribute_dict["adm2"] k.traits["display"] = f"{name}|{adm2}|{date}" - if name in query_dict.keys(): - if len(custom_tip_fields) > 0: + count = 0 + if len(custom_tip_fields) > 0: + if name in query_dict.keys(): for label_element in custom_tip_fields: - k.traits["display"] = k.traits["display"] + "|" + taxon_obj.attribute_dict[label_element] + if count == 0: + display_name = taxon_obj.attribute_dict[label_element] + else: + display_name = display_name + "|" + taxon_obj.attribute_dict[label_element] + count += 1 + + k.traits["display"] = display_name + + k.node_number = 1 else: if name.startswith("subtree"): number = name.split("_")[-1] - display = f"Tree {number}" - k.traits["display"] = display + display_name = f"Tree {number}" + k.traits["display"] = display_name + k.node_number = 1 else: k.traits["display"] = name + "|" + "not in dict" + k.node_number = 1 def find_colour_dict(query_dict, trait, colour_scheme): @@ -165,11 +177,10 @@ def make_scaled_tree(My_Tree, tree_name, tree_dir, num_tips, colour_dict_dict, d space_offset = tallest_height/10 absolute_x_axis_size = tallest_height+space_offset+space_offset + tallest_height #changed from /3 - tipsize = 40 c_func=lambda k: 'dimgrey' ## colour of branches l_func=lambda k: 'lightgrey' ## colour of dotted lines - s_func = lambda k: tipsize*5 if k.name in query_dict.keys() else tipsize + s_func = lambda k: tipsize*5 if k.name in query_dict.keys() else (0 if k.node_number > 1 else tipsize) z_func=lambda k: 100 b_func=lambda k: 2.0 #branch width so_func=lambda k: tipsize*5 if k.name in query_dict.keys() else 0 @@ -189,7 +200,6 @@ def make_scaled_tree(My_Tree, tree_name, tree_dir, num_tips, colour_dict_dict, d trait = next(key_iterator) #so always have the first trait as the first colour dot first_trait = trait - colour_dict = colour_dict_dict[trait] cn_func = lambda k: colour_dict[query_dict[k.name].attribute_dict[trait]] if k.name in query_dict.keys() else 'dimgrey' co_func=lambda k: colour_dict[query_dict[k.name].attribute_dict[trait]] if k.name in query_dict.keys() else 'dimgrey' @@ -220,7 +230,6 @@ def make_scaled_tree(My_Tree, tree_name, tree_dir, num_tips, colour_dict_dict, d My_Tree.plotPoints(ax, x_attr=x_attr, colour_function=cn_func,y_attr=y_attr, size_function=s_func, outline_colour=outline_colour_func) My_Tree.plotPoints(ax, x_attr=x_attr, colour_function=co_func, y_attr=y_attr, size_function=so_func, outline_colour=outline_colour_func) - blob_dict = {} for k in My_Tree.Objects: @@ -230,6 +239,10 @@ def make_scaled_tree(My_Tree, tree_name, tree_dir, num_tips, colour_dict_dict, d x=x_attr(k) y=y_attr(k) + if k.node_number > 1: + new_dot_size = tipsize*(1+math.log(k.node_number)) + ax.scatter(x, y, s=new_dot_size, marker="s", zorder=3, color="dimgrey") + height = My_Tree.treeHeight+offset text_start = tallest_height+space_offset+space_offset @@ -348,14 +361,14 @@ def make_all_of_the_trees(input_dir, tree_name_stem, taxon_dict, query_dict, des overall_tree_count = 0 - lst = sort_trees_index(input_dir) + tree_order = sort_trees_index(input_dir) for trait, colour_scheme in graphic_dict.items(): colour_dict = find_colour_dict(query_dict, trait, colour_scheme) colour_dict_dict[trait] = colour_dict - for fn in lst: + for fn in tree_order: lineage = fn treename = f"{tree_name_stem}_{fn}" treefile = f"{tree_name_stem}_{fn}.tree" @@ -403,7 +416,7 @@ def make_all_of_the_trees(input_dir, tree_name_stem, taxon_dict, query_dict, des too_tall_trees.append(lineage) continue - return too_tall_trees, overall_tree_count, colour_dict_dict, overall_df_dict + return too_tall_trees, overall_tree_count, colour_dict_dict, overall_df_dict, tree_order def summarise_collapsed_node_for_label(tree_dir, focal_node, focal_tree, full_tax_dict): @@ -456,7 +469,7 @@ def summarise_collapsed_node_for_label(tree_dir, focal_node, focal_tree, full_ta info = pretty_node_name + ": " + number_nodes + " in " + pretty_summary - return info + return info, len(member_list) def summarise_node_table(tree_dir, focal_tree, full_tax_dict): @@ -738,9 +751,6 @@ def describe_tree_background(full_tax_dict, tree_name_stem, tree_dir): plt.xticks(size=5, rotation=90) plt.yticks(size=5) - - - plt.title(pretty_focal + ": " + nde, size=5) diff --git a/civet/scripts/mapping.py b/civet/scripts/mapping.py index 71f2ab28..99a9d2e8 100644 --- a/civet/scripts/mapping.py +++ b/civet/scripts/mapping.py @@ -161,7 +161,7 @@ def make_map(centroid_geo, all_uk): ax.axis("off") -def run_map_functions(tax_dict, clean_locs_file, mapping_json_files): #So this takes adm2s and plots them onto the whole UK +def map_adm2(tax_dict, clean_locs_file, mapping_json_files): #So this takes adm2s and plots them onto the whole UK adm2s, metadata_multi_loc, straight_map = prep_data(tax_dict, clean_locs_file) @@ -175,12 +175,7 @@ def run_map_functions(tax_dict, clean_locs_file, mapping_json_files): #So this t make_map(centroid_geo, all_uk) -def map_traits(input_csv, input_crs, colour_map_trait, x_col, y_col, mapping_json_files, urban_centres): - - all_uk = generate_all_uk_dataframe(mapping_json_files) - all_uk = all_uk.to_crs("EPSG:3395") - - urban = geopandas.read_file(urban_centres) +def get_coords_from_file(input_csv, input_crs, colour_map_trait, x_col, y_col): ##READ IN TRAITS## @@ -200,15 +195,67 @@ def map_traits(input_csv, input_crs, colour_map_trait, x_col, y_col, mapping_jso trait = seq[colour_map_trait] if x != "" and y != "": - - #name_to_coords[name] = (((float(x)/200)-2.2,(float(y)/200)+55)) #just for now, will change to just x and y in a bit - name_to_ccords[name] = (float(x),float(y)) + #If we have the actual coordinates + name_to_coords[name] = (float(x),float(y)) if colour_map_trait != "False": name_to_trait[name] = trait + return name_to_coords, name_to_trait + +def generate_coords_from_outer_postcode(pc_file, input_csv, postcode_col, colour_map_trait): + + pc_to_coords = {} + name_to_coords = {} + name_to_trait = {} + + with open(pc_file) as f: + reader = csv.DictReader(f) + data = [r for r in reader] + + for line in data: + + pc = line["outcode"] + x = float(line["longitude"]) + y = float(line["latitude"]) + + pc_to_coords[pc] = ((x,y)) + + + with open(input_csv) as f: + reader = csv.DictReader(f) + data = [r for r in reader] + + for seq in data: + name = seq["name"] + outer_postcode = seq[postcode_col] + + if colour_map_trait != "False": + trait = seq[colour_map_trait] + + if outer_postcode != "": + if outer_postcode in pc_to_coords.keys(): + name_to_coords[name] = pc_to_coords[outer_postcode] + + if colour_map_trait != "False": + name_to_trait[name] = trait + + else: + pass + + + return name_to_coords, name_to_trait + + +def plot_coordinates(mapping_json_files, urban_centres, name_to_coords, name_to_trait, input_crs, colour_map_trait): + ##MAKE DATAFRAME## + all_uk = generate_all_uk_dataframe(mapping_json_files) + all_uk = all_uk.to_crs("EPSG:3395") + + urban = geopandas.read_file(urban_centres) + df_dict = defaultdict(list) for name, point in name_to_coords.items(): @@ -217,7 +264,7 @@ def map_traits(input_csv, input_crs, colour_map_trait, x_col, y_col, mapping_jso df_dict[colour_map_trait].append(name_to_trait[name]) crs = {'init':input_crs} - + df = geopandas.GeoDataFrame(df_dict, crs=crs) df_final = df.to_crs(all_uk.crs) @@ -290,11 +337,19 @@ def map_traits(input_csv, input_crs, colour_map_trait, x_col, y_col, mapping_jso ax.axis("off") return adm2_counter, adm2_percentages - - +def map_sequences_using_coordinates(input_csv, mapping_json_files, urban_centres, pc_file,colour_map_trait, map_inputs, input_crs): - + cols = map_inputs.split(",") + if len(cols) == 2: + x_col = cols[0] + y_col = cols[1] + name_to_coords, name_to_trait = get_coords_from_file(input_csv, input_crs, colour_map_trait, x_col, y_col) + elif len(cols) == 1: + postcode_col = cols[0] + name_to_coords, name_to_trait = generate_coords_from_outer_postcode(pc_file, input_csv, postcode_col, colour_map_trait) + adm2_counter, adm2_percentages = plot_coordinates(mapping_json_files, urban_centres, name_to_coords, name_to_trait, input_crs, colour_map_trait) + return adm2_counter, adm2_percentages \ No newline at end of file diff --git a/civet/scripts/time_functions.py b/civet/scripts/time_functions.py new file mode 100644 index 00000000..4915da6d --- /dev/null +++ b/civet/scripts/time_functions.py @@ -0,0 +1,155 @@ +import datetime as dt +import matplotlib.pyplot as plt +from matplotlib import cm +import numpy as np +import matplotlib.ticker as plticker +import math + + +def summarise_dates(query_dict): + + overall_dates = [] + + for tax in query_dict.values(): + overall_dates.extend(list(tax.date_dict.values())) + + if overall_dates != []: + + min_overall_date = min(overall_dates) + max_overall_date = max(overall_dates) + + min_string = min_overall_date.strftime("%Y-%m-%d") + max_string = max_overall_date.strftime("%Y-%m-%d") + + else: + return False + + + return overall_dates, max_overall_date, min_overall_date, max_string, min_string + +def display_name(tax, custom_tip_fields): + + name = tax.name + + date = tax.sample_date + display_name = f"{name}|{date}" + + if "adm2" in tax.attribute_dict.keys(): + adm2 = tax.attribute_dict["adm2"] + display_name = f"{name}|{adm2}|{date}" + + count = 0 + if len(custom_tip_fields) > 0: + for label_element in custom_tip_fields: + if count == 0: + display_name = tax.attribute_dict[label_element] + else: + display_name = display_name + "|" + tax.attribute_dict[label_element] + count += 1 + + return display_name + +def find_colour_dict(date_fields): + + colour_dict = {} + count = 0 + + cmap = cm.get_cmap("Paired") + + colors = cmap(np.linspace(0, 1, len(date_fields))) + + for option in sorted(date_fields): + colour_dict[option] = colors[count] + count += 1 + + return colour_dict + +def plot_time_series(tips, query_dict, overall_max_date, overall_min_date, date_fields, custom_tip_fields): + + colour_dict = find_colour_dict(date_fields) + + time_len = (overall_max_date - overall_min_date).days + + height = math.sqrt(len(tips))*2 + 1 + + if time_len > 20: + tick_loc_base = float(math.ceil(time_len/5)) + else: + tick_loc_base = 1.0 + + loc = plticker.MultipleLocator(base=tick_loc_base) #Sets a tick on each integer multiple of a base within the view interval + + fig, ax1 = plt.subplots(1,1, figsize=(20,height)) + ax2 = ax1.twinx() + + if time_len > 10: + offset = dt.timedelta(time_len/10) + else: + offset = dt.timedelta(time_len/3) + + + count = 1 + + for tax in tips: + if tax.date_dict != {} and tax in query_dict.values(): + + first_date_type = min(tax.date_dict.keys(), key=lambda k: tax.date_dict[k]) + last_date_type = max(tax.date_dict.keys(), key=lambda k: tax.date_dict[k]) + + first_date = tax.date_dict[first_date_type] + last_date = tax.date_dict[last_date_type] + + other_dates = {} + for date_type, date in tax.date_dict.items(): + if date != first_date and date != last_date: + other_dates[date_type] = date + + label = display_name(tax, custom_tip_fields) + + x = [first_date, last_date] + y = [count, count] + + ax1.scatter(first_date, count, color=colour_dict[first_date_type], s=200, zorder=2, label=first_date_type) + ax1.scatter(last_date, count, color=colour_dict[last_date_type], s=200, zorder=2, label=last_date_type) + + for date_option, date in other_dates.items(): + ax1.scatter(date, count, color=colour_dict[date_option], s=200, zorder=2, label=date_option) + + if first_date != last_date: + ax1.plot(x,y, color="dimgrey",zorder=1) + + # ax2.plot([last_date, overall_max_date+offset],y,ls='dotted',lw=1,color="dimgrey") + if x != overall_max_date: + ax1.plot([last_date, overall_max_date+offset],y,ls='--',lw=1,color="dimgrey", zorder=1) + + ax2.text(overall_max_date+offset, count, label,size=15) + + count += 1 + + ylim = ax1.get_ylim() + ax2.set_ylim(ylim) + + ax1.spines['top'].set_visible(False) ## make axes invisible + ax1.spines['right'].set_visible(False) + ax1.spines['left'].set_visible(False) + ax1.spines['bottom'].set_visible(False) + ax1.set_yticks([]) + ax2.spines['top'].set_visible(False) ## make axes invisible + ax2.spines['right'].set_visible(False) + ax2.spines['left'].set_visible(False) + ax2.set_yticks([]) + + ax1.tick_params(labelsize=20, rotation=90) + ax1.xaxis.set_major_locator(loc) + + handles, labels = ax1.get_legend_handles_labels() + unique = [(h, l) for i, (h, l) in enumerate(zip(handles, labels)) if l not in labels[:i]] + ax1.legend(*zip(*unique)) + + fig.tight_layout() + + + + + + diff --git a/civet/tests/test.csv b/civet/tests/test.csv index 2e34d65f..3c883142 100644 --- a/civet/tests/test.csv +++ b/civet/tests/test.csv @@ -1,12 +1,12 @@ -name,HCW_status,care_home,sample_date,adm2,test,test2 -EDB129_closest,HCW,A,2020-06-05,Edinburgh,yes,yes -EDB129_closestb,Patient,A,2020-02-30,Edinburgh,yes,yes -20144000304_closest,Patient,B,2020-02-31,Edinburgh,no,yes -This_seq_is_too_short,HCW,A,2020-02-32,Edinburgh,no,yes -EDB3588,HCW,C,2020-02-33,Edinburgh,no,yes -EDB2533,Patient,D,2020-02-34,Edinburgh,no,NA -PHEC-1A65C,HCW,A,2020-02-35,Edinburgh,yes,no -PHEC-1AD2A,NA,NA,2020-02-36,Edinburgh,no,yes -PHEC-1A917,Patient,D,2020-02-37,Edinburgh,NA,NA -This_seq_has_lots_of_Ns,HCW,A,2020-02-38,Edinburgh,NA,NA -This_seq_is_literally_just_N,HCW,D,2020-02-39,Edinburgh,NA,NA +name,HCW_status,care_home,sample_date,date_2,date_3,adm2,test,test2 +EDB129_closest,HCW,A,2020-06-05,2020-06-10,2020-06-17,Edinburgh,yes,yes +EDB129_closestb,Patient,A,2020-03-30,2020-04-04,2020-04-11,Edinburgh,yes,yes +20144000304_closest,Patient,B,2020-02-28,2020-03-04,2020-03-11,Edinburgh,no,yes +This_seq_is_too_short,HCW,A,2020-02-27,2020-03-03,2020-03-10,Edinburgh,no,yes +EDB3588,HCW,C,2020-04-20,2020-04-25,2020-05-02,Edinburgh,no,yes +EDB2533,Patient,D,2020-04-29,2020-05-04,2020-05-11,Edinburgh,no,NA +PHEC-1A65C,HCW,A,2020-05-04,2020-05-09,2020-05-16,Edinburgh,yes,no +PHEC-1AD2A,NA,NA,2020-03-02,2020-03-07,2020-03-14,Edinburgh,no,yes +PHEC-1A917,Patient,D,2020-02-27,2020-03-03,2020-03-10,Edinburgh,NA,NA +This_seq_has_lots_of_Ns,HCW,A,2020-05-15,2020-05-20,2020-05-27,Edinburgh,NA,NA +This_seq_is_literally_just_N,HCW,D,2020-04-30,2020-05-05,2020-05-12,Edinburgh,NA,NA \ No newline at end of file diff --git a/setup.py b/setup.py index 68682320..4410d45e 100644 --- a/setup.py +++ b/setup.py @@ -34,7 +34,8 @@ def run(self): "civet/scripts/find_snps.smk", "civet/scripts/find_snps.py", "civet/scripts/find_ambiguities.py", - "civet/scripts/make_genome_graph.py"], + "civet/scripts/make_genome_graph.py", + "civet/scripts/time_functions.py"], package_data={"civet":["data/*","data/headers/*","data/mapping_files/*","data/vega_templates/*"]}, install_requires=[ "biopython>=1.70",