diff --git a/CHANGES.md b/CHANGES.md index 654626c4c..3cba4b99e 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -1,4 +1,6 @@ ## Current +* core.match_filter.template + - new quick_group_templates function for 50x quicker template grouping. * core.match_filter - 30x speedup in handling detections (50x speedup in selecting detections, 4x speedup in adding prepick time) diff --git a/eqcorrscan/core/match_filter/template.py b/eqcorrscan/core/match_filter/template.py index 8b792de90..ba8955c3b 100644 --- a/eqcorrscan/core/match_filter/template.py +++ b/eqcorrscan/core/match_filter/template.py @@ -105,6 +105,15 @@ def __init__(self, name=None, st=None, lowcut=None, highcut=None, author=getpass.getuser()))) self.event = event + @property + def _processing_parameters(self): + """ + Internal function / attribute to return all processing parameters for + quick grouping of templates as tuple. + """ + return (self.lowcut, self.highcut, self.samp_rate, self.filt_order, + self.process_length) + def __repr__(self): """ Print the template. @@ -284,12 +293,9 @@ def same_processing(self, other): >>> template_a.same_processing(template_b) False """ - for key in self.__dict__.keys(): - if key in ['name', 'st', 'prepick', 'event', 'template_info']: - continue - if not self.__dict__[key] == other.__dict__[key]: - return False - return True + if self._processing_parameters == other._processing_parameters: + return True + return False def write(self, filename, format='tar'): """ @@ -707,6 +713,35 @@ def group_templates(templates): return template_groups +def quick_group_templates(templates): + """ + Group templates into sets of similarly processed templates. + + :type templates: List of Tribe of Templates + :return: List of Lists of Templates. + """ + # Get the template's processing parameters + processing_tuples = [template._processing_parameters + for template in templates] + # Get list of unique parameter-tuples. Sort it so that the order in which + # the groups are processed is consistent across different runs. + uniq_processing_parameters = sorted(list(set(processing_tuples))) + # sort templates into groups + template_groups = [] + for parameter_combination in uniq_processing_parameters: + # find indices of tuples in list with same parameters + template_indices_for_group = [ + j for j, param_tuple in enumerate(processing_tuples) + if param_tuple == parameter_combination] + + new_group = list() + for template_index in template_indices_for_group: + # use indices to sort templates into groups + new_group.append(templates[int(template_index)]) + template_groups.append(new_group) + return template_groups + + if __name__ == "__main__": import doctest diff --git a/eqcorrscan/core/match_filter/tribe.py b/eqcorrscan/core/match_filter/tribe.py index f03bf77a4..22329c484 100644 --- a/eqcorrscan/core/match_filter/tribe.py +++ b/eqcorrscan/core/match_filter/tribe.py @@ -24,7 +24,8 @@ from obspy import Catalog, Stream, read, read_events from obspy.core.event import Comment, CreationInfo -from eqcorrscan.core.match_filter.template import Template, group_templates +from eqcorrscan.core.match_filter.template import ( + Template, quick_group_templates) from eqcorrscan.core.match_filter.party import Party from eqcorrscan.core.match_filter.helpers import ( _safemembers, _par_read, get_waveform_client) @@ -584,7 +585,8 @@ def detect(self, stream, threshold, threshold_type, trig_int, plot=False, length is the number of channels within this template. """ party = Party() - template_groups = group_templates(self.templates) + # template_groups = group_templates(self.templates) + template_groups = quick_group_templates(self.templates) if len(template_groups) > 1 and pre_processed: raise NotImplementedError( "Inconsistent template processing and pre-processed data - " diff --git a/eqcorrscan/tests/match_filter_test.py b/eqcorrscan/tests/match_filter_test.py index 95a8f8a08..dd642dbbb 100644 --- a/eqcorrscan/tests/match_filter_test.py +++ b/eqcorrscan/tests/match_filter_test.py @@ -21,6 +21,8 @@ from eqcorrscan.core.match_filter.matched_filter import ( match_filter, MatchFilterError) from eqcorrscan.core.match_filter.helpers import get_waveform_client +from eqcorrscan.core.match_filter.template import quick_group_templates + from eqcorrscan.utils import pre_processing, catalog_utils from eqcorrscan.utils.correlate import fftw_normxcorr, numpy_normxcorr from eqcorrscan.utils.catalog_utils import filter_picks @@ -1342,6 +1344,22 @@ def test_template_io(self): if os.path.isfile('test_template.tgz'): os.remove('test_template.tgz') + def test_template_grouping(self): + # PR #524 + # Test that this works directly on the tribe - it should + tribe_len = len(self.tribe) + groups = quick_group_templates(self.tribe) + self.assertEqual(len(groups), 1) + # Add one copy of a template with a different processing length + t2 = self.tribe[0].copy() + t2.process_length -= 100 + templates = [t2] + templates.extend(self.tribe.templates) + # Quick check that we haven't changed the tribe + self.assertEqual(len(self.tribe), tribe_len) + groups2 = quick_group_templates(templates) + self.assertEqual(len(groups2), 2) + def test_party_io(self): """Test reading and writing party objects.""" if os.path.isfile('test_party_out.tgz'):