Skip to content

Commit

Permalink
multiple fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronsteers committed Dec 9, 2024
1 parent d8b40d1 commit 7c9a6ac
Show file tree
Hide file tree
Showing 28 changed files with 783 additions and 599 deletions.
84 changes: 46 additions & 38 deletions airbyte/_util/api_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
from __future__ import annotations

import json
from pydoc import cli
import token
from typing import TYPE_CHECKING, Any

import airbyte_api
Expand All @@ -28,18 +26,19 @@
AirbyteMultipleResourcesError,
PyAirbyteInputError,
)
from airbyte.secrets.base import SecretString


if TYPE_CHECKING:
from collections.abc import Callable

from airbyte.secrets.base import SecretString


JOB_WAIT_INTERVAL_SECS = 2.0
JOB_WAIT_TIMEOUT_SECS_DEFAULT = 60 * 60 # 1 hour
CLOUD_API_ROOT = "https://api.airbyte.com/v1"

# Helper functions
SourceConfiguration = Any


def status_ok(status_code: int) -> bool:
Expand Down Expand Up @@ -100,80 +99,89 @@ def get_workspace(
)


# Get bearer token
# List resources


def get_bearer_token(
def list_connections(
workspace_id: str,
*,
api_root: str,
client_id: SecretString,
client_secret: SecretString,
) -> str:
"""Get a bearer token."""
name: str | None = None,
name_filter: Callable[[str], bool] | None = None,
) -> list[models.ConnectionResponse]:
"""Get a connection."""
if name and name_filter:
raise PyAirbyteInputError(message="You can provide name or name_filter, but not both.")

name_filter = (lambda n: n == name) if name else name_filter or (lambda _: True)

