diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index 1a852aa19d98..851efc81221d 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -641,6 +641,8 @@ def __init__( self._value_coder = value_coder self._cleared = False self._added_elements: set[Any] = set() + # Track outstanding async state requests to await them at commit time. + self._futures = [] def _compact_data(self, rewrite=True): accumulator = set( @@ -650,9 +652,11 @@ def _compact_data(self, rewrite=True): self._added_elements)) if rewrite and accumulator: - self._state_handler.clear(self._state_key) - self._state_handler.extend( - self._state_key, self._value_coder.get_impl(), accumulator) + # Compaction writes are asynchronous; queue them so they are not lost. + self._futures.append(self._state_handler.clear(self._state_key)) + self._futures.append( + self._state_handler.extend( + self._state_key, self._value_coder.get_impl(), accumulator)) # Since everthing is already committed so we can safely reinitialize # added_elements here. @@ -666,7 +670,7 @@ def read(self) -> set[Any]: def add(self, value: Any) -> None: if self._cleared: # This is a good time explicitly clear. - self._state_handler.clear(self._state_key) + self._futures.append(self._state_handler.clear(self._state_key)) self._cleared = False self._added_elements.add(value) @@ -678,15 +682,26 @@ def clear(self) -> None: self._added_elements = set() def commit(self) -> None: - to_await = None if self._cleared: - to_await = self._state_handler.clear(self._state_key) + self._futures.append(self._state_handler.clear(self._state_key)) + self._cleared = False if self._added_elements: - to_await = self._state_handler.extend( - self._state_key, self._value_coder.get_impl(), self._added_elements) - if to_await: - # To commit, we need to wait on the last state request future to complete. - to_await.get() + self._futures.append( + self._state_handler.extend( + self._state_key, + self._value_coder.get_impl(), + self._added_elements)) + self._added_elements = set() + + # Block on all outstanding async state requests to ensure data is committed. + # We must swap and clear self._futures before awaiting them. Awaiting a future + # yields control, during which new futures could be appended to self._futures. + all_futures = self._futures + self._futures = [] + + for f in all_futures: + if f: + f.get() class RangeSet: diff --git a/sdks/python/apache_beam/transforms/userstate_test.py b/sdks/python/apache_beam/transforms/userstate_test.py index 6a8efd1a536f..45dba5c9e9eb 100644 --- a/sdks/python/apache_beam/transforms/userstate_test.py +++ b/sdks/python/apache_beam/transforms/userstate_test.py @@ -18,6 +18,9 @@ """Unit tests for the Beam State and Timer API interfaces.""" # pytype: skip-file +import queue +import threading +import time import unittest from typing import Any @@ -34,6 +37,8 @@ from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.runners import pipeline_context from apache_beam.runners.common import DoFnSignature +from apache_beam.runners.worker.sdk_worker import GrpcStateHandler +from apache_beam.runners.worker.sdk_worker import _Future from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.test_stream import TestStream from apache_beam.testing.util import assert_that @@ -719,6 +724,70 @@ def process(self, element, set_state=beam.DoFn.StateParam(SET_STATE)): actual_values = (values | beam.ParDo(SetStatefulDoFn())) assert_that(actual_values, equal_to([1, 3, 6, 10, 10])) + # Mock random to always return 1.0 to force compaction on every add. + @mock.patch( + 'apache_beam.runners.worker.bundle_processor.random.random', lambda: 1.0) + def test_stateful_set_state_compaction_race_portably(self): + old_request = GrpcStateHandler._request + request_queue = queue.Queue() + + def worker(): + while True: + handler, request, future, instruction_id = request_queue.get() + time.sleep(0.1) # Simulate latency for each request sequentially. + handler._context.process_instruction_id = instruction_id + underlying_future = old_request(handler, request) + underlying_future.wait() + future.set(underlying_future.get()) + request_queue.task_done() + + t = threading.Thread(target=worker, daemon=True) + t.start() + + def delayed_request(self, request): + if request.HasField('append') or request.HasField('clear'): + future = _Future() + instruction_id = getattr(self._context, 'process_instruction_id', None) + request_queue.put((self, request, future, instruction_id)) + return future + else: + return old_request(self, request) + + GrpcStateHandler._request = delayed_request + + class SetStatefulDoFn(beam.DoFn): + + SET_STATE = SetStateSpec('buffer', VarIntCoder()) + + def process(self, element, set_state=beam.DoFn.StateParam(SET_STATE)): + _, value = element + aggregated_value = 0 + set_state.add(value) + for saved_value in set_state.read(): + aggregated_value += saved_value + yield aggregated_value + + try: + options = PipelineOptions([ + '--max_cache_memory_usage_mb=100', + '--environment_type=LOOPBACK', + ]) + with TestPipeline(options=options) as p: + test_stream = ( + TestStream( + coder=beam.coders.TupleCoder(( + beam.coders.StrUtf8Coder(), beam.coders.VarIntCoder() + ))).advance_watermark_to(10).add_elements([ + ('key', 1) + ]).advance_watermark_to(20).add_elements([ + ('key', 2) + ]).advance_watermark_to(30)) + actual_values = (p | test_stream | beam.ParDo(SetStatefulDoFn())) + assert_that(actual_values, equal_to([1, 3])) + + finally: + GrpcStateHandler._request = old_request + def test_stateful_set_state_clean_portably(self): class SetStateClearingStatefulDoFn(beam.DoFn):