Skip to content

Commit

Permalink
WIP: Make ComparisonProxy sync
Browse files Browse the repository at this point in the history
  • Loading branch information
Swatinem committed Sep 20, 2024
1 parent 7e77395 commit 0098f86
Show file tree
Hide file tree
Showing 20 changed files with 486 additions and 612 deletions.
322 changes: 149 additions & 173 deletions services/comparison/__init__.py

Large diffs are not rendered by default.

4 changes: 2 additions & 2 deletions services/comparison/overlays/critical_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,11 +122,11 @@ async def search_files_for_critical_changes(
)
return list(critical_files_from_profiling | critical_files_from_yaml)

async def find_impacted_endpoints(self):
def find_impacted_endpoints(self):
analyzer = self.full_analyzer
if analyzer is None:
return None
diff = rustify_diff(await self._comparison.get_diff())
diff = rustify_diff(self._comparison.get_diff())
return self.full_analyzer.find_impacted_endpoints(
self._comparison.project_coverage_base.report.rust_report.get_report(),
self._comparison.head.report.rust_report.get_report(),
Expand Down
25 changes: 10 additions & 15 deletions services/comparison/tests/unit/test_comparison_proxy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
from mock import call, patch

from database.tests.factories import CommitFactory, PullFactory, RepositoryFactory
Expand Down Expand Up @@ -44,12 +43,11 @@ def make_sample_comparison(adjusted_base=False):
class TestComparisonProxy(object):
compare_url = "https://api.github.com/repos/{}/compare/{}...{}"

@pytest.mark.asyncio
@patch("shared.torngit.github.Github.get_compare")
async def test_get_diff_adjusted_base(self, mock_get_compare):
def test_get_diff_adjusted_base(self, mock_get_compare):
comparison = make_sample_comparison(adjusted_base=True)
mock_get_compare.return_value = {"diff": "magic string"}
result = await comparison.get_diff(use_original_base=False)
result = comparison.get_diff(use_original_base=False)

assert result == "magic string"
assert comparison._adjusted_base_diff == "magic string"
Expand All @@ -67,12 +65,11 @@ async def test_get_diff_adjusted_base(self, mock_get_compare):
),
]

@pytest.mark.asyncio
@patch("shared.torngit.github.Github.get_compare")
async def test_get_diff_original_base(self, mock_get_compare):
def test_get_diff_original_base(self, mock_get_compare):
comparison = make_sample_comparison(adjusted_base=True)
mock_get_compare.return_value = {"diff": "magic string"}
result = await comparison.get_diff(use_original_base=True)
result = comparison.get_diff(use_original_base=True)

assert result == "magic string"
assert comparison._original_base_diff == "magic string"
Expand All @@ -90,12 +87,11 @@ async def test_get_diff_original_base(self, mock_get_compare):
),
]

@pytest.mark.asyncio
@patch("shared.torngit.github.Github.get_compare")
async def test_get_diff_bases_match_original_base(self, mock_get_compare):
def test_get_diff_bases_match_original_base(self, mock_get_compare):
comparison = make_sample_comparison(adjusted_base=False)
mock_get_compare.return_value = {"diff": "magic string"}
result = await comparison.get_diff(use_original_base=True)
result = comparison.get_diff(use_original_base=True)

assert result == "magic string"
assert comparison._original_base_diff == "magic string"
Expand All @@ -106,7 +102,7 @@ async def test_get_diff_bases_match_original_base(self, mock_get_compare):

# In this test case, the adjusted and original base commits are the
# same. If we get one, we should set the cache for the other.
adjusted_base_result = await comparison.get_diff(use_original_base=False)
adjusted_base_result = comparison.get_diff(use_original_base=False)
assert comparison._adjusted_base_diff == "magic string"

# Make sure we only called the Git provider API once
Expand All @@ -118,12 +114,11 @@ async def test_get_diff_bases_match_original_base(self, mock_get_compare):
),
]

@pytest.mark.asyncio
@patch("shared.torngit.github.Github.get_compare")
async def test_get_diff_bases_match_adjusted_base(self, mock_get_compare):
def test_get_diff_bases_match_adjusted_base(self, mock_get_compare):
comparison = make_sample_comparison(adjusted_base=False)
mock_get_compare.return_value = {"diff": "magic string"}
result = await comparison.get_diff(use_original_base=False)
result = comparison.get_diff(use_original_base=False)

assert result == "magic string"
assert comparison._adjusted_base_diff == "magic string"
Expand All @@ -134,7 +129,7 @@ async def test_get_diff_bases_match_adjusted_base(self, mock_get_compare):

# In this test case, the adjusted and original base commits are the
# same. If we get one, we should set the cache for the other.
adjusted_base_result = await comparison.get_diff(use_original_base=True)
adjusted_base_result = comparison.get_diff(use_original_base=True)
assert comparison._adjusted_base_diff == "magic string"

