From 22026901c573ec831d4778bc766bd357d67a3358 Mon Sep 17 00:00:00 2001 From: Phani Kumar Date: Thu, 15 Jun 2023 18:53:37 +0530 Subject: [PATCH 1/5] Refactor Sqlalchemy queries to 2.0 style --- airflow/utils/db.py | 38 ++++++++++++++++++++------------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index a76f0d4f675d1..6a1f1034e942f 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -92,7 +92,7 @@ def _format_airflow_moved_table_name(source_table, version, category): @provide_session def merge_conn(conn, session: Session = NEW_SESSION): """Add new Connection.""" - if not session.query(conn.__class__).filter_by(conn_id=conn.conn_id).first(): + if not session.scalar(select(conn.__class__).filter_by(conn_id=conn.conn_id).limit(1)): session.add(conn) session.commit() @@ -959,7 +959,9 @@ def check_conn_id_duplicates(session: Session) -> Iterable[str]: dups = [] try: - dups = session.query(Connection.conn_id).group_by(Connection.conn_id).having(func.count() > 1).all() + dups = session.execute( + select(Connection.conn_id).group_by(Connection.conn_id).having(func.count() > 1) + ).all() except (exc.OperationalError, exc.ProgrammingError): # fallback if tables hasn't been created yet session.rollback() @@ -984,12 +986,11 @@ def check_username_duplicates(session: Session) -> Iterable[str]: for model in [User, RegisterUser]: dups = [] try: - dups = ( - session.query(model.username) # type: ignore[attr-defined] + dups = session.execute( + select(model.username) # type: ignore[attr-defined] .group_by(model.username) # type: ignore[attr-defined] .having(func.count() > 1) - .all() - ) + ).all() except (exc.OperationalError, exc.ProgrammingError): # fallback if tables hasn't been created yet session.rollback() @@ -1058,13 +1059,13 @@ def check_task_fail_for_duplicates(session): """ minimal_table_obj = table(table_name, *[column(x) for x in uniqueness]) try: - subquery = ( - session.query(minimal_table_obj, func.count().label("dupe_count")) + subquery = session.execute( + select(minimal_table_obj, func.count().label("dupe_count")) .group_by(*[text(x) for x in uniqueness]) .having(func.count() > text("1")) .subquery() ) - dupe_count = session.query(func.sum(subquery.c.dupe_count)).scalar() + dupe_count = session.scalar(select(func.sum(subquery.c.dupe_count))) if not dupe_count: # there are no duplicates; nothing to do. return @@ -1101,7 +1102,7 @@ def check_conn_type_null(session: Session) -> Iterable[str]: n_nulls = [] try: - n_nulls = session.query(Connection.conn_id).filter(Connection.conn_type.is_(None)).all() + n_nulls = session.execute(select(Connection.conn_id).filter(Connection.conn_type.is_(None))).all() except (exc.OperationalError, exc.ProgrammingError, exc.InternalError): # fallback if tables hasn't been created yet session.rollback() @@ -1143,7 +1144,7 @@ def check_run_id_null(session: Session) -> Iterable[str]: dagrun_table.c.run_id.is_(None), dagrun_table.c.execution_date.is_(None), ) - invalid_dagrun_count = session.query(func.count(dagrun_table.c.id)).filter(invalid_dagrun_filter).scalar() + invalid_dagrun_count = session.scalar(select(func.count(dagrun_table.c.id)).filter(invalid_dagrun_filter)) if invalid_dagrun_count > 0: dagrun_dangling_table_name = _format_airflow_moved_table_name(dagrun_table.name, "2.2", "dangling") if dagrun_dangling_table_name in inspect(session.get_bind()).get_table_names(): @@ -1240,7 +1241,7 @@ def _move_dangling_data_to_new_table( pk_cols = source_table.primary_key.columns delete = source_table.delete().where( - tuple_(*pk_cols).in_(session.query(*target_table.primary_key.columns).subquery()) + tuple_(*pk_cols).in_(session.select(*target_table.primary_key.columns).subquery()) ) else: delete = source_table.delete().where( @@ -1262,10 +1263,11 @@ def _dangling_against_dag_run(session, source_table, dag_run): source_table.c.dag_id == dag_run.c.dag_id, source_table.c.execution_date == dag_run.c.execution_date, ) + return ( - session.query(*[c.label(c.name) for c in source_table.c]) + select(*[c.label(c.name) for c in source_table.c]) .join(dag_run, source_to_dag_run_join_cond, isouter=True) - .filter(dag_run.c.dag_id.is_(None)) + .where(dag_run.c.dag_id.is_(None)) ) @@ -1304,10 +1306,10 @@ def _dangling_against_task_instance(session, source_table, dag_run, task_instanc ) return ( - session.query(*[c.label(c.name) for c in source_table.c]) + select(*[c.label(c.name) for c in source_table.c]) .join(dag_run, dr_join_cond, isouter=True) .join(task_instance, ti_join_cond, isouter=True) - .filter(or_(task_instance.c.dag_id.is_(None), dag_run.c.dag_id.is_(None))) + .where(or_(task_instance.c.dag_id.is_(None), dag_run.c.dag_id.is_(None))) ) @@ -1331,9 +1333,9 @@ def _move_duplicate_data_to_new_table( """ bind = session.get_bind() dialect_name = bind.dialect.name + query = ( - session.query(source_table) - .with_entities(*[getattr(source_table.c, x.name).label(str(x.name)) for x in source_table.columns]) + session.select(*[getattr(source_table.c, x.name).label(str(x.name)) for x in source_table.columns]) .select_from(source_table) .join(subquery, and_(*[getattr(source_table.c, x) == getattr(subquery.c, x) for x in uniqueness])) ) From b9a67d1779e731dbc6fefdf9fb02f0af8aceb8d0 Mon Sep 17 00:00:00 2001 From: Phani Kumar Date: Wed, 28 Jun 2023 17:56:15 +0530 Subject: [PATCH 2/5] Include changes to views --- airflow/www/views.py | 23 +++++++++++------------ 1 file changed, 11 insertions(+), 12 deletions(-) diff --git a/airflow/www/views.py b/airflow/www/views.py index 026f8b533499d..2969ed7a8af76 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -69,7 +69,7 @@ from pendulum.parsing.exceptions import ParserError from pygments import highlight, lexers from pygments.formatters import HtmlFormatter -from sqlalchemy import Date, and_, case, desc, func, inspect, or_, union_all +from sqlalchemy import Date, and_, case, desc, func, inspect, or_, select, union_all from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session, joinedload from wtforms import SelectField, validators @@ -3245,9 +3245,9 @@ def duration(self, dag_id: str, session: Session = NEW_SESSION): else: min_date = timezone.utc_epoch() ti_fails = ( - session.query(TaskFail) + select(TaskFail) .join(TaskFail.dag_run) - .filter( + .where( TaskFail.dag_id == dag.dag_id, DagRun.execution_date >= min_date, DagRun.execution_date <= base_date, @@ -3596,8 +3596,8 @@ def gantt(self, dag_id: str, session: Session = NEW_SESSION): form.execution_date.choices = dt_nr_dr_data["dr_choices"] tis = ( - session.query(TaskInstance) - .filter( + select(TaskInstance) + .where( TaskInstance.dag_id == dag_id, TaskInstance.run_id == dag_run_id, TaskInstance.start_date.is_not(None), @@ -3606,7 +3606,7 @@ def gantt(self, dag_id: str, session: Session = NEW_SESSION): .order_by(TaskInstance.start_date) ) - ti_fails = session.query(TaskFail).filter_by(run_id=dag_run_id, dag_id=dag_id) + ti_fails = select(TaskFail).filter_by(run_id=dag_run_id, dag_id=dag_id) if dag.partial: ti_fails = ti_fails.filter(TaskFail.task_id.in_([t.task_id for t in dag.tasks])) @@ -3715,12 +3715,13 @@ def extra_links(self, *, session: Session = NEW_SESSION): if link_name is None: return {"url": None, "error": "Link name not passed"}, 400 - ti = ( - session.query(TaskInstance) + ti = session.scalar( + select(TaskInstance) .filter_by(dag_id=dag_id, task_id=task_id, execution_date=dttm, map_index=map_index) .options(joinedload(TaskInstance.dag_run)) - .first() + .limit(1) ) + if not ti: return {"url": None, "error": "Task Instances not found"}, 404 try: @@ -3828,9 +3829,7 @@ def grid_data(self): base_date = dag.get_latest_execution_date() or timezone.utcnow() with create_session() as session: - query = session.query(DagRun).filter( - DagRun.dag_id == dag.dag_id, DagRun.execution_date <= base_date - ) + query = select(DagRun).where(DagRun.dag_id == dag.dag_id, DagRun.execution_date <= base_date) run_type = request.args.get("run_type") if run_type: From 4bc06cbe2ffd42393cd11fd2de027d51c371b49a Mon Sep 17 00:00:00 2001 From: Phani Kumar Date: Fri, 30 Jun 2023 20:17:44 +0530 Subject: [PATCH 3/5] More SqlAlchemy 2.0 changes --- airflow/www/utils.py | 11 +- airflow/www/views.py | 319 ++++++++++++++++++++++--------------------- 2 files changed, 170 insertions(+), 160 deletions(-) diff --git a/airflow/www/utils.py b/airflow/www/utils.py index 25fc1a28f98f9..46256ee3359e0 100644 --- a/airflow/www/utils.py +++ b/airflow/www/utils.py @@ -38,6 +38,7 @@ from pygments.formatters import HtmlFormatter from sqlalchemy import delete, func, types from sqlalchemy.ext.associationproxy import AssociationProxy +from sqlalchemy.sql import Select from airflow.exceptions import RemovedInAirflow3Warning from airflow.models import errors @@ -52,7 +53,6 @@ from airflow.www.widgets import AirflowDateTimePickerWidget if TYPE_CHECKING: - from sqlalchemy.orm.query import Query from sqlalchemy.orm.session import Session from sqlalchemy.sql.operators import ColumnOperators @@ -517,18 +517,21 @@ def _get_run_ordering_expr(name: str) -> ColumnOperators: return expr.desc() -def sorted_dag_runs(query: Query, *, ordering: Sequence[str], limit: int) -> Sequence[DagRun]: +def sorted_dag_runs( + query: Select, *, ordering: Sequence[str], limit: int, session: Session +) -> Sequence[DagRun]: """Produce DAG runs sorted by specified columns. - :param query: An ORM query object against *DagRun*. + :param query: An ORM select object against *DagRun*. :param ordering: Column names to sort the runs. should generally come from a timetable's ``run_ordering``. :param limit: Number of runs to limit to. + :param session: SQLAlchemy ORM session object :return: A list of DagRun objects ordered by the specified columns. The list contains only the *last* objects, but in *ascending* order. """ ordering_exprs = (_get_run_ordering_expr(name) for name in ordering) - runs = query.order_by(*ordering_exprs, DagRun.id.desc()).limit(limit).all() + runs = session.scalars(query.order_by(*ordering_exprs, DagRun.id.desc()).limit(limit)).all() runs.reverse() return runs diff --git a/airflow/www/views.py b/airflow/www/views.py index 2969ed7a8af76..6762a27339388 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -95,7 +95,7 @@ from airflow.jobs.job import Job from airflow.jobs.scheduler_job_runner import SchedulerJobRunner from airflow.jobs.triggerer_job_runner import TriggererJobRunner -from airflow.models import Connection, DagModel, DagTag, Log, SlaMiss, TaskFail, XCom, errors +from airflow.models import Connection, DagModel, DagTag, Log, SlaMiss, TaskFail, Trigger, XCom, errors from airflow.models.abstractoperator import AbstractOperator from airflow.models.dag import DAG, get_dataset_triggered_next_run_info from airflow.models.dagcode import DagCode @@ -291,8 +291,8 @@ def dag_to_grid(dag: DagModel, dag_runs: Sequence[DagRun], session: Session): Create a nested dict representation of the DAG's TaskGroup and its children used to construct the Graph and Grid views. """ - query = ( - session.query( + query = session.execute( + select( TaskInstance.task_id, TaskInstance.run_id, TaskInstance.state, @@ -303,7 +303,7 @@ def dag_to_grid(dag: DagModel, dag_runs: Sequence[DagRun], session: Session): func.max(TaskInstance.end_date).label("end_date"), ) .join(TaskInstance.task_instance_note, isouter=True) - .filter( + .where( TaskInstance.dag_id == dag.dag_id, TaskInstance.run_id.in_([dag_run.run_id for dag_run in dag_runs]), ) @@ -426,11 +426,9 @@ def get_summary(dag_run: DagRun): } def get_mapped_group_summaries(): - mapped_ti_query = ( - session.query( - TaskInstance.task_id, TaskInstance.state, TaskInstance.run_id, TaskInstance.map_index - ) - .filter( + mapped_ti_query = session.execute( + select(TaskInstance.task_id, TaskInstance.state, TaskInstance.run_id, TaskInstance.map_index) + .where( TaskInstance.dag_id == dag.dag_id, TaskInstance.task_id.in_(child["id"] for child in children), TaskInstance.run_id.in_(r.run_id for r in dag_runs), @@ -738,21 +736,20 @@ def index(self): with create_session() as session: # read orm_dags from the db - dags_query = session.query(DagModel).filter(~DagModel.is_subdag, DagModel.is_active) + dags_query = select(DagModel).where(~DagModel.is_subdag, DagModel.is_active) if arg_search_query: escaped_arg_search_query = arg_search_query.replace("_", r"\_") - dags_query = dags_query.filter( + dags_query = dags_query.where( DagModel.dag_id.ilike("%" + escaped_arg_search_query + "%", escape="\\") | DagModel.owners.ilike("%" + escaped_arg_search_query + "%", escape="\\") ) if arg_tags_filter: - dags_query = dags_query.filter(DagModel.tags.any(DagTag.name.in_(arg_tags_filter))) - - dags_query = dags_query.filter(DagModel.dag_id.in_(filter_dag_ids)) + dags_query = dags_query.where(DagModel.tags.any(DagTag.name.in_(arg_tags_filter))) - filtered_dag_count = dags_query.count() + dags_query = dags_query.where(DagModel.dag_id.in_(filter_dag_ids)) + filtered_dag_count = session.scalar(select(func.count()).select_from(dags_query)) if filtered_dag_count == 0 and len(arg_tags_filter): flash( "No matching DAG tags found.", @@ -762,28 +759,28 @@ def index(self): return redirect(url_for("Airflow.index")) all_dags = dags_query - active_dags = dags_query.filter(~DagModel.is_paused) - paused_dags = dags_query.filter(DagModel.is_paused) + active_dags = dags_query.where(~DagModel.is_paused) + paused_dags = dags_query.where(DagModel.is_paused) # find DAGs which have a RUNNING DagRun - running_dags = dags_query.join(DagRun, DagModel.dag_id == DagRun.dag_id).filter( + running_dags = dags_query.join(DagRun, DagModel.dag_id == DagRun.dag_id).where( DagRun.state == State.RUNNING ) # find DAGs for which the latest DagRun is FAILED subq_all = ( - session.query(DagRun.dag_id, func.max(DagRun.start_date).label("start_date")) + select(DagRun.dag_id, func.max(DagRun.start_date).label("start_date")) .group_by(DagRun.dag_id) .subquery() ) subq_failed = ( - session.query(DagRun.dag_id, func.max(DagRun.start_date).label("start_date")) - .filter(DagRun.state == State.FAILED) + select(DagRun.dag_id, func.max(DagRun.start_date).label("start_date")) + .where(DagRun.state == State.FAILED) .group_by(DagRun.dag_id) .subquery() ) subq_join = ( - session.query(subq_all.c.dag_id, subq_all.c.start_date) + select(subq_all.c.dag_id, subq_all.c.start_date) .join( subq_failed, and_( @@ -796,16 +793,18 @@ def index(self): failed_dags = dags_query.join(subq_join, DagModel.dag_id == subq_join.c.dag_id) is_paused_count = dict( - all_dags.with_entities(DagModel.is_paused, func.count(DagModel.dag_id)).group_by( - DagModel.is_paused - ) + session.execute( + select(DagModel.is_paused, func.count(DagModel.dag_id)) + .group_by(DagModel.is_paused) + .select_from(all_dags) + ).all() ) status_count_active = is_paused_count.get(False, 0) status_count_paused = is_paused_count.get(True, 0) - status_count_running = running_dags.count() - status_count_failed = failed_dags.count() + status_count_running = session.scalar(select(func.count()).select_from(running_dags)) + status_count_failed = session.scalar(select(func.count()).select_from(failed_dags)) all_dags_count = status_count_active + status_count_paused if arg_status_filter == "active": @@ -826,7 +825,7 @@ def index(self): if arg_sorting_key == "last_dagrun": dag_run_subquery = ( - session.query( + select( DagRun.dag_id, sqla.func.max(DagRun.execution_date).label("max_execution_date"), ) @@ -854,7 +853,13 @@ def index(self): else: current_dags = current_dags.order_by(null_case, sort_column) - dags = current_dags.options(joinedload(DagModel.tags)).offset(start).limit(dags_per_page).all() + dags = ( + session.scalars( + current_dags.options(joinedload(DagModel.tags)).offset(start).limit(dags_per_page) + ) + .unique() + .all() + ) user_permissions = g.user.perms can_create_dag_run = ( permissions.ACTION_CAN_CREATE, @@ -874,7 +879,7 @@ def index(self): dag.can_trigger = dag.can_edit and can_create_dag_run dag.can_delete = get_airflow_app().appbuilder.sm.can_delete_dag(dag.dag_id, g.user) - dagtags = session.query(func.distinct(DagTag.name)).order_by(DagTag.name).all() + dagtags = session.execute(select(func.distinct(DagTag.name)).order_by(DagTag.name)).all() tags = [ {"name": name, "selected": bool(arg_tags_filter and name in arg_tags_filter)} for name, in dagtags @@ -882,14 +887,15 @@ def index(self): owner_links_dict = DagOwnerAttributes.get_all(session) - import_errors = session.query(errors.ImportError).order_by(errors.ImportError.id) + import_errors = select(errors.ImportError).order_by(errors.ImportError.id) if (permissions.ACTION_CAN_READ, permissions.RESOURCE_DAG) not in user_permissions: # if the user doesn't have access to all DAGs, only display errors from visible DAGs import_errors = import_errors.join( DagModel, DagModel.fileloc == errors.ImportError.filename - ).filter(DagModel.dag_id.in_(filter_dag_ids)) + ).where(DagModel.dag_id.in_(filter_dag_ids)) + import_errors = session.scalars(import_errors) for import_error in import_errors: flash( f"Broken DAG: [{import_error.filename}] {import_error.stacktrace}", @@ -933,10 +939,13 @@ def _iter_parsed_moved_data_table_names(): permissions.RESOURCE_ADMIN_MENU, ) in user_permissions and conf.getboolean("webserver", "warn_deployment_exposure"): robots_file_access_count = ( - session.query(Log) - .filter(Log.event == "robots") - .filter(Log.dttm > (utcnow() - datetime.timedelta(days=7))) - .count() + select(Log) + .where(Log.event == "robots") + .where(Log.dttm > (utcnow() - datetime.timedelta(days=7))) + # .count() + ) + robots_file_access_count = session.scalar( + select(func.count()).select_from(robots_file_access_count) ) if robots_file_access_count > 0: flash( @@ -1038,9 +1047,11 @@ def next_run_datasets_summary(self, session: Session = NEW_SESSION): dataset_triggered_dag_ids = [ dag.dag_id for dag in ( - session.query(DagModel.dag_id) - .filter(DagModel.dag_id.in_(filter_dag_ids)) - .filter(DagModel.schedule_interval == "Dataset") + session.scalars( + select(DagModel.dag_id) + .where(DagModel.dag_id.in_(filter_dag_ids)) + .where(DagModel.schedule_interval == "Dataset") + ) ) ] @@ -1071,10 +1082,10 @@ def dag_stats(self, session: Session = NEW_SESSION): if not filter_dag_ids: return flask.json.jsonify({}) - dag_state_stats = ( - session.query(DagRun.dag_id, DagRun.state, sqla.func.count(DagRun.state)) + dag_state_stats = session.execute( + select(DagRun.dag_id, DagRun.state, sqla.func.count(DagRun.state)) .group_by(DagRun.dag_id, DagRun.state) - .filter(DagRun.dag_id.in_(filter_dag_ids)) + .where(DagRun.dag_id.in_(filter_dag_ids)) ) dag_state_data = {(dag_id, state): count for dag_id, state, count in dag_state_stats} @@ -1112,17 +1123,17 @@ def task_stats(self, session: Session = NEW_SESSION): filter_dag_ids = allowed_dag_ids running_dag_run_query_result = ( - session.query(DagRun.dag_id, DagRun.run_id) + select(DagRun.dag_id, DagRun.run_id) .join(DagModel, DagModel.dag_id == DagRun.dag_id) - .filter(DagRun.state == State.RUNNING, DagModel.is_active) + .where(DagRun.state == State.RUNNING, DagModel.is_active) ) - running_dag_run_query_result = running_dag_run_query_result.filter(DagRun.dag_id.in_(filter_dag_ids)) + running_dag_run_query_result = running_dag_run_query_result.where(DagRun.dag_id.in_(filter_dag_ids)) running_dag_run_query_result = running_dag_run_query_result.subquery("running_dag_run") # Select all task_instances from active dag_runs. - running_task_instance_query_result = session.query( + running_task_instance_query_result = select( TaskInstance.dag_id.label("dag_id"), TaskInstance.state.label("state"), sqla.literal(True).label("is_dag_running"), @@ -1136,19 +1147,19 @@ def task_stats(self, session: Session = NEW_SESSION): if conf.getboolean("webserver", "SHOW_RECENT_STATS_FOR_COMPLETED_RUNS", fallback=True): last_dag_run = ( - session.query(DagRun.dag_id, sqla.func.max(DagRun.execution_date).label("execution_date")) + select(DagRun.dag_id, sqla.func.max(DagRun.execution_date).label("execution_date")) .join(DagModel, DagModel.dag_id == DagRun.dag_id) - .filter(DagRun.state != State.RUNNING, DagModel.is_active) + .where(DagRun.state != State.RUNNING, DagModel.is_active) .group_by(DagRun.dag_id) ) - last_dag_run = last_dag_run.filter(DagRun.dag_id.in_(filter_dag_ids)) + last_dag_run = last_dag_run.where(DagRun.dag_id.in_(filter_dag_ids)) last_dag_run = last_dag_run.subquery("last_dag_run") # Select all task_instances from active dag_runs. # If no dag_run is active, return task instances from most recent dag_run. last_task_instance_query_result = ( - session.query( + select( TaskInstance.dag_id.label("dag_id"), TaskInstance.state.label("state"), sqla.literal(False).label("is_dag_running"), @@ -1169,8 +1180,8 @@ def task_stats(self, session: Session = NEW_SESSION): else: final_task_instance_query_result = running_task_instance_query_result.subquery("final_ti") - qry = ( - session.query( + qry = session.execute( + select( final_task_instance_query_result.c.dag_id, final_task_instance_query_result.c.state, final_task_instance_query_result.c.is_dag_running, @@ -1186,7 +1197,6 @@ def task_stats(self, session: Session = NEW_SESSION): final_task_instance_query_result.c.is_dag_running.desc(), ) ) - data = get_task_stats_from_query(qry) payload: dict[str, list[dict[str, Any]]] = collections.defaultdict(list) for dag_id in filter_dag_ids: @@ -1219,29 +1229,31 @@ def last_dagruns(self, session: Session = NEW_SESSION): return flask.json.jsonify({}) last_runs_subquery = ( - session.query( + select( DagRun.dag_id, sqla.func.max(DagRun.execution_date).label("max_execution_date"), ) .group_by(DagRun.dag_id) - .filter(DagRun.dag_id.in_(filter_dag_ids)) # Only include accessible/selected DAGs. + .where(DagRun.dag_id.in_(filter_dag_ids)) # Only include accessible/selected DAGs. .subquery("last_runs") ) - query = session.query( - DagRun.dag_id, - DagRun.start_date, - DagRun.end_date, - DagRun.state, - DagRun.execution_date, - DagRun.data_interval_start, - DagRun.data_interval_end, - ).join( - last_runs_subquery, - and_( - last_runs_subquery.c.dag_id == DagRun.dag_id, - last_runs_subquery.c.max_execution_date == DagRun.execution_date, - ), + query = session.execute( + select( + DagRun.dag_id, + DagRun.start_date, + DagRun.end_date, + DagRun.state, + DagRun.execution_date, + DagRun.data_interval_start, + DagRun.data_interval_end, + ).join( + last_runs_subquery, + and_( + last_runs_subquery.c.dag_id == DagRun.dag_id, + last_runs_subquery.c.max_execution_date == DagRun.execution_date, + ), + ) ) resp = { @@ -1340,19 +1352,18 @@ def dag_details(self, dag_id, session: Session = NEW_SESSION): title = "DAG Details" root = request.args.get("root", "") - states = ( - session.query(TaskInstance.state, sqla.func.count(TaskInstance.dag_id)) - .filter(TaskInstance.dag_id == dag_id) + states = session.execute( + select(TaskInstance.state, sqla.func.count(TaskInstance.dag_id)) + .where(TaskInstance.dag_id == dag_id) .group_by(TaskInstance.state) - .all() - ) + ).all() active_runs = models.DagRun.find(dag_id=dag_id, state=DagRunState.RUNNING, external_trigger=False) - tags = session.query(models.DagTag).filter(models.DagTag.dag_id == dag_id).all() + tags = session.scalars(select(models.DagTag).where(models.DagTag.dag_id == dag_id)).all() # TODO: convert this to a relationship - owner_links = session.query(DagOwnerAttributes).filter_by(dag_id=dag_id).all() + owner_links = session.execute(select(DagOwnerAttributes).filter_by(dag_id=dag_id)).all() attrs_to_avoid = [ "schedule_datasets", @@ -1617,18 +1628,17 @@ def get_logs_with_metadata(self, session: Session = NEW_SESSION): "metadata": {"end_of_log": True}, } - ti = ( - session.query(models.TaskInstance) - .filter( + ti = session.scalar( + select(models.TaskInstance) + .where( TaskInstance.task_id == task_id, TaskInstance.dag_id == dag_id, TaskInstance.execution_date == execution_date, TaskInstance.map_index == map_index, ) .join(TaskInstance.dag_run) - .options(joinedload("trigger")) - .options(joinedload("trigger.triggerer_job")) - .first() + .options(joinedload(TaskInstance.trigger).joinedload(Trigger.triggerer_job)) + .limit(1) ) if ti is None: @@ -1686,10 +1696,10 @@ def log(self, session: Session = NEW_SESSION): form = DateTimeForm(data={"execution_date": dttm}) dag_model = DagModel.get_dagmodel(dag_id) - ti = ( - session.query(models.TaskInstance) + ti = session.scalar( + select(models.TaskInstance) .filter_by(dag_id=dag_id, task_id=task_id, execution_date=dttm, map_index=map_index) - .first() + .limit(1) ) num_logs = 0 @@ -1730,10 +1740,10 @@ def redirect_to_external_log(self, session: Session = NEW_SESSION): map_index = request.args.get("map_index", -1, type=int) try_number = request.args.get("try_number", 1) - ti = ( - session.query(models.TaskInstance) + ti = session.scalar( + select(models.TaskInstance) .filter_by(dag_id=dag_id, task_id=task_id, execution_date=dttm, map_index=map_index) - .first() + .limit(1) ) if not ti: @@ -1775,8 +1785,8 @@ def task(self, session: Session = NEW_SESSION): task = copy.copy(dag.get_task(task_id)) task.resolve_template_files() - ti: TaskInstance | None = ( - session.query(TaskInstance) + ti: TaskInstance | None = session.scalar( + select(TaskInstance) .options( # HACK: Eager-load relationships. This is needed because # multiple properties mis-use provide_session() that destroys @@ -1785,7 +1795,6 @@ def task(self, session: Session = NEW_SESSION): joinedload(TaskInstance.trigger, innerjoin=False), ) .filter_by(execution_date=dttm, dag_id=dag_id, task_id=task_id, map_index=map_index) - .one_or_none() ) if ti is None: ti_attrs: list[tuple[str, Any]] | None = None @@ -1908,7 +1917,7 @@ def xcom(self, session: Session = NEW_SESSION): flash(f"Task [{dag_id}.{task_id}] doesn't seem to exist at the moment", "error") return redirect(url_for("Airflow.index")) - xcom_query = session.query(XCom.key, XCom.value).filter( + xcom_query = session.query(XCom.key, XCom.value).where( XCom.dag_id == dag_id, XCom.task_id == task_id, XCom.execution_date == dttm, @@ -1982,7 +1991,7 @@ def trigger(self, dag_id: str, session: Session = NEW_SESSION): request_execution_date = request.values.get("execution_date", default=timezone.utcnow().isoformat()) is_dag_run_conf_overrides_params = conf.getboolean("core", "dag_run_conf_overrides_params") dag = get_airflow_app().dag_bag.get_dag(dag_id) - dag_orm: DagModel = session.query(DagModel).filter(DagModel.dag_id == dag_id).first() + dag_orm: DagModel = session.query(DagModel).where(DagModel.dag_id == dag_id).first() # Prepare form fields with param struct details to render a proper form with schema information form_fields = {} @@ -2024,7 +2033,7 @@ def trigger(self, dag_id: str, session: Session = NEW_SESSION): session.query( DagRun.conf, func.max(DagRun.run_id).label("run_id"), func.max(DagRun.execution_date) ) - .filter( + .where( DagRun.dag_id == dag_id, DagRun.run_type == DagRunType.MANUAL, DagRun.conf.isnot(None), @@ -2297,15 +2306,15 @@ def clear(self, *, session: Session = NEW_SESSION): # Lock the related dag runs to prevent from possible dead lock. # https://github.com/apache/airflow/pull/26658 - dag_runs_query = session.query(DagRun.id).filter(DagRun.dag_id == dag_id).with_for_update() + dag_runs_query = session.query(DagRun.id).where(DagRun.dag_id == dag_id).with_for_update() if start_date is None and end_date is None: - dag_runs_query = dag_runs_query.filter(DagRun.execution_date == start_date) + dag_runs_query = dag_runs_query.where(DagRun.execution_date == start_date) else: if start_date is not None: - dag_runs_query = dag_runs_query.filter(DagRun.execution_date >= start_date) + dag_runs_query = dag_runs_query.where(DagRun.execution_date >= start_date) if end_date is not None: - dag_runs_query = dag_runs_query.filter(DagRun.execution_date <= end_date) + dag_runs_query = dag_runs_query.where(DagRun.execution_date <= end_date) locked_dag_run_ids = dag_runs_query.all() elif task_id: @@ -2396,8 +2405,8 @@ def blocked(self, session: Session = NEW_SESSION): dags = ( session.query(DagRun.dag_id, sqla.func.count(DagRun.id)) - .filter(DagRun.state == DagRunState.RUNNING) - .filter(DagRun.dag_id.in_(filter_dag_ids)) + .where(DagRun.state == DagRunState.RUNNING) + .where(DagRun.dag_id.in_(filter_dag_ids)) .group_by(DagRun.dag_id) ) @@ -2478,7 +2487,7 @@ def _mark_dagrun_state_as_queued( # Identify tasks that will be queued up to run when confirmed all_task_ids = [task.task_id for task in dag.tasks] - existing_tis = session.query(TaskInstance.task_id).filter( + existing_tis = session.query(TaskInstance.task_id).where( TaskInstance.dag_id == dag.dag_id, TaskInstance.run_id == dag_run_id, ) @@ -2970,7 +2979,7 @@ def _convert_to_date(session, column): func.max(DagRun.data_interval_end).label("data_interval_end"), func.count("*").label("count"), ) - .filter(DagRun.dag_id == dag.dag_id) + .where(DagRun.dag_id == dag.dag_id) .group_by(_convert_to_date(session, DagRun.execution_date), DagRun.state) .order_by(_convert_to_date(session, DagRun.execution_date).asc()) .all() @@ -3254,7 +3263,8 @@ def duration(self, dag_id: str, session: Session = NEW_SESSION): ) ) if dag.partial: - ti_fails = ti_fails.filter(TaskFail.task_id.in_([t.task_id for t in dag.tasks])) + ti_fails = ti_fails.where(TaskFail.task_id.in_([t.task_id for t in dag.tasks])) + ti_fails = session.scalars(ti_fails) fails_totals: dict[tuple[str, str, str], int] = defaultdict(int) for failed_task_instance in ti_fails: dict_key = ( @@ -3595,7 +3605,7 @@ def gantt(self, dag_id: str, session: Session = NEW_SESSION): form = DateTimeWithNumRunsWithDagRunsForm(data=dt_nr_dr_data) form.execution_date.choices = dt_nr_dr_data["dr_choices"] - tis = ( + tis = session.scalars( select(TaskInstance) .where( TaskInstance.dag_id == dag_id, @@ -3608,8 +3618,8 @@ def gantt(self, dag_id: str, session: Session = NEW_SESSION): ti_fails = select(TaskFail).filter_by(run_id=dag_run_id, dag_id=dag_id) if dag.partial: - ti_fails = ti_fails.filter(TaskFail.task_id.in_([t.task_id for t in dag.tasks])) - + ti_fails = ti_fails.where(TaskFail.task_id.in_([t.task_id for t in dag.tasks])) + ti_fails = session.scalars(ti_fails) tasks = [] for ti in tis: if not dag.has_task(ti.task_id): @@ -3831,23 +3841,23 @@ def grid_data(self): with create_session() as session: query = select(DagRun).where(DagRun.dag_id == dag.dag_id, DagRun.execution_date <= base_date) - run_type = request.args.get("run_type") - if run_type: - query = query.filter(DagRun.run_type == run_type) + run_type = request.args.get("run_type") + if run_type: + query = query.where(DagRun.run_type == run_type) - run_state = request.args.get("run_state") - if run_state: - query = query.filter(DagRun.state == run_state) + run_state = request.args.get("run_state") + if run_state: + query = query.where(DagRun.state == run_state) - dag_runs = wwwutils.sorted_dag_runs(query, ordering=dag.timetable.run_ordering, limit=num_runs) - encoded_runs = [ - wwwutils.encode_dag_run(dr, json_encoder=utils_json.WebEncoder) for dr in dag_runs - ] - data = { - "groups": dag_to_grid(dag, dag_runs, session), - "dag_runs": encoded_runs, - "ordering": dag.timetable.run_ordering, - } + dag_runs = wwwutils.sorted_dag_runs( + query, ordering=dag.timetable.run_ordering, limit=num_runs, session=session + ) + encoded_runs = [wwwutils.encode_dag_run(dr, json_encoder=utils_json.WebEncoder) for dr in dag_runs] + data = { + "groups": dag_to_grid(dag, dag_runs, session), + "dag_runs": encoded_runs, + "ordering": dag.timetable.run_ordering, + } # avoid spaces to reduce payload size return ( htmlsafe_json_dumps(data, separators=(",", ":"), dumps=flask.json.dumps), @@ -3866,37 +3876,34 @@ def historical_metrics_data(self): end_date = _safe_parse_datetime(request.args.get("end_date")) with create_session() as session: # DagRuns - dag_runs_type = ( - session.query(DagRun.run_type, func.count(DagRun.run_id)) - .filter( + dag_runs_type = session.execute( + select(DagRun.run_type, func.count(DagRun.run_id)) + .where( DagRun.start_date >= start_date, or_(DagRun.end_date.is_(None), DagRun.end_date <= end_date), ) .group_by(DagRun.run_type) - .all() - ) + ).all() - dag_run_states = ( - session.query(DagRun.state, func.count(DagRun.run_id)) - .filter( + dag_run_states = session.execute( + select(DagRun.state, func.count(DagRun.run_id)) + .where( DagRun.start_date >= start_date, or_(DagRun.end_date.is_(None), DagRun.end_date <= end_date), ) .group_by(DagRun.state) - .all() - ) + ).all() # TaskInstances - task_instance_states = ( - session.query(TaskInstance.state, func.count(TaskInstance.run_id)) + task_instance_states = session.execute( + select(TaskInstance.state, func.count(TaskInstance.run_id)) .join(TaskInstance.dag_run) - .filter( + .where( DagRun.start_date >= start_date, or_(DagRun.end_date.is_(None), DagRun.end_date <= end_date), ) .group_by(TaskInstance.state) - .all() - ) + ).all() data = { "dag_run_types": { @@ -3949,7 +3956,7 @@ def next_run_datasets(self, dag_id): DatasetEvent.dataset_id == DatasetModel.id, isouter=True, ) - .filter(DagScheduleDatasetReference.dag_id == dag_id, ~DatasetModel.is_orphaned) + .where(DagScheduleDatasetReference.dag_id == dag_id, ~DatasetModel.is_orphaned) .group_by(DatasetModel.id, DatasetModel.uri) .order_by(DatasetModel.uri) ] @@ -4079,8 +4086,8 @@ def datasets_summary(self): if updated_before: filters.append(DatasetEvent.timestamp <= updated_before) - query = query.filter(*filters).offset(offset).limit(limit) - count_query = count_query.filter(*filters) + query = query.where(*filters).offset(offset).limit(limit) + count_query = count_query.where(*filters) datasets = [dict(dataset) for dataset in query] data = {"datasets": datasets, "total_entries": count_query.scalar()} @@ -4129,13 +4136,13 @@ def audit_log(self, dag_id: str, session: Session = NEW_SESSION): included_events_raw = conf.get("webserver", "audit_view_included_events", fallback=None) excluded_events_raw = conf.get("webserver", "audit_view_excluded_events", fallback=None) - query = session.query(Log).filter(Log.dag_id == dag_id) + query = session.query(Log).where(Log.dag_id == dag_id) if included_events_raw: included_events = {event.strip() for event in included_events_raw.split(",")} - query = query.filter(Log.event.in_(included_events)) + query = query.where(Log.event.in_(included_events)) elif excluded_events_raw: excluded_events = {event.strip() for event in excluded_events_raw.split(",")} - query = query.filter(Log.event.notin_(excluded_events)) + query = query.where(Log.event.notin_(excluded_events)) current_page = request.args.get("page", default=0, type=int) arg_sorting_key = request.args.get("sorting_key", "dttm") @@ -4275,7 +4282,7 @@ def apply(self, query, func): if get_airflow_app().appbuilder.sm.has_all_dags_access(g.user): return query filter_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) - return query.filter(self.model.dag_id.in_(filter_dag_ids)) + return query.where(self.model.dag_id.in_(filter_dag_ids)) class AirflowModelView(ModelView): @@ -4720,7 +4727,7 @@ def action_mulduplicate(self, connections, session: Session = NEW_SESSION): potential_connection_ids = [f"{base_conn_id}_copy{i}" for i in range(1, 11)] - query = session.query(Connection.conn_id).filter(Connection.conn_id.in_(potential_connection_ids)) + query = session.query(Connection.conn_id).where(Connection.conn_id.in_(potential_connection_ids)) found_conn_id_set = {conn_id for conn_id, in query} @@ -5392,7 +5399,7 @@ def _set_dag_runs_to_active_state(self, drs: list[DagRun], state: str, session: """This routine only supports Running and Queued state.""" try: count = 0 - for dr in session.query(DagRun).filter(DagRun.id.in_(dagrun.id for dagrun in drs)): + for dr in session.query(DagRun).where(DagRun.id.in_(dagrun.id for dagrun in drs)): count += 1 if state == State.RUNNING: dr.start_date = timezone.utcnow() @@ -5418,7 +5425,7 @@ def action_set_failed(self, drs: list[DagRun], session: Session = NEW_SESSION): try: count = 0 altered_tis = [] - for dr in session.query(DagRun).filter(DagRun.id.in_(dagrun.id for dagrun in drs)): + for dr in session.query(DagRun).where(DagRun.id.in_(dagrun.id for dagrun in drs)): count += 1 altered_tis += set_dag_run_state_to_failed( dag=get_airflow_app().dag_bag.get_dag(dr.dag_id), @@ -5446,7 +5453,7 @@ def action_set_success(self, drs: list[DagRun], session: Session = NEW_SESSION): try: count = 0 altered_tis = [] - for dr in session.query(DagRun).filter(DagRun.id.in_(dagrun.id for dagrun in drs)): + for dr in session.query(DagRun).where(DagRun.id.in_(dagrun.id for dagrun in drs)): count += 1 altered_tis += set_dag_run_state_to_success( dag=get_airflow_app().dag_bag.get_dag(dr.dag_id), @@ -5470,7 +5477,7 @@ def action_clear(self, drs: list[DagRun], session: Session = NEW_SESSION): count = 0 cleared_ti_count = 0 dag_to_tis: dict[DAG, list[TaskInstance]] = {} - for dr in session.query(DagRun).filter(DagRun.id.in_(dagrun.id for dagrun in drs)): + for dr in session.query(DagRun).where(DagRun.id.in_(dagrun.id for dagrun in drs)): count += 1 dag = get_airflow_app().dag_bag.get_dag(dr.dag_id) tis_to_clear = dag_to_tis.setdefault(dag, []) @@ -5889,7 +5896,7 @@ def autocomplete(self, session: Session = NEW_SESSION): dag_ids_query = session.query( sqla.literal("dag").label("type"), DagModel.dag_id.label("name"), - ).filter(~DagModel.is_subdag, DagModel.is_active, DagModel.dag_id.ilike(f"%{query}%")) + ).where(~DagModel.is_subdag, DagModel.is_active, DagModel.dag_id.ilike(f"%{query}%")) owners_query = ( session.query( @@ -5897,22 +5904,22 @@ def autocomplete(self, session: Session = NEW_SESSION): DagModel.owners.label("name"), ) .distinct() - .filter(~DagModel.is_subdag, DagModel.is_active, DagModel.owners.ilike(f"%{query}%")) + .where(~DagModel.is_subdag, DagModel.is_active, DagModel.owners.ilike(f"%{query}%")) ) # Hide DAGs if not showing status: "all" status = flask_session.get(FILTER_STATUS_COOKIE) if status == "active": - dag_ids_query = dag_ids_query.filter(~DagModel.is_paused) - owners_query = owners_query.filter(~DagModel.is_paused) + dag_ids_query = dag_ids_query.where(~DagModel.is_paused) + owners_query = owners_query.where(~DagModel.is_paused) elif status == "paused": - dag_ids_query = dag_ids_query.filter(DagModel.is_paused) - owners_query = owners_query.filter(DagModel.is_paused) + dag_ids_query = dag_ids_query.where(DagModel.is_paused) + owners_query = owners_query.where(DagModel.is_paused) filter_dag_ids = get_airflow_app().appbuilder.sm.get_accessible_dag_ids(g.user) - dag_ids_query = dag_ids_query.filter(DagModel.dag_id.in_(filter_dag_ids)) - owners_query = owners_query.filter(DagModel.dag_id.in_(filter_dag_ids)) + dag_ids_query = dag_ids_query.where(DagModel.dag_id.in_(filter_dag_ids)) + owners_query = owners_query.where(DagModel.dag_id.in_(filter_dag_ids)) payload = [row._asdict() for row in dag_ids_query.union(owners_query).order_by("name").limit(10)] return flask.json.jsonify(payload) From 57c61b801d46600962840fd6385154450590adaa Mon Sep 17 00:00:00 2001 From: Phani Kumar Date: Sun, 2 Jul 2023 18:19:06 +0530 Subject: [PATCH 4/5] Replace remaining session.query code --- airflow/www/views.py | 139 +++++++++++++++++++++++-------------------- 1 file changed, 75 insertions(+), 64 deletions(-) diff --git a/airflow/www/views.py b/airflow/www/views.py index 6762a27339388..e285418028c66 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -231,16 +231,15 @@ def get_date_time_num_runs_dag_runs_form_data(www_request, session, dag): # loaded and the actual requested run would be excluded by the limit(). Once # the user has changed base date to be anything else we want to use that instead. query_date = base_date - if date_time < base_date and date_time + datetime.timedelta(seconds=1) >= base_date: + if date_time < base_date <= date_time + datetime.timedelta(seconds=1): query_date = date_time - drs = ( - session.query(DagRun) - .filter(DagRun.dag_id == dag.dag_id, DagRun.execution_date <= query_date) + drs = session.scalars( + select(DagRun) + .where(DagRun.dag_id == dag.dag_id, DagRun.execution_date <= query_date) .order_by(desc(DagRun.execution_date)) .limit(num_runs) - .all() - ) + ).all() dr_choices = [] dr_state = None for dr in drs: @@ -1911,17 +1910,19 @@ def xcom(self, session: Session = NEW_SESSION): form = DateTimeForm(data={"execution_date": dttm}) root = request.args.get("root", "") dag = DagModel.get_dagmodel(dag_id) - ti = session.query(TaskInstance).filter_by(dag_id=dag_id, task_id=task_id).first() + ti = session.scalar(select(TaskInstance).filter_by(dag_id=dag_id, task_id=task_id).limit(1)) if not ti: flash(f"Task [{dag_id}.{task_id}] doesn't seem to exist at the moment", "error") return redirect(url_for("Airflow.index")) - xcom_query = session.query(XCom.key, XCom.value).where( - XCom.dag_id == dag_id, - XCom.task_id == task_id, - XCom.execution_date == dttm, - XCom.map_index == map_index, + xcom_query = session.execute( + select(XCom.key, XCom.value).where( + XCom.dag_id == dag_id, + XCom.task_id == task_id, + XCom.execution_date == dttm, + XCom.map_index == map_index, + ) ) attributes = [(k, v) for k, v in xcom_query if not k.startswith("_")] @@ -1991,7 +1992,7 @@ def trigger(self, dag_id: str, session: Session = NEW_SESSION): request_execution_date = request.values.get("execution_date", default=timezone.utcnow().isoformat()) is_dag_run_conf_overrides_params = conf.getboolean("core", "dag_run_conf_overrides_params") dag = get_airflow_app().dag_bag.get_dag(dag_id) - dag_orm: DagModel = session.query(DagModel).where(DagModel.dag_id == dag_id).first() + dag_orm: DagModel = session.scalar(select(DagModel).where(DagModel.dag_id == dag_id).limit(1)) # Prepare form fields with param struct details to render a proper form with schema information form_fields = {} @@ -2029,10 +2030,8 @@ def trigger(self, dag_id: str, session: Session = NEW_SESSION): flash(f"Cannot create dagruns because the dag {dag_id} has import errors", "error") return redirect(origin) - recent_runs = ( - session.query( - DagRun.conf, func.max(DagRun.run_id).label("run_id"), func.max(DagRun.execution_date) - ) + recent_runs = session.execute( + select(DagRun.conf, func.max(DagRun.run_id).label("run_id"), func.max(DagRun.execution_date)) .where( DagRun.dag_id == dag_id, DagRun.run_type == DagRunType.MANUAL, @@ -2306,7 +2305,9 @@ def clear(self, *, session: Session = NEW_SESSION): # Lock the related dag runs to prevent from possible dead lock. # https://github.com/apache/airflow/pull/26658 - dag_runs_query = session.query(DagRun.id).where(DagRun.dag_id == dag_id).with_for_update() + dag_runs_query = session.scalars( + select(DagRun.id).where(DagRun.dag_id == dag_id).with_for_update() + ) if start_date is None and end_date is None: dag_runs_query = dag_runs_query.where(DagRun.execution_date == start_date) else: @@ -2403,8 +2404,8 @@ def blocked(self, session: Session = NEW_SESSION): if not filter_dag_ids: return flask.json.jsonify([]) - dags = ( - session.query(DagRun.dag_id, sqla.func.count(DagRun.id)) + dags = session.execute( + select(DagRun.dag_id, sqla.func.count(DagRun.id)) .where(DagRun.state == DagRunState.RUNNING) .where(DagRun.dag_id.in_(filter_dag_ids)) .group_by(DagRun.dag_id) @@ -2487,9 +2488,11 @@ def _mark_dagrun_state_as_queued( # Identify tasks that will be queued up to run when confirmed all_task_ids = [task.task_id for task in dag.tasks] - existing_tis = session.query(TaskInstance.task_id).where( - TaskInstance.dag_id == dag.dag_id, - TaskInstance.run_id == dag_run_id, + existing_tis = session.execute( + select(TaskInstance.task_id).where( + TaskInstance.dag_id == dag.dag_id, + TaskInstance.run_id == dag_run_id, + ) ) completed_tis_ids = [task_id for task_id, in existing_tis] @@ -2971,9 +2974,9 @@ def _convert_to_date(session, column): if root: dag = dag.partial_subset(task_ids_or_regex=root, include_downstream=False, include_upstream=True) - dag_states = ( - session.query( - (_convert_to_date(session, DagRun.execution_date)).label("date"), + dag_states = session.execute( + select( + _convert_to_date(session, DagRun.execution_date).label("date"), DagRun.state, func.max(DagRun.data_interval_start).label("data_interval_start"), func.max(DagRun.data_interval_end).label("data_interval_end"), @@ -2982,8 +2985,7 @@ def _convert_to_date(session, column): .where(DagRun.dag_id == dag.dag_id) .group_by(_convert_to_date(session, DagRun.execution_date), DagRun.state) .order_by(_convert_to_date(session, DagRun.execution_date).asc()) - .all() - ) + ).all() data_dag_states = [ { @@ -3937,28 +3939,32 @@ def next_run_datasets(self, dag_id): with create_session() as session: data = [ dict(info) - for info in session.query( - DatasetModel.id, - DatasetModel.uri, - func.max(DatasetEvent.timestamp).label("lastUpdate"), - ) - .join(DagScheduleDatasetReference, DagScheduleDatasetReference.dataset_id == DatasetModel.id) - .join( - DatasetDagRunQueue, - and_( - DatasetDagRunQueue.dataset_id == DatasetModel.id, - DatasetDagRunQueue.target_dag_id == DagScheduleDatasetReference.dag_id, - ), - isouter=True, - ) - .join( - DatasetEvent, - DatasetEvent.dataset_id == DatasetModel.id, - isouter=True, + for info in session.execute( + select( + DatasetModel.id, + DatasetModel.uri, + func.max(DatasetEvent.timestamp).label("lastUpdate"), + ) + .join( + DagScheduleDatasetReference, DagScheduleDatasetReference.dataset_id == DatasetModel.id + ) + .join( + DatasetDagRunQueue, + and_( + DatasetDagRunQueue.dataset_id == DatasetModel.id, + DatasetDagRunQueue.target_dag_id == DagScheduleDatasetReference.dag_id, + ), + isouter=True, + ) + .join( + DatasetEvent, + DatasetEvent.dataset_id == DatasetModel.id, + isouter=True, + ) + .where(DagScheduleDatasetReference.dag_id == dag_id, ~DatasetModel.is_orphaned) + .group_by(DatasetModel.id, DatasetModel.uri) + .order_by(DatasetModel.uri) ) - .where(DagScheduleDatasetReference.dag_id == dag_id, ~DatasetModel.is_orphaned) - .group_by(DatasetModel.id, DatasetModel.uri) - .order_by(DatasetModel.uri) ] return ( htmlsafe_json_dumps(data, separators=(",", ":"), dumps=flask.json.dumps), @@ -4056,12 +4062,12 @@ def datasets_summary(self): if session.bind.dialect.name == "postgresql": order_by = (order_by[0].nulls_first(), *order_by[1:]) - count_query = session.query(func.count(DatasetModel.id)) + count_query = select(func.count(DatasetModel.id)) has_event_filters = bool(updated_before or updated_after) query = ( - session.query( + select( DatasetModel.id, DatasetModel.uri, func.max(DatasetEvent.timestamp).label("last_dataset_update"), @@ -4089,8 +4095,9 @@ def datasets_summary(self): query = query.where(*filters).offset(offset).limit(limit) count_query = count_query.where(*filters) + query = session.execute(query) datasets = [dict(dataset) for dataset in query] - data = {"datasets": datasets, "total_entries": count_query.scalar()} + data = {"datasets": datasets, "total_entries": session.scalar(count_query)} return ( htmlsafe_json_dumps(data, separators=(",", ":"), cls=utils_json.WebEncoder), @@ -4136,7 +4143,7 @@ def audit_log(self, dag_id: str, session: Session = NEW_SESSION): included_events_raw = conf.get("webserver", "audit_view_included_events", fallback=None) excluded_events_raw = conf.get("webserver", "audit_view_excluded_events", fallback=None) - query = session.query(Log).where(Log.dag_id == dag_id) + query = select(Log).where(Log.dag_id == dag_id) if included_events_raw: included_events = {event.strip() for event in included_events_raw.split(",")} query = query.where(Log.event.in_(included_events)) @@ -4149,7 +4156,7 @@ def audit_log(self, dag_id: str, session: Session = NEW_SESSION): arg_sorting_direction = request.args.get("sorting_direction", default="desc") logs_per_page = PAGE_SIZE - audit_logs_count = query.count() + audit_logs_count = session.scalar(select(func.count()).select_from(query)) num_of_pages = int(math.ceil(audit_logs_count / float(logs_per_page))) start = current_page * logs_per_page @@ -4161,7 +4168,7 @@ def audit_log(self, dag_id: str, session: Session = NEW_SESSION): sort_column = sort_column.desc() query = query.order_by(sort_column) - dag_audit_logs = query.offset(start).limit(logs_per_page).all() + dag_audit_logs = session.scalars(query.offset(start).limit(logs_per_page)).all() return self.render_template( "airflow/dag_audit_log.html", dag=dag, @@ -4727,7 +4734,9 @@ def action_mulduplicate(self, connections, session: Session = NEW_SESSION): potential_connection_ids = [f"{base_conn_id}_copy{i}" for i in range(1, 11)] - query = session.query(Connection.conn_id).where(Connection.conn_id.in_(potential_connection_ids)) + query = session.execute( + select(Connection.conn_id).where(Connection.conn_id.in_(potential_connection_ids)) + ) found_conn_id_set = {conn_id for conn_id, in query} @@ -5399,7 +5408,7 @@ def _set_dag_runs_to_active_state(self, drs: list[DagRun], state: str, session: """This routine only supports Running and Queued state.""" try: count = 0 - for dr in session.query(DagRun).where(DagRun.id.in_(dagrun.id for dagrun in drs)): + for dr in session.scalars(select(DagRun).where(DagRun.id.in_(dagrun.id for dagrun in drs))): count += 1 if state == State.RUNNING: dr.start_date = timezone.utcnow() @@ -5425,7 +5434,7 @@ def action_set_failed(self, drs: list[DagRun], session: Session = NEW_SESSION): try: count = 0 altered_tis = [] - for dr in session.query(DagRun).where(DagRun.id.in_(dagrun.id for dagrun in drs)): + for dr in session.scalars(select(DagRun).where(DagRun.id.in_(dagrun.id for dagrun in drs))): count += 1 altered_tis += set_dag_run_state_to_failed( dag=get_airflow_app().dag_bag.get_dag(dr.dag_id), @@ -5453,7 +5462,7 @@ def action_set_success(self, drs: list[DagRun], session: Session = NEW_SESSION): try: count = 0 altered_tis = [] - for dr in session.query(DagRun).where(DagRun.id.in_(dagrun.id for dagrun in drs)): + for dr in session.scalars(select(DagRun).where(DagRun.id.in_(dagrun.id for dagrun in drs))): count += 1 altered_tis += set_dag_run_state_to_success( dag=get_airflow_app().dag_bag.get_dag(dr.dag_id), @@ -5477,7 +5486,7 @@ def action_clear(self, drs: list[DagRun], session: Session = NEW_SESSION): count = 0 cleared_ti_count = 0 dag_to_tis: dict[DAG, list[TaskInstance]] = {} - for dr in session.query(DagRun).where(DagRun.id.in_(dagrun.id for dagrun in drs)): + for dr in session.scalars(select(DagRun).where(DagRun.id.in_(dagrun.id for dagrun in drs))): count += 1 dag = get_airflow_app().dag_bag.get_dag(dr.dag_id) tis_to_clear = dag_to_tis.setdefault(dag, []) @@ -5893,13 +5902,13 @@ def autocomplete(self, session: Session = NEW_SESSION): return flask.json.jsonify([]) # Provide suggestions of dag_ids and owners - dag_ids_query = session.query( + dag_ids_query = select( sqla.literal("dag").label("type"), DagModel.dag_id.label("name"), ).where(~DagModel.is_subdag, DagModel.is_active, DagModel.dag_id.ilike(f"%{query}%")) owners_query = ( - session.query( + select( sqla.literal("owner").label("type"), DagModel.owners.label("name"), ) @@ -5920,8 +5929,10 @@ def autocomplete(self, session: Session = NEW_SESSION): dag_ids_query = dag_ids_query.where(DagModel.dag_id.in_(filter_dag_ids)) owners_query = owners_query.where(DagModel.dag_id.in_(filter_dag_ids)) - - payload = [row._asdict() for row in dag_ids_query.union(owners_query).order_by("name").limit(10)] + payload = [ + row._asdict() + for row in session.execute(dag_ids_query.union(owners_query).order_by("name").limit(10)) + ] return flask.json.jsonify(payload) From c56bafc4f4ef7f68819b4901081cb2ad1112e461 Mon Sep 17 00:00:00 2001 From: Phani Kumar Date: Mon, 3 Jul 2023 15:52:17 +0530 Subject: [PATCH 5/5] Apply review suggestions --- airflow/utils/db.py | 6 +++--- airflow/www/views.py | 5 ++--- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/airflow/utils/db.py b/airflow/utils/db.py index 6a1f1034e942f..46ddcbfb3453f 100644 --- a/airflow/utils/db.py +++ b/airflow/utils/db.py @@ -1102,7 +1102,7 @@ def check_conn_type_null(session: Session) -> Iterable[str]: n_nulls = [] try: - n_nulls = session.execute(select(Connection.conn_id).filter(Connection.conn_type.is_(None))).all() + n_nulls = session.scalars(select(Connection.conn_id).where(Connection.conn_type.is_(None))).all() except (exc.OperationalError, exc.ProgrammingError, exc.InternalError): # fallback if tables hasn't been created yet session.rollback() @@ -1144,7 +1144,7 @@ def check_run_id_null(session: Session) -> Iterable[str]: dagrun_table.c.run_id.is_(None), dagrun_table.c.execution_date.is_(None), ) - invalid_dagrun_count = session.scalar(select(func.count(dagrun_table.c.id)).filter(invalid_dagrun_filter)) + invalid_dagrun_count = session.scalar(select(func.count(dagrun_table.c.id)).where(invalid_dagrun_filter)) if invalid_dagrun_count > 0: dagrun_dangling_table_name = _format_airflow_moved_table_name(dagrun_table.name, "2.2", "dangling") if dagrun_dangling_table_name in inspect(session.get_bind()).get_table_names(): @@ -1335,7 +1335,7 @@ def _move_duplicate_data_to_new_table( dialect_name = bind.dialect.name query = ( - session.select(*[getattr(source_table.c, x.name).label(str(x.name)) for x in source_table.columns]) + select(*[getattr(source_table.c, x.name).label(str(x.name)) for x in source_table.columns]) .select_from(source_table) .join(subquery, and_(*[getattr(source_table.c, x) == getattr(subquery.c, x) for x in uniqueness])) ) diff --git a/airflow/www/views.py b/airflow/www/views.py index e285418028c66..d791eefc37c76 100644 --- a/airflow/www/views.py +++ b/airflow/www/views.py @@ -941,7 +941,6 @@ def _iter_parsed_moved_data_table_names(): select(Log) .where(Log.event == "robots") .where(Log.dttm > (utcnow() - datetime.timedelta(days=7))) - # .count() ) robots_file_access_count = session.scalar( select(func.count()).select_from(robots_file_access_count) @@ -4734,11 +4733,11 @@ def action_mulduplicate(self, connections, session: Session = NEW_SESSION): potential_connection_ids = [f"{base_conn_id}_copy{i}" for i in range(1, 11)] - query = session.execute( + query = session.scalars( select(Connection.conn_id).where(Connection.conn_id.in_(potential_connection_ids)) ) - found_conn_id_set = {conn_id for conn_id, in query} + found_conn_id_set = {conn_id for conn_id in query} possible_conn_id_iter = ( connection_id