diff --git a/examples/servers/src/common/progress_demo.rs b/examples/servers/src/common/progress_demo.rs index 341b3e70a..55de24b79 100644 --- a/examples/servers/src/common/progress_demo.rs +++ b/examples/servers/src/common/progress_demo.rs @@ -11,7 +11,7 @@ use rmcp::{ }; use serde_json::json; use tokio_stream::StreamExt; -use tracing::debug; +use tracing::{debug, info}; // a Stream data source that generates data in chunks #[derive(Clone)] @@ -30,7 +30,7 @@ impl StreamDataSource { } } pub fn from_text(text: &str) -> Self { - Self::new(text.as_bytes().to_vec(), 1) + Self::new(text.as_bytes().to_vec(), 5) } } @@ -61,7 +61,7 @@ impl ProgressDemo { #[allow(dead_code)] pub fn new() -> Self { Self { - data_source: StreamDataSource::from_text("Hello, world!"), + data_source: StreamDataSource::from_text("1111122222333334444455555"), } } #[tool(description = "Process data stream with progress updates")] @@ -70,6 +70,24 @@ impl ProgressDemo { ctx: RequestContext, ) -> Result { let mut counter = 0; + info!( + "Processing stream with progress token {:?}", + ctx.meta.get_key_value("progressToken") + ); + let Some((_, progress_token)) = ctx.meta.get_key_value("progressToken") else { + return Err(McpError::internal_error( + format!("Progress token not found in request context"), + None, + )); + }; + + let Ok(progress_token) = serde_json::from_value::(progress_token.clone()) + else { + return Err(McpError::internal_error( + format!("Progress token must be a string or number"), + None, + )); + }; let mut data_source = self.data_source.clone(); loop { @@ -83,9 +101,9 @@ impl ProgressDemo { counter += 1; // create progress notification param let progress_param = ProgressNotificationParam { - progress_token: ProgressToken(NumberOrString::Number(counter)), + progress_token: ProgressToken(progress_token.clone()), progress: counter as f64, - total: None, + total: Some(5.0), message: Some(chunk_str.to_string()), }; @@ -104,6 +122,7 @@ impl ProgressDemo { )); } } + tokio::time::sleep(std::time::Duration::from_secs(1)).await; } Ok(CallToolResult::success(vec![Content::text(format!( diff --git a/examples/servers/src/progress_demo.rs b/examples/servers/src/progress_demo.rs index e9e147ea1..98e4f398e 100644 --- a/examples/servers/src/progress_demo.rs +++ b/examples/servers/src/progress_demo.rs @@ -10,6 +10,7 @@ use rmcp::{ mod common; use common::progress_demo::ProgressDemo; +use tracing_subscriber::EnvFilter; const HTTP_BIND_ADDRESS: &str = "127.0.0.1:8001"; @@ -20,6 +21,10 @@ async fn main() -> anyhow::Result<()> { .nth(1) .unwrap_or_else(|| env::var("TRANSPORT_MODE").unwrap_or_else(|_| "stdio".to_string())); + tracing_subscriber::fmt() + .with_env_filter(EnvFilter::from_default_env()) + .init(); + match transport_mode.as_str() { "stdio" => run_stdio().await, "http" | "streamhttp" => run_streamable_http().await,