From 5adcccdc399a8d23ba7b1b1beb6dfdbf158c60db Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Tue, 30 Jun 2026 13:20:34 -0400 Subject: [PATCH 1/6] Add a test to reproduce the problem. --- .../apache_beam/transforms/userstate_test.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) diff --git a/sdks/python/apache_beam/transforms/userstate_test.py b/sdks/python/apache_beam/transforms/userstate_test.py index 6a8efd1a536f..f34222df2690 100644 --- a/sdks/python/apache_beam/transforms/userstate_test.py +++ b/sdks/python/apache_beam/transforms/userstate_test.py @@ -719,6 +719,69 @@ def process(self, element, set_state=beam.DoFn.StateParam(SET_STATE)): actual_values = (values | beam.ParDo(SetStatefulDoFn())) assert_that(actual_values, equal_to([1, 3, 6, 10, 10])) + # Mock random to always return 1.0 to force compaction on every add. + @mock.patch('apache_beam.runners.worker.bundle_processor.random.random', lambda: 1.0) + def test_stateful_set_state_compaction_race_portably(self): + from apache_beam.runners.worker.sdk_worker import GrpcStateHandler + from apache_beam.runners.worker.sdk_worker import _Future + import time + import threading + + old_request = GrpcStateHandler._request + + def delayed_request(self, request): + # Delay append and clear requests to simulate slow state updates. + if request.HasField('append') or request.HasField('clear'): + future = _Future() + instruction_id = getattr(self._context, 'process_instruction_id', None) + def run(): + time.sleep(0.5) + self._context.process_instruction_id = instruction_id + underlying_future = old_request(self, request) + underlying_future.wait() + future.set(underlying_future.get()) + t = threading.Thread(target=run, daemon=True) + t.start() + return future + else: + return old_request(self, request) + + GrpcStateHandler._request = delayed_request + + + class SetStatefulDoFn(beam.DoFn): + + SET_STATE = SetStateSpec('buffer', VarIntCoder()) + + def process(self, element, set_state=beam.DoFn.StateParam(SET_STATE)): + _, value = element + aggregated_value = 0 + set_state.add(value) + for saved_value in set_state.read(): + aggregated_value += saved_value + yield aggregated_value + + try: + options = PipelineOptions([ + '--max_cache_memory_usage_mb=100', + '--environment_type=LOOPBACK', + ]) + with TestPipeline(runner='PrismRunner', options=options) as p: + test_stream = ( + TestStream(coder=beam.coders.TupleCoder((beam.coders.StrUtf8Coder(), beam.coders.VarIntCoder()))) + .advance_watermark_to(10) + .add_elements([('key', 1)]) + .advance_watermark_to(20) + .add_elements([('key', 2)]) + .advance_watermark_to(30) + ) + actual_values = (p | test_stream | beam.ParDo(SetStatefulDoFn())) + assert_that(actual_values, equal_to([1, 3])) + + + finally: + GrpcStateHandler._request = old_request + def test_stateful_set_state_clean_portably(self): class SetStateClearingStatefulDoFn(beam.DoFn): From cb7ed5b22f27a91b8e98e1e85d4cc0f216a990bc Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Tue, 30 Jun 2026 14:23:41 -0400 Subject: [PATCH 2/6] Fix race condition in SetState compaction by awaiting outstanding state requests --- .../runners/worker/bundle_processor.py | 37 +++++++++++++------ 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index 1a852aa19d98..f3ca4b645189 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -641,6 +641,8 @@ def __init__( self._value_coder = value_coder self._cleared = False self._added_elements: set[Any] = set() + # Track outstanding async state requests to await them at commit time. + self._futures = [] def _compact_data(self, rewrite=True): accumulator = set( @@ -650,9 +652,12 @@ def _compact_data(self, rewrite=True): self._added_elements)) if rewrite and accumulator: - self._state_handler.clear(self._state_key) - self._state_handler.extend( - self._state_key, self._value_coder.get_impl(), accumulator) + # Compaction writes are asynchronous; queue them so they are not lost. + self._futures.append( + self._state_handler.clear(self._state_key)) + self._futures.append( + self._state_handler.extend( + self._state_key, self._value_coder.get_impl(), accumulator)) # Since everthing is already committed so we can safely reinitialize # added_elements here. @@ -666,7 +671,8 @@ def read(self) -> set[Any]: def add(self, value: Any) -> None: if self._cleared: # This is a good time explicitly clear. - self._state_handler.clear(self._state_key) + self._futures.append( + self._state_handler.clear(self._state_key)) self._cleared = False self._added_elements.add(value) @@ -678,15 +684,24 @@ def clear(self) -> None: self._added_elements = set() def commit(self) -> None: - to_await = None + to_await = [] if self._cleared: - to_await = self._state_handler.clear(self._state_key) + to_await.append(self._state_handler.clear(self._state_key)) + self._cleared = False if self._added_elements: - to_await = self._state_handler.extend( - self._state_key, self._value_coder.get_impl(), self._added_elements) - if to_await: - # To commit, we need to wait on the last state request future to complete. - to_await.get() + to_await.append( + self._state_handler.extend( + self._state_key, self._value_coder.get_impl(), + self._added_elements)) + self._added_elements = set() + + # Block on all outstanding async state requests to ensure data is committed. + all_futures = self._futures + to_await + self._futures = [] + + for f in all_futures: + if f: + f.get() class RangeSet: From 667d63966bfa199915127b775c009a4d1ee4baa1 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Tue, 30 Jun 2026 14:33:10 -0400 Subject: [PATCH 3/6] Reformat --- .../runners/worker/bundle_processor.py | 9 ++++---- .../apache_beam/transforms/userstate_test.py | 22 ++++++++++--------- 2 files changed, 16 insertions(+), 15 deletions(-) diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index f3ca4b645189..9930e7bcd29d 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -653,8 +653,7 @@ def _compact_data(self, rewrite=True): if rewrite and accumulator: # Compaction writes are asynchronous; queue them so they are not lost. - self._futures.append( - self._state_handler.clear(self._state_key)) + self._futures.append(self._state_handler.clear(self._state_key)) self._futures.append( self._state_handler.extend( self._state_key, self._value_coder.get_impl(), accumulator)) @@ -671,8 +670,7 @@ def read(self) -> set[Any]: def add(self, value: Any) -> None: if self._cleared: # This is a good time explicitly clear. - self._futures.append( - self._state_handler.clear(self._state_key)) + self._futures.append(self._state_handler.clear(self._state_key)) self._cleared = False self._added_elements.add(value) @@ -691,7 +689,8 @@ def commit(self) -> None: if self._added_elements: to_await.append( self._state_handler.extend( - self._state_key, self._value_coder.get_impl(), + self._state_key, + self._value_coder.get_impl(), self._added_elements)) self._added_elements = set() diff --git a/sdks/python/apache_beam/transforms/userstate_test.py b/sdks/python/apache_beam/transforms/userstate_test.py index f34222df2690..d39524d0b2f0 100644 --- a/sdks/python/apache_beam/transforms/userstate_test.py +++ b/sdks/python/apache_beam/transforms/userstate_test.py @@ -720,7 +720,8 @@ def process(self, element, set_state=beam.DoFn.StateParam(SET_STATE)): assert_that(actual_values, equal_to([1, 3, 6, 10, 10])) # Mock random to always return 1.0 to force compaction on every add. - @mock.patch('apache_beam.runners.worker.bundle_processor.random.random', lambda: 1.0) + @mock.patch( + 'apache_beam.runners.worker.bundle_processor.random.random', lambda: 1.0) def test_stateful_set_state_compaction_race_portably(self): from apache_beam.runners.worker.sdk_worker import GrpcStateHandler from apache_beam.runners.worker.sdk_worker import _Future @@ -734,12 +735,14 @@ def delayed_request(self, request): if request.HasField('append') or request.HasField('clear'): future = _Future() instruction_id = getattr(self._context, 'process_instruction_id', None) + def run(): time.sleep(0.5) self._context.process_instruction_id = instruction_id underlying_future = old_request(self, request) underlying_future.wait() future.set(underlying_future.get()) + t = threading.Thread(target=run, daemon=True) t.start() return future @@ -748,7 +751,6 @@ def run(): GrpcStateHandler._request = delayed_request - class SetStatefulDoFn(beam.DoFn): SET_STATE = SetStateSpec('buffer', VarIntCoder()) @@ -768,17 +770,17 @@ def process(self, element, set_state=beam.DoFn.StateParam(SET_STATE)): ]) with TestPipeline(runner='PrismRunner', options=options) as p: test_stream = ( - TestStream(coder=beam.coders.TupleCoder((beam.coders.StrUtf8Coder(), beam.coders.VarIntCoder()))) - .advance_watermark_to(10) - .add_elements([('key', 1)]) - .advance_watermark_to(20) - .add_elements([('key', 2)]) - .advance_watermark_to(30) - ) + TestStream( + coder=beam.coders.TupleCoder(( + beam.coders.StrUtf8Coder(), beam.coders.VarIntCoder() + ))).advance_watermark_to(10).add_elements([ + ('key', 1) + ]).advance_watermark_to(20).add_elements([ + ('key', 2) + ]).advance_watermark_to(30)) actual_values = (p | test_stream | beam.ParDo(SetStatefulDoFn())) assert_that(actual_values, equal_to([1, 3])) - finally: GrpcStateHandler._request = old_request From cca9ca175d8e252b8866a34ca384ea5a9cd0ab43 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Tue, 30 Jun 2026 14:53:00 -0400 Subject: [PATCH 4/6] Simply logic --- .../runners/worker/bundle_processor.py | 9 ++++--- .../apache_beam/transforms/userstate_test.py | 27 +++++++++++-------- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/sdks/python/apache_beam/runners/worker/bundle_processor.py b/sdks/python/apache_beam/runners/worker/bundle_processor.py index 9930e7bcd29d..851efc81221d 100644 --- a/sdks/python/apache_beam/runners/worker/bundle_processor.py +++ b/sdks/python/apache_beam/runners/worker/bundle_processor.py @@ -682,12 +682,11 @@ def clear(self) -> None: self._added_elements = set() def commit(self) -> None: - to_await = [] if self._cleared: - to_await.append(self._state_handler.clear(self._state_key)) + self._futures.append(self._state_handler.clear(self._state_key)) self._cleared = False if self._added_elements: - to_await.append( + self._futures.append( self._state_handler.extend( self._state_key, self._value_coder.get_impl(), @@ -695,7 +694,9 @@ def commit(self) -> None: self._added_elements = set() # Block on all outstanding async state requests to ensure data is committed. - all_futures = self._futures + to_await + # We must swap and clear self._futures before awaiting them. Awaiting a future + # yields control, during which new futures could be appended to self._futures. + all_futures = self._futures self._futures = [] for f in all_futures: diff --git a/sdks/python/apache_beam/transforms/userstate_test.py b/sdks/python/apache_beam/transforms/userstate_test.py index d39524d0b2f0..0031ac24028e 100644 --- a/sdks/python/apache_beam/transforms/userstate_test.py +++ b/sdks/python/apache_beam/transforms/userstate_test.py @@ -727,24 +727,29 @@ def test_stateful_set_state_compaction_race_portably(self): from apache_beam.runners.worker.sdk_worker import _Future import time import threading + import queue old_request = GrpcStateHandler._request + request_queue = queue.Queue() + + def worker(): + while True: + handler, request, future, instruction_id = request_queue.get() + time.sleep(0.1) # Simulate latency for each request sequentially. + handler._context.process_instruction_id = instruction_id + underlying_future = old_request(handler, request) + underlying_future.wait() + future.set(underlying_future.get()) + request_queue.task_done() + + t = threading.Thread(target=worker, daemon=True) + t.start() def delayed_request(self, request): - # Delay append and clear requests to simulate slow state updates. if request.HasField('append') or request.HasField('clear'): future = _Future() instruction_id = getattr(self._context, 'process_instruction_id', None) - - def run(): - time.sleep(0.5) - self._context.process_instruction_id = instruction_id - underlying_future = old_request(self, request) - underlying_future.wait() - future.set(underlying_future.get()) - - t = threading.Thread(target=run, daemon=True) - t.start() + request_queue.put((self, request, future, instruction_id)) return future else: return old_request(self, request) From d7d367826bdef60f0bb2f63268c414e54ac9f7c8 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Tue, 30 Jun 2026 14:58:13 -0400 Subject: [PATCH 5/6] Fix lints --- sdks/python/apache_beam/transforms/userstate_test.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/sdks/python/apache_beam/transforms/userstate_test.py b/sdks/python/apache_beam/transforms/userstate_test.py index 0031ac24028e..2dc19c238ccc 100644 --- a/sdks/python/apache_beam/transforms/userstate_test.py +++ b/sdks/python/apache_beam/transforms/userstate_test.py @@ -18,6 +18,9 @@ """Unit tests for the Beam State and Timer API interfaces.""" # pytype: skip-file +import queue +import threading +import time import unittest from typing import Any @@ -34,6 +37,8 @@ from apache_beam.portability.api import beam_runner_api_pb2 from apache_beam.runners import pipeline_context from apache_beam.runners.common import DoFnSignature +from apache_beam.runners.worker.sdk_worker import GrpcStateHandler +from apache_beam.runners.worker.sdk_worker import _Future from apache_beam.testing.test_pipeline import TestPipeline from apache_beam.testing.test_stream import TestStream from apache_beam.testing.util import assert_that @@ -723,12 +728,6 @@ def process(self, element, set_state=beam.DoFn.StateParam(SET_STATE)): @mock.patch( 'apache_beam.runners.worker.bundle_processor.random.random', lambda: 1.0) def test_stateful_set_state_compaction_race_portably(self): - from apache_beam.runners.worker.sdk_worker import GrpcStateHandler - from apache_beam.runners.worker.sdk_worker import _Future - import time - import threading - import queue - old_request = GrpcStateHandler._request request_queue = queue.Queue() From 012fdd2d05d44989da4dd13f929c5c0e79ce4b69 Mon Sep 17 00:00:00 2001 From: Shunping Huang Date: Tue, 30 Jun 2026 15:38:12 -0400 Subject: [PATCH 6/6] Do not force to use prism for the new test. --- sdks/python/apache_beam/transforms/userstate_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdks/python/apache_beam/transforms/userstate_test.py b/sdks/python/apache_beam/transforms/userstate_test.py index 2dc19c238ccc..45dba5c9e9eb 100644 --- a/sdks/python/apache_beam/transforms/userstate_test.py +++ b/sdks/python/apache_beam/transforms/userstate_test.py @@ -772,7 +772,7 @@ def process(self, element, set_state=beam.DoFn.StateParam(SET_STATE)): '--max_cache_memory_usage_mb=100', '--environment_type=LOOPBACK', ]) - with TestPipeline(runner='PrismRunner', options=options) as p: + with TestPipeline(options=options) as p: test_stream = ( TestStream( coder=beam.coders.TupleCoder((