-
Notifications
You must be signed in to change notification settings - Fork 208
/
Copy pathbigquery_metadata_extractor.py
122 lines (101 loc) · 5.09 KB
/
bigquery_metadata_extractor.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright Contributors to the Amundsen project.
# SPDX-License-Identifier: Apache-2.0
import logging
from typing import (
Any, Dict, List, Set, cast,
)
from pyhocon import ConfigTree
from databuilder.extractor.base_bigquery_extractor import BaseBigQueryExtractor, DatasetRef
from databuilder.models.table_metadata import ColumnMetadata, TableMetadata
LOGGER = logging.getLogger(__name__)
class BigQueryMetadataExtractor(BaseBigQueryExtractor):
""" A metadata extractor for bigquery tables, taking the schema metadata
from the google cloud bigquery API's. This extractor goes through all visible
datasets in the project identified by project_id and iterates over all tables
it finds. A separate account is configurable through the key_path parameter,
which should point to a valid json file corresponding to a service account.
This extractor supports nested columns, which are delimited by a dot (.) in the
column name.
"""
def init(self, conf: ConfigTree) -> None:
BaseBigQueryExtractor.init(self, conf)
self.iter = iter(self._iterate_over_tables())
def _retrieve_tables(self, dataset: DatasetRef) -> Any:
grouped_tables: Set[str] = set([])
for page in self._page_table_list_results(dataset):
if 'tables' not in page:
continue
for table in page['tables']:
tableRef = table['tableReference']
table_id = tableRef['tableId']
# BigQuery tables that have 8 digits as last characters are
# considered date range tables and are grouped together in the UI.
# ( e.g. ga_sessions_20190101, ga_sessions_20190102, etc. )
if self._is_sharded_table(table_id):
# If the last eight characters are digits, we assume the table is of a table date range type
# and then we only need one schema definition
table_prefix = table_id[:-BigQueryMetadataExtractor.DATE_LENGTH]
if table_prefix in grouped_tables:
# If one table in the date range is processed, then ignore other ones
# (it adds too much metadata)
continue
table_id = table_prefix
grouped_tables.add(table_prefix)
table = self.bigquery_service.tables().get(
projectId=tableRef['projectId'],
datasetId=tableRef['datasetId'],
tableId=tableRef['tableId']).execute(num_retries=BigQueryMetadataExtractor.NUM_RETRIES)
# BigQuery tables also have interesting metadata about partitioning
# data location (EU/US), mod/create time, etc... Extract that some other time?
cols: List[ColumnMetadata] = []
# Not all tables have schemas
if 'schema' in table:
schema = table['schema']
if 'fields' in schema:
total_cols = 0
for column in schema['fields']:
# TRICKY: this mutates :cols:
total_cols = self._iterate_over_cols('', column, cols, total_cols + 1)
table_meta = TableMetadata(
database='bigquery',
cluster=tableRef['projectId'],
schema=tableRef['datasetId'],
name=table_id,
description=table.get('description', ''),
columns=cols,
is_view=table['type'] == 'VIEW')
yield(table_meta)
def _iterate_over_cols(self,
parent: str,
column: Dict[str, str],
cols: List[ColumnMetadata],
total_cols: int) -> int:
if len(parent) > 0:
col_name = f'{parent}.{column["name"]}'
else:
col_name = column['name']
if column['type'] == 'RECORD':
col = ColumnMetadata(
name=col_name,
description=column.get('description', ''),
col_type=column['type'],
sort_order=total_cols)
cols.append(col)
total_cols += 1
for field in column['fields']:
# TODO field is actually a TableFieldSchema, per
# https://cloud.google.com/bigquery/docs/reference/rest/v2/tables#TableFieldSchema
# however it's typed as str, which is incorrect. Work-around by casting.
field_casted = cast(Dict[str, str], field)
total_cols = self._iterate_over_cols(col_name, field_casted, cols, total_cols)
return total_cols
else:
col = ColumnMetadata(
name=col_name,
description=column.get('description', ''),
col_type=column['type'],
sort_order=total_cols)
cols.append(col)
return total_cols + 1
def get_scope(self) -> str:
return 'extractor.bigquery_table_metadata'