Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 26 additions & 11 deletions sdks/python/apache_beam/runners/worker/bundle_processor.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,6 +641,8 @@ def __init__(
self._value_coder = value_coder
self._cleared = False
self._added_elements: set[Any] = set()
# Track outstanding async state requests to await them at commit time.
self._futures = []

def _compact_data(self, rewrite=True):
accumulator = set(
Expand All @@ -650,9 +652,11 @@ def _compact_data(self, rewrite=True):
self._added_elements))

if rewrite and accumulator:
self._state_handler.clear(self._state_key)
self._state_handler.extend(
self._state_key, self._value_coder.get_impl(), accumulator)
# Compaction writes are asynchronous; queue them so they are not lost.
self._futures.append(self._state_handler.clear(self._state_key))
self._futures.append(
self._state_handler.extend(
self._state_key, self._value_coder.get_impl(), accumulator))

# Since everthing is already committed so we can safely reinitialize
# added_elements here.
Expand All @@ -666,7 +670,7 @@ def read(self) -> set[Any]:
def add(self, value: Any) -> None:
if self._cleared:
# This is a good time explicitly clear.
self._state_handler.clear(self._state_key)
self._futures.append(self._state_handler.clear(self._state_key))
self._cleared = False

self._added_elements.add(value)
Expand All @@ -678,15 +682,26 @@ def clear(self) -> None:
self._added_elements = set()

def commit(self) -> None:
to_await = None
if self._cleared:
to_await = self._state_handler.clear(self._state_key)
self._futures.append(self._state_handler.clear(self._state_key))
self._cleared = False
if self._added_elements:
to_await = self._state_handler.extend(
self._state_key, self._value_coder.get_impl(), self._added_elements)
if to_await:
# To commit, we need to wait on the last state request future to complete.
to_await.get()
self._futures.append(
self._state_handler.extend(
self._state_key,
self._value_coder.get_impl(),
self._added_elements))
self._added_elements = set()

# Block on all outstanding async state requests to ensure data is committed.
# We must swap and clear self._futures before awaiting them. Awaiting a future
# yields control, during which new futures could be appended to self._futures.
all_futures = self._futures
self._futures = []

for f in all_futures:
if f:
f.get()
Comment thread
shunping marked this conversation as resolved.


class RangeSet:
Expand Down
69 changes: 69 additions & 0 deletions sdks/python/apache_beam/transforms/userstate_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@
"""Unit tests for the Beam State and Timer API interfaces."""
# pytype: skip-file

import queue
import threading
import time
import unittest
from typing import Any

Expand All @@ -34,6 +37,8 @@
from apache_beam.portability.api import beam_runner_api_pb2
from apache_beam.runners import pipeline_context
from apache_beam.runners.common import DoFnSignature
from apache_beam.runners.worker.sdk_worker import GrpcStateHandler
from apache_beam.runners.worker.sdk_worker import _Future
from apache_beam.testing.test_pipeline import TestPipeline
from apache_beam.testing.test_stream import TestStream
from apache_beam.testing.util import assert_that
Expand Down Expand Up @@ -719,6 +724,70 @@ def process(self, element, set_state=beam.DoFn.StateParam(SET_STATE)):
actual_values = (values | beam.ParDo(SetStatefulDoFn()))
assert_that(actual_values, equal_to([1, 3, 6, 10, 10]))

# Mock random to always return 1.0 to force compaction on every add.
@mock.patch(
'apache_beam.runners.worker.bundle_processor.random.random', lambda: 1.0)
def test_stateful_set_state_compaction_race_portably(self):
old_request = GrpcStateHandler._request
request_queue = queue.Queue()

def worker():
while True:
handler, request, future, instruction_id = request_queue.get()
time.sleep(0.1) # Simulate latency for each request sequentially.
handler._context.process_instruction_id = instruction_id
underlying_future = old_request(handler, request)
underlying_future.wait()
future.set(underlying_future.get())
request_queue.task_done()

t = threading.Thread(target=worker, daemon=True)
t.start()

def delayed_request(self, request):
if request.HasField('append') or request.HasField('clear'):
future = _Future()
instruction_id = getattr(self._context, 'process_instruction_id', None)
request_queue.put((self, request, future, instruction_id))
return future
else:
return old_request(self, request)

GrpcStateHandler._request = delayed_request

class SetStatefulDoFn(beam.DoFn):

SET_STATE = SetStateSpec('buffer', VarIntCoder())

def process(self, element, set_state=beam.DoFn.StateParam(SET_STATE)):
_, value = element
aggregated_value = 0
set_state.add(value)
for saved_value in set_state.read():
aggregated_value += saved_value
yield aggregated_value

try:
options = PipelineOptions([
'--max_cache_memory_usage_mb=100',
'--environment_type=LOOPBACK',
])
with TestPipeline(options=options) as p:
test_stream = (
TestStream(
coder=beam.coders.TupleCoder((
beam.coders.StrUtf8Coder(), beam.coders.VarIntCoder()
))).advance_watermark_to(10).add_elements([
('key', 1)
]).advance_watermark_to(20).add_elements([
('key', 2)
]).advance_watermark_to(30))
actual_values = (p | test_stream | beam.ParDo(SetStatefulDoFn()))
assert_that(actual_values, equal_to([1, 3]))

finally:
GrpcStateHandler._request = old_request

def test_stateful_set_state_clean_portably(self):
class SetStateClearingStatefulDoFn(beam.DoFn):

Expand Down
Loading