diff --git a/bigquery/google/cloud/bigquery/_pandas_helpers.py b/bigquery/google/cloud/bigquery/_pandas_helpers.py index 5e73c9f58e22..bfbaf92bbe38 100644 --- a/bigquery/google/cloud/bigquery/_pandas_helpers.py +++ b/bigquery/google/cloud/bigquery/_pandas_helpers.py @@ -210,7 +210,7 @@ def list_columns_and_indexes(dataframe): """Return all index and column names with dtypes. Returns: - Sequence[Tuple[dtype, str]]: + Sequence[Tuple[str, dtype]]: Returns a sorted list of indexes and column names with corresponding dtypes. If an index is missing a name or has the same name as a column, the index is omitted. diff --git a/bigquery/google/cloud/bigquery/client.py b/bigquery/google/cloud/bigquery/client.py index cc53ffa22985..c33e119cbc74 100644 --- a/bigquery/google/cloud/bigquery/client.py +++ b/bigquery/google/cloud/bigquery/client.py @@ -1547,6 +1547,27 @@ def load_table_from_dataframe( if location is None: location = self.location + # If table schema is not provided, we try to fetch the existing table + # schema, and check if dataframe schema is compatible with it - except + # for WRITE_TRUNCATE jobs, the existing schema does not matter then. + if ( + not job_config.schema + and job_config.write_disposition != job.WriteDisposition.WRITE_TRUNCATE + ): + try: + table = self.get_table(destination) + except google.api_core.exceptions.NotFound: + table = None + else: + columns_and_indexes = frozenset( + name + for name, _ in _pandas_helpers.list_columns_and_indexes(dataframe) + ) + # schema fields not present in the dataframe are not needed + job_config.schema = [ + field for field in table.schema if field.name in columns_and_indexes + ] + job_config.schema = _pandas_helpers.dataframe_to_bq_schema( dataframe, job_config.schema ) diff --git a/bigquery/tests/unit/test_client.py b/bigquery/tests/unit/test_client.py index f31d8587322b..da3cee11e5d0 100644 --- a/bigquery/tests/unit/test_client.py +++ b/bigquery/tests/unit/test_client.py @@ -20,6 +20,7 @@ import gzip import io import json +import operator import unittest import warnings @@ -5279,15 +5280,23 @@ def test_load_table_from_file_bad_mode(self): def test_load_table_from_dataframe(self): from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES from google.cloud.bigquery import job + from google.cloud.bigquery.schema import SchemaField client = self._make_client() records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}] dataframe = pandas.DataFrame(records) + get_table_patch = mock.patch( + "google.cloud.bigquery.client.Client.get_table", + autospec=True, + return_value=mock.Mock( + schema=[SchemaField("id", "INTEGER"), SchemaField("age", "INTEGER")] + ), + ) load_patch = mock.patch( "google.cloud.bigquery.client.Client.load_table_from_file", autospec=True ) - with load_patch as load_table_from_file: + with load_patch as load_table_from_file, get_table_patch: client.load_table_from_dataframe(dataframe, self.TABLE_REF) load_table_from_file.assert_called_once_with( @@ -5314,15 +5323,23 @@ def test_load_table_from_dataframe(self): def test_load_table_from_dataframe_w_client_location(self): from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES from google.cloud.bigquery import job + from google.cloud.bigquery.schema import SchemaField client = self._make_client(location=self.LOCATION) records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}] dataframe = pandas.DataFrame(records) + get_table_patch = mock.patch( + "google.cloud.bigquery.client.Client.get_table", + autospec=True, + return_value=mock.Mock( + schema=[SchemaField("id", "INTEGER"), SchemaField("age", "INTEGER")] + ), + ) load_patch = mock.patch( "google.cloud.bigquery.client.Client.load_table_from_file", autospec=True ) - with load_patch as load_table_from_file: + with load_patch as load_table_from_file, get_table_patch: client.load_table_from_dataframe(dataframe, self.TABLE_REF) load_table_from_file.assert_called_once_with( @@ -5349,20 +5366,33 @@ def test_load_table_from_dataframe_w_client_location(self): def test_load_table_from_dataframe_w_custom_job_config(self): from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES from google.cloud.bigquery import job + from google.cloud.bigquery.schema import SchemaField client = self._make_client() records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}] dataframe = pandas.DataFrame(records) - job_config = job.LoadJobConfig() + job_config = job.LoadJobConfig( + write_disposition=job.WriteDisposition.WRITE_TRUNCATE + ) + get_table_patch = mock.patch( + "google.cloud.bigquery.client.Client.get_table", + autospec=True, + return_value=mock.Mock( + schema=[SchemaField("id", "INTEGER"), SchemaField("age", "INTEGER")] + ), + ) load_patch = mock.patch( "google.cloud.bigquery.client.Client.load_table_from_file", autospec=True ) - with load_patch as load_table_from_file: + with load_patch as load_table_from_file, get_table_patch as get_table: client.load_table_from_dataframe( dataframe, self.TABLE_REF, job_config=job_config, location=self.LOCATION ) + # no need to fetch and inspect table schema for WRITE_TRUNCATE jobs + assert not get_table.called + load_table_from_file.assert_called_once_with( client, mock.ANY, @@ -5378,6 +5408,7 @@ def test_load_table_from_dataframe_w_custom_job_config(self): sent_config = load_table_from_file.mock_calls[0][2]["job_config"] assert sent_config.source_format == job.SourceFormat.PARQUET + assert sent_config.write_disposition == job.WriteDisposition.WRITE_TRUNCATE @unittest.skipIf(pandas is None, "Requires `pandas`") @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") @@ -5421,7 +5452,12 @@ def test_load_table_from_dataframe_w_automatic_schema(self): "google.cloud.bigquery.client.Client.load_table_from_file", autospec=True ) - with load_patch as load_table_from_file: + get_table_patch = mock.patch( + "google.cloud.bigquery.client.Client.get_table", + autospec=True, + side_effect=google.api_core.exceptions.NotFound("Table not found"), + ) + with load_patch as load_table_from_file, get_table_patch: client.load_table_from_dataframe( dataframe, self.TABLE_REF, location=self.LOCATION ) @@ -5449,6 +5485,100 @@ def test_load_table_from_dataframe_w_automatic_schema(self): SchemaField("ts_col", "TIMESTAMP"), ) + @unittest.skipIf(pandas is None, "Requires `pandas`") + @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") + def test_load_table_from_dataframe_w_index_and_auto_schema(self): + from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES + from google.cloud.bigquery import job + from google.cloud.bigquery.schema import SchemaField + + client = self._make_client() + df_data = collections.OrderedDict( + [("int_col", [10, 20, 30]), ("float_col", [1.0, 2.0, 3.0])] + ) + dataframe = pandas.DataFrame( + df_data, + index=pandas.Index(name="unique_name", data=["one", "two", "three"]), + ) + + load_patch = mock.patch( + "google.cloud.bigquery.client.Client.load_table_from_file", autospec=True + ) + + get_table_patch = mock.patch( + "google.cloud.bigquery.client.Client.get_table", + autospec=True, + return_value=mock.Mock( + schema=[ + SchemaField("int_col", "INTEGER"), + SchemaField("float_col", "FLOAT"), + SchemaField("unique_name", "STRING"), + ] + ), + ) + with load_patch as load_table_from_file, get_table_patch: + client.load_table_from_dataframe( + dataframe, self.TABLE_REF, location=self.LOCATION + ) + + load_table_from_file.assert_called_once_with( + client, + mock.ANY, + self.TABLE_REF, + num_retries=_DEFAULT_NUM_RETRIES, + rewind=True, + job_id=mock.ANY, + job_id_prefix=None, + location=self.LOCATION, + project=None, + job_config=mock.ANY, + ) + + sent_config = load_table_from_file.mock_calls[0][2]["job_config"] + assert sent_config.source_format == job.SourceFormat.PARQUET + + sent_schema = sorted(sent_config.schema, key=operator.attrgetter("name")) + expected_sent_schema = [ + SchemaField("float_col", "FLOAT"), + SchemaField("int_col", "INTEGER"), + SchemaField("unique_name", "STRING"), + ] + assert sent_schema == expected_sent_schema + + @unittest.skipIf(pandas is None, "Requires `pandas`") + @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") + def test_load_table_from_dataframe_unknown_table(self): + from google.cloud.bigquery.client import _DEFAULT_NUM_RETRIES + + client = self._make_client() + records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}] + dataframe = pandas.DataFrame(records) + + get_table_patch = mock.patch( + "google.cloud.bigquery.client.Client.get_table", + autospec=True, + side_effect=google.api_core.exceptions.NotFound("Table not found"), + ) + load_patch = mock.patch( + "google.cloud.bigquery.client.Client.load_table_from_file", autospec=True + ) + with load_patch as load_table_from_file, get_table_patch: + # there should be no error + client.load_table_from_dataframe(dataframe, self.TABLE_REF) + + load_table_from_file.assert_called_once_with( + client, + mock.ANY, + self.TABLE_REF, + num_retries=_DEFAULT_NUM_RETRIES, + rewind=True, + job_id=mock.ANY, + job_id_prefix=None, + location=None, + project=None, + job_config=mock.ANY, + ) + @unittest.skipIf(pandas is None, "Requires `pandas`") @unittest.skipIf(pyarrow is None, "Requires `pyarrow`") def test_load_table_from_dataframe_struct_fields_error(self): @@ -5741,6 +5871,11 @@ def test_load_table_from_dataframe_wo_pyarrow_custom_compression(self): records = [{"id": 1, "age": 100}, {"id": 2, "age": 60}] dataframe = pandas.DataFrame(records) + get_table_patch = mock.patch( + "google.cloud.bigquery.client.Client.get_table", + autospec=True, + side_effect=google.api_core.exceptions.NotFound("Table not found"), + ) load_patch = mock.patch( "google.cloud.bigquery.client.Client.load_table_from_file", autospec=True ) @@ -5749,7 +5884,7 @@ def test_load_table_from_dataframe_wo_pyarrow_custom_compression(self): dataframe, "to_parquet", wraps=dataframe.to_parquet ) - with load_patch, pyarrow_patch, to_parquet_patch as to_parquet_spy: + with load_patch, get_table_patch, pyarrow_patch, to_parquet_patch as to_parquet_spy: client.load_table_from_dataframe( dataframe, self.TABLE_REF,