Skip to content

Commit

Permalink
Keep mapping of connections. Fix unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: lorenabalan <[email protected]>
  • Loading branch information
lorenabalan committed Jan 27, 2022
1 parent 03c8c2d commit 2af5c5f
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 118 deletions.
104 changes: 50 additions & 54 deletions kedro/extras/datasets/pandas/sql_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,60 +207,53 @@ def __init__(
self._load_args["table_name"] = table_name
self._save_args["name"] = table_name

self._load_args["con"] = self._save_args["con"] = credentials["con"]
self.create_connection(self._load_args["con"])
self._connection_str = credentials["con"]
self.create_connection(self._connection_str)

@classmethod
def create_connection(cls, con):
"""Create singleton connection to be used
across all instances of `SQLTableDataSet`.
def create_connection(cls, connection_str: str) -> None:
"""Given a connection string, create singleton connection
to be used across all instances of `SQLTableDataSet` that
need to connect to the same source.
"""
if hasattr(cls, "engine"):
if connection_str in getattr(cls, "engines", {}):
return

engine = create_engine(con)
cls.engine = engine
engines = cls.engines if hasattr(cls, "engines") else {} # type:ignore

try:
engine = create_engine(connection_str)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc

engines[connection_str] = engine
cls.engines = engines # type: ignore

def _describe(self) -> Dict[str, Any]:
load_args = self._load_args.copy()
save_args = self._save_args.copy()
load_args = copy.deepcopy(self._load_args)
save_args = copy.deepcopy(self._save_args)
del load_args["table_name"]
del load_args["con"]
del save_args["name"]
del save_args["con"]
return dict(
table_name=self._load_args["table_name"],
load_args=load_args,
save_args=save_args,
)

def _load(self) -> pd.DataFrame:
load_args = copy.deepcopy(self._load_args)
load_args["con"] = self.engine # type: ignore

try:
return pd.read_sql_table(**load_args)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc
engine = self.engines.get(self._connection_str) # type:ignore
return pd.read_sql_table(con=engine, **self._load_args)

def _save(self, data: pd.DataFrame) -> None:
save_args = copy.deepcopy(self._save_args)
save_args["con"] = self.engine # type: ignore

try:
data.to_sql(**save_args)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc
engine = self.engines.get(self._connection_str) # type: ignore
data.to_sql(con=engine, **self._save_args)

def _exists(self) -> bool:
eng = self.engine # type: ignore
eng = self.engines[self._connection_str] # type: ignore
schema = self._load_args.get("schema", None)
exists = self._load_args["table_name"] in eng.table_names(schema)
# eng.dispose()
return exists


Expand Down Expand Up @@ -392,45 +385,48 @@ def __init__( # pylint: disable=too-many-arguments
self._protocol = protocol
self._fs = fsspec.filesystem(self._protocol, **_fs_credentials, **_fs_args)
self._filepath = path
self._load_args["con"] = credentials["con"]
self.create_connection(self._load_args["con"])
self._connection_str = credentials["con"]
self.create_connection(self._connection_str)

@classmethod
def create_connection(cls, con):
"""Create singleton connection to be used
across all instances of `SQLQueryDataSet`.
def create_connection(cls, connection_str: str) -> None:
"""Given a connection string, create singleton connection
to be used across all instances of `SQLQueryDataSet` that
need to connect to the same source.
"""
if hasattr(cls, "engine"):
if connection_str in getattr(cls, "engines", {}):
return

engine = create_engine(con)
cls.engine = engine
engines = cls.engines if hasattr(cls, "engines") else {} # type:ignore

try:
engine = create_engine(connection_str)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc

engines[connection_str] = engine
cls.engines = engines # type: ignore

def _describe(self) -> Dict[str, Any]:
load_args = copy.deepcopy(self._load_args)
desc = {}
desc["sql"] = str(load_args.pop("sql", None))
desc["filepath"] = str(self._filepath)
del load_args["con"]
desc["load_args"] = str(load_args)

return desc
return dict(
sql=str(load_args.pop("sql", None)),
filepath=str(self._filepath),
load_args=str(load_args),
)

def _load(self) -> pd.DataFrame:
load_args = copy.deepcopy(self._load_args)
load_args["con"] = self.engine # type: ignore
engine = self.engines[self._connection_str] # type: ignore

if self._filepath:
load_path = get_filepath_str(PurePosixPath(self._filepath), self._protocol)
with self._fs.open(load_path, mode="r") as fs_file:
load_args["sql"] = fs_file.read()

try:
return pd.read_sql_query(**load_args)
except ImportError as import_error:
raise _get_missing_module_error(import_error) from import_error
except NoSuchModuleError as exc:
raise _get_sql_alchemy_missing_error() from exc
return pd.read_sql_query(con=engine, **load_args)

def _save(self, data: pd.DataFrame) -> None:
raise DataSetError("`save` is not supported on SQLQueryDataSet")
Loading

0 comments on commit 2af5c5f

Please sign in to comment.