From 2e81cf899a6f615c9010140a45a66a24c51d0c65 Mon Sep 17 00:00:00 2001 From: goodov <5928869+goodov@users.noreply.github.com> Date: Fri, 26 Apr 2024 17:43:04 +0700 Subject: [PATCH] Sort studies by name and add basic study intersection check. (#1023) * Sort studies by name and add basic study intersection validation. * Fix wildcard version comparison, add tests. * Make sure to properly cover empty channel/platform lists (although this is not allowed). --- seed/serialize.py | 120 ++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 120 insertions(+) diff --git a/seed/serialize.py b/seed/serialize.py index 09c7528c..21b80edf 100644 --- a/seed/serialize.py +++ b/seed/serialize.py @@ -7,6 +7,7 @@ import sys import time import proto.variations_seed_pb2 as variations_seed_pb2 +import collections SEED_BIN_PATH = "./seed.bin" SERIALNUMBER_PATH = "./serialnumber" @@ -22,6 +23,73 @@ def load(seed_json_path): return seed_data +def version_to_int_array(version_str): + version_list = [] + if version_str is None: + return version_list + + parts = version_str.split('.') + for part in parts: + if part == '*': + version_list.append(part) + break + version_list.append(int(part)) + + return version_list + +def compare_versions(version1, version2): + min_len = None + if not version1: + min_len = 0 + elif version1[-1] == '*': + version1 = version1[:-1] + min_len = len(version1) + + if not version2: + min_len = 0 + elif version2[-1] == '*': + version2 = version2[:-1] + if min_len is not None: + min_len = min(min_len, len(version2)) + else: + min_len = len(version2) + + if min_len is not None: + version1 = version1[:min_len] + version2 = version2[:min_len] + + if version1 > version2: + return 1 + elif version1 < version2: + return -1 + else: + return 0 + +def test_version_comparison(): + # //base/version_unittest.cc VersionTest.CompareToWildcardString + test_cases = [ + ["1.0", "1.*", 0], + ["1.0", "0.*", 1], + ["1.0", "2.*", -1], + ["1.2.3", "1.2.3.*", 0], + ["10.0", "1.0.*", 1], + ["1.0", "3.0.*", -1], + ["1.4", "1.3.0.*", 1], + ["1.3.9", "1.3.*", 0], + ["1.4.1", "1.3.*", 1], + ["1.3", "1.4.5.*", -1], + ["1.5", "1.4.5.*", 1], + ["1.3.9", "1.3.*", 0], + ["1.2.0.0.0.0", "1.2.*", 0], + [None, None, 0], + [None, "1", 0], + ["1", None, 0], + ] + for test_case in test_cases: + version1 = version_to_int_array(test_case[0]) + version2 = version_to_int_array(test_case[1]) + assert compare_versions(version1, version2) == test_case[2] + def validate(seed): for study in seed['studies']: total_proba = 0 @@ -40,6 +108,57 @@ def validate(seed): print("platform not in ", PLATFORMS) return False + feature_names_to_studies = collections.defaultdict(list) + for study in seed['studies']: + used_feature_names = set() + for experiment in study['experiments']: + feature_association = experiment.get('feature_association') + if feature_association: + for enable_feature in feature_association.get('enable_feature', []): + used_feature_names.add(enable_feature) + for disable_feature in feature_association.get('disable_feature', []): + used_feature_names.add(disable_feature) + + for used_feature_names in used_feature_names: + feature_names_to_studies[used_feature_names].append(study) + + def get_study_platforms(study): + return set(study.get('filter', {}).get('platform', [])) + + def get_study_channels(study): + return set(study.get('filter', {}).get('channel', [])) + + def get_study_version_range(study): + return [ + version_to_int_array(study.get('filter', {}).get('min_version')), + version_to_int_array(study.get('filter', {}).get('max_version')), + ] + + def is_filter_set_intersect(a, b): + return not a or not b or a.intersection(b) + + def is_version_range_intersect(range1, range2): + return compare_versions(range1[1], range2[0]) >= 0 and compare_versions(range2[1], range1[0]) >= 0 + + test_version_comparison() + for studies in feature_names_to_studies.values(): + for i, study1 in enumerate(studies): + study1_platform = get_study_platforms(study1) + study1_channel = get_study_channels(study1) + study1_version_range = get_study_version_range(study1) + for j in range(i + 1, len(studies)): + study2 = studies[j] + study2_platform = get_study_platforms(study2) + study2_channel = get_study_channels(study2) + study2_version_range = get_study_version_range(study2) + # Check if the studies overlap in platform + if is_filter_set_intersect(study1_platform, study2_platform): + # Check if the studies overlap in channel + if is_filter_set_intersect(study1_channel, study2_channel): + # Check if the studies overlap in version + if is_version_range_intersect(study1_version_range, study2_version_range): + raise ValueError(f"Studies overlap:\n{json.dumps(study1, indent=2)}\n\n{json.dumps(study2, indent=2)}") + return True @@ -207,6 +326,7 @@ def serialize_and_save_variations_seed_message(seed_data, path): if __name__ == "__main__": print("Load seed.json") seed_data = load(sys.argv[1]) + seed_data['studies'].sort(key=lambda study: study['name']) print("Validate seed data") if validate(seed_data):