Skip to content

Commit

Permalink
Branches send_all_records dag to determine http or ftp transmission.
Browse files Browse the repository at this point in the history
  • Loading branch information
shelleydoljack committed Feb 14, 2025
1 parent f91089a commit d672c6e
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 8 deletions.
57 changes: 49 additions & 8 deletions libsys_airflow/dags/data_exports/full_dump_transmission.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
import logging
from datetime import datetime, timedelta

from airflow.decorators import dag
from airflow.decorators import dag, task
from airflow.models.param import Param
from airflow.models.connection import Connection
from airflow.models import Variable
from airflow.operators.empty import EmptyOperator

from libsys_airflow.plugins.data_exports.transmission_tasks import (
gather_files_task,
retry_failed_files_task,
transmit_data_http_task,
transmit_data_ftp_task,
)

from libsys_airflow.plugins.data_exports.email import (
Expand All @@ -28,6 +30,32 @@
}


@task(multiple_outputs=True)
def retrieve_params(**kwargs):
"""
Determine connection type based on "vendor" Param
"""
params = kwargs.get("params", {})
conn_id = params["vendor"]
return {"conn_id": conn_id}


@task.branch()
def http_or_ftp_path(**kwargs):
"""
Determine transmission type based on conn_type from connection
"""
conn_id = kwargs.get("connection")
logger.info(f"Send all records to vendor {conn_id}")
connection = Connection.get_connection_from_secrets(conn_id)
conn_type = connection.conn_type
logger.info(f"Transmit data via {conn_type}")
if conn_type == "http":
return "transmit_data_http_task"
else:
return "transmit_data_ftp_task"


@dag(
default_args=default_args,
schedule=None,
Expand All @@ -39,7 +67,7 @@
"pod",
type="string",
description="Send all records to this vendor.",
enum=["pod", "sharevde"],
enum=["pod", "sharevde", "backstage"],
),
"bucket": Param(
Variable.get("FOLIO_AWS_BUCKET", "folio-data-export-prod"), type="string"
Expand All @@ -53,6 +81,11 @@ def send_all_records():

gather_files = gather_files_task(vendor="full-dump")

vars = retrieve_params()

choose_branch = http_or_ftp_path(connection=vars["conn_id"])

# http branch
transmit_data = transmit_data_http_task(
gather_files,
files_params="upload[files][]",
Expand All @@ -69,13 +102,21 @@ def send_all_records():

email_failures = failed_transmission_email(retry_transmission["failures"])

# ftp branch
transmit_data_ftp = transmit_data_ftp_task(vars["conn_id"], gather_files)
retry_files_ftp = retry_failed_files_task(
vendor="full-dump", files=transmit_data_ftp["failures"]
)
retry_transmit_data_ftp = transmit_data_ftp_task(vars["conn_id"], retry_files_ftp)
email_failures_ftp = failed_transmission_email(retry_transmit_data_ftp["failures"])

start >> gather_files >> vars >> choose_branch >> [transmit_data, transmit_data_ftp]
transmit_data >> retry_files >> retry_transmission >> email_failures >> end
(
start
>> gather_files
>> transmit_data
>> retry_files
>> retry_transmission
>> email_failures
transmit_data_ftp
>> retry_files_ftp
>> retry_transmit_data_ftp
>> email_failures_ftp
>> end
)

Expand Down
32 changes: 32 additions & 0 deletions tests/data_exports/test_transmission_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
set_holdings_oclc_task,
)

from libsys_airflow.dags.data_exports.full_dump_transmission import (
http_or_ftp_path,
)


@pytest.fixture(params=["pod", "gobi", "backstage"])
def mock_vendor_marc_files(tmp_path, request):
Expand Down Expand Up @@ -90,6 +94,12 @@ def mock_marc_files(mock_file_system):
return {"file_list": marc_files, "s3": False}


@pytest.fixture(params=["http-example.com", "ftp-example.com"])
def mock_full_dump_params(request):
conn_id = request.param
return {"conn_id": conn_id}


@pytest.fixture
def mock_httpx_connection():
return Connection(
Expand Down Expand Up @@ -268,6 +278,28 @@ def test_transmit_data_task(
assert "Setting URL params to" not in caplog.text


@pytest.mark.parametrize("mock_full_dump_params", ["http-example.com"], indirect=True)
def test_full_dump_http(mocker, mock_httpx_connection, mock_full_dump_params, caplog):
mocker.patch(
"libsys_airflow.plugins.data_exports.transmission_tasks.Connection.get_connection_from_secrets",
return_value=mock_httpx_connection,
)
branch = http_or_ftp_path.function(connection=mock_full_dump_params["conn_id"])
assert branch == "transmit_data_http_task"
assert "Transmit data via http" in caplog.text


@pytest.mark.parametrize("mock_full_dump_params", ["ftp-example.com"], indirect=True)
def test_full_dump_ftp(mocker, mock_ftphook_connection, mock_full_dump_params, caplog):
mocker.patch(
"libsys_airflow.plugins.data_exports.transmission_tasks.Connection.get_connection_from_secrets",
return_value=mock_ftphook_connection,
)
branch = http_or_ftp_path.function(connection=mock_full_dump_params["conn_id"])
assert branch == "transmit_data_ftp_task"
assert "Transmit data via ftp" in caplog.text


def test_transmit_data_from_s3_task(
mocker, mock_httpx_connection, mock_httpx_success, mock_marc_files, caplog
):
Expand Down

0 comments on commit d672c6e

Please sign in to comment.