diff --git a/CHANGELOG.md b/CHANGELOG.md index f0e411d..caa8476 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,3 +5,4 @@ 0.1.0 - Transport WebSocket générique 0.1.1 - Intégration Tauri minimale du WsClient 0.2.0 - Couche JSON-RPC WS Solana +0.3.0 - Registre subscriptions / notifications diff --git a/Cargo.toml b/Cargo.toml index 9449e77..b99e676 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ members = [ ] [workspace.package] -version = "0.2.0" +version = "0.3.0" edition = "2024" license = "MIT" repository = "https://git.sasedev.com/Sasedev/khadhroony-bobobot" diff --git a/kb_app/package.json b/kb_app/package.json index 7107937..b01bb77 100644 --- a/kb_app/package.json +++ b/kb_app/package.json @@ -1,7 +1,7 @@ { "name": "kb-app", "private": true, - "version": "0.2.0", + "version": "0.3.0", "type": "module", "scripts": { "dev": "vite", diff --git a/kb_app/src/lib.rs b/kb_app/src/lib.rs index 3f83883..7f47038 100644 --- a/kb_app/src/lib.rs +++ b/kb_app/src/lib.rs @@ -330,9 +330,7 @@ fn kb_emit_app_log(app_handle: &tauri::AppHandle, message: &str) { } } -fn kb_format_ws_event( - event: &kb_lib::WsEvent, -) -> std::string::String { +fn kb_format_ws_event(event: &kb_lib::WsEvent) -> std::string::String { match event { kb_lib::WsEvent::Connected { endpoint_name, @@ -349,31 +347,26 @@ fn kb_format_ws_event( kb_lib::WsEvent::JsonRpcMessage { endpoint_name, message, - } => { - match message { - kb_lib::KbJsonRpcWsIncomingMessage::SuccessResponse(response) => { - format!( - "[ws:{endpoint_name}] json-rpc success id={} result={}", - response.id, - response.result - ) - }, - kb_lib::KbJsonRpcWsIncomingMessage::ErrorResponse(response) => { - format!( - "[ws:{endpoint_name}] json-rpc error id={} code={} message={}", - response.id, - response.error.code, - response.error.message - ) - }, - kb_lib::KbJsonRpcWsIncomingMessage::Notification(notification) => { - format!( - "[ws:{endpoint_name}] json-rpc notification method={} subscription={} result={}", - notification.method, - notification.params.subscription, - notification.params.result - ) - }, + } => match message { + kb_lib::KbJsonRpcWsIncomingMessage::SuccessResponse(response) => { + format!( + "[ws:{endpoint_name}] json-rpc success id={} result={}", + response.id, response.result + ) + } + kb_lib::KbJsonRpcWsIncomingMessage::ErrorResponse(response) => { + format!( + "[ws:{endpoint_name}] json-rpc error id={} code={} message={}", + response.id, response.error.code, response.error.message + ) + } + kb_lib::KbJsonRpcWsIncomingMessage::Notification(notification) => { + format!( + "[ws:{endpoint_name}] json-rpc notification method={} subscription={} result={}", + notification.method, + notification.params.subscription, + notification.params.result + ) } }, kb_lib::WsEvent::JsonRpcParseError { @@ -383,31 +376,75 @@ fn kb_format_ws_event( } => { format!( "[ws:{endpoint_name}] json-rpc parse error: {} | raw={}", - error, - text + error, text ) - }, + } + kb_lib::WsEvent::SubscriptionRegistered { + endpoint_name, + subscription, + } => { + format!( + "[ws:{endpoint_name}] subscription registered subscribe_method={} unsubscribe_method={} notification_method={} request_id={} subscription_id={}", + subscription.subscribe_method, + subscription.unsubscribe_method, + subscription.notification_method, + subscription.request_id, + subscription.subscription_id + ) + } + kb_lib::WsEvent::SubscriptionNotification { + endpoint_name, + subscription, + notification, + method_matches_registry, + } => { + format!( + "[ws:{endpoint_name}] tracked notification method={} expected_method={} matches_registry={} subscription_id={} result={}", + notification.method, + subscription.notification_method, + method_matches_registry, + subscription.subscription_id, + notification.params.result + ) + } + kb_lib::WsEvent::JsonRpcNotificationWithoutSubscription { + endpoint_name, + notification, + } => { + format!( + "[ws:{endpoint_name}] untracked notification method={} subscription={} result={}", + notification.method, notification.params.subscription, notification.params.result + ) + } + kb_lib::WsEvent::SubscriptionUnregistered { + endpoint_name, + subscription_id, + unsubscribe_method, + was_active, + } => { + format!( + "[ws:{endpoint_name}] subscription unregistered subscription_id={} unsubscribe_method={} was_active={}", + subscription_id, unsubscribe_method, was_active + ) + } kb_lib::WsEvent::BinaryMessage { endpoint_name, data, } => { - format!( - "[ws:{endpoint_name}] binary message ({} bytes)", - data.len() - ) - }, + format!("[ws:{endpoint_name}] binary message ({} bytes)", data.len()) + } kb_lib::WsEvent::Ping { endpoint_name, data, } => { format!("[ws:{endpoint_name}] ping ({} bytes)", data.len()) - }, + } kb_lib::WsEvent::Pong { endpoint_name, data, } => { format!("[ws:{endpoint_name}] pong ({} bytes)", data.len()) - }, + } kb_lib::WsEvent::CloseReceived { endpoint_name, code, @@ -415,21 +452,18 @@ fn kb_format_ws_event( } => { format!( "[ws:{endpoint_name}] close received code={:?} reason={:?}", - code, - reason + code, reason ) - }, - kb_lib::WsEvent::Disconnected { - endpoint_name, - } => { + } + kb_lib::WsEvent::Disconnected { endpoint_name } => { format!("[ws:{endpoint_name}] disconnected") - }, + } kb_lib::WsEvent::Error { endpoint_name, error, } => { format!("[ws:{endpoint_name}] error: {error}") - }, + } } } diff --git a/kb_app/tauri.conf.json b/kb_app/tauri.conf.json index 3ad7109..10fafe7 100644 --- a/kb_app/tauri.conf.json +++ b/kb_app/tauri.conf.json @@ -1,7 +1,7 @@ { "$schema": "https://schema.tauri.app/config/2", "productName": "kb-bapp", - "version": "0.2.0", + "version": "0.3.0", "identifier": "com.sasedev.kb-app", "build": { "beforeDevCommand": "npm run dev", diff --git a/kb_lib/src/lib.rs b/kb_lib/src/lib.rs index 6b3a341..ef5a69b 100644 --- a/kb_lib/src/lib.rs +++ b/kb_lib/src/lib.rs @@ -43,3 +43,4 @@ pub use crate::types::KbConnectionState; pub use crate::ws_client::WsClient; pub use crate::ws_client::WsEvent; pub use crate::ws_client::WsOutgoingMessage; +pub use crate::ws_client::WsSubscriptionInfo; diff --git a/kb_lib/src/ws_client.rs b/kb_lib/src/ws_client.rs index ec29435..ee5922d 100644 --- a/kb_lib/src/ws_client.rs +++ b/kb_lib/src/ws_client.rs @@ -2,9 +2,12 @@ //! Generic asynchronous WebSocket transport client. //! -//! Version `0.2.x` keeps the transport layer introduced in `0.1.x` and adds -//! generic JSON-RPC 2.0 request helpers plus incoming JSON-RPC parsing for -//! text messages received from the server. +//! Version `0.3.x` keeps the transport and JSON-RPC helpers introduced earlier +//! and adds: +//! - a registry of pending JSON-RPC requests +//! - a registry of active subscriptions +//! - automatic routing of notifications to known subscriptions +//! - automatic unsubscribe attempts before disconnect use futures_util::SinkExt; use futures_util::StreamExt; @@ -24,6 +27,23 @@ pub enum WsOutgoingMessage { Close, } +/// Active subscription metadata tracked by the client runtime. +#[derive(Clone, Debug, PartialEq)] +pub struct WsSubscriptionInfo { + /// Local request identifier that created the subscription. + pub request_id: u64, + /// Remote subscription identifier returned by the server. + pub subscription_id: u64, + /// Subscribe method name. + pub subscribe_method: std::string::String, + /// Unsubscribe method name paired with the subscription. + pub unsubscribe_method: std::string::String, + /// Expected notification method name. + pub notification_method: std::string::String, + /// Original subscribe request parameters. + pub params: std::vec::Vec, +} + /// Incoming WebSocket transport event emitted by [`crate::WsClient`]. #[derive(Clone, Debug, PartialEq)] pub enum WsEvent { @@ -57,6 +77,42 @@ pub enum WsEvent { /// Parse error. error: crate::KbError, }, + /// A subscribe response created a tracked active subscription. + SubscriptionRegistered { + /// Stable endpoint name from configuration. + endpoint_name: std::string::String, + /// Registered subscription metadata. + subscription: WsSubscriptionInfo, + }, + /// A notification was matched to a tracked active subscription. + SubscriptionNotification { + /// Stable endpoint name from configuration. + endpoint_name: std::string::String, + /// Matched subscription metadata. + subscription: WsSubscriptionInfo, + /// Received notification payload. + notification: crate::KbJsonRpcWsNotification, + /// Indicates whether the notification method matches the expected one. + method_matches_registry: bool, + }, + /// A notification could not be matched to any tracked active subscription. + JsonRpcNotificationWithoutSubscription { + /// Stable endpoint name from configuration. + endpoint_name: std::string::String, + /// Received notification payload. + notification: crate::KbJsonRpcWsNotification, + }, + /// An unsubscribe response removed a tracked subscription. + SubscriptionUnregistered { + /// Stable endpoint name from configuration. + endpoint_name: std::string::String, + /// Removed subscription identifier. + subscription_id: u64, + /// Unsubscribe method used by the request. + unsubscribe_method: std::string::String, + /// Indicates whether the subscription was active before removal. + was_active: bool, + }, /// Binary message received. BinaryMessage { /// Stable endpoint name from configuration. @@ -126,6 +182,42 @@ impl WsClientRuntime { } } +#[derive(Debug)] +struct WsClientRegistry { + pending_requests: std::collections::BTreeMap, + active_subscriptions: std::collections::BTreeMap, +} + +impl WsClientRegistry { + fn new() -> Self { + Self { + pending_requests: std::collections::BTreeMap::new(), + active_subscriptions: std::collections::BTreeMap::new(), + } + } +} + +#[derive(Clone, Debug, PartialEq)] +struct WsPendingJsonRpcRequest { + request_id: u64, + method: std::string::String, + kind: WsPendingJsonRpcRequestKind, +} + +#[derive(Clone, Debug, PartialEq)] +enum WsPendingJsonRpcRequestKind { + Generic, + Subscribe { + notification_method: std::string::String, + unsubscribe_method: std::string::String, + params: std::vec::Vec, + }, + Unsubscribe { + subscription_id: u64, + unsubscribe_method: std::string::String, + }, +} + #[derive(Clone, Debug)] enum WsWriteCommand { Send(WsOutgoingMessage), @@ -140,6 +232,7 @@ pub struct WsClient { state: std::sync::Arc>, event_tx: tokio::sync::broadcast::Sender, runtime: std::sync::Arc>, + registry: std::sync::Arc>, } impl WsClient { @@ -165,6 +258,7 @@ impl WsClient { )), event_tx, runtime: std::sync::Arc::new(tokio::sync::Mutex::new(WsClientRuntime::new())), + registry: std::sync::Arc::new(tokio::sync::Mutex::new(WsClientRegistry::new())), }) } @@ -200,6 +294,28 @@ impl WsClient { *state_guard } + /// Returns the number of tracked pending JSON-RPC requests. + pub async fn pending_request_count(&self) -> usize { + let registry_guard = self.registry.lock().await; + registry_guard.pending_requests.len() + } + + /// Returns the number of tracked active subscriptions. + pub async fn active_subscription_count(&self) -> usize { + let registry_guard = self.registry.lock().await; + registry_guard.active_subscriptions.len() + } + + /// Returns a snapshot of the tracked active subscriptions. + pub async fn active_subscriptions(&self) -> std::vec::Vec { + let registry_guard = self.registry.lock().await; + registry_guard + .active_subscriptions + .values() + .cloned() + .collect() + } + /// Connects the client to its remote WebSocket endpoint. pub async fn connect(&self) -> Result<(), crate::KbError> { if !self.endpoint.enabled { @@ -243,6 +359,7 @@ impl WsClient { return Err(error); } }; + let (ws_stream, _response) = match connect_result { Ok(parts) => parts, Err(error) => { @@ -393,7 +510,7 @@ impl WsClient { self.send_text(text).await } - /// Sends a prebuilt JSON-RPC request object. + /// Sends a prebuilt JSON-RPC request object and tracks it when the request id is numeric. pub async fn send_json_rpc_request_object( &self, request: &crate::KbJsonRpcWsRequest, @@ -403,7 +520,24 @@ impl WsClient { Ok(value) => value, Err(error) => return Err(error), }; - self.send_json_value(&value).await + let tracked_request = kb_build_pending_json_rpc_request(request); + if let Some(tracked_request) = &tracked_request { + let mut registry_guard = self.registry.lock().await; + registry_guard + .pending_requests + .insert(tracked_request.request_id, tracked_request.clone()); + } + let send_result = self.send_json_value(&value).await; + if let Err(error) = send_result { + if let Some(tracked_request) = tracked_request { + let mut registry_guard = self.registry.lock().await; + registry_guard + .pending_requests + .remove(&tracked_request.request_id); + } + return Err(error); + } + Ok(()) } /// Builds and sends a JSON-RPC request with a generated numeric identifier. @@ -430,8 +564,9 @@ impl WsClient { /// Disconnects the client from its remote endpoint. /// - /// This method initiates a close handshake, signals shutdown, waits for the - /// transport tasks to complete, and aborts them if the timeout is exceeded. + /// Before closing the transport, this method attempts to unsubscribe all + /// currently tracked active subscriptions and waits up to + /// `unsubscribe_timeout_ms` for their removal. pub async fn disconnect(&self) -> Result<(), crate::KbError> { let current_state = self.connection_state().await; if current_state == crate::KbConnectionState::Disconnected { @@ -451,6 +586,14 @@ impl WsClient { endpoint_name = %self.endpoint.name, "disconnecting websocket client" ); + let auto_unsubscribe_result = self.unsubscribe_all_active_subscriptions().await; + if let Err(error) = auto_unsubscribe_result { + tracing::warn!( + endpoint_name = %self.endpoint.name, + "automatic unsubscribe phase failed before disconnect: {}", + error + ); + } let ( generation, writer_tx_option, @@ -544,6 +687,51 @@ impl WsClient { } } + async fn unsubscribe_all_active_subscriptions(&self) -> Result { + let subscriptions = self.active_subscriptions().await; + if subscriptions.is_empty() { + return Ok(0); + } + tracing::info!( + endpoint_name = %self.endpoint.name, + subscription_count = subscriptions.len(), + "sending automatic unsubscribe requests before disconnect" + ); + for subscription in &subscriptions { + let unsubscribe_params = vec![serde_json::Value::from(subscription.subscription_id)]; + let send_result = self + .send_json_rpc_request(subscription.unsubscribe_method.clone(), unsubscribe_params) + .await; + if let Err(error) = send_result { + tracing::warn!( + endpoint_name = %self.endpoint.name, + subscription_id = subscription.subscription_id, + unsubscribe_method = %subscription.unsubscribe_method, + "cannot send automatic unsubscribe request: {}", + error + ); + } + } + let started_at = std::time::Instant::now(); + let wait_timeout = std::time::Duration::from_millis(self.endpoint.unsubscribe_timeout_ms); + loop { + let active_count = self.active_subscription_count().await; + if active_count == 0 { + break; + } + if started_at.elapsed() >= wait_timeout { + tracing::warn!( + endpoint_name = %self.endpoint.name, + remaining_active_subscriptions = active_count, + "automatic unsubscribe wait timeout reached" + ); + break; + } + tokio::time::sleep(std::time::Duration::from_millis(25)).await; + } + Ok(subscriptions.len()) + } + async fn run_supervisor( &self, generation: u64, @@ -627,8 +815,9 @@ impl WsClient { Ok(parsed_message) => { self.emit_event(WsEvent::JsonRpcMessage { endpoint_name: self.endpoint.name.clone(), - message: parsed_message, + message: parsed_message.clone(), }); + self.handle_incoming_json_rpc_message(&parsed_message).await; } Err(error) => { self.emit_event(WsEvent::JsonRpcParseError { @@ -686,6 +875,114 @@ impl WsClient { } } + async fn handle_incoming_json_rpc_message(&self, message: &crate::KbJsonRpcWsIncomingMessage) { + match message { + crate::KbJsonRpcWsIncomingMessage::SuccessResponse(response) => { + let request_id_option = kb_json_value_to_u64(&response.id); + let request_id = match request_id_option { + Some(request_id) => request_id, + None => return, + }; + let pending_request_option = { + let mut registry_guard = self.registry.lock().await; + registry_guard.pending_requests.remove(&request_id) + }; + let pending_request = match pending_request_option { + Some(pending_request) => pending_request, + None => return, + }; + match pending_request.kind { + WsPendingJsonRpcRequestKind::Generic => {} + WsPendingJsonRpcRequestKind::Subscribe { + notification_method, + unsubscribe_method, + params, + } => { + let subscription_id_option = response.result.as_u64(); + let subscription_id = match subscription_id_option { + Some(subscription_id) => subscription_id, + None => return, + }; + let subscription = WsSubscriptionInfo { + request_id, + subscription_id, + subscribe_method: pending_request.method, + unsubscribe_method, + notification_method, + params, + }; + { + let mut registry_guard = self.registry.lock().await; + registry_guard + .active_subscriptions + .insert(subscription_id, subscription.clone()); + } + self.emit_event(WsEvent::SubscriptionRegistered { + endpoint_name: self.endpoint.name.clone(), + subscription, + }); + } + WsPendingJsonRpcRequestKind::Unsubscribe { + subscription_id, + unsubscribe_method, + } => { + let result_bool = response.result.as_bool(); + if result_bool != Some(true) { + return; + } + let removed_subscription_option = { + let mut registry_guard = self.registry.lock().await; + registry_guard.active_subscriptions.remove(&subscription_id) + }; + self.emit_event(WsEvent::SubscriptionUnregistered { + endpoint_name: self.endpoint.name.clone(), + subscription_id, + unsubscribe_method, + was_active: removed_subscription_option.is_some(), + }); + } + } + } + crate::KbJsonRpcWsIncomingMessage::ErrorResponse(response) => { + let request_id_option = kb_json_value_to_u64(&response.id); + let request_id = match request_id_option { + Some(request_id) => request_id, + None => return, + }; + let mut registry_guard = self.registry.lock().await; + registry_guard.pending_requests.remove(&request_id); + } + crate::KbJsonRpcWsIncomingMessage::Notification(notification) => { + let subscription_id = notification.params.subscription; + let matched_subscription_option = { + let registry_guard = self.registry.lock().await; + registry_guard + .active_subscriptions + .get(&subscription_id) + .cloned() + }; + match matched_subscription_option { + Some(subscription) => { + let method_matches_registry = + subscription.notification_method == notification.method; + self.emit_event(WsEvent::SubscriptionNotification { + endpoint_name: self.endpoint.name.clone(), + subscription, + notification: notification.clone(), + method_matches_registry, + }); + } + None => { + self.emit_event(WsEvent::JsonRpcNotificationWithoutSubscription { + endpoint_name: self.endpoint.name.clone(), + notification: notification.clone(), + }); + } + } + } + } + } + async fn run_write_loop( &self, mut write_half: futures_util::stream::SplitSink< @@ -697,31 +994,28 @@ impl WsClient { ) where S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + std::marker::Send + 'static, { - let close_sent = false; loop { tokio::select! { shutdown_result = shutdown_rx.changed() => { match shutdown_result { Ok(()) => { if *shutdown_rx.borrow() { - if !close_sent { - let close_result = write_half.send( - tokio_tungstenite::tungstenite::Message::Close(None) - ).await; - match close_result { - Ok(()) => {}, - Err(error) => { - if !kb_is_normal_close_error(&error) { - self.emit_event(WsEvent::Error { - endpoint_name: self.endpoint.name.clone(), - error: crate::KbError::Ws(format!( - "write close error on endpoint '{}': {error}", - self.endpoint.name - )), - }); - } - }, - } + let close_result = write_half.send( + tokio_tungstenite::tungstenite::Message::Close(None) + ).await; + match close_result { + Ok(()) => {}, + Err(error) => { + if !kb_is_normal_close_error(&error) { + self.emit_event(WsEvent::Error { + endpoint_name: self.endpoint.name.clone(), + error: crate::KbError::Ws(format!( + "write close error on endpoint '{}': {error}", + self.endpoint.name + )), + }); + } + }, } let flush_result = write_half.flush().await; if let Err(error) = flush_result { @@ -846,6 +1140,11 @@ impl WsClient { if !should_clear { return false; } + { + let mut registry_guard = self.registry.lock().await; + registry_guard.pending_requests.clear(); + registry_guard.active_subscriptions.clear(); + } let mut state_guard = self.state.write().await; if *state_guard == crate::KbConnectionState::Disconnected { return false; @@ -887,6 +1186,81 @@ fn kb_is_normal_close_error(error: &tokio_tungstenite::tungstenite::Error) -> bo } } +fn kb_json_value_to_u64(value: &serde_json::Value) -> std::option::Option { + value.as_u64() +} + +fn kb_is_subscribe_method(method: &str) -> bool { + method.ends_with("Subscribe") +} + +fn kb_is_unsubscribe_method(method: &str) -> bool { + method.ends_with("Unsubscribe") +} + +fn kb_infer_unsubscribe_method_from_subscribe( + subscribe_method: &str, +) -> std::option::Option { + if !kb_is_subscribe_method(subscribe_method) { + return None; + } + let base = subscribe_method.trim_end_matches("Subscribe"); + Some(format!("{base}Unsubscribe")) +} + +fn kb_infer_notification_method_from_subscribe( + subscribe_method: &str, +) -> std::option::Option { + if !kb_is_subscribe_method(subscribe_method) { + return None; + } + let base = subscribe_method.trim_end_matches("Subscribe"); + Some(format!("{base}Notification")) +} + +fn kb_build_pending_json_rpc_request( + request: &crate::KbJsonRpcWsRequest, +) -> std::option::Option { + let request_id_option = kb_json_value_to_u64(&request.id); + let request_id = request_id_option?; + + if kb_is_subscribe_method(&request.method) { + let notification_method_option = + kb_infer_notification_method_from_subscribe(&request.method); + let unsubscribe_method_option = kb_infer_unsubscribe_method_from_subscribe(&request.method); + let notification_method = notification_method_option?; + let unsubscribe_method = unsubscribe_method_option?; + return Some(WsPendingJsonRpcRequest { + request_id, + method: request.method.clone(), + kind: WsPendingJsonRpcRequestKind::Subscribe { + notification_method, + unsubscribe_method, + params: request.params.clone(), + }, + }); + } + if kb_is_unsubscribe_method(&request.method) { + let first_param_option = request.params.first(); + let first_param = first_param_option?; + let subscription_id_option = first_param.as_u64(); + let subscription_id = subscription_id_option?; + return Some(WsPendingJsonRpcRequest { + request_id, + method: request.method.clone(), + kind: WsPendingJsonRpcRequestKind::Unsubscribe { + subscription_id, + unsubscribe_method: request.method.clone(), + }, + }); + } + Some(WsPendingJsonRpcRequest { + request_id, + method: request.method.clone(), + kind: WsPendingJsonRpcRequestKind::Generic, + }) +} + #[cfg(test)] mod tests { use futures_util::SinkExt; @@ -896,10 +1270,13 @@ mod tests { struct TestWsServer { url: std::string::String, shutdown_tx: std::option::Option>, + observed_methods: std::sync::Arc>>, } impl TestWsServer { async fn spawn_echo_server() -> Self { + let observed_methods = + std::sync::Arc::new(tokio::sync::Mutex::new(std::vec::Vec::new())); let bind_result = tokio::net::TcpListener::bind("127.0.0.1:0").await; let listener = bind_result.expect("listener bind must succeed"); let local_addr = listener.local_addr().expect("local addr must be available"); @@ -962,14 +1339,18 @@ mod tests { Self { url: format!("ws://{}", local_addr), shutdown_tx: Some(shutdown_tx), + observed_methods, } } async fn spawn_json_rpc_server() -> Self { + let observed_methods = + std::sync::Arc::new(tokio::sync::Mutex::new(std::vec::Vec::new())); let bind_result = tokio::net::TcpListener::bind("127.0.0.1:0").await; let listener = bind_result.expect("listener bind must succeed"); let local_addr = listener.local_addr().expect("local addr must be available"); let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + let observed_methods_for_server = observed_methods.clone(); tokio::spawn(async move { loop { tokio::select! { @@ -978,6 +1359,7 @@ mod tests { }, accept_result = listener.accept() => { let (stream, _peer_addr) = accept_result.expect("accept must succeed"); + let observed_methods_for_connection = observed_methods_for_server.clone(); tokio::spawn(async move { let accept_ws_result = tokio_tungstenite::accept_async(stream).await; let mut ws_stream = accept_ws_result.expect("websocket accept must succeed"); @@ -994,8 +1376,13 @@ mod tests { tokio_tungstenite::tungstenite::Message::Text(text) => { let value: serde_json::Value = serde_json::from_str(text.as_ref()) .expect("request json must parse"); - let method = value["method"].as_str().expect("method must be a string"); + let method = value["method"].as_str().expect("method must be a string").to_string(); let id = value["id"].clone(); + { + let mut observed_methods_guard = + observed_methods_for_connection.lock().await; + observed_methods_guard.push(method.clone()); + } if method == "slotSubscribe" { let response = serde_json::json!({ "jsonrpc": "2.0", @@ -1020,6 +1407,15 @@ mod tests { ws_stream.send( tokio_tungstenite::tungstenite::Message::Text(notification.to_string().into()) ).await.expect("notification send must succeed"); + } else if method == "slotUnsubscribe" { + let response = serde_json::json!({ + "jsonrpc": "2.0", + "result": true, + "id": id + }); + ws_stream.send( + tokio_tungstenite::tungstenite::Message::Text(response.to_string().into()) + ).await.expect("unsubscribe response send must succeed"); } else { let response = serde_json::json!({ "jsonrpc": "2.0", @@ -1058,9 +1454,14 @@ mod tests { Self { url: format!("ws://{}", local_addr), shutdown_tx: Some(shutdown_tx), + observed_methods, } } + async fn observed_methods_snapshot(&self) -> std::vec::Vec { + let observed_methods_guard = self.observed_methods.lock().await; + observed_methods_guard.clone() + } async fn shutdown(mut self) { if let Some(shutdown_tx) = self.shutdown_tx.take() { let _ = shutdown_tx.send(()); @@ -1211,7 +1612,7 @@ mod tests { } #[tokio::test] - async fn send_json_rpc_request_emits_success_response_and_notification() { + async fn subscribe_registers_subscription_and_routes_notification() { let server = TestWsServer::spawn_json_rpc_server().await; let endpoint = make_ws_endpoint(server.url.clone()); let client = crate::WsClient::new(endpoint).expect("client creation must succeed"); @@ -1223,54 +1624,79 @@ mod tests { .await .expect("json-rpc send must succeed"); assert_eq!(request_id, 1); - let mut success_seen = false; - let mut notification_seen = false; - for _ in 0..6 { + + let mut subscription_registered_seen = false; + let mut subscription_notification_seen = false; + + for _ in 0..8 { let event = recv_event(&mut receiver).await; match event { - crate::WsEvent::JsonRpcMessage { + crate::WsEvent::SubscriptionRegistered { endpoint_name, - message, + subscription, } => { assert_eq!(endpoint_name, "test_ws"); - match message { - crate::KbJsonRpcWsIncomingMessage::SuccessResponse(response) => { - assert_eq!(response.id, serde_json::Value::from(1u64)); - assert_eq!(response.result, serde_json::Value::from(77u64)); - success_seen = true; - } - crate::KbJsonRpcWsIncomingMessage::Notification(notification) => { - assert_eq!(notification.method, "slotNotification"); - assert_eq!(notification.params.subscription, 77); - assert_eq!( - notification.params.result["slot"], - serde_json::Value::from(12u64) - ); - notification_seen = true; - } - crate::KbJsonRpcWsIncomingMessage::ErrorResponse(other) => { - panic!("unexpected error response: {other:?}"); - } - } + assert_eq!(subscription.request_id, 1); + assert_eq!(subscription.subscription_id, 77); + assert_eq!(subscription.subscribe_method, "slotSubscribe"); + assert_eq!(subscription.unsubscribe_method, "slotUnsubscribe"); + assert_eq!(subscription.notification_method, "slotNotification"); + subscription_registered_seen = true; + } + crate::WsEvent::SubscriptionNotification { + endpoint_name, + subscription, + notification, + method_matches_registry, + } => { + assert_eq!(endpoint_name, "test_ws"); + assert_eq!(subscription.subscription_id, 77); + assert!(method_matches_registry); + assert_eq!(notification.method, "slotNotification"); + assert_eq!(notification.params.subscription, 77); + assert_eq!( + notification.params.result["slot"], + serde_json::Value::from(12u64) + ); + subscription_notification_seen = true; } crate::WsEvent::TextMessage { .. } => {} + crate::WsEvent::JsonRpcMessage { .. } => {} other => { panic!("unexpected event: {other:?}"); } } - - if success_seen && notification_seen { + if subscription_registered_seen && subscription_notification_seen { break; } } - assert!(success_seen, "json-rpc success response must be observed"); - assert!(notification_seen, "json-rpc notification must be observed"); + assert!( + subscription_registered_seen, + "subscription must be registered" + ); + assert!( + subscription_notification_seen, + "subscription notification must be routed" + ); + assert_eq!(client.active_subscription_count().await, 1); + assert_eq!(client.pending_request_count().await, 0); client.disconnect().await.expect("disconnect must succeed"); + let observed_methods = server.observed_methods_snapshot().await; + assert!( + observed_methods + .iter() + .any(|method| method == "slotSubscribe") + ); + assert!( + observed_methods + .iter() + .any(|method| method == "slotUnsubscribe") + ); server.shutdown().await; } #[tokio::test] - async fn send_unknown_json_rpc_method_emits_error_response() { + async fn unknown_json_rpc_method_emits_error_and_clears_pending_request() { let server = TestWsServer::spawn_json_rpc_server().await; let endpoint = make_ws_endpoint(server.url.clone()); let client = crate::WsClient::new(endpoint).expect("client creation must succeed"); @@ -1283,7 +1709,7 @@ mod tests { .expect("json-rpc send must succeed"); assert_eq!(request_id, 1); let mut error_seen = false; - for _ in 0..4 { + for _ in 0..6 { let event = recv_event(&mut receiver).await; match event { crate::WsEvent::JsonRpcMessage { message, .. } => match message { @@ -1308,6 +1734,7 @@ mod tests { } } assert!(error_seen, "json-rpc error response must be observed"); + assert_eq!(client.pending_request_count().await, 0); client.disconnect().await.expect("disconnect must succeed"); server.shutdown().await; }