Skip to content

Commit

Permalink
feat: count difference of uploaded sessions (#454)
Browse files Browse the repository at this point in the history
  • Loading branch information
giovanni-guidini authored May 28, 2024
1 parent 8390733 commit d6168c2
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 3 deletions.
64 changes: 62 additions & 2 deletions services/comparison/__init__.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,22 @@
import asyncio
import logging
from dataclasses import dataclass
from typing import List, Optional
from typing import Dict, List, Optional

from shared.reports.changes import get_changes_using_rust, run_comparison_using_rust
from shared.reports.types import Change
from shared.torngit.exceptions import (
TorngitClientGeneralError,
)
from shared.utils.sessions import SessionType

from database.enums import CompareCommitState, TestResultsProcessingError
from database.models import CompareCommit
from helpers.metrics import metrics
from services.archive import ArchiveService
from services.comparison.changes import get_changes
from services.comparison.overlays import get_overlay
from services.comparison.types import Comparison, FullCommit
from services.comparison.types import Comparison, FullCommit, ReportUploadedCount
from services.repository import get_repo_provider_service

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -46,6 +47,7 @@ class ComparisonProxy(object):
Attributes:
comparison (Comparison): The original comparison we want to wrap and proxy
context (ComparisonContext | None): Other information not coverage-related that may affect notifications
"""

def __init__(
Expand All @@ -66,6 +68,7 @@ def __init__(
self._archive_service = None
self._overlays = {}
self.context = context or ComparisonContext()
self._cached_reports_uploaded_per_flag: List[ReportUploadedCount] | None = None

def get_archive_service(self):
if self._archive_service is None:
Expand Down Expand Up @@ -280,6 +283,63 @@ async def get_impacted_files(self):
files_in_diff,
)

def get_reports_uploaded_count_per_flag(self) -> List[ReportUploadedCount]:
"""This function counts how many reports (by flag) the BASE and HEAD commit have."""
if self._cached_reports_uploaded_per_flag:
# Reports may have many sessions, so it's useful to memoize this function
return self._cached_reports_uploaded_per_flag
if not self.has_head_report() or not self.has_project_coverage_base_report():
log.warning(
"Can't calculate diff in uploads. Missing some report",
extra=dict(
has_head_report=self.has_head_report(),
has_project_base_report=self.has_project_coverage_base_report(),
),
)
return []
per_flag_dict: Dict[str, ReportUploadedCount] = dict()
base_report = self.comparison.project_coverage_base.report
head_report = self.comparison.head.report
ops = [(base_report, "base_count"), (head_report, "head_count")]
for curr_report, curr_counter in ops:
for session in curr_report.sessions:
# We ignore carryforward sessions
# Because not all commits would upload all flags (potentially)
# But they are still carried forward
if session.session_type != SessionType.carriedforward:
if session.flags == []:
session.flags = [""]
for flag in session.flags:
dict_value = per_flag_dict.get(flag)
if dict_value is None:
dict_value = ReportUploadedCount(
flag=flag, base_count=0, head_count=0
)
dict_value[curr_counter] += 1
per_flag_dict[flag] = dict_value
self._cached_reports_uploaded_per_flag = list(per_flag_dict.values())
return self._cached_reports_uploaded_per_flag

def get_reports_uploaded_count_per_flag_diff(self) -> List[ReportUploadedCount]:
"""
Returns the difference, per flag, or reports uploaded in BASE and HEAD
❗️ For a difference to be considered there must be at least 1 "uploaded" upload in both
BASE and HEAD (that is, if all reports for a flag are "carryforward" it's not considered a diff)
"""
reports_per_flag = self.get_reports_uploaded_count_per_flag()

def is_valid_diff(obj: ReportUploadedCount):
return (
obj["base_count"] > 0
and obj["head_count"] > 0
and obj["base_count"] != obj["head_count"]
)

per_flag_diff = list(filter(is_valid_diff, reports_per_flag))
self._cached_reports_uploaded_per_flag = per_flag_diff
return per_flag_diff


class FilteredComparison(object):
def __init__(self, real_comparison: ComparisonProxy, *, flags, path_patterns):
Expand Down
117 changes: 117 additions & 0 deletions services/comparison/tests/unit/test_reports_uploaded_count_diff.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from unittest.mock import MagicMock

import pytest
from shared.reports.resources import Report
from shared.utils.sessions import Session, SessionType

from services.comparison import ComparisonProxy
from services.comparison.types import Comparison, FullCommit, ReportUploadedCount


@pytest.mark.parametrize(
"head_sessions, base_sessions, expected_count, expected_diff",
[
(
[
Session(
flags=["unit", "local"], session_type=SessionType.carriedforward
),
Session(flags=["integration"], session_type=SessionType.uploaded),
Session(flags=["unit"], session_type=SessionType.uploaded),
Session(flags=["unit"], session_type=SessionType.uploaded),
Session(flags=["integration"], session_type=SessionType.uploaded),
Session(flags=[], session_type=SessionType.uploaded),
],
[
Session(
flags=["unit", "local"], session_type=SessionType.carriedforward
),
Session(flags=["integration"], session_type=SessionType.carriedforward),
Session(flags=["unit"], session_type=SessionType.uploaded),
Session(flags=["unit"], session_type=SessionType.uploaded),
],
[
ReportUploadedCount(flag="unit", base_count=2, head_count=2),
ReportUploadedCount(flag="integration", base_count=0, head_count=2),
ReportUploadedCount(flag="", base_count=0, head_count=1),
],
[],
),
(
[
Session(
flags=["unit", "local"], session_type=SessionType.carriedforward
),
Session(flags=["integration"], session_type=SessionType.uploaded),
Session(flags=["unit"], session_type=SessionType.uploaded),
Session(flags=["unit"], session_type=SessionType.uploaded),
Session(flags=["integration"], session_type=SessionType.uploaded),
Session(flags=[""], session_type=SessionType.uploaded),
],
[
Session(flags=["unit", "local"], session_type=SessionType.uploaded),
Session(flags=["integration"], session_type=SessionType.uploaded),
Session(flags=["unit"], session_type=SessionType.uploaded),
Session(flags=["unit"], session_type=SessionType.uploaded),
Session(flags=["obscure_flag"], session_type=SessionType.uploaded),
],
[
ReportUploadedCount(flag="unit", base_count=3, head_count=2),
ReportUploadedCount(flag="local", base_count=1, head_count=0),
ReportUploadedCount(flag="integration", base_count=1, head_count=2),
ReportUploadedCount(flag="obscure_flag", base_count=1, head_count=0),
ReportUploadedCount(flag="", base_count=0, head_count=1),
],
[
ReportUploadedCount(flag="unit", base_count=3, head_count=2),
ReportUploadedCount(flag="integration", base_count=1, head_count=2),
],
),
],
ids=["flag_counts_no_diff", "flag_count_yes_diff"],
)
def test_get_reports_uploaded_count_per_flag(
head_sessions, base_sessions, expected_count, expected_diff
):
head_report = Report()
head_report.sessions = head_sessions
base_report = Report()
base_report.sessions = base_sessions
comparison_proxy = ComparisonProxy(
comparison=Comparison(
head=FullCommit(report=head_report, commit=None),
project_coverage_base=FullCommit(report=base_report, commit=None),
patch_coverage_base_commitid=None,
enriched_pull=None,
)
)
# Python Dicts preserve order, so we can actually test this equality
# See more https://stackoverflow.com/a/39537308
assert comparison_proxy.get_reports_uploaded_count_per_flag() == expected_count
assert comparison_proxy.get_reports_uploaded_count_per_flag_diff() == expected_diff


def test_get_reports_uploaded_count_per_flag_cached():
comparison_proxy = ComparisonProxy(comparison=MagicMock(name="fake_comparison"))
comparison_proxy._cached_reports_uploaded_per_flag = (
"object_that_doesnt_have_this_shape"
)
assert (
comparison_proxy.get_reports_uploaded_count_per_flag()
== "object_that_doesnt_have_this_shape"
)


def test_get_reports_uploaded_count_per_flag_diff_missing_report():
head_report = None
base_report = Report()
base_report.sessions = None
comparison_proxy = ComparisonProxy(
comparison=Comparison(
head=FullCommit(report=head_report, commit=None),
project_coverage_base=FullCommit(report=base_report, commit=None),
patch_coverage_base_commitid=None,
enriched_pull=None,
)
)
assert comparison_proxy.get_reports_uploaded_count_per_flag_diff() == []
8 changes: 7 additions & 1 deletion services/comparison/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from dataclasses import dataclass
from typing import Optional
from typing import Optional, TypedDict

from shared.reports.resources import Report
from shared.yaml import UserYaml
Expand All @@ -14,6 +14,12 @@ class FullCommit(object):
report: Report


class ReportUploadedCount(TypedDict):
flag: str = ""
base_count: int = 0
head_count: int = 0


@dataclass
class Comparison(object):
head: FullCommit
Expand Down

0 comments on commit d6168c2

Please sign in to comment.