From db6da2b5bb9c09d6f3c0fc6bbbbb62d863a1684e Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sun, 27 Oct 2019 12:40:43 +0100 Subject: [PATCH 01/77] changed mtx output for features and small printing bug --- cite_seq_count/__main__.py | 13 +++++++------ cite_seq_count/io.py | 2 +- cite_seq_count/preprocessing.py | 14 +++++++++++--- cite_seq_count/processing.py | 4 ++-- 4 files changed, 21 insertions(+), 12 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 519c391..e949d52 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -247,7 +247,7 @@ def main(): # Load TAGs/ABs. ab_map = preprocessing.parse_tags_csv(args.tags) - ab_map = preprocessing.check_tags(ab_map, args.max_error) + ordered_tags_map = preprocessing.check_tags(ab_map, args.max_error) # Identify input file(s) read1_paths, read2_paths = preprocessing.get_read_paths(args.read1_path, args.read2_path) @@ -280,13 +280,14 @@ def main(): reads_per_cell = Counter() merged_no_match = Counter() number_of_samples = len(read1_paths) - n_reads = 0 + #Print a statement if multiple files are run. if number_of_samples != 1: print('Detected {} files to run on.'.format(number_of_samples)) for read1_path, read2_path in zip(read1_paths, read2_paths): + n_reads = 0 if args.first_n: n_lines = (args.first_n*4)/number_of_samples else: @@ -366,10 +367,10 @@ def main(): else: # Explicitly save the counter to that tag final_results[cell_barcode][tag] = _final_results[cell_barcode][tag] - ordered_tags_map = OrderedDict() - for i,tag in enumerate(ab_map.values()): - ordered_tags_map[tag] = i - ordered_tags_map['unmapped'] = i + 1 + # ordered_tags_map = OrderedDict() + # for i,tag in enumerate(ab_map.values()): + # ordered_tags_map[tag] = i + # ordered_tags_map['unmapped'] = i + 1 # Correct cell barcodes if(len(umis_per_cell) <= args.expected_cells): diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 2dc04f0..56fd90c 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -25,7 +25,7 @@ def write_to_files(sparse_matrix, top_cells, ordered_tags_map, data_type, outfol barcode_file.write('{}\n'.format(barcode).encode()) with gzip.open(os.path.join(prefix,'features.tsv.gz'), 'wb') as feature_file: for feature in ordered_tags_map: - feature_file.write('{}\n'.format(feature).encode()) + feature_file.write('{}\t{}\n'.format(ordered_tags_map[feature]['sequence'], feature).encode()) with open(os.path.join(prefix,'matrix.mtx'),'rb') as mtx_in: with gzip.open(os.path.join(prefix,'matrix.mtx') + '.gz','wb') as mtx_gz: shutil.copyfileobj(mtx_in, mtx_gz) diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 4ca9f8c..6225531 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -153,14 +153,21 @@ def check_tags(tags, maximum_distance): """ ordered_tags = OrderedDict() - for tag in sorted(tags, key=len, reverse=True): - ordered_tags[tag] = tags[tag] + '-' + tag + for i,tag_seq in enumerate(sorted(tags, key=len, reverse=True)): + ordered_tags[tags[tag_seq]] = {} + ordered_tags[tags[tag_seq]]['id'] = i + ordered_tags[tags[tag_seq]]['sequence'] = tag_seq + ordered_tags['unmapped'] = {} + ordered_tags['unmapped']['id'] = i + 1 + ordered_tags['unmapped']['sequence'] = 'UNKNOWN' # If only one TAG is provided, then no distances to compare. if (len(tags) == 1): + ordered_tags['unmapped'] = {} + ordered_tags['unmapped']['id'] = 2 return(ordered_tags) offending_pairs = [] - for a, b in combinations(ordered_tags.keys(), 2): + for a, b in combinations(tags.keys(), 2): distance = Levenshtein.distance(a, b) if (distance <= (maximum_distance - 1)): offending_pairs.append([a, b, distance]) @@ -183,6 +190,7 @@ def check_tags(tags, maximum_distance): .format(tag1=pair[0], tag2=pair[1], distance=pair[2]) ) sys.exit('Exiting the application.\n') + return(ordered_tags) diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index fe656ef..af48215 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -434,7 +434,7 @@ def generate_sparse_matrices(final_results, ordered_tags_map, top_cells): for i,cell_barcode in enumerate(top_cells): for j,TAG in enumerate(final_results[cell_barcode]): if final_results[cell_barcode][TAG]: - umi_results_matrix[ordered_tags_map[TAG],i] = len(final_results[cell_barcode][TAG]) - read_results_matrix[ordered_tags_map[TAG],i] = sum(final_results[cell_barcode][TAG].values()) + umi_results_matrix[ordered_tags_map[TAG]['id'],i] = len(final_results[cell_barcode][TAG]) + read_results_matrix[ordered_tags_map[TAG]['id'],i] = sum(final_results[cell_barcode][TAG].values()) return(umi_results_matrix, read_results_matrix) From 480ac207ecffced907c34995848bf1cc670edad5 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Wed, 30 Oct 2019 09:25:10 +0100 Subject: [PATCH 02/77] Changed some verbose output --- cite_seq_count/__main__.py | 55 ++++++++++++++++++-------------------- 1 file changed, 26 insertions(+), 29 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index e949d52..4c95d66 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -280,7 +280,7 @@ def main(): reads_per_cell = Counter() merged_no_match = Counter() number_of_samples = len(read1_paths) - + total_reads = 0 #Print a statement if multiple files are run. if number_of_samples != 1: @@ -293,6 +293,7 @@ def main(): else: n_lines = preprocessing.get_n_lines(read1_path) n_reads += int(n_lines/4) + total_reads += n_reads n_threads = args.n_threads print('Started mapping') print('Processing {:,} reads'.format(n_reads)) @@ -313,7 +314,6 @@ def main(): start_trim=args.start_trim, maximum_distance=args.max_error, sliding_window=args.sliding_window) - print('Mapping done') _umis_per_cell = Counter() _reads_per_cell = Counter() for cell_barcode, counts in _final_results.items(): @@ -344,7 +344,6 @@ def main(): error_callback=sys.stderr) p.close() p.join() - print('Mapping done') print('Merging results') ( @@ -355,6 +354,7 @@ def main(): ) = processing.merge_results(parallel_results=parallel_results) del(parallel_results) + print('Mapping done, merging the different lanes') # Update the overall counts dicts umis_per_cell.update(_umis_per_cell) reads_per_cell.update(_reads_per_cell) @@ -367,10 +367,6 @@ def main(): else: # Explicitly save the counter to that tag final_results[cell_barcode][tag] = _final_results[cell_barcode][tag] - # ordered_tags_map = OrderedDict() - # for i,tag in enumerate(ab_map.values()): - # ordered_tags_map[tag] = i - # ordered_tags_map['unmapped'] = i + 1 # Correct cell barcodes if(len(umis_per_cell) <= args.expected_cells): @@ -421,8 +417,8 @@ def main(): top_cells_tuple = umis_per_cell.most_common(args.expected_cells) top_cells = set([pair[0] for pair in top_cells_tuple]) + #UMI correction - if args.no_umi_correction: #Don't correct umis_corrected = 0 @@ -439,27 +435,28 @@ def main(): top_cells=top_cells, max_umis=20000) - #Remove aberrant cells from the top cells - for cell_barcode in aberrant_cells: - top_cells.remove(cell_barcode) + if len(aberrant_cells) > 0: + #Remove aberrant cells from the top cells + for cell_barcode in aberrant_cells: + top_cells.remove(cell_barcode) - #Create sparse aberrant cells matrix - ( - umi_aberrant_matrix, - read_aberrant_matrix - ) = processing.generate_sparse_matrices( - final_results=final_results, - ordered_tags_map=ordered_tags_map, - top_cells=aberrant_cells) - - #Write uncorrected cells to dense output - io.write_dense( - sparse_matrix=umi_aberrant_matrix, - index=list(ordered_tags_map.keys()), - columns=aberrant_cells, - outfolder=os.path.join(args.outfolder,'uncorrected_cells'), - filename='dense_umis.tsv') - + #Create sparse aberrant cells matrix + ( + umi_aberrant_matrix, + read_aberrant_matrix + ) = processing.generate_sparse_matrices( + final_results=final_results, + ordered_tags_map=ordered_tags_map, + top_cells=aberrant_cells) + + #Write uncorrected cells to dense output + io.write_dense( + sparse_matrix=umi_aberrant_matrix, + index=list(ordered_tags_map.keys()), + columns=aberrant_cells, + outfolder=os.path.join(args.outfolder,'uncorrected_cells'), + filename='dense_umis.tsv') + #Create sparse matrices for results ( umi_results_matrix, @@ -494,7 +491,7 @@ def main(): #Create report and write it to disk create_report( - n_reads=n_reads, + n_reads=total_reads, reads_per_cell=reads_per_cell, no_match=merged_no_match, version=version, From 2fb744b735d47fbcca757f0de21e478e425ee034 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Tue, 31 Mar 2020 08:44:11 +0200 Subject: [PATCH 03/77] deleted an enumerate --- .gitignore | 3 ++- cite_seq_count/processing.py | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.gitignore b/.gitignore index 132e6dd..a6f833c 100644 --- a/.gitignore +++ b/.gitignore @@ -2,4 +2,5 @@ dist/ build/ CITE_seq_Count.egg-info/ __pycache__ -*.pyc \ No newline at end of file +*.pyc +.vscode diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index af48215..f5636b3 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -394,7 +394,7 @@ def find_true_to_false_map(barcode_tree, cell_barcodes, whitelist, collapsing_th true_to_false (defaultdict(list)): Contains the mapping between the fake and real barcodes. The key is the real one. """ true_to_false = defaultdict(list) - for i, cell_barcode in enumerate(cell_barcodes): + for cell_barcode in cell_barcodes: if cell_barcode in whitelist: # if the barcode is already whitelisted, no need to add continue From 27b6fc66d99534820b6abfa8d6ee872f264d96d5 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sat, 4 Apr 2020 19:09:32 +0200 Subject: [PATCH 04/77] added named_tuple ref --- cite_seq_count/__main__.py | 9 +++++---- cite_seq_count/preprocessing.py | 12 ++++++++++++ 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 4c95d66..ca69aa3 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -14,6 +14,7 @@ from collections import OrderedDict from collections import Counter from collections import defaultdict +from collections import namedtuple from multiprocess import cpu_count from multiprocess import Pool @@ -248,7 +249,7 @@ def main(): # Load TAGs/ABs. ab_map = preprocessing.parse_tags_csv(args.tags) ordered_tags_map = preprocessing.check_tags(ab_map, args.max_error) - + named_tuples_tags_map = preprocessing.convert_to_named_tuple(ordered_tags=ordered_tags_map) # Identify input file(s) read1_paths, read2_paths = preprocessing.get_read_paths(args.read1_path, args.read2_path) @@ -388,7 +389,7 @@ def main(): umis_per_cell=umis_per_cell, expected_cells=args.expected_cells, collapsing_threshold=args.bc_threshold, - ab_map=ordered_tags_map) + ab_map=named_tuples_tags_map) else: ( final_results, @@ -398,7 +399,7 @@ def main(): umis_per_cell=umis_per_cell, whitelist=whitelist, collapsing_threshold=args.bc_threshold, - ab_map=ordered_tags_map) + ab_map=named_tuples_tags_map) # If given, use whitelist for top cells if whitelist: @@ -409,7 +410,7 @@ def main(): continue else: final_results[missing_cell] = dict() - for TAG in ordered_tags_map: + for TAG in named_tuples_tags_map: final_results[missing_cell][TAG] = Counter() top_cells.add(missing_cell) else: diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 6225531..e13b0be 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -6,6 +6,7 @@ from math import floor from collections import OrderedDict +from collections import namedtuple from itertools import combinations from itertools import islice @@ -193,6 +194,17 @@ def check_tags(tags, maximum_distance): return(ordered_tags) +def sanitize_name(string): + return(string.replace('-', '_')) + +def convert_to_named_tuple(ordered_tags): + #all_tags = namedtuple('all_tags', [sanitize_name(tag) for tag in ordered_tags.keys()]) + tag = namedtuple('tag', ['name','sequence']) + tag_list = [] + for index, tag_name in enumerate(ordered_tags): + tag_list.append(tag(name=tag_name, sequence=ordered_tags[tag_name]['sequence'])) + #all_tags[index+1]=ordered_tags[tag_name]['sequence'] + return(tag_list) def get_read_length(filename): """Check wether SEQUENCE lengths are consistent in a FASTQ file and return From 2c13410ba0d8448b17626a8cec8469cc711ae20c Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Wed, 15 Apr 2020 15:43:30 +0200 Subject: [PATCH 05/77] parallel umis --- cite_seq_count/__main__.py | 251 +++++++++++++++++++------------- cite_seq_count/preprocessing.py | 18 ++- cite_seq_count/processing.py | 69 ++++----- 3 files changed, 202 insertions(+), 136 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index ca69aa3..c564d30 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -8,7 +8,9 @@ import datetime import pkg_resources import logging +import gzip +from itertools import islice from argparse import ArgumentParser from argparse import RawTextHelpFormatter from collections import OrderedDict @@ -16,8 +18,9 @@ from collections import defaultdict from collections import namedtuple -from multiprocess import cpu_count -from multiprocess import Pool +from multiprocess import cpu_count, Pool, Queue, JoinableQueue, Process + + from cite_seq_count import preprocessing from cite_seq_count import processing @@ -127,13 +130,22 @@ def get_args(): help=("Allow for a sliding window when aligning.") ) - + # Parallel group. + parallel = parser.add_argument_group( + 'Parallelization options', + description=("Options for performance on parallelization") + ) # Remaining arguments. - parser.add_argument('-T', '--threads', required=False, type=int, + parallel.add_argument('-T', '--threads', required=False, type=int, dest='n_threads', default=cpu_count(), help="How many threads are to be used for running the program") + parallel.add_argument('-C', '--chunk_size', required=False, type=int, + dest='chunk_size', default=1000000, + help="How many reads shuold be sent to a child process at a time") + + parser.add_argument('-n', '--first_n', required=False, type=int, - dest='first_n', default=None, + dest='first_n', default=float('inf'), help="Select N reads to run on instead of all.") parser.add_argument('-o', '--output', required=False, type=str, default='Results', dest='outfolder', help="Results will be written to this folder") @@ -248,7 +260,7 @@ def main(): # Load TAGs/ABs. ab_map = preprocessing.parse_tags_csv(args.tags) - ordered_tags_map = preprocessing.check_tags(ab_map, args.max_error) + ordered_tags_map, longest_tag_len = preprocessing.check_tags(ab_map, args.max_error) named_tuples_tags_map = preprocessing.convert_to_named_tuple(ordered_tags=ordered_tags_map) # Identify input file(s) read1_paths, read2_paths = preprocessing.get_read_paths(args.read1_path, args.read2_path) @@ -257,7 +269,10 @@ def main(): # one of the inputs is not valid. read1_lengths = [] read2_lengths = [] + total_reads = 0 for read1_path, read2_path in zip(read1_paths, read2_paths): + n_lines = preprocessing.get_n_lines(read1_path) + total_reads += n_lines/4 # Get reads length. So far, there is no validation for Read2. read1_lengths.append(preprocessing.get_read_length(read1_path)) read2_lengths.append(preprocessing.get_read_length(read2_path)) @@ -271,108 +286,111 @@ def main(): args.cb_first, args.cb_last, args.umi_first, args.umi_last) + # Ensure all files have the same input length - #if len(set(read1_lengths)) != 1: - # sys.exit('Input barcode fastqs (read1) do not all have same length.\nExiting') + if len(set(read1_lengths)) != 1: + sys.exit('Input barcode fastqs (read1) do not all have same length.\nExiting') + if len(set(read2_lengths)) != 1: + sys.exit('Input barcode fastqs (read2) do not all have same length.\nExiting') + # Define R2_lenght to reduce amount of data to transfer to childrens + if args.sliding_window: + R2_max_length = read2_lengths[1] + else: + R2_max_length = longest_tag_len # Initialize the counts dicts that will be generated from each input fastq pair final_results = defaultdict(lambda: defaultdict(Counter)) umis_per_cell = Counter() reads_per_cell = Counter() merged_no_match = Counter() number_of_samples = len(read1_paths) - total_reads = 0 #Print a statement if multiple files are run. if number_of_samples != 1: print('Detected {} files to run on.'.format(number_of_samples)) + + + input_queue = [] + #output_queue = Queue() + + #read_struct = namedtuple('read_struct', ['r1', 'r2']) + mapping_input = namedtuple('mapping_input', ['filename', 'tags', 'barcode_slice', 'umi_slice', 'debug', 'maximum_distance', 'sliding_window']) + + print('Writing chunks to disk') + reads_count = 0 + read_list = [] + num_chunks = 0 + chunk_size = round(total_reads/args.n_threads) + 1 for read1_path, read2_path in zip(read1_paths, read2_paths): - n_reads = 0 - if args.first_n: - n_lines = (args.first_n*4)/number_of_samples - else: - n_lines = preprocessing.get_n_lines(read1_path) - n_reads += int(n_lines/4) - total_reads += n_reads - n_threads = args.n_threads - print('Started mapping') - print('Processing {:,} reads'.format(n_reads)) - #Run with one process - if n_threads <= 1 or n_reads < 1000001: - print('CITE-seq-Count is running with one core.') - ( - _final_results, - _merged_no_match) = processing.map_reads( - read1_path=read1_path, - read2_path=read2_path, - tags=ab_map, - barcode_slice=barcode_slice, - umi_slice=umi_slice, - indexes=[0,n_reads], - whitelist=whitelist, - debug=args.debug, - start_trim=args.start_trim, - maximum_distance=args.max_error, - sliding_window=args.sliding_window) - _umis_per_cell = Counter() - _reads_per_cell = Counter() - for cell_barcode, counts in _final_results.items(): - _umis_per_cell[cell_barcode] = sum([len(counts[UMI]) for UMI in counts]) - _reads_per_cell[cell_barcode] = sum([sum(counts[UMI].values()) for UMI in counts]) - else: - # Run with multiple processes - print('CITE-seq-Count is running with {} cores.'.format(n_threads)) - p = Pool(processes=n_threads) - chunk_indexes = preprocessing.chunk_reads(n_reads, n_threads) - parallel_results = [] - - for indexes in chunk_indexes: - p.apply_async(processing.map_reads, - args=( - read1_path, - read2_path, - ab_map, - barcode_slice, - umi_slice, - indexes, - whitelist, - args.debug, - args.start_trim, - args.max_error, - args.sliding_window), - callback=parallel_results.append, - error_callback=sys.stderr) - p.close() - p.join() - print('Merging results') + print('Reading reads from files: {}, {}'.format(read1_path, read2_path)) + with gzip.open(read1_path, 'rt') as textfile1, \ + gzip.open(read2_path, 'rt') as textfile2: + + # Read all 2nd lines from 4 line chunks. If first_n not None read only 4 times the given amount. + secondlines = islice(zip(textfile1, textfile2), 1, None, 4) + temp_filename = 'temp_{}'.format(num_chunks) + chunked_file_object = open(temp_filename, 'w') + for read1, read2 in secondlines: + read1 = read1.strip()[0:args.umi_last] + read2 = read2.strip()[args.start_trim:R2_max_length] + chunked_file_object.write('{},{}\n'.format(read1, read2)) + reads_count += 1 + if reads_count % chunk_size == 0: + input_queue.append(mapping_input( + filename=temp_filename, + tags=named_tuples_tags_map, + barcode_slice=barcode_slice, + umi_slice=umi_slice, + debug=args.debug, + maximum_distance=args.max_error, + sliding_window=args.sliding_window)) + num_chunks +=1 + chunked_file_object.close() + temp_filename = 'temp_{}'.format(num_chunks) + chunked_file_object = open(temp_filename, 'w') + if reads_count >= args.first_n: + break + + input_queue.append(mapping_input( + filename=temp_filename, + tags=named_tuples_tags_map, + barcode_slice=barcode_slice, + umi_slice=umi_slice, + debug=args.debug, + maximum_distance=args.max_error, + sliding_window=args.sliding_window)) + chunked_file_object.close() + + print('Started mapping') + parallel_results = [] + pool = Pool(processes=args.n_threads) + errors = [] + mapping = pool.map_async(processing.map_reads, input_queue, callback=parallel_results.append, error_callback=errors.append) + mapping.wait() + pool.close() + pool.join() + if len(errors) != 0: + for error in errors: + print(error) + + + + print('Merging results') + ( + final_results, + umis_per_cell, + reads_per_cell, + merged_no_match + ) = processing.merge_results(parallel_results=parallel_results[0]) + + del(parallel_results) - ( - _final_results, - _umis_per_cell, - _reads_per_cell, - _merged_no_match - ) = processing.merge_results(parallel_results=parallel_results) - del(parallel_results) - - print('Mapping done, merging the different lanes') - # Update the overall counts dicts - umis_per_cell.update(_umis_per_cell) - reads_per_cell.update(_reads_per_cell) - merged_no_match.update(_merged_no_match) - for cell_barcode in _final_results: - for tag in _final_results[cell_barcode]: - if tag in final_results[cell_barcode]: - # Counter + Counter = Counter - final_results[cell_barcode][tag] += _final_results[cell_barcode][tag] - else: - # Explicitly save the counter to that tag - final_results[cell_barcode][tag] = _final_results[cell_barcode][tag] # Correct cell barcodes if(len(umis_per_cell) <= args.expected_cells): print("Number of expected cells, {}, is higher " \ - "than number of cells found {}.\nNot performing" \ + "than number of cells found {}.\nNot performing " \ "cell barcode correction" \ "".format(args.expected_cells, len(umis_per_cell))) bcs_corrected = 0 @@ -426,15 +444,52 @@ def main(): aberrant_cells = set() else: #Correct UMIS - ( - final_results, - umis_corrected, - aberrant_cells - ) = processing.correct_umis( - final_results=final_results, - collapsing_threshold=args.umi_threshold, - top_cells=top_cells, - max_umis=20000) + input_queue = [] + + umi_correction_input = namedtuple('umi_correction_input', ['cells','collapsing_threshold','max_umis']) + cells = {} + n_cells = 0 + num_chunks = 0 + + cell_batch_size = round(len(top_cells)/args.n_threads)+1 + for cell in top_cells: + cells[cell] = final_results[cell] + n_cells += 1 + if n_cells % cell_batch_size == 0: + input_queue.append(umi_correction_input( + cells=cells, + collapsing_threshold=args.umi_threshold, + max_umis=20000)) + cells = {} + num_chunks += 1 + input_queue.append(umi_correction_input( + cells=cells, + collapsing_threshold=args.umi_threshold, + max_umis=20000)) + + pool = Pool(processes=args.n_threads) + errors = [] + parallel_results = [] + correct_umis = pool.map_async(processing.correct_umis, input_queue, callback=parallel_results.append, error_callback=errors.append) + + correct_umis.wait() + pool.close() + pool.join() + + if len(errors) != 0: + for error in errors: + print(error) + + + final_results = {} + umis_corrected = 0 + aberrant_cells = set() + + for chunk in parallel_results[0]: + (temp_results, temp_umis, temp_aberrant_cells) = chunk + final_results.update(temp_results) + umis_corrected += temp_umis + aberrant_cells.update(temp_aberrant_cells) if len(aberrant_cells) > 0: #Remove aberrant cells from the top cells diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index e13b0be..5bc778b 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -149,15 +149,20 @@ def check_tags(tags, maximum_distance): between two TAGs. Returns: - collections.OrderedDict: An ordered dictionary containing the TAGs and + OrderedDict: An ordered dictionary containing the TAGs and their names in descendent order based on the length of the TAGs. + int: the length of the longest TAG """ ordered_tags = OrderedDict() + longest_tag_len = 0 for i,tag_seq in enumerate(sorted(tags, key=len, reverse=True)): ordered_tags[tags[tag_seq]] = {} ordered_tags[tags[tag_seq]]['id'] = i ordered_tags[tags[tag_seq]]['sequence'] = tag_seq + if len(tag_seq) > longest_tag_len: + longest_tag_len = len(tag_seq) + ordered_tags['unmapped'] = {} ordered_tags['unmapped']['id'] = i + 1 ordered_tags['unmapped']['sequence'] = 'UNKNOWN' @@ -165,7 +170,7 @@ def check_tags(tags, maximum_distance): if (len(tags) == 1): ordered_tags['unmapped'] = {} ordered_tags['unmapped']['id'] = 2 - return(ordered_tags) + return(ordered_tags, longest_tag_len) offending_pairs = [] for a, b in combinations(tags.keys(), 2): @@ -192,17 +197,17 @@ def check_tags(tags, maximum_distance): ) sys.exit('Exiting the application.\n') - return(ordered_tags) + return(ordered_tags, longest_tag_len) def sanitize_name(string): return(string.replace('-', '_')) def convert_to_named_tuple(ordered_tags): #all_tags = namedtuple('all_tags', [sanitize_name(tag) for tag in ordered_tags.keys()]) - tag = namedtuple('tag', ['name','sequence']) + tag = namedtuple('tag', ['safe_name','name','sequence', 'id']) tag_list = [] for index, tag_name in enumerate(ordered_tags): - tag_list.append(tag(name=tag_name, sequence=ordered_tags[tag_name]['sequence'])) + tag_list.append(tag(safe_name=sanitize_name(tag_name), name=tag_name, sequence=ordered_tags[tag_name]['sequence'], id=(index))) #all_tags[index+1]=ordered_tags[tag_name]['sequence'] return(tag_list) @@ -231,6 +236,9 @@ def get_read_length(filename): return(read_length) +def get_chunk_strategy(read1_paths, read2_paths, chunk_size): + pass + def check_barcodes_lengths(read1_length, cb_first, cb_last, umi_first, umi_last): """Check Read1 length against CELL and UMI barcodes length. diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index f5636b3..425a4ca 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -5,9 +5,11 @@ import Levenshtein import regex import pybktree +import csv from collections import Counter from collections import defaultdict +from collections import namedtuple from multiprocess import Pool from itertools import islice @@ -40,14 +42,14 @@ def find_best_match(TAG_seq, tags, maximum_distance): """ best_match = 'unmapped' best_score = maximum_distance - for tag, name in tags.items(): - score = Levenshtein.hamming(tag, TAG_seq[:len(tag)]) + for tag in tags: + score = Levenshtein.hamming(tag.sequence, TAG_seq[:len(tag.sequence)]) if score == 0: #Best possible match - return(name) + return(tag.name) elif score <= best_score: best_score = score - best_match = name + best_match = tag.name return(best_match) return(best_match) @@ -85,9 +87,7 @@ def find_best_match_shift(TAG_seq, tags, maximum_distance): return(best_match) -def map_reads(read1_path, read2_path, tags, barcode_slice, - umi_slice, indexes, whitelist, debug, - start_trim, maximum_distance, sliding_window): +def map_reads(mapping_input): """Read through R1/R2 files and generate a islice starting at a specific index. It reads both Read1 and Read2 files, creating a dict based on cell barcode. @@ -102,7 +102,6 @@ def map_reads(read1_path, read2_path, tags, barcode_slice, umi_slice (slice): A slice for extracting the UMI portion from the sequence. indexes (list): Pair of first and last index for islice - whitelist (set): The set of white-listed barcodes. debug (bool): Print debug messages. Default is False. start_trim (int): Number of bases to trim at the start. maximum_distance (int): Maximum distance given by the user. @@ -113,20 +112,19 @@ def map_reads(read1_path, read2_path, tags, barcode_slice, no_match (Counter): A counter with unmapped sequences. """ # Initiate values + (filename, tags, barcode_slice, umi_slice, debug, maximum_distance, sliding_window) = mapping_input + print('Started mapping in child process {}'.format(os.getpid())) results = {} no_match = Counter() n = 1 t = time.time() - with gzip.open(read1_path, 'rt') as textfile1, \ - gzip.open(read2_path, 'rt') as textfile2: - - # Read all 2nd lines from 4 line chunks. If first_n not None read only 4 times the given amount. - secondlines = islice(zip(textfile1, textfile2), indexes[0]*4+1, indexes[1]*4+1, 4) - for read1, read2 in secondlines: - read1 = read1.strip() - read2 = read2.strip() - - # Progress info + + # Progress info + with open(filename, 'r') as input_file: + reads = csv.reader(input_file) + for read in reads: + read1 = read[0] + read2 = read[1] if n % 1000000 == 0: print("Processed 1,000,000 reads in {}. Total " "reads: {:,} in child {}".format( @@ -141,21 +139,19 @@ def map_reads(read1_path, read2_path, tags, barcode_slice, cell_barcode = read1[barcode_slice] # This change in bytes is required by umi_tools for umi correction UMI = bytes(read1[umi_slice], 'ascii') - # Trim potential starting sequences - TAG_seq = read2[start_trim:] if cell_barcode not in results: results[cell_barcode] = defaultdict(Counter) if(sliding_window): - best_match = find_best_match_shift(TAG_seq, tags, maximum_distance) + best_match = find_best_match_shift(read2, tags, maximum_distance) else: - best_match = find_best_match(TAG_seq, tags, maximum_distance) + best_match = find_best_match(read2, tags, maximum_distance) results[cell_barcode][best_match][UMI] += 1 if(best_match == 'unmapped'): - no_match[TAG_seq] += 1 + no_match[read2] += 1 if debug: print( @@ -163,14 +159,15 @@ def map_reads(read1_path, read2_path, tags, barcode_slice, "cell_barcode:{1}\tUMI:{2}\tTAG_seq:{3}\n" "line length:{4}\tcell barcode length:{5}\tUMI length:{6}\tTAG sequence length:{7}\n" "Best match is: {8}" - .format(read1 + read2, cell_barcode, UMI, TAG_seq, - len(read1 + read2), len(cell_barcode), len(UMI), len(TAG_seq), best_match + .format(read1 + read2, cell_barcode, UMI, read2, + len(read1 + read2), len(cell_barcode), len(UMI), len(read2), best_match ) ) sys.stdout.flush() n += 1 - print("Mapping done for process {}. Processed {:,} reads".format(os.getpid(), n - 1)) - sys.stdout.flush() + print("Mapping done for process {}. Processed {:,} reads".format(os.getpid(), n - 1)) + sys.stdout.flush() + return(results, no_match) @@ -207,7 +204,7 @@ def merge_results(parallel_results): return(merged_results, umis_per_cell, reads_per_cell, merged_no_match) -def correct_umis(final_results, collapsing_threshold, top_cells, max_umis): +def correct_umis(umi_correction_input): """ Corrects umi barcodes within same cell/tag groups. @@ -222,10 +219,15 @@ def correct_umis(final_results, collapsing_threshold, top_cells, max_umis): corrected_umis (int): How many umis have been corrected. aberrant_umi_count_cells (set): Set of uncorrected cells. """ - print('Correcting umis') + + + + (final_results, collapsing_threshold, max_umis) = umi_correction_input + print('Started umi correction in child process {} working on {} cells'.format(os.getpid(), len(final_results))) corrected_umis = 0 - aberrant_umi_count_cells = set() - for cell_barcode in top_cells: + aberrant_cells = set() + cells = final_results.keys() + for cell_barcode in cells: for TAG in final_results[cell_barcode]: n_umis = len(final_results[cell_barcode][TAG]) if n_umis > 1 and n_umis <= max_umis: @@ -237,8 +239,9 @@ def correct_umis(final_results, collapsing_threshold, top_cells, max_umis): final_results[cell_barcode][TAG] = new_res corrected_umis += temp_corrected_umis elif n_umis > max_umis: - aberrant_umi_count_cells.add(cell_barcode) - return(final_results, corrected_umis, aberrant_umi_count_cells) + aberrant_cells.add(cell_barcode) + print('Finished correcting umis in child {}'.format(os.getpid())) + return(final_results, corrected_umis, aberrant_cells) def update_umi_counts(UMIclusters, cell_tag_counts): From 70e21704dcfaecfb71d2fc26b6056abc8dc693c1 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sun, 3 May 2020 16:39:46 +0200 Subject: [PATCH 06/77] Fixed tests --- cite_seq_count/__main__.py | 105 ++++++++++----- cite_seq_count/preprocessing.py | 41 +----- cite_seq_count/processing.py | 12 +- tests/test_data/fastq/test_csv.csv | 200 +++++++++++++++++++++++++++++ tests/test_io.py | 8 +- tests/test_preprocessing.py | 24 ++-- tests/test_processing.py | 75 +++++------ 7 files changed, 335 insertions(+), 130 deletions(-) create mode 100644 tests/test_data/fastq/test_csv.csv diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index c564d30..ced0980 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -11,17 +11,11 @@ import gzip from itertools import islice -from argparse import ArgumentParser -from argparse import RawTextHelpFormatter -from collections import OrderedDict -from collections import Counter -from collections import defaultdict -from collections import namedtuple +from argparse import ArgumentParser, ArgumentTypeError, RawTextHelpFormatter +from collections import OrderedDict, Counter, defaultdict, namedtuple from multiprocess import cpu_count, Pool, Queue, JoinableQueue, Process - - from cite_seq_count import preprocessing from cite_seq_count import processing from cite_seq_count import io @@ -29,6 +23,17 @@ version = pkg_resources.require("cite_seq_count")[0].version +def chunk_size_limit(arg): + """Validates chunk_size limits""" + try: + f = int(arg) + except ValueError: + raise ArgumentTypeError("Must be am int") + if f < 1 or f > 2147483647: + raise ArgumentTypeError("Argument must be < " + str(2147483647) + "and > " + str(1)) + return f + + def get_args(): """ Get args. @@ -88,6 +93,7 @@ def get_args(): barcodes.add_argument('--bc_collapsing_dist', dest='bc_threshold', required=False, type=int, default=1, help="threshold for cellular barcode collapsing.") + # Cells group cells = parser.add_argument_group( 'Cells', description=("Expected number of cells and potential whitelist") @@ -139,11 +145,15 @@ def get_args(): parallel.add_argument('-T', '--threads', required=False, type=int, dest='n_threads', default=cpu_count(), help="How many threads are to be used for running the program") - parallel.add_argument('-C', '--chunk_size', required=False, type=int, - dest='chunk_size', default=1000000, - help="How many reads shuold be sent to a child process at a time") - + parallel.add_argument('-C', '--chunk_size', required=False, type=chunk_size_limit, + dest='chunk_size', + help="How many reads should be sent to a child process at a time") + parallel.add_argument('--temp_path', required=False, type=str, + dest='temp_path', default="", + help="Temp folder for chunk creation specification. Useful when using a cluster with a scratch folder") + + # Global group parser.add_argument('-n', '--first_n', required=False, type=int, dest='first_n', default=float('inf'), help="Select N reads to run on instead of all.") @@ -165,12 +175,11 @@ def get_args(): return parser -def create_report(n_reads, reads_per_cell, no_match, version, start_time, ordered_tags_map, umis_corrected, bcs_corrected, bad_cells, args): +def create_report(total_reads, reads_per_cell, no_match, version, start_time, ordered_tags_map, umis_corrected, bcs_corrected, bad_cells, R1_too_short, R2_too_short, args): """ Creates a report with details about the run in a yaml format. - Args: - n_reads (int): Number of reads that have been processed. + total_reads (int): Number of reads that have been processed. reads_matrix (scipy.sparse.dok_matrix): A sparse matrix continining read counts. no_match (Counter): Counter of unmapped tags. version (string): CITE-seq-Count package version. @@ -179,9 +188,11 @@ def create_report(n_reads, reads_per_cell, no_match, version, start_time, ordere """ total_unmapped = sum(no_match.values()) - total_mapped = sum(reads_per_cell.values()) - total_unmapped - mapped_perc = round((total_mapped/n_reads)*100) - unmapped_perc = round((total_unmapped/n_reads)*100) + total_mapped = total_reads - total_unmapped + total_too_short = total_reads - total_unmapped - total_mapped + too_short_perc = round((total_too_short/total_reads)*100) + mapped_perc = round((total_mapped/total_reads)*100) + unmapped_perc = round((total_unmapped/total_reads)*100) with open(os.path.join(args.outfolder, 'run_report.yaml'), 'w') as report_file: report_file.write( @@ -191,6 +202,9 @@ def create_report(n_reads, reads_per_cell, no_match, version, start_time, ordere Reads processed: {} Percentage mapped: {} Percentage unmapped: {} +Percentage too short: {} +\tR1_too_short: {} +\tR2_too_short: {} Uncorrected cells: {} Correction: \tCell barcodes collapsing threshold: {} @@ -213,9 +227,12 @@ def create_report(n_reads, reads_per_cell, no_match, version, start_time, ordere datetime.datetime.today().strftime('%Y-%m-%d'), secondsToText.secondsToText(time.time()-start_time), version, - n_reads, + int(total_reads), mapped_perc, unmapped_perc, + too_short_perc, + R1_too_short, + R2_too_short, len(bad_cells), args.bc_threshold, bcs_corrected, @@ -249,6 +266,8 @@ def main(): # Parse arguments. args = parser.parse_args() + temp_path = os.path.abspath(args.temp_path) + assert os.access(temp_path, os.W_OK) if args.whitelist: print('Loading whitelist') (whitelist, args.bc_threshold) = preprocessing.parse_whitelist_csv( @@ -280,7 +299,7 @@ def main(): ( barcode_slice, umi_slice, - barcode_umi_length + _ ) = preprocessing.check_barcodes_lengths( read1_lengths[-1], args.cb_first, @@ -315,13 +334,18 @@ def main(): #output_queue = Queue() #read_struct = namedtuple('read_struct', ['r1', 'r2']) - mapping_input = namedtuple('mapping_input', ['filename', 'tags', 'barcode_slice', 'umi_slice', 'debug', 'maximum_distance', 'sliding_window']) + mapping_input = namedtuple('mapping_input', ['filename', 'tags', 'debug', 'maximum_distance', 'sliding_window']) print('Writing chunks to disk') reads_count = 0 - read_list = [] num_chunks = 0 - chunk_size = round(total_reads/args.n_threads) + 1 + if args.chunk_size: + chunk_size = args.chunk_size + else: + chunk_size = round(total_reads/args.n_threads) + 1 + temp_files = [] + R1_too_short = 0 + R2_too_short = 0 for read1_path, read2_path in zip(read1_paths, read2_paths): print('Reading reads from files: {}, {}'.format(read1_path, read2_path)) with gzip.open(read1_path, 'rt') as textfile1, \ @@ -329,19 +353,29 @@ def main(): # Read all 2nd lines from 4 line chunks. If first_n not None read only 4 times the given amount. secondlines = islice(zip(textfile1, textfile2), 1, None, 4) - temp_filename = 'temp_{}'.format(num_chunks) + temp_filename = os.path.join(temp_path, 'temp_{}'.format(num_chunks)) chunked_file_object = open(temp_filename, 'w') + temp_files.append(os.path.abspath(temp_filename)) for read1, read2 in secondlines: - read1 = read1.strip()[0:args.umi_last] - read2 = read2.strip()[args.start_trim:R2_max_length] - chunked_file_object.write('{},{}\n'.format(read1, read2)) + + read1 = read1.strip() + if len(read1) < args.umi_last: + R1_too_short += 1 + # The entire read is skipped + continue + read1_sliced = read1[0:args.umi_last] + if len(read2) < R2_max_length: + R2_too_short += 1 + # The entire read is skipped + continue + + read2_sliced = read2[args.start_trim:R2_max_length] + chunked_file_object.write('{},{},{}\n'.format(read1_sliced[barcode_slice], read1_sliced[umi_slice], read2_sliced)) reads_count += 1 if reads_count % chunk_size == 0: input_queue.append(mapping_input( filename=temp_filename, tags=named_tuples_tags_map, - barcode_slice=barcode_slice, - umi_slice=umi_slice, debug=args.debug, maximum_distance=args.max_error, sliding_window=args.sliding_window)) @@ -349,14 +383,13 @@ def main(): chunked_file_object.close() temp_filename = 'temp_{}'.format(num_chunks) chunked_file_object = open(temp_filename, 'w') + temp_files.append(os.path.abspath(temp_filename)) if reads_count >= args.first_n: break input_queue.append(mapping_input( filename=temp_filename, tags=named_tuples_tags_map, - barcode_slice=barcode_slice, - umi_slice=umi_slice, debug=args.debug, maximum_distance=args.max_error, sliding_window=args.sliding_window)) @@ -385,6 +418,10 @@ def main(): ) = processing.merge_results(parallel_results=parallel_results[0]) del(parallel_results) + + # Delete temp_files + for file_path in temp_files: + os.remove(file_path) # Correct cell barcodes @@ -499,7 +536,7 @@ def main(): #Create sparse aberrant cells matrix ( umi_aberrant_matrix, - read_aberrant_matrix + _ ) = processing.generate_sparse_matrices( final_results=final_results, ordered_tags_map=ordered_tags_map, @@ -547,7 +584,7 @@ def main(): #Create report and write it to disk create_report( - n_reads=total_reads, + total_reads=total_reads, reads_per_cell=reads_per_cell, no_match=merged_no_match, version=version, @@ -556,6 +593,8 @@ def main(): umis_corrected=umis_corrected, bcs_corrected=bcs_corrected, bad_cells=aberrant_cells, + R1_too_short=R1_too_short, + R2_too_short=R2_too_short, args=args) #Write dense matrix to disk if requested diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 5bc778b..aebfd90 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -9,45 +9,6 @@ from collections import namedtuple from itertools import combinations from itertools import islice - -def get_indexes(start_index, chunk_size, nth): - """ - Creates indexes from a reference index, a chunk size an nth number - - Args: - start_index (int): first position - chunk_size (int): Chunk size - nth (int): The nth number - - Returns: - list: First and last position of indexes - """ - start_index = nth * chunk_size - stop_index = chunk_size + nth * chunk_size - return([start_index,stop_index]) - - -def chunk_reads(n_reads, n): - """ - Creates a list of indexes for the islice iterator from the map_reads function. - - Args: - n_reads (int): Number of reads to split - n (int): How many buckets for the split. - Returns: - indexes (list(list)): Each entry contains the first and the last index for a read. - """ - indexes=list() - if n_reads % n == 0: - chunk_size = int(n_reads/n) - rest = 0 - else: - chunk_size = floor(n_reads/n) - rest = n_reads - (n*chunk_size) - for i in range(0,n): - indexes.append(get_indexes(i, chunk_size, i)) - indexes[-1][1] += rest - return(indexes) def parse_whitelist_csv(filename, barcode_length, collapsing_threshold): @@ -156,7 +117,7 @@ def check_tags(tags, maximum_distance): """ ordered_tags = OrderedDict() longest_tag_len = 0 - for i,tag_seq in enumerate(sorted(tags, key=len, reverse=True)): + for i, tag_seq in enumerate(sorted(tags, key=len, reverse=True)): ordered_tags[tags[tag_seq]] = {} ordered_tags[tags[tag_seq]]['id'] = i ordered_tags[tags[tag_seq]]['sequence'] = tag_seq diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 425a4ca..5acc734 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -112,7 +112,7 @@ def map_reads(mapping_input): no_match (Counter): A counter with unmapped sequences. """ # Initiate values - (filename, tags, barcode_slice, umi_slice, debug, maximum_distance, sliding_window) = mapping_input + (filename, tags, debug, maximum_distance, sliding_window) = mapping_input print('Started mapping in child process {}'.format(os.getpid())) results = {} no_match = Counter() @@ -123,8 +123,10 @@ def map_reads(mapping_input): with open(filename, 'r') as input_file: reads = csv.reader(input_file) for read in reads: - read1 = read[0] - read2 = read[1] + cell_barcode = read[0] + # This change in bytes is required by umi_tools for umi correction + UMI = bytes(read[1], 'ascii') + read2 = read[2] if n % 1000000 == 0: print("Processed 1,000,000 reads in {}. Total " "reads: {:,} in child {}".format( @@ -135,10 +137,6 @@ def map_reads(mapping_input): sys.stdout.flush() t = time.time() - # Get cell and umi barcodes. - cell_barcode = read1[barcode_slice] - # This change in bytes is required by umi_tools for umi correction - UMI = bytes(read1[umi_slice], 'ascii') if cell_barcode not in results: results[cell_barcode] = defaultdict(Counter) diff --git a/tests/test_data/fastq/test_csv.csv b/tests/test_data/fastq/test_csv.csv new file mode 100644 index 0000000..d4d8295 --- /dev/null +++ b/tests/test_data/fastq/test_csv.csv @@ -0,0 +1,200 @@ +TAGAGGGAAGTCAAGC,CNGAGTCTCN,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,CACNTAAATC,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,TCTCCGAAGC,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,TTGACTTATC,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,TTTCCTTCCG,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,NAGAACACCA,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,CTCATCTTAT,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,TCGACAAGCT,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,AAANACACTC,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,AAANACACTC,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,GNCNCTGATA,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,TGGTGTGNGC,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,AGANTCTCCC,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,GCGCGGAGAA,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,TGGTGTGNGC,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,TCTGGTAGTT,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,TAGGAGCAGN,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,ATCACGGATC,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,CCGGGGACTA,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,TTACTAGTAA,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,AGAACGTCGC,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,ANGAGNAAGT,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,ATAACTAGAA,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,TTGACTTATC,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,ACACTGCTAT,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,GCGCGGAGAA,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,CGTGATGAGC,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,TCACCCNCGG,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,ATCCGTACTT,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,CCGGGGACTA,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,NTTCATGTTG,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,ATCGGGAGNC,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,CGTCNGTTGC,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,TAGTCAGAAT,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,ATCGGGAGNC,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,ACACTGCTAT,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,TGCTCAATAG,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,GATCGTACAA,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,AGACCTNTGG,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,TGACCTAAGC,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,GATCGTACAA,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,TGGTGTGNGC,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,TCGACACCAC,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,CATCNAGTGN,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,TTAAAACCNA,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,CCATACGNNA,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,ATTGTTCGGA,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,CACNCTAGGG,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,AGGAGCNCCC,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,GCGCGGAGAA,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,TTCCGTNCAA,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,GNCNCTGATA,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,CAAGGGAACG,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,AGAANGCCNA,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,TCTGGTAGTT,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,ACAGAGTAAN,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,CACNCTAGGG,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,CTACACGTGA,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,CGTCNGTTGC,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,CGTCGTTATA,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,TTCNGTCACC,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,TTGGNGTACA,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,TCGACAAGCT,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,CGNAATTTGA,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,GCGCGGAGAA,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,CTACGCCGCC,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,TAGGAGCAGN,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,AGCANTGTAG,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,AGAANGCCNA,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,TGTCTGCACG,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,CNGAGTCTCN,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,ATAACTAGAA,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,CCATGTGNGT,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,CCATGTGNGT,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,CGTAGGCATT,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,CCGGGGACTA,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,ANTTCTCTCA,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,ACAGAGTAAN,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,TGCTCAATAG,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,AGACTTAGGG,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,TAAAGGCTTG,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,CTTGAGAGGG,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,TTCCGTNCAA,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,CATCNAGTGN,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,GAACACTGAG,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,ACGCGGAGTT,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,ANGAGNAAGT,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,ANTTCTCTCA,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,GCTGTGTTAG,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,TTCNGTCACC,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,CCATACGNNA,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,CGTCNGTTGC,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,CCATGTGNGT,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,GCCCGCTCAC,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,AAANACACTC,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,TAAAGGCTTG,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,GATTTCACCG,CGTCGTAGCTGATCG +TAGAGGGAAGTCAAGC,CCATACGNNA,CGTACGTAGCCTAGC +TAGAGGGAAGTCAAGC,AGCANTGTAG,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,AACTCCCACG,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,GATCGAACGG,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,GATCGAACGG,CGTCGTAGCTGATCG +TACATATTCTTTACTG,GANCGGGACA,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,NGCTGGCACG,CGTACGTAGCCTAGC +TACATATTCTTTACTG,GANCGGGACA,CGTACGTAGCCTAGC +TACATATTCTTTACTG,ATTCATTGTA,CGTACGTAGCCTAGC +TACATATTCTTTACTG,TAATCATACC,CGTACGTAGCCTAGC +TACATATTCTTTACTG,GGTCTAAGAG,CGTCGTAGCTGATCG +TACATATTCTTTACTG,GGATNTTGTA,CGTCGTAGCTGATCG +TACATATTCTTTACTG,AATATGANTG,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,ATTAAGCCNG,CGTACGTAGCCTAGC +TACATATTCTTTACTG,TGAGGGTAGA,CGTCGTAGCTGATCG +TACATATTCTTTACTG,CTCTCGCTTT,CGTACGTAGCCTAGC +TACATATTCTTTACTG,GGATNTTGTA,CGTCGTAGCTGATCG +TACATATTCTTTACTG,AGGTTTACTG,CGTCGTAGCTGATCG +TACATATTCTTTACTG,GAGGCGTGTC,CGTCGTAGCTGATCG +TACATATTCTTTACTG,TGCTGAATAA,CGTCGTAGCTGATCG +TACATATTCTTTACTG,AAGGCACTTT,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,CTTTCAAGTN,CGTACGTAGCCTAGC +TACATATTCTTTACTG,GAGGCGTGTC,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,TGTNAATCCA,CGTCGTAGCTGATCG +TACATATTCTTTACTG,GCCAAGTACA,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,GGATNTTGTA,CGTCGTAGCTGATCG +TACATATTCTTTACTG,CCGNTGTGGC,CGTACGTAGCCTAGC +TACATATTCTTTACTG,GATCGAACGG,CGTCGTAGCTGATCG +TACATATTCTTTACTG,TCGCGATGNT,CGTCGTAGCTGATCG +TACATATTCTTTACTG,ACCGTGAGGC,CGTACGTAGCCTAGC +TACATATTCTTTACTG,GGTCGCAGTN,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,CCAGACTTGA,CGTCGTAGCTGATCG +TACATATTCTTTACTG,GACTTTTCCT,CGTCGTAGCTGATCG +TACATATTCTTTACTG,CTTCCATGCC,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,AGCAACCCGA,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,AGCAACCCGA,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,GACGGGGTCT,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,TACGAAGAAT,CGTACGTAGCCTAGC +TACATATTCTTTACTG,CGAGGTGCGN,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,TNCATCGGAT,CGTCGTAGCTGATCG +TACATATTCTTTACTG,AAGGCACTTT,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,ATTAAGCCNG,CGTCGTAGCTGATCG +TACATATTCTTTACTG,GCTAACCCGN,CGTCGTAGCTGATCG +TACATATTCTTTACTG,AGGTTTACTG,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,ANAGGANAAC,CGTACGTAGCCTAGC +TACATATTCTTTACTG,AGGTTTACTG,CGTCGTAGCTGATCG +TACATATTCTTTACTG,CTNATCGGTC,CGTACGTAGCCTAGC +TACATATTCTTTACTG,AGCAACCCGA,CGTACGTAGCCTAGC +TACATATTCTTTACTG,ACTGGTCGCT,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,GCNGTCGCTA,CGTCGTAGCTGATCG +TACATATTCTTTACTG,TCGCGATGNT,CGTCGTAGCTGATCG +TACATATTCTTTACTG,TAATCATACC,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,GACTTTTCCT,CGTACGTAGCCTAGC +TACATATTCTTTACTG,CCCGAATGAA,CGTCGTAGCTGATCG +TACATATTCTTTACTG,GCTTCTACCN,CGTACGTAGCCTAGC +TACATATTCTTTACTG,ACTGGTCGCT,CGTACGTAGCCTAGC +TACATATTCTTTACTG,AGGTCGCTAC,CGTACGTAGCCTAGC +TACATATTCTTTACTG,AGCGCCNTGG,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,CCAGCGCCCG,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,GAGATCCGAG,CGTACGTAGCCTAGC +TACATATTCTTTACTG,TAGCCCCCCC,CGTACGTAGCCTAGC +TACATATTCTTTACTG,ATTCATTGTA,CGTACGTAGCCTAGC +TACATATTCTTTACTG,ATCGGGCGCC,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,CGAGGTGCGN,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,AGTAANGCAA,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,NGCTGGCACG,CGTCGTAGCTGATCG +TACATATTCTTTACTG,CCGNTGTGGC,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,ATCGGGCGCC,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,ATTAAGCCNG,CGTCGTAGCTGATCG +TACATATTCTTTACTG,GGNCGCACCC,CGTACGTAGCCTAGC +TACATATTCTTTACTG,AANCACANGT,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,AANTAAGCAT,CGTCGTAGCTGATCG +TACATATTCTTTACTG,ANAGGANAAC,CGTACGTAGCCTAGC +TACATATTCTTTACTG,GGNCGCACCC,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,CAATTCCGGC,CGTACGTAGCCTAGC +TACATATTCTTTACTG,AGGTTTACTG,CGTCGTAGCTGATCG +TACATATTCTTTACTG,CCGNTGTGGC,CGTACGTAGCCTAGC +TACATATTCTTTACTG,ACGCTATGTA,CGTCGTAGCTGATCG +TACATATTCTTTACTG,CTCCTGTGGC,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,GTTGTTTATT,CGTCGTAGCTGATCG +TACATATTCTTTACTG,CGAAGAGAAC,CGTCGTAGCTGATCG +TACATATTCTTTACTG,CGAAGAGAAC,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,GCGGCCATTC,CGTACGTAGCCTAGC +TACATATTCTTTACTG,AGTAANGCAA,CGTCGTAGCTGATCG +TACATATTCTTTACTG,CGAAGAGAAC,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,GTCAACCGGG,CGTACGTAGCCTAGC +TACATATTCTTTACTG,CTCAATACTA,CGTACGTAGCCTAGC +TACATATTCTTTACTG,ATTAAGCCNG,CGTACGTAGCCTAGC +TACATATTCTTTACTG,TACTGTGCTA,CGTACGTAGCCTAGC +TACATATTCTTTACTG,ANGCACTCGA,CGTCGTAGCTGATCG +TACATATTCTTTACTG,NGCTGGCACG,CGTCGTAGCTGATCG +TACATATTCTTTACTG,TAGTATGGAA,CGTCGTAGCTGATCG +TACATATTCTTTACTG,TCGCGATGNT,CGTCGTAGCTGATCG +TACATATTCTTTACTG,ANAGGANAAC,CGTCGTAGCTGATCG +TACATATTCTTTACTG,ACAGTAAATG,CGTACGTAGCCTAGC +TACATATTCTTTACTG,GGTCTAAGAG,CGTACGTAGCCTAGC +TACATATTCTTTACTG,CGAGGTGCGN,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,ACTGGTCGCT,CGTCGTAGCTGATCG +TACATATTCTTTACTG,GTTGTTTATT,CGTAGCTCGAAAAAA +TACATATTCTTTACTG,AAGGCACTTT,CGTACGTAGCCTAGC +TACATATTCTTTACTG,TGACATCAAC,CGTACGTAGCCTAGC +TACATATTCTTTACTG,TGCAGAAANG,CGTCGTAGCTGATCG +TACATATTCTTTACTG,CTTCAANTGA,CGTAGCTCGAAAAAA diff --git a/tests/test_io.py b/tests/test_io.py index d0be8ab..f71e027 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -12,10 +12,10 @@ def data(): pytest.sparse_matrix = test_matrix pytest.top_cells = set(['ACTGTTTTATTGGCCT','TTCATAAGGTAGGGAT']) pytest.ordered_tags_map = OrderedDict({ - 'test3-CGTCGTAGCTGATCGTAGCTGAC':0, - 'test2-CGTACGTAGCCTAGC':1, - 'test1-CGTAGCTCG': 3, - 'unmapped': 4 + 'test3':{'id':0, 'sequence': 'CGTA'}, + 'test2':{'id':1, 'sequence': 'CGTA'}, + 'test1': {'id':3, 'sequence': 'CGTA'}, + 'unmapped': {'id':4, 'sequence': 'CGTA'} }) pytest.data_type = 'umi' pytest.outfolder = 'tests/test_data/' diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 6465fb9..43401d8 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -36,15 +36,16 @@ def data(): 'ACTGTCTAACGGGTCAGTGC':'CITE_LEN_20_2', 'TATCACATCGGTGGATCCAT':'CITE_LEN_20_3'} pytest.correct_ordered_tags = OrderedDict({ - 'TGTGACGTATTGCTAGCTAG':'CITE_LEN_20_1-TGTGACGTATTGCTAGCTAG', - 'ACTGTCTAACGGGTCAGTGC':'CITE_LEN_20_2-ACTGTCTAACGGGTCAGTGC', - 'TATCACATCGGTGGATCCAT':'CITE_LEN_20_3-TATCACATCGGTGGATCCAT', - 'TCGATAATGCGAGTACAA':'CITE_LEN_18_1-TCGATAATGCGAGTACAA', - 'GAGGCTGAGCTAGCTAGT':'CITE_LEN_18_2-GAGGCTGAGCTAGCTAGT', - 'GGCTGATGCTGACTGCTA':'CITE_LEN_18_3-GGCTGATGCTGACTGCTA', - 'AGGACCATCCAA':'CITE_LEN_12_1-AGGACCATCCAA', - 'ACATGTTACCGT':'CITE_LEN_12_2-ACATGTTACCGT', - 'AGCTTACTATCC':'CITE_LEN_12_3-AGCTTACTATCC'}) + 'CITE_LEN_20_1':{'id':0,'sequence':'TGTGACGTATTGCTAGCTAG'}, + 'CITE_LEN_20_2':{'id':1,'sequence':'ACTGTCTAACGGGTCAGTGC'}, + 'CITE_LEN_20_3':{'id':2,'sequence':'TATCACATCGGTGGATCCAT'}, + 'CITE_LEN_18_1':{'id':3,'sequence':'TCGATAATGCGAGTACAA'}, + 'CITE_LEN_18_2':{'id':4,'sequence':'GAGGCTGAGCTAGCTAGT'}, + 'CITE_LEN_18_3':{'id':5,'sequence':'GGCTGATGCTGACTGCTA'}, + 'CITE_LEN_12_1':{'id':6,'sequence':'AGGACCATCCAA'}, + 'CITE_LEN_12_2':{'id':7,'sequence':'ACATGTTACCGT'}, + 'CITE_LEN_12_3':{'id':8,'sequence':'AGCTTACTATCC'}, + 'unmapped':{'id':9, 'sequence': 'UNKNOWN'}}) pytest.barcode_slice = slice(0, 16) pytest.umi_slice = slice(16, 26) pytest.barcode_umi_length = 26 @@ -59,7 +60,10 @@ def test_parse_tags_csv(data): @pytest.mark.dependency(depends=['test_parse_tags_csv']) def test_check_tags(data): - assert preprocessing.check_tags(pytest.correct_tags, 5) == pytest.correct_ordered_tags + tags = preprocessing.check_tags(pytest.correct_tags, 5)[0] + for name in tags.keys(): + assert tags[name] == pytest.correct_ordered_tags[name] + @pytest.mark.dependency(depends=['test_check_tags']) def test_check_distance_too_big_between_tags(data): diff --git a/tests/test_processing.py b/tests/test_processing.py index 63eea35..039bd9e 100644 --- a/tests/test_processing.py +++ b/tests/test_processing.py @@ -55,15 +55,16 @@ def data(): # Test file paths pytest.correct_R1_path = 'tests/test_data/fastq/correct_R1.fastq.gz' pytest.correct_R2_path = 'tests/test_data/fastq/correct_R2.fastq.gz' + pytest.file_path = 'tests/test_data/fastq/test_csv.csv' pytest.chunk_size = 800 pytest.tags = OrderedDict({ - 'CGTACGTAGCCTAGC': 'test2-CGTACGTAGCCTAGC', - 'CGTAGCTCG': 'test1-CGTAGCTCG' + 'test2':{'id':0,'sequence':'CGTACGTAGCCTAGC'}, + 'test1':{'id':1,'sequence':'CGTAGCTCG'}, + 'unmapped':{'id':2,'sequence':'UNKNOWN'}, }) pytest.barcode_slice = slice(0, 16) pytest.umi_slice = slice(16, 26) - pytest.indexes = [0,800] pytest.correct_whitelist = set(['ACTGTTTTATTGGCCT','TTCATAAGGTAGGGAT']) pytest.legacy = False pytest.debug = False @@ -71,18 +72,18 @@ def data(): pytest.maximum_distance = 5 pytest.results = { 'ACTGTTTTATTGGCCT': - {'test1-CGTAGCTCG': + {'test1': Counter({b'CATTAGTGGT': 3, b'CATTAGTGGG': 2, b'CATTCGTGGT': 1})}, 'TTCATAAGGTAGGGAT': - {'test2-CGTACGTAGCCTAGC': + {'test2': Counter({b'TAGCTTAGTA': 3, b'TAGCTTAGTC': 2, b'GCGATGCATA': 1})} } pytest.corrected_results = { 'ACTGTTTTATTGGCCT': - {'test1-CGTAGCTCG': + {'test1': Counter({b'CATTAGTGGT': 6})}, 'TTCATAAGGTAGGGAT': - {'test2-CGTACGTAGCCTAGC': + {'test2': Counter({b'TAGCTTAGTA': 5, b'GCGATGCATA': 1})} } pytest.umis_per_cell = Counter({ @@ -100,45 +101,53 @@ def data(): pytest.max_umis = 20000 pytest.sequence_pool = [] - pytest.tags_complete = preprocessing.check_tags(preprocessing.parse_tags_csv('tests/test_data/tags/correct.csv'), 5) + pytest.tags_complete_dict = preprocessing.check_tags(preprocessing.parse_tags_csv('tests/test_data/tags/correct.csv'), 5)[0] + pytest.tags_complete_tuple = preprocessing.convert_to_named_tuple(pytest.tags_complete_dict) + pytest.tags_short_tuple = preprocessing.convert_to_named_tuple(pytest.tags) @pytest.mark.dependency() def test_find_best_match_with_1_distance(data): distance = 1 - for tag,name in pytest.tags_complete.items(): + for name, tag in pytest.tags_complete_dict.items(): counts = Counter() - for seq in extend_seq_pool(tag, distance): - counts[processing.find_best_match(seq, pytest.tags_complete, distance)] += 1 + if name == 'unmapped': + continue + for seq in extend_seq_pool(tag['sequence'], distance): + counts[processing.find_best_match(seq, pytest.tags_complete_tuple, distance)] += 1 assert counts[name] == 4 @pytest.mark.dependency() def test_find_best_match_with_2_distance(data): distance = 2 - for tag,name in pytest.tags_complete.items(): + for name, tag in pytest.tags_complete_dict.items(): counts = Counter() - - for seq in extend_seq_pool(tag, distance): - counts[processing.find_best_match(seq, pytest.tags_complete, distance)] += 1 + if name == 'unmapped': + continue + for seq in extend_seq_pool(tag['sequence'], distance): + counts[processing.find_best_match(seq, pytest.tags_complete_tuple, distance)] += 1 assert counts[name] == 4 @pytest.mark.dependency() def test_find_best_match_with_3_distance(data): distance = 3 - for tag,name in pytest.tags_complete.items(): + for name, tag in pytest.tags_complete_dict.items(): counts = Counter() - - for seq in extend_seq_pool(tag, distance): - counts[processing.find_best_match(seq, pytest.tags_complete, distance)] += 1 + if name == 'unmapped': + continue + for seq in extend_seq_pool(tag['sequence'], distance): + counts[processing.find_best_match(seq, pytest.tags_complete_tuple, distance)] += 1 assert counts[name] == 4 @pytest.mark.dependency() def test_find_best_match_with_3_distance_reverse(data): distance = 3 - for tag,name in sorted(pytest.tags_complete.items()): + for name, tag in sorted(pytest.tags_complete_dict.items()): counts = Counter() - for seq in extend_seq_pool(tag, distance): - counts[processing.find_best_match(seq, pytest.tags_complete, distance)] += 1 + if name == 'unmapped': + continue + for seq in extend_seq_pool(tag['sequence'], distance): + counts[processing.find_best_match(seq, pytest.tags_complete_tuple, distance)] += 1 assert counts[name] == 4 @pytest.mark.dependency(depends=[ @@ -147,24 +156,18 @@ def test_find_best_match_with_3_distance_reverse(data): 'test_find_best_match_with_3_distance', 'test_find_best_match_with_3_distance_reverse',]) def test_classify_reads_multi_process(data): - (results, no_match) = processing.map_reads( - pytest.correct_R1_path, - pytest.correct_R2_path, - pytest.tags, - pytest.barcode_slice, - pytest.umi_slice, - pytest.indexes, - pytest.correct_whitelist, + (results, no_match) = processing.map_reads(( + pytest.file_path, + pytest.tags_short_tuple, pytest.debug, - pytest.start_trim, pytest.maximum_distance, - pytest.sliding_window) + pytest.sliding_window)) assert len(results) == 2 @pytest.mark.dependency(depends=['test_classify_reads_multi_process']) def test_correct_umis(data): - temp = processing.correct_umis(pytest.results, 2, pytest.corrected_results.keys(), pytest.max_umis) + temp = processing.correct_umis((pytest.results, 2, pytest.max_umis)) results = temp[0] n_corrected = temp[1] for cell_barcode in results.keys(): @@ -182,11 +185,11 @@ def test_correct_cells(data): @pytest.mark.dependency(depends=['test_correct_umis']) def test_generate_sparse_matrices(data): (umi_results_matrix, read_results_matrix) = processing.generate_sparse_matrices( - pytest.corrected_results, pytest.ordered_tags_map, + pytest.corrected_results, pytest.tags, set(['ACTGTTTTATTGGCCT','TTCATAAGGTAGGGAT']) ) - assert umi_results_matrix.shape == (4,2) - assert read_results_matrix.shape == (4,2) + assert umi_results_matrix.shape == (3,2) + assert read_results_matrix.shape == (3,2) read_results_matrix = read_results_matrix.tocsr() total_reads = 0 for i in range(read_results_matrix.shape[0]): From a6f595c89cf9388d169d4716d7d73d7c1970b2cf Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sun, 3 May 2020 16:49:47 +0200 Subject: [PATCH 07/77] some updates to CHANGELOG --- CHANGELOG.md | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c234933..8cbff47 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,19 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](http://keepachangelog.com/) and this project adheres to [Semantic Versioning](http://semver.org/). +## [1.5.0] - XXXX +### Added + - `CITE-se-Count` is now Compatible with trimmed data. There is a new `too_short` category in the `run_report.yaml` + that will let you know how much you lost due to reads being too short. Will not work well with the `--sliding_window` option + because any read that is trimmed even by one base will be discarded. + - UMI correction is now also parallelized and will use the threads proposed. +### Changed + - The `features.csv` now has different columns for the tag name and the tag sequence. This keeps the relevant information + in the output files as well as simplifies reading the mtx format when processing the data. + - The mapping step has been changed. It will first write the reads to files and then read in the chunks. + This should solve the io bottleneck from before. + + ## [1.4.3] - 05.10.2019 ### Added - Support for multiple files as input. This allows you to not merge different lanes before From 5e09e19dc12e9b2cf758100cd9c4508f8b9e4e30 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sun, 3 May 2020 16:53:54 +0200 Subject: [PATCH 08/77] more changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8cbff47..cc752cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). in the output files as well as simplifies reading the mtx format when processing the data. - The mapping step has been changed. It will first write the reads to files and then read in the chunks. This should solve the io bottleneck from before. + - There are new options now for parallel computing. `--chunk_size` Determines how many reads will be read per chunk. ## [1.4.3] - 05.10.2019 From ebdda53bd1de08820fd85a282287500aa4c6de88 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sun, 3 May 2020 18:58:46 +0200 Subject: [PATCH 09/77] fixed slidin_window --- CHANGELOG.md | 4 +-- cite_seq_count/__main__.py | 26 ++++++++----------- cite_seq_count/processing.py | 49 ++++++++++++++++++------------------ 3 files changed, 37 insertions(+), 42 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index cc752cf..906c5a0 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,8 +7,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## [1.5.0] - XXXX ### Added - `CITE-se-Count` is now Compatible with trimmed data. There is a new `too_short` category in the `run_report.yaml` - that will let you know how much you lost due to reads being too short. Will not work well with the `--sliding_window` option - because any read that is trimmed even by one base will be discarded. + that will let you know how much you lost due to reads being too short. - UMI correction is now also parallelized and will use the threads proposed. ### Changed - The `features.csv` now has different columns for the tag name and the tag sequence. This keeps the relevant information @@ -16,6 +15,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - The mapping step has been changed. It will first write the reads to files and then read in the chunks. This should solve the io bottleneck from before. - There are new options now for parallel computing. `--chunk_size` Determines how many reads will be read per chunk. + - `--sliding-window` now only checks for exact matches. ## [1.4.3] - 05.10.2019 diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index ced0980..e5364ab 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -307,14 +307,14 @@ def main(): args.umi_first, args.umi_last) # Ensure all files have the same input length - if len(set(read1_lengths)) != 1: - sys.exit('Input barcode fastqs (read1) do not all have same length.\nExiting') - if len(set(read2_lengths)) != 1: - sys.exit('Input barcode fastqs (read2) do not all have same length.\nExiting') + #if len(set(read1_lengths)) != 1: + #sys.exit('Input barcode fastqs (read1) do not all have same length.\nExiting') + #if len(set(read2_lengths)) != 1: + #sys.exit('Input barcode fastqs (read2) do not all have same length.\nExiting') # Define R2_lenght to reduce amount of data to transfer to childrens if args.sliding_window: - R2_max_length = read2_lengths[1] + R2_max_length = read2_lengths[0] else: R2_max_length = longest_tag_len # Initialize the counts dicts that will be generated from each input fastq pair @@ -327,13 +327,7 @@ def main(): #Print a statement if multiple files are run. if number_of_samples != 1: print('Detected {} files to run on.'.format(number_of_samples)) - - - input_queue = [] - #output_queue = Queue() - - #read_struct = namedtuple('read_struct', ['r1', 'r2']) mapping_input = namedtuple('mapping_input', ['filename', 'tags', 'debug', 'maximum_distance', 'sliding_window']) print('Writing chunks to disk') @@ -350,8 +344,6 @@ def main(): print('Reading reads from files: {}, {}'.format(read1_path, read2_path)) with gzip.open(read1_path, 'rt') as textfile1, \ gzip.open(read2_path, 'rt') as textfile2: - - # Read all 2nd lines from 4 line chunks. If first_n not None read only 4 times the given amount. secondlines = islice(zip(textfile1, textfile2), 1, None, 4) temp_filename = os.path.join(temp_path, 'temp_{}'.format(num_chunks)) chunked_file_object = open(temp_filename, 'w') @@ -369,7 +361,7 @@ def main(): # The entire read is skipped continue - read2_sliced = read2[args.start_trim:R2_max_length] + read2_sliced = read2[args.start_trim:(R2_max_length + args.start_trim)] chunked_file_object.write('{},{},{}\n'.format(read1_sliced[barcode_slice], read1_sliced[umi_slice], read2_sliced)) reads_count += 1 if reads_count % chunk_size == 0: @@ -385,6 +377,7 @@ def main(): chunked_file_object = open(temp_filename, 'w') temp_files.append(os.path.abspath(temp_filename)) if reads_count >= args.first_n: + total_reads = args.first_n break input_queue.append(mapping_input( @@ -406,8 +399,6 @@ def main(): if len(errors) != 0: for error in errors: print(error) - - print('Merging results') ( @@ -419,6 +410,9 @@ def main(): del(parallel_results) + # Check if 99% of the reads are unmapped. + processing.check_unmapped(no_match=merged_no_match, total_reads=total_reads, start_trim=args.start_trim) + # Delete temp_files for file_path in temp_files: os.remove(file_path) diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 5acc734..b7bb7e5 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -33,7 +33,7 @@ def find_best_match(TAG_seq, tags, maximum_distance): If no matches found returns 'unmapped'. We add 1 Args: - TAG_seq (string): Sequence from R1 already start trimmed + TAG_seq (string): Sequence from R2 already start trimmed tags (dict): A dictionary with the TAGs as keys and TAG Names as values. maximum_distance (int): Maximum distance given by the user. @@ -54,16 +54,17 @@ def find_best_match(TAG_seq, tags, maximum_distance): return(best_match) -def find_best_match_shift(TAG_seq, tags, maximum_distance): +def find_best_match_shift(TAG_seq, tags): """ Find the best match from the list of tags with sliding window. + Only works with exact match. Compares the Levenshtein distance between tags and the trimmed sequences. The tag and the sequence must have the same length. If no matches found returns 'unmapped'. We add 1 Args: - TAG_seq (string): Sequence from R1 already start trimmed + TAG_seq (string): Sequence from R2 already start trimmed tags (dict): A dictionary with the TAGs as keys and TAG Names as values. maximum_distance (int): Maximum distance given by the user. @@ -71,19 +72,9 @@ def find_best_match_shift(TAG_seq, tags, maximum_distance): best_match (string): The TAG name that will be used for counting. """ best_match = 'unmapped' - best_score = maximum_distance - shifts = range(0,len(TAG_seq) - len(max(tags,key=len))) - - for shift in shifts: - for tag, name in tags.items(): - score = Levenshtein.hamming(tag, TAG_seq[shift:len(tag)+shift]) - if score == 0: - #Best possible match - return(name) - elif score <= best_score: - best_score = score - best_match = name - return(best_match) + for tag in tags: + if tag.sequence in TAG_seq: + return(tag.name) return(best_match) @@ -142,7 +133,7 @@ def map_reads(mapping_input): results[cell_barcode] = defaultdict(Counter) if(sliding_window): - best_match = find_best_match_shift(read2, tags, maximum_distance) + best_match = find_best_match_shift(read2, tags) else: best_match = find_best_match(read2, tags, maximum_distance) @@ -153,12 +144,17 @@ def map_reads(mapping_input): if debug: print( - "\nline:{0}\n" - "cell_barcode:{1}\tUMI:{2}\tTAG_seq:{3}\n" - "line length:{4}\tcell barcode length:{5}\tUMI length:{6}\tTAG sequence length:{7}\n" - "Best match is: {8}" - .format(read1 + read2, cell_barcode, UMI, read2, - len(read1 + read2), len(cell_barcode), len(UMI), len(read2), best_match + "cell_barcode:{0}\tUMI:{1}\tTAG_seq:{2}\n" + "cell barcode length:{3}\tUMI length:{4}\tTAG sequence length:{5}\n" + "Best match is: {6}\n" + .format( + cell_barcode, + UMI, + read2, + len(cell_barcode), + len(UMI), + len(read2), + best_match ) ) sys.stdout.flush() @@ -202,6 +198,11 @@ def merge_results(parallel_results): return(merged_results, umis_per_cell, reads_per_cell, merged_no_match) +def check_unmapped(no_match, total_reads, start_trim): + """Check if the number of unmapped is higher than 99%""" + if sum(no_match.values())/total_reads > float(0.99): + exit("""More than 99 percent of your data is unmapped.\nPlease check that your --start_trim {} parameter is correct and that your tags file is properly formatted""".format(start_trim)) + def correct_umis(umi_correction_input): """ Corrects umi barcodes within same cell/tag groups. @@ -319,7 +320,7 @@ def correct_cells(final_results, reads_per_cell, umis_per_cell, collapsing_thres corrected_umis (int): How many umis have been corrected. """ print('Looking for a whitelist') - cell_whitelist, true_to_false = whitelist_methods.getCellWhitelist( + _, true_to_false = whitelist_methods.getCellWhitelist( cell_barcode_counts=reads_per_cell, expect_cells=expected_cells, cell_number=expected_cells, From e3ba958777d2d6941a297feeb4ab4dbc600a7dda Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sun, 3 May 2020 19:17:45 +0200 Subject: [PATCH 10/77] CHANGELOG update --- CHANGELOG.md | 7 +++-- cite_seq_count/__main__.py | 62 ++++++++++++++++++++------------------ setup.py | 3 +- 3 files changed, 40 insertions(+), 32 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 906c5a0..abf0c9d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -7,15 +7,18 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## [1.5.0] - XXXX ### Added - `CITE-se-Count` is now Compatible with trimmed data. There is a new `too_short` category in the `run_report.yaml` - that will let you know how much you lost due to reads being too short. + that will let you know how much you lost due to reads being too short. #123 - UMI correction is now also parallelized and will use the threads proposed. + - Added a check at the end of the mapping. If more than 99% of the reads are unmapped, CITE-seq-Count will exit. + ### Changed - The `features.csv` now has different columns for the tag name and the tag sequence. This keeps the relevant information in the output files as well as simplifies reading the mtx format when processing the data. - The mapping step has been changed. It will first write the reads to files and then read in the chunks. This should solve the io bottleneck from before. - - There are new options now for parallel computing. `--chunk_size` Determines how many reads will be read per chunk. + - There are new options now for parallel computing. `--chunk_size` Determines how many reads will be read per chunk. #99 - `--sliding-window` now only checks for exact matches. + - Added cython dependency based on issue #117 ## [1.4.3] - 05.10.2019 diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index e5364ab..172b9cb 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -419,36 +419,40 @@ def main(): # Correct cell barcodes - if(len(umis_per_cell) <= args.expected_cells): - print("Number of expected cells, {}, is higher " \ - "than number of cells found {}.\nNot performing " \ - "cell barcode correction" \ - "".format(args.expected_cells, len(umis_per_cell))) - bcs_corrected = 0 - else: - print('Correcting cell barcodes') - if not whitelist: - ( - final_results, - umis_per_cell, - bcs_corrected - ) = processing.correct_cells( - final_results=final_results, - reads_per_cell=reads_per_cell, - umis_per_cell=umis_per_cell, - expected_cells=args.expected_cells, - collapsing_threshold=args.bc_threshold, - ab_map=named_tuples_tags_map) + if args.bc_threshold != 0: + if(len(umis_per_cell) <= args.expected_cells): + print("Number of expected cells, {}, is higher " \ + "than number of cells found {}.\nNot performing " \ + "cell barcode correction" \ + "".format(args.expected_cells, len(umis_per_cell))) + bcs_corrected = 0 else: - ( - final_results, - umis_per_cell, - bcs_corrected) = processing.correct_cells_whitelist( - final_results=final_results, - umis_per_cell=umis_per_cell, - whitelist=whitelist, - collapsing_threshold=args.bc_threshold, - ab_map=named_tuples_tags_map) + print('Correcting cell barcodes') + if not whitelist: + ( + final_results, + umis_per_cell, + bcs_corrected + ) = processing.correct_cells( + final_results=final_results, + reads_per_cell=reads_per_cell, + umis_per_cell=umis_per_cell, + expected_cells=args.expected_cells, + collapsing_threshold=args.bc_threshold, + ab_map=named_tuples_tags_map) + else: + ( + final_results, + umis_per_cell, + bcs_corrected) = processing.correct_cells_whitelist( + final_results=final_results, + umis_per_cell=umis_per_cell, + whitelist=whitelist, + collapsing_threshold=args.bc_threshold, + ab_map=named_tuples_tags_map) + else: + print('Skipping cell barcode correction') + bcs_corrected = 0 # If given, use whitelist for top cells if whitelist: diff --git a/setup.py b/setup.py index a32de53..c276fe4 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,8 @@ 'pytest==4.1.0', 'pytest-dependency==0.4.0', 'pandas>=0.23.4', - 'pybktree==1.1' + 'pybktree==1.1', + 'cython>=0.29.17' ], python_requires='>=3.6' ) From b41424211239c9f73c4fb735a855a1c3a30e4c86 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Mon, 4 May 2020 09:27:01 +0200 Subject: [PATCH 11/77] got rid of second length check --- cite_seq_count/preprocessing.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index aebfd90..87c2946 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -185,15 +185,15 @@ def get_read_length(filename): """ with gzip.open(filename, 'r') as fastq_file: secondlines = islice(fastq_file, 1, 1000, 4) - temp_length = len(next(secondlines).rstrip()) + #temp_length = len(next(secondlines).rstrip()) for sequence in secondlines: read_length = len(sequence.rstrip()) - if (temp_length != read_length): - sys.exit( - '[ERROR] Sequence length in {} is not consistent. Please, trim all ' - 'sequences at the same length.\n' - 'Exiting the application.\n'.format(filename) - ) + # if (temp_length != read_length): + # sys.exit( + # '[ERROR] Sequence length in {} is not consistent. Please, trim all ' + # 'sequences at the same length.\n' + # 'Exiting the application.\n'.format(filename) + # ) return(read_length) From 795a1a494413dc24084b08c5807b6e05731aff31 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sun, 5 Jul 2020 19:08:37 +0200 Subject: [PATCH 12/77] integrated a pull from db for chemistry def --- cite_seq_count/__main__.py | 87 +++++++++++++++++++++++--------------- cite_seq_count/database.py | 42 ++++++++++++++++++ 2 files changed, 96 insertions(+), 33 deletions(-) create mode 100644 cite_seq_count/database.py diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 172b9cb..b16c34e 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -18,6 +18,7 @@ from cite_seq_count import preprocessing from cite_seq_count import processing +from cite_seq_count import database from cite_seq_count import io from cite_seq_count import secondsToText @@ -38,12 +39,13 @@ def get_args(): """ Get args. """ + parser = ArgumentParser( prog='CITE-seq-Count', formatter_class=RawTextHelpFormatter, description=("This script counts matching antibody tags from paired fastq " "files. Version {}".format(version)), ) - + # REQUIRED INPUTS group. inputs = parser.add_argument_group('Inputs', description="Required input files.") @@ -61,30 +63,31 @@ def get_args(): "\tATGCGA,First_tag_name\n" "\tGTCATG,Second_tag_name") ) - - # BARCODES group. + # BARCODES group. barcodes = parser.add_argument_group( 'Barcodes', description=("Positions of the cellular barcodes and UMI. If your " - "cellular barcodes and UMI\n are positioned as follows:\n" - "\tBarcodes from 1 to 16 and UMI from 17 to 26\n" - "then this is the input you need:\n" - "\t-cbf 1 -cbl 16 -umif 17 -umil 26") + "cellular barcodes and UMI\n are positioned as follows:\n" + "\tBarcodes from 1 to 16 and UMI from 17 to 26\n" + "then this is the input you need:\n" + "\t-cbf 1 -cbl 16 -umif 17 -umil 26") ) - barcodes.add_argument('-cbf', '--cell_barcode_first_base', dest='cb_first', - required=True, type=int, - help=("Postion of the first base of your cell " - "barcodes.")) - barcodes.add_argument('-cbl', '--cell_barcode_last_base', dest='cb_last', - required=True, type=int, - help=("Postion of the last base of your cell " - "barcodes.")) - barcodes.add_argument('-umif', '--umi_first_base', dest='umi_first', - required=True, type=int, - help="Postion of the first base of your UMI.") - barcodes.add_argument('-umil', '--umi_last_base', dest='umi_last', - required=True, type=int, - help="Postion of the last base of your UMI.") + barcodes.add_argument('--chemistry', type=str, required=False, default=False) + if '--chemistry' not in sys.argv: + barcodes.add_argument('-cbf', '--cell_barcode_first_base', dest='cb_first', + required=True, type=int, + help=("Postion of the first base of your cell " + "barcodes.")) + barcodes.add_argument('-cbl', '--cell_barcode_last_base', dest='cb_last', + required=True, type=int, + help=("Postion of the last base of your cell " + "barcodes.")) + barcodes.add_argument('-umif', '--umi_first_base', dest='umi_first', + required=True, type=int, + help="Postion of the first base of your UMI.") + barcodes.add_argument('-umil', '--umi_last_base', dest='umi_last', + required=True, type=int, + help="Postion of the last base of your UMI.") barcodes.add_argument('--umi_collapsing_dist', dest='umi_threshold', required=False, type=int, default=2, help="threshold for umi collapsing.") @@ -112,6 +115,15 @@ def get_args(): "Or 10X-style:\n" "\tATGCTAGTGCTA-1\n\tGCTAGTCAGGAT-1\n\tCGACTGCTAACG-1\n") ) + if '--chemistry' not in sys.argv: + cells.add_argument('--translation', required=False, type=str, + help="A csv file containing the mapping between two sets of cell barcode list.\n" + "A required header such as the reference is named whitelist. Example:\n\n" + "\twhitelist,feature_barcoding_map\n" + "\tAAACCCAAGAAACACT,AAACCCATCAAACACT\n" + "\tAAACCCAAGAAACCAT,AAACCCATCAAACCAT\n" + "\nThe output matrix will possess both cell barcode IDs" + ) # FILTERS group. filters = parser.add_argument_group( @@ -175,7 +187,7 @@ def get_args(): return parser -def create_report(total_reads, reads_per_cell, no_match, version, start_time, ordered_tags_map, umis_corrected, bcs_corrected, bad_cells, R1_too_short, R2_too_short, args): +def create_report(total_reads, reads_per_cell, no_match, version, start_time, ordered_tags_map, umis_corrected, bcs_corrected, bad_cells, R1_too_short, R2_too_short, args, chemistry_def): """ Creates a report with details about the run in a yaml format. Args: @@ -243,10 +255,10 @@ def create_report(total_reads, reads_per_cell, no_match, version, start_time, or args.cb_first, args.cb_last, args.umi_first, - args.umi_last, + chemistry_def.umi_end, args.expected_cells, args.max_error, - args.start_trim)) + chemistry_def.R2_start_trim)) def main(): #Create logger and stream handler @@ -277,6 +289,12 @@ def main(): else: whitelist = False + # Get chemistry defs + if args.chemistry: + + chemistry_def = database.get_chemistry_definition(args.chemistry) + else: + chemistry_def = database.create_chemistry_definition(args) # Load TAGs/ABs. ab_map = preprocessing.parse_tags_csv(args.tags) ordered_tags_map, longest_tag_len = preprocessing.check_tags(ab_map, args.max_error) @@ -289,6 +307,7 @@ def main(): read1_lengths = [] read2_lengths = [] total_reads = 0 + for read1_path, read2_path in zip(read1_paths, read2_paths): n_lines = preprocessing.get_n_lines(read1_path) total_reads += n_lines/4 @@ -302,9 +321,10 @@ def main(): _ ) = preprocessing.check_barcodes_lengths( read1_lengths[-1], - args.cb_first, - args.cb_last, - args.umi_first, args.umi_last) + chemistry_def.barcode_start, + chemistry_def.barcode_end, + chemistry_def.umi_start, + chemistry_def.umi_end) # Ensure all files have the same input length #if len(set(read1_lengths)) != 1: @@ -351,17 +371,17 @@ def main(): for read1, read2 in secondlines: read1 = read1.strip() - if len(read1) < args.umi_last: + if len(read1) < chemistry_def.umi_end: R1_too_short += 1 # The entire read is skipped continue - read1_sliced = read1[0:args.umi_last] + read1_sliced = read1[0:chemistry_def.umi_end] if len(read2) < R2_max_length: R2_too_short += 1 # The entire read is skipped continue - read2_sliced = read2[args.start_trim:(R2_max_length + args.start_trim)] + read2_sliced = read2[chemistry_def.R2_start_trim:(R2_max_length + chemistry_def.R2_start_trim)] chunked_file_object.write('{},{},{}\n'.format(read1_sliced[barcode_slice], read1_sliced[umi_slice], read2_sliced)) reads_count += 1 if reads_count % chunk_size == 0: @@ -411,7 +431,7 @@ def main(): del(parallel_results) # Check if 99% of the reads are unmapped. - processing.check_unmapped(no_match=merged_no_match, total_reads=total_reads, start_trim=args.start_trim) + processing.check_unmapped(no_match=merged_no_match, total_reads=total_reads, start_trim=chemistry_def.R2_start_trim) # Delete temp_files for file_path in temp_files: @@ -486,7 +506,7 @@ def main(): n_cells = 0 num_chunks = 0 - cell_batch_size = round(len(top_cells)/args.n_threads)+1 + cell_batch_size = round(len(top_cells)/args.n_threads) + 1 for cell in top_cells: cells[cell] = final_results[cell] n_cells += 1 @@ -593,7 +613,8 @@ def main(): bad_cells=aberrant_cells, R1_too_short=R1_too_short, R2_too_short=R2_too_short, - args=args) + args=args, + chemistry_def=chemistry_def) #Write dense matrix to disk if requested if args.dense: diff --git a/cite_seq_count/database.py b/cite_seq_count/database.py new file mode 100644 index 0000000..b373a6c --- /dev/null +++ b/cite_seq_count/database.py @@ -0,0 +1,42 @@ +"""This module is holding code for all the remote fetching on the chemistries database.""" +import requests +import sys + +from collections import namedtuple + + +CHEMISTRY_DEFINITIONS = "https://raw.githubusercontent.com/Hoohm/scg_lib_structs/10xv3_totalseq_b/chemistries/definitions.json" + +CHEMISTRY_DEFINITION = namedtuple('chemistry_def', ['barcode_start','barcode_end','umi_start','umi_end', 'R2_trim_start']) + +def get_chemistry_definition(chemistry_short_name, url=CHEMISTRY_DEFINITIONS): + """ + Fetches chemistry definitions from a remote definitions.json and returns the json. + """ + print('Loading remote file from: {}'.format(url)) + with requests.get(url) as r: + r.raise_for_status() + chemistry_defs = r.json().get(chemistry_short_name, False) + if not chemistry_defs: + sys.exit('Could not find the chemistry: {}. Please check that it does exist at: {}\nExiting'.format(chemistry_short_name, url)) + + chemistry_def = CHEMISTRY_DEFINITION( + barcode_start=chemistry_defs['barcode_structure_indexes']['cell_barcode']['R1']['start'], + barcode_end=chemistry_defs['barcode_structure_indexes']['cell_barcode']['R1']['stop'], + umi_start=chemistry_defs['barcode_structure_indexes']['umi_barcode']['R1']['start'], + umi_end=chemistry_defs['barcode_structure_indexes']['umi_barcode']['R1']['stop'], + R2_trim_start=chemistry_defs['sequence_structure_indexes']['R2']['start'] - 1 + ) + return(chemistry_def) + +def create_chemistry_definition(args): + """ + """ + chemistry_def = CHEMISTRY_DEFINITION( + barcode_start=args.cbf, + barcode_end=args.cbl, + umi_start=args.umif, + umi_end=args.umil, + R2_trim_start=args.start_trim + ) + return(chemistry_def) \ No newline at end of file From 9862a82ce39ca0fb533c2b815fcc8d952176affd Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sun, 2 Aug 2020 17:54:53 +0200 Subject: [PATCH 13/77] added remote downloading of definitions --- CHANGELOG.md | 3 +- cite_seq_count/__main__.py | 827 ++++++++++++++++++++------------ cite_seq_count/database.py | 90 +++- cite_seq_count/preprocessing.py | 188 +++++--- cite_seq_count/processing.py | 231 +++++---- setup.py | 28 +- 6 files changed, 851 insertions(+), 516 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index abf0c9d..98c955d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,10 +6,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ## [1.5.0] - XXXX ### Added - - `CITE-se-Count` is now Compatible with trimmed data. There is a new `too_short` category in the `run_report.yaml` + - `CITE-seq-Count` is now Compatible with trimmed data. There is a new `too_short` category in the `run_report.yaml` that will let you know how much you lost due to reads being too short. #123 - UMI correction is now also parallelized and will use the threads proposed. - Added a check at the end of the mapping. If more than 99% of the reads are unmapped, CITE-seq-Count will exit. + - (BETA) New functionnality that will fetch the chemistry definition from a remote repo to simplify usage and reduce human errors. ### Changed - The `features.csv` now has different columns for the tag name and the tag sequence. This keeps the relevant information diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index b16c34e..a4bf9af 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -9,12 +9,14 @@ import pkg_resources import logging import gzip +import requests from itertools import islice from argparse import ArgumentParser, ArgumentTypeError, RawTextHelpFormatter from collections import OrderedDict, Counter, defaultdict, namedtuple from multiprocess import cpu_count, Pool, Queue, JoinableQueue, Process +from tempfile import NamedTemporaryFile from cite_seq_count import preprocessing from cite_seq_count import processing @@ -24,14 +26,18 @@ version = pkg_resources.require("cite_seq_count")[0].version + def chunk_size_limit(arg): """Validates chunk_size limits""" + max_size = 2147483647 try: f = int(arg) - except ValueError: + except ValueError: raise ArgumentTypeError("Must be am int") - if f < 1 or f > 2147483647: - raise ArgumentTypeError("Argument must be < " + str(2147483647) + "and > " + str(1)) + if f < 1 or f > max_size: + raise ArgumentTypeError( + "Argument must be < " + str(max_size) + "and > " + str(1) + ) return f @@ -39,155 +45,300 @@ def get_args(): """ Get args. """ - + parser = ArgumentParser( - prog='CITE-seq-Count', formatter_class=RawTextHelpFormatter, - description=("This script counts matching antibody tags from paired fastq " - "files. Version {}".format(version)), + prog="CITE-seq-Count", + formatter_class=RawTextHelpFormatter, + description=( + "This script counts matching antibody tags from paired fastq " + "files. Version {}".format(version) + ), ) - + # REQUIRED INPUTS group. - inputs = parser.add_argument_group('Inputs', - description="Required input files.") - inputs.add_argument('-R1', '--read1', dest='read1_path', required=True, - help=("The path of Read1 in gz format, or a comma-separated list of paths to all Read1 files in" - " gz format (E.g. A1.fq.gz,B1.fq,gz,...")) - inputs.add_argument('-R2', '--read2', dest='read2_path', required=True, - help=("The path of Read2 in gz format, or a comma-separated list of paths to all Read2 files in" - " gz format (E.g. A2.fq.gz,B2.fq,gz,...")) + inputs = parser.add_argument_group("Inputs", description="Required input files.") inputs.add_argument( - '-t', '--tags', dest='tags', required=True, - help=("The path to the csv file containing the antibody\n" - "barcodes as well as their respective names.\n\n" - "Example of an antibody barcode file structure:\n\n" - "\tATGCGA,First_tag_name\n" - "\tGTCATG,Second_tag_name") - ) - # BARCODES group. + "-R1", + "--read1", + dest="read1_path", + required=True, + help=( + "The path of Read1 in gz format, or a comma-separated list of paths to all Read1 files in" + " gz format (E.g. A1.fq.gz,B1.fq,gz,..." + ), + ) + inputs.add_argument( + "-R2", + "--read2", + dest="read2_path", + required=True, + help=( + "The path of Read2 in gz format, or a comma-separated list of paths to all Read2 files in" + " gz format (E.g. A2.fq.gz,B2.fq,gz,..." + ), + ) + inputs.add_argument( + "-t", + "--tags", + dest="tags", + required=True, + help=( + "The path to the csv file containing the antibody\n" + "barcodes as well as their respective names.\n\n" + "Example of an antibody barcode file structure:\n\n" + "\tATGCGA,First_tag_name\n" + "\tGTCATG,Second_tag_name" + ), + ) + # BARCODES group. barcodes = parser.add_argument_group( - 'Barcodes', - description=("Positions of the cellular barcodes and UMI. If your " - "cellular barcodes and UMI\n are positioned as follows:\n" - "\tBarcodes from 1 to 16 and UMI from 17 to 26\n" - "then this is the input you need:\n" - "\t-cbf 1 -cbl 16 -umif 17 -umil 26") - ) - barcodes.add_argument('--chemistry', type=str, required=False, default=False) - if '--chemistry' not in sys.argv: - barcodes.add_argument('-cbf', '--cell_barcode_first_base', dest='cb_first', - required=True, type=int, - help=("Postion of the first base of your cell " - "barcodes.")) - barcodes.add_argument('-cbl', '--cell_barcode_last_base', dest='cb_last', - required=True, type=int, - help=("Postion of the last base of your cell " - "barcodes.")) - barcodes.add_argument('-umif', '--umi_first_base', dest='umi_first', - required=True, type=int, - help="Postion of the first base of your UMI.") - barcodes.add_argument('-umil', '--umi_last_base', dest='umi_last', - required=True, type=int, - help="Postion of the last base of your UMI.") - barcodes.add_argument('--umi_collapsing_dist', dest='umi_threshold', - required=False, type=int, default=2, - help="threshold for umi collapsing.") - barcodes.add_argument('--no_umi_correction', required=False, action='store_true', default=False, - dest='no_umi_correction', help="Deactivate UMI collapsing") - barcodes.add_argument('--bc_collapsing_dist', dest='bc_threshold', - required=False, type=int, default=1, - help="threshold for cellular barcode collapsing.") + "Barcodes", + description=( + "Positions of the cellular barcodes and UMI. If your " + "cellular barcodes and UMI\n are positioned as follows:\n" + "\tBarcodes from 1 to 16 and UMI from 17 to 26\n" + "then this is the input you need:\n" + "\t-cbf 1 -cbl 16 -umif 17 -umil 26" + ), + ) + barcodes.add_argument("--chemistry", type=str, required=False, default=False) + if "--chemistry" not in sys.argv: + barcodes.add_argument( + "-cbf", + "--cell_barcode_first_base", + dest="cb_first", + required=True, + type=int, + help=("Postion of the first base of your cell " "barcodes."), + ) + barcodes.add_argument( + "-cbl", + "--cell_barcode_last_base", + dest="cb_last", + required=True, + type=int, + help=("Postion of the last base of your cell " "barcodes."), + ) + barcodes.add_argument( + "-umif", + "--umi_first_base", + dest="umi_first", + required=True, + type=int, + help="Postion of the first base of your UMI.", + ) + barcodes.add_argument( + "-umil", + "--umi_last_base", + dest="umi_last", + required=True, + type=int, + help="Postion of the last base of your UMI.", + ) + barcodes.add_argument( + "--umi_collapsing_dist", + dest="umi_threshold", + required=False, + type=int, + default=2, + help="threshold for umi collapsing.", + ) + barcodes.add_argument( + "--no_umi_correction", + required=False, + action="store_true", + default=False, + dest="no_umi_correction", + help="Deactivate UMI collapsing", + ) + barcodes.add_argument( + "--bc_collapsing_dist", + dest="bc_threshold", + required=False, + type=int, + default=1, + help="threshold for cellular barcode collapsing.", + ) # Cells group cells = parser.add_argument_group( - 'Cells', - description=("Expected number of cells and potential whitelist") + "Cells", description=("Expected number of cells and potential whitelist") ) cells.add_argument( - '-cells', '--expected_cells', dest='expected_cells', required=True, type=int, - help=("Number of expected cells from your run."), default=0 + "-cells", + "--expected_cells", + dest="expected_cells", + required=True, + type=int, + help=("Number of expected cells from your run."), + default=0, ) - cells.add_argument( - '-wl', '--whitelist', dest='whitelist', required=False, type=str, - help=("A csv file containning a whitelist of barcodes produced" - " by the mRNA data.\n\n" - "\tExample:\n" - "\tATGCTAGTGCTA\n\tGCTAGTCAGGAT\n\tCGACTGCTAACG\n\n" - "Or 10X-style:\n" - "\tATGCTAGTGCTA-1\n\tGCTAGTCAGGAT-1\n\tCGACTGCTAACG-1\n") - ) - if '--chemistry' not in sys.argv: - cells.add_argument('--translation', required=False, type=str, + if "--chemistry" not in sys.argv: + cells.add_argument( + "-wl", + "--whitelist", + dest="whitelist", + required=False, + type=str, + help=( + "A csv file containning a whitelist of barcodes produced" + " by the mRNA data.\n\n" + "\tExample:\n" + "\tATGCTAGTGCTA\n\tGCTAGTCAGGAT\n\tCGACTGCTAACG\n\n" + "Or 10X-style:\n" + "\tATGCTAGTGCTA-1\n\tGCTAGTCAGGAT-1\n\tCGACTGCTAACG-1\n" + ), + ) + + cells.add_argument( + "--translation", + required=False, + type=str, help="A csv file containing the mapping between two sets of cell barcode list.\n" - "A required header such as the reference is named whitelist. Example:\n\n" - "\twhitelist,feature_barcoding_map\n" - "\tAAACCCAAGAAACACT,AAACCCATCAAACACT\n" - "\tAAACCCAAGAAACCAT,AAACCCATCAAACCAT\n" - "\nThe output matrix will possess both cell barcode IDs" - ) + "A required header such as the reference is named whitelist. Example:\n\n" + "\twhitelist,trasnlation\n" + "\tAAACCCAAGAAACACT,AAACCCATCAAACACT\n" + "\tAAACCCAAGAAACCAT,AAACCCATCAAACCAT\n" + "\nThe output matrix will possess both cell barcode IDs", + ) # FILTERS group. filters = parser.add_argument_group( - 'TAG filters', - description=("Filtering and trimming for read2.") + "TAG filters", description=("Filtering and trimming for read2.") ) filters.add_argument( - '--max-errors', dest='max_error', - required=False, type=int, default=2, - help=("Maximum Levenshtein distance allowed for antibody barcodes.") + "--max-errors", + dest="max_error", + required=False, + type=int, + default=2, + help=("Maximum Levenshtein distance allowed for antibody barcodes."), ) - + filters.add_argument( - '-trim', '--start-trim', dest='start_trim', - required=False, type=int, default=0, - help=("Number of bases to discard from read2.") + "-trim", + "--start-trim", + dest="start_trim", + required=False, + type=int, + default=0, + help=("Number of bases to discard from read2."), ) - + filters.add_argument( - '--sliding-window', dest='sliding_window', - required=False, default=False, action='store_true', - help=("Allow for a sliding window when aligning.") + "--sliding-window", + dest="sliding_window", + required=False, + default=False, + action="store_true", + help=("Allow for a sliding window when aligning."), ) - + # Parallel group. parallel = parser.add_argument_group( - 'Parallelization options', - description=("Options for performance on parallelization") + "Parallelization options", + description=("Options for performance on parallelization"), ) # Remaining arguments. - parallel.add_argument('-T', '--threads', required=False, type=int, - dest='n_threads', default=cpu_count(), - help="How many threads are to be used for running the program") - parallel.add_argument('-C', '--chunk_size', required=False, type=chunk_size_limit, - dest='chunk_size', - help="How many reads should be sent to a child process at a time") - parallel.add_argument('--temp_path', required=False, type=str, - dest='temp_path', default="", - help="Temp folder for chunk creation specification. Useful when using a cluster with a scratch folder") - + parallel.add_argument( + "-T", + "--threads", + required=False, + type=int, + dest="n_threads", + default=cpu_count(), + help="How many threads are to be used for running the program", + ) + parallel.add_argument( + "-C", + "--chunk_size", + required=False, + type=chunk_size_limit, + dest="chunk_size", + help="How many reads should be sent to a child process at a time", + ) + parallel.add_argument( + "--temp_path", + required=False, + type=str, + dest="temp_path", + default="", + help="Temp folder for chunk creation specification. Useful when using a cluster with a scratch folder", + ) # Global group - parser.add_argument('-n', '--first_n', required=False, type=int, - dest='first_n', default=float('inf'), - help="Select N reads to run on instead of all.") - parser.add_argument('-o', '--output', required=False, type=str, default='Results', - dest='outfolder', help="Results will be written to this folder") - parser.add_argument('--dense', required=False, action='store_true', default=False, - dest='dense', help="Add a dense output to the results folder") - parser.add_argument('-u', '--unmapped-tags', required=False, type=str, - dest='unmapped_file', default='unmapped.csv', - help="Write table of unknown TAGs to file.") - parser.add_argument('-ut', '--unknown-top-tags', required=False, - dest='unknowns_top', type=int, default=100, - help="Top n unmapped TAGs.") - parser.add_argument('--debug', action='store_true', - help="Print extra information for debugging.") - parser.add_argument('--version', action='version', version='CITE-seq-Count v{}'.format(version), - help="Print version number.") + parser.add_argument( + "-n", + "--first_n", + required=False, + type=int, + dest="first_n", + default=float("inf"), + help="Select N reads to run on instead of all.", + ) + parser.add_argument( + "-o", + "--output", + required=False, + type=str, + default="Results", + dest="outfolder", + help="Results will be written to this folder", + ) + parser.add_argument( + "--dense", + required=False, + action="store_true", + default=False, + dest="dense", + help="Add a dense output to the results folder", + ) + parser.add_argument( + "-u", + "--unmapped-tags", + required=False, + type=str, + dest="unmapped_file", + default="unmapped.csv", + help="Write table of unknown TAGs to file.", + ) + parser.add_argument( + "-ut", + "--unknown-top-tags", + required=False, + dest="unknowns_top", + type=int, + default=100, + help="Top n unmapped TAGs.", + ) + parser.add_argument( + "--debug", action="store_true", help="Print extra information for debugging." + ) + parser.add_argument( + "--version", + action="version", + version="CITE-seq-Count v{}".format(version), + help="Print version number.", + ) # Finally! Too many options XD return parser -def create_report(total_reads, reads_per_cell, no_match, version, start_time, ordered_tags_map, umis_corrected, bcs_corrected, bad_cells, R1_too_short, R2_too_short, args, chemistry_def): +def create_report( + total_reads, + reads_per_cell, + no_match, + version, + start_time, + ordered_tags_map, + umis_corrected, + bcs_corrected, + bad_cells, + R1_too_short, + R2_too_short, + args, + chemistry_def, +): """ Creates a report with details about the run in a yaml format. Args: @@ -202,13 +353,13 @@ def create_report(total_reads, reads_per_cell, no_match, version, start_time, or total_unmapped = sum(no_match.values()) total_mapped = total_reads - total_unmapped total_too_short = total_reads - total_unmapped - total_mapped - too_short_perc = round((total_too_short/total_reads)*100) - mapped_perc = round((total_mapped/total_reads)*100) - unmapped_perc = round((total_unmapped/total_reads)*100) - - with open(os.path.join(args.outfolder, 'run_report.yaml'), 'w') as report_file: + too_short_perc = round((total_too_short / total_reads) * 100) + mapped_perc = round((total_mapped / total_reads) * 100) + unmapped_perc = round((total_unmapped / total_reads) * 100) + + with open(os.path.join(args.outfolder, "run_report.yaml"), "w") as report_file: report_file.write( -"""Date: {} + """Date: {} Running time: {} CITE-seq-Count Version: {} Reads processed: {} @@ -236,37 +387,42 @@ def create_report(total_reads, reads_per_cell, no_match, version, start_time, or \tTags max errors: {} \tStart trim: {} """.format( - datetime.datetime.today().strftime('%Y-%m-%d'), - secondsToText.secondsToText(time.time()-start_time), - version, - int(total_reads), - mapped_perc, - unmapped_perc, - too_short_perc, - R1_too_short, - R2_too_short, - len(bad_cells), - args.bc_threshold, - bcs_corrected, - args.umi_threshold, - umis_corrected, - args.read1_path, - args.read2_path, - args.cb_first, - args.cb_last, - args.umi_first, - chemistry_def.umi_end, - args.expected_cells, - args.max_error, - chemistry_def.R2_start_trim)) + datetime.datetime.today().strftime("%Y-%m-%d"), + secondsToText.secondsToText(time.time() - start_time), + version, + int(total_reads), + mapped_perc, + unmapped_perc, + too_short_perc, + R1_too_short, + R2_too_short, + len(bad_cells), + args.bc_threshold, + bcs_corrected, + args.umi_threshold, + umis_corrected, + args.read1_path, + args.read2_path, + args.cb_first, + args.cb_last, + args.umi_first, + chemistry_def.umi_barcode_end, + args.expected_cells, + args.max_error, + chemistry_def.R2_trim_start, + ) + ) + def main(): - #Create logger and stream handler - logger = logging.getLogger('cite_seq_count') + # Create logger and stream handler + logger = logging.getLogger("cite_seq_count") logger.setLevel(logging.CRITICAL) ch = logging.StreamHandler() ch.setLevel(logging.CRITICAL) - formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s') + formatter = logging.Formatter( + "%(asctime)s - %(name)s - %(levelname)s - %(message)s" + ) ch.setFormatter(formatter) logger.addHandler(ch) @@ -280,57 +436,71 @@ def main(): args = parser.parse_args() temp_path = os.path.abspath(args.temp_path) assert os.access(temp_path, os.W_OK) - if args.whitelist: - print('Loading whitelist') - (whitelist, args.bc_threshold) = preprocessing.parse_whitelist_csv( - filename=args.whitelist, - barcode_length=args.cb_last - args.cb_first + 1, - collapsing_threshold=args.bc_threshold) - else: - whitelist = False # Get chemistry defs if args.chemistry: - chemistry_def = database.get_chemistry_definition(args.chemistry) + with requests.get(chemistry_def.whitelist_path) as r: + r.raise_for_status() + ###### TODO deal with the download of the whitelist from remote + with NamedTemporaryFile() as temp_local_whitelist: + temp_local_whitelist.write(r.content) + (whitelist, args.bc_threshold) = preprocessing.parse_whitelist_csv( + filename=temp_local_whitelist, + barcode_length=chemistry_def.cell_barcode_end + - chemistry_def.cell_barcode_start + + 1, + collapsing_threshold=args.bc_threshold, + ) else: chemistry_def = database.create_chemistry_definition(args) + if args.whitelist: + print("Loading whitelist") + (whitelist, args.bc_threshold) = preprocessing.parse_whitelist_csv( + filename=args.whitelist, + barcode_length=args.cb_last - args.cb_first + 1, + collapsing_threshold=args.bc_threshold, + ) + else: + whitelist = False + # Load TAGs/ABs. ab_map = preprocessing.parse_tags_csv(args.tags) ordered_tags_map, longest_tag_len = preprocessing.check_tags(ab_map, args.max_error) - named_tuples_tags_map = preprocessing.convert_to_named_tuple(ordered_tags=ordered_tags_map) + named_tuples_tags_map = preprocessing.convert_to_named_tuple( + ordered_tags=ordered_tags_map + ) # Identify input file(s) - read1_paths, read2_paths = preprocessing.get_read_paths(args.read1_path, args.read2_path) + read1_paths, read2_paths = preprocessing.get_read_paths( + args.read1_path, args.read2_path + ) # preprocessing and processing occur in separate loops so the program can crash earlier if # one of the inputs is not valid. read1_lengths = [] read2_lengths = [] total_reads = 0 - + for read1_path, read2_path in zip(read1_paths, read2_paths): n_lines = preprocessing.get_n_lines(read1_path) - total_reads += n_lines/4 + total_reads += n_lines / 4 # Get reads length. So far, there is no validation for Read2. read1_lengths.append(preprocessing.get_read_length(read1_path)) read2_lengths.append(preprocessing.get_read_length(read2_path)) # Check Read1 length against CELL and UMI barcodes length. - ( - barcode_slice, - umi_slice, - _ - ) = preprocessing.check_barcodes_lengths( - read1_lengths[-1], - chemistry_def.barcode_start, - chemistry_def.barcode_end, - chemistry_def.umi_start, - chemistry_def.umi_end) - + (barcode_slice, umi_slice, _) = preprocessing.check_barcodes_lengths( + read1_lengths[-1], + chemistry_def.cell_barcode_start, + chemistry_def.cell_barcode_end, + chemistry_def.umi_barcode_start, + chemistry_def.umi_barcode_end, + ) + # Ensure all files have the same input length - #if len(set(read1_lengths)) != 1: - #sys.exit('Input barcode fastqs (read1) do not all have same length.\nExiting') - #if len(set(read2_lengths)) != 1: - #sys.exit('Input barcode fastqs (read2) do not all have same length.\nExiting') + # if len(set(read1_lengths)) != 1: + # sys.exit('Input barcode fastqs (read1) do not all have same length.\nExiting') + # if len(set(read2_lengths)) != 1: + # sys.exit('Input barcode fastqs (read2) do not all have same length.\nExiting') # Define R2_lenght to reduce amount of data to transfer to childrens if args.sliding_window: @@ -343,76 +513,101 @@ def main(): reads_per_cell = Counter() merged_no_match = Counter() number_of_samples = len(read1_paths) - - #Print a statement if multiple files are run. + + # Print a statement if multiple files are run. if number_of_samples != 1: - print('Detected {} files to run on.'.format(number_of_samples)) + print("Detected {} files to run on.".format(number_of_samples)) input_queue = [] - mapping_input = namedtuple('mapping_input', ['filename', 'tags', 'debug', 'maximum_distance', 'sliding_window']) + mapping_input = namedtuple( + "mapping_input", + ["filename", "tags", "debug", "maximum_distance", "sliding_window"], + ) - print('Writing chunks to disk') + print("Writing chunks to disk") reads_count = 0 num_chunks = 0 if args.chunk_size: chunk_size = args.chunk_size else: - chunk_size = round(total_reads/args.n_threads) + 1 + chunk_size = round(total_reads / args.n_threads) + 1 temp_files = [] R1_too_short = 0 R2_too_short = 0 for read1_path, read2_path in zip(read1_paths, read2_paths): - print('Reading reads from files: {}, {}'.format(read1_path, read2_path)) - with gzip.open(read1_path, 'rt') as textfile1, \ - gzip.open(read2_path, 'rt') as textfile2: + print("Reading reads from files: {}, {}".format(read1_path, read2_path)) + with gzip.open(read1_path, "rt") as textfile1, gzip.open( + read2_path, "rt" + ) as textfile2: secondlines = islice(zip(textfile1, textfile2), 1, None, 4) - temp_filename = os.path.join(temp_path, 'temp_{}'.format(num_chunks)) - chunked_file_object = open(temp_filename, 'w') + temp_filename = os.path.join(temp_path, "temp_{}".format(num_chunks)) + chunked_file_object = open(temp_filename, "w") temp_files.append(os.path.abspath(temp_filename)) for read1, read2 in secondlines: - + read1 = read1.strip() - if len(read1) < chemistry_def.umi_end: + if len(read1) < chemistry_def.umi_barcode_end: R1_too_short += 1 # The entire read is skipped continue - read1_sliced = read1[0:chemistry_def.umi_end] + read1_sliced = read1[0 : chemistry_def.umi_barcode_end] if len(read2) < R2_max_length: R2_too_short += 1 # The entire read is skipped continue - - read2_sliced = read2[chemistry_def.R2_start_trim:(R2_max_length + chemistry_def.R2_start_trim)] - chunked_file_object.write('{},{},{}\n'.format(read1_sliced[barcode_slice], read1_sliced[umi_slice], read2_sliced)) + + read2_sliced = read2[ + chemistry_def.R2_trim_start : ( + R2_max_length + chemistry_def.R2_trim_start + ) + ] + chunked_file_object.write( + "{},{},{}\n".format( + read1_sliced[barcode_slice], + read1_sliced[umi_slice], + read2_sliced, + ) + ) reads_count += 1 if reads_count % chunk_size == 0: - input_queue.append(mapping_input( - filename=temp_filename, - tags=named_tuples_tags_map, - debug=args.debug, - maximum_distance=args.max_error, - sliding_window=args.sliding_window)) - num_chunks +=1 + input_queue.append( + mapping_input( + filename=temp_filename, + tags=named_tuples_tags_map, + debug=args.debug, + maximum_distance=args.max_error, + sliding_window=args.sliding_window, + ) + ) + num_chunks += 1 chunked_file_object.close() - temp_filename = 'temp_{}'.format(num_chunks) - chunked_file_object = open(temp_filename, 'w') + temp_filename = "temp_{}".format(num_chunks) + chunked_file_object = open(temp_filename, "w") temp_files.append(os.path.abspath(temp_filename)) if reads_count >= args.first_n: total_reads = args.first_n break - - input_queue.append(mapping_input( - filename=temp_filename, - tags=named_tuples_tags_map, - debug=args.debug, - maximum_distance=args.max_error, - sliding_window=args.sliding_window)) + + input_queue.append( + mapping_input( + filename=temp_filename, + tags=named_tuples_tags_map, + debug=args.debug, + maximum_distance=args.max_error, + sliding_window=args.sliding_window, + ) + ) chunked_file_object.close() - - print('Started mapping') + + print("Started mapping") parallel_results = [] pool = Pool(processes=args.n_threads) errors = [] - mapping = pool.map_async(processing.map_reads, input_queue, callback=parallel_results.append, error_callback=errors.append) + mapping = pool.map_async( + processing.map_reads, + input_queue, + callback=parallel_results.append, + error_callback=errors.append, + ) mapping.wait() pool.close() pool.join() @@ -420,58 +615,68 @@ def main(): for error in errors: print(error) - print('Merging results') + print("Merging results") ( final_results, umis_per_cell, reads_per_cell, - merged_no_match + merged_no_match, ) = processing.merge_results(parallel_results=parallel_results[0]) - - del(parallel_results) - - # Check if 99% of the reads are unmapped. - processing.check_unmapped(no_match=merged_no_match, total_reads=total_reads, start_trim=chemistry_def.R2_start_trim) + del parallel_results + + # Check if 99% of the reads are unmapped. + processing.check_unmapped( + no_match=merged_no_match, + total_reads=total_reads, + start_trim=chemistry_def.R2_trim_start, + ) # Delete temp_files for file_path in temp_files: - os.remove(file_path) - + if os.path.exists(file_path): + os.remove(file_path) + else: + print("Could not find file: {}".format(file_path)) # Correct cell barcodes if args.bc_threshold != 0: - if(len(umis_per_cell) <= args.expected_cells): - print("Number of expected cells, {}, is higher " \ - "than number of cells found {}.\nNot performing " \ - "cell barcode correction" \ - "".format(args.expected_cells, len(umis_per_cell))) + if len(umis_per_cell) <= args.expected_cells: + print( + "Number of expected cells, {}, is higher " + "than number of cells found {}.\nNot performing " + "cell barcode correction" + "".format(args.expected_cells, len(umis_per_cell)) + ) bcs_corrected = 0 else: - print('Correcting cell barcodes') + print("Correcting cell barcodes") if not whitelist: ( final_results, umis_per_cell, - bcs_corrected + bcs_corrected, ) = processing.correct_cells( - final_results=final_results, - reads_per_cell=reads_per_cell, - umis_per_cell=umis_per_cell, - expected_cells=args.expected_cells, - collapsing_threshold=args.bc_threshold, - ab_map=named_tuples_tags_map) + final_results=final_results, + reads_per_cell=reads_per_cell, + umis_per_cell=umis_per_cell, + expected_cells=args.expected_cells, + collapsing_threshold=args.bc_threshold, + ab_map=named_tuples_tags_map, + ) else: ( final_results, umis_per_cell, - bcs_corrected) = processing.correct_cells_whitelist( - final_results=final_results, - umis_per_cell=umis_per_cell, - whitelist=whitelist, - collapsing_threshold=args.bc_threshold, - ab_map=named_tuples_tags_map) + bcs_corrected, + ) = processing.correct_cells_whitelist( + final_results=final_results, + umis_per_cell=umis_per_cell, + whitelist=whitelist, + collapsing_threshold=args.bc_threshold, + ab_map=named_tuples_tags_map, + ) else: - print('Skipping cell barcode correction') + print("Skipping cell barcode correction") bcs_corrected = 0 # If given, use whitelist for top cells @@ -491,55 +696,64 @@ def main(): top_cells_tuple = umis_per_cell.most_common(args.expected_cells) top_cells = set([pair[0] for pair in top_cells_tuple]) - - #UMI correction + # UMI correction if args.no_umi_correction: - #Don't correct + # Don't correct umis_corrected = 0 aberrant_cells = set() else: - #Correct UMIS + # Correct UMIS input_queue = [] - - umi_correction_input = namedtuple('umi_correction_input', ['cells','collapsing_threshold','max_umis']) + + umi_correction_input = namedtuple( + "umi_correction_input", ["cells", "collapsing_threshold", "max_umis"] + ) cells = {} n_cells = 0 num_chunks = 0 - cell_batch_size = round(len(top_cells)/args.n_threads) + 1 + cell_batch_size = round(len(top_cells) / args.n_threads) + 1 for cell in top_cells: cells[cell] = final_results[cell] n_cells += 1 if n_cells % cell_batch_size == 0: - input_queue.append(umi_correction_input( - cells=cells, - collapsing_threshold=args.umi_threshold, - max_umis=20000)) + input_queue.append( + umi_correction_input( + cells=cells, + collapsing_threshold=args.umi_threshold, + max_umis=20000, + ) + ) cells = {} num_chunks += 1 - input_queue.append(umi_correction_input( - cells=cells, - collapsing_threshold=args.umi_threshold, - max_umis=20000)) - + input_queue.append( + umi_correction_input( + cells=cells, collapsing_threshold=args.umi_threshold, max_umis=20000 + ) + ) + pool = Pool(processes=args.n_threads) errors = [] parallel_results = [] - correct_umis = pool.map_async(processing.correct_umis, input_queue, callback=parallel_results.append, error_callback=errors.append) - + correct_umis = pool.map_async( + processing.correct_umis, + input_queue, + callback=parallel_results.append, + error_callback=errors.append, + ) + correct_umis.wait() pool.close() pool.join() - + if len(errors) != 0: for error in errors: print(error) - - + final_results = {} umis_corrected = 0 aberrant_cells = set() - + for chunk in parallel_results[0]: (temp_results, temp_umis, temp_aberrant_cells) = chunk final_results.update(temp_results) @@ -547,60 +761,60 @@ def main(): aberrant_cells.update(temp_aberrant_cells) if len(aberrant_cells) > 0: - #Remove aberrant cells from the top cells + # Remove aberrant cells from the top cells for cell_barcode in aberrant_cells: top_cells.remove(cell_barcode) - #Create sparse aberrant cells matrix - ( - umi_aberrant_matrix, - _ - ) = processing.generate_sparse_matrices( + # Create sparse aberrant cells matrix + (umi_aberrant_matrix, _) = processing.generate_sparse_matrices( final_results=final_results, ordered_tags_map=ordered_tags_map, - top_cells=aberrant_cells) - - #Write uncorrected cells to dense output + top_cells=aberrant_cells, + ) + + # Write uncorrected cells to dense output io.write_dense( - sparse_matrix=umi_aberrant_matrix, - index=list(ordered_tags_map.keys()), - columns=aberrant_cells, - outfolder=os.path.join(args.outfolder,'uncorrected_cells'), - filename='dense_umis.tsv') - - #Create sparse matrices for results - ( - umi_results_matrix, - read_results_matrix - ) = processing.generate_sparse_matrices( + sparse_matrix=umi_aberrant_matrix, + index=list(ordered_tags_map.keys()), + columns=aberrant_cells, + outfolder=os.path.join(args.outfolder, "uncorrected_cells"), + filename="dense_umis.tsv", + ) + + # Create sparse matrices for results + (umi_results_matrix, read_results_matrix) = processing.generate_sparse_matrices( final_results=final_results, ordered_tags_map=ordered_tags_map, - top_cells=top_cells) - + top_cells=top_cells, + ) + # Write umis to file io.write_to_files( sparse_matrix=umi_results_matrix, top_cells=top_cells, ordered_tags_map=ordered_tags_map, - data_type='umi', - outfolder=args.outfolder) - + data_type="umi", + outfolder=args.outfolder, + ) + # Write reads to file io.write_to_files( sparse_matrix=read_results_matrix, top_cells=top_cells, ordered_tags_map=ordered_tags_map, - data_type='read', - outfolder=args.outfolder) - - #Write unmapped sequences + data_type="read", + outfolder=args.outfolder, + ) + + # Write unmapped sequences io.write_unmapped( merged_no_match=merged_no_match, top_unknowns=args.unknowns_top, outfolder=args.outfolder, - filename=args.unmapped_file) - - #Create report and write it to disk + filename=args.unmapped_file, + ) + + # Create report and write it to disk create_report( total_reads=total_reads, reads_per_cell=reads_per_cell, @@ -614,17 +828,20 @@ def main(): R1_too_short=R1_too_short, R2_too_short=R2_too_short, args=args, - chemistry_def=chemistry_def) - - #Write dense matrix to disk if requested + chemistry_def=chemistry_def, + ) + + # Write dense matrix to disk if requested if args.dense: - print('Writing dense format output') + print("Writing dense format output") io.write_dense( sparse_matrix=umi_results_matrix, index=list(ordered_tags_map.keys()), columns=top_cells, outfolder=args.outfolder, - filename='dense_umis.tsv') + filename="dense_umis.tsv", + ) + -if __name__ == '__main__': +if __name__ == "__main__": main() diff --git a/cite_seq_count/database.py b/cite_seq_count/database.py index b373a6c..f66b5bd 100644 --- a/cite_seq_count/database.py +++ b/cite_seq_count/database.py @@ -1,42 +1,88 @@ """This module is holding code for all the remote fetching on the chemistries database.""" import requests import sys +import os from collections import namedtuple +from tempfile import TemporaryFile +from dataclasses import dataclass +GLOBAL_LINK_RAW = "https://raw.githubusercontent.com/Hoohm/scg_lib_structs/10xv3_totalseq_b/chemistries/" +GLOBAL_LINK_GITHUB = "https://github.com/Hoohm/scg_lib_structs/raw/10xv3_totalseq_b/" +GLOBAL_LINK_GITHUB_IO = "https://teichlab.github.io/scg_lib_structs" +CHEMISTRY_DEFINITIONS = os.path.join(GLOBAL_LINK_RAW, "definitions.json") -CHEMISTRY_DEFINITIONS = "https://raw.githubusercontent.com/Hoohm/scg_lib_structs/10xv3_totalseq_b/chemistries/definitions.json" -CHEMISTRY_DEFINITION = namedtuple('chemistry_def', ['barcode_start','barcode_end','umi_start','umi_end', 'R2_trim_start']) +@dataclass +class Chemistry: + name: str + cell_barcode_start: int + cell_barcode_end: int + umi_barcode_start: int + umi_barcode_end: int + R2_trim_start: int + whitelist_path: str + mapping_required: bool + + +def list_chemistries(url=CHEMISTRY_DEFINITIONS): + print("Loading remote file from: {}".format(url)) + with requests.get(url) as r: + r.raise_for_status() + all_chemistry_defs = r.json() + print( + "Here are all the possible chemistries available at {}".format( + GLOBAL_LINK_GITHUB_IO + ) + ) + for chemistry in all_chemistry_defs: + print( + "\n-- {}\n shortname: {}\n Protocol link: {}\n\n".format( + all_chemistry_defs[chemistry]["Description"], + chemistry, + os.path.join( + GLOBAL_LINK_GITHUB_IO, all_chemistry_defs[chemistry]["html"] + ), + ) + ) + def get_chemistry_definition(chemistry_short_name, url=CHEMISTRY_DEFINITIONS): """ Fetches chemistry definitions from a remote definitions.json and returns the json. """ - print('Loading remote file from: {}'.format(url)) + print("Loading remote file from: {}".format(url)) with requests.get(url) as r: r.raise_for_status() chemistry_defs = r.json().get(chemistry_short_name, False) if not chemistry_defs: - sys.exit('Could not find the chemistry: {}. Please check that it does exist at: {}\nExiting'.format(chemistry_short_name, url)) - - chemistry_def = CHEMISTRY_DEFINITION( - barcode_start=chemistry_defs['barcode_structure_indexes']['cell_barcode']['R1']['start'], - barcode_end=chemistry_defs['barcode_structure_indexes']['cell_barcode']['R1']['stop'], - umi_start=chemistry_defs['barcode_structure_indexes']['umi_barcode']['R1']['start'], - umi_end=chemistry_defs['barcode_structure_indexes']['umi_barcode']['R1']['stop'], - R2_trim_start=chemistry_defs['sequence_structure_indexes']['R2']['start'] - 1 + sys.exit( + "Could not find the chemistry: {}. Please check that it does exist at: {}\nExiting".format( + chemistry_short_name, url + ) + ) + chemistry_def = Chemistry( + name=chemistry_short_name, + cell_barcode_start=chemistry_defs["barcode_structure_indexes"]["cell_barcode"][ + "R1" + ]["start"], + cell_barcode_end=chemistry_defs["barcode_structure_indexes"]["cell_barcode"][ + "R1" + ]["stop"], + umi_barcode_start=chemistry_defs["barcode_structure_indexes"]["umi_barcode"][ + "R1" + ]["start"], + umi_barcode_end=chemistry_defs["barcode_structure_indexes"]["umi_barcode"][ + "R1" + ]["stop"], + R2_trim_start=chemistry_defs["sequence_structure_indexes"]["R2"]["start"] - 1, + whitelist_path=os.path.join( + GLOBAL_LINK_GITHUB, "chemistries", chemistry_defs["whitelist"]["path"] + ), + mapping_required=chemistry_defs["whitelist"]["mapping"], ) - return(chemistry_def) + return chemistry_def + def create_chemistry_definition(args): - """ - """ - chemistry_def = CHEMISTRY_DEFINITION( - barcode_start=args.cbf, - barcode_end=args.cbl, - umi_start=args.umif, - umi_end=args.umil, - R2_trim_start=args.start_trim - ) - return(chemistry_def) \ No newline at end of file + return chemistry_def diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 87c2946..ba7626d 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -3,13 +3,14 @@ import sys import regex import Levenshtein +import requests from math import floor from collections import OrderedDict from collections import namedtuple from itertools import combinations -from itertools import islice - +from itertools import islice + def parse_whitelist_csv(filename, barcode_length, collapsing_threshold): """Reads white-listed barcodes from a CSV file. @@ -28,18 +29,37 @@ def parse_whitelist_csv(filename, barcode_length, collapsing_threshold): """ STRIP_CHARS = '"0123456789- \t\n' - with open(filename, mode='r') as csv_file: - csv_reader = csv.reader(csv_file) - cell_pattern = regex.compile(r'[ATGC]{{{}}}'.format(barcode_length)) - whitelist = [row[0].strip(STRIP_CHARS) for row in csv_reader - if (len(row[0].strip(STRIP_CHARS)) == barcode_length)] + cell_pattern = regex.compile(r"[ATGC]{{{}}}".format(barcode_length)) + if filename.endswith("gz"): + with gzip.open(filename.name, mode="r") as csv_file: + csv_reader = csv.reader(csv_file) + whitelist = [ + row[0].strip(STRIP_CHARS) + for row in csv_reader + if (len(row[0].strip(STRIP_CHARS)) == barcode_length) + ] + else: + with open(filename, mode="r") as csv_file: + csv_reader = csv.reader(csv_file) + whitelist = [ + row[0].strip(STRIP_CHARS) + for row in csv_reader + if (len(row[0].strip(STRIP_CHARS)) == barcode_length) + ] + for cell_barcode in whitelist: if not cell_pattern.match(cell_barcode): - sys.exit('This barcode {} is not only composed of ATGC bases.'.format(cell_barcode)) - #collapsing_threshold=test_cell_distances(whitelist, collapsing_threshold) + sys.exit( + "This barcode {} is not only composed of ATGC bases.".format( + cell_barcode + ) + ) + # collapsing_threshold=test_cell_distances(whitelist, collapsing_threshold) if len(whitelist) == 0: - sys.exit('Please check cell barcode indexes -cbs, -cbl because none of the given whitelist is valid.') - return(set(whitelist), collapsing_threshold) + sys.exit( + "Please check cell barcode indexes -cbs, -cbl because none of the given whitelist is valid." + ) + return (set(whitelist), collapsing_threshold) def test_cell_distances(whitelist, collapsing_threshold): @@ -57,17 +77,21 @@ def test_cell_distances(whitelist, collapsing_threshold): """ ok = False while not ok: - print('Testing cell barcode collapsing threshold of {}'.format(collapsing_threshold)) + print( + "Testing cell barcode collapsing threshold of {}".format( + collapsing_threshold + ) + ) all_comb = combinations(whitelist, 2) for comb in all_comb: if Levenshtein.hamming(comb[0], comb[1]) <= collapsing_threshold: collapsing_threshold -= 1 - print('Value is too high, reducing it by 1') + print("Value is too high, reducing it by 1") break else: ok = True - print('Using {} for cell barcode collapsing threshold'.format(collapsing_threshold)) - return(collapsing_threshold) + print("Using {} for cell barcode collapsing threshold".format(collapsing_threshold)) + return collapsing_threshold def parse_tags_csv(filename): @@ -86,12 +110,12 @@ def parse_tags_csv(filename): dict: A dictionary containing the TAGs and their names. """ - with open(filename, mode='r') as csv_file: + with open(filename, mode="r") as csv_file: csv_reader = csv.reader(csv_file) tags = {} for row in csv_reader: tags[row[0].strip()] = row[1].strip() - return(tags) + return tags def check_tags(tags, maximum_distance): @@ -119,58 +143,73 @@ def check_tags(tags, maximum_distance): longest_tag_len = 0 for i, tag_seq in enumerate(sorted(tags, key=len, reverse=True)): ordered_tags[tags[tag_seq]] = {} - ordered_tags[tags[tag_seq]]['id'] = i - ordered_tags[tags[tag_seq]]['sequence'] = tag_seq + ordered_tags[tags[tag_seq]]["id"] = i + ordered_tags[tags[tag_seq]]["sequence"] = tag_seq if len(tag_seq) > longest_tag_len: longest_tag_len = len(tag_seq) - - ordered_tags['unmapped'] = {} - ordered_tags['unmapped']['id'] = i + 1 - ordered_tags['unmapped']['sequence'] = 'UNKNOWN' + + ordered_tags["unmapped"] = {} + ordered_tags["unmapped"]["id"] = i + 1 + ordered_tags["unmapped"]["sequence"] = "UNKNOWN" # If only one TAG is provided, then no distances to compare. - if (len(tags) == 1): - ordered_tags['unmapped'] = {} - ordered_tags['unmapped']['id'] = 2 - return(ordered_tags, longest_tag_len) - + if len(tags) == 1: + ordered_tags["unmapped"] = {} + ordered_tags["unmapped"]["id"] = 2 + return (ordered_tags, longest_tag_len) + offending_pairs = [] for a, b in combinations(tags.keys(), 2): distance = Levenshtein.distance(a, b) - if (distance <= (maximum_distance - 1)): + if distance <= (maximum_distance - 1): offending_pairs.append([a, b, distance]) - DNA_pattern = regex.compile('^[ATGC]*$') + DNA_pattern = regex.compile("^[ATGC]*$") for tag in tags: if not DNA_pattern.match(tag): - print('This tag {} is not only composed of ATGC bases.\nPlease check your tags file'.format(tag)) - sys.exit('Exiting the application.\n') + print( + "This tag {} is not only composed of ATGC bases.\nPlease check your tags file".format( + tag + ) + ) + sys.exit("Exiting the application.\n") # If offending pairs are found, print them all. if offending_pairs: print( - '[ERROR] Minimum Levenshtein distance of TAGs barcode is less ' - 'than given threshold.\n' - 'Please use a smaller distance.\n\n' - 'Offending case(s):\n' + "[ERROR] Minimum Levenshtein distance of TAGs barcode is less " + "than given threshold.\n" + "Please use a smaller distance.\n\n" + "Offending case(s):\n" ) for pair in offending_pairs: print( - '\t{tag1}\n\t{tag2}\n\tDistance = {distance}\n' - .format(tag1=pair[0], tag2=pair[1], distance=pair[2]) + "\t{tag1}\n\t{tag2}\n\tDistance = {distance}\n".format( + tag1=pair[0], tag2=pair[1], distance=pair[2] + ) ) - sys.exit('Exiting the application.\n') + sys.exit("Exiting the application.\n") + + return (ordered_tags, longest_tag_len) - return(ordered_tags, longest_tag_len) def sanitize_name(string): - return(string.replace('-', '_')) + return string.replace("-", "_") + def convert_to_named_tuple(ordered_tags): - #all_tags = namedtuple('all_tags', [sanitize_name(tag) for tag in ordered_tags.keys()]) - tag = namedtuple('tag', ['safe_name','name','sequence', 'id']) + # all_tags = namedtuple('all_tags', [sanitize_name(tag) for tag in ordered_tags.keys()]) + tag = namedtuple("tag", ["safe_name", "name", "sequence", "id"]) tag_list = [] for index, tag_name in enumerate(ordered_tags): - tag_list.append(tag(safe_name=sanitize_name(tag_name), name=tag_name, sequence=ordered_tags[tag_name]['sequence'], id=(index))) - #all_tags[index+1]=ordered_tags[tag_name]['sequence'] - return(tag_list) + tag_list.append( + tag( + safe_name=sanitize_name(tag_name), + name=tag_name, + sequence=ordered_tags[tag_name]["sequence"], + id=(index), + ) + ) + # all_tags[index+1]=ordered_tags[tag_name]['sequence'] + return tag_list + def get_read_length(filename): """Check wether SEQUENCE lengths are consistent in a FASTQ file and return @@ -183,9 +222,9 @@ def get_read_length(filename): int: The file's SEQUENCE length. """ - with gzip.open(filename, 'r') as fastq_file: + with gzip.open(filename, "r") as fastq_file: secondlines = islice(fastq_file, 1, 1000, 4) - #temp_length = len(next(secondlines).rstrip()) + # temp_length = len(next(secondlines).rstrip()) for sequence in secondlines: read_length = len(sequence.rstrip()) # if (temp_length != read_length): @@ -194,12 +233,13 @@ def get_read_length(filename): # 'sequences at the same length.\n' # 'Exiting the application.\n'.format(filename) # ) - return(read_length) + return read_length def get_chunk_strategy(read1_paths, read2_paths, chunk_size): pass + def check_barcodes_lengths(read1_length, cb_first, cb_last, umi_first, umi_last): """Check Read1 length against CELL and UMI barcodes length. @@ -224,19 +264,20 @@ def check_barcodes_lengths(read1_length, cb_first, cb_last, umi_first, umi_last) if barcode_umi_length > read1_length: sys.exit( - '[ERROR] Read1 length is shorter than the option you are using for ' - 'Cell and UMI barcodes length. Please, check your options and rerun.\n\n' - 'Exiting the application.\n' + "[ERROR] Read1 length is shorter than the option you are using for " + "Cell and UMI barcodes length. Please, check your options and rerun.\n\n" + "Exiting the application.\n" ) elif barcode_umi_length < read1_length: print( - '[WARNING] Read1 length is {}bp but you are using {}bp for Cell ' - 'and UMI barcodes combined.\nThis might lead to wrong cell ' - 'attribution and skewed umi counts.\n' - .format(read1_length, barcode_umi_length) + "[WARNING] Read1 length is {}bp but you are using {}bp for Cell " + "and UMI barcodes combined.\nThis might lead to wrong cell " + "attribution and skewed umi counts.\n".format( + read1_length, barcode_umi_length + ) ) - - return(barcode_slice, umi_slice, barcode_umi_length) + + return (barcode_slice, umi_slice, barcode_umi_length) def blocks(files, size=65536): @@ -253,7 +294,8 @@ def blocks(files, size=65536): """ while True: b = files.read(size) - if not b: break + if not b: + break yield b @@ -269,13 +311,15 @@ def get_n_lines(file_path): Returns: n_lines (int): Number of lines in the file """ - print('Counting number of reads') - with gzip.open(file_path, "rt",encoding="utf-8",errors='ignore') as f: + print("Counting number of reads") + with gzip.open(file_path, "rt", encoding="utf-8", errors="ignore") as f: n_lines = sum(bl.count("\n") for bl in blocks(f)) - if n_lines %4 !=0: - sys.exit('{}\'s number of lines is not a multiple of 4. The file ' - 'might be corrupted.\n Exiting'.format(file_path)) - return(n_lines) + if n_lines % 4 != 0: + sys.exit( + "{}'s number of lines is not a multiple of 4. The file " + "might be corrupted.\n Exiting".format(file_path) + ) + return n_lines def get_read_paths(read1_path, read2_path): @@ -290,9 +334,11 @@ def get_read_paths(read1_path, read2_path): _read1_path (list(string)): list of paths to read1.fq _read2_path (list(string)): list of paths to read2.fq """ - _read1_path = read1_path.split(',') - _read2_path = read2_path.split(',') - if len(read1_path) != len(read2_path): - sys.exit('Unequal number of read1 ({}) and read2({}) files provided' - '\n Exiting'.format(len(read1_path),len(read2_path))) - return(_read1_path, _read2_path) + _read1_path = read1_path.split(",") + _read2_path = read2_path.split(",") + if len(_read1_path) != len(_read2_path): + sys.exit( + "Unequal number of read1 ({}) and read2({}) files provided" + "\n Exiting".format(len(read1_path), len(read2_path)) + ) + return (_read1_path, _read2_path) diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index b7bb7e5..f5cef29 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -40,18 +40,18 @@ def find_best_match(TAG_seq, tags, maximum_distance): Returns: best_match (string): The TAG name that will be used for counting. """ - best_match = 'unmapped' + best_match = "unmapped" best_score = maximum_distance for tag in tags: - score = Levenshtein.hamming(tag.sequence, TAG_seq[:len(tag.sequence)]) + score = Levenshtein.hamming(tag.sequence, TAG_seq[: len(tag.sequence)]) if score == 0: - #Best possible match - return(tag.name) + # Best possible match + return tag.name elif score <= best_score: best_score = score best_match = tag.name - return(best_match) - return(best_match) + return best_match + return best_match def find_best_match_shift(TAG_seq, tags): @@ -71,11 +71,11 @@ def find_best_match_shift(TAG_seq, tags): Returns: best_match (string): The TAG name that will be used for counting. """ - best_match = 'unmapped' + best_match = "unmapped" for tag in tags: if tag.sequence in TAG_seq: - return(tag.name) - return(best_match) + return tag.name + return best_match def map_reads(mapping_input): @@ -104,65 +104,67 @@ def map_reads(mapping_input): """ # Initiate values (filename, tags, debug, maximum_distance, sliding_window) = mapping_input - print('Started mapping in child process {}'.format(os.getpid())) + print("Started mapping in child process {}".format(os.getpid())) results = {} no_match = Counter() n = 1 t = time.time() - + # Progress info - with open(filename, 'r') as input_file: + with open(filename, "r") as input_file: reads = csv.reader(input_file) for read in reads: cell_barcode = read[0] # This change in bytes is required by umi_tools for umi correction - UMI = bytes(read[1], 'ascii') + UMI = bytes(read[1], "ascii") read2 = read[2] if n % 1000000 == 0: - print("Processed 1,000,000 reads in {}. Total " + print( + "Processed 1,000,000 reads in {}. Total " "reads: {:,} in child {}".format( - secondsToText.secondsToText(time.time() - t), - n, - os.getpid()) + secondsToText.secondsToText(time.time() - t), n, os.getpid() ) + ) sys.stdout.flush() t = time.time() - if cell_barcode not in results: results[cell_barcode] = defaultdict(Counter) - - if(sliding_window): + + if sliding_window: best_match = find_best_match_shift(read2, tags) else: best_match = find_best_match(read2, tags, maximum_distance) - + results[cell_barcode][best_match][UMI] += 1 - - if(best_match == 'unmapped'): - no_match[read2] += 1 - + + if best_match == "unmapped": + no_match[read2] += 1 + if debug: print( "cell_barcode:{0}\tUMI:{1}\tTAG_seq:{2}\n" "cell barcode length:{3}\tUMI length:{4}\tTAG sequence length:{5}\n" - "Best match is: {6}\n" - .format( + "Best match is: {6}\n".format( cell_barcode, UMI, read2, len(cell_barcode), len(UMI), len(read2), - best_match + best_match, ) ) sys.stdout.flush() n += 1 - print("Mapping done for process {}. Processed {:,} reads".format(os.getpid(), n - 1)) + print( + "Mapping done for process {}. Processed {:,} reads".format( + os.getpid(), n - 1 + ) + ) sys.stdout.flush() - - return(results, no_match) + + return (results, no_match) def merge_results(parallel_results): @@ -191,18 +193,25 @@ def merge_results(parallel_results): # Test the counter. Returns false if empty if mapped[cell_barcode][TAG]: for UMI in mapped[cell_barcode][TAG]: - merged_results[cell_barcode][TAG][UMI] += mapped[cell_barcode][TAG][UMI] + merged_results[cell_barcode][TAG][UMI] += mapped[cell_barcode][ + TAG + ][UMI] umis_per_cell[cell_barcode] += len(mapped[cell_barcode][TAG]) reads_per_cell[cell_barcode] += mapped[cell_barcode][TAG][UMI] merged_no_match.update(unmapped) - return(merged_results, umis_per_cell, reads_per_cell, merged_no_match) + return (merged_results, umis_per_cell, reads_per_cell, merged_no_match) def check_unmapped(no_match, total_reads, start_trim): """Check if the number of unmapped is higher than 99%""" - if sum(no_match.values())/total_reads > float(0.99): - exit("""More than 99 percent of your data is unmapped.\nPlease check that your --start_trim {} parameter is correct and that your tags file is properly formatted""".format(start_trim)) - + if sum(no_match.values()) / total_reads > float(0.99): + exit( + """More than 99 percent of your data is unmapped.\nPlease check that your --start_trim {} parameter is correct and that your tags file is properly formatted""".format( + start_trim + ) + ) + + def correct_umis(umi_correction_input): """ Corrects umi barcodes within same cell/tag groups. @@ -219,10 +228,12 @@ def correct_umis(umi_correction_input): aberrant_umi_count_cells (set): Set of uncorrected cells. """ - - (final_results, collapsing_threshold, max_umis) = umi_correction_input - print('Started umi correction in child process {} working on {} cells'.format(os.getpid(), len(final_results))) + print( + "Started umi correction in child process {} working on {} cells".format( + os.getpid(), len(final_results) + ) + ) corrected_umis = 0 aberrant_cells = set() cells = final_results.keys() @@ -232,15 +243,17 @@ def correct_umis(umi_correction_input): if n_umis > 1 and n_umis <= max_umis: umi_clusters = network.UMIClusterer() UMIclusters = umi_clusters( - final_results[cell_barcode][TAG], - collapsing_threshold) - (new_res, temp_corrected_umis) = update_umi_counts(UMIclusters, final_results[cell_barcode].pop(TAG)) + final_results[cell_barcode][TAG], collapsing_threshold + ) + (new_res, temp_corrected_umis) = update_umi_counts( + UMIclusters, final_results[cell_barcode].pop(TAG) + ) final_results[cell_barcode][TAG] = new_res corrected_umis += temp_corrected_umis elif n_umis > max_umis: aberrant_cells.add(cell_barcode) - print('Finished correcting umis in child {}'.format(os.getpid())) - return(final_results, corrected_umis, aberrant_cells) + print("Finished correcting umis in child {}".format(os.getpid())) + return (final_results, corrected_umis, aberrant_cells) def update_umi_counts(UMIclusters, cell_tag_counts): @@ -256,14 +269,16 @@ def update_umi_counts(UMIclusters, cell_tag_counts): temp_corrected_umis (int): Number of corrected umis """ temp_corrected_umis = 0 - for umi_cluster in UMIclusters: # This is a list with the first element the dominant barcode - if(len(umi_cluster) > 1): # This means we got a correction + for ( + umi_cluster + ) in UMIclusters: # This is a list with the first element the dominant barcode + if len(umi_cluster) > 1: # This means we got a correction major_umi = umi_cluster[0] for minor_umi in umi_cluster[1:]: temp_corrected_umis += 1 temp = cell_tag_counts.pop(minor_umi) cell_tag_counts[major_umi] += temp - return(cell_tag_counts, temp_corrected_umis) + return (cell_tag_counts, temp_corrected_umis) def collapse_cells(true_to_false, umis_per_cell, final_results, ab_map): @@ -281,7 +296,7 @@ def collapse_cells(true_to_false, umis_per_cell, final_results, ab_map): final_results (dict): Same as input but with corrected cell barcodes. corrected_barcodes (int): How many cell barcodes have been corrected. """ - print('Collapsing cell barcodes') + print("Collapsing cell barcodes") corrected_barcodes = 0 for real_barcode in true_to_false: # If the cell barcode is not in the results @@ -295,15 +310,22 @@ def collapse_cells(true_to_false, umis_per_cell, final_results, ab_map): for TAG in temp.keys(): final_results[real_barcode][TAG].update(temp[TAG]) temp_umi_counts = umis_per_cell.pop(fake_barcode) - #temp_read_counts = reads_per_cell.pop(fake_barcode) - + # temp_read_counts = reads_per_cell.pop(fake_barcode) + umis_per_cell[real_barcode] += temp_umi_counts - #reads_per_cell[real_barcode] += temp_read_counts + # reads_per_cell[real_barcode] += temp_read_counts - return(umis_per_cell, final_results, corrected_barcodes) + return (umis_per_cell, final_results, corrected_barcodes) -def correct_cells(final_results, reads_per_cell, umis_per_cell, collapsing_threshold, expected_cells, ab_map): +def correct_cells( + final_results, + reads_per_cell, + umis_per_cell, + collapsing_threshold, + expected_cells, + ab_map, +): """ Corrects cell barcodes. @@ -319,27 +341,27 @@ def correct_cells(final_results, reads_per_cell, umis_per_cell, collapsing_thres umis_per_cell (Counter): Counter of umis per cell after cell barcode correction corrected_umis (int): How many umis have been corrected. """ - print('Looking for a whitelist') + print("Looking for a whitelist") _, true_to_false = whitelist_methods.getCellWhitelist( cell_barcode_counts=reads_per_cell, expect_cells=expected_cells, cell_number=expected_cells, error_correct_threshold=collapsing_threshold, - plotfile_prefix=False) - - ( - umis_per_cell, - final_results, - corrected_barcodes - ) = collapse_cells( - true_to_false=true_to_false, - umis_per_cell=umis_per_cell, - final_results=final_results, - ab_map=ab_map) - return(final_results, umis_per_cell, corrected_barcodes) - - -def correct_cells_whitelist(final_results, umis_per_cell, whitelist, collapsing_threshold, ab_map): + plotfile_prefix=False, + ) + + (umis_per_cell, final_results, corrected_barcodes) = collapse_cells( + true_to_false=true_to_false, + umis_per_cell=umis_per_cell, + final_results=final_results, + ab_map=ab_map, + ) + return (final_results, umis_per_cell, corrected_barcodes) + + +def correct_cells_whitelist( + final_results, umis_per_cell, whitelist, collapsing_threshold, ab_map +): """ Corrects cell barcodes. @@ -357,32 +379,28 @@ def correct_cells_whitelist(final_results, umis_per_cell, whitelist, collapsing_ corrected_barcodes (int): How many umis have been corrected. """ barcode_tree = pybktree.BKTree(Levenshtein.hamming, whitelist) - print('Generated barcode tree from whitelist') + print("Generated barcode tree from whitelist") cell_barcodes = list(final_results.keys()) n_barcodes = len(cell_barcodes) - print('Finding reference candidates') - print('Processing {:,} cell barcodes'.format(n_barcodes)) + print("Finding reference candidates") + print("Processing {:,} cell barcodes".format(n_barcodes)) - #Run with one process + # Run with one process true_to_false = find_true_to_false_map( - barcode_tree=barcode_tree, - cell_barcodes=cell_barcodes, - whitelist=whitelist, - collapsing_threshold=collapsing_threshold) - ( - umis_per_cell, - final_results, - corrected_barcodes) = collapse_cells( - true_to_false, - umis_per_cell, - final_results, - ab_map) - return(final_results, umis_per_cell, corrected_barcodes) - - - - -def find_true_to_false_map(barcode_tree, cell_barcodes, whitelist, collapsing_threshold): + barcode_tree=barcode_tree, + cell_barcodes=cell_barcodes, + whitelist=whitelist, + collapsing_threshold=collapsing_threshold, + ) + (umis_per_cell, final_results, corrected_barcodes) = collapse_cells( + true_to_false, umis_per_cell, final_results, ab_map + ) + return (final_results, umis_per_cell, corrected_barcodes) + + +def find_true_to_false_map( + barcode_tree, cell_barcodes, whitelist, collapsing_threshold +): """ Creates a mapping between "fake" cell barcodes and their original true barcode. @@ -401,7 +419,11 @@ def find_true_to_false_map(barcode_tree, cell_barcodes, whitelist, collapsing_th # if the barcode is already whitelisted, no need to add continue # get all members of whitelist that are at distance of collapsing_threshold - candidates = [white_cell for d, white_cell in barcode_tree.find(cell_barcode, collapsing_threshold) if d > 0] + candidates = [ + white_cell + for d, white_cell in barcode_tree.find(cell_barcode, collapsing_threshold) + if d > 0 + ] if len(candidates) == 1: white_cell_str = candidates[0] true_to_false[white_cell_str].append(cell_barcode) @@ -414,8 +436,7 @@ def find_true_to_false_map(barcode_tree, cell_barcodes, whitelist, collapsing_th # more than on whitelisted candidate: # we drop it as its not uniquely assignable continue - return(true_to_false) - + return true_to_false def generate_sparse_matrices(final_results, ordered_tags_map, top_cells): @@ -431,12 +452,20 @@ def generate_sparse_matrices(final_results, ordered_tags_map, top_cells): read_results_matrix (scipy.sparse.dok_matrix): Read counts """ - umi_results_matrix = sparse.dok_matrix((len(ordered_tags_map) ,len(top_cells)), dtype=int32) - read_results_matrix = sparse.dok_matrix((len(ordered_tags_map) ,len(top_cells)), dtype=int32) - for i,cell_barcode in enumerate(top_cells): - for j,TAG in enumerate(final_results[cell_barcode]): + umi_results_matrix = sparse.dok_matrix( + (len(ordered_tags_map), len(top_cells)), dtype=int32 + ) + read_results_matrix = sparse.dok_matrix( + (len(ordered_tags_map), len(top_cells)), dtype=int32 + ) + for i, cell_barcode in enumerate(top_cells): + for j, TAG in enumerate(final_results[cell_barcode]): if final_results[cell_barcode][TAG]: - umi_results_matrix[ordered_tags_map[TAG]['id'],i] = len(final_results[cell_barcode][TAG]) - read_results_matrix[ordered_tags_map[TAG]['id'],i] = sum(final_results[cell_barcode][TAG].values()) - return(umi_results_matrix, read_results_matrix) + umi_results_matrix[ordered_tags_map[TAG]["id"], i] = len( + final_results[cell_barcode][TAG] + ) + read_results_matrix[ordered_tags_map[TAG]["id"], i] = sum( + final_results[cell_barcode][TAG].values() + ) + return (umi_results_matrix, read_results_matrix) diff --git a/setup.py b/setup.py index c276fe4..eb083e0 100644 --- a/setup.py +++ b/setup.py @@ -11,26 +11,22 @@ description="A python package to map reads from CITE-seq or hashing data for single cell experiments", url="https://github.com/Hoohm/CITE-seq-Count/", packages=setuptools.find_packages(), - entry_points={ - 'console_scripts': [ - 'CITE-seq-Count = cite_seq_count.__main__:main' - ] - }, + entry_points={"console_scripts": ["CITE-seq-Count = cite_seq_count.__main__:main"]}, classifiers=( "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", "Operating System :: OS Independent", ), install_requires=[ - 'python-levenshtein>=0.12.0', - 'scipy>=1.1.0', - 'multiprocess>=0.70.6.1', - 'umi_tools==1.0.0', - 'pytest==4.1.0', - 'pytest-dependency==0.4.0', - 'pandas>=0.23.4', - 'pybktree==1.1', - 'cython>=0.29.17' - ], - python_requires='>=3.6' + "python-levenshtein>=0.12.0", + "scipy>=1.1.0", + "multiprocess>=0.70.6.1", + "umi_tools==1.0.0", + "pytest==4.1.0", + "pytest-dependency==0.4.0", + "pandas>=0.23.4", + "pybktree==1.1", + "cython>=0.29.17", + ], + python_requires=">=3.7", ) From e5dc23d12c1556c10b60eaa672dcf8000bb96c98 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sun, 2 Aug 2020 22:34:09 +0200 Subject: [PATCH 14/77] a lot of code refactoring --- cite_seq_count/__main__.py | 434 +------------------------------- cite_seq_count/argsparser.py | 310 +++++++++++++++++++++++ cite_seq_count/chemistry.py | 149 +++++++++++ cite_seq_count/database.py | 25 +- cite_seq_count/io.py | 125 +++++++-- cite_seq_count/preprocessing.py | 26 +- cite_seq_count/processing.py | 7 +- 7 files changed, 618 insertions(+), 458 deletions(-) create mode 100644 cite_seq_count/argsparser.py create mode 100644 cite_seq_count/chemistry.py diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index a4bf9af..6746e00 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -3,415 +3,25 @@ Author: Patrick Roelli """ import sys -import time import os -import datetime -import pkg_resources import logging import gzip import requests +import time from itertools import islice -from argparse import ArgumentParser, ArgumentTypeError, RawTextHelpFormatter from collections import OrderedDict, Counter, defaultdict, namedtuple -from multiprocess import cpu_count, Pool, Queue, JoinableQueue, Process -from tempfile import NamedTemporaryFile + +# pylint: disable=no-name-in-module +from multiprocess import Pool, Queue, JoinableQueue, Process from cite_seq_count import preprocessing from cite_seq_count import processing -from cite_seq_count import database +from cite_seq_count import chemistry from cite_seq_count import io from cite_seq_count import secondsToText - -version = pkg_resources.require("cite_seq_count")[0].version - - -def chunk_size_limit(arg): - """Validates chunk_size limits""" - max_size = 2147483647 - try: - f = int(arg) - except ValueError: - raise ArgumentTypeError("Must be am int") - if f < 1 or f > max_size: - raise ArgumentTypeError( - "Argument must be < " + str(max_size) + "and > " + str(1) - ) - return f - - -def get_args(): - """ - Get args. - """ - - parser = ArgumentParser( - prog="CITE-seq-Count", - formatter_class=RawTextHelpFormatter, - description=( - "This script counts matching antibody tags from paired fastq " - "files. Version {}".format(version) - ), - ) - - # REQUIRED INPUTS group. - inputs = parser.add_argument_group("Inputs", description="Required input files.") - inputs.add_argument( - "-R1", - "--read1", - dest="read1_path", - required=True, - help=( - "The path of Read1 in gz format, or a comma-separated list of paths to all Read1 files in" - " gz format (E.g. A1.fq.gz,B1.fq,gz,..." - ), - ) - inputs.add_argument( - "-R2", - "--read2", - dest="read2_path", - required=True, - help=( - "The path of Read2 in gz format, or a comma-separated list of paths to all Read2 files in" - " gz format (E.g. A2.fq.gz,B2.fq,gz,..." - ), - ) - inputs.add_argument( - "-t", - "--tags", - dest="tags", - required=True, - help=( - "The path to the csv file containing the antibody\n" - "barcodes as well as their respective names.\n\n" - "Example of an antibody barcode file structure:\n\n" - "\tATGCGA,First_tag_name\n" - "\tGTCATG,Second_tag_name" - ), - ) - # BARCODES group. - barcodes = parser.add_argument_group( - "Barcodes", - description=( - "Positions of the cellular barcodes and UMI. If your " - "cellular barcodes and UMI\n are positioned as follows:\n" - "\tBarcodes from 1 to 16 and UMI from 17 to 26\n" - "then this is the input you need:\n" - "\t-cbf 1 -cbl 16 -umif 17 -umil 26" - ), - ) - barcodes.add_argument("--chemistry", type=str, required=False, default=False) - if "--chemistry" not in sys.argv: - barcodes.add_argument( - "-cbf", - "--cell_barcode_first_base", - dest="cb_first", - required=True, - type=int, - help=("Postion of the first base of your cell " "barcodes."), - ) - barcodes.add_argument( - "-cbl", - "--cell_barcode_last_base", - dest="cb_last", - required=True, - type=int, - help=("Postion of the last base of your cell " "barcodes."), - ) - barcodes.add_argument( - "-umif", - "--umi_first_base", - dest="umi_first", - required=True, - type=int, - help="Postion of the first base of your UMI.", - ) - barcodes.add_argument( - "-umil", - "--umi_last_base", - dest="umi_last", - required=True, - type=int, - help="Postion of the last base of your UMI.", - ) - barcodes.add_argument( - "--umi_collapsing_dist", - dest="umi_threshold", - required=False, - type=int, - default=2, - help="threshold for umi collapsing.", - ) - barcodes.add_argument( - "--no_umi_correction", - required=False, - action="store_true", - default=False, - dest="no_umi_correction", - help="Deactivate UMI collapsing", - ) - barcodes.add_argument( - "--bc_collapsing_dist", - dest="bc_threshold", - required=False, - type=int, - default=1, - help="threshold for cellular barcode collapsing.", - ) - # Cells group - cells = parser.add_argument_group( - "Cells", description=("Expected number of cells and potential whitelist") - ) - - cells.add_argument( - "-cells", - "--expected_cells", - dest="expected_cells", - required=True, - type=int, - help=("Number of expected cells from your run."), - default=0, - ) - if "--chemistry" not in sys.argv: - cells.add_argument( - "-wl", - "--whitelist", - dest="whitelist", - required=False, - type=str, - help=( - "A csv file containning a whitelist of barcodes produced" - " by the mRNA data.\n\n" - "\tExample:\n" - "\tATGCTAGTGCTA\n\tGCTAGTCAGGAT\n\tCGACTGCTAACG\n\n" - "Or 10X-style:\n" - "\tATGCTAGTGCTA-1\n\tGCTAGTCAGGAT-1\n\tCGACTGCTAACG-1\n" - ), - ) - - cells.add_argument( - "--translation", - required=False, - type=str, - help="A csv file containing the mapping between two sets of cell barcode list.\n" - "A required header such as the reference is named whitelist. Example:\n\n" - "\twhitelist,trasnlation\n" - "\tAAACCCAAGAAACACT,AAACCCATCAAACACT\n" - "\tAAACCCAAGAAACCAT,AAACCCATCAAACCAT\n" - "\nThe output matrix will possess both cell barcode IDs", - ) - - # FILTERS group. - filters = parser.add_argument_group( - "TAG filters", description=("Filtering and trimming for read2.") - ) - filters.add_argument( - "--max-errors", - dest="max_error", - required=False, - type=int, - default=2, - help=("Maximum Levenshtein distance allowed for antibody barcodes."), - ) - - filters.add_argument( - "-trim", - "--start-trim", - dest="start_trim", - required=False, - type=int, - default=0, - help=("Number of bases to discard from read2."), - ) - - filters.add_argument( - "--sliding-window", - dest="sliding_window", - required=False, - default=False, - action="store_true", - help=("Allow for a sliding window when aligning."), - ) - - # Parallel group. - parallel = parser.add_argument_group( - "Parallelization options", - description=("Options for performance on parallelization"), - ) - # Remaining arguments. - parallel.add_argument( - "-T", - "--threads", - required=False, - type=int, - dest="n_threads", - default=cpu_count(), - help="How many threads are to be used for running the program", - ) - parallel.add_argument( - "-C", - "--chunk_size", - required=False, - type=chunk_size_limit, - dest="chunk_size", - help="How many reads should be sent to a child process at a time", - ) - parallel.add_argument( - "--temp_path", - required=False, - type=str, - dest="temp_path", - default="", - help="Temp folder for chunk creation specification. Useful when using a cluster with a scratch folder", - ) - - # Global group - parser.add_argument( - "-n", - "--first_n", - required=False, - type=int, - dest="first_n", - default=float("inf"), - help="Select N reads to run on instead of all.", - ) - parser.add_argument( - "-o", - "--output", - required=False, - type=str, - default="Results", - dest="outfolder", - help="Results will be written to this folder", - ) - parser.add_argument( - "--dense", - required=False, - action="store_true", - default=False, - dest="dense", - help="Add a dense output to the results folder", - ) - parser.add_argument( - "-u", - "--unmapped-tags", - required=False, - type=str, - dest="unmapped_file", - default="unmapped.csv", - help="Write table of unknown TAGs to file.", - ) - parser.add_argument( - "-ut", - "--unknown-top-tags", - required=False, - dest="unknowns_top", - type=int, - default=100, - help="Top n unmapped TAGs.", - ) - parser.add_argument( - "--debug", action="store_true", help="Print extra information for debugging." - ) - parser.add_argument( - "--version", - action="version", - version="CITE-seq-Count v{}".format(version), - help="Print version number.", - ) - # Finally! Too many options XD - return parser - - -def create_report( - total_reads, - reads_per_cell, - no_match, - version, - start_time, - ordered_tags_map, - umis_corrected, - bcs_corrected, - bad_cells, - R1_too_short, - R2_too_short, - args, - chemistry_def, -): - """ - Creates a report with details about the run in a yaml format. - Args: - total_reads (int): Number of reads that have been processed. - reads_matrix (scipy.sparse.dok_matrix): A sparse matrix continining read counts. - no_match (Counter): Counter of unmapped tags. - version (string): CITE-seq-Count package version. - start_time (time): Start time of the run. - args (arg_parse): Arguments provided by the user. - - """ - total_unmapped = sum(no_match.values()) - total_mapped = total_reads - total_unmapped - total_too_short = total_reads - total_unmapped - total_mapped - too_short_perc = round((total_too_short / total_reads) * 100) - mapped_perc = round((total_mapped / total_reads) * 100) - unmapped_perc = round((total_unmapped / total_reads) * 100) - - with open(os.path.join(args.outfolder, "run_report.yaml"), "w") as report_file: - report_file.write( - """Date: {} -Running time: {} -CITE-seq-Count Version: {} -Reads processed: {} -Percentage mapped: {} -Percentage unmapped: {} -Percentage too short: {} -\tR1_too_short: {} -\tR2_too_short: {} -Uncorrected cells: {} -Correction: -\tCell barcodes collapsing threshold: {} -\tCell barcodes corrected: {} -\tUMI collapsing threshold: {} -\tUMIs corrected: {} -Run parameters: -\tRead1_paths: {} -\tRead2_paths: {} -\tCell barcode: -\t\tFirst position: {} -\t\tLast position: {} -\tUMI barcode: -\t\tFirst position: {} -\t\tLast position: {} -\tExpected cells: {} -\tTags max errors: {} -\tStart trim: {} -""".format( - datetime.datetime.today().strftime("%Y-%m-%d"), - secondsToText.secondsToText(time.time() - start_time), - version, - int(total_reads), - mapped_perc, - unmapped_perc, - too_short_perc, - R1_too_short, - R2_too_short, - len(bad_cells), - args.bc_threshold, - bcs_corrected, - args.umi_threshold, - umis_corrected, - args.read1_path, - args.read2_path, - args.cb_first, - args.cb_last, - args.umi_first, - chemistry_def.umi_barcode_end, - args.expected_cells, - args.max_error, - chemistry_def.R2_trim_start, - ) - ) +from cite_seq_count import argsparser def main(): @@ -427,7 +37,7 @@ def main(): logger.addHandler(ch) start_time = time.time() - parser = get_args() + parser = argsparser.get_args() if not sys.argv[1:]: parser.print_help(file=sys.stderr) sys.exit(2) @@ -438,31 +48,7 @@ def main(): assert os.access(temp_path, os.W_OK) # Get chemistry defs - if args.chemistry: - chemistry_def = database.get_chemistry_definition(args.chemistry) - with requests.get(chemistry_def.whitelist_path) as r: - r.raise_for_status() - ###### TODO deal with the download of the whitelist from remote - with NamedTemporaryFile() as temp_local_whitelist: - temp_local_whitelist.write(r.content) - (whitelist, args.bc_threshold) = preprocessing.parse_whitelist_csv( - filename=temp_local_whitelist, - barcode_length=chemistry_def.cell_barcode_end - - chemistry_def.cell_barcode_start - + 1, - collapsing_threshold=args.bc_threshold, - ) - else: - chemistry_def = database.create_chemistry_definition(args) - if args.whitelist: - print("Loading whitelist") - (whitelist, args.bc_threshold) = preprocessing.parse_whitelist_csv( - filename=args.whitelist, - barcode_length=args.cb_last - args.cb_first + 1, - collapsing_threshold=args.bc_threshold, - ) - else: - whitelist = False + (whitelist, chemistry_def) = chemistry.setup_chemistry(args) # Load TAGs/ABs. ab_map = preprocessing.parse_tags_csv(args.tags) @@ -815,11 +401,11 @@ def main(): ) # Create report and write it to disk - create_report( + io.create_report( total_reads=total_reads, reads_per_cell=reads_per_cell, no_match=merged_no_match, - version=version, + version=argsparser.get_package_version(), start_time=start_time, ordered_tags_map=ordered_tags_map, umis_corrected=umis_corrected, diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py new file mode 100644 index 0000000..1694bd9 --- /dev/null +++ b/cite_seq_count/argsparser.py @@ -0,0 +1,310 @@ +import pkg_resources +import sys + + +from argparse import ArgumentParser, ArgumentTypeError, RawTextHelpFormatter + +# pylint: disable=no-name-in-module +from multiprocess import cpu_count + + +def get_package_version(): + version = pkg_resources.require("cite_seq_count")[0].version + return version + + +def chunk_size_limit(arg): + """Validates chunk_size limits""" + max_size = 2147483647 + try: + f = int(arg) + except ValueError: + raise ArgumentTypeError("Must be an int") + if f < 1 or f > max_size: + raise ArgumentTypeError( + "Argument must be < " + str(max_size) + "and > " + str(1) + ) + return f + + +def get_args(): + """ + Get args. + """ + + parser = ArgumentParser( + prog="CITE-seq-Count", + formatter_class=RawTextHelpFormatter, + description=( + "This script counts matching antibody tags from paired fastq " + "files. Version {}".format(get_package_version()) + ), + ) + + # REQUIRED INPUTS group. + inputs = parser.add_argument_group("Inputs", description="Required input files.") + inputs.add_argument( + "-R1", + "--read1", + dest="read1_path", + required=True, + help=( + "The path of Read1 in gz format, or a comma-separated list of paths to all Read1 files in" + " gz format (E.g. A1.fq.gz,B1.fq,gz,..." + ), + ) + inputs.add_argument( + "-R2", + "--read2", + dest="read2_path", + required=True, + help=( + "The path of Read2 in gz format, or a comma-separated list of paths to all Read2 files in" + " gz format (E.g. A2.fq.gz,B2.fq,gz,..." + ), + ) + inputs.add_argument( + "-t", + "--tags", + dest="tags", + required=True, + help=( + "The path to the csv file containing the antibody\n" + "barcodes as well as their respective names.\n\n" + "Example of an antibody barcode file structure:\n\n" + "\tATGCGA,First_tag_name\n" + "\tGTCATG,Second_tag_name" + ), + ) + # BARCODES group. + barcodes = parser.add_argument_group( + "Barcodes", + description=( + "Positions of the cellular barcodes and UMI. If your " + "cellular barcodes and UMI\n are positioned as follows:\n" + "\tBarcodes from 1 to 16 and UMI from 17 to 26\n" + "then this is the input you need:\n" + "\t-cbf 1 -cbl 16 -umif 17 -umil 26" + ), + ) + barcodes.add_argument("--chemistry", type=str, required=False, default=False) + if "--chemistry" not in sys.argv: + barcodes.add_argument( + "-cbf", + "--cell_barcode_first_base", + dest="cb_first", + required=True, + type=int, + help=("Postion of the first base of your cell " "barcodes."), + ) + barcodes.add_argument( + "-cbl", + "--cell_barcode_last_base", + dest="cb_last", + required=True, + type=int, + help=("Postion of the last base of your cell " "barcodes."), + ) + barcodes.add_argument( + "-umif", + "--umi_first_base", + dest="umi_first", + required=True, + type=int, + help="Postion of the first base of your UMI.", + ) + barcodes.add_argument( + "-umil", + "--umi_last_base", + dest="umi_last", + required=True, + type=int, + help="Postion of the last base of your UMI.", + ) + barcodes.add_argument( + "--umi_collapsing_dist", + dest="umi_threshold", + required=False, + type=int, + default=2, + help="threshold for umi collapsing.", + ) + barcodes.add_argument( + "--no_umi_correction", + required=False, + action="store_true", + default=False, + dest="no_umi_correction", + help="Deactivate UMI collapsing", + ) + barcodes.add_argument( + "--bc_collapsing_dist", + dest="bc_threshold", + required=False, + type=int, + default=1, + help="threshold for cellular barcode collapsing.", + ) + # Cells group + cells = parser.add_argument_group( + "Cells", description=("Expected number of cells and potential whitelist") + ) + + cells.add_argument( + "-cells", + "--expected_cells", + dest="expected_cells", + required=True, + type=int, + help=("Number of expected cells from your run."), + default=0, + ) + if "--chemistry" not in sys.argv: + cells.add_argument( + "-wl", + "--whitelist", + dest="whitelist", + required=False, + type=str, + help=( + "A csv file containning a whitelist of barcodes produced" + " by the mRNA data.\n\n" + "\tExample:\n" + "\tATGCTAGTGCTA\n\tGCTAGTCAGGAT\n\tCGACTGCTAACG\n\n" + "Or 10X-style:\n" + "\tATGCTAGTGCTA-1\n\tGCTAGTCAGGAT-1\n\tCGACTGCTAACG-1\n" + ), + ) + + cells.add_argument( + "--translation", + required=False, + type=str, + help="A csv file containing the mapping between two sets of cell barcode list.\n" + "A required header such as the reference is named whitelist. Example:\n\n" + "\twhitelist,trasnlation\n" + "\tAAACCCAAGAAACACT,AAACCCATCAAACACT\n" + "\tAAACCCAAGAAACCAT,AAACCCATCAAACCAT\n" + "\nThe output matrix will possess both cell barcode IDs", + ) + + # FILTERS group. + filters = parser.add_argument_group( + "TAG filters", description=("Filtering and trimming for read2.") + ) + filters.add_argument( + "--max-errors", + dest="max_error", + required=False, + type=int, + default=2, + help=("Maximum Levenshtein distance allowed for antibody barcodes."), + ) + + filters.add_argument( + "-trim", + "--start-trim", + dest="start_trim", + required=False, + type=int, + default=0, + help=("Number of bases to discard from read2."), + ) + + filters.add_argument( + "--sliding-window", + dest="sliding_window", + required=False, + default=False, + action="store_true", + help=("Allow for a sliding window when aligning."), + ) + + # Parallel group. + parallel = parser.add_argument_group( + "Parallelization options", + description=("Options for performance on parallelization"), + ) + # Remaining arguments. + parallel.add_argument( + "-T", + "--threads", + required=False, + type=int, + dest="n_threads", + default=cpu_count(), + help="How many threads are to be used for running the program", + ) + parallel.add_argument( + "-C", + "--chunk_size", + required=False, + type=chunk_size_limit, + dest="chunk_size", + help="How many reads should be sent to a child process at a time", + ) + parallel.add_argument( + "--temp_path", + required=False, + type=str, + dest="temp_path", + default="", + help="Temp folder for chunk creation specification. Useful when using a cluster with a scratch folder", + ) + + # Global group + parser.add_argument( + "-n", + "--first_n", + required=False, + type=int, + dest="first_n", + default=float("inf"), + help="Select N reads to run on instead of all.", + ) + parser.add_argument( + "-o", + "--output", + required=False, + type=str, + default="Results", + dest="outfolder", + help="Results will be written to this folder", + ) + parser.add_argument( + "--dense", + required=False, + action="store_true", + default=False, + dest="dense", + help="Add a dense output to the results folder", + ) + parser.add_argument( + "-u", + "--unmapped-tags", + required=False, + type=str, + dest="unmapped_file", + default="unmapped.csv", + help="Write table of unknown TAGs to file.", + ) + parser.add_argument( + "-ut", + "--unknown-top-tags", + required=False, + dest="unknowns_top", + type=int, + default=100, + help="Top n unmapped TAGs.", + ) + parser.add_argument( + "--debug", action="store_true", help="Print extra information for debugging." + ) + parser.add_argument( + "--version", + action="version", + version="CITE-seq-Count v{}".format(get_package_version()), + help="Print version number.", + ) + # Finally! Too many options XD + return parser diff --git a/cite_seq_count/chemistry.py b/cite_seq_count/chemistry.py new file mode 100644 index 0000000..a97f2fb --- /dev/null +++ b/cite_seq_count/chemistry.py @@ -0,0 +1,149 @@ +"""This module is holding code for all the remote fetching on the chemistries database.""" +import requests +import sys +import os +import gzip +import io +import csv + +from collections import namedtuple + +from dataclasses import dataclass + +from cite_seq_count import preprocessing + +GLOBAL_LINK_RAW = "https://raw.githubusercontent.com/Hoohm/scg_lib_structs/10xv3_totalseq_b/chemistries/" +GLOBAL_LINK_GITHUB = "https://github.com/Hoohm/scg_lib_structs/raw/10xv3_totalseq_b/" +GLOBAL_LINK_GITHUB_IO = "https://teichlab.github.io/scg_lib_structs" +CHEMISTRY_DEFINITIONS = os.path.join(GLOBAL_LINK_RAW, "definitions.json") + + +@dataclass +class Chemistry: + name: str + cell_barcode_start: int + cell_barcode_end: int + umi_barcode_start: int + umi_barcode_end: int + R2_trim_start: int + whitelist_path: str + mapping_required: bool + + +def list_chemistries(url=CHEMISTRY_DEFINITIONS): + print("Loading remote file from: {}".format(url)) + with requests.get(url) as r: + r.raise_for_status() + all_chemistry_defs = r.json() + print( + "Here are all the possible chemistries available at {}".format( + GLOBAL_LINK_GITHUB_IO + ) + ) + for chemistry in all_chemistry_defs: + print( + "\n-- {}\n shortname: {}\n Protocol link: {}\n\n".format( + all_chemistry_defs[chemistry]["Description"], + chemistry, + os.path.join( + GLOBAL_LINK_GITHUB_IO, all_chemistry_defs[chemistry]["html"] + ), + ) + ) + + +def get_chemistry_definition(chemistry_short_name, url=CHEMISTRY_DEFINITIONS): + """ + Fetches chemistry definitions from a remote definitions.json and returns the json. + """ + print("Loading remote file from: {}".format(url)) + with requests.get(url) as r: + r.raise_for_status() + chemistry_defs = r.json().get(chemistry_short_name, False) + if not chemistry_defs: + sys.exit( + "Could not find the chemistry: {}. Please check that it does exist at: {}\nExiting".format( + chemistry_short_name, url + ) + ) + chemistry_def = Chemistry( + name=chemistry_short_name, + cell_barcode_start=chemistry_defs["barcode_structure_indexes"]["cell_barcode"][ + "R1" + ]["start"], + cell_barcode_end=chemistry_defs["barcode_structure_indexes"]["cell_barcode"][ + "R1" + ]["stop"], + umi_barcode_start=chemistry_defs["barcode_structure_indexes"]["umi_barcode"][ + "R1" + ]["start"], + umi_barcode_end=chemistry_defs["barcode_structure_indexes"]["umi_barcode"][ + "R1" + ]["stop"], + R2_trim_start=chemistry_defs["sequence_structure_indexes"]["R2"]["start"] - 1, + whitelist_path=os.path.join( + GLOBAL_LINK_GITHUB, "chemistries", chemistry_defs["whitelist"]["path"] + ), + mapping_required=chemistry_defs["whitelist"]["mapping"], + ) + return chemistry_def + + +def get_csv_reader(file): + if file.startswith("http://") or file.startswith("https://"): + response = requests.get(file) + response.raise_for_status() + + if file.endswith(".gz"): + content = response.content + text = gzip.decompress(content).decode("utf-8") + else: + text = response.text + reader = csv.reader(io.StringIO(text)) + + elif file.endswith(".gz"): + f = gzip.open(file, mode="rt") + reader = csv.reader(f) + else: + f = open(file, encoding="UTF-8") + reader = csv.reader(f) + + return reader + + +def create_chemistry_definition(args): + chemistry_def = Chemistry( + name="custom", + cell_barcode_start=args.cb_first, + cell_barcode_end=args.cb_last, + umi_barcode_start=args.umi_first, + umi_barcode_end=args.umi_last, + R2_trim_start=args.start_trim, + whitelist_path=args.whitelist, + mapping_required=args.translation, + ) + return chemistry_def + + +def setup_chemistry(args): + if args.chemistry: + chemistry_def = get_chemistry_definition(args.chemistry) + (whitelist, args.bc_threshold) = preprocessing.parse_whitelist_csv( + csv_reader=get_csv_reader(chemistry_def.whitelist_path), + barcode_length=chemistry_def.cell_barcode_end + - chemistry_def.cell_barcode_start + + 1, + collapsing_threshold=args.bc_threshold, + ) + else: + chemistry_def = create_chemistry_definition(args) + if args.whitelist: + print("Loading whitelist") + (whitelist, args.bc_threshold) = preprocessing.parse_whitelist_csv( + csv_reader=get_csv_reader(args.whitelist), + barcode_length=args.cb_last - args.cb_first + 1, + collapsing_threshold=args.bc_threshold, + ) + else: + whitelist = False + return (whitelist, chemistry_def) diff --git a/cite_seq_count/database.py b/cite_seq_count/database.py index f66b5bd..9870266 100644 --- a/cite_seq_count/database.py +++ b/cite_seq_count/database.py @@ -2,9 +2,11 @@ import requests import sys import os +import gzip +import io from collections import namedtuple -from tempfile import TemporaryFile + from dataclasses import dataclass GLOBAL_LINK_RAW = "https://raw.githubusercontent.com/Hoohm/scg_lib_structs/10xv3_totalseq_b/chemistries/" @@ -84,5 +86,26 @@ def get_chemistry_definition(chemistry_short_name, url=CHEMISTRY_DEFINITIONS): return chemistry_def +def getstream(file): + if file.startswith("http://") or file.startswith("https://"): + response = requests.get(file) + response.raise_for_status() + + if file.endswith(".gz"): + content = response.content + text = gzip.decompress(content).decode("utf-8") + else: + text = response.text + f = io.StringIO(text) + return text + + elif file.endswith(".gz"): + f = gzip.open(file, mode="rt") + + else: + f = open(file, encoding="UTF-8") + return f + + def create_chemistry_definition(args): return chemistry_def diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 56fd90c..31f2a48 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -1,10 +1,13 @@ import os import gzip import shutil +import time +import datetime import pandas as pd from scipy import io +from cite_seq_count import secondsToText def write_to_files(sparse_matrix, top_cells, ordered_tags_map, data_type, outfolder): @@ -17,19 +20,23 @@ def write_to_files(sparse_matrix, top_cells, ordered_tags_map, data_type, outfol data_type (string): A string definning if the data is umi or read based. outfolder (string): Path to the output folder. """ - prefix = os.path.join(outfolder,data_type + '_count') + prefix = os.path.join(outfolder, data_type + "_count") os.makedirs(prefix, exist_ok=True) - io.mmwrite(os.path.join(prefix,'matrix.mtx'),sparse_matrix) - with gzip.open(os.path.join(prefix,'barcodes.tsv.gz'), 'wb') as barcode_file: + io.mmwrite(os.path.join(prefix, "matrix.mtx"), sparse_matrix) + with gzip.open(os.path.join(prefix, "barcodes.tsv.gz"), "wb") as barcode_file: for barcode in top_cells: - barcode_file.write('{}\n'.format(barcode).encode()) - with gzip.open(os.path.join(prefix,'features.tsv.gz'), 'wb') as feature_file: + barcode_file.write("{}\n".format(barcode).encode()) + with gzip.open(os.path.join(prefix, "features.tsv.gz"), "wb") as feature_file: for feature in ordered_tags_map: - feature_file.write('{}\t{}\n'.format(ordered_tags_map[feature]['sequence'], feature).encode()) - with open(os.path.join(prefix,'matrix.mtx'),'rb') as mtx_in: - with gzip.open(os.path.join(prefix,'matrix.mtx') + '.gz','wb') as mtx_gz: + feature_file.write( + "{}\t{}\n".format( + ordered_tags_map[feature]["sequence"], feature + ).encode() + ) + with open(os.path.join(prefix, "matrix.mtx"), "rb") as mtx_in: + with gzip.open(os.path.join(prefix, "matrix.mtx") + ".gz", "wb") as mtx_gz: shutil.copyfileobj(mtx_in, mtx_gz) - os.remove(os.path.join(prefix,'matrix.mtx')) + os.remove(os.path.join(prefix, "matrix.mtx")) def write_dense(sparse_matrix, index, columns, outfolder, filename): @@ -46,7 +53,7 @@ def write_dense(sparse_matrix, index, columns, outfolder, filename): prefix = os.path.join(outfolder) os.makedirs(prefix, exist_ok=True) pandas_dense = pd.DataFrame(sparse_matrix.todense(), columns=columns, index=index) - pandas_dense.to_csv(os.path.join(outfolder,filename), sep='\t') + pandas_dense.to_csv(os.path.join(outfolder, filename), sep="\t") def write_unmapped(merged_no_match, top_unknowns, outfolder, filename): @@ -59,10 +66,100 @@ def write_unmapped(merged_no_match, top_unknowns, outfolder, filename): outfolder (string): Path of the output folder filename (string): Name of the output file """ - + top_unmapped = merged_no_match.most_common(top_unknowns) - with open(os.path.join(outfolder, filename),'w') as unknown_file: - unknown_file.write('tag,count\n') + with open(os.path.join(outfolder, filename), "w") as unknown_file: + unknown_file.write("tag,count\n") for element in top_unmapped: - unknown_file.write('{},{}\n'.format(element[0],element[1])) + unknown_file.write("{},{}\n".format(element[0], element[1])) + + +def create_report( + total_reads, + reads_per_cell, + no_match, + version, + start_time, + ordered_tags_map, + umis_corrected, + bcs_corrected, + bad_cells, + R1_too_short, + R2_too_short, + args, + chemistry_def, +): + """ + Creates a report with details about the run in a yaml format. + Args: + total_reads (int): Number of reads that have been processed. + reads_matrix (scipy.sparse.dok_matrix): A sparse matrix continining read counts. + no_match (Counter): Counter of unmapped tags. + version (string): CITE-seq-Count package version. + start_time (time): Start time of the run. + args (arg_parse): Arguments provided by the user. + + """ + total_unmapped = sum(no_match.values()) + total_mapped = total_reads - total_unmapped + total_too_short = total_reads - total_unmapped - total_mapped + too_short_perc = round((total_too_short / total_reads) * 100) + mapped_perc = round((total_mapped / total_reads) * 100) + unmapped_perc = round((total_unmapped / total_reads) * 100) + + with open(os.path.join(args.outfolder, "run_report.yaml"), "w") as report_file: + report_file.write( + """Date: {} +Running time: {} +CITE-seq-Count Version: {} +Reads processed: {} +Percentage mapped: {} +Percentage unmapped: {} +Percentage too short: {} +\tR1_too_short: {} +\tR2_too_short: {} +Uncorrected cells: {} +Correction: +\tCell barcodes collapsing threshold: {} +\tCell barcodes corrected: {} +\tUMI collapsing threshold: {} +\tUMIs corrected: {} +Run parameters: +\tRead1_paths: {} +\tRead2_paths: {} +\tCell barcode: +\t\tFirst position: {} +\t\tLast position: {} +\tUMI barcode: +\t\tFirst position: {} +\t\tLast position: {} +\tExpected cells: {} +\tTags max errors: {} +\tStart trim: {} +""".format( + datetime.datetime.today().strftime("%Y-%m-%d"), + secondsToText.secondsToText(time.time() - start_time), + version, + int(total_reads), + mapped_perc, + unmapped_perc, + too_short_perc, + R1_too_short, + R2_too_short, + len(bad_cells), + args.bc_threshold, + bcs_corrected, + args.umi_threshold, + umis_corrected, + args.read1_path, + args.read2_path, + args.cb_first, + args.cb_last, + args.umi_first, + chemistry_def.umi_barcode_end, + args.expected_cells, + args.max_error, + chemistry_def.R2_trim_start, + ) + ) diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index ba7626d..b4a45a9 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -12,7 +12,7 @@ from itertools import islice -def parse_whitelist_csv(filename, barcode_length, collapsing_threshold): +def parse_whitelist_csv(csv_reader, barcode_length, collapsing_threshold): """Reads white-listed barcodes from a CSV file. The function accepts plain barcodes or even 10X style barcodes with the @@ -30,22 +30,12 @@ def parse_whitelist_csv(filename, barcode_length, collapsing_threshold): """ STRIP_CHARS = '"0123456789- \t\n' cell_pattern = regex.compile(r"[ATGC]{{{}}}".format(barcode_length)) - if filename.endswith("gz"): - with gzip.open(filename.name, mode="r") as csv_file: - csv_reader = csv.reader(csv_file) - whitelist = [ - row[0].strip(STRIP_CHARS) - for row in csv_reader - if (len(row[0].strip(STRIP_CHARS)) == barcode_length) - ] - else: - with open(filename, mode="r") as csv_file: - csv_reader = csv.reader(csv_file) - whitelist = [ - row[0].strip(STRIP_CHARS) - for row in csv_reader - if (len(row[0].strip(STRIP_CHARS)) == barcode_length) - ] + + whitelist = [ + row[0].strip(STRIP_CHARS) + for row in csv_reader + if (len(row[0].strip(STRIP_CHARS)) == barcode_length) + ] for cell_barcode in whitelist: if not cell_pattern.match(cell_barcode): @@ -84,6 +74,7 @@ def test_cell_distances(whitelist, collapsing_threshold): ) all_comb = combinations(whitelist, 2) for comb in all_comb: + # pylint: disable=no-member if Levenshtein.hamming(comb[0], comb[1]) <= collapsing_threshold: collapsing_threshold -= 1 print("Value is too high, reducing it by 1") @@ -159,6 +150,7 @@ def check_tags(tags, maximum_distance): offending_pairs = [] for a, b in combinations(tags.keys(), 2): + # pylint: disable=no-member distance = Levenshtein.distance(a, b) if distance <= (maximum_distance - 1): offending_pairs.append([a, b, distance]) diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index f5cef29..663ee51 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -10,6 +10,8 @@ from collections import Counter from collections import defaultdict from collections import namedtuple + +# pylint: disable=no-name-in-module from multiprocess import Pool from itertools import islice @@ -43,6 +45,7 @@ def find_best_match(TAG_seq, tags, maximum_distance): best_match = "unmapped" best_score = maximum_distance for tag in tags: + # pylint: disable=no-member score = Levenshtein.hamming(tag.sequence, TAG_seq[: len(tag.sequence)]) if score == 0: # Best possible match @@ -378,6 +381,7 @@ def correct_cells_whitelist( umis_per_cell (Counter): Updated UMI counts after correction. corrected_barcodes (int): How many umis have been corrected. """ + # pylint: disable=no-member barcode_tree = pybktree.BKTree(Levenshtein.hamming, whitelist) print("Generated barcode tree from whitelist") cell_barcodes = list(final_results.keys()) @@ -459,7 +463,7 @@ def generate_sparse_matrices(final_results, ordered_tags_map, top_cells): (len(ordered_tags_map), len(top_cells)), dtype=int32 ) for i, cell_barcode in enumerate(top_cells): - for j, TAG in enumerate(final_results[cell_barcode]): + for TAG in final_results[cell_barcode]: if final_results[cell_barcode][TAG]: umi_results_matrix[ordered_tags_map[TAG]["id"], i] = len( final_results[cell_barcode][TAG] @@ -468,4 +472,3 @@ def generate_sparse_matrices(final_results, ordered_tags_map, top_cells): final_results[cell_barcode][TAG].values() ) return (umi_results_matrix, read_results_matrix) - From f4cd58cadf5fbccee1503d788692697db208a7f5 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sun, 23 Aug 2020 15:08:09 +0200 Subject: [PATCH 15/77] refactoring, moved chunking to io --- cite_seq_count/__main__.py | 115 +++++++------------------------- cite_seq_count/argsparser.py | 2 + cite_seq_count/database.py | 111 ------------------------------ cite_seq_count/io.py | 112 +++++++++++++++++++++++++++++++ cite_seq_count/preprocessing.py | 4 -- 5 files changed, 139 insertions(+), 205 deletions(-) delete mode 100644 cite_seq_count/database.py diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 6746e00..ec3baf8 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -9,8 +9,6 @@ import requests import time -from itertools import islice - from collections import OrderedDict, Counter, defaultdict, namedtuple # pylint: disable=no-name-in-module @@ -44,8 +42,7 @@ def main(): # Parse arguments. args = parser.parse_args() - temp_path = os.path.abspath(args.temp_path) - assert os.access(temp_path, os.W_OK) + assert os.access(args.temp_path, os.W_OK) # Get chemistry defs (whitelist, chemistry_def) = chemistry.setup_chemistry(args) @@ -74,7 +71,7 @@ def main(): read1_lengths.append(preprocessing.get_read_length(read1_path)) read2_lengths.append(preprocessing.get_read_length(read2_path)) # Check Read1 length against CELL and UMI barcodes length. - (barcode_slice, umi_slice, _) = preprocessing.check_barcodes_lengths( + preprocessing.check_barcodes_lengths( read1_lengths[-1], chemistry_def.cell_barcode_start, chemistry_def.cell_barcode_end, @@ -89,100 +86,37 @@ def main(): # sys.exit('Input barcode fastqs (read2) do not all have same length.\nExiting') # Define R2_lenght to reduce amount of data to transfer to childrens + number_of_samples = len(read1_paths) + + # Print a statement if multiple files are run. + if number_of_samples != 1: + print("Detected {} files to run on.".format(number_of_samples)) + if args.sliding_window: R2_max_length = read2_lengths[0] else: R2_max_length = longest_tag_len + + ( + input_queue, + temp_files, + R1_too_short, + R2_too_short, + total_reads, + ) = io.write_chunks_to_disk( + args=args, + read1_paths=read1_paths, + read2_paths=read2_paths, + R2_max_length=R2_max_length, + total_reads=total_reads, + chemistry_def=chemistry_def, + named_tuples_tags_map=named_tuples_tags_map, + ) # Initialize the counts dicts that will be generated from each input fastq pair final_results = defaultdict(lambda: defaultdict(Counter)) umis_per_cell = Counter() reads_per_cell = Counter() merged_no_match = Counter() - number_of_samples = len(read1_paths) - - # Print a statement if multiple files are run. - if number_of_samples != 1: - print("Detected {} files to run on.".format(number_of_samples)) - input_queue = [] - mapping_input = namedtuple( - "mapping_input", - ["filename", "tags", "debug", "maximum_distance", "sliding_window"], - ) - - print("Writing chunks to disk") - reads_count = 0 - num_chunks = 0 - if args.chunk_size: - chunk_size = args.chunk_size - else: - chunk_size = round(total_reads / args.n_threads) + 1 - temp_files = [] - R1_too_short = 0 - R2_too_short = 0 - for read1_path, read2_path in zip(read1_paths, read2_paths): - print("Reading reads from files: {}, {}".format(read1_path, read2_path)) - with gzip.open(read1_path, "rt") as textfile1, gzip.open( - read2_path, "rt" - ) as textfile2: - secondlines = islice(zip(textfile1, textfile2), 1, None, 4) - temp_filename = os.path.join(temp_path, "temp_{}".format(num_chunks)) - chunked_file_object = open(temp_filename, "w") - temp_files.append(os.path.abspath(temp_filename)) - for read1, read2 in secondlines: - - read1 = read1.strip() - if len(read1) < chemistry_def.umi_barcode_end: - R1_too_short += 1 - # The entire read is skipped - continue - read1_sliced = read1[0 : chemistry_def.umi_barcode_end] - if len(read2) < R2_max_length: - R2_too_short += 1 - # The entire read is skipped - continue - - read2_sliced = read2[ - chemistry_def.R2_trim_start : ( - R2_max_length + chemistry_def.R2_trim_start - ) - ] - chunked_file_object.write( - "{},{},{}\n".format( - read1_sliced[barcode_slice], - read1_sliced[umi_slice], - read2_sliced, - ) - ) - reads_count += 1 - if reads_count % chunk_size == 0: - input_queue.append( - mapping_input( - filename=temp_filename, - tags=named_tuples_tags_map, - debug=args.debug, - maximum_distance=args.max_error, - sliding_window=args.sliding_window, - ) - ) - num_chunks += 1 - chunked_file_object.close() - temp_filename = "temp_{}".format(num_chunks) - chunked_file_object = open(temp_filename, "w") - temp_files.append(os.path.abspath(temp_filename)) - if reads_count >= args.first_n: - total_reads = args.first_n - break - - input_queue.append( - mapping_input( - filename=temp_filename, - tags=named_tuples_tags_map, - debug=args.debug, - maximum_distance=args.max_error, - sliding_window=args.sliding_window, - ) - ) - chunked_file_object.close() print("Started mapping") parallel_results = [] @@ -218,6 +152,7 @@ def main(): start_trim=chemistry_def.R2_trim_start, ) # Delete temp_files + exit() for file_path in temp_files: if os.path.exists(file_path): os.remove(file_path) diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py index 1694bd9..e035f78 100644 --- a/cite_seq_count/argsparser.py +++ b/cite_seq_count/argsparser.py @@ -24,6 +24,8 @@ def chunk_size_limit(arg): raise ArgumentTypeError( "Argument must be < " + str(max_size) + "and > " + str(1) ) + else: + return False return f diff --git a/cite_seq_count/database.py b/cite_seq_count/database.py deleted file mode 100644 index 9870266..0000000 --- a/cite_seq_count/database.py +++ /dev/null @@ -1,111 +0,0 @@ -"""This module is holding code for all the remote fetching on the chemistries database.""" -import requests -import sys -import os -import gzip -import io - -from collections import namedtuple - -from dataclasses import dataclass - -GLOBAL_LINK_RAW = "https://raw.githubusercontent.com/Hoohm/scg_lib_structs/10xv3_totalseq_b/chemistries/" -GLOBAL_LINK_GITHUB = "https://github.com/Hoohm/scg_lib_structs/raw/10xv3_totalseq_b/" -GLOBAL_LINK_GITHUB_IO = "https://teichlab.github.io/scg_lib_structs" -CHEMISTRY_DEFINITIONS = os.path.join(GLOBAL_LINK_RAW, "definitions.json") - - -@dataclass -class Chemistry: - name: str - cell_barcode_start: int - cell_barcode_end: int - umi_barcode_start: int - umi_barcode_end: int - R2_trim_start: int - whitelist_path: str - mapping_required: bool - - -def list_chemistries(url=CHEMISTRY_DEFINITIONS): - print("Loading remote file from: {}".format(url)) - with requests.get(url) as r: - r.raise_for_status() - all_chemistry_defs = r.json() - print( - "Here are all the possible chemistries available at {}".format( - GLOBAL_LINK_GITHUB_IO - ) - ) - for chemistry in all_chemistry_defs: - print( - "\n-- {}\n shortname: {}\n Protocol link: {}\n\n".format( - all_chemistry_defs[chemistry]["Description"], - chemistry, - os.path.join( - GLOBAL_LINK_GITHUB_IO, all_chemistry_defs[chemistry]["html"] - ), - ) - ) - - -def get_chemistry_definition(chemistry_short_name, url=CHEMISTRY_DEFINITIONS): - """ - Fetches chemistry definitions from a remote definitions.json and returns the json. - """ - print("Loading remote file from: {}".format(url)) - with requests.get(url) as r: - r.raise_for_status() - chemistry_defs = r.json().get(chemistry_short_name, False) - if not chemistry_defs: - sys.exit( - "Could not find the chemistry: {}. Please check that it does exist at: {}\nExiting".format( - chemistry_short_name, url - ) - ) - chemistry_def = Chemistry( - name=chemistry_short_name, - cell_barcode_start=chemistry_defs["barcode_structure_indexes"]["cell_barcode"][ - "R1" - ]["start"], - cell_barcode_end=chemistry_defs["barcode_structure_indexes"]["cell_barcode"][ - "R1" - ]["stop"], - umi_barcode_start=chemistry_defs["barcode_structure_indexes"]["umi_barcode"][ - "R1" - ]["start"], - umi_barcode_end=chemistry_defs["barcode_structure_indexes"]["umi_barcode"][ - "R1" - ]["stop"], - R2_trim_start=chemistry_defs["sequence_structure_indexes"]["R2"]["start"] - 1, - whitelist_path=os.path.join( - GLOBAL_LINK_GITHUB, "chemistries", chemistry_defs["whitelist"]["path"] - ), - mapping_required=chemistry_defs["whitelist"]["mapping"], - ) - return chemistry_def - - -def getstream(file): - if file.startswith("http://") or file.startswith("https://"): - response = requests.get(file) - response.raise_for_status() - - if file.endswith(".gz"): - content = response.content - text = gzip.decompress(content).decode("utf-8") - else: - text = response.text - f = io.StringIO(text) - return text - - elif file.endswith(".gz"): - f = gzip.open(file, mode="rt") - - else: - f = open(file, encoding="UTF-8") - return f - - -def create_chemistry_definition(args): - return chemistry_def diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 31f2a48..fc57dba 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -4,6 +4,9 @@ import time import datetime +from collections import namedtuple +from itertools import islice + import pandas as pd from scipy import io @@ -163,3 +166,112 @@ def create_report( chemistry_def.R2_trim_start, ) ) + + +def write_chunks_to_disk( + args, + read1_paths, + read2_paths, + R2_max_length, + total_reads, + chemistry_def, + named_tuples_tags_map, +): + """ + """ + mapping_input = namedtuple( + "mapping_input", + ["filename", "tags", "debug", "maximum_distance", "sliding_window"], + ) + + print("Writing chunks to disk") + + num_chunk = 0 + if not args.chunk_size: + args.chunk_size = round(total_reads / args.n_threads) + 1 + temp_path = os.path.abspath(args.temp_path) + input_queue = [] + temp_files = [] + R1_too_short = 0 + R2_too_short = 0 + total_reads_written = 0 + + barcode_slice = slice( + chemistry_def.cell_barcode_start - 1, chemistry_def.cell_barcode_end + ) + umi_slice = slice( + chemistry_def.umi_barcode_start - 1, chemistry_def.umi_barcode_end + ) + + for read1_path, read2_path in zip(read1_paths, read2_paths): + print("Reading reads from files: {}, {}".format(read1_path, read2_path)) + with gzip.open(read1_path, "rt") as textfile1, gzip.open( + read2_path, "rt" + ) as textfile2: + secondlines = islice(zip(textfile1, textfile2), 1, None, 4) + temp_filename = os.path.join(temp_path, "temp_{}".format(num_chunk)) + chunked_file_object = open(temp_filename, "w") + temp_files.append(os.path.abspath(temp_filename)) + reads_written = 0 + for read1, read2 in secondlines: + + read1 = read1.strip() + if len(read1) < chemistry_def.umi_barcode_end: + R1_too_short += 1 + # The entire read is skipped + continue + if len(read2) < R2_max_length: + R2_too_short += 1 + # The entire read is skipped + continue + + read1_sliced = read1[ + chemistry_def.cell_barcode_start - 1 : chemistry_def.umi_barcode_end + ] + + read2_sliced = read2[ + chemistry_def.R2_trim_start : ( + R2_max_length + chemistry_def.R2_trim_start + ) + ] + chunked_file_object.write( + "{},{},{}\n".format( + read1_sliced[barcode_slice], + read1_sliced[umi_slice], + read2_sliced, + ) + ) + + reads_written += 1 + total_reads_written += 1 + if reads_written % args.chunk_size == 0: + input_queue.append( + mapping_input( + filename=temp_filename, + tags=named_tuples_tags_map, + debug=args.debug, + maximum_distance=args.max_error, + sliding_window=args.sliding_window, + ) + ) + num_chunk += 1 + chunked_file_object.close() + temp_filename = "temp_{}".format(num_chunk) + chunked_file_object = open(temp_filename, "w") + temp_files.append(os.path.abspath(temp_filename)) + reads_written = 0 + if total_reads_written >= args.first_n: + total_reads = total_reads_written + break + + input_queue.append( + mapping_input( + filename=temp_filename, + tags=named_tuples_tags_map, + debug=args.debug, + maximum_distance=args.max_error, + sliding_window=args.sliding_window, + ) + ) + chunked_file_object.close() + return input_queue, temp_files, R1_too_short, R2_too_short, total_reads diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index b4a45a9..cd36194 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -251,8 +251,6 @@ def check_barcodes_lengths(read1_length, cb_first, cb_last, umi_first, umi_last) barcode_length = cb_last - cb_first + 1 umi_length = umi_last - umi_first + 1 barcode_umi_length = barcode_length + umi_length - barcode_slice = slice(cb_first - 1, cb_last) - umi_slice = slice(umi_first - 1, umi_last) if barcode_umi_length > read1_length: sys.exit( @@ -269,8 +267,6 @@ def check_barcodes_lengths(read1_length, cb_first, cb_last, umi_first, umi_last) ) ) - return (barcode_slice, umi_slice, barcode_umi_length) - def blocks(files, size=65536): """ From e5bbe59dc37706397460a4b308b1f8167f73a66e Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sat, 5 Sep 2020 17:45:17 +0200 Subject: [PATCH 16/77] some more changes --- cite_seq_count/argsparser.py | 2 +- cite_seq_count/chemistry.py | 6 ++++++ cite_seq_count/preprocessing.py | 6 ------ setup.py | 2 +- 4 files changed, 8 insertions(+), 8 deletions(-) diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py index e035f78..60bcf0d 100644 --- a/cite_seq_count/argsparser.py +++ b/cite_seq_count/argsparser.py @@ -250,7 +250,7 @@ def get_args(): required=False, type=str, dest="temp_path", - default="", + default=".", help="Temp folder for chunk creation specification. Useful when using a cluster with a scratch folder", ) diff --git a/cite_seq_count/chemistry.py b/cite_seq_count/chemistry.py index a97f2fb..445a5da 100644 --- a/cite_seq_count/chemistry.py +++ b/cite_seq_count/chemistry.py @@ -31,6 +31,12 @@ class Chemistry: def list_chemistries(url=CHEMISTRY_DEFINITIONS): + """ + List all the available chemistries in the database + Args: + url (str): The url to the database file + + """ print("Loading remote file from: {}".format(url)) with requests.get(url) as r: r.raise_for_status() diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index cd36194..0fb3d75 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -241,12 +241,6 @@ def check_barcodes_lengths(read1_length, cb_first, cb_last, umi_first, umi_last) cb_last (int): Barcode last base position for Read1. umi_first (int): UMI first base position for Read1. umi_last (int): UMI last base position for Read1. - - Returns: - slice: A `slice` object to extract the Barcode from the sequence string. - slice: A `slice` object to extract the UMI from the sequence string. - int: The Barcode + UMI length. - """ barcode_length = cb_last - cb_first + 1 umi_length = umi_last - umi_first + 1 diff --git a/setup.py b/setup.py index eb083e0..0e60ceb 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ "scipy>=1.1.0", "multiprocess>=0.70.6.1", "umi_tools==1.0.0", - "pytest==4.1.0", + "pytest==6.0.1", "pytest-dependency==0.4.0", "pandas>=0.23.4", "pybktree==1.1", From 7c3fbc605efaa5238254cec7f78bba59250c0f2e Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sun, 6 Sep 2020 21:39:27 +0200 Subject: [PATCH 17/77] fixed chunking --- cite_seq_count/__main__.py | 93 ++++++++++++++++++--------------- cite_seq_count/argsparser.py | 12 +---- cite_seq_count/chemistry.py | 91 +++++++++++++++++--------------- cite_seq_count/io.py | 81 ++++++++++++++-------------- cite_seq_count/preprocessing.py | 14 ++--- cite_seq_count/processing.py | 75 ++++++++++++++------------ setup.py | 1 + 7 files changed, 194 insertions(+), 173 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index ec3baf8..cb672c1 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -79,18 +79,18 @@ def main(): chemistry_def.umi_barcode_end, ) - # Ensure all files have the same input length - # if len(set(read1_lengths)) != 1: - # sys.exit('Input barcode fastqs (read1) do not all have same length.\nExiting') - # if len(set(read2_lengths)) != 1: - # sys.exit('Input barcode fastqs (read2) do not all have same length.\nExiting') + # Get all reads or only top N? + if args.first_n < float("inf"): + n_reads = args.first_n + else: + n_reads = total_reads # Define R2_lenght to reduce amount of data to transfer to childrens number_of_samples = len(read1_paths) # Print a statement if multiple files are run. if number_of_samples != 1: - print("Detected {} files to run on.".format(number_of_samples)) + print("Detected {} pairs of files to run on.".format(number_of_samples)) if args.sliding_window: R2_max_length = read2_lengths[0] @@ -108,7 +108,7 @@ def main(): read1_paths=read1_paths, read2_paths=read2_paths, R2_max_length=R2_max_length, - total_reads=total_reads, + n_reads_to_chunk=n_reads, chemistry_def=chemistry_def, named_tuples_tags_map=named_tuples_tags_map, ) @@ -148,16 +148,18 @@ def main(): # Check if 99% of the reads are unmapped. processing.check_unmapped( no_match=merged_no_match, + too_short=R1_too_short + R2_too_short, total_reads=total_reads, start_trim=chemistry_def.R2_trim_start, ) # Delete temp_files - exit() + # exit() for file_path in temp_files: - if os.path.exists(file_path): - os.remove(file_path) - else: - print("Could not find file: {}".format(file_path)) + os.remove(file_path) + + # Select top cells + top_cells_tuple = umis_per_cell.most_common(args.expected_cells * 10) + top_cells = set([pair[0] for pair in top_cells_tuple]) # Correct cell barcodes if args.bc_threshold != 0: @@ -192,6 +194,7 @@ def main(): ) = processing.correct_cells_whitelist( final_results=final_results, umis_per_cell=umis_per_cell, + top_cells=top_cells, whitelist=whitelist, collapsing_threshold=args.bc_threshold, ab_map=named_tuples_tags_map, @@ -201,28 +204,37 @@ def main(): bcs_corrected = 0 # If given, use whitelist for top cells - if whitelist: - top_cells = whitelist - # Add potential missing cell barcodes. - for missing_cell in whitelist: - if missing_cell in final_results: - continue - else: - final_results[missing_cell] = dict() - for TAG in named_tuples_tags_map: - final_results[missing_cell][TAG] = Counter() - top_cells.add(missing_cell) - else: - # Select top cells based on total umis per cell - top_cells_tuple = umis_per_cell.most_common(args.expected_cells) - top_cells = set([pair[0] for pair in top_cells_tuple]) + # if whitelist: + # top_cells = whitelist + # # Add potential missing cell barcodes. + # for missing_cell in whitelist: + # if missing_cell in final_results: + # continue + # else: + # final_results[missing_cell] = dict() + # for TAG in named_tuples_tags_map: + # final_results[missing_cell][TAG.safe_name] = Counter() + # top_cells.add(missing_cell) + # else: + # Select top cells based on total umis per cell + + # Create sparse matrices for reads results + read_results_matrix = processing.generate_sparse_matrices( + final_results=final_results, + ordered_tags_map=ordered_tags_map, + top_cells=top_cells, + ) + # Write reads to file + io.write_to_files( + sparse_matrix=read_results_matrix, + top_cells=top_cells, + ordered_tags_map=ordered_tags_map, + data_type="read", + outfolder=args.outfolder, + ) # UMI correction - if args.no_umi_correction: - # Don't correct - umis_corrected = 0 - aberrant_cells = set() - else: + if args.umi_threshold != 0: # Correct UMIS input_queue = [] @@ -280,6 +292,10 @@ def main(): final_results.update(temp_results) umis_corrected += temp_umis aberrant_cells.update(temp_aberrant_cells) + else: + # Don't correct + umis_corrected = 0 + aberrant_cells = set() if len(aberrant_cells) > 0: # Remove aberrant cells from the top cells @@ -302,11 +318,11 @@ def main(): filename="dense_umis.tsv", ) - # Create sparse matrices for results - (umi_results_matrix, read_results_matrix) = processing.generate_sparse_matrices( + umi_results_matrix = processing.generate_sparse_matrices( final_results=final_results, ordered_tags_map=ordered_tags_map, top_cells=top_cells, + umi_counts=True, ) # Write umis to file @@ -318,15 +334,6 @@ def main(): outfolder=args.outfolder, ) - # Write reads to file - io.write_to_files( - sparse_matrix=read_results_matrix, - top_cells=top_cells, - ordered_tags_map=ordered_tags_map, - data_type="read", - outfolder=args.outfolder, - ) - # Write unmapped sequences io.write_unmapped( merged_no_match=merged_no_match, diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py index 60bcf0d..47a986e 100644 --- a/cite_seq_count/argsparser.py +++ b/cite_seq_count/argsparser.py @@ -128,17 +128,9 @@ def get_args(): dest="umi_threshold", required=False, type=int, - default=2, + default=1, help="threshold for umi collapsing.", ) - barcodes.add_argument( - "--no_umi_correction", - required=False, - action="store_true", - default=False, - dest="no_umi_correction", - help="Deactivate UMI collapsing", - ) barcodes.add_argument( "--bc_collapsing_dist", dest="bc_threshold", @@ -184,7 +176,7 @@ def get_args(): type=str, help="A csv file containing the mapping between two sets of cell barcode list.\n" "A required header such as the reference is named whitelist. Example:\n\n" - "\twhitelist,trasnlation\n" + "\twhitelist,translation\n" "\tAAACCCAAGAAACACT,AAACCCATCAAACACT\n" "\tAAACCCAAGAAACCAT,AAACCCATCAAACCAT\n" "\nThe output matrix will possess both cell barcode IDs", diff --git a/cite_seq_count/chemistry.py b/cite_seq_count/chemistry.py index 445a5da..3df59ff 100644 --- a/cite_seq_count/chemistry.py +++ b/cite_seq_count/chemistry.py @@ -4,7 +4,9 @@ import os import gzip import io +import pooch import csv +import json from collections import namedtuple @@ -15,7 +17,8 @@ GLOBAL_LINK_RAW = "https://raw.githubusercontent.com/Hoohm/scg_lib_structs/10xv3_totalseq_b/chemistries/" GLOBAL_LINK_GITHUB = "https://github.com/Hoohm/scg_lib_structs/raw/10xv3_totalseq_b/" GLOBAL_LINK_GITHUB_IO = "https://teichlab.github.io/scg_lib_structs" -CHEMISTRY_DEFINITIONS = os.path.join(GLOBAL_LINK_RAW, "definitions.json") +# "https://github.com/Hoohm/scg_lib_structs/raw/10xv3_totalseq_b/chemistries/whitelists/3M-february-2018.csv.gz" +# CHEMISTRY_DEFINITIONS = os.path.join(GLOBAL_LINK_RAW, "definitions.json") @dataclass @@ -30,17 +33,40 @@ class Chemistry: mapping_required: bool -def list_chemistries(url=CHEMISTRY_DEFINITIONS): +DEFINITIONS_DB = pooch.create( + path=pooch.os_cache("cite_seq_count"), + base_url=GLOBAL_LINK_RAW, + version="0.1.0", + env="MYPACKAGE_DATA_DIR", + # The cache file registry. A dictionary with all files managed by this + # pooch. Keys are the file names (relative to *base_url*) and values + # are their respective SHA256 hashes. Files will be downloaded + # automatically when needed (see fetch_gravity_data). + registry={ + "definitions.json": "4f2e1cee60446062f0c805b0b51ce12318cac04c64f1c7825037e2c73ff9955e" + }, +) + + +def fetch_definitions(): + """ + Load some sample gravity data to use in your docs. + """ + fname = DEFINITIONS_DB.fetch("definitions.json") + with open(fname, "r") as json_file: + data = json_file.read() + json_data = json.loads(data) + return json_data + + +def list_chemistries(chemistry_defs): """ List all the available chemistries in the database Args: url (str): The url to the database file """ - print("Loading remote file from: {}".format(url)) - with requests.get(url) as r: - r.raise_for_status() - all_chemistry_defs = r.json() + all_chemistry_defs = fetch_definitions() print( "Here are all the possible chemistries available at {}".format( GLOBAL_LINK_GITHUB_IO @@ -58,20 +84,23 @@ def list_chemistries(url=CHEMISTRY_DEFINITIONS): ) -def get_chemistry_definition(chemistry_short_name, url=CHEMISTRY_DEFINITIONS): +def get_chemistry_definition(chemistry_short_name): """ Fetches chemistry definitions from a remote definitions.json and returns the json. """ - print("Loading remote file from: {}".format(url)) - with requests.get(url) as r: - r.raise_for_status() - chemistry_defs = r.json().get(chemistry_short_name, False) - if not chemistry_defs: - sys.exit( - "Could not find the chemistry: {}. Please check that it does exist at: {}\nExiting".format( - chemistry_short_name, url - ) + chemistry_defs = fetch_definitions()[chemistry_short_name] + + if chemistry_defs["whitelist"]["path"] not in DEFINITIONS_DB.registry: + path = pooch.retrieve( + url=os.path.join( + GLOBAL_LINK_GITHUB, "chemistries", chemistry_defs["whitelist"]["path"] + ), + known_hash=None, + fname=chemistry_defs["whitelist"]["path"], + path=DEFINITIONS_DB.abspath, ) + else: + path = DEFINITIONS_DB.registry[chemistry_defs["whitelist"]["path"]] chemistry_def = Chemistry( name=chemistry_short_name, cell_barcode_start=chemistry_defs["barcode_structure_indexes"]["cell_barcode"][ @@ -87,36 +116,12 @@ def get_chemistry_definition(chemistry_short_name, url=CHEMISTRY_DEFINITIONS): "R1" ]["stop"], R2_trim_start=chemistry_defs["sequence_structure_indexes"]["R2"]["start"] - 1, - whitelist_path=os.path.join( - GLOBAL_LINK_GITHUB, "chemistries", chemistry_defs["whitelist"]["path"] - ), + whitelist_path=path, mapping_required=chemistry_defs["whitelist"]["mapping"], ) return chemistry_def -def get_csv_reader(file): - if file.startswith("http://") or file.startswith("https://"): - response = requests.get(file) - response.raise_for_status() - - if file.endswith(".gz"): - content = response.content - text = gzip.decompress(content).decode("utf-8") - else: - text = response.text - reader = csv.reader(io.StringIO(text)) - - elif file.endswith(".gz"): - f = gzip.open(file, mode="rt") - reader = csv.reader(f) - else: - f = open(file, encoding="UTF-8") - reader = csv.reader(f) - - return reader - - def create_chemistry_definition(args): chemistry_def = Chemistry( name="custom", @@ -135,7 +140,7 @@ def setup_chemistry(args): if args.chemistry: chemistry_def = get_chemistry_definition(args.chemistry) (whitelist, args.bc_threshold) = preprocessing.parse_whitelist_csv( - csv_reader=get_csv_reader(chemistry_def.whitelist_path), + filename=chemistry_def.whitelist_path, barcode_length=chemistry_def.cell_barcode_end - chemistry_def.cell_barcode_start + 1, @@ -146,7 +151,7 @@ def setup_chemistry(args): if args.whitelist: print("Loading whitelist") (whitelist, args.bc_threshold) = preprocessing.parse_whitelist_csv( - csv_reader=get_csv_reader(args.whitelist), + filename=args.whitelist, barcode_length=args.cb_last - args.cb_first + 1, collapsing_threshold=args.bc_threshold, ) diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index fc57dba..c8dae1a 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -105,8 +105,9 @@ def create_report( """ total_unmapped = sum(no_match.values()) - total_mapped = total_reads - total_unmapped - total_too_short = total_reads - total_unmapped - total_mapped + total_too_short = R1_too_short + R2_too_short + total_mapped = total_reads - total_unmapped - total_too_short + too_short_perc = round((total_too_short / total_reads) * 100) mapped_perc = round((total_mapped / total_reads) * 100) unmapped_perc = round((total_unmapped / total_reads) * 100) @@ -120,26 +121,26 @@ def create_report( Percentage mapped: {} Percentage unmapped: {} Percentage too short: {} -\tR1_too_short: {} -\tR2_too_short: {} + R1_too_short: {} + R2_too_short: {} Uncorrected cells: {} Correction: -\tCell barcodes collapsing threshold: {} -\tCell barcodes corrected: {} -\tUMI collapsing threshold: {} -\tUMIs corrected: {} + Cell barcodes collapsing threshold: {} + Cell barcodes corrected: {} + UMI collapsing threshold: {} + UMIs corrected: {} Run parameters: -\tRead1_paths: {} -\tRead2_paths: {} -\tCell barcode: -\t\tFirst position: {} -\t\tLast position: {} -\tUMI barcode: -\t\tFirst position: {} -\t\tLast position: {} -\tExpected cells: {} -\tTags max errors: {} -\tStart trim: {} + Read1_paths: {} + Read2_paths: {} + Cell barcode: + First position: {} + Last position: {} + UMI barcode: + First position: {} + Last position: {} + Expected cells: {} + Tags max errors: {} + Start trim: {} """.format( datetime.datetime.today().strftime("%Y-%m-%d"), secondsToText.secondsToText(time.time() - start_time), @@ -157,9 +158,9 @@ def create_report( umis_corrected, args.read1_path, args.read2_path, - args.cb_first, - args.cb_last, - args.umi_first, + chemistry_def.cell_barcode_start, + chemistry_def.cell_barcode_end, + chemistry_def.umi_barcode_start, chemistry_def.umi_barcode_end, args.expected_cells, args.max_error, @@ -173,7 +174,7 @@ def write_chunks_to_disk( read1_paths, read2_paths, R2_max_length, - total_reads, + n_reads_to_chunk, chemistry_def, named_tuples_tags_map, ): @@ -188,13 +189,16 @@ def write_chunks_to_disk( num_chunk = 0 if not args.chunk_size: - args.chunk_size = round(total_reads / args.n_threads) + 1 + chunk_size = round(n_reads_to_chunk / args.n_threads) + else: + chunk_size = args.chunk_size temp_path = os.path.abspath(args.temp_path) input_queue = [] temp_files = [] R1_too_short = 0 R2_too_short = 0 total_reads_written = 0 + enough_reads = False barcode_slice = slice( chemistry_def.cell_barcode_start - 1, chemistry_def.cell_barcode_end @@ -204,11 +208,14 @@ def write_chunks_to_disk( ) for read1_path, read2_path in zip(read1_paths, read2_paths): + if enough_reads: + break print("Reading reads from files: {}, {}".format(read1_path, read2_path)) with gzip.open(read1_path, "rt") as textfile1, gzip.open( read2_path, "rt" ) as textfile2: secondlines = islice(zip(textfile1, textfile2), 1, None, 4) + temp_filename = os.path.join(temp_path, "temp_{}".format(num_chunk)) chunked_file_object = open(temp_filename, "w") temp_files.append(os.path.abspath(temp_filename)) @@ -244,7 +251,9 @@ def write_chunks_to_disk( reads_written += 1 total_reads_written += 1 - if reads_written % args.chunk_size == 0: + if reads_written % chunk_size == 0 and reads_written != 0: + # We have enough reads in this chunk, open a new one + chunked_file_object.close() input_queue.append( mapping_input( filename=temp_filename, @@ -254,24 +263,18 @@ def write_chunks_to_disk( sliding_window=args.sliding_window, ) ) + if total_reads_written == n_reads_to_chunk: + enough_reads = True + chunked_file_object.close() + break num_chunk += 1 - chunked_file_object.close() temp_filename = "temp_{}".format(num_chunk) chunked_file_object = open(temp_filename, "w") temp_files.append(os.path.abspath(temp_filename)) reads_written = 0 - if total_reads_written >= args.first_n: - total_reads = total_reads_written + if total_reads_written == n_reads_to_chunk: + enough_reads = True + chunked_file_object.close() break - input_queue.append( - mapping_input( - filename=temp_filename, - tags=named_tuples_tags_map, - debug=args.debug, - maximum_distance=args.max_error, - sliding_window=args.sliding_window, - ) - ) - chunked_file_object.close() - return input_queue, temp_files, R1_too_short, R2_too_short, total_reads + return input_queue, temp_files, R1_too_short, R2_too_short, total_reads_written diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 0fb3d75..3f5772d 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -12,7 +12,7 @@ from itertools import islice -def parse_whitelist_csv(csv_reader, barcode_length, collapsing_threshold): +def parse_whitelist_csv(filename, barcode_length, collapsing_threshold): """Reads white-listed barcodes from a CSV file. The function accepts plain barcodes or even 10X style barcodes with the @@ -31,6 +31,12 @@ def parse_whitelist_csv(csv_reader, barcode_length, collapsing_threshold): STRIP_CHARS = '"0123456789- \t\n' cell_pattern = regex.compile(r"[ATGC]{{{}}}".format(barcode_length)) + if filename.endswith(".gz"): + f = gzip.open(filename, mode="rt") + csv_reader = csv.reader(f) + else: + f = open(filename, encoding="UTF-8") + csv_reader = csv.reader(f) whitelist = [ row[0].strip(STRIP_CHARS) for row in csv_reader @@ -228,10 +234,6 @@ def get_read_length(filename): return read_length -def get_chunk_strategy(read1_paths, read2_paths, chunk_size): - pass - - def check_barcodes_lengths(read1_length, cb_first, cb_last, umi_first, umi_last): """Check Read1 length against CELL and UMI barcodes length. @@ -293,7 +295,7 @@ def get_n_lines(file_path): Returns: n_lines (int): Number of lines in the file """ - print("Counting number of reads") + print("Counting number of reads in file {}".format(file_path)) with gzip.open(file_path, "rt", encoding="utf-8", errors="ignore") as f: n_lines = sum(bl.count("\n") for bl in blocks(f)) if n_lines % 4 != 0: diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 663ee51..8543a2d 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -77,7 +77,7 @@ def find_best_match_shift(TAG_seq, tags): best_match = "unmapped" for tag in tags: if tag.sequence in TAG_seq: - return tag.name + return tag.secure_name return best_match @@ -112,7 +112,6 @@ def map_reads(mapping_input): no_match = Counter() n = 1 t = time.time() - # Progress info with open(filename, "r") as input_file: reads = csv.reader(input_file) @@ -205,11 +204,11 @@ def merge_results(parallel_results): return (merged_results, umis_per_cell, reads_per_cell, merged_no_match) -def check_unmapped(no_match, total_reads, start_trim): +def check_unmapped(no_match, too_short, total_reads, start_trim): """Check if the number of unmapped is higher than 99%""" - if sum(no_match.values()) / total_reads > float(0.99): + if (sum(no_match.values()) + too_short) / total_reads > float(0.99): exit( - """More than 99 percent of your data is unmapped.\nPlease check that your --start_trim {} parameter is correct and that your tags file is properly formatted""".format( + """More than 99% of your data is unmapped.\nPlease check that your --start_trim {} parameter is correct and that your tags file is properly formatted""".format( start_trim ) ) @@ -242,6 +241,9 @@ def correct_umis(umi_correction_input): cells = final_results.keys() for cell_barcode in cells: for TAG in final_results[cell_barcode]: + if TAG == "unmapped": + final_results[cell_barcode][TAG].pop() + n_umis = len(final_results[cell_barcode][TAG]) if n_umis > 1 and n_umis <= max_umis: umi_clusters = network.UMIClusterer() @@ -303,17 +305,22 @@ def collapse_cells(true_to_false, umis_per_cell, final_results, ab_map): corrected_barcodes = 0 for real_barcode in true_to_false: # If the cell barcode is not in the results + # add it in. if real_barcode not in final_results: final_results[real_barcode] = defaultdict() for TAG in ab_map: - final_results[real_barcode][TAG] = Counter() - for fake_barcode in true_to_false[real_barcode]: - temp = final_results.pop(fake_barcode) - corrected_barcodes += 1 + final_results[real_barcode][TAG.safe_name] = Counter() + for wrong_barcode in true_to_false[real_barcode]: + temp = final_results.pop(wrong_barcode) + for TAG in temp.keys(): - final_results[real_barcode][TAG].update(temp[TAG]) - temp_umi_counts = umis_per_cell.pop(fake_barcode) - # temp_read_counts = reads_per_cell.pop(fake_barcode) + if TAG in final_results[real_barcode]: + final_results[real_barcode][TAG].update(temp[TAG]) + else: + final_results[real_barcode][TAG] = temp[TAG] + corrected_barcodes += 1 + temp_umi_counts = umis_per_cell.pop(wrong_barcode) + # temp_read_counts = reads_per_cell.pop(wrong_barcode) umis_per_cell[real_barcode] += temp_umi_counts # reads_per_cell[real_barcode] += temp_read_counts @@ -363,7 +370,7 @@ def correct_cells( def correct_cells_whitelist( - final_results, umis_per_cell, whitelist, collapsing_threshold, ab_map + final_results, umis_per_cell, whitelist, top_cells, collapsing_threshold, ab_map ): """ Corrects cell barcodes. @@ -384,18 +391,18 @@ def correct_cells_whitelist( # pylint: disable=no-member barcode_tree = pybktree.BKTree(Levenshtein.hamming, whitelist) print("Generated barcode tree from whitelist") - cell_barcodes = list(final_results.keys()) - n_barcodes = len(cell_barcodes) + print("Finding reference candidates") - print("Processing {:,} cell barcodes".format(n_barcodes)) + print("Processing {:,} cell barcodes".format(len(top_cells))) # Run with one process true_to_false = find_true_to_false_map( barcode_tree=barcode_tree, - cell_barcodes=cell_barcodes, + cell_barcodes=top_cells, whitelist=whitelist, collapsing_threshold=collapsing_threshold, ) + print("Collapsing wrong barcodes with original barcodes") (umis_per_cell, final_results, corrected_barcodes) = collapse_cells( true_to_false, umis_per_cell, final_results, ab_map ) @@ -443,7 +450,9 @@ def find_true_to_false_map( return true_to_false -def generate_sparse_matrices(final_results, ordered_tags_map, top_cells): +def generate_sparse_matrices( + final_results, ordered_tags_map, top_cells, umi_counts=False +): """ Create two sparse matrices with umi and read counts. @@ -452,23 +461,25 @@ def generate_sparse_matrices(final_results, ordered_tags_map, top_cells): ordered_tags_map (dict): Tags in order with indexes as values. Returns: - umi_results_matrix (scipy.sparse.dok_matrix): UMI counts - read_results_matrix (scipy.sparse.dok_matrix): Read counts + results_matrix (scipy.sparse.dok_matrix): UMI counts + """ - umi_results_matrix = sparse.dok_matrix( - (len(ordered_tags_map), len(top_cells)), dtype=int32 - ) - read_results_matrix = sparse.dok_matrix( - (len(ordered_tags_map), len(top_cells)), dtype=int32 + print(ordered_tags_map) + results_matrix = sparse.dok_matrix( + (len(ordered_tags_map) + 1, len(top_cells)), dtype=int32 ) for i, cell_barcode in enumerate(top_cells): for TAG in final_results[cell_barcode]: if final_results[cell_barcode][TAG]: - umi_results_matrix[ordered_tags_map[TAG]["id"], i] = len( - final_results[cell_barcode][TAG] - ) - read_results_matrix[ordered_tags_map[TAG]["id"], i] = sum( - final_results[cell_barcode][TAG].values() - ) - return (umi_results_matrix, read_results_matrix) + if umi_counts: + if TAG == "unmapped": + continue + results_matrix[ordered_tags_map[TAG]["id"], i] = len( + final_results[cell_barcode][TAG] + ) + else: + results_matrix[ordered_tags_map[TAG]["id"], i] = sum( + final_results[cell_barcode][TAG].values() + ) + return results_matrix diff --git a/setup.py b/setup.py index 0e60ceb..5481d95 100644 --- a/setup.py +++ b/setup.py @@ -27,6 +27,7 @@ "pandas>=0.23.4", "pybktree==1.1", "cython>=0.29.17", + "pooch==1.1.1", ], python_requires=">=3.7", ) From e8c0ff515dea0ed02c48c4b8ecb67e12ad16b0eb Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sun, 20 Sep 2020 17:27:44 +0200 Subject: [PATCH 18/77] fixed sprase output --- cite_seq_count/__main__.py | 96 +++++++++++++++++---------------- cite_seq_count/io.py | 21 ++++---- cite_seq_count/preprocessing.py | 15 ++++-- cite_seq_count/processing.py | 54 ++++++++++--------- setup.py | 2 +- tests/test_io.py | 42 ++++++++------- 6 files changed, 124 insertions(+), 106 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index cb672c1..659c992 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -49,9 +49,11 @@ def main(): # Load TAGs/ABs. ab_map = preprocessing.parse_tags_csv(args.tags) - ordered_tags_map, longest_tag_len = preprocessing.check_tags(ab_map, args.max_error) + named_tuples_tags_map, longest_tag_len = preprocessing.check_tags( + ab_map, args.max_error + ) named_tuples_tags_map = preprocessing.convert_to_named_tuple( - ordered_tags=ordered_tags_map + ordered_tags=named_tuples_tags_map ) # Identify input file(s) read1_paths, read2_paths = preprocessing.get_read_paths( @@ -157,10 +159,6 @@ def main(): for file_path in temp_files: os.remove(file_path) - # Select top cells - top_cells_tuple = umis_per_cell.most_common(args.expected_cells * 10) - top_cells = set([pair[0] for pair in top_cells_tuple]) - # Correct cell barcodes if args.bc_threshold != 0: if len(umis_per_cell) <= args.expected_cells: @@ -194,7 +192,6 @@ def main(): ) = processing.correct_cells_whitelist( final_results=final_results, umis_per_cell=umis_per_cell, - top_cells=top_cells, whitelist=whitelist, collapsing_threshold=args.bc_threshold, ab_map=named_tuples_tags_map, @@ -204,31 +201,37 @@ def main(): bcs_corrected = 0 # If given, use whitelist for top cells - # if whitelist: - # top_cells = whitelist - # # Add potential missing cell barcodes. - # for missing_cell in whitelist: - # if missing_cell in final_results: - # continue - # else: - # final_results[missing_cell] = dict() - # for TAG in named_tuples_tags_map: - # final_results[missing_cell][TAG.safe_name] = Counter() - # top_cells.add(missing_cell) - # else: - # Select top cells based on total umis per cell + top_cells_tuple = umis_per_cell.most_common(args.expected_cells * 10) + if whitelist: + # Add potential missing cell barcodes. + # for missing_cell in whitelist: + # if missing_cell in final_results: + # continue + # else: + # final_results[missing_cell] = dict() + # for TAG in named_tuples_tags_map: + # final_results[missing_cell][TAG.safe_name] = Counter() + # filtered_cells.add(missing_cell) + top_cells = set([pair[0] for pair in top_cells_tuple]) + filtered_cells = set() + for cell in top_cells: + if cell in whitelist: + filtered_cells.add(cell) + else: + # Select top cells based on total umis per cell + filtered_cells = set([pair[0] for pair in top_cells_tuple]) # Create sparse matrices for reads results read_results_matrix = processing.generate_sparse_matrices( final_results=final_results, - ordered_tags_map=ordered_tags_map, - top_cells=top_cells, + ordered_tags=named_tuples_tags_map, + filtered_cells=filtered_cells, ) # Write reads to file io.write_to_files( sparse_matrix=read_results_matrix, - top_cells=top_cells, - ordered_tags_map=ordered_tags_map, + filtered_cells=filtered_cells, + ordered_tags=named_tuples_tags_map, data_type="read", outfolder=args.outfolder, ) @@ -244,9 +247,9 @@ def main(): cells = {} n_cells = 0 num_chunks = 0 - - cell_batch_size = round(len(top_cells) / args.n_threads) + 1 - for cell in top_cells: + print("preparing UMI correction jobs") + cell_batch_size = round(len(filtered_cells) / args.n_threads) + 1 + for cell in filtered_cells: cells[cell] = final_results[cell] n_cells += 1 if n_cells % cell_batch_size == 0: @@ -300,47 +303,48 @@ def main(): if len(aberrant_cells) > 0: # Remove aberrant cells from the top cells for cell_barcode in aberrant_cells: - top_cells.remove(cell_barcode) + filtered_cells.remove(cell_barcode) # Create sparse aberrant cells matrix - (umi_aberrant_matrix, _) = processing.generate_sparse_matrices( + umi_aberrant_matrix = processing.generate_sparse_matrices( final_results=final_results, - ordered_tags_map=ordered_tags_map, - top_cells=aberrant_cells, + ordered_tags=named_tuples_tags_map, + filtered_cells=aberrant_cells, ) # Write uncorrected cells to dense output io.write_dense( sparse_matrix=umi_aberrant_matrix, - index=list(ordered_tags_map.keys()), + ordered_tags=named_tuples_tags_map, columns=aberrant_cells, outfolder=os.path.join(args.outfolder, "uncorrected_cells"), filename="dense_umis.tsv", ) - + named_tuples_tags_map.pop() umi_results_matrix = processing.generate_sparse_matrices( final_results=final_results, - ordered_tags_map=ordered_tags_map, - top_cells=top_cells, + ordered_tags=named_tuples_tags_map, + filtered_cells=filtered_cells, umi_counts=True, ) # Write umis to file io.write_to_files( sparse_matrix=umi_results_matrix, - top_cells=top_cells, - ordered_tags_map=ordered_tags_map, + filtered_cells=filtered_cells, + ordered_tags=named_tuples_tags_map, data_type="umi", outfolder=args.outfolder, ) # Write unmapped sequences - io.write_unmapped( - merged_no_match=merged_no_match, - top_unknowns=args.unknowns_top, - outfolder=args.outfolder, - filename=args.unmapped_file, - ) + if len(merged_no_match) > 0: + io.write_unmapped( + merged_no_match=merged_no_match, + top_unknowns=args.unknowns_top, + outfolder=args.outfolder, + filename=args.unmapped_file, + ) # Create report and write it to disk io.create_report( @@ -349,7 +353,7 @@ def main(): no_match=merged_no_match, version=argsparser.get_package_version(), start_time=start_time, - ordered_tags_map=ordered_tags_map, + ordered_tags=named_tuples_tags_map, umis_corrected=umis_corrected, bcs_corrected=bcs_corrected, bad_cells=aberrant_cells, @@ -364,8 +368,8 @@ def main(): print("Writing dense format output") io.write_dense( sparse_matrix=umi_results_matrix, - index=list(ordered_tags_map.keys()), - columns=top_cells, + ordered_tags=named_tuples_tags_map, + columns=filtered_cells, outfolder=args.outfolder, filename="dense_umis.tsv", ) diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index c8dae1a..266d5a0 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -13,13 +13,13 @@ from cite_seq_count import secondsToText -def write_to_files(sparse_matrix, top_cells, ordered_tags_map, data_type, outfolder): +def write_to_files(sparse_matrix, filtered_cells, ordered_tags, data_type, outfolder): """Write the umi and read sparse matrices to file in gzipped mtx format. Args: sparse_matrix (dok_matrix): Results in a sparse matrix. - top_cells (set): Set of cells that are selected for output. - ordered_tags_map (dict): Tags in order with indexes as values. + filtered_cells (set): Set of cells that are selected for output. + ordered_tags (dict): Tags in order with indexes as values. data_type (string): A string definning if the data is umi or read based. outfolder (string): Path to the output folder. """ @@ -27,14 +27,12 @@ def write_to_files(sparse_matrix, top_cells, ordered_tags_map, data_type, outfol os.makedirs(prefix, exist_ok=True) io.mmwrite(os.path.join(prefix, "matrix.mtx"), sparse_matrix) with gzip.open(os.path.join(prefix, "barcodes.tsv.gz"), "wb") as barcode_file: - for barcode in top_cells: + for barcode in filtered_cells: barcode_file.write("{}\n".format(barcode).encode()) with gzip.open(os.path.join(prefix, "features.tsv.gz"), "wb") as feature_file: - for feature in ordered_tags_map: + for feature in ordered_tags: feature_file.write( - "{}\t{}\n".format( - ordered_tags_map[feature]["sequence"], feature - ).encode() + "{}\t{}\n".format(feature.sequence, feature.name).encode() ) with open(os.path.join(prefix, "matrix.mtx"), "rb") as mtx_in: with gzip.open(os.path.join(prefix, "matrix.mtx") + ".gz", "wb") as mtx_gz: @@ -42,7 +40,7 @@ def write_to_files(sparse_matrix, top_cells, ordered_tags_map, data_type, outfol os.remove(os.path.join(prefix, "matrix.mtx")) -def write_dense(sparse_matrix, index, columns, outfolder, filename): +def write_dense(sparse_matrix, ordered_tags, columns, outfolder, filename): """ Writes a dense matrix in a csv format @@ -55,6 +53,9 @@ def write_dense(sparse_matrix, index, columns, outfolder, filename): """ prefix = os.path.join(outfolder) os.makedirs(prefix, exist_ok=True) + index = [] + for tag in ordered_tags: + index.append(tag.name) pandas_dense = pd.DataFrame(sparse_matrix.todense(), columns=columns, index=index) pandas_dense.to_csv(os.path.join(outfolder, filename), sep="\t") @@ -84,7 +85,7 @@ def create_report( no_match, version, start_time, - ordered_tags_map, + ordered_tags, umis_corrected, bcs_corrected, bad_cells, diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 3f5772d..1bd1907 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -139,19 +139,24 @@ def check_tags(tags, maximum_distance): ordered_tags = OrderedDict() longest_tag_len = 0 for i, tag_seq in enumerate(sorted(tags, key=len, reverse=True)): - ordered_tags[tags[tag_seq]] = {} - ordered_tags[tags[tag_seq]]["id"] = i - ordered_tags[tags[tag_seq]]["sequence"] = tag_seq + safe_name = sanitize_name(tags[tag_seq]) + ordered_tags[safe_name] = {} + ordered_tags[safe_name]["id"] = i + ordered_tags[safe_name]["sequence"] = tag_seq + ordered_tags[safe_name]["feature_name"] = tags[tag_seq] if len(tag_seq) > longest_tag_len: longest_tag_len = len(tag_seq) ordered_tags["unmapped"] = {} ordered_tags["unmapped"]["id"] = i + 1 ordered_tags["unmapped"]["sequence"] = "UNKNOWN" + ordered_tags["unmapped"]["feature_name"] = "unmapped" # If only one TAG is provided, then no distances to compare. if len(tags) == 1: ordered_tags["unmapped"] = {} ordered_tags["unmapped"]["id"] = 2 + ordered_tags["unmapped"]["sequence"] = "UNKNOWN" + ordered_tags["unmapped"]["feature_name"] = "unmapped" return (ordered_tags, longest_tag_len) offending_pairs = [] @@ -199,8 +204,8 @@ def convert_to_named_tuple(ordered_tags): for index, tag_name in enumerate(ordered_tags): tag_list.append( tag( - safe_name=sanitize_name(tag_name), - name=tag_name, + safe_name=tag_name, + name=ordered_tags[tag_name]["feature_name"], sequence=ordered_tags[tag_name]["sequence"], id=(index), ) diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 8543a2d..57849bc 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -42,17 +42,17 @@ def find_best_match(TAG_seq, tags, maximum_distance): Returns: best_match (string): The TAG name that will be used for counting. """ - best_match = "unmapped" + best_match = len(tags) - 1 best_score = maximum_distance for tag in tags: # pylint: disable=no-member score = Levenshtein.hamming(tag.sequence, TAG_seq[: len(tag.sequence)]) if score == 0: # Best possible match - return tag.name + return tag.id elif score <= best_score: best_score = score - best_match = tag.name + best_match = tag.id return best_match return best_match @@ -221,7 +221,7 @@ def correct_umis(umi_correction_input): Args: final_results (dict): Dict of dict of Counters with mapping results. collapsing_threshold (int): Max distance between umis. - top_cells (set): Set of cells to go through. + filtered_cells (set): Set of cells to go through. max_umis (int): Maximum UMIs to consider for one cluster. Returns: @@ -309,7 +309,7 @@ def collapse_cells(true_to_false, umis_per_cell, final_results, ab_map): if real_barcode not in final_results: final_results[real_barcode] = defaultdict() for TAG in ab_map: - final_results[real_barcode][TAG.safe_name] = Counter() + final_results[real_barcode][TAG.id] = Counter() for wrong_barcode in true_to_false[real_barcode]: temp = final_results.pop(wrong_barcode) @@ -370,7 +370,7 @@ def correct_cells( def correct_cells_whitelist( - final_results, umis_per_cell, whitelist, top_cells, collapsing_threshold, ab_map + final_results, umis_per_cell, whitelist, collapsing_threshold, ab_map ): """ Corrects cell barcodes. @@ -391,14 +391,14 @@ def correct_cells_whitelist( # pylint: disable=no-member barcode_tree = pybktree.BKTree(Levenshtein.hamming, whitelist) print("Generated barcode tree from whitelist") - + barcodes = set(final_results.keys()) print("Finding reference candidates") - print("Processing {:,} cell barcodes".format(len(top_cells))) + print("Processing {:,} cell barcodes".format(len(barcodes))) # Run with one process true_to_false = find_true_to_false_map( barcode_tree=barcode_tree, - cell_barcodes=top_cells, + cell_barcodes=barcodes, whitelist=whitelist, collapsing_threshold=collapsing_threshold, ) @@ -451,35 +451,37 @@ def find_true_to_false_map( def generate_sparse_matrices( - final_results, ordered_tags_map, top_cells, umi_counts=False + final_results, ordered_tags, filtered_cells, umi_counts=False ): """ Create two sparse matrices with umi and read counts. Args: final_results (dict): Results in a dict of dicts of Counters. - ordered_tags_map (dict): Tags in order with indexes as values. + ordered_tags (list): Ordered tags in a list of tuples. Returns: results_matrix (scipy.sparse.dok_matrix): UMI counts """ - print(ordered_tags_map) + unmapped_id = len(ordered_tags) results_matrix = sparse.dok_matrix( - (len(ordered_tags_map) + 1, len(top_cells)), dtype=int32 + (len(ordered_tags), len(filtered_cells)), dtype=int32 ) - for i, cell_barcode in enumerate(top_cells): - for TAG in final_results[cell_barcode]: - if final_results[cell_barcode][TAG]: - if umi_counts: - if TAG == "unmapped": - continue - results_matrix[ordered_tags_map[TAG]["id"], i] = len( - final_results[cell_barcode][TAG] - ) - else: - results_matrix[ordered_tags_map[TAG]["id"], i] = sum( - final_results[cell_barcode][TAG].values() - ) + # print(ordered_tags) + for i, cell_barcode in enumerate(filtered_cells): + if cell_barcode not in final_results: + continue + for TAG_id in final_results[cell_barcode]: + # if TAG_id in final_results[cell_barcode]: + if umi_counts: + + if TAG_id == unmapped_id: + continue + results_matrix[TAG_id, i] = len(final_results[cell_barcode][TAG_id]) + else: + results_matrix[TAG_id, i] = sum( + final_results[cell_barcode][TAG_id].values() + ) return results_matrix diff --git a/setup.py b/setup.py index 5481d95..95f6539 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setuptools.setup( name="CITE-seq-Count", - version="1.4.3", + version="1.5.0", author="Roelli Patrick", author_email="patrick.roelli@gmail.com", description="A python package to map reads from CITE-seq or hashing data for single cell experiments", diff --git a/tests/test_io.py b/tests/test_io.py index f71e027..48b1960 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -2,32 +2,38 @@ from cite_seq_count import io - @pytest.fixture def data(): from collections import OrderedDict from scipy import sparse - test_matrix = sparse.dok_matrix((4,2)) - test_matrix[1,1] = 1 + + test_matrix = sparse.dok_matrix((4, 2)) + test_matrix[1, 1] = 1 pytest.sparse_matrix = test_matrix - pytest.top_cells = set(['ACTGTTTTATTGGCCT','TTCATAAGGTAGGGAT']) - pytest.ordered_tags_map = OrderedDict({ - 'test3':{'id':0, 'sequence': 'CGTA'}, - 'test2':{'id':1, 'sequence': 'CGTA'}, - 'test1': {'id':3, 'sequence': 'CGTA'}, - 'unmapped': {'id':4, 'sequence': 'CGTA'} - }) - pytest.data_type = 'umi' - pytest.outfolder = 'tests/test_data/' + pytest.filtered_cells = set(["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]) + pytest.ordered_tags_map = OrderedDict( + { + "test3": {"id": 0, "sequence": "CGTA"}, + "test2": {"id": 1, "sequence": "CGTA"}, + "test1": {"id": 3, "sequence": "CGTA"}, + "unmapped": {"id": 4, "sequence": "CGTA"}, + } + ) + pytest.data_type = "umi" + pytest.outfolder = "tests/test_data/" + def test_write_to_files(data, tmpdir): import gzip import scipy - io.write_to_files(pytest.sparse_matrix, - pytest.top_cells, + + io.write_to_files( + pytest.sparse_matrix, + pytest.filtered_cells, pytest.ordered_tags_map, pytest.data_type, - tmpdir) - file = tmpdir.join('umi_count/matrix.mtx.gz') - with gzip.open(file, 'rb') as mtx_file: - assert isinstance(scipy.io.mmread(mtx_file) ,scipy.sparse.coo.coo_matrix) + tmpdir, + ) + file = tmpdir.join("umi_count/matrix.mtx.gz") + with gzip.open(file, "rb") as mtx_file: + assert isinstance(scipy.io.mmread(mtx_file), scipy.sparse.coo.coo_matrix) From 343f8a236e13db84b6e67ac93335862a0a83dfde Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sun, 20 Sep 2020 17:48:48 +0200 Subject: [PATCH 19/77] fixed debugging --- CHANGELOG.md | 8 +++++++- cite_seq_count/__main__.py | 39 ++++++++++++++++++------------------ cite_seq_count/io.py | 10 +++++---- cite_seq_count/processing.py | 2 +- 4 files changed, 34 insertions(+), 25 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 98c955d..d9eeaa7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,9 +8,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/). ### Added - `CITE-seq-Count` is now Compatible with trimmed data. There is a new `too_short` category in the `run_report.yaml` that will let you know how much you lost due to reads being too short. #123 - - UMI correction is now also parallelized and will use the threads proposed. + - UMI correction is now also parallelized and will use the threads given. - Added a check at the end of the mapping. If more than 99% of the reads are unmapped, CITE-seq-Count will exit. - (BETA) New functionnality that will fetch the chemistry definition from a remote repo to simplify usage and reduce human errors. + - New step just after mapping that will exit if more than 99% of the reads are unmapped. ### Changed - The `features.csv` now has different columns for the tag name and the tag sequence. This keeps the relevant information @@ -20,6 +21,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - There are new options now for parallel computing. `--chunk_size` Determines how many reads will be read per chunk. #99 - `--sliding-window` now only checks for exact matches. - Added cython dependency based on issue #117 + - The main results dict will now use an `int` as keys reducing memory footprint. + - Fixed the issue #92 with using `--bc_collapsing_dist 0`. + + ### Removed + - Unmmapped reads are not umi corrected anymore reducing running time and memory usage. ## [1.4.3] - 05.10.2019 diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 659c992..5c905c6 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -49,12 +49,8 @@ def main(): # Load TAGs/ABs. ab_map = preprocessing.parse_tags_csv(args.tags) - named_tuples_tags_map, longest_tag_len = preprocessing.check_tags( - ab_map, args.max_error - ) - named_tuples_tags_map = preprocessing.convert_to_named_tuple( - ordered_tags=named_tuples_tags_map - ) + ordered_tags, longest_tag_len = preprocessing.check_tags(ab_map, args.max_error) + ordered_tags = preprocessing.convert_to_named_tuple(ordered_tags=ordered_tags) # Identify input file(s) read1_paths, read2_paths = preprocessing.get_read_paths( args.read1_path, args.read2_path @@ -96,8 +92,10 @@ def main(): if args.sliding_window: R2_max_length = read2_lengths[0] + maximum_distance = 0 else: R2_max_length = longest_tag_len + maximum_distance = args.max_error ( input_queue, @@ -112,7 +110,8 @@ def main(): R2_max_length=R2_max_length, n_reads_to_chunk=n_reads, chemistry_def=chemistry_def, - named_tuples_tags_map=named_tuples_tags_map, + ordered_tags=ordered_tags, + maximum_distance=maximum_distance, ) # Initialize the counts dicts that will be generated from each input fastq pair final_results = defaultdict(lambda: defaultdict(Counter)) @@ -182,7 +181,7 @@ def main(): umis_per_cell=umis_per_cell, expected_cells=args.expected_cells, collapsing_threshold=args.bc_threshold, - ab_map=named_tuples_tags_map, + ab_map=ordered_tags, ) else: ( @@ -194,7 +193,7 @@ def main(): umis_per_cell=umis_per_cell, whitelist=whitelist, collapsing_threshold=args.bc_threshold, - ab_map=named_tuples_tags_map, + ab_map=ordered_tags, ) else: print("Skipping cell barcode correction") @@ -209,7 +208,7 @@ def main(): # continue # else: # final_results[missing_cell] = dict() - # for TAG in named_tuples_tags_map: + # for TAG in ordered_tags: # final_results[missing_cell][TAG.safe_name] = Counter() # filtered_cells.add(missing_cell) top_cells = set([pair[0] for pair in top_cells_tuple]) @@ -224,14 +223,14 @@ def main(): # Create sparse matrices for reads results read_results_matrix = processing.generate_sparse_matrices( final_results=final_results, - ordered_tags=named_tuples_tags_map, + ordered_tags=ordered_tags, filtered_cells=filtered_cells, ) # Write reads to file io.write_to_files( sparse_matrix=read_results_matrix, filtered_cells=filtered_cells, - ordered_tags=named_tuples_tags_map, + ordered_tags=ordered_tags, data_type="read", outfolder=args.outfolder, ) @@ -308,22 +307,23 @@ def main(): # Create sparse aberrant cells matrix umi_aberrant_matrix = processing.generate_sparse_matrices( final_results=final_results, - ordered_tags=named_tuples_tags_map, + ordered_tags=ordered_tags, filtered_cells=aberrant_cells, ) # Write uncorrected cells to dense output io.write_dense( sparse_matrix=umi_aberrant_matrix, - ordered_tags=named_tuples_tags_map, + ordered_tags=ordered_tags, columns=aberrant_cells, outfolder=os.path.join(args.outfolder, "uncorrected_cells"), filename="dense_umis.tsv", ) - named_tuples_tags_map.pop() + # delete the last element (unmapped) + ordered_tags.pop() umi_results_matrix = processing.generate_sparse_matrices( final_results=final_results, - ordered_tags=named_tuples_tags_map, + ordered_tags=ordered_tags, filtered_cells=filtered_cells, umi_counts=True, ) @@ -332,7 +332,7 @@ def main(): io.write_to_files( sparse_matrix=umi_results_matrix, filtered_cells=filtered_cells, - ordered_tags=named_tuples_tags_map, + ordered_tags=ordered_tags, data_type="umi", outfolder=args.outfolder, ) @@ -353,7 +353,7 @@ def main(): no_match=merged_no_match, version=argsparser.get_package_version(), start_time=start_time, - ordered_tags=named_tuples_tags_map, + ordered_tags=ordered_tags, umis_corrected=umis_corrected, bcs_corrected=bcs_corrected, bad_cells=aberrant_cells, @@ -361,6 +361,7 @@ def main(): R2_too_short=R2_too_short, args=args, chemistry_def=chemistry_def, + maximum_distance=maximum_distance, ) # Write dense matrix to disk if requested @@ -368,7 +369,7 @@ def main(): print("Writing dense format output") io.write_dense( sparse_matrix=umi_results_matrix, - ordered_tags=named_tuples_tags_map, + ordered_tags=ordered_tags, columns=filtered_cells, outfolder=args.outfolder, filename="dense_umis.tsv", diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 266d5a0..4b54681 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -93,6 +93,7 @@ def create_report( R2_too_short, args, chemistry_def, + maximum_distance, ): """ Creates a report with details about the run in a yaml format. @@ -164,7 +165,7 @@ def create_report( chemistry_def.umi_barcode_start, chemistry_def.umi_barcode_end, args.expected_cells, - args.max_error, + maximum_distance, chemistry_def.R2_trim_start, ) ) @@ -177,7 +178,8 @@ def write_chunks_to_disk( R2_max_length, n_reads_to_chunk, chemistry_def, - named_tuples_tags_map, + ordered_tags, + maximum_distance, ): """ """ @@ -258,9 +260,9 @@ def write_chunks_to_disk( input_queue.append( mapping_input( filename=temp_filename, - tags=named_tuples_tags_map, + tags=ordered_tags, debug=args.debug, - maximum_distance=args.max_error, + maximum_distance=maximum_distance, sliding_window=args.sliding_window, ) ) diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 57849bc..4b6d04b 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -154,7 +154,7 @@ def map_reads(mapping_input): len(cell_barcode), len(UMI), len(read2), - best_match, + tags[best_match].id, ) ) sys.stdout.flush() From f0598b513708deaa544e2eb2eaf26665171d95ab Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Thu, 10 Dec 2020 08:08:52 +0100 Subject: [PATCH 20/77] correction in README --- CHANGELOG.md | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index d9eeaa7..da5eb7a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,6 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - UMI correction is now also parallelized and will use the threads given. - Added a check at the end of the mapping. If more than 99% of the reads are unmapped, CITE-seq-Count will exit. - (BETA) New functionnality that will fetch the chemistry definition from a remote repo to simplify usage and reduce human errors. - - New step just after mapping that will exit if more than 99% of the reads are unmapped. ### Changed - The `features.csv` now has different columns for the tag name and the tag sequence. This keeps the relevant information @@ -23,9 +22,12 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Added cython dependency based on issue #117 - The main results dict will now use an `int` as keys reducing memory footprint. - Fixed the issue #92 with using `--bc_collapsing_dist 0`. + - Fixed issue #122 and now properly checks number of files. + - Fixed the error in the documentation pointed by issue #132. + - The report is now a proper yaml file. Issue #133 ### Removed - - Unmmapped reads are not umi corrected anymore reducing running time and memory usage. + - Unmmapped reads are not umi corrected anymore reducing run time and memory usage. ## [1.4.3] - 05.10.2019 From 729db14f90f48ad07482f06f265d60565e7a1d7a Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Thu, 10 Dec 2020 08:48:42 +0100 Subject: [PATCH 21/77] other merge conflicts --- cite_seq_count/preprocessing.py | 114 ++++++++------------------------ 1 file changed, 28 insertions(+), 86 deletions(-) diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 548d050..d1142d9 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -8,54 +8,9 @@ from collections import OrderedDict from collections import namedtuple from itertools import combinations -<<<<<<< HEAD -from itertools import islice - -======= from itertools import islice -def get_indexes(start_index, chunk_size, nth): - """ - Creates indexes from a reference index, a chunk size an nth number - - Args: - start_index (int): first position - chunk_size (int): Chunk size - nth (int): The nth number - - Returns: - list: First and last position of indexes - """ - start_index = nth * chunk_size - stop_index = chunk_size + nth * chunk_size - return [start_index, stop_index] - - -def chunk_reads(n_reads, n): - """ - Creates a list of indexes for the islice iterator from the map_reads function. - - Args: - n_reads (int): Number of reads to split - n (int): How many buckets for the split. - Returns: - indexes (list(list)): Each entry contains the first and the last index for a read. - """ - indexes = list() - if n_reads % n == 0: - chunk_size = int(n_reads / n) - rest = 0 - else: - chunk_size = floor(n_reads / n) - rest = n_reads - (n * chunk_size) - for i in range(0, n): - indexes.append(get_indexes(i, chunk_size, i)) - indexes[-1][1] += rest - return indexes - ->>>>>>> 1.4.4 - def parse_whitelist_csv(filename, barcode_length, collapsing_threshold): """Reads white-listed barcodes from a CSV file. @@ -174,32 +129,23 @@ def check_tags(tags, maximum_distance): """ ordered_tags = OrderedDict() -<<<<<<< HEAD longest_tag_len = 0 for i, tag_seq in enumerate(sorted(tags, key=len, reverse=True)): ordered_tags[tags[tag_seq]] = {} - ordered_tags[tags[tag_seq]]['id'] = i - ordered_tags[tags[tag_seq]]['sequence'] = tag_seq + ordered_tags[tags[tag_seq]]["id"] = i + ordered_tags[tags[tag_seq]]["sequence"] = tag_seq if len(tag_seq) > longest_tag_len: longest_tag_len = len(tag_seq) - - ordered_tags['unmapped'] = {} - ordered_tags['unmapped']['id'] = i + 1 - ordered_tags['unmapped']['sequence'] = 'UNKNOWN' - # If only one TAG is provided, then no distances to compare. - if (len(tags) == 1): - ordered_tags['unmapped'] = {} - ordered_tags['unmapped']['id'] = 2 - return(ordered_tags, longest_tag_len) - -======= - for tag in sorted(tags, key=len, reverse=True): - ordered_tags[tag] = tags[tag] + "-" + tag + + ordered_tags["unmapped"] = {} + ordered_tags["unmapped"]["id"] = i + 1 + ordered_tags["unmapped"]["sequence"] = "UNKNOWN" # If only one TAG is provided, then no distances to compare. if len(tags) == 1: - return ordered_tags + ordered_tags["unmapped"] = {} + ordered_tags["unmapped"]["id"] = 2 + return (ordered_tags, longest_tag_len) ->>>>>>> 1.4.4 offending_pairs = [] for a, b in combinations(tags.keys(), 2): distance = Levenshtein.distance(a, b) @@ -228,26 +174,31 @@ def check_tags(tags, maximum_distance): tag1=pair[0], tag2=pair[1], distance=pair[2] ) ) -<<<<<<< HEAD - sys.exit('Exiting the application.\n') -======= sys.exit("Exiting the application.\n") - return ordered_tags ->>>>>>> 1.4.4 - return(ordered_tags, longest_tag_len) + return (ordered_tags, longest_tag_len) + def sanitize_name(string): - return(string.replace('-', '_')) + return string.replace("-", "_") + def convert_to_named_tuple(ordered_tags): - #all_tags = namedtuple('all_tags', [sanitize_name(tag) for tag in ordered_tags.keys()]) - tag = namedtuple('tag', ['safe_name','name','sequence', 'id']) + # all_tags = namedtuple('all_tags', [sanitize_name(tag) for tag in ordered_tags.keys()]) + tag = namedtuple("tag", ["safe_name", "name", "sequence", "id"]) tag_list = [] for index, tag_name in enumerate(ordered_tags): - tag_list.append(tag(safe_name=sanitize_name(tag_name), name=tag_name, sequence=ordered_tags[tag_name]['sequence'], id=(index))) - #all_tags[index+1]=ordered_tags[tag_name]['sequence'] - return(tag_list) + tag_list.append( + tag( + safe_name=sanitize_name(tag_name), + name=tag_name, + sequence=ordered_tags[tag_name]["sequence"], + id=(index), + ) + ) + # all_tags[index+1]=ordered_tags[tag_name]['sequence'] + return tag_list + def get_read_length(filename): """Check wether SEQUENCE lengths are consistent in a FASTQ file and return @@ -262,31 +213,22 @@ def get_read_length(filename): """ with gzip.open(filename, "r") as fastq_file: secondlines = islice(fastq_file, 1, 1000, 4) - #temp_length = len(next(secondlines).rstrip()) + # temp_length = len(next(secondlines).rstrip()) for sequence in secondlines: read_length = len(sequence.rstrip()) -<<<<<<< HEAD # if (temp_length != read_length): # sys.exit( # '[ERROR] Sequence length in {} is not consistent. Please, trim all ' # 'sequences at the same length.\n' # 'Exiting the application.\n'.format(filename) # ) - return(read_length) -======= - if temp_length != read_length: - sys.exit( - "[ERROR] Sequence length in {} is not consistent. Please, trim all " - "sequences at the same length.\n" - "Exiting the application.\n".format(filename) - ) return read_length ->>>>>>> 1.4.4 def get_chunk_strategy(read1_paths, read2_paths, chunk_size): pass + def check_barcodes_lengths(read1_length, cb_first, cb_last, umi_first, umi_last): """Check Read1 length against CELL and UMI barcodes length. From 82764946139feba5e6b089ea16e72bd423ff20b0 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Thu, 10 Dec 2020 20:56:59 +0100 Subject: [PATCH 22/77] resolved more conflicts --- cite_seq_count/preprocessing.py | 43 --------------------------------- 1 file changed, 43 deletions(-) diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 256f1c4..230c7e6 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -29,7 +29,6 @@ def parse_whitelist_csv(filename, barcode_length, collapsing_threshold): """ STRIP_CHARS = '"0123456789- \t\n' -<<<<<<< HEAD cell_pattern = regex.compile(r"[ATGC]{{{}}}".format(barcode_length)) if filename.endswith(".gz"): @@ -44,16 +43,6 @@ def parse_whitelist_csv(filename, barcode_length, collapsing_threshold): if (len(row[0].strip(STRIP_CHARS)) == barcode_length) ] -======= - with open(filename, mode="r") as csv_file: - csv_reader = csv.reader(csv_file) - cell_pattern = regex.compile(r"[ATGC]{{{}}}".format(barcode_length)) - whitelist = [ - row[0].strip(STRIP_CHARS) - for row in csv_reader - if (len(row[0].strip(STRIP_CHARS)) == barcode_length) - ] ->>>>>>> develop for cell_barcode in whitelist: if not cell_pattern.match(cell_barcode): sys.exit( @@ -150,36 +139,24 @@ def check_tags(tags, maximum_distance): ordered_tags = OrderedDict() longest_tag_len = 0 for i, tag_seq in enumerate(sorted(tags, key=len, reverse=True)): -<<<<<<< HEAD safe_name = sanitize_name(tags[tag_seq]) ordered_tags[safe_name] = {} ordered_tags[safe_name]["id"] = i ordered_tags[safe_name]["sequence"] = tag_seq ordered_tags[safe_name]["feature_name"] = tags[tag_seq] -======= - ordered_tags[tags[tag_seq]] = {} - ordered_tags[tags[tag_seq]]["id"] = i - ordered_tags[tags[tag_seq]]["sequence"] = tag_seq ->>>>>>> develop if len(tag_seq) > longest_tag_len: longest_tag_len = len(tag_seq) ordered_tags["unmapped"] = {} ordered_tags["unmapped"]["id"] = i + 1 ordered_tags["unmapped"]["sequence"] = "UNKNOWN" -<<<<<<< HEAD ordered_tags["unmapped"]["feature_name"] = "unmapped" -======= ->>>>>>> develop # If only one TAG is provided, then no distances to compare. if len(tags) == 1: ordered_tags["unmapped"] = {} ordered_tags["unmapped"]["id"] = 2 -<<<<<<< HEAD ordered_tags["unmapped"]["sequence"] = "UNKNOWN" ordered_tags["unmapped"]["feature_name"] = "unmapped" -======= ->>>>>>> develop return (ordered_tags, longest_tag_len) offending_pairs = [] @@ -227,13 +204,8 @@ def convert_to_named_tuple(ordered_tags): for index, tag_name in enumerate(ordered_tags): tag_list.append( tag( -<<<<<<< HEAD safe_name=tag_name, name=ordered_tags[tag_name]["feature_name"], -======= - safe_name=sanitize_name(tag_name), - name=tag_name, ->>>>>>> develop sequence=ordered_tags[tag_name]["sequence"], id=(index), ) @@ -300,11 +272,6 @@ def check_barcodes_lengths(read1_length, cb_first, cb_last, umi_first, umi_last) read1_length, barcode_umi_length ) ) -<<<<<<< HEAD -======= - - return (barcode_slice, umi_slice, barcode_umi_length) ->>>>>>> develop def blocks(files, size=65536): @@ -338,11 +305,7 @@ def get_n_lines(file_path): Returns: n_lines (int): Number of lines in the file """ -<<<<<<< HEAD print("Counting number of reads in file {}".format(file_path)) -======= - print("Counting number of reads") ->>>>>>> develop with gzip.open(file_path, "rt", encoding="utf-8", errors="ignore") as f: n_lines = sum(bl.count("\n") for bl in blocks(f)) if n_lines % 4 != 0: @@ -370,13 +333,7 @@ def get_read_paths(read1_path, read2_path): if len(_read1_path) != len(_read2_path): sys.exit( "Unequal number of read1 ({}) and read2({}) files provided" -<<<<<<< HEAD - "\n Exiting".format(len(read1_path), len(read2_path)) - ) - return (_read1_path, _read2_path) -======= "\n Exiting".format(len(_read1_path), len(_read2_path)) ) return (_read1_path, _read2_path) ->>>>>>> develop From 1b604071782a025005bd3bc0eaf6b26fdd671ee0 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Mon, 14 Dec 2020 21:55:49 +0100 Subject: [PATCH 23/77] rewrote all preprocssing tests and got rid of step for tags --- CHANGELOG.md | 4 + cite_seq_count/__main__.py | 2 +- cite_seq_count/argsparser.py | 6 +- cite_seq_count/chemistry.py | 6 +- cite_seq_count/preprocessing.py | 126 ++++++----------- cite_seq_count/processing.py | 4 - docs/docs/Running-the-script.md | 11 +- tests/test_data/tags/fail/false_sequence.csv | 10 ++ tests/test_data/tags/fail/missing_entry.csv | 10 ++ tests/test_data/tags/fail/wrong_header.csv | 10 ++ tests/test_data/tags/{ => pass}/correct.csv | 1 + tests/test_data/tags/pass/correct_2.csv | 10 ++ tests/test_preprocessing.py | 137 ++++++++++++------- 13 files changed, 186 insertions(+), 151 deletions(-) create mode 100644 tests/test_data/tags/fail/false_sequence.csv create mode 100644 tests/test_data/tags/fail/missing_entry.csv create mode 100644 tests/test_data/tags/fail/wrong_header.csv rename tests/test_data/tags/{ => pass}/correct.csv (92%) create mode 100644 tests/test_data/tags/pass/correct_2.csv diff --git a/CHANGELOG.md b/CHANGELOG.md index da5eb7a..c62318e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -25,6 +25,10 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Fixed issue #122 and now properly checks number of files. - Fixed the error in the documentation pointed by issue #132. - The report is now a proper yaml file. Issue #133 + - Removed distance checking on the whitelist because too slow for long whitelists. + - Tags csv file now requires a header with at least "sequence" and "feature_name". + - Updated tags file parsing to make it more reliable. + - Added new tests to help out contributions. ### Removed - Unmmapped reads are not umi corrected anymore reducing run time and memory usage. diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 5c905c6..d7cdc0b 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -50,7 +50,7 @@ def main(): # Load TAGs/ABs. ab_map = preprocessing.parse_tags_csv(args.tags) ordered_tags, longest_tag_len = preprocessing.check_tags(ab_map, args.max_error) - ordered_tags = preprocessing.convert_to_named_tuple(ordered_tags=ordered_tags) + # ordered_tags = preprocessing.convert_to_named_tuple(ordered_tags=ordered_tags) # Identify input file(s) read1_paths, read2_paths = preprocessing.get_read_paths( args.read1_path, args.read2_path diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py index 47a986e..034d9e9 100644 --- a/cite_seq_count/argsparser.py +++ b/cite_seq_count/argsparser.py @@ -73,9 +73,11 @@ def get_args(): help=( "The path to the csv file containing the antibody\n" "barcodes as well as their respective names.\n\n" + "Requires feature_name and sequence in the header\n\n" "Example of an antibody barcode file structure:\n\n" - "\tATGCGA,First_tag_name\n" - "\tGTCATG,Second_tag_name" + "\tfeature_name,sequence\n" + "\tFirst_tag_name,ATGCGA\n" + "\tSecond_tag_name,GTCATG" ), ) # BARCODES group. diff --git a/cite_seq_count/chemistry.py b/cite_seq_count/chemistry.py index 3df59ff..f279e49 100644 --- a/cite_seq_count/chemistry.py +++ b/cite_seq_count/chemistry.py @@ -139,21 +139,19 @@ def create_chemistry_definition(args): def setup_chemistry(args): if args.chemistry: chemistry_def = get_chemistry_definition(args.chemistry) - (whitelist, args.bc_threshold) = preprocessing.parse_whitelist_csv( + whitelist = preprocessing.parse_whitelist_csv( filename=chemistry_def.whitelist_path, barcode_length=chemistry_def.cell_barcode_end - chemistry_def.cell_barcode_start + 1, - collapsing_threshold=args.bc_threshold, ) else: chemistry_def = create_chemistry_definition(args) if args.whitelist: print("Loading whitelist") - (whitelist, args.bc_threshold) = preprocessing.parse_whitelist_csv( + whitelist = preprocessing.parse_whitelist_csv( filename=args.whitelist, barcode_length=args.cb_last - args.cb_first + 1, - collapsing_threshold=args.bc_threshold, ) else: whitelist = False diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 230c7e6..72b1c46 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -12,7 +12,7 @@ from itertools import islice -def parse_whitelist_csv(filename, barcode_length, collapsing_threshold): +def parse_whitelist_csv(filename, barcode_length): """Reads white-listed barcodes from a CSV file. The function accepts plain barcodes or even 10X style barcodes with the @@ -21,11 +21,9 @@ def parse_whitelist_csv(filename, barcode_length, collapsing_threshold): Args: filename (str): Whitelist barcode file. barcode_length (int): Length of the expected barcodes. - collapsing_threshold (int): Maximum distance to collapse cell barcodes. Returns: set: The set of white-listed barcodes. - int: Collasping threshold """ STRIP_CHARS = '"0123456789- \t\n' @@ -50,52 +48,20 @@ def parse_whitelist_csv(filename, barcode_length, collapsing_threshold): cell_barcode ) ) - # collapsing_threshold=test_cell_distances(whitelist, collapsing_threshold) if len(whitelist) == 0: sys.exit( "Please check cell barcode indexes -cbs, -cbl because none of the given whitelist is valid." ) - return (set(whitelist), collapsing_threshold) - - -def test_cell_distances(whitelist, collapsing_threshold): - """Tests cell barcode distances to validate provided cell barcode collapsing threshold - - Function needs the given whitelist as well as the threshold. - If the value is too high, it will rerun until an acceptable value is found. - - Args: - whitelist (set): Whitelist barcode set - collapsing_threshold (int): Value of threshold - - Returns: - collapsing_threshold (int): Valid threshold - """ - ok = False - while not ok: - print( - "Testing cell barcode collapsing threshold of {}".format( - collapsing_threshold - ) - ) - all_comb = combinations(whitelist, 2) - for comb in all_comb: - # pylint: disable=no-member - if Levenshtein.hamming(comb[0], comb[1]) <= collapsing_threshold: - collapsing_threshold -= 1 - print("Value is too high, reducing it by 1") - break - else: - ok = True - print("Using {} for cell barcode collapsing threshold".format(collapsing_threshold)) - return collapsing_threshold + return set(whitelist) def parse_tags_csv(filename): - """Reads the TAGs from a CSV file. + """Reads the TAGs from a CSV file. Checks if sequences are made of ATGC - The expected file format (no header) is: TAG,TAG_NAME. + The expected file format has a header with "sequence" and "feature_name". + Order doesn't matter. e.g. file content + sequence,feature_name GTCAACTCTTTAGCG,Hashtag_1 TGATGGCCTATTGGG,Hashtag_2 TTCCGCCTCTCTTTG,Hashtag_3 @@ -104,30 +70,46 @@ def parse_tags_csv(filename): filename (str): TAGs file. Returns: - dict: A dictionary containing the TAGs and their names. + dict: A dictionary containing using sequences as keys and names as values. """ + REQUIRED_HEADER = ["sequence", "feature_name"] + atgc_test = regex.compile("^[ATGC]{1,}$") with open(filename, mode="r") as csv_file: csv_reader = csv.reader(csv_file) tags = {} - for row in csv_reader: - tags[row[0].strip()] = row[1].strip() + header = next(csv_reader) + set_dif = set(REQUIRED_HEADER) - set(header) + if len(set_dif) != 0: + raise SystemExit( + "The header is missing {}. Exiting".format(",".join(list(set_dif))) + ) + sequence_id = header.index("sequence") + feature_id = header.index("feature_name") + for i, row in enumerate(csv_reader): + sequence = row[sequence_id].strip() + + if not regex.match(atgc_test, sequence): + raise SystemExit( + "Sequence {} on line {} is not only composed of ATGC. Exiting".format( + sequence, i + ) + ) + tags[sequence] = row[feature_id].strip() return tags def check_tags(tags, maximum_distance): """Evaluates the distance between the TAGs based on the `maximum distance` argument provided. - - Additionally, it adds the barcode to the name of the TAG circumventing the - need of having to share the mapping of the antibody and the barcode. The output will have the keys sorted by TAG length (longer first). This way, longer barcodes will be evaluated first. + Adds unmapped category as well. Args: - tags (dict): A dictionary with the TAGs + TAG Names. - maximum_distance (int): The maximum Levenshtein distance allowed + tags (dict): A dictionary with TAG sequences as keys and names as values. + maximum_distance (int): The minimum Levenshtein distance allowed between two TAGs. Returns: @@ -136,44 +118,30 @@ def check_tags(tags, maximum_distance): int: the length of the longest TAG """ - ordered_tags = OrderedDict() + tag = namedtuple("tag", ["name", "sequence", "id"]) longest_tag_len = 0 + seq_list = [] + tag_list = [] for i, tag_seq in enumerate(sorted(tags, key=len, reverse=True)): safe_name = sanitize_name(tags[tag_seq]) - ordered_tags[safe_name] = {} - ordered_tags[safe_name]["id"] = i - ordered_tags[safe_name]["sequence"] = tag_seq - ordered_tags[safe_name]["feature_name"] = tags[tag_seq] + + # for index, tag_name in enumerate(ordered_tags): + tag_list.append(tag(name=safe_name, sequence=tag_seq, id=i,)) if len(tag_seq) > longest_tag_len: longest_tag_len = len(tag_seq) - - ordered_tags["unmapped"] = {} - ordered_tags["unmapped"]["id"] = i + 1 - ordered_tags["unmapped"]["sequence"] = "UNKNOWN" - ordered_tags["unmapped"]["feature_name"] = "unmapped" + seq_list.append(tag_seq) + tag_list.append(tag(name="unmapped", sequence="UNKNOWN", id=i + 1,)) # If only one TAG is provided, then no distances to compare. if len(tags) == 1: - ordered_tags["unmapped"] = {} - ordered_tags["unmapped"]["id"] = 2 - ordered_tags["unmapped"]["sequence"] = "UNKNOWN" - ordered_tags["unmapped"]["feature_name"] = "unmapped" - return (ordered_tags, longest_tag_len) + return (tag_list, longest_tag_len) + # Check if the distance is big enoughbetween tags offending_pairs = [] - for a, b in combinations(tags.keys(), 2): + for a, b in combinations(seq_list, 2): # pylint: disable=no-member distance = Levenshtein.distance(a, b) if distance <= (maximum_distance - 1): offending_pairs.append([a, b, distance]) - DNA_pattern = regex.compile("^[ATGC]*$") - for tag in tags: - if not DNA_pattern.match(tag): - print( - "This tag {} is not only composed of ATGC bases.\nPlease check your tags file".format( - tag - ) - ) - sys.exit("Exiting the application.\n") # If offending pairs are found, print them all. if offending_pairs: print( @@ -190,7 +158,7 @@ def check_tags(tags, maximum_distance): ) sys.exit("Exiting the application.\n") - return (ordered_tags, longest_tag_len) + return (tag_list, longest_tag_len) def sanitize_name(string): @@ -199,12 +167,11 @@ def sanitize_name(string): def convert_to_named_tuple(ordered_tags): # all_tags = namedtuple('all_tags', [sanitize_name(tag) for tag in ordered_tags.keys()]) - tag = namedtuple("tag", ["safe_name", "name", "sequence", "id"]) + tag = namedtuple("tag", ["name", "sequence", "id"]) tag_list = [] for index, tag_name in enumerate(ordered_tags): tag_list.append( tag( - safe_name=tag_name, name=ordered_tags[tag_name]["feature_name"], sequence=ordered_tags[tag_name]["sequence"], id=(index), @@ -237,11 +204,6 @@ def get_read_length(filename): # 'Exiting the application.\n'.format(filename) # ) return read_length -<<<<<<< HEAD -======= - ->>>>>>> develop - def check_barcodes_lengths(read1_length, cb_first, cb_last, umi_first, umi_last): @@ -259,7 +221,7 @@ def check_barcodes_lengths(read1_length, cb_first, cb_last, umi_first, umi_last) barcode_umi_length = barcode_length + umi_length if barcode_umi_length > read1_length: - sys.exit( + raise SystemExit( "[ERROR] Read1 length is shorter than the option you are using for " "Cell and UMI barcodes length. Please, check your options and rerun.\n\n" "Exiting the application.\n" diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 1b3ae17..e62f385 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -112,10 +112,6 @@ def map_reads(mapping_input): no_match = Counter() n = 1 t = time.time() -<<<<<<< HEAD -======= - ->>>>>>> develop # Progress info with open(filename, "r") as input_file: reads = csv.reader(input_file) diff --git a/docs/docs/Running-the-script.md b/docs/docs/Running-the-script.md index 61596d7..032e7de 100644 --- a/docs/docs/Running-the-script.md +++ b/docs/docs/Running-the-script.md @@ -39,11 +39,12 @@ You can run tags of different length together. ### Antibody barcodes structure: -``` -ATGCGA,First_tag_name -GTCATG,Second_tag_name -GCTAGTCGTACGA,Third_tag_name -GCTAGGTGTCGTA,Forth_tag_name +``` +feature_name,sequence +First_tag_name,ATGCGA +Second_tag_name,GTCATG +Third_tag_name,GCTAGTCGTACGA +Forth_tag_name,GCTAGGTGTCGTA ``` *IMPORTANT*: You need to provide only the variable region of the TAG in the tags.csv. Please refer to the following examples. diff --git a/tests/test_data/tags/fail/false_sequence.csv b/tests/test_data/tags/fail/false_sequence.csv new file mode 100644 index 0000000..65907e5 --- /dev/null +++ b/tests/test_data/tags/fail/false_sequence.csv @@ -0,0 +1,10 @@ +sequence,feature_name +AGGACCATCCAA,CITE_LEN_12_1 +ACATGTTACCGT,CITE_LEN_12_2 +AGCTTACTATCC,CITE_LEN_12_3 +TCGATAATGCGAGTACAA,CITE_LEN_18_1 +GAGGCTGAGCTAGCTAGT,CITE_LEN_18_2 +GGCTGATGCTGACTGCTA,CITE_LEN_18_3 +TGTGACGTATTGCTAGCTAG,CITE_LEN_20_1 +ACTGTCTNACGGGTCAGTGC,CITE_LEN_20_2 +TATCACATCGGTGGATCCAT,CITE_LEN_20_3 \ No newline at end of file diff --git a/tests/test_data/tags/fail/missing_entry.csv b/tests/test_data/tags/fail/missing_entry.csv new file mode 100644 index 0000000..9815324 --- /dev/null +++ b/tests/test_data/tags/fail/missing_entry.csv @@ -0,0 +1,10 @@ +sequence,feature_name +AGGACCATCCAA,CITE_LEN_12_1 +ACATGTTACCGT,CITE_LEN_12_2 +AGCTTACTATCC +TCGATAATGCGAGTACAA,CITE_LEN_18_1 +GAGGCTGAGCTAGCTAGT,CITE_LEN_18_2 +GGCTGATGCTGACTGCTA,CITE_LEN_18_3 +TGTGACGTATTGCTAGCTAG,CITE_LEN_20_1 +ACTGTCTAACGGGTCAGTGC,CITE_LEN_20_2 +TATCACATCGGTGGATCCAT,CITE_LEN_20_3 \ No newline at end of file diff --git a/tests/test_data/tags/fail/wrong_header.csv b/tests/test_data/tags/fail/wrong_header.csv new file mode 100644 index 0000000..b167686 --- /dev/null +++ b/tests/test_data/tags/fail/wrong_header.csv @@ -0,0 +1,10 @@ +sequences,feature_name +AGGACCATCCAA,CITE_LEN_12_1 +ACATGTTACCGT,CITE_LEN_12_2 +AGCTTACTATCC,CITE_LEN_12_3 +TCGATAATGCGAGTACAA,CITE_LEN_18_1 +GAGGCTGAGCTAGCTAGT,CITE_LEN_18_2 +GGCTGATGCTGACTGCTA,CITE_LEN_18_3 +TGTGACGTATTGCTAGCTAG,CITE_LEN_20_1 +ACTGTCTAACGGGTCAGTGC,CITE_LEN_20_2 +TATCACATCGGTGGATCCAT,CITE_LEN_20_3 \ No newline at end of file diff --git a/tests/test_data/tags/correct.csv b/tests/test_data/tags/pass/correct.csv similarity index 92% rename from tests/test_data/tags/correct.csv rename to tests/test_data/tags/pass/correct.csv index 5d49bb0..66b0a4d 100644 --- a/tests/test_data/tags/correct.csv +++ b/tests/test_data/tags/pass/correct.csv @@ -1,3 +1,4 @@ +sequence,feature_name AGGACCATCCAA,CITE_LEN_12_1 ACATGTTACCGT,CITE_LEN_12_2 AGCTTACTATCC,CITE_LEN_12_3 diff --git a/tests/test_data/tags/pass/correct_2.csv b/tests/test_data/tags/pass/correct_2.csv new file mode 100644 index 0000000..1bb1dce --- /dev/null +++ b/tests/test_data/tags/pass/correct_2.csv @@ -0,0 +1,10 @@ +feature_name,sequence +CITE_LEN_12_1,AGGACCATCCAA +CITE_LEN_12_2,ACATGTTACCGT +CITE_LEN_12_3,AGCTTACTATCC +CITE_LEN_18_1,TCGATAATGCGAGTACAA +CITE_LEN_18_2,GAGGCTGAGCTAGCTAGT +CITE_LEN_18_3,GGCTGATGCTGACTGCTA +CITE_LEN_20_1,TGTGACGTATTGCTAGCTAG +CITE_LEN_20_2,ACTGTCTAACGGGTCAGTGC +CITE_LEN_20_3,TATCACATCGGTGGATCCAT \ No newline at end of file diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 43401d8..910a680 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -1,93 +1,124 @@ import pytest import io from cite_seq_count import preprocessing +import glob +from collections import namedtuple @pytest.fixture def data(): from collections import OrderedDict from itertools import islice - + + pytest.passing_csv = "tests/test_data/tags/pass/*.csv" + pytest.failing_csv = "tests/test_data/tags/fail/*.csv" + # Test file paths - pytest.correct_whitelist_path = 'tests/test_data/whitelists/correct.csv' - pytest.correct_tags_path = 'tests/test_data/tags/correct.csv' - pytest.correct_R1_path = 'tests/test_data/fastq/correct_R1.fastq.gz' - pytest.correct_R2_path = 'tests/test_data/fastq/correct_R2.fastq.gz' - pytest.corrupt_R1_path = 'tests/test_data/fastq/corrupted_R1.fastq.gz' - pytest.corrupt_R2_path = 'tests/test_data/fastq/corrupted_R2.fastq.gz' + pytest.correct_whitelist_path = "tests/test_data/whitelists/correct.csv" + pytest.correct_tags_path = "tests/test_data/tags/pass/correct.csv" + pytest.correct_R1_path = "tests/test_data/fastq/correct_R1.fastq.gz" + pytest.correct_R2_path = "tests/test_data/fastq/correct_R2.fastq.gz" + pytest.corrupt_R1_path = "tests/test_data/fastq/corrupted_R1.fastq.gz" + pytest.corrupt_R2_path = "tests/test_data/fastq/corrupted_R2.fastq.gz" - pytest.correct_R1_multipath = 'path/to/R1_1.fastq.gz,path/to/R1_2.fastq.gz' - pytest.correct_R2_multipath = 'path/to/R2_1.fastq.gz,path/to/R2_2.fastq.gz' - pytest.incorrect_R2_multipath = 'path/to/R2_1.fastq.gz,path/to/R2_2.fastq.gz,path/to/R2_3.fastq.gz' + pytest.correct_R1_multipath = "path/to/R1_1.fastq.gz,path/to/R1_2.fastq.gz" + pytest.correct_R2_multipath = "path/to/R2_1.fastq.gz,path/to/R2_2.fastq.gz" + pytest.incorrect_R2_multipath = ( + "path/to/R2_1.fastq.gz,path/to/R2_2.fastq.gz,path/to/R2_3.fastq.gz" + ) - pytest.correct_multipath_result = (['path/to/R1_1.fastq.gz', 'path/to/R1_2.fastq.gz'], - ['path/to/R2_1.fastq.gz', 'path/to/R2_2.fastq.gz']) + pytest.correct_multipath_result = ( + ["path/to/R1_1.fastq.gz", "path/to/R1_2.fastq.gz"], + ["path/to/R2_1.fastq.gz", "path/to/R2_2.fastq.gz"], + ) # Create some variables to compare to - pytest.correct_whitelist = set(['ACTGTTTTATTGGCCT','TTCATAAGGTAGGGAT']) + pytest.correct_whitelist = set(["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]) pytest.correct_tags = { - 'AGGACCATCCAA':'CITE_LEN_12_1', - 'ACATGTTACCGT':'CITE_LEN_12_2', - 'AGCTTACTATCC':'CITE_LEN_12_3', - 'TCGATAATGCGAGTACAA':'CITE_LEN_18_1', - 'GAGGCTGAGCTAGCTAGT':'CITE_LEN_18_2', - 'GGCTGATGCTGACTGCTA':'CITE_LEN_18_3', - 'TGTGACGTATTGCTAGCTAG':'CITE_LEN_20_1', - 'ACTGTCTAACGGGTCAGTGC':'CITE_LEN_20_2', - 'TATCACATCGGTGGATCCAT':'CITE_LEN_20_3'} - pytest.correct_ordered_tags = OrderedDict({ - 'CITE_LEN_20_1':{'id':0,'sequence':'TGTGACGTATTGCTAGCTAG'}, - 'CITE_LEN_20_2':{'id':1,'sequence':'ACTGTCTAACGGGTCAGTGC'}, - 'CITE_LEN_20_3':{'id':2,'sequence':'TATCACATCGGTGGATCCAT'}, - 'CITE_LEN_18_1':{'id':3,'sequence':'TCGATAATGCGAGTACAA'}, - 'CITE_LEN_18_2':{'id':4,'sequence':'GAGGCTGAGCTAGCTAGT'}, - 'CITE_LEN_18_3':{'id':5,'sequence':'GGCTGATGCTGACTGCTA'}, - 'CITE_LEN_12_1':{'id':6,'sequence':'AGGACCATCCAA'}, - 'CITE_LEN_12_2':{'id':7,'sequence':'ACATGTTACCGT'}, - 'CITE_LEN_12_3':{'id':8,'sequence':'AGCTTACTATCC'}, - 'unmapped':{'id':9, 'sequence': 'UNKNOWN'}}) + "AGGACCATCCAA": "CITE_LEN_12_1", + "ACATGTTACCGT": "CITE_LEN_12_2", + "AGCTTACTATCC": "CITE_LEN_12_3", + "TCGATAATGCGAGTACAA": "CITE_LEN_18_1", + "GAGGCTGAGCTAGCTAGT": "CITE_LEN_18_2", + "GGCTGATGCTGACTGCTA": "CITE_LEN_18_3", + "TGTGACGTATTGCTAGCTAG": "CITE_LEN_20_1", + "ACTGTCTAACGGGTCAGTGC": "CITE_LEN_20_2", + "TATCACATCGGTGGATCCAT": "CITE_LEN_20_3", + } + tag = namedtuple("tag", ["name", "sequence", "id"]) + pytest.correct_tags_tuple = [ + tag(name="CITE_LEN_20_1", sequence="TGTGACGTATTGCTAGCTAG", id=0), + tag(name="CITE_LEN_20_2", sequence="ACTGTCTAACGGGTCAGTGC", id=1), + tag(name="CITE_LEN_20_3", sequence="TATCACATCGGTGGATCCAT", id=2), + tag(name="CITE_LEN_18_1", sequence="TCGATAATGCGAGTACAA", id=3), + tag(name="CITE_LEN_18_2", sequence="GAGGCTGAGCTAGCTAGT", id=4), + tag(name="CITE_LEN_18_3", sequence="GGCTGATGCTGACTGCTA", id=5), + tag(name="CITE_LEN_12_1", sequence="AGGACCATCCAA", id=6), + tag(name="CITE_LEN_12_2", sequence="ACATGTTACCGT", id=7), + tag(name="CITE_LEN_12_3", sequence="AGCTTACTATCC", id=8), + tag(name="unmapped", sequence="UNKNOWN", id=9), + ] pytest.barcode_slice = slice(0, 16) pytest.umi_slice = slice(16, 26) pytest.barcode_umi_length = 26 + +def test_csv_parser(data): + passing_files = glob.glob(pytest.passing_csv) + for file_path in passing_files: + preprocessing.parse_tags_csv(file_path) + with pytest.raises(SystemExit): + failing_files = glob.glob(pytest.failing_csv) + for file_path in failing_files: + print(file_path) + preprocessing.parse_tags_csv(file_path) + + @pytest.mark.dependency() def test_parse_whitelist_csv(data): - assert preprocessing.parse_whitelist_csv(pytest.correct_whitelist_path, 16, 1) == (pytest.correct_whitelist,1) + assert preprocessing.parse_whitelist_csv(pytest.correct_whitelist_path, 16) in ( + pytest.correct_whitelist, + 1, + ) + @pytest.mark.dependency() def test_parse_tags_csv(data): - assert preprocessing.parse_tags_csv(pytest.correct_tags_path) == pytest.correct_tags - -@pytest.mark.dependency(depends=['test_parse_tags_csv']) -def test_check_tags(data): tags = preprocessing.check_tags(pytest.correct_tags, 5)[0] - for name in tags.keys(): - assert tags[name] == pytest.correct_ordered_tags[name] - + for i, tag in enumerate(tags): + assert tag == pytest.correct_tags_tuple[i] + -@pytest.mark.dependency(depends=['test_check_tags']) +@pytest.mark.dependency(depends=["test_parse_tags_csv"]) def test_check_distance_too_big_between_tags(data): with pytest.raises(SystemExit): preprocessing.check_tags(pytest.correct_tags, 8) -@pytest.mark.dependency(depends=['test_parse_whitelist_csv']) -def test_check_barcodes_lengths(data): - assert preprocessing.check_barcodes_lengths(26, 1, 16, 17, 26) == (pytest.barcode_slice, pytest.umi_slice, pytest.barcode_umi_length) @pytest.mark.dependency() def test_get_n_lines(data): - assert preprocessing.get_n_lines(pytest.correct_R1_path) == (200 * 4) + assert preprocessing.get_n_lines(pytest.correct_R1_path) == (200 * 4) -@pytest.mark.dependency(depends=['test_get_n_lines']) + +@pytest.mark.dependency(depends=["test_get_n_lines"]) def test_get_n_lines_not_multiple_of_4(data): - with pytest.raises(SystemExit): - preprocessing.get_n_lines(pytest.corrupt_R1_path) + with pytest.raises(SystemExit): + preprocessing.get_n_lines(pytest.corrupt_R1_path) + @pytest.mark.dependency() def test_corrrect_multipath(data): - assert preprocessing.get_read_paths(pytest.correct_R1_multipath, pytest.correct_R2_multipath) == pytest.correct_multipath_result + assert ( + preprocessing.get_read_paths( + pytest.correct_R1_multipath, pytest.correct_R2_multipath + ) + == pytest.correct_multipath_result + ) + -@pytest.mark.dependency(depends=['test_get_n_lines']) +@pytest.mark.dependency(depends=["test_get_n_lines"]) def test_incorrrect_multipath(data): - with pytest.raises(SystemExit): - preprocessing.get_read_paths(pytest.correct_R1_multipath, pytest.incorrect_R2_multipath) + with pytest.raises(SystemExit): + preprocessing.get_read_paths( + pytest.correct_R1_multipath, pytest.incorrect_R2_multipath + ) From 49cfdbabd0f601e46a7c2517333d88a196006fde Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Mon, 14 Dec 2020 23:25:08 +0100 Subject: [PATCH 24/77] docstring updates --- cite_seq_count/__main__.py | 10 +- cite_seq_count/io.py | 26 ++-- cite_seq_count/preprocessing.py | 56 ++++----- cite_seq_count/processing.py | 30 ++--- tests/test_processing.py | 215 +++++++++++++++++--------------- 5 files changed, 175 insertions(+), 162 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index d7cdc0b..99b23f3 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -62,12 +62,12 @@ def main(): read2_lengths = [] total_reads = 0 - for read1_path, read2_path in zip(read1_paths, read2_paths): + for read1_path in read1_paths: n_lines = preprocessing.get_n_lines(read1_path) total_reads += n_lines / 4 # Get reads length. So far, there is no validation for Read2. read1_lengths.append(preprocessing.get_read_length(read1_path)) - read2_lengths.append(preprocessing.get_read_length(read2_path)) + # read2_lengths.append(preprocessing.get_read_length(read2_path)) # Check Read1 length against CELL and UMI barcodes length. preprocessing.check_barcodes_lengths( read1_lengths[-1], @@ -91,10 +91,10 @@ def main(): print("Detected {} pairs of files to run on.".format(number_of_samples)) if args.sliding_window: - R2_max_length = read2_lengths[0] + R2_min_length = read2_lengths[0] maximum_distance = 0 else: - R2_max_length = longest_tag_len + R2_min_length = longest_tag_len maximum_distance = args.max_error ( @@ -107,7 +107,7 @@ def main(): args=args, read1_paths=read1_paths, read2_paths=read2_paths, - R2_max_length=R2_max_length, + R2_min_length=R2_min_length, n_reads_to_chunk=n_reads, chemistry_def=chemistry_def, ordered_tags=ordered_tags, diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 4b54681..d8c45cf 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -175,13 +175,25 @@ def write_chunks_to_disk( args, read1_paths, read2_paths, - R2_max_length, - n_reads_to_chunk, + R2_min_length, + n_reads_per_chunk, chemistry_def, ordered_tags, maximum_distance, ): """ + Writes chunked files of reads to disk and prepares parallel + processing queue parameters. + + Args: + args(argparse): All parsed arguments. + read1_paths (list): List of R1 fastq.gz paths. + read2_paths (list): List of R2 fastq.gz paths. + R2_min_length (int): Minimum length of read2 sequences. + n_reads_per_chunk (int): How many reads per chunk. + chemistry_def (namedtuple): Hols all the information about the chemistry definition. + ordered_tags (list): List of namedtuple tags. + maximum_distance (int): Maximum hamming distance for mapping. """ mapping_input = namedtuple( "mapping_input", @@ -192,7 +204,7 @@ def write_chunks_to_disk( num_chunk = 0 if not args.chunk_size: - chunk_size = round(n_reads_to_chunk / args.n_threads) + chunk_size = round(n_reads_per_chunk / args.n_threads) else: chunk_size = args.chunk_size temp_path = os.path.abspath(args.temp_path) @@ -230,7 +242,7 @@ def write_chunks_to_disk( R1_too_short += 1 # The entire read is skipped continue - if len(read2) < R2_max_length: + if len(read2) < R2_min_length: R2_too_short += 1 # The entire read is skipped continue @@ -241,7 +253,7 @@ def write_chunks_to_disk( read2_sliced = read2[ chemistry_def.R2_trim_start : ( - R2_max_length + chemistry_def.R2_trim_start + R2_min_length + chemistry_def.R2_trim_start ) ] chunked_file_object.write( @@ -266,7 +278,7 @@ def write_chunks_to_disk( sliding_window=args.sliding_window, ) ) - if total_reads_written == n_reads_to_chunk: + if total_reads_written == n_reads_per_chunk: enough_reads = True chunked_file_object.close() break @@ -275,7 +287,7 @@ def write_chunks_to_disk( chunked_file_object = open(temp_filename, "w") temp_files.append(os.path.abspath(temp_filename)) reads_written = 0 - if total_reads_written == n_reads_to_chunk: + if total_reads_written == n_reads_per_chunk: enough_reads = True chunked_file_object.close() break diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 72b1c46..6096d2c 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -49,14 +49,13 @@ def parse_whitelist_csv(filename, barcode_length): ) ) if len(whitelist) == 0: - sys.exit( - "Please check cell barcode indexes -cbs, -cbl because none of the given whitelist is valid." - ) + sys.exit("Whitelist is empty.") return set(whitelist) def parse_tags_csv(filename): - """Reads the TAGs from a CSV file. Checks if sequences are made of ATGC + """Reads the TAGs from a CSV file. Checks that the header contains + necessary strings and if sequences are made of ATGC The expected file format has a header with "sequence" and "feature_name". Order doesn't matter. @@ -67,10 +66,10 @@ def parse_tags_csv(filename): TTCCGCCTCTCTTTG,Hashtag_3 Args: - filename (str): TAGs file. + filename (str): TAGs file path. Returns: - dict: A dictionary containing using sequences as keys and names as values. + dict: A dictionary using sequences as keys and feature names as values. """ REQUIRED_HEADER = ["sequence", "feature_name"] @@ -113,8 +112,7 @@ def check_tags(tags, maximum_distance): between two TAGs. Returns: - OrderedDict: An ordered dictionary containing the TAGs and - their names in descendent order based on the length of the TAGs. + list: An ordered list of namedtuples int: the length of the longest TAG """ @@ -162,27 +160,21 @@ def check_tags(tags, maximum_distance): def sanitize_name(string): - return string.replace("-", "_") - + """ + Transforms special characters that are not compatible with namedtuples -def convert_to_named_tuple(ordered_tags): - # all_tags = namedtuple('all_tags', [sanitize_name(tag) for tag in ordered_tags.keys()]) - tag = namedtuple("tag", ["name", "sequence", "id"]) - tag_list = [] - for index, tag_name in enumerate(ordered_tags): - tag_list.append( - tag( - name=ordered_tags[tag_name]["feature_name"], - sequence=ordered_tags[tag_name]["sequence"], - id=(index), - ) - ) - # all_tags[index+1]=ordered_tags[tag_name]['sequence'] - return tag_list + Args: + string(str): a string from a feature name + + Returns: + str: modified string + """ + return string.replace("-", "_") def get_read_length(filename): - """Check wether SEQUENCE lengths are consistent in a FASTQ file and return + """Check wether SEQUENCE lengths are consistent in + the first 1000 reads from a FASTQ file and return the length. Args: @@ -194,15 +186,15 @@ def get_read_length(filename): """ with gzip.open(filename, "r") as fastq_file: secondlines = islice(fastq_file, 1, 1000, 4) - # temp_length = len(next(secondlines).rstrip()) + temp_length = len(next(secondlines).rstrip()) for sequence in secondlines: read_length = len(sequence.rstrip()) - # if (temp_length != read_length): - # sys.exit( - # '[ERROR] Sequence length in {} is not consistent. Please, trim all ' - # 'sequences at the same length.\n' - # 'Exiting the application.\n'.format(filename) - # ) + if temp_length != read_length: + sys.exit( + "[ERROR] Sequence length in {} is not consistent. Please, trim all " + "sequences at the same length.\n" + "Exiting the application.\n".format(filename) + ) return read_length diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index e62f385..503eea5 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -61,15 +61,12 @@ def find_best_match_shift(TAG_seq, tags): """ Find the best match from the list of tags with sliding window. Only works with exact match. - - Compares the Levenshtein distance between tags and the trimmed sequences. - The tag and the sequence must have the same length. + Just checks if the string is in the sequence. If no matches found returns 'unmapped'. - We add 1 + Args: TAG_seq (string): Sequence from R2 already start trimmed tags (dict): A dictionary with the TAGs as keys and TAG Names as values. - maximum_distance (int): Maximum distance given by the user. Returns: best_match (string): The TAG name that will be used for counting. @@ -77,7 +74,7 @@ def find_best_match_shift(TAG_seq, tags): best_match = "unmapped" for tag in tags: if tag.sequence in TAG_seq: - return tag.secure_name + return tag.name return best_match @@ -87,19 +84,12 @@ def map_reads(mapping_input): It reads both Read1 and Read2 files, creating a dict based on cell barcode. Args: - read1_path (string): Path to R1.fastq.gz - read2_path (string): Path to R2.fastq.gz - chunk_size (int): The number of lines to process - tags (dict): A dictionary with the TAGs + TAG Names. - barcode_slice (slice): A slice for extracting the Barcode portion from the - sequence. - umi_slice (slice): A slice for extracting the UMI portion from the - sequence. - indexes (list): Pair of first and last index for islice - debug (bool): Print debug messages. Default is False. - start_trim (int): Number of bases to trim at the start. - maximum_distance (int): Maximum distance given by the user. - sliding_window (bool): A bool enabling a sliding window search + mapping_input (namedtuple): List of paramters to run in parallel. + filename (str): Path to the chunk file + tags (list): List of named tuples tags + debug (bool): Should debug information be shown or not + maximum_distance (int): Maximum distance given by the user + sliding_window (bool): A bool enabling a sliding window search Returns: results (dict): A dict of dict of Counters with the mapping results. @@ -207,7 +197,7 @@ def merge_results(parallel_results): def check_unmapped(no_match, too_short, total_reads, start_trim): """Check if the number of unmapped is higher than 99%""" if (sum(no_match.values()) + too_short) / total_reads > float(0.99): - exit( + sys.exit( """More than 99% of your data is unmapped.\nPlease check that your --start_trim {} parameter is correct and that your tags file is properly formatted""".format( start_trim ) diff --git a/tests/test_processing.py b/tests/test_processing.py index 039bd9e..e24e43e 100644 --- a/tests/test_processing.py +++ b/tests/test_processing.py @@ -8,41 +8,45 @@ def complete_poly_A(seq, final_length=40): poly_A_len = final_length - len(seq) - return(seq + 'A' * poly_A_len) + return seq + "A" * poly_A_len + def get_sequences(ref_path): sequences = [] - with open(ref_path, 'r') as adt_ref: + with open(ref_path, "r") as adt_ref: lines = adt_ref.readlines() - entries = int(len(lines)/2) + entries = int(len(lines) / 2) for i in range(0, entries, 2): - sequences.append(complete_poly_A(lines[i+1].strip())) - return(sequences) - + sequences.append(complete_poly_A(lines[i + 1].strip())) + return sequences + + def extend_seq_pool(ref_seq, distance): extended_pool = [complete_poly_A(ref_seq)] - extended_pool.append(modify(ref_seq, distance, modification_type='mutate')) - extended_pool.append(modify(ref_seq, distance, modification_type='mutate')) - extended_pool.append(modify(ref_seq, distance, modification_type='mutate')) - return(extended_pool) + extended_pool.append(modify(ref_seq, distance, modification_type="mutate")) + extended_pool.append(modify(ref_seq, distance, modification_type="mutate")) + extended_pool.append(modify(ref_seq, distance, modification_type="mutate")) + return extended_pool + def modify(seq, n, modification_type): - bases=list('ATGCN') + bases = list("ATGCN") positions = list(range(len(seq))) seq = list(seq) for i in range(n): - if modification_type == 'mutate': + if modification_type == "mutate": position = random.choice(positions) positions.remove(position) temp_bases = copy.copy(bases) del temp_bases[bases.index(seq[position])] seq[position] = random.choice(temp_bases) - elif modification_type == 'delete': - del seq[random.randint(0,len(seq)-2)] - elif modification_type == 'add': - position = random.randint(0,len(seq)-1) + elif modification_type == "delete": + del seq[random.randint(0, len(seq) - 2)] + elif modification_type == "add": + position = random.randint(0, len(seq) - 1) seq.insert(position, random.choice(bases)) - return(complete_poly_A(''.join(seq))) + return complete_poly_A("".join(seq)) + @pytest.fixture def data(): @@ -50,50 +54,43 @@ def data(): from collections import defaultdict from collections import OrderedDict from collections import Counter - + from itertools import islice + # Test file paths - pytest.correct_R1_path = 'tests/test_data/fastq/correct_R1.fastq.gz' - pytest.correct_R2_path = 'tests/test_data/fastq/correct_R2.fastq.gz' - pytest.file_path = 'tests/test_data/fastq/test_csv.csv' - + pytest.correct_R1_path = "tests/test_data/fastq/correct_R1.fastq.gz" + pytest.correct_R2_path = "tests/test_data/fastq/correct_R2.fastq.gz" + pytest.file_path = "tests/test_data/fastq/test_csv.csv" + pytest.chunk_size = 800 - pytest.tags = OrderedDict({ - 'test2':{'id':0,'sequence':'CGTACGTAGCCTAGC'}, - 'test1':{'id':1,'sequence':'CGTAGCTCG'}, - 'unmapped':{'id':2,'sequence':'UNKNOWN'}, - }) + pytest.tags = OrderedDict( + { + "test2": {"id": 0, "sequence": "CGTACGTAGCCTAGC"}, + "test1": {"id": 1, "sequence": "CGTAGCTCG"}, + "unmapped": {"id": 2, "sequence": "UNKNOWN"}, + } + ) pytest.barcode_slice = slice(0, 16) pytest.umi_slice = slice(16, 26) - pytest.correct_whitelist = set(['ACTGTTTTATTGGCCT','TTCATAAGGTAGGGAT']) + pytest.correct_whitelist = set(["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]) pytest.legacy = False pytest.debug = False pytest.start_trim = 0 pytest.maximum_distance = 5 pytest.results = { - 'ACTGTTTTATTGGCCT': - {'test1': - Counter({b'CATTAGTGGT': 3, b'CATTAGTGGG': 2, b'CATTCGTGGT': 1})}, - 'TTCATAAGGTAGGGAT': - {'test2': - Counter({b'TAGCTTAGTA': 3, b'TAGCTTAGTC': 2, b'GCGATGCATA': 1})} + "ACTGTTTTATTGGCCT": { + "test1": Counter({b"CATTAGTGGT": 3, b"CATTAGTGGG": 2, b"CATTCGTGGT": 1}) + }, + "TTCATAAGGTAGGGAT": { + "test2": Counter({b"TAGCTTAGTA": 3, b"TAGCTTAGTC": 2, b"GCGATGCATA": 1}) + }, } pytest.corrected_results = { - 'ACTGTTTTATTGGCCT': - {'test1': - Counter({b'CATTAGTGGT': 6})}, - 'TTCATAAGGTAGGGAT': - {'test2': - Counter({b'TAGCTTAGTA': 5, b'GCGATGCATA': 1})} + "ACTGTTTTATTGGCCT": {"test1": Counter({b"CATTAGTGGT": 6})}, + "TTCATAAGGTAGGGAT": {"test2": Counter({b"TAGCTTAGTA": 5, b"GCGATGCATA": 1})}, } - pytest.umis_per_cell = Counter({ - 'ACTGTTTTATTGGCCT': 1, - 'TTCATAAGGTAGGGAT': 2 - }) - pytest.reads_per_cell = Counter({ - 'ACTGTTTTATTGGCCT': 3, - 'TTCATAAGGTAGGGAT': 6 - }) + pytest.umis_per_cell = Counter({"ACTGTTTTATTGGCCT": 1, "TTCATAAGGTAGGGAT": 2}) + pytest.reads_per_cell = Counter({"ACTGTTTTATTGGCCT": 3, "TTCATAAGGTAGGGAT": 6}) pytest.expected_cells = 2 pytest.no_match = Counter() pytest.collapsing_threshold = 1 @@ -101,98 +98,120 @@ def data(): pytest.max_umis = 20000 pytest.sequence_pool = [] - pytest.tags_complete_dict = preprocessing.check_tags(preprocessing.parse_tags_csv('tests/test_data/tags/correct.csv'), 5)[0] - pytest.tags_complete_tuple = preprocessing.convert_to_named_tuple(pytest.tags_complete_dict) - pytest.tags_short_tuple = preprocessing.convert_to_named_tuple(pytest.tags) - + pytest.tags_tuple = preprocessing.check_tags( + preprocessing.parse_tags_csv("tests/test_data/tags/pass/correct.csv"), 5 + )[0] + @pytest.mark.dependency() def test_find_best_match_with_1_distance(data): distance = 1 - for name, tag in pytest.tags_complete_dict.items(): + for tag in pytest.tags_tuple: counts = Counter() - if name == 'unmapped': + if tag.name == "unmapped": continue - for seq in extend_seq_pool(tag['sequence'], distance): - counts[processing.find_best_match(seq, pytest.tags_complete_tuple, distance)] += 1 - assert counts[name] == 4 + for seq in extend_seq_pool(tag.sequence, distance): + counts[processing.find_best_match(seq, pytest.tags_tuple, distance)] += 1 + assert counts[tag.id] == 4 + @pytest.mark.dependency() def test_find_best_match_with_2_distance(data): distance = 2 - for name, tag in pytest.tags_complete_dict.items(): + for tag in pytest.tags_tuple: counts = Counter() - if name == 'unmapped': + if tag.name == "unmapped": continue - for seq in extend_seq_pool(tag['sequence'], distance): - counts[processing.find_best_match(seq, pytest.tags_complete_tuple, distance)] += 1 - assert counts[name] == 4 + for seq in extend_seq_pool(tag.sequence, distance): + counts[processing.find_best_match(seq, pytest.tags_tuple, distance)] += 1 + assert counts[tag.id] == 4 + @pytest.mark.dependency() def test_find_best_match_with_3_distance(data): distance = 3 - for name, tag in pytest.tags_complete_dict.items(): + for tag in pytest.tags_tuple: counts = Counter() - if name == 'unmapped': + if tag.name == "unmapped": continue - for seq in extend_seq_pool(tag['sequence'], distance): - counts[processing.find_best_match(seq, pytest.tags_complete_tuple, distance)] += 1 - assert counts[name] == 4 + for seq in extend_seq_pool(tag.sequence, distance): + counts[processing.find_best_match(seq, pytest.tags_tuple, distance)] += 1 + assert counts[tag.id] == 4 + @pytest.mark.dependency() def test_find_best_match_with_3_distance_reverse(data): distance = 3 - for name, tag in sorted(pytest.tags_complete_dict.items()): + for tag in pytest.tags_tuple: counts = Counter() - if name == 'unmapped': - continue - for seq in extend_seq_pool(tag['sequence'], distance): - counts[processing.find_best_match(seq, pytest.tags_complete_tuple, distance)] += 1 - assert counts[name] == 4 - -@pytest.mark.dependency(depends=[ - 'test_find_best_match_with_1_distance', - 'test_find_best_match_with_2_distance', - 'test_find_best_match_with_3_distance', - 'test_find_best_match_with_3_distance_reverse',]) + if tag.name == "unmapped": + continue + for seq in extend_seq_pool(tag.sequence, distance): + counts[processing.find_best_match(seq, pytest.tags_tuple, distance)] += 1 + assert counts[tag.id] == 4 + + +@pytest.mark.dependency( + depends=[ + "test_find_best_match_with_1_distance", + "test_find_best_match_with_2_distance", + "test_find_best_match_with_3_distance", + "test_find_best_match_with_3_distance_reverse", + ] +) def test_classify_reads_multi_process(data): - (results, no_match) = processing.map_reads(( - pytest.file_path, - pytest.tags_short_tuple, - pytest.debug, - pytest.maximum_distance, - pytest.sliding_window)) + (results, _) = processing.map_reads( + ( + pytest.file_path, + pytest.tags_tuple, + pytest.debug, + pytest.maximum_distance, + pytest.sliding_window, + ) + ) assert len(results) == 2 -@pytest.mark.dependency(depends=['test_classify_reads_multi_process']) +@pytest.mark.dependency(depends=["test_classify_reads_multi_process"]) def test_correct_umis(data): temp = processing.correct_umis((pytest.results, 2, pytest.max_umis)) results = temp[0] n_corrected = temp[1] for cell_barcode in results.keys(): for TAG in results[cell_barcode]: - assert len(results[cell_barcode][TAG]) == len(pytest.corrected_results[cell_barcode][TAG]) - assert sum(results[cell_barcode][TAG].values()) == sum(pytest.corrected_results[cell_barcode][TAG].values()) + assert len(results[cell_barcode][TAG]) == len( + pytest.corrected_results[cell_barcode][TAG] + ) + assert sum(results[cell_barcode][TAG].values()) == sum( + pytest.corrected_results[cell_barcode][TAG].values() + ) assert n_corrected == 3 -@pytest.mark.dependency(depends=['test_correct_umis']) +@pytest.mark.dependency(depends=["test_correct_umis"]) def test_correct_cells(data): - processing.correct_cells(pytest.corrected_results, pytest.reads_per_cell, pytest.umis_per_cell, pytest.expected_cells, pytest.collapsing_threshold, pytest.tags) + processing.correct_cells( + pytest.corrected_results, + pytest.reads_per_cell, + pytest.umis_per_cell, + pytest.expected_cells, + pytest.collapsing_threshold, + pytest.tags, + ) -@pytest.mark.dependency(depends=['test_correct_umis']) +@pytest.mark.dependency(depends=["test_correct_umis"]) def test_generate_sparse_matrices(data): (umi_results_matrix, read_results_matrix) = processing.generate_sparse_matrices( - pytest.corrected_results, pytest.tags, - set(['ACTGTTTTATTGGCCT','TTCATAAGGTAGGGAT']) - ) - assert umi_results_matrix.shape == (3,2) - assert read_results_matrix.shape == (3,2) + pytest.corrected_results, + pytest.tags, + set(["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]), + ) + assert umi_results_matrix.shape == (3, 2) + assert read_results_matrix.shape == (3, 2) read_results_matrix = read_results_matrix.tocsr() total_reads = 0 for i in range(read_results_matrix.shape[0]): for j in range(read_results_matrix.shape[1]): - total_reads += read_results_matrix[i,j] - assert total_reads == 12 \ No newline at end of file + total_reads += read_results_matrix[i, j] + assert total_reads == 12 From b6a68611d7edabff073730d1a4c2be8a8eac9118 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Tue, 15 Dec 2020 09:25:52 +0100 Subject: [PATCH 25/77] fixed the test file --- tests/test_data/fastq/test_csv.csv | 400 ++++++++++++++--------------- tests/test_processing.py | 24 +- 2 files changed, 214 insertions(+), 210 deletions(-) diff --git a/tests/test_data/fastq/test_csv.csv b/tests/test_data/fastq/test_csv.csv index d4d8295..dc77a28 100644 --- a/tests/test_data/fastq/test_csv.csv +++ b/tests/test_data/fastq/test_csv.csv @@ -1,200 +1,200 @@ -TAGAGGGAAGTCAAGC,CNGAGTCTCN,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,CACNTAAATC,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,TCTCCGAAGC,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,TTGACTTATC,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,TTTCCTTCCG,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,NAGAACACCA,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,CTCATCTTAT,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,TCGACAAGCT,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,AAANACACTC,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,AAANACACTC,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,GNCNCTGATA,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,TGGTGTGNGC,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,AGANTCTCCC,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,GCGCGGAGAA,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,TGGTGTGNGC,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,TCTGGTAGTT,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,TAGGAGCAGN,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,ATCACGGATC,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,CCGGGGACTA,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,TTACTAGTAA,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,AGAACGTCGC,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,ANGAGNAAGT,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,ATAACTAGAA,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,TTGACTTATC,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,ACACTGCTAT,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,GCGCGGAGAA,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,CGTGATGAGC,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,TCACCCNCGG,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,ATCCGTACTT,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,CCGGGGACTA,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,NTTCATGTTG,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,ATCGGGAGNC,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,CGTCNGTTGC,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,TAGTCAGAAT,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,ATCGGGAGNC,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,ACACTGCTAT,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,TGCTCAATAG,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,GATCGTACAA,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,AGACCTNTGG,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,TGACCTAAGC,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,GATCGTACAA,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,TGGTGTGNGC,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,TCGACACCAC,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,CATCNAGTGN,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,TTAAAACCNA,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,CCATACGNNA,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,ATTGTTCGGA,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,CACNCTAGGG,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,AGGAGCNCCC,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,GCGCGGAGAA,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,TTCCGTNCAA,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,GNCNCTGATA,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,CAAGGGAACG,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,AGAANGCCNA,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,TCTGGTAGTT,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,ACAGAGTAAN,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,CACNCTAGGG,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,CTACACGTGA,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,CGTCNGTTGC,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,CGTCGTTATA,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,TTCNGTCACC,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,TTGGNGTACA,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,TCGACAAGCT,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,CGNAATTTGA,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,GCGCGGAGAA,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,CTACGCCGCC,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,TAGGAGCAGN,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,AGCANTGTAG,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,AGAANGCCNA,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,TGTCTGCACG,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,CNGAGTCTCN,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,ATAACTAGAA,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,CCATGTGNGT,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,CCATGTGNGT,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,CGTAGGCATT,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,CCGGGGACTA,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,ANTTCTCTCA,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,ACAGAGTAAN,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,TGCTCAATAG,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,AGACTTAGGG,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,TAAAGGCTTG,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,CTTGAGAGGG,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,TTCCGTNCAA,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,CATCNAGTGN,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,GAACACTGAG,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,ACGCGGAGTT,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,ANGAGNAAGT,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,ANTTCTCTCA,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,GCTGTGTTAG,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,TTCNGTCACC,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,CCATACGNNA,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,CGTCNGTTGC,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,CCATGTGNGT,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,GCCCGCTCAC,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,AAANACACTC,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,TAAAGGCTTG,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,GATTTCACCG,CGTCGTAGCTGATCG -TAGAGGGAAGTCAAGC,CCATACGNNA,CGTACGTAGCCTAGC -TAGAGGGAAGTCAAGC,AGCANTGTAG,CGTAGCTCGAAAAAA -TAGAGGGAAGTCAAGC,AACTCCCACG,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,GATCGAACGG,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,GATCGAACGG,CGTCGTAGCTGATCG -TACATATTCTTTACTG,GANCGGGACA,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,NGCTGGCACG,CGTACGTAGCCTAGC -TACATATTCTTTACTG,GANCGGGACA,CGTACGTAGCCTAGC -TACATATTCTTTACTG,ATTCATTGTA,CGTACGTAGCCTAGC -TACATATTCTTTACTG,TAATCATACC,CGTACGTAGCCTAGC -TACATATTCTTTACTG,GGTCTAAGAG,CGTCGTAGCTGATCG -TACATATTCTTTACTG,GGATNTTGTA,CGTCGTAGCTGATCG -TACATATTCTTTACTG,AATATGANTG,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,ATTAAGCCNG,CGTACGTAGCCTAGC -TACATATTCTTTACTG,TGAGGGTAGA,CGTCGTAGCTGATCG -TACATATTCTTTACTG,CTCTCGCTTT,CGTACGTAGCCTAGC -TACATATTCTTTACTG,GGATNTTGTA,CGTCGTAGCTGATCG -TACATATTCTTTACTG,AGGTTTACTG,CGTCGTAGCTGATCG -TACATATTCTTTACTG,GAGGCGTGTC,CGTCGTAGCTGATCG -TACATATTCTTTACTG,TGCTGAATAA,CGTCGTAGCTGATCG -TACATATTCTTTACTG,AAGGCACTTT,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,CTTTCAAGTN,CGTACGTAGCCTAGC -TACATATTCTTTACTG,GAGGCGTGTC,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,TGTNAATCCA,CGTCGTAGCTGATCG -TACATATTCTTTACTG,GCCAAGTACA,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,GGATNTTGTA,CGTCGTAGCTGATCG -TACATATTCTTTACTG,CCGNTGTGGC,CGTACGTAGCCTAGC -TACATATTCTTTACTG,GATCGAACGG,CGTCGTAGCTGATCG -TACATATTCTTTACTG,TCGCGATGNT,CGTCGTAGCTGATCG -TACATATTCTTTACTG,ACCGTGAGGC,CGTACGTAGCCTAGC -TACATATTCTTTACTG,GGTCGCAGTN,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,CCAGACTTGA,CGTCGTAGCTGATCG -TACATATTCTTTACTG,GACTTTTCCT,CGTCGTAGCTGATCG -TACATATTCTTTACTG,CTTCCATGCC,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,AGCAACCCGA,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,AGCAACCCGA,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,GACGGGGTCT,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,TACGAAGAAT,CGTACGTAGCCTAGC -TACATATTCTTTACTG,CGAGGTGCGN,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,TNCATCGGAT,CGTCGTAGCTGATCG -TACATATTCTTTACTG,AAGGCACTTT,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,ATTAAGCCNG,CGTCGTAGCTGATCG -TACATATTCTTTACTG,GCTAACCCGN,CGTCGTAGCTGATCG -TACATATTCTTTACTG,AGGTTTACTG,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,ANAGGANAAC,CGTACGTAGCCTAGC -TACATATTCTTTACTG,AGGTTTACTG,CGTCGTAGCTGATCG -TACATATTCTTTACTG,CTNATCGGTC,CGTACGTAGCCTAGC -TACATATTCTTTACTG,AGCAACCCGA,CGTACGTAGCCTAGC -TACATATTCTTTACTG,ACTGGTCGCT,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,GCNGTCGCTA,CGTCGTAGCTGATCG -TACATATTCTTTACTG,TCGCGATGNT,CGTCGTAGCTGATCG -TACATATTCTTTACTG,TAATCATACC,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,GACTTTTCCT,CGTACGTAGCCTAGC -TACATATTCTTTACTG,CCCGAATGAA,CGTCGTAGCTGATCG -TACATATTCTTTACTG,GCTTCTACCN,CGTACGTAGCCTAGC -TACATATTCTTTACTG,ACTGGTCGCT,CGTACGTAGCCTAGC -TACATATTCTTTACTG,AGGTCGCTAC,CGTACGTAGCCTAGC -TACATATTCTTTACTG,AGCGCCNTGG,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,CCAGCGCCCG,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,GAGATCCGAG,CGTACGTAGCCTAGC -TACATATTCTTTACTG,TAGCCCCCCC,CGTACGTAGCCTAGC -TACATATTCTTTACTG,ATTCATTGTA,CGTACGTAGCCTAGC -TACATATTCTTTACTG,ATCGGGCGCC,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,CGAGGTGCGN,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,AGTAANGCAA,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,NGCTGGCACG,CGTCGTAGCTGATCG -TACATATTCTTTACTG,CCGNTGTGGC,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,ATCGGGCGCC,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,ATTAAGCCNG,CGTCGTAGCTGATCG -TACATATTCTTTACTG,GGNCGCACCC,CGTACGTAGCCTAGC -TACATATTCTTTACTG,AANCACANGT,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,AANTAAGCAT,CGTCGTAGCTGATCG -TACATATTCTTTACTG,ANAGGANAAC,CGTACGTAGCCTAGC -TACATATTCTTTACTG,GGNCGCACCC,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,CAATTCCGGC,CGTACGTAGCCTAGC -TACATATTCTTTACTG,AGGTTTACTG,CGTCGTAGCTGATCG -TACATATTCTTTACTG,CCGNTGTGGC,CGTACGTAGCCTAGC -TACATATTCTTTACTG,ACGCTATGTA,CGTCGTAGCTGATCG -TACATATTCTTTACTG,CTCCTGTGGC,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,GTTGTTTATT,CGTCGTAGCTGATCG -TACATATTCTTTACTG,CGAAGAGAAC,CGTCGTAGCTGATCG -TACATATTCTTTACTG,CGAAGAGAAC,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,GCGGCCATTC,CGTACGTAGCCTAGC -TACATATTCTTTACTG,AGTAANGCAA,CGTCGTAGCTGATCG -TACATATTCTTTACTG,CGAAGAGAAC,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,GTCAACCGGG,CGTACGTAGCCTAGC -TACATATTCTTTACTG,CTCAATACTA,CGTACGTAGCCTAGC -TACATATTCTTTACTG,ATTAAGCCNG,CGTACGTAGCCTAGC -TACATATTCTTTACTG,TACTGTGCTA,CGTACGTAGCCTAGC -TACATATTCTTTACTG,ANGCACTCGA,CGTCGTAGCTGATCG -TACATATTCTTTACTG,NGCTGGCACG,CGTCGTAGCTGATCG -TACATATTCTTTACTG,TAGTATGGAA,CGTCGTAGCTGATCG -TACATATTCTTTACTG,TCGCGATGNT,CGTCGTAGCTGATCG -TACATATTCTTTACTG,ANAGGANAAC,CGTCGTAGCTGATCG -TACATATTCTTTACTG,ACAGTAAATG,CGTACGTAGCCTAGC -TACATATTCTTTACTG,GGTCTAAGAG,CGTACGTAGCCTAGC -TACATATTCTTTACTG,CGAGGTGCGN,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,ACTGGTCGCT,CGTCGTAGCTGATCG -TACATATTCTTTACTG,GTTGTTTATT,CGTAGCTCGAAAAAA -TACATATTCTTTACTG,AAGGCACTTT,CGTACGTAGCCTAGC -TACATATTCTTTACTG,TGACATCAAC,CGTACGTAGCCTAGC -TACATATTCTTTACTG,TGCAGAAANG,CGTCGTAGCTGATCG -TACATATTCTTTACTG,CTTCAANTGA,CGTAGCTCGAAAAAA +TAGAGGGAAGTCAAGC,CNGAGTCTCN,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,CACNTAAATC,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,TCTCCGAAGC,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,TTGACTTATC,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,TTTCCTTCCG,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,NAGAACACCA,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,CTCATCTTAT,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,TCGACAAGCT,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,AAANACACTC,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,AAANACACTC,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,GNCNCTGATA,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,TGGTGTGNGC,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,AGANTCTCCC,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,GCGCGGAGAA,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,TGGTGTGNGC,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,TCTGGTAGTT,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,TAGGAGCAGN,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,ATCACGGATC,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,CCGGGGACTA,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,TTACTAGTAA,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,AGAACGTCGC,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,ANGAGNAAGT,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,ATAACTAGAA,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,TTGACTTATC,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,ACACTGCTAT,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,GCGCGGAGAA,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,CGTGATGAGC,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,TCACCCNCGG,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,ATCCGTACTT,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,CCGGGGACTA,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,NTTCATGTTG,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,ATCGGGAGNC,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,CGTCNGTTGC,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,TAGTCAGAAT,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,ATCGGGAGNC,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,ACACTGCTAT,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,TGCTCAATAG,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,GATCGTACAA,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,AGACCTNTGG,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,TGACCTAAGC,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,GATCGTACAA,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,TGGTGTGNGC,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,TCGACACCAC,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,CATCNAGTGN,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,TTAAAACCNA,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,CCATACGNNA,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,ATTGTTCGGA,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,CACNCTAGGG,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,AGGAGCNCCC,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,GCGCGGAGAA,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,TTCCGTNCAA,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,GNCNCTGATA,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,CAAGGGAACG,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,AGAANGCCNA,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,TCTGGTAGTT,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,ACAGAGTAAN,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,CACNCTAGGG,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,CTACACGTGA,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,CGTCNGTTGC,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,CGTCGTTATA,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,TTCNGTCACC,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,TTGGNGTACA,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,TCGACAAGCT,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,CGNAATTTGA,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,GCGCGGAGAA,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,CTACGCCGCC,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,TAGGAGCAGN,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,AGCANTGTAG,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,AGAANGCCNA,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,TGTCTGCACG,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,CNGAGTCTCN,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,ATAACTAGAA,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,CCATGTGNGT,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,CCATGTGNGT,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,CGTAGGCATT,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,CCGGGGACTA,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,ANTTCTCTCA,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,ACAGAGTAAN,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,TGCTCAATAG,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,AGACTTAGGG,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,TAAAGGCTTG,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,CTTGAGAGGG,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,TTCCGTNCAA,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,CATCNAGTGN,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,GAACACTGAG,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,ACGCGGAGTT,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,ANGAGNAAGT,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,ANTTCTCTCA,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,GCTGTGTTAG,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,TTCNGTCACC,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,CCATACGNNA,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,CGTCNGTTGC,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,CCATGTGNGT,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,GCCCGCTCAC,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,AAANACACTC,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,TAAAGGCTTG,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,GATTTCACCG,CGTCGTAGCTGATCGAAAAAAAAA +TAGAGGGAAGTCAAGC,CCATACGNNA,CGTACGTAGCCTAGCAAAAAAAAA +TAGAGGGAAGTCAAGC,AGCANTGTAG,CGTAGCTCGAAAAAAAAAAAAAAA +TAGAGGGAAGTCAAGC,AACTCCCACG,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,GATCGAACGG,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,GATCGAACGG,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,GANCGGGACA,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,NGCTGGCACG,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,GANCGGGACA,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,ATTCATTGTA,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,TAATCATACC,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,GGTCTAAGAG,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,GGATNTTGTA,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,AATATGANTG,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,ATTAAGCCNG,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,TGAGGGTAGA,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,CTCTCGCTTT,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,GGATNTTGTA,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,AGGTTTACTG,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,GAGGCGTGTC,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,TGCTGAATAA,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,AAGGCACTTT,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,CTTTCAAGTN,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,GAGGCGTGTC,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,TGTNAATCCA,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,GCCAAGTACA,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,GGATNTTGTA,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,CCGNTGTGGC,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,GATCGAACGG,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,TCGCGATGNT,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,ACCGTGAGGC,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,GGTCGCAGTN,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,CCAGACTTGA,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,GACTTTTCCT,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,CTTCCATGCC,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,AGCAACCCGA,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,AGCAACCCGA,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,GACGGGGTCT,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,TACGAAGAAT,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,CGAGGTGCGN,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,TNCATCGGAT,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,AAGGCACTTT,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,ATTAAGCCNG,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,GCTAACCCGN,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,AGGTTTACTG,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,ANAGGANAAC,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,AGGTTTACTG,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,CTNATCGGTC,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,AGCAACCCGA,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,ACTGGTCGCT,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,GCNGTCGCTA,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,TCGCGATGNT,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,TAATCATACC,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,GACTTTTCCT,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,CCCGAATGAA,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,GCTTCTACCN,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,ACTGGTCGCT,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,AGGTCGCTAC,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,AGCGCCNTGG,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,CCAGCGCCCG,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,GAGATCCGAG,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,TAGCCCCCCC,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,ATTCATTGTA,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,ATCGGGCGCC,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,CGAGGTGCGN,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,AGTAANGCAA,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,NGCTGGCACG,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,CCGNTGTGGC,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,ATCGGGCGCC,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,ATTAAGCCNG,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,GGNCGCACCC,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,AANCACANGT,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,AANTAAGCAT,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,ANAGGANAAC,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,GGNCGCACCC,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,CAATTCCGGC,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,AGGTTTACTG,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,CCGNTGTGGC,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,ACGCTATGTA,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,CTCCTGTGGC,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,GTTGTTTATT,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,CGAAGAGAAC,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,CGAAGAGAAC,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,GCGGCCATTC,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,AGTAANGCAA,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,CGAAGAGAAC,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,GTCAACCGGG,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,CTCAATACTA,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,ATTAAGCCNG,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,TACTGTGCTA,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,ANGCACTCGA,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,NGCTGGCACG,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,TAGTATGGAA,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,TCGCGATGNT,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,ANAGGANAAC,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,ACAGTAAATG,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,GGTCTAAGAG,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,CGAGGTGCGN,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,ACTGGTCGCT,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,GTTGTTTATT,CGTAGCTCGAAAAAAAAAAAAAAA +TACATATTCTTTACTG,AAGGCACTTT,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,TGACATCAAC,CGTACGTAGCCTAGCAAAAAAAAA +TACATATTCTTTACTG,TGCAGAAANG,CGTCGTAGCTGATCGAAAAAAAAA +TACATATTCTTTACTG,CTTCAANTGA,CGTAGCTCGAAAAAAAAAAAAAAA diff --git a/tests/test_processing.py b/tests/test_processing.py index e24e43e..ff59c8d 100644 --- a/tests/test_processing.py +++ b/tests/test_processing.py @@ -1,7 +1,7 @@ import pytest import random import copy -from collections import Counter +from collections import Counter, namedtuple from cite_seq_count import processing from cite_seq_count import preprocessing @@ -101,6 +101,17 @@ def data(): pytest.tags_tuple = preprocessing.check_tags( preprocessing.parse_tags_csv("tests/test_data/tags/pass/correct.csv"), 5 )[0] + pytest.mapping_input = namedtuple( + "mapping_input", + ["filename", "tags", "debug", "maximum_distance", "sliding_window"], + ) + pytest.mappint_input_test = pytest.mapping_input( + filename=pytest.file_path, + tags=pytest.tags_tuple, + debug=pytest.debug, + maximum_distance=pytest.maximum_distance, + sliding_window=pytest.sliding_window, + ) @pytest.mark.dependency() @@ -160,15 +171,8 @@ def test_find_best_match_with_3_distance_reverse(data): ] ) def test_classify_reads_multi_process(data): - (results, _) = processing.map_reads( - ( - pytest.file_path, - pytest.tags_tuple, - pytest.debug, - pytest.maximum_distance, - pytest.sliding_window, - ) - ) + (results, _) = processing.map_reads(pytest.mappint_input_test) + print(results) assert len(results) == 2 From 507afeb6e2fc6d49035b5bd778fe66c42ce9bfff Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Thu, 24 Dec 2020 17:01:33 +0100 Subject: [PATCH 26/77] fixed procssing and io tests --- tests/test_io.py | 17 +++++++------- tests/test_processing.py | 49 ++++++++++++++++++++++++++-------------- 2 files changed, 41 insertions(+), 25 deletions(-) diff --git a/tests/test_io.py b/tests/test_io.py index 48b1960..d1fa97d 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -1,5 +1,6 @@ import pytest from cite_seq_count import io +from collections import namedtuple @pytest.fixture @@ -11,14 +12,14 @@ def data(): test_matrix[1, 1] = 1 pytest.sparse_matrix = test_matrix pytest.filtered_cells = set(["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]) - pytest.ordered_tags_map = OrderedDict( - { - "test3": {"id": 0, "sequence": "CGTA"}, - "test2": {"id": 1, "sequence": "CGTA"}, - "test1": {"id": 3, "sequence": "CGTA"}, - "unmapped": {"id": 4, "sequence": "CGTA"}, - } - ) + tag = namedtuple("tag", ["name", "sequence", "id"]) + pytest.ordered_tags_map = [ + tag(name="test1", sequence="CGTA", id=0), + tag(name="test2", sequence="CGTA", id=1), + tag(name="test3", sequence="CGTA", id=2), + tag(name="unmapped", sequence="UNKNOWN", id=3), + ] + pytest.data_type = "umi" pytest.outfolder = "tests/test_data/" diff --git a/tests/test_processing.py b/tests/test_processing.py index ff59c8d..0053c03 100644 --- a/tests/test_processing.py +++ b/tests/test_processing.py @@ -63,13 +63,13 @@ def data(): pytest.file_path = "tests/test_data/fastq/test_csv.csv" pytest.chunk_size = 800 - pytest.tags = OrderedDict( - { - "test2": {"id": 0, "sequence": "CGTACGTAGCCTAGC"}, - "test1": {"id": 1, "sequence": "CGTAGCTCG"}, - "unmapped": {"id": 2, "sequence": "UNKNOWN"}, - } - ) + tag = namedtuple("tag", ["name", "sequence", "id"]) + pytest.tags = [ + tag(name="test1", sequence="CGTACGTAGCCTAGC", id=0), + tag(name="test2", sequence="CGTAGCTCG", id=1), + tag(name="unmapped", sequence="UNKNOWN", id=3), + ] + pytest.barcode_slice = slice(0, 16) pytest.umi_slice = slice(16, 26) pytest.correct_whitelist = set(["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]) @@ -79,15 +79,15 @@ def data(): pytest.maximum_distance = 5 pytest.results = { "ACTGTTTTATTGGCCT": { - "test1": Counter({b"CATTAGTGGT": 3, b"CATTAGTGGG": 2, b"CATTCGTGGT": 1}) + 0: Counter({b"CATTAGTGGT": 3, b"CATTAGTGGG": 2, b"CATTCGTGGT": 1}) }, "TTCATAAGGTAGGGAT": { - "test2": Counter({b"TAGCTTAGTA": 3, b"TAGCTTAGTC": 2, b"GCGATGCATA": 1}) + 1: Counter({b"TAGCTTAGTA": 3, b"TAGCTTAGTC": 2, b"GCGATGCATA": 1}) }, } pytest.corrected_results = { - "ACTGTTTTATTGGCCT": {"test1": Counter({b"CATTAGTGGT": 6})}, - "TTCATAAGGTAGGGAT": {"test2": Counter({b"TAGCTTAGTA": 5, b"GCGATGCATA": 1})}, + "ACTGTTTTATTGGCCT": {0: Counter({b"CATTAGTGGT": 6})}, + "TTCATAAGGTAGGGAT": {1: Counter({b"TAGCTTAGTA": 5, b"GCGATGCATA": 1})}, } pytest.umis_per_cell = Counter({"ACTGTTTTATTGGCCT": 1, "TTCATAAGGTAGGGAT": 2}) pytest.reads_per_cell = Counter({"ACTGTTTTATTGGCCT": 3, "TTCATAAGGTAGGGAT": 6}) @@ -205,17 +205,32 @@ def test_correct_cells(data): @pytest.mark.dependency(depends=["test_correct_umis"]) -def test_generate_sparse_matrices(data): - (umi_results_matrix, read_results_matrix) = processing.generate_sparse_matrices( +def test_generate_sparse_umi_matrices(data): + umi_results_matrix = processing.generate_sparse_matrices( pytest.corrected_results, pytest.tags, set(["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]), + umi_counts=True, ) assert umi_results_matrix.shape == (3, 2) + total_umis = 0 + for i in range(umi_results_matrix.shape[0]): + for j in range(umi_results_matrix.shape[1]): + total_umis += umi_results_matrix[i, j] + assert total_umis == 3 + + +@pytest.mark.dependency(depends=["test_correct_umis"]) +def test_generate_sparse_read_matrices(data): + read_results_matrix = processing.generate_sparse_matrices( + pytest.corrected_results, + pytest.tags, + set(["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]), + umi_counts=False, + ) assert read_results_matrix.shape == (3, 2) - read_results_matrix = read_results_matrix.tocsr() - total_reads = 0 + total_umis = 0 for i in range(read_results_matrix.shape[0]): for j in range(read_results_matrix.shape[1]): - total_reads += read_results_matrix[i, j] - assert total_reads == 12 + total_umis += read_results_matrix[i, j] + assert total_umis == 12 From e507c78f96e90498e7894395f8ed4eed728915f9 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Fri, 25 Dec 2020 11:27:12 +0100 Subject: [PATCH 27/77] fixed mapping --- cite_seq_count/__main__.py | 21 ++++++++++----------- cite_seq_count/processing.py | 15 ++++++++++----- setup.py | 2 +- 3 files changed, 21 insertions(+), 17 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 99b23f3..8c42c9a 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -108,7 +108,7 @@ def main(): read1_paths=read1_paths, read2_paths=read2_paths, R2_min_length=R2_min_length, - n_reads_to_chunk=n_reads, + n_reads_per_chunk=n_reads, chemistry_def=chemistry_def, ordered_tags=ordered_tags, maximum_distance=maximum_distance, @@ -310,17 +310,16 @@ def main(): ordered_tags=ordered_tags, filtered_cells=aberrant_cells, ) - - # Write uncorrected cells to dense output - io.write_dense( - sparse_matrix=umi_aberrant_matrix, - ordered_tags=ordered_tags, - columns=aberrant_cells, - outfolder=os.path.join(args.outfolder, "uncorrected_cells"), - filename="dense_umis.tsv", - ) + if len(umi_aberrant_matrix) > 0: + # Write uncorrected cells to dense output + io.write_dense( + sparse_matrix=umi_aberrant_matrix, + ordered_tags=ordered_tags, + columns=aberrant_cells, + outfolder=os.path.join(args.outfolder, "uncorrected_cells"), + filename="dense_umis.tsv", + ) # delete the last element (unmapped) - ordered_tags.pop() umi_results_matrix = processing.generate_sparse_matrices( final_results=final_results, ordered_tags=ordered_tags, diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 503eea5..723a536 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -42,7 +42,7 @@ def find_best_match(TAG_seq, tags, maximum_distance): Returns: best_match (string): The TAG name that will be used for counting. """ - best_match = len(tags) - 1 + best_match = len(tags) best_score = maximum_distance for tag in tags: # pylint: disable=no-member @@ -102,6 +102,8 @@ def map_reads(mapping_input): no_match = Counter() n = 1 t = time.time() + unmapped_id = len(tags) - 1 + del tags[-1] # Progress info with open(filename, "r") as input_file: reads = csv.reader(input_file) @@ -130,7 +132,7 @@ def map_reads(mapping_input): results[cell_barcode][best_match][UMI] += 1 - if best_match == "unmapped": + if best_match == unmapped_id: no_match[read2] += 1 if debug: @@ -144,7 +146,7 @@ def map_reads(mapping_input): len(cell_barcode), len(UMI), len(read2), - tags[best_match].id, + tags[best_match].name, ) ) sys.stdout.flush() @@ -455,13 +457,16 @@ def generate_sparse_matrices( """ - unmapped_id = len(ordered_tags) + unmapped_id = len(ordered_tags) - 1 + if umi_counts: + del ordered_tags[-1] results_matrix = sparse.dok_matrix( (len(ordered_tags), len(filtered_cells)), dtype=int32 ) # print(ordered_tags) + for i, cell_barcode in enumerate(filtered_cells): - if cell_barcode not in final_results: + if cell_barcode not in final_results.keys(): continue for TAG_id in final_results[cell_barcode]: # if TAG_id in final_results[cell_barcode]: diff --git a/setup.py b/setup.py index 0eebf17..02bc519 100644 --- a/setup.py +++ b/setup.py @@ -23,7 +23,7 @@ "scipy>=1.1.0", "multiprocess>=0.70.6.1", "umi_tools==1.0.0", - "pytest==4.1.0", + "pytest>=6.0.0", "pytest-dependency==0.4.0", "pandas>=0.23.4", "pybktree==1.1", From f309737a42e9ed491eddc00fadc8600ca6d9ce35 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Fri, 25 Dec 2020 11:41:21 +0100 Subject: [PATCH 28/77] updated changelog --- CHANGELOG.md | 1 + 1 file changed, 1 insertion(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index c62318e..aac0a3a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -29,6 +29,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Tags csv file now requires a header with at least "sequence" and "feature_name". - Updated tags file parsing to make it more reliable. - Added new tests to help out contributions. + - If no clustered cells found, the dense output matrix for that will not be written. ### Removed - Unmmapped reads are not umi corrected anymore reducing run time and memory usage. From 47cf4e46342c288f6051c12f1e962ce2edb44a72 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Mon, 28 Dec 2020 10:13:13 +0100 Subject: [PATCH 29/77] changed all whitelist to reference list --- CHANGELOG.md | 8 ++-- cite_seq_count/__main__.py | 28 ++++++----- cite_seq_count/argsparser.py | 12 ++--- cite_seq_count/chemistry.py | 34 +++++++------- cite_seq_count/io.py | 14 +++++- cite_seq_count/preprocessing.py | 47 ++++++++++++++----- cite_seq_count/processing.py | 43 +++++++++-------- .../correct.csv | 1 + tests/test_io.py | 4 +- tests/test_preprocessing.py | 13 +++-- tests/test_processing.py | 11 ++--- 11 files changed, 126 insertions(+), 89 deletions(-) rename tests/test_data/{whitelists => reference_lists}/correct.csv (76%) diff --git a/CHANGELOG.md b/CHANGELOG.md index aac0a3a..8434d65 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,14 +13,14 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - (BETA) New functionnality that will fetch the chemistry definition from a remote repo to simplify usage and reduce human errors. ### Changed - - The `features.csv` now has different columns for the tag name and the tag sequence. This keeps the relevant information + - The `features.tsv` now has different columns for the tag name and the tag sequence. This keeps the relevant information in the output files as well as simplifies reading the mtx format when processing the data. - - The mapping step has been changed. It will first write the reads to files and then read in the chunks. + - The mapping step has been changed. It will first write chunks of reads to files and then read in the chunks in each child process. This should solve the io bottleneck from before. - There are new options now for parallel computing. `--chunk_size` Determines how many reads will be read per chunk. #99 - `--sliding-window` now only checks for exact matches. - Added cython dependency based on issue #117 - - The main results dict will now use an `int` as keys reducing memory footprint. + - The main results dict now uses an `int` as keys reducing memory footprint. - Fixed the issue #92 with using `--bc_collapsing_dist 0`. - Fixed issue #122 and now properly checks number of files. - Fixed the error in the documentation pointed by issue #132. @@ -29,7 +29,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Tags csv file now requires a header with at least "sequence" and "feature_name". - Updated tags file parsing to make it more reliable. - Added new tests to help out contributions. - - If no clustered cells found, the dense output matrix for that will not be written. + - If no clustered cells found, the dense output matrix will not be written. ### Removed - Unmmapped reads are not umi corrected anymore reducing run time and memory usage. diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 8c42c9a..6daaefd 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -45,7 +45,7 @@ def main(): assert os.access(args.temp_path, os.W_OK) # Get chemistry defs - (whitelist, chemistry_def) = chemistry.setup_chemistry(args) + (reference_dict, chemistry_def) = chemistry.setup_chemistry(args) # Load TAGs/ABs. ab_map = preprocessing.parse_tags_csv(args.tags) @@ -130,6 +130,7 @@ def main(): error_callback=errors.append, ) mapping.wait() + pool.close() pool.join() if len(errors) != 0: @@ -170,7 +171,7 @@ def main(): bcs_corrected = 0 else: print("Correcting cell barcodes") - if not whitelist: + if not reference_dict: ( final_results, umis_per_cell, @@ -188,10 +189,10 @@ def main(): final_results, umis_per_cell, bcs_corrected, - ) = processing.correct_cells_whitelist( + ) = processing.correct_cells_reference_list( final_results=final_results, umis_per_cell=umis_per_cell, - whitelist=whitelist, + reference_list=set(reference_dict.keys()), collapsing_threshold=args.bc_threshold, ab_map=ordered_tags, ) @@ -199,11 +200,11 @@ def main(): print("Skipping cell barcode correction") bcs_corrected = 0 - # If given, use whitelist for top cells + # If given, use reference_list for top cells top_cells_tuple = umis_per_cell.most_common(args.expected_cells * 10) - if whitelist: + if reference_dict: # Add potential missing cell barcodes. - # for missing_cell in whitelist: + # for missing_cell in reference_list: # if missing_cell in final_results: # continue # else: @@ -211,14 +212,15 @@ def main(): # for TAG in ordered_tags: # final_results[missing_cell][TAG.safe_name] = Counter() # filtered_cells.add(missing_cell) - top_cells = set([pair[0] for pair in top_cells_tuple]) - filtered_cells = set() + top_cells = [pair[0] for pair in top_cells_tuple] + filtered_cells = [] for cell in top_cells: - if cell in whitelist: - filtered_cells.add(cell) + # pylint: disable=no-member + if cell in reference_dict.keys(): + filtered_cells.append(cell) else: # Select top cells based on total umis per cell - filtered_cells = set([pair[0] for pair in top_cells_tuple]) + filtered_cells = [pair[0] for pair in top_cells_tuple] # Create sparse matrices for reads results read_results_matrix = processing.generate_sparse_matrices( @@ -233,6 +235,7 @@ def main(): ordered_tags=ordered_tags, data_type="read", outfolder=args.outfolder, + reference_dict=reference_dict, ) # UMI correction @@ -334,6 +337,7 @@ def main(): ordered_tags=ordered_tags, data_type="umi", outfolder=args.outfolder, + reference_dict=reference_dict, ) # Write unmapped sequences diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py index 034d9e9..6dcd12b 100644 --- a/cite_seq_count/argsparser.py +++ b/cite_seq_count/argsparser.py @@ -143,7 +143,7 @@ def get_args(): ) # Cells group cells = parser.add_argument_group( - "Cells", description=("Expected number of cells and potential whitelist") + "Cells", description=("Expected number of cells and potential reference_list") ) cells.add_argument( @@ -158,12 +158,12 @@ def get_args(): if "--chemistry" not in sys.argv: cells.add_argument( "-wl", - "--whitelist", - dest="whitelist", + "--reference_list", + dest="reference_list", required=False, type=str, help=( - "A csv file containning a whitelist of barcodes produced" + "A csv file containning a reference_list of barcodes produced" " by the mRNA data.\n\n" "\tExample:\n" "\tATGCTAGTGCTA\n\tGCTAGTCAGGAT\n\tCGACTGCTAACG\n\n" @@ -177,8 +177,8 @@ def get_args(): required=False, type=str, help="A csv file containing the mapping between two sets of cell barcode list.\n" - "A required header such as the reference is named whitelist. Example:\n\n" - "\twhitelist,translation\n" + "A required header such as the reference is named reference_list. Example:\n\n" + "\treference_list,translation\n" "\tAAACCCAAGAAACACT,AAACCCATCAAACACT\n" "\tAAACCCAAGAAACCAT,AAACCCATCAAACCAT\n" "\nThe output matrix will possess both cell barcode IDs", diff --git a/cite_seq_count/chemistry.py b/cite_seq_count/chemistry.py index f279e49..d3c051a 100644 --- a/cite_seq_count/chemistry.py +++ b/cite_seq_count/chemistry.py @@ -29,7 +29,7 @@ class Chemistry: umi_barcode_start: int umi_barcode_end: int R2_trim_start: int - whitelist_path: str + reference_list_path: str mapping_required: bool @@ -90,17 +90,19 @@ def get_chemistry_definition(chemistry_short_name): """ chemistry_defs = fetch_definitions()[chemistry_short_name] - if chemistry_defs["whitelist"]["path"] not in DEFINITIONS_DB.registry: + if chemistry_defs["reference_list"]["path"] not in DEFINITIONS_DB.registry: path = pooch.retrieve( url=os.path.join( - GLOBAL_LINK_GITHUB, "chemistries", chemistry_defs["whitelist"]["path"] + GLOBAL_LINK_GITHUB, + "chemistries", + chemistry_defs["reference_list"]["path"], ), known_hash=None, - fname=chemistry_defs["whitelist"]["path"], + fname=chemistry_defs["reference_list"]["path"], path=DEFINITIONS_DB.abspath, ) else: - path = DEFINITIONS_DB.registry[chemistry_defs["whitelist"]["path"]] + path = DEFINITIONS_DB.registry[chemistry_defs["reference_list"]["path"]] chemistry_def = Chemistry( name=chemistry_short_name, cell_barcode_start=chemistry_defs["barcode_structure_indexes"]["cell_barcode"][ @@ -116,8 +118,8 @@ def get_chemistry_definition(chemistry_short_name): "R1" ]["stop"], R2_trim_start=chemistry_defs["sequence_structure_indexes"]["R2"]["start"] - 1, - whitelist_path=path, - mapping_required=chemistry_defs["whitelist"]["mapping"], + reference_list_path=path, + mapping_required=chemistry_defs["reference_list"]["mapping"], ) return chemistry_def @@ -130,7 +132,7 @@ def create_chemistry_definition(args): umi_barcode_start=args.umi_first, umi_barcode_end=args.umi_last, R2_trim_start=args.start_trim, - whitelist_path=args.whitelist, + reference_list_path=args.reference_list, mapping_required=args.translation, ) return chemistry_def @@ -139,20 +141,20 @@ def create_chemistry_definition(args): def setup_chemistry(args): if args.chemistry: chemistry_def = get_chemistry_definition(args.chemistry) - whitelist = preprocessing.parse_whitelist_csv( - filename=chemistry_def.whitelist_path, + reference_dict = preprocessing.parse_reference_list_csv( + filename=chemistry_def.reference_list_path, barcode_length=chemistry_def.cell_barcode_end - chemistry_def.cell_barcode_start + 1, ) else: chemistry_def = create_chemistry_definition(args) - if args.whitelist: - print("Loading whitelist") - whitelist = preprocessing.parse_whitelist_csv( - filename=args.whitelist, + if args.reference_list: + print("Loading reference_list") + reference_dict = preprocessing.parse_reference_list_csv( + filename=args.reference_list, barcode_length=args.cb_last - args.cb_first + 1, ) else: - whitelist = False - return (whitelist, chemistry_def) + reference_dict = False + return (reference_dict, chemistry_def) diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index d8c45cf..435fe76 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -10,10 +10,13 @@ import pandas as pd from scipy import io +import numpy as np from cite_seq_count import secondsToText -def write_to_files(sparse_matrix, filtered_cells, ordered_tags, data_type, outfolder): +def write_to_files( + sparse_matrix, filtered_cells, ordered_tags, data_type, outfolder, reference_dict +): """Write the umi and read sparse matrices to file in gzipped mtx format. Args: @@ -28,12 +31,19 @@ def write_to_files(sparse_matrix, filtered_cells, ordered_tags, data_type, outfo io.mmwrite(os.path.join(prefix, "matrix.mtx"), sparse_matrix) with gzip.open(os.path.join(prefix, "barcodes.tsv.gz"), "wb") as barcode_file: for barcode in filtered_cells: - barcode_file.write("{}\n".format(barcode).encode()) + if reference_dict[barcode] != 0: + barcode_file.write( + "{}\t{}\n".format(barcode, reference_dict[barcode]).encode(), + ) + else: + barcode_file.write("{}\n".format(barcode).encode()) with gzip.open(os.path.join(prefix, "features.tsv.gz"), "wb") as feature_file: for feature in ordered_tags: feature_file.write( "{}\t{}\n".format(feature.sequence, feature.name).encode() ) + if data_type == "read": + feature_file.write("{}\t{}\n".format("UNKNOWN", "unmapped").encode()) with open(os.path.join(prefix, "matrix.mtx"), "rb") as mtx_in: with gzip.open(os.path.join(prefix, "matrix.mtx") + ".gz", "wb") as mtx_gz: shutil.copyfileobj(mtx_in, mtx_gz) diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 6096d2c..5c097ff 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -12,14 +12,14 @@ from itertools import islice -def parse_whitelist_csv(filename, barcode_length): +def parse_reference_list_csv(filename, barcode_length): """Reads white-listed barcodes from a CSV file. The function accepts plain barcodes or even 10X style barcodes with the `-1` at the end of each barcode. Args: - filename (str): Whitelist barcode file. + filename (str): reference_list barcode file. barcode_length (int): Length of the expected barcodes. Returns: @@ -27,6 +27,9 @@ def parse_whitelist_csv(filename, barcode_length): """ STRIP_CHARS = '"0123456789- \t\n' + REQUIRED_HEADER = ["reference"] + # OPTIONAL_HEADER = ["translation"] + cell_pattern = regex.compile(r"[ATGC]{{{}}}".format(barcode_length)) if filename.endswith(".gz"): @@ -35,22 +38,42 @@ def parse_whitelist_csv(filename, barcode_length): else: f = open(filename, encoding="UTF-8") csv_reader = csv.reader(f) - whitelist = [ - row[0].strip(STRIP_CHARS) - for row in csv_reader - if (len(row[0].strip(STRIP_CHARS)) == barcode_length) - ] - for cell_barcode in whitelist: + header = next(csv_reader) + set_dif = set(REQUIRED_HEADER) - set(header) + if len(set_dif) != 0: + raise SystemExit( + "The header is missing {}. Exiting".format(",".join(list(set_dif))) + ) + + reference_id = header.index("reference") + reference_dict = {} + if "translation" in header: + + translation_id = header.index("translation") + for row in csv_reader: + ref_barcode = row[reference_id].strip(STRIP_CHARS) + tra_barcode = row[translation_id].strip(STRIP_CHARS) + if ( + len(ref_barcode) == barcode_length + and len(tra_barcode) == barcode_length + ): + reference_dict[ref_barcode] = tra_barcode + else: + for row in csv_reader: + ref_barcode = row[reference_id].strip(STRIP_CHARS) + if len(ref_barcode) == barcode_length: + reference_dict[ref_barcode] = 0 + for cell_barcode in reference_dict.keys(): if not cell_pattern.match(cell_barcode): sys.exit( "This barcode {} is not only composed of ATGC bases.".format( cell_barcode ) ) - if len(whitelist) == 0: - sys.exit("Whitelist is empty.") - return set(whitelist) + if len(reference_dict) == 0: + sys.exit("reference_dict is empty.") + return reference_dict def parse_tags_csv(filename): @@ -128,7 +151,7 @@ def check_tags(tags, maximum_distance): if len(tag_seq) > longest_tag_len: longest_tag_len = len(tag_seq) seq_list.append(tag_seq) - tag_list.append(tag(name="unmapped", sequence="UNKNOWN", id=i + 1,)) + # tag_list.append(tag(name="unmapped", sequence="UNKNOWN", id=i + 1,)) # If only one TAG is provided, then no distances to compare. if len(tags) == 1: return (tag_list, longest_tag_len) diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 723a536..1d17b32 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -103,7 +103,6 @@ def map_reads(mapping_input): n = 1 t = time.time() unmapped_id = len(tags) - 1 - del tags[-1] # Progress info with open(filename, "r") as input_file: reads = csv.reader(input_file) @@ -343,7 +342,7 @@ def correct_cells( umis_per_cell (Counter): Counter of umis per cell after cell barcode correction corrected_umis (int): How many umis have been corrected. """ - print("Looking for a whitelist") + print("Looking for a reference_list") _, true_to_false = whitelist_methods.getCellWhitelist( cell_barcode_counts=reads_per_cell, expect_cells=expected_cells, @@ -361,16 +360,16 @@ def correct_cells( return (final_results, umis_per_cell, corrected_barcodes) -def correct_cells_whitelist( - final_results, umis_per_cell, whitelist, collapsing_threshold, ab_map +def correct_cells_reference_list( + final_results, umis_per_cell, reference_list, collapsing_threshold, ab_map ): """ - Corrects cell barcodes. + Corrects cell barcodes based on a given reference_list. Args: final_results (dict): Dict of dict of Counters with mapping results. umis_per_cell (Counter): Counter of UMIs per cell. - whitelist (set): The whitelist reference given by the user. + reference_list (set): The reference_list reference given by the user. collapsing_threshold (int): Max distance between umis. ab_map (OrederedDict): Tags in an ordered dict. @@ -380,18 +379,18 @@ def correct_cells_whitelist( umis_per_cell (Counter): Updated UMI counts after correction. corrected_barcodes (int): How many umis have been corrected. """ + print("Generating barcode tree from reference list") # pylint: disable=no-member - barcode_tree = pybktree.BKTree(Levenshtein.hamming, whitelist) - print("Generated barcode tree from whitelist") + barcode_tree = pybktree.BKTree(Levenshtein.hamming, reference_list) barcodes = set(final_results.keys()) - print("Finding reference candidates") + print("Selecting reference candidates") print("Processing {:,} cell barcodes".format(len(barcodes))) # Run with one process true_to_false = find_true_to_false_map( barcode_tree=barcode_tree, cell_barcodes=barcodes, - whitelist=whitelist, + reference_list=reference_list, collapsing_threshold=collapsing_threshold, ) print("Collapsing wrong barcodes with original barcodes") @@ -402,7 +401,7 @@ def correct_cells_whitelist( def find_true_to_false_map( - barcode_tree, cell_barcodes, whitelist, collapsing_threshold + barcode_tree, cell_barcodes, reference_list, collapsing_threshold ): """ Creates a mapping between "fake" cell barcodes and their original true barcode. @@ -410,7 +409,7 @@ def find_true_to_false_map( Args: barcode_tree (BKTree): BKTree of all original cell barcodes. cell_barcodes (List): Cell barcodes to go through. - whitelist (Set): Set of the whitelist, the "true" cell barcodes. + reference_list (Set): Set of the reference_list, the "true" cell barcodes. collasping_threshold (int): How many mistakes to correct. Return: @@ -418,10 +417,10 @@ def find_true_to_false_map( """ true_to_false = defaultdict(list) for cell_barcode in cell_barcodes: - if cell_barcode in whitelist: - # if the barcode is already whitelisted, no need to add + if cell_barcode in reference_list: + # if the barcode is already reference_listed, no need to add continue - # get all members of whitelist that are at distance of collapsing_threshold + # get all members of reference_list that are at distance of collapsing_threshold candidates = [ white_cell for d, white_cell in barcode_tree.find(cell_barcode, collapsing_threshold) @@ -431,12 +430,12 @@ def find_true_to_false_map( white_cell_str = candidates[0] true_to_false[white_cell_str].append(cell_barcode) elif len(candidates) == 0: - # the cell doesnt match to any whitelisted barcode, + # the cell doesnt match to any reference_listed barcode, # hence we have to drop it # (as it cannot be asscociated with any frequent barcode) continue else: - # more than on whitelisted candidate: + # more than on reference_listed candidate: # we drop it as its not uniquely assignable continue return true_to_false @@ -457,12 +456,12 @@ def generate_sparse_matrices( """ - unmapped_id = len(ordered_tags) - 1 + unmapped_id = len(ordered_tags) if umi_counts: - del ordered_tags[-1] - results_matrix = sparse.dok_matrix( - (len(ordered_tags), len(filtered_cells)), dtype=int32 - ) + n_features = len(ordered_tags) + else: + n_features = len(ordered_tags) + 1 + results_matrix = sparse.dok_matrix((n_features, len(filtered_cells)), dtype=int32) # print(ordered_tags) for i, cell_barcode in enumerate(filtered_cells): diff --git a/tests/test_data/whitelists/correct.csv b/tests/test_data/reference_lists/correct.csv similarity index 76% rename from tests/test_data/whitelists/correct.csv rename to tests/test_data/reference_lists/correct.csv index 9de44b6..0b5a0fd 100644 --- a/tests/test_data/whitelists/correct.csv +++ b/tests/test_data/reference_lists/correct.csv @@ -1,2 +1,3 @@ +reference ACTGTTTTATTGGCCT TTCATAAGGTAGGGAT \ No newline at end of file diff --git a/tests/test_io.py b/tests/test_io.py index d1fa97d..7245518 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -11,7 +11,7 @@ def data(): test_matrix = sparse.dok_matrix((4, 2)) test_matrix[1, 1] = 1 pytest.sparse_matrix = test_matrix - pytest.filtered_cells = set(["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]) + pytest.filtered_cells = ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"] tag = namedtuple("tag", ["name", "sequence", "id"]) pytest.ordered_tags_map = [ tag(name="test1", sequence="CGTA", id=0), @@ -28,12 +28,14 @@ def test_write_to_files(data, tmpdir): import gzip import scipy + reference_dict = {"ACTGTTTTATTGGCCT": 0, "TTCATAAGGTAGGGAT": 0} io.write_to_files( pytest.sparse_matrix, pytest.filtered_cells, pytest.ordered_tags_map, pytest.data_type, tmpdir, + reference_dict=reference_dict, ) file = tmpdir.join("umi_count/matrix.mtx.gz") with gzip.open(file, "rb") as mtx_file: diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 910a680..dcd2a97 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -14,7 +14,7 @@ def data(): pytest.failing_csv = "tests/test_data/tags/fail/*.csv" # Test file paths - pytest.correct_whitelist_path = "tests/test_data/whitelists/correct.csv" + pytest.correct_reference_list_path = "tests/test_data/reference_lists/correct.csv" pytest.correct_tags_path = "tests/test_data/tags/pass/correct.csv" pytest.correct_R1_path = "tests/test_data/fastq/correct_R1.fastq.gz" pytest.correct_R2_path = "tests/test_data/fastq/correct_R2.fastq.gz" @@ -33,7 +33,7 @@ def data(): ) # Create some variables to compare to - pytest.correct_whitelist = set(["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]) + pytest.correct_reference_list = set(["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]) pytest.correct_tags = { "AGGACCATCCAA": "CITE_LEN_12_1", "ACATGTTACCGT": "CITE_LEN_12_2", @@ -75,11 +75,10 @@ def test_csv_parser(data): @pytest.mark.dependency() -def test_parse_whitelist_csv(data): - assert preprocessing.parse_whitelist_csv(pytest.correct_whitelist_path, 16) in ( - pytest.correct_whitelist, - 1, - ) +def test_parse_reference_list_csv(data): + assert preprocessing.parse_reference_list_csv( + pytest.correct_reference_list_path, 16 + ).keys() in (pytest.correct_reference_list, 1,) @pytest.mark.dependency() diff --git a/tests/test_processing.py b/tests/test_processing.py index 0053c03..3c97288 100644 --- a/tests/test_processing.py +++ b/tests/test_processing.py @@ -67,12 +67,11 @@ def data(): pytest.tags = [ tag(name="test1", sequence="CGTACGTAGCCTAGC", id=0), tag(name="test2", sequence="CGTAGCTCG", id=1), - tag(name="unmapped", sequence="UNKNOWN", id=3), ] pytest.barcode_slice = slice(0, 16) pytest.umi_slice = slice(16, 26) - pytest.correct_whitelist = set(["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]) + pytest.correct_reference_list = set(["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]) pytest.legacy = False pytest.debug = False pytest.start_trim = 0 @@ -143,8 +142,6 @@ def test_find_best_match_with_3_distance(data): distance = 3 for tag in pytest.tags_tuple: counts = Counter() - if tag.name == "unmapped": - continue for seq in extend_seq_pool(tag.sequence, distance): counts[processing.find_best_match(seq, pytest.tags_tuple, distance)] += 1 assert counts[tag.id] == 4 @@ -209,10 +206,10 @@ def test_generate_sparse_umi_matrices(data): umi_results_matrix = processing.generate_sparse_matrices( pytest.corrected_results, pytest.tags, - set(["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]), + ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"], umi_counts=True, ) - assert umi_results_matrix.shape == (3, 2) + assert umi_results_matrix.shape == (2, 2) total_umis = 0 for i in range(umi_results_matrix.shape[0]): for j in range(umi_results_matrix.shape[1]): @@ -225,7 +222,7 @@ def test_generate_sparse_read_matrices(data): read_results_matrix = processing.generate_sparse_matrices( pytest.corrected_results, pytest.tags, - set(["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]), + ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"], umi_counts=False, ) assert read_results_matrix.shape == (3, 2) From 1f8f0b6d56bbe4759377cca85e05ffce4fb0c190 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Mon, 28 Dec 2020 10:43:20 +0100 Subject: [PATCH 30/77] some more tests for reference_lists --- .../reference_lists/fail/bad_barcode.csv | 3 +++ .../fail/duplicated_barcodes.csv | 3 +++ .../reference_lists/fail/missing_header.csv | 2 ++ .../fail/wrong_barcode_length.csv | 3 +++ .../{correct.csv => pass/simple_ref.csv} | 0 .../reference_lists/pass/translation.csv | 3 +++ tests/test_preprocessing.py | 20 ++++++++++++------- tests/test_processing.py | 2 +- 8 files changed, 28 insertions(+), 8 deletions(-) create mode 100644 tests/test_data/reference_lists/fail/bad_barcode.csv create mode 100644 tests/test_data/reference_lists/fail/duplicated_barcodes.csv create mode 100644 tests/test_data/reference_lists/fail/missing_header.csv create mode 100644 tests/test_data/reference_lists/fail/wrong_barcode_length.csv rename tests/test_data/reference_lists/{correct.csv => pass/simple_ref.csv} (100%) create mode 100644 tests/test_data/reference_lists/pass/translation.csv diff --git a/tests/test_data/reference_lists/fail/bad_barcode.csv b/tests/test_data/reference_lists/fail/bad_barcode.csv new file mode 100644 index 0000000..d2636f3 --- /dev/null +++ b/tests/test_data/reference_lists/fail/bad_barcode.csv @@ -0,0 +1,3 @@ +reference +ACTGTTTTXTTGGCCT +TTCATAAGGTAGGGAT \ No newline at end of file diff --git a/tests/test_data/reference_lists/fail/duplicated_barcodes.csv b/tests/test_data/reference_lists/fail/duplicated_barcodes.csv new file mode 100644 index 0000000..c52c85b --- /dev/null +++ b/tests/test_data/reference_lists/fail/duplicated_barcodes.csv @@ -0,0 +1,3 @@ +reference +TTCATAAGGTAGGGAT +TTCATAAGGTAGGGAT \ No newline at end of file diff --git a/tests/test_data/reference_lists/fail/missing_header.csv b/tests/test_data/reference_lists/fail/missing_header.csv new file mode 100644 index 0000000..9de44b6 --- /dev/null +++ b/tests/test_data/reference_lists/fail/missing_header.csv @@ -0,0 +1,2 @@ +ACTGTTTTATTGGCCT +TTCATAAGGTAGGGAT \ No newline at end of file diff --git a/tests/test_data/reference_lists/fail/wrong_barcode_length.csv b/tests/test_data/reference_lists/fail/wrong_barcode_length.csv new file mode 100644 index 0000000..7332848 --- /dev/null +++ b/tests/test_data/reference_lists/fail/wrong_barcode_length.csv @@ -0,0 +1,3 @@ +reference +ACTGTTTTATTGGC +TTCATAAGGTAGGG \ No newline at end of file diff --git a/tests/test_data/reference_lists/correct.csv b/tests/test_data/reference_lists/pass/simple_ref.csv similarity index 100% rename from tests/test_data/reference_lists/correct.csv rename to tests/test_data/reference_lists/pass/simple_ref.csv diff --git a/tests/test_data/reference_lists/pass/translation.csv b/tests/test_data/reference_lists/pass/translation.csv new file mode 100644 index 0000000..9c151a9 --- /dev/null +++ b/tests/test_data/reference_lists/pass/translation.csv @@ -0,0 +1,3 @@ +reference,translation +ACTGTTTTATTGGCCT,ACTGTTTTATTGGCCT +TTCATAAGGTAGGGAT,TTCATCCTTTAGGGAT \ No newline at end of file diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index dcd2a97..aa3dcf4 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -13,8 +13,9 @@ def data(): pytest.passing_csv = "tests/test_data/tags/pass/*.csv" pytest.failing_csv = "tests/test_data/tags/fail/*.csv" - # Test file paths - pytest.correct_reference_list_path = "tests/test_data/reference_lists/correct.csv" + pytest.passing_reference_list_csv = "tests/test_data/reference_lists/pass/*.csv" + pytest.failing_reference_list_csv = "tests/test_data/reference_lists/fail/*.csv" + pytest.correct_tags_path = "tests/test_data/tags/pass/correct.csv" pytest.correct_R1_path = "tests/test_data/fastq/correct_R1.fastq.gz" pytest.correct_R2_path = "tests/test_data/fastq/correct_R2.fastq.gz" @@ -56,7 +57,6 @@ def data(): tag(name="CITE_LEN_12_1", sequence="AGGACCATCCAA", id=6), tag(name="CITE_LEN_12_2", sequence="ACATGTTACCGT", id=7), tag(name="CITE_LEN_12_3", sequence="AGCTTACTATCC", id=8), - tag(name="unmapped", sequence="UNKNOWN", id=9), ] pytest.barcode_slice = slice(0, 16) pytest.umi_slice = slice(16, 26) @@ -70,15 +70,21 @@ def test_csv_parser(data): with pytest.raises(SystemExit): failing_files = glob.glob(pytest.failing_csv) for file_path in failing_files: - print(file_path) preprocessing.parse_tags_csv(file_path) @pytest.mark.dependency() def test_parse_reference_list_csv(data): - assert preprocessing.parse_reference_list_csv( - pytest.correct_reference_list_path, 16 - ).keys() in (pytest.correct_reference_list, 1,) + passing_files = glob.glob(pytest.passing_reference_list_csv) + for file_path in passing_files: + assert preprocessing.parse_reference_list_csv(file_path, 16).keys() in ( + pytest.correct_reference_list, + 1, + ) + with pytest.raises(SystemExit): + failing_files = glob.glob(pytest.failing_reference_list_csv) + for file_path in failing_files: + preprocessing.parse_reference_list_csv(file_path, 16) @pytest.mark.dependency() diff --git a/tests/test_processing.py b/tests/test_processing.py index 3c97288..459f6bf 100644 --- a/tests/test_processing.py +++ b/tests/test_processing.py @@ -33,7 +33,7 @@ def modify(seq, n, modification_type): bases = list("ATGCN") positions = list(range(len(seq))) seq = list(seq) - for i in range(n): + for _ in range(n): if modification_type == "mutate": position = random.choice(positions) positions.remove(position) From 45185c7bad6abb4eced7b6ec63707065ee026c94 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Mon, 28 Dec 2020 12:06:32 +0100 Subject: [PATCH 31/77] some reformatting --- CHANGELOG.md | 2 + cite_seq_count/__main__.py | 85 +++++---------------------------- cite_seq_count/argsparser.py | 17 ++----- cite_seq_count/chemistry.py | 3 -- cite_seq_count/preprocessing.py | 54 +++++++++++++++++++++ cite_seq_count/processing.py | 46 ++++++++++++++++++ docs/docs/Guidelines.md | 10 +++- docs/docs/Installation.md | 2 +- docs/docs/Reading-the-output.md | 15 ++++-- docs/docs/Running-the-script.md | 33 ++++++++----- docs/docs/index.md | 2 +- 11 files changed, 163 insertions(+), 106 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8434d65..e1e0ae8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,6 +30,8 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Updated tags file parsing to make it more reliable. - Added new tests to help out contributions. - If no clustered cells found, the dense output matrix will not be written. + - Barcode whitelists are now called reference lists. + - The reference list file now requires a header `reference`. There is now an optional column called `translation`. This is specific to chemistries such as 10xV3 that use different barcodes for mRNA and Antibody tag capture sequences. See more details in the documentation. ### Removed - Unmmapped reads are not umi corrected anymore reducing run time and memory usage. diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 6daaefd..c3b3c05 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -50,53 +50,20 @@ def main(): # Load TAGs/ABs. ab_map = preprocessing.parse_tags_csv(args.tags) ordered_tags, longest_tag_len = preprocessing.check_tags(ab_map, args.max_error) - # ordered_tags = preprocessing.convert_to_named_tuple(ordered_tags=ordered_tags) + # Identify input file(s) read1_paths, read2_paths = preprocessing.get_read_paths( args.read1_path, args.read2_path ) + # Checks before chunking. + (n_reads, R2_min_length, maximum_distance) = preprocessing.pre_run_checks( + read1_paths=read1_paths, + chemistry_def=chemistry_def, + longest_tag_len=longest_tag_len, + args=args, + ) - # preprocessing and processing occur in separate loops so the program can crash earlier if - # one of the inputs is not valid. - read1_lengths = [] - read2_lengths = [] - total_reads = 0 - - for read1_path in read1_paths: - n_lines = preprocessing.get_n_lines(read1_path) - total_reads += n_lines / 4 - # Get reads length. So far, there is no validation for Read2. - read1_lengths.append(preprocessing.get_read_length(read1_path)) - # read2_lengths.append(preprocessing.get_read_length(read2_path)) - # Check Read1 length against CELL and UMI barcodes length. - preprocessing.check_barcodes_lengths( - read1_lengths[-1], - chemistry_def.cell_barcode_start, - chemistry_def.cell_barcode_end, - chemistry_def.umi_barcode_start, - chemistry_def.umi_barcode_end, - ) - - # Get all reads or only top N? - if args.first_n < float("inf"): - n_reads = args.first_n - else: - n_reads = total_reads - - # Define R2_lenght to reduce amount of data to transfer to childrens - number_of_samples = len(read1_paths) - - # Print a statement if multiple files are run. - if number_of_samples != 1: - print("Detected {} pairs of files to run on.".format(number_of_samples)) - - if args.sliding_window: - R2_min_length = read2_lengths[0] - maximum_distance = 0 - else: - R2_min_length = longest_tag_len - maximum_distance = args.max_error - + # Chunk the data to disk before mapping ( input_queue, temp_files, @@ -113,39 +80,13 @@ def main(): ordered_tags=ordered_tags, maximum_distance=maximum_distance, ) - # Initialize the counts dicts that will be generated from each input fastq pair - final_results = defaultdict(lambda: defaultdict(Counter)) - umis_per_cell = Counter() - reads_per_cell = Counter() - merged_no_match = Counter() - - print("Started mapping") - parallel_results = [] - pool = Pool(processes=args.n_threads) - errors = [] - mapping = pool.map_async( - processing.map_reads, - input_queue, - callback=parallel_results.append, - error_callback=errors.append, - ) - mapping.wait() - - pool.close() - pool.join() - if len(errors) != 0: - for error in errors: - print(error) - - print("Merging results") + # Map the data ( final_results, umis_per_cell, reads_per_cell, merged_no_match, - ) = processing.merge_results(parallel_results=parallel_results[0]) - - del parallel_results + ) = processing.map_data(input_queue=input_queue, args=args) # Check if 99% of the reads are unmapped. processing.check_unmapped( @@ -154,8 +95,8 @@ def main(): total_reads=total_reads, start_trim=chemistry_def.R2_trim_start, ) - # Delete temp_files - # exit() + + # Remove temp chunks for file_path in temp_files: os.remove(file_path) diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py index 6dcd12b..9e79faa 100644 --- a/cite_seq_count/argsparser.py +++ b/cite_seq_count/argsparser.py @@ -157,33 +157,22 @@ def get_args(): ) if "--chemistry" not in sys.argv: cells.add_argument( - "-wl", + "-rl", "--reference_list", dest="reference_list", required=False, type=str, help=( - "A csv file containning a reference_list of barcodes produced" + "A csv file containning a reference list of barcodes produced" " by the mRNA data.\n\n" "\tExample:\n" + "reference\n" "\tATGCTAGTGCTA\n\tGCTAGTCAGGAT\n\tCGACTGCTAACG\n\n" "Or 10X-style:\n" "\tATGCTAGTGCTA-1\n\tGCTAGTCAGGAT-1\n\tCGACTGCTAACG-1\n" ), ) - cells.add_argument( - "--translation", - required=False, - type=str, - help="A csv file containing the mapping between two sets of cell barcode list.\n" - "A required header such as the reference is named reference_list. Example:\n\n" - "\treference_list,translation\n" - "\tAAACCCAAGAAACACT,AAACCCATCAAACACT\n" - "\tAAACCCAAGAAACCAT,AAACCCATCAAACCAT\n" - "\nThe output matrix will possess both cell barcode IDs", - ) - # FILTERS group. filters = parser.add_argument_group( "TAG filters", description=("Filtering and trimming for read2.") diff --git a/cite_seq_count/chemistry.py b/cite_seq_count/chemistry.py index d3c051a..ada37ec 100644 --- a/cite_seq_count/chemistry.py +++ b/cite_seq_count/chemistry.py @@ -30,7 +30,6 @@ class Chemistry: umi_barcode_end: int R2_trim_start: int reference_list_path: str - mapping_required: bool DEFINITIONS_DB = pooch.create( @@ -119,7 +118,6 @@ def get_chemistry_definition(chemistry_short_name): ]["stop"], R2_trim_start=chemistry_defs["sequence_structure_indexes"]["R2"]["start"] - 1, reference_list_path=path, - mapping_required=chemistry_defs["reference_list"]["mapping"], ) return chemistry_def @@ -133,7 +131,6 @@ def create_chemistry_definition(args): umi_barcode_end=args.umi_last, R2_trim_start=args.start_trim, reference_list_path=args.reference_list, - mapping_required=args.translation, ) return chemistry_def diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 5c097ff..e213b47 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -314,3 +314,57 @@ def get_read_paths(read1_path, read2_path): ) return (_read1_path, _read2_path) + +def pre_run_checks(read1_paths, chemistry_def, longest_tag_len, args): + """ Checks that the chemistry is properly set and defines how many reads to process + + Args: + read1_paths (list): List of paths + chemistry_def (Chemistry): Chemistry definition + longest_tag_len (int): Longest tag sequence + args (argparse): List of arguments + + Returns: + n_reads (int): Number of reads to run on + R2_min_length (int): Min R2 length to check if reads are too short + maximum_distance (int): Maximum error rate allowed for mapping tags + + """ + read1_lengths = [] + read2_lengths = [] + total_reads = 0 + + for read1_path in read1_paths: + n_lines = get_n_lines(read1_path) + total_reads += n_lines / 4 + # Get reads length. So far, there is no validation for Read2. + read1_lengths.append(get_read_length(read1_path)) + + # Check Read1 length against CELL and UMI barcodes length. + check_barcodes_lengths( + read1_lengths[-1], + chemistry_def.cell_barcode_start, + chemistry_def.cell_barcode_end, + chemistry_def.umi_barcode_start, + chemistry_def.umi_barcode_end, + ) + + # Get all reads or only top N? + if args.first_n < float("inf"): + n_reads = args.first_n + else: + n_reads = total_reads + + number_of_samples = len(read1_paths) + + # Print a statement if multiple files are run. + if number_of_samples != 1: + print("Detected {} pairs of files to run on.".format(number_of_samples)) + + if args.sliding_window: + R2_min_length = read2_lengths[0] + maximum_distance = 0 + else: + R2_min_length = longest_tag_len + maximum_distance = args.max_error + return n_reads, R2_min_length, maximum_distance diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 1d17b32..aa461d4 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -478,3 +478,49 @@ def generate_sparse_matrices( final_results[cell_barcode][TAG_id].values() ) return results_matrix + + +def map_data(input_queue, args): + """ + Maps the data given an input_queue + + Args: + input_queue (list): List of parameters to run in parallel + args (argparse): List of arguments + + Returns: + final_results (dict): final dictionnary with results + umis_per_cell (Counter): Counter of UMIs per cell + reads_per_cell (Counter): Counter of reads per cell + merged_no_match (Counter): Counter of unmapped reads + """ + # Initialize the counts dicts that will be generated from each input fastq pair + final_results = defaultdict(lambda: defaultdict(Counter)) + umis_per_cell = Counter() + reads_per_cell = Counter() + merged_no_match = Counter() + + print("Started mapping") + parallel_results = [] + pool = Pool(processes=args.n_threads) + errors = [] + mapping = pool.map_async( + map_reads, + input_queue, + callback=parallel_results.append, + error_callback=errors.append, + ) + mapping.wait() + + pool.close() + pool.join() + if len(errors) != 0: + for error in errors: + print(error) + + print("Merging results") + (final_results, umis_per_cell, reads_per_cell, merged_no_match,) = merge_results( + parallel_results=parallel_results[0] + ) + + return final_results, umis_per_cell, reads_per_cell, merged_no_match diff --git a/docs/docs/Guidelines.md b/docs/docs/Guidelines.md index f400b83..b3ebd45 100644 --- a/docs/docs/Guidelines.md +++ b/docs/docs/Guidelines.md @@ -8,5 +8,13 @@ Guidelines for typical chemistries 10x genomics V3 chemistry for feature barcoding is using a mapping between RNA cell barcodes and Protein cell barcodes. You can find this mapping [here](https://github.com/10XGenomics/cellranger/blob/master/lib/python/cellranger/barcodes/translation/3M-february-2018.txt.gz) -Since the list is composed of ~7M cells instead of the ~3M described in the technologie, using the whitelist of V3 as an input is unwise. +## POST 1.5.0 instructions + +Since version 1.5.0, this is taken care of by CSC if you provide the translation column in the reference `--reference_file` file as described in the documentation. +* The dense output will have the translated barcodes in the header. +* the MTX output will have two columns. The firts column is the original barcode found in the fastqs provided and the second column will be the translated barcode given by the reference list csv. + +## PRE 1.5.0 instructions + +Since the list is composed of ~7M cells instead of the ~3M described in the technologie, using the reference_list of V3 as an input is unwise. I suggest running CITE-seq-Count with using the `-cells` argument instead. \ No newline at end of file diff --git a/docs/docs/Installation.md b/docs/docs/Installation.md index 714ee30..883b860 100644 --- a/docs/docs/Installation.md +++ b/docs/docs/Installation.md @@ -4,7 +4,7 @@ Installation with pip `CITE-seq-Count` is stored on pypi. You can install it using the following command: ``` -pip install CITE-seq-Count==1.4.3 +pip install CITE-seq-Count==1.5.0 ``` diff --git a/docs/docs/Reading-the-output.md b/docs/docs/Reading-the-output.md index 0d8bad2..bc6d1d2 100644 --- a/docs/docs/Reading-the-output.md +++ b/docs/docs/Reading-the-output.md @@ -25,7 +25,7 @@ File descriptions ------------------- * `features.tsv.gz` contains the feature names, in this context our tags. -* `barcodes.tsv.gz` contains the cell barcodes. +* `barcodes.tsv.gz` contains the cell barcodes. If running with a translation, first column is the translated barcode, second column is the original barcode found in the data. * `matrix.mtx.gz` contains the actual values. read_count and umi_count contain respectively the read counts and the collapsed umi counts. For analysis you should use the umi data. The read_count can be used to check if you have an overamplification or oversequencing issue with your protocol. * `unmapped.csv` contains the top N tags that haven't been mapped. @@ -39,6 +39,9 @@ CITE-seq-Count Version: 1.4.3 Reads processed: 1000000 Percentage mapped: 33 Percentage unmapped: 67 +Percentage too short: 0 + R1_too_short: 0 + R2_too_short: 0 Uncorrected cells: 0 Correction: Cell barcodes collapsing threshold: 1 @@ -65,10 +68,14 @@ Packages to read MTX I recommend using `Seurat` and their `Read10x` function to read the results. + With Seurat V3: `Read10x('OUTFOLDER/umi_count/', gene.column=1)` +Version 1.5.0 of CSC came with some breaking changes. Older versions would use this command: + +`Read10x('OUTFOLDER/umi_count/', gene.column=1)` With Matrix: @@ -82,7 +89,8 @@ mat <- readMM(file = matrix.path) feature.names = read.delim(features.path, header = FALSE, stringsAsFactors = FALSE) barcode.names = read.delim(barcode.path, header = FALSE, stringsAsFactors = FALSE) colnames(mat) = barcode.names$V1 -rownames(mat) = feature.names$V1 +rownames(mat) = feature.names$V2 +# rownames(mat) = feature.names$V1 if you are using an older version than 1.5.0 ``` **Python** @@ -101,5 +109,6 @@ data = data.T features = pd.read_csv(os.path.join(path, 'umi_count/features.tsv.gz'), header=None) barcodes = pd.read_csv(os.path.join(path, 'umi_count/barcodes.tsv.gz'), header=None) data.var_names = features[0] -data.obs_names = barcodes[0] +data.obs_names = barcodes[1] +#data.obs_names = barcodes[0] if you are using an older version than 1.5.0 ``` diff --git a/docs/docs/Running-the-script.md b/docs/docs/Running-the-script.md index 032e7de..71ae13e 100644 --- a/docs/docs/Running-the-script.md +++ b/docs/docs/Running-the-script.md @@ -22,18 +22,20 @@ You can find a description of each option bellow. ### INPUT -* [Required] Read1 fastq file location in fastq.gz format. Read 1 typically contains Cell barcode and UMI. +* [Required] Read1 fastq file location in fastq.gz format. Read 1 typically contains Cell barcode and UMI. You can provide multiple lanes of the same run separated by a `,`. `-R1 READ1_PATH.fastq.gz, --read1 READ1_PATH.fastq.gz` +`-R1 READ1_PATH_L001.fastq.gz,READ1_PATH_L002.fastq.gz --read1 READ1_PATH.fastq.gz,READ1_PATH_L002.fastq.gz` -* [Required] Read2 fastq file location in fastq.gz. Read 2 typically contains the Antibody barcode. +* [Required] Read2 fastq file location in fastq.gz. Read 2 typically contains the Antibody barcode. You can provide multiple lanes of the same run separated by a `,`. `-R2 READ2_PATH.fastq.gz, --read2 READ2_PATH.fastq.gz` +`-R1 READ2_PATH_L001.fastq.gz,READ2_PATH_L002.fastq.gz --read1 READ2_PATH.fastq.gz,READ2_PATH_L002.fastq.gz` * [Required] The path to the csv file containing the antibody barcodes as well as their respective names. -You can run tags of different length together. +You can run tags of different length together. The headers `sequence` and `feature_name` are required in no particular order. `-t tags.csv, --tags tags.csv` @@ -61,7 +63,7 @@ GCTGTCAGCATAC C AAAAAAAAAA GCTGTCAGCATAC G AAAAAAAAAA ``` The tags.csv should only contain the part before the `T` -``` +``` GCTAGTCGTACGA,tag1 GCTGTCAGCATAC,tag2 ``` @@ -113,6 +115,7 @@ Barcodes from 1 to 16 and UMI from 17 to 26, then this is the input you need: `-cbf 1 -cbl 16 -umif 17 -umil 26` +If you have doubts about those parameters, you can check [this great ressource](https://teichlab.github.io/scg_lib_structs/) for help. * [Optional] How many errors are allowed between two cell barcodes to collapse them onto one cell. @@ -132,29 +135,37 @@ Barcodes from 1 to 16 and UMI from 17 to 26, then this is the input you need: You have to choose either the number of cells you expect or give it a list of cell barcodes to retrieve. * [Required] How many cells you expect in your run. -* [Optional] If a whitelist is provided. +* [Optional] If a reference_list is provided. `-cells EXPECTED_CELLS, --expected_cells EXPECTED_CELLS` -* [Optional] Whitelist of cell barcodes provided as a csv file. CITE-seq-Count will search for those barcodes in the data and correct other barcodes based on this list. Will force the output to provide all the barcodes from the whitelist. Please see the [guidelines](Guidelines.md) for information regarding specific chemistries. +* [Optional] reference list of cell barcodes provided as a csv file. CITE-seq-Count will search for those barcodes in the data and correct other barcodes based on this list. Please see the [guidelines](Guidelines.md) for information regarding specific chemistries. -`-wl WHITELIST, --whitelist WHITELIST` +`-rl reference_list, --reference_list reference_list` -Example: +Example simple reference: ``` +reference ATGCTAGTGCTA GCTAGTCAGGAT CGACTGCTAACG ``` +Example translated reference: +``` +reference,translation +ATGCTAGTGCTA,GCTGACTGATGC +GCTAGTCAGGAT,GCTGACTTATCG +CGACTGCTAACG,GGCTTAGCATAG +``` ### FILTERING Filtering for structure of the antibody barcode as well as maximum errors. -* [OPTIONAL] Maximum Levenshtein distance allowed. This allows to catch antibody barcodes that might have `--max-error` errors compared to the real barcodes. (was `-hd` in previous versions) +* [OPTIONAL] Maximum hamming distance allowed. This allows to catch antibody barcodes that might have `--max-error` errors compared to the real barcodes. (was `-hd` in previous versions) `--max-error MAX_ERROR`, default `2` @@ -174,7 +185,7 @@ There is a sanity check when for the `MAX_ERROR` value chosen to be sure you are `-trim N_BASES, --start-trim N_BASES`, default `0` -* [OPTIONAL] Activate sliding window alignement. Use this when you have a protocol that has a variable sequence before the inserted TAG. +* [OPTIONAL] Activate sliding window alignement. Use this when you have a protocol that has a variable sequence before the inserted TAG. This disables error correction on the TAGS. Only exact matches will be outputed. `--sliding-window`, default `False` @@ -204,7 +215,7 @@ TTCAATTTC ATGCTAGCTAAAAAAAAAAAA ### OPTIONAL -* [Optional] Select first N reads to run on instead of all. This is usefull when trying to test out your parameters before running the whole dataset. +* [Optional] Select first N reads to run on instead of all. This is usefull when trying to test out your parameters before running the whole dataset. `-n FIRST_N, --first_n FIRST_N` diff --git a/docs/docs/index.md b/docs/docs/index.md index 2376046..66e941a 100644 --- a/docs/docs/index.md +++ b/docs/docs/index.md @@ -1,7 +1,7 @@ # Welcome to CITE-seq-Count's documentation -IMPORTANT NEWS +IMPORTANT ------------------------ For users who have processed data using CITE-seq-Count version `1.3.4` please rerun any data using version 1.4.0 or higher. From 197f4adf254086a0398b0236971c010688ed6365 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Mon, 28 Dec 2020 13:42:24 +0100 Subject: [PATCH 32/77] fixed when failing to find knee estimate --- CHANGELOG.md | 1 + cite_seq_count/__main__.py | 129 +++++++------------------------- cite_seq_count/io.py | 9 ++- cite_seq_count/processing.py | 129 +++++++++++++++++++++++++++++--- docs/docs/Running-the-script.md | 7 +- setup.py | 2 +- 6 files changed, 157 insertions(+), 120 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index e1e0ae8..f1d2784 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - If no clustered cells found, the dense output matrix will not be written. - Barcode whitelists are now called reference lists. - The reference list file now requires a header `reference`. There is now an optional column called `translation`. This is specific to chemistries such as 10xV3 that use different barcodes for mRNA and Antibody tag capture sequences. See more details in the documentation. + - Bumped UMI_tools to 1.1.1 ### Removed - Unmmapped reads are not umi corrected anymore reducing run time and memory usage. diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index c3b3c05..f7528eb 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -102,41 +102,18 @@ def main(): # Correct cell barcodes if args.bc_threshold != 0: - if len(umis_per_cell) <= args.expected_cells: - print( - "Number of expected cells, {}, is higher " - "than number of cells found {}.\nNot performing " - "cell barcode correction" - "".format(args.expected_cells, len(umis_per_cell)) - ) - bcs_corrected = 0 - else: - print("Correcting cell barcodes") - if not reference_dict: - ( - final_results, - umis_per_cell, - bcs_corrected, - ) = processing.correct_cells( - final_results=final_results, - reads_per_cell=reads_per_cell, - umis_per_cell=umis_per_cell, - expected_cells=args.expected_cells, - collapsing_threshold=args.bc_threshold, - ab_map=ordered_tags, - ) - else: - ( - final_results, - umis_per_cell, - bcs_corrected, - ) = processing.correct_cells_reference_list( - final_results=final_results, - umis_per_cell=umis_per_cell, - reference_list=set(reference_dict.keys()), - collapsing_threshold=args.bc_threshold, - ab_map=ordered_tags, - ) + ( + final_results, + umis_per_cell, + bcs_corrected, + ) = processing.run_cell_barcode_correction( + final_results=final_results, + umis_per_cell=umis_per_cell, + reads_per_cell=reads_per_cell, + reference_dict=reference_dict, + ordered_tags=ordered_tags, + args=args, + ) else: print("Skipping cell barcode correction") bcs_corrected = 0 @@ -180,68 +157,19 @@ def main(): ) # UMI correction + if args.umi_threshold != 0: # Correct UMIS - input_queue = [] - - umi_correction_input = namedtuple( - "umi_correction_input", ["cells", "collapsing_threshold", "max_umis"] - ) - cells = {} - n_cells = 0 - num_chunks = 0 - print("preparing UMI correction jobs") - cell_batch_size = round(len(filtered_cells) / args.n_threads) + 1 - for cell in filtered_cells: - cells[cell] = final_results[cell] - n_cells += 1 - if n_cells % cell_batch_size == 0: - input_queue.append( - umi_correction_input( - cells=cells, - collapsing_threshold=args.umi_threshold, - max_umis=20000, - ) - ) - cells = {} - num_chunks += 1 - input_queue.append( - umi_correction_input( - cells=cells, collapsing_threshold=args.umi_threshold, max_umis=20000 - ) - ) - - pool = Pool(processes=args.n_threads) - errors = [] - parallel_results = [] - correct_umis = pool.map_async( - processing.correct_umis, - input_queue, - callback=parallel_results.append, - error_callback=errors.append, + (final_results, umis_corrected, aberrant_cells) = processing.run_umi_correction( + final_results=final_results, + filtered_cells=filtered_cells, + unmapped_id=len(ordered_tags), + args=args, ) - - correct_umis.wait() - pool.close() - pool.join() - - if len(errors) != 0: - for error in errors: - print(error) - - final_results = {} - umis_corrected = 0 - aberrant_cells = set() - - for chunk in parallel_results[0]: - (temp_results, temp_umis, temp_aberrant_cells) = chunk - final_results.update(temp_results) - umis_corrected += temp_umis - aberrant_cells.update(temp_aberrant_cells) else: # Don't correct umis_corrected = 0 - aberrant_cells = set() + aberrant_cells = [] if len(aberrant_cells) > 0: # Remove aberrant cells from the top cells @@ -254,16 +182,15 @@ def main(): ordered_tags=ordered_tags, filtered_cells=aberrant_cells, ) - if len(umi_aberrant_matrix) > 0: - # Write uncorrected cells to dense output - io.write_dense( - sparse_matrix=umi_aberrant_matrix, - ordered_tags=ordered_tags, - columns=aberrant_cells, - outfolder=os.path.join(args.outfolder, "uncorrected_cells"), - filename="dense_umis.tsv", - ) - # delete the last element (unmapped) + # Write uncorrected cells to dense output + io.write_dense( + sparse_matrix=umi_aberrant_matrix, + ordered_tags=ordered_tags, + columns=aberrant_cells, + outfolder=os.path.join(args.outfolder, "uncorrected_cells"), + filename="dense_umis.tsv", + ) + # Generate the UMI count matrix umi_results_matrix = processing.generate_sparse_matrices( final_results=final_results, ordered_tags=ordered_tags, diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 435fe76..1b8060e 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -31,10 +31,11 @@ def write_to_files( io.mmwrite(os.path.join(prefix, "matrix.mtx"), sparse_matrix) with gzip.open(os.path.join(prefix, "barcodes.tsv.gz"), "wb") as barcode_file: for barcode in filtered_cells: - if reference_dict[barcode] != 0: - barcode_file.write( - "{}\t{}\n".format(barcode, reference_dict[barcode]).encode(), - ) + if reference_dict: + if reference_dict[barcode] != 0: + barcode_file.write( + "{}\t{}\n".format(barcode, reference_dict[barcode]).encode(), + ) else: barcode_file.write("{}\n".format(barcode).encode()) with gzip.open(os.path.join(prefix, "features.tsv.gz"), "wb") as feature_file: diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index aa461d4..a21340e 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -192,7 +192,7 @@ def merge_results(parallel_results): umis_per_cell[cell_barcode] += len(mapped[cell_barcode][TAG]) reads_per_cell[cell_barcode] += mapped[cell_barcode][TAG][UMI] merged_no_match.update(unmapped) - return (merged_results, umis_per_cell, reads_per_cell, merged_no_match) + return merged_results, umis_per_cell, reads_per_cell, merged_no_match def check_unmapped(no_match, too_short, total_reads, start_trim): @@ -205,7 +205,7 @@ def check_unmapped(no_match, too_short, total_reads, start_trim): ) -def correct_umis(umi_correction_input): +def correct_umis_in_cells(umi_correction_input): """ Corrects umi barcodes within same cell/tag groups. @@ -221,7 +221,7 @@ def correct_umis(umi_correction_input): aberrant_umi_count_cells (set): Set of uncorrected cells. """ - (final_results, collapsing_threshold, max_umis) = umi_correction_input + (final_results, collapsing_threshold, max_umis, unmapped_id) = umi_correction_input print( "Started umi correction in child process {} working on {} cells".format( os.getpid(), len(final_results) @@ -232,8 +232,8 @@ def correct_umis(umi_correction_input): cells = final_results.keys() for cell_barcode in cells: for TAG in final_results[cell_barcode]: - if TAG == "unmapped": - final_results[cell_barcode][TAG].pop() + if TAG == unmapped_id: + final_results[cell_barcode].pop(unmapped_id) n_umis = len(final_results[cell_barcode][TAG]) if n_umis > 1 and n_umis <= max_umis: @@ -319,7 +319,7 @@ def collapse_cells(true_to_false, umis_per_cell, final_results, ab_map): return (umis_per_cell, final_results, corrected_barcodes) -def correct_cells( +def correct_cells_no_reference_list( final_results, reads_per_cell, umis_per_cell, @@ -342,15 +342,19 @@ def correct_cells( umis_per_cell (Counter): Counter of umis per cell after cell barcode correction corrected_umis (int): How many umis have been corrected. """ - print("Looking for a reference_list") + print("Looking for a reference list") _, true_to_false = whitelist_methods.getCellWhitelist( + knee_method="density", cell_barcode_counts=reads_per_cell, expect_cells=expected_cells, cell_number=expected_cells, error_correct_threshold=collapsing_threshold, plotfile_prefix=False, ) - + if true_to_false is None: + print("Failed to find a good reference list. Will not correct cell barcodes") + corrected_barcodes = 0 + return (final_results, umis_per_cell, corrected_barcodes) (umis_per_cell, final_results, corrected_barcodes) = collapse_cells( true_to_false=true_to_false, umis_per_cell=umis_per_cell, @@ -409,7 +413,7 @@ def find_true_to_false_map( Args: barcode_tree (BKTree): BKTree of all original cell barcodes. cell_barcodes (List): Cell barcodes to go through. - reference_list (Set): Set of the reference_list, the "true" cell barcodes. + reference_list (dict): Dict of the reference_list, the "true" cell barcodes. collasping_threshold (int): How many mistakes to correct. Return: @@ -524,3 +528,110 @@ def map_data(input_queue, args): ) return final_results, umis_per_cell, reads_per_cell, merged_no_match + + +def run_umi_correction(final_results, filtered_cells, unmapped_id, args): + input_queue = [] + umi_correction_input = namedtuple( + "umi_correction_input", + ["cells", "collapsing_threshold", "max_umis", "unmapped_id"], + ) + cells = {} + n_cells = 0 + num_chunks = 0 + print("preparing UMI correction jobs") + cell_batch_size = round(len(filtered_cells) / args.n_threads) + 1 + for cell in filtered_cells: + cells[cell] = final_results[cell] + n_cells += 1 + if n_cells % cell_batch_size == 0: + input_queue.append( + umi_correction_input( + cells=cells, + collapsing_threshold=args.umi_threshold, + max_umis=20000, + unmapped_id=unmapped_id, + ) + ) + cells = {} + num_chunks += 1 + input_queue.append( + umi_correction_input( + cells=cells, + collapsing_threshold=args.umi_threshold, + max_umis=20000, + unmapped_id=unmapped_id, + ) + ) + + pool = Pool(processes=args.n_threads) + errors = [] + parallel_results = [] + correct_umis = pool.map_async( + correct_umis_in_cells, + input_queue, + callback=parallel_results.append, + error_callback=errors.append, + ) + + correct_umis.wait() + pool.close() + pool.join() + + if len(errors) != 0: + for error in errors: + print("There was an error {}", error) + + final_results = {} + umis_corrected = 0 + aberrant_cells = set() + for chunk in parallel_results[0]: + (temp_results, temp_umis, temp_aberrant_cells) = chunk + final_results.update(temp_results) + umis_corrected += temp_umis + aberrant_cells.update(temp_aberrant_cells) + + return final_results, umis_corrected, aberrant_cells + + +def run_cell_barcode_correction( + final_results, umis_per_cell, reads_per_cell, reference_dict, ordered_tags, args +): + if len(umis_per_cell) <= args.expected_cells: + print( + "Number of expected cells, {}, is higher " + "than number of cells found {}.\nNot performing " + "cell barcode correction" + "".format(args.expected_cells, len(umis_per_cell)) + ) + bcs_corrected = 0 + else: + print("Correcting cell barcodes") + if not reference_dict: + print("Reference list not given") + ( + final_results, + umis_per_cell, + bcs_corrected, + ) = correct_cells_no_reference_list( + final_results=final_results, + reads_per_cell=reads_per_cell, + umis_per_cell=umis_per_cell, + expected_cells=args.expected_cells, + collapsing_threshold=args.bc_threshold, + ab_map=ordered_tags, + ) + else: + print("Reference list given") + ( + final_results, + umis_per_cell, + bcs_corrected, + ) = correct_cells_reference_list( + final_results=final_results, + umis_per_cell=umis_per_cell, + reference_list=set(reference_dict.keys()), + collapsing_threshold=args.bc_threshold, + ab_map=ordered_tags, + ) + return final_results, umis_per_cell, bcs_corrected diff --git a/docs/docs/Running-the-script.md b/docs/docs/Running-the-script.md index 71ae13e..009e170 100644 --- a/docs/docs/Running-the-script.md +++ b/docs/docs/Running-the-script.md @@ -117,18 +117,15 @@ Barcodes from 1 to 16 and UMI from 17 to 26, then this is the input you need: If you have doubts about those parameters, you can check [this great ressource](https://teichlab.github.io/scg_lib_structs/) for help. -* [Optional] How many errors are allowed between two cell barcodes to collapse them onto one cell. +* [Optional] How many errors are allowed between two cell barcodes to collapse them onto one cell. If set to 0, deactivates correction. `--bc_collapsing_dist N_ERRORS`, default `1` -* [Optional] How many errors are allowed between two umi within the same cell and TAG to collapse. +* [Optional] How many errors are allowed between two umi within the same cell and TAG to collapse. If set to 0, deactivates correction. `--umi_collapsing_dist N_ERRORS`, default `2` -* [Optional] Deactivate UMI correction. - -`--no_umi_correction` ### Cells diff --git a/setup.py b/setup.py index 02bc519..6e16228 100644 --- a/setup.py +++ b/setup.py @@ -22,7 +22,7 @@ "python-levenshtein>=0.12.0", "scipy>=1.1.0", "multiprocess>=0.70.6.1", - "umi_tools==1.0.0", + "umi_tools==1.1.1", "pytest>=6.0.0", "pytest-dependency==0.4.0", "pandas>=0.23.4", From 34766857064396936ae7f3ad81db8156c9accf5c Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Mon, 28 Dec 2020 13:45:37 +0100 Subject: [PATCH 33/77] fixed test callings --- tests/test_processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/test_processing.py b/tests/test_processing.py index 459f6bf..756925c 100644 --- a/tests/test_processing.py +++ b/tests/test_processing.py @@ -175,7 +175,7 @@ def test_classify_reads_multi_process(data): @pytest.mark.dependency(depends=["test_classify_reads_multi_process"]) def test_correct_umis(data): - temp = processing.correct_umis((pytest.results, 2, pytest.max_umis)) + temp = processing.correct_umis_in_cells((pytest.results, 2, pytest.max_umis, 2)) results = temp[0] n_corrected = temp[1] for cell_barcode in results.keys(): @@ -191,7 +191,7 @@ def test_correct_umis(data): @pytest.mark.dependency(depends=["test_correct_umis"]) def test_correct_cells(data): - processing.correct_cells( + processing.correct_cells_no_reference_list( pytest.corrected_results, pytest.reads_per_cell, pytest.umis_per_cell, From 78fc6232b9b027ad6855c175d4ef7a28e62cad6b Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Tue, 29 Dec 2020 23:15:38 +0100 Subject: [PATCH 34/77] fixed wrong chunking --- CHANGELOG.md | 1 + cite_seq_count/argsparser.py | 14 ++++++++++---- cite_seq_count/io.py | 9 +++++---- setup.py | 2 +- 4 files changed, 17 insertions(+), 9 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index f1d2784..99e74fc 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -33,6 +33,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Barcode whitelists are now called reference lists. - The reference list file now requires a header `reference`. There is now an optional column called `translation`. This is specific to chemistries such as 10xV3 that use different barcodes for mRNA and Antibody tag capture sequences. See more details in the documentation. - Bumped UMI_tools to 1.1.1 + - Changed `-cells` paramter to `-n_cells` for more explicit argument. ### Removed - Unmmapped reads are not umi corrected anymore reducing run time and memory usage. diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py index 9e79faa..c0b482c 100644 --- a/cite_seq_count/argsparser.py +++ b/cite_seq_count/argsparser.py @@ -91,7 +91,13 @@ def get_args(): "\t-cbf 1 -cbl 16 -umif 17 -umil 26" ), ) - barcodes.add_argument("--chemistry", type=str, required=False, default=False) + barcodes.add_argument( + "--chemistry", + type=str, + required=False, + default=False, + help=("Option replacing cell/UMI barcodes indexes and reference list."), + ) if "--chemistry" not in sys.argv: barcodes.add_argument( "-cbf", @@ -99,7 +105,7 @@ def get_args(): dest="cb_first", required=True, type=int, - help=("Postion of the first base of your cell " "barcodes."), + help=("Postion of the first base of your cell barcodes."), ) barcodes.add_argument( "-cbl", @@ -107,7 +113,7 @@ def get_args(): dest="cb_last", required=True, type=int, - help=("Postion of the last base of your cell " "barcodes."), + help=("Postion of the last base of your cell barcodes."), ) barcodes.add_argument( "-umif", @@ -147,7 +153,7 @@ def get_args(): ) cells.add_argument( - "-cells", + "-n_cells", "--expected_cells", dest="expected_cells", required=True, diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 1b8060e..92b7a3a 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -233,6 +233,11 @@ def write_chunks_to_disk( chemistry_def.umi_barcode_start - 1, chemistry_def.umi_barcode_end ) + temp_filename = os.path.join(temp_path, "temp_{}".format(num_chunk)) + chunked_file_object = open(temp_filename, "w") + temp_files.append(os.path.abspath(temp_filename)) + reads_written = 0 + for read1_path, read2_path in zip(read1_paths, read2_paths): if enough_reads: break @@ -242,10 +247,6 @@ def write_chunks_to_disk( ) as textfile2: secondlines = islice(zip(textfile1, textfile2), 1, None, 4) - temp_filename = os.path.join(temp_path, "temp_{}".format(num_chunk)) - chunked_file_object = open(temp_filename, "w") - temp_files.append(os.path.abspath(temp_filename)) - reads_written = 0 for read1, read2 in secondlines: read1 = read1.strip() diff --git a/setup.py b/setup.py index 6e16228..fa65dec 100644 --- a/setup.py +++ b/setup.py @@ -29,5 +29,5 @@ "pybktree==1.1", "cython>=0.29.17", ], - python_requires=">=3.6", + python_requires=">=3.8", ) From f2ba84a4557dc0bc3045846444ed7ea0d26f76a1 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Wed, 30 Dec 2020 12:55:16 +0100 Subject: [PATCH 35/77] fixed unmapped not working and reduced cell barcode to correct for --- cite_seq_count/__main__.py | 4 +++- cite_seq_count/preprocessing.py | 6 ++++++ cite_seq_count/processing.py | 18 ++++++++---------- 3 files changed, 17 insertions(+), 11 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index f7528eb..3955fb8 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -86,7 +86,9 @@ def main(): umis_per_cell, reads_per_cell, merged_no_match, - ) = processing.map_data(input_queue=input_queue, args=args) + ) = processing.map_data( + input_queue=input_queue, unmapped_id=len(ordered_tags), args=args + ) # Check if 99% of the reads are unmapped. processing.check_unmapped( diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index e213b47..bab885e 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -28,6 +28,7 @@ def parse_reference_list_csv(filename, barcode_length): """ STRIP_CHARS = '"0123456789- \t\n' REQUIRED_HEADER = ["reference"] + has_translation = False # OPTIONAL_HEADER = ["translation"] cell_pattern = regex.compile(r"[ATGC]{{{}}}".format(barcode_length)) @@ -49,6 +50,7 @@ def parse_reference_list_csv(filename, barcode_length): reference_id = header.index("reference") reference_dict = {} if "translation" in header: + has_translation = True translation_id = header.index("translation") for row in csv_reader: @@ -73,6 +75,10 @@ def parse_reference_list_csv(filename, barcode_length): ) if len(reference_dict) == 0: sys.exit("reference_dict is empty.") + if has_translation: + print( + "Your reference list provides a translation name. This will be the default for count matrices." + ) return reference_dict diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index a21340e..a01d5d9 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -102,7 +102,7 @@ def map_reads(mapping_input): no_match = Counter() n = 1 t = time.time() - unmapped_id = len(tags) - 1 + unmapped_id = len(tags) # Progress info with open(filename, "r") as input_file: reads = csv.reader(input_file) @@ -160,7 +160,7 @@ def map_reads(mapping_input): return (results, no_match) -def merge_results(parallel_results): +def merge_results(parallel_results, unmapped_id): """Merge chunked results from parallel processing. Args: @@ -183,6 +183,8 @@ def merge_results(parallel_results): if cell_barcode not in merged_results: merged_results[cell_barcode] = defaultdict(Counter) for TAG in mapped[cell_barcode]: + if TAG == unmapped_id: + continue # Test the counter. Returns false if empty if mapped[cell_barcode][TAG]: for UMI in mapped[cell_barcode][TAG]: @@ -386,7 +388,7 @@ def correct_cells_reference_list( print("Generating barcode tree from reference list") # pylint: disable=no-member barcode_tree = pybktree.BKTree(Levenshtein.hamming, reference_list) - barcodes = set(final_results.keys()) + barcodes = set(umis_per_cell) print("Selecting reference candidates") print("Processing {:,} cell barcodes".format(len(barcodes))) @@ -433,15 +435,11 @@ def find_true_to_false_map( if len(candidates) == 1: white_cell_str = candidates[0] true_to_false[white_cell_str].append(cell_barcode) - elif len(candidates) == 0: + else: # the cell doesnt match to any reference_listed barcode, # hence we have to drop it # (as it cannot be asscociated with any frequent barcode) continue - else: - # more than on reference_listed candidate: - # we drop it as its not uniquely assignable - continue return true_to_false @@ -484,7 +482,7 @@ def generate_sparse_matrices( return results_matrix -def map_data(input_queue, args): +def map_data(input_queue, unmapped_id, args): """ Maps the data given an input_queue @@ -524,7 +522,7 @@ def map_data(input_queue, args): print("Merging results") (final_results, umis_per_cell, reads_per_cell, merged_no_match,) = merge_results( - parallel_results=parallel_results[0] + parallel_results=parallel_results[0], unmapped_id=unmapped_id ) return final_results, umis_per_cell, reads_per_cell, merged_no_match From cf72ec255476d075c84fd4e9f1cd15c10869c447 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Thu, 31 Dec 2020 11:08:59 +0100 Subject: [PATCH 36/77] first column in barcodes.tsv is now the translated barcode --- cite_seq_count/io.py | 2 +- cite_seq_count/processing.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 92b7a3a..e1f7e42 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -34,7 +34,7 @@ def write_to_files( if reference_dict: if reference_dict[barcode] != 0: barcode_file.write( - "{}\t{}\n".format(barcode, reference_dict[barcode]).encode(), + "{}\t{}\n".format(reference_dict[barcode], barcode).encode(), ) else: barcode_file.write("{}\n".format(barcode).encode()) diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index a01d5d9..28d8c20 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -183,6 +183,7 @@ def merge_results(parallel_results, unmapped_id): if cell_barcode not in merged_results: merged_results[cell_barcode] = defaultdict(Counter) for TAG in mapped[cell_barcode]: + # We don't want to capture unmapped data in the umi counts if TAG == unmapped_id: continue # Test the counter. Returns false if empty From 2cd77f867559160f111a2fa36516b77b11a04215 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Thu, 31 Dec 2020 12:53:58 +0100 Subject: [PATCH 37/77] documentation updates --- CHANGELOG.md | 10 ++++++---- cite_seq_count/__main__.py | 26 +++++++++++++++----------- cite_seq_count/processing.py | 16 ++++++++-------- docs/docs/Guidelines.md | 3 ++- 4 files changed, 31 insertions(+), 24 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 99e74fc..52f0fec 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,17 +9,17 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - `CITE-seq-Count` is now Compatible with trimmed data. There is a new `too_short` category in the `run_report.yaml` that will let you know how much you lost due to reads being too short. #123 - UMI correction is now also parallelized and will use the threads given. - - Added a check at the end of the mapping. If more than 99% of the reads are unmapped, CITE-seq-Count will exit. + - Added a check at the end of the mapping. If more than 99% of the reads are unmapped, CITE-seq-Count will exit. #62 - (BETA) New functionnality that will fetch the chemistry definition from a remote repo to simplify usage and reduce human errors. + - Added cython dependency based on issue #117 ### Changed - The `features.tsv` now has different columns for the tag name and the tag sequence. This keeps the relevant information in the output files as well as simplifies reading the mtx format when processing the data. - The mapping step has been changed. It will first write chunks of reads to files and then read in the chunks in each child process. This should solve the io bottleneck from before. - - There are new options now for parallel computing. `--chunk_size` Determines how many reads will be read per chunk. #99 + - There are new options now for parallel computing. `--chunk_size` Determines how many reads will be read per chunk. This should fix issues like #99. - `--sliding-window` now only checks for exact matches. - - Added cython dependency based on issue #117 - The main results dict now uses an `int` as keys reducing memory footprint. - Fixed the issue #92 with using `--bc_collapsing_dist 0`. - Fixed issue #122 and now properly checks number of files. @@ -31,9 +31,11 @@ and this project adheres to [Semantic Versioning](http://semver.org/). - Added new tests to help out contributions. - If no clustered cells found, the dense output matrix will not be written. - Barcode whitelists are now called reference lists. - - The reference list file now requires a header `reference`. There is now an optional column called `translation`. This is specific to chemistries such as 10xV3 that use different barcodes for mRNA and Antibody tag capture sequences. See more details in the documentation. + - The reference list file now requires a header `reference`. There is now an optional column called `translation`. This is specific to chemistries such as 10xV3 that use different barcodes for mRNA and Antibody tag capture sequences. See more details in the documentation. #139 and #141 - Bumped UMI_tools to 1.1.1 - Changed `-cells` paramter to `-n_cells` for more explicit argument. + - Cell barcode correction when reference list is provided is now discarding cells with only unmapped reads reducing run time. + - Aberrant cells wording replaced by clustered cells to be more specific. ### Removed - Unmmapped reads are not umi corrected anymore reducing run time and memory usage. diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 3955fb8..fd97844 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -162,7 +162,11 @@ def main(): if args.umi_threshold != 0: # Correct UMIS - (final_results, umis_corrected, aberrant_cells) = processing.run_umi_correction( + ( + final_results, + umis_corrected, + clustered_cells, + ) = processing.run_umi_correction( final_results=final_results, filtered_cells=filtered_cells, unmapped_id=len(ordered_tags), @@ -171,24 +175,24 @@ def main(): else: # Don't correct umis_corrected = 0 - aberrant_cells = [] + clustered_cells = [] - if len(aberrant_cells) > 0: - # Remove aberrant cells from the top cells - for cell_barcode in aberrant_cells: + if len(clustered_cells) > 0: + # Remove clustered cells from the top cells + for cell_barcode in clustered_cells: filtered_cells.remove(cell_barcode) - # Create sparse aberrant cells matrix - umi_aberrant_matrix = processing.generate_sparse_matrices( + # Create sparse clustered cells matrix + umi_clustered_matrix = processing.generate_sparse_matrices( final_results=final_results, ordered_tags=ordered_tags, - filtered_cells=aberrant_cells, + filtered_cells=clustered_cells, ) # Write uncorrected cells to dense output io.write_dense( - sparse_matrix=umi_aberrant_matrix, + sparse_matrix=umi_clustered_matrix, ordered_tags=ordered_tags, - columns=aberrant_cells, + columns=clustered_cells, outfolder=os.path.join(args.outfolder, "uncorrected_cells"), filename="dense_umis.tsv", ) @@ -229,7 +233,7 @@ def main(): ordered_tags=ordered_tags, umis_corrected=umis_corrected, bcs_corrected=bcs_corrected, - bad_cells=aberrant_cells, + bad_cells=clustered_cells, R1_too_short=R1_too_short, R2_too_short=R2_too_short, args=args, diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 28d8c20..c181dd1 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -221,7 +221,7 @@ def correct_umis_in_cells(umi_correction_input): Returns: final_results (dict): Same as input but with corrected umis. corrected_umis (int): How many umis have been corrected. - aberrant_umi_count_cells (set): Set of uncorrected cells. + clustered_umi_count_cells (set): Set of uncorrected cells. """ (final_results, collapsing_threshold, max_umis, unmapped_id) = umi_correction_input @@ -231,7 +231,7 @@ def correct_umis_in_cells(umi_correction_input): ) ) corrected_umis = 0 - aberrant_cells = set() + clustered_cells = set() cells = final_results.keys() for cell_barcode in cells: for TAG in final_results[cell_barcode]: @@ -250,9 +250,9 @@ def correct_umis_in_cells(umi_correction_input): final_results[cell_barcode][TAG] = new_res corrected_umis += temp_corrected_umis elif n_umis > max_umis: - aberrant_cells.add(cell_barcode) + clustered_cells.add(cell_barcode) print("Finished correcting umis in child {}".format(os.getpid())) - return (final_results, corrected_umis, aberrant_cells) + return (final_results, corrected_umis, clustered_cells) def update_umi_counts(UMIclusters, cell_tag_counts): @@ -583,14 +583,14 @@ def run_umi_correction(final_results, filtered_cells, unmapped_id, args): final_results = {} umis_corrected = 0 - aberrant_cells = set() + clustered_cells = set() for chunk in parallel_results[0]: - (temp_results, temp_umis, temp_aberrant_cells) = chunk + (temp_results, temp_umis, temp_clustered_cells) = chunk final_results.update(temp_results) umis_corrected += temp_umis - aberrant_cells.update(temp_aberrant_cells) + clustered_cells.update(temp_clustered_cells) - return final_results, umis_corrected, aberrant_cells + return final_results, umis_corrected, clustered_cells def run_cell_barcode_correction( diff --git a/docs/docs/Guidelines.md b/docs/docs/Guidelines.md index b3ebd45..7512682 100644 --- a/docs/docs/Guidelines.md +++ b/docs/docs/Guidelines.md @@ -12,7 +12,8 @@ You can find this mapping [here](https://github.com/10XGenomics/cellranger/blob/ Since version 1.5.0, this is taken care of by CSC if you provide the translation column in the reference `--reference_file` file as described in the documentation. * The dense output will have the translated barcodes in the header. -* the MTX output will have two columns. The firts column is the original barcode found in the fastqs provided and the second column will be the translated barcode given by the reference list csv. +* the MTX output will have two columns. The first column is the translated barcode given by the reference list csv and the second column is the original barcode found in the data. +* I recommend using the MTX format because it contains both cell barcodes. ## PRE 1.5.0 instructions From 8b647ed7b5fdfdfa29f22a465827f2dd85bd8526 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Fri, 1 Jan 2021 14:56:22 +0100 Subject: [PATCH 38/77] cleaned up imports and added single thread fix --- cite_seq_count/__main__.py | 9 +--- cite_seq_count/argsparser.py | 22 ++++++-- cite_seq_count/io.py | 23 +++++++- cite_seq_count/preprocessing.py | 3 -- cite_seq_count/processing.py | 95 ++++++++++++++++++--------------- 5 files changed, 92 insertions(+), 60 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index fd97844..4ea4721 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -5,15 +5,8 @@ import sys import os import logging -import gzip -import requests import time -from collections import OrderedDict, Counter, defaultdict, namedtuple - -# pylint: disable=no-name-in-module -from multiprocess import Pool, Queue, JoinableQueue, Process - from cite_seq_count import preprocessing from cite_seq_count import processing from cite_seq_count import chemistry @@ -121,7 +114,7 @@ def main(): bcs_corrected = 0 # If given, use reference_list for top cells - top_cells_tuple = umis_per_cell.most_common(args.expected_cells * 10) + top_cells_tuple = umis_per_cell.most_common(args.expected_cells) if reference_dict: # Add potential missing cell barcodes. # for missing_cell in reference_list: diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py index c0b482c..314418f 100644 --- a/cite_seq_count/argsparser.py +++ b/cite_seq_count/argsparser.py @@ -19,14 +19,28 @@ def chunk_size_limit(arg): try: f = int(arg) except ValueError: - raise ArgumentTypeError("Must be an int") + raise ArgumentTypeError("Chunk size must be an int") if f < 1 or f > max_size: raise ArgumentTypeError( "Argument must be < " + str(max_size) + "and > " + str(1) ) else: - return False - return f + return f + + +def thread_default(): + """ + Set number of threads default. + + """ + max_cpu = cpu_count() + + if max_cpu > 4: + return 4 + elif max_cpu == 4: + return 3 + else: + return 1 def get_args(): @@ -223,7 +237,7 @@ def get_args(): required=False, type=int, dest="n_threads", - default=cpu_count(), + default=thread_default(), help="How many threads are to be used for running the program", ) parallel.add_argument( diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index e1f7e42..50c8c2c 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -10,7 +10,6 @@ import pandas as pd from scipy import io -import numpy as np from cite_seq_count import secondsToText @@ -223,6 +222,7 @@ def write_chunks_to_disk( temp_files = [] R1_too_short = 0 R2_too_short = 0 + total_reads = 0 total_reads_written = 0 enough_reads = False @@ -239,6 +239,7 @@ def write_chunks_to_disk( reads_written = 0 for read1_path, read2_path in zip(read1_paths, read2_paths): + if enough_reads: break print("Reading reads from files: {}, {}".format(read1_path, read2_path)) @@ -248,6 +249,7 @@ def write_chunks_to_disk( secondlines = islice(zip(textfile1, textfile2), 1, None, 4) for read1, read2 in secondlines: + total_reads += 1 read1 = read1.strip() if len(read1) < chemistry_def.umi_barcode_end: @@ -303,5 +305,22 @@ def write_chunks_to_disk( enough_reads = True chunked_file_object.close() break + if not enough_reads: + chunked_file_object.close() + input_queue.append( + mapping_input( + filename=temp_filename, + tags=ordered_tags, + debug=args.debug, + maximum_distance=maximum_distance, + sliding_window=args.sliding_window, + ) + ) + return ( + input_queue, + temp_files, + R1_too_short, + R2_too_short, + total_reads, + ) - return input_queue, temp_files, R1_too_short, R2_too_short, total_reads_written diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index bab885e..b2ab3cf 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -3,10 +3,7 @@ import sys import regex import Levenshtein -import requests -from math import floor -from collections import OrderedDict from collections import namedtuple from itertools import combinations from itertools import islice diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index c181dd1..1770957 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -1,9 +1,7 @@ import time -import gzip import sys import os import Levenshtein -import regex import pybktree import csv @@ -18,12 +16,10 @@ from numpy import int32 from scipy import sparse from umi_tools import network -from umi_tools import umi_methods import umi_tools.whitelist_methods as whitelist_methods from cite_seq_count import secondsToText -from cite_seq_count import preprocessing def find_best_match(TAG_seq, tags, maximum_distance): @@ -101,9 +97,10 @@ def map_reads(mapping_input): results = {} no_match = Counter() n = 1 - t = time.time() + unmapped_id = len(tags) # Progress info + t = time.time() with open(filename, "r") as input_file: reads = csv.reader(input_file) for read in reads: @@ -200,7 +197,8 @@ def merge_results(parallel_results, unmapped_id): def check_unmapped(no_match, too_short, total_reads, start_trim): """Check if the number of unmapped is higher than 99%""" - if (sum(no_match.values()) + too_short) / total_reads > float(0.99): + sum_unmapped = sum(no_match.values()) + too_short + if sum_unmapped / total_reads > float(0.99): sys.exit( """More than 99% of your data is unmapped.\nPlease check that your --start_trim {} parameter is correct and that your tags file is properly formatted""".format( start_trim @@ -505,23 +503,28 @@ def map_data(input_queue, unmapped_id, args): print("Started mapping") parallel_results = [] - pool = Pool(processes=args.n_threads) - errors = [] - mapping = pool.map_async( - map_reads, - input_queue, - callback=parallel_results.append, - error_callback=errors.append, - ) - mapping.wait() - pool.close() - pool.join() - if len(errors) != 0: - for error in errors: - print(error) + if args.n_threads == 1: + mapped_reads = map_reads(input_queue[0]) + parallel_results.append([mapped_reads]) + else: + pool = Pool(processes=args.n_threads) + errors = [] + mapping = pool.map_async( + map_reads, + input_queue, + callback=parallel_results.append, + error_callback=errors.append, + ) + mapping.wait() - print("Merging results") + pool.close() + pool.join() + if len(errors) != 0: + for error in errors: + print(error) + + print("Merging results") (final_results, umis_per_cell, reads_per_cell, merged_no_match,) = merge_results( parallel_results=parallel_results[0], unmapped_id=unmapped_id ) @@ -535,52 +538,58 @@ def run_umi_correction(final_results, filtered_cells, unmapped_id, args): "umi_correction_input", ["cells", "collapsing_threshold", "max_umis", "unmapped_id"], ) - cells = {} + cells_results = {} n_cells = 0 num_chunks = 0 + print("preparing UMI correction jobs") cell_batch_size = round(len(filtered_cells) / args.n_threads) + 1 for cell in filtered_cells: - cells[cell] = final_results[cell] + cells_results[cell] = final_results.pop(cell) n_cells += 1 if n_cells % cell_batch_size == 0: input_queue.append( umi_correction_input( - cells=cells, + cells=cells_results, collapsing_threshold=args.umi_threshold, max_umis=20000, unmapped_id=unmapped_id, ) ) - cells = {} + cells_results = {} num_chunks += 1 + + del final_results + input_queue.append( umi_correction_input( - cells=cells, + cells=cells_results, collapsing_threshold=args.umi_threshold, max_umis=20000, unmapped_id=unmapped_id, ) ) + if args.n_threads == 1: + pool = Pool(processes=args.n_threads) + errors = [] + parallel_results = [] + correct_umis = pool.map_async( + correct_umis_in_cells, + input_queue, + callback=parallel_results.append, + error_callback=errors.append, + ) - pool = Pool(processes=args.n_threads) - errors = [] - parallel_results = [] - correct_umis = pool.map_async( - correct_umis_in_cells, - input_queue, - callback=parallel_results.append, - error_callback=errors.append, - ) - - correct_umis.wait() - pool.close() - pool.join() - - if len(errors) != 0: - for error in errors: - print("There was an error {}", error) + correct_umis.wait() + pool.close() + pool.join() + if len(errors) != 0: + for error in errors: + print("There was an error {}", error) + else: + single_thread_result = correct_umis_in_cells(input_queue[0]) + parallel_results.append([single_thread_result]) final_results = {} umis_corrected = 0 clustered_cells = set() From b3a3e70908973d17c95a638e973a23ede0b29080 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Tue, 19 Jan 2021 14:10:18 +0100 Subject: [PATCH 39/77] single thread fix --- cite_seq_count/processing.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 1770957..36d89b6 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -569,10 +569,10 @@ def run_umi_correction(final_results, filtered_cells, unmapped_id, args): unmapped_id=unmapped_id, ) ) - if args.n_threads == 1: + parallel_results = [] + if args.n_threads != 1: pool = Pool(processes=args.n_threads) errors = [] - parallel_results = [] correct_umis = pool.map_async( correct_umis_in_cells, input_queue, From df45e4776089ca9643137caaa9540d2a7e80d144 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sat, 23 Jan 2021 14:08:02 +0100 Subject: [PATCH 40/77] fixed MTX without translation barcodes --- cite_seq_count/__main__.py | 1 + cite_seq_count/io.py | 2 ++ tests/test_io.py | 72 +++++++++++++++++++++++++++++++++++--- 3 files changed, 70 insertions(+), 5 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 4ea4721..23a5d8b 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -116,6 +116,7 @@ def main(): # If given, use reference_list for top cells top_cells_tuple = umis_per_cell.most_common(args.expected_cells) if reference_dict: + # Add potential missing cell barcodes. # for missing_cell in reference_list: # if missing_cell in final_results: diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 50c8c2c..611babb 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -35,6 +35,8 @@ def write_to_files( barcode_file.write( "{}\t{}\n".format(reference_dict[barcode], barcode).encode(), ) + else: + barcode_file.write("{}\n".format(barcode).encode()) else: barcode_file.write("{}\n".format(barcode).encode()) with gzip.open(os.path.join(prefix, "features.tsv.gz"), "wb") as feature_file: diff --git a/tests/test_io.py b/tests/test_io.py index 7245518..a19ad23 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -1,7 +1,20 @@ import pytest +import os +import gzip from cite_seq_count import io from collections import namedtuple +# copied from https://stackoverflow.com/questions/3431825/generating-an-md5-checksum-of-a-file +import hashlib + + +def md5(fname): + hash_md5 = hashlib.md5() + with gzip.open(fname, "rb") as f: + string = f.read() + hash_md5.update(string) + return hash_md5.hexdigest() + @pytest.fixture def data(): @@ -21,22 +34,71 @@ def data(): ] pytest.data_type = "umi" - pytest.outfolder = "tests/test_data/" -def test_write_to_files(data, tmpdir): +def test_write_to_files_wo_translation(data, tmpdir): import gzip import scipy reference_dict = {"ACTGTTTTATTGGCCT": 0, "TTCATAAGGTAGGGAT": 0} + output_path = os.path.join(tmpdir, "without_translation") + + mtx_path = os.path.join(output_path, "umi_count", "matrix.mtx.gz") + features_path = os.path.join(output_path, "umi_count", "features.tsv.gz") + barcodes_path = os.path.join(output_path, "umi_count", "barcodes.tsv.gz") + md5_sums = { + barcodes_path: "b7af6a32e83963606f181509a571966f", + features_path: "e889e780dbce481287c993dd043714c8", + mtx_path: "0312f3a2bfe57222ebe94051ba07786e", + } + + io.write_to_files( + pytest.sparse_matrix, + pytest.filtered_cells, + pytest.ordered_tags_map, + pytest.data_type, + output_path, + reference_dict=reference_dict, + ) + file_path = os.path.join(tmpdir, "without_translation", "umi_count/matrix.mtx.gz") + with gzip.open(file_path, "rb") as mtx_file: + assert isinstance(scipy.io.mmread(mtx_file), scipy.sparse.coo.coo_matrix) + assert md5_sums[barcodes_path] == md5(barcodes_path) + assert md5_sums[features_path] == md5(features_path) + assert md5_sums[mtx_path] == md5(mtx_path) + + +def test_write_to_files_with_translation(data, tmpdir): + import gzip + import scipy + + reference_dict = { + "ACTGTTTTATTGGCCT": "GGCTTCGATACTAGAT", + "TTCATAAGGTAGGGAT": "GATCGGATAGCTAATA", + } + output_path = os.path.join(tmpdir, "with_translation") + + mtx_path = os.path.join(output_path, "umi_count", "matrix.mtx.gz") + features_path = os.path.join(output_path, "umi_count", "features.tsv.gz") + barcodes_path = os.path.join(output_path, "umi_count", "barcodes.tsv.gz") + + md5_sums = { + barcodes_path: "fce83378b4dd548882fb9271bdd5b4f1", + features_path: "e889e780dbce481287c993dd043714c8", + mtx_path: "0312f3a2bfe57222ebe94051ba07786e", + } + io.write_to_files( pytest.sparse_matrix, pytest.filtered_cells, pytest.ordered_tags_map, pytest.data_type, - tmpdir, + output_path, reference_dict=reference_dict, ) - file = tmpdir.join("umi_count/matrix.mtx.gz") - with gzip.open(file, "rb") as mtx_file: + file_path = os.path.join(tmpdir, "with_translation", "umi_count/matrix.mtx.gz") + with gzip.open(file_path, "rb") as mtx_file: assert isinstance(scipy.io.mmread(mtx_file), scipy.sparse.coo.coo_matrix) + assert md5_sums[barcodes_path] == md5(barcodes_path) + assert md5_sums[features_path] == md5(features_path) + assert md5_sums[mtx_path] == md5(mtx_path) From 14e8c75bfb627b6ea766a1ff81a9d063cbbd6dc7 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sun, 31 Jan 2021 10:47:30 +0100 Subject: [PATCH 41/77] formatting --- tests/test_io.py | 32 ++++++++++++++++++++++++++------ 1 file changed, 26 insertions(+), 6 deletions(-) diff --git a/tests/test_io.py b/tests/test_io.py index a19ad23..b306c07 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -1,6 +1,7 @@ import pytest import os import gzip +import scipy from cite_seq_count import io from collections import namedtuple @@ -37,9 +38,6 @@ def data(): def test_write_to_files_wo_translation(data, tmpdir): - import gzip - import scipy - reference_dict = {"ACTGTTTTATTGGCCT": 0, "TTCATAAGGTAGGGAT": 0} output_path = os.path.join(tmpdir, "without_translation") @@ -69,9 +67,6 @@ def test_write_to_files_wo_translation(data, tmpdir): def test_write_to_files_with_translation(data, tmpdir): - import gzip - import scipy - reference_dict = { "ACTGTTTTATTGGCCT": "GGCTTCGATACTAGAT", "TTCATAAGGTAGGGAT": "GATCGGATAGCTAATA", @@ -102,3 +97,28 @@ def test_write_to_files_with_translation(data, tmpdir): assert md5_sums[barcodes_path] == md5(barcodes_path) assert md5_sums[features_path] == md5(features_path) assert md5_sums[mtx_path] == md5(mtx_path) + + +def test_write_to_dense_wo_translation(data, tmpdir): + reference_dict = {"ACTGTTTTATTGGCCT": 0, "TTCATAAGGTAGGGAT": 0} + output_path = os.path.join(tmpdir, "without_translation") + csv_name = "dense_umis.tsv" + csv_path = os.path.join(output_path, csv_name) + + md5_sums = { + csv_path: "b7af6a32e83963606f181509a571966f", + } + + io.write_dense( + sparse_matrix=pytest.sparse_matrix, + ordered_tags=pytest.ordered_tags_map, + columns=pytest.filtered_cells, + outfolder=output_path, + filename=csv_name, + ) + file_path = os.path.join(tmpdir, "without_translation", csv_name) + assert False + + +def test_write_to_dense_with_translation(data, tmpdir): + assert False From 0f8a8abd3a5c974ae012af3c758a91029b66bc89 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sun, 21 Mar 2021 15:34:33 +0100 Subject: [PATCH 42/77] Added tempfile naming for chunks --- cite_seq_count/argsparser.py | 4 ++-- cite_seq_count/io.py | 21 +++++++++++++-------- 2 files changed, 15 insertions(+), 10 deletions(-) diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py index 314418f..ab5fbda 100644 --- a/cite_seq_count/argsparser.py +++ b/cite_seq_count/argsparser.py @@ -1,6 +1,6 @@ import pkg_resources import sys - +import tempfile from argparse import ArgumentParser, ArgumentTypeError, RawTextHelpFormatter @@ -253,7 +253,7 @@ def get_args(): required=False, type=str, dest="temp_path", - default=".", + default=tempfile.gettempdir(), help="Temp folder for chunk creation specification. Useful when using a cluster with a scratch folder", ) diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 611babb..73125bd 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -3,6 +3,7 @@ import shutil import time import datetime +import tempfile from collections import namedtuple from itertools import islice @@ -235,9 +236,11 @@ def write_chunks_to_disk( chemistry_def.umi_barcode_start - 1, chemistry_def.umi_barcode_end ) - temp_filename = os.path.join(temp_path, "temp_{}".format(num_chunk)) - chunked_file_object = open(temp_filename, "w") - temp_files.append(os.path.abspath(temp_filename)) + chunked_file_object = tempfile.NamedTemporaryFile( + "w", dir=temp_path, suffix="_csc", delete=False + ) + # chunked_file_object = open(temp_file, "w") + temp_files.append(chunked_file_object.name) reads_written = 0 for read1_path, read2_path in zip(read1_paths, read2_paths): @@ -287,7 +290,7 @@ def write_chunks_to_disk( chunked_file_object.close() input_queue.append( mapping_input( - filename=temp_filename, + filename=chunked_file_object.name, tags=ordered_tags, debug=args.debug, maximum_distance=maximum_distance, @@ -299,9 +302,11 @@ def write_chunks_to_disk( chunked_file_object.close() break num_chunk += 1 - temp_filename = "temp_{}".format(num_chunk) - chunked_file_object = open(temp_filename, "w") - temp_files.append(os.path.abspath(temp_filename)) + chunked_file_object = tempfile.NamedTemporaryFile( + "w", dir=temp_path, suffix="_csc", delete=False + ) + # chunked_file_object = open(temp_file, "w") + temp_files.append(chunked_file_object.name) reads_written = 0 if total_reads_written == n_reads_per_chunk: enough_reads = True @@ -311,7 +316,7 @@ def write_chunks_to_disk( chunked_file_object.close() input_queue.append( mapping_input( - filename=temp_filename, + filename=chunked_file_object.name, tags=ordered_tags, debug=args.debug, maximum_distance=maximum_distance, From d753ea4c7a3569012cdd000d837750c0c88f4692 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Thu, 27 May 2021 00:01:42 +0200 Subject: [PATCH 43/77] fixed the io tests --- cite_seq_count/__main__.py | 32 ++---- cite_seq_count/argsparser.py | 55 ++++++---- cite_seq_count/chemistry.py | 6 +- cite_seq_count/io.py | 2 +- cite_seq_count/preprocessing.py | 103 +++++++++++++++--- cite_seq_count/processing.py | 102 ++++++++++++----- .../filtered_lists/fail/different_length.csv | 2 + .../filtered_lists/fail/with_header.csv | 3 + .../filtered_lists/fail/wrong_barcode.csv | 2 + .../filtered_lists/pass/normal_ref.csv | 2 + tests/test_io.py | 26 +++-- tests/test_preprocessing.py | 22 +++- 12 files changed, 249 insertions(+), 108 deletions(-) create mode 100644 tests/test_data/filtered_lists/fail/different_length.csv create mode 100644 tests/test_data/filtered_lists/fail/with_header.csv create mode 100644 tests/test_data/filtered_lists/fail/wrong_barcode.csv create mode 100644 tests/test_data/filtered_lists/pass/normal_ref.csv diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 23a5d8b..e232521 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -14,6 +14,8 @@ from cite_seq_count import secondsToText from cite_seq_count import argsparser +from collections import Counter + def main(): # Create logger and stream handler @@ -40,6 +42,8 @@ def main(): # Get chemistry defs (reference_dict, chemistry_def) = chemistry.setup_chemistry(args) + # Check if we have a filtered list provided + # Load TAGs/ABs. ab_map = preprocessing.parse_tags_csv(args.tags) ordered_tags, longest_tag_len = preprocessing.check_tags(ab_map, args.max_error) @@ -94,9 +98,9 @@ def main(): # Remove temp chunks for file_path in temp_files: os.remove(file_path) - + cell_barcode_correction = preprocessing.determine_cell_correction_mode(args) # Correct cell barcodes - if args.bc_threshold != 0: + if cell_barcode_correction != "no": ( final_results, umis_per_cell, @@ -107,35 +111,13 @@ def main(): reads_per_cell=reads_per_cell, reference_dict=reference_dict, ordered_tags=ordered_tags, + cell_barcode_correction=cell_barcode_correction, args=args, ) else: print("Skipping cell barcode correction") bcs_corrected = 0 - # If given, use reference_list for top cells - top_cells_tuple = umis_per_cell.most_common(args.expected_cells) - if reference_dict: - - # Add potential missing cell barcodes. - # for missing_cell in reference_list: - # if missing_cell in final_results: - # continue - # else: - # final_results[missing_cell] = dict() - # for TAG in ordered_tags: - # final_results[missing_cell][TAG.safe_name] = Counter() - # filtered_cells.add(missing_cell) - top_cells = [pair[0] for pair in top_cells_tuple] - filtered_cells = [] - for cell in top_cells: - # pylint: disable=no-member - if cell in reference_dict.keys(): - filtered_cells.append(cell) - else: - # Select top cells based on total umis per cell - filtered_cells = [pair[0] for pair in top_cells_tuple] - # Create sparse matrices for reads results read_results_matrix = processing.generate_sparse_matrices( final_results=final_results, diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py index ab5fbda..65e6e27 100644 --- a/cite_seq_count/argsparser.py +++ b/cite_seq_count/argsparser.py @@ -52,7 +52,7 @@ def get_args(): prog="CITE-seq-Count", formatter_class=RawTextHelpFormatter, description=( - "This script counts matching antibody tags from paired fastq " + "This package counts matching antibody tags from paired fastq " "files. Version {}".format(get_package_version()) ), ) @@ -135,7 +135,7 @@ def get_args(): dest="umi_first", required=True, type=int, - help="Postion of the first base of your UMI.", + help=("Postion of the first base of your UMI."), ) barcodes.add_argument( "-umil", @@ -143,7 +143,7 @@ def get_args(): dest="umi_last", required=True, type=int, - help="Postion of the last base of your UMI.", + help=("Postion of the last base of your UMI."), ) barcodes.add_argument( "--umi_collapsing_dist", @@ -151,7 +151,7 @@ def get_args(): required=False, type=int, default=1, - help="threshold for umi collapsing.", + help=("threshold for umi collapsing."), ) barcodes.add_argument( "--bc_collapsing_dist", @@ -159,32 +159,39 @@ def get_args(): required=False, type=int, default=1, - help="threshold for cellular barcode collapsing.", - ) - # Cells group - cells = parser.add_argument_group( - "Cells", description=("Expected number of cells and potential reference_list") + help=("threshold for cellular barcode collapsing."), ) + # Cell filtering group. We ask for either number of expected cells or a pre-filtered list of cells. + + cells_filtering = parser.add_mutually_exclusive_group(required=True) - cells.add_argument( + cells_filtering.add_argument( "-n_cells", "--expected_cells", dest="expected_cells", - required=True, type=int, help=("Number of expected cells from your run."), default=0, ) + cells_filtering.add_argument( + "-fl", + "--filtered_cells", + dest="filtered_cells", + type=str, + help=("A specific list of cells to look for."), + default=False, + ) + if "--chemistry" not in sys.argv: - cells.add_argument( + barcodes.add_argument( "-rl", "--reference_list", dest="reference_list", required=False, type=str, + default=False, help=( - "A csv file containning a reference list of barcodes produced" - " by the mRNA data.\n\n" + "A csv file containning a reference list of all potential barcodes\n\n" "\tExample:\n" "reference\n" "\tATGCTAGTGCTA\n\tGCTAGTCAGGAT\n\tCGACTGCTAACG\n\n" @@ -238,7 +245,7 @@ def get_args(): type=int, dest="n_threads", default=thread_default(), - help="How many threads are to be used for running the program", + help=("How many threads are to be used for running the program"), ) parallel.add_argument( "-C", @@ -246,7 +253,7 @@ def get_args(): required=False, type=chunk_size_limit, dest="chunk_size", - help="How many reads should be sent to a child process at a time", + help=("How many reads should be sent to a child process at a time"), ) parallel.add_argument( "--temp_path", @@ -254,7 +261,9 @@ def get_args(): type=str, dest="temp_path", default=tempfile.gettempdir(), - help="Temp folder for chunk creation specification. Useful when using a cluster with a scratch folder", + help=( + "Temp folder for chunk creation specification. Useful when using a cluster with a scratch folder" + ), ) # Global group @@ -265,7 +274,7 @@ def get_args(): type=int, dest="first_n", default=float("inf"), - help="Select N reads to run on instead of all.", + help=("Select N reads to run on instead of all."), ) parser.add_argument( "-o", @@ -274,7 +283,7 @@ def get_args(): type=str, default="Results", dest="outfolder", - help="Results will be written to this folder", + help=("Results will be written to this folder"), ) parser.add_argument( "--dense", @@ -282,7 +291,7 @@ def get_args(): action="store_true", default=False, dest="dense", - help="Add a dense output to the results folder", + help=("Add a dense output to the results folder"), ) parser.add_argument( "-u", @@ -291,7 +300,7 @@ def get_args(): type=str, dest="unmapped_file", default="unmapped.csv", - help="Write table of unknown TAGs to file.", + help=("Write table of unknown TAGs to file."), ) parser.add_argument( "-ut", @@ -300,10 +309,10 @@ def get_args(): dest="unknowns_top", type=int, default=100, - help="Top n unmapped TAGs.", + help=("Top n unmapped TAGs."), ) parser.add_argument( - "--debug", action="store_true", help="Print extra information for debugging." + "--debug", action="store_true", help=("Print extra information for debugging.") ) parser.add_argument( "--version", diff --git a/cite_seq_count/chemistry.py b/cite_seq_count/chemistry.py index ada37ec..f91ca9e 100644 --- a/cite_seq_count/chemistry.py +++ b/cite_seq_count/chemistry.py @@ -138,19 +138,21 @@ def create_chemistry_definition(args): def setup_chemistry(args): if args.chemistry: chemistry_def = get_chemistry_definition(args.chemistry) - reference_dict = preprocessing.parse_reference_list_csv( + reference_dict = preprocessing.parse_cell_list_csv( filename=chemistry_def.reference_list_path, barcode_length=chemistry_def.cell_barcode_end - chemistry_def.cell_barcode_start + 1, + file_type="reference", ) else: chemistry_def = create_chemistry_definition(args) if args.reference_list: print("Loading reference_list") - reference_dict = preprocessing.parse_reference_list_csv( + reference_dict = preprocessing.parse_cell_list_csv( filename=args.reference_list, barcode_length=args.cb_last - args.cb_first + 1, + file_type="reference", ) else: reference_dict = False diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 73125bd..0c188cd 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -28,7 +28,7 @@ def write_to_files( """ prefix = os.path.join(outfolder, data_type + "_count") os.makedirs(prefix, exist_ok=True) - io.mmwrite(os.path.join(prefix, "matrix.mtx"), sparse_matrix) + io.mmwrite(os.path.join(prefix, "matrix.mtx"), a=sparse_matrix, field="integer") with gzip.open(os.path.join(prefix, "barcodes.tsv.gz"), "wb") as barcode_file: for barcode in filtered_cells: if reference_dict: diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index b2ab3cf..389ebff 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -8,8 +8,61 @@ from itertools import combinations from itertools import islice +from pandas import read_csv -def parse_reference_list_csv(filename, barcode_length): + +def get_csv_reader_from_path(filename): + """ + Returns a csv_reader object for a file weather it's a flat file or compressed. + + Args: + filename: str + + Returns: + csv_reader: The csv_reader for the file + """ + if filename.endswith(".gz"): + f = gzip.open(filename, mode="rt") + csv_reader = csv.reader(f) + else: + f = open(filename, encoding="UTF-8") + csv_reader = csv.reader(f) + return csv_reader + + +def parse_filtered_list_csv(filename, barcode_length): + """ + Reads in a one column, no header list of barcodes and returns a set. + + Args: + filename(str): file path + barcode_length(int): Barcode expected length + + Returns: + set: A set of barcodes + """ + STRIP_CHARS = '"0123456789- \t\n' + barcodes_pd = read_csv(filename) + + barcodes = set(barcodes_pd.iloc[:, 0]) + + out_set = set() + barcode_pattern = regex.compile(r"^[ATGC]{{{}}}".format(barcode_length)) + for barcode in barcodes: + check_barcode = barcode.strip(STRIP_CHARS) + if barcode_pattern.match(check_barcode): + out_set.add(check_barcode) + else: + sys.exit( + "This barcode {} is not only composed of ATGC bases.".format( + check_barcode + ) + ) + + return out_set + + +def parse_cell_list_csv(filename, barcode_length, file_type): """Reads white-listed barcodes from a CSV file. The function accepts plain barcodes or even 10X style barcodes with the @@ -24,19 +77,16 @@ def parse_reference_list_csv(filename, barcode_length): """ STRIP_CHARS = '"0123456789- \t\n' - REQUIRED_HEADER = ["reference"] - has_translation = False - # OPTIONAL_HEADER = ["translation"] + if file_type == "reference": + REQUIRED_HEADER = ["reference"] + elif file_type == "filtered": + REQUIRED_HEADER = ["filtered_list"] - cell_pattern = regex.compile(r"[ATGC]{{{}}}".format(barcode_length)) - - if filename.endswith(".gz"): - f = gzip.open(filename, mode="rt") - csv_reader = csv.reader(f) - else: - f = open(filename, encoding="UTF-8") - csv_reader = csv.reader(f) + has_translation = False + # OPTIONAL_HEADER = ["translation", "filtered_list"] + cell_pattern = regex.compile(r"^[ATGC]{{{}}}".format(barcode_length)) + csv_reader = get_csv_reader_from_path(filename=filename) header = next(csv_reader) set_dif = set(REQUIRED_HEADER) - set(header) if len(set_dif) != 0: @@ -44,9 +94,9 @@ def parse_reference_list_csv(filename, barcode_length): "The header is missing {}. Exiting".format(",".join(list(set_dif))) ) - reference_id = header.index("reference") + reference_id = header.index(REQUIRED_HEADER[0]) reference_dict = {} - if "translation" in header: + if "translation" in header and REQUIRED_HEADER[0] == "reference": has_translation = True translation_id = header.index("translation") @@ -63,6 +113,7 @@ def parse_reference_list_csv(filename, barcode_length): ref_barcode = row[reference_id].strip(STRIP_CHARS) if len(ref_barcode) == barcode_length: reference_dict[ref_barcode] = 0 + for cell_barcode in reference_dict.keys(): if not cell_pattern.match(cell_barcode): sys.exit( @@ -74,7 +125,7 @@ def parse_reference_list_csv(filename, barcode_length): sys.exit("reference_dict is empty.") if has_translation: print( - "Your reference list provides a translation name. This will be the default for count matrices." + "Your reference list provides a translation name. This will be the default for the count matrices." ) return reference_dict @@ -224,6 +275,28 @@ def get_read_length(filename): return read_length +def determine_cell_correction_mode(args, chemistry): + """ + Determines what mode to use for cell barcode correction. + Args: + args(argparse): All arguments + + Returns: + str: type of correction + """ + if args.bc_threshold != 0: + if args.filtered_list: + filtered_set = parse_filtered_list_csv( + args.filtered_list, + (chemistry.cell_barcode_stop - chemistry.cell_barcode_start), + ) + return filtered_set + else: + return False + else: + return False + + def check_barcodes_lengths(read1_length, cb_first, cb_last, umi_first, umi_last): """Check Read1 length against CELL and UMI barcodes length. diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 36d89b6..8a6ffab 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -20,6 +20,7 @@ from cite_seq_count import secondsToText +from cite_seq_count.preprocessing import parse_cell_list_csv def find_best_match(TAG_seq, tags, maximum_distance): @@ -603,19 +604,24 @@ def run_umi_correction(final_results, filtered_cells, unmapped_id, args): def run_cell_barcode_correction( - final_results, umis_per_cell, reads_per_cell, reference_dict, ordered_tags, args + final_results, + umis_per_cell, + reads_per_cell, + reference_dict, + ordered_tags, + cell_barcode_correction, + args, ): - if len(umis_per_cell) <= args.expected_cells: - print( - "Number of expected cells, {}, is higher " - "than number of cells found {}.\nNot performing " - "cell barcode correction" - "".format(args.expected_cells, len(umis_per_cell)) - ) - bcs_corrected = 0 - else: - print("Correcting cell barcodes") - if not reference_dict: + if cell_barcode_correction == "top": + if len(umis_per_cell) <= args.expected_cells: + print( + "Number of expected cells, {}, is higher " + "than number of cells found {}.\nNot performing " + "cell barcode correction" + "".format(args.expected_cells, len(umis_per_cell)) + ) + bcs_corrected = 0 + else: print("Reference list not given") ( final_results, @@ -629,17 +635,63 @@ def run_cell_barcode_correction( collapsing_threshold=args.bc_threshold, ab_map=ordered_tags, ) - else: - print("Reference list given") - ( - final_results, - umis_per_cell, - bcs_corrected, - ) = correct_cells_reference_list( - final_results=final_results, - umis_per_cell=umis_per_cell, - reference_list=set(reference_dict.keys()), - collapsing_threshold=args.bc_threshold, - ab_map=ordered_tags, - ) + elif cell_barcode_correction == "list": + (final_results, umis_per_cell, bcs_corrected,) = correct_cells_reference_list( + final_results=final_results, + umis_per_cell=umis_per_cell, + reference_list=set(reference_dict.keys()), + collapsing_threshold=args.bc_threshold, + ab_map=ordered_tags, + ) return final_results, umis_per_cell, bcs_corrected + + +def choose_filtered_cells( + given_filtered_cells, + expected_cells, + chemistry_def, + final_results, + ordered_tags, + umis_per_cell, + translation_dict, +): + """ + Returns a list of barcodes that will be in the output + and helps decide based on the inputs. + + Args: + given_filtered_cells (bool or str): False if not given, else string + expected_cells (int): Number of expected cells + chemistry_def (Chemistry): Defines the details of the chemistry + final_results (dict): All results + ordered_tags (named_tuple): Holds tags info + umis_per_cell (Counter): Holds number of UMIs per barcode + + Returns: + set: filtered cell set + """ + # If given, use filtered_list for top cells + if given_filtered_cells: + filtered_cells = set( + parse_cell_list_csv( + filename=given_filtered_cells, + barcode_length=chemistry_def.cell_barcode_end + - chemistry_def.cell_barcode_start + + 1, + file_type="filtered", + ).keys() + ) + # Add potential missing cell barcodes. + for missing_cell in filtered_cells: + if missing_cell in final_results: + continue + else: + final_results[missing_cell] = dict() + for TAG in ordered_tags: + final_results[missing_cell][TAG.safe_name] = Counter() + filtered_cells.add(missing_cell) + else: + top_cells_tuple = umis_per_cell.most_common(expected_cells) + # Select top cells based on total umis per cell + filtered_cells = [pair[0] for pair in top_cells_tuple] + diff --git a/tests/test_data/filtered_lists/fail/different_length.csv b/tests/test_data/filtered_lists/fail/different_length.csv new file mode 100644 index 0000000..0ee1bd3 --- /dev/null +++ b/tests/test_data/filtered_lists/fail/different_length.csv @@ -0,0 +1,2 @@ +GCATCGTCAGTCGTCG +GCGCTAGCTA \ No newline at end of file diff --git a/tests/test_data/filtered_lists/fail/with_header.csv b/tests/test_data/filtered_lists/fail/with_header.csv new file mode 100644 index 0000000..cbd7763 --- /dev/null +++ b/tests/test_data/filtered_lists/fail/with_header.csv @@ -0,0 +1,3 @@ +barcodes +CGATGCACTAGCTAGT +GCTGCTACGCTATCGT \ No newline at end of file diff --git a/tests/test_data/filtered_lists/fail/wrong_barcode.csv b/tests/test_data/filtered_lists/fail/wrong_barcode.csv new file mode 100644 index 0000000..773fb4e --- /dev/null +++ b/tests/test_data/filtered_lists/fail/wrong_barcode.csv @@ -0,0 +1,2 @@ +GCTAGCTGCATGTGCT +GCTAGCTAGCTCTAAX \ No newline at end of file diff --git a/tests/test_data/filtered_lists/pass/normal_ref.csv b/tests/test_data/filtered_lists/pass/normal_ref.csv new file mode 100644 index 0000000..1f204ee --- /dev/null +++ b/tests/test_data/filtered_lists/pass/normal_ref.csv @@ -0,0 +1,2 @@ +GCTAGCTAGCTAGCTG +TTCATAAGGTAGGGAT \ No newline at end of file diff --git a/tests/test_io.py b/tests/test_io.py index b306c07..4bcc82f 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -4,6 +4,7 @@ import scipy from cite_seq_count import io from collections import namedtuple +import numpy as np # copied from https://stackoverflow.com/questions/3431825/generating-an-md5-checksum-of-a-file import hashlib @@ -11,9 +12,14 @@ def md5(fname): hash_md5 = hashlib.md5() - with gzip.open(fname, "rb") as f: - string = f.read() - hash_md5.update(string) + if fname.endswith("gz"): + with gzip.open(fname, "rb") as f: + string = f.read() + hash_md5.update(string) + else: + with open(fname, "r") as f: + string = f.read() + hash_md5.update(string.encode()) return hash_md5.hexdigest() @@ -22,7 +28,7 @@ def data(): from collections import OrderedDict from scipy import sparse - test_matrix = sparse.dok_matrix((4, 2)) + test_matrix = sparse.dok_matrix((4, 2), dtype=np.int32) test_matrix[1, 1] = 1 pytest.sparse_matrix = test_matrix pytest.filtered_cells = ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"] @@ -47,7 +53,7 @@ def test_write_to_files_wo_translation(data, tmpdir): md5_sums = { barcodes_path: "b7af6a32e83963606f181509a571966f", features_path: "e889e780dbce481287c993dd043714c8", - mtx_path: "0312f3a2bfe57222ebe94051ba07786e", + mtx_path: "3ea98c44d88a947215bace0c72ac1303", } io.write_to_files( @@ -80,7 +86,7 @@ def test_write_to_files_with_translation(data, tmpdir): md5_sums = { barcodes_path: "fce83378b4dd548882fb9271bdd5b4f1", features_path: "e889e780dbce481287c993dd043714c8", - mtx_path: "0312f3a2bfe57222ebe94051ba07786e", + mtx_path: "3ea98c44d88a947215bace0c72ac1303", } io.write_to_files( @@ -106,7 +112,7 @@ def test_write_to_dense_wo_translation(data, tmpdir): csv_path = os.path.join(output_path, csv_name) md5_sums = { - csv_path: "b7af6a32e83963606f181509a571966f", + csv_path: "fef502237900ec386d100169fa1fab7c", } io.write_dense( @@ -117,8 +123,4 @@ def test_write_to_dense_wo_translation(data, tmpdir): filename=csv_name, ) file_path = os.path.join(tmpdir, "without_translation", csv_name) - assert False - - -def test_write_to_dense_with_translation(data, tmpdir): - assert False + assert md5_sums[csv_path] == md5(file_path) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index aa3dcf4..79375e5 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -2,13 +2,12 @@ import io from cite_seq_count import preprocessing import glob -from collections import namedtuple +from collections import namedtuple, OrderedDict +from itertools import islice @pytest.fixture def data(): - from collections import OrderedDict - from itertools import islice pytest.passing_csv = "tests/test_data/tags/pass/*.csv" pytest.failing_csv = "tests/test_data/tags/fail/*.csv" @@ -16,6 +15,9 @@ def data(): pytest.passing_reference_list_csv = "tests/test_data/reference_lists/pass/*.csv" pytest.failing_reference_list_csv = "tests/test_data/reference_lists/fail/*.csv" + pytest.passing_filtered_list_csv = "tests/test_data/filtered_lists/pass/*.csv" + pytest.failing_filtered_list_csv = "tests/test_data/filtered_lists/fail/*.csv" + pytest.correct_tags_path = "tests/test_data/tags/pass/correct.csv" pytest.correct_R1_path = "tests/test_data/fastq/correct_R1.fastq.gz" pytest.correct_R2_path = "tests/test_data/fastq/correct_R2.fastq.gz" @@ -73,18 +75,28 @@ def test_csv_parser(data): preprocessing.parse_tags_csv(file_path) +def test_filtered_list_parser(data): + passing_files = glob.glob(pytest.passing_filtered_list_csv) + for file_path in passing_files: + preprocessing.parse_filtered_list_csv(file_path, barcode_length=16) + with pytest.raises(SystemExit): + failing_files = glob.glob(pytest.failing_filtered_list_csv) + for file_path in failing_files: + preprocessing.parse_filtered_list_csv(file_path, barcode_length=16) + + @pytest.mark.dependency() def test_parse_reference_list_csv(data): passing_files = glob.glob(pytest.passing_reference_list_csv) for file_path in passing_files: - assert preprocessing.parse_reference_list_csv(file_path, 16).keys() in ( + assert preprocessing.parse_cell_list_csv(file_path, 16, "reference").keys() in ( pytest.correct_reference_list, 1, ) with pytest.raises(SystemExit): failing_files = glob.glob(pytest.failing_reference_list_csv) for file_path in failing_files: - preprocessing.parse_reference_list_csv(file_path, 16) + preprocessing.parse_cell_list_csv(file_path, 16, "reference") @pytest.mark.dependency() From 68b2826c389caedbbf228d5e36f2f1272c72cc42 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sat, 29 May 2021 11:38:18 +0200 Subject: [PATCH 44/77] Moved some functions to io --- cite_seq_count/__main__.py | 2 +- cite_seq_count/io.py | 85 +++++++++++++++++++++ cite_seq_count/preprocessing.py | 127 +++++++++++--------------------- tests/test_io.py | 40 ++++++++++ tests/test_preprocessing.py | 43 ----------- 5 files changed, 168 insertions(+), 129 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index e232521..a3c5c89 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -100,7 +100,7 @@ def main(): os.remove(file_path) cell_barcode_correction = preprocessing.determine_cell_correction_mode(args) # Correct cell barcodes - if cell_barcode_correction != "no": + if cell_barcode_correction: ( final_results, umis_per_cell, diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 0c188cd..7bc544a 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -1,4 +1,6 @@ import os +import csv +import sys import gzip import shutil import time @@ -14,6 +16,89 @@ from cite_seq_count import secondsToText +def blocks(files, size=65536): + """ + A fast way of counting the lines of a large file. + Ref: + https://stackoverflow.com/a/9631635/9178565 + + Args: + files (io.handler): A file handler + size (int): Block size + Returns: + A generator + """ + while True: + b = files.read(size) + if not b: + break + yield b + + +def get_n_lines(file_path): + """ + Determines how many lines have to be processed + depending on options and number of available lines. + Checks that the number of lines is a multiple of 4. + + Args: + file_path (string): Path to a fastq.gz file + + Returns: + n_lines (int): Number of lines in the file + """ + print("Counting number of reads in file {}".format(file_path)) + with gzip.open(file_path, "rt", encoding="utf-8", errors="ignore") as f: + n_lines = sum(bl.count("\n") for bl in blocks(f)) + if n_lines % 4 != 0: + sys.exit( + "{}'s number of lines is not a multiple of 4. The file " + "might be corrupted.\n Exiting".format(file_path) + ) + return n_lines + + +def get_read_paths(read1_path, read2_path): + """ + Splits up 2 comma-separated strings of input files into list of files + to process. Ensures both lists are equal in length. + + Args: + read1_path (string): Comma-separated paths to read1.fq + read2_path (string): Comma-separated paths to read2.fq + Returns: + _read1_path (list(string)): list of paths to read1.fq + _read2_path (list(string)): list of paths to read2.fq + """ + _read1_path = read1_path.split(",") + _read2_path = read2_path.split(",") + if len(_read1_path) != len(_read2_path): + sys.exit( + "Unequal number of read1 ({}) and read2({}) files provided" + "\n Exiting".format(len(_read1_path), len(_read2_path)) + ) + return (_read1_path, _read2_path) + + +def get_csv_reader_from_path(filename): + """ + Returns a csv_reader object for a file weather it's a flat file or compressed. + + Args: + filename: str + + Returns: + csv_reader: The csv_reader for the file + """ + if filename.endswith(".gz"): + f = gzip.open(filename, mode="rt") + csv_reader = csv.reader(f) + else: + f = open(filename, encoding="UTF-8") + csv_reader = csv.reader(f) + return csv_reader + + def write_to_files( sparse_matrix, filtered_cells, ordered_tags, data_type, outfolder, reference_dict ): diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 389ebff..d2d3458 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -3,31 +3,15 @@ import sys import regex import Levenshtein +import umi_tools.whitelist_methods as whitelist_methods +from cite_seq_count.io import get_csv_reader_from_path, get_n_lines from collections import namedtuple from itertools import combinations from itertools import islice -from pandas import read_csv - - -def get_csv_reader_from_path(filename): - """ - Returns a csv_reader object for a file weather it's a flat file or compressed. - Args: - filename: str - - Returns: - csv_reader: The csv_reader for the file - """ - if filename.endswith(".gz"): - f = gzip.open(filename, mode="rt") - csv_reader = csv.reader(f) - else: - f = open(filename, encoding="UTF-8") - csv_reader = csv.reader(f) - return csv_reader +from pandas import read_csv def parse_filtered_list_csv(filename, barcode_length): @@ -275,7 +259,23 @@ def get_read_length(filename): return read_length -def determine_cell_correction_mode(args, chemistry): +def translate_barcodes(cell_set, reference_dict): + """Translate a list of barcode using a mapping reference + Args: + cell_set (set): A set of barcodes + reference_dict (dict): A dict providing a simple key value translation + + Returns: + translated_barcodes (set): A set of translated barcodes + """ + + translated_barcodes = set() + for cell in cell_set: + translate_barcodes.add(reference_dict[cell]) + return translated_barcodes + + +def get_filtered_list(args, chemistry, reference_dict, reads_per_cell): """ Determines what mode to use for cell barcode correction. Args: @@ -285,14 +285,35 @@ def determine_cell_correction_mode(args, chemistry): str: type of correction """ if args.bc_threshold != 0: + # Are we provided with a filtered list? if args.filtered_list: filtered_set = parse_filtered_list_csv( args.filtered_list, (chemistry.cell_barcode_stop - chemistry.cell_barcode_start), ) + # Do we need to translate the list? + if args.reference_dict: + # get the translation + filtered_set = translate_barcodes( + cell_set=filtered_set, reference_dict=reference_dict + ) return filtered_set + # We try and rely on the top number of cells now else: - return False + print("Looking for a reference list") + _, true_to_false = whitelist_methods.getCellWhitelist( + knee_method="density", + cell_barcode_counts=reads_per_cell, + expect_cells=args.expected_cells, + cell_number=args.expected_cells, + error_correct_threshold=args.bc_threshold, + plotfile_prefix=False, + ) + if true_to_false is None: + print( + "Failed to find a good reference list. Will not correct cell barcodes" + ) + return False else: return False @@ -327,70 +348,6 @@ def check_barcodes_lengths(read1_length, cb_first, cb_last, umi_first, umi_last) ) -def blocks(files, size=65536): - """ - A fast way of counting the lines of a large file. - Ref: - https://stackoverflow.com/a/9631635/9178565 - - Args: - files (io.handler): A file handler - size (int): Block size - Returns: - A generator - """ - while True: - b = files.read(size) - if not b: - break - yield b - - -def get_n_lines(file_path): - """ - Determines how many lines have to be processed - depending on options and number of available lines. - Checks that the number of lines is a multiple of 4. - - Args: - file_path (string): Path to a fastq.gz file - - Returns: - n_lines (int): Number of lines in the file - """ - print("Counting number of reads in file {}".format(file_path)) - with gzip.open(file_path, "rt", encoding="utf-8", errors="ignore") as f: - n_lines = sum(bl.count("\n") for bl in blocks(f)) - if n_lines % 4 != 0: - sys.exit( - "{}'s number of lines is not a multiple of 4. The file " - "might be corrupted.\n Exiting".format(file_path) - ) - return n_lines - - -def get_read_paths(read1_path, read2_path): - """ - Splits up 2 comma-separated strings of input files into list of files - to process. Ensures both lists are equal in length. - - Args: - read1_path (string): Comma-separated paths to read1.fq - read2_path (string): Comma-separated paths to read2.fq - Returns: - _read1_path (list(string)): list of paths to read1.fq - _read2_path (list(string)): list of paths to read2.fq - """ - _read1_path = read1_path.split(",") - _read2_path = read2_path.split(",") - if len(_read1_path) != len(_read2_path): - sys.exit( - "Unequal number of read1 ({}) and read2({}) files provided" - "\n Exiting".format(len(_read1_path), len(_read2_path)) - ) - return (_read1_path, _read2_path) - - def pre_run_checks(read1_paths, chemistry_def, longest_tag_len, args): """ Checks that the chemistry is properly set and defines how many reads to process diff --git a/tests/test_io.py b/tests/test_io.py index 4bcc82f..fbdafaa 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -28,6 +28,21 @@ def data(): from collections import OrderedDict from scipy import sparse + pytest.correct_R1_path = "tests/test_data/fastq/correct_R1.fastq.gz" + pytest.correct_R2_path = "tests/test_data/fastq/correct_R2.fastq.gz" + pytest.corrupt_R1_path = "tests/test_data/fastq/corrupted_R1.fastq.gz" + pytest.corrupt_R2_path = "tests/test_data/fastq/corrupted_R2.fastq.gz" + + pytest.correct_R1_multipath = "path/to/R1_1.fastq.gz,path/to/R1_2.fastq.gz" + pytest.correct_R2_multipath = "path/to/R2_1.fastq.gz,path/to/R2_2.fastq.gz" + pytest.incorrect_R2_multipath = ( + "path/to/R2_1.fastq.gz,path/to/R2_2.fastq.gz,path/to/R2_3.fastq.gz" + ) + + pytest.correct_multipath_result = ( + ["path/to/R1_1.fastq.gz", "path/to/R1_2.fastq.gz"], + ["path/to/R2_1.fastq.gz", "path/to/R2_2.fastq.gz"], + ) test_matrix = sparse.dok_matrix((4, 2), dtype=np.int32) test_matrix[1, 1] = 1 pytest.sparse_matrix = test_matrix @@ -124,3 +139,28 @@ def test_write_to_dense_wo_translation(data, tmpdir): ) file_path = os.path.join(tmpdir, "without_translation", csv_name) assert md5_sums[csv_path] == md5(file_path) + + +@pytest.mark.dependency() +def test_get_n_lines(data): + assert io.get_n_lines(pytest.correct_R1_path) == (200 * 4) + + +@pytest.mark.dependency() +def test_corrrect_multipath(data): + assert ( + io.get_read_paths(pytest.correct_R1_multipath, pytest.correct_R2_multipath) + == pytest.correct_multipath_result + ) + + +@pytest.mark.dependency(depends=["test_get_n_lines"]) +def test_incorrrect_multipath(data): + with pytest.raises(SystemExit): + io.get_read_paths(pytest.correct_R1_multipath, pytest.incorrect_R2_multipath) + + +@pytest.mark.dependency(depends=["test_get_n_lines"]) +def test_get_n_lines_not_multiple_of_4(data): + with pytest.raises(SystemExit): + io.get_n_lines(pytest.corrupt_R1_path) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 79375e5..3dd5b21 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -19,21 +19,6 @@ def data(): pytest.failing_filtered_list_csv = "tests/test_data/filtered_lists/fail/*.csv" pytest.correct_tags_path = "tests/test_data/tags/pass/correct.csv" - pytest.correct_R1_path = "tests/test_data/fastq/correct_R1.fastq.gz" - pytest.correct_R2_path = "tests/test_data/fastq/correct_R2.fastq.gz" - pytest.corrupt_R1_path = "tests/test_data/fastq/corrupted_R1.fastq.gz" - pytest.corrupt_R2_path = "tests/test_data/fastq/corrupted_R2.fastq.gz" - - pytest.correct_R1_multipath = "path/to/R1_1.fastq.gz,path/to/R1_2.fastq.gz" - pytest.correct_R2_multipath = "path/to/R2_1.fastq.gz,path/to/R2_2.fastq.gz" - pytest.incorrect_R2_multipath = ( - "path/to/R2_1.fastq.gz,path/to/R2_2.fastq.gz,path/to/R2_3.fastq.gz" - ) - - pytest.correct_multipath_result = ( - ["path/to/R1_1.fastq.gz", "path/to/R1_2.fastq.gz"], - ["path/to/R2_1.fastq.gz", "path/to/R2_2.fastq.gz"], - ) # Create some variables to compare to pytest.correct_reference_list = set(["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]) @@ -111,31 +96,3 @@ def test_check_distance_too_big_between_tags(data): with pytest.raises(SystemExit): preprocessing.check_tags(pytest.correct_tags, 8) - -@pytest.mark.dependency() -def test_get_n_lines(data): - assert preprocessing.get_n_lines(pytest.correct_R1_path) == (200 * 4) - - -@pytest.mark.dependency(depends=["test_get_n_lines"]) -def test_get_n_lines_not_multiple_of_4(data): - with pytest.raises(SystemExit): - preprocessing.get_n_lines(pytest.corrupt_R1_path) - - -@pytest.mark.dependency() -def test_corrrect_multipath(data): - assert ( - preprocessing.get_read_paths( - pytest.correct_R1_multipath, pytest.correct_R2_multipath - ) - == pytest.correct_multipath_result - ) - - -@pytest.mark.dependency(depends=["test_get_n_lines"]) -def test_incorrrect_multipath(data): - with pytest.raises(SystemExit): - preprocessing.get_read_paths( - pytest.correct_R1_multipath, pytest.incorrect_R2_multipath - ) From 2d37cbc414ac425e6a972272157865f9fb41ff97 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sat, 29 May 2021 16:01:44 +0200 Subject: [PATCH 45/77] some code and testing refactoring --- cite_seq_count/mapping.py | 202 +++++++++++++++++++++++++++++++++++ cite_seq_count/processing.py | 193 +-------------------------------- tests/test_mapping.py | 168 +++++++++++++++++++++++++++++ tests/test_processing.py | 144 +------------------------ 4 files changed, 373 insertions(+), 334 deletions(-) create mode 100644 cite_seq_count/mapping.py create mode 100644 tests/test_mapping.py diff --git a/cite_seq_count/mapping.py b/cite_seq_count/mapping.py new file mode 100644 index 0000000..f92d02f --- /dev/null +++ b/cite_seq_count/mapping.py @@ -0,0 +1,202 @@ +import time +import csv + +import sys +import os +import Levenshtein + +from collections import Counter +from collections import defaultdict +from collections import namedtuple + +# pylint: disable=no-name-in-module +from multiprocess import Pool + +from cite_seq_count.processing import merge_results +from cite_seq_count import secondsToText + + +def map_data(input_queue, unmapped_id, args): + """ + Maps the data given an input_queue + + Args: + input_queue (list): List of parameters to run in parallel + args (argparse): List of arguments + + Returns: + final_results (dict): final dictionnary with results + umis_per_cell (Counter): Counter of UMIs per cell + reads_per_cell (Counter): Counter of reads per cell + merged_no_match (Counter): Counter of unmapped reads + """ + # Initialize the counts dicts that will be generated from each input fastq pair + final_results = defaultdict(lambda: defaultdict(Counter)) + umis_per_cell = Counter() + reads_per_cell = Counter() + merged_no_match = Counter() + + print("Started mapping") + parallel_results = [] + + if args.n_threads == 1: + mapped_reads = map_reads(input_queue[0]) + parallel_results.append([mapped_reads]) + else: + pool = Pool(processes=args.n_threads) + errors = [] + mapping = pool.map_async( + map_reads, + input_queue, + callback=parallel_results.append, + error_callback=errors.append, + ) + mapping.wait() + + pool.close() + pool.join() + if len(errors) != 0: + for error in errors: + print(error) + + print("Merging results") + (final_results, umis_per_cell, reads_per_cell, merged_no_match,) = merge_results( + parallel_results=parallel_results[0], unmapped_id=unmapped_id + ) + + return final_results, umis_per_cell, reads_per_cell, merged_no_match + + +def find_best_match(TAG_seq, tags, maximum_distance): + """ + Find the best match from the list of tags. + + Compares the Levenshtein distance between tags and the trimmed sequences. + The tag and the sequence must have the same length. + If no matches found returns 'unmapped'. + We add 1 + Args: + TAG_seq (string): Sequence from R2 already start trimmed + tags (dict): A dictionary with the TAGs as keys and TAG Names as values. + maximum_distance (int): Maximum distance given by the user. + + Returns: + best_match (string): The TAG name that will be used for counting. + """ + best_match = len(tags) + best_score = maximum_distance + for tag in tags: + # pylint: disable=no-member + score = Levenshtein.hamming(tag.sequence, TAG_seq[: len(tag.sequence)]) + if score == 0: + # Best possible match + return tag.id + elif score <= best_score: + best_score = score + best_match = tag.id + return best_match + return best_match + + +def find_best_match_shift(TAG_seq, tags): + """ + Find the best match from the list of tags with sliding window. + Only works with exact match. + Just checks if the string is in the sequence. + If no matches found returns 'unmapped'. + + Args: + TAG_seq (string): Sequence from R2 already start trimmed + tags (dict): A dictionary with the TAGs as keys and TAG Names as values. + + Returns: + best_match (string): The TAG name that will be used for counting. + """ + best_match = "unmapped" + for tag in tags: + if tag.sequence in TAG_seq: + return tag.name + return best_match + + +def map_reads(mapping_input): + """Read through R1/R2 files and generate. + + It reads both Read1 and Read2 files, creating a dict based on cell barcode. + + Args: + mapping_input (namedtuple): List of paramters to run in parallel. + filename (str): Path to the chunk file + tags (list): List of named tuples tags + debug (bool): Should debug information be shown or not + maximum_distance (int): Maximum distance given by the user + sliding_window (bool): A bool enabling a sliding window search + + Returns: + results (dict): A dict of dict of Counters with the mapping results. + no_match (Counter): A counter with unmapped sequences. + """ + # Initiate values + (filename, tags, debug, maximum_distance, sliding_window) = mapping_input + print("Started mapping in child process {}".format(os.getpid())) + results = {} + no_match = Counter() + n = 1 + + unmapped_id = len(tags) + # Progress info + t = time.time() + with open(filename, "r") as input_file: + reads = csv.reader(input_file) + for read in reads: + cell_barcode = read[0] + # This change in bytes is required by umi_tools for umi correction + UMI = bytes(read[1], "ascii") + read2 = read[2] + if n % 1000000 == 0: + print( + "Processed 1,000,000 reads in {}. Total " + "reads: {:,} in child {}".format( + secondsToText.secondsToText(time.time() - t), n, os.getpid() + ) + ) + sys.stdout.flush() + t = time.time() + + if cell_barcode not in results: + results[cell_barcode] = defaultdict(Counter) + + if sliding_window: + best_match = find_best_match_shift(read2, tags) + else: + best_match = find_best_match(read2, tags, maximum_distance) + + results[cell_barcode][best_match][UMI] += 1 + + if best_match == unmapped_id: + no_match[read2] += 1 + + if debug: + print( + "cell_barcode:{0}\tUMI:{1}\tTAG_seq:{2}\n" + "cell barcode length:{3}\tUMI length:{4}\tTAG sequence length:{5}\n" + "Best match is: {6}\n".format( + cell_barcode, + UMI, + read2, + len(cell_barcode), + len(UMI), + len(read2), + tags[best_match].name, + ) + ) + sys.stdout.flush() + n += 1 + print( + "Mapping done for process {}. Processed {:,} reads".format( + os.getpid(), n - 1 + ) + ) + sys.stdout.flush() + + return (results, no_match) diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 8a6ffab..753a811 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -1,9 +1,8 @@ -import time import sys import os import Levenshtein import pybktree -import csv + from collections import Counter from collections import defaultdict @@ -12,152 +11,15 @@ # pylint: disable=no-name-in-module from multiprocess import Pool -from itertools import islice + from numpy import int32 from scipy import sparse from umi_tools import network import umi_tools.whitelist_methods as whitelist_methods - -from cite_seq_count import secondsToText from cite_seq_count.preprocessing import parse_cell_list_csv -def find_best_match(TAG_seq, tags, maximum_distance): - """ - Find the best match from the list of tags. - - Compares the Levenshtein distance between tags and the trimmed sequences. - The tag and the sequence must have the same length. - If no matches found returns 'unmapped'. - We add 1 - Args: - TAG_seq (string): Sequence from R2 already start trimmed - tags (dict): A dictionary with the TAGs as keys and TAG Names as values. - maximum_distance (int): Maximum distance given by the user. - - Returns: - best_match (string): The TAG name that will be used for counting. - """ - best_match = len(tags) - best_score = maximum_distance - for tag in tags: - # pylint: disable=no-member - score = Levenshtein.hamming(tag.sequence, TAG_seq[: len(tag.sequence)]) - if score == 0: - # Best possible match - return tag.id - elif score <= best_score: - best_score = score - best_match = tag.id - return best_match - return best_match - - -def find_best_match_shift(TAG_seq, tags): - """ - Find the best match from the list of tags with sliding window. - Only works with exact match. - Just checks if the string is in the sequence. - If no matches found returns 'unmapped'. - - Args: - TAG_seq (string): Sequence from R2 already start trimmed - tags (dict): A dictionary with the TAGs as keys and TAG Names as values. - - Returns: - best_match (string): The TAG name that will be used for counting. - """ - best_match = "unmapped" - for tag in tags: - if tag.sequence in TAG_seq: - return tag.name - return best_match - - -def map_reads(mapping_input): - """Read through R1/R2 files and generate a islice starting at a specific index. - - It reads both Read1 and Read2 files, creating a dict based on cell barcode. - - Args: - mapping_input (namedtuple): List of paramters to run in parallel. - filename (str): Path to the chunk file - tags (list): List of named tuples tags - debug (bool): Should debug information be shown or not - maximum_distance (int): Maximum distance given by the user - sliding_window (bool): A bool enabling a sliding window search - - Returns: - results (dict): A dict of dict of Counters with the mapping results. - no_match (Counter): A counter with unmapped sequences. - """ - # Initiate values - (filename, tags, debug, maximum_distance, sliding_window) = mapping_input - print("Started mapping in child process {}".format(os.getpid())) - results = {} - no_match = Counter() - n = 1 - - unmapped_id = len(tags) - # Progress info - t = time.time() - with open(filename, "r") as input_file: - reads = csv.reader(input_file) - for read in reads: - cell_barcode = read[0] - # This change in bytes is required by umi_tools for umi correction - UMI = bytes(read[1], "ascii") - read2 = read[2] - if n % 1000000 == 0: - print( - "Processed 1,000,000 reads in {}. Total " - "reads: {:,} in child {}".format( - secondsToText.secondsToText(time.time() - t), n, os.getpid() - ) - ) - sys.stdout.flush() - t = time.time() - - if cell_barcode not in results: - results[cell_barcode] = defaultdict(Counter) - - if sliding_window: - best_match = find_best_match_shift(read2, tags) - else: - best_match = find_best_match(read2, tags, maximum_distance) - - results[cell_barcode][best_match][UMI] += 1 - - if best_match == unmapped_id: - no_match[read2] += 1 - - if debug: - print( - "cell_barcode:{0}\tUMI:{1}\tTAG_seq:{2}\n" - "cell barcode length:{3}\tUMI length:{4}\tTAG sequence length:{5}\n" - "Best match is: {6}\n".format( - cell_barcode, - UMI, - read2, - len(cell_barcode), - len(UMI), - len(read2), - tags[best_match].name, - ) - ) - sys.stdout.flush() - n += 1 - print( - "Mapping done for process {}. Processed {:,} reads".format( - os.getpid(), n - 1 - ) - ) - sys.stdout.flush() - - return (results, no_match) - - def merge_results(parallel_results, unmapped_id): """Merge chunked results from parallel processing. @@ -482,57 +344,6 @@ def generate_sparse_matrices( return results_matrix -def map_data(input_queue, unmapped_id, args): - """ - Maps the data given an input_queue - - Args: - input_queue (list): List of parameters to run in parallel - args (argparse): List of arguments - - Returns: - final_results (dict): final dictionnary with results - umis_per_cell (Counter): Counter of UMIs per cell - reads_per_cell (Counter): Counter of reads per cell - merged_no_match (Counter): Counter of unmapped reads - """ - # Initialize the counts dicts that will be generated from each input fastq pair - final_results = defaultdict(lambda: defaultdict(Counter)) - umis_per_cell = Counter() - reads_per_cell = Counter() - merged_no_match = Counter() - - print("Started mapping") - parallel_results = [] - - if args.n_threads == 1: - mapped_reads = map_reads(input_queue[0]) - parallel_results.append([mapped_reads]) - else: - pool = Pool(processes=args.n_threads) - errors = [] - mapping = pool.map_async( - map_reads, - input_queue, - callback=parallel_results.append, - error_callback=errors.append, - ) - mapping.wait() - - pool.close() - pool.join() - if len(errors) != 0: - for error in errors: - print(error) - - print("Merging results") - (final_results, umis_per_cell, reads_per_cell, merged_no_match,) = merge_results( - parallel_results=parallel_results[0], unmapped_id=unmapped_id - ) - - return final_results, umis_per_cell, reads_per_cell, merged_no_match - - def run_umi_correction(final_results, filtered_cells, unmapped_id, args): input_queue = [] umi_correction_input = namedtuple( diff --git a/tests/test_mapping.py b/tests/test_mapping.py new file mode 100644 index 0000000..5edd623 --- /dev/null +++ b/tests/test_mapping.py @@ -0,0 +1,168 @@ +import pytest +import random +import copy +from collections import Counter, namedtuple +from cite_seq_count import mapping +from cite_seq_count import preprocessing + + +def complete_poly_A(seq, final_length=40): + poly_A_len = final_length - len(seq) + return seq + "A" * poly_A_len + + +def get_sequences(ref_path): + sequences = [] + with open(ref_path, "r") as adt_ref: + lines = adt_ref.readlines() + entries = int(len(lines) / 2) + for i in range(0, entries, 2): + sequences.append(complete_poly_A(lines[i + 1].strip())) + return sequences + + +def extend_seq_pool(ref_seq, distance): + extended_pool = [complete_poly_A(ref_seq)] + extended_pool.append(modify(ref_seq, distance, modification_type="mutate")) + extended_pool.append(modify(ref_seq, distance, modification_type="mutate")) + extended_pool.append(modify(ref_seq, distance, modification_type="mutate")) + return extended_pool + + +def modify(seq, n, modification_type): + bases = list("ATGCN") + positions = list(range(len(seq))) + seq = list(seq) + for _ in range(n): + if modification_type == "mutate": + position = random.choice(positions) + positions.remove(position) + temp_bases = copy.copy(bases) + del temp_bases[bases.index(seq[position])] + seq[position] = random.choice(temp_bases) + elif modification_type == "delete": + del seq[random.randint(0, len(seq) - 2)] + elif modification_type == "add": + position = random.randint(0, len(seq) - 1) + seq.insert(position, random.choice(bases)) + return complete_poly_A("".join(seq)) + + +@pytest.fixture +def data(): + from collections import Counter + + # Test file paths + pytest.correct_R1_path = "tests/test_data/fastq/correct_R1.fastq.gz" + pytest.correct_R2_path = "tests/test_data/fastq/correct_R2.fastq.gz" + pytest.file_path = "tests/test_data/fastq/test_csv.csv" + + pytest.chunk_size = 800 + tag = namedtuple("tag", ["name", "sequence", "id"]) + pytest.tags = [ + tag(name="test1", sequence="CGTACGTAGCCTAGC", id=0), + tag(name="test2", sequence="CGTAGCTCG", id=1), + ] + + pytest.barcode_slice = slice(0, 16) + pytest.umi_slice = slice(16, 26) + pytest.correct_reference_list = set(["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]) + pytest.legacy = False + pytest.debug = False + pytest.start_trim = 0 + pytest.maximum_distance = 5 + pytest.results = { + "ACTGTTTTATTGGCCT": { + 0: Counter({b"CATTAGTGGT": 3, b"CATTAGTGGG": 2, b"CATTCGTGGT": 1}) + }, + "TTCATAAGGTAGGGAT": { + 1: Counter({b"TAGCTTAGTA": 3, b"TAGCTTAGTC": 2, b"GCGATGCATA": 1}) + }, + } + pytest.corrected_results = { + "ACTGTTTTATTGGCCT": {0: Counter({b"CATTAGTGGT": 6})}, + "TTCATAAGGTAGGGAT": {1: Counter({b"TAGCTTAGTA": 5, b"GCGATGCATA": 1})}, + } + pytest.umis_per_cell = Counter({"ACTGTTTTATTGGCCT": 1, "TTCATAAGGTAGGGAT": 2}) + pytest.reads_per_cell = Counter({"ACTGTTTTATTGGCCT": 3, "TTCATAAGGTAGGGAT": 6}) + pytest.expected_cells = 2 + pytest.no_match = Counter() + pytest.collapsing_threshold = 1 + pytest.sliding_window = False + pytest.max_umis = 20000 + + pytest.sequence_pool = [] + pytest.tags_tuple = preprocessing.check_tags( + preprocessing.parse_tags_csv("tests/test_data/tags/pass/correct.csv"), 5 + )[0] + pytest.mapping_input = namedtuple( + "mapping_input", + ["filename", "tags", "debug", "maximum_distance", "sliding_window"], + ) + pytest.mappint_input_test = pytest.mapping_input( + filename=pytest.file_path, + tags=pytest.tags_tuple, + debug=pytest.debug, + maximum_distance=pytest.maximum_distance, + sliding_window=pytest.sliding_window, + ) + + +@pytest.mark.dependency() +def test_find_best_match_with_1_distance(data): + distance = 1 + for tag in pytest.tags_tuple: + counts = Counter() + if tag.name == "unmapped": + continue + for seq in extend_seq_pool(tag.sequence, distance): + counts[mapping.find_best_match(seq, pytest.tags_tuple, distance)] += 1 + assert counts[tag.id] == 4 + + +@pytest.mark.dependency() +def test_find_best_match_with_2_distance(data): + distance = 2 + for tag in pytest.tags_tuple: + counts = Counter() + if tag.name == "unmapped": + continue + for seq in extend_seq_pool(tag.sequence, distance): + counts[mapping.find_best_match(seq, pytest.tags_tuple, distance)] += 1 + assert counts[tag.id] == 4 + + +@pytest.mark.dependency() +def test_find_best_match_with_3_distance(data): + distance = 3 + for tag in pytest.tags_tuple: + counts = Counter() + for seq in extend_seq_pool(tag.sequence, distance): + counts[mapping.find_best_match(seq, pytest.tags_tuple, distance)] += 1 + assert counts[tag.id] == 4 + + +@pytest.mark.dependency() +def test_find_best_match_with_3_distance_reverse(data): + distance = 3 + for tag in pytest.tags_tuple: + counts = Counter() + if tag.name == "unmapped": + continue + for seq in extend_seq_pool(tag.sequence, distance): + counts[mapping.find_best_match(seq, pytest.tags_tuple, distance)] += 1 + assert counts[tag.id] == 4 + + +@pytest.mark.dependency( + depends=[ + "test_find_best_match_with_1_distance", + "test_find_best_match_with_2_distance", + "test_find_best_match_with_3_distance", + "test_find_best_match_with_3_distance_reverse", + ] +) +def test_classify_reads_multi_process(data): + (results, _) = mapping.map_reads(pytest.mappint_input_test) + print(results) + assert len(results) == 2 diff --git a/tests/test_processing.py b/tests/test_processing.py index 756925c..8b129b8 100644 --- a/tests/test_processing.py +++ b/tests/test_processing.py @@ -1,81 +1,17 @@ import pytest -import random -import copy -from collections import Counter, namedtuple +from collections import namedtuple from cite_seq_count import processing -from cite_seq_count import preprocessing - - -def complete_poly_A(seq, final_length=40): - poly_A_len = final_length - len(seq) - return seq + "A" * poly_A_len - - -def get_sequences(ref_path): - sequences = [] - with open(ref_path, "r") as adt_ref: - lines = adt_ref.readlines() - entries = int(len(lines) / 2) - for i in range(0, entries, 2): - sequences.append(complete_poly_A(lines[i + 1].strip())) - return sequences - - -def extend_seq_pool(ref_seq, distance): - extended_pool = [complete_poly_A(ref_seq)] - extended_pool.append(modify(ref_seq, distance, modification_type="mutate")) - extended_pool.append(modify(ref_seq, distance, modification_type="mutate")) - extended_pool.append(modify(ref_seq, distance, modification_type="mutate")) - return extended_pool - - -def modify(seq, n, modification_type): - bases = list("ATGCN") - positions = list(range(len(seq))) - seq = list(seq) - for _ in range(n): - if modification_type == "mutate": - position = random.choice(positions) - positions.remove(position) - temp_bases = copy.copy(bases) - del temp_bases[bases.index(seq[position])] - seq[position] = random.choice(temp_bases) - elif modification_type == "delete": - del seq[random.randint(0, len(seq) - 2)] - elif modification_type == "add": - position = random.randint(0, len(seq) - 1) - seq.insert(position, random.choice(bases)) - return complete_poly_A("".join(seq)) @pytest.fixture def data(): - import json - from collections import defaultdict - from collections import OrderedDict from collections import Counter - from itertools import islice - - # Test file paths - pytest.correct_R1_path = "tests/test_data/fastq/correct_R1.fastq.gz" - pytest.correct_R2_path = "tests/test_data/fastq/correct_R2.fastq.gz" - pytest.file_path = "tests/test_data/fastq/test_csv.csv" - - pytest.chunk_size = 800 tag = namedtuple("tag", ["name", "sequence", "id"]) pytest.tags = [ tag(name="test1", sequence="CGTACGTAGCCTAGC", id=0), tag(name="test2", sequence="CGTAGCTCG", id=1), ] - - pytest.barcode_slice = slice(0, 16) - pytest.umi_slice = slice(16, 26) - pytest.correct_reference_list = set(["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]) - pytest.legacy = False - pytest.debug = False - pytest.start_trim = 0 - pytest.maximum_distance = 5 pytest.results = { "ACTGTTTTATTGGCCT": { 0: Counter({b"CATTAGTGGT": 3, b"CATTAGTGGG": 2, b"CATTCGTGGT": 1}) @@ -91,89 +27,11 @@ def data(): pytest.umis_per_cell = Counter({"ACTGTTTTATTGGCCT": 1, "TTCATAAGGTAGGGAT": 2}) pytest.reads_per_cell = Counter({"ACTGTTTTATTGGCCT": 3, "TTCATAAGGTAGGGAT": 6}) pytest.expected_cells = 2 - pytest.no_match = Counter() pytest.collapsing_threshold = 1 - pytest.sliding_window = False pytest.max_umis = 20000 - pytest.sequence_pool = [] - pytest.tags_tuple = preprocessing.check_tags( - preprocessing.parse_tags_csv("tests/test_data/tags/pass/correct.csv"), 5 - )[0] - pytest.mapping_input = namedtuple( - "mapping_input", - ["filename", "tags", "debug", "maximum_distance", "sliding_window"], - ) - pytest.mappint_input_test = pytest.mapping_input( - filename=pytest.file_path, - tags=pytest.tags_tuple, - debug=pytest.debug, - maximum_distance=pytest.maximum_distance, - sliding_window=pytest.sliding_window, - ) - @pytest.mark.dependency() -def test_find_best_match_with_1_distance(data): - distance = 1 - for tag in pytest.tags_tuple: - counts = Counter() - if tag.name == "unmapped": - continue - for seq in extend_seq_pool(tag.sequence, distance): - counts[processing.find_best_match(seq, pytest.tags_tuple, distance)] += 1 - assert counts[tag.id] == 4 - - -@pytest.mark.dependency() -def test_find_best_match_with_2_distance(data): - distance = 2 - for tag in pytest.tags_tuple: - counts = Counter() - if tag.name == "unmapped": - continue - for seq in extend_seq_pool(tag.sequence, distance): - counts[processing.find_best_match(seq, pytest.tags_tuple, distance)] += 1 - assert counts[tag.id] == 4 - - -@pytest.mark.dependency() -def test_find_best_match_with_3_distance(data): - distance = 3 - for tag in pytest.tags_tuple: - counts = Counter() - for seq in extend_seq_pool(tag.sequence, distance): - counts[processing.find_best_match(seq, pytest.tags_tuple, distance)] += 1 - assert counts[tag.id] == 4 - - -@pytest.mark.dependency() -def test_find_best_match_with_3_distance_reverse(data): - distance = 3 - for tag in pytest.tags_tuple: - counts = Counter() - if tag.name == "unmapped": - continue - for seq in extend_seq_pool(tag.sequence, distance): - counts[processing.find_best_match(seq, pytest.tags_tuple, distance)] += 1 - assert counts[tag.id] == 4 - - -@pytest.mark.dependency( - depends=[ - "test_find_best_match_with_1_distance", - "test_find_best_match_with_2_distance", - "test_find_best_match_with_3_distance", - "test_find_best_match_with_3_distance_reverse", - ] -) -def test_classify_reads_multi_process(data): - (results, _) = processing.map_reads(pytest.mappint_input_test) - print(results) - assert len(results) == 2 - - -@pytest.mark.dependency(depends=["test_classify_reads_multi_process"]) def test_correct_umis(data): temp = processing.correct_umis_in_cells((pytest.results, 2, pytest.max_umis, 2)) results = temp[0] From a867d448a2a0bd5195dd8ae96a4f233ba76b1b8b Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sat, 29 May 2021 16:05:09 +0200 Subject: [PATCH 46/77] cleaning up test_mapping.py --- tests/test_mapping.py | 27 ++------------------------- 1 file changed, 2 insertions(+), 25 deletions(-) diff --git a/tests/test_mapping.py b/tests/test_mapping.py index 5edd623..3d9cbb2 100644 --- a/tests/test_mapping.py +++ b/tests/test_mapping.py @@ -52,24 +52,11 @@ def modify(seq, n, modification_type): def data(): from collections import Counter - # Test file paths - pytest.correct_R1_path = "tests/test_data/fastq/correct_R1.fastq.gz" - pytest.correct_R2_path = "tests/test_data/fastq/correct_R2.fastq.gz" pytest.file_path = "tests/test_data/fastq/test_csv.csv" - - pytest.chunk_size = 800 - tag = namedtuple("tag", ["name", "sequence", "id"]) - pytest.tags = [ - tag(name="test1", sequence="CGTACGTAGCCTAGC", id=0), - tag(name="test2", sequence="CGTAGCTCG", id=1), - ] - + pytest.debug = False pytest.barcode_slice = slice(0, 16) pytest.umi_slice = slice(16, 26) pytest.correct_reference_list = set(["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]) - pytest.legacy = False - pytest.debug = False - pytest.start_trim = 0 pytest.maximum_distance = 5 pytest.results = { "ACTGTTTTATTGGCCT": { @@ -79,18 +66,8 @@ def data(): 1: Counter({b"TAGCTTAGTA": 3, b"TAGCTTAGTC": 2, b"GCGATGCATA": 1}) }, } - pytest.corrected_results = { - "ACTGTTTTATTGGCCT": {0: Counter({b"CATTAGTGGT": 6})}, - "TTCATAAGGTAGGGAT": {1: Counter({b"TAGCTTAGTA": 5, b"GCGATGCATA": 1})}, - } - pytest.umis_per_cell = Counter({"ACTGTTTTATTGGCCT": 1, "TTCATAAGGTAGGGAT": 2}) - pytest.reads_per_cell = Counter({"ACTGTTTTATTGGCCT": 3, "TTCATAAGGTAGGGAT": 6}) - pytest.expected_cells = 2 - pytest.no_match = Counter() - pytest.collapsing_threshold = 1 - pytest.sliding_window = False - pytest.max_umis = 20000 + pytest.sliding_window = False pytest.sequence_pool = [] pytest.tags_tuple = preprocessing.check_tags( preprocessing.parse_tags_csv("tests/test_data/tags/pass/correct.csv"), 5 From 5708028e7519274d93484bf67e8267c31e6b0516 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sun, 6 Jun 2021 15:49:18 +0200 Subject: [PATCH 47/77] changed reference to translate --- cite_seq_count/__main__.py | 37 +-- cite_seq_count/argsparser.py | 17 +- cite_seq_count/chemistry.py | 34 +-- cite_seq_count/io.py | 17 +- cite_seq_count/mapping.py | 11 + cite_seq_count/preprocessing.py | 109 ++++----- cite_seq_count/processing.py | 419 ++++++++++++++++---------------- 7 files changed, 316 insertions(+), 328 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index a3c5c89..a31d334 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -8,6 +8,7 @@ import time from cite_seq_count import preprocessing +from cite_seq_count import mapping from cite_seq_count import processing from cite_seq_count import chemistry from cite_seq_count import io @@ -40,7 +41,7 @@ def main(): assert os.access(args.temp_path, os.W_OK) # Get chemistry defs - (reference_dict, chemistry_def) = chemistry.setup_chemistry(args) + (translation_dict, chemistry_def) = chemistry.setup_chemistry(args) # Check if we have a filtered list provided @@ -49,8 +50,11 @@ def main(): ordered_tags, longest_tag_len = preprocessing.check_tags(ab_map, args.max_error) # Identify input file(s) - read1_paths, read2_paths = preprocessing.get_read_paths( - args.read1_path, args.read2_path + read1_paths, read2_paths = io.get_read_paths(args.read1_path, args.read2_path) + + # Check filtered input list + filtered_cells = preprocessing.get_filtered_list( + args=args, chemistry=chemistry_def, translation_dict=translation_dict ) # Checks before chunking. (n_reads, R2_min_length, maximum_distance) = preprocessing.pre_run_checks( @@ -78,17 +82,12 @@ def main(): maximum_distance=maximum_distance, ) # Map the data - ( - final_results, - umis_per_cell, - reads_per_cell, - merged_no_match, - ) = processing.map_data( + (final_results, umis_per_cell, reads_per_cell, merged_no_match,) = mapping.map_data( input_queue=input_queue, unmapped_id=len(ordered_tags), args=args ) # Check if 99% of the reads are unmapped. - processing.check_unmapped( + mapping.check_unmapped( no_match=merged_no_match, too_short=R1_too_short + R2_too_short, total_reads=total_reads, @@ -98,9 +97,15 @@ def main(): # Remove temp chunks for file_path in temp_files: os.remove(file_path) - cell_barcode_correction = preprocessing.determine_cell_correction_mode(args) + + filtered_cells = processing.check_filtered_cells( + filtered_cells=filtered_cells, + expected_cells=args.expected_cells, + umis_per_cell=umis_per_cell, + ) + # Correct cell barcodes - if cell_barcode_correction: + if args.bc_threshold > 0: ( final_results, umis_per_cell, @@ -108,10 +113,8 @@ def main(): ) = processing.run_cell_barcode_correction( final_results=final_results, umis_per_cell=umis_per_cell, - reads_per_cell=reads_per_cell, - reference_dict=reference_dict, ordered_tags=ordered_tags, - cell_barcode_correction=cell_barcode_correction, + filtered_set=filtered_cells, args=args, ) else: @@ -131,7 +134,7 @@ def main(): ordered_tags=ordered_tags, data_type="read", outfolder=args.outfolder, - reference_dict=reference_dict, + translation_dict=translation_dict, ) # UMI correction @@ -187,7 +190,7 @@ def main(): ordered_tags=ordered_tags, data_type="umi", outfolder=args.outfolder, - reference_dict=reference_dict, + translation_dict=translation_dict, ) # Write unmapped sequences diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py index 65e6e27..39d6f51 100644 --- a/cite_seq_count/argsparser.py +++ b/cite_seq_count/argsparser.py @@ -110,7 +110,7 @@ def get_args(): type=str, required=False, default=False, - help=("Option replacing cell/UMI barcodes indexes and reference list."), + help=("Option replacing cell/UMI barcodes indexes and translation list."), ) if "--chemistry" not in sys.argv: barcodes.add_argument( @@ -175,6 +175,7 @@ def get_args(): ) cells_filtering.add_argument( "-fl", + "-wl", "--filtered_cells", dest="filtered_cells", type=str, @@ -184,19 +185,17 @@ def get_args(): if "--chemistry" not in sys.argv: barcodes.add_argument( - "-rl", - "--reference_list", - dest="reference_list", + "-tl", + "--translation_list", + dest="translation_list", required=False, type=str, default=False, help=( - "A csv file containning a reference list of all potential barcodes\n\n" + "A csv file containning a translation list of all potential barcodes\n\n" "\tExample:\n" - "reference\n" - "\tATGCTAGTGCTA\n\tGCTAGTCAGGAT\n\tCGACTGCTAACG\n\n" - "Or 10X-style:\n" - "\tATGCTAGTGCTA-1\n\tGCTAGTCAGGAT-1\n\tCGACTGCTAACG-1\n" + "whitelist,translation\n" + "\tAAACCCAAGAAACACT,AAACCCATCAAACACT\n\AAACCCAAGAAACCAT,AAACCCATCAAACCAT\n\AAACCCAAGAAACCCA,AAACCCATCAAACCCA\n\n" ), ) diff --git a/cite_seq_count/chemistry.py b/cite_seq_count/chemistry.py index f91ca9e..8acf502 100644 --- a/cite_seq_count/chemistry.py +++ b/cite_seq_count/chemistry.py @@ -29,7 +29,7 @@ class Chemistry: umi_barcode_start: int umi_barcode_end: int R2_trim_start: int - reference_list_path: str + translation_list_path: str DEFINITIONS_DB = pooch.create( @@ -89,19 +89,19 @@ def get_chemistry_definition(chemistry_short_name): """ chemistry_defs = fetch_definitions()[chemistry_short_name] - if chemistry_defs["reference_list"]["path"] not in DEFINITIONS_DB.registry: + if chemistry_defs["translation_list"]["path"] not in DEFINITIONS_DB.registry: path = pooch.retrieve( url=os.path.join( GLOBAL_LINK_GITHUB, "chemistries", - chemistry_defs["reference_list"]["path"], + chemistry_defs["translation_list"]["path"], ), known_hash=None, - fname=chemistry_defs["reference_list"]["path"], + fname=chemistry_defs["translation_list"]["path"], path=DEFINITIONS_DB.abspath, ) else: - path = DEFINITIONS_DB.registry[chemistry_defs["reference_list"]["path"]] + path = DEFINITIONS_DB.registry[chemistry_defs["translation_list"]["path"]] chemistry_def = Chemistry( name=chemistry_short_name, cell_barcode_start=chemistry_defs["barcode_structure_indexes"]["cell_barcode"][ @@ -117,7 +117,7 @@ def get_chemistry_definition(chemistry_short_name): "R1" ]["stop"], R2_trim_start=chemistry_defs["sequence_structure_indexes"]["R2"]["start"] - 1, - reference_list_path=path, + translation_list_path=path, ) return chemistry_def @@ -130,7 +130,7 @@ def create_chemistry_definition(args): umi_barcode_start=args.umi_first, umi_barcode_end=args.umi_last, R2_trim_start=args.start_trim, - reference_list_path=args.reference_list, + translation_list_path=args.translation_list, ) return chemistry_def @@ -138,22 +138,22 @@ def create_chemistry_definition(args): def setup_chemistry(args): if args.chemistry: chemistry_def = get_chemistry_definition(args.chemistry) - reference_dict = preprocessing.parse_cell_list_csv( - filename=chemistry_def.reference_list_path, + translation_dict = preprocessing.parse_cell_list_csv( + filename=chemistry_def.translation_list_path, barcode_length=chemistry_def.cell_barcode_end - chemistry_def.cell_barcode_start + 1, - file_type="reference", + file_type="translation", ) else: chemistry_def = create_chemistry_definition(args) - if args.reference_list: - print("Loading reference_list") - reference_dict = preprocessing.parse_cell_list_csv( - filename=args.reference_list, + if args.translation_list: + print("Loading translation_list") + translation_dict = preprocessing.parse_cell_list_csv( + filename=args.translation_list, barcode_length=args.cb_last - args.cb_first + 1, - file_type="reference", + file_type="translation", ) else: - reference_dict = False - return (reference_dict, chemistry_def) + translation_dict = False + return (translation_dict, chemistry_def) diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 7bc544a..a94ad57 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -100,7 +100,7 @@ def get_csv_reader_from_path(filename): def write_to_files( - sparse_matrix, filtered_cells, ordered_tags, data_type, outfolder, reference_dict + sparse_matrix, filtered_cells, ordered_tags, data_type, outfolder, translation_dict ): """Write the umi and read sparse matrices to file in gzipped mtx format. @@ -111,18 +111,19 @@ def write_to_files( data_type (string): A string definning if the data is umi or read based. outfolder (string): Path to the output folder. """ + original_barcode = list(translation_dict.keys()) + translated_barcode = list(translation_dict.values()) prefix = os.path.join(outfolder, data_type + "_count") os.makedirs(prefix, exist_ok=True) io.mmwrite(os.path.join(prefix, "matrix.mtx"), a=sparse_matrix, field="integer") with gzip.open(os.path.join(prefix, "barcodes.tsv.gz"), "wb") as barcode_file: for barcode in filtered_cells: - if reference_dict: - if reference_dict[barcode] != 0: - barcode_file.write( - "{}\t{}\n".format(reference_dict[barcode], barcode).encode(), - ) - else: - barcode_file.write("{}\n".format(barcode).encode()) + if translation_dict: + barcode_file.write( + "{}\t{}\n".format( + original_barcode[translated_barcode.index(barcode)], barcode + ).encode(), + ) else: barcode_file.write("{}\n".format(barcode).encode()) with gzip.open(os.path.join(prefix, "features.tsv.gz"), "wb") as feature_file: diff --git a/cite_seq_count/mapping.py b/cite_seq_count/mapping.py index f92d02f..910174b 100644 --- a/cite_seq_count/mapping.py +++ b/cite_seq_count/mapping.py @@ -200,3 +200,14 @@ def map_reads(mapping_input): sys.stdout.flush() return (results, no_match) + + +def check_unmapped(no_match, too_short, total_reads, start_trim): + """Check if the number of unmapped is higher than 99%""" + sum_unmapped = sum(no_match.values()) + too_short + if sum_unmapped / total_reads > float(0.99): + sys.exit( + """More than 99% of your data is unmapped.\nPlease check that your --start_trim {} parameter is correct and that your tags file is properly formatted""".format( + start_trim + ) + ) diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index d2d3458..97b8f1c 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -53,7 +53,7 @@ def parse_cell_list_csv(filename, barcode_length, file_type): `-1` at the end of each barcode. Args: - filename (str): reference_list barcode file. + filename (str): translation_list barcode file. barcode_length (int): Length of the expected barcodes. Returns: @@ -61,8 +61,8 @@ def parse_cell_list_csv(filename, barcode_length, file_type): """ STRIP_CHARS = '"0123456789- \t\n' - if file_type == "reference": - REQUIRED_HEADER = ["reference"] + if file_type == "translation": + REQUIRED_HEADER = ["translation"] elif file_type == "filtered": REQUIRED_HEADER = ["filtered_list"] @@ -78,40 +78,40 @@ def parse_cell_list_csv(filename, barcode_length, file_type): "The header is missing {}. Exiting".format(",".join(list(set_dif))) ) - reference_id = header.index(REQUIRED_HEADER[0]) - reference_dict = {} - if "translation" in header and REQUIRED_HEADER[0] == "reference": + translation_id = header.index(REQUIRED_HEADER[0]) + translation_dict = {} + if "translation" in header and REQUIRED_HEADER[0] == "translation": has_translation = True translation_id = header.index("translation") for row in csv_reader: - ref_barcode = row[reference_id].strip(STRIP_CHARS) + ref_barcode = row[translation_id].strip(STRIP_CHARS) tra_barcode = row[translation_id].strip(STRIP_CHARS) if ( len(ref_barcode) == barcode_length and len(tra_barcode) == barcode_length ): - reference_dict[ref_barcode] = tra_barcode + translation_dict[ref_barcode] = tra_barcode else: for row in csv_reader: - ref_barcode = row[reference_id].strip(STRIP_CHARS) + ref_barcode = row[translation_id].strip(STRIP_CHARS) if len(ref_barcode) == barcode_length: - reference_dict[ref_barcode] = 0 + translation_dict[ref_barcode] = 0 - for cell_barcode in reference_dict.keys(): + for cell_barcode in translation_dict.keys(): if not cell_pattern.match(cell_barcode): sys.exit( "This barcode {} is not only composed of ATGC bases.".format( cell_barcode ) ) - if len(reference_dict) == 0: - sys.exit("reference_dict is empty.") + if len(translation_dict) == 0: + sys.exit("translation_dict is empty.") if has_translation: print( - "Your reference list provides a translation name. This will be the default for the count matrices." + "Your translation list provides a translation name. This will be the default for the count matrices." ) - return reference_dict + return translation_dict def parse_tags_csv(filename): @@ -259,11 +259,11 @@ def get_read_length(filename): return read_length -def translate_barcodes(cell_set, reference_dict): - """Translate a list of barcode using a mapping reference +def translate_barcodes(cell_set, translation_dict): + """Translate a list of barcode using a mapping translation Args: cell_set (set): A set of barcodes - reference_dict (dict): A dict providing a simple key value translation + translation_dict (dict): A dict providing a simple key value translation Returns: translated_barcodes (set): A set of translated barcodes @@ -271,53 +271,10 @@ def translate_barcodes(cell_set, reference_dict): translated_barcodes = set() for cell in cell_set: - translate_barcodes.add(reference_dict[cell]) + translate_barcodes.add(translation_dict[cell]) return translated_barcodes -def get_filtered_list(args, chemistry, reference_dict, reads_per_cell): - """ - Determines what mode to use for cell barcode correction. - Args: - args(argparse): All arguments - - Returns: - str: type of correction - """ - if args.bc_threshold != 0: - # Are we provided with a filtered list? - if args.filtered_list: - filtered_set = parse_filtered_list_csv( - args.filtered_list, - (chemistry.cell_barcode_stop - chemistry.cell_barcode_start), - ) - # Do we need to translate the list? - if args.reference_dict: - # get the translation - filtered_set = translate_barcodes( - cell_set=filtered_set, reference_dict=reference_dict - ) - return filtered_set - # We try and rely on the top number of cells now - else: - print("Looking for a reference list") - _, true_to_false = whitelist_methods.getCellWhitelist( - knee_method="density", - cell_barcode_counts=reads_per_cell, - expect_cells=args.expected_cells, - cell_number=args.expected_cells, - error_correct_threshold=args.bc_threshold, - plotfile_prefix=False, - ) - if true_to_false is None: - print( - "Failed to find a good reference list. Will not correct cell barcodes" - ) - return False - else: - return False - - def check_barcodes_lengths(read1_length, cb_first, cb_last, umi_first, umi_last): """Check Read1 length against CELL and UMI barcodes length. @@ -401,3 +358,31 @@ def pre_run_checks(read1_paths, chemistry_def, longest_tag_len, args): R2_min_length = longest_tag_len maximum_distance = args.max_error return n_reads, R2_min_length, maximum_distance + + +def get_filtered_list(args, chemistry, translation_dict): + """ + Determines what mode to use for cell barcode correction. + Args: + args(argparse): All arguments + + Returns: + set if we have a filtered list + None if we want correction and we have not a list + False if we deactivation filtering + """ + if args.filtered_cells: + filtered_set = parse_filtered_list_csv( + args.filtered_cells, + (chemistry.cell_barcode_stop - chemistry.cell_barcode_start), + ) + # Do we need to translate the list? + if args.translation_dict: + # get the translation + translated_set = translate_barcodes( + cell_set=filtered_set, translation_dict=translation_dict + ) + return translated_set + return filtered_set + else: + return None diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 753a811..8ab5050 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -17,8 +17,6 @@ from umi_tools import network import umi_tools.whitelist_methods as whitelist_methods -from cite_seq_count.preprocessing import parse_cell_list_csv - def merge_results(parallel_results, unmapped_id): """Merge chunked results from parallel processing. @@ -58,87 +56,7 @@ def merge_results(parallel_results, unmapped_id): return merged_results, umis_per_cell, reads_per_cell, merged_no_match -def check_unmapped(no_match, too_short, total_reads, start_trim): - """Check if the number of unmapped is higher than 99%""" - sum_unmapped = sum(no_match.values()) + too_short - if sum_unmapped / total_reads > float(0.99): - sys.exit( - """More than 99% of your data is unmapped.\nPlease check that your --start_trim {} parameter is correct and that your tags file is properly formatted""".format( - start_trim - ) - ) - - -def correct_umis_in_cells(umi_correction_input): - """ - Corrects umi barcodes within same cell/tag groups. - - Args: - final_results (dict): Dict of dict of Counters with mapping results. - collapsing_threshold (int): Max distance between umis. - filtered_cells (set): Set of cells to go through. - max_umis (int): Maximum UMIs to consider for one cluster. - - Returns: - final_results (dict): Same as input but with corrected umis. - corrected_umis (int): How many umis have been corrected. - clustered_umi_count_cells (set): Set of uncorrected cells. - """ - - (final_results, collapsing_threshold, max_umis, unmapped_id) = umi_correction_input - print( - "Started umi correction in child process {} working on {} cells".format( - os.getpid(), len(final_results) - ) - ) - corrected_umis = 0 - clustered_cells = set() - cells = final_results.keys() - for cell_barcode in cells: - for TAG in final_results[cell_barcode]: - if TAG == unmapped_id: - final_results[cell_barcode].pop(unmapped_id) - - n_umis = len(final_results[cell_barcode][TAG]) - if n_umis > 1 and n_umis <= max_umis: - umi_clusters = network.UMIClusterer() - UMIclusters = umi_clusters( - final_results[cell_barcode][TAG], collapsing_threshold - ) - (new_res, temp_corrected_umis) = update_umi_counts( - UMIclusters, final_results[cell_barcode][TAG] - ) - final_results[cell_barcode][TAG] = new_res - corrected_umis += temp_corrected_umis - elif n_umis > max_umis: - clustered_cells.add(cell_barcode) - print("Finished correcting umis in child {}".format(os.getpid())) - return (final_results, corrected_umis, clustered_cells) - - -def update_umi_counts(UMIclusters, cell_tag_counts): - """ - Update a dict object with umis corrected. - - Args: - UMIclusters (list): List of lists with corrected umis - cell_tag_counts (Counter): Counter of umis - - Returns: - cell_tag_counts (Counter): Updated Counter of umis - temp_corrected_umis (int): Number of corrected umis - """ - temp_corrected_umis = 0 - for ( - umi_cluster - ) in UMIclusters: # This is a list with the first element the dominant barcode - if len(umi_cluster) > 1: # This means we got a correction - major_umi = umi_cluster[0] - for minor_umi in umi_cluster[1:]: - temp_corrected_umis += 1 - temp = cell_tag_counts.pop(minor_umi) - cell_tag_counts[major_umi] += temp - return (cell_tag_counts, temp_corrected_umis) +# Unit Barcode correction def collapse_cells(true_to_false, umis_per_cell, final_results, ab_map): @@ -146,7 +64,7 @@ def collapse_cells(true_to_false, umis_per_cell, final_results, ab_map): Collapses cell barcodes based on the mapping true_to_false Args: - true_to_false (dict): Mapping between the reference and the "mutated" barcodes. + true_to_false (dict): Mapping between the translation and the "mutated" barcodes. umis_per_cell (Counter): Counter of number of umis per cell. final_results (dict): Dict of dict of Counters with mapping results. ab_map (dict): Dict of the TAGS. @@ -183,7 +101,7 @@ def collapse_cells(true_to_false, umis_per_cell, final_results, ab_map): return (umis_per_cell, final_results, corrected_barcodes) -def correct_cells_no_reference_list( +def correct_cells_no_translation_list( final_results, reads_per_cell, umis_per_cell, @@ -192,7 +110,7 @@ def correct_cells_no_reference_list( ab_map, ): """ - Corrects cell barcodes. + Corrects cell barcodes without a translation. Args: final_results (dict): Dict of dict of Counters with mapping results. @@ -206,7 +124,7 @@ def correct_cells_no_reference_list( umis_per_cell (Counter): Counter of umis per cell after cell barcode correction corrected_umis (int): How many umis have been corrected. """ - print("Looking for a reference list") + print("Looking for a translation list") _, true_to_false = whitelist_methods.getCellWhitelist( knee_method="density", cell_barcode_counts=reads_per_cell, @@ -216,7 +134,7 @@ def correct_cells_no_reference_list( plotfile_prefix=False, ) if true_to_false is None: - print("Failed to find a good reference list. Will not correct cell barcodes") + print("Failed to find a good translation list. Will not correct cell barcodes") corrected_barcodes = 0 return (final_results, umis_per_cell, corrected_barcodes) (umis_per_cell, final_results, corrected_barcodes) = collapse_cells( @@ -228,16 +146,16 @@ def correct_cells_no_reference_list( return (final_results, umis_per_cell, corrected_barcodes) -def correct_cells_reference_list( - final_results, umis_per_cell, reference_list, collapsing_threshold, ab_map +def correct_cells_translation_list( + final_results, umis_per_cell, translation_list, collapsing_threshold, ab_map ): """ - Corrects cell barcodes based on a given reference_list. + Corrects cell barcodes based on a given translation_list. Args: final_results (dict): Dict of dict of Counters with mapping results. umis_per_cell (Counter): Counter of UMIs per cell. - reference_list (set): The reference_list reference given by the user. + translation_list (set): The translation_list translation given by the user. collapsing_threshold (int): Max distance between umis. ab_map (OrederedDict): Tags in an ordered dict. @@ -247,18 +165,18 @@ def correct_cells_reference_list( umis_per_cell (Counter): Updated UMI counts after correction. corrected_barcodes (int): How many umis have been corrected. """ - print("Generating barcode tree from reference list") + print("Generating barcode tree from translation list") # pylint: disable=no-member - barcode_tree = pybktree.BKTree(Levenshtein.hamming, reference_list) + barcode_tree = pybktree.BKTree(Levenshtein.hamming, translation_list) barcodes = set(umis_per_cell) - print("Selecting reference candidates") + print("Selecting translation candidates") print("Processing {:,} cell barcodes".format(len(barcodes))) # Run with one process true_to_false = find_true_to_false_map( barcode_tree=barcode_tree, cell_barcodes=barcodes, - reference_list=reference_list, + translation_list=translation_list, collapsing_threshold=collapsing_threshold, ) print("Collapsing wrong barcodes with original barcodes") @@ -269,7 +187,7 @@ def correct_cells_reference_list( def find_true_to_false_map( - barcode_tree, cell_barcodes, reference_list, collapsing_threshold + barcode_tree, cell_barcodes, translation_list, collapsing_threshold ): """ Creates a mapping between "fake" cell barcodes and their original true barcode. @@ -277,7 +195,7 @@ def find_true_to_false_map( Args: barcode_tree (BKTree): BKTree of all original cell barcodes. cell_barcodes (List): Cell barcodes to go through. - reference_list (dict): Dict of the reference_list, the "true" cell barcodes. + translation_list (dict): Dict of the translation_list, the "true" cell barcodes. collasping_threshold (int): How many mistakes to correct. Return: @@ -285,10 +203,10 @@ def find_true_to_false_map( """ true_to_false = defaultdict(list) for cell_barcode in cell_barcodes: - if cell_barcode in reference_list: - # if the barcode is already reference_listed, no need to add + if cell_barcode in translation_list: + # if the barcode is already translation_listed, no need to add continue - # get all members of reference_list that are at distance of collapsing_threshold + # get all members of translation_list that are at distance of collapsing_threshold candidates = [ white_cell for d, white_cell in barcode_tree.find(cell_barcode, collapsing_threshold) @@ -298,50 +216,175 @@ def find_true_to_false_map( white_cell_str = candidates[0] true_to_false[white_cell_str].append(cell_barcode) else: - # the cell doesnt match to any reference_listed barcode, + # the cell doesnt match to any translation_listed barcode, # hence we have to drop it # (as it cannot be asscociated with any frequent barcode) continue return true_to_false -def generate_sparse_matrices( - final_results, ordered_tags, filtered_cells, umi_counts=False +def run_cell_barcode_correction( + final_results, umis_per_cell, ordered_tags, filtered_set, args, ): - """ - Create two sparse matrices with umi and read counts. + if args.expected_cells > len(filtered_set): + print( + "Number of expected cells, {}, is higher " + "than number of cells found {}.\nNot performing " + "cell barcode correction" + "".format(args.expected_cells, len(umis_per_cell)) + ) + bcs_corrected = 0 + return final_results, umis_per_cell, bcs_corrected - Args: - final_results (dict): Results in a dict of dicts of Counters. - ordered_tags (list): Ordered tags in a list of tuples. + elif type(filtered_set) == set: + (final_results, umis_per_cell, bcs_corrected,) = correct_cells_translation_list( + final_results=final_results, + umis_per_cell=umis_per_cell, + translation_list=filtered_set, + collapsing_threshold=args.bc_threshold, + ab_map=ordered_tags, + ) + for missing_cell in filtered_set: + if missing_cell in final_results: + continue + else: + final_results[missing_cell] = dict() + for TAG in ordered_tags: + final_results[missing_cell][TAG.safe_name] = Counter() + return final_results, umis_per_cell, bcs_corrected - Returns: - results_matrix (scipy.sparse.dok_matrix): UMI counts + +def check_filtered_cells(filtered_cells, expected_cells, umis_per_cell): + if filtered_cells is None: + top_cells_tuple = umis_per_cell.most_common(expected_cells) + # Select top cells based on total umis per cell + filtered_cells = set([pair[0] for pair in top_cells_tuple]) + return filtered_cells + + +# def choose_filtered_cells( +# given_filtered_cells, +# expected_cells, +# chemistry_def, +# final_results, +# ordered_tags, +# umis_per_cell, +# translation_dict, +# ): +# """ +# Returns a list of barcodes that will be in the output +# and helps decide based on the inputs. + +# Args: +# given_filtered_cells (bool or str): False if not given, else string +# expected_cells (int): Number of expected cells +# chemistry_def (Chemistry): Defines the details of the chemistry +# final_results (dict): All results +# ordered_tags (named_tuple): Holds tags info +# umis_per_cell (Counter): Holds number of UMIs per barcode + +# Returns: +# set: filtered cell set +# """ +# # If given, use filtered_list for top cells +# if given_filtered_cells: +# filtered_cells = set( +# parse_cell_list_csv( +# filename=given_filtered_cells, +# barcode_length=chemistry_def.cell_barcode_end +# - chemistry_def.cell_barcode_start +# + 1, +# file_type="filtered", +# ).keys() +# ) +# # Add potential missing cell barcodes. +# for missing_cell in filtered_cells: +# if missing_cell in final_results: +# continue +# else: +# final_results[missing_cell] = dict() +# for TAG in ordered_tags: +# final_results[missing_cell][TAG.safe_name] = Counter() +# filtered_cells.add(missing_cell) +# else: +# top_cells_tuple = umis_per_cell.most_common(expected_cells) +# # Select top cells based on total umis per cell +# filtered_cells = [pair[0] for pair in top_cells_tuple] + + +# UMI correction section +def correct_umis_in_cells(umi_correction_input): + """ + Corrects umi barcodes within same cell/tag groups. + + Args: + final_results (dict): Dict of dict of Counters with mapping results. + collapsing_threshold (int): Max distance between umis. + filtered_cells (set): Set of cells to go through. + max_umis (int): Maximum UMIs to consider for one cluster. + + Returns: + final_results (dict): Same as input but with corrected umis. + corrected_umis (int): How many umis have been corrected. + clustered_umi_count_cells (set): Set of uncorrected cells. """ - unmapped_id = len(ordered_tags) - if umi_counts: - n_features = len(ordered_tags) - else: - n_features = len(ordered_tags) + 1 - results_matrix = sparse.dok_matrix((n_features, len(filtered_cells)), dtype=int32) - # print(ordered_tags) - for i, cell_barcode in enumerate(filtered_cells): - if cell_barcode not in final_results.keys(): - continue - for TAG_id in final_results[cell_barcode]: - # if TAG_id in final_results[cell_barcode]: - if umi_counts: - if TAG_id == unmapped_id: - continue - results_matrix[TAG_id, i] = len(final_results[cell_barcode][TAG_id]) - else: - results_matrix[TAG_id, i] = sum( - final_results[cell_barcode][TAG_id].values() + (final_results, collapsing_threshold, max_umis, unmapped_id) = umi_correction_input + print( + "Started umi correction in child process {} working on {} cells".format( + os.getpid(), len(final_results) + ) + ) + corrected_umis = 0 + clustered_cells = set() + cells = final_results.keys() + for cell_barcode in cells: + for TAG in final_results[cell_barcode]: + if TAG == unmapped_id: + final_results[cell_barcode].pop(unmapped_id) + + n_umis = len(final_results[cell_barcode][TAG]) + if n_umis > 1 and n_umis <= max_umis: + umi_clusters = network.UMIClusterer() + UMIclusters = umi_clusters( + final_results[cell_barcode][TAG], collapsing_threshold ) - return results_matrix + (new_res, temp_corrected_umis) = update_umi_counts( + UMIclusters, final_results[cell_barcode][TAG] + ) + final_results[cell_barcode][TAG] = new_res + corrected_umis += temp_corrected_umis + elif n_umis > max_umis: + clustered_cells.add(cell_barcode) + print("Finished correcting umis in child {}".format(os.getpid())) + return (final_results, corrected_umis, clustered_cells) + + +def update_umi_counts(UMIclusters, cell_tag_counts): + """ + Update a dict object with umis corrected. + + Args: + UMIclusters (list): List of lists with corrected umis + cell_tag_counts (Counter): Counter of umis + + Returns: + cell_tag_counts (Counter): Updated Counter of umis + temp_corrected_umis (int): Number of corrected umis + """ + temp_corrected_umis = 0 + for ( + umi_cluster + ) in UMIclusters: # This is a list with the first element the dominant barcode + if len(umi_cluster) > 1: # This means we got a correction + major_umi = umi_cluster[0] + for minor_umi in umi_cluster[1:]: + temp_corrected_umis += 1 + temp = cell_tag_counts.pop(minor_umi) + cell_tag_counts[major_umi] += temp + return (cell_tag_counts, temp_corrected_umis) def run_umi_correction(final_results, filtered_cells, unmapped_id, args): @@ -414,95 +457,41 @@ def run_umi_correction(final_results, filtered_cells, unmapped_id, args): return final_results, umis_corrected, clustered_cells -def run_cell_barcode_correction( - final_results, - umis_per_cell, - reads_per_cell, - reference_dict, - ordered_tags, - cell_barcode_correction, - args, -): - if cell_barcode_correction == "top": - if len(umis_per_cell) <= args.expected_cells: - print( - "Number of expected cells, {}, is higher " - "than number of cells found {}.\nNot performing " - "cell barcode correction" - "".format(args.expected_cells, len(umis_per_cell)) - ) - bcs_corrected = 0 - else: - print("Reference list not given") - ( - final_results, - umis_per_cell, - bcs_corrected, - ) = correct_cells_no_reference_list( - final_results=final_results, - reads_per_cell=reads_per_cell, - umis_per_cell=umis_per_cell, - expected_cells=args.expected_cells, - collapsing_threshold=args.bc_threshold, - ab_map=ordered_tags, - ) - elif cell_barcode_correction == "list": - (final_results, umis_per_cell, bcs_corrected,) = correct_cells_reference_list( - final_results=final_results, - umis_per_cell=umis_per_cell, - reference_list=set(reference_dict.keys()), - collapsing_threshold=args.bc_threshold, - ab_map=ordered_tags, - ) - return final_results, umis_per_cell, bcs_corrected - - -def choose_filtered_cells( - given_filtered_cells, - expected_cells, - chemistry_def, - final_results, - ordered_tags, - umis_per_cell, - translation_dict, +def generate_sparse_matrices( + final_results, ordered_tags, filtered_cells, umi_counts=False ): """ - Returns a list of barcodes that will be in the output - and helps decide based on the inputs. + Create two sparse matrices with umi and read counts. Args: - given_filtered_cells (bool or str): False if not given, else string - expected_cells (int): Number of expected cells - chemistry_def (Chemistry): Defines the details of the chemistry - final_results (dict): All results - ordered_tags (named_tuple): Holds tags info - umis_per_cell (Counter): Holds number of UMIs per barcode - + final_results (dict): Results in a dict of dicts of Counters. + ordered_tags (list): Ordered tags in a list of tuples. + Returns: - set: filtered cell set + results_matrix (scipy.sparse.dok_matrix): UMI counts + + """ - # If given, use filtered_list for top cells - if given_filtered_cells: - filtered_cells = set( - parse_cell_list_csv( - filename=given_filtered_cells, - barcode_length=chemistry_def.cell_barcode_end - - chemistry_def.cell_barcode_start - + 1, - file_type="filtered", - ).keys() - ) - # Add potential missing cell barcodes. - for missing_cell in filtered_cells: - if missing_cell in final_results: - continue - else: - final_results[missing_cell] = dict() - for TAG in ordered_tags: - final_results[missing_cell][TAG.safe_name] = Counter() - filtered_cells.add(missing_cell) + unmapped_id = len(ordered_tags) + if umi_counts: + n_features = len(ordered_tags) else: - top_cells_tuple = umis_per_cell.most_common(expected_cells) - # Select top cells based on total umis per cell - filtered_cells = [pair[0] for pair in top_cells_tuple] + n_features = len(ordered_tags) + 1 + results_matrix = sparse.dok_matrix((n_features, len(filtered_cells)), dtype=int32) + # print(ordered_tags) + + for i, cell_barcode in enumerate(filtered_cells): + if cell_barcode not in final_results.keys(): + continue + for TAG_id in final_results[cell_barcode]: + # if TAG_id in final_results[cell_barcode]: + if umi_counts: + if TAG_id == unmapped_id: + continue + results_matrix[TAG_id, i] = len(final_results[cell_barcode][TAG_id]) + else: + results_matrix[TAG_id, i] = sum( + final_results[cell_barcode][TAG_id].values() + ) + return results_matrix From 5c57d12ae59d7304110175535313045afbebda9a Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Fri, 9 Jul 2021 18:36:40 +0200 Subject: [PATCH 48/77] refactored the filtering and translation --- cite_seq_count/__main__.py | 4 ++-- cite_seq_count/chemistry.py | 2 +- cite_seq_count/io.py | 5 ++-- cite_seq_count/preprocessing.py | 19 ++++++++------- cite_seq_count/processing.py | 42 ++++++++++++++++----------------- 5 files changed, 37 insertions(+), 35 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index a31d334..c147f77 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -53,6 +53,7 @@ def main(): read1_paths, read2_paths = io.get_read_paths(args.read1_path, args.read2_path) # Check filtered input list + # If a translation is given, will return the translated version filtered_cells = preprocessing.get_filtered_list( args=args, chemistry=chemistry_def, translation_dict=translation_dict ) @@ -97,7 +98,7 @@ def main(): # Remove temp chunks for file_path in temp_files: os.remove(file_path) - + # Check that we have a filtered cell list to work on filtered_cells = processing.check_filtered_cells( filtered_cells=filtered_cells, expected_cells=args.expected_cells, @@ -138,7 +139,6 @@ def main(): ) # UMI correction - if args.umi_threshold != 0: # Correct UMIS ( diff --git a/cite_seq_count/chemistry.py b/cite_seq_count/chemistry.py index 8acf502..cccacb9 100644 --- a/cite_seq_count/chemistry.py +++ b/cite_seq_count/chemistry.py @@ -58,7 +58,7 @@ def fetch_definitions(): return json_data -def list_chemistries(chemistry_defs): +def list_chemistries(all_chemistry_defs): """ List all the available chemistries in the database Args: diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index a94ad57..436521b 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -111,8 +111,9 @@ def write_to_files( data_type (string): A string definning if the data is umi or read based. outfolder (string): Path to the output folder. """ - original_barcode = list(translation_dict.keys()) - translated_barcode = list(translation_dict.values()) + if translation_dict: + original_barcode = list(translation_dict.keys()) + translated_barcode = list(translation_dict.values()) prefix = os.path.join(outfolder, data_type + "_count") os.makedirs(prefix, exist_ok=True) io.mmwrite(os.path.join(prefix, "matrix.mtx"), a=sparse_matrix, field="integer") diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 97b8f1c..f987588 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -62,7 +62,7 @@ def parse_cell_list_csv(filename, barcode_length, file_type): """ STRIP_CHARS = '"0123456789- \t\n' if file_type == "translation": - REQUIRED_HEADER = ["translation"] + REQUIRED_HEADER = ["reference", "translation"] elif file_type == "filtered": REQUIRED_HEADER = ["filtered_list"] @@ -78,23 +78,24 @@ def parse_cell_list_csv(filename, barcode_length, file_type): "The header is missing {}. Exiting".format(",".join(list(set_dif))) ) - translation_id = header.index(REQUIRED_HEADER[0]) + # translation_id = header.index(REQUIRED_HEADER[0]) translation_dict = {} - if "translation" in header and REQUIRED_HEADER[0] == "translation": + if "translation" in header: has_translation = True translation_id = header.index("translation") + reference_id = header.index("reference") for row in csv_reader: - ref_barcode = row[translation_id].strip(STRIP_CHARS) + ref_barcode = row[reference_id].strip(STRIP_CHARS) tra_barcode = row[translation_id].strip(STRIP_CHARS) if ( len(ref_barcode) == barcode_length and len(tra_barcode) == barcode_length ): - translation_dict[ref_barcode] = tra_barcode + translation_dict[tra_barcode] = ref_barcode else: for row in csv_reader: - ref_barcode = row[translation_id].strip(STRIP_CHARS) + ref_barcode = row[0].strip(STRIP_CHARS) if len(ref_barcode) == barcode_length: translation_dict[ref_barcode] = 0 @@ -271,7 +272,7 @@ def translate_barcodes(cell_set, translation_dict): translated_barcodes = set() for cell in cell_set: - translate_barcodes.add(translation_dict[cell]) + translated_barcodes.add(translation_dict[cell]) return translated_barcodes @@ -374,10 +375,10 @@ def get_filtered_list(args, chemistry, translation_dict): if args.filtered_cells: filtered_set = parse_filtered_list_csv( args.filtered_cells, - (chemistry.cell_barcode_stop - chemistry.cell_barcode_start), + (chemistry.cell_barcode_end - chemistry.cell_barcode_start), ) # Do we need to translate the list? - if args.translation_dict: + if args.translation_list: # get the translation translated_set = translate_barcodes( cell_set=filtered_set, translation_dict=translation_dict diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 8ab5050..143f00f 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -124,7 +124,7 @@ def correct_cells_no_translation_list( umis_per_cell (Counter): Counter of umis per cell after cell barcode correction corrected_umis (int): How many umis have been corrected. """ - print("Looking for a translation list") + print("Looking for a reference list") _, true_to_false = whitelist_methods.getCellWhitelist( knee_method="density", cell_barcode_counts=reads_per_cell, @@ -134,7 +134,7 @@ def correct_cells_no_translation_list( plotfile_prefix=False, ) if true_to_false is None: - print("Failed to find a good translation list. Will not correct cell barcodes") + print("Failed to find a good reference list. Will not correct cell barcodes") corrected_barcodes = 0 return (final_results, umis_per_cell, corrected_barcodes) (umis_per_cell, final_results, corrected_barcodes) = collapse_cells( @@ -146,8 +146,8 @@ def correct_cells_no_translation_list( return (final_results, umis_per_cell, corrected_barcodes) -def correct_cells_translation_list( - final_results, umis_per_cell, translation_list, collapsing_threshold, ab_map +def correct_cells_filtered_set( + final_results, umis_per_cell, filtered_set, collapsing_threshold, ab_map ): """ Corrects cell barcodes based on a given translation_list. @@ -165,18 +165,18 @@ def correct_cells_translation_list( umis_per_cell (Counter): Updated UMI counts after correction. corrected_barcodes (int): How many umis have been corrected. """ - print("Generating barcode tree from translation list") + print("Generating barcode tree from reference list") # pylint: disable=no-member - barcode_tree = pybktree.BKTree(Levenshtein.hamming, translation_list) + barcode_tree = pybktree.BKTree(Levenshtein.hamming, filtered_set) barcodes = set(umis_per_cell) - print("Selecting translation candidates") + print("Selecting reference candidates") print("Processing {:,} cell barcodes".format(len(barcodes))) # Run with one process true_to_false = find_true_to_false_map( barcode_tree=barcode_tree, cell_barcodes=barcodes, - translation_list=translation_list, + filtered_set=filtered_set, collapsing_threshold=collapsing_threshold, ) print("Collapsing wrong barcodes with original barcodes") @@ -187,7 +187,7 @@ def correct_cells_translation_list( def find_true_to_false_map( - barcode_tree, cell_barcodes, translation_list, collapsing_threshold + barcode_tree, cell_barcodes, filtered_set, collapsing_threshold ): """ Creates a mapping between "fake" cell barcodes and their original true barcode. @@ -195,7 +195,7 @@ def find_true_to_false_map( Args: barcode_tree (BKTree): BKTree of all original cell barcodes. cell_barcodes (List): Cell barcodes to go through. - translation_list (dict): Dict of the translation_list, the "true" cell barcodes. + filtered_set (dict): Dict of the filtered_set, the "true" cell barcodes. collasping_threshold (int): How many mistakes to correct. Return: @@ -203,10 +203,10 @@ def find_true_to_false_map( """ true_to_false = defaultdict(list) for cell_barcode in cell_barcodes: - if cell_barcode in translation_list: - # if the barcode is already translation_listed, no need to add + if cell_barcode in filtered_set: + # if the barcode is already filtered_set, no need to add continue - # get all members of translation_list that are at distance of collapsing_threshold + # get all members of filtered_set that are at distance of collapsing_threshold candidates = [ white_cell for d, white_cell in barcode_tree.find(cell_barcode, collapsing_threshold) @@ -216,7 +216,7 @@ def find_true_to_false_map( white_cell_str = candidates[0] true_to_false[white_cell_str].append(cell_barcode) else: - # the cell doesnt match to any translation_listed barcode, + # the cell doesnt match to any filtered_set barcode, # hence we have to drop it # (as it cannot be asscociated with any frequent barcode) continue @@ -237,10 +237,10 @@ def run_cell_barcode_correction( return final_results, umis_per_cell, bcs_corrected elif type(filtered_set) == set: - (final_results, umis_per_cell, bcs_corrected,) = correct_cells_translation_list( + (final_results, umis_per_cell, bcs_corrected,) = correct_cells_filtered_set( final_results=final_results, umis_per_cell=umis_per_cell, - translation_list=filtered_set, + filtered_set=filtered_set, collapsing_threshold=args.bc_threshold, ab_map=ordered_tags, ) @@ -250,7 +250,7 @@ def run_cell_barcode_correction( else: final_results[missing_cell] = dict() for TAG in ordered_tags: - final_results[missing_cell][TAG.safe_name] = Counter() + final_results[missing_cell][TAG.id] = Counter() return final_results, umis_per_cell, bcs_corrected @@ -304,7 +304,7 @@ def check_filtered_cells(filtered_cells, expected_cells, umis_per_cell): # else: # final_results[missing_cell] = dict() # for TAG in ordered_tags: -# final_results[missing_cell][TAG.safe_name] = Counter() +# final_results[missing_cell][TAG.name] = Counter() # filtered_cells.add(missing_cell) # else: # top_cells_tuple = umis_per_cell.most_common(expected_cells) @@ -468,7 +468,7 @@ def generate_sparse_matrices( ordered_tags (list): Ordered tags in a list of tuples. Returns: - results_matrix (scipy.sparse.dok_matrix): UMI counts + results_matrix (scipy.sparse.dok_matrix): UMI or Read counts """ @@ -478,7 +478,6 @@ def generate_sparse_matrices( else: n_features = len(ordered_tags) + 1 results_matrix = sparse.dok_matrix((n_features, len(filtered_cells)), dtype=int32) - # print(ordered_tags) for i, cell_barcode in enumerate(filtered_cells): if cell_barcode not in final_results.keys(): @@ -488,7 +487,8 @@ def generate_sparse_matrices( if umi_counts: if TAG_id == unmapped_id: continue - results_matrix[TAG_id, i] = len(final_results[cell_barcode][TAG_id]) + else: + results_matrix[TAG_id, i] = len(final_results[cell_barcode][TAG_id]) else: results_matrix[TAG_id, i] = sum( final_results[cell_barcode][TAG_id].values() From 757a4a638e36d520c6a0a44e3f4e2726f900dae6 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sat, 10 Jul 2021 17:38:01 +0200 Subject: [PATCH 49/77] fixed translation reading --- cite_seq_count/chemistry.py | 2 -- cite_seq_count/preprocessing.py | 28 +++++++++------------------- tests/test_preprocessing.py | 4 ++-- 3 files changed, 11 insertions(+), 23 deletions(-) diff --git a/cite_seq_count/chemistry.py b/cite_seq_count/chemistry.py index cccacb9..171b56d 100644 --- a/cite_seq_count/chemistry.py +++ b/cite_seq_count/chemistry.py @@ -143,7 +143,6 @@ def setup_chemistry(args): barcode_length=chemistry_def.cell_barcode_end - chemistry_def.cell_barcode_start + 1, - file_type="translation", ) else: chemistry_def = create_chemistry_definition(args) @@ -152,7 +151,6 @@ def setup_chemistry(args): translation_dict = preprocessing.parse_cell_list_csv( filename=args.translation_list, barcode_length=args.cb_last - args.cb_first + 1, - file_type="translation", ) else: translation_dict = False diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index f987588..101dde4 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -46,7 +46,7 @@ def parse_filtered_list_csv(filename, barcode_length): return out_set -def parse_cell_list_csv(filename, barcode_length, file_type): +def parse_cell_list_csv(filename, barcode_length): """Reads white-listed barcodes from a CSV file. The function accepts plain barcodes or even 10X style barcodes with the @@ -61,13 +61,7 @@ def parse_cell_list_csv(filename, barcode_length, file_type): """ STRIP_CHARS = '"0123456789- \t\n' - if file_type == "translation": - REQUIRED_HEADER = ["reference", "translation"] - elif file_type == "filtered": - REQUIRED_HEADER = ["filtered_list"] - - has_translation = False - # OPTIONAL_HEADER = ["translation", "filtered_list"] + REQUIRED_HEADER = ["reference", "translation"] cell_pattern = regex.compile(r"^[ATGC]{{{}}}".format(barcode_length)) csv_reader = get_csv_reader_from_path(filename=filename) @@ -81,7 +75,6 @@ def parse_cell_list_csv(filename, barcode_length, file_type): # translation_id = header.index(REQUIRED_HEADER[0]) translation_dict = {} if "translation" in header: - has_translation = True translation_id = header.index("translation") reference_id = header.index("reference") @@ -94,10 +87,7 @@ def parse_cell_list_csv(filename, barcode_length, file_type): ): translation_dict[tra_barcode] = ref_barcode else: - for row in csv_reader: - ref_barcode = row[0].strip(STRIP_CHARS) - if len(ref_barcode) == barcode_length: - translation_dict[ref_barcode] = 0 + sys.exit('The header is missing a the "{}" keyword'.format("translation")) for cell_barcode in translation_dict.keys(): if not cell_pattern.match(cell_barcode): @@ -108,10 +98,9 @@ def parse_cell_list_csv(filename, barcode_length, file_type): ) if len(translation_dict) == 0: sys.exit("translation_dict is empty.") - if has_translation: - print( - "Your translation list provides a translation name. This will be the default for the count matrices." - ) + print( + "Your translation list provides a translation name. This will be the default for the count matrices." + ) return translation_dict @@ -271,8 +260,9 @@ def translate_barcodes(cell_set, translation_dict): """ translated_barcodes = set() - for cell in cell_set: - translated_barcodes.add(translation_dict[cell]) + for translated_barcode in translation_dict.keys(): + if translation_dict[translated_barcode] in cell_set: + translated_barcodes.add(translated_barcode) return translated_barcodes diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 3dd5b21..0c836bb 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -74,14 +74,14 @@ def test_filtered_list_parser(data): def test_parse_reference_list_csv(data): passing_files = glob.glob(pytest.passing_reference_list_csv) for file_path in passing_files: - assert preprocessing.parse_cell_list_csv(file_path, 16, "reference").keys() in ( + assert preprocessing.parse_cell_list_csv(file_path, 16).keys() in ( pytest.correct_reference_list, 1, ) with pytest.raises(SystemExit): failing_files = glob.glob(pytest.failing_reference_list_csv) for file_path in failing_files: - preprocessing.parse_cell_list_csv(file_path, 16, "reference") + preprocessing.parse_cell_list_csv(file_path, 16) @pytest.mark.dependency() From 1d6ebdc3ae73d7a37d880bf7c6db2489e2eaad0a Mon Sep 17 00:00:00 2001 From: Hoohm Date: Sun, 21 Nov 2021 13:49:57 +0100 Subject: [PATCH 50/77] added file checks for sequencing data --- cite_seq_count/argsparser.py | 2 +- cite_seq_count/io.py | 23 ++++-- cite_seq_count/mapping.py | 13 ++-- cite_seq_count/preprocessing.py | 22 +++--- cite_seq_count/processing.py | 19 ++--- cite_seq_count/secondsToText.py | 121 +++++++++++++++++--------------- setup.py | 1 + tests/test_preprocessing.py | 1 - 8 files changed, 115 insertions(+), 87 deletions(-) diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py index 39d6f51..091a896 100644 --- a/cite_seq_count/argsparser.py +++ b/cite_seq_count/argsparser.py @@ -2,7 +2,7 @@ import sys import tempfile -from argparse import ArgumentParser, ArgumentTypeError, RawTextHelpFormatter +from argparse import ArgumentParser, ArgumentTypeError, RawTextHelpFormatter, FileType # pylint: disable=no-name-in-module from multiprocess import cpu_count diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 436521b..b22dd87 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -23,7 +23,7 @@ def blocks(files, size=65536): https://stackoverflow.com/a/9631635/9178565 Args: - files (io.handler): A file handler + files (io.handler): A file handler size (int): Block size Returns: A generator @@ -77,25 +77,35 @@ def get_read_paths(read1_path, read2_path): "Unequal number of read1 ({}) and read2({}) files provided" "\n Exiting".format(len(_read1_path), len(_read2_path)) ) + all_files = _read1_path + _read2_path + for file_path in all_files: + if os.path.isfile(file_path): + if os.access(file_path, os.R_OK): + continue + else: + sys.exit("{} is not readable. Exiting".format(file_path)) + else: + sys.exit("{} does not exist. Exiting".format(file_path)) + return (_read1_path, _read2_path) -def get_csv_reader_from_path(filename): +def get_csv_reader_from_path(filename, sep="\t"): """ Returns a csv_reader object for a file weather it's a flat file or compressed. Args: filename: str - + Returns: csv_reader: The csv_reader for the file """ if filename.endswith(".gz"): f = gzip.open(filename, mode="rt") - csv_reader = csv.reader(f) + csv_reader = csv.reader(f, delimiter=sep) else: f = open(filename, encoding="UTF-8") - csv_reader = csv.reader(f) + csv_reader = csv.reader(f, delimiter=sep) return csv_reader @@ -143,7 +153,7 @@ def write_to_files( def write_dense(sparse_matrix, ordered_tags, columns, outfolder, filename): """ Writes a dense matrix in a csv format - + Args: sparse_matrix (dok_matrix): Results in a sparse matrix. index (list): List of TAGS @@ -417,4 +427,3 @@ def write_chunks_to_disk( R2_too_short, total_reads, ) - diff --git a/cite_seq_count/mapping.py b/cite_seq_count/mapping.py index 910174b..d3ed584 100644 --- a/cite_seq_count/mapping.py +++ b/cite_seq_count/mapping.py @@ -23,7 +23,7 @@ def map_data(input_queue, unmapped_id, args): Args: input_queue (list): List of parameters to run in parallel args (argparse): List of arguments - + Returns: final_results (dict): final dictionnary with results umis_per_cell (Counter): Counter of UMIs per cell @@ -60,9 +60,12 @@ def map_data(input_queue, unmapped_id, args): print(error) print("Merging results") - (final_results, umis_per_cell, reads_per_cell, merged_no_match,) = merge_results( - parallel_results=parallel_results[0], unmapped_id=unmapped_id - ) + ( + final_results, + umis_per_cell, + reads_per_cell, + merged_no_match, + ) = merge_results(parallel_results=parallel_results[0], unmapped_id=unmapped_id) return final_results, umis_per_cell, reads_per_cell, merged_no_match @@ -104,7 +107,7 @@ def find_best_match_shift(TAG_seq, tags): Only works with exact match. Just checks if the string is in the sequence. If no matches found returns 'unmapped'. - + Args: TAG_seq (string): Sequence from R2 already start trimmed tags (dict): A dictionary with the TAGs as keys and TAG Names as values. diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 101dde4..ec5cf82 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -21,7 +21,7 @@ def parse_filtered_list_csv(filename, barcode_length): Args: filename(str): file path barcode_length(int): Barcode expected length - + Returns: set: A set of barcodes """ @@ -152,7 +152,7 @@ def parse_tags_csv(filename): def check_tags(tags, maximum_distance): """Evaluates the distance between the TAGs based on the `maximum distance` argument provided. - + The output will have the keys sorted by TAG length (longer first). This way, longer barcodes will be evaluated first. Adds unmapped category as well. @@ -175,7 +175,13 @@ def check_tags(tags, maximum_distance): safe_name = sanitize_name(tags[tag_seq]) # for index, tag_name in enumerate(ordered_tags): - tag_list.append(tag(name=safe_name, sequence=tag_seq, id=i,)) + tag_list.append( + tag( + name=safe_name, + sequence=tag_seq, + id=i, + ) + ) if len(tag_seq) > longest_tag_len: longest_tag_len = len(tag_seq) seq_list.append(tag_seq) @@ -216,7 +222,7 @@ def sanitize_name(string): Args: string(str): a string from a feature name - + Returns: str: modified string """ @@ -254,7 +260,7 @@ def translate_barcodes(cell_set, translation_dict): Args: cell_set (set): A set of barcodes translation_dict (dict): A dict providing a simple key value translation - + Returns: translated_barcodes (set): A set of translated barcodes """ @@ -297,14 +303,14 @@ def check_barcodes_lengths(read1_length, cb_first, cb_last, umi_first, umi_last) def pre_run_checks(read1_paths, chemistry_def, longest_tag_len, args): - """ Checks that the chemistry is properly set and defines how many reads to process + """Checks that the chemistry is properly set and defines how many reads to process Args: read1_paths (list): List of paths chemistry_def (Chemistry): Chemistry definition longest_tag_len (int): Longest tag sequence args (argparse): List of arguments - + Returns: n_reads (int): Number of reads to run on R2_min_length (int): Min R2 length to check if reads are too short @@ -356,7 +362,7 @@ def get_filtered_list(args, chemistry, translation_dict): Determines what mode to use for cell barcode correction. Args: args(argparse): All arguments - + Returns: set if we have a filtered list None if we want correction and we have not a list diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 143f00f..bfee077 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -111,14 +111,14 @@ def correct_cells_no_translation_list( ): """ Corrects cell barcodes without a translation. - + Args: final_results (dict): Dict of dict of Counters with mapping results. umis_per_cell (Counter): Counter of number of umis per cell. collapsing_threshold (int): Max distance between umis. expected_cells (int): Number of expected cells. ab_map (dict): Dict of the TAGS. - + Returns: final_results (dict): Same as input but with corrected umis. umis_per_cell (Counter): Counter of umis per cell after cell barcode correction @@ -151,7 +151,7 @@ def correct_cells_filtered_set( ): """ Corrects cell barcodes based on a given translation_list. - + Args: final_results (dict): Dict of dict of Counters with mapping results. umis_per_cell (Counter): Counter of UMIs per cell. @@ -159,7 +159,7 @@ def correct_cells_filtered_set( collapsing_threshold (int): Max distance between umis. ab_map (OrederedDict): Tags in an ordered dict. - + Returns: final_results (dict): Same as input but with corrected umis. umis_per_cell (Counter): Updated UMI counts after correction. @@ -224,7 +224,11 @@ def find_true_to_false_map( def run_cell_barcode_correction( - final_results, umis_per_cell, ordered_tags, filtered_set, args, + final_results, + umis_per_cell, + ordered_tags, + filtered_set, + args, ): if args.expected_cells > len(filtered_set): print( @@ -318,13 +322,13 @@ def check_filtered_cells(filtered_cells, expected_cells, umis_per_cell): def correct_umis_in_cells(umi_correction_input): """ Corrects umi barcodes within same cell/tag groups. - + Args: final_results (dict): Dict of dict of Counters with mapping results. collapsing_threshold (int): Max distance between umis. filtered_cells (set): Set of cells to go through. max_umis (int): Maximum UMIs to consider for one cluster. - + Returns: final_results (dict): Same as input but with corrected umis. corrected_umis (int): How many umis have been corrected. @@ -494,4 +498,3 @@ def generate_sparse_matrices( final_results[cell_barcode][TAG_id].values() ) return results_matrix - diff --git a/cite_seq_count/secondsToText.py b/cite_seq_count/secondsToText.py index e0799e1..2dac866 100644 --- a/cite_seq_count/secondsToText.py +++ b/cite_seq_count/secondsToText.py @@ -1,66 +1,73 @@ # Gist found here: https://gist.github.com/Highstaker/280a09591df4a5fb1363b0bbaf858f0d + def pluralizeRussian(number, nom_sing, gen_sing, gen_pl): - """ - Changes the hours, minutes, seconds to plural - """ - s_last_digit = str(number)[-1] + """ + Changes the hours, minutes, seconds to plural + """ + s_last_digit = str(number)[-1] + + if int(str(number)[-2:]) in range(11, 20): + # 11-19 + return gen_pl + elif s_last_digit == "1": + # 1 + return nom_sing + elif int(s_last_digit) in range(2, 5): + # 2,3,4 + return gen_sing + else: + # 5,6,7,8,9,0 + return gen_pl - if int(str(number)[-2:]) in range(11,20): - #11-19 - return gen_pl - elif s_last_digit == '1': - #1 - return nom_sing - elif int(s_last_digit) in range(2,5): - #2,3,4 - return gen_sing - else: - #5,6,7,8,9,0 - return gen_pl def secondsToText(secs, lang="EN"): - """ - Converts datetime to human readable hours, minutes, secondes format. + """ + Converts datetime to human readable hours, minutes, secondes format. + + Args: + secs (float): Secondes + lang (string): Language - Args: - secs (float): Secondes - lang (string): Language - - Returns: - string: Human readable datetime format. - """ - days = secs//86400 - hours = (secs - days*86400)//3600 - minutes = (secs - days*86400 - hours*3600)//60 - seconds = secs - days*86400 - hours*3600 - minutes*60 + Returns: + string: Human readable datetime format. + """ + days = secs // 86400 + hours = (secs - days * 86400) // 3600 + minutes = (secs - days * 86400 - hours * 3600) // 60 + seconds = secs - days * 86400 - hours * 3600 - minutes * 60 - if lang == "ES": - days_text = "día{}".format("s" if days!=1 else "") - hours_text = "hora{}".format("s" if hours!=1 else "") - minutes_text = "minuto{}".format("s" if minutes!=1 else "") - seconds_text = "segundo{}".format("s" if seconds!=1 else "") - elif lang == "DE": - days_text = "Tag{}".format("e" if days!=1 else "") - hours_text = "Stunde{}".format("n" if hours!=1 else "") - minutes_text = "Minute{}".format("n" if minutes!=1 else "") - seconds_text = "Sekunde{}".format("n" if seconds!=1 else "") - elif lang == "RU": - days_text = pluralizeRussian(days, "день", "дня", "дней") - hours_text = pluralizeRussian(hours, "час", "часа", "часов") - minutes_text = pluralizeRussian(minutes, "минута", "минуты", "минут") - seconds_text = pluralizeRussian(seconds, "секунда", "секунды", "секунд") - else: - #Default to English - days_text = "day{}".format("s" if days!=1 else "") - hours_text = "hour{}".format("s" if hours!=1 else "") - minutes_text = "minute{}".format("s" if minutes!=1 else "") - seconds_text = "second{}".format("s" if seconds!=1 else "") + if lang == "ES": + days_text = "día{}".format("s" if days != 1 else "") + hours_text = "hora{}".format("s" if hours != 1 else "") + minutes_text = "minuto{}".format("s" if minutes != 1 else "") + seconds_text = "segundo{}".format("s" if seconds != 1 else "") + elif lang == "DE": + days_text = "Tag{}".format("e" if days != 1 else "") + hours_text = "Stunde{}".format("n" if hours != 1 else "") + minutes_text = "Minute{}".format("n" if minutes != 1 else "") + seconds_text = "Sekunde{}".format("n" if seconds != 1 else "") + elif lang == "RU": + days_text = pluralizeRussian(days, "день", "дня", "дней") + hours_text = pluralizeRussian(hours, "час", "часа", "часов") + minutes_text = pluralizeRussian(minutes, "минута", "минуты", "минут") + seconds_text = pluralizeRussian(seconds, "секунда", "секунды", "секунд") + else: + # Default to English + days_text = "day{}".format("s" if days != 1 else "") + hours_text = "hour{}".format("s" if hours != 1 else "") + minutes_text = "minute{}".format("s" if minutes != 1 else "") + seconds_text = "second{}".format("s" if seconds != 1 else "") - result = ", ".join(filter(lambda x: bool(x),[ - "{0} {1}".format(days, days_text) if days else "", - "{0} {1}".format(hours, hours_text) if hours else "", - "{0} {1}".format(minutes, minutes_text) if minutes else "", - "{0:.4} {1}".format(seconds, seconds_text) if seconds else "" - ])) - return result \ No newline at end of file + result = ", ".join( + filter( + lambda x: bool(x), + [ + "{0} {1}".format(days, days_text) if days else "", + "{0} {1}".format(hours, hours_text) if hours else "", + "{0} {1}".format(minutes, minutes_text) if minutes else "", + "{0:.4} {1}".format(seconds, seconds_text) if seconds else "", + ], + ) + ) + return result diff --git a/setup.py b/setup.py index fa65dec..2405cd4 100644 --- a/setup.py +++ b/setup.py @@ -28,6 +28,7 @@ "pandas>=0.23.4", "pybktree==1.1", "cython>=0.29.17", + "jsonschema==4.2.1", ], python_requires=">=3.8", ) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 0c836bb..916404a 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -95,4 +95,3 @@ def test_parse_tags_csv(data): def test_check_distance_too_big_between_tags(data): with pytest.raises(SystemExit): preprocessing.check_tags(pytest.correct_tags, 8) - From 68c8da7480e70943c5a5d5e4348a467463fdc864 Mon Sep 17 00:00:00 2001 From: Hoohm Date: Sun, 21 Nov 2021 14:26:31 +0100 Subject: [PATCH 51/77] added some checks --- cite_seq_count/__main__.py | 9 ++++--- cite_seq_count/argsparser.py | 12 ++++----- cite_seq_count/preprocessing.py | 46 ++++++++++++++++++--------------- 3 files changed, 37 insertions(+), 30 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index c147f77..1626f53 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -38,13 +38,16 @@ def main(): # Parse arguments. args = parser.parse_args() - assert os.access(args.temp_path, os.W_OK) + # Check a few path before doing anything + if not os.access(args.temp_path, os.W_OK): + sys.exit("Temp folder: {} is not writable. Please check permissions and/or change temp folder.".format(args.temp_path)) + if not os.access(args.outfolder, os.W_OK): + sys.exit("Output folder: {} is not writable. Please check permissions and/or change output folder.".format(args.outfolder)) + # Get chemistry defs (translation_dict, chemistry_def) = chemistry.setup_chemistry(args) - # Check if we have a filtered list provided - # Load TAGs/ABs. ab_map = preprocessing.parse_tags_csv(args.tags) ordered_tags, longest_tag_len = preprocessing.check_tags(ab_map, args.max_error) diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py index 091a896..bd8fc79 100644 --- a/cite_seq_count/argsparser.py +++ b/cite_seq_count/argsparser.py @@ -13,11 +13,11 @@ def get_package_version(): return version -def chunk_size_limit(arg): +def chunk_size_limit(chunk_size: int) -> int: """Validates chunk_size limits""" max_size = 2147483647 try: - f = int(arg) + f = int(chunk_size) except ValueError: raise ArgumentTypeError("Chunk size must be an int") if f < 1 or f > max_size: @@ -254,7 +254,9 @@ def get_args(): dest="chunk_size", help=("How many reads should be sent to a child process at a time"), ) - parallel.add_argument( + + # Global group + parser.add_argument( "--temp_path", required=False, type=str, @@ -264,8 +266,6 @@ def get_args(): "Temp folder for chunk creation specification. Useful when using a cluster with a scratch folder" ), ) - - # Global group parser.add_argument( "-n", "--first_n", @@ -280,7 +280,7 @@ def get_args(): "--output", required=False, type=str, - default="Results", + default="results", dest="outfolder", help=("Results will be written to this folder"), ) diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index ec5cf82..874b3e7 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -104,7 +104,7 @@ def parse_cell_list_csv(filename, barcode_length): return translation_dict -def parse_tags_csv(filename): +def parse_tags_csv(file_name): """Reads the TAGs from a CSV file. Checks that the header contains necessary strings and if sequences are made of ATGC @@ -117,7 +117,7 @@ def parse_tags_csv(filename): TTCCGCCTCTCTTTG,Hashtag_3 Args: - filename (str): TAGs file path. + file_name (file): TAGs file name. Returns: dict: A dictionary using sequences as keys and feature names as values. @@ -125,27 +125,31 @@ def parse_tags_csv(filename): """ REQUIRED_HEADER = ["sequence", "feature_name"] atgc_test = regex.compile("^[ATGC]{1,}$") - with open(filename, mode="r") as csv_file: - csv_reader = csv.reader(csv_file) - tags = {} - header = next(csv_reader) - set_dif = set(REQUIRED_HEADER) - set(header) - if len(set_dif) != 0: + + try: + with open(file_name) as csvfile: + csv_reader = csv.reader(csvfile) + except Exception as e: + sys.exit(e) + tags = {} + header = next(csv_reader) + set_dif = set(REQUIRED_HEADER) - set(header) + if len(set_dif) != 0: + raise SystemExit( + "The header is missing {}. Exiting".format(",".join(list(set_dif))) + ) + sequence_id = header.index("sequence") + feature_id = header.index("feature_name") + for i, row in enumerate(csv_reader): + sequence = row[sequence_id].strip() + + if not regex.match(atgc_test, sequence): raise SystemExit( - "The header is missing {}. Exiting".format(",".join(list(set_dif))) - ) - sequence_id = header.index("sequence") - feature_id = header.index("feature_name") - for i, row in enumerate(csv_reader): - sequence = row[sequence_id].strip() - - if not regex.match(atgc_test, sequence): - raise SystemExit( - "Sequence {} on line {} is not only composed of ATGC. Exiting".format( - sequence, i - ) + "Sequence {} on line {} is not only composed of ATGC. Exiting".format( + sequence, i ) - tags[sequence] = row[feature_id].strip() + ) + tags[sequence] = row[feature_id].strip() return tags From b65e3686340c6d7988b02298e715cbf3f0fd8102 Mon Sep 17 00:00:00 2001 From: Hoohm Date: Sun, 21 Nov 2021 16:22:14 +0100 Subject: [PATCH 52/77] fixed writing out unknown barcodes --- cite_seq_count/__main__.py | 2 +- cite_seq_count/argsparser.py | 4 +- cite_seq_count/chemistry.py | 4 +- cite_seq_count/io.py | 23 ++++---- cite_seq_count/preprocessing.py | 94 ++++++++++++++------------------- 5 files changed, 60 insertions(+), 67 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 1626f53..f52068a 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -41,7 +41,7 @@ def main(): # Check a few path before doing anything if not os.access(args.temp_path, os.W_OK): sys.exit("Temp folder: {} is not writable. Please check permissions and/or change temp folder.".format(args.temp_path)) - if not os.access(args.outfolder, os.W_OK): + if not os.access(os.path.dirname(os.path.abspath(args.outfolder)), os.W_OK): sys.exit("Output folder: {} is not writable. Please check permissions and/or change output folder.".format(args.outfolder)) diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py index bd8fc79..5b5c9c3 100644 --- a/cite_seq_count/argsparser.py +++ b/cite_seq_count/argsparser.py @@ -106,11 +106,11 @@ def get_args(): ), ) barcodes.add_argument( - "--chemistry", + "--chemistry_id", type=str, required=False, default=False, - help=("Option replacing cell/UMI barcodes indexes and translation list."), + help=("[BETA FEATURE] Option replacing cell/UMI barcodes indexes and translation list."), ) if "--chemistry" not in sys.argv: barcodes.add_argument( diff --git a/cite_seq_count/chemistry.py b/cite_seq_count/chemistry.py index 171b56d..d074443 100644 --- a/cite_seq_count/chemistry.py +++ b/cite_seq_count/chemistry.py @@ -136,8 +136,8 @@ def create_chemistry_definition(args): def setup_chemistry(args): - if args.chemistry: - chemistry_def = get_chemistry_definition(args.chemistry) + if args.chemistry_id: + chemistry_def = get_chemistry_definition(args.chemistry_id) translation_dict = preprocessing.parse_cell_list_csv( filename=chemistry_def.translation_list_path, barcode_length=chemistry_def.cell_barcode_end diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index b22dd87..39751d9 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -121,20 +121,26 @@ def write_to_files( data_type (string): A string definning if the data is umi or read based. outfolder (string): Path to the output folder. """ - if translation_dict: - original_barcode = list(translation_dict.keys()) - translated_barcode = list(translation_dict.values()) prefix = os.path.join(outfolder, data_type + "_count") + unknown_id = 1 os.makedirs(prefix, exist_ok=True) io.mmwrite(os.path.join(prefix, "matrix.mtx"), a=sparse_matrix, field="integer") with gzip.open(os.path.join(prefix, "barcodes.tsv.gz"), "wb") as barcode_file: for barcode in filtered_cells: if translation_dict: - barcode_file.write( - "{}\t{}\n".format( - original_barcode[translated_barcode.index(barcode)], barcode - ).encode(), - ) + if barcode in translation_dict: + barcode_file.write( + "{}\t{}\n".format( + translation_dict[barcode], barcode + ).encode(), + ) + else: + barcode_file.write( + "{}\t{}\n".format( + "translation_not_found_{}".format(unknown_id), barcode + ).encode(), + ) + unknown_id += 1 else: barcode_file.write("{}\n".format(barcode).encode()) with gzip.open(os.path.join(prefix, "features.tsv.gz"), "wb") as feature_file: @@ -402,7 +408,6 @@ def write_chunks_to_disk( chunked_file_object = tempfile.NamedTemporaryFile( "w", dir=temp_path, suffix="_csc", delete=False ) - # chunked_file_object = open(temp_file, "w") temp_files.append(chunked_file_object.name) reads_written = 0 if total_reads_written == n_reads_per_chunk: diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 874b3e7..033d26c 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -33,13 +33,13 @@ def parse_filtered_list_csv(filename, barcode_length): out_set = set() barcode_pattern = regex.compile(r"^[ATGC]{{{}}}".format(barcode_length)) for barcode in barcodes: - check_barcode = barcode.strip(STRIP_CHARS) - if barcode_pattern.match(check_barcode): - out_set.add(check_barcode) + checked_barcode = barcode.strip(STRIP_CHARS) + if barcode_pattern.match(checked_barcode): + out_set.add(checked_barcode) else: sys.exit( - "This barcode {} is not only composed of ATGC bases.".format( - check_barcode + "Only ATGC barcodes are accepted in the filtered list. Please delete entry {}".format( + checked_barcode ) ) @@ -63,44 +63,30 @@ def parse_cell_list_csv(filename, barcode_length): STRIP_CHARS = '"0123456789- \t\n' REQUIRED_HEADER = ["reference", "translation"] - cell_pattern = regex.compile(r"^[ATGC]{{{}}}".format(barcode_length)) - csv_reader = get_csv_reader_from_path(filename=filename) - header = next(csv_reader) + data = read_csv(filename, dtype={"reference": str, "translation": str}) + if data.shape[1] != 2: + print(data.head()) + sys.exit("Your translation file only holds 1 column or is tab delimited instead of csv.") + barcode_pattern = regex.compile(r"^[ATGC]{{{}}}".format(barcode_length)) + + header = data.columns set_dif = set(REQUIRED_HEADER) - set(header) if len(set_dif) != 0: raise SystemExit( "The header is missing {}. Exiting".format(",".join(list(set_dif))) ) - # translation_id = header.index(REQUIRED_HEADER[0]) - translation_dict = {} - if "translation" in header: - - translation_id = header.index("translation") - reference_id = header.index("reference") - for row in csv_reader: - ref_barcode = row[reference_id].strip(STRIP_CHARS) - tra_barcode = row[translation_id].strip(STRIP_CHARS) - if ( - len(ref_barcode) == barcode_length - and len(tra_barcode) == barcode_length - ): - translation_dict[tra_barcode] = ref_barcode - else: - sys.exit('The header is missing a the "{}" keyword'.format("translation")) + #Prepare and validate data - for cell_barcode in translation_dict.keys(): - if not cell_pattern.match(cell_barcode): - sys.exit( - "This barcode {} is not only composed of ATGC bases.".format( - cell_barcode - ) - ) - if len(translation_dict) == 0: - sys.exit("translation_dict is empty.") - print( - "Your translation list provides a translation name. This will be the default for the count matrices." - ) + data["reference"] = data["reference"].map(lambda x: x.rstrip(STRIP_CHARS)) + data["translation"] = data["translation"].map(lambda x: x.rstrip(STRIP_CHARS)) + + if any(data["reference"].map(lambda x: not barcode_pattern.match(x))): + sys.exit("Barcode(s) in reference column don't match [ATGC] or a length of {}. Please check.".format(barcode_length)) + if any(data["translation"].map(lambda x: not barcode_pattern.match(x))): + sys.exit("Barcode(s) in translation column don't match [ATGC] or a length of {}. Please check.".format(barcode_length)) + + translation_dict = dict(zip(data.translation, data.reference)) return translation_dict @@ -131,25 +117,27 @@ def parse_tags_csv(file_name): csv_reader = csv.reader(csvfile) except Exception as e: sys.exit(e) - tags = {} - header = next(csv_reader) - set_dif = set(REQUIRED_HEADER) - set(header) - if len(set_dif) != 0: - raise SystemExit( - "The header is missing {}. Exiting".format(",".join(list(set_dif))) - ) - sequence_id = header.index("sequence") - feature_id = header.index("feature_name") - for i, row in enumerate(csv_reader): - sequence = row[sequence_id].strip() - - if not regex.match(atgc_test, sequence): + with open(file_name) as csvfile: + csv_reader = csv.reader(csvfile) + tags = {} + header = next(csv_reader) + set_dif = set(REQUIRED_HEADER) - set(header) + if len(set_dif) != 0: raise SystemExit( - "Sequence {} on line {} is not only composed of ATGC. Exiting".format( - sequence, i - ) + "The header is missing {}. Exiting".format(",".join(list(set_dif))) ) - tags[sequence] = row[feature_id].strip() + sequence_id = header.index("sequence") + feature_id = header.index("feature_name") + for i, row in enumerate(csv_reader): + sequence = row[sequence_id].strip() + + if not regex.match(atgc_test, sequence): + raise SystemExit( + "Sequence {} on line {} is not only composed of ATGC. Exiting".format( + sequence, i + ) + ) + tags[sequence] = row[feature_id].strip() return tags From fb0d22602799c439404e441662620d6dafb95cd9 Mon Sep 17 00:00:00 2001 From: Hoohm Date: Sun, 21 Nov 2021 16:41:53 +0100 Subject: [PATCH 53/77] reformatting --- cite_seq_count/__main__.py | 13 ++++++++++--- cite_seq_count/argsparser.py | 6 ++++-- cite_seq_count/io.py | 4 +--- cite_seq_count/preprocessing.py | 30 ++++++++++++++++++++---------- 4 files changed, 35 insertions(+), 18 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index f52068a..803e53a 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -40,10 +40,17 @@ def main(): args = parser.parse_args() # Check a few path before doing anything if not os.access(args.temp_path, os.W_OK): - sys.exit("Temp folder: {} is not writable. Please check permissions and/or change temp folder.".format(args.temp_path)) + sys.exit( + "Temp folder: {} is not writable. Please check permissions and/or change temp folder.".format( + args.temp_path + ) + ) if not os.access(os.path.dirname(os.path.abspath(args.outfolder)), os.W_OK): - sys.exit("Output folder: {} is not writable. Please check permissions and/or change output folder.".format(args.outfolder)) - + sys.exit( + "Output folder: {} is not writable. Please check permissions and/or change output folder.".format( + args.outfolder + ) + ) # Get chemistry defs (translation_dict, chemistry_def) = chemistry.setup_chemistry(args) diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py index 5b5c9c3..3d408e8 100644 --- a/cite_seq_count/argsparser.py +++ b/cite_seq_count/argsparser.py @@ -110,7 +110,9 @@ def get_args(): type=str, required=False, default=False, - help=("[BETA FEATURE] Option replacing cell/UMI barcodes indexes and translation list."), + help=( + "[BETA FEATURE] Option replacing cell/UMI barcodes indexes and translation list." + ), ) if "--chemistry" not in sys.argv: barcodes.add_argument( @@ -254,7 +256,7 @@ def get_args(): dest="chunk_size", help=("How many reads should be sent to a child process at a time"), ) - + # Global group parser.add_argument( "--temp_path", diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 39751d9..4dd04d2 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -130,9 +130,7 @@ def write_to_files( if translation_dict: if barcode in translation_dict: barcode_file.write( - "{}\t{}\n".format( - translation_dict[barcode], barcode - ).encode(), + "{}\t{}\n".format(translation_dict[barcode], barcode).encode(), ) else: barcode_file.write( diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 033d26c..41c1979 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -66,9 +66,11 @@ def parse_cell_list_csv(filename, barcode_length): data = read_csv(filename, dtype={"reference": str, "translation": str}) if data.shape[1] != 2: print(data.head()) - sys.exit("Your translation file only holds 1 column or is tab delimited instead of csv.") + sys.exit( + "Your translation file only holds 1 column or is tab delimited instead of csv." + ) barcode_pattern = regex.compile(r"^[ATGC]{{{}}}".format(barcode_length)) - + header = data.columns set_dif = set(REQUIRED_HEADER) - set(header) if len(set_dif) != 0: @@ -76,16 +78,24 @@ def parse_cell_list_csv(filename, barcode_length): "The header is missing {}. Exiting".format(",".join(list(set_dif))) ) - #Prepare and validate data + # Prepare and validate data data["reference"] = data["reference"].map(lambda x: x.rstrip(STRIP_CHARS)) data["translation"] = data["translation"].map(lambda x: x.rstrip(STRIP_CHARS)) - + if any(data["reference"].map(lambda x: not barcode_pattern.match(x))): - sys.exit("Barcode(s) in reference column don't match [ATGC] or a length of {}. Please check.".format(barcode_length)) + sys.exit( + "Barcode(s) in reference column don't match [ATGC] or a length of {}. Please check.".format( + barcode_length + ) + ) if any(data["translation"].map(lambda x: not barcode_pattern.match(x))): - sys.exit("Barcode(s) in translation column don't match [ATGC] or a length of {}. Please check.".format(barcode_length)) - + sys.exit( + "Barcode(s) in translation column don't match [ATGC] or a length of {}. Please check.".format( + barcode_length + ) + ) + translation_dict = dict(zip(data.translation, data.reference)) return translation_dict @@ -111,13 +121,13 @@ def parse_tags_csv(file_name): """ REQUIRED_HEADER = ["sequence", "feature_name"] atgc_test = regex.compile("^[ATGC]{1,}$") - + try: - with open(file_name) as csvfile: + with open(file_name) as csvfile: csv_reader = csv.reader(csvfile) except Exception as e: sys.exit(e) - with open(file_name) as csvfile: + with open(file_name) as csvfile: csv_reader = csv.reader(csvfile) tags = {} header = next(csv_reader) From 564a5d05e1c50a2013443ef196df2ac233728fcc Mon Sep 17 00:00:00 2001 From: Hoohm Date: Sun, 21 Nov 2021 16:54:46 +0100 Subject: [PATCH 54/77] fixed tests --- tests/test_data/reference_lists/pass/simple_ref.csv | 3 --- tests/test_io.py | 12 ++++++------ tests/test_preprocessing.py | 4 ++-- tests/test_processing.py | 2 +- 4 files changed, 9 insertions(+), 12 deletions(-) delete mode 100644 tests/test_data/reference_lists/pass/simple_ref.csv diff --git a/tests/test_data/reference_lists/pass/simple_ref.csv b/tests/test_data/reference_lists/pass/simple_ref.csv deleted file mode 100644 index 0b5a0fd..0000000 --- a/tests/test_data/reference_lists/pass/simple_ref.csv +++ /dev/null @@ -1,3 +0,0 @@ -reference -ACTGTTTTATTGGCCT -TTCATAAGGTAGGGAT \ No newline at end of file diff --git a/tests/test_io.py b/tests/test_io.py index fbdafaa..cfdfd45 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -33,15 +33,15 @@ def data(): pytest.corrupt_R1_path = "tests/test_data/fastq/corrupted_R1.fastq.gz" pytest.corrupt_R2_path = "tests/test_data/fastq/corrupted_R2.fastq.gz" - pytest.correct_R1_multipath = "path/to/R1_1.fastq.gz,path/to/R1_2.fastq.gz" - pytest.correct_R2_multipath = "path/to/R2_1.fastq.gz,path/to/R2_2.fastq.gz" + pytest.correct_R1_multipath = "tests/test_data/fastq/correct_R1.fastq.gz,tests/test_data/fastq/correct_R1.fastq.gz" + pytest.correct_R2_multipath = "tests/test_data/fastq/correct_R2.fastq.gz,tests/test_data/fastq/correct_R2.fastq.gz" pytest.incorrect_R2_multipath = ( "path/to/R2_1.fastq.gz,path/to/R2_2.fastq.gz,path/to/R2_3.fastq.gz" ) pytest.correct_multipath_result = ( - ["path/to/R1_1.fastq.gz", "path/to/R1_2.fastq.gz"], - ["path/to/R2_1.fastq.gz", "path/to/R2_2.fastq.gz"], + [pytest.correct_R1_path, pytest.correct_R1_path], + [pytest.correct_R2_path, pytest.correct_R2_path], ) test_matrix = sparse.dok_matrix((4, 2), dtype=np.int32) test_matrix[1, 1] = 1 @@ -77,7 +77,7 @@ def test_write_to_files_wo_translation(data, tmpdir): pytest.ordered_tags_map, pytest.data_type, output_path, - reference_dict=reference_dict, + translation_dict=False, ) file_path = os.path.join(tmpdir, "without_translation", "umi_count/matrix.mtx.gz") with gzip.open(file_path, "rb") as mtx_file: @@ -110,7 +110,7 @@ def test_write_to_files_with_translation(data, tmpdir): pytest.ordered_tags_map, pytest.data_type, output_path, - reference_dict=reference_dict, + translation_dict=reference_dict, ) file_path = os.path.join(tmpdir, "with_translation", "umi_count/matrix.mtx.gz") with gzip.open(file_path, "rb") as mtx_file: diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 916404a..829776d 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -21,7 +21,7 @@ def data(): pytest.correct_tags_path = "tests/test_data/tags/pass/correct.csv" # Create some variables to compare to - pytest.correct_reference_list = set(["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]) + pytest.correct_reference_translation_list = set(["ACTGTTTTATTGGCCT","TTCATCCTTTAGGGAT"]) pytest.correct_tags = { "AGGACCATCCAA": "CITE_LEN_12_1", "ACATGTTACCGT": "CITE_LEN_12_2", @@ -75,7 +75,7 @@ def test_parse_reference_list_csv(data): passing_files = glob.glob(pytest.passing_reference_list_csv) for file_path in passing_files: assert preprocessing.parse_cell_list_csv(file_path, 16).keys() in ( - pytest.correct_reference_list, + pytest.correct_reference_translation_list, 1, ) with pytest.raises(SystemExit): diff --git a/tests/test_processing.py b/tests/test_processing.py index 8b129b8..310d493 100644 --- a/tests/test_processing.py +++ b/tests/test_processing.py @@ -49,7 +49,7 @@ def test_correct_umis(data): @pytest.mark.dependency(depends=["test_correct_umis"]) def test_correct_cells(data): - processing.correct_cells_no_reference_list( + processing.correct_cells_no_translation_list( pytest.corrected_results, pytest.reads_per_cell, pytest.umis_per_cell, From 14a9c616eb0966b620c04fc561c635459dc6d4c4 Mon Sep 17 00:00:00 2001 From: Hoohm Date: Sun, 21 Nov 2021 20:11:36 +0100 Subject: [PATCH 55/77] added pyymal --- cite_seq_count/chemistry.py | 2 +- setup.py | 2 +- tests/test_io.py | 5 ++--- 3 files changed, 4 insertions(+), 5 deletions(-) diff --git a/cite_seq_count/chemistry.py b/cite_seq_count/chemistry.py index d074443..02678ba 100644 --- a/cite_seq_count/chemistry.py +++ b/cite_seq_count/chemistry.py @@ -88,7 +88,7 @@ def get_chemistry_definition(chemistry_short_name): Fetches chemistry definitions from a remote definitions.json and returns the json. """ chemistry_defs = fetch_definitions()[chemistry_short_name] - + print(chemistry_defs) if chemistry_defs["translation_list"]["path"] not in DEFINITIONS_DB.registry: path = pooch.retrieve( url=os.path.join( diff --git a/setup.py b/setup.py index 2405cd4..bb49063 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ "pandas>=0.23.4", "pybktree==1.1", "cython>=0.29.17", - "jsonschema==4.2.1", + "pyyaml==6.0" ], python_requires=">=3.8", ) diff --git a/tests/test_io.py b/tests/test_io.py index cfdfd45..d2084fc 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -59,7 +59,6 @@ def data(): def test_write_to_files_wo_translation(data, tmpdir): - reference_dict = {"ACTGTTTTATTGGCCT": 0, "TTCATAAGGTAGGGAT": 0} output_path = os.path.join(tmpdir, "without_translation") mtx_path = os.path.join(output_path, "umi_count", "matrix.mtx.gz") @@ -88,7 +87,7 @@ def test_write_to_files_wo_translation(data, tmpdir): def test_write_to_files_with_translation(data, tmpdir): - reference_dict = { + translation_dict = { "ACTGTTTTATTGGCCT": "GGCTTCGATACTAGAT", "TTCATAAGGTAGGGAT": "GATCGGATAGCTAATA", } @@ -110,7 +109,7 @@ def test_write_to_files_with_translation(data, tmpdir): pytest.ordered_tags_map, pytest.data_type, output_path, - translation_dict=reference_dict, + translation_dict=translation_dict, ) file_path = os.path.join(tmpdir, "with_translation", "umi_count/matrix.mtx.gz") with gzip.open(file_path, "rb") as mtx_file: From edba219370321ffa6d93bc088951fabd6b077952 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sat, 2 Jul 2022 15:47:29 +0300 Subject: [PATCH 56/77] feat: Add json template --- cite_seq_count/__main__.py | 12 +- cite_seq_count/io.py | 191 ++++++++++++++------------- cite_seq_count/templates/report.json | 33 +++++ setup.py | 8 +- 4 files changed, 137 insertions(+), 107 deletions(-) create mode 100644 cite_seq_count/templates/report.json diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 803e53a..dccb56d 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -7,15 +7,7 @@ import logging import time -from cite_seq_count import preprocessing -from cite_seq_count import mapping -from cite_seq_count import processing -from cite_seq_count import chemistry -from cite_seq_count import io -from cite_seq_count import secondsToText -from cite_seq_count import argsparser - -from collections import Counter +from cite_seq_count import preprocessing, argsparser, mapping, processing, chemistry, io def main(): @@ -93,7 +85,7 @@ def main(): maximum_distance=maximum_distance, ) # Map the data - (final_results, umis_per_cell, reads_per_cell, merged_no_match,) = mapping.map_data( + (final_results, umis_per_cell, reads_per_cell, merged_no_match) = mapping.map_data( input_queue=input_queue, unmapped_id=len(ordered_tags), args=args ) diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 4dd04d2..43f668b 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -1,3 +1,4 @@ +"""Handle io operations""" import os import csv import sys @@ -6,15 +7,20 @@ import time import datetime import tempfile +import json from collections import namedtuple from itertools import islice +import pkg_resources +import yaml import pandas as pd from scipy import io from cite_seq_count import secondsToText +JSON_REPORT_PATH = pkg_resources.resource_filename(__name__, "templates/report.json") + def blocks(files, size=65536): """ @@ -29,10 +35,10 @@ def blocks(files, size=65536): A generator """ while True: - b = files.read(size) - if not b: + partial_file = files.read(size) + if not partial_file: break - yield b + yield partial_file def get_n_lines(file_path): @@ -47,13 +53,13 @@ def get_n_lines(file_path): Returns: n_lines (int): Number of lines in the file """ - print("Counting number of reads in file {}".format(file_path)) - with gzip.open(file_path, "rt", encoding="utf-8", errors="ignore") as f: - n_lines = sum(bl.count("\n") for bl in blocks(f)) + print(f"Counting number of reads in file {file_path}") + with gzip.open(file_path, "rt", encoding="utf-8", errors="ignore") as infile: + n_lines = sum(bl.count("\n") for bl in blocks(infile)) if n_lines % 4 != 0: sys.exit( - "{}'s number of lines is not a multiple of 4. The file " - "might be corrupted.\n Exiting".format(file_path) + f"{file_path}'s number of lines is not a multiple of 4. The file " + "might be corrupted.\n Exiting" ) return n_lines @@ -74,8 +80,8 @@ def get_read_paths(read1_path, read2_path): _read2_path = read2_path.split(",") if len(_read1_path) != len(_read2_path): sys.exit( - "Unequal number of read1 ({}) and read2({}) files provided" - "\n Exiting".format(len(_read1_path), len(_read2_path)) + f"Unequal number of read1 ({len(_read1_path)}) and read2({len(_read2_path)}) files provided" + "\n Exiting" ) all_files = _read1_path + _read2_path for file_path in all_files: @@ -83,9 +89,9 @@ def get_read_paths(read1_path, read2_path): if os.access(file_path, os.R_OK): continue else: - sys.exit("{} is not readable. Exiting".format(file_path)) + sys.exit(f"{file_path} is not readable. Exiting") else: - sys.exit("{} does not exist. Exiting".format(file_path)) + sys.exit(f"{file_path} does not exist. Exiting") return (_read1_path, _read2_path) @@ -101,11 +107,11 @@ def get_csv_reader_from_path(filename, sep="\t"): csv_reader: The csv_reader for the file """ if filename.endswith(".gz"): - f = gzip.open(filename, mode="rt") - csv_reader = csv.reader(f, delimiter=sep) + file_handle = gzip.open(filename, mode="rt") + csv_reader = csv.reader(file_handle, delimiter=sep) else: - f = open(filename, encoding="UTF-8") - csv_reader = csv.reader(f, delimiter=sep) + file_handle = open(filename, encoding="UTF-8") + csv_reader = csv.reader(file_handle, delimiter=sep) return csv_reader @@ -130,22 +136,20 @@ def write_to_files( if translation_dict: if barcode in translation_dict: barcode_file.write( - "{}\t{}\n".format(translation_dict[barcode], barcode).encode(), + f"{translation_dict[barcode]}\t{barcode}\n".encode(), ) else: barcode_file.write( "{}\t{}\n".format( - "translation_not_found_{}".format(unknown_id), barcode + f"translation_not_found_{unknown_id}", barcode ).encode(), ) unknown_id += 1 else: - barcode_file.write("{}\n".format(barcode).encode()) + barcode_file.write(f"{barcode}\n".encode()) with gzip.open(os.path.join(prefix, "features.tsv.gz"), "wb") as feature_file: for feature in ordered_tags: - feature_file.write( - "{}\t{}\n".format(feature.sequence, feature.name).encode() - ) + feature_file.write(f"{feature.sequence}\t{feature.name}\n".encode()) if data_type == "read": feature_file.write("{}\t{}\n".format("UNKNOWN", "unmapped").encode()) with open(os.path.join(prefix, "matrix.mtx"), "rb") as mtx_in: @@ -187,24 +191,38 @@ def write_unmapped(merged_no_match, top_unknowns, outfolder, filename): top_unmapped = merged_no_match.most_common(top_unknowns) - with open(os.path.join(outfolder, filename), "w") as unknown_file: + with open(os.path.join(outfolder, filename), "w", encoding="utf-8") as unknown_file: unknown_file.write("tag,count\n") for element in top_unmapped: - unknown_file.write("{},{}\n".format(element[0], element[1])) + unknown_file.write(f"{element[0]},{element[1]}\n") + + +def load_report_template() -> dict: + """Load json template for the report + + Returns: + dict: Dict for the report + """ + with open(file=JSON_REPORT_PATH, encoding="utf-8", mode="r") as json_in: + try: + report_dict = json.load(json_in) + except json.JSONDecodeError: + sys.exit( + f"Json report template at {JSON_REPORT_PATH} is not a valid json file." + ) + return report_dict def create_report( total_reads, - reads_per_cell, no_match, version, start_time, - ordered_tags, umis_corrected, bcs_corrected, bad_cells, - R1_too_short, - R2_too_short, + r1_too_short, + r2_too_short, args, chemistry_def, maximum_distance, @@ -220,76 +238,59 @@ def create_report( args (arg_parse): Arguments provided by the user. """ + total_unmapped = sum(no_match.values()) - total_too_short = R1_too_short + R2_too_short + total_too_short = r1_too_short + r2_too_short total_mapped = total_reads - total_unmapped - total_too_short too_short_perc = round((total_too_short / total_reads) * 100) mapped_perc = round((total_mapped / total_reads) * 100) unmapped_perc = round((total_unmapped / total_reads) * 100) - with open(os.path.join(args.outfolder, "run_report.yaml"), "w") as report_file: - report_file.write( - """Date: {} -Running time: {} -CITE-seq-Count Version: {} -Reads processed: {} -Percentage mapped: {} -Percentage unmapped: {} -Percentage too short: {} - R1_too_short: {} - R2_too_short: {} -Uncorrected cells: {} -Correction: - Cell barcodes collapsing threshold: {} - Cell barcodes corrected: {} - UMI collapsing threshold: {} - UMIs corrected: {} -Run parameters: - Read1_paths: {} - Read2_paths: {} - Cell barcode: - First position: {} - Last position: {} - UMI barcode: - First position: {} - Last position: {} - Expected cells: {} - Tags max errors: {} - Start trim: {} -""".format( - datetime.datetime.today().strftime("%Y-%m-%d"), - secondsToText.secondsToText(time.time() - start_time), - version, - int(total_reads), - mapped_perc, - unmapped_perc, - too_short_perc, - R1_too_short, - R2_too_short, - len(bad_cells), - args.bc_threshold, - bcs_corrected, - args.umi_threshold, - umis_corrected, - args.read1_path, - args.read2_path, - chemistry_def.cell_barcode_start, - chemistry_def.cell_barcode_end, - chemistry_def.umi_barcode_start, - chemistry_def.umi_barcode_end, - args.expected_cells, - maximum_distance, - chemistry_def.R2_trim_start, - ) - ) + report_data = load_report_template() + report_data["Date"] = datetime.datetime.today().strftime("%Y-%m-%d") + report_data["Running time"] = secondsToText.secondsToText(time.time() - start_time) + report_data["CITE-seq-Count Version"] = version + report_data["Reads processed"] = int(total_reads) + report_data["Percentage mapped"] = mapped_perc + report_data["Percentage unmapped"] = unmapped_perc + report_data["Percentage too short"] = too_short_perc + report_data["Percentage too short"]["r1_too_short"] = r1_too_short + report_data["Percentage too short"]["r2_too_short"] = r2_too_short + report_data["Uncorrected cells"] = len(bad_cells) + report_data["Correction"]["Cell barcodes collapsing threshold"] = args.bc_threshold + report_data["Correction"]["Cell barcodes corrected"] = bcs_corrected + report_data["Correction"]["UMI collapsing threshold"] = args.umi_threshold + report_data["Correction"]["UMIs corrected"] = umis_corrected + report_data["Run parameters"]["Read1_paths"] = args.read1_path + report_data["Run parameters"]["Read2_paths"] = args.read2_path + report_data["Run parameters"]["Cell barcode"][ + "First position" + ] = chemistry_def.cell_barcode_start + report_data["Run parameters"]["Cell barcode"][ + "Last position" + ] = chemistry_def.cell_barcode_end + report_data["Run parameters"]["UMI barcode"][ + "First position" + ] = chemistry_def.umi_barcode_start + report_data["Run parameters"]["UMI barcode"][ + "Last position" + ] = chemistry_def.umi_barcode_end + report_data["Expected cells"] = args.expected_cells + report_data["Tags max errors"] = maximum_distance + report_data["Start trim"] = chemistry_def.R2_trim_start + + with open( + os.path.join(args.outfolder, "run_report.yaml"), "w", encoding="utf-8" + ) as report_file: + yaml.dump(report_data, report_file, default_flow_style=False, sort_keys=False) def write_chunks_to_disk( args, read1_paths, read2_paths, - R2_min_length, + r2_min_length, n_reads_per_chunk, chemistry_def, ordered_tags, @@ -303,7 +304,7 @@ def write_chunks_to_disk( args(argparse): All parsed arguments. read1_paths (list): List of R1 fastq.gz paths. read2_paths (list): List of R2 fastq.gz paths. - R2_min_length (int): Minimum length of read2 sequences. + r2_min_length (int): Minimum length of read2 sequences. n_reads_per_chunk (int): How many reads per chunk. chemistry_def (namedtuple): Hols all the information about the chemistry definition. ordered_tags (list): List of namedtuple tags. @@ -324,8 +325,8 @@ def write_chunks_to_disk( temp_path = os.path.abspath(args.temp_path) input_queue = [] temp_files = [] - R1_too_short = 0 - R2_too_short = 0 + r1_too_short = 0 + r2_too_short = 0 total_reads = 0 total_reads_written = 0 enough_reads = False @@ -348,7 +349,7 @@ def write_chunks_to_disk( if enough_reads: break - print("Reading reads from files: {}, {}".format(read1_path, read2_path)) + print(f"Reading reads from files: {read1_path}, {read2_path}") with gzip.open(read1_path, "rt") as textfile1, gzip.open( read2_path, "rt" ) as textfile2: @@ -359,11 +360,11 @@ def write_chunks_to_disk( read1 = read1.strip() if len(read1) < chemistry_def.umi_barcode_end: - R1_too_short += 1 + r1_too_short += 1 # The entire read is skipped continue - if len(read2) < R2_min_length: - R2_too_short += 1 + if len(read2) < r2_min_length: + r2_too_short += 1 # The entire read is skipped continue @@ -373,7 +374,7 @@ def write_chunks_to_disk( read2_sliced = read2[ chemistry_def.R2_trim_start : ( - R2_min_length + chemistry_def.R2_trim_start + r2_min_length + chemistry_def.R2_trim_start ) ] chunked_file_object.write( @@ -426,7 +427,7 @@ def write_chunks_to_disk( return ( input_queue, temp_files, - R1_too_short, - R2_too_short, + r1_too_short, + r2_too_short, total_reads, ) diff --git a/cite_seq_count/templates/report.json b/cite_seq_count/templates/report.json new file mode 100644 index 0000000..616fbc7 --- /dev/null +++ b/cite_seq_count/templates/report.json @@ -0,0 +1,33 @@ +{ + "Date": "", + "Running time": "", + "CITE-seq-Count Version": "", + "Reads processed": "", + "Percentage mapped": "", + "Percentage unmapped": "", + "Percentage too short": "", + "r1_too_short": "", + "r2_too_short": "", + "Uncorrected cells": "", + "Correction": { + "Cell barcodes collapsing threshold": "", + "Cell barcodes corrected": "", + "UMI collapsing threshold": "", + "UMIs corrected": "" + }, + "Run parameters": { + "Read1_paths": "", + "Read2_paths": "" + }, + "Cell barcode": { + "First position": "", + "Last position": "" + }, + "UMI barcode": { + "First position": "", + "Last position": "" + }, + "Expected cells": "", + "Tags max errors": "", + "Start trim": "" +} \ No newline at end of file diff --git a/setup.py b/setup.py index bb49063..6c81cea 100644 --- a/setup.py +++ b/setup.py @@ -1,6 +1,8 @@ +"""Setup.py file +""" import setuptools -with open("README.md", "r") as fh: +with open("README.md", "r", encoding="utf-8") as fh: long_description = fh.read() setuptools.setup( @@ -28,7 +30,9 @@ "pandas>=0.23.4", "pybktree==1.1", "cython>=0.29.17", - "pyyaml==6.0" + "pyyaml==6.0", + "pooch==1.6.0", ], python_requires=">=3.8", + data_files=[("report_template", ["templates/report.json"])], ) From 3e3d52962dee8f4625b5a18fb4bc6feb093fca25 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sat, 2 Jul 2022 16:09:54 +0300 Subject: [PATCH 57/77] Pyupgrade to 3.8 --- cite_seq_count/__main__.py | 28 ++++++------------- cite_seq_count/argsparser.py | 26 ++++++++++++------ cite_seq_count/chemistry.py | 2 +- cite_seq_count/io.py | 2 +- cite_seq_count/mapping.py | 48 +++++++++++++++++---------------- cite_seq_count/preprocessing.py | 6 ++--- cite_seq_count/processing.py | 6 ++--- cite_seq_count/secondsToText.py | 8 +++--- 8 files changed, 63 insertions(+), 63 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index dccb56d..303cfe8 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -4,23 +4,13 @@ """ import sys import os -import logging import time from cite_seq_count import preprocessing, argsparser, mapping, processing, chemistry, io def main(): - # Create logger and stream handler - logger = logging.getLogger("cite_seq_count") - logger.setLevel(logging.CRITICAL) - ch = logging.StreamHandler() - ch.setLevel(logging.CRITICAL) - formatter = logging.Formatter( - "%(asctime)s - %(name)s - %(levelname)s - %(message)s" - ) - ch.setFormatter(formatter) - logger.addHandler(ch) + """Main function""" start_time = time.time() parser = argsparser.get_args() @@ -60,7 +50,7 @@ def main(): args=args, chemistry=chemistry_def, translation_dict=translation_dict ) # Checks before chunking. - (n_reads, R2_min_length, maximum_distance) = preprocessing.pre_run_checks( + (n_reads, r2_min_length, maximum_distance) = preprocessing.pre_run_checks( read1_paths=read1_paths, chemistry_def=chemistry_def, longest_tag_len=longest_tag_len, @@ -71,14 +61,14 @@ def main(): ( input_queue, temp_files, - R1_too_short, - R2_too_short, + r1_too_short, + r2_too_short, total_reads, ) = io.write_chunks_to_disk( args=args, read1_paths=read1_paths, read2_paths=read2_paths, - R2_min_length=R2_min_length, + r2_min_length=r2_min_length, n_reads_per_chunk=n_reads, chemistry_def=chemistry_def, ordered_tags=ordered_tags, @@ -92,7 +82,7 @@ def main(): # Check if 99% of the reads are unmapped. mapping.check_unmapped( no_match=merged_no_match, - too_short=R1_too_short + R2_too_short, + too_short=r1_too_short + r2_too_short, total_reads=total_reads, start_trim=chemistry_def.R2_trim_start, ) @@ -207,16 +197,14 @@ def main(): # Create report and write it to disk io.create_report( total_reads=total_reads, - reads_per_cell=reads_per_cell, no_match=merged_no_match, version=argsparser.get_package_version(), start_time=start_time, - ordered_tags=ordered_tags, umis_corrected=umis_corrected, bcs_corrected=bcs_corrected, bad_cells=clustered_cells, - R1_too_short=R1_too_short, - R2_too_short=R2_too_short, + r1_too_short=r1_too_short, + r2_too_short=r2_too_short, args=args, chemistry_def=chemistry_def, maximum_distance=maximum_distance, diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py index 3d408e8..3975089 100644 --- a/cite_seq_count/argsparser.py +++ b/cite_seq_count/argsparser.py @@ -1,14 +1,24 @@ -import pkg_resources +"""Functions for argument parsing +""" + import sys import tempfile -from argparse import ArgumentParser, ArgumentTypeError, RawTextHelpFormatter, FileType +from argparse import ArgumentParser, ArgumentTypeError, RawTextHelpFormatter + +import pkg_resources + # pylint: disable=no-name-in-module from multiprocess import cpu_count def get_package_version(): + """Return package version + + Returns: + str: Package version as string + """ version = pkg_resources.require("cite_seq_count")[0].version return version @@ -17,15 +27,15 @@ def chunk_size_limit(chunk_size: int) -> int: """Validates chunk_size limits""" max_size = 2147483647 try: - f = int(chunk_size) + chunk_value = int(chunk_size) except ValueError: - raise ArgumentTypeError("Chunk size must be an int") - if f < 1 or f > max_size: + raise SystemExit("Chunk size must be an int") + if chunk_value < 1 or chunk_value > max_size: raise ArgumentTypeError( "Argument must be < " + str(max_size) + "and > " + str(1) ) else: - return f + return chunk_value def thread_default(): @@ -197,7 +207,7 @@ def get_args(): "A csv file containning a translation list of all potential barcodes\n\n" "\tExample:\n" "whitelist,translation\n" - "\tAAACCCAAGAAACACT,AAACCCATCAAACACT\n\AAACCCAAGAAACCAT,AAACCCATCAAACCAT\n\AAACCCAAGAAACCCA,AAACCCATCAAACCCA\n\n" + "\tAAACCCAAGAAACACT,AAACCCATCAAACACT\n\\AAACCCAAGAAACCAT,AAACCCATCAAACCAT\n\\AAACCCAAGAAACCCA,AAACCCATCAAACCCA\n\n" ), ) @@ -318,7 +328,7 @@ def get_args(): parser.add_argument( "--version", action="version", - version="CITE-seq-Count v{}".format(get_package_version()), + version=f"CITE-seq-Count v{get_package_version()}", help="Print version number.", ) # Finally! Too many options XD diff --git a/cite_seq_count/chemistry.py b/cite_seq_count/chemistry.py index 02678ba..f250df2 100644 --- a/cite_seq_count/chemistry.py +++ b/cite_seq_count/chemistry.py @@ -52,7 +52,7 @@ def fetch_definitions(): Load some sample gravity data to use in your docs. """ fname = DEFINITIONS_DB.fetch("definitions.json") - with open(fname, "r") as json_file: + with open(fname) as json_file: data = json_file.read() json_data = json.loads(data) return json_data diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 43f668b..32d5bd8 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -203,7 +203,7 @@ def load_report_template() -> dict: Returns: dict: Dict for the report """ - with open(file=JSON_REPORT_PATH, encoding="utf-8", mode="r") as json_in: + with open(file=JSON_REPORT_PATH, encoding="utf-8") as json_in: try: report_dict = json.load(json_in) except json.JSONDecodeError: diff --git a/cite_seq_count/mapping.py b/cite_seq_count/mapping.py index d3ed584..965a3cb 100644 --- a/cite_seq_count/mapping.py +++ b/cite_seq_count/mapping.py @@ -1,13 +1,13 @@ +"""Mapping module. Holds all code related to mapping reads +""" import time import csv - import sys import os -import Levenshtein -from collections import Counter -from collections import defaultdict -from collections import namedtuple +from collections import Counter, defaultdict + +import Levenshtein # pylint: disable=no-name-in-module from multiprocess import Pool @@ -70,7 +70,7 @@ def map_data(input_queue, unmapped_id, args): return final_results, umis_per_cell, reads_per_cell, merged_no_match -def find_best_match(TAG_seq, tags, maximum_distance): +def find_best_match(tag_seq, tags, maximum_distance): """ Find the best match from the list of tags. @@ -79,7 +79,7 @@ def find_best_match(TAG_seq, tags, maximum_distance): If no matches found returns 'unmapped'. We add 1 Args: - TAG_seq (string): Sequence from R2 already start trimmed + tag_seq (string): Sequence from R2 already start trimmed tags (dict): A dictionary with the TAGs as keys and TAG Names as values. maximum_distance (int): Maximum distance given by the user. @@ -90,7 +90,7 @@ def find_best_match(TAG_seq, tags, maximum_distance): best_score = maximum_distance for tag in tags: # pylint: disable=no-member - score = Levenshtein.hamming(tag.sequence, TAG_seq[: len(tag.sequence)]) + score = Levenshtein.hamming(tag.sequence, tag_seq[: len(tag.sequence)]) if score == 0: # Best possible match return tag.id @@ -101,7 +101,7 @@ def find_best_match(TAG_seq, tags, maximum_distance): return best_match -def find_best_match_shift(TAG_seq, tags): +def find_best_match_shift(tag_seq, tags): """ Find the best match from the list of tags with sliding window. Only works with exact match. @@ -109,7 +109,7 @@ def find_best_match_shift(TAG_seq, tags): If no matches found returns 'unmapped'. Args: - TAG_seq (string): Sequence from R2 already start trimmed + tag_seq (string): Sequence from R2 already start trimmed tags (dict): A dictionary with the TAGs as keys and TAG Names as values. Returns: @@ -117,7 +117,7 @@ def find_best_match_shift(TAG_seq, tags): """ best_match = "unmapped" for tag in tags: - if tag.sequence in TAG_seq: + if tag.sequence in tag_seq: return tag.name return best_match @@ -141,30 +141,32 @@ def map_reads(mapping_input): """ # Initiate values (filename, tags, debug, maximum_distance, sliding_window) = mapping_input - print("Started mapping in child process {}".format(os.getpid())) + print(f"Started mapping in child process {os.getpid()}") results = {} no_match = Counter() - n = 1 + n_reads = 1 unmapped_id = len(tags) # Progress info - t = time.time() - with open(filename, "r") as input_file: + current_time = time.time() + with open(filename, encoding="utf-8") as input_file: reads = csv.reader(input_file) for read in reads: cell_barcode = read[0] # This change in bytes is required by umi_tools for umi correction UMI = bytes(read[1], "ascii") read2 = read[2] - if n % 1000000 == 0: + if n_reads % 1000000 == 0: print( "Processed 1,000,000 reads in {}. Total " "reads: {:,} in child {}".format( - secondsToText.secondsToText(time.time() - t), n, os.getpid() + secondsToText.secondsToText(time.time() - current_time), + n_reads, + os.getpid(), ) ) sys.stdout.flush() - t = time.time() + current_time = time.time() if cell_barcode not in results: results[cell_barcode] = defaultdict(Counter) @@ -181,9 +183,9 @@ def map_reads(mapping_input): if debug: print( - "cell_barcode:{0}\tUMI:{1}\tTAG_seq:{2}\n" - "cell barcode length:{3}\tUMI length:{4}\tTAG sequence length:{5}\n" - "Best match is: {6}\n".format( + "cell_barcode:{}\tUMI:{}\ttag_seq:{}\n" + "cell barcode length:{}\tUMI length:{}\tTAG sequence length:{}\n" + "Best match is: {}\n".format( cell_barcode, UMI, read2, @@ -194,10 +196,10 @@ def map_reads(mapping_input): ) ) sys.stdout.flush() - n += 1 + n_reads += 1 print( "Mapping done for process {}. Processed {:,} reads".format( - os.getpid(), n - 1 + os.getpid(), n_reads - 1 ) ) sys.stdout.flush() diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 41c1979..30bd778 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -31,7 +31,7 @@ def parse_filtered_list_csv(filename, barcode_length): barcodes = set(barcodes_pd.iloc[:, 0]) out_set = set() - barcode_pattern = regex.compile(r"^[ATGC]{{{}}}".format(barcode_length)) + barcode_pattern = regex.compile(fr"^[ATGC]{{{barcode_length}}}") for barcode in barcodes: checked_barcode = barcode.strip(STRIP_CHARS) if barcode_pattern.match(checked_barcode): @@ -69,7 +69,7 @@ def parse_cell_list_csv(filename, barcode_length): sys.exit( "Your translation file only holds 1 column or is tab delimited instead of csv." ) - barcode_pattern = regex.compile(r"^[ATGC]{{{}}}".format(barcode_length)) + barcode_pattern = regex.compile(fr"^[ATGC]{{{barcode_length}}}") header = data.columns set_dif = set(REQUIRED_HEADER) - set(header) @@ -348,7 +348,7 @@ def pre_run_checks(read1_paths, chemistry_def, longest_tag_len, args): # Print a statement if multiple files are run. if number_of_samples != 1: - print("Detected {} pairs of files to run on.".format(number_of_samples)) + print(f"Detected {number_of_samples} pairs of files to run on.") if args.sliding_window: R2_min_length = read2_lengths[0] diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index bfee077..1f4a6fe 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -170,7 +170,7 @@ def correct_cells_filtered_set( barcode_tree = pybktree.BKTree(Levenshtein.hamming, filtered_set) barcodes = set(umis_per_cell) print("Selecting reference candidates") - print("Processing {:,} cell barcodes".format(len(barcodes))) + print(f"Processing {len(barcodes):,} cell barcodes") # Run with one process true_to_false = find_true_to_false_map( @@ -262,7 +262,7 @@ def check_filtered_cells(filtered_cells, expected_cells, umis_per_cell): if filtered_cells is None: top_cells_tuple = umis_per_cell.most_common(expected_cells) # Select top cells based on total umis per cell - filtered_cells = set([pair[0] for pair in top_cells_tuple]) + filtered_cells = {pair[0] for pair in top_cells_tuple} return filtered_cells @@ -362,7 +362,7 @@ def correct_umis_in_cells(umi_correction_input): corrected_umis += temp_corrected_umis elif n_umis > max_umis: clustered_cells.add(cell_barcode) - print("Finished correcting umis in child {}".format(os.getpid())) + print(f"Finished correcting umis in child {os.getpid()}") return (final_results, corrected_umis, clustered_cells) diff --git a/cite_seq_count/secondsToText.py b/cite_seq_count/secondsToText.py index 2dac866..c8befc8 100644 --- a/cite_seq_count/secondsToText.py +++ b/cite_seq_count/secondsToText.py @@ -63,10 +63,10 @@ def secondsToText(secs, lang="EN"): filter( lambda x: bool(x), [ - "{0} {1}".format(days, days_text) if days else "", - "{0} {1}".format(hours, hours_text) if hours else "", - "{0} {1}".format(minutes, minutes_text) if minutes else "", - "{0:.4} {1}".format(seconds, seconds_text) if seconds else "", + f"{days} {days_text}" if days else "", + f"{hours} {hours_text}" if hours else "", + f"{minutes} {minutes_text}" if minutes else "", + f"{seconds:.4} {seconds_text}" if seconds else "", ], ) ) From 9e4cfea076b52be2eb89ccaca6a34c784015dc9b Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sun, 3 Jul 2022 16:34:11 +0300 Subject: [PATCH 58/77] code reformatting --- cite_seq_count/__main__.py | 2 +- cite_seq_count/argsparser.py | 4 +- cite_seq_count/chemistry.py | 8 ++-- cite_seq_count/io.py | 61 +++++++++++++++++----------- cite_seq_count/templates/report.json | 48 +++++++++++----------- 5 files changed, 69 insertions(+), 54 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 303cfe8..fddfd54 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -84,7 +84,7 @@ def main(): no_match=merged_no_match, too_short=r1_too_short + r2_too_short, total_reads=total_reads, - start_trim=chemistry_def.R2_trim_start, + start_trim=chemistry_def.r2_trim_start, ) # Remove temp chunks diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py index 3975089..0b281c1 100644 --- a/cite_seq_count/argsparser.py +++ b/cite_seq_count/argsparser.py @@ -38,7 +38,7 @@ def chunk_size_limit(chunk_size: int) -> int: return chunk_value -def thread_default(): +def thread_default() -> int: """ Set number of threads default. @@ -53,7 +53,7 @@ def thread_default(): return 1 -def get_args(): +def get_args() -> ArgumentParser: """ Get args. """ diff --git a/cite_seq_count/chemistry.py b/cite_seq_count/chemistry.py index f250df2..c0d5128 100644 --- a/cite_seq_count/chemistry.py +++ b/cite_seq_count/chemistry.py @@ -28,7 +28,7 @@ class Chemistry: cell_barcode_end: int umi_barcode_start: int umi_barcode_end: int - R2_trim_start: int + r2_trim_start: int translation_list_path: str @@ -52,7 +52,7 @@ def fetch_definitions(): Load some sample gravity data to use in your docs. """ fname = DEFINITIONS_DB.fetch("definitions.json") - with open(fname) as json_file: + with open(fname, encoding="utf-8") as json_file: data = json_file.read() json_data = json.loads(data) return json_data @@ -116,7 +116,7 @@ def get_chemistry_definition(chemistry_short_name): umi_barcode_end=chemistry_defs["barcode_structure_indexes"]["umi_barcode"][ "R1" ]["stop"], - R2_trim_start=chemistry_defs["sequence_structure_indexes"]["R2"]["start"] - 1, + r2_trim_start=chemistry_defs["sequence_structure_indexes"]["R2"]["start"] - 1, translation_list_path=path, ) return chemistry_def @@ -129,7 +129,7 @@ def create_chemistry_definition(args): cell_barcode_end=args.cb_last, umi_barcode_start=args.umi_first, umi_barcode_end=args.umi_last, - R2_trim_start=args.start_trim, + r2_trim_start=args.start_trim, translation_list_path=args.translation_list, ) return chemistry_def diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 32d5bd8..536a001 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -9,9 +9,11 @@ import tempfile import json -from collections import namedtuple +from collections import namedtuple, Counter from itertools import islice +from typing import Tuple +import scipy import pkg_resources import yaml import pandas as pd @@ -22,26 +24,26 @@ JSON_REPORT_PATH = pkg_resources.resource_filename(__name__, "templates/report.json") -def blocks(files, size=65536): +def blocks(file, size: int = 65536): """ A fast way of counting the lines of a large file. Ref: https://stackoverflow.com/a/9631635/9178565 Args: - files (io.handler): A file handler + file (io.handler): A file handler size (int): Block size Returns: A generator """ while True: - partial_file = files.read(size) + partial_file = file.read(size) if not partial_file: break yield partial_file -def get_n_lines(file_path): +def get_n_lines(file_path: str) -> int: """ Determines how many lines have to be processed depending on options and number of available lines. @@ -64,7 +66,7 @@ def get_n_lines(file_path): return n_lines -def get_read_paths(read1_path, read2_path): +def get_read_paths(read1_path: str, read2_path: str) -> Tuple[str, str]: """ Splits up 2 comma-separated strings of input files into list of files to process. Ensures both lists are equal in length. @@ -96,7 +98,7 @@ def get_read_paths(read1_path, read2_path): return (_read1_path, _read2_path) -def get_csv_reader_from_path(filename, sep="\t"): +def get_csv_reader_from_path(filename: str, sep: str = "\t") -> csv.reader: """ Returns a csv_reader object for a file weather it's a flat file or compressed. @@ -116,7 +118,12 @@ def get_csv_reader_from_path(filename, sep="\t"): def write_to_files( - sparse_matrix, filtered_cells, ordered_tags, data_type, outfolder, translation_dict + sparse_matrix: scipy.sparse.base.spmatrix, + filtered_cells: set, + ordered_tags: dict, + data_type: str, + outfolder: str, + translation_dict: dict, ): """Write the umi and read sparse matrices to file in gzipped mtx format. @@ -158,7 +165,13 @@ def write_to_files( os.remove(os.path.join(prefix, "matrix.mtx")) -def write_dense(sparse_matrix, ordered_tags, columns, outfolder, filename): +def write_dense( + sparse_matrix: scipy.sparse.base.spmatrix, + ordered_tags: dict, + columns: set, + outfolder: str, + filename: str, +): """ Writes a dense matrix in a csv format @@ -178,7 +191,9 @@ def write_dense(sparse_matrix, ordered_tags, columns, outfolder, filename): pandas_dense.to_csv(os.path.join(outfolder, filename), sep="\t") -def write_unmapped(merged_no_match, top_unknowns, outfolder, filename): +def write_unmapped( + merged_no_match: Counter, top_unknowns: int, outfolder: str, filename: str +): """ Writes a list of top unmapped sequences @@ -214,18 +229,18 @@ def load_report_template() -> dict: def create_report( - total_reads, - no_match, - version, + total_reads: int, + no_match: Counter, + version: str, start_time, - umis_corrected, - bcs_corrected, + umis_corrected: int, + bcs_corrected: int, bad_cells, - r1_too_short, - r2_too_short, + r1_too_short: int, + r2_too_short: int, args, chemistry_def, - maximum_distance, + maximum_distance: int, ): """ Creates a report with details about the run in a yaml format. @@ -255,8 +270,8 @@ def create_report( report_data["Percentage mapped"] = mapped_perc report_data["Percentage unmapped"] = unmapped_perc report_data["Percentage too short"] = too_short_perc - report_data["Percentage too short"]["r1_too_short"] = r1_too_short - report_data["Percentage too short"]["r2_too_short"] = r2_too_short + report_data["r1_too_short"] = r1_too_short + report_data["r2_too_short"] = r2_too_short report_data["Uncorrected cells"] = len(bad_cells) report_data["Correction"]["Cell barcodes collapsing threshold"] = args.bc_threshold report_data["Correction"]["Cell barcodes corrected"] = bcs_corrected @@ -278,7 +293,7 @@ def create_report( ] = chemistry_def.umi_barcode_end report_data["Expected cells"] = args.expected_cells report_data["Tags max errors"] = maximum_distance - report_data["Start trim"] = chemistry_def.R2_trim_start + report_data["Start trim"] = chemistry_def.r2_trim_start with open( os.path.join(args.outfolder, "run_report.yaml"), "w", encoding="utf-8" @@ -373,8 +388,8 @@ def write_chunks_to_disk( ] read2_sliced = read2[ - chemistry_def.R2_trim_start : ( - r2_min_length + chemistry_def.R2_trim_start + chemistry_def.r2_trim_start : ( + r2_min_length + chemistry_def.r2_trim_start ) ] chunked_file_object.write( diff --git a/cite_seq_count/templates/report.json b/cite_seq_count/templates/report.json index 616fbc7..3824014 100644 --- a/cite_seq_count/templates/report.json +++ b/cite_seq_count/templates/report.json @@ -2,32 +2,32 @@ "Date": "", "Running time": "", "CITE-seq-Count Version": "", - "Reads processed": "", - "Percentage mapped": "", - "Percentage unmapped": "", - "Percentage too short": "", - "r1_too_short": "", - "r2_too_short": "", - "Uncorrected cells": "", + "Reads processed": 0, + "Percentage mapped": 0, + "Percentage unmapped": 0, + "Percentage too short": 0, + "r1_too_short": 0, + "r2_too_short": 0, + "Uncorrected cells": 0, "Correction": { - "Cell barcodes collapsing threshold": "", - "Cell barcodes corrected": "", - "UMI collapsing threshold": "", - "UMIs corrected": "" + "Cell barcodes collapsing threshold": 0, + "Cell barcodes corrected": 0, + "UMI collapsing threshold": 0, + "UMIs corrected": 0 }, "Run parameters": { "Read1_paths": "", - "Read2_paths": "" - }, - "Cell barcode": { - "First position": "", - "Last position": "" - }, - "UMI barcode": { - "First position": "", - "Last position": "" - }, - "Expected cells": "", - "Tags max errors": "", - "Start trim": "" + "Read2_paths": "", + "Cell barcode": { + "First position": 0, + "Last position": 0 + }, + "UMI barcode": { + "First position": 0, + "Last position": 0 + }, + "Expected cells": 0, + "Tags max errors": 0, + "Start trim": 0 + } } \ No newline at end of file From 9195df971d08574a32ab7bf6f04baf13395a0f88 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sun, 16 Oct 2022 12:45:33 +0200 Subject: [PATCH 59/77] Fix: Fix testing for preprocessing --- cite_seq_count/__main__.py | 10 +-- cite_seq_count/io.py | 4 +- cite_seq_count/preprocessing.py | 86 ++++++++----------- setup.py | 1 + .../tags/fail/missing_entry_sequence.csv | 10 +++ tests/test_io.py | 4 +- tests/test_preprocessing.py | 20 +++-- 7 files changed, 70 insertions(+), 65 deletions(-) create mode 100644 tests/test_data/tags/fail/missing_entry_sequence.csv diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index fddfd54..aac2de5 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -23,15 +23,11 @@ def main(): # Check a few path before doing anything if not os.access(args.temp_path, os.W_OK): sys.exit( - "Temp folder: {} is not writable. Please check permissions and/or change temp folder.".format( - args.temp_path - ) + f"Temp folder: {args.temp_path} is not writable. Please check permissions and/or change temp folder." ) if not os.access(os.path.dirname(os.path.abspath(args.outfolder)), os.W_OK): sys.exit( - "Output folder: {} is not writable. Please check permissions and/or change output folder.".format( - args.outfolder - ) + f"Output folder: {args.outfolder} is not writable. Please check permissions and/or change output folder." ) # Get chemistry defs @@ -75,7 +71,7 @@ def main(): maximum_distance=maximum_distance, ) # Map the data - (final_results, umis_per_cell, reads_per_cell, merged_no_match) = mapping.map_data( + (final_results, umis_per_cell, _, merged_no_match) = mapping.map_data( input_queue=input_queue, unmapped_id=len(ordered_tags), args=args ) diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 536a001..ee33118 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -118,7 +118,7 @@ def get_csv_reader_from_path(filename: str, sep: str = "\t") -> csv.reader: def write_to_files( - sparse_matrix: scipy.sparse.base.spmatrix, + sparse_matrix: scipy.sparse.coo_matrix, filtered_cells: set, ordered_tags: dict, data_type: str, @@ -166,7 +166,7 @@ def write_to_files( def write_dense( - sparse_matrix: scipy.sparse.base.spmatrix, + sparse_matrix: scipy.sparse.coo_matrix, ordered_tags: dict, columns: set, outfolder: str, diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 30bd778..668374b 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -1,17 +1,19 @@ +"""Sets of functions to preprocess the data""" + import csv import gzip import sys +from collections import namedtuple +from itertools import combinations, islice import regex import Levenshtein -import umi_tools.whitelist_methods as whitelist_methods +from pandas import read_csv +from cite_seq_count.io import get_n_lines -from cite_seq_count.io import get_csv_reader_from_path, get_n_lines -from collections import namedtuple -from itertools import combinations -from itertools import islice - -from pandas import read_csv +REQUIRED_TAGS_HEADER = ["sequence", "feature_name"] +REQUIRED_TRANSLATION_HEADER = ["reference", "translation"] +STRIP_CHARS = '"0123456789- \t\n' def parse_filtered_list_csv(filename, barcode_length): @@ -25,22 +27,19 @@ def parse_filtered_list_csv(filename, barcode_length): Returns: set: A set of barcodes """ - STRIP_CHARS = '"0123456789- \t\n' barcodes_pd = read_csv(filename) barcodes = set(barcodes_pd.iloc[:, 0]) out_set = set() - barcode_pattern = regex.compile(fr"^[ATGC]{{{barcode_length}}}") + barcode_pattern = regex.compile(rf"^[ATGC]{{{barcode_length}}}") for barcode in barcodes: checked_barcode = barcode.strip(STRIP_CHARS) if barcode_pattern.match(checked_barcode): out_set.add(checked_barcode) else: sys.exit( - "Only ATGC barcodes are accepted in the filtered list. Please delete entry {}".format( - checked_barcode - ) + f"Only ATGC barcodes are accepted in the filtered list. Please delete entry {checked_barcode}" ) return out_set @@ -60,23 +59,19 @@ def parse_cell_list_csv(filename, barcode_length): set: The set of white-listed barcodes. """ - STRIP_CHARS = '"0123456789- \t\n' - REQUIRED_HEADER = ["reference", "translation"] - data = read_csv(filename, dtype={"reference": str, "translation": str}) if data.shape[1] != 2: print(data.head()) sys.exit( "Your translation file only holds 1 column or is tab delimited instead of csv." ) - barcode_pattern = regex.compile(fr"^[ATGC]{{{barcode_length}}}") + barcode_pattern = regex.compile(rf"^[ATGC]{{{barcode_length}}}") header = data.columns - set_dif = set(REQUIRED_HEADER) - set(header) + set_dif = set(REQUIRED_TRANSLATION_HEADER) - set(header) if len(set_dif) != 0: - raise SystemExit( - "The header is missing {}. Exiting".format(",".join(list(set_dif))) - ) + set_diff_string = ",".join(list(set_dif)) + raise SystemExit(f"The header is missing {set_diff_string}. Exiting") # Prepare and validate data @@ -85,15 +80,11 @@ def parse_cell_list_csv(filename, barcode_length): if any(data["reference"].map(lambda x: not barcode_pattern.match(x))): sys.exit( - "Barcode(s) in reference column don't match [ATGC] or a length of {}. Please check.".format( - barcode_length - ) + f"Barcode(s) in reference column don't match [ATGC] or a length of {barcode_length}. Please check." ) if any(data["translation"].map(lambda x: not barcode_pattern.match(x))): sys.exit( - "Barcode(s) in translation column don't match [ATGC] or a length of {}. Please check.".format( - barcode_length - ) + f"Barcode(s) in translation column don't match [ATGC] or a length of {barcode_length}. Please check." ) translation_dict = dict(zip(data.translation, data.reference)) @@ -119,33 +110,34 @@ def parse_tags_csv(file_name): dict: A dictionary using sequences as keys and feature names as values. """ - REQUIRED_HEADER = ["sequence", "feature_name"] atgc_test = regex.compile("^[ATGC]{1,}$") try: - with open(file_name) as csvfile: + with open(file_name, mode="r", encoding="utf-8") as csvfile: csv_reader = csv.reader(csvfile) - except Exception as e: - sys.exit(e) - with open(file_name) as csvfile: + except IOError: + sys.exit(f"Cannot read file {file_name}") + with open(file_name, mode="r", encoding="utf-8") as csvfile: csv_reader = csv.reader(csvfile) tags = {} header = next(csv_reader) - set_dif = set(REQUIRED_HEADER) - set(header) + set_dif = set(REQUIRED_TAGS_HEADER) - set(header) if len(set_dif) != 0: - raise SystemExit( - "The header is missing {}. Exiting".format(",".join(list(set_dif))) - ) + set_diff_string = ",".join(list(set_dif)) + raise SystemExit(f"The header is missing {set_diff_string}. Exiting") sequence_id = header.index("sequence") feature_id = header.index("feature_name") for i, row in enumerate(csv_reader): + # Allow for optional columns + if len(row) < len(REQUIRED_TAGS_HEADER): + raise SystemExit( + f"Row number: {i+1} is incomplete. Please check the csv Tags file." + ) sequence = row[sequence_id].strip() if not regex.match(atgc_test, sequence): raise SystemExit( - "Sequence {} on line {} is not only composed of ATGC. Exiting".format( - sequence, i - ) + f"Sequence {sequence} on line {i} is not only composed of ATGC. Exiting" ) tags[sequence] = row[feature_id].strip() return tags @@ -194,11 +186,11 @@ def check_tags(tags, maximum_distance): # Check if the distance is big enoughbetween tags offending_pairs = [] - for a, b in combinations(seq_list, 2): + for tag_a, tag_b in combinations(seq_list, 2): # pylint: disable=no-member - distance = Levenshtein.distance(a, b) + distance = Levenshtein.distance(tag_a, tag_b) if distance <= (maximum_distance - 1): - offending_pairs.append([a, b, distance]) + offending_pairs.append([tag_a, tag_b, distance]) # If offending pairs are found, print them all. if offending_pairs: print( @@ -208,11 +200,7 @@ def check_tags(tags, maximum_distance): "Offending case(s):\n" ) for pair in offending_pairs: - print( - "\t{tag1}\n\t{tag2}\n\tDistance = {distance}\n".format( - tag1=pair[0], tag2=pair[1], distance=pair[2] - ) - ) + print(f"\t{pair[0]}\n\t{pair[1]}\n\tDistance = {pair[2]}\n") sys.exit("Exiting the application.\n") return (tag_list, longest_tag_len) @@ -351,12 +339,12 @@ def pre_run_checks(read1_paths, chemistry_def, longest_tag_len, args): print(f"Detected {number_of_samples} pairs of files to run on.") if args.sliding_window: - R2_min_length = read2_lengths[0] + r2_min_length = read2_lengths[0] maximum_distance = 0 else: - R2_min_length = longest_tag_len + r2_min_length = longest_tag_len maximum_distance = args.max_error - return n_reads, R2_min_length, maximum_distance + return n_reads, r2_min_length, maximum_distance def get_filtered_list(args, chemistry, translation_dict): diff --git a/setup.py b/setup.py index 6c81cea..8f74576 100644 --- a/setup.py +++ b/setup.py @@ -32,6 +32,7 @@ "cython>=0.29.17", "pyyaml==6.0", "pooch==1.6.0", + "six==1.16.0", ], python_requires=">=3.8", data_files=[("report_template", ["templates/report.json"])], diff --git a/tests/test_data/tags/fail/missing_entry_sequence.csv b/tests/test_data/tags/fail/missing_entry_sequence.csv new file mode 100644 index 0000000..b9ebd06 --- /dev/null +++ b/tests/test_data/tags/fail/missing_entry_sequence.csv @@ -0,0 +1,10 @@ +sequence,feature_name +AGGACCATCCAA,CITE_LEN_12_1 +ACATGTTACCGT,CITE_LEN_12_2 +,CITE_LEN_12_3 +TCGATAATGCGAGTACAA,CITE_LEN_18_1 +GAGGCTGAGCTAGCTAGT,CITE_LEN_18_2 +GGCTGATGCTGACTGCTA,CITE_LEN_18_3 +TGTGACGTATTGCTAGCTAG,CITE_LEN_20_1 +ACTGTCTAACGGGTCAGTGC,CITE_LEN_20_2 +TATCACATCGGTGGATCCAT,CITE_LEN_20_3 \ No newline at end of file diff --git a/tests/test_io.py b/tests/test_io.py index d2084fc..f407e5b 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -80,7 +80,7 @@ def test_write_to_files_wo_translation(data, tmpdir): ) file_path = os.path.join(tmpdir, "without_translation", "umi_count/matrix.mtx.gz") with gzip.open(file_path, "rb") as mtx_file: - assert isinstance(scipy.io.mmread(mtx_file), scipy.sparse.coo.coo_matrix) + assert isinstance(scipy.io.mmread(mtx_file), scipy.sparse.coo_matrix) assert md5_sums[barcodes_path] == md5(barcodes_path) assert md5_sums[features_path] == md5(features_path) assert md5_sums[mtx_path] == md5(mtx_path) @@ -113,7 +113,7 @@ def test_write_to_files_with_translation(data, tmpdir): ) file_path = os.path.join(tmpdir, "with_translation", "umi_count/matrix.mtx.gz") with gzip.open(file_path, "rb") as mtx_file: - assert isinstance(scipy.io.mmread(mtx_file), scipy.sparse.coo.coo_matrix) + assert isinstance(scipy.io.mmread(mtx_file), scipy.sparse.coo_matrix) assert md5_sums[barcodes_path] == md5(barcodes_path) assert md5_sums[features_path] == md5(features_path) assert md5_sums[mtx_path] == md5(mtx_path) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 829776d..13463d6 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -1,13 +1,15 @@ +"""Test function preprocessing of the module""" + +import glob +from collections import namedtuple + import pytest -import io from cite_seq_count import preprocessing -import glob -from collections import namedtuple, OrderedDict -from itertools import islice @pytest.fixture def data(): + """Load up data for testing""" pytest.passing_csv = "tests/test_data/tags/pass/*.csv" pytest.failing_csv = "tests/test_data/tags/fail/*.csv" @@ -21,7 +23,9 @@ def data(): pytest.correct_tags_path = "tests/test_data/tags/pass/correct.csv" # Create some variables to compare to - pytest.correct_reference_translation_list = set(["ACTGTTTTATTGGCCT","TTCATCCTTTAGGGAT"]) + pytest.correct_reference_translation_list = set( + ["ACTGTTTTATTGGCCT", "TTCATCCTTTAGGGAT"] + ) pytest.correct_tags = { "AGGACCATCCAA": "CITE_LEN_12_1", "ACATGTTACCGT": "CITE_LEN_12_2", @@ -51,12 +55,18 @@ def data(): def test_csv_parser(data): + """Test the csv parser + + Args: + data (_type_): _description_ + """ passing_files = glob.glob(pytest.passing_csv) for file_path in passing_files: preprocessing.parse_tags_csv(file_path) with pytest.raises(SystemExit): failing_files = glob.glob(pytest.failing_csv) for file_path in failing_files: + print(file_path) preprocessing.parse_tags_csv(file_path) From 1548e32361f9ea751a868cc76e5747491c9d7a20 Mon Sep 17 00:00:00 2001 From: Patrick Roelli Date: Sun, 30 Oct 2022 13:04:07 +0100 Subject: [PATCH 60/77] Fix: formatting --- cite_seq_count/mapping.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/cite_seq_count/mapping.py b/cite_seq_count/mapping.py index 965a3cb..250a7b6 100644 --- a/cite_seq_count/mapping.py +++ b/cite_seq_count/mapping.py @@ -212,7 +212,5 @@ def check_unmapped(no_match, too_short, total_reads, start_trim): sum_unmapped = sum(no_match.values()) + too_short if sum_unmapped / total_reads > float(0.99): sys.exit( - """More than 99% of your data is unmapped.\nPlease check that your --start_trim {} parameter is correct and that your tags file is properly formatted""".format( - start_trim - ) + f"More than 99% of your data is unmapped.\nPlease check that your --start_trim {start_trim} parameter is correct and that your tags file is properly formatted" ) From a65f62b4c68179347184a458fb21c7b0757d08b2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patrick=20R=C3=B6lli?= Date: Wed, 6 Sep 2023 10:05:00 +0200 Subject: [PATCH 61/77] Preprocessing: Code refactor with tests --- cite_seq_count/__main__.py | 1 + cite_seq_count/chemistry.py | 6 - cite_seq_count/io.py | 18 +- cite_seq_count/mapping.py | 15 +- cite_seq_count/preprocessing.py | 244 ++++++++---------- setup.py | 4 +- .../reference_lists/pass/translation.csv | 3 +- tests/test_preprocessing.py | 82 +++--- 8 files changed, 179 insertions(+), 194 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index aac2de5..f9cd104 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -86,6 +86,7 @@ def main(): # Remove temp chunks for file_path in temp_files: os.remove(file_path) + # Check that we have a filtered cell list to work on filtered_cells = processing.check_filtered_cells( filtered_cells=filtered_cells, diff --git a/cite_seq_count/chemistry.py b/cite_seq_count/chemistry.py index c0d5128..ecfc89f 100644 --- a/cite_seq_count/chemistry.py +++ b/cite_seq_count/chemistry.py @@ -1,14 +1,8 @@ """This module is holding code for all the remote fetching on the chemistries database.""" -import requests -import sys import os -import gzip -import io import pooch -import csv import json -from collections import namedtuple from dataclasses import dataclass diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index ee33118..dcee05a 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -12,6 +12,9 @@ from collections import namedtuple, Counter from itertools import islice from typing import Tuple +from pathlib import Path +from os import access, R_OK + import scipy import pkg_resources @@ -164,6 +167,19 @@ def write_to_files( shutil.copyfileobj(mtx_in, mtx_gz) os.remove(os.path.join(prefix, "matrix.mtx")) +def check_file(file_str:str) -> Path: + """Check that a file exists and is readable + + Args: + file_str (str): Path to the file as a string + + Returns: + Path: Path to the file + """ + file_path = Path(file_str) + if file_path.exists and access(file_path, R_OK): + return file_path + def write_dense( sparse_matrix: scipy.sparse.coo_matrix, @@ -187,7 +203,7 @@ def write_dense( index = [] for tag in ordered_tags: index.append(tag.name) - pandas_dense = pd.DataFrame(sparse_matrix.todense(), columns=columns, index=index) + pandas_dense = pd.DataFrame(sparse_matrix.todense(), columns=list(columns), index=index) pandas_dense.to_csv(os.path.join(outfolder, filename), sep="\t") diff --git a/cite_seq_count/mapping.py b/cite_seq_count/mapping.py index 250a7b6..8c9196b 100644 --- a/cite_seq_count/mapping.py +++ b/cite_seq_count/mapping.py @@ -6,7 +6,7 @@ import os from collections import Counter, defaultdict - +import numpy as np import Levenshtein # pylint: disable=no-name-in-module @@ -69,6 +69,10 @@ def map_data(input_queue, unmapped_id, args): return final_results, umis_per_cell, reads_per_cell, merged_no_match +def fast_dict_getter(seq, tags, unmapped_id): + return tags.get(seq, unmapped_id) + + def find_best_match(tag_seq, tags, maximum_distance): """ @@ -145,8 +149,11 @@ def map_reads(mapping_input): results = {} no_match = Counter() n_reads = 1 - - unmapped_id = len(tags) + new_tags = {} + vec_dict_mapper = np.vectorize(fast_dict_getter) + for i in tags: + new_tags[i.sequence] = i.id + unmapped_id = len(tags) # Progress info current_time = time.time() with open(filename, encoding="utf-8") as input_file: @@ -174,7 +181,7 @@ def map_reads(mapping_input): if sliding_window: best_match = find_best_match_shift(read2, tags) else: - best_match = find_best_match(read2, tags, maximum_distance) + best_match = fast_dict_getter(read2, new_tags, len(tags)) results[cell_barcode][best_match][UMI] += 1 diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 668374b..0feec91 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -1,51 +1,26 @@ """Sets of functions to preprocess the data""" -import csv + import gzip import sys -from collections import namedtuple from itertools import combinations, islice -import regex import Levenshtein -from pandas import read_csv -from cite_seq_count.io import get_n_lines +from cite_seq_count.io import get_n_lines, check_file +import polars as pl +from pandas import read_csv + REQUIRED_TAGS_HEADER = ["sequence", "feature_name"] -REQUIRED_TRANSLATION_HEADER = ["reference", "translation"] +REQUIRED_CELLS_REF_HEADER = ["reference"] +OPTIONAL_CELLS_REF_HEADER = ["translation"] +FEATURE_NAME = "feature_name" +SEQUENCE = "sequence" +REQUIRED_TAGS_HEADER = [FEATURE_NAME, SEQUENCE] STRIP_CHARS = '"0123456789- \t\n' -def parse_filtered_list_csv(filename, barcode_length): - """ - Reads in a one column, no header list of barcodes and returns a set. - - Args: - filename(str): file path - barcode_length(int): Barcode expected length - - Returns: - set: A set of barcodes - """ - barcodes_pd = read_csv(filename) - - barcodes = set(barcodes_pd.iloc[:, 0]) - - out_set = set() - barcode_pattern = regex.compile(rf"^[ATGC]{{{barcode_length}}}") - for barcode in barcodes: - checked_barcode = barcode.strip(STRIP_CHARS) - if barcode_pattern.match(checked_barcode): - out_set.add(checked_barcode) - else: - sys.exit( - f"Only ATGC barcodes are accepted in the filtered list. Please delete entry {checked_barcode}" - ) - - return out_set - - -def parse_cell_list_csv(filename, barcode_length): +def parse_cell_list_csv(filename: str, barcode_length: int) -> pl.DataFrame: """Reads white-listed barcodes from a CSV file. The function accepts plain barcodes or even 10X style barcodes with the @@ -59,39 +34,52 @@ def parse_cell_list_csv(filename, barcode_length): set: The set of white-listed barcodes. """ - data = read_csv(filename, dtype={"reference": str, "translation": str}) - if data.shape[1] != 2: - print(data.head()) - sys.exit( - "Your translation file only holds 1 column or is tab delimited instead of csv." - ) - barcode_pattern = regex.compile(rf"^[ATGC]{{{barcode_length}}}") + file_path = check_file(filename) + cells_pl = pl.read_csv(file_path) + barcode_pattern = rf"^[ATGC]{{{barcode_length}}}" - header = data.columns - set_dif = set(REQUIRED_TRANSLATION_HEADER) - set(header) + header = cells_pl.columns + set_dif = set(REQUIRED_CELLS_REF_HEADER) - set(header) if len(set_dif) != 0: set_diff_string = ",".join(list(set_dif)) raise SystemExit(f"The header is missing {set_diff_string}. Exiting") + if OPTIONAL_CELLS_REF_HEADER in header: + with_translation = True + else: + with_translation = False + # Prepare and validate cells_pl + if with_translation: + cells_pl = cells_pl.with_columns( + reference=pl.col("reference").str.strip(STRIP_CHARS), + translation=pl.col("translation").str.strip(STRIP_CHARS), + ) - # Prepare and validate data - - data["reference"] = data["reference"].map(lambda x: x.rstrip(STRIP_CHARS)) - data["translation"] = data["translation"].map(lambda x: x.rstrip(STRIP_CHARS)) - - if any(data["reference"].map(lambda x: not barcode_pattern.match(x))): - sys.exit( - f"Barcode(s) in reference column don't match [ATGC] or a length of {barcode_length}. Please check." + else: + cells_pl = cells_pl.with_columns( + reference=pl.col("reference").str.strip(STRIP_CHARS), ) - if any(data["translation"].map(lambda x: not barcode_pattern.match(x))): - sys.exit( - f"Barcode(s) in translation column don't match [ATGC] or a length of {barcode_length}. Please check." + + check_sequence_pattern( + df=cells_pl, + pattern=barcode_pattern, + column_name="reference", + file_type="Cell reference", + expected_pattern="ATGC", + ) + + if with_translation: + check_sequence_pattern( + df=cells_pl, + pattern=barcode_pattern, + column_name="translation", + file_type="Cell reference", + expected_pattern="ATGC", ) - translation_dict = dict(zip(data.translation, data.reference)) - return translation_dict + return cells_pl -def parse_tags_csv(file_name): +def parse_tags_csv(file_name: str) -> pl.DataFrame: """Reads the TAGs from a CSV file. Checks that the header contains necessary strings and if sequences are made of ATGC @@ -104,46 +92,75 @@ def parse_tags_csv(file_name): TTCCGCCTCTCTTTG,Hashtag_3 Args: - file_name (file): TAGs file name. + file_name (str): file path as a tring Returns: - dict: A dictionary using sequences as keys and feature names as values. - + pl.DataFrame: polars dataframe with the csv content """ - atgc_test = regex.compile("^[ATGC]{1,}$") - - try: - with open(file_name, mode="r", encoding="utf-8") as csvfile: - csv_reader = csv.reader(csvfile) - except IOError: - sys.exit(f"Cannot read file {file_name}") - with open(file_name, mode="r", encoding="utf-8") as csvfile: - csv_reader = csv.reader(csvfile) - tags = {} - header = next(csv_reader) - set_dif = set(REQUIRED_TAGS_HEADER) - set(header) - if len(set_dif) != 0: - set_diff_string = ",".join(list(set_dif)) - raise SystemExit(f"The header is missing {set_diff_string}. Exiting") - sequence_id = header.index("sequence") - feature_id = header.index("feature_name") - for i, row in enumerate(csv_reader): - # Allow for optional columns - if len(row) < len(REQUIRED_TAGS_HEADER): - raise SystemExit( - f"Row number: {i+1} is incomplete. Please check the csv Tags file." - ) - sequence = row[sequence_id].strip() - if not regex.match(atgc_test, sequence): - raise SystemExit( - f"Sequence {sequence} on line {i} is not only composed of ATGC. Exiting" - ) - tags[sequence] = row[feature_id].strip() - return tags + atgc_test = "^[ATGC]{1,}$" + + file_path = check_file(file_str=file_name) + + data_pl = pl.read_csv(file_path) + file_header = set(data_pl.columns) + set_diff = set(REQUIRED_TAGS_HEADER).difference(file_header) + if len(set_diff) != 0: + set_diff_str = " AND ".join(set_diff) + raise SystemExit( + f"The header of the tags file at {file_path} is missing the following header(s) {set_diff_str}" + ) + for column in REQUIRED_TAGS_HEADER: + if not data_pl.filter(pl.col([column]) is None).is_empty(): + raise SystemExit( + f"Column {column} is missing a value. Please fix the CSV file." + ) + check_sequence_pattern( + df=data_pl, + pattern=atgc_test, + column_name=SEQUENCE, + file_type="tags", + expected_pattern="ATGC", + ) + + return data_pl + + +def check_sequence_pattern( + df: pl.DataFrame, + pattern: str, + column_name: str, + file_type: str, + expected_pattern: str, +) -> None: + """Check that a column of a polars df matches a given pattern and exit if not + + Args: + df (pl.DataFrame): Df holding the info to be tested + pattern (str): Regex pattern to be tested + column_name (str): Which column to test + file_type (str): File type for the error raised + expected_pattern (str): Human readable pattern to be raised + + Raises: + SystemExit: Exists if some patterns don't match + """ + regex_test = df.with_columns( + pl.col(column_name).str.contains(pattern).alias("regex") + ) + if not regex_test.select(pl.col("regex").all()).get_column("regex").item(): + sequences = ( + regex_test.filter(pl.col("regex") == False) + .get_column(column_name) + .to_list() + ) + sequences_str = "\n".join(sequences) + raise SystemExit( + f"Some sequences in the {file_type} file is not only composed of the proper pattern {expected_pattern}.\nHere are the sequences{sequences_str}" + ) -def check_tags(tags, maximum_distance): +def check_tags(tags_pl: pl.DataFrame, maximum_distance: int) -> int: """Evaluates the distance between the TAGs based on the `maximum distance` argument provided. @@ -157,36 +174,13 @@ def check_tags(tags, maximum_distance): between two TAGs. Returns: - list: An ordered list of namedtuples int: the length of the longest TAG """ - tag = namedtuple("tag", ["name", "sequence", "id"]) - longest_tag_len = 0 - seq_list = [] - tag_list = [] - for i, tag_seq in enumerate(sorted(tags, key=len, reverse=True)): - safe_name = sanitize_name(tags[tag_seq]) - - # for index, tag_name in enumerate(ordered_tags): - tag_list.append( - tag( - name=safe_name, - sequence=tag_seq, - id=i, - ) - ) - if len(tag_seq) > longest_tag_len: - longest_tag_len = len(tag_seq) - seq_list.append(tag_seq) - # tag_list.append(tag(name="unmapped", sequence="UNKNOWN", id=i + 1,)) - # If only one TAG is provided, then no distances to compare. - if len(tags) == 1: - return (tag_list, longest_tag_len) # Check if the distance is big enoughbetween tags offending_pairs = [] - for tag_a, tag_b in combinations(seq_list, 2): + for tag_a, tag_b in combinations(tags_pl["sequence"], 2): # pylint: disable=no-member distance = Levenshtein.distance(tag_a, tag_b) if distance <= (maximum_distance - 1): @@ -202,21 +196,9 @@ def check_tags(tags, maximum_distance): for pair in offending_pairs: print(f"\t{pair[0]}\n\t{pair[1]}\n\tDistance = {pair[2]}\n") sys.exit("Exiting the application.\n") + longest_tag_len = max(tags_pl["sequence"].str.n_chars()) - return (tag_list, longest_tag_len) - - -def sanitize_name(string): - """ - Transforms special characters that are not compatible with namedtuples - - Args: - string(str): a string from a feature name - - Returns: - str: modified string - """ - return string.replace("-", "_") + return longest_tag_len def get_read_length(filename): diff --git a/setup.py b/setup.py index 8f74576..58f6cec 100644 --- a/setup.py +++ b/setup.py @@ -24,7 +24,7 @@ "python-levenshtein>=0.12.0", "scipy>=1.1.0", "multiprocess>=0.70.6.1", - "umi_tools==1.1.1", + "umi_tools==1.1.4", "pytest>=6.0.0", "pytest-dependency==0.4.0", "pandas>=0.23.4", @@ -35,5 +35,5 @@ "six==1.16.0", ], python_requires=">=3.8", - data_files=[("report_template", ["templates/report.json"])], + package_data={"report_template": ["templates/*.json"]}, ) diff --git a/tests/test_data/reference_lists/pass/translation.csv b/tests/test_data/reference_lists/pass/translation.csv index 9c151a9..e41cdd7 100644 --- a/tests/test_data/reference_lists/pass/translation.csv +++ b/tests/test_data/reference_lists/pass/translation.csv @@ -1,3 +1,2 @@ reference,translation -ACTGTTTTATTGGCCT,ACTGTTTTATTGGCCT -TTCATAAGGTAGGGAT,TTCATCCTTTAGGGAT \ No newline at end of file +ACTGTTTTATTGGCCT,TTCATCCTTTAGGGAT \ No newline at end of file diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 13463d6..3b9f203 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -4,6 +4,8 @@ from collections import namedtuple import pytest +import polars as pl +from polars.testing import assert_frame_equal from cite_seq_count import preprocessing @@ -23,32 +25,35 @@ def data(): pytest.correct_tags_path = "tests/test_data/tags/pass/correct.csv" # Create some variables to compare to - pytest.correct_reference_translation_list = set( - ["ACTGTTTTATTGGCCT", "TTCATCCTTTAGGGAT"] + pytest.correct_reference_translation_list = pl.DataFrame( + {"reference": "ACTGTTTTATTGGCCT", "translation": "TTCATCCTTTAGGGAT"} + ) + pytest.correct_tag_pl = pl.DataFrame( + { + "feature_name": [ + "CITE_LEN_20_1", + "CITE_LEN_20_2", + "CITE_LEN_20_3", + "CITE_LEN_18_1", + "CITE_LEN_18_2", + "CITE_LEN_18_3", + "CITE_LEN_12_1", + "CITE_LEN_12_2", + "CITE_LEN_12_3", + ], + "sequence": [ + "TGTGACGTATTGCTAGCTAG", + "ACTGTCTAACGGGTCAGTGC", + "TATCACATCGGTGGATCCAT", + "TCGATAATGCGAGTACAA", + "GAGGCTGAGCTAGCTAGT", + "GGCTGATGCTGACTGCTA", + "AGGACCATCCAA", + "ACATGTTACCGT", + "AGCTTACTATCC", + ], + } ) - pytest.correct_tags = { - "AGGACCATCCAA": "CITE_LEN_12_1", - "ACATGTTACCGT": "CITE_LEN_12_2", - "AGCTTACTATCC": "CITE_LEN_12_3", - "TCGATAATGCGAGTACAA": "CITE_LEN_18_1", - "GAGGCTGAGCTAGCTAGT": "CITE_LEN_18_2", - "GGCTGATGCTGACTGCTA": "CITE_LEN_18_3", - "TGTGACGTATTGCTAGCTAG": "CITE_LEN_20_1", - "ACTGTCTAACGGGTCAGTGC": "CITE_LEN_20_2", - "TATCACATCGGTGGATCCAT": "CITE_LEN_20_3", - } - tag = namedtuple("tag", ["name", "sequence", "id"]) - pytest.correct_tags_tuple = [ - tag(name="CITE_LEN_20_1", sequence="TGTGACGTATTGCTAGCTAG", id=0), - tag(name="CITE_LEN_20_2", sequence="ACTGTCTAACGGGTCAGTGC", id=1), - tag(name="CITE_LEN_20_3", sequence="TATCACATCGGTGGATCCAT", id=2), - tag(name="CITE_LEN_18_1", sequence="TCGATAATGCGAGTACAA", id=3), - tag(name="CITE_LEN_18_2", sequence="GAGGCTGAGCTAGCTAGT", id=4), - tag(name="CITE_LEN_18_3", sequence="GGCTGATGCTGACTGCTA", id=5), - tag(name="CITE_LEN_12_1", sequence="AGGACCATCCAA", id=6), - tag(name="CITE_LEN_12_2", sequence="ACATGTTACCGT", id=7), - tag(name="CITE_LEN_12_3", sequence="AGCTTACTATCC", id=8), - ] pytest.barcode_slice = slice(0, 16) pytest.umi_slice = slice(16, 26) pytest.barcode_umi_length = 26 @@ -66,27 +71,16 @@ def test_csv_parser(data): with pytest.raises(SystemExit): failing_files = glob.glob(pytest.failing_csv) for file_path in failing_files: - print(file_path) preprocessing.parse_tags_csv(file_path) -def test_filtered_list_parser(data): - passing_files = glob.glob(pytest.passing_filtered_list_csv) - for file_path in passing_files: - preprocessing.parse_filtered_list_csv(file_path, barcode_length=16) - with pytest.raises(SystemExit): - failing_files = glob.glob(pytest.failing_filtered_list_csv) - for file_path in failing_files: - preprocessing.parse_filtered_list_csv(file_path, barcode_length=16) - - @pytest.mark.dependency() def test_parse_reference_list_csv(data): passing_files = glob.glob(pytest.passing_reference_list_csv) for file_path in passing_files: - assert preprocessing.parse_cell_list_csv(file_path, 16).keys() in ( - pytest.correct_reference_translation_list, - 1, + assert_frame_equal( + left=preprocessing.parse_cell_list_csv(file_path, 16), + right=pytest.correct_reference_translation_list, ) with pytest.raises(SystemExit): failing_files = glob.glob(pytest.failing_reference_list_csv) @@ -94,14 +88,6 @@ def test_parse_reference_list_csv(data): preprocessing.parse_cell_list_csv(file_path, 16) -@pytest.mark.dependency() -def test_parse_tags_csv(data): - tags = preprocessing.check_tags(pytest.correct_tags, 5)[0] - for i, tag in enumerate(tags): - assert tag == pytest.correct_tags_tuple[i] - - -@pytest.mark.dependency(depends=["test_parse_tags_csv"]) def test_check_distance_too_big_between_tags(data): with pytest.raises(SystemExit): - preprocessing.check_tags(pytest.correct_tags, 8) + preprocessing.check_tags(pytest.correct_tag_pl, 8) From 6c5f9e2faff06e4123449dd7c02c266b4c385634 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patrick=20R=C3=B6lli?= Date: Fri, 20 Oct 2023 23:03:21 +0200 Subject: [PATCH 62/77] Rewrote barcode correction --- cite_seq_count/__main__.py | 94 ++++---- cite_seq_count/argsparser.py | 43 ++-- cite_seq_count/chemistry.py | 43 ++-- cite_seq_count/io.py | 142 ++++++++++-- cite_seq_count/mapping.py | 376 ++++++++++++++------------------ cite_seq_count/preprocessing.py | 216 +++++++++++------- cite_seq_count/processing.py | 348 ++++++----------------------- tests/test_io.py | 8 +- tests/test_mapping.py | 6 +- tests/test_preprocessing.py | 4 +- 10 files changed, 572 insertions(+), 708 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index f9cd104..8a0d1bb 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python3.6 +#!/usr/bin/env python3.11 """ Author: Patrick Roelli """ @@ -23,28 +23,25 @@ def main(): # Check a few path before doing anything if not os.access(args.temp_path, os.W_OK): sys.exit( - f"Temp folder: {args.temp_path} is not writable. Please check permissions and/or change temp folder." + f"Temp folder: {args.temp_path} is not writable." + f"Please check permissions and/or change temp folder." ) if not os.access(os.path.dirname(os.path.abspath(args.outfolder)), os.W_OK): sys.exit( - f"Output folder: {args.outfolder} is not writable. Please check permissions and/or change output folder." + f"Output folder: {args.outfolder} is not writable." + f"Please check permissions and/or change output folder." ) # Get chemistry defs - (translation_dict, chemistry_def) = chemistry.setup_chemistry(args) + (barcode_reference, chemistry_def) = chemistry.setup_chemistry(args) # Load TAGs/ABs. - ab_map = preprocessing.parse_tags_csv(args.tags) - ordered_tags, longest_tag_len = preprocessing.check_tags(ab_map, args.max_error) + parsed_tags = preprocessing.parse_tags_csv(args.tags) + longest_tag_len = preprocessing.check_tags(parsed_tags, args.max_error) # Identify input file(s) read1_paths, read2_paths = io.get_read_paths(args.read1_path, args.read2_path) - # Check filtered input list - # If a translation is given, will return the translated version - filtered_cells = preprocessing.get_filtered_list( - args=args, chemistry=chemistry_def, translation_dict=translation_dict - ) # Checks before chunking. (n_reads, r2_min_length, maximum_distance) = preprocessing.pre_run_checks( read1_paths=read1_paths, @@ -52,60 +49,47 @@ def main(): longest_tag_len=longest_tag_len, args=args, ) - - # Chunk the data to disk before mapping ( - input_queue, - temp_files, + temp_file, r1_too_short, r2_too_short, total_reads, - ) = io.write_chunks_to_disk( + ) = io.write_mapping_input( args=args, read1_paths=read1_paths, read2_paths=read2_paths, r2_min_length=r2_min_length, - n_reads_per_chunk=n_reads, chemistry_def=chemistry_def, - ordered_tags=ordered_tags, - maximum_distance=maximum_distance, - ) - # Map the data - (final_results, umis_per_cell, _, merged_no_match) = mapping.map_data( - input_queue=input_queue, unmapped_id=len(ordered_tags), args=args ) - # Check if 99% of the reads are unmapped. - mapping.check_unmapped( - no_match=merged_no_match, - too_short=r1_too_short + r2_too_short, - total_reads=total_reads, - start_trim=chemistry_def.r2_trim_start, + mapped_reads = mapping.map_reads_hybrid( + mapping_input_file=temp_file, + parsed_tags=parsed_tags, + maximum_distance=maximum_distance, ) + # Remove temp file + os.remove(temp_file) + # Check if 99% of the reads are unmapped. + mapping.check_unmapped(mapped_reads=mapped_reads) - # Remove temp chunks - for file_path in temp_files: - os.remove(file_path) - - # Check that we have a filtered cell list to work on - filtered_cells = processing.check_filtered_cells( - filtered_cells=filtered_cells, - expected_cells=args.expected_cells, - umis_per_cell=umis_per_cell, + # Check filtered input list + # If a translation is given, will return the translated version + barcode_subset, enable_barcode_correction = preprocessing.get_barcode_subset( + args=args, + chemistry=chemistry_def, + barcode_reference=barcode_reference, + mapped_reads=mapped_reads, ) # Correct cell barcodes - if args.bc_threshold > 0: + if args.bc_threshold > 0 and enable_barcode_correction: ( - final_results, - umis_per_cell, + mapped_reads, bcs_corrected, - ) = processing.run_cell_barcode_correction( - final_results=final_results, - umis_per_cell=umis_per_cell, - ordered_tags=ordered_tags, - filtered_set=filtered_cells, - args=args, + ) = processing.correct_barcodes( + mapped_reads=mapped_reads, + barcode_subset=barcode_subset, + collapsing_threshold=args.bc_threshold, ) else: print("Skipping cell barcode correction") @@ -114,14 +98,14 @@ def main(): # Create sparse matrices for reads results read_results_matrix = processing.generate_sparse_matrices( final_results=final_results, - ordered_tags=ordered_tags, + parsed_tags=parsed_tags, filtered_cells=filtered_cells, ) # Write reads to file io.write_to_files( sparse_matrix=read_results_matrix, filtered_cells=filtered_cells, - ordered_tags=ordered_tags, + parsed_tags=parsed_tags, data_type="read", outfolder=args.outfolder, translation_dict=translation_dict, @@ -137,7 +121,7 @@ def main(): ) = processing.run_umi_correction( final_results=final_results, filtered_cells=filtered_cells, - unmapped_id=len(ordered_tags), + unmapped_id=len(parsed_tags), args=args, ) else: @@ -153,13 +137,13 @@ def main(): # Create sparse clustered cells matrix umi_clustered_matrix = processing.generate_sparse_matrices( final_results=final_results, - ordered_tags=ordered_tags, + parsed_tags=parsed_tags, filtered_cells=clustered_cells, ) # Write uncorrected cells to dense output io.write_dense( sparse_matrix=umi_clustered_matrix, - ordered_tags=ordered_tags, + parsed_tags=parsed_tags, columns=clustered_cells, outfolder=os.path.join(args.outfolder, "uncorrected_cells"), filename="dense_umis.tsv", @@ -167,7 +151,7 @@ def main(): # Generate the UMI count matrix umi_results_matrix = processing.generate_sparse_matrices( final_results=final_results, - ordered_tags=ordered_tags, + parsed_tags=parsed_tags, filtered_cells=filtered_cells, umi_counts=True, ) @@ -176,7 +160,7 @@ def main(): io.write_to_files( sparse_matrix=umi_results_matrix, filtered_cells=filtered_cells, - ordered_tags=ordered_tags, + parsed_tags=parsed_tags, data_type="umi", outfolder=args.outfolder, translation_dict=translation_dict, @@ -212,7 +196,7 @@ def main(): print("Writing dense format output") io.write_dense( sparse_matrix=umi_results_matrix, - ordered_tags=ordered_tags, + parsed_tags=parsed_tags, columns=filtered_cells, outfolder=args.outfolder, filename="dense_umis.tsv", diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py index 0b281c1..8b51455 100644 --- a/cite_seq_count/argsparser.py +++ b/cite_seq_count/argsparser.py @@ -62,8 +62,8 @@ def get_args() -> ArgumentParser: prog="CITE-seq-Count", formatter_class=RawTextHelpFormatter, description=( - "This package counts matching antibody tags from paired fastq " - "files. Version {}".format(get_package_version()) + f"This package counts matching antibody tags from paired fastq " + f"files. Version {get_package_version}" ), ) @@ -175,39 +175,44 @@ def get_args() -> ArgumentParser: ) # Cell filtering group. We ask for either number of expected cells or a pre-filtered list of cells. - cells_filtering = parser.add_mutually_exclusive_group(required=True) + barcodes = parser.add_mutually_exclusive_group(required=True) - cells_filtering.add_argument( - "-n_cells", - "--expected_cells", - dest="expected_cells", + barcodes.add_argument( + "-n_barcodes", + "--expected_barcodes", + dest="expected_barcodes", type=int, - help=("Number of expected cells from your run."), + help=("Number of expected barcodes from your run."), default=0, ) - cells_filtering.add_argument( - "-fl", + barcodes.add_argument( + "-fb", "-wl", - "--filtered_cells", - dest="filtered_cells", + "--filtered_barcodes", + dest="filtered_barcodes", type=str, - help=("A specific list of cells to look for."), + help=( + "A path to a specific list of barcodes to look for." + "\tExample:\n" + "\twhitelist\n" + "\tAAACCCAAGAAACACT\nAAACCCAAGAAACCAT\nAAACCCAAGAAACCCA\n" + ), default=False, ) if "--chemistry" not in sys.argv: barcodes.add_argument( - "-tl", - "--translation_list", - dest="translation_list", + "-br", + "--barcode_reference", + dest="barcode_reference", required=False, type=str, default=False, help=( - "A csv file containning a translation list of all potential barcodes\n\n" + "A csv file containning a barcode reference list of all potential barcodes\n\n" "\tExample:\n" - "whitelist,translation\n" - "\tAAACCCAAGAAACACT,AAACCCATCAAACACT\n\\AAACCCAAGAAACCAT,AAACCCATCAAACCAT\n\\AAACCCAAGAAACCCA,AAACCCATCAAACCCA\n\n" + "reference\n" + "\tAAACCCAAGAAACACT\nAAACCCAAGAAACCAT\nAAACCCAAGAAACCCA\n" ), ) diff --git a/cite_seq_count/chemistry.py b/cite_seq_count/chemistry.py index ecfc89f..5ccd09c 100644 --- a/cite_seq_count/chemistry.py +++ b/cite_seq_count/chemistry.py @@ -5,8 +5,9 @@ from dataclasses import dataclass - +from argparse import ArgumentParser from cite_seq_count import preprocessing +import polars as pl GLOBAL_LINK_RAW = "https://raw.githubusercontent.com/Hoohm/scg_lib_structs/10xv3_totalseq_b/chemistries/" GLOBAL_LINK_GITHUB = "https://github.com/Hoohm/scg_lib_structs/raw/10xv3_totalseq_b/" @@ -23,7 +24,8 @@ class Chemistry: umi_barcode_start: int umi_barcode_end: int r2_trim_start: int - translation_list_path: str + barcode_reference_path: str + holds_translation: bool DEFINITIONS_DB = pooch.create( @@ -41,7 +43,7 @@ class Chemistry: ) -def fetch_definitions(): +def fetch_definitions() -> dict: """ Load some sample gravity data to use in your docs. """ @@ -52,7 +54,7 @@ def fetch_definitions(): return json_data -def list_chemistries(all_chemistry_defs): +def list_chemistries(all_chemistry_defs: str) -> None: """ List all the available chemistries in the database Args: @@ -77,25 +79,25 @@ def list_chemistries(all_chemistry_defs): ) -def get_chemistry_definition(chemistry_short_name): +def get_chemistry_definition(chemistry_short_name: str) -> Chemistry: """ Fetches chemistry definitions from a remote definitions.json and returns the json. """ chemistry_defs = fetch_definitions()[chemistry_short_name] print(chemistry_defs) - if chemistry_defs["translation_list"]["path"] not in DEFINITIONS_DB.registry: + if chemistry_defs["barcode_reference"]["path"] not in DEFINITIONS_DB.registry: path = pooch.retrieve( url=os.path.join( GLOBAL_LINK_GITHUB, "chemistries", - chemistry_defs["translation_list"]["path"], + chemistry_defs["barcode_reference"]["path"], ), known_hash=None, - fname=chemistry_defs["translation_list"]["path"], + fname=chemistry_defs["barcode_reference"]["path"], path=DEFINITIONS_DB.abspath, ) else: - path = DEFINITIONS_DB.registry[chemistry_defs["translation_list"]["path"]] + path = DEFINITIONS_DB.registry[chemistry_defs["barcode_reference"]["path"]] chemistry_def = Chemistry( name=chemistry_short_name, cell_barcode_start=chemistry_defs["barcode_structure_indexes"]["cell_barcode"][ @@ -116,7 +118,7 @@ def get_chemistry_definition(chemistry_short_name): return chemistry_def -def create_chemistry_definition(args): +def create_chemistry_definition(args: ArgumentParser) -> Chemistry: chemistry_def = Chemistry( name="custom", cell_barcode_start=args.cb_first, @@ -124,28 +126,29 @@ def create_chemistry_definition(args): umi_barcode_start=args.umi_first, umi_barcode_end=args.umi_last, r2_trim_start=args.start_trim, - translation_list_path=args.translation_list, + barcode_reference_path=args.barcode_reference, + holds_translation=args.has_translation, ) return chemistry_def -def setup_chemistry(args): +def setup_chemistry(args: ArgumentParser) -> tuple[pl.DataFrame | None, Chemistry]: if args.chemistry_id: chemistry_def = get_chemistry_definition(args.chemistry_id) - translation_dict = preprocessing.parse_cell_list_csv( - filename=chemistry_def.translation_list_path, + barcode_reference = preprocessing.parse_barcode_reference( + filename=chemistry_def.barcode_reference_path, barcode_length=chemistry_def.cell_barcode_end - chemistry_def.cell_barcode_start + 1, ) else: chemistry_def = create_chemistry_definition(args) - if args.translation_list: - print("Loading translation_list") - translation_dict = preprocessing.parse_cell_list_csv( - filename=args.translation_list, + if args.barcode_reference: + print("Loading barcode reference") + barcode_reference = preprocessing.parse_barcode_reference( + filename=args.barcode_reference, barcode_length=args.cb_last - args.cb_first + 1, ) else: - translation_dict = False - return (translation_dict, chemistry_def) + barcode_reference = None + return (barcode_reference, chemistry_def) diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index dcee05a..5618806 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -10,6 +10,7 @@ import json from collections import namedtuple, Counter +from argparse import ArgumentParser from itertools import islice from typing import Tuple from pathlib import Path @@ -20,14 +21,15 @@ import pkg_resources import yaml import pandas as pd - +import polars as pl from scipy import io from cite_seq_count import secondsToText +from cite_seq_count.chemistry import Chemistry JSON_REPORT_PATH = pkg_resources.resource_filename(__name__, "templates/report.json") -def blocks(file, size: int = 65536): +def blocks(file: Path, size: int = 65536): """ A fast way of counting the lines of a large file. Ref: @@ -46,7 +48,7 @@ def blocks(file, size: int = 65536): yield partial_file -def get_n_lines(file_path: str) -> int: +def get_n_lines(file_path: Path) -> int: """ Determines how many lines have to be processed depending on options and number of available lines. @@ -123,7 +125,7 @@ def get_csv_reader_from_path(filename: str, sep: str = "\t") -> csv.reader: def write_to_files( sparse_matrix: scipy.sparse.coo_matrix, filtered_cells: set, - ordered_tags: dict, + parsed_tags: dict, data_type: str, outfolder: str, translation_dict: dict, @@ -133,7 +135,7 @@ def write_to_files( Args: sparse_matrix (dok_matrix): Results in a sparse matrix. filtered_cells (set): Set of cells that are selected for output. - ordered_tags (dict): Tags in order with indexes as values. + parsed_tags (dict): Tags in order with indexes as values. data_type (string): A string definning if the data is umi or read based. outfolder (string): Path to the output folder. """ @@ -158,7 +160,7 @@ def write_to_files( else: barcode_file.write(f"{barcode}\n".encode()) with gzip.open(os.path.join(prefix, "features.tsv.gz"), "wb") as feature_file: - for feature in ordered_tags: + for feature in parsed_tags: feature_file.write(f"{feature.sequence}\t{feature.name}\n".encode()) if data_type == "read": feature_file.write("{}\t{}\n".format("UNKNOWN", "unmapped").encode()) @@ -167,8 +169,9 @@ def write_to_files( shutil.copyfileobj(mtx_in, mtx_gz) os.remove(os.path.join(prefix, "matrix.mtx")) -def check_file(file_str:str) -> Path: - """Check that a file exists and is readable + +def check_file(file_str: str) -> Path: + """Check that a file exists and is readable. Args: file_str (str): Path to the file as a string @@ -183,7 +186,7 @@ def check_file(file_str:str) -> Path: def write_dense( sparse_matrix: scipy.sparse.coo_matrix, - ordered_tags: dict, + parsed_tags: dict, columns: set, outfolder: str, filename: str, @@ -201,9 +204,11 @@ def write_dense( prefix = os.path.join(outfolder) os.makedirs(prefix, exist_ok=True) index = [] - for tag in ordered_tags: + for tag in parsed_tags: index.append(tag.name) - pandas_dense = pd.DataFrame(sparse_matrix.todense(), columns=list(columns), index=index) + pandas_dense = pd.DataFrame( + sparse_matrix.todense(), columns=list(columns), index=index + ) pandas_dense.to_csv(os.path.join(outfolder, filename), sep="\t") @@ -254,8 +259,8 @@ def create_report( bad_cells, r1_too_short: int, r2_too_short: int, - args, - chemistry_def, + args: ArgumentParser, + chemistry_def: Chemistry, maximum_distance: int, ): """ @@ -318,13 +323,13 @@ def create_report( def write_chunks_to_disk( - args, - read1_paths, - read2_paths, - r2_min_length, - n_reads_per_chunk, - chemistry_def, - ordered_tags, + args: ArgumentParser, + read1_paths: list[Path], + read2_paths: list[Path], + r2_min_length: int, + n_reads_per_chunk: int, + chemistry_def: Chemistry, + parsed_tags: pl.DataFrame, maximum_distance, ): """ @@ -338,7 +343,7 @@ def write_chunks_to_disk( r2_min_length (int): Minimum length of read2 sequences. n_reads_per_chunk (int): How many reads per chunk. chemistry_def (namedtuple): Hols all the information about the chemistry definition. - ordered_tags (list): List of namedtuple tags. + parsed_tags (list): List of namedtuple tags. maximum_distance (int): Maximum hamming distance for mapping. """ mapping_input = namedtuple( @@ -377,7 +382,6 @@ def write_chunks_to_disk( reads_written = 0 for read1_path, read2_path in zip(read1_paths, read2_paths): - if enough_reads: break print(f"Reading reads from files: {read1_path}, {read2_path}") @@ -424,7 +428,7 @@ def write_chunks_to_disk( input_queue.append( mapping_input( filename=chunked_file_object.name, - tags=ordered_tags, + tags=parsed_tags, debug=args.debug, maximum_distance=maximum_distance, sliding_window=args.sliding_window, @@ -449,7 +453,7 @@ def write_chunks_to_disk( input_queue.append( mapping_input( filename=chunked_file_object.name, - tags=ordered_tags, + tags=parsed_tags, debug=args.debug, maximum_distance=maximum_distance, sliding_window=args.sliding_window, @@ -462,3 +466,93 @@ def write_chunks_to_disk( r2_too_short, total_reads, ) + + +def write_mapping_input( + args: ArgumentParser, + read1_paths: list[Path], + read2_paths: list[Path], + r2_min_length: int, + chemistry_def: Chemistry, +): + """ + Writes chunked files of reads to disk and prepares parallel + processing queue parameters. + + Args: + args(argparse): All parsed arguments. + read1_paths (list): List of R1 fastq.gz paths. + read2_paths (list): List of R2 fastq.gz paths. + r2_min_length (int): Minimum length of read2 sequences. + chemistry_def (namedtuple): Hols all the information about the chemistry definition. + parsed_tags (list): List of namedtuple tags. + maximum_distance (int): Maximum hamming distance for mapping. + """ + print("Writing chunks to disk") + + temp_path = os.path.abspath(args.temp_path) + r1_too_short = 0 + r2_too_short = 0 + total_reads = 0 + total_reads_written = 0 + + barcode_slice = slice( + chemistry_def.cell_barcode_start - 1, chemistry_def.cell_barcode_end + ) + umi_slice = slice( + chemistry_def.umi_barcode_start - 1, chemistry_def.umi_barcode_end + ) + + temp_file = tempfile.NamedTemporaryFile( + "w", dir=temp_path, suffix="_csc", delete=False + ) + reads_written = 0 + + for read1_path, read2_path in zip(read1_paths, read2_paths): + print(f"Reading reads from files: {read1_path}, {read2_path}") + with gzip.open(read1_path, "rt") as textfile1, gzip.open( + read2_path, "rt" + ) as textfile2: + secondlines = islice(zip(textfile1, textfile2), 1, None, 4) + + for read1, read2 in secondlines: + total_reads += 1 + + read1 = read1.strip() + if len(read1) < chemistry_def.umi_barcode_end: + r1_too_short += 1 + # The entire read is skipped + continue + if len(read2) < r2_min_length: + r2_too_short += 1 + # The entire read is skipped + continue + + read1_sliced = read1[ + chemistry_def.cell_barcode_start - 1 : chemistry_def.umi_barcode_end + ] + + read2_sliced = read2[ + chemistry_def.r2_trim_start : ( + r2_min_length + chemistry_def.r2_trim_start + ) + ] + temp_file.write( + "{},{},{}\n".format( + read1_sliced[barcode_slice], + read1_sliced[umi_slice], + read2_sliced, + ) + ) + + reads_written += 1 + total_reads_written += 1 + + temp_file.close() + + return ( + temp_file, + r1_too_short, + r2_too_short, + total_reads, + ) diff --git a/cite_seq_count/mapping.py b/cite_seq_count/mapping.py index 8c9196b..b5c562f 100644 --- a/cite_seq_count/mapping.py +++ b/cite_seq_count/mapping.py @@ -1,223 +1,171 @@ """Mapping module. Holds all code related to mapping reads """ -import time -import csv -import sys -import os - -from collections import Counter, defaultdict -import numpy as np -import Levenshtein - -# pylint: disable=no-name-in-module -from multiprocess import Pool - -from cite_seq_count.processing import merge_results -from cite_seq_count import secondsToText - - -def map_data(input_queue, unmapped_id, args): - """ - Maps the data given an input_queue - - Args: - input_queue (list): List of parameters to run in parallel - args (argparse): List of arguments - - Returns: - final_results (dict): final dictionnary with results - umis_per_cell (Counter): Counter of UMIs per cell - reads_per_cell (Counter): Counter of reads per cell - merged_no_match (Counter): Counter of unmapped reads - """ - # Initialize the counts dicts that will be generated from each input fastq pair - final_results = defaultdict(lambda: defaultdict(Counter)) - umis_per_cell = Counter() - reads_per_cell = Counter() - merged_no_match = Counter() - - print("Started mapping") - parallel_results = [] - - if args.n_threads == 1: - mapped_reads = map_reads(input_queue[0]) - parallel_results.append([mapped_reads]) - else: - pool = Pool(processes=args.n_threads) - errors = [] - mapping = pool.map_async( - map_reads, - input_queue, - callback=parallel_results.append, - error_callback=errors.append, +from pathlib import Path +import polars as pl +from rapidfuzz import fuzz, process + +from cite_seq_count.preprocessing import ( + SEQUENCE_COLUMN, + R2_COLUMN, + FEATURE_NAME_COLUMN, + BARCODE_COLUMN, + UMI_COLUMN, + UNMAPPED_NAME, +) + + +def find_best_match_rapid(tag_seq, tags_list, maximum_distance): + choices = tags_list[SEQUENCE_COLUMN].to_list() + features = tags_list[FEATURE_NAME_COLUMN].to_list() + res = process.extractOne(choices=choices, query=tag_seq, scorer=fuzz.QRatio) + min_score = (len(tag_seq) - maximum_distance) / len(tag_seq) * 100 + if res[1] >= min_score: + return features[res[2]] + return UNMAPPED_NAME + + +def map_reads_hybrid( + mapping_input_file: Path, parsed_tags: pl.DataFrame, maximum_distance: int +) -> pl.DataFrame: + input_reads = pl.read_csv( + mapping_input_file, + has_header=False, + new_columns=[BARCODE_COLUMN, UMI_COLUMN, R2_COLUMN], + ) + mapped_reads = input_reads.join( + parsed_tags, left_on=R2_COLUMN, right_on=SEQUENCE_COLUMN, how="left" + ).with_columns( + pl.when(pl.col(FEATURE_NAME_COLUMN).is_null()) + .then( + pl.col(R2_COLUMN) + .map_elements( + lambda x: find_best_match_rapid( + x, tags_list=parsed_tags, maximum_distance=maximum_distance + ) + ) + .alias(FEATURE_NAME_COLUMN) ) - mapping.wait() - - pool.close() - pool.join() - if len(errors) != 0: - for error in errors: - print(error) - - print("Merging results") - ( - final_results, - umis_per_cell, - reads_per_cell, - merged_no_match, - ) = merge_results(parallel_results=parallel_results[0], unmapped_id=unmapped_id) - - return final_results, umis_per_cell, reads_per_cell, merged_no_match - -def fast_dict_getter(seq, tags, unmapped_id): - return tags.get(seq, unmapped_id) - - - -def find_best_match(tag_seq, tags, maximum_distance): - """ - Find the best match from the list of tags. - - Compares the Levenshtein distance between tags and the trimmed sequences. - The tag and the sequence must have the same length. - If no matches found returns 'unmapped'. - We add 1 - Args: - tag_seq (string): Sequence from R2 already start trimmed - tags (dict): A dictionary with the TAGs as keys and TAG Names as values. - maximum_distance (int): Maximum distance given by the user. - - Returns: - best_match (string): The TAG name that will be used for counting. - """ - best_match = len(tags) - best_score = maximum_distance - for tag in tags: - # pylint: disable=no-member - score = Levenshtein.hamming(tag.sequence, tag_seq[: len(tag.sequence)]) - if score == 0: - # Best possible match - return tag.id - elif score <= best_score: - best_score = score - best_match = tag.id - return best_match - return best_match - - -def find_best_match_shift(tag_seq, tags): - """ - Find the best match from the list of tags with sliding window. - Only works with exact match. - Just checks if the string is in the sequence. - If no matches found returns 'unmapped'. - - Args: - tag_seq (string): Sequence from R2 already start trimmed - tags (dict): A dictionary with the TAGs as keys and TAG Names as values. - - Returns: - best_match (string): The TAG name that will be used for counting. - """ - best_match = "unmapped" - for tag in tags: - if tag.sequence in tag_seq: - return tag.name - return best_match - + .otherwise(pl.col(FEATURE_NAME_COLUMN)) + ) + return mapped_reads + + +def map_reads_pl_beta( + parsed_tags: pl.DataFrame, reads_input: pl.DataFrame, max_errors: int +) -> pl.DataFrame: + # First pass with exact matches + first_pass = reads_input.join( + parsed_tags, left_on="r2", right_on="sequence", how="left" + ) + first_pass_mapped = first_pass.filter(~pl.col("feature_name").is_null()) + parsed_tags_extended = parsed_tags.with_columns( + pl.col("sequence") + .str.replace_all("A", 0) + .str.replace_all("T", 1) + .str.replace_all("G", 2) + .str.replace_all("C", 3) + .alias("ref_num") + ).with_columns(pl.col("ref_num").str.parse_int(4)) + # Get unmapped reads + unmapped = first_pass.filter(pl.col("feature_name").is_null()) + unmapped_with_ns = unmapped.filter(pl.col("r2").str.contains("N")) + unmapped_without_ns = ( + unmapped.filter(~pl.col("r2").str.contains("N")) + .with_columns( + pl.col("r2") + .str.replace_all("A", 0) + .str.replace_all("T", 1) + .str.replace_all("G", 2) + .str.replace_all("C", 3) + .alias("tag_num") + ) + .with_columns(pl.col("tag_num").str.parse_int(4)) + ) + # Second pass based on closest + second_pass = ( + unmapped_without_ns.sort("tag_num") + .join_asof( + parsed_tags_extended.sort("ref_num"), left_on="tag_num", right_on="ref_num" + ) + .with_columns(diff=pl.col("ref_num") - pl.col("tag_num")) + .select(["barcode", "umi", "r2", "feature_name_right", "diff"]) + .rename({"feature_name_right": "feature_name"}) + .with_columns( + pl.when(pl.col("diff").abs() <= max_errors) + .then(pl.col("feature_name")) + .otherwise(None) + ) + ).drop("diff") + results = pl.concat([first_pass_mapped, second_pass, unmapped_with_ns]) + return results + + +def map_barcodes_pl( + cell_reference: pl.DataFrame, mapped_reads_input: pl.DataFrame, max_errors: int +) -> pl.DataFrame: + # First pass with exact matches + first_pass = mapped_reads_input.join( + cell_reference.with_row_count("barcode_id"), + left_on="barcode", + right_on="reference", + how="left", + ) + first_pass_mapped = first_pass.filter(~pl.col("barcode_id").is_null()) + ref_barcodes = first_pass_mapped.select("barcode").unique() + ref_cells_extended = ref_barcodes.with_columns( + pl.col("reference") + .str.replace_all("A", 0) + .str.replace_all("T", 1) + .str.replace_all("G", 2) + .str.replace_all("C", 3) + .alias("ref_barcode_num") + ).with_columns(pl.col("ref_barcode_num").str.parse_int(4)) + # Get unmapped reads + unmapped = first_pass.filter(pl.col("barcode_id").is_null()) + unmapped_with_ns = unmapped.filter(pl.col("barcode").str.contains("N")) + unmapped_without_ns = ( + unmapped.filter(~pl.col("barcode").str.contains("N")) + .with_columns( + pl.col("barcode") + .str.replace_all("A", 0) + .str.replace_all("T", 1) + .str.replace_all("G", 2) + .str.replace_all("C", 3) + .alias("barcode_num") + ) + .with_columns(pl.col("barcode_num").str.parse_int(4)) + ) + # Second pass based on closest + second_pass = ( + unmapped_without_ns.sort("barcode_num") + .join_asof( + ref_cells_extended.sort("ref_barcode_num"), + left_on="barcode_num", + right_on="ref_barcode_num", + ) + .with_columns(diff=pl.col("ref_barcode_num") - pl.col("barcode_num")) + .select(["barcode", "umi", "r2", "feature_name_right", "diff"]) + .rename({"feature_name_right": "feature_name"}) + .with_columns( + pl.when(pl.col("diff").abs() <= 1) + .then(pl.col("feature_name")) + .otherwise(None) + ) + ).drop("diff") + results = pl.concat([first_pass_mapped, second_pass, unmapped_with_ns]) + return results -def map_reads(mapping_input): - """Read through R1/R2 files and generate. - It reads both Read1 and Read2 files, creating a dict based on cell barcode. +def check_unmapped(mapped_reads: pl.DataFrame): + """_summary_ Args: - mapping_input (namedtuple): List of paramters to run in parallel. - filename (str): Path to the chunk file - tags (list): List of named tuples tags - debug (bool): Should debug information be shown or not - maximum_distance (int): Maximum distance given by the user - sliding_window (bool): A bool enabling a sliding window search + mapped_reads (pl.DataFrame): _description_ - Returns: - results (dict): A dict of dict of Counters with the mapping results. - no_match (Counter): A counter with unmapped sequences. """ - # Initiate values - (filename, tags, debug, maximum_distance, sliding_window) = mapping_input - print(f"Started mapping in child process {os.getpid()}") - results = {} - no_match = Counter() - n_reads = 1 - new_tags = {} - vec_dict_mapper = np.vectorize(fast_dict_getter) - for i in tags: - new_tags[i.sequence] = i.id - unmapped_id = len(tags) - # Progress info - current_time = time.time() - with open(filename, encoding="utf-8") as input_file: - reads = csv.reader(input_file) - for read in reads: - cell_barcode = read[0] - # This change in bytes is required by umi_tools for umi correction - UMI = bytes(read[1], "ascii") - read2 = read[2] - if n_reads % 1000000 == 0: - print( - "Processed 1,000,000 reads in {}. Total " - "reads: {:,} in child {}".format( - secondsToText.secondsToText(time.time() - current_time), - n_reads, - os.getpid(), - ) - ) - sys.stdout.flush() - current_time = time.time() - - if cell_barcode not in results: - results[cell_barcode] = defaultdict(Counter) - - if sliding_window: - best_match = find_best_match_shift(read2, tags) - else: - best_match = fast_dict_getter(read2, new_tags, len(tags)) - - results[cell_barcode][best_match][UMI] += 1 - - if best_match == unmapped_id: - no_match[read2] += 1 - - if debug: - print( - "cell_barcode:{}\tUMI:{}\ttag_seq:{}\n" - "cell barcode length:{}\tUMI length:{}\tTAG sequence length:{}\n" - "Best match is: {}\n".format( - cell_barcode, - UMI, - read2, - len(cell_barcode), - len(UMI), - len(read2), - tags[best_match].name, - ) - ) - sys.stdout.flush() - n_reads += 1 - print( - "Mapping done for process {}. Processed {:,} reads".format( - os.getpid(), n_reads - 1 - ) - ) - sys.stdout.flush() - - return (results, no_match) - - -def check_unmapped(no_match, too_short, total_reads, start_trim): - """Check if the number of unmapped is higher than 99%""" - sum_unmapped = sum(no_match.values()) + too_short - if sum_unmapped / total_reads > float(0.99): - sys.exit( - f"More than 99% of your data is unmapped.\nPlease check that your --start_trim {start_trim} parameter is correct and that your tags file is properly formatted" - ) + n_reads = mapped_reads.shape[0] + n_unmapped = mapped_reads.filter( + pl.col(FEATURE_NAME_COLUMN) == UNMAPPED_NAME + ).shape[0] + if n_reads / n_unmapped > 0.99: + SystemExit("Number of unmapped reads is more than 99%. Exiting") diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 0feec91..3b55a59 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -4,23 +4,43 @@ import gzip import sys from itertools import combinations, islice +from collections import Counter +from argparse import ArgumentParser +from pathlib import Path import Levenshtein +import polars as pl +import umi_tools.whitelist_methods as whitelist_method +from cite_seq_count.io import get_n_lines, check_file +from cite_seq_count.chemistry import Chemistry +s -from cite_seq_count.io import get_n_lines, check_file -import polars as pl -from pandas import read_csv -REQUIRED_TAGS_HEADER = ["sequence", "feature_name"] +# REQUIRED_TAGS_HEADER = ["sequence", "feature_name"] REQUIRED_CELLS_REF_HEADER = ["reference"] OPTIONAL_CELLS_REF_HEADER = ["translation"] -FEATURE_NAME = "feature_name" -SEQUENCE = "sequence" -REQUIRED_TAGS_HEADER = [FEATURE_NAME, SEQUENCE] +# Polars column names +# Tags input +FEATURE_NAME_COLUMN = "feature_name" +SEQUENCE_COLUMN = "sequence" +REQUIRED_TAGS_HEADER = [FEATURE_NAME_COLUMN, SEQUENCE_COLUMN] +# Reads input +BARCODE_COLUMN = "barcode" +CORRECTED_BARCODE_COLUMN = "corrected_barcode" +UMI_COLUMN = "umi" +R2_COLUMN = "r2" +# Barcode input +REFERENCE_COLUMN = "reference" +TRANSLATION_COLUMN = "translation" +WHITELIST_COLUMN = "whitelist" STRIP_CHARS = '"0123456789- \t\n' +UNMAPPED_NAME = "unmapped" -def parse_cell_list_csv(filename: str, barcode_length: int) -> pl.DataFrame: + +def parse_barcode_reference( + filename: str, barcode_length: int, required_header: str +) -> pl.DataFrame: """Reads white-listed barcodes from a CSV file. The function accepts plain barcodes or even 10X style barcodes with the @@ -35,11 +55,11 @@ def parse_cell_list_csv(filename: str, barcode_length: int) -> pl.DataFrame: """ file_path = check_file(filename) - cells_pl = pl.read_csv(file_path) + barcodes_pl = pl.read_csv(file_path.absolute()) barcode_pattern = rf"^[ATGC]{{{barcode_length}}}" - header = cells_pl.columns - set_dif = set(REQUIRED_CELLS_REF_HEADER) - set(header) + header = barcodes_pl.columns + set_dif = set(required_header) - set(header) if len(set_dif) != 0: set_diff_string = ",".join(list(set_dif)) raise SystemExit(f"The header is missing {set_diff_string}. Exiting") @@ -47,36 +67,36 @@ def parse_cell_list_csv(filename: str, barcode_length: int) -> pl.DataFrame: with_translation = True else: with_translation = False - # Prepare and validate cells_pl + # Prepare and validate barcodes_pl if with_translation: - cells_pl = cells_pl.with_columns( - reference=pl.col("reference").str.strip(STRIP_CHARS), - translation=pl.col("translation").str.strip(STRIP_CHARS), + barcodes_pl = barcodes_pl.with_columns( + reference=pl.col(REFERENCE_COLUMN).str.strip_chars(STRIP_CHARS), + translation=pl.col(TRANSLATION_COLUMN).str.strip_chars(STRIP_CHARS), ) else: - cells_pl = cells_pl.with_columns( - reference=pl.col("reference").str.strip(STRIP_CHARS), + barcodes_pl = barcodes_pl.with_columns( + reference=pl.col(REFERENCE_COLUMN).str.strip_chars(STRIP_CHARS), ) check_sequence_pattern( - df=cells_pl, + df=barcodes_pl, pattern=barcode_pattern, - column_name="reference", - file_type="Cell reference", + column_name=REFERENCE_COLUMN, + file_type="Barcode reference", expected_pattern="ATGC", ) if with_translation: check_sequence_pattern( - df=cells_pl, + df=barcodes_pl, pattern=barcode_pattern, - column_name="translation", - file_type="Cell reference", + column_name=TRANSLATION_COLUMN, + file_type="Barcode reference", expected_pattern="ATGC", ) - return cells_pl + return barcodes_pl def parse_tags_csv(file_name: str) -> pl.DataFrame: @@ -106,9 +126,12 @@ def parse_tags_csv(file_name: str) -> pl.DataFrame: file_header = set(data_pl.columns) set_diff = set(REQUIRED_TAGS_HEADER).difference(file_header) if len(set_diff) != 0: + print(set_diff) + print(len(set_diff)) set_diff_str = " AND ".join(set_diff) raise SystemExit( - f"The header of the tags file at {file_path} is missing the following header(s) {set_diff_str}" + f"The header of the tags file at {file_path}" + f"is missing the following header(s) {set_diff_str}" ) for column in REQUIRED_TAGS_HEADER: if not data_pl.filter(pl.col([column]) is None).is_empty(): @@ -118,7 +141,7 @@ def parse_tags_csv(file_name: str) -> pl.DataFrame: check_sequence_pattern( df=data_pl, pattern=atgc_test, - column_name=SEQUENCE, + column_name=SEQUENCE_COLUMN, file_type="tags", expected_pattern="ATGC", ) @@ -150,13 +173,15 @@ def check_sequence_pattern( ) if not regex_test.select(pl.col("regex").all()).get_column("regex").item(): sequences = ( - regex_test.filter(pl.col("regex") == False) + regex_test.filter(pl.col("regex") is False) .get_column(column_name) .to_list() ) sequences_str = "\n".join(sequences) raise SystemExit( - f"Some sequences in the {file_type} file is not only composed of the proper pattern {expected_pattern}.\nHere are the sequences{sequences_str}" + f"Some sequences in the {file_type} file is not only composed" + f"of the proper pattern {expected_pattern}.\n" + f"Here are the sequences{sequences_str}" ) @@ -177,10 +202,16 @@ def check_tags(tags_pl: pl.DataFrame, maximum_distance: int) -> int: int: the length of the longest TAG """ - - # Check if the distance is big enoughbetween tags + # TODO: Decide to keep or delete. + # # Check that all tags are the same length + # if tags_pl[SEQUENCE_COLUMN].str.lengths().unique().shape[0] != 1: + # SystemExit( + # "Tag sequences have different lengths. Version 2 can only run with one + # length. Please use an older version" + # ) + # Check if the distance is big enough between tags offending_pairs = [] - for tag_a, tag_b in combinations(tags_pl["sequence"], 2): + for tag_a, tag_b in combinations(tags_pl[SEQUENCE_COLUMN], 2): # pylint: disable=no-member distance = Levenshtein.distance(tag_a, tag_b) if distance <= (maximum_distance - 1): @@ -196,12 +227,12 @@ def check_tags(tags_pl: pl.DataFrame, maximum_distance: int) -> int: for pair in offending_pairs: print(f"\t{pair[0]}\n\t{pair[1]}\n\tDistance = {pair[2]}\n") sys.exit("Exiting the application.\n") - longest_tag_len = max(tags_pl["sequence"].str.n_chars()) + longest_tag_len = max(tags_pl[SEQUENCE_COLUMN].str.n_chars()) return longest_tag_len -def get_read_length(filename): +def get_read_length(filename: Path): """Check wether SEQUENCE lengths are consistent in the first 1000 reads from a FASTQ file and return the length. @@ -220,31 +251,16 @@ def get_read_length(filename): read_length = len(sequence.rstrip()) if temp_length != read_length: sys.exit( - "[ERROR] Sequence length in {} is not consistent. Please, trim all " - "sequences at the same length.\n" - "Exiting the application.\n".format(filename) + f"[ERROR] Sequence length in {filename} is not consistent." + f" Please, trim all sequences at the same length.\n" + f"Exiting the application.\n" ) return read_length -def translate_barcodes(cell_set, translation_dict): - """Translate a list of barcode using a mapping translation - Args: - cell_set (set): A set of barcodes - translation_dict (dict): A dict providing a simple key value translation - - Returns: - translated_barcodes (set): A set of translated barcodes - """ - - translated_barcodes = set() - for translated_barcode in translation_dict.keys(): - if translation_dict[translated_barcode] in cell_set: - translated_barcodes.add(translated_barcode) - return translated_barcodes - - -def check_barcodes_lengths(read1_length, cb_first, cb_last, umi_first, umi_last): +def check_barcodes_lengths( + read1_length: int, cb_first: int, cb_last: int, umi_first: int, umi_last: int +): """Check Read1 length against CELL and UMI barcodes length. Args: @@ -266,15 +282,19 @@ def check_barcodes_lengths(read1_length, cb_first, cb_last, umi_first, umi_last) ) elif barcode_umi_length < read1_length: print( - "[WARNING] Read1 length is {}bp but you are using {}bp for Cell " - "and UMI barcodes combined.\nThis might lead to wrong cell " - "attribution and skewed umi counts.\n".format( - read1_length, barcode_umi_length - ) + f"[WARNING] Read1 length is {read1_length}bp" + f"but you are using {barcode_umi_length}bp for Cell " + f"and UMI barcodes combined.\nThis might lead to wrong cell " + f"attribution and skewed umi counts.\n" ) -def pre_run_checks(read1_paths, chemistry_def, longest_tag_len, args): +def pre_run_checks( + read1_paths: list[Path], + chemistry_def: dict, + longest_tag_len: int, + args: ArgumentParser, +): """Checks that the chemistry is properly set and defines how many reads to process Args: @@ -329,29 +349,63 @@ def pre_run_checks(read1_paths, chemistry_def, longest_tag_len, args): return n_reads, r2_min_length, maximum_distance -def get_filtered_list(args, chemistry, translation_dict): +def get_barcode_subset( + args: ArgumentParser, + chemistry: Chemistry, + barcode_reference: pl.DataFrame | None, + mapped_reads: pl.DataFrame, +): """ - Determines what mode to use for cell barcode correction. - Args: - args(argparse): All arguments - - Returns: - set if we have a filtered list - None if we want correction and we have not a list - False if we deactivation filtering + Generate the barcode list used for barcode correction and subsetting """ + enable_barcode_correction = True if args.filtered_cells: - filtered_set = parse_filtered_list_csv( - args.filtered_cells, - (chemistry.cell_barcode_end - chemistry.cell_barcode_start), + barcode_subset = parse_barcode_reference( + filename=args.filtered_cells, + barcode_length=(chemistry.cell_barcode_end - chemistry.cell_barcode_start), + required_header=WHITELIST_COLUMN, ) - # Do we need to translate the list? - if args.translation_list: - # get the translation - translated_set = translate_barcodes( - cell_set=filtered_set, translation_dict=translation_dict - ) - return translated_set - return filtered_set else: - return None + n_barcodes = args.expected_barcodes + if barcode_reference is not None: + barcode_subset = ( + mapped_reads.filter( + pl.col(BARCODE_COLUMN).str.is_in( + barcode_reference[REFERENCE_COLUMN] + ) + ) + .group_by(BARCODE_COLUMN) + .agg(pl.count()) + .sort("count", descending=True) + .head(n_barcodes * 1.2) + .drop("count") + .rename({SEQUENCE_COLUMN: WHITELIST_COLUMN}) + ) + else: + raw_barcodes_dict = ( + mapped_reads.filter( + (~pl.col(BARCODE_COLUMN).str.count_matches("N")) + & (pl.col(FEATURE_NAME_COLUMN) != UNMAPPED_NAME) + ) + .group_by(BARCODE_COLUMN) + .agg(pl.count()) + .sort("count", descending=True) + ).to_dict() + barcode_counter = Counter( + zip(raw_barcodes_dict[BARCODE_COLUMN], raw_barcodes_dict["count"]) + ) + true_barcodes = whitelist_method.getKneeEstimateDensity( + cell_barcode_counts=barcode_counter, expect_cells=n_barcodes + ) + barcode_subset = pl.DataFrame( + true_barcodes, schema={WHITELIST_COLUMN: pl.Utf8} + ) + + if n_barcodes > barcode_subset.shape[0]: + print( + f"Number of expected cells, {n_barcodes}, is higher " + f"than number of cells found {barcode_subset.shape[0]}.\nNot performing " + f"cell barcode correction" + ) + enable_barcode_correction = False + return barcode_subset, enable_barcode_correction diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 1f4a6fe..19038bf 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -1,11 +1,9 @@ -import sys import os import Levenshtein import pybktree +import polars as pl -from collections import Counter -from collections import defaultdict from collections import namedtuple # pylint: disable=no-name-in-module @@ -15,305 +13,83 @@ from numpy import int32 from scipy import sparse from umi_tools import network -import umi_tools.whitelist_methods as whitelist_methods -def merge_results(parallel_results, unmapped_id): - """Merge chunked results from parallel processing. - - Args: - parallel_results (list): List of dict with mapping results. - - Returns: - merged_results (dict): Results combined as a dict of dict of Counters - umis_per_cell (Counter): Total umis per cell as a Counter - reads_per_cell (Counter): Total reads per cell as a Counter - merged_no_match (Counter): Unmapped tags as a Counter - """ - merged_results = {} - merged_no_match = Counter() - umis_per_cell = Counter() - reads_per_cell = Counter() - for chunk in parallel_results: - mapped = chunk[0] - unmapped = chunk[1] - for cell_barcode in mapped: - if cell_barcode not in merged_results: - merged_results[cell_barcode] = defaultdict(Counter) - for TAG in mapped[cell_barcode]: - # We don't want to capture unmapped data in the umi counts - if TAG == unmapped_id: - continue - # Test the counter. Returns false if empty - if mapped[cell_barcode][TAG]: - for UMI in mapped[cell_barcode][TAG]: - merged_results[cell_barcode][TAG][UMI] += mapped[cell_barcode][ - TAG - ][UMI] - umis_per_cell[cell_barcode] += len(mapped[cell_barcode][TAG]) - reads_per_cell[cell_barcode] += mapped[cell_barcode][TAG][UMI] - merged_no_match.update(unmapped) - return merged_results, umis_per_cell, reads_per_cell, merged_no_match - +from cite_seq_count.preprocessing import ( + WHITELIST_COLUMN, + BARCODE_COLUMN, + CORRECTED_BARCODE_COLUMN, +) # Unit Barcode correction -def collapse_cells(true_to_false, umis_per_cell, final_results, ab_map): - """ - Collapses cell barcodes based on the mapping true_to_false +def find_original_barcode(barcode: str, barcode_tree: pybktree.BKTree, distance: int): + """Pare a BKtree to find the original barcode to correct to. Args: - true_to_false (dict): Mapping between the translation and the "mutated" barcodes. - umis_per_cell (Counter): Counter of number of umis per cell. - final_results (dict): Dict of dict of Counters with mapping results. - ab_map (dict): Dict of the TAGS. + barcode (str): barcode to be corrected + barcode_tree (pybktree.BKTree): Barcode whitelist BKTree + distance (int): Hamming distance to search for Returns: - umis_per_cell (Counter): Counter of number of umis per cell. - final_results (dict): Same as input but with corrected cell barcodes. - corrected_barcodes (int): How many cell barcodes have been corrected. + barcode(str): corrected barcode """ - print("Collapsing cell barcodes") - corrected_barcodes = 0 - for real_barcode in true_to_false: - # If the cell barcode is not in the results - # add it in. - if real_barcode not in final_results: - final_results[real_barcode] = defaultdict() - for TAG in ab_map: - final_results[real_barcode][TAG.id] = Counter() - for wrong_barcode in true_to_false[real_barcode]: - temp = final_results.pop(wrong_barcode) - - for TAG in temp.keys(): - if TAG in final_results[real_barcode]: - final_results[real_barcode][TAG].update(temp[TAG]) - else: - final_results[real_barcode][TAG] = temp[TAG] - corrected_barcodes += 1 - temp_umi_counts = umis_per_cell.pop(wrong_barcode) - # temp_read_counts = reads_per_cell.pop(wrong_barcode) - - umis_per_cell[real_barcode] += temp_umi_counts - # reads_per_cell[real_barcode] += temp_read_counts - - return (umis_per_cell, final_results, corrected_barcodes) - - -def correct_cells_no_translation_list( - final_results, - reads_per_cell, - umis_per_cell, - collapsing_threshold, - expected_cells, - ab_map, -): - """ - Corrects cell barcodes without a translation. - - Args: - final_results (dict): Dict of dict of Counters with mapping results. - umis_per_cell (Counter): Counter of number of umis per cell. - collapsing_threshold (int): Max distance between umis. - expected_cells (int): Number of expected cells. - ab_map (dict): Dict of the TAGS. - - Returns: - final_results (dict): Same as input but with corrected umis. - umis_per_cell (Counter): Counter of umis per cell after cell barcode correction - corrected_umis (int): How many umis have been corrected. - """ - print("Looking for a reference list") - _, true_to_false = whitelist_methods.getCellWhitelist( - knee_method="density", - cell_barcode_counts=reads_per_cell, - expect_cells=expected_cells, - cell_number=expected_cells, - error_correct_threshold=collapsing_threshold, - plotfile_prefix=False, - ) - if true_to_false is None: - print("Failed to find a good reference list. Will not correct cell barcodes") - corrected_barcodes = 0 - return (final_results, umis_per_cell, corrected_barcodes) - (umis_per_cell, final_results, corrected_barcodes) = collapse_cells( - true_to_false=true_to_false, - umis_per_cell=umis_per_cell, - final_results=final_results, - ab_map=ab_map, - ) - return (final_results, umis_per_cell, corrected_barcodes) - - -def correct_cells_filtered_set( - final_results, umis_per_cell, filtered_set, collapsing_threshold, ab_map + candidates = [ + white_cell for d, white_cell in barcode_tree.find(barcode, distance) if d > 0 + ] + if len(candidates) == 1: + barcode = candidates[0] + return barcode + + +def correct_barcodes( + mapped_reads: pl.DataFrame, + barcode_subset: pl.DataFrame, + collapsing_threshold: int, ): - """ - Corrects cell barcodes based on a given translation_list. + """Correct barcodes based on a given whitelist Args: - final_results (dict): Dict of dict of Counters with mapping results. - umis_per_cell (Counter): Counter of UMIs per cell. - translation_list (set): The translation_list translation given by the user. - collapsing_threshold (int): Max distance between umis. - ab_map (OrederedDict): Tags in an ordered dict. - + mapped_reads (pl.DataFrame): Mapped reads + barcode_subset (pl.DataFrame): Given whitelist + collapsing_threshold (int): Hamming distance to correct on Returns: - final_results (dict): Same as input but with corrected umis. - umis_per_cell (Counter): Updated UMI counts after correction. - corrected_barcodes (int): How many umis have been corrected. + mapped_reads_barcode_corrected (pl.DataFrame): Barcode corrected mapped reads + n_corrected_barcodes (int): Number of corrected barcodes """ print("Generating barcode tree from reference list") - # pylint: disable=no-member - barcode_tree = pybktree.BKTree(Levenshtein.hamming, filtered_set) - barcodes = set(umis_per_cell) - print("Selecting reference candidates") - print(f"Processing {len(barcodes):,} cell barcodes") - - # Run with one process - true_to_false = find_true_to_false_map( - barcode_tree=barcode_tree, - cell_barcodes=barcodes, - filtered_set=filtered_set, - collapsing_threshold=collapsing_threshold, + barcode_tree = pybktree.BKTree( + Levenshtein.hamming, barcode_subset[WHITELIST_COLUMN].to_list() ) - print("Collapsing wrong barcodes with original barcodes") - (umis_per_cell, final_results, corrected_barcodes) = collapse_cells( - true_to_false, umis_per_cell, final_results, ab_map - ) - return (final_results, umis_per_cell, corrected_barcodes) - - -def find_true_to_false_map( - barcode_tree, cell_barcodes, filtered_set, collapsing_threshold -): - """ - Creates a mapping between "fake" cell barcodes and their original true barcode. - - Args: - barcode_tree (BKTree): BKTree of all original cell barcodes. - cell_barcodes (List): Cell barcodes to go through. - filtered_set (dict): Dict of the filtered_set, the "true" cell barcodes. - collasping_threshold (int): How many mistakes to correct. - - Return: - true_to_false (defaultdict(list)): Contains the mapping between the fake and real barcodes. The key is the real one. - """ - true_to_false = defaultdict(list) - for cell_barcode in cell_barcodes: - if cell_barcode in filtered_set: - # if the barcode is already filtered_set, no need to add - continue - # get all members of filtered_set that are at distance of collapsing_threshold - candidates = [ - white_cell - for d, white_cell in barcode_tree.find(cell_barcode, collapsing_threshold) - if d > 0 - ] - if len(candidates) == 1: - white_cell_str = candidates[0] - true_to_false[white_cell_str].append(cell_barcode) - else: - # the cell doesnt match to any filtered_set barcode, - # hence we have to drop it - # (as it cannot be asscociated with any frequent barcode) - continue - return true_to_false - - -def run_cell_barcode_correction( - final_results, - umis_per_cell, - ordered_tags, - filtered_set, - args, -): - if args.expected_cells > len(filtered_set): - print( - "Number of expected cells, {}, is higher " - "than number of cells found {}.\nNot performing " - "cell barcode correction" - "".format(args.expected_cells, len(umis_per_cell)) + print("Finding original barcodes") + barcode_mapping = ( + mapped_reads.select(pl.col(BARCODE_COLUMN)) + .unique() + .filter(~pl.col(BARCODE_COLUMN).is_in(barcode_subset[WHITELIST_COLUMN])) + .with_columns( + pl.col(BARCODE_COLUMN) + .map_elements( + lambda x: find_original_barcode( + x, barcode_tree=barcode_tree, distance=collapsing_threshold + ) + ) + .alias(CORRECTED_BARCODE_COLUMN) ) - bcs_corrected = 0 - return final_results, umis_per_cell, bcs_corrected - - elif type(filtered_set) == set: - (final_results, umis_per_cell, bcs_corrected,) = correct_cells_filtered_set( - final_results=final_results, - umis_per_cell=umis_per_cell, - filtered_set=filtered_set, - collapsing_threshold=args.bc_threshold, - ab_map=ordered_tags, + .with_columns( + barcode_corrected=pl.col(BARCODE_COLUMN) == pl.col(CORRECTED_BARCODE_COLUMN) ) - for missing_cell in filtered_set: - if missing_cell in final_results: - continue - else: - final_results[missing_cell] = dict() - for TAG in ordered_tags: - final_results[missing_cell][TAG.id] = Counter() - return final_results, umis_per_cell, bcs_corrected - - -def check_filtered_cells(filtered_cells, expected_cells, umis_per_cell): - if filtered_cells is None: - top_cells_tuple = umis_per_cell.most_common(expected_cells) - # Select top cells based on total umis per cell - filtered_cells = {pair[0] for pair in top_cells_tuple} - return filtered_cells - - -# def choose_filtered_cells( -# given_filtered_cells, -# expected_cells, -# chemistry_def, -# final_results, -# ordered_tags, -# umis_per_cell, -# translation_dict, -# ): -# """ -# Returns a list of barcodes that will be in the output -# and helps decide based on the inputs. - -# Args: -# given_filtered_cells (bool or str): False if not given, else string -# expected_cells (int): Number of expected cells -# chemistry_def (Chemistry): Defines the details of the chemistry -# final_results (dict): All results -# ordered_tags (named_tuple): Holds tags info -# umis_per_cell (Counter): Holds number of UMIs per barcode - -# Returns: -# set: filtered cell set -# """ -# # If given, use filtered_list for top cells -# if given_filtered_cells: -# filtered_cells = set( -# parse_cell_list_csv( -# filename=given_filtered_cells, -# barcode_length=chemistry_def.cell_barcode_end -# - chemistry_def.cell_barcode_start -# + 1, -# file_type="filtered", -# ).keys() -# ) -# # Add potential missing cell barcodes. -# for missing_cell in filtered_cells: -# if missing_cell in final_results: -# continue -# else: -# final_results[missing_cell] = dict() -# for TAG in ordered_tags: -# final_results[missing_cell][TAG.name] = Counter() -# filtered_cells.add(missing_cell) -# else: -# top_cells_tuple = umis_per_cell.most_common(expected_cells) -# # Select top cells based on total umis per cell -# filtered_cells = [pair[0] for pair in top_cells_tuple] + .drop(BARCODE_COLUMN) + ) + print("Collapsing wrong barcodes with original barcodes") + mapped_reads_barcode_corrected = mapped_reads.join( + barcode_mapping.drop(BARCODE_COLUMN), + left_on=BARCODE_COLUMN, + right_on=CORRECTED_BARCODE_COLUMN, + ).drop("corrected") + n_corrected_barcodes = barcode_mapping.filter(pl.col("corrected")).shape[0] + return mapped_reads_barcode_corrected, n_corrected_barcodes # UMI correction section @@ -462,25 +238,25 @@ def run_umi_correction(final_results, filtered_cells, unmapped_id, args): def generate_sparse_matrices( - final_results, ordered_tags, filtered_cells, umi_counts=False + final_results, parsed_tags, filtered_cells, umi_counts=False ): """ Create two sparse matrices with umi and read counts. Args: final_results (dict): Results in a dict of dicts of Counters. - ordered_tags (list): Ordered tags in a list of tuples. + parsed_tags (list): Ordered tags in a list of tuples. Returns: results_matrix (scipy.sparse.dok_matrix): UMI or Read counts """ - unmapped_id = len(ordered_tags) + unmapped_id = len(parsed_tags) if umi_counts: - n_features = len(ordered_tags) + n_features = len(parsed_tags) else: - n_features = len(ordered_tags) + 1 + n_features = len(parsed_tags) + 1 results_matrix = sparse.dok_matrix((n_features, len(filtered_cells)), dtype=int32) for i, cell_barcode in enumerate(filtered_cells): diff --git a/tests/test_io.py b/tests/test_io.py index f407e5b..9b4593f 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -48,7 +48,7 @@ def data(): pytest.sparse_matrix = test_matrix pytest.filtered_cells = ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"] tag = namedtuple("tag", ["name", "sequence", "id"]) - pytest.ordered_tags_map = [ + pytest.parsed_tags_map = [ tag(name="test1", sequence="CGTA", id=0), tag(name="test2", sequence="CGTA", id=1), tag(name="test3", sequence="CGTA", id=2), @@ -73,7 +73,7 @@ def test_write_to_files_wo_translation(data, tmpdir): io.write_to_files( pytest.sparse_matrix, pytest.filtered_cells, - pytest.ordered_tags_map, + pytest.parsed_tags_map, pytest.data_type, output_path, translation_dict=False, @@ -106,7 +106,7 @@ def test_write_to_files_with_translation(data, tmpdir): io.write_to_files( pytest.sparse_matrix, pytest.filtered_cells, - pytest.ordered_tags_map, + pytest.parsed_tags_map, pytest.data_type, output_path, translation_dict=translation_dict, @@ -131,7 +131,7 @@ def test_write_to_dense_wo_translation(data, tmpdir): io.write_dense( sparse_matrix=pytest.sparse_matrix, - ordered_tags=pytest.ordered_tags_map, + parsed_tags=pytest.parsed_tags_map, columns=pytest.filtered_cells, outfolder=output_path, filename=csv_name, diff --git a/tests/test_mapping.py b/tests/test_mapping.py index 3d9cbb2..080420a 100644 --- a/tests/test_mapping.py +++ b/tests/test_mapping.py @@ -69,9 +69,9 @@ def data(): pytest.sliding_window = False pytest.sequence_pool = [] - pytest.tags_tuple = preprocessing.check_tags( - preprocessing.parse_tags_csv("tests/test_data/tags/pass/correct.csv"), 5 - )[0] + pytest.tags_tuple = preprocessing.parse_tags_csv( + preprocessing.parse_tags_csv("tests/test_data/tags/pass/correct.csv") + ) pytest.mapping_input = namedtuple( "mapping_input", ["filename", "tags", "debug", "maximum_distance", "sliding_window"], diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 3b9f203..606ba62 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -79,13 +79,13 @@ def test_parse_reference_list_csv(data): passing_files = glob.glob(pytest.passing_reference_list_csv) for file_path in passing_files: assert_frame_equal( - left=preprocessing.parse_cell_list_csv(file_path, 16), + left=preprocessing.parse_barcode_reference(file_path, 16), right=pytest.correct_reference_translation_list, ) with pytest.raises(SystemExit): failing_files = glob.glob(pytest.failing_reference_list_csv) for file_path in failing_files: - preprocessing.parse_cell_list_csv(file_path, 16) + preprocessing.parse_barcode_reference(file_path, 16) def test_check_distance_too_big_between_tags(data): From 44bf8491a8cdab1d0ebee828ea8553b5a6d36cf4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patrick=20R=C3=B6lli?= Date: Tue, 31 Oct 2023 21:20:08 +0100 Subject: [PATCH 63/77] feat: rewriting mapping, barcode correction in polars --- cite_seq_count/__main__.py | 33 ++++--- cite_seq_count/chemistry.py | 6 +- cite_seq_count/io.py | 3 +- cite_seq_count/mapping.py | 148 +++++++------------------------- cite_seq_count/preprocessing.py | 49 +++++++---- cite_seq_count/processing.py | 60 ++++++++++++- tests/test_preprocessing.py | 1 - 7 files changed, 143 insertions(+), 157 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 8a0d1bb..0e36e13 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -10,7 +10,7 @@ def main(): - """Main function""" + """Main""" start_time = time.time() parser = argsparser.get_args() @@ -61,35 +61,34 @@ def main(): r2_min_length=r2_min_length, chemistry_def=chemistry_def, ) - - mapped_reads = mapping.map_reads_hybrid( - mapping_input_file=temp_file, - parsed_tags=parsed_tags, - maximum_distance=maximum_distance, + input_df, barcodes_df, r2_df = preprocessing.split_data_input( + mapping_input_path=temp_file ) # Remove temp file os.remove(temp_file) - # Check if 99% of the reads are unmapped. - mapping.check_unmapped(mapped_reads=mapped_reads) + mapped_r2_df = mapping.map_reads_hybrid( + r2_df=r2_df, + parsed_tags=parsed_tags, + maximum_distance=maximum_distance, + ) - # Check filtered input list - # If a translation is given, will return the translated version barcode_subset, enable_barcode_correction = preprocessing.get_barcode_subset( - args=args, + barcode_whitelist=args.filtered_barcodes, + expected_barcodes=args.expected_barcodes, chemistry=chemistry_def, barcode_reference=barcode_reference, - mapped_reads=mapped_reads, + barcodes_df=barcodes_df, ) # Correct cell barcodes if args.bc_threshold > 0 and enable_barcode_correction: ( - mapped_reads, + barcode_corrected_df, bcs_corrected, - ) = processing.correct_barcodes( - mapped_reads=mapped_reads, - barcode_subset=barcode_subset, - collapsing_threshold=args.bc_threshold, + ) = processing.correct_barcodes_pl( + barcodes_df=barcodes_df, + barcode_subset_df=barcode_subset, + hamming_distance=args.bc_threshold, ) else: print("Skipping cell barcode correction") diff --git a/cite_seq_count/chemistry.py b/cite_seq_count/chemistry.py index 5ccd09c..e187f7b 100644 --- a/cite_seq_count/chemistry.py +++ b/cite_seq_count/chemistry.py @@ -25,7 +25,6 @@ class Chemistry: umi_barcode_end: int r2_trim_start: int barcode_reference_path: str - holds_translation: bool DEFINITIONS_DB = pooch.create( @@ -113,7 +112,7 @@ def get_chemistry_definition(chemistry_short_name: str) -> Chemistry: "R1" ]["stop"], r2_trim_start=chemistry_defs["sequence_structure_indexes"]["R2"]["start"] - 1, - translation_list_path=path, + barcode_reference_path=path, ) return chemistry_def @@ -127,7 +126,6 @@ def create_chemistry_definition(args: ArgumentParser) -> Chemistry: umi_barcode_end=args.umi_last, r2_trim_start=args.start_trim, barcode_reference_path=args.barcode_reference, - holds_translation=args.has_translation, ) return chemistry_def @@ -140,6 +138,7 @@ def setup_chemistry(args: ArgumentParser) -> tuple[pl.DataFrame | None, Chemistr barcode_length=chemistry_def.cell_barcode_end - chemistry_def.cell_barcode_start + 1, + required_header=["reference"], ) else: chemistry_def = create_chemistry_definition(args) @@ -148,6 +147,7 @@ def setup_chemistry(args: ArgumentParser) -> tuple[pl.DataFrame | None, Chemistr barcode_reference = preprocessing.parse_barcode_reference( filename=args.barcode_reference, barcode_length=args.cb_last - args.cb_first + 1, + required_header=["reference"], ) else: barcode_reference = None diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 5618806..8d46283 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -506,6 +506,7 @@ def write_mapping_input( temp_file = tempfile.NamedTemporaryFile( "w", dir=temp_path, suffix="_csc", delete=False ) + temp_file_path = temp_file.name reads_written = 0 for read1_path, read2_path in zip(read1_paths, read2_paths): @@ -551,7 +552,7 @@ def write_mapping_input( temp_file.close() return ( - temp_file, + temp_file_path, r1_too_short, r2_too_short, total_reads, diff --git a/cite_seq_count/mapping.py b/cite_seq_count/mapping.py index b5c562f..36a5025 100644 --- a/cite_seq_count/mapping.py +++ b/cite_seq_count/mapping.py @@ -2,7 +2,7 @@ """ from pathlib import Path import polars as pl -from rapidfuzz import fuzz, process +from rapidfuzz import fuzz, process, distance from cite_seq_count.preprocessing import ( SEQUENCE_COLUMN, @@ -24,15 +24,38 @@ def find_best_match_rapid(tag_seq, tags_list, maximum_distance): return UNMAPPED_NAME +def match_generic_string_dfs( + ref_df: pl.DataFrame, + target_df: pl.DataFrame, + left_on: str, + right_on: str, + hamming_distance: int, +): + corrected_column = "corrected_" + left_on + joined = ( + target_df.sort(left_on) + .join_asof(ref_df.sort(right_on), left_on=left_on, right_on=right_on) + .with_columns( + pl.when(pl.col(left_on) == pl.col(right_on)) + .then(False) + .otherwise(True) + .alias(corrected_column) + ) + .with_columns( + pl.when(pl.col(corrected_column)) + .then(distance.Hamming.distance(s1=pl.col(left_on), s2=pl.col(right_on))) + .otherwise(0) + .alias("hamming_distance") + ) + ) + return joined + + def map_reads_hybrid( - mapping_input_file: Path, parsed_tags: pl.DataFrame, maximum_distance: int + r2_df: pl.DataFrame, parsed_tags: pl.DataFrame, maximum_distance: int ) -> pl.DataFrame: - input_reads = pl.read_csv( - mapping_input_file, - has_header=False, - new_columns=[BARCODE_COLUMN, UMI_COLUMN, R2_COLUMN], - ) - mapped_reads = input_reads.join( + print("Mapping reads") + mapped_r2_df = r2_df.join( parsed_tags, left_on=R2_COLUMN, right_on=SEQUENCE_COLUMN, how="left" ).with_columns( pl.when(pl.col(FEATURE_NAME_COLUMN).is_null()) @@ -47,113 +70,8 @@ def map_reads_hybrid( ) .otherwise(pl.col(FEATURE_NAME_COLUMN)) ) - return mapped_reads - - -def map_reads_pl_beta( - parsed_tags: pl.DataFrame, reads_input: pl.DataFrame, max_errors: int -) -> pl.DataFrame: - # First pass with exact matches - first_pass = reads_input.join( - parsed_tags, left_on="r2", right_on="sequence", how="left" - ) - first_pass_mapped = first_pass.filter(~pl.col("feature_name").is_null()) - parsed_tags_extended = parsed_tags.with_columns( - pl.col("sequence") - .str.replace_all("A", 0) - .str.replace_all("T", 1) - .str.replace_all("G", 2) - .str.replace_all("C", 3) - .alias("ref_num") - ).with_columns(pl.col("ref_num").str.parse_int(4)) - # Get unmapped reads - unmapped = first_pass.filter(pl.col("feature_name").is_null()) - unmapped_with_ns = unmapped.filter(pl.col("r2").str.contains("N")) - unmapped_without_ns = ( - unmapped.filter(~pl.col("r2").str.contains("N")) - .with_columns( - pl.col("r2") - .str.replace_all("A", 0) - .str.replace_all("T", 1) - .str.replace_all("G", 2) - .str.replace_all("C", 3) - .alias("tag_num") - ) - .with_columns(pl.col("tag_num").str.parse_int(4)) - ) - # Second pass based on closest - second_pass = ( - unmapped_without_ns.sort("tag_num") - .join_asof( - parsed_tags_extended.sort("ref_num"), left_on="tag_num", right_on="ref_num" - ) - .with_columns(diff=pl.col("ref_num") - pl.col("tag_num")) - .select(["barcode", "umi", "r2", "feature_name_right", "diff"]) - .rename({"feature_name_right": "feature_name"}) - .with_columns( - pl.when(pl.col("diff").abs() <= max_errors) - .then(pl.col("feature_name")) - .otherwise(None) - ) - ).drop("diff") - results = pl.concat([first_pass_mapped, second_pass, unmapped_with_ns]) - return results - - -def map_barcodes_pl( - cell_reference: pl.DataFrame, mapped_reads_input: pl.DataFrame, max_errors: int -) -> pl.DataFrame: - # First pass with exact matches - first_pass = mapped_reads_input.join( - cell_reference.with_row_count("barcode_id"), - left_on="barcode", - right_on="reference", - how="left", - ) - first_pass_mapped = first_pass.filter(~pl.col("barcode_id").is_null()) - ref_barcodes = first_pass_mapped.select("barcode").unique() - ref_cells_extended = ref_barcodes.with_columns( - pl.col("reference") - .str.replace_all("A", 0) - .str.replace_all("T", 1) - .str.replace_all("G", 2) - .str.replace_all("C", 3) - .alias("ref_barcode_num") - ).with_columns(pl.col("ref_barcode_num").str.parse_int(4)) - # Get unmapped reads - unmapped = first_pass.filter(pl.col("barcode_id").is_null()) - unmapped_with_ns = unmapped.filter(pl.col("barcode").str.contains("N")) - unmapped_without_ns = ( - unmapped.filter(~pl.col("barcode").str.contains("N")) - .with_columns( - pl.col("barcode") - .str.replace_all("A", 0) - .str.replace_all("T", 1) - .str.replace_all("G", 2) - .str.replace_all("C", 3) - .alias("barcode_num") - ) - .with_columns(pl.col("barcode_num").str.parse_int(4)) - ) - # Second pass based on closest - second_pass = ( - unmapped_without_ns.sort("barcode_num") - .join_asof( - ref_cells_extended.sort("ref_barcode_num"), - left_on="barcode_num", - right_on="ref_barcode_num", - ) - .with_columns(diff=pl.col("ref_barcode_num") - pl.col("barcode_num")) - .select(["barcode", "umi", "r2", "feature_name_right", "diff"]) - .rename({"feature_name_right": "feature_name"}) - .with_columns( - pl.when(pl.col("diff").abs() <= 1) - .then(pl.col("feature_name")) - .otherwise(None) - ) - ).drop("diff") - results = pl.concat([first_pass_mapped, second_pass, unmapped_with_ns]) - return results + print("Mapping done") + return mapped_r2_df def check_unmapped(mapped_reads: pl.DataFrame): diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 3b55a59..baff2d1 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -13,8 +13,6 @@ from cite_seq_count.io import get_n_lines, check_file from cite_seq_count.chemistry import Chemistry -s - # REQUIRED_TAGS_HEADER = ["sequence", "feature_name"] REQUIRED_CELLS_REF_HEADER = ["reference"] @@ -349,27 +347,49 @@ def pre_run_checks( return n_reads, r2_min_length, maximum_distance +def split_data_input(mapping_input_path: Path): + input_df = ( + pl.read_csv( + mapping_input_path, + has_header=False, + new_columns=[BARCODE_COLUMN, UMI_COLUMN, R2_COLUMN], + ) + .group_by([BARCODE_COLUMN, UMI_COLUMN, R2_COLUMN]) + .agg(pl.count()) + ) + + barcodes_df = ( + input_df.select([BARCODE_COLUMN, "count"]) + .group_by(BARCODE_COLUMN) + .agg(pl.sum("count")) + ) + r2_df = input_df.select(R2_COLUMN).unique() + + return input_df, barcodes_df, r2_df + + def get_barcode_subset( - args: ArgumentParser, + barcode_whitelist: Path, + expected_barcodes: int, chemistry: Chemistry, barcode_reference: pl.DataFrame | None, - mapped_reads: pl.DataFrame, + barcodes_df: pl.DataFrame, ): """ Generate the barcode list used for barcode correction and subsetting """ enable_barcode_correction = True - if args.filtered_cells: + if barcode_whitelist: barcode_subset = parse_barcode_reference( - filename=args.filtered_cells, + filename=expected_barcodes, barcode_length=(chemistry.cell_barcode_end - chemistry.cell_barcode_start), required_header=WHITELIST_COLUMN, ) else: - n_barcodes = args.expected_barcodes + n_barcodes = barcode_whitelist if barcode_reference is not None: barcode_subset = ( - mapped_reads.filter( + barcodes_df.filter( pl.col(BARCODE_COLUMN).str.is_in( barcode_reference[REFERENCE_COLUMN] ) @@ -383,10 +403,7 @@ def get_barcode_subset( ) else: raw_barcodes_dict = ( - mapped_reads.filter( - (~pl.col(BARCODE_COLUMN).str.count_matches("N")) - & (pl.col(FEATURE_NAME_COLUMN) != UNMAPPED_NAME) - ) + barcodes_df.filter(~pl.col(BARCODE_COLUMN).str.contains("N")) .group_by(BARCODE_COLUMN) .agg(pl.count()) .sort("count", descending=True) @@ -394,12 +411,12 @@ def get_barcode_subset( barcode_counter = Counter( zip(raw_barcodes_dict[BARCODE_COLUMN], raw_barcodes_dict["count"]) ) - true_barcodes = whitelist_method.getKneeEstimateDensity( - cell_barcode_counts=barcode_counter, expect_cells=n_barcodes + true_barcodes = whitelist_method.getKneeEstimateDistance( + cell_barcode_counts=barcode_counter, cell_number=n_barcodes ) barcode_subset = pl.DataFrame( - true_barcodes, schema={WHITELIST_COLUMN: pl.Utf8} - ) + true_barcodes, schema={WHITELIST_COLUMN: pl.Utf8, "counts": pl.UInt32} + ).drop("counts") if n_barcodes > barcode_subset.shape[0]: print( diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 19038bf..5eb34f4 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -3,6 +3,7 @@ import pybktree import polars as pl +from rapidfuzz import distance from collections import namedtuple @@ -19,6 +20,7 @@ WHITELIST_COLUMN, BARCODE_COLUMN, CORRECTED_BARCODE_COLUMN, + R2_COLUMN, ) # Unit Barcode correction @@ -43,15 +45,65 @@ def find_original_barcode(barcode: str, barcode_tree: pybktree.BKTree, distance: return barcode +def merge_results( + mapped_r2_df: pl.DataFrame, + corrected_barcodes_df: pl.DataFrame, + input_df: pl.DataFrame, +): + merged = ( + input_df.join(mapped_r2_df, on=R2_COLUMN, how="inner") + .join(corrected_barcodes_df.drop("count"), on=BARCODE_COLUMN, how="inner") + .drop([R2_COLUMN, BARCODE_COLUMN]) + ) + return merged + + +def correct_barcodes_pl( + barcodes_df: pl.DataFrame, barcode_subset_df: pl.DataFrame, hamming_distance: int +) -> tuple[pl.DataFrame, int]: + print("Correcting barcodes") + joined = ( + barcodes_df.sort(BARCODE_COLUMN) + .join_asof( + barcode_subset_df.sort(WHITELIST_COLUMN), + left_on=BARCODE_COLUMN, + right_on=WHITELIST_COLUMN, + ) + .with_columns( + pl.when(pl.col(BARCODE_COLUMN) == pl.col(WHITELIST_COLUMN)) + .then(False) + .otherwise(True) + .alias(CORRECTED_BARCODE_COLUMN) + ) + .filter(~pl.col(WHITELIST_COLUMN).is_null()) + ) + joined = ( + pl.concat( + [ + joined, + joined.map_rows( + lambda x: distance.Hamming.distance(s1=x[0], s2=x[2]) + ).rename({"map": "hamming_distance"}), + ], + how="horizontal", + ) + .filter(pl.col("hamming_distance") <= hamming_distance) + .drop("hamming_distance") + ) + n_corrected_barcodes = joined.filter(pl.col(CORRECTED_BARCODE_COLUMN)).shape[0] + print("Barcodes corrected") + return joined.drop(CORRECTED_BARCODE_COLUMN), n_corrected_barcodes + + def correct_barcodes( - mapped_reads: pl.DataFrame, + barcodes_df: pl.DataFrame, barcode_subset: pl.DataFrame, collapsing_threshold: int, ): """Correct barcodes based on a given whitelist Args: - mapped_reads (pl.DataFrame): Mapped reads + barcodes_df (pl.DataFrame): All barcodes barcode_subset (pl.DataFrame): Given whitelist collapsing_threshold (int): Hamming distance to correct on @@ -65,7 +117,7 @@ def correct_barcodes( ) print("Finding original barcodes") barcode_mapping = ( - mapped_reads.select(pl.col(BARCODE_COLUMN)) + barcodes_df.select(pl.col(BARCODE_COLUMN)) .unique() .filter(~pl.col(BARCODE_COLUMN).is_in(barcode_subset[WHITELIST_COLUMN])) .with_columns( @@ -83,7 +135,7 @@ def correct_barcodes( .drop(BARCODE_COLUMN) ) print("Collapsing wrong barcodes with original barcodes") - mapped_reads_barcode_corrected = mapped_reads.join( + mapped_reads_barcode_corrected = barcodes_df.join( barcode_mapping.drop(BARCODE_COLUMN), left_on=BARCODE_COLUMN, right_on=CORRECTED_BARCODE_COLUMN, diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 606ba62..0a14fbb 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -1,7 +1,6 @@ """Test function preprocessing of the module""" import glob -from collections import namedtuple import pytest import polars as pl From 5f762241efc5efbfc0d77563e1866e19536b7c8e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patrick=20R=C3=B6lli?= Date: Thu, 23 Nov 2023 06:11:38 +0100 Subject: [PATCH 64/77] Fix: python version --- cite_seq_count/processing.py | 69 +++---------------- setup.py | 7 +- .../test_data/matrix/.~lock.test_matrix.csv# | 1 - 3 files changed, 12 insertions(+), 65 deletions(-) delete mode 100644 tests/test_data/matrix/.~lock.test_matrix.csv# diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 5eb34f4..6227055 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -76,16 +76,16 @@ def correct_barcodes_pl( .alias(CORRECTED_BARCODE_COLUMN) ) .filter(~pl.col(WHITELIST_COLUMN).is_null()) - ) - joined = ( - pl.concat( - [ - joined, - joined.map_rows( - lambda x: distance.Hamming.distance(s1=x[0], s2=x[2]) - ).rename({"map": "hamming_distance"}), - ], - how="horizontal", + ).with_columns( + pl.struct( + pl.col(BARCODE_COLUMN), + pl.col(WHITELIST_COLUMN) + .map_elements( + lambda x: distance.Hamming.distance( + x[BARCODE_COLUMN], x[WHITELIST_COLUMN] + ) + ) + .alias("hamming_distance"), ) .filter(pl.col("hamming_distance") <= hamming_distance) .drop("hamming_distance") @@ -95,55 +95,6 @@ def correct_barcodes_pl( return joined.drop(CORRECTED_BARCODE_COLUMN), n_corrected_barcodes -def correct_barcodes( - barcodes_df: pl.DataFrame, - barcode_subset: pl.DataFrame, - collapsing_threshold: int, -): - """Correct barcodes based on a given whitelist - - Args: - barcodes_df (pl.DataFrame): All barcodes - barcode_subset (pl.DataFrame): Given whitelist - collapsing_threshold (int): Hamming distance to correct on - - Returns: - mapped_reads_barcode_corrected (pl.DataFrame): Barcode corrected mapped reads - n_corrected_barcodes (int): Number of corrected barcodes - """ - print("Generating barcode tree from reference list") - barcode_tree = pybktree.BKTree( - Levenshtein.hamming, barcode_subset[WHITELIST_COLUMN].to_list() - ) - print("Finding original barcodes") - barcode_mapping = ( - barcodes_df.select(pl.col(BARCODE_COLUMN)) - .unique() - .filter(~pl.col(BARCODE_COLUMN).is_in(barcode_subset[WHITELIST_COLUMN])) - .with_columns( - pl.col(BARCODE_COLUMN) - .map_elements( - lambda x: find_original_barcode( - x, barcode_tree=barcode_tree, distance=collapsing_threshold - ) - ) - .alias(CORRECTED_BARCODE_COLUMN) - ) - .with_columns( - barcode_corrected=pl.col(BARCODE_COLUMN) == pl.col(CORRECTED_BARCODE_COLUMN) - ) - .drop(BARCODE_COLUMN) - ) - print("Collapsing wrong barcodes with original barcodes") - mapped_reads_barcode_corrected = barcodes_df.join( - barcode_mapping.drop(BARCODE_COLUMN), - left_on=BARCODE_COLUMN, - right_on=CORRECTED_BARCODE_COLUMN, - ).drop("corrected") - n_corrected_barcodes = barcode_mapping.filter(pl.col("corrected")).shape[0] - return mapped_reads_barcode_corrected, n_corrected_barcodes - - # UMI correction section diff --git a/setup.py b/setup.py index 58f6cec..7e9fa7b 100644 --- a/setup.py +++ b/setup.py @@ -21,19 +21,16 @@ "Operating System :: OS Independent", ), install_requires=[ - "python-levenshtein>=0.12.0", "scipy>=1.1.0", - "multiprocess>=0.70.6.1", "umi_tools==1.1.4", "pytest>=6.0.0", "pytest-dependency==0.4.0", - "pandas>=0.23.4", - "pybktree==1.1", "cython>=0.29.17", "pyyaml==6.0", "pooch==1.6.0", "six==1.16.0", + "polars== 0.19.14" ], - python_requires=">=3.8", + python_requires="==3.11.6", package_data={"report_template": ["templates/*.json"]}, ) diff --git a/tests/test_data/matrix/.~lock.test_matrix.csv# b/tests/test_data/matrix/.~lock.test_matrix.csv# deleted file mode 100644 index 6f2f611..0000000 --- a/tests/test_data/matrix/.~lock.test_matrix.csv# +++ /dev/null @@ -1 +0,0 @@ -,proelli,proelli-ThinkPad-T470s,23.01.2019 16:02,file:///home/proelli/.config/libreoffice/4; \ No newline at end of file From 53c5e5b89f0d0e6bf90fb4fcbe467ae3b19aded5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Patrick=20R=C3=B6lli?= Date: Tue, 26 Dec 2023 16:54:06 +0100 Subject: [PATCH 65/77] (feat): Barcode correction using asof_join --- cite_seq_count/__main__.py | 12 ++-- cite_seq_count/chemistry.py | 6 +- cite_seq_count/constants.py | 22 +++++++ cite_seq_count/io.py | 7 +- cite_seq_count/mapping.py | 3 +- cite_seq_count/preprocessing.py | 47 ++++++-------- cite_seq_count/processing.py | 110 +++++++++++++++++++++++--------- setup.py | 2 +- tests/test_io.py | 16 ++--- tests/test_mapping.py | 6 +- tests/test_preprocessing.py | 13 ++-- tests/test_processing.py | 80 +++++++++++++---------- 12 files changed, 204 insertions(+), 120 deletions(-) create mode 100644 cite_seq_count/constants.py diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 0e36e13..5d71d30 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -61,7 +61,7 @@ def main(): r2_min_length=r2_min_length, chemistry_def=chemistry_def, ) - input_df, barcodes_df, r2_df = preprocessing.split_data_input( + main_df, barcodes_df, r2_df = preprocessing.split_data_input( mapping_input_path=temp_file ) # Remove temp file @@ -83,16 +83,20 @@ def main(): # Correct cell barcodes if args.bc_threshold > 0 and enable_barcode_correction: ( - barcode_corrected_df, - bcs_corrected, + barcodes_with_correction_df, + n_bcs_corrected, + mapped_barcodes, ) = processing.correct_barcodes_pl( barcodes_df=barcodes_df, barcode_subset_df=barcode_subset, hamming_distance=args.bc_threshold, ) + main_df = processing.update_main_df( + main_df=main_df, mapped_barcodes=mapped_barcodes + ) else: print("Skipping cell barcode correction") - bcs_corrected = 0 + n_bcs_corrected = 0 # Create sparse matrices for reads results read_results_matrix = processing.generate_sparse_matrices( diff --git a/cite_seq_count/chemistry.py b/cite_seq_count/chemistry.py index e187f7b..340bb45 100644 --- a/cite_seq_count/chemistry.py +++ b/cite_seq_count/chemistry.py @@ -6,7 +6,7 @@ from dataclasses import dataclass from argparse import ArgumentParser -from cite_seq_count import preprocessing +from cite_seq_count.preprocessing import parse_barcode_reference import polars as pl GLOBAL_LINK_RAW = "https://raw.githubusercontent.com/Hoohm/scg_lib_structs/10xv3_totalseq_b/chemistries/" @@ -133,7 +133,7 @@ def create_chemistry_definition(args: ArgumentParser) -> Chemistry: def setup_chemistry(args: ArgumentParser) -> tuple[pl.DataFrame | None, Chemistry]: if args.chemistry_id: chemistry_def = get_chemistry_definition(args.chemistry_id) - barcode_reference = preprocessing.parse_barcode_reference( + barcode_reference = parse_barcode_reference( filename=chemistry_def.barcode_reference_path, barcode_length=chemistry_def.cell_barcode_end - chemistry_def.cell_barcode_start @@ -144,7 +144,7 @@ def setup_chemistry(args: ArgumentParser) -> tuple[pl.DataFrame | None, Chemistr chemistry_def = create_chemistry_definition(args) if args.barcode_reference: print("Loading barcode reference") - barcode_reference = preprocessing.parse_barcode_reference( + barcode_reference = parse_barcode_reference( filename=args.barcode_reference, barcode_length=args.cb_last - args.cb_first + 1, required_header=["reference"], diff --git a/cite_seq_count/constants.py b/cite_seq_count/constants.py new file mode 100644 index 0000000..3ce7729 --- /dev/null +++ b/cite_seq_count/constants.py @@ -0,0 +1,22 @@ + + +# REQUIRED_TAGS_HEADER = ["sequence", "feature_name"] +REQUIRED_CELLS_REF_HEADER = ["reference"] +OPTIONAL_CELLS_REF_HEADER = ["translation"] +# Polars column names +# Tags input +FEATURE_NAME_COLUMN = "feature_name" +SEQUENCE_COLUMN = "sequence" +REQUIRED_TAGS_HEADER = [FEATURE_NAME_COLUMN, SEQUENCE_COLUMN] +# Reads input +BARCODE_COLUMN = "barcode" +CORRECTED_BARCODE_COLUMN = "corrected_barcode" +UMI_COLUMN = "umi" +R2_COLUMN = "r2" +# Barcode input +REFERENCE_COLUMN = "reference" +TRANSLATION_COLUMN = "translation" +WHITELIST_COLUMN = "whitelist" +STRIP_CHARS = '"0123456789- \t\n' + +UNMAPPED_NAME = "unmapped" diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 8d46283..609858c 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -24,7 +24,6 @@ import polars as pl from scipy import io from cite_seq_count import secondsToText -from cite_seq_count.chemistry import Chemistry JSON_REPORT_PATH = pkg_resources.resource_filename(__name__, "templates/report.json") @@ -260,7 +259,7 @@ def create_report( r1_too_short: int, r2_too_short: int, args: ArgumentParser, - chemistry_def: Chemistry, + chemistry_def, maximum_distance: int, ): """ @@ -328,7 +327,7 @@ def write_chunks_to_disk( read2_paths: list[Path], r2_min_length: int, n_reads_per_chunk: int, - chemistry_def: Chemistry, + chemistry_def, parsed_tags: pl.DataFrame, maximum_distance, ): @@ -473,7 +472,7 @@ def write_mapping_input( read1_paths: list[Path], read2_paths: list[Path], r2_min_length: int, - chemistry_def: Chemistry, + chemistry_def, ): """ Writes chunked files of reads to disk and prepares parallel diff --git a/cite_seq_count/mapping.py b/cite_seq_count/mapping.py index 36a5025..4226268 100644 --- a/cite_seq_count/mapping.py +++ b/cite_seq_count/mapping.py @@ -1,10 +1,9 @@ """Mapping module. Holds all code related to mapping reads """ -from pathlib import Path import polars as pl from rapidfuzz import fuzz, process, distance -from cite_seq_count.preprocessing import ( +from cite_seq_count.constants import ( SEQUENCE_COLUMN, R2_COLUMN, FEATURE_NAME_COLUMN, diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index baff2d1..6ce4346 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -11,29 +11,20 @@ import polars as pl import umi_tools.whitelist_methods as whitelist_method from cite_seq_count.io import get_n_lines, check_file -from cite_seq_count.chemistry import Chemistry - - -# REQUIRED_TAGS_HEADER = ["sequence", "feature_name"] -REQUIRED_CELLS_REF_HEADER = ["reference"] -OPTIONAL_CELLS_REF_HEADER = ["translation"] -# Polars column names -# Tags input -FEATURE_NAME_COLUMN = "feature_name" -SEQUENCE_COLUMN = "sequence" -REQUIRED_TAGS_HEADER = [FEATURE_NAME_COLUMN, SEQUENCE_COLUMN] -# Reads input -BARCODE_COLUMN = "barcode" -CORRECTED_BARCODE_COLUMN = "corrected_barcode" -UMI_COLUMN = "umi" -R2_COLUMN = "r2" -# Barcode input -REFERENCE_COLUMN = "reference" -TRANSLATION_COLUMN = "translation" -WHITELIST_COLUMN = "whitelist" -STRIP_CHARS = '"0123456789- \t\n' - -UNMAPPED_NAME = "unmapped" +from cite_seq_count.constants import ( + SEQUENCE_COLUMN, + R2_COLUMN, + FEATURE_NAME_COLUMN, + BARCODE_COLUMN, + UMI_COLUMN, + UNMAPPED_NAME, + REQUIRED_TAGS_HEADER, + REFERENCE_COLUMN, + TRANSLATION_COLUMN, + OPTIONAL_CELLS_REF_HEADER, + STRIP_CHARS, + WHITELIST_COLUMN +) def parse_barcode_reference( @@ -348,7 +339,7 @@ def pre_run_checks( def split_data_input(mapping_input_path: Path): - input_df = ( + main_df = ( pl.read_csv( mapping_input_path, has_header=False, @@ -359,19 +350,19 @@ def split_data_input(mapping_input_path: Path): ) barcodes_df = ( - input_df.select([BARCODE_COLUMN, "count"]) + main_df.select([BARCODE_COLUMN, "count"]) .group_by(BARCODE_COLUMN) .agg(pl.sum("count")) ) - r2_df = input_df.select(R2_COLUMN).unique() + r2_df = main_df.select(R2_COLUMN).unique() - return input_df, barcodes_df, r2_df + return main_df, barcodes_df, r2_df def get_barcode_subset( barcode_whitelist: Path, expected_barcodes: int, - chemistry: Chemistry, + chemistry, barcode_reference: pl.DataFrame | None, barcodes_df: pl.DataFrame, ): diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 6227055..5c9c70e 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -16,11 +16,12 @@ from umi_tools import network -from cite_seq_count.preprocessing import ( +from cite_seq_count.constants import ( WHITELIST_COLUMN, BARCODE_COLUMN, CORRECTED_BARCODE_COLUMN, R2_COLUMN, + UMI_COLUMN, ) # Unit Barcode correction @@ -59,40 +60,91 @@ def merge_results( def correct_barcodes_pl( - barcodes_df: pl.DataFrame, barcode_subset_df: pl.DataFrame, hamming_distance: int -) -> tuple[pl.DataFrame, int]: + barcodes_df: pl.DataFrame, + barcode_subset_df: pl.DataFrame, + hamming_distance: int, +) -> tuple[pl.DataFrame, int, dict]: + """Corrects barcodes using a whitelist based on join_asof from polars. + Uses both forward and backward strategy to dinf the closest barcode + + Args: + barcodes_df (pl.DataFrame): All barcodes with their respective counts + barcode_subset_df (pl.DataFrame): Barcode reference used to correct + hamming_distance (int): Max hamming distance allowed + mapped_barcodes (dict): Dict of mapped barcodes + + Returns: + tuple[pl.DataFrame, int]: The corrected version of the input barcodes_df, number of corrected barcodes + """ print("Correcting barcodes") - joined = ( - barcodes_df.sort(BARCODE_COLUMN) - .join_asof( - barcode_subset_df.sort(WHITELIST_COLUMN), - left_on=BARCODE_COLUMN, - right_on=WHITELIST_COLUMN, - ) - .with_columns( - pl.when(pl.col(BARCODE_COLUMN) == pl.col(WHITELIST_COLUMN)) - .then(False) - .otherwise(True) - .alias(CORRECTED_BARCODE_COLUMN) - ) - .filter(~pl.col(WHITELIST_COLUMN).is_null()) - ).with_columns( - pl.struct( - pl.col(BARCODE_COLUMN), - pl.col(WHITELIST_COLUMN) - .map_elements( - lambda x: distance.Hamming.distance( - x[BARCODE_COLUMN], x[WHITELIST_COLUMN] + corrected_barcodes_pl = pl.DataFrame( + schema={ + BARCODE_COLUMN: pl.Utf8, + "count": pl.Int64, + WHITELIST_COLUMN: pl.Utf8, + "hamming_distance": pl.UInt8, + } + ) + methods = ["backward", "forward"] + for method in methods: + current_barcodes = ( + barcodes_df.filter( + (~pl.col(BARCODE_COLUMN).is_in(corrected_barcodes_pl[BARCODE_COLUMN])) + & (~pl.col(BARCODE_COLUMN).is_in(barcode_subset_df[WHITELIST_COLUMN])) + ) + .sort(BARCODE_COLUMN) + .join_asof( + barcode_subset_df.sort(WHITELIST_COLUMN), + left_on=BARCODE_COLUMN, + right_on=WHITELIST_COLUMN, + strategy=method, + ) + .filter(~pl.col(WHITELIST_COLUMN).is_null()) + .with_columns( + pl.struct(pl.col(BARCODE_COLUMN), pl.col(WHITELIST_COLUMN)) + .map_elements( + lambda x: distance.Hamming.distance( + x[BARCODE_COLUMN], x[WHITELIST_COLUMN] + ), + return_dtype=pl.UInt8, ) + .alias("hamming_distance") ) - .alias("hamming_distance"), + .filter(pl.col("hamming_distance") <= hamming_distance) ) - .filter(pl.col("hamming_distance") <= hamming_distance) - .drop("hamming_distance") + corrected_barcodes_pl = pl.concat([corrected_barcodes_pl, current_barcodes]) + mapped_barcodes = dict( + corrected_barcodes_pl.select(BARCODE_COLUMN, WHITELIST_COLUMN).iter_rows() + ) + final_corrected = ( + barcodes_df.with_columns( + pl.col(BARCODE_COLUMN).map_dict(mapped_barcodes, default=pl.first()) + ) + .group_by(BARCODE_COLUMN) + .agg(pl.sum("count")) ) - n_corrected_barcodes = joined.filter(pl.col(CORRECTED_BARCODE_COLUMN)).shape[0] print("Barcodes corrected") - return joined.drop(CORRECTED_BARCODE_COLUMN), n_corrected_barcodes + n_corrected_barcodes = corrected_barcodes_pl.shape[0] + + return final_corrected, n_corrected_barcodes, mapped_barcodes + + +def update_main_df(main_df: pl.DataFrame, mapped_barcodes: dict): + """Update the main data df with the corrected barcodes + + Args: + main_df (pl.DataFrame): Data of all reads + mapped_barcodes (dict): Mapped barcodes from correction + + Returns: + pl.DataFrame: Data of all reads with barcodes corrected + """ + main_df = ( + main_df.with_columns(pl.col(BARCODE_COLUMN).map_dict(mapped_barcodes)) + .group_by([BARCODE_COLUMN, UMI_COLUMN, R2_COLUMN]) + .agg(pl.sum("count")) + ) + return main_df # UMI correction section diff --git a/setup.py b/setup.py index 7e9fa7b..a251c0c 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ "pyyaml==6.0", "pooch==1.6.0", "six==1.16.0", - "polars== 0.19.14" + "polars== 0.19.14", ], python_requires="==3.11.6", package_data={"report_template": ["templates/*.json"]}, diff --git a/tests/test_io.py b/tests/test_io.py index 9b4593f..f6625a3 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -2,7 +2,7 @@ import os import gzip import scipy -from cite_seq_count import io +from cite_seq_count.io import get_n_lines, write_to_files, write_dense, get_read_paths from collections import namedtuple import numpy as np @@ -70,7 +70,7 @@ def test_write_to_files_wo_translation(data, tmpdir): mtx_path: "3ea98c44d88a947215bace0c72ac1303", } - io.write_to_files( + write_to_files( pytest.sparse_matrix, pytest.filtered_cells, pytest.parsed_tags_map, @@ -103,7 +103,7 @@ def test_write_to_files_with_translation(data, tmpdir): mtx_path: "3ea98c44d88a947215bace0c72ac1303", } - io.write_to_files( + write_to_files( pytest.sparse_matrix, pytest.filtered_cells, pytest.parsed_tags_map, @@ -129,7 +129,7 @@ def test_write_to_dense_wo_translation(data, tmpdir): csv_path: "fef502237900ec386d100169fa1fab7c", } - io.write_dense( + write_dense( sparse_matrix=pytest.sparse_matrix, parsed_tags=pytest.parsed_tags_map, columns=pytest.filtered_cells, @@ -142,13 +142,13 @@ def test_write_to_dense_wo_translation(data, tmpdir): @pytest.mark.dependency() def test_get_n_lines(data): - assert io.get_n_lines(pytest.correct_R1_path) == (200 * 4) + assert get_n_lines(pytest.correct_R1_path) == (200 * 4) @pytest.mark.dependency() def test_corrrect_multipath(data): assert ( - io.get_read_paths(pytest.correct_R1_multipath, pytest.correct_R2_multipath) + get_read_paths(pytest.correct_R1_multipath, pytest.correct_R2_multipath) == pytest.correct_multipath_result ) @@ -156,10 +156,10 @@ def test_corrrect_multipath(data): @pytest.mark.dependency(depends=["test_get_n_lines"]) def test_incorrrect_multipath(data): with pytest.raises(SystemExit): - io.get_read_paths(pytest.correct_R1_multipath, pytest.incorrect_R2_multipath) + get_read_paths(pytest.correct_R1_multipath, pytest.incorrect_R2_multipath) @pytest.mark.dependency(depends=["test_get_n_lines"]) def test_get_n_lines_not_multiple_of_4(data): with pytest.raises(SystemExit): - io.get_n_lines(pytest.corrupt_R1_path) + get_n_lines(pytest.corrupt_R1_path) diff --git a/tests/test_mapping.py b/tests/test_mapping.py index 080420a..0991a89 100644 --- a/tests/test_mapping.py +++ b/tests/test_mapping.py @@ -3,7 +3,7 @@ import copy from collections import Counter, namedtuple from cite_seq_count import mapping -from cite_seq_count import preprocessing +from cite_seq_count.preprocessing import parse_tags_csv def complete_poly_A(seq, final_length=40): @@ -69,8 +69,8 @@ def data(): pytest.sliding_window = False pytest.sequence_pool = [] - pytest.tags_tuple = preprocessing.parse_tags_csv( - preprocessing.parse_tags_csv("tests/test_data/tags/pass/correct.csv") + pytest.tags_tuple = parse_tags_csv( + parse_tags_csv("tests/test_data/tags/pass/correct.csv") ) pytest.mapping_input = namedtuple( "mapping_input", diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 0a14fbb..13748c6 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -5,7 +5,8 @@ import pytest import polars as pl from polars.testing import assert_frame_equal -from cite_seq_count import preprocessing +from cite_seq_count.preprocessing import parse_barcode_reference, parse_tags_csv, check_tags +from cite_seq_count.constants import REFERENCE_COLUMN @pytest.fixture @@ -66,11 +67,11 @@ def test_csv_parser(data): """ passing_files = glob.glob(pytest.passing_csv) for file_path in passing_files: - preprocessing.parse_tags_csv(file_path) + parse_tags_csv(file_path) with pytest.raises(SystemExit): failing_files = glob.glob(pytest.failing_csv) for file_path in failing_files: - preprocessing.parse_tags_csv(file_path) + parse_tags_csv(file_path) @pytest.mark.dependency() @@ -78,15 +79,15 @@ def test_parse_reference_list_csv(data): passing_files = glob.glob(pytest.passing_reference_list_csv) for file_path in passing_files: assert_frame_equal( - left=preprocessing.parse_barcode_reference(file_path, 16), + left=parse_barcode_reference(file_path, 16, [REFERENCE_COLUMN]), right=pytest.correct_reference_translation_list, ) with pytest.raises(SystemExit): failing_files = glob.glob(pytest.failing_reference_list_csv) for file_path in failing_files: - preprocessing.parse_barcode_reference(file_path, 16) + parse_barcode_reference(file_path, 16, [REFERENCE_COLUMN]) def test_check_distance_too_big_between_tags(data): with pytest.raises(SystemExit): - preprocessing.check_tags(pytest.correct_tag_pl, 8) + check_tags(pytest.correct_tag_pl, 8) diff --git a/tests/test_processing.py b/tests/test_processing.py index 310d493..067e60f 100644 --- a/tests/test_processing.py +++ b/tests/test_processing.py @@ -1,36 +1,64 @@ import pytest from collections import namedtuple from cite_seq_count import processing +import polars as pl +from polars.testing import assert_frame_equal @pytest.fixture def data(): - from collections import Counter - tag = namedtuple("tag", ["name", "sequence", "id"]) - pytest.tags = [ - tag(name="test1", sequence="CGTACGTAGCCTAGC", id=0), - tag(name="test2", sequence="CGTAGCTCG", id=1), - ] - pytest.results = { - "ACTGTTTTATTGGCCT": { - 0: Counter({b"CATTAGTGGT": 3, b"CATTAGTGGG": 2, b"CATTCGTGGT": 1}) - }, - "TTCATAAGGTAGGGAT": { - 1: Counter({b"TAGCTTAGTA": 3, b"TAGCTTAGTC": 2, b"GCGATGCATA": 1}) - }, - } - pytest.corrected_results = { - "ACTGTTTTATTGGCCT": {0: Counter({b"CATTAGTGGT": 6})}, - "TTCATAAGGTAGGGAT": {1: Counter({b"TAGCTTAGTA": 5, b"GCGATGCATA": 1})}, - } - pytest.umis_per_cell = Counter({"ACTGTTTTATTGGCCT": 1, "TTCATAAGGTAGGGAT": 2}) - pytest.reads_per_cell = Counter({"ACTGTTTTATTGGCCT": 3, "TTCATAAGGTAGGGAT": 6}) + + pytest.barcodes_df = pl.DataFrame( + { + "barcode": [ + "TACATATTCTTTACTG", + "AACATATTCTTTACTG", + "CACATATTCTTTACTG", + "GACATATTCTTTACTG", + "TACATATTCTTTACTA", + "TACATATTCTTTACTC", + "TACATATTCTTTACTT", + "TAGAGGGAGGTCAAGC", + "TAGAGGGACGTCAAGC", + "TAGAGGGATGTCAAGC", + "TAGAGGGAAGTCAAGC", + ], + "count": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + } + ) + + pytest.barcode_subset_df = pl.DataFrame( + {"whitelist": ["TACATATTCTTTACTG", "TAGAGGGAAGTCAAGC"]} + ) + + pytest.corrected_barcodes_df = pl.DataFrame( + { + "barcode": [ + "TAGAGGGAAGTCAAGC", + "TACATATTCTTTACTG", + ], + "count": [4, 7], + } + ) + pytest.expected_cells = 2 pytest.collapsing_threshold = 1 pytest.max_umis = 20000 +@pytest.mark.dependency() +def test_correct_barcodes(data): + corrected_barcodes, _, _ = processing.correct_barcodes_pl( + barcodes_df=pytest.barcodes_df, + barcode_subset_df=pytest.barcode_subset_df, + hamming_distance=1, + ) + assert_frame_equal( + pytest.corrected_barcodes_df, corrected_barcodes, check_row_order=False + ) + + @pytest.mark.dependency() def test_correct_umis(data): temp = processing.correct_umis_in_cells((pytest.results, 2, pytest.max_umis, 2)) @@ -47,18 +75,6 @@ def test_correct_umis(data): assert n_corrected == 3 -@pytest.mark.dependency(depends=["test_correct_umis"]) -def test_correct_cells(data): - processing.correct_cells_no_translation_list( - pytest.corrected_results, - pytest.reads_per_cell, - pytest.umis_per_cell, - pytest.expected_cells, - pytest.collapsing_threshold, - pytest.tags, - ) - - @pytest.mark.dependency(depends=["test_correct_umis"]) def test_generate_sparse_umi_matrices(data): umi_results_matrix = processing.generate_sparse_matrices( From ced3e614aab2178dafe2a0a6dbe783d19f297f8e Mon Sep 17 00:00:00 2001 From: hoohm Date: Thu, 28 Dec 2023 20:15:30 +0100 Subject: [PATCH 66/77] (feat): Mtx writing --- cite_seq_count/__main__.py | 157 ++++++++++----------------- cite_seq_count/constants.py | 14 ++- cite_seq_count/io.py | 63 +++++++++++ cite_seq_count/mapping.py | 45 +++----- cite_seq_count/preprocessing.py | 81 +++++++++----- cite_seq_count/processing.py | 186 ++++---------------------------- tests/test_preprocessing.py | 17 ++- 7 files changed, 237 insertions(+), 326 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 5d71d30..7f22086 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -23,12 +23,12 @@ def main(): # Check a few path before doing anything if not os.access(args.temp_path, os.W_OK): sys.exit( - f"Temp folder: {args.temp_path} is not writable." + f"Temp folder: {args.temp_path} is not writeable." f"Please check permissions and/or change temp folder." ) if not os.access(os.path.dirname(os.path.abspath(args.outfolder)), os.W_OK): sys.exit( - f"Output folder: {args.outfolder} is not writable." + f"Output folder: {args.outfolder} is not writeable." f"Please check permissions and/or change output folder." ) @@ -74,7 +74,7 @@ def main(): barcode_subset, enable_barcode_correction = preprocessing.get_barcode_subset( barcode_whitelist=args.filtered_barcodes, - expected_barcodes=args.expected_barcodes, + n_barcodes=args.expected_barcodes, chemistry=chemistry_def, barcode_reference=barcode_reference, barcodes_df=barcodes_df, @@ -83,7 +83,7 @@ def main(): # Correct cell barcodes if args.bc_threshold > 0 and enable_barcode_correction: ( - barcodes_with_correction_df, + barcodes_df, n_bcs_corrected, mapped_barcodes, ) = processing.correct_barcodes_pl( @@ -97,113 +97,72 @@ def main(): else: print("Skipping cell barcode correction") n_bcs_corrected = 0 - - # Create sparse matrices for reads results - read_results_matrix = processing.generate_sparse_matrices( - final_results=final_results, - parsed_tags=parsed_tags, - filtered_cells=filtered_cells, + read_counts = processing.generate_mtx_counts( + main_df=main_df, + barcode_subset=barcodes_df, + mapped_r2_df=mapped_r2_df, + data_type="read", ) # Write reads to file - io.write_to_files( - sparse_matrix=read_results_matrix, - filtered_cells=filtered_cells, - parsed_tags=parsed_tags, + io.write_data_to_mtx( + main_df=read_counts, + tags_df=parsed_tags, + barcodes_df=barcode_subset, data_type="read", - outfolder=args.outfolder, - translation_dict=translation_dict, + outpath=args.outfolder, ) + # TODO: add clustered cells filter: Max UMIs per cell per feature: 20000 + print("UMI correction not implemented yet") + # Don't correct + umis_corrected = 0 + clustered_cells = [] + # TODO: Write out to mtx and csv clustered cells - # UMI correction - if args.umi_threshold != 0: - # Correct UMIS - ( - final_results, - umis_corrected, - clustered_cells, - ) = processing.run_umi_correction( - final_results=final_results, - filtered_cells=filtered_cells, - unmapped_id=len(parsed_tags), - args=args, - ) - else: - # Don't correct - umis_corrected = 0 - clustered_cells = [] - - if len(clustered_cells) > 0: - # Remove clustered cells from the top cells - for cell_barcode in clustered_cells: - filtered_cells.remove(cell_barcode) - - # Create sparse clustered cells matrix - umi_clustered_matrix = processing.generate_sparse_matrices( - final_results=final_results, - parsed_tags=parsed_tags, - filtered_cells=clustered_cells, - ) - # Write uncorrected cells to dense output - io.write_dense( - sparse_matrix=umi_clustered_matrix, - parsed_tags=parsed_tags, - columns=clustered_cells, - outfolder=os.path.join(args.outfolder, "uncorrected_cells"), - filename="dense_umis.tsv", - ) # Generate the UMI count matrix - umi_results_matrix = processing.generate_sparse_matrices( - final_results=final_results, - parsed_tags=parsed_tags, - filtered_cells=filtered_cells, - umi_counts=True, + umi_counts = processing.generate_mtx_counts( + main_df=main_df, + barcode_subset=barcodes_df, + mapped_r2_df=mapped_r2_df, + data_type="umi", ) # Write umis to file - io.write_to_files( - sparse_matrix=umi_results_matrix, - filtered_cells=filtered_cells, - parsed_tags=parsed_tags, + io.write_data_to_mtx( + main_df=read_counts, + tags_df=parsed_tags, + barcodes_df=barcode_subset, data_type="umi", - outfolder=args.outfolder, - translation_dict=translation_dict, - ) - - # Write unmapped sequences - if len(merged_no_match) > 0: - io.write_unmapped( - merged_no_match=merged_no_match, - top_unknowns=args.unknowns_top, - outfolder=args.outfolder, - filename=args.unmapped_file, - ) - - # Create report and write it to disk - io.create_report( - total_reads=total_reads, - no_match=merged_no_match, - version=argsparser.get_package_version(), - start_time=start_time, - umis_corrected=umis_corrected, - bcs_corrected=bcs_corrected, - bad_cells=clustered_cells, - r1_too_short=r1_too_short, - r2_too_short=r2_too_short, - args=args, - chemistry_def=chemistry_def, - maximum_distance=maximum_distance, + outpath=args.outfolder, ) - # Write dense matrix to disk if requested - if args.dense: - print("Writing dense format output") - io.write_dense( - sparse_matrix=umi_results_matrix, - parsed_tags=parsed_tags, - columns=filtered_cells, - outfolder=args.outfolder, - filename="dense_umis.tsv", - ) + # TODO: Write unmapped sequences + # TODO: rewrite reporting + # # Create report and write it to disk + # io.create_report( + # total_reads=total_reads, + # no_match=merged_no_match, + # version=argsparser.get_package_version(), + # start_time=start_time, + # umis_corrected=umis_corrected, + # bcs_corrected=bcs_corrected, + # bad_cells=clustered_cells, + # r1_too_short=r1_too_short, + # r2_too_short=r2_too_short, + # args=args, + # chemistry_def=chemistry_def, + # maximum_distance=maximum_distance, + # ) + # TODO: Rewrite dense format output + # # Write dense matrix to disk if requested + # if args.dense: + # print("Writing dense format output") + # io.write_dense( + # sparse_matrix=umi_results_matrix, + # parsed_tags=parsed_tags, + # columns=filtered_cells, + # outfolder=args.outfolder, + # filename="dense_umis.tsv", + # ) if __name__ == "__main__": diff --git a/cite_seq_count/constants.py b/cite_seq_count/constants.py index 3ce7729..40fa6c7 100644 --- a/cite_seq_count/constants.py +++ b/cite_seq_count/constants.py @@ -1,5 +1,3 @@ - - # REQUIRED_TAGS_HEADER = ["sequence", "feature_name"] REQUIRED_CELLS_REF_HEADER = ["reference"] OPTIONAL_CELLS_REF_HEADER = ["translation"] @@ -8,15 +6,25 @@ FEATURE_NAME_COLUMN = "feature_name" SEQUENCE_COLUMN = "sequence" REQUIRED_TAGS_HEADER = [FEATURE_NAME_COLUMN, SEQUENCE_COLUMN] +UNMAPPED_NAME = "unmapped" + # Reads input BARCODE_COLUMN = "barcode" CORRECTED_BARCODE_COLUMN = "corrected_barcode" UMI_COLUMN = "umi" R2_COLUMN = "r2" +COUNT_COLUMN = "count" # Barcode input REFERENCE_COLUMN = "reference" TRANSLATION_COLUMN = "translation" WHITELIST_COLUMN = "whitelist" STRIP_CHARS = '"0123456789- \t\n' -UNMAPPED_NAME = "unmapped" +# MTX format +BARCODE_ID_COLUMN = "barcode_id" +FEATURE_ID_COLUMN = "feature_id" +MTX_HEADER = """%%MatrixMarket matrix coordinate integer general\n%\n""" +TEMP_MTX = "temp.mtx" +FEATURES_MTX = "features.tsv.gz" +BARCODE_MTX = "barcodes.tsv.gz" +MATRIX_MTX = "matrix.mtx.gz" diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 609858c..6994539 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -24,10 +24,32 @@ import polars as pl from scipy import io from cite_seq_count import secondsToText +from cite_seq_count.constants import ( + FEATURE_NAME_COLUMN, + BARCODE_COLUMN, + BARCODE_ID_COLUMN, + FEATURE_ID_COLUMN, + COUNT_COLUMN, + MTX_HEADER, + FEATURES_MTX, + BARCODE_MTX, + MATRIX_MTX, + TEMP_MTX, + UNMAPPED_NAME, + SEQUENCE_COLUMN, + WHITELIST_COLUMN, +) JSON_REPORT_PATH = pkg_resources.resource_filename(__name__, "templates/report.json") +def compress_file(in_file: Path, out_file: Path): + with open(in_file, "rb") as f_in: + with gzip.open(out_file, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + os.remove(in_file) + + def blocks(file: Path, size: int = 65536): """ A fast way of counting the lines of a large file. @@ -467,6 +489,47 @@ def write_chunks_to_disk( ) +def write_data_to_mtx( + main_df: pl.DataFrame, + tags_df: pl.DataFrame, + barcodes_df: pl.DataFrame, + data_type: str, + outpath: str, +) -> None: + tags_indexed = pl.concat( + [ + tags_df, + pl.DataFrame( + {FEATURE_NAME_COLUMN: UNMAPPED_NAME, SEQUENCE_COLUMN: "UNKNOWN"} + ), + ] + ).with_row_count(offset=1, name=FEATURE_ID_COLUMN) + barcodes_indexed = barcodes_df.with_row_count(offset=1, name=BARCODE_ID_COLUMN) + mtx_df = ( + tags_indexed.join(main_df, on=FEATURE_NAME_COLUMN) + .join(barcodes_indexed, left_on=BARCODE_COLUMN, right_on=WHITELIST_COLUMN) + .select([FEATURE_ID_COLUMN, BARCODE_ID_COLUMN, COUNT_COLUMN]) + ) + data_path = Path(outpath) / f"{data_type}_count" + mtx_df.write_csv(include_header=False, file=data_path / TEMP_MTX, separator="\t") + # Write out the full MTX matrix + with open(data_path / TEMP_MTX, "r") as mtx_in: + mtx_main = mtx_in.read() + final_mtx = MTX_HEADER + mtx_main + with open(data_path / MATRIX_MTX, "wb") as mtx_out: + mtx_out.write(final_mtx.encode()) + os.remove(data_path / TEMP_MTX) + # Write ouf features and barcodes + tags_indexed.sort(FEATURE_ID_COLUMN).select(FEATURE_NAME_COLUMN).write_csv( + file=data_path / "features.csv", include_header=False + ) + compress_file(data_path / "features.csv", FEATURES_MTX) + barcodes_indexed.sort(BARCODE_ID_COLUMN).select(WHITELIST_COLUMN).write_csv( + file=data_path / "barcodes.csv", include_header=False + ) + compress_file(data_path / "barcodes.csv", BARCODE_MTX) + + def write_mapping_input( args: ArgumentParser, read1_paths: list[Path], diff --git a/cite_seq_count/mapping.py b/cite_seq_count/mapping.py index 4226268..3eabe9d 100644 --- a/cite_seq_count/mapping.py +++ b/cite_seq_count/mapping.py @@ -7,13 +7,11 @@ SEQUENCE_COLUMN, R2_COLUMN, FEATURE_NAME_COLUMN, - BARCODE_COLUMN, - UMI_COLUMN, UNMAPPED_NAME, ) -def find_best_match_rapid(tag_seq, tags_list, maximum_distance): +def find_best_match_fast(tag_seq, tags_list, maximum_distance): choices = tags_list[SEQUENCE_COLUMN].to_list() features = tags_list[FEATURE_NAME_COLUMN].to_list() res = process.extractOne(choices=choices, query=tag_seq, scorer=fuzz.QRatio) @@ -23,36 +21,21 @@ def find_best_match_rapid(tag_seq, tags_list, maximum_distance): return UNMAPPED_NAME -def match_generic_string_dfs( - ref_df: pl.DataFrame, - target_df: pl.DataFrame, - left_on: str, - right_on: str, - hamming_distance: int, -): - corrected_column = "corrected_" + left_on - joined = ( - target_df.sort(left_on) - .join_asof(ref_df.sort(right_on), left_on=left_on, right_on=right_on) - .with_columns( - pl.when(pl.col(left_on) == pl.col(right_on)) - .then(False) - .otherwise(True) - .alias(corrected_column) - ) - .with_columns( - pl.when(pl.col(corrected_column)) - .then(distance.Hamming.distance(s1=pl.col(left_on), s2=pl.col(right_on))) - .otherwise(0) - .alias("hamming_distance") - ) - ) - return joined - - def map_reads_hybrid( r2_df: pl.DataFrame, parsed_tags: pl.DataFrame, maximum_distance: int ) -> pl.DataFrame: + """Map sequence data to a tags reference. + Using a hybdrid approach where we first join all the data for the exact matches + then using a hamming distance calculation to find the closest match + + Args: + r2_df (pl.DataFrame): All r2 sequences to map + parsed_tags (pl.DataFrame): tags to map to + maximum_distance (int): max distance allowed for mismatches + + Returns: + pl.DataFrame: Mapped data + """ print("Mapping reads") mapped_r2_df = r2_df.join( parsed_tags, left_on=R2_COLUMN, right_on=SEQUENCE_COLUMN, how="left" @@ -61,7 +44,7 @@ def map_reads_hybrid( .then( pl.col(R2_COLUMN) .map_elements( - lambda x: find_best_match_rapid( + lambda x: find_best_match_fast( x, tags_list=parsed_tags, maximum_distance=maximum_distance ) ) diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 6ce4346..a7f0c18 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -9,21 +9,22 @@ from pathlib import Path import Levenshtein import polars as pl +import numpy as np +import numpy.matlib as npm import umi_tools.whitelist_methods as whitelist_method from cite_seq_count.io import get_n_lines, check_file from cite_seq_count.constants import ( SEQUENCE_COLUMN, R2_COLUMN, - FEATURE_NAME_COLUMN, BARCODE_COLUMN, UMI_COLUMN, - UNMAPPED_NAME, REQUIRED_TAGS_HEADER, REFERENCE_COLUMN, TRANSLATION_COLUMN, OPTIONAL_CELLS_REF_HEADER, STRIP_CHARS, - WHITELIST_COLUMN + WHITELIST_COLUMN, + COUNT_COLUMN, ) @@ -134,7 +135,6 @@ def parse_tags_csv(file_name: str) -> pl.DataFrame: file_type="tags", expected_pattern="ATGC", ) - return data_pl @@ -338,7 +338,23 @@ def pre_run_checks( return n_reads, r2_min_length, maximum_distance -def split_data_input(mapping_input_path: Path): +def split_data_input( + mapping_input_path: Path, +) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]: + """Read in all the input data and split it into three dataframes. + + Reduce the size of the data by grouping on barcodes, umis and sequences. + The function splits the data into three main dataframes. + 1. main_df is the dataframe holding the links between the other dfs + 2. barcodes_df holds only the barcode information for barcode correction + 3. r2_df only holds the sequences for mapping + + Args: + mapping_input_path (Path): Path to the csv file containing all the input data + + Returns: + tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]: Three dfs described above + """ main_df = ( pl.read_csv( mapping_input_path, @@ -350,9 +366,9 @@ def split_data_input(mapping_input_path: Path): ) barcodes_df = ( - main_df.select([BARCODE_COLUMN, "count"]) + main_df.select([BARCODE_COLUMN, COUNT_COLUMN]) .group_by(BARCODE_COLUMN) - .agg(pl.sum("count")) + .agg(pl.sum(COUNT_COLUMN)) ) r2_df = main_df.select(R2_COLUMN).unique() @@ -361,23 +377,33 @@ def split_data_input(mapping_input_path: Path): def get_barcode_subset( barcode_whitelist: Path, - expected_barcodes: int, + n_barcodes: int, chemistry, barcode_reference: pl.DataFrame | None, barcodes_df: pl.DataFrame, ): - """ - Generate the barcode list used for barcode correction and subsetting + """Generate the barcode list used for barcode correction and subsetting + + Args: + barcode_whitelist (Path): Barcode whitelist df (can hold translation column) + expected_barcodes (int): Number of expected barcodes from user + chemistry (_type_): Chemistry definition + barcode_reference (pl.DataFrame | None): Specific subset given by the user + barcodes_df (pl.DataFrame): Barcodes from the input data + + Returns: + _type_: _description_ """ enable_barcode_correction = True + # Whitelist: True if barcode_whitelist: barcode_subset = parse_barcode_reference( - filename=expected_barcodes, + filename=barcode_whitelist, barcode_length=(chemistry.cell_barcode_end - chemistry.cell_barcode_start), required_header=WHITELIST_COLUMN, ) else: - n_barcodes = barcode_whitelist + # Whitelist: False, Reference: True if barcode_reference is not None: barcode_subset = ( barcodes_df.filter( @@ -393,21 +419,8 @@ def get_barcode_subset( .rename({SEQUENCE_COLUMN: WHITELIST_COLUMN}) ) else: - raw_barcodes_dict = ( - barcodes_df.filter(~pl.col(BARCODE_COLUMN).str.contains("N")) - .group_by(BARCODE_COLUMN) - .agg(pl.count()) - .sort("count", descending=True) - ).to_dict() - barcode_counter = Counter( - zip(raw_barcodes_dict[BARCODE_COLUMN], raw_barcodes_dict["count"]) - ) - true_barcodes = whitelist_method.getKneeEstimateDistance( - cell_barcode_counts=barcode_counter, cell_number=n_barcodes - ) - barcode_subset = pl.DataFrame( - true_barcodes, schema={WHITELIST_COLUMN: pl.Utf8, "counts": pl.UInt32} - ).drop("counts") + # Whitelist: False, Reference: False + barcode_subset = find_knee_estimated_barcodes(barcodes_df=barcodes_df) if n_barcodes > barcode_subset.shape[0]: print( @@ -417,3 +430,17 @@ def get_barcode_subset( ) enable_barcode_correction = False return barcode_subset, enable_barcode_correction + + +def find_knee_estimated_barcodes(barcodes_df): + raw_barcodes_dict = barcodes_df.filter( + ~pl.col(BARCODE_COLUMN).str.contains("N") + ).sort("count", descending=True) + barcode_counter = Counter() + barcode_counts = dict(raw_barcodes_dict.iter_rows()) + barcode_counter.update(barcode_counts) + true_barcodes = whitelist_method.getKneeEstimateDistance( + cell_barcode_counts=barcode_counter + ) + barcode_subset = pl.DataFrame(true_barcodes, schema={WHITELIST_COLUMN: pl.Utf8}) + return barcode_subset diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 5c9c70e..7de0397 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -20,44 +20,12 @@ WHITELIST_COLUMN, BARCODE_COLUMN, CORRECTED_BARCODE_COLUMN, + FEATURE_NAME_COLUMN, R2_COLUMN, UMI_COLUMN, + COUNT_COLUMN, ) -# Unit Barcode correction - - -def find_original_barcode(barcode: str, barcode_tree: pybktree.BKTree, distance: int): - """Pare a BKtree to find the original barcode to correct to. - - Args: - barcode (str): barcode to be corrected - barcode_tree (pybktree.BKTree): Barcode whitelist BKTree - distance (int): Hamming distance to search for - - Returns: - barcode(str): corrected barcode - """ - candidates = [ - white_cell for d, white_cell in barcode_tree.find(barcode, distance) if d > 0 - ] - if len(candidates) == 1: - barcode = candidates[0] - return barcode - - -def merge_results( - mapped_r2_df: pl.DataFrame, - corrected_barcodes_df: pl.DataFrame, - input_df: pl.DataFrame, -): - merged = ( - input_df.join(mapped_r2_df, on=R2_COLUMN, how="inner") - .join(corrected_barcodes_df.drop("count"), on=BARCODE_COLUMN, how="inner") - .drop([R2_COLUMN, BARCODE_COLUMN]) - ) - return merged - def correct_barcodes_pl( barcodes_df: pl.DataFrame, @@ -80,7 +48,7 @@ def correct_barcodes_pl( corrected_barcodes_pl = pl.DataFrame( schema={ BARCODE_COLUMN: pl.Utf8, - "count": pl.Int64, + "count": pl.UInt32, WHITELIST_COLUMN: pl.Utf8, "hamming_distance": pl.UInt8, } @@ -197,135 +165,23 @@ def correct_umis_in_cells(umi_correction_input): return (final_results, corrected_umis, clustered_cells) -def update_umi_counts(UMIclusters, cell_tag_counts): - """ - Update a dict object with umis corrected. - - Args: - UMIclusters (list): List of lists with corrected umis - cell_tag_counts (Counter): Counter of umis - - Returns: - cell_tag_counts (Counter): Updated Counter of umis - temp_corrected_umis (int): Number of corrected umis - """ - temp_corrected_umis = 0 - for ( - umi_cluster - ) in UMIclusters: # This is a list with the first element the dominant barcode - if len(umi_cluster) > 1: # This means we got a correction - major_umi = umi_cluster[0] - for minor_umi in umi_cluster[1:]: - temp_corrected_umis += 1 - temp = cell_tag_counts.pop(minor_umi) - cell_tag_counts[major_umi] += temp - return (cell_tag_counts, temp_corrected_umis) - - -def run_umi_correction(final_results, filtered_cells, unmapped_id, args): - input_queue = [] - umi_correction_input = namedtuple( - "umi_correction_input", - ["cells", "collapsing_threshold", "max_umis", "unmapped_id"], - ) - cells_results = {} - n_cells = 0 - num_chunks = 0 - - print("preparing UMI correction jobs") - cell_batch_size = round(len(filtered_cells) / args.n_threads) + 1 - for cell in filtered_cells: - cells_results[cell] = final_results.pop(cell) - n_cells += 1 - if n_cells % cell_batch_size == 0: - input_queue.append( - umi_correction_input( - cells=cells_results, - collapsing_threshold=args.umi_threshold, - max_umis=20000, - unmapped_id=unmapped_id, - ) - ) - cells_results = {} - num_chunks += 1 - - del final_results - - input_queue.append( - umi_correction_input( - cells=cells_results, - collapsing_threshold=args.umi_threshold, - max_umis=20000, - unmapped_id=unmapped_id, - ) - ) - parallel_results = [] - if args.n_threads != 1: - pool = Pool(processes=args.n_threads) - errors = [] - correct_umis = pool.map_async( - correct_umis_in_cells, - input_queue, - callback=parallel_results.append, - error_callback=errors.append, +def generate_mtx_counts( + main_df: pl.DataFrame, + barcode_subset: pl.DataFrame, + mapped_r2_df: pl.DataFrame, + data_type: str, +) -> pl.DataFrame: + if data_type == "read": + return ( + main_df.join(barcode_subset, on=BARCODE_COLUMN) + .join(mapped_r2_df, on=R2_COLUMN) + .group_by([BARCODE_COLUMN, FEATURE_NAME_COLUMN]) + .agg(pl.sum(COUNT_COLUMN)) ) - - correct_umis.wait() - pool.close() - pool.join() - - if len(errors) != 0: - for error in errors: - print("There was an error {}", error) else: - single_thread_result = correct_umis_in_cells(input_queue[0]) - parallel_results.append([single_thread_result]) - final_results = {} - umis_corrected = 0 - clustered_cells = set() - for chunk in parallel_results[0]: - (temp_results, temp_umis, temp_clustered_cells) = chunk - final_results.update(temp_results) - umis_corrected += temp_umis - clustered_cells.update(temp_clustered_cells) - - return final_results, umis_corrected, clustered_cells - - -def generate_sparse_matrices( - final_results, parsed_tags, filtered_cells, umi_counts=False -): - """ - Create two sparse matrices with umi and read counts. - - Args: - final_results (dict): Results in a dict of dicts of Counters. - parsed_tags (list): Ordered tags in a list of tuples. - - Returns: - results_matrix (scipy.sparse.dok_matrix): UMI or Read counts - - - """ - unmapped_id = len(parsed_tags) - if umi_counts: - n_features = len(parsed_tags) - else: - n_features = len(parsed_tags) + 1 - results_matrix = sparse.dok_matrix((n_features, len(filtered_cells)), dtype=int32) - - for i, cell_barcode in enumerate(filtered_cells): - if cell_barcode not in final_results.keys(): - continue - for TAG_id in final_results[cell_barcode]: - # if TAG_id in final_results[cell_barcode]: - if umi_counts: - if TAG_id == unmapped_id: - continue - else: - results_matrix[TAG_id, i] = len(final_results[cell_barcode][TAG_id]) - else: - results_matrix[TAG_id, i] = sum( - final_results[cell_barcode][TAG_id].values() - ) - return results_matrix + return ( + main_df.join(barcode_subset, on=BARCODE_COLUMN) + .join(mapped_r2_df, on=R2_COLUMN) + .group_by([BARCODE_COLUMN, FEATURE_NAME_COLUMN]) + .agg(pl.count()) + ) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 13748c6..2e1a33f 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -5,7 +5,12 @@ import pytest import polars as pl from polars.testing import assert_frame_equal -from cite_seq_count.preprocessing import parse_barcode_reference, parse_tags_csv, check_tags +from cite_seq_count.preprocessing import ( + parse_barcode_reference, + parse_tags_csv, + check_tags, + find_knee_estimated_barcodes, +) from cite_seq_count.constants import REFERENCE_COLUMN @@ -54,6 +59,12 @@ def data(): ], } ) + pytest.barcodes_df = pl.DataFrame( + { + "barcode": ["ATGCCC", "ATGCTT", "CCGCCC", "ATATCC", "ATATGG"], + "count": [200, 200, 200, 20, 10], + } + ) pytest.barcode_slice = slice(0, 16) pytest.umi_slice = slice(16, 26) pytest.barcode_umi_length = 26 @@ -91,3 +102,7 @@ def test_parse_reference_list_csv(data): def test_check_distance_too_big_between_tags(data): with pytest.raises(SystemExit): check_tags(pytest.correct_tag_pl, 8) + + +def test_find_knee_estimated_barcodes(data): + find_knee_estimated_barcodes(barcodes_df=pytest.barcodes_df) From 292a232af2b111d91ae77386c38ee21ee0eac4b9 Mon Sep 17 00:00:00 2001 From: hoohm Date: Sat, 30 Dec 2023 16:21:56 +0100 Subject: [PATCH 67/77] (test): Tests for IO and preprocessing --- .gitignore | 2 + cite_seq_count/__main__.py | 30 +- cite_seq_count/argsparser.py | 70 ++--- cite_seq_count/chemistry.py | 23 +- cite_seq_count/constants.py | 2 +- cite_seq_count/io.py | 84 ++--- cite_seq_count/preprocessing.py | 140 +++++---- cite_seq_count/processing.py | 39 +-- pytest.ini | 4 + setup.py | 6 +- .../filtered_lists/pass/normal_ref.csv | 2 - .../reference_lists/pass/translation.csv | 4 +- .../fail/different_length.csv | 0 .../fail/with_header.csv | 0 .../fail/wrong_barcode.csv | 0 .../reference_subsets/pass/normal_ref.csv | 3 + .../test_data/whitelists/short_whitelist.csv | 3 + tests/test_io.py | 286 ++++++++++-------- tests/test_preprocessing.py | 249 ++++++++++++--- 19 files changed, 583 insertions(+), 364 deletions(-) create mode 100644 pytest.ini delete mode 100644 tests/test_data/filtered_lists/pass/normal_ref.csv rename tests/test_data/{filtered_lists => reference_subsets}/fail/different_length.csv (100%) rename tests/test_data/{filtered_lists => reference_subsets}/fail/with_header.csv (100%) rename tests/test_data/{filtered_lists => reference_subsets}/fail/wrong_barcode.csv (100%) create mode 100644 tests/test_data/reference_subsets/pass/normal_ref.csv create mode 100644 tests/test_data/whitelists/short_whitelist.csv diff --git a/.gitignore b/.gitignore index a6f833c..f52ee7f 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ CITE_seq_Count.egg-info/ __pycache__ *.pyc .vscode +.ruff_cache +.pytest_cache diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 7f22086..2675d33 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -6,7 +6,17 @@ import os import time -from cite_seq_count import preprocessing, argsparser, mapping, processing, chemistry, io +from scipy import constants + +from cite_seq_count import ( + preprocessing, + argsparser, + mapping, + processing, + chemistry, + io, + constants, +) def main(): @@ -33,7 +43,15 @@ def main(): ) # Get chemistry defs - (barcode_reference, chemistry_def) = chemistry.setup_chemistry(args) + barcode_reference, chemistry_def = chemistry.setup_chemistry(args) + if args.subset_path is not None: + barcode_subset = preprocessing.parse_barcode_file( + filename=args.subset_path, + barcode_length=chemistry_def.barcode_length, + required_header=[constants.REFERENCE_COLUMN], + ) + else: + barcode_subset = None # Load TAGs/ABs. parsed_tags = preprocessing.parse_tags_csv(args.tags) @@ -47,7 +65,7 @@ def main(): read1_paths=read1_paths, chemistry_def=chemistry_def, longest_tag_len=longest_tag_len, - args=args, + arguments=args, ) ( temp_file, @@ -73,7 +91,7 @@ def main(): ) barcode_subset, enable_barcode_correction = preprocessing.get_barcode_subset( - barcode_whitelist=args.filtered_barcodes, + barcode_subset=barcode_subset, n_barcodes=args.expected_barcodes, chemistry=chemistry_def, barcode_reference=barcode_reference, @@ -107,7 +125,7 @@ def main(): io.write_data_to_mtx( main_df=read_counts, tags_df=parsed_tags, - barcodes_df=barcode_subset, + subset_df=barcode_subset, data_type="read", outpath=args.outfolder, ) @@ -130,7 +148,7 @@ def main(): io.write_data_to_mtx( main_df=read_counts, tags_df=parsed_tags, - barcodes_df=barcode_subset, + subset_df=barcode_subset, data_type="umi", outpath=args.outfolder, ) diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py index 8b51455..404836b 100644 --- a/cite_seq_count/argsparser.py +++ b/cite_seq_count/argsparser.py @@ -23,21 +23,6 @@ def get_package_version(): return version -def chunk_size_limit(chunk_size: int) -> int: - """Validates chunk_size limits""" - max_size = 2147483647 - try: - chunk_value = int(chunk_size) - except ValueError: - raise SystemExit("Chunk size must be an int") - if chunk_value < 1 or chunk_value > max_size: - raise ArgumentTypeError( - "Argument must be < " + str(max_size) + "and > " + str(1) - ) - else: - return chunk_value - - def thread_default() -> int: """ Set number of threads default. @@ -186,25 +171,24 @@ def get_args() -> ArgumentParser: default=0, ) barcodes.add_argument( - "-fb", - "-wl", - "--filtered_barcodes", - dest="filtered_barcodes", + "-sub", + "--subset", + dest="subset_path", type=str, help=( - "A path to a specific list of barcodes to look for." + "A path to a subset list of barcodes to look for." "\tExample:\n" - "\twhitelist\n" + "\treference\n" "\tAAACCCAAGAAACACT\nAAACCCAAGAAACCAT\nAAACCCAAGAAACCCA\n" ), - default=False, + default=None, ) if "--chemistry" not in sys.argv: barcodes.add_argument( - "-br", - "--barcode_reference", - dest="barcode_reference", + "-ref", + "--reference", + dest="reference", required=False, type=str, default=False, @@ -239,14 +223,14 @@ def get_args() -> ArgumentParser: help=("Number of bases to discard from read2."), ) - filters.add_argument( - "--sliding-window", - dest="sliding_window", - required=False, - default=False, - action="store_true", - help=("Allow for a sliding window when aligning."), - ) + # filters.add_argument( + # "--sliding-window", + # dest="sliding_window", + # required=False, + # default=False, + # action="store_true", + # help=("Allow for a sliding window when aligning."), + # ) # Parallel group. parallel = parser.add_argument_group( @@ -263,15 +247,6 @@ def get_args() -> ArgumentParser: default=thread_default(), help=("How many threads are to be used for running the program"), ) - parallel.add_argument( - "-C", - "--chunk_size", - required=False, - type=chunk_size_limit, - dest="chunk_size", - help=("How many reads should be sent to a child process at a time"), - ) - # Global group parser.add_argument( "--temp_path", @@ -301,14 +276,6 @@ def get_args() -> ArgumentParser: dest="outfolder", help=("Results will be written to this folder"), ) - parser.add_argument( - "--dense", - required=False, - action="store_true", - default=False, - dest="dense", - help=("Add a dense output to the results folder"), - ) parser.add_argument( "-u", "--unmapped-tags", @@ -327,9 +294,6 @@ def get_args() -> ArgumentParser: default=100, help=("Top n unmapped TAGs."), ) - parser.add_argument( - "--debug", action="store_true", help=("Print extra information for debugging.") - ) parser.add_argument( "--version", action="version", diff --git a/cite_seq_count/chemistry.py b/cite_seq_count/chemistry.py index 340bb45..d38a886 100644 --- a/cite_seq_count/chemistry.py +++ b/cite_seq_count/chemistry.py @@ -5,8 +5,8 @@ from dataclasses import dataclass -from argparse import ArgumentParser -from cite_seq_count.preprocessing import parse_barcode_reference +from argparse import Namespace +from cite_seq_count.preprocessing import parse_barcode_file import polars as pl GLOBAL_LINK_RAW = "https://raw.githubusercontent.com/Hoohm/scg_lib_structs/10xv3_totalseq_b/chemistries/" @@ -26,6 +26,9 @@ class Chemistry: r2_trim_start: int barcode_reference_path: str + def __post_init__(self): + self.barcode_length = self.cell_barcode_end - self.umi_barcode_start + DEFINITIONS_DB = pooch.create( path=pooch.os_cache("cite_seq_count"), @@ -53,7 +56,7 @@ def fetch_definitions() -> dict: return json_data -def list_chemistries(all_chemistry_defs: str) -> None: +def list_chemistries(all_chemistry_defs: dict) -> None: """ List all the available chemistries in the database Args: @@ -117,7 +120,7 @@ def get_chemistry_definition(chemistry_short_name: str) -> Chemistry: return chemistry_def -def create_chemistry_definition(args: ArgumentParser) -> Chemistry: +def create_chemistry_definition(args: Namespace) -> Chemistry: chemistry_def = Chemistry( name="custom", cell_barcode_start=args.cb_first, @@ -125,15 +128,15 @@ def create_chemistry_definition(args: ArgumentParser) -> Chemistry: umi_barcode_start=args.umi_first, umi_barcode_end=args.umi_last, r2_trim_start=args.start_trim, - barcode_reference_path=args.barcode_reference, + barcode_reference_path=args.reference, ) return chemistry_def -def setup_chemistry(args: ArgumentParser) -> tuple[pl.DataFrame | None, Chemistry]: +def setup_chemistry(args: Namespace) -> tuple[pl.DataFrame | None, Chemistry]: if args.chemistry_id: chemistry_def = get_chemistry_definition(args.chemistry_id) - barcode_reference = parse_barcode_reference( + barcode_reference = parse_barcode_file( filename=chemistry_def.barcode_reference_path, barcode_length=chemistry_def.cell_barcode_end - chemistry_def.cell_barcode_start @@ -142,10 +145,10 @@ def setup_chemistry(args: ArgumentParser) -> tuple[pl.DataFrame | None, Chemistr ) else: chemistry_def = create_chemistry_definition(args) - if args.barcode_reference: + if args.reference: print("Loading barcode reference") - barcode_reference = parse_barcode_reference( - filename=args.barcode_reference, + barcode_reference = parse_barcode_file( + filename=args.reference, barcode_length=args.cb_last - args.cb_first + 1, required_header=["reference"], ) diff --git a/cite_seq_count/constants.py b/cite_seq_count/constants.py index 40fa6c7..d71ed31 100644 --- a/cite_seq_count/constants.py +++ b/cite_seq_count/constants.py @@ -17,7 +17,7 @@ # Barcode input REFERENCE_COLUMN = "reference" TRANSLATION_COLUMN = "translation" -WHITELIST_COLUMN = "whitelist" +SUBSET_COLUMN = "subset" STRIP_CHARS = '"0123456789- \t\n' # MTX format diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 6994539..5ffe4d2 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -10,9 +10,9 @@ import json from collections import namedtuple, Counter -from argparse import ArgumentParser +from argparse import Namespace from itertools import islice -from typing import Tuple +from typing import Tuple, TextIO from pathlib import Path from os import access, R_OK @@ -37,7 +37,7 @@ TEMP_MTX, UNMAPPED_NAME, SEQUENCE_COLUMN, - WHITELIST_COLUMN, + SUBSET_COLUMN, ) JSON_REPORT_PATH = pkg_resources.resource_filename(__name__, "templates/report.json") @@ -50,7 +50,7 @@ def compress_file(in_file: Path, out_file: Path): os.remove(in_file) -def blocks(file: Path, size: int = 65536): +def blocks(file: TextIO, size: int = 65536): """ A fast way of counting the lines of a large file. Ref: @@ -92,7 +92,7 @@ def get_n_lines(file_path: Path) -> int: return n_lines -def get_read_paths(read1_path: str, read2_path: str) -> Tuple[str, str]: +def get_read_paths(read1_path: str, read2_path: str) -> Tuple[list[Path], list[Path]]: """ Splits up 2 comma-separated strings of input files into list of files to process. Ensures both lists are equal in length. @@ -104,14 +104,14 @@ def get_read_paths(read1_path: str, read2_path: str) -> Tuple[str, str]: _read1_path (list(string)): list of paths to read1.fq _read2_path (list(string)): list of paths to read2.fq """ - _read1_path = read1_path.split(",") - _read2_path = read2_path.split(",") - if len(_read1_path) != len(_read2_path): + _read1_paths = [Path(path) for path in read1_path.split(",")] + _read2_paths = [Path(path) for path in read2_path.split(",")] + if len(_read1_paths) != len(_read2_paths): sys.exit( - f"Unequal number of read1 ({len(_read1_path)}) and read2({len(_read2_path)}) files provided" + f"Unequal number of read1 ({len(_read1_paths)}) and read2({len(_read2_paths)}) files provided" "\n Exiting" ) - all_files = _read1_path + _read2_path + all_files = _read1_paths + _read2_paths for file_path in all_files: if os.path.isfile(file_path): if os.access(file_path, os.R_OK): @@ -121,10 +121,10 @@ def get_read_paths(read1_path: str, read2_path: str) -> Tuple[str, str]: else: sys.exit(f"{file_path} does not exist. Exiting") - return (_read1_path, _read2_path) + return (_read1_paths, _read2_paths) -def get_csv_reader_from_path(filename: str, sep: str = "\t") -> csv.reader: +def get_csv_reader_from_path(filename: str, sep: str = "\t"): """ Returns a csv_reader object for a file weather it's a flat file or compressed. @@ -203,6 +203,10 @@ def check_file(file_str: str) -> Path: file_path = Path(file_str) if file_path.exists and access(file_path, R_OK): return file_path + else: + raise FileNotFoundError( + f"This file {file_path} does not exist or is not accessible" + ) def write_dense( @@ -280,7 +284,7 @@ def create_report( bad_cells, r1_too_short: int, r2_too_short: int, - args: ArgumentParser, + args: Namespace, chemistry_def, maximum_distance: int, ): @@ -344,7 +348,7 @@ def create_report( def write_chunks_to_disk( - args: ArgumentParser, + args: Namespace, read1_paths: list[Path], read2_paths: list[Path], r2_min_length: int, @@ -488,29 +492,39 @@ def write_chunks_to_disk( total_reads, ) - +def create_mtx_df(main_df:pl.DataFrame, tags_df:pl.DataFrame, subset_df:pl.DataFrame): + tags_indexed = ( + pl.concat( + [ + tags_df, + pl.DataFrame( + {FEATURE_NAME_COLUMN: UNMAPPED_NAME, SEQUENCE_COLUMN: "UNKNOWN"} + ), + ] + ) + .sort(pl.col(FEATURE_NAME_COLUMN)) + .with_row_count(offset=1, name=FEATURE_ID_COLUMN) + ) + barcodes_indexed = subset_df.sort(pl.col(SUBSET_COLUMN)).with_row_count( + offset=1, name=BARCODE_ID_COLUMN + ) + mtx_df = ( + tags_indexed.join(main_df, on=FEATURE_NAME_COLUMN) + .join(barcodes_indexed, left_on=BARCODE_COLUMN, right_on=SUBSET_COLUMN) + .select([FEATURE_ID_COLUMN, BARCODE_ID_COLUMN, COUNT_COLUMN]) + ) + return mtx_df, tags_indexed, barcodes_indexed + def write_data_to_mtx( main_df: pl.DataFrame, tags_df: pl.DataFrame, - barcodes_df: pl.DataFrame, + subset_df: pl.DataFrame, data_type: str, outpath: str, ) -> None: - tags_indexed = pl.concat( - [ - tags_df, - pl.DataFrame( - {FEATURE_NAME_COLUMN: UNMAPPED_NAME, SEQUENCE_COLUMN: "UNKNOWN"} - ), - ] - ).with_row_count(offset=1, name=FEATURE_ID_COLUMN) - barcodes_indexed = barcodes_df.with_row_count(offset=1, name=BARCODE_ID_COLUMN) - mtx_df = ( - tags_indexed.join(main_df, on=FEATURE_NAME_COLUMN) - .join(barcodes_indexed, left_on=BARCODE_COLUMN, right_on=WHITELIST_COLUMN) - .select([FEATURE_ID_COLUMN, BARCODE_ID_COLUMN, COUNT_COLUMN]) - ) + mtx_df, tags_indexed, barcodes_indexed = create_mtx_df(main_df, tags_df, subset_df) data_path = Path(outpath) / f"{data_type}_count" + data_path.mkdir(parents=True, exist_ok=True) mtx_df.write_csv(include_header=False, file=data_path / TEMP_MTX, separator="\t") # Write out the full MTX matrix with open(data_path / TEMP_MTX, "r") as mtx_in: @@ -523,15 +537,15 @@ def write_data_to_mtx( tags_indexed.sort(FEATURE_ID_COLUMN).select(FEATURE_NAME_COLUMN).write_csv( file=data_path / "features.csv", include_header=False ) - compress_file(data_path / "features.csv", FEATURES_MTX) - barcodes_indexed.sort(BARCODE_ID_COLUMN).select(WHITELIST_COLUMN).write_csv( + compress_file(data_path / "features.csv", data_path / FEATURES_MTX) + barcodes_indexed.sort(BARCODE_ID_COLUMN).select(SUBSET_COLUMN).write_csv( file=data_path / "barcodes.csv", include_header=False ) - compress_file(data_path / "barcodes.csv", BARCODE_MTX) + compress_file(data_path / "barcodes.csv", data_path / BARCODE_MTX) def write_mapping_input( - args: ArgumentParser, + args: Namespace, read1_paths: list[Path], read2_paths: list[Path], r2_min_length: int, @@ -568,7 +582,7 @@ def write_mapping_input( temp_file = tempfile.NamedTemporaryFile( "w", dir=temp_path, suffix="_csc", delete=False ) - temp_file_path = temp_file.name + temp_file_path = Path(temp_file.name) reads_written = 0 for read1_path, read2_path in zip(read1_paths, read2_paths): diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index a7f0c18..bc0d219 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -5,12 +5,10 @@ import sys from itertools import combinations, islice from collections import Counter -from argparse import ArgumentParser +from argparse import Namespace from pathlib import Path import Levenshtein import polars as pl -import numpy as np -import numpy.matlib as npm import umi_tools.whitelist_methods as whitelist_method from cite_seq_count.io import get_n_lines, check_file from cite_seq_count.constants import ( @@ -23,15 +21,31 @@ TRANSLATION_COLUMN, OPTIONAL_CELLS_REF_HEADER, STRIP_CHARS, - WHITELIST_COLUMN, + SUBSET_COLUMN, COUNT_COLUMN, ) -def parse_barcode_reference( - filename: str, barcode_length: int, required_header: str +def check_equi_length(df: pl.DataFrame, column_name: str): + """Check that all the barcodes in the specified column of a polars DataFrame are the same length. + + Args: + df (pl.DataFrame): The DataFrame containing the barcodes. + column_name (str): The name of the column containing the barcodes. + + Raises: + ValueError: If the barcodes have different lengths. + + """ + barcode_lengths = df[column_name].str.len_chars().unique() + if len(barcode_lengths) > 1: + raise ValueError(f"Barcodes in {column_name} column have different lengths.") + + +def parse_barcode_file( + filename: str, barcode_length: int, required_header: list ) -> pl.DataFrame: - """Reads white-listed barcodes from a CSV file. + """Reads reference barcodes from a CSV file. The function accepts plain barcodes or even 10X style barcodes with the `-1` at the end of each barcode. @@ -41,14 +55,14 @@ def parse_barcode_reference( barcode_length (int): Length of the expected barcodes. Returns: - set: The set of white-listed barcodes. + set: The set of reference barcodes. """ file_path = check_file(filename) - barcodes_pl = pl.read_csv(file_path.absolute()) - barcode_pattern = rf"^[ATGC]{{{barcode_length}}}" + barcodes_df = pl.read_csv(file_path.absolute()) + barcode_pattern = "^[ATGC]{1,}$" - header = barcodes_pl.columns + header = barcodes_df.columns set_dif = set(required_header) - set(header) if len(set_dif) != 0: set_diff_string = ",".join(list(set_dif)) @@ -57,36 +71,41 @@ def parse_barcode_reference( with_translation = True else: with_translation = False - # Prepare and validate barcodes_pl + # Prepare and validate barcodes_df if with_translation: - barcodes_pl = barcodes_pl.with_columns( + barcodes_df = barcodes_df.with_columns( reference=pl.col(REFERENCE_COLUMN).str.strip_chars(STRIP_CHARS), translation=pl.col(TRANSLATION_COLUMN).str.strip_chars(STRIP_CHARS), ) else: - barcodes_pl = barcodes_pl.with_columns( + barcodes_df = barcodes_df.with_columns( reference=pl.col(REFERENCE_COLUMN).str.strip_chars(STRIP_CHARS), ) check_sequence_pattern( - df=barcodes_pl, + df=barcodes_df, pattern=barcode_pattern, column_name=REFERENCE_COLUMN, file_type="Barcode reference", expected_pattern="ATGC", + filename=filename, ) + check_equi_length(df=barcodes_df, column_name=REFERENCE_COLUMN) + if with_translation: check_sequence_pattern( - df=barcodes_pl, + df=barcodes_df, pattern=barcode_pattern, column_name=TRANSLATION_COLUMN, file_type="Barcode reference", expected_pattern="ATGC", + filename=filename, ) + check_equi_length(df=barcodes_df, column_name=TRANSLATION_COLUMN) - return barcodes_pl + return barcodes_df def parse_tags_csv(file_name: str) -> pl.DataFrame: @@ -134,6 +153,7 @@ def parse_tags_csv(file_name: str) -> pl.DataFrame: column_name=SEQUENCE_COLUMN, file_type="tags", expected_pattern="ATGC", + filename=file_name, ) return data_pl @@ -144,6 +164,7 @@ def check_sequence_pattern( column_name: str, file_type: str, expected_pattern: str, + filename: str, ) -> None: """Check that a column of a polars df matches a given pattern and exit if not @@ -168,9 +189,9 @@ def check_sequence_pattern( ) sequences_str = "\n".join(sequences) raise SystemExit( - f"Some sequences in the {file_type} file is not only composed" - f"of the proper pattern {expected_pattern}.\n" - f"Here are the sequences{sequences_str}" + f"Some sequences in the {file_type} file is not only composed " + f"of {expected_pattern} in the column: {column_name}. " + f"Here are the sequences: {sequences_str}. Filepath: {filename}" ) @@ -221,7 +242,7 @@ def check_tags(tags_pl: pl.DataFrame, maximum_distance: int) -> int: return longest_tag_len -def get_read_length(filename: Path): +def get_read_length(filename: Path) -> int: """Check wether SEQUENCE lengths are consistent in the first 1000 reads from a FASTQ file and return the length. @@ -236,6 +257,7 @@ def get_read_length(filename: Path): with gzip.open(filename, "r") as fastq_file: secondlines = islice(fastq_file, 1, 1000, 4) temp_length = len(next(secondlines).rstrip()) + read_length = 0 for sequence in secondlines: read_length = len(sequence.rstrip()) if temp_length != read_length: @@ -280,9 +302,9 @@ def check_barcodes_lengths( def pre_run_checks( read1_paths: list[Path], - chemistry_def: dict, + chemistry_def, longest_tag_len: int, - args: ArgumentParser, + arguments: Namespace, ): """Checks that the chemistry is properly set and defines how many reads to process @@ -318,8 +340,8 @@ def pre_run_checks( ) # Get all reads or only top N? - if args.first_n < float("inf"): - n_reads = args.first_n + if arguments.first_n < float("inf"): + n_reads = arguments.first_n else: n_reads = total_reads @@ -329,12 +351,12 @@ def pre_run_checks( if number_of_samples != 1: print(f"Detected {number_of_samples} pairs of files to run on.") - if args.sliding_window: - r2_min_length = read2_lengths[0] - maximum_distance = 0 - else: - r2_min_length = longest_tag_len - maximum_distance = args.max_error + # if arguments.sliding_window: + # r2_min_length = read2_lengths[0] + # maximum_distance = 0 + # else: + r2_min_length = longest_tag_len + maximum_distance = arguments.max_error return n_reads, r2_min_length, maximum_distance @@ -376,50 +398,40 @@ def split_data_input( def get_barcode_subset( - barcode_whitelist: Path, + barcode_reference: pl.DataFrame | None, n_barcodes: int, chemistry, - barcode_reference: pl.DataFrame | None, + barcode_subset: pl.DataFrame | None, barcodes_df: pl.DataFrame, -): +) -> tuple[pl.DataFrame, bool]: """Generate the barcode list used for barcode correction and subsetting Args: - barcode_whitelist (Path): Barcode whitelist df (can hold translation column) + barcode_reference (Path): Barcode reference df (can hold translation column) expected_barcodes (int): Number of expected barcodes from user - chemistry (_type_): Chemistry definition - barcode_reference (pl.DataFrame | None): Specific subset given by the user + chemistry (Chemistry): Chemistry definition + barcode_subset (pl.DataFrame | None): Specific subset given by the user barcodes_df (pl.DataFrame): Barcodes from the input data Returns: - _type_: _description_ + tuple[pl.DataFrame, bool]: Barcode subset, enable barcode correction """ enable_barcode_correction = True - # Whitelist: True - if barcode_whitelist: - barcode_subset = parse_barcode_reference( - filename=barcode_whitelist, - barcode_length=(chemistry.cell_barcode_end - chemistry.cell_barcode_start), - required_header=WHITELIST_COLUMN, - ) - else: - # Whitelist: False, Reference: True + # Subset: True + if barcode_subset is None: + # Subset: False, Reference: True if barcode_reference is not None: barcode_subset = ( barcodes_df.filter( - pl.col(BARCODE_COLUMN).str.is_in( - barcode_reference[REFERENCE_COLUMN] - ) + pl.col(BARCODE_COLUMN).is_in(barcode_reference[REFERENCE_COLUMN]) ) - .group_by(BARCODE_COLUMN) - .agg(pl.count()) - .sort("count", descending=True) - .head(n_barcodes * 1.2) - .drop("count") - .rename({SEQUENCE_COLUMN: WHITELIST_COLUMN}) + .sort(COUNT_COLUMN, descending=True) + .head(round(n_barcodes * 1.2)) + .drop(COUNT_COLUMN) + .rename({BARCODE_COLUMN: SUBSET_COLUMN}) ) else: - # Whitelist: False, Reference: False + # Subset: False, Reference: False barcode_subset = find_knee_estimated_barcodes(barcodes_df=barcodes_df) if n_barcodes > barcode_subset.shape[0]: @@ -432,15 +444,23 @@ def get_barcode_subset( return barcode_subset, enable_barcode_correction -def find_knee_estimated_barcodes(barcodes_df): +def find_knee_estimated_barcodes(barcodes_df: pl.DataFrame) -> pl.DataFrame: + """Find the subset of barcodes by the knee method + + Args: + barcodes_df (pl.DataFrame): barcodes to use + + Returns: + pl.DataFrame: Final list of barcodes + """ raw_barcodes_dict = barcodes_df.filter( ~pl.col(BARCODE_COLUMN).str.contains("N") ).sort("count", descending=True) barcode_counter = Counter() - barcode_counts = dict(raw_barcodes_dict.iter_rows()) + barcode_counts = dict(raw_barcodes_dict.iter_rows()) # type: ignore barcode_counter.update(barcode_counts) true_barcodes = whitelist_method.getKneeEstimateDistance( cell_barcode_counts=barcode_counter ) - barcode_subset = pl.DataFrame(true_barcodes, schema={WHITELIST_COLUMN: pl.Utf8}) + barcode_subset = pl.DataFrame(true_barcodes, schema={SUBSET_COLUMN: pl.Utf8}) return barcode_subset diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 7de0397..a1be51e 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -1,25 +1,13 @@ import os -import Levenshtein -import pybktree import polars as pl from rapidfuzz import distance -from collections import namedtuple - -# pylint: disable=no-name-in-module -from multiprocess import Pool - - -from numpy import int32 -from scipy import sparse from umi_tools import network - from cite_seq_count.constants import ( - WHITELIST_COLUMN, + SUBSET_COLUMN, BARCODE_COLUMN, - CORRECTED_BARCODE_COLUMN, FEATURE_NAME_COLUMN, R2_COLUMN, UMI_COLUMN, @@ -32,7 +20,7 @@ def correct_barcodes_pl( barcode_subset_df: pl.DataFrame, hamming_distance: int, ) -> tuple[pl.DataFrame, int, dict]: - """Corrects barcodes using a whitelist based on join_asof from polars. + """Corrects barcodes using a subset based on join_asof from polars. Uses both forward and backward strategy to dinf the closest barcode Args: @@ -49,30 +37,31 @@ def correct_barcodes_pl( schema={ BARCODE_COLUMN: pl.Utf8, "count": pl.UInt32, - WHITELIST_COLUMN: pl.Utf8, + SUBSET_COLUMN: pl.Utf8, "hamming_distance": pl.UInt8, } ) - methods = ["backward", "forward"] + + methods = ["forward", "backward"] for method in methods: current_barcodes = ( barcodes_df.filter( (~pl.col(BARCODE_COLUMN).is_in(corrected_barcodes_pl[BARCODE_COLUMN])) - & (~pl.col(BARCODE_COLUMN).is_in(barcode_subset_df[WHITELIST_COLUMN])) + & (~pl.col(BARCODE_COLUMN).is_in(barcode_subset_df[SUBSET_COLUMN])) ) .sort(BARCODE_COLUMN) .join_asof( - barcode_subset_df.sort(WHITELIST_COLUMN), + barcode_subset_df.sort(SUBSET_COLUMN), left_on=BARCODE_COLUMN, - right_on=WHITELIST_COLUMN, - strategy=method, + right_on=SUBSET_COLUMN, + strategy=method, # type: ignore ) - .filter(~pl.col(WHITELIST_COLUMN).is_null()) + .filter(~pl.col(SUBSET_COLUMN).is_null()) .with_columns( - pl.struct(pl.col(BARCODE_COLUMN), pl.col(WHITELIST_COLUMN)) + pl.struct(pl.col(BARCODE_COLUMN), pl.col(SUBSET_COLUMN)) .map_elements( lambda x: distance.Hamming.distance( - x[BARCODE_COLUMN], x[WHITELIST_COLUMN] + x[BARCODE_COLUMN], x[SUBSET_COLUMN] ), return_dtype=pl.UInt8, ) @@ -82,7 +71,7 @@ def correct_barcodes_pl( ) corrected_barcodes_pl = pl.concat([corrected_barcodes_pl, current_barcodes]) mapped_barcodes = dict( - corrected_barcodes_pl.select(BARCODE_COLUMN, WHITELIST_COLUMN).iter_rows() + corrected_barcodes_pl.select(BARCODE_COLUMN, SUBSET_COLUMN).iter_rows() # type: ignore ) final_corrected = ( barcodes_df.with_columns( @@ -116,8 +105,6 @@ def update_main_df(main_df: pl.DataFrame, mapped_barcodes: dict): # UMI correction section - - def correct_umis_in_cells(umi_correction_input): """ Corrects umi barcodes within same cell/tag groups. diff --git a/pytest.ini b/pytest.ini new file mode 100644 index 0000000..d9c1880 --- /dev/null +++ b/pytest.ini @@ -0,0 +1,4 @@ +[pytest] +filterwarnings = + ignore::DeprecationWarning + ignore::PendingDeprecationWarning \ No newline at end of file diff --git a/setup.py b/setup.py index a251c0c..aac68d7 100644 --- a/setup.py +++ b/setup.py @@ -7,12 +7,12 @@ setuptools.setup( name="CITE-seq-Count", - version="1.5.0", + version="2.0.0", author="Roelli Patrick", author_email="patrick.roelli@gmail.com", description="A python package to map reads from CITE-seq or hashing data for single cell experiments", url="https://github.com/Hoohm/CITE-seq-Count/", - download_url="https://github.com/Hoohm/CITE-seq-Count/archive/1.5.0.tar.gz", + download_url="https://github.com/Hoohm/CITE-seq-Count/archive/2.0.0.tar.gz", packages=setuptools.find_packages(), entry_points={"console_scripts": ["CITE-seq-Count = cite_seq_count.__main__:main"]}, classifiers=( @@ -29,7 +29,7 @@ "pyyaml==6.0", "pooch==1.6.0", "six==1.16.0", - "polars== 0.19.14", + "polars==0.20.3-rc2", ], python_requires="==3.11.6", package_data={"report_template": ["templates/*.json"]}, diff --git a/tests/test_data/filtered_lists/pass/normal_ref.csv b/tests/test_data/filtered_lists/pass/normal_ref.csv deleted file mode 100644 index 1f204ee..0000000 --- a/tests/test_data/filtered_lists/pass/normal_ref.csv +++ /dev/null @@ -1,2 +0,0 @@ -GCTAGCTAGCTAGCTG -TTCATAAGGTAGGGAT \ No newline at end of file diff --git a/tests/test_data/reference_lists/pass/translation.csv b/tests/test_data/reference_lists/pass/translation.csv index e41cdd7..4916451 100644 --- a/tests/test_data/reference_lists/pass/translation.csv +++ b/tests/test_data/reference_lists/pass/translation.csv @@ -1,2 +1,4 @@ reference,translation -ACTGTTTTATTGGCCT,TTCATCCTTTAGGGAT \ No newline at end of file +ATGCCC,GGATCG +ATGCTT,GGATCA +CCGCCC,GATCAA \ No newline at end of file diff --git a/tests/test_data/filtered_lists/fail/different_length.csv b/tests/test_data/reference_subsets/fail/different_length.csv similarity index 100% rename from tests/test_data/filtered_lists/fail/different_length.csv rename to tests/test_data/reference_subsets/fail/different_length.csv diff --git a/tests/test_data/filtered_lists/fail/with_header.csv b/tests/test_data/reference_subsets/fail/with_header.csv similarity index 100% rename from tests/test_data/filtered_lists/fail/with_header.csv rename to tests/test_data/reference_subsets/fail/with_header.csv diff --git a/tests/test_data/filtered_lists/fail/wrong_barcode.csv b/tests/test_data/reference_subsets/fail/wrong_barcode.csv similarity index 100% rename from tests/test_data/filtered_lists/fail/wrong_barcode.csv rename to tests/test_data/reference_subsets/fail/wrong_barcode.csv diff --git a/tests/test_data/reference_subsets/pass/normal_ref.csv b/tests/test_data/reference_subsets/pass/normal_ref.csv new file mode 100644 index 0000000..7f848df --- /dev/null +++ b/tests/test_data/reference_subsets/pass/normal_ref.csv @@ -0,0 +1,3 @@ +reference +ATGCCC +ATGCTT \ No newline at end of file diff --git a/tests/test_data/whitelists/short_whitelist.csv b/tests/test_data/whitelists/short_whitelist.csv new file mode 100644 index 0000000..a578ed3 --- /dev/null +++ b/tests/test_data/whitelists/short_whitelist.csv @@ -0,0 +1,3 @@ +ATGCCC +ATGCTT +CCGCCC \ No newline at end of file diff --git a/tests/test_io.py b/tests/test_io.py index f6625a3..2a3598d 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -2,9 +2,28 @@ import os import gzip import scipy -from cite_seq_count.io import get_n_lines, write_to_files, write_dense, get_read_paths +from pathlib import Path +from polars.testing import assert_frame_equal +from cite_seq_count.constants import ( + BARCODE_COLUMN, + COUNT_COLUMN, + FEATURE_NAME_COLUMN, + SEQUENCE_COLUMN, + SUBSET_COLUMN, + FEATURE_ID_COLUMN, + BARCODE_ID_COLUMN, + UNMAPPED_NAME, +) +from cite_seq_count.io import ( + get_n_lines, + write_to_files, + get_read_paths, + write_data_to_mtx, + create_mtx_df +) from collections import namedtuple import numpy as np +import polars as pl # copied from https://stackoverflow.com/questions/3431825/generating-an-md5-checksum-of-a-file import hashlib @@ -24,142 +43,163 @@ def md5(fname): @pytest.fixture -def data(): - from collections import OrderedDict - from scipy import sparse - - pytest.correct_R1_path = "tests/test_data/fastq/correct_R1.fastq.gz" - pytest.correct_R2_path = "tests/test_data/fastq/correct_R2.fastq.gz" - pytest.corrupt_R1_path = "tests/test_data/fastq/corrupted_R1.fastq.gz" - pytest.corrupt_R2_path = "tests/test_data/fastq/corrupted_R2.fastq.gz" - - pytest.correct_R1_multipath = "tests/test_data/fastq/correct_R1.fastq.gz,tests/test_data/fastq/correct_R1.fastq.gz" - pytest.correct_R2_multipath = "tests/test_data/fastq/correct_R2.fastq.gz,tests/test_data/fastq/correct_R2.fastq.gz" - pytest.incorrect_R2_multipath = ( - "path/to/R2_1.fastq.gz,path/to/R2_2.fastq.gz,path/to/R2_3.fastq.gz" - ) +def correct_R1(): + return Path("tests/test_data/fastq/correct_R1.fastq.gz") - pytest.correct_multipath_result = ( - [pytest.correct_R1_path, pytest.correct_R1_path], - [pytest.correct_R2_path, pytest.correct_R2_path], - ) - test_matrix = sparse.dok_matrix((4, 2), dtype=np.int32) - test_matrix[1, 1] = 1 - pytest.sparse_matrix = test_matrix - pytest.filtered_cells = ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"] - tag = namedtuple("tag", ["name", "sequence", "id"]) - pytest.parsed_tags_map = [ - tag(name="test1", sequence="CGTA", id=0), - tag(name="test2", sequence="CGTA", id=1), - tag(name="test3", sequence="CGTA", id=2), - tag(name="unmapped", sequence="UNKNOWN", id=3), - ] - - pytest.data_type = "umi" - - -def test_write_to_files_wo_translation(data, tmpdir): - output_path = os.path.join(tmpdir, "without_translation") - - mtx_path = os.path.join(output_path, "umi_count", "matrix.mtx.gz") - features_path = os.path.join(output_path, "umi_count", "features.tsv.gz") - barcodes_path = os.path.join(output_path, "umi_count", "barcodes.tsv.gz") - md5_sums = { - barcodes_path: "b7af6a32e83963606f181509a571966f", - features_path: "e889e780dbce481287c993dd043714c8", - mtx_path: "3ea98c44d88a947215bace0c72ac1303", - } - - write_to_files( - pytest.sparse_matrix, - pytest.filtered_cells, - pytest.parsed_tags_map, - pytest.data_type, - output_path, - translation_dict=False, - ) - file_path = os.path.join(tmpdir, "without_translation", "umi_count/matrix.mtx.gz") - with gzip.open(file_path, "rb") as mtx_file: - assert isinstance(scipy.io.mmread(mtx_file), scipy.sparse.coo_matrix) - assert md5_sums[barcodes_path] == md5(barcodes_path) - assert md5_sums[features_path] == md5(features_path) - assert md5_sums[mtx_path] == md5(mtx_path) - - -def test_write_to_files_with_translation(data, tmpdir): - translation_dict = { - "ACTGTTTTATTGGCCT": "GGCTTCGATACTAGAT", - "TTCATAAGGTAGGGAT": "GATCGGATAGCTAATA", - } - output_path = os.path.join(tmpdir, "with_translation") - - mtx_path = os.path.join(output_path, "umi_count", "matrix.mtx.gz") - features_path = os.path.join(output_path, "umi_count", "features.tsv.gz") - barcodes_path = os.path.join(output_path, "umi_count", "barcodes.tsv.gz") - - md5_sums = { - barcodes_path: "fce83378b4dd548882fb9271bdd5b4f1", - features_path: "e889e780dbce481287c993dd043714c8", - mtx_path: "3ea98c44d88a947215bace0c72ac1303", - } - - write_to_files( - pytest.sparse_matrix, - pytest.filtered_cells, - pytest.parsed_tags_map, - pytest.data_type, - output_path, - translation_dict=translation_dict, - ) - file_path = os.path.join(tmpdir, "with_translation", "umi_count/matrix.mtx.gz") - with gzip.open(file_path, "rb") as mtx_file: - assert isinstance(scipy.io.mmread(mtx_file), scipy.sparse.coo_matrix) - assert md5_sums[barcodes_path] == md5(barcodes_path) - assert md5_sums[features_path] == md5(features_path) - assert md5_sums[mtx_path] == md5(mtx_path) - - -def test_write_to_dense_wo_translation(data, tmpdir): - reference_dict = {"ACTGTTTTATTGGCCT": 0, "TTCATAAGGTAGGGAT": 0} - output_path = os.path.join(tmpdir, "without_translation") - csv_name = "dense_umis.tsv" - csv_path = os.path.join(output_path, csv_name) - - md5_sums = { - csv_path: "fef502237900ec386d100169fa1fab7c", - } - - write_dense( - sparse_matrix=pytest.sparse_matrix, - parsed_tags=pytest.parsed_tags_map, - columns=pytest.filtered_cells, - outfolder=output_path, - filename=csv_name, + +@pytest.fixture +def correct_R2(): + return Path("tests/test_data/fastq/correct_R2.fastq.gz") + + +@pytest.fixture +def corrupt_R1(): + return Path("tests/test_data/fastq/corrupted_R1.fastq.gz") + + +@pytest.fixture +def corrupt_R2(): + return Path("tests/test_data/fastq/corrupted_R2.fastq.gz") + + +@pytest.fixture +def correct_R1_multi(): + return "tests/test_data/fastq/correct_R1.fastq.gz,tests/test_data/fastq/correct_R1.fastq.gz" + + +@pytest.fixture +def correct_R2_multi(): + return "tests/test_data/fastq/correct_R2.fastq.gz,tests/test_data/fastq/correct_R2.fastq.gz" + + +@pytest.fixture +def corrupt_R2_multi(): + return "tests/test_data/fastq/correct_R2.fastq.gz,tests/test_data/fastq/corrupted_R2.fastq.gz" + + +@pytest.fixture +def incorrect_R2_multipath(): + return "path/to/R2_1.fastq.gz,path/to/R2_2.fastq.gz,path/to/R2_3.fastq.gz" + + +@pytest.fixture +def correct_multipath_combined(correct_R1, correct_R2): + return ( + [correct_R1, correct_R1], + [correct_R2, correct_R2], ) - file_path = os.path.join(tmpdir, "without_translation", csv_name) - assert md5_sums[csv_path] == md5(file_path) @pytest.mark.dependency() -def test_get_n_lines(data): - assert get_n_lines(pytest.correct_R1_path) == (200 * 4) +def test_get_n_lines(correct_R1): + assert get_n_lines(correct_R1) == (200 * 4) -@pytest.mark.dependency() -def test_corrrect_multipath(data): +def test_corrrect_multipath( + correct_R1_multi, correct_R2_multi, correct_multipath_combined +): assert ( - get_read_paths(pytest.correct_R1_multipath, pytest.correct_R2_multipath) - == pytest.correct_multipath_result + get_read_paths(correct_R1_multi, correct_R2_multi) == correct_multipath_combined ) @pytest.mark.dependency(depends=["test_get_n_lines"]) -def test_incorrrect_multipath(data): +def test_incorrect_multipath(correct_R1_multi, incorrect_R2_multipath): with pytest.raises(SystemExit): - get_read_paths(pytest.correct_R1_multipath, pytest.incorrect_R2_multipath) + get_read_paths(correct_R1_multi, incorrect_R2_multipath) @pytest.mark.dependency(depends=["test_get_n_lines"]) -def test_get_n_lines_not_multiple_of_4(data): +def test_get_n_lines_not_multiple_of_4(corrupt_R1): with pytest.raises(SystemExit): - get_n_lines(pytest.corrupt_R1_path) + get_n_lines(corrupt_R1) + + +def test_write_data_to_mtx(tmp_path): + # Create test data + main_df = pl.DataFrame( + { + FEATURE_NAME_COLUMN: ["test1", "test2", "test3"], + BARCODE_COLUMN: [ + "ACTGTTTTATTGGCCT", + "TTCATAAGGTAGGGAT", + "ACTGTTTTATTGGCCT", + ], + COUNT_COLUMN: [10, 20, 30], + } + ) + tags_df = pl.DataFrame( + { + FEATURE_NAME_COLUMN: ["test1", "test2", "test3"], + SEQUENCE_COLUMN: ["CGTA", "CGTA", "CGTA"], + } + ) + subset_df = pl.DataFrame({SUBSET_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]}) + data_type = "umi" + outpath = str(tmp_path) + print(outpath) + + # Call the function + write_data_to_mtx( + main_df=main_df, + tags_df=tags_df, + subset_df=subset_df, + data_type=data_type, + outpath=outpath, + ) + + # Check if the output files exist + assert os.path.exists(os.path.join(outpath, "umi_count", "matrix.mtx.gz")) + assert os.path.exists(os.path.join(outpath, "umi_count", "features.tsv.gz")) + assert os.path.exists(os.path.join(outpath, "umi_count", "barcodes.tsv.gz")) + + # TODO: Add assertions to validate the content of the output files +def test_create_mtx_df(): + # Create test data + main_df = pl.DataFrame( + { + FEATURE_NAME_COLUMN: ["test1", "test2", "test3"], + BARCODE_COLUMN: [ + "ACTGTTTTATTGGCCT", + "TTCATAAGGTAGGGAT", + "ACTGTTTTATTGGCCT", + ], + COUNT_COLUMN: [10, 20, 30], + } + ) + tags_df = pl.DataFrame( + { + FEATURE_NAME_COLUMN: ["test1", "test2", "test3"], + SEQUENCE_COLUMN: ["CGTA", "CGTA", "CGTA"], + } + ) + subset_df = pl.DataFrame({SUBSET_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]}) + + # Call the function + mtx_df, tags_indexed, barcodes_indexed = create_mtx_df(main_df, tags_df, subset_df) + + # Check the output + expected_mtx_df = pl.DataFrame( + { + FEATURE_ID_COLUMN: [1, 2, 3], + BARCODE_ID_COLUMN: [1, 2, 1], + COUNT_COLUMN: [10, 20, 30], + }, schema={FEATURE_ID_COLUMN: pl.UInt32, BARCODE_ID_COLUMN: pl.UInt32, COUNT_COLUMN: pl.Int64} + ) + expected_tags_indexed = pl.DataFrame( + { + FEATURE_NAME_COLUMN: ["test1", "test2", "test3", UNMAPPED_NAME], + SEQUENCE_COLUMN: ["CGTA", "CGTA", "CGTA", "UNKNOWN"], + FEATURE_ID_COLUMN: [1, 2, 3, 4], + }, schema={FEATURE_ID_COLUMN: pl.UInt32, SEQUENCE_COLUMN: pl.Utf8, FEATURE_NAME_COLUMN: pl.Utf8} + ) + expected_barcodes_indexed = pl.DataFrame( + { + SUBSET_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"], + BARCODE_ID_COLUMN: [1, 2], + }, schema={SUBSET_COLUMN: pl.Utf8, BARCODE_ID_COLUMN: pl.UInt32} + ) + + assert_frame_equal(mtx_df, expected_mtx_df) + assert_frame_equal(tags_indexed, expected_tags_indexed, check_column_order=False) + assert_frame_equal(barcodes_indexed, expected_barcodes_indexed, check_column_order=False) \ No newline at end of file diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 2e1a33f..85498f0 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -6,34 +6,80 @@ import polars as pl from polars.testing import assert_frame_equal from cite_seq_count.preprocessing import ( - parse_barcode_reference, + parse_barcode_file, parse_tags_csv, check_tags, find_knee_estimated_barcodes, + get_barcode_subset, ) -from cite_seq_count.constants import REFERENCE_COLUMN +from cite_seq_count.constants import ( + BARCODE_COLUMN, + COUNT_COLUMN, + REFERENCE_COLUMN, + SUBSET_COLUMN, + TRANSLATION_COLUMN, +) +from cite_seq_count.chemistry import Chemistry + + +@pytest.fixture +def passing_references(): + return "tests/test_data/reference_lists/pass/*.csv" + + +@pytest.fixture +def failing_references(): + return "tests/test_data/reference_lists/fail/*.csv" + + +@pytest.fixture +def passing_subsets(): + return "tests/test_data/reference_subsets/pass/*.csv" @pytest.fixture -def data(): - """Load up data for testing""" +def failing_subsets(): + return "tests/test_data/reference_subsets/fail/*.csv" - pytest.passing_csv = "tests/test_data/tags/pass/*.csv" - pytest.failing_csv = "tests/test_data/tags/fail/*.csv" - pytest.passing_reference_list_csv = "tests/test_data/reference_lists/pass/*.csv" - pytest.failing_reference_list_csv = "tests/test_data/reference_lists/fail/*.csv" +@pytest.fixture +def passing_tags(): + return "tests/test_data/tags/pass/*.csv" + + +@pytest.fixture +def failing_tags(): + return "tests/test_data/tags/fail/*.csv" + + +@pytest.fixture +def correct_reference_df(): + return pl.DataFrame( + { + REFERENCE_COLUMN: ["ATGCCC", "ATGCTT", "CCGCCC"], + TRANSLATION_COLUMN: ["GGATCG", "GGATCA", "GATCAA"], + } + ) + - pytest.passing_filtered_list_csv = "tests/test_data/filtered_lists/pass/*.csv" - pytest.failing_filtered_list_csv = "tests/test_data/filtered_lists/fail/*.csv" +@pytest.fixture +def correct_subset_df(): + return pl.DataFrame({REFERENCE_COLUMN: ["ATGCCC", "ATGCTT"]}) - pytest.correct_tags_path = "tests/test_data/tags/pass/correct.csv" - # Create some variables to compare to - pytest.correct_reference_translation_list = pl.DataFrame( - {"reference": "ACTGTTTTATTGGCCT", "translation": "TTCATCCTTTAGGGAT"} +@pytest.fixture +def barcodes_df(): + return pl.DataFrame( + { + BARCODE_COLUMN: ["ATGCCC", "ATGCTT", "CCGCCC", "ATATCC", "ATATGG"], + COUNT_COLUMN: [200, 200, 200, 20, 10], + } ) - pytest.correct_tag_pl = pl.DataFrame( + + +@pytest.fixture +def correct_tags_df(): + return pl.DataFrame( { "feature_name": [ "CITE_LEN_20_1", @@ -59,50 +105,165 @@ def data(): ], } ) - pytest.barcodes_df = pl.DataFrame( - { - "barcode": ["ATGCCC", "ATGCTT", "CCGCCC", "ATATCC", "ATATGG"], - "count": [200, 200, 200, 20, 10], - } - ) - pytest.barcode_slice = slice(0, 16) - pytest.umi_slice = slice(16, 26) - pytest.barcode_umi_length = 26 -def test_csv_parser(data): - """Test the csv parser +@pytest.fixture +def chemistry_def(): + return Chemistry( + name="test", + cell_barcode_start=1, + cell_barcode_end=6, + umi_barcode_start=7, + umi_barcode_end=12, + r2_trim_start=0, + barcode_reference_path="tests/test_data/reference_lists/pass/translation.csv", + ) - Args: - data (_type_): _description_ - """ - passing_files = glob.glob(pytest.passing_csv) + +def test_passing_parse_reference_list_csv(passing_references, correct_reference_df): + passing_files = glob.glob(passing_references) for file_path in passing_files: - parse_tags_csv(file_path) + assert_frame_equal( + left=parse_barcode_file(file_path, 16, [REFERENCE_COLUMN]), + right=correct_reference_df, + ) + + +def test_failing_parse_reference_list_csv(failing_references): with pytest.raises(SystemExit): - failing_files = glob.glob(pytest.failing_csv) + failing_files = glob.glob(failing_references) for file_path in failing_files: - parse_tags_csv(file_path) + parse_barcode_file(file_path, 16, [REFERENCE_COLUMN]) -@pytest.mark.dependency() -def test_parse_reference_list_csv(data): - passing_files = glob.glob(pytest.passing_reference_list_csv) +def test_parse_subset_list_csv(passing_subsets, failing_subsets, correct_subset_df): + passing_files = glob.glob(passing_subsets) for file_path in passing_files: assert_frame_equal( - left=parse_barcode_reference(file_path, 16, [REFERENCE_COLUMN]), - right=pytest.correct_reference_translation_list, + left=parse_barcode_file(file_path, 16, [REFERENCE_COLUMN]), + right=correct_subset_df, ) with pytest.raises(SystemExit): - failing_files = glob.glob(pytest.failing_reference_list_csv) + failing_files = glob.glob(failing_subsets) for file_path in failing_files: - parse_barcode_reference(file_path, 16, [REFERENCE_COLUMN]) + parse_barcode_file(file_path, 16, [REFERENCE_COLUMN]) -def test_check_distance_too_big_between_tags(data): +def test_check_distance_too_big_between_tags(correct_tags_df): with pytest.raises(SystemExit): - check_tags(pytest.correct_tag_pl, 8) + check_tags(correct_tags_df, 8) + + +# Test if there is no reference and no whitelist +def test_find_knee_estimated_barcodes(barcodes_df): + expected_subset = pl.DataFrame({SUBSET_COLUMN: ["ATGCCC", "ATGCTT", "CCGCCC"]}) + subset = find_knee_estimated_barcodes(barcodes_df=barcodes_df) + assert_frame_equal(subset, expected_subset) -def test_find_knee_estimated_barcodes(data): - find_knee_estimated_barcodes(barcodes_df=pytest.barcodes_df) +@pytest.fixture +def barcode_subset_df(): + return pl.DataFrame({SUBSET_COLUMN: ["ATGCCC", "ATGCTT"]}) + + +def test_get_barcode_subset_with_reference(correct_reference_df, barcodes_df): + expected_subset = pl.DataFrame( + { + SUBSET_COLUMN: ["ATGCCC", "ATGCTT"], + } + ) + expected_enable_correction = True + + subset, enable_correction = get_barcode_subset( + barcode_reference=correct_reference_df, + n_barcodes=2, + chemistry=None, + barcode_subset=None, + barcodes_df=barcodes_df, + ) + + assert_frame_equal(subset, expected_subset) + assert enable_correction == expected_enable_correction + + +def test_get_barcode_subset_without_reference(barcodes_df): + expected_subset = pl.DataFrame({SUBSET_COLUMN: ["ATGCCC", "ATGCTT", "CCGCCC"]}) + expected_enable_correction = True + + subset, enable_correction = get_barcode_subset( + barcode_reference=None, + n_barcodes=3, + chemistry=None, + barcode_subset=None, + barcodes_df=barcodes_df, + ) + + assert_frame_equal(subset, expected_subset) + assert enable_correction == expected_enable_correction + + +def test_get_barcode_subset_with_large_n_barcodes(correct_reference_df, barcodes_df): + expected_subset = pl.DataFrame( + { + SUBSET_COLUMN: ["ATGCCC", "ATGCTT", "CCGCCC"], + } + ) + expected_enable_correction = False + + subset, enable_correction = get_barcode_subset( + barcode_reference=correct_reference_df, + n_barcodes=4, + chemistry=None, + barcode_subset=None, + barcodes_df=barcodes_df, + ) + + assert_frame_equal(subset, expected_subset) + assert enable_correction == expected_enable_correction + + +def test_get_barcode_subset_with_existing_subset( + correct_reference_df, barcode_subset_df, barcodes_df +): + expected_subset = pl.DataFrame({SUBSET_COLUMN: ["ATGCCC", "ATGCTT"]}) + expected_enable_correction = True + + subset, enable_correction = get_barcode_subset( + barcode_reference=correct_reference_df, + n_barcodes=2, + chemistry=None, + barcode_subset=barcode_subset_df, + barcodes_df=barcodes_df, + ) + + assert_frame_equal(subset, expected_subset) + assert enable_correction == expected_enable_correction + + +def test_get_barcode_subset_with_no_subset(correct_reference_df, barcodes_df): + expected_subset = pl.DataFrame( + { + SUBSET_COLUMN: ["ATGCCC", "ATGCTT", "CCGCCC"], + } + ) + expected_enable_correction = True + + subset, enable_correction = get_barcode_subset( + barcode_reference=correct_reference_df, + n_barcodes=3, + chemistry=None, + barcode_subset=None, + barcodes_df=barcodes_df, + ) + + assert_frame_equal(subset, expected_subset) + assert enable_correction == expected_enable_correction + + +def test_parse_tags_csv(correct_tags_df): + file_name = "tests/test_data/tags/pass/correct.csv" + result_df = parse_tags_csv(file_name) + + assert_frame_equal( + result_df, correct_tags_df, check_column_order=False, check_row_order=False + ) From b3401d884e50f0e5decf74705797298cc3343a16 Mon Sep 17 00:00:00 2001 From: hoohm Date: Sat, 30 Dec 2023 17:38:03 +0100 Subject: [PATCH 68/77] (feat): Include yaml report again --- cite_seq_count/__main__.py | 54 +++++++++++++++++------------------- cite_seq_count/io.py | 15 ++++++---- cite_seq_count/mapping.py | 6 ++-- cite_seq_count/processing.py | 23 +++++++++++++++ tests/test_io.py | 25 +++++++++++++---- 5 files changed, 82 insertions(+), 41 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 2675d33..3b65cf3 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -84,12 +84,14 @@ def main(): ) # Remove temp file os.remove(temp_file) - mapped_r2_df = mapping.map_reads_hybrid( + mapped_r2_df, unmapped_r2_df = mapping.map_reads_hybrid( r2_df=r2_df, parsed_tags=parsed_tags, maximum_distance=maximum_distance, ) - + unmapped_df = processing.summarise_unmapped_df( + main_df=main_df, unmapped_r2_df=unmapped_r2_df + ) barcode_subset, enable_barcode_correction = preprocessing.get_barcode_subset( barcode_subset=barcode_subset, n_barcodes=args.expected_barcodes, @@ -149,38 +151,34 @@ def main(): main_df=read_counts, tags_df=parsed_tags, subset_df=barcode_subset, + data_type="read", + outpath=args.outfolder, + ) + io.write_data_to_mtx( + main_df=umi_counts, + tags_df=parsed_tags, + subset_df=barcode_subset, data_type="umi", outpath=args.outfolder, ) # TODO: Write unmapped sequences # TODO: rewrite reporting - # # Create report and write it to disk - # io.create_report( - # total_reads=total_reads, - # no_match=merged_no_match, - # version=argsparser.get_package_version(), - # start_time=start_time, - # umis_corrected=umis_corrected, - # bcs_corrected=bcs_corrected, - # bad_cells=clustered_cells, - # r1_too_short=r1_too_short, - # r2_too_short=r2_too_short, - # args=args, - # chemistry_def=chemistry_def, - # maximum_distance=maximum_distance, - # ) - # TODO: Rewrite dense format output - # # Write dense matrix to disk if requested - # if args.dense: - # print("Writing dense format output") - # io.write_dense( - # sparse_matrix=umi_results_matrix, - # parsed_tags=parsed_tags, - # columns=filtered_cells, - # outfolder=args.outfolder, - # filename="dense_umis.tsv", - # ) + # Create report and write it to disk + io.create_report( + total_reads=total_reads, + unmapped=unmapped_df, + version=argsparser.get_package_version(), + start_time=start_time, + umis_corrected=umis_corrected, + bcs_corrected=n_bcs_corrected, + bad_cells=clustered_cells, + r1_too_short=r1_too_short, + r2_too_short=r2_too_short, + args=args, + chemistry_def=chemistry_def, + maximum_distance=maximum_distance, + ) if __name__ == "__main__": diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 5ffe4d2..0c98c1d 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -276,7 +276,7 @@ def load_report_template() -> dict: def create_report( total_reads: int, - no_match: Counter, + unmapped: pl.DataFrame, version: str, start_time, umis_corrected: int, @@ -299,8 +299,7 @@ def create_report( args (arg_parse): Arguments provided by the user. """ - - total_unmapped = sum(no_match.values()) + total_unmapped = unmapped[COUNT_COLUMN][0] total_too_short = r1_too_short + r2_too_short total_mapped = total_reads - total_unmapped - total_too_short @@ -337,7 +336,7 @@ def create_report( report_data["Run parameters"]["UMI barcode"][ "Last position" ] = chemistry_def.umi_barcode_end - report_data["Expected cells"] = args.expected_cells + report_data["Expected cells"] = args.expected_barcodes report_data["Tags max errors"] = maximum_distance report_data["Start trim"] = chemistry_def.r2_trim_start @@ -492,7 +491,10 @@ def write_chunks_to_disk( total_reads, ) -def create_mtx_df(main_df:pl.DataFrame, tags_df:pl.DataFrame, subset_df:pl.DataFrame): + +def create_mtx_df( + main_df: pl.DataFrame, tags_df: pl.DataFrame, subset_df: pl.DataFrame +): tags_indexed = ( pl.concat( [ @@ -514,7 +516,8 @@ def create_mtx_df(main_df:pl.DataFrame, tags_df:pl.DataFrame, subset_df:pl.DataF .select([FEATURE_ID_COLUMN, BARCODE_ID_COLUMN, COUNT_COLUMN]) ) return mtx_df, tags_indexed, barcodes_indexed - + + def write_data_to_mtx( main_df: pl.DataFrame, tags_df: pl.DataFrame, diff --git a/cite_seq_count/mapping.py b/cite_seq_count/mapping.py index 3eabe9d..7815420 100644 --- a/cite_seq_count/mapping.py +++ b/cite_seq_count/mapping.py @@ -4,6 +4,7 @@ from rapidfuzz import fuzz, process, distance from cite_seq_count.constants import ( + COUNT_COLUMN, SEQUENCE_COLUMN, R2_COLUMN, FEATURE_NAME_COLUMN, @@ -23,7 +24,7 @@ def find_best_match_fast(tag_seq, tags_list, maximum_distance): def map_reads_hybrid( r2_df: pl.DataFrame, parsed_tags: pl.DataFrame, maximum_distance: int -) -> pl.DataFrame: +) -> tuple[pl.DataFrame, pl.DataFrame]: """Map sequence data to a tags reference. Using a hybdrid approach where we first join all the data for the exact matches then using a hamming distance calculation to find the closest match @@ -52,8 +53,9 @@ def map_reads_hybrid( ) .otherwise(pl.col(FEATURE_NAME_COLUMN)) ) + unmapped_r2_df = mapped_r2_df.filter(pl.col(FEATURE_NAME_COLUMN) == UNMAPPED_NAME) print("Mapping done") - return mapped_r2_df + return mapped_r2_df, unmapped_r2_df def check_unmapped(mapped_reads: pl.DataFrame): diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index a1be51e..e370d79 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -12,6 +12,7 @@ R2_COLUMN, UMI_COLUMN, COUNT_COLUMN, + UNMAPPED_NAME, ) @@ -86,6 +87,28 @@ def correct_barcodes_pl( return final_corrected, n_corrected_barcodes, mapped_barcodes +def summarise_unmapped_df(main_df: pl.DataFrame, unmapped_r2_df: pl.DataFrame): + """Merge main df and unmapped df to get a summary of the unmapped reads + + Args: + main_df (pl.DataFrame): _description_ + unmapped_r2_df (pl.DataFrame): _description_ + """ + unmapped_r2_df = ( + unmapped_r2_df.filter(pl.col(FEATURE_NAME_COLUMN) == UNMAPPED_NAME) + .join(main_df, on=R2_COLUMN, how="left") + .with_columns( + pl.when(pl.col(FEATURE_NAME_COLUMN).is_null()) + .then(pl.col(R2_COLUMN)) + .otherwise(pl.col(FEATURE_NAME_COLUMN)) + .alias(FEATURE_NAME_COLUMN) + ) + ) + unmapped_df = unmapped_r2_df.group_by(FEATURE_NAME_COLUMN).agg(pl.count()) + + return unmapped_df + + def update_main_df(main_df: pl.DataFrame, mapped_barcodes: dict): """Update the main data df with the corrected barcodes diff --git a/tests/test_io.py b/tests/test_io.py index 2a3598d..6843c08 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -19,7 +19,7 @@ write_to_files, get_read_paths, write_data_to_mtx, - create_mtx_df + create_mtx_df, ) from collections import namedtuple import numpy as np @@ -154,6 +154,8 @@ def test_write_data_to_mtx(tmp_path): assert os.path.exists(os.path.join(outpath, "umi_count", "barcodes.tsv.gz")) # TODO: Add assertions to validate the content of the output files + + def test_create_mtx_df(): # Create test data main_df = pl.DataFrame( @@ -184,22 +186,35 @@ def test_create_mtx_df(): FEATURE_ID_COLUMN: [1, 2, 3], BARCODE_ID_COLUMN: [1, 2, 1], COUNT_COLUMN: [10, 20, 30], - }, schema={FEATURE_ID_COLUMN: pl.UInt32, BARCODE_ID_COLUMN: pl.UInt32, COUNT_COLUMN: pl.Int64} + }, + schema={ + FEATURE_ID_COLUMN: pl.UInt32, + BARCODE_ID_COLUMN: pl.UInt32, + COUNT_COLUMN: pl.Int64, + }, ) expected_tags_indexed = pl.DataFrame( { FEATURE_NAME_COLUMN: ["test1", "test2", "test3", UNMAPPED_NAME], SEQUENCE_COLUMN: ["CGTA", "CGTA", "CGTA", "UNKNOWN"], FEATURE_ID_COLUMN: [1, 2, 3, 4], - }, schema={FEATURE_ID_COLUMN: pl.UInt32, SEQUENCE_COLUMN: pl.Utf8, FEATURE_NAME_COLUMN: pl.Utf8} + }, + schema={ + FEATURE_ID_COLUMN: pl.UInt32, + SEQUENCE_COLUMN: pl.Utf8, + FEATURE_NAME_COLUMN: pl.Utf8, + }, ) expected_barcodes_indexed = pl.DataFrame( { SUBSET_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"], BARCODE_ID_COLUMN: [1, 2], - }, schema={SUBSET_COLUMN: pl.Utf8, BARCODE_ID_COLUMN: pl.UInt32} + }, + schema={SUBSET_COLUMN: pl.Utf8, BARCODE_ID_COLUMN: pl.UInt32}, ) assert_frame_equal(mtx_df, expected_mtx_df) assert_frame_equal(tags_indexed, expected_tags_indexed, check_column_order=False) - assert_frame_equal(barcodes_indexed, expected_barcodes_indexed, check_column_order=False) \ No newline at end of file + assert_frame_equal( + barcodes_indexed, expected_barcodes_indexed, check_column_order=False + ) From 1ef8ffdd5c7b9bf1b59d4e9e43734e7651c53eaa Mon Sep 17 00:00:00 2001 From: hoohm Date: Mon, 1 Jan 2024 21:22:17 +0100 Subject: [PATCH 69/77] (feat): Mapping in polars only using polars-distance --- cite_seq_count/__main__.py | 4 +- cite_seq_count/argsparser.py | 5 +- cite_seq_count/chemistry.py | 8 +- cite_seq_count/io.py | 149 +-------------------------------- cite_seq_count/mapping.py | 60 ++++++++++++-- tests/test_mapping.py | 154 ++++++++++++++--------------------- 6 files changed, 119 insertions(+), 261 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 3b65cf3..f046c64 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -84,7 +84,7 @@ def main(): ) # Remove temp file os.remove(temp_file) - mapped_r2_df, unmapped_r2_df = mapping.map_reads_hybrid( + mapped_r2_df, unmapped_r2_df = mapping.map_reads_polars( r2_df=r2_df, parsed_tags=parsed_tags, maximum_distance=maximum_distance, @@ -146,6 +146,8 @@ def main(): data_type="umi", ) + umi_counts.write_parquet(file=args.outfolder + "/umi_counts.parquet") + # Write umis to file io.write_data_to_mtx( main_df=read_counts, diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py index 404836b..6799cdf 100644 --- a/cite_seq_count/argsparser.py +++ b/cite_seq_count/argsparser.py @@ -4,13 +4,12 @@ import sys import tempfile -from argparse import ArgumentParser, ArgumentTypeError, RawTextHelpFormatter +from argparse import ArgumentParser, RawTextHelpFormatter import pkg_resources -# pylint: disable=no-name-in-module -from multiprocess import cpu_count +from multiprocessing import cpu_count def get_package_version(): diff --git a/cite_seq_count/chemistry.py b/cite_seq_count/chemistry.py index d38a886..5d43951 100644 --- a/cite_seq_count/chemistry.py +++ b/cite_seq_count/chemistry.py @@ -27,7 +27,7 @@ class Chemistry: barcode_reference_path: str def __post_init__(self): - self.barcode_length = self.cell_barcode_end - self.umi_barcode_start + self.barcode_length = self.cell_barcode_end - self.umi_barcode_start + 1 DEFINITIONS_DB = pooch.create( @@ -138,9 +138,7 @@ def setup_chemistry(args: Namespace) -> tuple[pl.DataFrame | None, Chemistry]: chemistry_def = get_chemistry_definition(args.chemistry_id) barcode_reference = parse_barcode_file( filename=chemistry_def.barcode_reference_path, - barcode_length=chemistry_def.cell_barcode_end - - chemistry_def.cell_barcode_start - + 1, + barcode_length=chemistry_def.barcode_length, required_header=["reference"], ) else: @@ -149,7 +147,7 @@ def setup_chemistry(args: Namespace) -> tuple[pl.DataFrame | None, Chemistry]: print("Loading barcode reference") barcode_reference = parse_barcode_file( filename=args.reference, - barcode_length=args.cb_last - args.cb_first + 1, + barcode_length=chemistry_def.barcode_length, required_header=["reference"], ) else: diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 0c98c1d..5cfae98 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -9,14 +9,13 @@ import tempfile import json -from collections import namedtuple, Counter from argparse import Namespace from itertools import islice +from collections import Counter from typing import Tuple, TextIO from pathlib import Path from os import access, R_OK - import scipy import pkg_resources import yaml @@ -346,152 +345,6 @@ def create_report( yaml.dump(report_data, report_file, default_flow_style=False, sort_keys=False) -def write_chunks_to_disk( - args: Namespace, - read1_paths: list[Path], - read2_paths: list[Path], - r2_min_length: int, - n_reads_per_chunk: int, - chemistry_def, - parsed_tags: pl.DataFrame, - maximum_distance, -): - """ - Writes chunked files of reads to disk and prepares parallel - processing queue parameters. - - Args: - args(argparse): All parsed arguments. - read1_paths (list): List of R1 fastq.gz paths. - read2_paths (list): List of R2 fastq.gz paths. - r2_min_length (int): Minimum length of read2 sequences. - n_reads_per_chunk (int): How many reads per chunk. - chemistry_def (namedtuple): Hols all the information about the chemistry definition. - parsed_tags (list): List of namedtuple tags. - maximum_distance (int): Maximum hamming distance for mapping. - """ - mapping_input = namedtuple( - "mapping_input", - ["filename", "tags", "debug", "maximum_distance", "sliding_window"], - ) - - print("Writing chunks to disk") - - num_chunk = 0 - if not args.chunk_size: - chunk_size = round(n_reads_per_chunk / args.n_threads) - else: - chunk_size = args.chunk_size - temp_path = os.path.abspath(args.temp_path) - input_queue = [] - temp_files = [] - r1_too_short = 0 - r2_too_short = 0 - total_reads = 0 - total_reads_written = 0 - enough_reads = False - - barcode_slice = slice( - chemistry_def.cell_barcode_start - 1, chemistry_def.cell_barcode_end - ) - umi_slice = slice( - chemistry_def.umi_barcode_start - 1, chemistry_def.umi_barcode_end - ) - - chunked_file_object = tempfile.NamedTemporaryFile( - "w", dir=temp_path, suffix="_csc", delete=False - ) - # chunked_file_object = open(temp_file, "w") - temp_files.append(chunked_file_object.name) - reads_written = 0 - - for read1_path, read2_path in zip(read1_paths, read2_paths): - if enough_reads: - break - print(f"Reading reads from files: {read1_path}, {read2_path}") - with gzip.open(read1_path, "rt") as textfile1, gzip.open( - read2_path, "rt" - ) as textfile2: - secondlines = islice(zip(textfile1, textfile2), 1, None, 4) - - for read1, read2 in secondlines: - total_reads += 1 - - read1 = read1.strip() - if len(read1) < chemistry_def.umi_barcode_end: - r1_too_short += 1 - # The entire read is skipped - continue - if len(read2) < r2_min_length: - r2_too_short += 1 - # The entire read is skipped - continue - - read1_sliced = read1[ - chemistry_def.cell_barcode_start - 1 : chemistry_def.umi_barcode_end - ] - - read2_sliced = read2[ - chemistry_def.r2_trim_start : ( - r2_min_length + chemistry_def.r2_trim_start - ) - ] - chunked_file_object.write( - "{},{},{}\n".format( - read1_sliced[barcode_slice], - read1_sliced[umi_slice], - read2_sliced, - ) - ) - - reads_written += 1 - total_reads_written += 1 - if reads_written % chunk_size == 0 and reads_written != 0: - # We have enough reads in this chunk, open a new one - chunked_file_object.close() - input_queue.append( - mapping_input( - filename=chunked_file_object.name, - tags=parsed_tags, - debug=args.debug, - maximum_distance=maximum_distance, - sliding_window=args.sliding_window, - ) - ) - if total_reads_written == n_reads_per_chunk: - enough_reads = True - chunked_file_object.close() - break - num_chunk += 1 - chunked_file_object = tempfile.NamedTemporaryFile( - "w", dir=temp_path, suffix="_csc", delete=False - ) - temp_files.append(chunked_file_object.name) - reads_written = 0 - if total_reads_written == n_reads_per_chunk: - enough_reads = True - chunked_file_object.close() - break - if not enough_reads: - chunked_file_object.close() - input_queue.append( - mapping_input( - filename=chunked_file_object.name, - tags=parsed_tags, - debug=args.debug, - maximum_distance=maximum_distance, - sliding_window=args.sliding_window, - ) - ) - return ( - input_queue, - temp_files, - r1_too_short, - r2_too_short, - total_reads, - ) - - def create_mtx_df( main_df: pl.DataFrame, tags_df: pl.DataFrame, subset_df: pl.DataFrame ): diff --git a/cite_seq_count/mapping.py b/cite_seq_count/mapping.py index 7815420..7c5a66d 100644 --- a/cite_seq_count/mapping.py +++ b/cite_seq_count/mapping.py @@ -1,7 +1,9 @@ """Mapping module. Holds all code related to mapping reads """ +from turtle import right import polars as pl -from rapidfuzz import fuzz, process, distance +from rapidfuzz import fuzz, process +import polars_distance as pld from cite_seq_count.constants import ( COUNT_COLUMN, @@ -12,13 +14,23 @@ ) -def find_best_match_fast(tag_seq, tags_list, maximum_distance): - choices = tags_list[SEQUENCE_COLUMN].to_list() - features = tags_list[FEATURE_NAME_COLUMN].to_list() - res = process.extractOne(choices=choices, query=tag_seq, scorer=fuzz.QRatio) - min_score = (len(tag_seq) - maximum_distance) / len(tag_seq) * 100 - if res[1] >= min_score: - return features[res[2]] +def find_best_match_fast(tag_seq, tags_df, maximum_distance): + min_scores = ( + tags_df.with_columns(pl.col("sequence").str.len_chars().alias("seq_length")) + .with_columns( + ( + (pl.col("seq_length") - maximum_distance) / pl.col("seq_length") * 100 + ).alias("min_score") + )["min_score"] + .to_list() + ) + choices = tags_df[SEQUENCE_COLUMN].to_list() + features = tags_df[FEATURE_NAME_COLUMN].to_list() + _, score, index = process.extractOne( + choices=choices, query=tag_seq, scorer=fuzz.partial_ratio + ) + if score >= min_scores[index]: + return features[index] return UNMAPPED_NAME @@ -46,7 +58,7 @@ def map_reads_hybrid( pl.col(R2_COLUMN) .map_elements( lambda x: find_best_match_fast( - x, tags_list=parsed_tags, maximum_distance=maximum_distance + x, tags_df=parsed_tags, maximum_distance=maximum_distance ) ) .alias(FEATURE_NAME_COLUMN) @@ -58,6 +70,36 @@ def map_reads_hybrid( return mapped_r2_df, unmapped_r2_df +def map_reads_polars( + r2_df: pl.DataFrame, parsed_tags: pl.DataFrame, maximum_distance: int +) -> tuple[pl.DataFrame, pl.DataFrame]: + joined = r2_df.join( + parsed_tags, left_on=R2_COLUMN, right_on=SEQUENCE_COLUMN, how="left" + ) + + simple_join = joined.filter(~pl.col(FEATURE_NAME_COLUMN).is_null()) + hamming_mapped = (joined.filter(pl.col(FEATURE_NAME_COLUMN).is_null()) + .join(parsed_tags, left_on=R2_COLUMN, right_on=SEQUENCE_COLUMN, how="cross") + .drop(FEATURE_NAME_COLUMN) + .with_columns( + pld.col(SEQUENCE_COLUMN) + .dist_str.hamming(pl.col(R2_COLUMN)) + .alias("hamming_dist") + ) + .filter(pl.col("hamming_dist") <= maximum_distance) + .drop([SEQUENCE_COLUMN, "hamming_dist"]) + .rename({"feature_name_right": FEATURE_NAME_COLUMN}) + ) + multi_mapped = hamming_mapped.group_by(R2_COLUMN).agg(pl.count()).filter(pl.col(COUNT_COLUMN)>1) + print(simple_join) + print(hamming_mapped) + mapped = pl.concat([simple_join, hamming_mapped.filter(~pl.col(R2_COLUMN).is_in(multi_mapped[R2_COLUMN]))]) + unmapped = joined.filter(~pl.col(R2_COLUMN).is_in(mapped[R2_COLUMN])).with_columns( + pl.col(FEATURE_NAME_COLUMN).fill_null(UNMAPPED_NAME) + ) + return mapped, unmapped + + def check_unmapped(mapped_reads: pl.DataFrame): """_summary_ diff --git a/tests/test_mapping.py b/tests/test_mapping.py index 0991a89..202d760 100644 --- a/tests/test_mapping.py +++ b/tests/test_mapping.py @@ -1,9 +1,12 @@ +from ast import parse import pytest import random import copy -from collections import Counter, namedtuple from cite_seq_count import mapping +from polars.testing import assert_frame_equal from cite_seq_count.preprocessing import parse_tags_csv +from cite_seq_count.constants import R2_COLUMN, SEQUENCE_COLUMN, FEATURE_NAME_COLUMN +import polars as pl def complete_poly_A(seq, final_length=40): @@ -49,97 +52,58 @@ def modify(seq, n, modification_type): @pytest.fixture -def data(): - from collections import Counter - - pytest.file_path = "tests/test_data/fastq/test_csv.csv" - pytest.debug = False - pytest.barcode_slice = slice(0, 16) - pytest.umi_slice = slice(16, 26) - pytest.correct_reference_list = set(["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]) - pytest.maximum_distance = 5 - pytest.results = { - "ACTGTTTTATTGGCCT": { - 0: Counter({b"CATTAGTGGT": 3, b"CATTAGTGGG": 2, b"CATTCGTGGT": 1}) - }, - "TTCATAAGGTAGGGAT": { - 1: Counter({b"TAGCTTAGTA": 3, b"TAGCTTAGTC": 2, b"GCGATGCATA": 1}) - }, - } - - pytest.sliding_window = False - pytest.sequence_pool = [] - pytest.tags_tuple = parse_tags_csv( - parse_tags_csv("tests/test_data/tags/pass/correct.csv") - ) - pytest.mapping_input = namedtuple( - "mapping_input", - ["filename", "tags", "debug", "maximum_distance", "sliding_window"], - ) - pytest.mappint_input_test = pytest.mapping_input( - filename=pytest.file_path, - tags=pytest.tags_tuple, - debug=pytest.debug, - maximum_distance=pytest.maximum_distance, - sliding_window=pytest.sliding_window, - ) - - -@pytest.mark.dependency() -def test_find_best_match_with_1_distance(data): - distance = 1 - for tag in pytest.tags_tuple: - counts = Counter() - if tag.name == "unmapped": - continue - for seq in extend_seq_pool(tag.sequence, distance): - counts[mapping.find_best_match(seq, pytest.tags_tuple, distance)] += 1 - assert counts[tag.id] == 4 - - -@pytest.mark.dependency() -def test_find_best_match_with_2_distance(data): - distance = 2 - for tag in pytest.tags_tuple: - counts = Counter() - if tag.name == "unmapped": - continue - for seq in extend_seq_pool(tag.sequence, distance): - counts[mapping.find_best_match(seq, pytest.tags_tuple, distance)] += 1 - assert counts[tag.id] == 4 - - -@pytest.mark.dependency() -def test_find_best_match_with_3_distance(data): - distance = 3 - for tag in pytest.tags_tuple: - counts = Counter() - for seq in extend_seq_pool(tag.sequence, distance): - counts[mapping.find_best_match(seq, pytest.tags_tuple, distance)] += 1 - assert counts[tag.id] == 4 - - -@pytest.mark.dependency() -def test_find_best_match_with_3_distance_reverse(data): - distance = 3 - for tag in pytest.tags_tuple: - counts = Counter() - if tag.name == "unmapped": - continue - for seq in extend_seq_pool(tag.sequence, distance): - counts[mapping.find_best_match(seq, pytest.tags_tuple, distance)] += 1 - assert counts[tag.id] == 4 - - -@pytest.mark.dependency( - depends=[ - "test_find_best_match_with_1_distance", - "test_find_best_match_with_2_distance", - "test_find_best_match_with_3_distance", - "test_find_best_match_with_3_distance_reverse", - ] -) -def test_classify_reads_multi_process(data): - (results, _) = mapping.map_reads(pytest.mappint_input_test) - print(results) - assert len(results) == 2 +def small_dataset_path(): + return "tests/test_data/fastq/test_csv.csv" + + +@pytest.fixture +def parsed_tags_df(): + return parse_tags_csv("tests/test_data/tags/pass/correct.csv") + + +@pytest.fixture +def r2_df(): + # Create a sample DataFrame for r2_df + return pl.DataFrame({ + R2_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT", "AGCTAGCTAGCTAGCT"], + }) + +@pytest.fixture +def parsed_tags(): + # Create a sample DataFrame for parsed_tags + return pl.DataFrame({ + SEQUENCE_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"], + FEATURE_NAME_COLUMN: ["feature1", "feature2"] + }) + +def test_map_reads_polars_with_dist1(r2_df, parsed_tags): + maximum_distance = 1 + expected_mapped = pl.DataFrame({ + R2_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"], + FEATURE_NAME_COLUMN: ["feature1", "feature2"] + }) + expected_unmapped = pl.DataFrame({ + R2_COLUMN: ["AGCTAGCTAGCTAGCT"], + FEATURE_NAME_COLUMN: ["unmapped"] + }) + + mapped, unmapped = mapping.map_reads_polars(r2_df, parsed_tags, maximum_distance) + + assert_frame_equal(mapped, expected_mapped) + assert_frame_equal(unmapped, expected_unmapped) + +def test_map_reads_polars_with_dist2(r2_df, parsed_tags): + maximum_distance = 2 + expected_mapped = pl.DataFrame({ + R2_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"], + FEATURE_NAME_COLUMN: ["feature1", "feature2"] + }) + expected_unmapped = pl.DataFrame({ + R2_COLUMN: ["AGCTAGCTAGCTAGCT"], + FEATURE_NAME_COLUMN: ["unmapped"] + }) + + mapped, unmapped = mapping.map_reads_polars(r2_df, parsed_tags, maximum_distance) + + assert_frame_equal(mapped, expected_mapped) + assert_frame_equal(unmapped, expected_unmapped) \ No newline at end of file From 993f9980ea679ce3b9150ec931fa833731edcce4 Mon Sep 17 00:00:00 2001 From: hoohm Date: Wed, 3 Jan 2024 13:33:46 +0100 Subject: [PATCH 70/77] (feat): First attempt at UMI correction --- cite_seq_count/__main__.py | 46 ++++++++++++------------ cite_seq_count/io.py | 31 ++--------------- cite_seq_count/mapping.py | 40 ++++++++++++--------- cite_seq_count/preprocessing.py | 9 +++-- cite_seq_count/processing.py | 62 +++++++++++++++++++++++++++++++-- setup.py | 2 +- tests/test_mapping.py | 57 +++++++++++++++++------------- 7 files changed, 152 insertions(+), 95 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index f046c64..f33c54a 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -18,10 +18,10 @@ constants, ) +import polars as pl def main(): """Main""" - start_time = time.time() parser = argsparser.get_args() if not sys.argv[1:]: @@ -80,7 +80,7 @@ def main(): chemistry_def=chemistry_def, ) main_df, barcodes_df, r2_df = preprocessing.split_data_input( - mapping_input_path=temp_file + mapping_input_path=temp_file, n_reads=n_reads ) # Remove temp file os.remove(temp_file) @@ -115,15 +115,20 @@ def main(): main_df=main_df, mapped_barcodes=mapped_barcodes ) else: - print("Skipping cell barcode correction") n_bcs_corrected = 0 - read_counts = processing.generate_mtx_counts( - main_df=main_df, - barcode_subset=barcodes_df, - mapped_r2_df=mapped_r2_df, - data_type="read", - ) - # Write reads to file + print("UMI correction") + read_counts = processing.find_corrected_umis(main_df=main_df, mapped_r2_df=mapped_r2_df) + # Don't correct + umis_corrected = 0 + clustered_cells = [] + + # read_counts = processing.generate_mtx_counts( + # main_df=main_df, + # barcode_subset=barcodes_df, + # mapped_r2_df=mapped_r2_df, + # data_type="read", + # ) + # # Write reads to file io.write_data_to_mtx( main_df=read_counts, tags_df=parsed_tags, @@ -131,21 +136,18 @@ def main(): data_type="read", outpath=args.outfolder, ) - # TODO: add clustered cells filter: Max UMIs per cell per feature: 20000 - print("UMI correction not implemented yet") - # Don't correct - umis_corrected = 0 - clustered_cells = [] + # TODO: Write out to mtx and csv clustered cells # Generate the UMI count matrix - umi_counts = processing.generate_mtx_counts( - main_df=main_df, - barcode_subset=barcodes_df, - mapped_r2_df=mapped_r2_df, - data_type="umi", - ) - + # umi_counts = processing.generate_mtx_counts( + # main_df=main_df, + # barcode_subset=barcodes_df, + # mapped_r2_df=mapped_r2_df, + # data_type="umi", + # ) + print(read_counts) + umi_counts = read_counts.group_by(["barcode", "feature_name"]).agg(pl.count()) umi_counts.write_parquet(file=args.outfolder + "/umi_counts.parquet") # Write umis to file diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 5cfae98..884ccd7 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -208,34 +208,6 @@ def check_file(file_str: str) -> Path: ) -def write_dense( - sparse_matrix: scipy.sparse.coo_matrix, - parsed_tags: dict, - columns: set, - outfolder: str, - filename: str, -): - """ - Writes a dense matrix in a csv format - - Args: - sparse_matrix (dok_matrix): Results in a sparse matrix. - index (list): List of TAGS - columns (set): List of cells - outfolder (str): Output folder - filename (str): Filename - """ - prefix = os.path.join(outfolder) - os.makedirs(prefix, exist_ok=True) - index = [] - for tag in parsed_tags: - index.append(tag.name) - pandas_dense = pd.DataFrame( - sparse_matrix.todense(), columns=list(columns), index=index - ) - pandas_dense.to_csv(os.path.join(outfolder, filename), sep="\t") - - def write_unmapped( merged_no_match: Counter, top_unknowns: int, outfolder: str, filename: str ): @@ -489,3 +461,6 @@ def write_mapping_input( r2_too_short, total_reads, ) + + + diff --git a/cite_seq_count/mapping.py b/cite_seq_count/mapping.py index 7c5a66d..c1f22a4 100644 --- a/cite_seq_count/mapping.py +++ b/cite_seq_count/mapping.py @@ -78,22 +78,30 @@ def map_reads_polars( ) simple_join = joined.filter(~pl.col(FEATURE_NAME_COLUMN).is_null()) - hamming_mapped = (joined.filter(pl.col(FEATURE_NAME_COLUMN).is_null()) - .join(parsed_tags, left_on=R2_COLUMN, right_on=SEQUENCE_COLUMN, how="cross") - .drop(FEATURE_NAME_COLUMN) - .with_columns( - pld.col(SEQUENCE_COLUMN) - .dist_str.hamming(pl.col(R2_COLUMN)) - .alias("hamming_dist") - ) - .filter(pl.col("hamming_dist") <= maximum_distance) - .drop([SEQUENCE_COLUMN, "hamming_dist"]) - .rename({"feature_name_right": FEATURE_NAME_COLUMN}) - ) - multi_mapped = hamming_mapped.group_by(R2_COLUMN).agg(pl.count()).filter(pl.col(COUNT_COLUMN)>1) - print(simple_join) - print(hamming_mapped) - mapped = pl.concat([simple_join, hamming_mapped.filter(~pl.col(R2_COLUMN).is_in(multi_mapped[R2_COLUMN]))]) + hamming_mapped = ( + joined.filter(pl.col(FEATURE_NAME_COLUMN).is_null()) + .join(parsed_tags, left_on=R2_COLUMN, right_on=SEQUENCE_COLUMN, how="cross") + .drop(FEATURE_NAME_COLUMN) + .with_columns( + pld.col(SEQUENCE_COLUMN) + .dist_str.hamming(pl.col(R2_COLUMN)) + .alias("hamming_dist") + ) + .filter(pl.col("hamming_dist") <= maximum_distance) + .drop([SEQUENCE_COLUMN, "hamming_dist"]) + .rename({"feature_name_right": FEATURE_NAME_COLUMN}) + ) + multi_mapped = ( + hamming_mapped.group_by(R2_COLUMN) + .agg(pl.count()) + .filter(pl.col(COUNT_COLUMN) > 1) + ) + mapped = pl.concat( + [ + simple_join, + hamming_mapped.filter(~pl.col(R2_COLUMN).is_in(multi_mapped[R2_COLUMN])), + ] + ) unmapped = joined.filter(~pl.col(R2_COLUMN).is_in(mapped[R2_COLUMN])).with_columns( pl.col(FEATURE_NAME_COLUMN).fill_null(UNMAPPED_NAME) ) diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index bc0d219..fcd4da2 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -269,6 +269,10 @@ def get_read_length(filename: Path) -> int: return read_length +def read_fastq_to_polars(): + pass + + def check_barcodes_lengths( read1_length: int, cb_first: int, cb_last: int, umi_first: int, umi_last: int ): @@ -326,7 +330,7 @@ def pre_run_checks( for read1_path in read1_paths: n_lines = get_n_lines(read1_path) - total_reads += n_lines / 4 + total_reads += round(n_lines / 4) # Get reads length. So far, there is no validation for Read2. read1_lengths.append(get_read_length(read1_path)) @@ -361,7 +365,7 @@ def pre_run_checks( def split_data_input( - mapping_input_path: Path, + mapping_input_path: Path, n_reads: int ) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]: """Read in all the input data and split it into three dataframes. @@ -382,6 +386,7 @@ def split_data_input( mapping_input_path, has_header=False, new_columns=[BARCODE_COLUMN, UMI_COLUMN, R2_COLUMN], + n_rows=n_reads ) .group_by([BARCODE_COLUMN, UMI_COLUMN, R2_COLUMN]) .agg(pl.count()) diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index e370d79..5f3372e 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -15,7 +15,6 @@ UNMAPPED_NAME, ) - def correct_barcodes_pl( barcodes_df: pl.DataFrame, barcode_subset_df: pl.DataFrame, @@ -120,13 +119,72 @@ def update_main_df(main_df: pl.DataFrame, mapped_barcodes: dict): pl.DataFrame: Data of all reads with barcodes corrected """ main_df = ( - main_df.with_columns(pl.col(BARCODE_COLUMN).map_dict(mapped_barcodes)) + main_df.with_columns( + pl.col(BARCODE_COLUMN).map_dict(mapped_barcodes, default=pl.first()) + ) .group_by([BARCODE_COLUMN, UMI_COLUMN, R2_COLUMN]) .agg(pl.sum("count")) ) return main_df +def find_corrected_umis( + main_df, mapped_r2_df, umi_distance=1, cluster_method="directional", max_umis=20000 +): + merged = mapped_r2_df.join(main_df, on="r2") + temp = ( + merged.with_columns(umi=pl.col("umi").cast(pl.Binary)) + .group_by(["r2", "feature_name", "barcode"]) + .agg(pl.struct(pl.col("umi"), pl.col("count"))) + .filter((pl.col("umi").list.len() > 1) & (pl.col("umi").list.len() < max_umis)) + ) + mapping = {"r2": [], "feature_name": [], "barcode": [], "orig": [], "replace": []} + for r2, feature_name, barcode, umis in temp.iter_rows(): + corrected_umis = correct_umis( + umis, cluster_method=cluster_method, umi_distance=umi_distance + ) + if len(corrected_umis) == 0: + continue + for umi_set in corrected_umis: + for index, umi in enumerate(umi_set): + if index != 0: + mapping["r2"].append(r2) + mapping["feature_name"].append(feature_name) + mapping["barcode"].append(barcode) + mapping["orig"].append(umi) + mapping["replace"].append(umi_set[0]) + mapping_df = ( + pl.DataFrame(mapping) + .with_columns( + pl.col("orig").cast(pl.String).alias("umi"), + pl.col("replace").cast(pl.String), + ) + .drop("orig") + ) + read_counts = ( + merged.join(mapping_df, on=["r2", "feature_name", "barcode", "umi"], how="left") + .with_columns( + pl.when(pl.col("replace").is_null()) + .then(pl.col("umi")) + .otherwise(pl.col("replace")) + .alias("umi") + ) + .drop("replace") + .group_by(["r2", "barcode", "feature_name", "umi"]) + .agg(pl.sum("count")).drop("r2") + ) + return read_counts + + +def correct_umis(umis_list, cluster_method, umi_distance): + umi_clusterer = network.UMIClusterer(cluster_method=cluster_method) + umis = dict([(i["umi"], i["count"]) for i in umis_list]) + + res = umi_clusterer(umis, umi_distance) + corrected = [corrected_umis for corrected_umis in res if len(corrected_umis) > 1] + return corrected + + # UMI correction section def correct_umis_in_cells(umi_correction_input): """ diff --git a/setup.py b/setup.py index aac68d7..4c36c1f 100644 --- a/setup.py +++ b/setup.py @@ -29,7 +29,7 @@ "pyyaml==6.0", "pooch==1.6.0", "six==1.16.0", - "polars==0.20.3-rc2", + "polars==0.20.3", ], python_requires="==3.11.6", package_data={"report_template": ["templates/*.json"]}, diff --git a/tests/test_mapping.py b/tests/test_mapping.py index 202d760..c369265 100644 --- a/tests/test_mapping.py +++ b/tests/test_mapping.py @@ -64,46 +64,55 @@ def parsed_tags_df(): @pytest.fixture def r2_df(): # Create a sample DataFrame for r2_df - return pl.DataFrame({ - R2_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT", "AGCTAGCTAGCTAGCT"], - }) + return pl.DataFrame( + { + R2_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT", "AGCTAGCTAGCTAGCT"], + } + ) + @pytest.fixture def parsed_tags(): # Create a sample DataFrame for parsed_tags - return pl.DataFrame({ - SEQUENCE_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"], - FEATURE_NAME_COLUMN: ["feature1", "feature2"] - }) + return pl.DataFrame( + { + SEQUENCE_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"], + FEATURE_NAME_COLUMN: ["feature1", "feature2"], + } + ) + def test_map_reads_polars_with_dist1(r2_df, parsed_tags): maximum_distance = 1 - expected_mapped = pl.DataFrame({ - R2_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"], - FEATURE_NAME_COLUMN: ["feature1", "feature2"] - }) - expected_unmapped = pl.DataFrame({ - R2_COLUMN: ["AGCTAGCTAGCTAGCT"], - FEATURE_NAME_COLUMN: ["unmapped"] - }) + expected_mapped = pl.DataFrame( + { + R2_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"], + FEATURE_NAME_COLUMN: ["feature1", "feature2"], + } + ) + expected_unmapped = pl.DataFrame( + {R2_COLUMN: ["AGCTAGCTAGCTAGCT"], FEATURE_NAME_COLUMN: ["unmapped"]} + ) mapped, unmapped = mapping.map_reads_polars(r2_df, parsed_tags, maximum_distance) assert_frame_equal(mapped, expected_mapped) assert_frame_equal(unmapped, expected_unmapped) + def test_map_reads_polars_with_dist2(r2_df, parsed_tags): maximum_distance = 2 - expected_mapped = pl.DataFrame({ - R2_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"], - FEATURE_NAME_COLUMN: ["feature1", "feature2"] - }) - expected_unmapped = pl.DataFrame({ - R2_COLUMN: ["AGCTAGCTAGCTAGCT"], - FEATURE_NAME_COLUMN: ["unmapped"] - }) + expected_mapped = pl.DataFrame( + { + R2_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"], + FEATURE_NAME_COLUMN: ["feature1", "feature2"], + } + ) + expected_unmapped = pl.DataFrame( + {R2_COLUMN: ["AGCTAGCTAGCTAGCT"], FEATURE_NAME_COLUMN: ["unmapped"]} + ) mapped, unmapped = mapping.map_reads_polars(r2_df, parsed_tags, maximum_distance) assert_frame_equal(mapped, expected_mapped) - assert_frame_equal(unmapped, expected_unmapped) \ No newline at end of file + assert_frame_equal(unmapped, expected_unmapped) From 416d47599c740acf4bf1382b718d3ddcddd03934 Mon Sep 17 00:00:00 2001 From: hoohm Date: Thu, 4 Jan 2024 11:50:00 +0100 Subject: [PATCH 71/77] (fix): Mtx writing --- cite_seq_count/__main__.py | 35 ++----- cite_seq_count/argsparser.py | 6 +- cite_seq_count/io.py | 27 +++-- cite_seq_count/preprocessing.py | 2 +- cite_seq_count/processing.py | 176 +++++++++++++++----------------- 5 files changed, 116 insertions(+), 130 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index f33c54a..6f1cba9 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -20,6 +20,7 @@ import polars as pl + def main(): """Main""" start_time = time.time() @@ -117,42 +118,24 @@ def main(): else: n_bcs_corrected = 0 print("UMI correction") - read_counts = processing.find_corrected_umis(main_df=main_df, mapped_r2_df=mapped_r2_df) - # Don't correct - umis_corrected = 0 - clustered_cells = [] - - # read_counts = processing.generate_mtx_counts( - # main_df=main_df, - # barcode_subset=barcodes_df, - # mapped_r2_df=mapped_r2_df, - # data_type="read", - # ) + final_read_counts, umis_corrected, clustered_cells = processing.correct_umis_df( + main_df=main_df, mapped_r2_df=mapped_r2_df + ) + # # Write reads to file io.write_data_to_mtx( - main_df=read_counts, + main_df=final_read_counts, tags_df=parsed_tags, subset_df=barcode_subset, data_type="read", outpath=args.outfolder, ) - - # TODO: Write out to mtx and csv clustered cells - - # Generate the UMI count matrix - # umi_counts = processing.generate_mtx_counts( - # main_df=main_df, - # barcode_subset=barcodes_df, - # mapped_r2_df=mapped_r2_df, - # data_type="umi", - # ) - print(read_counts) - umi_counts = read_counts.group_by(["barcode", "feature_name"]).agg(pl.count()) - umi_counts.write_parquet(file=args.outfolder + "/umi_counts.parquet") + umi_counts = processing.generate_umi_counts(read_counts=final_read_counts) + io.write_out_parquet(df=umi_counts, outpath=args.outfolder, filename="umi_counts") # Write umis to file io.write_data_to_mtx( - main_df=read_counts, + main_df=final_read_counts, tags_df=parsed_tags, subset_df=barcode_subset, data_type="read", diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py index 6799cdf..f87b121 100644 --- a/cite_seq_count/argsparser.py +++ b/cite_seq_count/argsparser.py @@ -3,7 +3,7 @@ import sys import tempfile - +from pathlib import Path from argparse import ArgumentParser, RawTextHelpFormatter import pkg_resources @@ -270,8 +270,8 @@ def get_args() -> ArgumentParser: "-o", "--output", required=False, - type=str, - default="results", + type=Path, + default=Path("results"), dest="outfolder", help=("Results will be written to this folder"), ) diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 884ccd7..38bae07 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -68,6 +68,18 @@ def blocks(file: TextIO, size: int = 65536): yield partial_file +def write_out_parquet(df: pl.DataFrame, outpath: Path, filename: str): + """ + Writes out a dataframe to parquet format. + + Args: + df (pl.DataFrame): A dataframe to write out + outpath (Path): Path to the output folder + filename (str): Name of the file + """ + df.write_parquet(outpath / f"{filename}.parquet") + + def get_n_lines(file_path: Path) -> int: """ Determines how many lines have to be processed @@ -353,13 +365,17 @@ def write_data_to_mtx( mtx_df, tags_indexed, barcodes_indexed = create_mtx_df(main_df, tags_df, subset_df) data_path = Path(outpath) / f"{data_type}_count" data_path.mkdir(parents=True, exist_ok=True) - mtx_df.write_csv(include_header=False, file=data_path / TEMP_MTX, separator="\t") + mtx_df.write_csv(include_header=False, file=data_path / TEMP_MTX, separator=" ") # Write out the full MTX matrix + first_line_mtx = f"{tags_indexed.shape[0]} {barcodes_indexed.shape[0]} {mtx_df.shape[0]}\n" with open(data_path / TEMP_MTX, "r") as mtx_in: mtx_main = mtx_in.read() - final_mtx = MTX_HEADER + mtx_main - with open(data_path / MATRIX_MTX, "wb") as mtx_out: - mtx_out.write(final_mtx.encode()) + final_mtx = MTX_HEADER + first_line_mtx + mtx_main + with open(data_path / TEMP_MTX, "w") as mtx_out: + mtx_out.write(final_mtx) + with open(data_path / TEMP_MTX, "rb") as mtx_in: + with gzip.open(data_path / MATRIX_MTX, "wb") as mtx_gz: + shutil.copyfileobj(mtx_in, mtx_gz) os.remove(data_path / TEMP_MTX) # Write ouf features and barcodes tags_indexed.sort(FEATURE_ID_COLUMN).select(FEATURE_NAME_COLUMN).write_csv( @@ -461,6 +477,3 @@ def write_mapping_input( r2_too_short, total_reads, ) - - - diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index fcd4da2..61a6424 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -386,7 +386,7 @@ def split_data_input( mapping_input_path, has_header=False, new_columns=[BARCODE_COLUMN, UMI_COLUMN, R2_COLUMN], - n_rows=n_reads + n_rows=n_reads, ) .group_by([BARCODE_COLUMN, UMI_COLUMN, R2_COLUMN]) .agg(pl.count()) diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 5f3372e..1680931 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -15,6 +15,7 @@ UNMAPPED_NAME, ) + def correct_barcodes_pl( barcodes_df: pl.DataFrame, barcode_subset_df: pl.DataFrame, @@ -108,6 +109,21 @@ def summarise_unmapped_df(main_df: pl.DataFrame, unmapped_r2_df: pl.DataFrame): return unmapped_df +def generate_umi_counts(read_counts: pl.DataFrame) -> pl.DataFrame: + """Generate umi counts from read counts + + Args: + read_counts (pl.DataFrame): Read counts + + Returns: + pl.DataFrame: Umi counts + """ + umi_counts = read_counts.group_by([BARCODE_COLUMN, FEATURE_NAME_COLUMN]).agg( + pl.count() + ) + return umi_counts + + def update_main_df(main_df: pl.DataFrame, mapped_barcodes: dict): """Update the main data df with the corrected barcodes @@ -123,61 +139,83 @@ def update_main_df(main_df: pl.DataFrame, mapped_barcodes: dict): pl.col(BARCODE_COLUMN).map_dict(mapped_barcodes, default=pl.first()) ) .group_by([BARCODE_COLUMN, UMI_COLUMN, R2_COLUMN]) - .agg(pl.sum("count")) + .agg(pl.sum(COUNT_COLUMN)) ) return main_df -def find_corrected_umis( +def correct_umis_df( main_df, mapped_r2_df, umi_distance=1, cluster_method="directional", max_umis=20000 ): - merged = mapped_r2_df.join(main_df, on="r2") + merged = mapped_r2_df.join(main_df, on=R2_COLUMN) temp = ( - merged.with_columns(umi=pl.col("umi").cast(pl.Binary)) - .group_by(["r2", "feature_name", "barcode"]) - .agg(pl.struct(pl.col("umi"), pl.col("count"))) - .filter((pl.col("umi").list.len() > 1) & (pl.col("umi").list.len() < max_umis)) + merged.with_columns(umi=pl.col(UMI_COLUMN).cast(pl.Binary)) + .group_by([R2_COLUMN, FEATURE_NAME_COLUMN, BARCODE_COLUMN]) + .agg(pl.struct(pl.col(UMI_COLUMN), pl.col(COUNT_COLUMN))) ) - mapping = {"r2": [], "feature_name": [], "barcode": [], "orig": [], "replace": []} - for r2, feature_name, barcode, umis in temp.iter_rows(): + clustered_cells = ( + temp.filter(pl.col(UMI_COLUMN).list.len() > max_umis) + .select(BARCODE_COLUMN) + .unique() + .get_column(BARCODE_COLUMN) + .to_list() + ) + + mapping_list = [] + umi_clusterer = network.UMIClusterer(cluster_method=cluster_method) + for r2, feature_name, barcode, umis in temp.filter( + (pl.col(UMI_COLUMN).list.len() > 1) & (pl.col(UMI_COLUMN).list.len() < max_umis) + ).iter_rows(): corrected_umis = correct_umis( - umis, cluster_method=cluster_method, umi_distance=umi_distance + umis, + umi_distance=umi_distance, + umi_clusterer=umi_clusterer, ) if len(corrected_umis) == 0: continue for umi_set in corrected_umis: for index, umi in enumerate(umi_set): if index != 0: - mapping["r2"].append(r2) - mapping["feature_name"].append(feature_name) - mapping["barcode"].append(barcode) - mapping["orig"].append(umi) - mapping["replace"].append(umi_set[0]) + mapping_list.append([r2, feature_name, barcode, umi, umi_set[0]]) mapping_df = ( - pl.DataFrame(mapping) + pl.DataFrame( + mapping_list, + schema={ + R2_COLUMN: pl.String, + FEATURE_NAME_COLUMN: pl.String, + BARCODE_COLUMN: pl.String, + "orig": pl.Binary, + "replace": pl.Binary, + }, + ) .with_columns( - pl.col("orig").cast(pl.String).alias("umi"), + pl.col("orig").cast(pl.String).alias(UMI_COLUMN), pl.col("replace").cast(pl.String), ) .drop("orig") ) read_counts = ( - merged.join(mapping_df, on=["r2", "feature_name", "barcode", "umi"], how="left") + merged.join( + mapping_df, + on=[R2_COLUMN, FEATURE_NAME_COLUMN, BARCODE_COLUMN, UMI_COLUMN], + how="left", + ) .with_columns( pl.when(pl.col("replace").is_null()) - .then(pl.col("umi")) + .then(pl.col(UMI_COLUMN)) .otherwise(pl.col("replace")) - .alias("umi") + .alias(UMI_COLUMN) ) .drop("replace") - .group_by(["r2", "barcode", "feature_name", "umi"]) - .agg(pl.sum("count")).drop("r2") + .group_by([R2_COLUMN, BARCODE_COLUMN, FEATURE_NAME_COLUMN, UMI_COLUMN]) + .agg(pl.sum(COUNT_COLUMN)) + .drop(R2_COLUMN) ) - return read_counts + n_corrected_umis = mapping_df.shape[0] + return read_counts, n_corrected_umis, clustered_cells -def correct_umis(umis_list, cluster_method, umi_distance): - umi_clusterer = network.UMIClusterer(cluster_method=cluster_method) +def correct_umis(umis_list, umi_distance, umi_clusterer): umis = dict([(i["umi"], i["count"]) for i in umis_list]) res = umi_clusterer(umis, umi_distance) @@ -185,71 +223,23 @@ def correct_umis(umis_list, cluster_method, umi_distance): return corrected -# UMI correction section -def correct_umis_in_cells(umi_correction_input): - """ - Corrects umi barcodes within same cell/tag groups. - - Args: - final_results (dict): Dict of dict of Counters with mapping results. - collapsing_threshold (int): Max distance between umis. - filtered_cells (set): Set of cells to go through. - max_umis (int): Maximum UMIs to consider for one cluster. - - Returns: - final_results (dict): Same as input but with corrected umis. - corrected_umis (int): How many umis have been corrected. - clustered_umi_count_cells (set): Set of uncorrected cells. - """ - - (final_results, collapsing_threshold, max_umis, unmapped_id) = umi_correction_input - print( - "Started umi correction in child process {} working on {} cells".format( - os.getpid(), len(final_results) - ) - ) - corrected_umis = 0 - clustered_cells = set() - cells = final_results.keys() - for cell_barcode in cells: - for TAG in final_results[cell_barcode]: - if TAG == unmapped_id: - final_results[cell_barcode].pop(unmapped_id) - - n_umis = len(final_results[cell_barcode][TAG]) - if n_umis > 1 and n_umis <= max_umis: - umi_clusters = network.UMIClusterer() - UMIclusters = umi_clusters( - final_results[cell_barcode][TAG], collapsing_threshold - ) - (new_res, temp_corrected_umis) = update_umi_counts( - UMIclusters, final_results[cell_barcode][TAG] - ) - final_results[cell_barcode][TAG] = new_res - corrected_umis += temp_corrected_umis - elif n_umis > max_umis: - clustered_cells.add(cell_barcode) - print(f"Finished correcting umis in child {os.getpid()}") - return (final_results, corrected_umis, clustered_cells) - - -def generate_mtx_counts( - main_df: pl.DataFrame, - barcode_subset: pl.DataFrame, - mapped_r2_df: pl.DataFrame, - data_type: str, -) -> pl.DataFrame: - if data_type == "read": - return ( - main_df.join(barcode_subset, on=BARCODE_COLUMN) - .join(mapped_r2_df, on=R2_COLUMN) - .group_by([BARCODE_COLUMN, FEATURE_NAME_COLUMN]) - .agg(pl.sum(COUNT_COLUMN)) - ) - else: - return ( - main_df.join(barcode_subset, on=BARCODE_COLUMN) - .join(mapped_r2_df, on=R2_COLUMN) - .group_by([BARCODE_COLUMN, FEATURE_NAME_COLUMN]) - .agg(pl.count()) - ) +# def generate_mtx_counts( +# main_df: pl.DataFrame, +# barcode_subset: pl.DataFrame, +# mapped_r2_df: pl.DataFrame, +# data_type: str, +# ) -> pl.DataFrame: +# if data_type == "read": +# return ( +# main_df.join(barcode_subset, on=BARCODE_COLUMN) +# .join(mapped_r2_df, on=R2_COLUMN) +# .group_by([BARCODE_COLUMN, FEATURE_NAME_COLUMN]) +# .agg(pl.sum(COUNT_COLUMN)) +# ) +# else: +# return ( +# main_df.join(barcode_subset, on=BARCODE_COLUMN) +# .join(mapped_r2_df, on=R2_COLUMN) +# .group_by([BARCODE_COLUMN, FEATURE_NAME_COLUMN]) +# .agg(pl.count()) +# ) From 81396b3f6d0102c9e6eccfb3ea8b96a95f8c6c55 Mon Sep 17 00:00:00 2001 From: hoohm Date: Thu, 4 Jan 2024 12:25:43 +0100 Subject: [PATCH 72/77] (fix): duplicated read_counts writing --- cite_seq_count/__main__.py | 12 +++--------- cite_seq_count/io.py | 4 +++- 2 files changed, 6 insertions(+), 10 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 6f1cba9..f55376f 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -124,7 +124,9 @@ def main(): # # Write reads to file io.write_data_to_mtx( - main_df=final_read_counts, + main_df=final_read_counts.group_by( + [constants.BARCODE_COLUMN, constants.FEATURE_NAME_COLUMN] + ).agg(pl.sum(constants.COUNT_COLUMN)), tags_df=parsed_tags, subset_df=barcode_subset, data_type="read", @@ -133,14 +135,6 @@ def main(): umi_counts = processing.generate_umi_counts(read_counts=final_read_counts) io.write_out_parquet(df=umi_counts, outpath=args.outfolder, filename="umi_counts") - # Write umis to file - io.write_data_to_mtx( - main_df=final_read_counts, - tags_df=parsed_tags, - subset_df=barcode_subset, - data_type="read", - outpath=args.outfolder, - ) io.write_data_to_mtx( main_df=umi_counts, tags_df=parsed_tags, diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 38bae07..7b103e9 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -367,7 +367,9 @@ def write_data_to_mtx( data_path.mkdir(parents=True, exist_ok=True) mtx_df.write_csv(include_header=False, file=data_path / TEMP_MTX, separator=" ") # Write out the full MTX matrix - first_line_mtx = f"{tags_indexed.shape[0]} {barcodes_indexed.shape[0]} {mtx_df.shape[0]}\n" + first_line_mtx = ( + f"{tags_indexed.shape[0]} {barcodes_indexed.shape[0]} {mtx_df.shape[0]}\n" + ) with open(data_path / TEMP_MTX, "r") as mtx_in: mtx_main = mtx_in.read() final_mtx = MTX_HEADER + first_line_mtx + mtx_main From ddfe048beb5e20524da88db88c051192182ac3e9 Mon Sep 17 00:00:00 2001 From: hoohm Date: Fri, 5 Jan 2024 10:01:52 +0100 Subject: [PATCH 73/77] (feat): Top unmapped are back --- cite_seq_count/__main__.py | 3 + cite_seq_count/argsparser.py | 9 --- cite_seq_count/io.py | 97 ++++++++++----------------------- cite_seq_count/mapping.py | 17 +++--- cite_seq_count/preprocessing.py | 4 +- cite_seq_count/processing.py | 22 ++++---- setup.py | 1 - 7 files changed, 53 insertions(+), 100 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index f55376f..407aff7 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -93,6 +93,9 @@ def main(): unmapped_df = processing.summarise_unmapped_df( main_df=main_df, unmapped_r2_df=unmapped_r2_df ) + io.write_unmapped( + unmapped_df=unmapped_df, outfolder=args.outfolder, filename="unmapped" + ) barcode_subset, enable_barcode_correction = preprocessing.get_barcode_subset( barcode_subset=barcode_subset, n_barcodes=args.expected_barcodes, diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py index f87b121..4c392b5 100644 --- a/cite_seq_count/argsparser.py +++ b/cite_seq_count/argsparser.py @@ -275,15 +275,6 @@ def get_args() -> ArgumentParser: dest="outfolder", help=("Results will be written to this folder"), ) - parser.add_argument( - "-u", - "--unmapped-tags", - required=False, - type=str, - dest="unmapped_file", - default="unmapped.csv", - help=("Write table of unknown TAGs to file."), - ) parser.add_argument( "-ut", "--unknown-top-tags", diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index 7b103e9..c135d2b 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -154,54 +154,6 @@ def get_csv_reader_from_path(filename: str, sep: str = "\t"): return csv_reader -def write_to_files( - sparse_matrix: scipy.sparse.coo_matrix, - filtered_cells: set, - parsed_tags: dict, - data_type: str, - outfolder: str, - translation_dict: dict, -): - """Write the umi and read sparse matrices to file in gzipped mtx format. - - Args: - sparse_matrix (dok_matrix): Results in a sparse matrix. - filtered_cells (set): Set of cells that are selected for output. - parsed_tags (dict): Tags in order with indexes as values. - data_type (string): A string definning if the data is umi or read based. - outfolder (string): Path to the output folder. - """ - prefix = os.path.join(outfolder, data_type + "_count") - unknown_id = 1 - os.makedirs(prefix, exist_ok=True) - io.mmwrite(os.path.join(prefix, "matrix.mtx"), a=sparse_matrix, field="integer") - with gzip.open(os.path.join(prefix, "barcodes.tsv.gz"), "wb") as barcode_file: - for barcode in filtered_cells: - if translation_dict: - if barcode in translation_dict: - barcode_file.write( - f"{translation_dict[barcode]}\t{barcode}\n".encode(), - ) - else: - barcode_file.write( - "{}\t{}\n".format( - f"translation_not_found_{unknown_id}", barcode - ).encode(), - ) - unknown_id += 1 - else: - barcode_file.write(f"{barcode}\n".encode()) - with gzip.open(os.path.join(prefix, "features.tsv.gz"), "wb") as feature_file: - for feature in parsed_tags: - feature_file.write(f"{feature.sequence}\t{feature.name}\n".encode()) - if data_type == "read": - feature_file.write("{}\t{}\n".format("UNKNOWN", "unmapped").encode()) - with open(os.path.join(prefix, "matrix.mtx"), "rb") as mtx_in: - with gzip.open(os.path.join(prefix, "matrix.mtx") + ".gz", "wb") as mtx_gz: - shutil.copyfileobj(mtx_in, mtx_gz) - os.remove(os.path.join(prefix, "matrix.mtx")) - - def check_file(file_str: str) -> Path: """Check that a file exists and is readable. @@ -212,17 +164,15 @@ def check_file(file_str: str) -> Path: Path: Path to the file """ file_path = Path(file_str) - if file_path.exists and access(file_path, R_OK): - return file_path - else: - raise FileNotFoundError( - f"This file {file_path} does not exist or is not accessible" - ) + if not file_path.exists: + raise FileNotFoundError(f"This file {file_path} does not exist") + if not access(file_path, R_OK): + raise FileNotFoundError(f"This file {file_path} is not accessible") + return file_path -def write_unmapped( - merged_no_match: Counter, top_unknowns: int, outfolder: str, filename: str -): + +def write_unmapped(unmapped_df: pl.DataFrame, outfolder: Path, filename: str): """ Writes a list of top unmapped sequences @@ -233,12 +183,7 @@ def write_unmapped( filename (string): Name of the output file """ - top_unmapped = merged_no_match.most_common(top_unknowns) - - with open(os.path.join(outfolder, filename), "w", encoding="utf-8") as unknown_file: - unknown_file.write("tag,count\n") - for element in top_unmapped: - unknown_file.write(f"{element[0]},{element[1]}\n") + unmapped_df.write_csv(file=outfolder / f"{filename}.csv") def load_report_template() -> dict: @@ -331,7 +276,17 @@ def create_report( def create_mtx_df( main_df: pl.DataFrame, tags_df: pl.DataFrame, subset_df: pl.DataFrame -): +) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]: + """Create the MTX dataframe, indexed barcodes and indexed features from the different dataframes + + Args: + main_df (pl.DataFrame): Bridge between barcodes and features. Holds UMIs + tags_df (pl.DataFrame): Features and their sequences + subset_df (pl.DataFrame): Subset of barcodes to use + + Returns: + tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]: MTX dataframe, indexed barcodes, indexed features + """ tags_indexed = ( pl.concat( [ @@ -362,6 +317,15 @@ def write_data_to_mtx( data_type: str, outpath: str, ) -> None: + """Write out the data to disk in MTX format + + Args: + main_df (pl.DataFrame): Main df + tags_df (pl.DataFrame): Features df + subset_df (pl.DataFrame): Subsetted barcodes df + data_type (str): umi or read + outpath (str): Path to the output folder + """ mtx_df, tags_indexed, barcodes_indexed = create_mtx_df(main_df, tags_df, subset_df) data_path = Path(outpath) / f"{data_type}_count" data_path.mkdir(parents=True, exist_ok=True) @@ -398,8 +362,7 @@ def write_mapping_input( chemistry_def, ): """ - Writes chunked files of reads to disk and prepares parallel - processing queue parameters. + Writes all reads to one CSV to be used. Args: args(argparse): All parsed arguments. @@ -410,7 +373,7 @@ def write_mapping_input( parsed_tags (list): List of namedtuple tags. maximum_distance (int): Maximum hamming distance for mapping. """ - print("Writing chunks to disk") + print("Writing reads to disk") temp_path = os.path.abspath(args.temp_path) r1_too_short = 0 diff --git a/cite_seq_count/mapping.py b/cite_seq_count/mapping.py index c1f22a4..2b65c86 100644 --- a/cite_seq_count/mapping.py +++ b/cite_seq_count/mapping.py @@ -1,6 +1,5 @@ """Mapping module. Holds all code related to mapping reads """ -from turtle import right import polars as pl from rapidfuzz import fuzz, process import polars_distance as pld @@ -78,28 +77,30 @@ def map_reads_polars( ) simple_join = joined.filter(~pl.col(FEATURE_NAME_COLUMN).is_null()) - hamming_mapped = ( + levenshtein_mapped = ( joined.filter(pl.col(FEATURE_NAME_COLUMN).is_null()) .join(parsed_tags, left_on=R2_COLUMN, right_on=SEQUENCE_COLUMN, how="cross") .drop(FEATURE_NAME_COLUMN) .with_columns( pld.col(SEQUENCE_COLUMN) - .dist_str.hamming(pl.col(R2_COLUMN)) - .alias("hamming_dist") + .dist_str.levenshtein(pl.col(R2_COLUMN)) + .alias("levenshtein_dist") ) - .filter(pl.col("hamming_dist") <= maximum_distance) - .drop([SEQUENCE_COLUMN, "hamming_dist"]) + .filter(pl.col("levenshtein_dist") <= maximum_distance) + .drop([SEQUENCE_COLUMN, "levenshtein_dist"]) .rename({"feature_name_right": FEATURE_NAME_COLUMN}) ) multi_mapped = ( - hamming_mapped.group_by(R2_COLUMN) + levenshtein_mapped.group_by(R2_COLUMN) .agg(pl.count()) .filter(pl.col(COUNT_COLUMN) > 1) ) mapped = pl.concat( [ simple_join, - hamming_mapped.filter(~pl.col(R2_COLUMN).is_in(multi_mapped[R2_COLUMN])), + levenshtein_mapped.filter( + ~pl.col(R2_COLUMN).is_in(multi_mapped[R2_COLUMN]) + ), ] ) unmapped = joined.filter(~pl.col(R2_COLUMN).is_in(mapped[R2_COLUMN])).with_columns( diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 61a6424..2db6320 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -42,9 +42,7 @@ def check_equi_length(df: pl.DataFrame, column_name: str): raise ValueError(f"Barcodes in {column_name} column have different lengths.") -def parse_barcode_file( - filename: str, barcode_length: int, required_header: list -) -> pl.DataFrame: +def parse_barcode_file(filename: str, required_header: list) -> pl.DataFrame: """Reads reference barcodes from a CSV file. The function accepts plain barcodes or even 10X style barcodes with the diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 1680931..32fb73a 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -1,7 +1,5 @@ -import os - import polars as pl -from rapidfuzz import distance +import polars_distance as pld from umi_tools import network @@ -39,7 +37,7 @@ def correct_barcodes_pl( BARCODE_COLUMN: pl.Utf8, "count": pl.UInt32, SUBSET_COLUMN: pl.Utf8, - "hamming_distance": pl.UInt8, + "hamming_distance": pl.UInt32, } ) @@ -59,13 +57,8 @@ def correct_barcodes_pl( ) .filter(~pl.col(SUBSET_COLUMN).is_null()) .with_columns( - pl.struct(pl.col(BARCODE_COLUMN), pl.col(SUBSET_COLUMN)) - .map_elements( - lambda x: distance.Hamming.distance( - x[BARCODE_COLUMN], x[SUBSET_COLUMN] - ), - return_dtype=pl.UInt8, - ) + pld.col(BARCODE_COLUMN) + .dist_str.hamming(pl.col(SUBSET_COLUMN)) .alias("hamming_distance") ) .filter(pl.col("hamming_distance") <= hamming_distance) @@ -104,7 +97,12 @@ def summarise_unmapped_df(main_df: pl.DataFrame, unmapped_r2_df: pl.DataFrame): .alias(FEATURE_NAME_COLUMN) ) ) - unmapped_df = unmapped_r2_df.group_by(FEATURE_NAME_COLUMN).agg(pl.count()) + unmapped_df = ( + unmapped_r2_df.group_by(R2_COLUMN) + .agg(pl.sum(COUNT_COLUMN)) + .sort(COUNT_COLUMN, descending=True) + .head(1000) + ) return unmapped_df diff --git a/setup.py b/setup.py index 4c36c1f..cf53903 100644 --- a/setup.py +++ b/setup.py @@ -21,7 +21,6 @@ "Operating System :: OS Independent", ), install_requires=[ - "scipy>=1.1.0", "umi_tools==1.1.4", "pytest>=6.0.0", "pytest-dependency==0.4.0", From 5545d93d4930c14412ba66c1104fa5d3d797cc40 Mon Sep 17 00:00:00 2001 From: hoohm Date: Fri, 5 Jan 2024 10:02:56 +0100 Subject: [PATCH 74/77] (chore): Rename pl.Utf8 to pl.String --- cite_seq_count/preprocessing.py | 2 +- cite_seq_count/processing.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 2db6320..0950744 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -465,5 +465,5 @@ def find_knee_estimated_barcodes(barcodes_df: pl.DataFrame) -> pl.DataFrame: true_barcodes = whitelist_method.getKneeEstimateDistance( cell_barcode_counts=barcode_counter ) - barcode_subset = pl.DataFrame(true_barcodes, schema={SUBSET_COLUMN: pl.Utf8}) + barcode_subset = pl.DataFrame(true_barcodes, schema={SUBSET_COLUMN: pl.String}) return barcode_subset diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 32fb73a..fb1a840 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -34,9 +34,9 @@ def correct_barcodes_pl( print("Correcting barcodes") corrected_barcodes_pl = pl.DataFrame( schema={ - BARCODE_COLUMN: pl.Utf8, + BARCODE_COLUMN: pl.String, "count": pl.UInt32, - SUBSET_COLUMN: pl.Utf8, + SUBSET_COLUMN: pl.String, "hamming_distance": pl.UInt32, } ) From c137d5eea55f0d7a77d07ccb131c6b10a098a67f Mon Sep 17 00:00:00 2001 From: hoohm Date: Fri, 5 Jan 2024 16:04:28 +0100 Subject: [PATCH 75/77] (Fix): Barcode correction now iterates until it finds the best mapping using asof_join --- cite_seq_count/__main__.py | 16 ++--- cite_seq_count/io.py | 20 +++--- cite_seq_count/mapping.py | 86 ++++++++++++----------- cite_seq_count/preprocessing.py | 59 +++++++++------- cite_seq_count/processing.py | 104 +++++++++++++++++++--------- tests/test_io.py | 12 ++-- tests/test_mapping.py | 10 +-- tests/test_preprocessing.py | 56 +++++++-------- tests/test_processing.py | 118 ++++++++++++-------------------- 9 files changed, 248 insertions(+), 233 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 407aff7..3e813fa 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -48,7 +48,6 @@ def main(): if args.subset_path is not None: barcode_subset = preprocessing.parse_barcode_file( filename=args.subset_path, - barcode_length=chemistry_def.barcode_length, required_header=[constants.REFERENCE_COLUMN], ) else: @@ -83,8 +82,6 @@ def main(): main_df, barcodes_df, r2_df = preprocessing.split_data_input( mapping_input_path=temp_file, n_reads=n_reads ) - # Remove temp file - os.remove(temp_file) mapped_r2_df, unmapped_r2_df = mapping.map_reads_polars( r2_df=r2_df, parsed_tags=parsed_tags, @@ -111,8 +108,8 @@ def main(): n_bcs_corrected, mapped_barcodes, ) = processing.correct_barcodes_pl( - barcodes_df=barcodes_df, - barcode_subset_df=barcode_subset, + barcodes_df=barcodes_df.collect(), + barcode_subset_df=barcode_subset.collect(), hamming_distance=args.bc_threshold, ) main_df = processing.update_main_df( @@ -122,7 +119,7 @@ def main(): n_bcs_corrected = 0 print("UMI correction") final_read_counts, umis_corrected, clustered_cells = processing.correct_umis_df( - main_df=main_df, mapped_r2_df=mapped_r2_df + main_df=main_df, mapped_r2_df=mapped_r2_df, umi_distance=args.umi_threshold ) # # Write reads to file @@ -149,9 +146,11 @@ def main(): # TODO: Write unmapped sequences # TODO: rewrite reporting # Create report and write it to disk + # Remove temp file + io.create_report( total_reads=total_reads, - unmapped=unmapped_df, + unmapped=unmapped_df.collect(), version=argsparser.get_package_version(), start_time=start_time, umis_corrected=umis_corrected, @@ -163,7 +162,6 @@ def main(): chemistry_def=chemistry_def, maximum_distance=maximum_distance, ) - - + os.remove(temp_file) if __name__ == "__main__": main() diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index c135d2b..a6dccb1 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -68,7 +68,7 @@ def blocks(file: TextIO, size: int = 65536): yield partial_file -def write_out_parquet(df: pl.DataFrame, outpath: Path, filename: str): +def write_out_parquet(df: pl.LazyFrame, outpath: Path, filename: str): """ Writes out a dataframe to parquet format. @@ -77,7 +77,7 @@ def write_out_parquet(df: pl.DataFrame, outpath: Path, filename: str): outpath (Path): Path to the output folder filename (str): Name of the file """ - df.write_parquet(outpath / f"{filename}.parquet") + df.collect().write_parquet(outpath / f"{filename}.parquet") def get_n_lines(file_path: Path) -> int: @@ -172,7 +172,7 @@ def check_file(file_str: str) -> Path: return file_path -def write_unmapped(unmapped_df: pl.DataFrame, outfolder: Path, filename: str): +def write_unmapped(unmapped_df: pl.LazyFrame, outfolder: Path, filename: str): """ Writes a list of top unmapped sequences @@ -183,7 +183,7 @@ def write_unmapped(unmapped_df: pl.DataFrame, outfolder: Path, filename: str): filename (string): Name of the output file """ - unmapped_df.write_csv(file=outfolder / f"{filename}.csv") + unmapped_df.collect().write_csv(file=outfolder / f"{filename}.csv") def load_report_template() -> dict: @@ -275,7 +275,7 @@ def create_report( def create_mtx_df( - main_df: pl.DataFrame, tags_df: pl.DataFrame, subset_df: pl.DataFrame + main_df: pl.LazyFrame, tags_df: pl.DataFrame, subset_df: pl.LazyFrame ) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]: """Create the MTX dataframe, indexed barcodes and indexed features from the different dataframes @@ -290,8 +290,8 @@ def create_mtx_df( tags_indexed = ( pl.concat( [ - tags_df, - pl.DataFrame( + tags_df.lazy(), + pl.LazyFrame( {FEATURE_NAME_COLUMN: UNMAPPED_NAME, SEQUENCE_COLUMN: "UNKNOWN"} ), ] @@ -307,13 +307,13 @@ def create_mtx_df( .join(barcodes_indexed, left_on=BARCODE_COLUMN, right_on=SUBSET_COLUMN) .select([FEATURE_ID_COLUMN, BARCODE_ID_COLUMN, COUNT_COLUMN]) ) - return mtx_df, tags_indexed, barcodes_indexed + return mtx_df.collect(), tags_indexed.collect(), barcodes_indexed.collect() def write_data_to_mtx( - main_df: pl.DataFrame, + main_df: pl.LazyFrame, tags_df: pl.DataFrame, - subset_df: pl.DataFrame, + subset_df: pl.LazyFrame, data_type: str, outpath: str, ) -> None: diff --git a/cite_seq_count/mapping.py b/cite_seq_count/mapping.py index 2b65c86..1789dfe 100644 --- a/cite_seq_count/mapping.py +++ b/cite_seq_count/mapping.py @@ -33,53 +33,54 @@ def find_best_match_fast(tag_seq, tags_df, maximum_distance): return UNMAPPED_NAME -def map_reads_hybrid( - r2_df: pl.DataFrame, parsed_tags: pl.DataFrame, maximum_distance: int -) -> tuple[pl.DataFrame, pl.DataFrame]: - """Map sequence data to a tags reference. - Using a hybdrid approach where we first join all the data for the exact matches - then using a hamming distance calculation to find the closest match +# def map_reads_hybrid( +# r2_df: pl.DataFrame, parsed_tags: pl.DataFrame, maximum_distance: int +# ) -> tuple[pl.DataFrame, pl.DataFrame]: +# """Map sequence data to a tags reference. +# Using a hybdrid approach where we first join all the data for the exact matches +# then using a hamming distance calculation to find the closest match - Args: - r2_df (pl.DataFrame): All r2 sequences to map - parsed_tags (pl.DataFrame): tags to map to - maximum_distance (int): max distance allowed for mismatches +# Args: +# r2_df (pl.DataFrame): All r2 sequences to map +# parsed_tags (pl.DataFrame): tags to map to +# maximum_distance (int): max distance allowed for mismatches - Returns: - pl.DataFrame: Mapped data - """ - print("Mapping reads") - mapped_r2_df = r2_df.join( - parsed_tags, left_on=R2_COLUMN, right_on=SEQUENCE_COLUMN, how="left" - ).with_columns( - pl.when(pl.col(FEATURE_NAME_COLUMN).is_null()) - .then( - pl.col(R2_COLUMN) - .map_elements( - lambda x: find_best_match_fast( - x, tags_df=parsed_tags, maximum_distance=maximum_distance - ) - ) - .alias(FEATURE_NAME_COLUMN) - ) - .otherwise(pl.col(FEATURE_NAME_COLUMN)) - ) - unmapped_r2_df = mapped_r2_df.filter(pl.col(FEATURE_NAME_COLUMN) == UNMAPPED_NAME) - print("Mapping done") - return mapped_r2_df, unmapped_r2_df +# Returns: +# pl.DataFrame: Mapped data +# """ +# print("Mapping reads") +# mapped_r2_df = r2_df.join( +# parsed_tags, left_on=R2_COLUMN, right_on=SEQUENCE_COLUMN, how="left" +# ).with_columns( +# pl.when(pl.col(FEATURE_NAME_COLUMN).is_null()) +# .then( +# pl.col(R2_COLUMN) +# .map_elements( +# lambda x: find_best_match_fast( +# x, tags_df=parsed_tags, maximum_distance=maximum_distance +# ) +# ) +# .alias(FEATURE_NAME_COLUMN) +# ) +# .otherwise(pl.col(FEATURE_NAME_COLUMN)) +# ) +# unmapped_r2_df = mapped_r2_df.filter(pl.col(FEATURE_NAME_COLUMN) == UNMAPPED_NAME) +# print("Mapping done") +# return mapped_r2_df, unmapped_r2_df def map_reads_polars( - r2_df: pl.DataFrame, parsed_tags: pl.DataFrame, maximum_distance: int -) -> tuple[pl.DataFrame, pl.DataFrame]: + r2_df: pl.LazyFrame, parsed_tags: pl.DataFrame, maximum_distance: int +) -> tuple[pl.LazyFrame, pl.LazyFrame]: joined = r2_df.join( - parsed_tags, left_on=R2_COLUMN, right_on=SEQUENCE_COLUMN, how="left" + parsed_tags.lazy(), left_on=R2_COLUMN, right_on=SEQUENCE_COLUMN, how="left" ) - simple_join = joined.filter(~pl.col(FEATURE_NAME_COLUMN).is_null()) levenshtein_mapped = ( joined.filter(pl.col(FEATURE_NAME_COLUMN).is_null()) - .join(parsed_tags, left_on=R2_COLUMN, right_on=SEQUENCE_COLUMN, how="cross") + .join( + parsed_tags.lazy(), left_on=R2_COLUMN, right_on=SEQUENCE_COLUMN, how="cross" + ) .drop(FEATURE_NAME_COLUMN) .with_columns( pld.col(SEQUENCE_COLUMN) @@ -98,13 +99,16 @@ def map_reads_polars( mapped = pl.concat( [ simple_join, - levenshtein_mapped.filter( - ~pl.col(R2_COLUMN).is_in(multi_mapped[R2_COLUMN]) + levenshtein_mapped.join(multi_mapped, on=R2_COLUMN, how="inner").drop( + COUNT_COLUMN ), ] ) - unmapped = joined.filter(~pl.col(R2_COLUMN).is_in(mapped[R2_COLUMN])).with_columns( - pl.col(FEATURE_NAME_COLUMN).fill_null(UNMAPPED_NAME) + unmapped = ( + joined.join(mapped.drop(FEATURE_NAME_COLUMN), on=R2_COLUMN, how="outer") + .with_columns(pl.col(FEATURE_NAME_COLUMN).fill_null(UNMAPPED_NAME)) + .filter(pl.col("r2_right").is_null()) + .drop("r2_right") ) return mapped, unmapped diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 0950744..8f72099 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -42,7 +42,7 @@ def check_equi_length(df: pl.DataFrame, column_name: str): raise ValueError(f"Barcodes in {column_name} column have different lengths.") -def parse_barcode_file(filename: str, required_header: list) -> pl.DataFrame: +def parse_barcode_file(filename: str, required_header: list) -> pl.LazyFrame: """Reads reference barcodes from a CSV file. The function accepts plain barcodes or even 10X style barcodes with the @@ -57,7 +57,7 @@ def parse_barcode_file(filename: str, required_header: list) -> pl.DataFrame: """ file_path = check_file(filename) - barcodes_df = pl.read_csv(file_path.absolute()) + barcodes_df = pl.scan_csv(file_path.absolute(), has_header=True) barcode_pattern = "^[ATGC]{1,}$" header = barcodes_df.columns @@ -65,8 +65,9 @@ def parse_barcode_file(filename: str, required_header: list) -> pl.DataFrame: if len(set_dif) != 0: set_diff_string = ",".join(list(set_dif)) raise SystemExit(f"The header is missing {set_diff_string}. Exiting") + # TODO: Enable to use translation inputs if OPTIONAL_CELLS_REF_HEADER in header: - with_translation = True + with_translation = False else: with_translation = False # Prepare and validate barcodes_df @@ -90,7 +91,7 @@ def parse_barcode_file(filename: str, required_header: list) -> pl.DataFrame: filename=filename, ) - check_equi_length(df=barcodes_df, column_name=REFERENCE_COLUMN) + check_equi_length(df=barcodes_df.collect(), column_name=REFERENCE_COLUMN) if with_translation: check_sequence_pattern( @@ -101,9 +102,9 @@ def parse_barcode_file(filename: str, required_header: list) -> pl.DataFrame: expected_pattern="ATGC", filename=filename, ) - check_equi_length(df=barcodes_df, column_name=TRANSLATION_COLUMN) + check_equi_length(df=barcodes_df.collect(), column_name=TRANSLATION_COLUMN) - return barcodes_df + return barcodes_df.drop(TRANSLATION_COLUMN) def parse_tags_csv(file_name: str) -> pl.DataFrame: @@ -146,7 +147,7 @@ def parse_tags_csv(file_name: str) -> pl.DataFrame: f"Column {column} is missing a value. Please fix the CSV file." ) check_sequence_pattern( - df=data_pl, + df=data_pl.lazy(), pattern=atgc_test, column_name=SEQUENCE_COLUMN, file_type="tags", @@ -157,7 +158,7 @@ def parse_tags_csv(file_name: str) -> pl.DataFrame: def check_sequence_pattern( - df: pl.DataFrame, + df: pl.LazyFrame, pattern: str, column_name: str, file_type: str, @@ -176,7 +177,7 @@ def check_sequence_pattern( Raises: SystemExit: Exists if some patterns don't match """ - regex_test = df.with_columns( + regex_test = df.collect().with_columns( pl.col(column_name).str.contains(pattern).alias("regex") ) if not regex_test.select(pl.col("regex").all()).get_column("regex").item(): @@ -364,7 +365,7 @@ def pre_run_checks( def split_data_input( mapping_input_path: Path, n_reads: int -) -> tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]: +) -> tuple[pl.LazyFrame, pl.LazyFrame, pl.LazyFrame]: """Read in all the input data and split it into three dataframes. Reduce the size of the data by grouping on barcodes, umis and sequences. @@ -380,7 +381,7 @@ def split_data_input( tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]: Three dfs described above """ main_df = ( - pl.read_csv( + pl.scan_csv( mapping_input_path, has_header=False, new_columns=[BARCODE_COLUMN, UMI_COLUMN, R2_COLUMN], @@ -401,12 +402,12 @@ def split_data_input( def get_barcode_subset( - barcode_reference: pl.DataFrame | None, + barcode_reference: pl.LazyFrame | None, n_barcodes: int, chemistry, - barcode_subset: pl.DataFrame | None, - barcodes_df: pl.DataFrame, -) -> tuple[pl.DataFrame, bool]: + barcode_subset: pl.LazyFrame | None, + barcodes_df: pl.LazyFrame, +) -> tuple[pl.LazyFrame, bool]: """Generate the barcode list used for barcode correction and subsetting Args: @@ -425,8 +426,11 @@ def get_barcode_subset( # Subset: False, Reference: True if barcode_reference is not None: barcode_subset = ( - barcodes_df.filter( - pl.col(BARCODE_COLUMN).is_in(barcode_reference[REFERENCE_COLUMN]) + barcodes_df.join( + barcode_reference, + left_on=BARCODE_COLUMN, + right_on=REFERENCE_COLUMN, + how="inner", ) .sort(COUNT_COLUMN, descending=True) .head(round(n_barcodes * 1.2)) @@ -436,18 +440,19 @@ def get_barcode_subset( else: # Subset: False, Reference: False barcode_subset = find_knee_estimated_barcodes(barcodes_df=barcodes_df) - - if n_barcodes > barcode_subset.shape[0]: + barcodes_found = barcode_subset.collect().shape[0] + if n_barcodes > barcodes_found: print( f"Number of expected cells, {n_barcodes}, is higher " - f"than number of cells found {barcode_subset.shape[0]}.\nNot performing " + f"than number of cells found {barcodes_found}.\nNot performing " f"cell barcode correction" ) enable_barcode_correction = False + assert barcode_subset.schema == {SUBSET_COLUMN: pl.String} return barcode_subset, enable_barcode_correction -def find_knee_estimated_barcodes(barcodes_df: pl.DataFrame) -> pl.DataFrame: +def find_knee_estimated_barcodes(barcodes_df: pl.LazyFrame) -> pl.LazyFrame: """Find the subset of barcodes by the knee method Args: @@ -456,14 +461,18 @@ def find_knee_estimated_barcodes(barcodes_df: pl.DataFrame) -> pl.DataFrame: Returns: pl.DataFrame: Final list of barcodes """ - raw_barcodes_dict = barcodes_df.filter( - ~pl.col(BARCODE_COLUMN).str.contains("N") - ).sort("count", descending=True) + raw_barcodes_dict = ( + barcodes_df.filter(~pl.col(BARCODE_COLUMN).str.contains("N")) + .sort("count", descending=True) + .collect() + ) barcode_counter = Counter() barcode_counts = dict(raw_barcodes_dict.iter_rows()) # type: ignore barcode_counter.update(barcode_counts) true_barcodes = whitelist_method.getKneeEstimateDistance( cell_barcode_counts=barcode_counter ) - barcode_subset = pl.DataFrame(true_barcodes, schema={SUBSET_COLUMN: pl.String}) + barcode_subset = pl.DataFrame( + true_barcodes, schema={SUBSET_COLUMN: pl.String} + ).lazy() return barcode_subset diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index fb1a840..686084d 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -18,7 +18,7 @@ def correct_barcodes_pl( barcodes_df: pl.DataFrame, barcode_subset_df: pl.DataFrame, hamming_distance: int, -) -> tuple[pl.DataFrame, int, dict]: +) -> tuple[pl.LazyFrame, int, dict]: """Corrects barcodes using a subset based on join_asof from polars. Uses both forward and backward strategy to dinf the closest barcode @@ -37,33 +37,65 @@ def correct_barcodes_pl( BARCODE_COLUMN: pl.String, "count": pl.UInt32, SUBSET_COLUMN: pl.String, - "hamming_distance": pl.UInt32, } ) - - methods = ["forward", "backward"] - for method in methods: - current_barcodes = ( - barcodes_df.filter( - (~pl.col(BARCODE_COLUMN).is_in(corrected_barcodes_pl[BARCODE_COLUMN])) - & (~pl.col(BARCODE_COLUMN).is_in(barcode_subset_df[SUBSET_COLUMN])) - ) - .sort(BARCODE_COLUMN) - .join_asof( - barcode_subset_df.sort(SUBSET_COLUMN), - left_on=BARCODE_COLUMN, - right_on=SUBSET_COLUMN, - strategy=method, # type: ignore + current_barcodes = pl.DataFrame( + schema={ + BARCODE_COLUMN: pl.String, + "count": pl.UInt32, + SUBSET_COLUMN: pl.String, + } + ) + current_barcodes_to_correct = barcodes_df.shape[0] + last_iteration_barcodes_to_correct = 0 + unknown_barcodes = barcodes_df.filter( + ~pl.col(BARCODE_COLUMN).is_in(barcode_subset_df[SUBSET_COLUMN]) + ) + n_iterations = 0 + while ( + current_barcodes_to_correct > 0 + and current_barcodes_to_correct != last_iteration_barcodes_to_correct + ): + methods = ["backward", "forward"] + for method in methods: + current_barcodes = ( + ( + unknown_barcodes.filter( + ( + ~pl.col(BARCODE_COLUMN).is_in( + corrected_barcodes_pl[BARCODE_COLUMN] + ) + ) + ) + .sort(BARCODE_COLUMN) + .join_asof( + barcode_subset_df.sort(SUBSET_COLUMN), + left_on=BARCODE_COLUMN, + right_on=SUBSET_COLUMN, + strategy=method, # type: ignore + ) + ) + .filter(~pl.col(SUBSET_COLUMN).is_null()) + .with_columns( + pld.col(BARCODE_COLUMN) + .dist_str.hamming(pl.col(SUBSET_COLUMN)) + .alias("hamming_distance") + ) + .filter(pl.col("hamming_distance") <= hamming_distance) + .drop("hamming_distance") ) - .filter(~pl.col(SUBSET_COLUMN).is_null()) - .with_columns( - pld.col(BARCODE_COLUMN) - .dist_str.hamming(pl.col(SUBSET_COLUMN)) - .alias("hamming_distance") + corrected_barcodes_pl = pl.concat( + [ + corrected_barcodes_pl, + current_barcodes, + ] ) - .filter(pl.col("hamming_distance") <= hamming_distance) + current_barcodes_to_correct = current_barcodes.shape[0] + barcode_subset_df = barcode_subset_df.filter( + ~pl.col(SUBSET_COLUMN).is_in(corrected_barcodes_pl[SUBSET_COLUMN]) ) - corrected_barcodes_pl = pl.concat([corrected_barcodes_pl, current_barcodes]) + n_iterations += 1 + print(f"Corrected barcodes in {n_iterations} iterations") mapped_barcodes = dict( corrected_barcodes_pl.select(BARCODE_COLUMN, SUBSET_COLUMN).iter_rows() # type: ignore ) @@ -71,16 +103,16 @@ def correct_barcodes_pl( barcodes_df.with_columns( pl.col(BARCODE_COLUMN).map_dict(mapped_barcodes, default=pl.first()) ) - .group_by(BARCODE_COLUMN) - .agg(pl.sum("count")) + # .filter(pl.col(BARCODE_COLUMN).is_in(barcode_subset_df[SUBSET_COLUMN])) + .group_by(BARCODE_COLUMN).agg(pl.sum("count")) ) print("Barcodes corrected") n_corrected_barcodes = corrected_barcodes_pl.shape[0] - return final_corrected, n_corrected_barcodes, mapped_barcodes + return final_corrected.lazy(), n_corrected_barcodes, mapped_barcodes -def summarise_unmapped_df(main_df: pl.DataFrame, unmapped_r2_df: pl.DataFrame): +def summarise_unmapped_df(main_df: pl.LazyFrame, unmapped_r2_df: pl.LazyFrame): """Merge main df and unmapped df to get a summary of the unmapped reads Args: @@ -107,7 +139,7 @@ def summarise_unmapped_df(main_df: pl.DataFrame, unmapped_r2_df: pl.DataFrame): return unmapped_df -def generate_umi_counts(read_counts: pl.DataFrame) -> pl.DataFrame: +def generate_umi_counts(read_counts: pl.LazyFrame) -> pl.LazyFrame: """Generate umi counts from read counts Args: @@ -122,7 +154,7 @@ def generate_umi_counts(read_counts: pl.DataFrame) -> pl.DataFrame: return umi_counts -def update_main_df(main_df: pl.DataFrame, mapped_barcodes: dict): +def update_main_df(main_df: pl.LazyFrame, mapped_barcodes: dict) -> pl.LazyFrame: """Update the main data df with the corrected barcodes Args: @@ -143,14 +175,18 @@ def update_main_df(main_df: pl.DataFrame, mapped_barcodes: dict): def correct_umis_df( - main_df, mapped_r2_df, umi_distance=1, cluster_method="directional", max_umis=20000 -): + main_df: pl.LazyFrame, + mapped_r2_df: pl.LazyFrame, + umi_distance, + cluster_method="directional", + max_umis=20000, +) -> tuple[pl.LazyFrame, int, list]: merged = mapped_r2_df.join(main_df, on=R2_COLUMN) temp = ( merged.with_columns(umi=pl.col(UMI_COLUMN).cast(pl.Binary)) .group_by([R2_COLUMN, FEATURE_NAME_COLUMN, BARCODE_COLUMN]) .agg(pl.struct(pl.col(UMI_COLUMN), pl.col(COUNT_COLUMN))) - ) + ).collect() clustered_cells = ( temp.filter(pl.col(UMI_COLUMN).list.len() > max_umis) .select(BARCODE_COLUMN) @@ -176,7 +212,7 @@ def correct_umis_df( if index != 0: mapping_list.append([r2, feature_name, barcode, umi, umi_set[0]]) mapping_df = ( - pl.DataFrame( + pl.LazyFrame( mapping_list, schema={ R2_COLUMN: pl.String, @@ -209,7 +245,7 @@ def correct_umis_df( .agg(pl.sum(COUNT_COLUMN)) .drop(R2_COLUMN) ) - n_corrected_umis = mapping_df.shape[0] + n_corrected_umis = mapping_df.collect().shape[0] return read_counts, n_corrected_umis, clustered_cells diff --git a/tests/test_io.py b/tests/test_io.py index 6843c08..2327a86 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -1,7 +1,6 @@ import pytest import os import gzip -import scipy from pathlib import Path from polars.testing import assert_frame_equal from cite_seq_count.constants import ( @@ -16,13 +15,10 @@ ) from cite_seq_count.io import ( get_n_lines, - write_to_files, get_read_paths, write_data_to_mtx, create_mtx_df, ) -from collections import namedtuple -import numpy as np import polars as pl # copied from https://stackoverflow.com/questions/3431825/generating-an-md5-checksum-of-a-file @@ -117,7 +113,7 @@ def test_get_n_lines_not_multiple_of_4(corrupt_R1): def test_write_data_to_mtx(tmp_path): # Create test data - main_df = pl.DataFrame( + main_df = pl.LazyFrame( { FEATURE_NAME_COLUMN: ["test1", "test2", "test3"], BARCODE_COLUMN: [ @@ -134,7 +130,7 @@ def test_write_data_to_mtx(tmp_path): SEQUENCE_COLUMN: ["CGTA", "CGTA", "CGTA"], } ) - subset_df = pl.DataFrame({SUBSET_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]}) + subset_df = pl.LazyFrame({SUBSET_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]}) data_type = "umi" outpath = str(tmp_path) print(outpath) @@ -158,7 +154,7 @@ def test_write_data_to_mtx(tmp_path): def test_create_mtx_df(): # Create test data - main_df = pl.DataFrame( + main_df = pl.LazyFrame( { FEATURE_NAME_COLUMN: ["test1", "test2", "test3"], BARCODE_COLUMN: [ @@ -175,7 +171,7 @@ def test_create_mtx_df(): SEQUENCE_COLUMN: ["CGTA", "CGTA", "CGTA"], } ) - subset_df = pl.DataFrame({SUBSET_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]}) + subset_df = pl.LazyFrame({SUBSET_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"]}) # Call the function mtx_df, tags_indexed, barcodes_indexed = create_mtx_df(main_df, tags_df, subset_df) diff --git a/tests/test_mapping.py b/tests/test_mapping.py index c369265..a3c4681 100644 --- a/tests/test_mapping.py +++ b/tests/test_mapping.py @@ -64,7 +64,7 @@ def parsed_tags_df(): @pytest.fixture def r2_df(): # Create a sample DataFrame for r2_df - return pl.DataFrame( + return pl.LazyFrame( { R2_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT", "AGCTAGCTAGCTAGCT"], } @@ -84,13 +84,13 @@ def parsed_tags(): def test_map_reads_polars_with_dist1(r2_df, parsed_tags): maximum_distance = 1 - expected_mapped = pl.DataFrame( + expected_mapped = pl.LazyFrame( { R2_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"], FEATURE_NAME_COLUMN: ["feature1", "feature2"], } ) - expected_unmapped = pl.DataFrame( + expected_unmapped = pl.LazyFrame( {R2_COLUMN: ["AGCTAGCTAGCTAGCT"], FEATURE_NAME_COLUMN: ["unmapped"]} ) @@ -102,13 +102,13 @@ def test_map_reads_polars_with_dist1(r2_df, parsed_tags): def test_map_reads_polars_with_dist2(r2_df, parsed_tags): maximum_distance = 2 - expected_mapped = pl.DataFrame( + expected_mapped = pl.LazyFrame( { R2_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"], FEATURE_NAME_COLUMN: ["feature1", "feature2"], } ) - expected_unmapped = pl.DataFrame( + expected_unmapped = pl.LazyFrame( {R2_COLUMN: ["AGCTAGCTAGCTAGCT"], FEATURE_NAME_COLUMN: ["unmapped"]} ) diff --git a/tests/test_preprocessing.py b/tests/test_preprocessing.py index 85498f0..d423f2b 100644 --- a/tests/test_preprocessing.py +++ b/tests/test_preprocessing.py @@ -53,23 +53,23 @@ def failing_tags(): @pytest.fixture -def correct_reference_df(): - return pl.DataFrame( +def correct_reference_lf(): + return pl.LazyFrame( { REFERENCE_COLUMN: ["ATGCCC", "ATGCTT", "CCGCCC"], - TRANSLATION_COLUMN: ["GGATCG", "GGATCA", "GATCAA"], + # TRANSLATION_COLUMN: ["GGATCG", "GGATCA", "GATCAA"], } ) @pytest.fixture def correct_subset_df(): - return pl.DataFrame({REFERENCE_COLUMN: ["ATGCCC", "ATGCTT"]}) + return pl.LazyFrame({REFERENCE_COLUMN: ["ATGCCC", "ATGCTT"]}) @pytest.fixture def barcodes_df(): - return pl.DataFrame( + return pl.LazyFrame( { BARCODE_COLUMN: ["ATGCCC", "ATGCTT", "CCGCCC", "ATATCC", "ATATGG"], COUNT_COLUMN: [200, 200, 200, 20, 10], @@ -120,12 +120,12 @@ def chemistry_def(): ) -def test_passing_parse_reference_list_csv(passing_references, correct_reference_df): +def test_passing_parse_reference_list_csv(passing_references, correct_reference_lf): passing_files = glob.glob(passing_references) for file_path in passing_files: assert_frame_equal( - left=parse_barcode_file(file_path, 16, [REFERENCE_COLUMN]), - right=correct_reference_df, + left=parse_barcode_file(file_path, [REFERENCE_COLUMN]), + right=correct_reference_lf, ) @@ -133,20 +133,20 @@ def test_failing_parse_reference_list_csv(failing_references): with pytest.raises(SystemExit): failing_files = glob.glob(failing_references) for file_path in failing_files: - parse_barcode_file(file_path, 16, [REFERENCE_COLUMN]) + parse_barcode_file(file_path, [REFERENCE_COLUMN]) def test_parse_subset_list_csv(passing_subsets, failing_subsets, correct_subset_df): passing_files = glob.glob(passing_subsets) for file_path in passing_files: assert_frame_equal( - left=parse_barcode_file(file_path, 16, [REFERENCE_COLUMN]), + left=parse_barcode_file(file_path, [REFERENCE_COLUMN]), right=correct_subset_df, ) with pytest.raises(SystemExit): failing_files = glob.glob(failing_subsets) for file_path in failing_files: - parse_barcode_file(file_path, 16, [REFERENCE_COLUMN]) + parse_barcode_file(file_path, [REFERENCE_COLUMN]) def test_check_distance_too_big_between_tags(correct_tags_df): @@ -156,27 +156,27 @@ def test_check_distance_too_big_between_tags(correct_tags_df): # Test if there is no reference and no whitelist def test_find_knee_estimated_barcodes(barcodes_df): - expected_subset = pl.DataFrame({SUBSET_COLUMN: ["ATGCCC", "ATGCTT", "CCGCCC"]}) + expected_subset = pl.LazyFrame({SUBSET_COLUMN: ["ATGCCC", "ATGCTT", "CCGCCC"]}) subset = find_knee_estimated_barcodes(barcodes_df=barcodes_df) assert_frame_equal(subset, expected_subset) @pytest.fixture def barcode_subset_df(): - return pl.DataFrame({SUBSET_COLUMN: ["ATGCCC", "ATGCTT"]}) + return pl.LazyFrame({SUBSET_COLUMN: ["ATGCCC", "ATGCTT"]}) -def test_get_barcode_subset_with_reference(correct_reference_df, barcodes_df): - expected_subset = pl.DataFrame( +def test_get_barcode_subset_with_reference(correct_reference_lf, barcodes_df): + expected_subset = pl.LazyFrame( { - SUBSET_COLUMN: ["ATGCCC", "ATGCTT"], + SUBSET_COLUMN: ["ATGCCC", "ATGCTT", "CCGCCC"], } ) expected_enable_correction = True subset, enable_correction = get_barcode_subset( - barcode_reference=correct_reference_df, - n_barcodes=2, + barcode_reference=correct_reference_lf, + n_barcodes=3, chemistry=None, barcode_subset=None, barcodes_df=barcodes_df, @@ -187,7 +187,7 @@ def test_get_barcode_subset_with_reference(correct_reference_df, barcodes_df): def test_get_barcode_subset_without_reference(barcodes_df): - expected_subset = pl.DataFrame({SUBSET_COLUMN: ["ATGCCC", "ATGCTT", "CCGCCC"]}) + expected_subset = pl.LazyFrame({SUBSET_COLUMN: ["ATGCCC", "ATGCTT", "CCGCCC"]}) expected_enable_correction = True subset, enable_correction = get_barcode_subset( @@ -202,8 +202,8 @@ def test_get_barcode_subset_without_reference(barcodes_df): assert enable_correction == expected_enable_correction -def test_get_barcode_subset_with_large_n_barcodes(correct_reference_df, barcodes_df): - expected_subset = pl.DataFrame( +def test_get_barcode_subset_with_large_n_barcodes(correct_reference_lf, barcodes_df): + expected_subset = pl.LazyFrame( { SUBSET_COLUMN: ["ATGCCC", "ATGCTT", "CCGCCC"], } @@ -211,7 +211,7 @@ def test_get_barcode_subset_with_large_n_barcodes(correct_reference_df, barcodes expected_enable_correction = False subset, enable_correction = get_barcode_subset( - barcode_reference=correct_reference_df, + barcode_reference=correct_reference_lf, n_barcodes=4, chemistry=None, barcode_subset=None, @@ -223,13 +223,13 @@ def test_get_barcode_subset_with_large_n_barcodes(correct_reference_df, barcodes def test_get_barcode_subset_with_existing_subset( - correct_reference_df, barcode_subset_df, barcodes_df + correct_reference_lf, barcode_subset_df, barcodes_df ): - expected_subset = pl.DataFrame({SUBSET_COLUMN: ["ATGCCC", "ATGCTT"]}) + expected_subset = pl.LazyFrame({SUBSET_COLUMN: ["ATGCCC", "ATGCTT"]}) expected_enable_correction = True subset, enable_correction = get_barcode_subset( - barcode_reference=correct_reference_df, + barcode_reference=correct_reference_lf, n_barcodes=2, chemistry=None, barcode_subset=barcode_subset_df, @@ -240,8 +240,8 @@ def test_get_barcode_subset_with_existing_subset( assert enable_correction == expected_enable_correction -def test_get_barcode_subset_with_no_subset(correct_reference_df, barcodes_df): - expected_subset = pl.DataFrame( +def test_get_barcode_subset_with_no_subset(correct_reference_lf, barcodes_df): + expected_subset = pl.LazyFrame( { SUBSET_COLUMN: ["ATGCCC", "ATGCTT", "CCGCCC"], } @@ -249,7 +249,7 @@ def test_get_barcode_subset_with_no_subset(correct_reference_df, barcodes_df): expected_enable_correction = True subset, enable_correction = get_barcode_subset( - barcode_reference=correct_reference_df, + barcode_reference=correct_reference_lf, n_barcodes=3, chemistry=None, barcode_subset=None, diff --git a/tests/test_processing.py b/tests/test_processing.py index 067e60f..1accd07 100644 --- a/tests/test_processing.py +++ b/tests/test_processing.py @@ -1,107 +1,79 @@ import pytest -from collections import namedtuple from cite_seq_count import processing import polars as pl from polars.testing import assert_frame_equal @pytest.fixture -def data(): - tag = namedtuple("tag", ["name", "sequence", "id"]) - - pytest.barcodes_df = pl.DataFrame( +def barcodes_df(): + return pl.DataFrame( { "barcode": [ "TACATATTCTTTACTG", "AACATATTCTTTACTG", "CACATATTCTTTACTG", "GACATATTCTTTACTG", - "TACATATTCTTTACTA", - "TACATATTCTTTACTC", - "TACATATTCTTTACTT", + "GCTAGTCGTAGCTAGA", + "GCTAGTCGTAGCTAGT", + "GCTAGTCGTAGCTAGG", + "GCTAGTCGTAGCTAGC", "TAGAGGGAGGTCAAGC", "TAGAGGGACGTCAAGC", "TAGAGGGATGTCAAGC", "TAGAGGGAAGTCAAGC", ], - "count": [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + "count": [5, 1, 1, 1, 5, 1, 1, 1, 5, 1, 1, 1], } ) - pytest.barcode_subset_df = pl.DataFrame( - {"whitelist": ["TACATATTCTTTACTG", "TAGAGGGAAGTCAAGC"]} - ) - pytest.corrected_barcodes_df = pl.DataFrame( +def test_correct_barcodes_pl(barcodes_df): + barcode_subset_df = pl.DataFrame( { - "barcode": [ - "TAGAGGGAAGTCAAGC", + "subset": [ "TACATATTCTTTACTG", + "GCTAGTCGTAGCTAGA", + "TAGAGGGAGGTCAAGC", ], - "count": [4, 7], } ) + hamming_distance = 1 - pytest.expected_cells = 2 - pytest.collapsing_threshold = 1 - pytest.max_umis = 20000 + ( + corrected_barcodes, + n_corrected_barcodes, + mapped_barcodes, + ) = processing.correct_barcodes_pl(barcodes_df, barcode_subset_df, hamming_distance) - -@pytest.mark.dependency() -def test_correct_barcodes(data): - corrected_barcodes, _, _ = processing.correct_barcodes_pl( - barcodes_df=pytest.barcodes_df, - barcode_subset_df=pytest.barcode_subset_df, - hamming_distance=1, + # Assert the corrected barcodes + expected_corrected_barcodes = pl.LazyFrame( + { + "barcode": [ + "TACATATTCTTTACTG", + "GCTAGTCGTAGCTAGA", + "TAGAGGGAGGTCAAGC", + ], + "count": [8, 8, 8], + } ) assert_frame_equal( - pytest.corrected_barcodes_df, corrected_barcodes, check_row_order=False + corrected_barcodes, expected_corrected_barcodes, check_row_order=False ) + # Assert the number of corrected barcodes + expected_n_corrected_barcodes = 9 + assert n_corrected_barcodes == expected_n_corrected_barcodes -@pytest.mark.dependency() -def test_correct_umis(data): - temp = processing.correct_umis_in_cells((pytest.results, 2, pytest.max_umis, 2)) - results = temp[0] - n_corrected = temp[1] - for cell_barcode in results.keys(): - for TAG in results[cell_barcode]: - assert len(results[cell_barcode][TAG]) == len( - pytest.corrected_results[cell_barcode][TAG] - ) - assert sum(results[cell_barcode][TAG].values()) == sum( - pytest.corrected_results[cell_barcode][TAG].values() - ) - assert n_corrected == 3 - - -@pytest.mark.dependency(depends=["test_correct_umis"]) -def test_generate_sparse_umi_matrices(data): - umi_results_matrix = processing.generate_sparse_matrices( - pytest.corrected_results, - pytest.tags, - ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"], - umi_counts=True, - ) - assert umi_results_matrix.shape == (2, 2) - total_umis = 0 - for i in range(umi_results_matrix.shape[0]): - for j in range(umi_results_matrix.shape[1]): - total_umis += umi_results_matrix[i, j] - assert total_umis == 3 - - -@pytest.mark.dependency(depends=["test_correct_umis"]) -def test_generate_sparse_read_matrices(data): - read_results_matrix = processing.generate_sparse_matrices( - pytest.corrected_results, - pytest.tags, - ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"], - umi_counts=False, - ) - assert read_results_matrix.shape == (3, 2) - total_umis = 0 - for i in range(read_results_matrix.shape[0]): - for j in range(read_results_matrix.shape[1]): - total_umis += read_results_matrix[i, j] - assert total_umis == 12 + # Assert the mapped barcodes + expected_mapped_barcodes = { + "AACATATTCTTTACTG":"TACATATTCTTTACTG", + "CACATATTCTTTACTG":"TACATATTCTTTACTG", + "GACATATTCTTTACTG":"TACATATTCTTTACTG", + "GCTAGTCGTAGCTAGT":"GCTAGTCGTAGCTAGA", + "GCTAGTCGTAGCTAGG":"GCTAGTCGTAGCTAGA", + "GCTAGTCGTAGCTAGC":"GCTAGTCGTAGCTAGA", + "TAGAGGGACGTCAAGC":"TAGAGGGAGGTCAAGC", + "TAGAGGGATGTCAAGC":"TAGAGGGAGGTCAAGC", + "TAGAGGGAAGTCAAGC":"TAGAGGGAGGTCAAGC", + } + assert mapped_barcodes == expected_mapped_barcodes From 67d6c8296d79de65922ca1a946686a7a87b9a04c Mon Sep 17 00:00:00 2001 From: hoohm Date: Sun, 14 Jan 2024 17:40:30 +0100 Subject: [PATCH 76/77] (feat): Read fastq files using polars. --- cite_seq_count/__main__.py | 60 +++++--- cite_seq_count/argsparser.py | 2 +- cite_seq_count/chemistry.py | 7 +- cite_seq_count/io.py | 128 ++++++++++++++++- cite_seq_count/mapping.py | 20 +-- cite_seq_count/preprocessing.py | 10 +- cite_seq_count/processing.py | 191 +++++++++++++++++++------ setup.py | 2 +- tests/test_data/fastq/correct_R1.csv | 201 +++++++++++++++++++++++++++ tests/test_data/fastq/correct_R2.csv | 201 +++++++++++++++++++++++++++ tests/test_io.py | 92 +++++++++++- tests/test_processing.py | 38 +++-- 12 files changed, 842 insertions(+), 110 deletions(-) create mode 100644 tests/test_data/fastq/correct_R1.csv create mode 100644 tests/test_data/fastq/correct_R2.csv diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index 3e813fa..ad732a8 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -61,31 +61,45 @@ def main(): read1_paths, read2_paths = io.get_read_paths(args.read1_path, args.read2_path) # Checks before chunking. - (n_reads, r2_min_length, maximum_distance) = preprocessing.pre_run_checks( - read1_paths=read1_paths, - chemistry_def=chemistry_def, - longest_tag_len=longest_tag_len, - arguments=args, - ) + # (n_reads, r2_min_length, maximum_distance) = preprocessing.pre_run_checks( + # read1_paths=read1_paths, + # chemistry_def=chemistry_def, + # longest_tag_len=longest_tag_len, + # arguments=args, + # ) + # ( + # temp_file, + # r1_too_short, + # r2_too_short, + # total_reads, + # ) = io.write_mapping_input( + # args=args, + # read1_paths=read1_paths, + # read2_paths=read2_paths, + # r2_min_length=r2_min_length, + # chemistry_def=chemistry_def, + # ) + print("Writing mapping input") ( - temp_file, r1_too_short, r2_too_short, - total_reads, - ) = io.write_mapping_input( - args=args, - read1_paths=read1_paths, - read2_paths=read2_paths, - r2_min_length=r2_min_length, + n_reads, + temp_file_path, + ) = io.write_mapping_input_from_fastqs( chemistry_def=chemistry_def, + fastq_paths=list(zip(read1_paths, read2_paths)), + r2_min_length=longest_tag_len, + top_n_reads=args.first_n, + temp_path=args.temp_path, ) + print("Processing") main_df, barcodes_df, r2_df = preprocessing.split_data_input( - mapping_input_path=temp_file, n_reads=n_reads + mapping_input_path=temp_file_path, ) mapped_r2_df, unmapped_r2_df = mapping.map_reads_polars( r2_df=r2_df, parsed_tags=parsed_tags, - maximum_distance=maximum_distance, + maximum_distance=args.max_error, ) unmapped_df = processing.summarise_unmapped_df( main_df=main_df, unmapped_r2_df=unmapped_r2_df @@ -117,9 +131,11 @@ def main(): ) else: n_bcs_corrected = 0 - print("UMI correction") final_read_counts, umis_corrected, clustered_cells = processing.correct_umis_df( - main_df=main_df, mapped_r2_df=mapped_r2_df, umi_distance=args.umi_threshold + main_df=main_df, + mapped_r2_df=mapped_r2_df, + barcode_subset=barcode_subset, + umi_distance=args.umi_threshold, ) # # Write reads to file @@ -147,9 +163,9 @@ def main(): # TODO: rewrite reporting # Create report and write it to disk # Remove temp file - + io.create_report( - total_reads=total_reads, + total_reads=n_reads, unmapped=unmapped_df.collect(), version=argsparser.get_package_version(), start_time=start_time, @@ -160,8 +176,10 @@ def main(): r2_too_short=r2_too_short, args=args, chemistry_def=chemistry_def, - maximum_distance=maximum_distance, + maximum_distance=args.max_error, ) - os.remove(temp_file) + os.remove(temp_file_path) + + if __name__ == "__main__": main() diff --git a/cite_seq_count/argsparser.py b/cite_seq_count/argsparser.py index 4c392b5..236ee21 100644 --- a/cite_seq_count/argsparser.py +++ b/cite_seq_count/argsparser.py @@ -250,7 +250,7 @@ def get_args() -> ArgumentParser: parser.add_argument( "--temp_path", required=False, - type=str, + type=Path, dest="temp_path", default=tempfile.gettempdir(), help=( diff --git a/cite_seq_count/chemistry.py b/cite_seq_count/chemistry.py index 5d43951..eeb1bef 100644 --- a/cite_seq_count/chemistry.py +++ b/cite_seq_count/chemistry.py @@ -27,7 +27,8 @@ class Chemistry: barcode_reference_path: str def __post_init__(self): - self.barcode_length = self.cell_barcode_end - self.umi_barcode_start + 1 + self.barcode_length = self.cell_barcode_end - self.cell_barcode_start + 1 + self.umi_length = self.umi_barcode_end - self.umi_barcode_start + 1 DEFINITIONS_DB = pooch.create( @@ -133,12 +134,11 @@ def create_chemistry_definition(args: Namespace) -> Chemistry: return chemistry_def -def setup_chemistry(args: Namespace) -> tuple[pl.DataFrame | None, Chemistry]: +def setup_chemistry(args: Namespace) -> tuple[pl.LazyFrame | None, Chemistry]: if args.chemistry_id: chemistry_def = get_chemistry_definition(args.chemistry_id) barcode_reference = parse_barcode_file( filename=chemistry_def.barcode_reference_path, - barcode_length=chemistry_def.barcode_length, required_header=["reference"], ) else: @@ -147,7 +147,6 @@ def setup_chemistry(args: Namespace) -> tuple[pl.DataFrame | None, Chemistry]: print("Loading barcode reference") barcode_reference = parse_barcode_file( filename=args.reference, - barcode_length=chemistry_def.barcode_length, required_header=["reference"], ) else: diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index a6dccb1..ded1259 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -1,4 +1,5 @@ """Handle io operations""" +from math import inf import os import csv import sys @@ -11,17 +12,13 @@ from argparse import Namespace from itertools import islice -from collections import Counter from typing import Tuple, TextIO from pathlib import Path from os import access, R_OK -import scipy import pkg_resources import yaml -import pandas as pd import polars as pl -from scipy import io from cite_seq_count import secondsToText from cite_seq_count.constants import ( FEATURE_NAME_COLUMN, @@ -33,7 +30,9 @@ FEATURES_MTX, BARCODE_MTX, MATRIX_MTX, + R2_COLUMN, TEMP_MTX, + UMI_COLUMN, UNMAPPED_NAME, SEQUENCE_COLUMN, SUBSET_COLUMN, @@ -186,6 +185,88 @@ def write_unmapped(unmapped_df: pl.LazyFrame, outfolder: Path, filename: str): unmapped_df.collect().write_csv(file=outfolder / f"{filename}.csv") +def read_R1_polars(R1_path: Path, chemistry_def) -> pl.LazyFrame: + return ( + pl.read_csv(R1_path, has_header=False, separator="\t", new_columns=["sequence"]) + .lazy() + .gather_every(4, offset=1) + .with_columns( + pl.col("sequence") + .str.slice(offset=0, length=chemistry_def.barcode_length) + .alias("barcode"), + pl.col("sequence") + .str.slice( + offset=chemistry_def.cell_barcode_end, length=chemistry_def.umi_length + ) + .alias("umi"), + ) + .drop("sequence") + ) + + +def read_R2_polars(R2_path: Path, r2_min_length: int, chemistry) -> pl.LazyFrame: + return ( + pl.read_csv(R2_path, has_header=False, separator="\t", new_columns=["r2"]) + .lazy() + .gather_every(4, offset=1) + .with_columns( + pl.col("r2").str.slice(offset=chemistry.r2_trim_start, length=r2_min_length) + ) + ) + + +def write_fastq_inputs_as_parquet( + read_paths: list, + temp_path: Path, + chemistry_def, + r2_min_length: int, + top_n_reads: int, +) -> tuple[int, int, int]: + concats = [] + for r1_read_path, r2_read_path in read_paths: + R1_read = read_R1_polars(R1_path=r1_read_path, chemistry_def=chemistry_def) + R2_read = read_R2_polars( + R2_path=r2_read_path, r2_min_length=r2_min_length, chemistry=chemistry_def + ) + if top_n_reads != inf: + data = ( + pl.concat([R1_read, R2_read], how="horizontal") + # Since this is lazy, the head is actually computed ahead of time + .head(4 * round(top_n_reads / len(read_paths))) + .group_by(["barcode", "umi", "r2"]) + .agg(pl.count()) + ) + else: + data = ( + pl.concat([R1_read, R2_read], how="horizontal") + .group_by(["barcode", "umi", "r2"]) + .agg(pl.count()) + ) + + concats.append(data) + + all = pl.concat(concats) + total_reads = all.select(pl.sum(COUNT_COLUMN)).collect().item() + r1_too_short = ( + all.filter( + (pl.col(BARCODE_COLUMN) + pl.col(UMI_COLUMN)).str.len_bytes() + < (chemistry_def.barcode_length + chemistry_def.umi_length) + ) + .select(pl.sum(COUNT_COLUMN)) + .collect() + .item() + ) + r2_too_short = ( + all.filter((pl.col(R2_COLUMN).str.len_bytes() < r2_min_length)) + .select(pl.sum(COUNT_COLUMN)) + .collect() + .item() + ) + + all.collect().write_parquet(file=temp_path) + return r1_too_short, r2_too_short, total_reads + + def load_report_template() -> dict: """Load json template for the report @@ -227,7 +308,7 @@ def create_report( args (arg_parse): Arguments provided by the user. """ - total_unmapped = unmapped[COUNT_COLUMN][0] + total_unmapped = unmapped.sum().get_column(COUNT_COLUMN)[0] total_too_short = r1_too_short + r2_too_short total_mapped = total_reads - total_unmapped - total_too_short @@ -297,9 +378,9 @@ def create_mtx_df( ] ) .sort(pl.col(FEATURE_NAME_COLUMN)) - .with_row_count(offset=1, name=FEATURE_ID_COLUMN) + .with_row_index(offset=1, name=FEATURE_ID_COLUMN) ) - barcodes_indexed = subset_df.sort(pl.col(SUBSET_COLUMN)).with_row_count( + barcodes_indexed = subset_df.sort(pl.col(SUBSET_COLUMN)).with_row_index( offset=1, name=BARCODE_ID_COLUMN ) mtx_df = ( @@ -442,3 +523,36 @@ def write_mapping_input( r2_too_short, total_reads, ) + + +def write_mapping_input_from_fastqs( + fastq_paths: list[tuple[Path, Path]], + r2_min_length: int, + chemistry_def, + top_n_reads: int, + temp_path: str, +) -> tuple[int, int, int, Path]: + """Read fastq inputs using polars read_csv, concatenate R1 and R2 files, summaries and write to parquet + + Args: + fastq_paths (list[list]): List of lists containing the paths to the R1 and R2 fastq files + chemistry (str): The chemistry definition + + Returns: + None + """ + temp_path = os.path.abspath(temp_path) + temp_file = tempfile.NamedTemporaryFile( + "w", dir=temp_path, suffix="_csc.parquet", delete=False + ) + temp_file_path = Path(temp_file.name) + + r1_too_short, r2_too_short, total_reads = write_fastq_inputs_as_parquet( + temp_path=temp_file_path, + read_paths=fastq_paths, + chemistry_def=chemistry_def, + r2_min_length=r2_min_length, + top_n_reads=top_n_reads, + ) + + return r1_too_short, r2_too_short, total_reads, temp_file_path diff --git a/cite_seq_count/mapping.py b/cite_seq_count/mapping.py index 1789dfe..1a0f374 100644 --- a/cite_seq_count/mapping.py +++ b/cite_seq_count/mapping.py @@ -91,19 +91,13 @@ def map_reads_polars( .drop([SEQUENCE_COLUMN, "levenshtein_dist"]) .rename({"feature_name_right": FEATURE_NAME_COLUMN}) ) - multi_mapped = ( - levenshtein_mapped.group_by(R2_COLUMN) - .agg(pl.count()) - .filter(pl.col(COUNT_COLUMN) > 1) - ) - mapped = pl.concat( - [ - simple_join, - levenshtein_mapped.join(multi_mapped, on=R2_COLUMN, how="inner").drop( - COUNT_COLUMN - ), - ] - ) + # TODO: Deal with multimapped reads + # multi_mapped = ( + # levenshtein_mapped.group_by(R2_COLUMN) + # .agg(pl.count()) + # .filter(pl.col(COUNT_COLUMN) > 1) + # ) + mapped = pl.concat([simple_join, levenshtein_mapped]) unmapped = ( joined.join(mapped.drop(FEATURE_NAME_COLUMN), on=R2_COLUMN, how="outer") .with_columns(pl.col(FEATURE_NAME_COLUMN).fill_null(UNMAPPED_NAME)) diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index 8f72099..b90708d 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -10,7 +10,7 @@ import Levenshtein import polars as pl import umi_tools.whitelist_methods as whitelist_method -from cite_seq_count.io import get_n_lines, check_file +from cite_seq_count.io import get_n_lines, check_file, write_fastq_inputs_as_parquet from cite_seq_count.constants import ( SEQUENCE_COLUMN, R2_COLUMN, @@ -364,7 +364,7 @@ def pre_run_checks( def split_data_input( - mapping_input_path: Path, n_reads: int + mapping_input_path: Path, ) -> tuple[pl.LazyFrame, pl.LazyFrame, pl.LazyFrame]: """Read in all the input data and split it into three dataframes. @@ -381,16 +381,12 @@ def split_data_input( tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]: Three dfs described above """ main_df = ( - pl.scan_csv( + pl.scan_parquet( mapping_input_path, - has_header=False, - new_columns=[BARCODE_COLUMN, UMI_COLUMN, R2_COLUMN], - n_rows=n_reads, ) .group_by([BARCODE_COLUMN, UMI_COLUMN, R2_COLUMN]) .agg(pl.count()) ) - barcodes_df = ( main_df.select([BARCODE_COLUMN, COUNT_COLUMN]) .group_by(BARCODE_COLUMN) diff --git a/cite_seq_count/processing.py b/cite_seq_count/processing.py index 686084d..189fb6a 100644 --- a/cite_seq_count/processing.py +++ b/cite_seq_count/processing.py @@ -1,3 +1,4 @@ +from turtle import right import polars as pl import polars_distance as pld @@ -79,6 +80,7 @@ def correct_barcodes_pl( .with_columns( pld.col(BARCODE_COLUMN) .dist_str.hamming(pl.col(SUBSET_COLUMN)) + .cast(pl.UInt32) .alias("hamming_distance") ) .filter(pl.col("hamming_distance") <= hamming_distance) @@ -96,6 +98,7 @@ def correct_barcodes_pl( ) n_iterations += 1 print(f"Corrected barcodes in {n_iterations} iterations") + print(f"Number of uncorrected barcodes: {current_barcodes.shape[0]}") mapped_barcodes = dict( corrected_barcodes_pl.select(BARCODE_COLUMN, SUBSET_COLUMN).iter_rows() # type: ignore ) @@ -103,8 +106,8 @@ def correct_barcodes_pl( barcodes_df.with_columns( pl.col(BARCODE_COLUMN).map_dict(mapped_barcodes, default=pl.first()) ) - # .filter(pl.col(BARCODE_COLUMN).is_in(barcode_subset_df[SUBSET_COLUMN])) - .group_by(BARCODE_COLUMN).agg(pl.sum("count")) + .group_by(BARCODE_COLUMN) + .agg(pl.sum("count")) ) print("Barcodes corrected") n_corrected_barcodes = corrected_barcodes_pl.shape[0] @@ -177,27 +180,83 @@ def update_main_df(main_df: pl.LazyFrame, mapped_barcodes: dict) -> pl.LazyFrame def correct_umis_df( main_df: pl.LazyFrame, mapped_r2_df: pl.LazyFrame, - umi_distance, - cluster_method="directional", - max_umis=20000, + umi_distance: int, + barcode_subset: pl.LazyFrame, + cluster_method: str = "directional", + max_umis: int = 20000, ) -> tuple[pl.LazyFrame, int, list]: - merged = mapped_r2_df.join(main_df, on=R2_COLUMN) - temp = ( - merged.with_columns(umi=pl.col(UMI_COLUMN).cast(pl.Binary)) - .group_by([R2_COLUMN, FEATURE_NAME_COLUMN, BARCODE_COLUMN]) - .agg(pl.struct(pl.col(UMI_COLUMN), pl.col(COUNT_COLUMN))) - ).collect() - clustered_cells = ( - temp.filter(pl.col(UMI_COLUMN).list.len() > max_umis) - .select(BARCODE_COLUMN) - .unique() - .get_column(BARCODE_COLUMN) - .to_list() + """Take main df and mapped reads to correct umis. + + Args: + main_df (pl.LazyFrame): Main df mapping r2, barcode and umi + mapped_r2_df (pl.LazyFrame): mapped reads + umi_distance (int): Hamming distance for umi correction + cluster_method (str, optional): Cluster methods used by the umi clusterer. Defaults to "directional". + max_umis (int, optional): Threshold for clustered cells. Defaults to 20000. + + Returns: + tuple[pl.LazyFrame, int, list]: read counts, number of umis corrected, clustered cells lf + """ + merged_lf = mapped_r2_df.join(main_df, on=R2_COLUMN).join( + barcode_subset, left_on=BARCODE_COLUMN, right_on=SUBSET_COLUMN, how="inner" ) + if umi_distance > 0: + print("UMI correction") + temp = ( + merged_lf.with_columns(umi=pl.col(UMI_COLUMN).cast(pl.Binary)) + .drop(R2_COLUMN) + .group_by([FEATURE_NAME_COLUMN, BARCODE_COLUMN]) + .agg(pl.struct(pl.col(UMI_COLUMN), pl.col(COUNT_COLUMN))) + ).collect() + + clustered_cells = ( + temp.filter(pl.col(UMI_COLUMN).list.len() > max_umis) + .select(BARCODE_COLUMN) + .unique() + .get_column(BARCODE_COLUMN) + .to_list() + ) + + umi_mapping_lf = find_umis_to_correct( + temp=temp, + cluster_method=cluster_method, + max_umis=max_umis, + umi_distance=umi_distance, + ) + read_counts = update_umis_and_create_read_counts( + merged_lf=merged_lf, umi_mapping_lf=umi_mapping_lf + ) + + n_corrected_umis = umi_mapping_lf.collect().shape[0] + else: + read_counts = ( + merged_lf.group_by([BARCODE_COLUMN, FEATURE_NAME_COLUMN, UMI_COLUMN]) + .agg(pl.sum(COUNT_COLUMN)) + .drop(R2_COLUMN) + ) + clustered_cells = [] + n_corrected_umis = 0 + + return read_counts, n_corrected_umis, clustered_cells + + +def find_umis_to_correct( + temp: pl.DataFrame, cluster_method: str, max_umis: int, umi_distance: int +) -> pl.LazyFrame: + """Iterate through all umis that might need correcting and return a mapping lf + Args: + temp (pl.DataFrame): Filtered aggregated UMI per barcode per features + cluster_method (str): What cluster method to use + max_umis (int): Threshold for clustered cells + umi_distance (int): Hamming distance for umi correction + + Returns: + pl.LazyFrame: Mapping to correct umis + """ mapping_list = [] umi_clusterer = network.UMIClusterer(cluster_method=cluster_method) - for r2, feature_name, barcode, umis in temp.filter( + for feature_name, barcode, umis in temp.filter( (pl.col(UMI_COLUMN).list.len() > 1) & (pl.col(UMI_COLUMN).list.len() < max_umis) ).iter_rows(): corrected_umis = correct_umis( @@ -210,46 +269,79 @@ def correct_umis_df( for umi_set in corrected_umis: for index, umi in enumerate(umi_set): if index != 0: - mapping_list.append([r2, feature_name, barcode, umi, umi_set[0]]) + mapping_list.append([feature_name, barcode, umi, umi_set[0]]) + return get_umi_mapping_lf(mapping_list) + + +def update_umis_and_create_read_counts( + merged_lf: pl.LazyFrame, umi_mapping_lf: pl.LazyFrame +) -> pl.LazyFrame: + """Update corrected umis and format to a read count table + + Args: + merged_lf (pl.LazyFrame): Main df and R2 df joined + umi_mapping_lf (pl.LazyFrame): UMIs to update + + Returns: + pl.LazyFrame: Final read counts + """ + return ( + merged_lf.join( + umi_mapping_lf, + on=[FEATURE_NAME_COLUMN, BARCODE_COLUMN, UMI_COLUMN], + how="left", + ) + .with_columns( + pl.when(pl.col("replace").is_null()) + .then(pl.col(UMI_COLUMN)) + .otherwise(pl.col("replace")) + .alias(UMI_COLUMN) + ) + .drop("replace") + .group_by([BARCODE_COLUMN, FEATURE_NAME_COLUMN, UMI_COLUMN]) + .agg(pl.sum(COUNT_COLUMN)) + ) + + +def get_umi_mapping_lf(mapping_list: list) -> pl.LazyFrame: + """Convert a list of UMI per barcode per features to correct to a lazy frame + + Args: + mapping_list (list): List of UMIs to correct + + Returns: + pl.LazyFrame: LazyFrame to correct UMIs + """ mapping_df = ( pl.LazyFrame( mapping_list, schema={ - R2_COLUMN: pl.String, FEATURE_NAME_COLUMN: pl.String, BARCODE_COLUMN: pl.String, - "orig": pl.Binary, + UMI_COLUMN: pl.Binary, "replace": pl.Binary, }, ) .with_columns( - pl.col("orig").cast(pl.String).alias(UMI_COLUMN), + pl.col(UMI_COLUMN).cast(pl.String), pl.col("replace").cast(pl.String), ) .drop("orig") ) - read_counts = ( - merged.join( - mapping_df, - on=[R2_COLUMN, FEATURE_NAME_COLUMN, BARCODE_COLUMN, UMI_COLUMN], - how="left", - ) - .with_columns( - pl.when(pl.col("replace").is_null()) - .then(pl.col(UMI_COLUMN)) - .otherwise(pl.col("replace")) - .alias(UMI_COLUMN) - ) - .drop("replace") - .group_by([R2_COLUMN, BARCODE_COLUMN, FEATURE_NAME_COLUMN, UMI_COLUMN]) - .agg(pl.sum(COUNT_COLUMN)) - .drop(R2_COLUMN) - ) - n_corrected_umis = mapping_df.collect().shape[0] - return read_counts, n_corrected_umis, clustered_cells + return mapping_df -def correct_umis(umis_list, umi_distance, umi_clusterer): +def correct_umis(umis_list, umi_distance, umi_clusterer) -> list[list]: + """Find corrected umis from a pl.struct of UMI and counts + + Args: + umis_list (_type_): pl.struct of UMI and their counts + umi_distance (_type_): Hamming distance for a correction + umi_clusterer (_type_): umi_cluster object + + Returns: + list[list]: List of list. First member is the + """ umis = dict([(i["umi"], i["count"]) for i in umis_list]) res = umi_clusterer(umis, umi_distance) @@ -277,3 +369,18 @@ def correct_umis(umis_list, umi_distance, umi_clusterer): # .group_by([BARCODE_COLUMN, FEATURE_NAME_COLUMN]) # .agg(pl.count()) # ) + + +def find_closest_match( + df: pl.LazyFrame, + source_column: str, + target_df: pl.LazyFrame, + Levenshtein_distance=1, +): + return df.join( + target_df, left_on=source_column, right_on=target_df.columns[0], how="cross" + ).with_columns( + pld.col(source_column) + .dist_str.levenshtein(target_df.columns[0]) + .alias("distance") + ) diff --git a/setup.py b/setup.py index cf53903..4127138 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,7 @@ "pyyaml==6.0", "pooch==1.6.0", "six==1.16.0", - "polars==0.20.3", + "polars==0.20.4", ], python_requires="==3.11.6", package_data={"report_template": ["templates/*.json"]}, diff --git a/tests/test_data/fastq/correct_R1.csv b/tests/test_data/fastq/correct_R1.csv new file mode 100644 index 0000000..c1f9c89 --- /dev/null +++ b/tests/test_data/fastq/correct_R1.csv @@ -0,0 +1,201 @@ +barcode,umi +TAGAGGGAAGTCAAGC,CNGAGTCTCN +TAGAGGGAAGTCAAGC,CACNTAAATC +TAGAGGGAAGTCAAGC,TCTCCGAAGC +TAGAGGGAAGTCAAGC,TTGACTTATC +TAGAGGGAAGTCAAGC,TTTCCTTCCG +TAGAGGGAAGTCAAGC,NAGAACACCA +TAGAGGGAAGTCAAGC,CTCATCTTAT +TAGAGGGAAGTCAAGC,TCGACAAGCT +TAGAGGGAAGTCAAGC,AAANACACTC +TAGAGGGAAGTCAAGC,AAANACACTC +TAGAGGGAAGTCAAGC,GNCNCTGATA +TAGAGGGAAGTCAAGC,TGGTGTGNGC +TAGAGGGAAGTCAAGC,AGANTCTCCC +TAGAGGGAAGTCAAGC,GCGCGGAGAA +TAGAGGGAAGTCAAGC,TGGTGTGNGC +TAGAGGGAAGTCAAGC,TCTGGTAGTT +TAGAGGGAAGTCAAGC,TAGGAGCAGN +TAGAGGGAAGTCAAGC,ATCACGGATC +TAGAGGGAAGTCAAGC,CCGGGGACTA +TAGAGGGAAGTCAAGC,TTACTAGTAA +TAGAGGGAAGTCAAGC,AGAACGTCGC +TAGAGGGAAGTCAAGC,ANGAGNAAGT +TAGAGGGAAGTCAAGC,ATAACTAGAA +TAGAGGGAAGTCAAGC,TTGACTTATC +TAGAGGGAAGTCAAGC,ACACTGCTAT +TAGAGGGAAGTCAAGC,GCGCGGAGAA +TAGAGGGAAGTCAAGC,CGTGATGAGC +TAGAGGGAAGTCAAGC,TCACCCNCGG +TAGAGGGAAGTCAAGC,ATCCGTACTT +TAGAGGGAAGTCAAGC,CCGGGGACTA +TAGAGGGAAGTCAAGC,NTTCATGTTG +TAGAGGGAAGTCAAGC,ATCGGGAGNC +TAGAGGGAAGTCAAGC,CGTCNGTTGC +TAGAGGGAAGTCAAGC,TAGTCAGAAT +TAGAGGGAAGTCAAGC,ATCGGGAGNC +TAGAGGGAAGTCAAGC,ACACTGCTAT +TAGAGGGAAGTCAAGC,TGCTCAATAG +TAGAGGGAAGTCAAGC,GATCGTACAA +TAGAGGGAAGTCAAGC,AGACCTNTGG +TAGAGGGAAGTCAAGC,TGACCTAAGC +TAGAGGGAAGTCAAGC,GATCGTACAA +TAGAGGGAAGTCAAGC,TGGTGTGNGC +TAGAGGGAAGTCAAGC,TCGACACCAC +TAGAGGGAAGTCAAGC,CATCNAGTGN +TAGAGGGAAGTCAAGC,TTAAAACCNA +TAGAGGGAAGTCAAGC,CCATACGNNA +TAGAGGGAAGTCAAGC,ATTGTTCGGA +TAGAGGGAAGTCAAGC,CACNCTAGGG +TAGAGGGAAGTCAAGC,AGGAGCNCCC +TAGAGGGAAGTCAAGC,GCGCGGAGAA +TAGAGGGAAGTCAAGC,TTCCGTNCAA +TAGAGGGAAGTCAAGC,GNCNCTGATA +TAGAGGGAAGTCAAGC,CAAGGGAACG +TAGAGGGAAGTCAAGC,AGAANGCCNA +TAGAGGGAAGTCAAGC,TCTGGTAGTT +TAGAGGGAAGTCAAGC,ACAGAGTAAN +TAGAGGGAAGTCAAGC,CACNCTAGGG +TAGAGGGAAGTCAAGC,CTACACGTGA +TAGAGGGAAGTCAAGC,CGTCNGTTGC +TAGAGGGAAGTCAAGC,CGTCGTTATA +TAGAGGGAAGTCAAGC,TTCNGTCACC +TAGAGGGAAGTCAAGC,TTGGNGTACA +TAGAGGGAAGTCAAGC,TCGACAAGCT +TAGAGGGAAGTCAAGC,CGNAATTTGA +TAGAGGGAAGTCAAGC,GCGCGGAGAA +TAGAGGGAAGTCAAGC,CTACGCCGCC +TAGAGGGAAGTCAAGC,TAGGAGCAGN +TAGAGGGAAGTCAAGC,AGCANTGTAG +TAGAGGGAAGTCAAGC,AGAANGCCNA +TAGAGGGAAGTCAAGC,TGTCTGCACG +TAGAGGGAAGTCAAGC,CNGAGTCTCN +TAGAGGGAAGTCAAGC,ATAACTAGAA +TAGAGGGAAGTCAAGC,CCATGTGNGT +TAGAGGGAAGTCAAGC,CCATGTGNGT +TAGAGGGAAGTCAAGC,CGTAGGCATT +TAGAGGGAAGTCAAGC,CCGGGGACTA +TAGAGGGAAGTCAAGC,ANTTCTCTCA +TAGAGGGAAGTCAAGC,ACAGAGTAAN +TAGAGGGAAGTCAAGC,TGCTCAATAG +TAGAGGGAAGTCAAGC,AGACTTAGGG +TAGAGGGAAGTCAAGC,TAAAGGCTTG +TAGAGGGAAGTCAAGC,CTTGAGAGGG +TAGAGGGAAGTCAAGC,TTCCGTNCAA +TAGAGGGAAGTCAAGC,CATCNAGTGN +TAGAGGGAAGTCAAGC,GAACACTGAG +TAGAGGGAAGTCAAGC,ACGCGGAGTT +TAGAGGGAAGTCAAGC,ANGAGNAAGT +TAGAGGGAAGTCAAGC,ANTTCTCTCA +TAGAGGGAAGTCAAGC,GCTGTGTTAG +TAGAGGGAAGTCAAGC,TTCNGTCACC +TAGAGGGAAGTCAAGC,CCATACGNNA +TAGAGGGAAGTCAAGC,CGTCNGTTGC +TAGAGGGAAGTCAAGC,CCATGTGNGT +TAGAGGGAAGTCAAGC,GCCCGCTCAC +TAGAGGGAAGTCAAGC,AAANACACTC +TAGAGGGAAGTCAAGC,TAAAGGCTTG +TAGAGGGAAGTCAAGC,GATTTCACCG +TAGAGGGAAGTCAAGC,CCATACGNNA +TAGAGGGAAGTCAAGC,AGCANTGTAG +TAGAGGGAAGTCAAGC,AACTCCCACG +TACATATTCTTTACTG,GATCGAACGG +TACATATTCTTTACTG,GATCGAACGG +TACATATTCTTTACTG,GANCGGGACA +TACATATTCTTTACTG,NGCTGGCACG +TACATATTCTTTACTG,GANCGGGACA +TACATATTCTTTACTG,ATTCATTGTA +TACATATTCTTTACTG,TAATCATACC +TACATATTCTTTACTG,GGTCTAAGAG +TACATATTCTTTACTG,GGATNTTGTA +TACATATTCTTTACTG,AATATGANTG +TACATATTCTTTACTG,ATTAAGCCNG +TACATATTCTTTACTG,TGAGGGTAGA +TACATATTCTTTACTG,CTCTCGCTTT +TACATATTCTTTACTG,GGATNTTGTA +TACATATTCTTTACTG,AGGTTTACTG +TACATATTCTTTACTG,GAGGCGTGTC +TACATATTCTTTACTG,TGCTGAATAA +TACATATTCTTTACTG,AAGGCACTTT +TACATATTCTTTACTG,CTTTCAAGTN +TACATATTCTTTACTG,GAGGCGTGTC +TACATATTCTTTACTG,TGTNAATCCA +TACATATTCTTTACTG,GCCAAGTACA +TACATATTCTTTACTG,GGATNTTGTA +TACATATTCTTTACTG,CCGNTGTGGC +TACATATTCTTTACTG,GATCGAACGG +TACATATTCTTTACTG,TCGCGATGNT +TACATATTCTTTACTG,ACCGTGAGGC +TACATATTCTTTACTG,GGTCGCAGTN +TACATATTCTTTACTG,CCAGACTTGA +TACATATTCTTTACTG,GACTTTTCCT +TACATATTCTTTACTG,CTTCCATGCC +TACATATTCTTTACTG,AGCAACCCGA +TACATATTCTTTACTG,AGCAACCCGA +TACATATTCTTTACTG,GACGGGGTCT +TACATATTCTTTACTG,TACGAAGAAT +TACATATTCTTTACTG,CGAGGTGCGN +TACATATTCTTTACTG,TNCATCGGAT +TACATATTCTTTACTG,AAGGCACTTT +TACATATTCTTTACTG,ATTAAGCCNG +TACATATTCTTTACTG,GCTAACCCGN +TACATATTCTTTACTG,AGGTTTACTG +TACATATTCTTTACTG,ANAGGANAAC +TACATATTCTTTACTG,AGGTTTACTG +TACATATTCTTTACTG,CTNATCGGTC +TACATATTCTTTACTG,AGCAACCCGA +TACATATTCTTTACTG,ACTGGTCGCT +TACATATTCTTTACTG,GCNGTCGCTA +TACATATTCTTTACTG,TCGCGATGNT +TACATATTCTTTACTG,TAATCATACC +TACATATTCTTTACTG,GACTTTTCCT +TACATATTCTTTACTG,CCCGAATGAA +TACATATTCTTTACTG,GCTTCTACCN +TACATATTCTTTACTG,ACTGGTCGCT +TACATATTCTTTACTG,AGGTCGCTAC +TACATATTCTTTACTG,AGCGCCNTGG +TACATATTCTTTACTG,CCAGCGCCCG +TACATATTCTTTACTG,GAGATCCGAG +TACATATTCTTTACTG,TAGCCCCCCC +TACATATTCTTTACTG,ATTCATTGTA +TACATATTCTTTACTG,ATCGGGCGCC +TACATATTCTTTACTG,CGAGGTGCGN +TACATATTCTTTACTG,AGTAANGCAA +TACATATTCTTTACTG,NGCTGGCACG +TACATATTCTTTACTG,CCGNTGTGGC +TACATATTCTTTACTG,ATCGGGCGCC +TACATATTCTTTACTG,ATTAAGCCNG +TACATATTCTTTACTG,GGNCGCACCC +TACATATTCTTTACTG,AANCACANGT +TACATATTCTTTACTG,AANTAAGCAT +TACATATTCTTTACTG,ANAGGANAAC +TACATATTCTTTACTG,GGNCGCACCC +TACATATTCTTTACTG,CAATTCCGGC +TACATATTCTTTACTG,AGGTTTACTG +TACATATTCTTTACTG,CCGNTGTGGC +TACATATTCTTTACTG,ACGCTATGTA +TACATATTCTTTACTG,CTCCTGTGGC +TACATATTCTTTACTG,GTTGTTTATT +TACATATTCTTTACTG,CGAAGAGAAC +TACATATTCTTTACTG,CGAAGAGAAC +TACATATTCTTTACTG,GCGGCCATTC +TACATATTCTTTACTG,AGTAANGCAA +TACATATTCTTTACTG,CGAAGAGAAC +TACATATTCTTTACTG,GTCAACCGGG +TACATATTCTTTACTG,CTCAATACTA +TACATATTCTTTACTG,ATTAAGCCNG +TACATATTCTTTACTG,TACTGTGCTA +TACATATTCTTTACTG,ANGCACTCGA +TACATATTCTTTACTG,NGCTGGCACG +TACATATTCTTTACTG,TAGTATGGAA +TACATATTCTTTACTG,TCGCGATGNT +TACATATTCTTTACTG,ANAGGANAAC +TACATATTCTTTACTG,ACAGTAAATG +TACATATTCTTTACTG,GGTCTAAGAG +TACATATTCTTTACTG,CGAGGTGCGN +TACATATTCTTTACTG,ACTGGTCGCT +TACATATTCTTTACTG,GTTGTTTATT +TACATATTCTTTACTG,AAGGCACTTT +TACATATTCTTTACTG,TGACATCAAC +TACATATTCTTTACTG,TGCAGAAANG +TACATATTCTTTACTG,CTTCAANTGA diff --git a/tests/test_data/fastq/correct_R2.csv b/tests/test_data/fastq/correct_R2.csv new file mode 100644 index 0000000..be37858 --- /dev/null +++ b/tests/test_data/fastq/correct_R2.csv @@ -0,0 +1,201 @@ +r2 +CGTAGCTCGAAAAAAAAAAA +CGTACGTAGCCTAGCAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTCGTAGCTGATCGTAGCT +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTACGTAGCCTAGCAAAAA +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTACGTAGCCTAGCAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTACGTAGCCTAGCAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA +CGTCGTAGCTGATCGTAGCT +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA +CGTCGTAGCTGATCGTAGCT +CGTCGTAGCTGATCGTAGCT +CGTCGTAGCTGATCGTAGCT +CGTCGTAGCTGATCGTAGCT +CGTACGTAGCCTAGCAAAAA +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTCGTAGCTGATCGTAGCT +CGTCGTAGCTGATCGTAGCT +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTACGTAGCCTAGCAAAAA +CGTCGTAGCTGATCGTAGCT +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA +CGTCGTAGCTGATCGTAGCT +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTCGTAGCTGATCGTAGCT +CGTACGTAGCCTAGCAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTCGTAGCTGATCGTAGCT +CGTCGTAGCTGATCGTAGCT +CGTACGTAGCCTAGCAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTCGTAGCTGATCGTAGCT +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTACGTAGCCTAGCAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTACGTAGCCTAGCAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTCGTAGCTGATCGTAGCT +CGTACGTAGCCTAGCAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTCGTAGCTGATCGTAGCT +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA +CGTACGTAGCCTAGCAAAAA +CGTCGTAGCTGATCGTAGCT +CGTACGTAGCCTAGCAAAAA +CGTCGTAGCTGATCGTAGCT +CGTCGTAGCTGATCGTAGCT +CGTCGTAGCTGATCGTAGCT +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA +CGTACGTAGCCTAGCAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA +CGTCGTAGCTGATCGTAGCT +CGTACGTAGCCTAGCAAAAA +CGTCGTAGCTGATCGTAGCT +CGTCGTAGCTGATCGTAGCT +CGTACGTAGCCTAGCAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTCGTAGCTGATCGTAGCT +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTACGTAGCCTAGCAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA +CGTCGTAGCTGATCGTAGCT +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA +CGTACGTAGCCTAGCAAAAA +CGTCGTAGCTGATCGTAGCT +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTCGTAGCTGATCGTAGCT +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA +CGTACGTAGCCTAGCAAAAA +CGTCGTAGCTGATCGTAGCT +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTCGTAGCTGATCGTAGCT +CGTACGTAGCCTAGCAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTCGTAGCTGATCGTAGCT +CGTACGTAGCCTAGCAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTACGTAGCCTAGCAAAAA +CGTCGTAGCTGATCGTAGCT +CGTACGTAGCCTAGCAAAAA +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA +CGTCGTAGCTGATCGTAGCT +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA +CGTACGTAGCCTAGCAAAAA +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTCGTAGCTGATCGTAGCT +CGTCGTAGCTGATCGTAGCT +CGTCGTAGCTGATCGTAGCT +CGTCGTAGCTGATCGTAGCT +CGTCGTAGCTGATCGTAGCT +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTAGCTCGAAAAAAAAAAA +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA +CGTACGTAGCCTAGCAAAAA +CGTACGTAGCCTAGCAAAAA +CGTCGTAGCTGATCGTAGCT +CGTAGCTCGAAAAAAAAAAA diff --git a/tests/test_io.py b/tests/test_io.py index 2327a86..05696be 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -3,6 +3,7 @@ import gzip from pathlib import Path from polars.testing import assert_frame_equal +from zmq import has from cite_seq_count.constants import ( BARCODE_COLUMN, COUNT_COLUMN, @@ -18,7 +19,13 @@ get_read_paths, write_data_to_mtx, create_mtx_df, + write_mapping_input_from_fastqs, + read_R1_polars, + read_R2_polars, + write_fastq_inputs_as_parquet, ) + +from cite_seq_count.chemistry import Chemistry import polars as pl # copied from https://stackoverflow.com/questions/3431825/generating-an-md5-checksum-of-a-file @@ -38,6 +45,19 @@ def md5(fname): return hash_md5.hexdigest() +@pytest.fixture +def chemistry_def(): + return Chemistry( + name="test", + cell_barcode_start=1, + cell_barcode_end=16, + umi_barcode_start=17, + umi_barcode_end=26, + r2_trim_start=0, + barcode_reference_path="tests/reference_lists/pass/translation.csv", + ) + + @pytest.fixture def correct_R1(): return Path("tests/test_data/fastq/correct_R1.fastq.gz") @@ -197,8 +217,8 @@ def test_create_mtx_df(): }, schema={ FEATURE_ID_COLUMN: pl.UInt32, - SEQUENCE_COLUMN: pl.Utf8, - FEATURE_NAME_COLUMN: pl.Utf8, + SEQUENCE_COLUMN: pl.String, + FEATURE_NAME_COLUMN: pl.String, }, ) expected_barcodes_indexed = pl.DataFrame( @@ -206,7 +226,7 @@ def test_create_mtx_df(): SUBSET_COLUMN: ["ACTGTTTTATTGGCCT", "TTCATAAGGTAGGGAT"], BARCODE_ID_COLUMN: [1, 2], }, - schema={SUBSET_COLUMN: pl.Utf8, BARCODE_ID_COLUMN: pl.UInt32}, + schema={SUBSET_COLUMN: pl.String, BARCODE_ID_COLUMN: pl.UInt32}, ) assert_frame_equal(mtx_df, expected_mtx_df) @@ -214,3 +234,69 @@ def test_create_mtx_df(): assert_frame_equal( barcodes_indexed, expected_barcodes_indexed, check_column_order=False ) + + +@pytest.fixture +def fastq_reads(): + return pl.scan_csv( + "tests/test_data/fastq/test_csv.csv", + has_header=False, + schema={"barcode": pl.String, "umi": pl.String, "r2": pl.String}, + ) + + +@pytest.fixture +def fastq_paths(correct_R1, correct_R2): + return [ + (correct_R1, correct_R2), + ] + + +@pytest.fixture +def correct_R1_csv(): + return pl.scan_csv("tests/test_data/fastq/correct_R1.csv", has_header=True) + + +@pytest.fixture +def correct_R2_csv(): + return pl.scan_csv("tests/test_data/fastq/correct_R2.csv", has_header=True) + + +def test_read_R1_polars(correct_R1, chemistry_def, correct_R1_csv): + # Call the function + result = read_R1_polars(correct_R1, chemistry_def) + # Define the expected output + # Check the output + assert_frame_equal(result, correct_R1_csv) + + +def test_read_R2_polars(correct_R2, correct_R2_csv, chemistry_def): + # Call the function + result = read_R2_polars(correct_R2, r2_min_length=20, chemistry=chemistry_def) + # Define the expected output + + # Check the output + assert_frame_equal(result, correct_R2_csv) + + +def test_write_fastq_inputs_as_parquet(tmp_path, correct_R1, correct_R2, chemistry_def): + # Set up test data + read_paths = [(correct_R1, correct_R2)] + temp_path = tmp_path / "test.parquet" + r2_min_length = 20 + top_n_reads = 100 + + # Call the function + r1_too_short, r2_too_short, total_reads = write_fastq_inputs_as_parquet( + read_paths=read_paths, + temp_path=temp_path, + chemistry_def=chemistry_def, + r2_min_length=r2_min_length, + top_n_reads=top_n_reads, + ) + + # Assert the result + assert r1_too_short == 0 + assert r2_too_short == 0 + assert total_reads == top_n_reads + assert temp_path.exists() diff --git a/tests/test_processing.py b/tests/test_processing.py index 1accd07..02c99e9 100644 --- a/tests/test_processing.py +++ b/tests/test_processing.py @@ -23,7 +23,8 @@ def barcodes_df(): "TAGAGGGAAGTCAAGC", ], "count": [5, 1, 1, 1, 5, 1, 1, 1, 5, 1, 1, 1], - } + }, + schema={"barcode": pl.String, "count": pl.UInt32}, ) @@ -54,7 +55,8 @@ def test_correct_barcodes_pl(barcodes_df): "TAGAGGGAGGTCAAGC", ], "count": [8, 8, 8], - } + }, + schema={"barcode": pl.String, "count": pl.UInt32}, ) assert_frame_equal( corrected_barcodes, expected_corrected_barcodes, check_row_order=False @@ -66,14 +68,28 @@ def test_correct_barcodes_pl(barcodes_df): # Assert the mapped barcodes expected_mapped_barcodes = { - "AACATATTCTTTACTG":"TACATATTCTTTACTG", - "CACATATTCTTTACTG":"TACATATTCTTTACTG", - "GACATATTCTTTACTG":"TACATATTCTTTACTG", - "GCTAGTCGTAGCTAGT":"GCTAGTCGTAGCTAGA", - "GCTAGTCGTAGCTAGG":"GCTAGTCGTAGCTAGA", - "GCTAGTCGTAGCTAGC":"GCTAGTCGTAGCTAGA", - "TAGAGGGACGTCAAGC":"TAGAGGGAGGTCAAGC", - "TAGAGGGATGTCAAGC":"TAGAGGGAGGTCAAGC", - "TAGAGGGAAGTCAAGC":"TAGAGGGAGGTCAAGC", + "AACATATTCTTTACTG": "TACATATTCTTTACTG", + "CACATATTCTTTACTG": "TACATATTCTTTACTG", + "GACATATTCTTTACTG": "TACATATTCTTTACTG", + "GCTAGTCGTAGCTAGT": "GCTAGTCGTAGCTAGA", + "GCTAGTCGTAGCTAGG": "GCTAGTCGTAGCTAGA", + "GCTAGTCGTAGCTAGC": "GCTAGTCGTAGCTAGA", + "TAGAGGGACGTCAAGC": "TAGAGGGAGGTCAAGC", + "TAGAGGGATGTCAAGC": "TAGAGGGAGGTCAAGC", + "TAGAGGGAAGTCAAGC": "TAGAGGGAGGTCAAGC", } assert mapped_barcodes == expected_mapped_barcodes + + +# def test_find_closest_match(): +# input_lf = pl.LazyFrame( +# {"sequence": ["ATGCTATCAG", "GGCGAGGCT", "GGATTATCGA", "GCTAGCTTAG"]} +# ) +# ref_seq = pl.LazyFrame( +# {"ref": ["ATGCTAACAG", "GCTAGCTAT", "AGGAGATC", "GGATAGCGA", "GATTCGGAG"]} +# ) +# res = processing.find_closest_match( +# df=input_lf, source_column="sequence", target_df=ref_seq +# ) +# print(res.collect()) +# assert_frame_equal(res, pl.LazyFrame()) From c54dd6052ff85ad5e492d1a23eccc9b2e07970f6 Mon Sep 17 00:00:00 2001 From: hoohm Date: Sun, 25 Feb 2024 14:51:30 +0100 Subject: [PATCH 77/77] feat: New fastq reader/writer --- cite_seq_count/__main__.py | 7 +- cite_seq_count/io.py | 114 ++++++++++++++++++++------------ cite_seq_count/preprocessing.py | 18 +++-- setup.py | 3 +- tests/test_io.py | 19 +++++- tests/test_processing.py | 47 ++++++++++++- 6 files changed, 151 insertions(+), 57 deletions(-) diff --git a/cite_seq_count/__main__.py b/cite_seq_count/__main__.py index ad732a8..fbef519 100755 --- a/cite_seq_count/__main__.py +++ b/cite_seq_count/__main__.py @@ -84,7 +84,7 @@ def main(): r1_too_short, r2_too_short, n_reads, - temp_file_path, + temp_file_list, ) = io.write_mapping_input_from_fastqs( chemistry_def=chemistry_def, fastq_paths=list(zip(read1_paths, read2_paths)), @@ -94,7 +94,7 @@ def main(): ) print("Processing") main_df, barcodes_df, r2_df = preprocessing.split_data_input( - mapping_input_path=temp_file_path, + mapping_input_path=temp_file_list, ) mapped_r2_df, unmapped_r2_df = mapping.map_reads_polars( r2_df=r2_df, @@ -178,7 +178,8 @@ def main(): chemistry_def=chemistry_def, maximum_distance=args.max_error, ) - os.remove(temp_file_path) + for temp_file in temp_file_list: + os.remove(temp_file) if __name__ == "__main__": diff --git a/cite_seq_count/io.py b/cite_seq_count/io.py index ded1259..7ede2bf 100644 --- a/cite_seq_count/io.py +++ b/cite_seq_count/io.py @@ -1,5 +1,7 @@ """Handle io operations""" +from concurrent.futures import ThreadPoolExecutor, as_completed from math import inf +from tempfile import NamedTemporaryFile import os import csv import sys @@ -215,40 +217,36 @@ def read_R2_polars(R2_path: Path, r2_min_length: int, chemistry) -> pl.LazyFrame ) -def write_fastq_inputs_as_parquet( - read_paths: list, - temp_path: Path, +def concat_reads( + r1_read_path: Path, + r2_read_path: Path, chemistry_def, - r2_min_length: int, top_n_reads: int, -) -> tuple[int, int, int]: - concats = [] - for r1_read_path, r2_read_path in read_paths: - R1_read = read_R1_polars(R1_path=r1_read_path, chemistry_def=chemistry_def) - R2_read = read_R2_polars( - R2_path=r2_read_path, r2_min_length=r2_min_length, chemistry=chemistry_def + n_samples: int, + r2_min_length: int, + temp_path: Path, +): + R1_read = read_R1_polars(R1_path=r1_read_path, chemistry_def=chemistry_def) + R2_read = read_R2_polars( + R2_path=r2_read_path, r2_min_length=r2_min_length, chemistry=chemistry_def + ) + if top_n_reads != inf: + data = ( + pl.concat([R1_read, R2_read], how="horizontal") + # Since this is lazy, the head is actually computed ahead of time + .head(round(top_n_reads / n_samples)) + .group_by(["barcode", "umi", "r2"]) + .agg(pl.count()) ) - if top_n_reads != inf: - data = ( - pl.concat([R1_read, R2_read], how="horizontal") - # Since this is lazy, the head is actually computed ahead of time - .head(4 * round(top_n_reads / len(read_paths))) - .group_by(["barcode", "umi", "r2"]) - .agg(pl.count()) - ) - else: - data = ( - pl.concat([R1_read, R2_read], how="horizontal") - .group_by(["barcode", "umi", "r2"]) - .agg(pl.count()) - ) - - concats.append(data) - - all = pl.concat(concats) - total_reads = all.select(pl.sum(COUNT_COLUMN)).collect().item() + else: + data = ( + pl.concat([R1_read, R2_read], how="horizontal") + .group_by(["barcode", "umi", "r2"]) + .agg(pl.count()) + ) + total_reads = data.select(pl.sum(COUNT_COLUMN)).collect().item() r1_too_short = ( - all.filter( + data.filter( (pl.col(BARCODE_COLUMN) + pl.col(UMI_COLUMN)).str.len_bytes() < (chemistry_def.barcode_length + chemistry_def.umi_length) ) @@ -257,14 +255,52 @@ def write_fastq_inputs_as_parquet( .item() ) r2_too_short = ( - all.filter((pl.col(R2_COLUMN).str.len_bytes() < r2_min_length)) + data.filter((pl.col(R2_COLUMN).str.len_bytes() < r2_min_length)) .select(pl.sum(COUNT_COLUMN)) .collect() .item() ) + temp_file_path = NamedTemporaryFile(dir=temp_path, delete=False) + data.collect().write_parquet(file=temp_file_path.name) + return temp_file_path.name, r1_too_short, r2_too_short, total_reads + + +def write_fastq_inputs_as_parquet( + read_paths: list, + temp_path: Path, + chemistry_def, + r2_min_length: int, + top_n_reads: int, +) -> tuple[int, int, int, list[Path]]: + n_samples = len(read_paths) + with ThreadPoolExecutor(n_samples) as executor: + executors = [] + for r1_read_path, r2_read_path in read_paths: + executors.append( + executor.submit( + concat_reads, + r1_read_path, + r2_read_path, + chemistry_def, + top_n_reads, + n_samples, + r2_min_length, + temp_path, + ) + ) + results = [future.result() for future in as_completed(executors)] + r1_too_short = 0 + r2_too_short = 0 + total_reads = 0 + temp_path_list = [] + for result in results: + temp_path_list.append(Path(result[0])) + r1_too_short+=result[1] + r2_too_short+=result[2] + total_reads+=result[3] - all.collect().write_parquet(file=temp_path) - return r1_too_short, r2_too_short, total_reads + + return r1_too_short, r2_too_short, total_reads, temp_path_list def load_report_template() -> dict: @@ -531,7 +567,7 @@ def write_mapping_input_from_fastqs( chemistry_def, top_n_reads: int, temp_path: str, -) -> tuple[int, int, int, Path]: +) -> tuple[int, int, int, list[Path]]: """Read fastq inputs using polars read_csv, concatenate R1 and R2 files, summaries and write to parquet Args: @@ -542,17 +578,13 @@ def write_mapping_input_from_fastqs( None """ temp_path = os.path.abspath(temp_path) - temp_file = tempfile.NamedTemporaryFile( - "w", dir=temp_path, suffix="_csc.parquet", delete=False - ) - temp_file_path = Path(temp_file.name) - r1_too_short, r2_too_short, total_reads = write_fastq_inputs_as_parquet( - temp_path=temp_file_path, + r1_too_short, r2_too_short, total_reads, temp_file_list = write_fastq_inputs_as_parquet( + temp_path=Path(temp_path), read_paths=fastq_paths, chemistry_def=chemistry_def, r2_min_length=r2_min_length, top_n_reads=top_n_reads, ) - return r1_too_short, r2_too_short, total_reads, temp_file_path + return r1_too_short, r2_too_short, total_reads, temp_file_list diff --git a/cite_seq_count/preprocessing.py b/cite_seq_count/preprocessing.py index b90708d..dcff67d 100644 --- a/cite_seq_count/preprocessing.py +++ b/cite_seq_count/preprocessing.py @@ -364,7 +364,7 @@ def pre_run_checks( def split_data_input( - mapping_input_path: Path, + mapping_input_path: list[Path], ) -> tuple[pl.LazyFrame, pl.LazyFrame, pl.LazyFrame]: """Read in all the input data and split it into three dataframes. @@ -380,13 +380,17 @@ def split_data_input( Returns: tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]: Three dfs described above """ - main_df = ( - pl.scan_parquet( - mapping_input_path, + main_dfs = [] + for temp_file_path in mapping_input_path: + main_df = ( + pl.scan_parquet( + temp_file_path, + ) + .group_by([BARCODE_COLUMN, UMI_COLUMN, R2_COLUMN]) + .agg(pl.count()) ) - .group_by([BARCODE_COLUMN, UMI_COLUMN, R2_COLUMN]) - .agg(pl.count()) - ) + main_dfs.append(main_df) + main_df = pl.concat(main_dfs) barcodes_df = ( main_df.select([BARCODE_COLUMN, COUNT_COLUMN]) .group_by(BARCODE_COLUMN) diff --git a/setup.py b/setup.py index 4127138..72ed9c5 100644 --- a/setup.py +++ b/setup.py @@ -28,7 +28,8 @@ "pyyaml==6.0", "pooch==1.6.0", "six==1.16.0", - "polars==0.20.4", + "polars==0.20.10", + "polars-distance==0.4.1", ], python_requires="==3.11.6", package_data={"report_template": ["templates/*.json"]}, diff --git a/tests/test_io.py b/tests/test_io.py index 05696be..07db0ca 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -3,7 +3,6 @@ import gzip from pathlib import Path from polars.testing import assert_frame_equal -from zmq import has from cite_seq_count.constants import ( BARCODE_COLUMN, COUNT_COLUMN, @@ -23,6 +22,7 @@ read_R1_polars, read_R2_polars, write_fastq_inputs_as_parquet, + concat_reads, ) from cite_seq_count.chemistry import Chemistry @@ -279,15 +279,27 @@ def test_read_R2_polars(correct_R2, correct_R2_csv, chemistry_def): assert_frame_equal(result, correct_R2_csv) +def test_concat_reads(correct_R1, correct_R2, chemistry_def, tmp_path): + # Call the function to be tested + temp_file_path, r1_too_short, r2_too_short, total_reads = concat_reads( + correct_R1, correct_R2, chemistry_def, 100, 1, 20, tmp_path + ) + os.remove(temp_file_path) + # Assert the expected output + assert r1_too_short == 0 + assert r2_too_short == 0 + assert total_reads == 100 + + def test_write_fastq_inputs_as_parquet(tmp_path, correct_R1, correct_R2, chemistry_def): # Set up test data read_paths = [(correct_R1, correct_R2)] - temp_path = tmp_path / "test.parquet" + temp_path = tmp_path r2_min_length = 20 top_n_reads = 100 # Call the function - r1_too_short, r2_too_short, total_reads = write_fastq_inputs_as_parquet( + r1_too_short, r2_too_short, total_reads, temp_file_list = write_fastq_inputs_as_parquet( read_paths=read_paths, temp_path=temp_path, chemistry_def=chemistry_def, @@ -296,6 +308,7 @@ def test_write_fastq_inputs_as_parquet(tmp_path, correct_R1, correct_R2, chemist ) # Assert the result + print(temp_file_list) assert r1_too_short == 0 assert r2_too_short == 0 assert total_reads == top_n_reads diff --git a/tests/test_processing.py b/tests/test_processing.py index 02c99e9..d7b7617 100644 --- a/tests/test_processing.py +++ b/tests/test_processing.py @@ -1,6 +1,7 @@ import pytest from cite_seq_count import processing import polars as pl +import polars_distance as pld from polars.testing import assert_frame_equal @@ -35,8 +36,8 @@ def test_correct_barcodes_pl(barcodes_df): "TACATATTCTTTACTG", "GCTAGTCGTAGCTAGA", "TAGAGGGAGGTCAAGC", - ], - } + ] + }, schema = {"subset":pl.String} ) hamming_distance = 1 @@ -93,3 +94,45 @@ def test_correct_barcodes_pl(barcodes_df): # ) # print(res.collect()) # assert_frame_equal(res, pl.LazyFrame()) + + +# def test_umi_correction(): +# # Broadest context +# test_data_broad = pl.DataFrame( +# { "barcode": ["cell1", "cell1", "cell1", "cell2", "cell2", "cell2"], +# "feature": ["gene1", "gene1", "gene2", "gene1", "gene2", "gene2"], +# "umi": ["CCCC", "CCCA", "TTTT", "AAAA", "GGGG", "CAAA"], +# "count": [10, 1, 2, 3, 6, 3], +# } +# ) + +# expected_data_broad = pl.DataFrame( +# { "barcode": ["cell1", "cell1", "cell2", "cell2"], +# "feature": ["gene1", "gene2", "gene1", "gene2"], +# "umi": ["CCCC", "TTTT", "AAAA", "GGGG"], +# "count": [11, 2, 6, 6], +# } +# ) +# # One cell, one feature +# test_data = pl.DataFrame( +# { +# "umi": ["CCCC", "CCCA", "TTTT"], +# "count": [10, 1, 2], +# } +# ) +# expected = pl.DataFrame( +# { +# "umi": ["CCCC", "TTTT"], +# "count": [11, 2], +# } +# ) +# # What I've got so far that deals with one case of the smaller context but doesn't keep the "uncorrected" umis +# res = ( +# test_data.join(test_data.select("umi"), on="umi", how="cross") +# .filter(pl.col("umi") != pl.col("umi_right")) +# .with_columns(pld.col("umi").dist_str.hamming("umi_right").alias("hamming")) +# .filter(pl.col("hamming") <= 1) +# .sort("count", descending=True) +# .select(pl.first("umi"), pl.sum("count")) +# ) +# print(res)