Skip to content

Commit

Permalink
fix more tests
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronsteers committed Dec 7, 2024
1 parent be9225d commit d8b40d1
Show file tree
Hide file tree
Showing 13 changed files with 334 additions and 153 deletions.
151 changes: 113 additions & 38 deletions airbyte/_util/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
from __future__ import annotations

import json
from pydoc import cli
import token
from typing import TYPE_CHECKING, Any

import airbyte_api
Expand All @@ -26,6 +28,7 @@
AirbyteMultipleResourcesError,
PyAirbyteInputError,
)
from airbyte.secrets.base import SecretString


if TYPE_CHECKING:
Expand All @@ -46,13 +49,19 @@ def status_ok(status_code: int) -> bool:

def get_airbyte_server_instance(
*,
api_key: str,
api_root: str,
client_id: SecretString,
client_secret: SecretString,
) -> airbyte_api.AirbyteAPI:
"""Get an Airbyte instance."""
return airbyte_api.AirbyteAPI(
security=models.Security(
bearer_auth=api_key,
client_credentials=models.SchemeClientCredentials(
client_id=client_id,
client_secret=client_secret,
token_url=api_root + "/applications/token",
# e.g. https://api.airbyte.com/v1/applications/token
),
),
server_url=api_root,
)
Expand All @@ -65,12 +74,14 @@ def get_workspace(
workspace_id: str,
*,
api_root: str,
api_key: str,
client_id: SecretString,
client_secret: SecretString,
) -> models.WorkspaceResponse:
"""Get a connection."""
airbyte_instance = get_airbyte_server_instance(
api_key=api_key,
api_root=api_root,
client_id=client_id,
client_secret=client_secret,
)
response = airbyte_instance.workspaces.get_workspace(
api.GetWorkspaceRequest(
Expand All @@ -89,14 +100,47 @@ def get_workspace(
)


# Get bearer token


def get_bearer_token(
workspace_id: str,
*,
api_root: str,
client_id: SecretString,
client_secret: SecretString,
) -> str:
"""Get a bearer token."""
airbyte_instance = get_airbyte_server_instance(
client_id=client_id,
client_secret=client_secret,
api_root=api_root,
)
response = airbyte_instance.oauth.get_bearer_token(
api.GetBearerTokenRequest(
workspace_id=workspace_id,
client_id=client_id,
client_secret=client_secret,
),
)
if status_ok(response.status_code) and response.bearer_token_response:
return response.bearer_token_response.access_token

raise AirbyteError(
message="Could not get bearer token.",
response=response,
)


# List resources


def list_connections(
workspace_id: str,
*,
api_root: str,
api_key: str,
client_id: SecretString,
client_secret: SecretString,
name: str | None = None,
name_filter: Callable[[str], bool] | None = None,
) -> list[models.ConnectionResponse]:
Expand All @@ -108,7 +152,8 @@ def list_connections(

_ = workspace_id # Not used (yet)
airbyte_instance = get_airbyte_server_instance(
api_key=api_key,
client_id=client_id,
client_secret=client_secret,
api_root=api_root,
)
response = airbyte_instance.connections.list_connections(
Expand Down Expand Up @@ -136,7 +181,8 @@ def list_sources(
workspace_id: str,
*,
api_root: str,
api_key: str,
client_id: SecretString,
client_secret: SecretString,
name: str | None = None,
name_filter: Callable[[str], bool] | None = None,
) -> list[models.SourceResponse]:
Expand All @@ -148,7 +194,8 @@ def list_sources(

_ = workspace_id # Not used (yet)
airbyte_instance = get_airbyte_server_instance(
api_key=api_key,
client_id=client_id,
client_secret=client_secret,
api_root=api_root,
)
response = airbyte_instance.sources.list_sources(
Expand All @@ -172,7 +219,8 @@ def list_destinations(
workspace_id: str,
*,
api_root: str,
api_key: str,
client_id: SecretString,
client_secret: SecretString,
name: str | None = None,
name_filter: Callable[[str], bool] | None = None,
) -> list[models.DestinationResponse]:
Expand All @@ -184,7 +232,8 @@ def list_destinations(

_ = workspace_id # Not used (yet)
airbyte_instance = get_airbyte_server_instance(
api_key=api_key,
client_id=client_id,
client_secret=client_secret,
api_root=api_root,
)
response = airbyte_instance.destinations.list_destinations(
Expand Down Expand Up @@ -216,12 +265,14 @@ def get_connection(
connection_id: str,
*,
api_root: str,
api_key: str,
client_id: SecretString,
client_secret: SecretString,
) -> models.ConnectionResponse:
"""Get a connection."""
_ = workspace_id # Not used (yet)
airbyte_instance = get_airbyte_server_instance(
api_key=api_key,
client_id=client_id,
client_secret=client_secret,
api_root=api_root,
)
response = airbyte_instance.connections.get_connection(
Expand All @@ -244,7 +295,8 @@ def run_connection(
connection_id: str,
*,
api_root: str,
api_key: str,
client_id: SecretString,
client_secret: SecretString,
) -> models.JobResponse:
"""Get a connection.
Expand All @@ -254,7 +306,8 @@ def run_connection(
"""
_ = workspace_id # Not used (yet)
airbyte_instance = get_airbyte_server_instance(
api_key=api_key,
client_id=client_id,
client_secret=client_secret,
api_root=api_root,
)
response = airbyte_instance.jobs.create_job(
Expand Down Expand Up @@ -284,11 +337,13 @@ def get_job_logs(
limit: int = 20,
*,
api_root: str,
api_key: str,
client_id: SecretString,
client_secret: SecretString,
) -> list[models.JobResponse]:
"""Get a job's logs."""
airbyte_instance = get_airbyte_server_instance(
api_key=api_key,
client_id=client_id,
client_secret=client_secret,
api_root=api_root,
)
response: api.ListJobsResponse = airbyte_instance.jobs.list_jobs(
Expand All @@ -315,11 +370,13 @@ def get_job_info(
job_id: int,
*,
api_root: str,
api_key: str,
client_id: SecretString,
client_secret: SecretString,
) -> models.JobResponse:
"""Get a job."""
airbyte_instance = get_airbyte_server_instance(
api_key=api_key,
client_id=client_id,
client_secret=client_secret,
api_root=api_root,
)
response = airbyte_instance.jobs.get_job(
Expand All @@ -344,13 +401,15 @@ def create_source(
name: str,
*,
workspace_id: str,
config: dict[str, Any],
config: models.SourceConfiguration,
api_root: str,
api_key: str,
client_id: SecretString,
client_secret: SecretString,
) -> models.SourceResponse:
"""Get a connection."""
airbyte_instance = get_airbyte_server_instance(
api_key=api_key,
client_id=client_id,
client_secret=client_secret,
api_root=api_root,
)
response: api.CreateSourceResponse = airbyte_instance.sources.create_source(
Expand All @@ -375,11 +434,13 @@ def get_source(
source_id: str,
*,
api_root: str,
api_key: str,
client_id: SecretString,
client_secret: SecretString,
) -> models.SourceResponse:
"""Get a connection."""
airbyte_instance = get_airbyte_server_instance(
api_key=api_key,
client_id=client_id,
client_secret=client_secret,
api_root=api_root,
)
response = airbyte_instance.sources.get_source(
Expand All @@ -401,13 +462,15 @@ def delete_source(
source_id: str,
*,
api_root: str,
api_key: str,
client_id: SecretString,
client_secret: SecretString,
workspace_id: str | None = None,
) -> None:
"""Delete a source."""
_ = workspace_id # Not used (yet)
airbyte_instance = get_airbyte_server_instance(
api_key=api_key,
client_id=client_id,
client_secret=client_secret,
api_root=api_root,
)
response = airbyte_instance.sources.delete_source(
Expand All @@ -431,13 +494,15 @@ def create_destination(
name: str,
*,
workspace_id: str,
config: dict[str, Any],
config: models.DestinationConfiguration,
api_root: str,
api_key: str,
client_id: SecretString,
client_secret: SecretString,
) -> models.DestinationResponse:
"""Get a connection."""
airbyte_instance = get_airbyte_server_instance(
api_key=api_key,
client_id=client_id,
client_secret=client_secret,
api_root=api_root,
)
response: api.CreateDestinationResponse = airbyte_instance.destinations.create_destination(
Expand All @@ -460,11 +525,13 @@ def get_destination(
destination_id: str,
*,
api_root: str,
api_key: str,
client_id: SecretString,
client_secret: SecretString,
) -> models.DestinationResponse:
"""Get a connection."""
airbyte_instance = get_airbyte_server_instance(
api_key=api_key,
client_id=client_id,
client_secret=client_secret,
api_root=api_root,
)
response = airbyte_instance.destinations.get_destination(
Expand Down Expand Up @@ -510,13 +577,15 @@ def delete_destination(
destination_id: str,
*,
api_root: str,
api_key: str,
client_id: SecretString,
client_secret: SecretString,
workspace_id: str | None = None,
) -> None:
"""Delete a destination."""
_ = workspace_id # Not used (yet)
airbyte_instance = get_airbyte_server_instance(
api_key=api_key,
client_id=client_id,
client_secret=client_secret,
api_root=api_root,
)
response = airbyte_instance.destinations.delete_destination(
Expand All @@ -542,14 +611,16 @@ def create_connection(
source_id: str,
destination_id: str,
api_root: str,
api_key: str,
client_id: SecretString,
client_secret: SecretString,
workspace_id: str | None = None,
prefix: str,
selected_stream_names: list[str],
) -> models.ConnectionResponse:
_ = workspace_id # Not used (yet)
airbyte_instance = get_airbyte_server_instance(
api_key=api_key,
client_id=client_id,
client_secret=client_secret,
api_root=api_root,
)
stream_configurations: list[models.StreamConfiguration] = []
Expand Down Expand Up @@ -587,13 +658,15 @@ def get_connection_by_name(
connection_name: str,
*,
api_root: str,
api_key: str,
client_id: SecretString,
client_secret: SecretString,
) -> models.ConnectionResponse:
"""Get a connection."""
connections = list_connections(
workspace_id=workspace_id,
api_key=api_key,
api_root=api_root,
client_id=client_id,
client_secret=client_secret,
)
found: list[models.ConnectionResponse] = [
connection for connection in connections if connection.name == connection_name
Expand All @@ -620,11 +693,13 @@ def delete_connection(
connection_id: str,
api_root: str,
workspace_id: str,
api_key: str,
client_id: SecretString,
client_secret: SecretString,
) -> None:
_ = workspace_id # Not used (yet)
airbyte_instance = get_airbyte_server_instance(
api_key=api_key,
client_id=client_id,
client_secret=client_secret,
api_root=api_root,
)
response = airbyte_instance.connections.delete_connection(
Expand Down
Loading

0 comments on commit d8b40d1

Please sign in to comment.