Skip to content

Commit

Permalink
WIP: Make ComparisonProxy sync
Browse files Browse the repository at this point in the history
  • Loading branch information
Swatinem committed Sep 20, 2024
1 parent 7e77395 commit 1444f9c
Show file tree
Hide file tree
Showing 30 changed files with 915 additions and 1,276 deletions.
322 changes: 149 additions & 173 deletions services/comparison/__init__.py

Large diffs are not rendered by default.

15 changes: 7 additions & 8 deletions services/comparison/overlays/critical_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import re
from typing import Sequence

from asgiref.sync import async_to_sync
from cc_rustyribs import rustify_diff
from shared.profiling import ProfilingDataFullAnalyzer, ProfilingSummaryDataAnalyzer
from shared.storage.exceptions import FileNotInStorageError
Expand Down Expand Up @@ -76,7 +77,7 @@ def full_analyzer(self):
self._profiling_analyzer = _load_full_profiling_analyzer(self._comparison)
return self._profiling_analyzer

async def _get_critical_files_from_yaml(self, filenames_to_search: Sequence[str]):
def _get_critical_files_from_yaml(self, filenames_to_search: Sequence[str]):
"""
Get list of files in filenames_to_search that match the list of critical_file paths defined by the user in the YAML (under profiling.critical_files_paths)
"""
Expand All @@ -87,7 +88,7 @@ async def _get_critical_files_from_yaml(self, filenames_to_search: Sequence[str]
repo_provider = get_repo_provider_service(
repo, installation_name_to_use=gh_app_installation_name
)
current_yaml = await get_current_yaml(
current_yaml = async_to_sync(get_current_yaml)(
self._comparison.head.commit, repo_provider
)
if not current_yaml.get("profiling") or not current_yaml["profiling"].get(
Expand All @@ -103,9 +104,7 @@ async def _get_critical_files_from_yaml(self, filenames_to_search: Sequence[str]
]
return user_defined_critical_files

async def search_files_for_critical_changes(
self, filenames_to_search: Sequence[str]
):
def search_files_for_critical_changes(self, filenames_to_search: Sequence[str]):
"""
Returns list of files considered critical in filenames_to_search.
Critical files comes from 2 sources:
Expand All @@ -118,15 +117,15 @@ async def search_files_for_critical_changes(
self._critical_path_report.get_critical_files_filenames()
)
critical_files_from_yaml = set(
await self._get_critical_files_from_yaml(filenames_to_search)
self._get_critical_files_from_yaml(filenames_to_search)
)
return list(critical_files_from_profiling | critical_files_from_yaml)

async def find_impacted_endpoints(self):
def find_impacted_endpoints(self):
analyzer = self.full_analyzer
if analyzer is None:
return None
diff = rustify_diff(await self._comparison.get_diff())
diff = rustify_diff(self._comparison.get_diff())
return self.full_analyzer.find_impacted_endpoints(
self._comparison.project_coverage_base.report.rust_report.get_report(),
self._comparison.head.report.rust_report.get_report(),
Expand Down
40 changes: 14 additions & 26 deletions services/comparison/tests/unit/overlay/test_critical_path.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,22 +143,20 @@ def test_load_critical_path_report_yes_commit_no_storage(
assert _load_full_profiling_analyzer(sample_comparison) is None


@pytest.mark.asyncio
async def test_critical_files_from_yaml_no_paths(mocker, sample_comparison):
def test_critical_files_from_yaml_no_paths(mocker, sample_comparison):
sample_comparison.comparison.current_yaml = dict()
mocked_get_yaml = mocker.patch(
"services.comparison.overlays.critical_path.get_current_yaml"
)
overlay = CriticalPathOverlay(sample_comparison, None)
critical_paths_from_yaml = await overlay._get_critical_files_from_yaml(
critical_paths_from_yaml = overlay._get_critical_files_from_yaml(
["batata.txt", "a.py"]
)
assert critical_paths_from_yaml == []
mocked_get_yaml.assert_not_called()


@pytest.mark.asyncio
async def test_critical_files_from_yaml_with_paths(mocker, sample_comparison):
def test_critical_files_from_yaml_with_paths(mocker, sample_comparison):
sample_comparison.comparison.current_yaml = {
"profiling": {
"critical_files_paths": ["src/critical", "important.txt"],
Expand All @@ -168,15 +166,14 @@ async def test_critical_files_from_yaml_with_paths(mocker, sample_comparison):
"services.comparison.overlays.critical_path.get_current_yaml"
)
overlay = CriticalPathOverlay(sample_comparison, None)
critical_paths_from_yaml = await overlay._get_critical_files_from_yaml(
critical_paths_from_yaml = overlay._get_critical_files_from_yaml(
["batata.txt", "src/critical/a.py"]
)
assert critical_paths_from_yaml == ["src/critical/a.py"]
mocked_get_yaml.assert_not_called()


@pytest.mark.asyncio
async def test_critical_files_from_yaml_with_paths_get_yaml_from_provider(
def test_critical_files_from_yaml_with_paths_get_yaml_from_provider(
mocker, sample_comparison
):
mocked_get_yaml = mocker.patch(
Expand All @@ -188,27 +185,20 @@ async def test_critical_files_from_yaml_with_paths_get_yaml_from_provider(
},
)
overlay = CriticalPathOverlay(sample_comparison, None)
critical_paths_from_yaml = await overlay._get_critical_files_from_yaml(
critical_paths_from_yaml = overlay._get_critical_files_from_yaml(
["batata.txt", "src/critical/a.py"]
)
assert critical_paths_from_yaml == ["src/critical/a.py"]
mocked_get_yaml.assert_called()


class TestCriticalPathOverlay(object):
@pytest.mark.asyncio
async def test_search_files_for_critical_changes_none_report(
self, sample_comparison
):
def test_search_files_for_critical_changes_none_report(self, sample_comparison):
sample_comparison.comparison.current_yaml = dict()
a = CriticalPathOverlay(sample_comparison, None)
assert (
await a.search_files_for_critical_changes(["filenames", "to", "search"])
== []
)
assert a.search_files_for_critical_changes(["filenames", "to", "search"]) == []

@pytest.mark.asyncio
async def test_search_files_for_critical_changes_none_report_with_yaml_path(
def test_search_files_for_critical_changes_none_report_with_yaml_path(
self, sample_comparison, mocker
):
sample_comparison.comparison.current_yaml = {
Expand All @@ -217,18 +207,16 @@ async def test_search_files_for_critical_changes_none_report_with_yaml_path(
}
}
a = CriticalPathOverlay(sample_comparison, None)
assert await a.search_files_for_critical_changes(
assert a.search_files_for_critical_changes(
["filenames", "to", "search", "important.txt"]
) == ["important.txt"]

@pytest.mark.asyncio
async def test_find_impacted_endpoints_no_analyzer(self, sample_comparison):
def test_find_impacted_endpoints_no_analyzer(self, sample_comparison):
a = CriticalPathOverlay(sample_comparison, None)
a._profiling_analyzer = None
await a.find_impacted_endpoints() is None
a.find_impacted_endpoints() is None

@pytest.mark.asyncio
async def test_find_impacted_endpoints(
def test_find_impacted_endpoints(
self,
dbsession,
sample_comparison,
Expand Down Expand Up @@ -288,7 +276,7 @@ async def test_find_impacted_endpoints(
a = CriticalPathOverlay(sample_comparison, None)
print(sample_comparison.head.report.files)
print(sample_comparison.head.report.files)
res = await a.find_impacted_endpoints()
res = a.find_impacted_endpoints()
assert res == [
{
"files": [{"filename": "file_1.go", "impacted_base_lines": [5]}],
Expand Down
21 changes: 8 additions & 13 deletions services/comparison/tests/unit/test_behind_by.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
import pytest
from shared.torngit.exceptions import TorngitClientGeneralError

from services.comparison import ComparisonProxy


class TestGetBehindBy(object):
@pytest.mark.asyncio
async def test_get_behind_by(self, mocker, mock_repo_provider):
def test_get_behind_by(self, mocker, mock_repo_provider):
comparison = ComparisonProxy(mocker.MagicMock())
comparison.comparison.enriched_pull.provider_pull = {"base": {"branch": "a"}}
mock_repo_provider.get_branches.return_value = [("a", "1")]
Expand All @@ -18,25 +16,22 @@ async def test_get_behind_by(self, mocker, mock_repo_provider):
"services.comparison.get_repo_provider_service",
return_value=mock_repo_provider,
)
res = await comparison.get_behind_by()
res = comparison.get_behind_by()
assert res == 3

@pytest.mark.asyncio
async def test_get_behind_by_no_base_commit(self, mocker):
def test_get_behind_by_no_base_commit(self, mocker):
comparison = ComparisonProxy(mocker.MagicMock())
del comparison.comparison.project_coverage_base.commit.commitid
res = await comparison.get_behind_by()
res = comparison.get_behind_by()
assert res is None

@pytest.mark.asyncio
async def test_get_behind_by_no_provider_pull(self, mocker):
def test_get_behind_by_no_provider_pull(self, mocker):
comparison = ComparisonProxy(mocker.MagicMock())
comparison.comparison.enriched_pull.provider_pull = None
res = await comparison.get_behind_by()
res = comparison.get_behind_by()
assert res is None

@pytest.mark.asyncio
async def test_get_behind_by_no_matching_branches(self, mocker, mock_repo_provider):
def test_get_behind_by_no_matching_branches(self, mocker, mock_repo_provider):
mock_repo_provider.get_branch.side_effect = TorngitClientGeneralError(
404,
None,
Expand All @@ -47,5 +42,5 @@ async def test_get_behind_by_no_matching_branches(self, mocker, mock_repo_provid
return_value=mock_repo_provider,
)
comparison = ComparisonProxy(mocker.MagicMock())
res = await comparison.get_behind_by()
res = comparison.get_behind_by()
assert res is None
25 changes: 10 additions & 15 deletions services/comparison/tests/unit/test_comparison_proxy.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import pytest
from mock import call, patch

from database.tests.factories import CommitFactory, PullFactory, RepositoryFactory
Expand Down Expand Up @@ -44,12 +43,11 @@ def make_sample_comparison(adjusted_base=False):
class TestComparisonProxy(object):
compare_url = "https://api.github.com/repos/{}/compare/{}...{}"

@pytest.mark.asyncio
@patch("shared.torngit.github.Github.get_compare")
async def test_get_diff_adjusted_base(self, mock_get_compare):
def test_get_diff_adjusted_base(self, mock_get_compare):
comparison = make_sample_comparison(adjusted_base=True)
mock_get_compare.return_value = {"diff": "magic string"}
result = await comparison.get_diff(use_original_base=False)
result = comparison.get_diff(use_original_base=False)

assert result == "magic string"
assert comparison._adjusted_base_diff == "magic string"
Expand All @@ -67,12 +65,11 @@ async def test_get_diff_adjusted_base(self, mock_get_compare):
),
]

@pytest.mark.asyncio
@patch("shared.torngit.github.Github.get_compare")
async def test_get_diff_original_base(self, mock_get_compare):
def test_get_diff_original_base(self, mock_get_compare):
comparison = make_sample_comparison(adjusted_base=True)
mock_get_compare.return_value = {"diff": "magic string"}
result = await comparison.get_diff(use_original_base=True)
result = comparison.get_diff(use_original_base=True)

assert result == "magic string"
assert comparison._original_base_diff == "magic string"
Expand All @@ -90,12 +87,11 @@ async def test_get_diff_original_base(self, mock_get_compare):
),
]

@pytest.mark.asyncio
@patch("shared.torngit.github.Github.get_compare")
async def test_get_diff_bases_match_original_base(self, mock_get_compare):
def test_get_diff_bases_match_original_base(self, mock_get_compare):
comparison = make_sample_comparison(adjusted_base=False)
mock_get_compare.return_value = {"diff": "magic string"}
result = await comparison.get_diff(use_original_base=True)
result = comparison.get_diff(use_original_base=True)

assert result == "magic string"
assert comparison._original_base_diff == "magic string"
Expand All @@ -106,7 +102,7 @@ async def test_get_diff_bases_match_original_base(self, mock_get_compare):

# In this test case, the adjusted and original base commits are the
# same. If we get one, we should set the cache for the other.
adjusted_base_result = await comparison.get_diff(use_original_base=False)
adjusted_base_result = comparison.get_diff(use_original_base=False)
assert comparison._adjusted_base_diff == "magic string"

# Make sure we only called the Git provider API once
Expand All @@ -118,12 +114,11 @@ async def test_get_diff_bases_match_original_base(self, mock_get_compare):
),
]

@pytest.mark.asyncio
@patch("shared.torngit.github.Github.get_compare")
async def test_get_diff_bases_match_adjusted_base(self, mock_get_compare):
def test_get_diff_bases_match_adjusted_base(self, mock_get_compare):
comparison = make_sample_comparison(adjusted_base=False)
mock_get_compare.return_value = {"diff": "magic string"}
result = await comparison.get_diff(use_original_base=False)
result = comparison.get_diff(use_original_base=False)

assert result == "magic string"
assert comparison._adjusted_base_diff == "magic string"
Expand All @@ -134,7 +129,7 @@ async def test_get_diff_bases_match_adjusted_base(self, mock_get_compare):

# In this test case, the adjusted and original base commits are the
# same. If we get one, we should set the cache for the other.
adjusted_base_result = await comparison.get_diff(use_original_base=True)
adjusted_base_result = comparison.get_diff(use_original_base=True)
assert comparison._adjusted_base_diff == "magic string"

# Make sure we only called the Git provider API once
Expand Down
6 changes: 3 additions & 3 deletions services/comparison/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ class FullCommit(object):


class ReportUploadedCount(TypedDict):
flag: str = ""
base_count: int = 0
head_count: int = 0
flag: str
base_count: int
head_count: int


@dataclass
Expand Down
19 changes: 7 additions & 12 deletions services/notification/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ def get_statuses(self, current_flags: List[str]):
for component_status in self._get_component_statuses(current_flags):
yield component_status

async def notify(self, comparison: ComparisonProxy) -> List[NotificationResult]:
def notify(self, comparison: ComparisonProxy) -> list[NotificationResult]:
if not is_properly_licensed(comparison.head.commit.get_db_session()):
log.warning(
"Not sending notifications because the system is not properly licensed"
Expand All @@ -256,18 +256,13 @@ async def notify(self, comparison: ComparisonProxy) -> List[NotificationResult]:
for notifier in self.get_notifiers_instances()
if notifier.is_enabled()
]
results = []
chunk_size = 3
for i in range(0, len(notification_instances), chunk_size):
notification_instances_chunk = notification_instances[i : i + chunk_size]
task_chunk = [
self.notify_individual_notifier(notifier, comparison)
for notifier in notification_instances_chunk
]
results.extend(await asyncio.gather(*task_chunk))
results = [
self.notify_individual_notifier(notifier, comparison)
for notifier in notification_instances
]
return results

async def notify_individual_notifier(
def notify_individual_notifier(
self, notifier: AbstractBaseNotifier, comparison: ComparisonProxy
) -> NotificationResult:
commit = comparison.head.commit
Expand All @@ -292,7 +287,7 @@ async def notify_individual_notifier(
with metrics.timer(
f"worker.services.notifications.notifiers.{notifier.name}"
) as notify_timer:
res = await notifier.notify(comparison)
res = notifier.notify(comparison)
individual_result["result"] = res

notifier.store_results(comparison, res)
Expand Down
6 changes: 3 additions & 3 deletions services/notification/notifiers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
class NotificationResult(object):
notification_attempted: bool = False
notification_successful: bool = False
explanation: str = None
explanation: str | None = None
data_sent: Mapping[str, Any] | None = None
data_received: Mapping[str, Any] | None = None
github_app_used: int | None = None
Expand Down Expand Up @@ -53,7 +53,7 @@ def __init__(
notifier_yaml_settings: Mapping[str, Any],
notifier_site_settings: Mapping[str, Any],
current_yaml: Mapping[str, Any],
decoration_type: Decoration = None,
decoration_type: Decoration | None = None,
gh_installation_name_to_use: str = GITHUB_APP_INSTALLATION_DEFAULT_NAME,
):
"""
Expand All @@ -78,7 +78,7 @@ def __init__(
def name(self) -> str:
raise NotImplementedError()

async def notify(self, comparison: Comparison, **extra_data) -> NotificationResult:
def notify(self, comparison: Comparison, **extra_data) -> NotificationResult:
raise NotImplementedError()

def is_enabled(self) -> bool:
Expand Down
Loading

0 comments on commit 1444f9c

Please sign in to comment.