_ = workspace_id # Not used (yet)
airbyte_instance = get_airbyte_server_instance(
client_id=client_id,
client_secret=client_secret,
api_root=api_root,
)
response = airbyte_instance.oauth.get_bearer_token(
api.GetBearerTokenRequest(
workspace_id=workspace_id,
client_id=client_id,
client_secret=client_secret,
response = airbyte_instance.connections.list_connections(
api.ListConnectionsRequest(
workspace_ids=[workspace_id],
),
)
if status_ok(response.status_code) and response.bearer_token_response:
return response.bearer_token_response.access_token

raise AirbyteError(
message="Could not get bearer token.",
response=response,
)


# List resources
if not status_ok(response.status_code) and response.connections_response:
raise AirbyteError(
context={
"workspace_id": workspace_id,
"response": response,
}
)
assert response.connections_response is not None
return [
connection
for connection in response.connections_response.data
if name_filter(connection.name)
]


def list_connections(
def list_workspaces(
workspace_id: str,
*,
api_root: str,
client_id: SecretString,
client_secret: SecretString,
name: str | None = None,
name_filter: Callable[[str], bool] | None = None,
) -> list[models.ConnectionResponse]:
) -> list[models.WorkspaceResponse]:
"""Get a connection."""
if name and name_filter:
raise PyAirbyteInputError(message="You can provide name or name_filter, but not both.")

name_filter = (lambda n: n == name) if name else name_filter or (lambda _: True)

_ = workspace_id # Not used (yet)
airbyte_instance = get_airbyte_server_instance(
airbyte_instance: airbyte_api.AirbyteAPI = get_airbyte_server_instance(
client_id=client_id,
client_secret=client_secret,
api_root=api_root,
)
response = airbyte_instance.connections.list_connections(
api.ListConnectionsRequest(

response: api.ListWorkspacesResponse = airbyte_instance.workspaces.list_workspaces(
api.ListWorkspacesRequest(
workspace_ids=[workspace_id],
),
)

if not status_ok(response.status_code) and response.connections_response:
if not status_ok(response.status_code) and response.workspaces_response:
raise AirbyteError(
context={
"workspace_id": workspace_id,
"response": response,
}
)
assert response.connections_response is not None
assert response.workspaces_response is not None
return [
connection
for connection in response.connections_response.data
if name_filter(connection.name)
workspace for workspace in response.workspaces_response.data if name_filter(workspace.name)
]


Expand All @@ -193,12 +201,12 @@ def list_sources(
name_filter = (lambda n: n == name) if name else name_filter or (lambda _: True)

_ = workspace_id # Not used (yet)
airbyte_instance = get_airbyte_server_instance(
airbyte_instance: airbyte_api.AirbyteAPI = get_airbyte_server_instance(
client_id=client_id,
client_secret=client_secret,
api_root=api_root,
)
response = airbyte_instance.sources.list_sources(
response: api.ListSourcesResponse = airbyte_instance.sources.list_sources(
api.ListSourcesRequest(
workspace_ids=[workspace_id],
),
Expand Down Expand Up @@ -401,7 +409,7 @@ def create_source(
name: str,
*,
workspace_id: str,
config: models.SourceConfiguration,
config: models.SourceConfiguration | dict[str, Any],
api_root: str,
client_id: SecretString,
client_secret: SecretString,
Expand All @@ -416,7 +424,7 @@ def create_source(
models.SourceCreateRequest(
name=name,
workspace_id=workspace_id,
configuration=config,
configuration=config, # type: ignore [attr-type] # Speakeasy API wants a dataclass, not a dict
definition_id=None, # Not used alternative to config.sourceType.
secret_id=None, # For OAuth, not yet supported
),
Expand Down Expand Up @@ -494,7 +502,7 @@ def create_destination(
name: str,
*,
workspace_id: str,
config: models.DestinationConfiguration,
config: models.DestinationConfiguration | dict[str, Any],
api_root: str,
client_id: SecretString,
client_secret: SecretString,
Expand All @@ -509,7 +517,7 @@ def create_destination(
models.DestinationCreateRequest(
name=name,
workspace_id=workspace_id,
configuration=config,
configuration=config, # type: ignore [attr-type] # Speakeasy API wants a dataclass, not a dict
),
)
if status_ok(response.status_code) and response.destination_response:
Expand Down
2 changes: 1 addition & 1 deletion airbyte/_util/text_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,5 +18,5 @@ def generate_random_suffix() -> str:
which will be monotonically sortable. It is not guaranteed to be unique but
is sufficient for small-scale and medium-scale use cases.
"""
ulid_str = generate_ulid()
ulid_str = generate_ulid().lower()
return ulid_str[:6] + ulid_str[-3:]
32 changes: 27 additions & 5 deletions airbyte/caches/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from __future__ import annotations

from pathlib import Path
from typing import IO, TYPE_CHECKING, Any, final
from typing import IO, TYPE_CHECKING, Any, ClassVar, final

import pandas as pd
import pyarrow as pa
Expand All @@ -30,7 +30,6 @@

from airbyte._message_iterators import AirbyteMessageIterator
from airbyte.caches._state_backend_base import StateBackendBase
from airbyte.datasets._base import DatasetBase
from airbyte.progress import ProgressTracker
from airbyte.shared.sql_processor import SqlProcessorBase
from airbyte.shared.state_providers import StateProviderBase
Expand All @@ -56,14 +55,24 @@ class CacheBase(SqlConfig, AirbyteWriterInterface):
"""Whether to clean up the cache after use."""

_name: str = PrivateAttr()
_paired_destination_name: str

_sql_processor_class: type[SqlProcessorBase] = PrivateAttr()
_sql_processor_class: ClassVar[type[SqlProcessorBase]]
_read_processor: SqlProcessorBase = PrivateAttr()

_catalog_backend: CatalogBackendBase = PrivateAttr()
_state_backend: StateBackendBase = PrivateAttr()

paired_destination_name: ClassVar[str | None] = None
paired_destination_config_class: ClassVar[type | None] = None

@property
def paired_destination_config(self) -> Any | dict[str, Any]: # noqa: ANN401 # Allow Any return type
"""Return a dictionary of destination configuration values."""
raise NotImplementedError(
f"The type '{type(self).__name__}' does not define an equivalent destination "
"configuration."
)

def __init__(self, **data: Any) -> None: # noqa: ANN401
"""Initialize the cache and backends."""
super().__init__(**data)
Expand Down Expand Up @@ -228,6 +237,19 @@ def streams(self) -> dict[str, CachedDataset]:

return result

@final
def __len__(self) -> int:
"""Gets the number of streams."""
return len(self._catalog_backend.stream_names)

@final
def __bool__(self) -> bool:
"""Always True.
This is needed so that caches with zero streams are not falsely (None-like).
"""
return True

def get_state_provider(
self,
source_name: str,
Expand Down Expand Up @@ -271,7 +293,7 @@ def register_source(
incoming_stream_names=stream_names,
)

def __getitem__(self, stream: str) -> DatasetBase:
def __getitem__(self, stream: str) -> CachedDataset:
"""Return a dataset by stream name."""
return self.streams[stream]

Expand Down
22 changes: 18 additions & 4 deletions airbyte/caches/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,22 +17,36 @@

from __future__ import annotations

from typing import NoReturn
from typing import TYPE_CHECKING, ClassVar, NoReturn

from pydantic import PrivateAttr
from airbyte_api.models import DestinationBigquery

from airbyte._processors.sql.bigquery import BigQueryConfig, BigQuerySqlProcessor
from airbyte.caches.base import (
CacheBase,
)
from airbyte.constants import DEFAULT_ARROW_MAX_CHUNK_SIZE
from airbyte.destinations._translate_cache_to_dest import (
bigquery_cache_to_destination_configuration,
)


if TYPE_CHECKING:
from airbyte.shared.sql_processor import SqlProcessorBase


class BigQueryCache(BigQueryConfig, CacheBase):
"""The BigQuery cache implementation."""

_sql_processor_class: type[BigQuerySqlProcessor] = PrivateAttr(default=BigQuerySqlProcessor)
_paired_destination_name: str = "destination-bigquery"
_sql_processor_class: ClassVar[type[SqlProcessorBase]] = BigQuerySqlProcessor

paired_destination_name: ClassVar[str | None] = "destination-bigquery"
paired_destination_config_class: ClassVar[type | None] = DestinationBigquery

@property
def paired_destination_config(self) -> DestinationBigquery:
"""Return a dictionary of destination configuration values."""
return bigquery_cache_to_destination_configuration(cache=self)

def get_arrow_dataset(
self,
Expand Down
19 changes: 16 additions & 3 deletions airbyte/caches/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,18 @@
from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, ClassVar

from airbyte_api.models import DestinationDuckdb
from duckdb_engine import DuckDBEngineWarning
from pydantic import PrivateAttr

from airbyte._processors.sql.duckdb import DuckDBConfig, DuckDBSqlProcessor
from airbyte.caches.base import CacheBase
from airbyte.destinations._translate_cache_to_dest import duckdb_cache_to_destination_configuration


if TYPE_CHECKING:
from airbyte.shared.sql_processor import SqlProcessorBase


# Suppress warnings from DuckDB about reflection on indices.
Expand All @@ -37,8 +43,15 @@
class DuckDBCache(DuckDBConfig, CacheBase):
"""A DuckDB cache."""

_sql_processor_class: type[DuckDBSqlProcessor] = PrivateAttr(default=DuckDBSqlProcessor)
_paired_destination_name: str = "destination-duckdb"
_sql_processor_class: ClassVar[type[SqlProcessorBase]] = DuckDBSqlProcessor

paired_destination_name: ClassVar[str | None] = "destination-duckdb"
paired_destination_config_class: ClassVar[type | None] = DestinationDuckdb

@property
def paired_destination_config(self) -> DestinationDuckdb:
"""Return a dictionary of destination configuration values."""
return duckdb_cache_to_destination_configuration(cache=self)


# Expose the Cache class and also the Config class.
Expand Down
Loading

0 comments on commit 7c9a6ac

Please sign in to comment.