# Make sure we only called the Git provider API once
Expand Down
6 changes: 3 additions & 3 deletions services/comparison/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ class FullCommit(object):


class ReportUploadedCount(TypedDict):
flag: str = ""
base_count: int = 0
head_count: int = 0
flag: str
base_count: int
head_count: int


@dataclass
Expand Down
6 changes: 3 additions & 3 deletions services/notification/notifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class NotificationResult(object):
notification_attempted: bool = False
notification_successful: bool = False
explanation: str = None
explanation: str | None = None
data_sent: Mapping[str, Any] | None = None
data_received: Mapping[str, Any] | None = None
github_app_used: int | None = None
Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(
notifier_yaml_settings: Mapping[str, Any],
notifier_site_settings: Mapping[str, Any],
current_yaml: Mapping[str, Any],
decoration_type: Decoration = None,
decoration_type: Decoration | None = None,
gh_installation_name_to_use: str = GITHUB_APP_INSTALLATION_DEFAULT_NAME,
):
"""
Expand All @@ -78,7 +78,7 @@ def __init__(
def name(self) -> str:
raise NotImplementedError()

async def notify(self, comparison: Comparison, **extra_data) -> NotificationResult:
def notify(self, comparison: Comparison, **extra_data) -> NotificationResult:
raise NotImplementedError()

def is_enabled(self) -> bool:
Expand Down
58 changes: 24 additions & 34 deletions services/notification/notifiers/checks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from contextlib import nullcontext
from typing import Dict

from asgiref.sync import async_to_sync
from shared.torngit.exceptions import TorngitClientError, TorngitError

from helpers.metrics import metrics
from services.notification.notifiers.base import Comparison, NotificationResult
from services.notification.notifiers.status.base import StatusNotifier
from services.urls import (
Expand Down Expand Up @@ -64,14 +64,14 @@ def paginate_annotations(self, annotations):
for i in range(0, len(annotations), self.ANNOTATIONS_PER_REQUEST):
yield annotations[i : i + self.ANNOTATIONS_PER_REQUEST]

async def build_payload(self, comparison) -> Dict[str, str]:
def build_payload(self, comparison) -> Dict[str, str]:
raise NotImplementedError()

def get_status_external_name(self) -> str:
status_piece = f"/{self.title}" if self.title != "default" else ""
return f"codecov/{self.context}{status_piece}"

async def notify(self, comparison: Comparison):
def notify(self, comparison: Comparison):
if comparison.pull is None or ():
log.debug(
"Falling back to commit_status: Not a pull request",
Expand Down Expand Up @@ -152,7 +152,7 @@ async def notify(self, comparison: Comparison):
)
)
if not comparison.has_head_report():
payload = await self.build_payload(comparison)
payload = self.build_payload(comparison)
elif (
flag_coverage_not_uploaded_behavior == "exclude"
and not self.flag_coverage_was_uploaded(comparison)
Expand All @@ -171,7 +171,7 @@ async def notify(self, comparison: Comparison):
filtered_comparison = comparison.get_filtered_comparison(
**self.get_notifier_filters()
)
payload = await self.build_payload(filtered_comparison)
payload = self.build_payload(filtered_comparison)
payload["state"] = "success"
payload["output"]["summary"] = (
payload.get("output", {}).get("summary", "")
Expand All @@ -181,12 +181,12 @@ async def notify(self, comparison: Comparison):
filtered_comparison = comparison.get_filtered_comparison(
**self.get_notifier_filters()
)
payload = await self.build_payload(filtered_comparison)
payload = self.build_payload(filtered_comparison)
if comparison.pull:
payload["url"] = get_pull_url(comparison.pull)
else:
payload["url"] = get_commit_url(comparison.head.commit)
return await self.maybe_send_notification(comparison, payload)
return self.maybe_send_notification(comparison, payload)
except TorngitClientError as e:
if e.code == 403:
raise e
Expand Down Expand Up @@ -316,7 +316,6 @@ def get_lines_to_annotate(self, comparison, files_with_change):
previous_line = line
return line_headers

@metrics.timer("worker.services.notifications.notifiers.checks.create_annotations")
def create_annotations(self, comparison, diff):
files_with_change = [
{"type": _diff["type"], "path": path, "segments": _diff["segments"]}
Expand All @@ -343,7 +342,7 @@ def create_annotations(self, comparison, diff):
annotations.append(annotation)
return annotations

async def send_notification(self, comparison: Comparison, payload):
def send_notification(self, comparison: Comparison, payload):
title = self.get_status_external_name()
head = comparison.head.commit
repository_service = self.repository_service(head)
Expand Down Expand Up @@ -382,12 +381,9 @@ async def send_notification(self, comparison: Comparison, payload):
)

# We need to first create the check run, get that id and update the status
with metrics.timer(
"worker.services.notifications.notifiers.checks.create_check_run"
):
check_id = await repository_service.create_check_run(
check_name=title, head_sha=head.commitid
)
check_id = async_to_sync(repository_service.create_check_run)(
check_name=title, head_sha=head.commitid
)

if len(output.get("annotations", [])) > self.ANNOTATIONS_PER_REQUEST:
annotation_pages = list(
Expand All @@ -401,27 +397,21 @@ async def send_notification(self, comparison: Comparison, payload):
),
)
for annotation_page in annotation_pages:
with metrics.timer(
"worker.services.notifications.notifiers.checks.update_check_run"
):
await repository_service.update_check_run(
check_id,
state,
output={
"title": output.get("title"),
"summary": output.get("summary"),
"annotations": annotation_page,
},
url=payload.get("url"),
)
async_to_sync(repository_service.update_check_run)(
check_id,
state,
output={
"title": output.get("title"),
"summary": output.get("summary"),
"annotations": annotation_page,
},
url=payload.get("url"),
)

else:
with metrics.timer(
"worker.services.notifications.notifiers.checks.update_check_run"
):
await repository_service.update_check_run(
check_id, state, output=output, url=payload.get("url")
)
async_to_sync(repository_service.update_check_run)(
check_id, state, output=output, url=payload.get("url")
)

return NotificationResult(
notification_attempted=True,
Expand Down
74 changes: 35 additions & 39 deletions services/notification/notifiers/checks/patch.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from database.enums import Notification
from helpers.metrics import metrics
from services.notification.notifiers.base import Comparison
from services.notification.notifiers.checks.base import ChecksNotifier
from services.notification.notifiers.mixins.status import StatusPatchMixin
Expand All @@ -13,7 +12,7 @@ class PatchChecksNotifier(StatusPatchMixin, ChecksNotifier):
def notification_type(self) -> Notification:
return Notification.checks_patch

async def build_payload(self, comparison: Comparison):
def build_payload(self, comparison: Comparison):
"""
This method build the paylod of the patch github checks.
Expand All @@ -29,52 +28,49 @@ async def build_payload(self, comparison: Comparison):
"summary": message,
},
}
with metrics.timer(
"worker.services.notifications.notifiers.checks.patch.build_payload"
):
state, message = await self.get_patch_status(comparison)
codecov_link = self.get_codecov_pr_link(comparison)

