diff --git a/sdks/python/apache_beam/yaml/yaml_enrichment_test.py b/sdks/python/apache_beam/yaml/yaml_enrichment_test.py index 9a5d848486ca..ff828325393f 100644 --- a/sdks/python/apache_beam/yaml/yaml_enrichment_test.py +++ b/sdks/python/apache_beam/yaml/yaml_enrichment_test.py @@ -24,6 +24,7 @@ from apache_beam import Row from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +from apache_beam.yaml import yaml_enrichment from apache_beam.yaml import yaml_provider from apache_beam.yaml.yaml_transform import YamlTransform @@ -72,6 +73,41 @@ def test_enrichment_with_bigquery(self): ''') assert_that(result, equal_to(input_data)) + def test_enrichment_transform_direct_calls(self): + pcoll = mock.MagicMock() + with mock.patch('apache_beam.yaml.options.YamlOptions.check_enabled'): + with mock.patch('apache_beam.yaml.yaml_enrichment.Enrichment', None): + with self.assertRaises(ValueError): + yaml_enrichment.enrichment_transform('BigQuery', {}).expand(pcoll) + + with mock.patch('apache_beam.yaml.yaml_enrichment.Enrichment', + mock.MagicMock()): + with mock.patch( + 'apache_beam.yaml.yaml_enrichment.FeastFeatureStoreEnrichmentHandler', + None): + with self.assertRaises(ValueError): + yaml_enrichment.enrichment_transform('FeastFeatureStore', + {}).expand(pcoll) + + with self.assertRaises(ValueError): + yaml_enrichment.enrichment_transform('UnknownHandler', + {}).expand(pcoll) + + mock_bq_handler = mock.MagicMock() + mock_enrichment = mock.MagicMock() + with mock.patch( + 'apache_beam.yaml.yaml_enrichment.BigQueryEnrichmentHandler', + mock_bq_handler): + with mock.patch('apache_beam.yaml.yaml_enrichment.Enrichment', + mock_enrichment): + yaml_enrichment.enrichment_transform( + 'BigQuery', { + 'col': 'val' + }, timeout=10).expand(pcoll) + mock_bq_handler.assert_called_once_with(col='val') + mock_enrichment.assert_called_once_with( + source_handler=mock_bq_handler.return_value, timeout=10) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/sdks/python/apache_beam/yaml/yaml_io_test.py b/sdks/python/apache_beam/yaml/yaml_io_test.py index 250a54689f5a..30c956b0e37f 100644 --- a/sdks/python/apache_beam/yaml/yaml_io_test.py +++ b/sdks/python/apache_beam/yaml/yaml_io_test.py @@ -34,6 +34,7 @@ from apache_beam.testing.util import equal_to from apache_beam.typehints import schemas as schema_utils from apache_beam.utils.timestamp import Timestamp +from apache_beam.yaml import yaml_io from apache_beam.yaml.yaml_transform import YamlTransform try: @@ -763,6 +764,104 @@ def expand(self, pcoll): last_updated_in_seconds=None) ])) + def test_bigquery_coverage(self): + with mock.patch('apache_beam.yaml.yaml_io.ReadFromBigQuery'): + _ = yaml_io.read_from_bigquery(query='SELECT 1') + _ = yaml_io.read_from_bigquery(table='project:dataset.table') + + class MockWriteToBQ(beam.PTransform): + Method = mock.MagicMock() + + def __init__(self, *args, method=None, **kwargs): + super().__init__() + self._method = method + + def expand(self, pcoll): + res = mock.MagicMock() + res._method = self._method or self.Method.FILE_LOADS + return res + + p1 = beam.Pipeline() + pcoll1 = p1 | 'C1' >> beam.Create([beam.Row(a=1)]) + p2 = beam.Pipeline() + pcoll2 = p2 | 'C2' >> beam.Create([beam.Row(a=1)]) + with mock.patch('apache_beam.yaml.yaml_io.WriteToBigQuery', MockWriteToBQ): + _ = yaml_io.write_to_bigquery('p:d.t').expand(pcoll1) + _ = yaml_io.write_to_bigquery( + 'p:d.t', error_handling={ + 'output': 'err' + }).expand(pcoll2) + + def test_pubsub_coverage(self): + # Both topic and subscription are specified (only one is allowed) + with self.assertRaises(TypeError): + _ = beam.Pipeline() | 'Read1' >> yaml_io.read_from_pubsub( + topic='t', subscription='s', format='RAW') + # Neither topic nor subscription is specified (one is required) + with self.assertRaises(TypeError): + _ = beam.Pipeline() | 'Read2' >> yaml_io.read_from_pubsub(format='RAW') + # RAW format does not take a schema + with self.assertRaises(ValueError): + _ = beam.Pipeline() | 'Read3' >> yaml_io.read_from_pubsub( + topic='t', format='RAW', schema='s') + # STRING format does not take a schema + with self.assertRaises(ValueError): + _ = beam.Pipeline() | 'Read4' >> yaml_io.read_from_pubsub( + topic='t', format='STRING', schema='s') + # Format is unknown + with self.assertRaises(ValueError): + _ = beam.Pipeline() | 'Read5' >> yaml_io.read_from_pubsub( + topic='t', format='UNKNOWN') + + # Attributes is not a list of strings + pcoll = beam.Pipeline() | 'CreatePcoll' >> beam.Create([beam.Row(a=1)]) + with self.assertRaises(ValueError): + _ = pcoll | yaml_io.write_to_pubsub( + topic='t', format='RAW', attributes='missing_attr') + + class MockReadFromPubSub(beam.PTransform): + def __init__(self, *args, **kwargs): + super().__init__() + + def expand(self, pcoll): + return pcoll | beam.Create([beam.Row(payload=b'data')]) + + with mock.patch('apache_beam.io.ReadFromPubSub', MockReadFromPubSub): + _ = beam.Pipeline() | 'ReadRaw' >> yaml_io.read_from_pubsub( + topic='projects/p/topics/t', format='RAW') + _ = beam.Pipeline() | 'ReadStr' >> yaml_io.read_from_pubsub( + topic='projects/p/topics/t', format='STRING') + + def test_tfrecord_coverage(self): + class MockReadFromTFRecord(beam.PTransform): + def __init__(self, *args, **kwargs): + super().__init__() + + def expand(self, pcoll): + return pcoll | beam.Create([b'record_bytes']) + + class MockWriteToTFRecord(beam.PTransform): + def __init__(self, *args, **kwargs): + super().__init__() + + def expand(self, pcoll): + return mock.MagicMock() + + with mock.patch('apache_beam.yaml.yaml_io.ReadFromTFRecord', + MockReadFromTFRecord): + _ = beam.Pipeline( + ) | 'ReadTFR' >> yaml_io.read_from_tfrecord('file_pattern*') + + p = beam.Pipeline() + pcoll1 = p | 'C1' >> beam.Create([beam.Row(a=b'1')]) + with mock.patch('apache_beam.yaml.yaml_io.WriteToTFRecord', + MockWriteToTFRecord): + _ = pcoll1 | 'W1' >> yaml_io.write_to_tfrecord('prefix') + + with self.assertRaises(ValueError): + pcoll2 = p | 'C2' >> beam.Create([beam.Row(a=1, b=2)]) + _ = pcoll2 | 'W2' >> yaml_io.write_to_tfrecord('prefix2') + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/sdks/python/apache_beam/yaml/yaml_mapping_test.py b/sdks/python/apache_beam/yaml/yaml_mapping_test.py index 169c86d7b87b..20204b4b3fdf 100644 --- a/sdks/python/apache_beam/yaml/yaml_mapping_test.py +++ b/sdks/python/apache_beam/yaml/yaml_mapping_test.py @@ -15,9 +15,12 @@ # limitations under the License. # +import datetime import logging import typing import unittest +from decimal import Decimal +from unittest import mock import numpy as np @@ -35,6 +38,11 @@ except ImportError: jsonschema = None +try: + import quickjs +except ImportError: + quickjs = None + DATA = [ beam.Row(label='11a', conductor=11, rank=0), beam.Row(label='37a', conductor=37, rank=1), @@ -559,6 +567,245 @@ def test_extract_windowing_info_iterable(self): ])) +class YamlMappingJsExpressionTest(unittest.TestCase): + def test_javascript_expression_no_quickjs(self): + options = beam.options.pipeline_options.PipelineOptions( + yaml_experimental_features=['javascript']) + with mock.patch('apache_beam.yaml.yaml_mapping.quickjs', None): + with self.assertRaises(Exception): + with beam.Pipeline(options=options) as p: + _ = ( + p + | beam.Create([beam.Row(x=1)]) + | YamlTransform( + ''' + type: MapToFields + config: + language: javascript + fields: + y: x + ''')) + + def test_javascript_path_missing(self): + options = beam.options.pipeline_options.PipelineOptions( + yaml_experimental_features=['javascript']) + with mock.patch('apache_beam.yaml.yaml_mapping.quickjs', mock.MagicMock()): + with self.assertRaises(Exception): + with beam.Pipeline(options=options) as p: + _ = ( + p + | beam.Create([beam.Row(x=1)]) + | YamlTransform( + ''' + type: MapToFields + config: + language: javascript + fields: + y: + path: "bad.txt" + ''')) + + @unittest.skipIf(quickjs is None, "quickjs not installed") + def test_javascript_expression_execution(self): + options = beam.options.pipeline_options.PipelineOptions( + yaml_experimental_features=['javascript']) + with beam.Pipeline(options=options) as p: + elements = p | beam.Create([beam.Row(x=1, s="foo")]) + result = elements | YamlTransform( + ''' + type: MapToFields + config: + language: javascript + fields: + y: + expression: "x + 1" + t: + expression: "s + '_bar'" + ''') + assert_that(result, equal_to([beam.Row(y=2, t="foo_bar")])) + + @unittest.skipIf(quickjs is None, "quickjs not installed") + def test_javascript_callable_execution(self): + options = beam.options.pipeline_options.PipelineOptions( + yaml_experimental_features=['javascript']) + with beam.Pipeline(options=options) as p: + elements = p | beam.Create([beam.Row(x=1)]) + result = elements | YamlTransform( + ''' + type: MapToFields + config: + language: javascript + fields: + y: + callable: "function(row) { return row.x + 2; }" + ''') + assert_that(result, equal_to([beam.Row(y=3)])) + + @unittest.skipIf(quickjs is None, "quickjs not installed") + def test_javascript_path_execution(self): + options = beam.options.pipeline_options.PipelineOptions( + yaml_experimental_features=['javascript']) + with mock.patch( + 'apache_beam.io.filesystems.FileSystems.open', + mock.mock_open(read_data=b'function fn(row) { return row.x + 3; }')): + with beam.Pipeline(options=options) as p: + elements = p | beam.Create([beam.Row(x=1)]) + result = elements | YamlTransform( + ''' + type: MapToFields + config: + language: javascript + fields: + y: + path: "test.js" + name: "fn" + ''') + assert_that(result, equal_to([beam.Row(y=4)])) + + def test_js_value_to_js_dict(self): + val = beam.Row( + a=b'bytes', + b=datetime.datetime(2026, 1, 1), + c=Decimal('1.23'), + d=[1, 2], + e=beam.Row(f=3)) + res = yaml_mapping.py_value_to_js_dict(val) + self.assertEqual(res['a'], 'bytes') + self.assertEqual( + res['b'], { + '__date__': True, 'value': '2026-01-01T00:00:00' + }) + self.assertEqual(res['c'], 1.23) + self.assertEqual(res['d'], [1, 2]) + self.assertEqual(res['e'], {'f': 3}) + + +class YamlMappingErrorHandlingTest(unittest.TestCase): + def test_strip_error_metadata(self): + with beam.Pipeline() as p: + pcoll1 = p | 'C1' >> beam.Create([(1, 2)]) + res1 = pcoll1 | 'Strip1' >> YamlTransform( + ''' + type: StripErrorMetadata + ''') + assert_that(res1, equal_to([1]), label='res1') + + pcoll2 = p | 'C2' >> beam.Create([beam.Row(element=123, error='err')]) + res2 = pcoll2 | 'Strip2' >> YamlTransform( + ''' + type: StripErrorMetadata + ''') + assert_that(res2, equal_to([123]), label='res2') + + def test_strip_error_metadata_invalid(self): + with self.assertRaises(Exception): + with beam.Pipeline() as p: + pcoll3 = p | 'C3' >> beam.Create([beam.Row(a=1, b=2)]) + _ = pcoll3 | YamlTransform( + ''' + type: StripErrorMetadata + ''') + + def test_validate_with_exception_handling(self): + v = yaml_mapping.Validate({}) + v.with_exception_handling(output='err') + + +class YamlMappingSqlTransformTest(unittest.TestCase): + def test_sql_map_to_fields(self): + queries_received = [] + + def mock_sql_transform(query): + queries_received.append(query) + return beam.Map(lambda x: x) + + from apache_beam.yaml.yaml_provider import InlineProvider + mock_provider = InlineProvider({'Sql': mock_sql_transform}) + + with mock.patch( + 'apache_beam.yaml.yaml_provider.SqlBackedProvider.sql_provider', + lambda self: mock_provider): + with beam.Pipeline() as p: + pcoll = p | beam.Create([beam.Row(a=1)]) + _ = pcoll | YamlTransform( + ''' + type: MapToFields + config: + language: sql + fields: + x: a + y: a+1 + ''') + + self.assertEqual(len(queries_received), 1) + self.assertEqual( + queries_received[0], "SELECT (a) AS `x`, (a+1) AS `y` FROM PCOLLECTION") + + def test_sql_map_to_fields_invalid(self): + from apache_beam.yaml.yaml_provider import InlineProvider + mock_provider = InlineProvider({'Sql': lambda query: beam.Map(lambda x: x)}) + + with mock.patch( + 'apache_beam.yaml.yaml_provider.SqlBackedProvider.sql_provider', + lambda self: mock_provider): + with self.assertRaises(Exception): + with beam.Pipeline() as p: + pcoll = p | beam.Create([beam.Row(a=1)]) + _ = pcoll | YamlTransform( + ''' + type: MapToFields + config: + language: sql + fields: + x: + bad: 123 + ''') + + +class YamlMappingValidatorTest(unittest.TestCase): + def test_validator(self): + from apache_beam import schema_pb2 + t_bool = schema_pb2.FieldType(atomic_type=schema_pb2.BOOLEAN) + self.assertTrue(yaml_mapping._validator(t_bool)(True)) + + t_int = schema_pb2.FieldType(atomic_type=schema_pb2.INT64) + self.assertTrue(yaml_mapping._validator(t_int)(10)) + + t_double = schema_pb2.FieldType(atomic_type=schema_pb2.DOUBLE) + self.assertTrue(yaml_mapping._validator(t_double)(1.23)) + + t_str = schema_pb2.FieldType(atomic_type=schema_pb2.STRING) + self.assertTrue(yaml_mapping._validator(t_str)('s')) + + t_bytes = schema_pb2.FieldType(atomic_type=schema_pb2.BYTES) + self.assertTrue(yaml_mapping._validator(t_bytes)(b'b')) + + with self.assertRaises(ValueError): + yaml_mapping._validator( + schema_pb2.FieldType(atomic_type=schema_pb2.UNSPECIFIED)) + + t_arr = schema_pb2.FieldType( + array_type=schema_pb2.ArrayType(element_type=t_int)) + self.assertTrue(yaml_mapping._validator(t_arr)([1, 2])) + + t_iter = schema_pb2.FieldType( + iterable_type=schema_pb2.IterableType(element_type=t_str)) + self.assertTrue(yaml_mapping._validator(t_iter)(['a', 'b'])) + + t_map = schema_pb2.FieldType( + map_type=schema_pb2.MapType(key_type=t_str, value_type=t_int)) + self.assertTrue(yaml_mapping._validator(t_map)({'a': 1})) + + t_row = schema_pb2.FieldType( + row_type=schema_pb2.RowType( + schema=schema_pb2.Schema( + fields=[schema_pb2.Field(name='a', type=t_int)]))) + self.assertTrue(yaml_mapping._validator(t_row)(beam.Row(a=1))) + + with self.assertRaises(ValueError): + yaml_mapping._validator(schema_pb2.FieldType()) + + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) unittest.main() diff --git a/sdks/python/apache_beam/yaml/yaml_ml_test.py b/sdks/python/apache_beam/yaml/yaml_ml_test.py index d8b1bdbae1b2..c3e414fc7054 100644 --- a/sdks/python/apache_beam/yaml/yaml_ml_test.py +++ b/sdks/python/apache_beam/yaml/yaml_ml_test.py @@ -18,17 +18,25 @@ import logging import tempfile import unittest +from typing import Any +from unittest import mock import apache_beam as beam from apache_beam.testing.util import assert_that from apache_beam.testing.util import equal_to +from apache_beam.yaml import yaml_ml from apache_beam.yaml.yaml_transform import YamlTransform try: # pylint: disable=wrong-import-order, wrong-import-position, unused-import from apache_beam.ml.transforms import tft except ImportError: - raise unittest.SkipTest('tensorflow_transform is not installed.') + tft = None + +try: + import sentence_transformers +except ImportError: + sentence_transformers = None TRAIN_DATA = [ beam.Row(num=0, text='And God said, Let there be light,'), @@ -42,6 +50,7 @@ class MLTransformTest(unittest.TestCase): + @unittest.skipIf(tft is None, 'tensorflow_transform is not installed.') def test_ml_transform(self): ml_opts = beam.options.pipeline_options.PipelineOptions( pickle_library='cloudpickle', yaml_experimental_features=['ML']) @@ -86,6 +95,7 @@ def test_ml_transform(self): equal_to([5]), label='CheckVocab') + @unittest.skipIf(tft is None, 'tensorflow_transform is not installed.') def test_ml_transform_read_with_map_to_fields(self): ml_opts = beam.options.pipeline_options.PipelineOptions( pickle_library='cloudpickle', yaml_experimental_features=['ML']) @@ -133,6 +143,8 @@ def check_row(row): assert_that(result | beam.Map(check_row), equal_to([0.75])) + @unittest.skipIf( + sentence_transformers is None, 'sentence_transformers is not installed.') def test_sentence_transformer_embedding(self): SENTENCE_EMBEDDING_DIMENSION = 384 DATA = [{ @@ -165,6 +177,8 @@ def test_sentence_transformer_embedding(self): assert_that( actual_output, equal_to([SENTENCE_EMBEDDING_DIMENSION] * len(DATA))) + @unittest.skipIf( + sentence_transformers is None, 'sentence_transformers is not installed.') def test_sentence_transformer_embedding_with_beam_rows(self): SENTENCE_EMBEDDING_DIMENSION = 384 DATA = [ @@ -195,6 +209,8 @@ def test_sentence_transformer_embedding_with_beam_rows(self): assert_that( actual_output, equal_to([SENTENCE_EMBEDDING_DIMENSION] * len(DATA))) + @unittest.skipIf( + sentence_transformers is None, 'sentence_transformers is not installed.') def test_ml_transform_outputs_schema(self): SENTENCE_EMBEDDING_DIMENSION = 384 ml_opts = beam.options.pipeline_options.PipelineOptions( @@ -235,6 +251,154 @@ def check_row(row): assert_that(result | beam.Map(check_row), equal_to([1, 2, 3])) + def test_model_handler_provider(self): + provider = yaml_ml.ModelHandlerProvider( + "handler", preprocess={'callable': 'lambda x: x'}) + self.assertEqual(provider.underlying_handler(), "handler") + self.assertEqual(provider.inference_output_type(), Any) + self.assertEqual( + provider._preprocess_fn_internal()(beam.Row(a=1)), + (beam.Row(a=1), beam.Row(a=1))) + self.assertEqual(provider._postprocess_fn_internal()(('orig', [1]))[1], [1]) + # Verify error handling, defaults, and config parsing for ModelHandlerProvider. + with self.assertRaises(ValueError): + provider.default_preprocess_fn() + with self.assertRaises(NotImplementedError): + provider.validate({}) + with self.assertRaises(ValueError): + provider.parse_processing_transform({ + 'callable': 'f', 'path': 'p', 'name': 'n' + }, + 'preprocess') + with self.assertRaises(ValueError): + provider.parse_processing_transform({'callable': None}, 'preprocess') + with self.assertRaises(ValueError): + provider.parse_processing_transform('not_dict', 'preprocess') + with mock.patch('apache_beam.io.filesystems.FileSystems.open', + mock.mock_open(read_data=b'def fn(x):\n return x')): + _ = provider.parse_processing_transform({ + 'path': 'f.py', 'name': 'fn' + }, + 'preprocess') + with self.assertRaises(ValueError): + yaml_ml.ModelHandlerProvider.create_handler({ + 'type': 'Nonexistent', 'config': {} + }) + + @yaml_ml.ModelHandlerProvider.register_handler_type('DummyMLHandler') + def _create(**config): + return mock.MagicMock() + + _ = yaml_ml.ModelHandlerProvider.create_handler({ + 'type': 'DummyMLHandler', 'config': {} + }) + + def test_vertex_ai_provider(self): + mock_vertex = mock.MagicMock() + with mock.patch.dict( + 'sys.modules', + {'apache_beam.ml.inference.vertex_ai_inference': mock_vertex}): + mock_vertex.VertexAIModelHandlerJSON = mock.MagicMock() + p = yaml_ml.VertexAIModelHandlerJSONProvider( + 123, + project='p', + location='l', + preprocess={'callable': 'lambda x: x'}) + p.validate({}) + self.assertIsNotNone(p.inference_output_type()) + + def test_huggingface_provider(self): + mock_hf = mock.MagicMock() + with mock.patch.dict( + 'sys.modules', + {'apache_beam.ml.inference.huggingface_inference': mock_hf}): + mock_hf.HuggingFacePipelineModelHandler = mock.MagicMock() + p = yaml_ml.HuggingFacePipelineModelHandlerProvider( + task='t', + preprocess={'callable': 'lambda x: x'}, + inference_fn={'callable': 'lambda x: x'}) + p.validate({'task': 't'}) + with self.assertRaises(ValueError): + p.validate({}) + self.assertEqual(p.inference_output_type(), Any) + + def test_run_inference_coverage(self): + p = beam.Pipeline() + pcoll = p | 'CreatePcoll' >> beam.Create([beam.Row(a=1)]) + with mock.patch('apache_beam.yaml.options.YamlOptions.check_enabled'): + with self.assertRaises(ValueError): + yaml_ml.run_inference("not_dict").expand(pcoll) + with self.assertRaises(ValueError): + yaml_ml.run_inference({ + 'type': 't', 'config': {}, 'extra': 1 + }).expand(pcoll) + with self.assertRaises(ValueError): + yaml_ml.run_inference({'type': 't'}).expand(pcoll) + with self.assertRaises(NotImplementedError): + yaml_ml.run_inference({ + 'type': 'UnknownType', 'config': {} + }).expand(pcoll) + + mock_provider = mock.MagicMock() + mock_provider.underlying_handler.return_value = mock.MagicMock() + mock_provider.inference_output_type.return_value = Any + mock_provider._preprocess_fn_internal.return_value = lambda x: x + mock_provider._postprocess_fn_internal.return_value = lambda x: x + with mock.patch.dict(yaml_ml.ModelHandlerProvider.handler_types, + {'DummyType': mock.MagicMock()}): + with mock.patch.object(yaml_ml.ModelHandlerProvider, + 'create_handler', + return_value=mock_provider): + with mock.patch('apache_beam.yaml.yaml_ml.RunInference', + return_value=beam.Map(lambda x: (x, 'res'))): + with beam.Pipeline( + options=beam.options.pipeline_options.PipelineOptions( + yaml_experimental_features=['ML'])) as p: + result = (p | 'CreateInput' >> beam.Create([beam.Row(a=1)]) + ) | YamlTransform( + ''' + type: RunInference + config: + model_handler: + type: DummyType + config: {} + ''') + assert_that(result, equal_to([beam.Row(a=1, inference='res')])) + + @unittest.skipIf( + sentence_transformers is None, 'sentence_transformers is not installed.') + def test_ml_transform_and_config(self): + with self.assertRaises(ValueError): + yaml_ml._config_to_obj({}) + with self.assertRaises(ValueError): + yaml_ml._config_to_obj({'type': 't'}) + with self.assertRaises(ValueError): + yaml_ml._config_to_obj({'type': 'unknown', 'config': {}}) + + pcoll = beam.Pipeline() | 'CreateML' >> beam.Create([beam.Row(a=1)]) + with mock.patch('apache_beam.yaml.yaml_ml.MLTransform', None): + with self.assertRaises(ValueError): + yaml_ml.ml_transform().expand(pcoll) + + mock_mlt = mock.MagicMock(return_value=beam.Map(lambda x: x)) + with mock.patch('apache_beam.yaml.yaml_ml.MLTransform', mock_mlt): + with mock.patch('apache_beam.yaml.yaml_ml._config_to_obj', + return_value=mock.MagicMock()): + with beam.Pipeline( + options=beam.options.pipeline_options.PipelineOptions( + yaml_experimental_features=['ML'])) as p: + result = (p | 'CreateMLInput' >> beam.Create([beam.Row(a=1)]) + ) | YamlTransform( + ''' + type: MLTransform + config: + transforms: + - type: SentenceEmbeddings + config: + columns: [a] + ''') + assert_that(result, equal_to([beam.Row(a=1)])) + if __name__ == '__main__': logging.getLogger().setLevel(logging.INFO) diff --git a/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py b/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py index e1e3ee847d96..bfdab70057c9 100644 --- a/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py +++ b/sdks/python/apache_beam/yaml/yaml_provider_unit_test.py @@ -361,11 +361,6 @@ def test_empty_base(self): yaml_provider._join_url_or_filepath(None, 'a/b.yaml'), 'a/b.yaml') -if __name__ == '__main__': - logging.getLogger().setLevel(logging.INFO) - unittest.main() - - class YamlProvidersCreateTest(unittest.TestCase): def test_create_mixed_types(self): with beam.Pipeline() as p: @@ -377,3 +372,249 @@ def test_create_mixed_types(self): [('a', None), ('element', 1)], [('a', 2), ('element', None)], ])) + + +class YamlProvidersFlattenTest(unittest.TestCase): + def test_flatten_schema_merging(self): + with beam.Pipeline() as p: + pcoll1 = p | 'C1' >> beam.Create([beam.Row(a=1)]) + pcoll2 = p | 'C2' >> beam.Create([beam.Row(b=2)]) + res = { + 'first': pcoll1, 'second': pcoll2 + } | yaml_provider.YamlProviders.Flatten() + assert_that( + res | beam.Map(lambda x: sorted(x._asdict().items())), + equal_to([ + [('a', 1), ('b', None)], + [('a', None), ('b', 2)], + ])) + + def test_flatten_single_pcoll(self): + with beam.Pipeline() as p: + pcoll = p | beam.Create([beam.Row(a=1)]) + res = pcoll | yaml_provider.YamlProviders.Flatten() + assert_that(res, equal_to([beam.Row(a=1)])) + + def test_flatten_empty(self): + with beam.Pipeline() as p: + res = p | yaml_provider.YamlProviders.Flatten() + assert_that(res, equal_to([])) + + +class YamlProviderBaseAndHelpersTest(unittest.TestCase): + def test_not_available_with_reason(self): + n = yaml_provider.NotAvailableWithReason("reason") + self.assertEqual(n.reason, "reason") + self.assertFalse(bool(n)) + + def test_provider_defaults(self): + class MinimalProvider(yaml_provider.Provider): + def available(self): + return super().available() + + def cache_artifacts(self): + return super().cache_artifacts() + + def provided_transforms(self): + return super().provided_transforms() + + def create_transform(self, typ, args, spec): + return super().create_transform(typ, args, spec) + + def _with_extra_dependencies(self, deps): + return super()._with_extra_dependencies(deps) + + p = MinimalProvider() + with self.assertRaises(NotImplementedError): + p.available() + with self.assertRaises(NotImplementedError): + p.cache_artifacts() + with self.assertRaises(NotImplementedError): + p.provided_transforms() + with self.assertRaises(NotImplementedError): + p.create_transform("typ", {}, {}) + with self.assertRaises(ValueError): + p._with_extra_dependencies(["dep"]) + + self.assertIsNone(p.config_schema("typ")) + self.assertIsNone(p.description("typ")) + self.assertFalse(p.requires_inputs("ReadFromSource", {})) + self.assertTrue(p.requires_inputs("typ", {})) + self.assertIsNone(yaml_provider.InlineProvider({}).cache_artifacts()) + self.assertIsNone( + yaml_provider.RemoteProvider({}, 'localhost:1234').cache_artifacts()) + + def test_as_provider_list(self): + p1 = yaml_provider.as_provider("t", lambda: None) + self.assertIsInstance(p1, yaml_provider.InlineProvider) + self.assertIs(yaml_provider.as_provider("t", p1), p1) + + lst1 = yaml_provider.as_provider_list("t", lambda: None) + self.assertEqual(len(lst1), 1) + lst2 = yaml_provider.as_provider_list("t", [p1]) + self.assertEqual(lst2, [p1]) + + def test_merge_providers(self): + p1 = yaml_provider.as_provider("t1", lambda: None) + p2 = yaml_provider.as_provider("t2", lambda: None) + res = yaml_provider.merge_providers(p1, [p2]) + self.assertIn("t1", res) + self.assertIn("t2", res) + + def test_inline_provider_namespace(self): + _ = yaml_provider.standard_inline_providers.Create + with self.assertRaises(ValueError): + _ = yaml_provider.standard_inline_providers.NonexistentTransform + + def test_renaming_provider(self): + p = yaml_provider.InlineProvider({'T': lambda **kwargs: None}) + rp = yaml_provider.RenamingProvider( + transforms={'MyT': 'T'}, + provider_base_path=None, + mappings={'MyT': { + 'new_arg': 'x' + }}, + underlying_provider=p, + defaults={'MyT': { + 'def': 1 + }}) + self.assertTrue(rp.available()) + self.assertEqual(list(rp.provided_transforms()), ['MyT']) + self.assertEqual(rp.underlying_provider(), p) + self.assertIsNotNone(rp.config_schema('MyT')) + self.assertIsNone(rp.description('MyT')) + self.assertTrue(rp.requires_inputs('MyT', {})) + _ = rp.create_transform('MyT', {'new_arg': 123}, lambda t, pcoll: None) + + # Verify that mappings must be a dictionary. + with self.assertRaises(ValueError): + yaml_provider.RenamingProvider({'MyT': 'T'}, None, 'not_dict', p) + + # Verify that string-based delegation mappings must point to a defined key. + with self.assertRaises(ValueError): + yaml_provider.RenamingProvider({'MyT': 'T'}, None, {'MyT': 'UnknownT'}, p) + + # Verify that every renamed transform must have an entry in the mappings dictionary. + with self.assertRaises(ValueError): + yaml_provider.RenamingProvider({'Missing': 'T'}, None, {}, p) + + +class WindowIntoTransformTest(unittest.TestCase): + def test_window_into_types(self): + from apache_beam.transforms import window + + p = yaml_provider.YamlProviders.WindowInto._parse_window_spec( + {'type': 'global'}) + self.assertIsInstance(p.windowing.windowfn, window.GlobalWindows) + + p = yaml_provider.YamlProviders.WindowInto._parse_window_spec({ + 'type': 'fixed', 'size': '10s' + }) + self.assertIsInstance(p.windowing.windowfn, window.FixedWindows) + + p = yaml_provider.YamlProviders.WindowInto._parse_window_spec({ + 'type': 'sliding', 'size': '10s', 'period': '5s' + }) + self.assertIsInstance(p.windowing.windowfn, window.SlidingWindows) + + p = yaml_provider.YamlProviders.WindowInto._parse_window_spec({ + 'type': 'sessions', 'gap': '10s' + }) + self.assertIsInstance(p.windowing.windowfn, window.Sessions) + + with self.assertRaises(ValueError): + yaml_provider.YamlProviders.WindowInto._parse_window_spec( + {'type': 'unknown'}) + + with self.assertRaises(ValueError): + yaml_provider.YamlProviders.WindowInto._parse_window_spec({ + 'type': 'global', 'extra': 123 + }) + + +class LogForTestingTest(unittest.TestCase): + def test_log_for_testing(self): + with beam.Pipeline() as p: + pcoll = p | beam.Create([1]) + output_pcoll = pcoll | yaml_provider.YamlProviders.log_for_testing( + level='INFO', prefix='test:') + + # Extract the DoFn from the transform and run it directly in-process + transform = output_pcoll.producer.transform + dofn = transform.fn + + with self.assertLogs(level='INFO') as log: + outputs = list(dofn.process(beam.Row(a=b'bytes_val', b=[1, 2]))) + + self.assertEqual(outputs, [beam.Row(a=b'bytes_val', b=[1, 2])]) + self.assertIn( + 'INFO:root:test:{"a": "b\'bytes_val\'", "b": [1, 2]}', log.output) + + def test_log_for_testing_unknown_level(self): + with beam.Pipeline() as p: + pcoll = p | beam.Create([1]) + with self.assertRaises(ValueError): + _ = pcoll | yaml_provider.YamlProviders.log_for_testing(level='UNKNOWN') + + +class PypiExpansionServiceTest(unittest.TestCase): + @mock.patch('apache_beam.utils.subprocess_server.SubprocessServer') + @mock.patch('subprocess.run') + @mock.patch('os.path.exists') + def test_pypi_expansion_service_venv_creation( + self, mock_exists, mock_run, mock_server): + mock_exists.return_value = False + + mock_clone = mock.MagicMock() + mock_clone_module = mock.MagicMock() + mock_clone_module.clone_virtualenv = mock_clone + + with mock.patch.dict('sys.modules', {'clonevirtualenv': mock_clone_module}): + with mock.patch('builtins.open', mock.mock_open()): + with yaml_provider.PypiExpansionService(['pkg1', 'pkg2']) as service: + pass + + self.assertTrue(mock_clone.called) + self.assertTrue(mock_run.called) + + @mock.patch('apache_beam.utils.subprocess_server.SubprocessServer') + @mock.patch('os.environ.get') + @mock.patch('os.path.exists') + def test_pypi_expansion_service_dev_cloning( + self, mock_exists, mock_env_get, mock_server): + mock_exists.return_value = False + mock_env_get.return_value = 'false' + + mock_clone = mock.MagicMock() + mock_clone_module = mock.MagicMock() + mock_clone_module.clone_virtualenv = mock_clone + + with mock.patch.dict('sys.modules', {'clonevirtualenv': mock_clone_module}): + with mock.patch('builtins.open', mock.mock_open()): + with mock.patch('subprocess.run') as mock_run: + with mock.patch('apache_beam.yaml.yaml_provider.beam_version', + '2.50.0.dev'): + with mock.patch('os.path.dirname') as mock_dirname: + mock_dirname.return_value = '/path/to' + with yaml_provider.PypiExpansionService(['pkg1']) as service: + pass + + mock_clone.assert_called_once_with('/path/to', mock.ANY) + + +class ReshuffleTest(unittest.TestCase): + def test_reshuffle(self): + with beam.Pipeline() as p: + res = p | beam.Create( + [1, 2, 3]) | yaml_provider.YamlProviders.Reshuffle(num_buckets=1) + assert_that(res, equal_to([1, 2, 3])) + + def test_python_provider_with_packages(self): + p = yaml_provider.python( + urns={}, provider_base_path='/tmp', packages=['pkg']) + self.assertIsInstance(p, yaml_provider.ExternalPythonProvider) + + +if __name__ == '__main__': + logging.getLogger().setLevel(logging.INFO) + unittest.main()