Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 10 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,21 @@ sqlparser = { version = "0.61", features = ["visitor"] }
# around this same code.
rustls-pki-types = { version = "1", features = ["std"] }
rand = "0.10"
# Used to build typed values for the binary result-wire-format path
# (`types::encode_cell`). Versions are unified with the impls pgwire's
# default features pull into `postgres-types` (chrono 0.4, rust_decimal 1.x
# with `db-postgres`), so the `ToSql`/`ToSqlText` impls apply.
chrono = "0.4"
rust_decimal = { version = "1", features = ["db-postgres"] }

[build-dependencies]
chrono = "0.4"

[dev-dependencies]
tokio-postgres = "0.7"
# with-chrono-0_4 lets the test client decode binary timestamps, exercising
# the binary result path for the customer's timestamp columns end-to-end.
tokio-postgres = { version = "0.7", features = ["with-chrono-0_4"] }
chrono = "0.4"
tempfile = "3"
rcgen = "0.14"

Expand Down
7 changes: 7 additions & 0 deletions src/query_extended.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,7 @@ impl ExtendedQueryHandler for GatewayExtendedQueryHandler {
&conn_state.trino_client,
&conn_state.config,
Some(&conn_state.active_query_id),
Some(&portal.result_column_format),
)
.await?;
let response = responses
Expand Down Expand Up @@ -190,11 +191,16 @@ impl ExtendedQueryHandler for GatewayExtendedQueryHandler {
// row-returning query. The result Stream is dropped here; Trino's
// server-side query state is freed via its own TTL (see TODO in
// `do_describe_portal` below for promoter cancellation).
//
// The result format isn't known at Describe-Statement time (no Bind
// yet), so the RowDescription is built as text; the actual per-column
// format is applied later in do_query against the bound portal.
let responses = process_query(
query,
&conn_state.trino_client,
&conn_state.config,
Some(&conn_state.active_query_id),
None,
)
.await?;
let response = responses
Expand Down Expand Up @@ -234,6 +240,7 @@ impl ExtendedQueryHandler for GatewayExtendedQueryHandler {
&conn_state.trino_client,
&conn_state.config,
Some(&conn_state.active_query_id),
Some(&portal.result_column_format),
)
.await?;
let response = responses
Expand Down
22 changes: 19 additions & 3 deletions src/query_pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// SPDX-License-Identifier: OSL-3.0
use std::sync::Arc;

use pgwire::api::portal::Format;
use pgwire::api::results::{QueryResponse, Response, Tag};
use pgwire::error::{PgWireError, PgWireResult};
use sqlparser::dialect::PostgreSqlDialect;
Expand Down Expand Up @@ -38,23 +39,37 @@ use crate::trino_stream::execute_trino_query;
/// returns; cancelling them after later statements have submitted is not
/// supported, but no documented client (Power BI, pgjdbc) exercises this
/// path.
/// `result_format` carries the per-column wire format the client bound for
/// results (extended-query path); `None` means all-text (simple-query
/// protocol, which never negotiates binary). It is forwarded to
/// `execute_trino_query` so the result schema and DataRow encoding honor it.
pub(crate) async fn process_query(
query: &str,
trino_client: &Arc<TrinoClient>,
config: &Arc<Config>,
active_query_id: Option<&ActiveQueryId>,
result_format: Option<&Format>,
) -> PgWireResult<Vec<Response>> {
tracing::trace!(query, "Pipeline: enter");

let pieces = split_statements(query);
if pieces.len() <= 1 {
return process_single_statement(query, trino_client, config, active_query_id).await;
return process_single_statement(
query,
trino_client,
config,
active_query_id,
result_format,
)
.await;
}

tracing::trace!(count = pieces.len(), "Pipeline: multi-statement input");
let mut out = Vec::with_capacity(pieces.len());
for stmt in &pieces {
match process_single_statement(stmt, trino_client, config, active_query_id).await {
match process_single_statement(stmt, trino_client, config, active_query_id, result_format)
.await
{
Ok(mut responses) => out.append(&mut responses),
// User-visible errors (e.g. a Trino syntax error on statement N
// of a batch) are converted to a Response::Error so that the
Expand Down Expand Up @@ -91,6 +106,7 @@ async fn process_single_statement(
trino_client: &Arc<TrinoClient>,
config: &Arc<Config>,
active_query_id: Option<&ActiveQueryId>,
result_format: Option<&Format>,
) -> PgWireResult<Vec<Response>> {
// The query is parsed up to three times: once here (for routing
// checks), once by the multi-statement splitter in the public
Expand Down Expand Up @@ -136,7 +152,7 @@ async fn process_single_statement(
tracing::debug!(original = query, rewritten = %rewritten, "Rewritten query");

let (schema, row_stream) =
execute_trino_query(trino_client, rewritten, active_query_id).await?;
execute_trino_query(trino_client, rewritten, active_query_id, result_format).await?;

if schema.is_empty() {
tracing::trace!("Pipeline: Trino returned no schema — treating as DDL/DML");
Expand Down
3 changes: 3 additions & 0 deletions src/query_simple.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ impl SimpleQueryHandler for GatewayQueryHandler {
.get::<ConnectionState>()
.ok_or_else(|| PgWireError::ApiError("Connection state not found".into()))?;

// The simple-query protocol always uses text wire format; it never
// negotiates per-column binary results.
let result = process_query(
query,
&conn_state.trino_client,
&conn_state.config,
Some(&conn_state.active_query_id),
None,
)
.await;
match &result {
Expand Down
41 changes: 31 additions & 10 deletions src/trino_stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use std::sync::Arc;

use async_stream::stream;
use futures::Stream;
use pgwire::api::portal::Format;
use pgwire::api::results::{DataRowEncoder, FieldFormat, FieldInfo};
use pgwire::error::{PgWireError, PgWireResult};
use pgwire::messages::data::DataRow;
Expand All @@ -12,7 +13,7 @@ use trino_rust_client::models::{Column, QueryResultData};
use trino_rust_client::{Client, Row};

use crate::session::ActiveQueryId;
use crate::types::{encode_value, trino_type_to_pg};
use crate::types::{encode_cell, trino_type_to_pg};

#[derive(Clone)]
pub(crate) struct TrinoColumn {
Expand All @@ -29,17 +30,29 @@ impl From<&Column> for TrinoColumn {
}
}

pub(crate) fn build_pg_schema(columns: &[TrinoColumn]) -> Arc<Vec<FieldInfo>> {
/// Build the PG `RowDescription` schema for a Trino result set.
///
/// `result_format` is the per-column format the client bound for results;
/// `None` means all-text (the simple-query protocol, which never negotiates
/// binary). Each column's `FieldFormat` is taken from that request, so the
/// schema drives both the advertised RowDescription and the per-cell encoding
/// in `encode_row` — keeping them in lock-step.
pub(crate) fn build_pg_schema(
columns: &[TrinoColumn],
result_format: Option<&Format>,
) -> Arc<Vec<FieldInfo>> {
Arc::new(
columns
.iter()
.map(|col| {
.enumerate()
.map(|(idx, col)| {
let format = result_format.map_or(FieldFormat::Text, |f| f.format_for(idx));
FieldInfo::new(
col.name.clone(),
None,
None,
trino_type_to_pg(&col.trino_type),
FieldFormat::Text,
format,
)
})
.collect(),
Expand All @@ -52,8 +65,15 @@ pub(crate) fn encode_row(
schema: &Arc<Vec<FieldInfo>>,
) -> PgWireResult<DataRow> {
let mut encoder = DataRowEncoder::new(schema.clone());
for (value, col) in values.iter().zip(columns.iter()) {
encoder.encode_field(&encode_value(value, &col.trino_type))?;
for (idx, (value, col)) in values.iter().zip(columns.iter()).enumerate() {
let field = &schema[idx];
encode_cell(
&mut encoder,
value,
field.datatype(),
&col.trino_type,
field.format(),
)?;
}
Ok(encoder.take_row())
}
Expand Down Expand Up @@ -126,6 +146,7 @@ pub async fn execute_trino_query(
client: &Arc<Client>,
sql: String,
active_query_id: Option<&ActiveQueryId>,
result_format: Option<&Format>,
) -> Result<
(
Arc<Vec<FieldInfo>>,
Expand Down Expand Up @@ -200,7 +221,7 @@ pub async fn execute_trino_query(

// Empty schema means DDL/DML; the caller returns Response::Execution
// instead of Response::Query.
let schema = build_pg_schema(&trino_columns);
let schema = build_pg_schema(&trino_columns, result_format);

let stream_client = Arc::clone(client);
let stream_columns = trino_columns.clone();
Expand Down Expand Up @@ -298,7 +319,7 @@ mod tests {
trino_type: "varchar".to_owned(),
},
];
let schema = build_pg_schema(&columns);
let schema = build_pg_schema(&columns, None);
assert_eq!(schema.len(), 2);
assert_eq!(schema[0].name(), "id");
assert_eq!(*schema[0].datatype(), Type::INT4);
Expand All @@ -318,7 +339,7 @@ mod tests {
trino_type: "varchar".to_owned(),
},
];
let schema = build_pg_schema(&columns);
let schema = build_pg_schema(&columns, None);
let values = vec![json!(42), json!("alice")];

let row = encode_row(&values, &columns, &schema);
Expand All @@ -331,7 +352,7 @@ mod tests {
name: "val".to_owned(),
trino_type: "varchar".to_owned(),
}];
let schema = build_pg_schema(&columns);
let schema = build_pg_schema(&columns, None);
let values = vec![Value::Null];

let row = encode_row(&values, &columns, &schema);
Expand Down
Loading
Loading