title = message
state, message = self.get_patch_status(comparison)
codecov_link = self.get_codecov_pr_link(comparison)

should_use_upgrade = self.should_use_upgrade_decoration()
if should_use_upgrade:
message = self.get_upgrade_message(comparison)
title = "Codecov Report"
title = message

checks_yaml_field = read_yaml_field(self.current_yaml, ("github_checks",))
should_use_upgrade = self.should_use_upgrade_decoration()
if should_use_upgrade:
message = self.get_upgrade_message(comparison)
title = "Codecov Report"

should_annotate = (
checks_yaml_field.get("annotations", False)
if checks_yaml_field is not None
else True
)
checks_yaml_field = read_yaml_field(self.current_yaml, ("github_checks",))

flags = self.notifier_yaml_settings.get("flags")
paths = self.notifier_yaml_settings.get("paths")
if (
flags is not None
or paths is not None
or should_use_upgrade
or should_annotate is False
):
return {
"state": state,
"output": {
"title": f"{title}",
"summary": "\n\n".join([codecov_link, message]),
},
}
diff = await comparison.get_diff(use_original_base=True)
# TODO: Look into why the apply diff in get_patch_status is not saving state at this point
comparison.head.report.apply_diff(diff)
annotations = self.create_annotations(comparison, diff)
should_annotate = (
checks_yaml_field.get("annotations", False)
if checks_yaml_field is not None
else True
)

flags = self.notifier_yaml_settings.get("flags")
paths = self.notifier_yaml_settings.get("paths")
if (
flags is not None
or paths is not None
or should_use_upgrade
or should_annotate is False
):
return {
"state": state,
"output": {
"title": f"{title}",
"summary": "\n\n".join([codecov_link, message]),
"annotations": annotations,
},
}
diff = comparison.get_diff(use_original_base=True)
# TODO: Look into why the apply diff in get_patch_status is not saving state at this point
comparison.head.report.apply_diff(diff)
annotations = self.create_annotations(comparison, diff)

return {
"state": state,
"output": {
"title": f"{title}",
"summary": "\n\n".join([codecov_link, message]),
"annotations": annotations,
},
}
Loading

0 comments on commit 0098f86

Please sign in to comment.