diff --git a/java-storage/google-cloud-storage/src/main/java/com/google/cloud/storage/GapicUnbufferedReadableByteChannel.java b/java-storage/google-cloud-storage/src/main/java/com/google/cloud/storage/GapicUnbufferedReadableByteChannel.java index cef751213c6d..f17bd942fd31 100644 --- a/java-storage/google-cloud-storage/src/main/java/com/google/cloud/storage/GapicUnbufferedReadableByteChannel.java +++ b/java-storage/google-cloud-storage/src/main/java/com/google/cloud/storage/GapicUnbufferedReadableByteChannel.java @@ -49,6 +49,7 @@ import java.nio.channels.ScatteringByteChannel; import java.util.List; import java.util.Locale; +import java.util.OptionalLong; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.CancellationException; import java.util.concurrent.ExecutionException; @@ -91,7 +92,15 @@ final class GapicUnbufferedReadableByteChannel this.result = result; this.read = read; this.req = req; - this.hasher = hasher; + this.hasher = + (req.getReadOffset() == 0 && !(hasher instanceof Hasher.NoOpHasher)) + ? new CumulativeHasher( + hasher, + 0, + req.getReadLimit() <= 0 + ? OptionalLong.empty() + : OptionalLong.of(req.getReadLimit())) + : hasher; this.fetchOffset = new AtomicLong(req.getReadOffset()); this.blobOffset = req.getReadOffset(); this.retrier = retrier; @@ -154,7 +163,7 @@ public long read(ByteBuffer[] dsts, int offset, int length) throws IOException { if (take instanceof IOException) { IOException ioe = (IOException) take; if (alg.shouldRetry(ioe, null)) { - readObjectObserver = null; + cancelAndDrainCurrentObserver(); continue; } else { ioe.addSuppressed(new AsyncStorageTaskException()); @@ -165,7 +174,7 @@ public long read(ByteBuffer[] dsts, int offset, int length) throws IOException { Throwable throwable = (Throwable) take; BaseServiceException coalesce = StorageException.coalesce(throwable); if (alg.shouldRetry(coalesce, null)) { - readObjectObserver = null; + cancelAndDrainCurrentObserver(); continue; } else { close(); @@ -174,6 +183,7 @@ public long read(ByteBuffer[] dsts, int offset, int length) throws IOException { } if (take == EOF_MARKER) { complete = true; + validateCumulativeChecksum(); break; } @@ -240,7 +250,9 @@ private void drainQueue() throws IOException { while (queue.nonEmpty()) { try { java.lang.Object queueValue = queue.poll(); - if (queueValue instanceof ReadObjectResponse) { + if (queueValue instanceof java.io.Closeable) { + ((java.io.Closeable) queueValue).close(); + } else if (queueValue instanceof ReadObjectResponse) { ReadObjectResponse resp = (ReadObjectResponse) queueValue; ResponseContentLifecycleHandle handle = read.getResponseContentLifecycleManager().get(resp); @@ -273,6 +285,19 @@ private void drainQueue() throws IOException { } } + private void cancelAndDrainCurrentObserver() { + if (readObjectObserver != null) { + readObjectObserver.cancel(); + try { + drainQueue(); + } catch (IOException e) { + // drainQueue() in this context can be ignored because we are resetting the + // stream. + } + readObjectObserver = null; + } + } + ApiFuture getResult() { return result; } @@ -311,14 +336,27 @@ private IOException createError(String message) throws IOException { return new IOException(message, cause); } + private void validateCumulativeChecksum() throws IOException { + if (hasher instanceof CumulativeHasher) { + CumulativeHasher cumulativeHasher = (CumulativeHasher) hasher; + try { + cumulativeHasher.validateCumulativeChecksum(metadata); + } catch (UncheckedCumulativeChecksumMismatchException exception) { + throw new IOException(StorageException.coalesce(exception)); + } + } + } + private final class ReadObjectObserver extends StateCheckingResponseObserver { private final SettableApiFuture open = SettableApiFuture.create(); private final SettableApiFuture cancellation = SettableApiFuture.create(); private volatile StreamController controller; + private volatile boolean cancelled = false; void cancel() { + cancelled = true; controller.cancel(); } @@ -331,10 +369,13 @@ protected void onStartImpl(StreamController controller) { @Override protected void onResponseImpl(ReadObjectResponse response) { - controller.request(1); - open.set(null); try (ResponseContentLifecycleHandle handle = read.getResponseContentLifecycleManager().get(response)) { + if (cancelled) { + return; + } + controller.request(1); + open.set(null); ChecksummedData checksummedData = response.getChecksummedData(); ByteString content = checksummedData.getContent(); int contentSize = content.size(); @@ -348,6 +389,8 @@ protected void onResponseImpl(ReadObjectResponse response) { queue.offer(e); return; } + } else if (hasher instanceof CumulativeHasher) { + hasher.validateUnchecked(null, content); } if (response.hasMetadata()) { Object respMetadata = response.getMetadata(); @@ -380,6 +423,12 @@ protected void onResponseImpl(ReadObjectResponse response) { @Override protected void onErrorImpl(Throwable t) { + if (t instanceof CancellationException) { + cancellation.set(t); + } + if (cancelled) { + return; + } if (t instanceof OutOfRangeException) { try { queue.offer(EOF_MARKER); @@ -389,17 +438,15 @@ protected void onErrorImpl(Throwable t) { throw Code.ABORTED.toStatus().withCause(e).asRuntimeException(); } } - if (t instanceof CancellationException) { - cancellation.set(t); - } if (!open.isDone()) { open.setException(t); - } - try { - queue.offer(t); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw Code.ABORTED.toStatus().withCause(e).asRuntimeException(); + } else { + try { + queue.offer(t); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw Code.ABORTED.toStatus().withCause(e).asRuntimeException(); + } } } diff --git a/java-storage/google-cloud-storage/src/test/java/com/google/cloud/storage/GapicUnbufferedReadableByteChannelTest.java b/java-storage/google-cloud-storage/src/test/java/com/google/cloud/storage/GapicUnbufferedReadableByteChannelTest.java index 27d96ef6f06f..c4b19e717c88 100644 --- a/java-storage/google-cloud-storage/src/test/java/com/google/cloud/storage/GapicUnbufferedReadableByteChannelTest.java +++ b/java-storage/google-cloud-storage/src/test/java/com/google/cloud/storage/GapicUnbufferedReadableByteChannelTest.java @@ -18,6 +18,7 @@ import static com.google.cloud.storage.TestUtils.xxd; import static com.google.common.truth.Truth.assertThat; +import static org.junit.Assert.assertThrows; import com.google.api.core.SettableApiFuture; import com.google.api.gax.rpc.ApiCallContext; @@ -26,6 +27,8 @@ import com.google.cloud.storage.GrpcUtils.ZeroCopyServerStreamingCallable; import com.google.cloud.storage.Retrying.Retrier; import com.google.cloud.storage.it.ChecksummedTestContent; +import com.google.storage.v2.Object; +import com.google.storage.v2.ObjectChecksums; import com.google.storage.v2.ReadObjectRequest; import com.google.storage.v2.ReadObjectResponse; import java.io.IOException; @@ -75,4 +78,464 @@ public void call( assertThat(close.get()).isTrue(); } } + + @Test + public void validateCumulativeChecksum_success() throws IOException { + ChecksummedTestContent testContent = + ChecksummedTestContent.of(DataGenerator.base64Characters().genBytes(10)); + + Object metadata = + Object.newBuilder() + .setSize(testContent.length()) + .setChecksums(ObjectChecksums.newBuilder().setCrc32C(testContent.getCrc32c()).build()) + .build(); + + ResponseContentLifecycleManager manager = + ResponseContentLifecycleManager.noop(); + + try (GapicUnbufferedReadableByteChannel c = + new GapicUnbufferedReadableByteChannel( + SettableApiFuture.create(), + new ZeroCopyServerStreamingCallable<>( + new ServerStreamingCallable() { + @Override + public void call( + ReadObjectRequest request, + ResponseObserver respond, + ApiCallContext context) { + respond.onStart(TestUtils.nullStreamController()); + respond.onResponse( + ReadObjectResponse.newBuilder() + .setChecksummedData(testContent.asChecksummedData()) + .setMetadata(metadata) + .build()); + respond.onComplete(); + } + }, + manager), + ReadObjectRequest.getDefaultInstance(), + Hasher.defaultHasher(), + Retrier.attemptOnce(), + Retrying.neverRetry())) { + + ByteBuffer buffer = ByteBuffer.allocate(15); + int read = (int) c.read(new ByteBuffer[] {buffer}, 0, 1); + assertThat(read).isEqualTo(testContent.length()); + assertThat(xxd(buffer)).isEqualTo(xxd(testContent.getBytes())); + } + } + + @Test + public void validateCumulativeChecksum_failure() throws IOException { + ChecksummedTestContent testContent = + ChecksummedTestContent.of(DataGenerator.base64Characters().genBytes(10)); + + Object metadata = + Object.newBuilder() + .setSize(testContent.length()) + .setChecksums( + ObjectChecksums.newBuilder().setCrc32C(testContent.getCrc32c() + 1).build()) + .build(); + + ResponseContentLifecycleManager manager = + ResponseContentLifecycleManager.noop(); + + try (GapicUnbufferedReadableByteChannel c = + new GapicUnbufferedReadableByteChannel( + SettableApiFuture.create(), + new ZeroCopyServerStreamingCallable<>( + new ServerStreamingCallable() { + @Override + public void call( + ReadObjectRequest request, + ResponseObserver respond, + ApiCallContext context) { + respond.onStart(TestUtils.nullStreamController()); + respond.onResponse( + ReadObjectResponse.newBuilder() + .setChecksummedData(testContent.asChecksummedData()) + .setMetadata(metadata) + .build()); + respond.onComplete(); + } + }, + manager), + ReadObjectRequest.getDefaultInstance(), + Hasher.defaultHasher(), + Retrier.attemptOnce(), + Retrying.neverRetry())) { + + ByteBuffer buffer = ByteBuffer.allocate(15); + IOException exception = assertThrows(IOException.class, () -> c.read(buffer)); + assertThat(exception.getCause()).isInstanceOf(StorageException.class); + assertThat(exception.getCause().getCause()) + .isInstanceOf(UncheckedCumulativeChecksumMismatchException.class); + } + } + + @Test + public void validateCumulativeChecksum_disabled_noFailureOnMismatch() throws IOException { + ChecksummedTestContent testContent = + ChecksummedTestContent.of(DataGenerator.base64Characters().genBytes(10)); + + Object metadata = + Object.newBuilder() + .setSize(testContent.length()) + .setChecksums( + ObjectChecksums.newBuilder().setCrc32C(testContent.getCrc32c() + 1).build()) + .build(); + + ResponseContentLifecycleManager manager = + ResponseContentLifecycleManager.noop(); + + try (GapicUnbufferedReadableByteChannel c = + new GapicUnbufferedReadableByteChannel( + SettableApiFuture.create(), + new ZeroCopyServerStreamingCallable<>( + new ServerStreamingCallable() { + @Override + public void call( + ReadObjectRequest request, + ResponseObserver respond, + ApiCallContext context) { + respond.onStart(TestUtils.nullStreamController()); + respond.onResponse( + ReadObjectResponse.newBuilder() + .setChecksummedData(testContent.asChecksummedData()) + .setMetadata(metadata) + .build()); + respond.onComplete(); + } + }, + manager), + ReadObjectRequest.getDefaultInstance(), + Hasher.noop(), + Retrier.attemptOnce(), + Retrying.neverRetry())) { + + ByteBuffer buffer = ByteBuffer.allocate(15); + int read = (int) c.read(new ByteBuffer[] {buffer}, 0, 1); + assertThat(read).isEqualTo(testContent.length()); + assertThat(xxd(buffer)).isEqualTo(xxd(testContent.getBytes())); + } + } + + @Test + public void validateCumulativeChecksum_skippedForRangedRead() throws IOException { + ChecksummedTestContent testContent = + ChecksummedTestContent.of(DataGenerator.base64Characters().genBytes(10)); + + Object metadata = + Object.newBuilder() + .setSize(testContent.length()) + .setChecksums( + ObjectChecksums.newBuilder().setCrc32C(testContent.getCrc32c() + 1).build()) + .build(); + + ResponseContentLifecycleManager manager = + ResponseContentLifecycleManager.noop(); + + ReadObjectRequest req = ReadObjectRequest.newBuilder().setReadLimit(5).build(); + + try (GapicUnbufferedReadableByteChannel c = + new GapicUnbufferedReadableByteChannel( + SettableApiFuture.create(), + new ZeroCopyServerStreamingCallable<>( + new ServerStreamingCallable() { + @Override + public void call( + ReadObjectRequest request, + ResponseObserver respond, + ApiCallContext context) { + respond.onStart(TestUtils.nullStreamController()); + respond.onResponse( + ReadObjectResponse.newBuilder() + .setChecksummedData(testContent.slice(0, 5).asChecksummedData()) + .setMetadata(metadata) + .build()); + respond.onComplete(); + } + }, + manager), + req, + Hasher.defaultHasher(), + Retrier.attemptOnce(), + Retrying.neverRetry())) { + + ByteBuffer buffer = ByteBuffer.allocate(15); + int read = (int) c.read(buffer); + assertThat(read).isEqualTo(5); + assertThat(xxd(buffer)).isEqualTo(xxd(testContent.slice(0, 5).getBytes())); + } + } + + @Test + public void validateCumulativeChecksum_multipleChunks_success() throws IOException { + ChecksummedTestContent chunk1 = ChecksummedTestContent.of("abcde".getBytes()); + ChecksummedTestContent chunk2 = ChecksummedTestContent.of("fghij".getBytes()); + ChecksummedTestContent chunk3 = ChecksummedTestContent.of("klmno".getBytes()); + byte[] fullBytes = "abcdefghijklmno".getBytes(); + ChecksummedTestContent fullContent = ChecksummedTestContent.of(fullBytes); + + Object metadata = + Object.newBuilder() + .setSize(fullContent.length()) + .setChecksums(ObjectChecksums.newBuilder().setCrc32C(fullContent.getCrc32c()).build()) + .build(); + + ResponseContentLifecycleManager manager = + ResponseContentLifecycleManager.noop(); + + try (GapicUnbufferedReadableByteChannel c = + new GapicUnbufferedReadableByteChannel( + SettableApiFuture.create(), + new ZeroCopyServerStreamingCallable<>( + new ServerStreamingCallable() { + @Override + public void call( + ReadObjectRequest request, + ResponseObserver respond, + ApiCallContext context) { + respond.onStart(TestUtils.nullStreamController()); + new Thread( + () -> { + try { + respond.onResponse( + ReadObjectResponse.newBuilder() + .setChecksummedData(chunk1.asChecksummedData()) + .build()); + respond.onResponse( + ReadObjectResponse.newBuilder() + .setChecksummedData(chunk2.asChecksummedData()) + .build()); + respond.onResponse( + ReadObjectResponse.newBuilder() + .setChecksummedData(chunk3.asChecksummedData()) + .setMetadata(metadata) + .build()); + respond.onComplete(); + } catch (Throwable t) { + respond.onError(t); + } + }) + .start(); + } + }, + manager), + ReadObjectRequest.getDefaultInstance(), + Hasher.defaultHasher(), + Retrier.attemptOnce(), + Retrying.neverRetry())) { + + ByteBuffer buffer = ByteBuffer.allocate(20); + int read = (int) c.read(new ByteBuffer[] {buffer}, 0, 1); + assertThat(read).isEqualTo(15); + assertThat(xxd(buffer)).isEqualTo(xxd(fullContent.getBytes())); + } + } + + @Test + public void validateCumulativeChecksum_multipleChunks_failure() throws IOException { + ChecksummedTestContent chunk1 = ChecksummedTestContent.of("abcde".getBytes()); + ChecksummedTestContent chunk2 = ChecksummedTestContent.of("fghij".getBytes()); + ChecksummedTestContent chunk3 = ChecksummedTestContent.of("klmno".getBytes()); + byte[] fullBytes = "abcdefghijklmno".getBytes(); + ChecksummedTestContent fullContent = ChecksummedTestContent.of(fullBytes); + + Object metadata = + Object.newBuilder() + .setSize(fullContent.length()) + .setChecksums( + ObjectChecksums.newBuilder().setCrc32C(fullContent.getCrc32c() + 1).build()) + .build(); + + ResponseContentLifecycleManager manager = + ResponseContentLifecycleManager.noop(); + + try (GapicUnbufferedReadableByteChannel c = + new GapicUnbufferedReadableByteChannel( + SettableApiFuture.create(), + new ZeroCopyServerStreamingCallable<>( + new ServerStreamingCallable() { + @Override + public void call( + ReadObjectRequest request, + ResponseObserver respond, + ApiCallContext context) { + respond.onStart(TestUtils.nullStreamController()); + new Thread( + () -> { + try { + respond.onResponse( + ReadObjectResponse.newBuilder() + .setChecksummedData(chunk1.asChecksummedData()) + .build()); + respond.onResponse( + ReadObjectResponse.newBuilder() + .setChecksummedData(chunk2.asChecksummedData()) + .build()); + respond.onResponse( + ReadObjectResponse.newBuilder() + .setChecksummedData(chunk3.asChecksummedData()) + .setMetadata(metadata) + .build()); + respond.onComplete(); + } catch (Throwable t) { + respond.onError(t); + } + }) + .start(); + } + }, + manager), + ReadObjectRequest.getDefaultInstance(), + Hasher.defaultHasher(), + Retrier.attemptOnce(), + Retrying.neverRetry())) { + + ByteBuffer buffer = ByteBuffer.allocate(20); + IOException exception = + assertThrows( + IOException.class, + () -> { + c.read(new ByteBuffer[] {buffer}, 0, 1); + }); + assertThat(exception.getCause()).isInstanceOf(StorageException.class); + assertThat(exception.getCause().getCause()) + .isInstanceOf(UncheckedCumulativeChecksumMismatchException.class); + } + } + + @Test + public void validateCumulativeChecksum_metadataMissingCrc32c_skipped() throws IOException { + ChecksummedTestContent testContent = + ChecksummedTestContent.of(DataGenerator.base64Characters().genBytes(10)); + + Object metadata = + Object.newBuilder() + .setSize(testContent.length()) + .setChecksums(ObjectChecksums.newBuilder().build()) + .build(); + + ResponseContentLifecycleManager manager = + ResponseContentLifecycleManager.noop(); + + try (GapicUnbufferedReadableByteChannel c = + new GapicUnbufferedReadableByteChannel( + SettableApiFuture.create(), + new ZeroCopyServerStreamingCallable<>( + new ServerStreamingCallable() { + @Override + public void call( + ReadObjectRequest request, + ResponseObserver respond, + ApiCallContext context) { + respond.onStart(TestUtils.nullStreamController()); + respond.onResponse( + ReadObjectResponse.newBuilder() + .setChecksummedData(testContent.asChecksummedData()) + .setMetadata(metadata) + .build()); + respond.onComplete(); + } + }, + manager), + ReadObjectRequest.getDefaultInstance(), + Hasher.defaultHasher(), + Retrier.attemptOnce(), + Retrying.neverRetry())) { + + ByteBuffer buffer = ByteBuffer.allocate(15); + int read = (int) c.read(buffer); + assertThat(read).isEqualTo(10); + } + } + + @Test + public void validateCumulativeChecksum_nonZeroOffset_skipped() throws IOException { + ChecksummedTestContent testContent = + ChecksummedTestContent.of(DataGenerator.base64Characters().genBytes(10)); + + Object metadata = + Object.newBuilder() + .setSize(testContent.length()) + .setChecksums( + ObjectChecksums.newBuilder().setCrc32C(testContent.getCrc32c() + 1).build()) + .build(); + + ResponseContentLifecycleManager manager = + ResponseContentLifecycleManager.noop(); + + ReadObjectRequest req = ReadObjectRequest.newBuilder().setReadOffset(5).build(); + + try (GapicUnbufferedReadableByteChannel c = + new GapicUnbufferedReadableByteChannel( + SettableApiFuture.create(), + new ZeroCopyServerStreamingCallable<>( + new ServerStreamingCallable() { + @Override + public void call( + ReadObjectRequest request, + ResponseObserver respond, + ApiCallContext context) { + respond.onStart(TestUtils.nullStreamController()); + respond.onResponse( + ReadObjectResponse.newBuilder() + .setChecksummedData(testContent.slice(5, 5).asChecksummedData()) + .setMetadata(metadata) + .build()); + respond.onComplete(); + } + }, + manager), + req, + Hasher.defaultHasher(), + Retrier.attemptOnce(), + Retrying.neverRetry())) { + + ByteBuffer buffer = ByteBuffer.allocate(15); + int read = (int) c.read(buffer); + assertThat(read).isEqualTo(5); + } + } + + @Test + public void validateCumulativeChecksum_zeroByteObject_success() throws IOException { + Object metadata = + Object.newBuilder() + .setSize(0) + .setChecksums(ObjectChecksums.newBuilder().setCrc32C(0).build()) + .build(); + + ResponseContentLifecycleManager manager = + ResponseContentLifecycleManager.noop(); + + try (GapicUnbufferedReadableByteChannel c = + new GapicUnbufferedReadableByteChannel( + SettableApiFuture.create(), + new ZeroCopyServerStreamingCallable<>( + new ServerStreamingCallable() { + @Override + public void call( + ReadObjectRequest request, + ResponseObserver respond, + ApiCallContext context) { + respond.onStart(TestUtils.nullStreamController()); + respond.onResponse( + ReadObjectResponse.newBuilder().setMetadata(metadata).build()); + respond.onComplete(); + } + }, + manager), + ReadObjectRequest.getDefaultInstance(), + Hasher.defaultHasher(), + Retrier.attemptOnce(), + Retrying.neverRetry())) { + + ByteBuffer buffer = ByteBuffer.allocate(15); + int read = (int) c.read(buffer); + assertThat(read).isEqualTo(0); + assertThat(c.read(buffer)).isEqualTo(-1); + } + } }