From 3c3c337edd6a6905c958206dc8f9fe4303c856eb Mon Sep 17 00:00:00 2001 From: Jarek Potiuk Date: Tue, 4 Jul 2023 14:58:29 +0200 Subject: [PATCH] Revert "Refactor Sqlalchemy queries to 2.0 style (Part 3) (#32177)" (#32343) This reverts commit 1065687ec6df2b9b3557e38a67e71f835796427f. --- airflow/utils/db.py | 38 ++-- airflow/www/utils.py | 11 +- airflow/www/views.py | 460 +++++++++++++++++++++---------------------- 3 files changed, 244 insertions(+), 265 deletions(-) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 46ddcbfb3453f..a76f0d4f675d1 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -92,7 +92,7 @@ def _format_airflow_moved_table_name(source_table, version, category): @provide_session def merge_conn(conn, session: Session = NEW_SESSION): """Add new Connection.""" - if not session.scalar(select(conn.__class__).filter_by(conn_id=conn.conn_id).limit(1)): + if not session.query(conn.__class__).filter_by(conn_id=conn.conn_id).first(): session.add(conn) session.commit() @@ -959,9 +959,7 @@ def check_conn_id_duplicates(session: Session) -> Iterable[str]: dups = [] try: - dups = session.execute( - select(Connection.conn_id).group_by(Connection.conn_id).having(func.count() > 1) - ).all() + dups = session.query(Connection.conn_id).group_by(Connection.conn_id).having(func.count() > 1).all() except (exc.OperationalError, exc.ProgrammingError): # fallback if tables hasn't been created yet session.rollback() @@ -986,11 +984,12 @@ def check_username_duplicates(session: Session) -> Iterable[str]: for model in [User, RegisterUser]: dups = [] try: - dups = session.execute( - select(model.username) # type: ignore[attr-defined] + dups = ( + session.query(model.username) # type: ignore[attr-defined] .group_by(model.username) # type: ignore[attr-defined] .having(func.count() > 1) - ).all() + .all() + ) except (exc.OperationalError, exc.ProgrammingError): # fallback if tables hasn't been created yet session.rollback() @@ -1059,13 +1058,13 @@ def check_task_fail_for_duplicates(session): """ minimal_table_obj = table(table_name, *[column(x) for x in uniqueness]) try: - subquery = session.execute( - select(minimal_table_obj, func.count().label("dupe_count")) + subquery = ( + session.query(minimal_table_obj, func.count().label("dupe_count")) .group_by(*[text(x) for x in uniqueness]) .having(func.count() > text("1")) .subquery() ) - dupe_count = session.scalar(select(func.sum(subquery.c.dupe_count))) + dupe_count = session.query(func.sum(subquery.c.dupe_count)).scalar() if not dupe_count: # there are no duplicates; nothing to do. return @@ -1102,7 +1101,7 @@ def check_conn_type_null(session: Session) -> Iterable[str]: n_nulls = [] try: - n_nulls = session.scalars(select(Connection.conn_id).where(Connection.conn_type.is_(None))).all() + n_nulls = session.query(Connection.conn_id).filter(Connection.conn_type.is_(None)).all() except (exc.OperationalError, exc.ProgrammingError, exc.InternalError): # fallback if tables hasn't been created yet session.rollback() @@ -1144,7 +1143,7 @@ def check_run_id_null(session: Session) -> Iterable[str]: dagrun_table.c.run_id.is_(None), dagrun_table.c.execution_date.is_(None), ) - invalid_dagrun_count = session.scalar(select(func.count(dagrun_table.c.id)).where(invalid_dagrun_filter)) + invalid_dagrun_count = session.query(func.count(dagrun_table.c.id)).filter(invalid_dagrun_filter).scalar() if invalid_dagrun_count > 0: dagrun_dangling_table_name = _format_airflow_moved_table_name(dagrun_table.name, "2.2", "dangling") if dagrun_dangling_table_name in inspect(session.get_bind()).get_table_names(): @@ -1241,7 +1240,7 @@ def _move_dangling_data_to_new_table( pk_cols = source_table.primary_key.columns delete = source_table.delete().where( - tuple_(*pk_cols).in_(session.select(*target_table.primary_key.columns).subquery()) + tuple_(*pk_cols).in_(session.query(*target_table.primary_key.columns).subquery()) ) else: delete = source_table.delete().where( @@ -1263,11 +1262,10 @@ def _dangling_against_dag_run(session, source_table, dag_run): source_table.c.dag_id == dag_run.c.dag_id, source_table.c.execution_date == dag_run.c.execution_date, ) - return ( - select(*[c.label(c.name) for c in source_table.c]) + session.query(*[c.label(c.name) for c in source_table.c]) .join(dag_run, source_to_dag_run_join_cond, isouter=True) - .where(dag_run.c.dag_id.is_(None)) + .filter(dag_run.c.dag_id.is_(None)) ) @@ -1306,10 +1304,10 @@ def _dangling_against_task_instance(session, source_table, dag_run, task_instanc ) return ( - select(*[c.label(c.name) for c in source_table.c]) + session.query(*[c.label(c.name) for c in source_table.c]) .join(dag_run, dr_join_cond, isouter=True) .join(task_instance, ti_join_cond, isouter=True) - .where(or_(task_instance.c.dag_id.is_(None), dag_run.c.dag_id.is_(None))) + .filter(or_(task_instance.c.dag_id.is_(None), dag_run.c.dag_id.is_(None))) ) @@ -1333,9 +1331,9 @@ def _move_duplicate_data_to_new_table( """ bind = session.get_bind() dialect_name = bind.dialect.name - query = ( - select(*[getattr(source_table.c, x.name).label(str(x.name)) for x in source_table.columns]) + session.query(source_table) + .with_entities(*[getattr(source_table.c, x.name).label(str(x.name)) for x in source_table.columns]) .select_from(source_table) .join(subquery, and_(*[getattr(source_table.c, x) == getattr(subquery.c, x) for x in uniqueness])) ) diff --git a/airflow/www/utils.py b/airflow/www/utils.py index b31f9326d988c..76914dd9cdcd4 100644 --- a/airflow/www/utils.py +++ b/airflow/www/utils.py @@ -39,7 +39,6 @@ from pygments.lexer import Lexer from sqlalchemy import delete, func, types from sqlalchemy.ext.associationproxy import AssociationProxy -from sqlalchemy.sql import Select from airflow.exceptions import RemovedInAirflow3Warning from airflow.models import errors @@ -54,6 +53,7 @@ from airflow.www.widgets import AirflowDateTimePickerWidget if TYPE_CHECKING: + from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session from sqlalchemy.sql.operators import ColumnOperators @@ -518,21 +518,18 @@ def _get_run_ordering_expr(name: str) -> ColumnOperators: return expr.desc() -def sorted_dag_runs( - query: Select, *, ordering: Sequence[str], limit: int, session: Session -) -> Sequence[DagRun]: +def sorted_dag_runs(query: Query, *, ordering: Sequence[str], limit: int) -> Sequence[DagRun]: """Produce DAG runs sorted by specified columns. - :param query: An ORM select object against *DagRun*. + :param query: An ORM query object against *DagRun*. :param ordering: Column names to sort the runs. should generally come from a timetable's ``run_ordering``. :param limit: Number of runs to limit to. - :param session: SQLAlchemy ORM session object :return: A list of DagRun objects ordered by the specified columns. The list contains only the *last* objects, but in *ascending* order. """ ordering_exprs = (_get_run_ordering_expr(name) for name in ordering) - runs = session.scalars(query.order_by(*ordering_exprs, DagRun.id.desc()).limit(limit)).all() + runs = query.order_by(*ordering_exprs, DagRun.id.desc()).limit(limit).all() runs.reverse() return runs diff --git a/airflow/www/views.py b/airflow/www/views.py index 90d98a424c4a8..3f0965da38898 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -69,7 +69,7 @@ from pendulum.parsing.exceptions import ParserError from pygments import highlight, lexers from pygments.formatters import HtmlFormatter -from sqlalchemy import Date, and_, case, desc, func, inspect, or_, select, union_all +from sqlalchemy import Date, and_, case, desc, func, inspect, or_, union_all from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session, joinedload from wtforms import SelectField, validators @@ -95,7 +95,7 @@ from airflow.jobs.job import Job from airflow.jobs.scheduler_job_runner import SchedulerJobRunner from airflow.jobs.triggerer_job_runner import TriggererJobRunner -from airflow.models import Connection, DagModel, DagTag, Log, SlaMiss, TaskFail, Trigger, XCom, errors +from airflow.models import Connection, DagModel, DagTag, Log, SlaMiss, TaskFail, XCom, errors from airflow.models.abstractoperator import AbstractOperator from airflow.models.dag import DAG, get_dataset_triggered_next_run_info from airflow.models.dagcode import DagCode @@ -231,15 +231,16 @@ def get_date_time_num_runs_dag_runs_form_data(www_request, session, dag): # loaded and the actual requested run would be excluded by the limit(). Once # the user has changed base date to be anything else we want to use that instead. query_date = base_date - if date_time < base_date <= date_time + datetime.timedelta(seconds=1): + if date_time < base_date and date_time + datetime.timedelta(seconds=1) >= base_date: query_date = date_time - drs = session.scalars( - select(DagRun) - .where(DagRun.dag_id == dag.dag_id, DagRun.execution_date <= query_date) + drs = ( + session.query(DagRun) + .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date <= query_date) .order_by(desc(DagRun.execution_date)) .limit(num_runs) - ).all() + .all() + ) dr_choices = [] dr_state = None for dr in drs: @@ -290,8 +291,8 @@ def dag_to_grid(dag: DagModel, dag_runs: Sequence[DagRun], session: Session): Create a nested dict representation of the DAG's TaskGroup and its children used to construct the Graph and Grid views. """ - query = session.execute( - select( + query = ( + session.query( TaskInstance.task_id, TaskInstance.run_id, TaskInstance.state, @@ -302,7 +303,7 @@ def dag_to_grid(dag: DagModel, dag_runs: Sequence[DagRun], session: Session): func.max(TaskInstance.end_date).label("end_date"), ) .join(TaskInstance.task_instance_note, isouter=True) - .where( + .filter( TaskInstance.dag_id == dag.dag_id, TaskInstance.run_id.in_([dag_run.run_id for dag_run in dag_runs]), ) @@ -425,9 +426,11 @@ def get_summary(dag_run: DagRun): } def get_mapped_group_summaries(): - mapped_ti_query = session.execute( - select(TaskInstance.task_id, TaskInstance.state, TaskInstance.run_id, TaskInstance.map_index) - .where( + mapped_ti_query = ( + session.query( + TaskInstance.task_id, TaskInstance.state, TaskInstance.run_id, TaskInstance.map_index + ) + .filter( TaskInstance.dag_id == dag.dag_id, TaskInstance.task_id.in_(child["id"] for child in children), TaskInstance.run_id.in_(r.run_id for r in dag_runs), @@ -735,20 +738,21 @@ def index(self): with create_session() as session: # read orm_dags from the db - dags_query = select(DagModel).where(~DagModel.is_subdag, DagModel.is_active) + dags_query = session.query(DagModel).filter(~DagModel.is_subdag, DagModel.is_active) if arg_search_query: escaped_arg_search_query = arg_search_query.replace("_", r"\_") - dags_query = dags_query.where( + dags_query = dags_query.filter( DagModel.dag_id.ilike("%" + escaped_arg_search_query + "%", escape="\\") | DagModel.owners.ilike("%" + escaped_arg_search_query + "%", escape="\\") ) if arg_tags_filter: - dags_query = dags_query.where(DagModel.tags.any(DagTag.name.in_(arg_tags_filter))) + dags_query = dags_query.filter(DagModel.tags.any(DagTag.name.in_(arg_tags_filter))) - dags_query = dags_query.where(DagModel.dag_id.in_(filter_dag_ids)) - filtered_dag_count = session.scalar(select(func.count()).select_from(dags_query)) + dags_query = dags_query.filter(DagModel.dag_id.in_(filter_dag_ids)) + + filtered_dag_count = dags_query.count() if filtered_dag_count == 0 and len(arg_tags_filter): flash( "No matching DAG tags found.", @@ -758,28 +762,28 @@ def index(self): return redirect(url_for("Airflow.index")) all_dags = dags_query - active_dags = dags_query.where(~DagModel.is_paused) - paused_dags = dags_query.where(DagModel.is_paused) + active_dags = dags_query.filter(~DagModel.is_paused) + paused_dags = dags_query.filter(DagModel.is_paused) # find DAGs which have a RUNNING DagRun - running_dags = dags_query.join(DagRun, DagModel.dag_id == DagRun.dag_id).where( + running_dags = dags_query.join(DagRun, DagModel.dag_id == DagRun.dag_id).filter( DagRun.state == State.RUNNING ) # find DAGs for which the latest DagRun is FAILED subq_all = ( - select(DagRun.dag_id, func.max(DagRun.start_date).label("start_date")) + session.query(DagRun.dag_id, func.max(DagRun.start_date).label("start_date")) .group_by(DagRun.dag_id) .subquery() ) subq_failed = ( - select(DagRun.dag_id, func.max(DagRun.start_date).label("start_date")) - .where(DagRun.state == State.FAILED) + session.query(DagRun.dag_id, func.max(DagRun.start_date).label("start_date")) + .filter(DagRun.state == State.FAILED) .group_by(DagRun.dag_id) .subquery() ) subq_join = ( - select(subq_all.c.dag_id, subq_all.c.start_date) + session.query(subq_all.c.dag_id, subq_all.c.start_date) .join( subq_failed, and_( @@ -792,18 +796,16 @@ def index(self): failed_dags = dags_query.join(subq_join, DagModel.dag_id == subq_join.c.dag_id) is_paused_count = dict( - session.execute( - select(DagModel.is_paused, func.count(DagModel.dag_id)) - .group_by(DagModel.is_paused) - .select_from(all_dags) - ).all() + all_dags.with_entities(DagModel.is_paused, func.count(DagModel.dag_id)).group_by( + DagModel.is_paused + ) ) status_count_active = is_paused_count.get(False, 0) status_count_paused = is_paused_count.get(True, 0) - status_count_running = session.scalar(select(func.count()).select_from(running_dags)) - status_count_failed = session.scalar(select(func.count()).select_from(failed_dags)) + status_count_running = running_dags.count() + status_count_failed = failed_dags.count() all_dags_count = status_count_active + status_count_paused if arg_status_filter == "active": @@ -824,7 +826,7 @@ def index(self): if arg_sorting_key == "last_dagrun": dag_run_subquery = ( - select( + session.query( DagRun.dag_id, sqla.func.max(DagRun.execution_date).label("max_execution_date"), ) @@ -852,13 +854,7 @@ def index(self): else: current_dags = current_dags.order_by(null_case, sort_column) - dags = ( - session.scalars( - current_dags.options(joinedload(DagModel.tags)).offset(start).limit(dags_per_page) - ) - .unique() - .all() - ) + dags = current_dags.options(joinedload(DagModel.tags)).offset(start).limit(dags_per_page).all() user_permissions = g.user.perms can_create_dag_run = ( permissions.ACTION_CAN_CREATE, @@ -878,7 +874,7 @@ def index(self): dag.can_trigger = dag.can_edit and can_create_dag_run dag.can_delete = get_airflow_app().appbuilder.sm.can_delete_dag(dag.dag_id, g.user) - dagtags = session.execute(select(func.distinct(DagTag.name)).order_by(DagTag.name)).all() + dagtags = session.query(func.distinct(DagTag.name)).order_by(DagTag.name).all() tags = [ {"name": name, "selected": bool(arg_tags_filter and name in arg_tags_filter)} for name, in dagtags @@ -886,15 +882,14 @@ def index(self): owner_links_dict = DagOwnerAttributes.get_all(session) - import_errors = select(errors.ImportError).order_by(errors.ImportError.id) + import_errors = session.query(errors.ImportError).order_by(errors.ImportError.id) if (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG) not in user_permissions: # if the user doesn't have access to all DAGs, only display errors from visible DAGs import_errors = import_errors.join( DagModel, DagModel.fileloc == errors.ImportError.filename - ).where(DagModel.dag_id.in_(filter_dag_ids)) + ).filter(DagModel.dag_id.in_(filter_dag_ids)) - import_errors = session.scalars(import_errors) for import_error in import_errors: flash( f"Broken DAG: [{import_error.filename}] {import_error.stacktrace}", @@ -938,12 +933,10 @@ def _iter_parsed_moved_data_table_names(): permissions.RESOURCE_ADMIN_MENU, ) in user_permissions and conf.getboolean("webserver", "warn_deployment_exposure"): robots_file_access_count = ( - select(Log) - .where(Log.event == "robots") - .where(Log.dttm > (utcnow() - datetime.timedelta(days=7))) - ) - robots_file_access_count = session.scalar( - select(func.count()).select_from(robots_file_access_count) + session.query(Log) + .filter(Log.event == "robots") + .filter(Log.dttm > (utcnow() - datetime.timedelta(days=7))) + .count() ) if robots_file_access_count > 0: flash( @@ -1045,11 +1038,9 @@ def next_run_datasets_summary(self, session: Session = NEW_SESSION): dataset_triggered_dag_ids = [ dag.dag_id for dag in ( - session.scalars( - select(DagModel.dag_id) - .where(DagModel.dag_id.in_(filter_dag_ids)) - .where(DagModel.schedule_interval == "Dataset") - ) + session.query(DagModel.dag_id) + .filter(DagModel.dag_id.in_(filter_dag_ids)) + .filter(DagModel.schedule_interval == "Dataset") ) ] @@ -1080,10 +1071,10 @@ def dag_stats(self, session: Session = NEW_SESSION): if not filter_dag_ids: return flask.json.jsonify({}) - dag_state_stats = session.execute( - select(DagRun.dag_id, DagRun.state, sqla.func.count(DagRun.state)) + dag_state_stats = ( + session.query(DagRun.dag_id, DagRun.state, sqla.func.count(DagRun.state)) .group_by(DagRun.dag_id, DagRun.state) - .where(DagRun.dag_id.in_(filter_dag_ids)) + .filter(DagRun.dag_id.in_(filter_dag_ids)) ) dag_state_data = {(dag_id, state): count for dag_id, state, count in dag_state_stats} @@ -1121,17 +1112,17 @@ def task_stats(self, session: Session = NEW_SESSION): filter_dag_ids = allowed_dag_ids running_dag_run_query_result = ( - select(DagRun.dag_id, DagRun.run_id) + session.query(DagRun.dag_id, DagRun.run_id) .join(DagModel, DagModel.dag_id == DagRun.dag_id) - .where(DagRun.state == State.RUNNING, DagModel.is_active) + .filter(DagRun.state == State.RUNNING, DagModel.is_active) ) - running_dag_run_query_result = running_dag_run_query_result.where(DagRun.dag_id.in_(filter_dag_ids)) + running_dag_run_query_result = running_dag_run_query_result.filter(DagRun.dag_id.in_(filter_dag_ids)) running_dag_run_query_result = running_dag_run_query_result.subquery("running_dag_run") # Select all task_instances from active dag_runs. - running_task_instance_query_result = select( + running_task_instance_query_result = session.query( TaskInstance.dag_id.label("dag_id"), TaskInstance.state.label("state"), sqla.literal(True).label("is_dag_running"), @@ -1145,19 +1136,19 @@ def task_stats(self, session: Session = NEW_SESSION): if conf.getboolean("webserver", "SHOW_RECENT_STATS_FOR_COMPLETED_RUNS", fallback=True): last_dag_run = ( - select(DagRun.dag_id, sqla.func.max(DagRun.execution_date).label("execution_date")) + session.query(DagRun.dag_id, sqla.func.max(DagRun.execution_date).label("execution_date")) .join(DagModel, DagModel.dag_id == DagRun.dag_id) - .where(DagRun.state != State.RUNNING, DagModel.is_active) + .filter(DagRun.state != State.RUNNING, DagModel.is_active) .group_by(DagRun.dag_id) ) - last_dag_run = last_dag_run.where(DagRun.dag_id.in_(filter_dag_ids)) + last_dag_run = last_dag_run.filter(DagRun.dag_id.in_(filter_dag_ids)) last_dag_run = last_dag_run.subquery("last_dag_run") # Select all task_instances from active dag_runs. # If no dag_run is active, return task instances from most recent dag_run. last_task_instance_query_result = ( - select( + session.query( TaskInstance.dag_id.label("dag_id"), TaskInstance.state.label("state"), sqla.literal(False).label("is_dag_running"), @@ -1178,8 +1169,8 @@ def task_stats(self, session: Session = NEW_SESSION): else: final_task_instance_query_result = running_task_instance_query_result.subquery("final_ti") - qry = session.execute( - select( + qry = ( + session.query( final_task_instance_query_result.c.dag_id, final_task_instance_query_result.c.state, final_task_instance_query_result.c.is_dag_running, @@ -1195,6 +1186,7 @@ def task_stats(self, session: Session = NEW_SESSION): final_task_instance_query_result.c.is_dag_running.desc(), ) ) + data = get_task_stats_from_query(qry) payload: dict[str, list[dict[str, Any]]] = collections.defaultdict(list) for dag_id in filter_dag_ids: @@ -1227,31 +1219,29 @@ def last_dagruns(self, session: Session = NEW_SESSION): return flask.json.jsonify({}) last_runs_subquery = ( - select( + session.query( DagRun.dag_id, sqla.func.max(DagRun.execution_date).label("max_execution_date"), ) .group_by(DagRun.dag_id) - .where(DagRun.dag_id.in_(filter_dag_ids)) # Only include accessible/selected DAGs. + .filter(DagRun.dag_id.in_(filter_dag_ids)) # Only include accessible/selected DAGs. .subquery("last_runs") ) - query = session.execute( - select( - DagRun.dag_id, - DagRun.start_date, - DagRun.end_date, - DagRun.state, - DagRun.execution_date, - DagRun.data_interval_start, - DagRun.data_interval_end, - ).join( - last_runs_subquery, - and_( - last_runs_subquery.c.dag_id == DagRun.dag_id, - last_runs_subquery.c.max_execution_date == DagRun.execution_date, - ), - ) + query = session.query( + DagRun.dag_id, + DagRun.start_date, + DagRun.end_date, + DagRun.state, + DagRun.execution_date, + DagRun.data_interval_start, + DagRun.data_interval_end, + ).join( + last_runs_subquery, + and_( + last_runs_subquery.c.dag_id == DagRun.dag_id, + last_runs_subquery.c.max_execution_date == DagRun.execution_date, + ), ) resp = { @@ -1350,18 +1340,19 @@ def dag_details(self, dag_id, session: Session = NEW_SESSION): title = "DAG Details" root = request.args.get("root", "") - states = session.execute( - select(TaskInstance.state, sqla.func.count(TaskInstance.dag_id)) - .where(TaskInstance.dag_id == dag_id) + states = ( + session.query(TaskInstance.state, sqla.func.count(TaskInstance.dag_id)) + .filter(TaskInstance.dag_id == dag_id) .group_by(TaskInstance.state) - ).all() + .all() + ) active_runs = models.DagRun.find(dag_id=dag_id, state=DagRunState.RUNNING, external_trigger=False) - tags = session.scalars(select(models.DagTag).where(models.DagTag.dag_id == dag_id)).all() + tags = session.query(models.DagTag).filter(models.DagTag.dag_id == dag_id).all() # TODO: convert this to a relationship - owner_links = session.execute(select(DagOwnerAttributes).filter_by(dag_id=dag_id)).all() + owner_links = session.query(DagOwnerAttributes).filter_by(dag_id=dag_id).all() attrs_to_avoid = [ "schedule_datasets", @@ -1622,17 +1613,18 @@ def get_logs_with_metadata(self, session: Session = NEW_SESSION): "metadata": {"end_of_log": True}, } - ti = session.scalar( - select(models.TaskInstance) - .where( + ti = ( + session.query(models.TaskInstance) + .filter( TaskInstance.task_id == task_id, TaskInstance.dag_id == dag_id, TaskInstance.execution_date == execution_date, TaskInstance.map_index == map_index, ) .join(TaskInstance.dag_run) - .options(joinedload(TaskInstance.trigger).joinedload(Trigger.triggerer_job)) - .limit(1) + .options(joinedload("trigger")) + .options(joinedload("trigger.triggerer_job")) + .first() ) if ti is None: @@ -1690,10 +1682,10 @@ def log(self, session: Session = NEW_SESSION): form = DateTimeForm(data={"execution_date": dttm}) dag_model = DagModel.get_dagmodel(dag_id) - ti = session.scalar( - select(models.TaskInstance) + ti = ( + session.query(models.TaskInstance) .filter_by(dag_id=dag_id, task_id=task_id, execution_date=dttm, map_index=map_index) - .limit(1) + .first() ) num_logs = 0 @@ -1734,10 +1726,10 @@ def redirect_to_external_log(self, session: Session = NEW_SESSION): map_index = request.args.get("map_index", -1, type=int) try_number = request.args.get("try_number", 1) - ti = session.scalar( - select(models.TaskInstance) + ti = ( + session.query(models.TaskInstance) .filter_by(dag_id=dag_id, task_id=task_id, execution_date=dttm, map_index=map_index) - .limit(1) + .first() ) if not ti: @@ -1779,8 +1771,8 @@ def task(self, session: Session = NEW_SESSION): task = copy.copy(dag.get_task(task_id)) task.resolve_template_files() - ti: TaskInstance | None = session.scalar( - select(TaskInstance) + ti: TaskInstance | None = ( + session.query(TaskInstance) .options( # HACK: Eager-load relationships. This is needed because # multiple properties mis-use provide_session() that destroys @@ -1789,6 +1781,7 @@ def task(self, session: Session = NEW_SESSION): joinedload(TaskInstance.trigger, innerjoin=False), ) .filter_by(execution_date=dttm, dag_id=dag_id, task_id=task_id, map_index=map_index) + .one_or_none() ) if ti is None: ti_attrs: list[tuple[str, Any]] | None = None @@ -1905,19 +1898,17 @@ def xcom(self, session: Session = NEW_SESSION): form = DateTimeForm(data={"execution_date": dttm}) root = request.args.get("root", "") dag = DagModel.get_dagmodel(dag_id) - ti = session.scalar(select(TaskInstance).filter_by(dag_id=dag_id, task_id=task_id).limit(1)) + ti = session.query(TaskInstance).filter_by(dag_id=dag_id, task_id=task_id).first() if not ti: flash(f"Task [{dag_id}.{task_id}] doesn't seem to exist at the moment", "error") return redirect(url_for("Airflow.index")) - xcom_query = session.execute( - select(XCom.key, XCom.value).where( - XCom.dag_id == dag_id, - XCom.task_id == task_id, - XCom.execution_date == dttm, - XCom.map_index == map_index, - ) + xcom_query = session.query(XCom.key, XCom.value).filter( + XCom.dag_id == dag_id, + XCom.task_id == task_id, + XCom.execution_date == dttm, + XCom.map_index == map_index, ) attributes = [(k, v) for k, v in xcom_query if not k.startswith("_")] @@ -1987,7 +1978,7 @@ def trigger(self, dag_id: str, session: Session = NEW_SESSION): request_execution_date = request.values.get("execution_date", default=timezone.utcnow().isoformat()) is_dag_run_conf_overrides_params = conf.getboolean("core", "dag_run_conf_overrides_params") dag = get_airflow_app().dag_bag.get_dag(dag_id) - dag_orm: DagModel = session.scalar(select(DagModel).where(DagModel.dag_id == dag_id).limit(1)) + dag_orm: DagModel = session.query(DagModel).filter(DagModel.dag_id == dag_id).first() # Prepare form fields with param struct details to render a proper form with schema information form_fields = {} @@ -2025,9 +2016,11 @@ def trigger(self, dag_id: str, session: Session = NEW_SESSION): flash(f"Cannot create dagruns because the dag {dag_id} has import errors", "error") return redirect(origin) - recent_runs = session.execute( - select(DagRun.conf, func.max(DagRun.run_id).label("run_id"), func.max(DagRun.execution_date)) - .where( + recent_runs = ( + session.query( + DagRun.conf, func.max(DagRun.run_id).label("run_id"), func.max(DagRun.execution_date) + ) + .filter( DagRun.dag_id == dag_id, DagRun.run_type == DagRunType.MANUAL, DagRun.conf.isnot(None), @@ -2300,17 +2293,15 @@ def clear(self, *, session: Session = NEW_SESSION): # Lock the related dag runs to prevent from possible dead lock. # https://github.com/apache/airflow/pull/26658 - dag_runs_query = session.scalars( - select(DagRun.id).where(DagRun.dag_id == dag_id).with_for_update() - ) + dag_runs_query = session.query(DagRun.id).filter(DagRun.dag_id == dag_id).with_for_update() if start_date is None and end_date is None: - dag_runs_query = dag_runs_query.where(DagRun.execution_date == start_date) + dag_runs_query = dag_runs_query.filter(DagRun.execution_date == start_date) else: if start_date is not None: - dag_runs_query = dag_runs_query.where(DagRun.execution_date >= start_date) + dag_runs_query = dag_runs_query.filter(DagRun.execution_date >= start_date) if end_date is not None: - dag_runs_query = dag_runs_query.where(DagRun.execution_date <= end_date) + dag_runs_query = dag_runs_query.filter(DagRun.execution_date <= end_date) locked_dag_run_ids = dag_runs_query.all() elif task_id: @@ -2399,10 +2390,10 @@ def blocked(self, session: Session = NEW_SESSION): if not filter_dag_ids: return flask.json.jsonify([]) - dags = session.execute( - select(DagRun.dag_id, sqla.func.count(DagRun.id)) - .where(DagRun.state == DagRunState.RUNNING) - .where(DagRun.dag_id.in_(filter_dag_ids)) + dags = ( + session.query(DagRun.dag_id, sqla.func.count(DagRun.id)) + .filter(DagRun.state == DagRunState.RUNNING) + .filter(DagRun.dag_id.in_(filter_dag_ids)) .group_by(DagRun.dag_id) ) @@ -2483,11 +2474,9 @@ def _mark_dagrun_state_as_queued( # Identify tasks that will be queued up to run when confirmed all_task_ids = [task.task_id for task in dag.tasks] - existing_tis = session.execute( - select(TaskInstance.task_id).where( - TaskInstance.dag_id == dag.dag_id, - TaskInstance.run_id == dag_run_id, - ) + existing_tis = session.query(TaskInstance.task_id).filter( + TaskInstance.dag_id == dag.dag_id, + TaskInstance.run_id == dag_run_id, ) completed_tis_ids = [task_id for task_id, in existing_tis] @@ -2969,18 +2958,19 @@ def _convert_to_date(session, column): if root: dag = dag.partial_subset(task_ids_or_regex=root, include_downstream=False, include_upstream=True) - dag_states = session.execute( - select( - _convert_to_date(session, DagRun.execution_date).label("date"), + dag_states = ( + session.query( + (_convert_to_date(session, DagRun.execution_date)).label("date"), DagRun.state, func.max(DagRun.data_interval_start).label("data_interval_start"), func.max(DagRun.data_interval_end).label("data_interval_end"), func.count("*").label("count"), ) - .where(DagRun.dag_id == dag.dag_id) + .filter(DagRun.dag_id == dag.dag_id) .group_by(_convert_to_date(session, DagRun.execution_date), DagRun.state) .order_by(_convert_to_date(session, DagRun.execution_date).asc()) - ).all() + .all() + ) data_dag_states = [ { @@ -3251,17 +3241,16 @@ def duration(self, dag_id: str, session: Session = NEW_SESSION): else: min_date = timezone.utc_epoch() ti_fails = ( - select(TaskFail) + session.query(TaskFail) .join(TaskFail.dag_run) - .where( + .filter( TaskFail.dag_id == dag.dag_id, DagRun.execution_date >= min_date, DagRun.execution_date <= base_date, ) ) if dag.partial: - ti_fails = ti_fails.where(TaskFail.task_id.in_([t.task_id for t in dag.tasks])) - ti_fails = session.scalars(ti_fails) + ti_fails = ti_fails.filter(TaskFail.task_id.in_([t.task_id for t in dag.tasks])) fails_totals: dict[tuple[str, str, str], int] = defaultdict(int) for failed_task_instance in ti_fails: dict_key = ( @@ -3602,9 +3591,9 @@ def gantt(self, dag_id: str, session: Session = NEW_SESSION): form = DateTimeWithNumRunsWithDagRunsForm(data=dt_nr_dr_data) form.execution_date.choices = dt_nr_dr_data["dr_choices"] - tis = session.scalars( - select(TaskInstance) - .where( + tis = ( + session.query(TaskInstance) + .filter( TaskInstance.dag_id == dag_id, TaskInstance.run_id == dag_run_id, TaskInstance.start_date.is_not(None), @@ -3613,10 +3602,10 @@ def gantt(self, dag_id: str, session: Session = NEW_SESSION): .order_by(TaskInstance.start_date) ) - ti_fails = select(TaskFail).filter_by(run_id=dag_run_id, dag_id=dag_id) + ti_fails = session.query(TaskFail).filter_by(run_id=dag_run_id, dag_id=dag_id) if dag.partial: - ti_fails = ti_fails.where(TaskFail.task_id.in_([t.task_id for t in dag.tasks])) - ti_fails = session.scalars(ti_fails) + ti_fails = ti_fails.filter(TaskFail.task_id.in_([t.task_id for t in dag.tasks])) + tasks = [] for ti in tis: if not dag.has_task(ti.task_id): @@ -3722,13 +3711,12 @@ def extra_links(self, *, session: Session = NEW_SESSION): if link_name is None: return {"url": None, "error": "Link name not passed"}, 400 - ti = session.scalar( - select(TaskInstance) + ti = ( + session.query(TaskInstance) .filter_by(dag_id=dag_id, task_id=task_id, execution_date=dttm, map_index=map_index) .options(joinedload(TaskInstance.dag_run)) - .limit(1) + .first() ) - if not ti: return {"url": None, "error": "Task Instances not found"}, 404 try: @@ -3836,25 +3824,27 @@ def grid_data(self): base_date = dag.get_latest_execution_date() or timezone.utcnow() with create_session() as session: - query = select(DagRun).where(DagRun.dag_id == dag.dag_id, DagRun.execution_date <= base_date) + query = session.query(DagRun).filter( + DagRun.dag_id == dag.dag_id, DagRun.execution_date <= base_date + ) - run_type = request.args.get("run_type") - if run_type: - query = query.where(DagRun.run_type == run_type) + run_type = request.args.get("run_type") + if run_type: + query = query.filter(DagRun.run_type == run_type) - run_state = request.args.get("run_state") - if run_state: - query = query.where(DagRun.state == run_state) + run_state = request.args.get("run_state") + if run_state: + query = query.filter(DagRun.state == run_state) - dag_runs = wwwutils.sorted_dag_runs( - query, ordering=dag.timetable.run_ordering, limit=num_runs, session=session - ) - encoded_runs = [wwwutils.encode_dag_run(dr, json_encoder=utils_json.WebEncoder) for dr in dag_runs] - data = { - "groups": dag_to_grid(dag, dag_runs, session), - "dag_runs": encoded_runs, - "ordering": dag.timetable.run_ordering, - } + dag_runs = wwwutils.sorted_dag_runs(query, ordering=dag.timetable.run_ordering, limit=num_runs) + encoded_runs = [ + wwwutils.encode_dag_run(dr, json_encoder=utils_json.WebEncoder) for dr in dag_runs + ] + data = { + "groups": dag_to_grid(dag, dag_runs, session), + "dag_runs": encoded_runs, + "ordering": dag.timetable.run_ordering, + } # avoid spaces to reduce payload size return ( htmlsafe_json_dumps(data, separators=(",", ":"), dumps=flask.json.dumps), @@ -3873,34 +3863,37 @@ def historical_metrics_data(self): end_date = _safe_parse_datetime(request.args.get("end_date")) with create_session() as session: # DagRuns - dag_runs_type = session.execute( - select(DagRun.run_type, func.count(DagRun.run_id)) - .where( + dag_runs_type = ( + session.query(DagRun.run_type, func.count(DagRun.run_id)) + .filter( DagRun.start_date >= start_date, or_(DagRun.end_date.is_(None), DagRun.end_date <= end_date), ) .group_by(DagRun.run_type) - ).all() + .all() + ) - dag_run_states = session.execute( - select(DagRun.state, func.count(DagRun.run_id)) - .where( + dag_run_states = ( + session.query(DagRun.state, func.count(DagRun.run_id)) + .filter( DagRun.start_date >= start_date, or_(DagRun.end_date.is_(None), DagRun.end_date <= end_date), ) .group_by(DagRun.state) - ).all() + .all() + ) # TaskInstances - task_instance_states = session.execute( - select(TaskInstance.state, func.count(TaskInstance.run_id)) + task_instance_states = ( + session.query(TaskInstance.state, func.count(TaskInstance.run_id)) .join(TaskInstance.dag_run) - .where( + .filter( DagRun.start_date >= start_date, or_(DagRun.end_date.is_(None), DagRun.end_date <= end_date), ) .group_by(TaskInstance.state) - ).all() + .all() + ) data = { "dag_run_types": { @@ -3934,32 +3927,28 @@ def next_run_datasets(self, dag_id): with create_session() as session: data = [ dict(info) - for info in session.execute( - select( - DatasetModel.id, - DatasetModel.uri, - func.max(DatasetEvent.timestamp).label("lastUpdate"), - ) - .join( - DagScheduleDatasetReference, DagScheduleDatasetReference.dataset_id == DatasetModel.id - ) - .join( - DatasetDagRunQueue, - and_( - DatasetDagRunQueue.dataset_id == DatasetModel.id, - DatasetDagRunQueue.target_dag_id == DagScheduleDatasetReference.dag_id, - ), - isouter=True, - ) - .join( - DatasetEvent, - DatasetEvent.dataset_id == DatasetModel.id, - isouter=True, - ) - .where(DagScheduleDatasetReference.dag_id == dag_id, ~DatasetModel.is_orphaned) - .group_by(DatasetModel.id, DatasetModel.uri) - .order_by(DatasetModel.uri) + for info in session.query( + DatasetModel.id, + DatasetModel.uri, + func.max(DatasetEvent.timestamp).label("lastUpdate"), ) + .join(DagScheduleDatasetReference, DagScheduleDatasetReference.dataset_id == DatasetModel.id) + .join( + DatasetDagRunQueue, + and_( + DatasetDagRunQueue.dataset_id == DatasetModel.id, + DatasetDagRunQueue.target_dag_id == DagScheduleDatasetReference.dag_id, + ), + isouter=True, + ) + .join( + DatasetEvent, + DatasetEvent.dataset_id == DatasetModel.id, + isouter=True, + ) + .filter(DagScheduleDatasetReference.dag_id == dag_id, ~DatasetModel.is_orphaned) + .group_by(DatasetModel.id, DatasetModel.uri) + .order_by(DatasetModel.uri) ] return ( htmlsafe_json_dumps(data, separators=(",", ":"), dumps=flask.json.dumps), @@ -4057,12 +4046,12 @@ def datasets_summary(self): if session.bind.dialect.name == "postgresql": order_by = (order_by[0].nulls_first(), *order_by[1:]) - count_query = select(func.count(DatasetModel.id)) + count_query = session.query(func.count(DatasetModel.id)) has_event_filters = bool(updated_before or updated_after) query = ( - select( + session.query( DatasetModel.id, DatasetModel.uri, func.max(DatasetEvent.timestamp).label("last_dataset_update"), @@ -4087,12 +4076,11 @@ def datasets_summary(self): if updated_before: filters.append(DatasetEvent.timestamp <= updated_before) - query = query.where(*filters).offset(offset).limit(limit) - count_query = count_query.where(*filters) + query = query.filter(*filters).offset(offset).limit(limit) + count_query = count_query.filter(*filters) - query = session.execute(query) datasets = [dict(dataset) for dataset in query] - data = {"datasets": datasets, "total_entries": session.scalar(count_query)} + data = {"datasets": datasets, "total_entries": count_query.scalar()} return ( htmlsafe_json_dumps(data, separators=(",", ":"), cls=utils_json.WebEncoder), @@ -4138,20 +4126,20 @@ def audit_log(self, dag_id: str, session: Session = NEW_SESSION): included_events_raw = conf.get("webserver", "audit_view_included_events", fallback=None) excluded_events_raw = conf.get("webserver", "audit_view_excluded_events", fallback=None) - query = select(Log).where(Log.dag_id == dag_id) + query = session.query(Log).filter(Log.dag_id == dag_id) if included_events_raw: included_events = {event.strip() for event in included_events_raw.split(",")} - query = query.where(Log.event.in_(included_events)) + query = query.filter(Log.event.in_(included_events)) elif excluded_events_raw: excluded_events = {event.strip() for event in excluded_events_raw.split(",")} - query = query.where(Log.event.notin_(excluded_events)) + query = query.filter(Log.event.notin_(excluded_events)) current_page = request.args.get("page", default=0, type=int) arg_sorting_key = request.args.get("sorting_key", "dttm") arg_sorting_direction = request.args.get("sorting_direction", default="desc") logs_per_page = PAGE_SIZE - audit_logs_count = session.scalar(select(func.count()).select_from(query)) + audit_logs_count = query.count() num_of_pages = int(math.ceil(audit_logs_count / float(logs_per_page))) start = current_page * logs_per_page @@ -4163,7 +4151,7 @@ def audit_log(self, dag_id: str, session: Session = NEW_SESSION): sort_column = sort_column.desc() query = query.order_by(sort_column) - dag_audit_logs = session.scalars(query.offset(start).limit(logs_per_page)).all() + dag_audit_logs = query.offset(start).limit(logs_per_page).all() return self.render_template( "airflow/dag_audit_log.html", dag=dag, @@ -4284,7 +4272,7 @@ def apply(self, query, func): if get_airflow_app().appbuilder.sm.has_all_dags_access(g.user): return query filter_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) - return query.where(self.model.dag_id.in_(filter_dag_ids)) + return query.filter(self.model.dag_id.in_(filter_dag_ids)) class AirflowModelView(ModelView): @@ -4729,11 +4717,9 @@ def action_mulduplicate(self, connections, session: Session = NEW_SESSION): potential_connection_ids = [f"{base_conn_id}_copy{i}" for i in range(1, 11)] - query = session.scalars( - select(Connection.conn_id).where(Connection.conn_id.in_(potential_connection_ids)) - ) + query = session.query(Connection.conn_id).filter(Connection.conn_id.in_(potential_connection_ids)) - found_conn_id_set = {conn_id for conn_id in query} + found_conn_id_set = {conn_id for conn_id, in query} possible_conn_id_iter = ( connection_id @@ -5403,7 +5389,7 @@ def _set_dag_runs_to_active_state(self, drs: list[DagRun], state: str, session: """This routine only supports Running and Queued state.""" try: count = 0 - for dr in session.scalars(select(DagRun).where(DagRun.id.in_(dagrun.id for dagrun in drs))): + for dr in session.query(DagRun).filter(DagRun.id.in_(dagrun.id for dagrun in drs)): count += 1 if state == State.RUNNING: dr.start_date = timezone.utcnow() @@ -5429,7 +5415,7 @@ def action_set_failed(self, drs: list[DagRun], session: Session = NEW_SESSION): try: count = 0 altered_tis = [] - for dr in session.scalars(select(DagRun).where(DagRun.id.in_(dagrun.id for dagrun in drs))): + for dr in session.query(DagRun).filter(DagRun.id.in_(dagrun.id for dagrun in drs)): count += 1 altered_tis += set_dag_run_state_to_failed( dag=get_airflow_app().dag_bag.get_dag(dr.dag_id), @@ -5457,7 +5443,7 @@ def action_set_success(self, drs: list[DagRun], session: Session = NEW_SESSION): try: count = 0 altered_tis = [] - for dr in session.scalars(select(DagRun).where(DagRun.id.in_(dagrun.id for dagrun in drs))): + for dr in session.query(DagRun).filter(DagRun.id.in_(dagrun.id for dagrun in drs)): count += 1 altered_tis += set_dag_run_state_to_success( dag=get_airflow_app().dag_bag.get_dag(dr.dag_id), @@ -5481,7 +5467,7 @@ def action_clear(self, drs: list[DagRun], session: Session = NEW_SESSION): count = 0 cleared_ti_count = 0 dag_to_tis: dict[DAG, list[TaskInstance]] = {} - for dr in session.scalars(select(DagRun).where(DagRun.id.in_(dagrun.id for dagrun in drs))): + for dr in session.query(DagRun).filter(DagRun.id.in_(dagrun.id for dagrun in drs)): count += 1 dag = get_airflow_app().dag_bag.get_dag(dr.dag_id) tis_to_clear = dag_to_tis.setdefault(dag, []) @@ -5897,37 +5883,35 @@ def autocomplete(self, session: Session = NEW_SESSION): return flask.json.jsonify([]) # Provide suggestions of dag_ids and owners - dag_ids_query = select( + dag_ids_query = session.query( sqla.literal("dag").label("type"), DagModel.dag_id.label("name"), - ).where(~DagModel.is_subdag, DagModel.is_active, DagModel.dag_id.ilike(f"%{query}%")) + ).filter(~DagModel.is_subdag, DagModel.is_active, DagModel.dag_id.ilike(f"%{query}%")) owners_query = ( - select( + session.query( sqla.literal("owner").label("type"), DagModel.owners.label("name"), ) .distinct() - .where(~DagModel.is_subdag, DagModel.is_active, DagModel.owners.ilike(f"%{query}%")) + .filter(~DagModel.is_subdag, DagModel.is_active, DagModel.owners.ilike(f"%{query}%")) ) # Hide DAGs if not showing status: "all" status = flask_session.get(FILTER_STATUS_COOKIE) if status == "active": - dag_ids_query = dag_ids_query.where(~DagModel.is_paused) - owners_query = owners_query.where(~DagModel.is_paused) + dag_ids_query = dag_ids_query.filter(~DagModel.is_paused) + owners_query = owners_query.filter(~DagModel.is_paused) elif status == "paused": - dag_ids_query = dag_ids_query.where(DagModel.is_paused) - owners_query = owners_query.where(DagModel.is_paused) + dag_ids_query = dag_ids_query.filter(DagModel.is_paused) + owners_query = owners_query.filter(DagModel.is_paused) filter_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) - dag_ids_query = dag_ids_query.where(DagModel.dag_id.in_(filter_dag_ids)) - owners_query = owners_query.where(DagModel.dag_id.in_(filter_dag_ids)) - payload = [ - row._asdict() - for row in session.execute(dag_ids_query.union(owners_query).order_by("name").limit(10)) - ] + dag_ids_query = dag_ids_query.filter(DagModel.dag_id.in_(filter_dag_ids)) + owners_query = owners_query.filter(DagModel.dag_id.in_(filter_dag_ids)) + + payload = [row._asdict() for row in dag_ids_query.union(owners_query).order_by("name").limit(10)] return flask.json.jsonify(payload)