Skip to content

Commit

Permalink
database: fix ui issues
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Dec 4, 2023
1 parent 8768b50 commit 008f519
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 18 deletions.
5 changes: 3 additions & 2 deletions dvc/commands/imp_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def run(self):
dbt_config = merge(_get_dbt_config(self.repo.config), cli_dbt_config)

project_dir = self.repo.root_dir
if not conn_config or not dbt_config:
if not (conn_config or dbt_config):
if is_dbt_project(project_dir):
ui.write("Using", DBT_PROJECT_FILE, "for testing", styled=True)
else:
Expand Down Expand Up @@ -151,10 +151,11 @@ def add_parser(subparsers, parent_parser):

import_parser.set_defaults(func=CmdImportDb)

TEST_DB_HELP = "Test the database connection"
test_db_parser = subparsers.add_parser(
"test-db",
parents=[parent_parser],
description=append_doc_link(IMPORT_HELP, "test-db"),
description=append_doc_link(TEST_DB_HELP, "test-db"),
add_help=False,
)
test_db_parser.add_argument("--conn", dest="connection")
Expand Down
18 changes: 12 additions & 6 deletions dvc/database/serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,21 +6,26 @@
from agate import Table


def noop(_):
pass


@dataclass
class PandasSQLSerializer:
sql: "Union[sa.TextClause, str]"
con: "sa.Connection"
chunksize: int = 10_000

def to_csv(self, file: str) -> None:
def to_csv(self, file: str, progress=noop) -> None:
import pandas as pd

with open(file, mode="wb") as f:
idfs = pd.read_sql_query(self.sql, self.con, chunksize=self.chunksize)
for i, df in enumerate(idfs):
df.to_csv(f, header=i == 0, index=False)
progress(len(df))

def to_json(self, file: str) -> None:
def to_json(self, file: str, progress=noop) -> None: # noqa: ARG002
import pandas as pd

df = pd.read_sql_query(self.sql, self.con)
Expand All @@ -31,18 +36,19 @@ def to_json(self, file: str) -> None:
class AgateSerializer:
table: "Table"

def to_csv(self, file: str) -> None:
def to_csv(self, file: str, progress=noop) -> None: # noqa: ARG002
return self.table.to_csv(file)

def to_json(self, file: str) -> None:
def to_json(self, file: str, progress=noop) -> None: # noqa: ARG002
return self.table.to_json(file)


def export(
serializer: Union[PandasSQLSerializer, AgateSerializer],
file: str,
format: str = "csv", # noqa: A002
progress=noop,
) -> None:
if format == "csv":
return serializer.to_csv(file)
return serializer.to_json(file)
return serializer.to_csv(file, progress=progress)
return serializer.to_json(file, progress=progress)
40 changes: 30 additions & 10 deletions dvc/dependency/db.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import os
from contextlib import contextmanager, nullcontext
from typing import TYPE_CHECKING, Any, Dict, Iterator, Optional, Union
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterator, Optional, Union

from funcy import compact
from funcy import compact, log_durations

from dvc.exceptions import DvcException
from dvc.log import logger
Expand Down Expand Up @@ -53,6 +53,19 @@ def chdir(path):
os.chdir(wdir)


@contextmanager
def download_progress(to: "Output") -> Iterator[Callable[[int], Any]]:
from dvc.ui import ui
from dvc.ui._rich_progress import DbDownloadProgress

with log_durations(logger.debug, f"Saving to {to}"), DbDownloadProgress(
console=ui.error_console,
) as progress:
task = progress.add_task("Saving", total=None, output=to)
yield lambda n: progress.advance(task, advance=n)
progress.update(task, description="Saved", total=0)


class AbstractDependency(Dependency):
"""Dependency without workspace/fs/fs_path"""

Expand Down Expand Up @@ -102,7 +115,13 @@ def dumpd(self, **kwargs):
def update(self, rev=None):
"""nothing to update."""

def download(self, to, jobs=None, file_format=None, **kwargs): # noqa: ARG002
def download(
self,
to: "Output",
jobs: Optional[int] = None, # noqa: ARG002
file_format: Optional[str] = None,
**kwargs: Any,
) -> None:
from dvc.database import export, get_adapter

db_info = self.info.get(PARAM_DB, {})
Expand All @@ -129,10 +148,14 @@ def download(self, to, jobs=None, file_format=None, **kwargs): # noqa: ARG002
db.test_connection(onerror=status.stop)

file_format = file_format or db_info.get(PARAM_FILE_FORMAT, "csv")
assert file_format
with log_status("Executing query") as status, db.query(query) as serializer:
status.stop()
logger.debug("using serializer: %s", serializer)
with log_status(f"Saving to {to}", status=status):
return export(serializer, to.fs_path, format=file_format)
with download_progress(to) as progress:
return export(
serializer, to.fs_path, format=file_format, progress=progress
)


class DbtDependency(AbstractDependency):
Expand Down Expand Up @@ -246,8 +269,6 @@ def download(
jobs: Optional[int] = None, # noqa: ARG002
file_format: Optional[str] = None,
) -> None:
from dvc.ui import ui

from .repo import RepoDependency

project_dir = self.info.get(PARAM_DB, {}).get(self.PARAM_PROJECT_DIR, "")
Expand All @@ -263,7 +284,6 @@ def download(
project_path = os.path.join(wdir, project_dir) if project_dir else root
with ctx, chdir(project_path):
self._download_db(to, file_format=file_format)
ui.write(f"Saved file to {to}", styled=True)

def _download_db(
self,
Expand All @@ -290,8 +310,8 @@ def _download_db(
model, version=version, profile=profile, target=target
)
# NOTE: we keep everything in memory, and then export it out later.
with log_status(f"Saving to {to}"):
export(serializer, to.fs_path, format=file_format)
with download_progress(to) as progress:
export(serializer, to.fs_path, format=file_format, progress=progress)


DB_SCHEMA = {
Expand Down
22 changes: 22 additions & 0 deletions dvc/ui/_rich_progress.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
DownloadColumn,
MofNCompleteColumn,
Progress,
SpinnerColumn,
TextColumn,
TimeElapsedColumn,
TimeRemainingColumn,
Expand Down Expand Up @@ -52,3 +53,24 @@ def get_renderables(self):
yield self.make_tasks_table(summary_tasks)
self.columns = self.TRANSFER_COLS
yield self.make_tasks_table(other_tasks)


class DbDownloadProgress(RichProgress):
PROGRESS_COLS = (
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
TextColumn("[progress.download]{task.completed:,} rows"),
TextColumn("to [repr.filename]{task.fields[output]}"),
)
STATUS_COLS = (
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
TextColumn("to [repr.filename]{task.fields[output]}"),
)

def get_renderables(self):
if self.tasks:
(task, *_) = self.tasks
cols = self.PROGRESS_COLS if task.completed else self.STATUS_COLS
self.columns = cols[1:] if task.finished else cols
yield self.make_tasks_table(self.tasks)

0 comments on commit 008f519

Please sign in to comment.