diff --git a/airflow/api_connexion/endpoints/xcom_endpoint.py b/airflow/api_connexion/endpoints/xcom_endpoint.py index 2592611abaa45..cd317ad819979 100644 --- a/airflow/api_connexion/endpoints/xcom_endpoint.py +++ b/airflow/api_connexion/endpoints/xcom_endpoint.py @@ -14,10 +14,16 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from flask import request +from sqlalchemy import and_, func +from sqlalchemy.orm.session import Session -from sqlalchemy import and_ - -from airflow.api_connexion.schemas.xcom_schema import xcom_collection_item_schema +from airflow.api_connexion import parameters +from airflow.api_connexion.exceptions import NotFound +from airflow.api_connexion.schemas.xcom_schema import ( + XComCollection, XComCollectionItemSchema, XComCollectionSchema, xcom_collection_item_schema, + xcom_collection_schema, +) from airflow.models import DagRun as DR, XCom from airflow.utils.session import provide_session @@ -29,15 +35,44 @@ def delete_xcom_entry(): raise NotImplementedError("Not implemented yet.") -def get_xcom_entries(): +@provide_session +def get_xcom_entries( + dag_id: str, + dag_run_id: str, + task_id: str, + session: Session +) -> XComCollectionSchema: """ Get all XCom values """ - raise NotImplementedError("Not implemented yet.") + offset = request.args.get(parameters.page_offset, 0) + limit = min(int(request.args.get(parameters.page_limit, 100)), 100) + query = session.query(XCom) + if dag_id != '~': + query = query.filter(XCom.dag_id == dag_id) + query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.execution_date == DR.execution_date)) + else: + query.join(DR, XCom.execution_date == DR.execution_date) + if task_id != '~': + query = query.filter(XCom.task_id == task_id) + if dag_run_id != '~': + query = query.filter(DR.run_id == dag_run_id) + query = query.order_by( + XCom.execution_date, XCom.task_id, XCom.dag_id, XCom.key + ) + total_entries = session.query(func.count(XCom.key)).scalar() + query = query.offset(offset).limit(limit) + return xcom_collection_schema.dump(XComCollection(xcom_entries=query.all(), total_entries=total_entries)) @provide_session -def get_xcom_entry(dag_id, task_id, dag_run_id, xcom_key, session): +def get_xcom_entry( + dag_id: str, + task_id: str, + dag_run_id: str, + xcom_key: str, + session: Session +) -> XComCollectionItemSchema: """ Get an XCom entry """ @@ -48,10 +83,10 @@ def get_xcom_entry(dag_id, task_id, dag_run_id, xcom_key, session): query = query.join(DR, and_(XCom.dag_id == DR.dag_id, XCom.execution_date == DR.execution_date)) query = query.filter(DR.run_id == dag_run_id) - q_object = query.one_or_none() - if not q_object: - raise Exception("Object Not found") - return xcom_collection_item_schema.dump(q_object) + query_object = query.one_or_none() + if not query_object: + raise NotFound("XCom entry not found") + return xcom_collection_item_schema.dump(query_object) def patch_xcom_entry(): diff --git a/airflow/api_connexion/openapi/v1.yaml b/airflow/api_connexion/openapi/v1.yaml index 68f696570a854..dae24cf1c4971 100644 --- a/airflow/api_connexion/openapi/v1.yaml +++ b/airflow/api_connexion/openapi/v1.yaml @@ -2193,17 +2193,6 @@ components: required: true description: The Variable Key. - ExecutionDate: - in: path - name: execution_date - schema: - type: string - format: date-time - required: true - description: The date-time notation as defined by - [RFC 3339, section 5.6](https://tools.ietf.org/html/rfc3339#section-5.6), - E.G. `2017-07-21T17:32:28Z` - # Logs FullContent: in: query diff --git a/airflow/api_connexion/schemas/xcom_schema.py b/airflow/api_connexion/schemas/xcom_schema.py index 58d6ec391ba8a..5adc36da34da6 100644 --- a/airflow/api_connexion/schemas/xcom_schema.py +++ b/airflow/api_connexion/schemas/xcom_schema.py @@ -14,54 +14,50 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. -from marshmallow import post_dump +from typing import List, NamedTuple + +from marshmallow import Schema, fields from marshmallow_sqlalchemy import SQLAlchemySchema, auto_field from airflow.models import XCom class XComCollectionItemSchema(SQLAlchemySchema): + """ + Schema for a xcom item + """ class Meta: """ Meta """ model = XCom - COLLECTION_NAME = 'xcom_entries' - FIELDS_FROM_NONE_TO_EMPTY_STRING = ['key', 'task_id', 'dag_id'] - key = auto_field() timestamp = auto_field() execution_date = auto_field() task_id = auto_field() dag_id = auto_field() - @post_dump(pass_many=True) - def wrap_with_envelope(self, data, many, **kwargs): - """ - :param data: Deserialized data - :param many: Collection or an item - """ - if many: - data = self._process_list_data(data) - return {self.COLLECTION_NAME: data, 'total_entries': len(data)} - data = self._process_data(data) - return data - def _process_list_data(self, data): - return [self._process_data(x) for x in data] +class XComSchema(XComCollectionItemSchema): + """ + XCom schema + """ - def _process_data(self, data): - for key in self.FIELDS_FROM_NONE_TO_EMPTY_STRING: - if not data[key]: - data.update({key: ''}) - return data + value = auto_field() -class XComSchema(XComCollectionItemSchema): +class XComCollection(NamedTuple): + """ List of XComs with meta""" + xcom_entries: List[XCom] + total_entries: int - value = auto_field() + +class XComCollectionSchema(Schema): + """ XCom Collection Schema""" + xcom_entries = fields.List(fields.Nested(XComCollectionItemSchema)) + total_entries = fields.Int() -xcom_schema = XComSchema() -xcom_collection_item_schema = XComCollectionItemSchema() -xcom_collection_schema = XComCollectionItemSchema(many=True) +xcom_schema = XComSchema(strict=True) +xcom_collection_item_schema = XComCollectionItemSchema(strict=True) +xcom_collection_schema = XComCollectionSchema(strict=True) diff --git a/airflow/models/dagrun.py b/airflow/models/dagrun.py index a65aa19ab94e2..0cc03d9855af4 100644 --- a/airflow/models/dagrun.py +++ b/airflow/models/dagrun.py @@ -22,7 +22,7 @@ Boolean, Column, DateTime, Index, Integer, PickleType, String, UniqueConstraint, and_, func, or_, ) from sqlalchemy.ext.declarative import declared_attr -from sqlalchemy.orm import backref, relationship, synonym +from sqlalchemy.orm import synonym from sqlalchemy.orm.session import Session from airflow.exceptions import AirflowException @@ -66,13 +66,6 @@ class DagRun(Base, LoggingMixin): UniqueConstraint('dag_id', 'run_id'), ) - task_instances = relationship( - TI, - primaryjoin=and_(TI.dag_id == dag_id, TI.execution_date == execution_date), - foreign_keys=(dag_id, execution_date), - backref=backref('dag_run', uselist=False), - ) - def __init__(self, dag_id=None, run_id=None, execution_date=None, start_date=None, external_trigger=None, conf=None, state=None, run_type=None): self.dag_id = dag_id diff --git a/tests/api_connexion/endpoints/test_xcom_endpoint.py b/tests/api_connexion/endpoints/test_xcom_endpoint.py index fcf7b41a3b9d0..3f36e602d8534 100644 --- a/tests/api_connexion/endpoints/test_xcom_endpoint.py +++ b/tests/api_connexion/endpoints/test_xcom_endpoint.py @@ -17,6 +17,7 @@ import unittest import pytest +from parameterized import parameterized from airflow.models import DagRun as DR, XCom from airflow.utils.dates import parse_execution_date @@ -32,18 +33,29 @@ def setUpClass(cls) -> None: cls.app = app.create_app(testing=True) # type:ignore def setUp(self) -> None: + """ + Setup For XCom endpoint TC + """ self.client = self.app.test_client() # type:ignore # clear existing xcoms with create_session() as session: session.query(XCom).delete() session.query(DR).delete() + def tearDown(self) -> None: + """ + Clear Hanging XComs + """ + with create_session() as session: + session.query(XCom).delete() + session.query(DR).delete() + class TestDeleteXComEntry(TestXComEndpoint): @pytest.mark.skip(reason="Not implemented yet") def test_should_response_200(self): response = self.client.delete( - "/dags/TEST_DAG_ID/taskInstances/TEST_TASK_ID/2005-04-02T21:37:42Z/xcomEntries/XCOM_KEY" + "/dags/TEST_DAG_ID/taskInstances/TEST_TASK_ID/2005-04-02T00:00:00Z/xcomEntries/XCOM_KEY" ) assert response.status_code == 204 @@ -52,35 +64,29 @@ class TestGetXComEntry(TestXComEndpoint): @provide_session def test_should_response_200(self, session): - # WIP datetime spece dag_id = 'test-dag-id' task_id = 'test-task-id' - execution_date = '2005-04-02T21:37:42+00:00' + execution_date = '2005-04-02T00:00:00+00:00' xcom_key = 'test-xcom-key' execution_date_parsed = parse_execution_date(execution_date) - xcom_model = XCom( - key=xcom_key, - execution_date=execution_date_parsed, - task_id=task_id, - dag_id=dag_id, - timestamp=execution_date_parsed, - ) + xcom_model = XCom(key=xcom_key, + execution_date=execution_date_parsed, + task_id=task_id, + dag_id=dag_id, + timestamp=execution_date_parsed) dag_run_id = DR.generate_run_id(DagRunType.MANUAL, execution_date_parsed) - dagrun = DR( - dag_id=dag_id, - run_id=dag_run_id, - execution_date=execution_date_parsed, - start_date=execution_date_parsed, - run_type=DagRunType.MANUAL.value, - ) + dagrun = DR(dag_id=dag_id, + run_id=dag_run_id, + execution_date=execution_date_parsed, + start_date=execution_date_parsed, + run_type=DagRunType.MANUAL.value) session.add(xcom_model) session.add(dagrun) session.commit() response = self.client.get( - f"api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}" + f"/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries/{xcom_key}" ) self.assertEqual(200, response.status_code) - print(response.json) self.assertEqual( response.json, { @@ -94,19 +100,156 @@ def test_should_response_200(self, session): class TestGetXComEntries(TestXComEndpoint): - @pytest.mark.skip(reason="Not implemented yet") - def test_should_response_200(self): + @provide_session + def test_should_response_200(self, session): + dag_id = 'test-dag-id' + task_id = 'test-task-id' + execution_date = '2005-04-02T00:00:00+00:00' + execution_date_parsed = parse_execution_date(execution_date) + xcom_model_1 = XCom(key='test-xcom-key-1', + execution_date=execution_date_parsed, + task_id=task_id, + dag_id=dag_id, + timestamp=execution_date_parsed) + xcom_model_2 = XCom(key='test-xcom-key-2', + execution_date=execution_date_parsed, + task_id=task_id, + dag_id=dag_id, + timestamp=execution_date_parsed) + dag_run_id = DR.generate_run_id(DagRunType.MANUAL, execution_date_parsed) + dagrun = DR(dag_id=dag_id, + run_id=dag_run_id, + execution_date=execution_date_parsed, + start_date=execution_date_parsed, + run_type=DagRunType.MANUAL.value) + xcom_models = [xcom_model_1, xcom_model_2] + session.add_all(xcom_models) + session.add(dagrun) + session.commit() response = self.client.get( - "/dags/TEST_DAG_ID/taskInstances/TEST_TASK_ID/2005-04-02T21:37:42Z/xcomEntries/" + f"/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries" + ) + self.assertEqual(200, response.status_code) + self.assertEqual( + response.json, + { + 'xcom_entries': [ + { + 'dag_id': dag_id, + 'execution_date': execution_date, + 'key': 'test-xcom-key-1', + 'task_id': task_id, + 'timestamp': execution_date + }, + { + 'dag_id': dag_id, + 'execution_date': execution_date, + 'key': 'test-xcom-key-2', + 'task_id': task_id, + 'timestamp': execution_date + } + ], + 'total_entries': 2, + } ) + + +class TestPaginationGetXComEntries(TestXComEndpoint): + + def setUp(self): + super().setUp() + self.dag_id = 'test-dag-id' + self.task_id = 'test-task-id' + self.execution_date = '2005-04-02T00:00:00+00:00' + self.execution_date_parsed = parse_execution_date(self.execution_date) + self.dag_run_id = DR.generate_run_id(DagRunType.MANUAL, self.execution_date_parsed) + + @parameterized.expand( + [ + ( + "limit=1", + ["TEST_XCOM_KEY1"], + ), + ( + "limit=2", + ["TEST_XCOM_KEY1", "TEST_XCOM_KEY10"], + ), + ( + "offset=5", + [ + "TEST_XCOM_KEY5", + "TEST_XCOM_KEY6", + "TEST_XCOM_KEY7", + "TEST_XCOM_KEY8", + "TEST_XCOM_KEY9", + ] + ), + ( + "offset=0", + [ + "TEST_XCOM_KEY1", + "TEST_XCOM_KEY10", + "TEST_XCOM_KEY2", + "TEST_XCOM_KEY3", + "TEST_XCOM_KEY4", + "TEST_XCOM_KEY5", + "TEST_XCOM_KEY6", + "TEST_XCOM_KEY7", + "TEST_XCOM_KEY8", + "TEST_XCOM_KEY9" + ] + ), + ( + "limit=1&offset=5", + ["TEST_XCOM_KEY5"], + ), + ( + "limit=1&offset=1", + ["TEST_XCOM_KEY10"], + ), + ( + "limit=2&offset=2", + ["TEST_XCOM_KEY2", "TEST_XCOM_KEY3"], + ), + ] + ) + @provide_session + def test_handle_limit_offset(self, query_params, expected_xcom_ids, session): + url = "/api/v1/dags/{dag_id}/dagRuns/{dag_run_id}/taskInstances/{task_id}/xcomEntries?{query_params}" + url = url.format(dag_id=self.dag_id, + dag_run_id=self.dag_run_id, + task_id=self.task_id, + query_params=query_params) + dagrun = DR(dag_id=self.dag_id, + run_id=self.dag_run_id, + execution_date=self.execution_date_parsed, + start_date=self.execution_date_parsed, + run_type=DagRunType.MANUAL.value) + xcom_models = self._create_xcoms(10) + session.add_all(xcom_models) + session.add(dagrun) + session.commit() + response = self.client.get(url) assert response.status_code == 200 + self.assertEqual(response.json["total_entries"], 10) + conn_ids = [conn["key"] for conn in response.json["xcom_entries"] if conn] + self.assertEqual(conn_ids, expected_xcom_ids) + + def _create_xcoms(self, count): + return [XCom( + key=f'TEST_XCOM_KEY{i}', + execution_date=self.execution_date_parsed, + task_id=self.task_id, + dag_id=self.dag_id, + timestamp=self.execution_date_parsed, + ) for i in range(1, count + 1)] class TestPatchXComEntry(TestXComEndpoint): @pytest.mark.skip(reason="Not implemented yet") def test_should_response_200(self): response = self.client.patch( - "/dags/TEST_DAG_ID/taskInstances/TEST_TASK_ID/2005-04-02T21:37:42Z/xcomEntries" + "/dags/TEST_DAG_ID/taskInstances/TEST_TASK_ID/2005-04-02T00:00:00Z/xcomEntries" ) assert response.status_code == 200 @@ -115,6 +258,6 @@ class TestPostXComEntry(TestXComEndpoint): @pytest.mark.skip(reason="Not implemented yet") def test_should_response_200(self): response = self.client.post( - "/dags/TEST_DAG_ID/taskInstances/TEST_TASK_ID/2005-04-02T21:37:42Z/xcomEntries/XCOM_KEY" + "/dags/TEST_DAG_ID/taskInstances/TEST_TASK_ID/2005-04-02T00:00:00Z/xcomEntries/XCOM_KEY" ) assert response.status_code == 200 diff --git a/tests/api_connexion/schemas/test_xcom_schema.py b/tests/api_connexion/schemas/test_xcom_schema.py index 966ee46d667d5..d66c8ce58f89b 100644 --- a/tests/api_connexion/schemas/test_xcom_schema.py +++ b/tests/api_connexion/schemas/test_xcom_schema.py @@ -19,26 +19,43 @@ from sqlalchemy import or_ from airflow.api_connexion.schemas.xcom_schema import ( - xcom_collection_item_schema, xcom_collection_schema, xcom_schema, + XComCollection, xcom_collection_item_schema, xcom_collection_schema, xcom_schema, ) from airflow.models import XCom -from airflow.utils import timezone +from airflow.utils.dates import parse_execution_date from airflow.utils.session import create_session, provide_session -class TestXComCollectionItemSchema(unittest.TestCase): +class TestXComSchemaBase(unittest.TestCase): - def setUp(self) -> None: - self.now = timezone.utcnow() + def setUp(self): + """ + Clear Hanging XComs pre test + """ + with create_session() as session: + session.query(XCom).delete() + + def tearDown(self) -> None: + """ + Clear Hanging XComs post test + """ with create_session() as session: session.query(XCom).delete() + +class TestXComCollectionItemSchema(TestXComSchemaBase): + + def setUp(self) -> None: + super().setUp() + self.default_time = '2005-04-02T21:00:00+00:00' + self.default_time_parsed = parse_execution_date(self.default_time) + @provide_session def test_serialize(self, session): xcom_model = XCom( key='test_key', - timestamp=self.now, - execution_date=self.now, + timestamp=self.default_time_parsed, + execution_date=self.default_time_parsed, task_id='test_task_id', dag_id='test_dag', ) @@ -50,19 +67,18 @@ def test_serialize(self, session): deserialized_xcom[0], { 'key': 'test_key', - 'timestamp': self.now.isoformat(), - 'execution_date': self.now.isoformat(), + 'timestamp': self.default_time, + 'execution_date': self.default_time, 'task_id': 'test_task_id', 'dag_id': 'test_dag', } ) - @provide_session - def test_deserialize(self, session): + def test_deserialize(self): xcom_dump = { 'key': 'test_key', - 'timestamp': self.now.isoformat(), - 'execution_date': self.now.isoformat(), + 'timestamp': self.default_time, + 'execution_date': self.default_time, 'task_id': 'test_task_id', 'dag_id': 'test_dag', } @@ -71,82 +87,87 @@ def test_deserialize(self, session): result[0], { 'key': 'test_key', - 'timestamp': self.now, - 'execution_date': self.now, + 'timestamp': self.default_time_parsed, + 'execution_date': self.default_time_parsed, 'task_id': 'test_task_id', 'dag_id': 'test_dag', } ) -class TestXComCollectionSchema(unittest.TestCase): +class TestXComCollectionSchema(TestXComSchemaBase): def setUp(self) -> None: - self.t1 = timezone.utcnow() - self.t2 = timezone.utcnow() - with create_session() as session: - session.query(XCom).delete() + super().setUp() + self.default_time_1 = '2005-04-02T21:00:00+00:00' + self.default_time_2 = '2005-04-02T21:01:00+00:00' + self.time_1 = parse_execution_date(self.default_time_1) + self.time_2 = parse_execution_date(self.default_time_2) @provide_session def test_serialize(self, session): xcom_model_1 = XCom( key='test_key_1', - timestamp=self.t1, - execution_date=self.t1, + timestamp=self.time_1, + execution_date=self.time_1, task_id='test_task_id_1', dag_id='test_dag_1', ) xcom_model_2 = XCom( key='test_key_2', - timestamp=self.t2, - execution_date=self.t2, + timestamp=self.time_2, + execution_date=self.time_2, task_id='test_task_id_2', dag_id='test_dag_2', ) xcom_models = [xcom_model_1, xcom_model_2] session.add_all(xcom_models) session.commit() - xcom_models_queried = session.query(XCom).filter( - or_(XCom.execution_date == self.t1, XCom.execution_date == self.t2) - ).all() - deserialized_xcoms = xcom_collection_schema.dump(xcom_models_queried) + xcom_models_query = session.query(XCom).filter( + or_(XCom.execution_date == self.time_1, XCom.execution_date == self.time_2) + ) + xcom_models_queried = xcom_models_query.all() + deserialized_xcoms = xcom_collection_schema.dump(XComCollection( + xcom_entries=xcom_models_queried, + total_entries=xcom_models_query.count(), + )) self.assertEqual( deserialized_xcoms[0], { 'xcom_entries': [ { 'key': 'test_key_1', - 'timestamp': self.t1.isoformat(), - 'execution_date': self.t1.isoformat(), + 'timestamp': self.default_time_1, + 'execution_date': self.default_time_1, 'task_id': 'test_task_id_1', 'dag_id': 'test_dag_1', }, { 'key': 'test_key_2', - 'timestamp': self.t2.isoformat(), - 'execution_date': self.t2.isoformat(), + 'timestamp': self.default_time_2, + 'execution_date': self.default_time_2, 'task_id': 'test_task_id_2', 'dag_id': 'test_dag_2', } ], - 'total_entries': 2 + 'total_entries': len(xcom_models), } ) -class TestXComSchema(unittest.TestCase): +class TestXComSchema(TestXComSchemaBase): def setUp(self) -> None: - self.now = timezone.utcnow() - with create_session() as session: - session.query(XCom).delete() + super().setUp() + self.default_time = '2005-04-02T21:00:00+00:00' + self.default_time_parsed = parse_execution_date(self.default_time) @provide_session def test_serialize(self, session): xcom_model = XCom( key='test_key', - timestamp=self.now, - execution_date=self.now, + timestamp=self.default_time_parsed, + execution_date=self.default_time_parsed, task_id='test_task_id', dag_id='test_dag', value=b'test_binary', @@ -159,20 +180,19 @@ def test_serialize(self, session): deserialized_xcom[0], { 'key': 'test_key', - 'timestamp': self.now.isoformat(), - 'execution_date': self.now.isoformat(), + 'timestamp': self.default_time, + 'execution_date': self.default_time, 'task_id': 'test_task_id', 'dag_id': 'test_dag', 'value': 'test_binary', } ) - @provide_session - def test_deserialize(self, session): + def test_deserialize(self): xcom_dump = { 'key': 'test_key', - 'timestamp': self.now.isoformat(), - 'execution_date': self.now.isoformat(), + 'timestamp': self.default_time, + 'execution_date': self.default_time, 'task_id': 'test_task_id', 'dag_id': 'test_dag', 'value': b'test_binary', @@ -182,8 +202,8 @@ def test_deserialize(self, session): result[0], { 'key': 'test_key', - 'timestamp': self.now, - 'execution_date': self.now, + 'timestamp': self.default_time_parsed, + 'execution_date': self.default_time_parsed, 'task_id': 'test_task_id', 'dag_id': 'test_dag', 'value': 'test_binary',