Skip to content

Commit

Permalink
chore: Convert typings to mypy (#311)
Browse files Browse the repository at this point in the history
Signed-off-by: Tao Feng <[email protected]>
  • Loading branch information
dorianj authored Aug 18, 2020
1 parent 8eb7708 commit 1d3a274
Show file tree
Hide file tree
Showing 171 changed files with 1,626 additions and 2,260 deletions.
6 changes: 5 additions & 1 deletion databuilder/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ test_unit:
lint:
flake8 .

.PHONY: mypy
mypy:
mypy .

.PHONY: test
test: test_unit lint
test: test_unit lint mypy

3 changes: 1 addition & 2 deletions databuilder/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -646,8 +646,7 @@ job.launch()
The `RedashDashboardExtractor` extracts raw queries from each dashboard. You may optionally use these queries to parse out relations to tables in Amundsen. A table parser can be provided in the configuration for the `RedashDashboardExtractor`, as seen above. This function should have type signature `(RedashVisualizationWidget) -> Iterator[TableRelationData]`. For example:

```python
def parse_tables(viz_widget):
# type: (RedashVisualiationWidget) -> Iterator[TableRelationData]
def parse_tables(viz_widget: RedashVisualiationWidget) -> Iterator[TableRelationData]:
# Each viz_widget corresponds to one query.
# viz_widget.data_source_id is the ID of the target DB in Redash.
# viz_widget.raw_query is the raw query (e.g., SQL).
Expand Down
12 changes: 4 additions & 8 deletions databuilder/databuilder/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ class Scoped(object, metaclass=abc.ABCMeta):
"""

@abc.abstractmethod
def init(self, conf):
# type: (ConfigTree) -> None
def init(self, conf: ConfigTree) -> None:
"""
All scoped instance is expected to be lazily initialized. Means that
__init__ should not have any heavy operation such as service call.
Expand All @@ -46,26 +45,23 @@ def init(self, conf):
pass

@abc.abstractmethod
def get_scope(self):
# type: () -> str
def get_scope(self) -> str:
"""
A scope for the config. Typesafe config supports nested config.
Scope, string, is used to basically peel off nested config
:return:
"""
return ''

def close(self):
# type: () -> None
def close(self) -> None:
"""
Anything that needs to be cleaned up after the use of the instance.
:return: None
"""
pass

@classmethod
def get_scoped_conf(cls, conf, scope):
# type: (ConfigTree, str) -> ConfigTree
def get_scoped_conf(cls, conf: ConfigTree, scope: str) -> ConfigTree:
"""
Convenient method to provide scoped method.
Expand Down
9 changes: 3 additions & 6 deletions databuilder/databuilder/callback/call_back.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,25 +16,23 @@ class Callback(object, metaclass=abc.ABCMeta):
"""

@abc.abstractmethod
def on_success(self):
# type: () -> None
def on_success(self) -> None:
"""
A call back method that will be called when operation is successful
:return: None
"""
pass

@abc.abstractmethod
def on_failure(self):
# type: () -> None
def on_failure(self) -> None:
"""
A call back method that will be called when operation failed
:return: None
"""
pass


def notify_callbacks(callbacks, is_success):
def notify_callbacks(callbacks: List[Callback], is_success: bool) -> None:
"""
A Utility method that notifies callback. If any callback fails it will still go through all the callbacks,
and raise the last exception it experienced.
Expand All @@ -43,7 +41,6 @@ def notify_callbacks(callbacks, is_success):
:param is_success:
:return:
"""
# type: (List[Callback], bool) -> None

if not callbacks:
LOGGER.info('No callbacks to notify')
Expand Down
20 changes: 7 additions & 13 deletions databuilder/databuilder/extractor/athena_metadata_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,7 @@ class AthenaMetadataExtractor(Extractor):
{WHERE_CLAUSE_SUFFIX_KEY: ' ', CATALOG_KEY: DEFAULT_CLUSTER_NAME}
)

def init(self, conf):
# type: (ConfigTree) -> None
def init(self, conf: ConfigTree) -> None:
conf = conf.with_fallback(AthenaMetadataExtractor.DEFAULT_CONFIG)
self._cluster = '{}'.format(conf.get_string(AthenaMetadataExtractor.CATALOG_KEY))

Expand All @@ -60,23 +59,20 @@ def init(self, conf):
.with_fallback(ConfigFactory.from_dict({SQLAlchemyExtractor.EXTRACT_SQL: self.sql_stmt}))

self._alchemy_extractor.init(sql_alch_conf)
self._extract_iter = None # type: Union[None, Iterator]
self._extract_iter: Union[None, Iterator] = None

def extract(self):
# type: () -> Union[TableMetadata, None]
def extract(self) -> Union[TableMetadata, None]:
if not self._extract_iter:
self._extract_iter = self._get_extract_iter()
try:
return next(self._extract_iter)
except StopIteration:
return None

def get_scope(self):
# type: () -> str
def get_scope(self) -> str:
return 'extractor.athena_metadata'

def _get_extract_iter(self):
# type: () -> Iterator[TableMetadata]
def _get_extract_iter(self) -> Iterator[TableMetadata]:
"""
Using itertools.groupby and raw level iterator, it groups to table and yields TableMetadata
:return:
Expand All @@ -97,8 +93,7 @@ def _get_extract_iter(self):
'',
columns)

def _get_raw_extract_iter(self):
# type: () -> Iterator[Dict[str, Any]]
def _get_raw_extract_iter(self) -> Iterator[Dict[str, Any]]:
"""
Provides iterator of result row from SQLAlchemy extractor
:return:
Expand All @@ -108,8 +103,7 @@ def _get_raw_extract_iter(self):
yield row
row = self._alchemy_extractor.extract()

def _get_table_key(self, row):
# type: (Dict[str, Any]) -> Union[TableKey, None]
def _get_table_key(self, row: Dict[str, Any]) -> Union[TableKey, None]:
"""
Table key consists of schema and table name
:param row:
Expand Down
40 changes: 20 additions & 20 deletions databuilder/databuilder/extractor/base_bigquery_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from googleapiclient.discovery import build
import httplib2
from pyhocon import ConfigTree # noqa: F401
from typing import List, Any # noqa: F401
from typing import Any, Dict, Iterator, List # noqa: F401

from databuilder.extractor.base_extractor import Extractor

Expand All @@ -28,13 +28,12 @@ class BaseBigQueryExtractor(Extractor):
CRED_KEY = 'project_cred'
PAGE_SIZE_KEY = 'page_size'
FILTER_KEY = 'filter'
_DEFAULT_SCOPES = ['https://www.googleapis.com/auth/bigquery.readonly', ]
_DEFAULT_SCOPES = ['https://www.googleapis.com/auth/bigquery.readonly']
DEFAULT_PAGE_SIZE = 300
NUM_RETRIES = 3
DATE_LENGTH = 8

def init(self, conf):
# type: (ConfigTree) -> None
def init(self, conf: ConfigTree) -> None:
# should use key_path, or cred_key if the former doesn't exist
self.key_path = conf.get_string(BaseBigQueryExtractor.KEY_PATH_KEY, None)
self.cred_key = conf.get_string(BaseBigQueryExtractor.CRED_KEY, None)
Expand All @@ -55,33 +54,37 @@ def init(self, conf):
google.oauth2.service_account.Credentials.from_service_account_info(
service_account_info, scopes=self._DEFAULT_SCOPES))
else:
credentials, _ = google.auth.default(scopes=self._DEFAULT_SCOPES)
# FIXME: mypy can't find this attribute
google_auth: Any = getattr(google, 'auth')
credentials, _ = google_auth.default(scopes=self._DEFAULT_SCOPES)

http = httplib2.Http()
authed_http = google_auth_httplib2.AuthorizedHttp(credentials, http=http)
self.bigquery_service = build('bigquery', 'v2', http=authed_http, cache_discovery=False)
self.logging_service = build('logging', 'v2', http=authed_http, cache_discovery=False)
self.iter = iter(self._iterate_over_tables())
self.iter: Iterator[Any] = iter([])

def extract(self):
# type: () -> Any
def extract(self) -> Any:
try:
return next(self.iter)
except StopIteration:
return None

def _is_sharded_table(self, table_id):
def _is_sharded_table(self, table_id: str) -> bool:
suffix = table_id[-BaseBigQueryExtractor.DATE_LENGTH:]
return suffix.isdigit()

def _iterate_over_tables(self):
# type: () -> Any
def _iterate_over_tables(self) -> Any:
for dataset in self._retrieve_datasets():
for entry in self._retrieve_tables(dataset):
yield(entry)
yield entry

def _retrieve_datasets(self):
# type: () -> List[DatasetRef]
# TRICKY: this function has different return types between different subclasses,
# so type as Any. Should probably refactor to remove this unclear sharing.
def _retrieve_tables(self, dataset: DatasetRef) -> Any:
pass

def _retrieve_datasets(self) -> List[DatasetRef]:
datasets = []
for page in self._page_dataset_list_results():
if 'datasets' not in page:
Expand All @@ -94,8 +97,7 @@ def _retrieve_datasets(self):

return datasets

def _page_dataset_list_results(self):
# type: () -> Any
def _page_dataset_list_results(self) -> Iterator[Any]:
response = self.bigquery_service.datasets().list(
projectId=self.project_id,
all=False, # Do not return hidden datasets
Expand All @@ -116,8 +118,7 @@ def _page_dataset_list_results(self):
else:
response = None

def _page_table_list_results(self, dataset):
# type: (DatasetRef) -> Any
def _page_table_list_results(self, dataset: DatasetRef) -> Iterator[Dict[str, Any]]:
response = self.bigquery_service.tables().list(
projectId=dataset.projectId,
datasetId=dataset.datasetId,
Expand All @@ -137,6 +138,5 @@ def _page_table_list_results(self, dataset):
else:
response = None

def get_scope(self):
# type: () -> str
def get_scope(self) -> str:
return 'extractor.bigquery_table_metadata'
9 changes: 3 additions & 6 deletions databuilder/databuilder/extractor/base_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,15 @@ class Extractor(Scoped):
"""

@abc.abstractmethod
def init(self, conf):
# type: (ConfigTree) -> None
def init(self, conf: ConfigTree) -> None:
pass

@abc.abstractmethod
def extract(self):
# type: () -> Any
def extract(self) -> Any:
"""
:return: Provides a record or None if no more to extract
"""
return None

def get_scope(self):
# type: () -> str
def get_scope(self) -> str:
return 'extractor'
36 changes: 19 additions & 17 deletions databuilder/databuilder/extractor/bigquery_metadata_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,14 @@
# SPDX-License-Identifier: Apache-2.0

import logging
from collections import namedtuple

from pyhocon import ConfigTree # noqa: F401
from typing import List, Any # noqa: F401
from typing import cast, Any, Dict, List, Set # noqa: F401

from databuilder.extractor.base_bigquery_extractor import BaseBigQueryExtractor
from databuilder.extractor.base_bigquery_extractor import BaseBigQueryExtractor, DatasetRef
from databuilder.models.table_metadata import TableMetadata, ColumnMetadata


DatasetRef = namedtuple('DatasetRef', ['datasetId', 'projectId'])
TableKey = namedtuple('TableKey', ['schema', 'table_name'])

LOGGER = logging.getLogger(__name__)


Expand All @@ -29,13 +25,12 @@ class BigQueryMetadataExtractor(BaseBigQueryExtractor):
column name.
"""

def init(self, conf):
# type: (ConfigTree) -> None
def init(self, conf: ConfigTree) -> None:
BaseBigQueryExtractor.init(self, conf)
self.grouped_tables = set([])
self.grouped_tables: Set[str] = set([])
self.iter = iter(self._iterate_over_tables())

def _retrieve_tables(self, dataset):
# type: () -> Any
def _retrieve_tables(self, dataset: DatasetRef) -> Any:
for page in self._page_table_list_results(dataset):
if 'tables' not in page:
continue
Expand Down Expand Up @@ -66,13 +61,14 @@ def _retrieve_tables(self, dataset):

# BigQuery tables also have interesting metadata about partitioning
# data location (EU/US), mod/create time, etc... Extract that some other time?
cols = []
cols: List[ColumnMetadata] = []
# Not all tables have schemas
if 'schema' in table:
schema = table['schema']
if 'fields' in schema:
total_cols = 0
for column in schema['fields']:
# TRICKY: this mutates :cols:
total_cols = self._iterate_over_cols('', column, cols, total_cols + 1)

table_meta = TableMetadata(
Expand All @@ -86,8 +82,11 @@ def _retrieve_tables(self, dataset):

yield(table_meta)

def _iterate_over_cols(self, parent, column, cols, total_cols):
# type: (str, str, List[ColumnMetadata()], int) -> int
def _iterate_over_cols(self,
parent: str,
column: Dict[str, str],
cols: List[ColumnMetadata],
total_cols: int) -> int:
if len(parent) > 0:
col_name = '{parent}.{field}'.format(parent=parent, field=column['name'])
else:
Expand All @@ -102,7 +101,11 @@ def _iterate_over_cols(self, parent, column, cols, total_cols):
cols.append(col)
total_cols += 1
for field in column['fields']:
total_cols = self._iterate_over_cols(col_name, field, cols, total_cols)
# TODO field is actually a TableFieldSchema, per
# https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#TableFieldSchema
# however it's typed as str, which is incorrect. Work-around by casting.
field_casted = cast(Dict[str, str], field)
total_cols = self._iterate_over_cols(col_name, field_casted, cols, total_cols)
return total_cols
else:
col = ColumnMetadata(
Expand All @@ -113,6 +116,5 @@ def _iterate_over_cols(self, parent, column, cols, total_cols):
cols.append(col)
return total_cols + 1

def get_scope(self):
# type: () -> str
def get_scope(self) -> str:
return 'extractor.bigquery_table_metadata'
Loading

0 comments on commit 1d3a274

Please sign in to comment.