Skip to content

Commit

Permalink
Fix providers compact tests
Browse files Browse the repository at this point in the history
- Copy old test case that use read or _read methods
- Add mark_test_for_stream_based_read_log_method and
  mark_test_for_old_read_log_method to skip corresponding CI tests
  • Loading branch information
jason810496 committed Jan 8, 2025
1 parent bf542f4 commit b0d2a5d
Show file tree
Hide file tree
Showing 9 changed files with 659 additions and 3 deletions.
36 changes: 36 additions & 0 deletions providers/tests/amazon/aws/log/test_cloudwatch_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
from airflow.utils.timezone import datetime

from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.file_task_handler import (
mark_test_for_old_read_log_method,
mark_test_for_stream_based_read_log_method,
)
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS


Expand Down Expand Up @@ -130,6 +134,7 @@ def test_event_to_str(self):
]
)

@mark_test_for_old_read_log_method
def test_read(self):
# Confirmed via AWS Support call:
# CloudWatch events must be ordered chronologically otherwise
Expand All @@ -147,6 +152,37 @@ def test_read(self):
],
)

msg_template = "*** Reading remote log from Cloudwatch log_group: {} log_stream: {}.\n{}\n"
events = "\n".join(
[
f"[{get_time_str(current_time-2000)}] First",
f"[{get_time_str(current_time-1000)}] Second",
f"[{get_time_str(current_time)}] Third",
]
)
assert self.cloudwatch_task_handler.read(self.ti) == (
[[("", msg_template.format(self.remote_log_group, self.remote_log_stream, events))]],
[{"end_of_log": True}],
)

@mark_test_for_stream_based_read_log_method
def test_stream_based_read(self):
# Confirmed via AWS Support call:
# CloudWatch events must be ordered chronologically otherwise
# boto3 put_log_event API throws InvalidParameterException
# (moto does not throw this exception)
current_time = int(time.time()) * 1000
generate_log_events(
self.conn,
self.remote_log_group,
self.remote_log_stream,
[
{"timestamp": current_time - 2000, "message": "First"},
{"timestamp": current_time - 1000, "message": "Second"},
{"timestamp": current_time, "message": "Third"},
],
)

msg_template = "*** Reading remote log from Cloudwatch log_group: {} log_stream: {}.\n{}"
events = "\n".join(
[
Expand Down
32 changes: 30 additions & 2 deletions providers/tests/amazon/aws/log/test_s3_task_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,10 @@
from airflow.utils.timezone import datetime

from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.file_task_handler import (
mark_test_for_old_read_log_method,
mark_test_for_stream_based_read_log_method,
)
from tests_common.test_utils.version_compat import AIRFLOW_V_3_0_PLUS


Expand Down Expand Up @@ -126,24 +130,48 @@ def test_set_context_not_raw(self):
mock_open.assert_called_once_with(os.path.join(self.local_log_location, "1.log"), "w")
mock_open().write.assert_not_called()

@mark_test_for_old_read_log_method
def test_read(self):
self.conn.put_object(Bucket="bucket", Key=self.remote_log_key, Body=b"Log line\n")
ti = copy.copy(self.ti)
ti.state = TaskInstanceState.SUCCESS
log, metadata = self.s3_task_handler.read(ti)
actual = log[0][0][-1]
assert "*** Found logs in s3:\n*** * s3://bucket/remote/log/location/1.log\n" in actual
assert actual.endswith("Log line")
assert metadata == [{"end_of_log": True, "log_pos": 8}]

@mark_test_for_stream_based_read_log_method
def test_stream_based_read(self):
self.conn.put_object(Bucket="bucket", Key=self.remote_log_key, Body=b"Log line\n")
ti = copy.copy(self.ti)
ti.state = TaskInstanceState.SUCCESS
read_result = self.s3_task_handler.read(ti)
print("read_result", read_result)
_, log_streams, metadata_array = read_result
log_str = "".join(line for line in log_streams[0])
assert "*** Found logs in s3:\n*** * s3://bucket/remote/log/location/1.log\n" in log_str
assert log_str.endswith("Log line\n")
assert metadata_array == [{"end_of_log": True, "log_pos": 9}]

@mark_test_for_old_read_log_method
def test_read_when_s3_log_missing(self):
ti = copy.copy(self.ti)
ti.state = TaskInstanceState.SUCCESS
self.s3_task_handler._read_from_logs_server = mock.Mock(return_value=([], []))
log, metadata = self.s3_task_handler.read(ti)
assert len(log) == 1
assert len(log) == len(metadata)
actual = log[0][0][-1]
expected = "*** No logs found on s3 for ti=<TaskInstance: dag_for_testing_s3_task_handler.task_for_testing_s3_log_handler test [success]>\n"
assert expected in actual
assert metadata[0] == {"end_of_log": True, "log_pos": 0}

@mark_test_for_stream_based_read_log_method
def test_stream_based_read_when_s3_log_missing(self):
ti = copy.copy(self.ti)
ti.state = TaskInstanceState.SUCCESS
self.s3_task_handler._read_from_logs_server = mock.Mock(return_value=([], [], 0))
read_result = self.s3_task_handler.read(ti)
print("read_result", read_result)
_, log_streams, metadata_array = read_result
assert len(log_streams) == 1
assert len(log_streams) == len(metadata_array)
Expand Down
32 changes: 31 additions & 1 deletion providers/tests/celery/log_handlers/test_log_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,11 @@
from airflow.utils.types import DagRunType

from tests_common.test_utils.config import conf_vars
from tests_common.test_utils.file_task_handler import log_str_to_parsed_log_stream
from tests_common.test_utils.file_task_handler import (
log_str_to_parsed_log_stream,
mark_test_for_old_read_log_method,
mark_test_for_stream_based_read_log_method,
)

pytestmark = pytest.mark.db_test

Expand All @@ -61,7 +65,33 @@ def setup_method(self):
def teardown_method(self):
self.clean_up()

@mark_test_for_old_read_log_method
def test__read_for_celery_executor_fallbacks_to_worker(self, create_task_instance):
"""Test for executors which do not have `get_task_log` method, it fallbacks to reading
log from worker"""
executor_name = "CeleryExecutor"
ti = create_task_instance(
dag_id="dag_for_testing_celery_executor_log_read",
task_id="task_for_testing_celery_executor_log_read",
run_type=DagRunType.SCHEDULED,
logical_date=DEFAULT_DATE,
)
ti.state = TaskInstanceState.RUNNING
ti.try_number = 1
with conf_vars({("core", "executor"): executor_name}):
reload(executor_loader)
fth = FileTaskHandler("")

fth._read_from_logs_server = mock.Mock()
fth._read_from_logs_server.return_value = ["this message"], ["this\nlog\ncontent"]
actual = fth._read(ti=ti, try_number=1)
fth._read_from_logs_server.assert_called_once()
assert "*** this message\n" in actual[0]
assert actual[0].endswith("this\nlog\ncontent")
assert actual[1] == {"end_of_log": False, "log_pos": 16}

@mark_test_for_stream_based_read_log_method
def test_stream_based__read_for_celery_executor_fallbacks_to_worker(self, create_task_instance):
"""Test for executors which do not have `get_task_log` method, it fallbacks to reading
log from worker"""
executor_name = "CeleryExecutor"
Expand Down
Loading

0 comments on commit b0d2a5d

Please sign in to comment.