Compare commits

...

2 Commits

Author SHA1 Message Date
bbad47b0e8 0.3.0 2026-04-21 07:22:39 +02:00
21630c1c01 syntaxe correction 2026-04-21 06:04:31 +02:00
8 changed files with 573 additions and 112 deletions

View File

@@ -5,3 +5,4 @@
0.1.0 - Transport WebSocket générique 0.1.0 - Transport WebSocket générique
0.1.1 - Intégration Tauri minimale du WsClient 0.1.1 - Intégration Tauri minimale du WsClient
0.2.0 - Couche JSON-RPC WS Solana 0.2.0 - Couche JSON-RPC WS Solana
0.3.0 - Registre subscriptions / notifications

View File

@@ -8,7 +8,7 @@ members = [
] ]
[workspace.package] [workspace.package]
version = "0.2.0" version = "0.3.0"
edition = "2024" edition = "2024"
license = "MIT" license = "MIT"
repository = "https://git.sasedev.com/Sasedev/khadhroony-bobobot" repository = "https://git.sasedev.com/Sasedev/khadhroony-bobobot"

View File

@@ -1,7 +1,7 @@
{ {
"name": "kb-app", "name": "kb-app",
"private": true, "private": true,
"version": "0.2.0", "version": "0.3.0",
"type": "module", "type": "module",
"scripts": { "scripts": {
"dev": "vite", "dev": "vite",

View File

@@ -330,9 +330,7 @@ fn kb_emit_app_log(app_handle: &tauri::AppHandle, message: &str) {
} }
} }
fn kb_format_ws_event( fn kb_format_ws_event(event: &kb_lib::WsEvent) -> std::string::String {
event: &kb_lib::WsEvent,
) -> std::string::String {
match event { match event {
kb_lib::WsEvent::Connected { kb_lib::WsEvent::Connected {
endpoint_name, endpoint_name,
@@ -349,23 +347,19 @@ fn kb_format_ws_event(
kb_lib::WsEvent::JsonRpcMessage { kb_lib::WsEvent::JsonRpcMessage {
endpoint_name, endpoint_name,
message, message,
} => { } => match message {
match message {
kb_lib::KbJsonRpcWsIncomingMessage::SuccessResponse(response) => { kb_lib::KbJsonRpcWsIncomingMessage::SuccessResponse(response) => {
format!( format!(
"[ws:{endpoint_name}] json-rpc success id={} result={}", "[ws:{endpoint_name}] json-rpc success id={} result={}",
response.id, response.id, response.result
response.result
) )
}, }
kb_lib::KbJsonRpcWsIncomingMessage::ErrorResponse(response) => { kb_lib::KbJsonRpcWsIncomingMessage::ErrorResponse(response) => {
format!( format!(
"[ws:{endpoint_name}] json-rpc error id={} code={} message={}", "[ws:{endpoint_name}] json-rpc error id={} code={} message={}",
response.id, response.id, response.error.code, response.error.message
response.error.code,
response.error.message
) )
}, }
kb_lib::KbJsonRpcWsIncomingMessage::Notification(notification) => { kb_lib::KbJsonRpcWsIncomingMessage::Notification(notification) => {
format!( format!(
"[ws:{endpoint_name}] json-rpc notification method={} subscription={} result={}", "[ws:{endpoint_name}] json-rpc notification method={} subscription={} result={}",
@@ -373,7 +367,6 @@ fn kb_format_ws_event(
notification.params.subscription, notification.params.subscription,
notification.params.result notification.params.result
) )
},
} }
}, },
kb_lib::WsEvent::JsonRpcParseError { kb_lib::WsEvent::JsonRpcParseError {
@@ -383,31 +376,75 @@ fn kb_format_ws_event(
} => { } => {
format!( format!(
"[ws:{endpoint_name}] json-rpc parse error: {} | raw={}", "[ws:{endpoint_name}] json-rpc parse error: {} | raw={}",
error, error, text
text
) )
}, }
kb_lib::WsEvent::SubscriptionRegistered {
endpoint_name,
subscription,
} => {
format!(
"[ws:{endpoint_name}] subscription registered subscribe_method={} unsubscribe_method={} notification_method={} request_id={} subscription_id={}",
subscription.subscribe_method,
subscription.unsubscribe_method,
subscription.notification_method,
subscription.request_id,
subscription.subscription_id
)
}
kb_lib::WsEvent::SubscriptionNotification {
endpoint_name,
subscription,
notification,
method_matches_registry,
} => {
format!(
"[ws:{endpoint_name}] tracked notification method={} expected_method={} matches_registry={} subscription_id={} result={}",
notification.method,
subscription.notification_method,
method_matches_registry,
subscription.subscription_id,
notification.params.result
)
}
kb_lib::WsEvent::JsonRpcNotificationWithoutSubscription {
endpoint_name,
notification,
} => {
format!(
"[ws:{endpoint_name}] untracked notification method={} subscription={} result={}",
notification.method, notification.params.subscription, notification.params.result
)
}
kb_lib::WsEvent::SubscriptionUnregistered {
endpoint_name,
subscription_id,
unsubscribe_method,
was_active,
} => {
format!(
"[ws:{endpoint_name}] subscription unregistered subscription_id={} unsubscribe_method={} was_active={}",
subscription_id, unsubscribe_method, was_active
)
}
kb_lib::WsEvent::BinaryMessage { kb_lib::WsEvent::BinaryMessage {
endpoint_name, endpoint_name,
data, data,
} => { } => {
format!( format!("[ws:{endpoint_name}] binary message ({} bytes)", data.len())
"[ws:{endpoint_name}] binary message ({} bytes)", }
data.len()
)
},
kb_lib::WsEvent::Ping { kb_lib::WsEvent::Ping {
endpoint_name, endpoint_name,
data, data,
} => { } => {
format!("[ws:{endpoint_name}] ping ({} bytes)", data.len()) format!("[ws:{endpoint_name}] ping ({} bytes)", data.len())
}, }
kb_lib::WsEvent::Pong { kb_lib::WsEvent::Pong {
endpoint_name, endpoint_name,
data, data,
} => { } => {
format!("[ws:{endpoint_name}] pong ({} bytes)", data.len()) format!("[ws:{endpoint_name}] pong ({} bytes)", data.len())
}, }
kb_lib::WsEvent::CloseReceived { kb_lib::WsEvent::CloseReceived {
endpoint_name, endpoint_name,
code, code,
@@ -415,21 +452,18 @@ fn kb_format_ws_event(
} => { } => {
format!( format!(
"[ws:{endpoint_name}] close received code={:?} reason={:?}", "[ws:{endpoint_name}] close received code={:?} reason={:?}",
code, code, reason
reason
) )
}, }
kb_lib::WsEvent::Disconnected { kb_lib::WsEvent::Disconnected { endpoint_name } => {
endpoint_name,
} => {
format!("[ws:{endpoint_name}] disconnected") format!("[ws:{endpoint_name}] disconnected")
}, }
kb_lib::WsEvent::Error { kb_lib::WsEvent::Error {
endpoint_name, endpoint_name,
error, error,
} => { } => {
format!("[ws:{endpoint_name}] error: {error}") format!("[ws:{endpoint_name}] error: {error}")
}, }
} }
} }

View File

@@ -1,7 +1,7 @@
{ {
"$schema": "https://schema.tauri.app/config/2", "$schema": "https://schema.tauri.app/config/2",
"productName": "kb-bapp", "productName": "kb-bapp",
"version": "0.2.0", "version": "0.3.0",
"identifier": "com.sasedev.kb-app", "identifier": "com.sasedev.kb-app",
"build": { "build": {
"beforeDevCommand": "npm run dev", "beforeDevCommand": "npm run dev",

View File

@@ -7,7 +7,7 @@ import { resolve } from 'path';
const host = process.env.TAURI_DEV_HOST; const host = process.env.TAURI_DEV_HOST;
// https://vite.dev/config/ // https://vite.dev/config/
export default defineConfig(async () => ({ export default defineConfig(() => ({
envPrefix: ['VITE_', 'TAURI_ENV_*'], envPrefix: ['VITE_', 'TAURI_ENV_*'],
// Vite options tailored for Tauri development and only applied in `tauri dev` or `tauri build` // Vite options tailored for Tauri development and only applied in `tauri dev` or `tauri build`
@@ -57,11 +57,9 @@ export default defineConfig(async () => ({
preprocessorOptions: { preprocessorOptions: {
scss: { scss: {
quietDeps: true, quietDeps: true,
//silenceDeprecations: ["import", "color-functions", "global-builtin"] as const,
silenceDeprecations: ["import", "color-functions", "global-builtin",], silenceDeprecations: ["import", "color-functions", "global-builtin",],
verbose: false, verbose: false,
//api: 'modern', api: 'modern',
//api: 'modern-compiler',
importers: [new NodePackageImporter()], importers: [new NodePackageImporter()],
} }
} }

View File

@@ -43,3 +43,4 @@ pub use crate::types::KbConnectionState;
pub use crate::ws_client::WsClient; pub use crate::ws_client::WsClient;
pub use crate::ws_client::WsEvent; pub use crate::ws_client::WsEvent;
pub use crate::ws_client::WsOutgoingMessage; pub use crate::ws_client::WsOutgoingMessage;
pub use crate::ws_client::WsSubscriptionInfo;

View File

@@ -2,9 +2,12 @@
//! Generic asynchronous WebSocket transport client. //! Generic asynchronous WebSocket transport client.
//! //!
//! Version `0.2.x` keeps the transport layer introduced in `0.1.x` and adds //! Version `0.3.x` keeps the transport and JSON-RPC helpers introduced earlier
//! generic JSON-RPC 2.0 request helpers plus incoming JSON-RPC parsing for //! and adds:
//! text messages received from the server. //! - a registry of pending JSON-RPC requests
//! - a registry of active subscriptions
//! - automatic routing of notifications to known subscriptions
//! - automatic unsubscribe attempts before disconnect
use futures_util::SinkExt; use futures_util::SinkExt;
use futures_util::StreamExt; use futures_util::StreamExt;
@@ -24,6 +27,23 @@ pub enum WsOutgoingMessage {
Close, Close,
} }
/// Active subscription metadata tracked by the client runtime.
#[derive(Clone, Debug, PartialEq)]
pub struct WsSubscriptionInfo {
/// Local request identifier that created the subscription.
pub request_id: u64,
/// Remote subscription identifier returned by the server.
pub subscription_id: u64,
/// Subscribe method name.
pub subscribe_method: std::string::String,
/// Unsubscribe method name paired with the subscription.
pub unsubscribe_method: std::string::String,
/// Expected notification method name.
pub notification_method: std::string::String,
/// Original subscribe request parameters.
pub params: std::vec::Vec<serde_json::Value>,
}
/// Incoming WebSocket transport event emitted by [`crate::WsClient`]. /// Incoming WebSocket transport event emitted by [`crate::WsClient`].
#[derive(Clone, Debug, PartialEq)] #[derive(Clone, Debug, PartialEq)]
pub enum WsEvent { pub enum WsEvent {
@@ -57,6 +77,42 @@ pub enum WsEvent {
/// Parse error. /// Parse error.
error: crate::KbError, error: crate::KbError,
}, },
/// A subscribe response created a tracked active subscription.
SubscriptionRegistered {
/// Stable endpoint name from configuration.
endpoint_name: std::string::String,
/// Registered subscription metadata.
subscription: WsSubscriptionInfo,
},
/// A notification was matched to a tracked active subscription.
SubscriptionNotification {
/// Stable endpoint name from configuration.
endpoint_name: std::string::String,
/// Matched subscription metadata.
subscription: WsSubscriptionInfo,
/// Received notification payload.
notification: crate::KbJsonRpcWsNotification,
/// Indicates whether the notification method matches the expected one.
method_matches_registry: bool,
},
/// A notification could not be matched to any tracked active subscription.
JsonRpcNotificationWithoutSubscription {
/// Stable endpoint name from configuration.
endpoint_name: std::string::String,
/// Received notification payload.
notification: crate::KbJsonRpcWsNotification,
},
/// An unsubscribe response removed a tracked subscription.
SubscriptionUnregistered {
/// Stable endpoint name from configuration.
endpoint_name: std::string::String,
/// Removed subscription identifier.
subscription_id: u64,
/// Unsubscribe method used by the request.
unsubscribe_method: std::string::String,
/// Indicates whether the subscription was active before removal.
was_active: bool,
},
/// Binary message received. /// Binary message received.
BinaryMessage { BinaryMessage {
/// Stable endpoint name from configuration. /// Stable endpoint name from configuration.
@@ -126,6 +182,42 @@ impl WsClientRuntime {
} }
} }
#[derive(Debug)]
struct WsClientRegistry {
pending_requests: std::collections::BTreeMap<u64, WsPendingJsonRpcRequest>,
active_subscriptions: std::collections::BTreeMap<u64, WsSubscriptionInfo>,
}
impl WsClientRegistry {
fn new() -> Self {
Self {
pending_requests: std::collections::BTreeMap::new(),
active_subscriptions: std::collections::BTreeMap::new(),
}
}
}
#[derive(Clone, Debug, PartialEq)]
struct WsPendingJsonRpcRequest {
request_id: u64,
method: std::string::String,
kind: WsPendingJsonRpcRequestKind,
}
#[derive(Clone, Debug, PartialEq)]
enum WsPendingJsonRpcRequestKind {
Generic,
Subscribe {
notification_method: std::string::String,
unsubscribe_method: std::string::String,
params: std::vec::Vec<serde_json::Value>,
},
Unsubscribe {
subscription_id: u64,
unsubscribe_method: std::string::String,
},
}
#[derive(Clone, Debug)] #[derive(Clone, Debug)]
enum WsWriteCommand { enum WsWriteCommand {
Send(WsOutgoingMessage), Send(WsOutgoingMessage),
@@ -140,6 +232,7 @@ pub struct WsClient {
state: std::sync::Arc<tokio::sync::RwLock<crate::KbConnectionState>>, state: std::sync::Arc<tokio::sync::RwLock<crate::KbConnectionState>>,
event_tx: tokio::sync::broadcast::Sender<WsEvent>, event_tx: tokio::sync::broadcast::Sender<WsEvent>,
runtime: std::sync::Arc<tokio::sync::Mutex<WsClientRuntime>>, runtime: std::sync::Arc<tokio::sync::Mutex<WsClientRuntime>>,
registry: std::sync::Arc<tokio::sync::Mutex<WsClientRegistry>>,
} }
impl WsClient { impl WsClient {
@@ -165,6 +258,7 @@ impl WsClient {
)), )),
event_tx, event_tx,
runtime: std::sync::Arc::new(tokio::sync::Mutex::new(WsClientRuntime::new())), runtime: std::sync::Arc::new(tokio::sync::Mutex::new(WsClientRuntime::new())),
registry: std::sync::Arc::new(tokio::sync::Mutex::new(WsClientRegistry::new())),
}) })
} }
@@ -200,6 +294,28 @@ impl WsClient {
*state_guard *state_guard
} }
/// Returns the number of tracked pending JSON-RPC requests.
pub async fn pending_request_count(&self) -> usize {
let registry_guard = self.registry.lock().await;
registry_guard.pending_requests.len()
}
/// Returns the number of tracked active subscriptions.
pub async fn active_subscription_count(&self) -> usize {
let registry_guard = self.registry.lock().await;
registry_guard.active_subscriptions.len()
}
/// Returns a snapshot of the tracked active subscriptions.
pub async fn active_subscriptions(&self) -> std::vec::Vec<WsSubscriptionInfo> {
let registry_guard = self.registry.lock().await;
registry_guard
.active_subscriptions
.values()
.cloned()
.collect()
}
/// Connects the client to its remote WebSocket endpoint. /// Connects the client to its remote WebSocket endpoint.
pub async fn connect(&self) -> Result<(), crate::KbError> { pub async fn connect(&self) -> Result<(), crate::KbError> {
if !self.endpoint.enabled { if !self.endpoint.enabled {
@@ -243,6 +359,7 @@ impl WsClient {
return Err(error); return Err(error);
} }
}; };
let (ws_stream, _response) = match connect_result { let (ws_stream, _response) = match connect_result {
Ok(parts) => parts, Ok(parts) => parts,
Err(error) => { Err(error) => {
@@ -393,7 +510,7 @@ impl WsClient {
self.send_text(text).await self.send_text(text).await
} }
/// Sends a prebuilt JSON-RPC request object. /// Sends a prebuilt JSON-RPC request object and tracks it when the request id is numeric.
pub async fn send_json_rpc_request_object( pub async fn send_json_rpc_request_object(
&self, &self,
request: &crate::KbJsonRpcWsRequest, request: &crate::KbJsonRpcWsRequest,
@@ -403,7 +520,24 @@ impl WsClient {
Ok(value) => value, Ok(value) => value,
Err(error) => return Err(error), Err(error) => return Err(error),
}; };
self.send_json_value(&value).await let tracked_request = kb_build_pending_json_rpc_request(request);
if let Some(tracked_request) = &tracked_request {
let mut registry_guard = self.registry.lock().await;
registry_guard
.pending_requests
.insert(tracked_request.request_id, tracked_request.clone());
}
let send_result = self.send_json_value(&value).await;
if let Err(error) = send_result {
if let Some(tracked_request) = tracked_request {
let mut registry_guard = self.registry.lock().await;
registry_guard
.pending_requests
.remove(&tracked_request.request_id);
}
return Err(error);
}
Ok(())
} }
/// Builds and sends a JSON-RPC request with a generated numeric identifier. /// Builds and sends a JSON-RPC request with a generated numeric identifier.
@@ -430,8 +564,9 @@ impl WsClient {
/// Disconnects the client from its remote endpoint. /// Disconnects the client from its remote endpoint.
/// ///
/// This method initiates a close handshake, signals shutdown, waits for the /// Before closing the transport, this method attempts to unsubscribe all
/// transport tasks to complete, and aborts them if the timeout is exceeded. /// currently tracked active subscriptions and waits up to
/// `unsubscribe_timeout_ms` for their removal.
pub async fn disconnect(&self) -> Result<(), crate::KbError> { pub async fn disconnect(&self) -> Result<(), crate::KbError> {
let current_state = self.connection_state().await; let current_state = self.connection_state().await;
if current_state == crate::KbConnectionState::Disconnected { if current_state == crate::KbConnectionState::Disconnected {
@@ -451,6 +586,14 @@ impl WsClient {
endpoint_name = %self.endpoint.name, endpoint_name = %self.endpoint.name,
"disconnecting websocket client" "disconnecting websocket client"
); );
let auto_unsubscribe_result = self.unsubscribe_all_active_subscriptions().await;
if let Err(error) = auto_unsubscribe_result {
tracing::warn!(
endpoint_name = %self.endpoint.name,
"automatic unsubscribe phase failed before disconnect: {}",
error
);
}
let ( let (
generation, generation,
writer_tx_option, writer_tx_option,
@@ -544,6 +687,51 @@ impl WsClient {
} }
} }
async fn unsubscribe_all_active_subscriptions(&self) -> Result<usize, crate::KbError> {
let subscriptions = self.active_subscriptions().await;
if subscriptions.is_empty() {
return Ok(0);
}
tracing::info!(
endpoint_name = %self.endpoint.name,
subscription_count = subscriptions.len(),
"sending automatic unsubscribe requests before disconnect"
);
for subscription in &subscriptions {
let unsubscribe_params = vec![serde_json::Value::from(subscription.subscription_id)];
let send_result = self
.send_json_rpc_request(subscription.unsubscribe_method.clone(), unsubscribe_params)
.await;
if let Err(error) = send_result {
tracing::warn!(
endpoint_name = %self.endpoint.name,
subscription_id = subscription.subscription_id,
unsubscribe_method = %subscription.unsubscribe_method,
"cannot send automatic unsubscribe request: {}",
error
);
}
}
let started_at = std::time::Instant::now();
let wait_timeout = std::time::Duration::from_millis(self.endpoint.unsubscribe_timeout_ms);
loop {
let active_count = self.active_subscription_count().await;
if active_count == 0 {
break;
}
if started_at.elapsed() >= wait_timeout {
tracing::warn!(
endpoint_name = %self.endpoint.name,
remaining_active_subscriptions = active_count,
"automatic unsubscribe wait timeout reached"
);
break;
}
tokio::time::sleep(std::time::Duration::from_millis(25)).await;
}
Ok(subscriptions.len())
}
async fn run_supervisor( async fn run_supervisor(
&self, &self,
generation: u64, generation: u64,
@@ -627,8 +815,9 @@ impl WsClient {
Ok(parsed_message) => { Ok(parsed_message) => {
self.emit_event(WsEvent::JsonRpcMessage { self.emit_event(WsEvent::JsonRpcMessage {
endpoint_name: self.endpoint.name.clone(), endpoint_name: self.endpoint.name.clone(),
message: parsed_message, message: parsed_message.clone(),
}); });
self.handle_incoming_json_rpc_message(&parsed_message).await;
} }
Err(error) => { Err(error) => {
self.emit_event(WsEvent::JsonRpcParseError { self.emit_event(WsEvent::JsonRpcParseError {
@@ -686,6 +875,114 @@ impl WsClient {
} }
} }
async fn handle_incoming_json_rpc_message(&self, message: &crate::KbJsonRpcWsIncomingMessage) {
match message {
crate::KbJsonRpcWsIncomingMessage::SuccessResponse(response) => {
let request_id_option = kb_json_value_to_u64(&response.id);
let request_id = match request_id_option {
Some(request_id) => request_id,
None => return,
};
let pending_request_option = {
let mut registry_guard = self.registry.lock().await;
registry_guard.pending_requests.remove(&request_id)
};
let pending_request = match pending_request_option {
Some(pending_request) => pending_request,
None => return,
};
match pending_request.kind {
WsPendingJsonRpcRequestKind::Generic => {}
WsPendingJsonRpcRequestKind::Subscribe {
notification_method,
unsubscribe_method,
params,
} => {
let subscription_id_option = response.result.as_u64();
let subscription_id = match subscription_id_option {
Some(subscription_id) => subscription_id,
None => return,
};
let subscription = WsSubscriptionInfo {
request_id,
subscription_id,
subscribe_method: pending_request.method,
unsubscribe_method,
notification_method,
params,
};
{
let mut registry_guard = self.registry.lock().await;
registry_guard
.active_subscriptions
.insert(subscription_id, subscription.clone());
}
self.emit_event(WsEvent::SubscriptionRegistered {
endpoint_name: self.endpoint.name.clone(),
subscription,
});
}
WsPendingJsonRpcRequestKind::Unsubscribe {
subscription_id,
unsubscribe_method,
} => {
let result_bool = response.result.as_bool();
if result_bool != Some(true) {
return;
}
let removed_subscription_option = {
let mut registry_guard = self.registry.lock().await;
registry_guard.active_subscriptions.remove(&subscription_id)
};
self.emit_event(WsEvent::SubscriptionUnregistered {
endpoint_name: self.endpoint.name.clone(),
subscription_id,
unsubscribe_method,
was_active: removed_subscription_option.is_some(),
});
}
}
}
crate::KbJsonRpcWsIncomingMessage::ErrorResponse(response) => {
let request_id_option = kb_json_value_to_u64(&response.id);
let request_id = match request_id_option {
Some(request_id) => request_id,
None => return,
};
let mut registry_guard = self.registry.lock().await;
registry_guard.pending_requests.remove(&request_id);
}
crate::KbJsonRpcWsIncomingMessage::Notification(notification) => {
let subscription_id = notification.params.subscription;
let matched_subscription_option = {
let registry_guard = self.registry.lock().await;
registry_guard
.active_subscriptions
.get(&subscription_id)
.cloned()
};
match matched_subscription_option {
Some(subscription) => {
let method_matches_registry =
subscription.notification_method == notification.method;
self.emit_event(WsEvent::SubscriptionNotification {
endpoint_name: self.endpoint.name.clone(),
subscription,
notification: notification.clone(),
method_matches_registry,
});
}
None => {
self.emit_event(WsEvent::JsonRpcNotificationWithoutSubscription {
endpoint_name: self.endpoint.name.clone(),
notification: notification.clone(),
});
}
}
}
}
}
async fn run_write_loop<S>( async fn run_write_loop<S>(
&self, &self,
mut write_half: futures_util::stream::SplitSink< mut write_half: futures_util::stream::SplitSink<
@@ -697,14 +994,12 @@ impl WsClient {
) where ) where
S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + std::marker::Send + 'static, S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + std::marker::Send + 'static,
{ {
let close_sent = false;
loop { loop {
tokio::select! { tokio::select! {
shutdown_result = shutdown_rx.changed() => { shutdown_result = shutdown_rx.changed() => {
match shutdown_result { match shutdown_result {
Ok(()) => { Ok(()) => {
if *shutdown_rx.borrow() { if *shutdown_rx.borrow() {
if !close_sent {
let close_result = write_half.send( let close_result = write_half.send(
tokio_tungstenite::tungstenite::Message::Close(None) tokio_tungstenite::tungstenite::Message::Close(None)
).await; ).await;
@@ -722,7 +1017,6 @@ impl WsClient {
} }
}, },
} }
}
let flush_result = write_half.flush().await; let flush_result = write_half.flush().await;
if let Err(error) = flush_result { if let Err(error) = flush_result {
if !kb_is_normal_close_error(&error) { if !kb_is_normal_close_error(&error) {
@@ -846,6 +1140,11 @@ impl WsClient {
if !should_clear { if !should_clear {
return false; return false;
} }
{
let mut registry_guard = self.registry.lock().await;
registry_guard.pending_requests.clear();
registry_guard.active_subscriptions.clear();
}
let mut state_guard = self.state.write().await; let mut state_guard = self.state.write().await;
if *state_guard == crate::KbConnectionState::Disconnected { if *state_guard == crate::KbConnectionState::Disconnected {
return false; return false;
@@ -887,6 +1186,81 @@ fn kb_is_normal_close_error(error: &tokio_tungstenite::tungstenite::Error) -> bo
} }
} }
fn kb_json_value_to_u64(value: &serde_json::Value) -> std::option::Option<u64> {
value.as_u64()
}
fn kb_is_subscribe_method(method: &str) -> bool {
method.ends_with("Subscribe")
}
fn kb_is_unsubscribe_method(method: &str) -> bool {
method.ends_with("Unsubscribe")
}
fn kb_infer_unsubscribe_method_from_subscribe(
subscribe_method: &str,
) -> std::option::Option<std::string::String> {
if !kb_is_subscribe_method(subscribe_method) {
return None;
}
let base = subscribe_method.trim_end_matches("Subscribe");
Some(format!("{base}Unsubscribe"))
}
fn kb_infer_notification_method_from_subscribe(
subscribe_method: &str,
) -> std::option::Option<std::string::String> {
if !kb_is_subscribe_method(subscribe_method) {
return None;
}
let base = subscribe_method.trim_end_matches("Subscribe");
Some(format!("{base}Notification"))
}
fn kb_build_pending_json_rpc_request(
request: &crate::KbJsonRpcWsRequest,
) -> std::option::Option<WsPendingJsonRpcRequest> {
let request_id_option = kb_json_value_to_u64(&request.id);
let request_id = request_id_option?;
if kb_is_subscribe_method(&request.method) {
let notification_method_option =
kb_infer_notification_method_from_subscribe(&request.method);
let unsubscribe_method_option = kb_infer_unsubscribe_method_from_subscribe(&request.method);
let notification_method = notification_method_option?;
let unsubscribe_method = unsubscribe_method_option?;
return Some(WsPendingJsonRpcRequest {
request_id,
method: request.method.clone(),
kind: WsPendingJsonRpcRequestKind::Subscribe {
notification_method,
unsubscribe_method,
params: request.params.clone(),
},
});
}
if kb_is_unsubscribe_method(&request.method) {
let first_param_option = request.params.first();
let first_param = first_param_option?;
let subscription_id_option = first_param.as_u64();
let subscription_id = subscription_id_option?;
return Some(WsPendingJsonRpcRequest {
request_id,
method: request.method.clone(),
kind: WsPendingJsonRpcRequestKind::Unsubscribe {
subscription_id,
unsubscribe_method: request.method.clone(),
},
});
}
Some(WsPendingJsonRpcRequest {
request_id,
method: request.method.clone(),
kind: WsPendingJsonRpcRequestKind::Generic,
})
}
#[cfg(test)] #[cfg(test)]
mod tests { mod tests {
use futures_util::SinkExt; use futures_util::SinkExt;
@@ -896,10 +1270,13 @@ mod tests {
struct TestWsServer { struct TestWsServer {
url: std::string::String, url: std::string::String,
shutdown_tx: std::option::Option<tokio::sync::oneshot::Sender<()>>, shutdown_tx: std::option::Option<tokio::sync::oneshot::Sender<()>>,
observed_methods: std::sync::Arc<tokio::sync::Mutex<std::vec::Vec<std::string::String>>>,
} }
impl TestWsServer { impl TestWsServer {
async fn spawn_echo_server() -> Self { async fn spawn_echo_server() -> Self {
let observed_methods =
std::sync::Arc::new(tokio::sync::Mutex::new(std::vec::Vec::new()));
let bind_result = tokio::net::TcpListener::bind("127.0.0.1:0").await; let bind_result = tokio::net::TcpListener::bind("127.0.0.1:0").await;
let listener = bind_result.expect("listener bind must succeed"); let listener = bind_result.expect("listener bind must succeed");
let local_addr = listener.local_addr().expect("local addr must be available"); let local_addr = listener.local_addr().expect("local addr must be available");
@@ -962,14 +1339,18 @@ mod tests {
Self { Self {
url: format!("ws://{}", local_addr), url: format!("ws://{}", local_addr),
shutdown_tx: Some(shutdown_tx), shutdown_tx: Some(shutdown_tx),
observed_methods,
} }
} }
async fn spawn_json_rpc_server() -> Self { async fn spawn_json_rpc_server() -> Self {
let observed_methods =
std::sync::Arc::new(tokio::sync::Mutex::new(std::vec::Vec::new()));
let bind_result = tokio::net::TcpListener::bind("127.0.0.1:0").await; let bind_result = tokio::net::TcpListener::bind("127.0.0.1:0").await;
let listener = bind_result.expect("listener bind must succeed"); let listener = bind_result.expect("listener bind must succeed");
let local_addr = listener.local_addr().expect("local addr must be available"); let local_addr = listener.local_addr().expect("local addr must be available");
let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>(); let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>();
let observed_methods_for_server = observed_methods.clone();
tokio::spawn(async move { tokio::spawn(async move {
loop { loop {
tokio::select! { tokio::select! {
@@ -978,6 +1359,7 @@ mod tests {
}, },
accept_result = listener.accept() => { accept_result = listener.accept() => {
let (stream, _peer_addr) = accept_result.expect("accept must succeed"); let (stream, _peer_addr) = accept_result.expect("accept must succeed");
let observed_methods_for_connection = observed_methods_for_server.clone();
tokio::spawn(async move { tokio::spawn(async move {
let accept_ws_result = tokio_tungstenite::accept_async(stream).await; let accept_ws_result = tokio_tungstenite::accept_async(stream).await;
let mut ws_stream = accept_ws_result.expect("websocket accept must succeed"); let mut ws_stream = accept_ws_result.expect("websocket accept must succeed");
@@ -994,8 +1376,13 @@ mod tests {
tokio_tungstenite::tungstenite::Message::Text(text) => { tokio_tungstenite::tungstenite::Message::Text(text) => {
let value: serde_json::Value = serde_json::from_str(text.as_ref()) let value: serde_json::Value = serde_json::from_str(text.as_ref())
.expect("request json must parse"); .expect("request json must parse");
let method = value["method"].as_str().expect("method must be a string"); let method = value["method"].as_str().expect("method must be a string").to_string();
let id = value["id"].clone(); let id = value["id"].clone();
{
let mut observed_methods_guard =
observed_methods_for_connection.lock().await;
observed_methods_guard.push(method.clone());
}
if method == "slotSubscribe" { if method == "slotSubscribe" {
let response = serde_json::json!({ let response = serde_json::json!({
"jsonrpc": "2.0", "jsonrpc": "2.0",
@@ -1020,6 +1407,15 @@ mod tests {
ws_stream.send( ws_stream.send(
tokio_tungstenite::tungstenite::Message::Text(notification.to_string().into()) tokio_tungstenite::tungstenite::Message::Text(notification.to_string().into())
).await.expect("notification send must succeed"); ).await.expect("notification send must succeed");
} else if method == "slotUnsubscribe" {
let response = serde_json::json!({
"jsonrpc": "2.0",
"result": true,
"id": id
});
ws_stream.send(
tokio_tungstenite::tungstenite::Message::Text(response.to_string().into())
).await.expect("unsubscribe response send must succeed");
} else { } else {
let response = serde_json::json!({ let response = serde_json::json!({
"jsonrpc": "2.0", "jsonrpc": "2.0",
@@ -1058,9 +1454,14 @@ mod tests {
Self { Self {
url: format!("ws://{}", local_addr), url: format!("ws://{}", local_addr),
shutdown_tx: Some(shutdown_tx), shutdown_tx: Some(shutdown_tx),
observed_methods,
} }
} }
async fn observed_methods_snapshot(&self) -> std::vec::Vec<std::string::String> {
let observed_methods_guard = self.observed_methods.lock().await;
observed_methods_guard.clone()
}
async fn shutdown(mut self) { async fn shutdown(mut self) {
if let Some(shutdown_tx) = self.shutdown_tx.take() { if let Some(shutdown_tx) = self.shutdown_tx.take() {
let _ = shutdown_tx.send(()); let _ = shutdown_tx.send(());
@@ -1211,7 +1612,7 @@ mod tests {
} }
#[tokio::test] #[tokio::test]
async fn send_json_rpc_request_emits_success_response_and_notification() { async fn subscribe_registers_subscription_and_routes_notification() {
let server = TestWsServer::spawn_json_rpc_server().await; let server = TestWsServer::spawn_json_rpc_server().await;
let endpoint = make_ws_endpoint(server.url.clone()); let endpoint = make_ws_endpoint(server.url.clone());
let client = crate::WsClient::new(endpoint).expect("client creation must succeed"); let client = crate::WsClient::new(endpoint).expect("client creation must succeed");
@@ -1223,54 +1624,79 @@ mod tests {
.await .await
.expect("json-rpc send must succeed"); .expect("json-rpc send must succeed");
assert_eq!(request_id, 1); assert_eq!(request_id, 1);
let mut success_seen = false;
let mut notification_seen = false; let mut subscription_registered_seen = false;
for _ in 0..6 { let mut subscription_notification_seen = false;
for _ in 0..8 {
let event = recv_event(&mut receiver).await; let event = recv_event(&mut receiver).await;
match event { match event {
crate::WsEvent::JsonRpcMessage { crate::WsEvent::SubscriptionRegistered {
endpoint_name, endpoint_name,
message, subscription,
} => { } => {
assert_eq!(endpoint_name, "test_ws"); assert_eq!(endpoint_name, "test_ws");
match message { assert_eq!(subscription.request_id, 1);
crate::KbJsonRpcWsIncomingMessage::SuccessResponse(response) => { assert_eq!(subscription.subscription_id, 77);
assert_eq!(response.id, serde_json::Value::from(1u64)); assert_eq!(subscription.subscribe_method, "slotSubscribe");
assert_eq!(response.result, serde_json::Value::from(77u64)); assert_eq!(subscription.unsubscribe_method, "slotUnsubscribe");
success_seen = true; assert_eq!(subscription.notification_method, "slotNotification");
subscription_registered_seen = true;
} }
crate::KbJsonRpcWsIncomingMessage::Notification(notification) => { crate::WsEvent::SubscriptionNotification {
endpoint_name,
subscription,
notification,
method_matches_registry,
} => {
assert_eq!(endpoint_name, "test_ws");
assert_eq!(subscription.subscription_id, 77);
assert!(method_matches_registry);
assert_eq!(notification.method, "slotNotification"); assert_eq!(notification.method, "slotNotification");
assert_eq!(notification.params.subscription, 77); assert_eq!(notification.params.subscription, 77);
assert_eq!( assert_eq!(
notification.params.result["slot"], notification.params.result["slot"],
serde_json::Value::from(12u64) serde_json::Value::from(12u64)
); );
notification_seen = true; subscription_notification_seen = true;
}
crate::KbJsonRpcWsIncomingMessage::ErrorResponse(other) => {
panic!("unexpected error response: {other:?}");
}
}
} }
crate::WsEvent::TextMessage { .. } => {} crate::WsEvent::TextMessage { .. } => {}
crate::WsEvent::JsonRpcMessage { .. } => {}
other => { other => {
panic!("unexpected event: {other:?}"); panic!("unexpected event: {other:?}");
} }
} }
if subscription_registered_seen && subscription_notification_seen {
if success_seen && notification_seen {
break; break;
} }
} }
assert!(success_seen, "json-rpc success response must be observed"); assert!(
assert!(notification_seen, "json-rpc notification must be observed"); subscription_registered_seen,
"subscription must be registered"
);
assert!(
subscription_notification_seen,
"subscription notification must be routed"
);
assert_eq!(client.active_subscription_count().await, 1);
assert_eq!(client.pending_request_count().await, 0);
client.disconnect().await.expect("disconnect must succeed"); client.disconnect().await.expect("disconnect must succeed");
let observed_methods = server.observed_methods_snapshot().await;
assert!(
observed_methods
.iter()
.any(|method| method == "slotSubscribe")
);
assert!(
observed_methods
.iter()
.any(|method| method == "slotUnsubscribe")
);
server.shutdown().await; server.shutdown().await;
} }
#[tokio::test] #[tokio::test]
async fn send_unknown_json_rpc_method_emits_error_response() { async fn unknown_json_rpc_method_emits_error_and_clears_pending_request() {
let server = TestWsServer::spawn_json_rpc_server().await; let server = TestWsServer::spawn_json_rpc_server().await;
let endpoint = make_ws_endpoint(server.url.clone()); let endpoint = make_ws_endpoint(server.url.clone());
let client = crate::WsClient::new(endpoint).expect("client creation must succeed"); let client = crate::WsClient::new(endpoint).expect("client creation must succeed");
@@ -1283,7 +1709,7 @@ mod tests {
.expect("json-rpc send must succeed"); .expect("json-rpc send must succeed");
assert_eq!(request_id, 1); assert_eq!(request_id, 1);
let mut error_seen = false; let mut error_seen = false;
for _ in 0..4 { for _ in 0..6 {
let event = recv_event(&mut receiver).await; let event = recv_event(&mut receiver).await;
match event { match event {
crate::WsEvent::JsonRpcMessage { message, .. } => match message { crate::WsEvent::JsonRpcMessage { message, .. } => match message {
@@ -1308,6 +1734,7 @@ mod tests {
} }
} }
assert!(error_seen, "json-rpc error response must be observed"); assert!(error_seen, "json-rpc error response must be observed");
assert_eq!(client.pending_request_count().await, 0);
client.disconnect().await.expect("disconnect must succeed"); client.disconnect().await.expect("disconnect must succeed");
server.shutdown().await; server.shutdown().await;
} }