diff --git a/sdks/python/apache_beam/io/gcp/bigquery.py b/sdks/python/apache_beam/io/gcp/bigquery.py index a2d17f12569e..ba299d7c553c 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery.py +++ b/sdks/python/apache_beam/io/gcp/bigquery.py @@ -2937,6 +2937,13 @@ class ReadFromBigQuery(PTransform): PCollection with a schema and yielding Beam Rows via the option `BEAM_ROW`. For more information on schemas, see https://beam.apache.org/documentation/programming-guide/#what-is-a-schema) + query_output_schema: Required when output_type is 'BEAM_ROW' and a query + is specified. A BigQuery schema describing the query result columns, + since the schema cannot be auto-derived from an existing table when + using a query. Accepts the same formats as WriteToBigQuery's schema + parameter: a dict like + ``{'fields': [{'name': 'col', 'type': 'STRING', 'mode': 'NULLABLE'}]}``, + a JSON string, or a TableSchema object. """ class Method(object): EXPORT = 'EXPORT' # This is currently the default. @@ -2951,11 +2958,13 @@ def __init__( use_native_datetime=False, output_type=None, timeout=None, + query_output_schema=None, *args, **kwargs): self.method = method or ReadFromBigQuery.Method.EXPORT self.use_native_datetime = use_native_datetime self.output_type = output_type + self.query_output_schema = query_output_schema self._args = args self._kwargs = kwargs if timeout is not None: @@ -2979,9 +2988,15 @@ def __init__( if self.output_type == 'BEAM_ROW' and self._kwargs.get('query', None) is not None: - raise ValueError( - "Both a query and an output type of 'BEAM_ROW' were specified. " - "'BEAM_ROW' is not currently supported with queries.") + if self.query_output_schema is None: + raise ValueError( + "Both a query and an output type of 'BEAM_ROW' were specified " + "without a query_output_schema. When using a query, you must " + "provide query_output_schema so the output schema can be " + "determined without reading an existing table. The schema should " + "be a BigQuery schema dict, e.g. " + "{'fields': [{'name': 'col', 'type': 'STRING', 'mode': 'NULLABLE'}" + ", ...]}, or a TableSchema object.") self.gcs_location = gcs_location self.bigquery_dataset_labels = { @@ -3004,6 +3019,9 @@ def _expand_output_type(self, output_pcollection): if self.output_type == 'PYTHON_DICT' or self.output_type is None: return output_pcollection elif self.output_type == 'BEAM_ROW': + if self._kwargs.get('query', None) is not None: + return output_pcollection | bigquery_schema_tools.convert_to_usertype( + self.query_output_schema, self._kwargs.get('selected_fields', None)) table_details = bigquery_tools.parse_table_reference( table=self._kwargs.get("table", None), dataset=self._kwargs.get("dataset", None), diff --git a/sdks/python/apache_beam/io/gcp/bigquery_test.py b/sdks/python/apache_beam/io/gcp/bigquery_test.py index 234c99847a44..47e0af378e78 100644 --- a/sdks/python/apache_beam/io/gcp/bigquery_test.py +++ b/sdks/python/apache_beam/io/gcp/bigquery_test.py @@ -777,6 +777,55 @@ def test_read_all_lineage(self): 'bigquery:project2.dataset2.table2' ])) + def test_query_with_beam_row_requires_schema(self): + with self.assertRaisesRegex(ValueError, 'query_output_schema'): + ReadFromBigQuery( + query='SELECT id, name FROM dataset.table', output_type='BEAM_ROW') + + def test_query_with_beam_row_and_schema_accepted(self): + schema = { + 'fields': [ + { + 'name': 'id', 'type': 'INTEGER', 'mode': 'NULLABLE' + }, + { + 'name': 'name', 'type': 'STRING', 'mode': 'NULLABLE' + }, + ] + } + transform = ReadFromBigQuery( + query='SELECT id, name FROM dataset.table', + output_type='BEAM_ROW', + query_output_schema=schema) + self.assertEqual(transform.query_output_schema, schema) + + def test_expand_output_type_uses_query_schema(self): + schema = { + 'fields': [ + { + 'name': 'id', 'type': 'INTEGER', 'mode': 'NULLABLE' + }, + { + 'name': 'name', 'type': 'STRING', 'mode': 'NULLABLE' + }, + ] + } + transform = ReadFromBigQuery( + query='SELECT id, name FROM dataset.table', + output_type='BEAM_ROW', + query_output_schema=schema) + + with mock.patch.object(bigquery_tools.BigQueryWrapper, + 'get_table') as mock_get_table, \ + mock.patch('apache_beam.io.gcp.bigquery.bigquery_schema_tools' + '.convert_to_usertype') as mock_convert: + mock_convert.return_value = beam.Map(lambda x: x) + fake_pcoll = mock.MagicMock() + transform._expand_output_type(fake_pcoll) + + mock_get_table.assert_not_called() + mock_convert.assert_called_once_with(schema, None) + @unittest.skipIf(HttpError is None, 'GCP dependencies are not installed') class TestBigQuerySink(unittest.TestCase): diff --git a/sdks/python/apache_beam/yaml/yaml_io.py b/sdks/python/apache_beam/yaml/yaml_io.py index 77cbc41def32..59540e2d9f8f 100644 --- a/sdks/python/apache_beam/yaml/yaml_io.py +++ b/sdks/python/apache_beam/yaml/yaml_io.py @@ -101,7 +101,8 @@ def read_from_bigquery( table: Optional[str] = None, query: Optional[str] = None, row_restriction: Optional[str] = None, - fields: Optional[Iterable[str]] = None): + fields: Optional[Iterable[str]] = None, + schema: Optional[Any] = None): """Reads data from BigQuery. Exactly one of table or query must be set. @@ -119,18 +120,27 @@ def read_from_bigquery( specified field is a nested field, all the sub-fields in the field will be selected. The output field order is unrelated to the order of fields given here. + schema (dict): Required when query is set. A BigQuery schema describing + the query result columns, e.g. + ``{'fields': [{'name': 'col', 'type': 'STRING', 'mode': 'NULLABLE'}]}``. + Not applicable when reading from a table (schema is auto-derived). """ if query is None: assert table is not None else: assert table is None and row_restriction is None and fields is None + if schema is None: + raise ValueError( + "When using 'query' in ReadFromBigQuery YAML transform, " + "'schema' is required to define the output row structure.") return ReadFromBigQuery( query=query, table=table, row_restriction=row_restriction, selected_fields=fields, method='DIRECT_READ', - output_type='BEAM_ROW') + output_type='BEAM_ROW', + query_output_schema=schema) def write_to_bigquery( diff --git a/sdks/python/apache_beam/yaml/yaml_io_test.py b/sdks/python/apache_beam/yaml/yaml_io_test.py index 250a54689f5a..c3df0328f22b 100644 --- a/sdks/python/apache_beam/yaml/yaml_io_test.py +++ b/sdks/python/apache_beam/yaml/yaml_io_test.py @@ -764,6 +764,49 @@ def expand(self, pcoll): ])) +class ReadFromBigQueryTest(unittest.TestCase): + def test_query_without_schema_raises(self): + from apache_beam.yaml.yaml_io import read_from_bigquery + with self.assertRaisesRegex(ValueError, 'schema'): + read_from_bigquery(query='SELECT id FROM dataset.table') + + def test_table_without_schema_ok(self): + import unittest.mock as mock + + from apache_beam.yaml.yaml_io import read_from_bigquery + with mock.patch('apache_beam.yaml.yaml_io.ReadFromBigQuery') as mock_rfbq: + mock_rfbq.return_value = mock.MagicMock() + read_from_bigquery(table='project:dataset.table') + mock_rfbq.assert_called_once() + call_kwargs = mock_rfbq.call_args[1] + self.assertIsNone(call_kwargs.get('query_output_schema')) + + def test_query_with_schema_passes_through(self): + import unittest.mock as mock + + from apache_beam.yaml.yaml_io import read_from_bigquery + schema = { + 'fields': [ + { + 'name': 'id', 'type': 'INTEGER', 'mode': 'NULLABLE' + }, + ] + } + with mock.patch('apache_beam.yaml.yaml_io.ReadFromBigQuery') as mock_rfbq: + mock_rfbq.return_value = mock.MagicMock() + read_from_bigquery(query='SELECT id FROM dataset.table', schema=schema) + call_kwargs = mock_rfbq.call_args[1] + self.assertEqual(call_kwargs['query_output_schema'], schema) + + def test_query_and_table_both_raises(self): + from apache_beam.yaml.yaml_io import read_from_bigquery + with self.assertRaises(AssertionError): + read_from_bigquery( + table='project:dataset.table', + query='SELECT id FROM dataset.table', + schema={'fields': []}) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main()