":slice",
":slice_buffer",
"//:chttp2_frame",
+ "//:gpr_platform",
"//:ref_counted_ptr",
],
)
///////////////////////////////////////////////////////////////////////////////
// GRPC Header
+namespace {
+ValueOrHttp2Status<uint32_t> ParseGrpcMessageFlags(const uint8_t flags) {
+ switch (flags) {
+ case kGrpcMessageHeaderNoFlags:
+ return 0u;
+ case kGrpcMessageHeaderWriteInternalCompress:
+ return GRPC_WRITE_INTERNAL_COMPRESS;
+ default:
+ LOG(ERROR) << "Invalid gRPC header flags: "
+ << static_cast<uint32_t>(flags);
+ return Http2Status::Http2StreamError(
+ Http2ErrorCode::kInternalError,
+ absl::StrCat("Invalid gRPC header flags: ", flags));
+ }
+}
-GrpcMessageHeader ExtractGrpcHeader(SliceBuffer& payload) {
+uint8_t SerializeGrpcMessageFlags(const uint32_t flags) {
+ return (flags & GRPC_WRITE_INTERNAL_COMPRESS)
+ ? kGrpcMessageHeaderWriteInternalCompress
+ : kGrpcMessageHeaderNoFlags;
+}
+} // namespace
+
+ValueOrHttp2Status<GrpcMessageHeader> ExtractGrpcHeader(SliceBuffer& payload) {
GRPC_CHECK_GE(payload.Length(), kGrpcHeaderSizeInBytes);
uint8_t buffer[kGrpcHeaderSizeInBytes];
payload.CopyFirstNBytesIntoBuffer(kGrpcHeaderSizeInBytes, buffer);
GrpcMessageHeader header;
- header.flags = buffer[0];
+ ValueOrHttp2Status<uint32_t> message_flags = ParseGrpcMessageFlags(buffer[0]);
+ if (!message_flags.IsOk()) {
+ return message_flags.TakeStatus(std::move(message_flags));
+ }
+
+ header.flags = message_flags.value();
header.length = Read4b(buffer + 1);
return header;
}
-void AppendGrpcHeaderToSliceBuffer(SliceBuffer& payload, const uint8_t flags,
+void AppendGrpcHeaderToSliceBuffer(SliceBuffer& payload, const uint32_t flags,
const uint32_t length) {
uint8_t* frame_hdr = payload.AddTiny(kGrpcHeaderSizeInBytes);
- frame_hdr[0] = flags;
+ frame_hdr[0] = SerializeGrpcMessageFlags(flags);
Write4b(length, frame_hdr + 1);
}
// GRPC Header
constexpr uint8_t kGrpcHeaderSizeInBytes = 5;
+constexpr uint8_t kGrpcMessageHeaderNoFlags = 0;
+constexpr uint8_t kGrpcMessageHeaderWriteInternalCompress = 1;
struct GrpcMessageHeader {
- uint8_t flags = 0;
+ uint32_t flags = 0;
uint32_t length = 0;
};
// If the payload SliceBuffer is too small to hold a gRPC header, this function
// will crash. The calling function MUST ensure that the payload SliceBuffer
// has length greater than or equal to the gRPC header.
-GrpcMessageHeader ExtractGrpcHeader(SliceBuffer& payload);
+http2::ValueOrHttp2Status<GrpcMessageHeader> ExtractGrpcHeader(
+ SliceBuffer& payload);
-void AppendGrpcHeaderToSliceBuffer(SliceBuffer& payload, const uint8_t flags,
- const uint32_t length);
+void AppendGrpcHeaderToSliceBuffer(SliceBuffer& payload, uint32_t flags,
+ uint32_t length);
///////////////////////////////////////////////////////////////////////////////
// Validations
#ifndef GRPC_SRC_CORE_EXT_TRANSPORT_CHTTP2_TRANSPORT_MESSAGE_ASSEMBLER_H
#define GRPC_SRC_CORE_EXT_TRANSPORT_CHTTP2_TRANSPORT_MESSAGE_ASSEMBLER_H
+#include <grpc/support/port_platform.h>
+
#include <cstdint>
#include <utility>
#include "src/core/lib/slice/slice_buffer.h"
#include "src/core/util/grpc_check.h"
#include "src/core/util/ref_counted_ptr.h"
+#include "absl/log/log.h"
namespace grpc_core {
namespace http2 {
+#define GRPC_MESSAGE_ASSEMBLER_DEBUG VLOG(2)
+
// TODO(tjagtap) TODO(akshitpatel): [PH2][P3] : Write micro benchmarks for
// assembler and disassembler code
// TODO(tjagtap) : [PH2][P3] : Write a test for this.
return ReturnNullOrError();
}
- GrpcMessageHeader header = ExtractGrpcHeader(message_buffer_);
+ ValueOrHttp2Status<GrpcMessageHeader> header =
+ ExtractGrpcHeader(message_buffer_);
+ if (!header.IsOk()) {
+ return header.TakeStatus(std::move(header));
+ }
+ const uint32_t header_length = header.value().length;
+
if constexpr (sizeof(size_t) == 4) {
- if (GPR_UNLIKELY(header.length > kOneGb)) {
+ if (GPR_UNLIKELY(header_length > kOneGb)) {
return Http2Status::Http2StreamError(
Http2ErrorCode::kInternalError,
"Stream Error: SliceBuffer overflow for 32 bit platforms.");
}
}
- if (GPR_LIKELY(current_len - kGrpcHeaderSizeInBytes >= header.length)) {
+ if (GPR_LIKELY(current_len - kGrpcHeaderSizeInBytes >= header_length)) {
SliceBuffer discard;
message_buffer_.MoveFirstNBytesIntoSliceBuffer(kGrpcHeaderSizeInBytes,
discard);
// bounds.
MessageHandle grpc_message = Arena::MakePooled<Message>();
message_buffer_.MoveFirstNBytesIntoSliceBuffer(
- header.length, *(grpc_message->payload()));
- uint32_t& flag = grpc_message->mutable_flags();
- flag = header.flags;
+ header_length, *(grpc_message->payload()));
+ grpc_message->mutable_flags() = header.value().flags;
return std::move(grpc_message);
}
return ReturnNullOrError();
return Http2Status::Http2StreamError(Http2ErrorCode::kInternalError,
"Incomplete gRPC frame received");
}
- VLOG(2) << "Incomplete gRPC message received. Return nullptr";
+ GRPC_MESSAGE_ASSEMBLER_DEBUG
+ << "Incomplete gRPC message received. Return nullptr";
return ValueOrHttp2Status<MessageHandle>(nullptr);
}
+
bool is_end_stream_ = false;
SliceBuffer message_buffer_;
};
auto EnqueueMessage(MessageHandle&& message) {
GRPC_HTTP2_STREAM_LOG
<< "Http2ClientTransport::Stream::EnqueueMessage stream_id="
- << stream_id << " with payload size = " << message->payload()->Length();
+ << stream_id << " with payload size = " << message->payload()->Length()
+ << " and flags = " << message->flags();
return data_queue->EnqueueMessage(std::move(message));
}
}
TEST(Frame, GrpcHeaderTest) {
- constexpr uint8_t kFlags = 15;
constexpr uint32_t kLength = 1111111;
- SliceBuffer payload;
- EXPECT_EQ(payload.Length(), 0);
-
- AppendGrpcHeaderToSliceBuffer(payload, kFlags, kLength);
- EXPECT_EQ(payload.Length(), kGrpcHeaderSizeInBytes);
-
- GrpcMessageHeader header = ExtractGrpcHeader(payload);
- EXPECT_EQ(payload.Length(), kGrpcHeaderSizeInBytes);
- EXPECT_EQ(header.flags, kFlags);
- EXPECT_EQ(header.length, kLength);
+ auto verify_header = [](const uint32_t flags, const uint32_t expected_flags,
+ const uint32_t length) {
+ SliceBuffer payload;
+ EXPECT_EQ(payload.Length(), 0);
+ AppendGrpcHeaderToSliceBuffer(payload, flags, length);
+ EXPECT_EQ(payload.Length(), kGrpcHeaderSizeInBytes);
+
+ ValueOrHttp2Status<GrpcMessageHeader> header = ExtractGrpcHeader(payload);
+ EXPECT_TRUE(header.IsOk());
+ EXPECT_EQ(payload.Length(), kGrpcHeaderSizeInBytes);
+ EXPECT_EQ(header.value().flags, expected_flags);
+ EXPECT_EQ(header.value().length, length);
+ };
+
+ verify_header(/*flags=*/0, /*expected_flags=*/0, kLength);
+ verify_header(/*flags=*/GRPC_WRITE_INTERNAL_COMPRESS,
+ /*expected_flags=*/GRPC_WRITE_INTERNAL_COMPRESS, kLength);
+ verify_header(/*flags=*/10, /*expected_flags=*/0, kLength);
}
TEST(Frame, ValidateSettingsValuesInvalidInitialWindowSize) {
AppendGrpcHeaderToSliceBuffer(payload, kFlags0, 0);
}
-void AppendHeaderAndMessage(SliceBuffer& payload, absl::string_view str) {
- AppendGrpcHeaderToSliceBuffer(payload, kFlags0, str.size());
+void AppendHeaderAndMessage(SliceBuffer& payload, absl::string_view str,
+ const uint32_t flags = kFlags0) {
+ AppendGrpcHeaderToSliceBuffer(payload, flags, str.size());
payload.Append(Slice::FromCopiedString(str));
}
-void AppendHeaderAndPartialMessage(SliceBuffer& payload, const uint8_t flag,
+void AppendHeaderAndPartialMessage(SliceBuffer& payload, const uint32_t flag,
const uint32_t length,
absl::string_view str) {
AppendGrpcHeaderToSliceBuffer(payload, flag, length);
}
void ExpectMessagePayload(ValueOrHttp2Status<MessageHandle>&& result,
- const size_t expect_length, const uint8_t flags) {
+ const size_t expect_length, const uint32_t flags) {
EXPECT_TRUE(result.IsOk());
MessageHandle message = TakeValue(std::move(result));
EXPECT_EQ(message->payload()->Length(), expect_length);
}
void ExpectMessagePayload(ValueOrHttp2Status<MessageHandle>&& result,
- const size_t expect_length, const uint8_t flags,
+ const size_t expect_length, const uint32_t flags,
absl::string_view str) {
EXPECT_TRUE(result.IsOk());
MessageHandle message = TakeValue(std::move(result));
EXPECT_TRUE(append3.IsOk());
ValueOrHttp2Status<MessageHandle> result3 = assembler.ExtractMessage();
- ExpectMessagePayload(std::move(result3), 2 * kStr1024.size(), kFlags5);
+ ExpectMessagePayload(std::move(result3), 2 * kStr1024.size(), kFlags0);
ValueOrHttp2Status<MessageHandle> result4 = assembler.ExtractMessage();
ExpectMessagePayload(std::move(result4), kStr1024.size(), kFlags0);
EXPECT_FALSE(result1.IsOk());
}
+TEST(GrpcMessageAssemblerTest, ValidateGrpcMessageHeaderFlags) {
+ GrpcMessageAssembler assembler;
+ constexpr absl::string_view kPayload = "Hello!";
+
+ // Default flags
+ {
+ SliceBuffer frame;
+ AppendHeaderAndMessage(frame, /*str=*/kPayload, /*flags=*/kFlags0);
+ Http2Status result = assembler.AppendNewDataFrame(frame, kNotEndStream);
+ EXPECT_TRUE(result.IsOk());
+ ExpectMessagePayload(assembler.ExtractMessage(), kPayload.size(),
+ /*flags=*/kFlags0, kPayload);
+ }
+
+ // GRPC_WRITE_INTERNAL_COMPRESS flag
+ {
+ SliceBuffer frame;
+ // The flag will be converted to kWriteInternalCompress when sending on the
+ // wire and re-converted to GRPC_WRITE_INTERNAL_COMPRESS when received.
+ AppendHeaderAndMessage(frame, /*str=*/kPayload,
+ /*flags=*/
+ GRPC_WRITE_INTERNAL_COMPRESS);
+ Http2Status result = assembler.AppendNewDataFrame(frame, kNotEndStream);
+ EXPECT_TRUE(result.IsOk());
+ ExpectMessagePayload(assembler.ExtractMessage(), kPayload.size(),
+ /*flags=*/GRPC_WRITE_INTERNAL_COMPRESS, kPayload);
+ }
+
+ // Invalid flags
+ {
+ SliceBuffer frame;
+ AppendHeaderAndMessage(frame, /*str=*/kPayload, /*flags=*/kFlags5);
+ Http2Status result = assembler.AppendNewDataFrame(frame, kNotEndStream);
+ EXPECT_TRUE(result.IsOk());
+ ExpectMessagePayload(assembler.ExtractMessage(), kPayload.size(),
+ /*flags=*/kFlags0, kPayload);
+ }
+
+ {
+ SliceBuffer frame;
+ uint8_t* header = frame.AddTiny(kGrpcHeaderSizeInBytes);
+ header[0] = kFlags5;
+ const uint32_t length = kPayload.size();
+ header[1] = static_cast<uint8_t>(length >> 24);
+ header[2] = static_cast<uint8_t>(length >> 16);
+ header[3] = static_cast<uint8_t>(length >> 8);
+ header[4] = static_cast<uint8_t>(length);
+ frame.Append(Slice::FromCopiedString(kPayload));
+ Http2Status result = assembler.AppendNewDataFrame(frame, kNotEndStream);
+ EXPECT_TRUE(result.IsOk());
+
+ ValueOrHttp2Status<MessageHandle> message = assembler.ExtractMessage();
+ EXPECT_FALSE(message.IsOk());
+ }
+}
+
///////////////////////////////////////////////////////////////////////////////
// GrpcMessageDisassembler Tests