Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sort studies by name and add basic study intersection check [prod]. (#1023) #1025

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 120 additions & 0 deletions seed/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import sys
import time
import proto.variations_seed_pb2 as variations_seed_pb2
import collections

SEED_BIN_PATH = "./seed.bin"
SERIALNUMBER_PATH = "./serialnumber"
Expand All @@ -22,6 +23,73 @@ def load(seed_json_path):
return seed_data


def version_to_int_array(version_str):
version_list = []
if version_str is None:
return version_list

parts = version_str.split('.')
for part in parts:
if part == '*':
version_list.append(part)
break
version_list.append(int(part))

return version_list

def compare_versions(version1, version2):
min_len = None
if not version1:
min_len = 0
elif version1[-1] == '*':
version1 = version1[:-1]
min_len = len(version1)

if not version2:
min_len = 0
elif version2[-1] == '*':
version2 = version2[:-1]
if min_len is not None:
min_len = min(min_len, len(version2))
else:
min_len = len(version2)

if min_len is not None:
version1 = version1[:min_len]
version2 = version2[:min_len]

if version1 > version2:
return 1
elif version1 < version2:
return -1
else:
return 0

def test_version_comparison():
# //base/version_unittest.cc VersionTest.CompareToWildcardString
test_cases = [
["1.0", "1.*", 0],
["1.0", "0.*", 1],
["1.0", "2.*", -1],
["1.2.3", "1.2.3.*", 0],
["10.0", "1.0.*", 1],
["1.0", "3.0.*", -1],
["1.4", "1.3.0.*", 1],
["1.3.9", "1.3.*", 0],
["1.4.1", "1.3.*", 1],
["1.3", "1.4.5.*", -1],
["1.5", "1.4.5.*", 1],
["1.3.9", "1.3.*", 0],
["1.2.0.0.0.0", "1.2.*", 0],
[None, None, 0],
[None, "1", 0],
["1", None, 0],
]
for test_case in test_cases:
version1 = version_to_int_array(test_case[0])
version2 = version_to_int_array(test_case[1])
assert compare_versions(version1, version2) == test_case[2]

def validate(seed):
for study in seed['studies']:
total_proba = 0
Expand All @@ -40,6 +108,57 @@ def validate(seed):
print("platform not in ", PLATFORMS)
return False

feature_names_to_studies = collections.defaultdict(list)
for study in seed['studies']:
used_feature_names = set()
for experiment in study['experiments']:
feature_association = experiment.get('feature_association')
if feature_association:
for enable_feature in feature_association.get('enable_feature', []):
used_feature_names.add(enable_feature)
for disable_feature in feature_association.get('disable_feature', []):
used_feature_names.add(disable_feature)

for used_feature_names in used_feature_names:
feature_names_to_studies[used_feature_names].append(study)

def get_study_platforms(study):
return set(study.get('filter', {}).get('platform', []))

def get_study_channels(study):
return set(study.get('filter', {}).get('channel', []))

def get_study_version_range(study):
return [
version_to_int_array(study.get('filter', {}).get('min_version')),
version_to_int_array(study.get('filter', {}).get('max_version')),
]

def is_filter_set_intersect(a, b):
return not a or not b or a.intersection(b)

def is_version_range_intersect(range1, range2):
return compare_versions(range1[1], range2[0]) >= 0 and compare_versions(range2[1], range1[0]) >= 0

test_version_comparison()
for studies in feature_names_to_studies.values():
for i, study1 in enumerate(studies):
study1_platform = get_study_platforms(study1)
study1_channel = get_study_channels(study1)
study1_version_range = get_study_version_range(study1)
for j in range(i + 1, len(studies)):
study2 = studies[j]
study2_platform = get_study_platforms(study2)
study2_channel = get_study_channels(study2)
study2_version_range = get_study_version_range(study2)
# Check if the studies overlap in platform
if is_filter_set_intersect(study1_platform, study2_platform):
# Check if the studies overlap in channel
if is_filter_set_intersect(study1_channel, study2_channel):
# Check if the studies overlap in version
if is_version_range_intersect(study1_version_range, study2_version_range):
raise ValueError(f"Studies overlap:\n{json.dumps(study1, indent=2)}\n\n{json.dumps(study2, indent=2)}")

return True


Expand Down Expand Up @@ -207,6 +326,7 @@ def serialize_and_save_variations_seed_message(seed_data, path):
if __name__ == "__main__":
print("Load seed.json")
seed_data = load(sys.argv[1])
seed_data['studies'].sort(key=lambda study: study['name'])

print("Validate seed data")
if validate(seed_data):
Expand Down
Loading