diff --git a/CHANGELOG.md b/CHANGELOG.md index 931d15a..bd7cdca 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,5 @@ 0.0.1 - initial skel -0.0.2 - Socle conforme \ No newline at end of file +0.0.2 - Socle conforme +0.1.0 - Transport WebSocket générique \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index 20fd36a..e419c6e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ members = [ ] [workspace.package] -version = "0.0.2" +version = "0.1.0" edition = "2024" license = "MIT" repository = "https://git.sasedev.com/Sasedev/khadhroony-bobobot" @@ -23,7 +23,7 @@ chacha20poly1305 = { version = "^0.10", features = ["std", "stream"] } chrono = { version = "^0.4", features = ["serde"] } fs2 = { version = "^0.4", features = [] } futures-util = { version = "^0.3", features = ["default", "std" ,"futures-sink"] } -jsonschema = { version = "^0.40", features = [] } +jsonschema = { version = "^0.46", features = [] } rand = { version = "^0.10", features = ["std", "serde", "sys_rng"] } reqwest = { version = "^0.13", default-features = false, features = ["charset", "cookies", "deflate", "form", "gzip", "http2", "json", "multipart", "query", "rustls", "socks", "stream", "zstd"] } rustls = { version = "^0.23", features = ["aws-lc-rs"] } diff --git a/kb_lib/src/config.rs b/kb_lib/src/config.rs index 511cde4..e01cf28 100644 --- a/kb_lib/src/config.rs +++ b/kb_lib/src/config.rs @@ -15,124 +15,6 @@ pub struct KbConfig { pub solana: KbSolanaConfig, } -/// Generic application settings. -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] -pub struct KbAppConfig { - /// Human-readable application name. - pub name: std::string::String, - /// Current environment name such as `development` or `production`. - pub environment: std::string::String, - /// Default reconnection preference used by future UI settings. - pub auto_reconnect_default: bool, -} - -/// Logging and tracing configuration. -/// -/// In version `0.0.2`, the project actively uses: -/// `level`, `console_enabled`, `console_ansi`, `file_enabled`, -/// `directory`, `file_prefix`, and `rotation`. -/// -/// The fields `message_format` and `time_format` are already stored in the -/// configuration so that the format policy is stabilized early, even though -/// their handling will be refined in later versions. -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] -pub struct KbLoggingConfig { - /// Global default log level. - pub level: std::string::String, - /// Enables console logging. - pub console_enabled: bool, - /// Enables ANSI colors on console output. - pub console_ansi: bool, - /// Enables file logging. - pub file_enabled: bool, - /// Directory where log files are stored. - pub directory: std::string::String, - /// Prefix used for log file names. - pub file_prefix: std::string::String, - /// File rotation strategy such as `daily`, `hourly`, or `never`. - pub rotation: std::string::String, - /// Preferred message formatting preset. - pub message_format: std::string::String, - /// Preferred time formatting preset. - pub time_format: std::string::String, - /// Per-target log level overrides. - pub target_filters: std::collections::BTreeMap, -} - -/// Local data paths used by the application. -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] -pub struct KbDataConfig { - /// SQLite database path. - pub sqlite_path: std::string::String, - /// Directory storing Solana wallets and related material in future versions. - pub wallets_directory: std::string::String, -} - -/// Solana transport configuration. -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] -pub struct KbSolanaConfig { - /// Named HTTP endpoints. - pub http_endpoints: std::vec::Vec, - /// Named WebSocket endpoints. - pub ws_endpoints: std::vec::Vec, -} - -/// HTTP endpoint configuration. -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] -pub struct KbHttpEndpointConfig { - /// Stable internal endpoint name used by the application. - pub name: std::string::String, - /// Enables or disables the endpoint. - pub enabled: bool, - /// Provider name such as `solana-public`, `helius`, or `custom`. - pub provider: std::string::String, - /// Base HTTP RPC URL. - pub url: std::string::String, - /// Optional environment variable name used to resolve an API key later. - pub api_key_env_var: std::option::Option, - /// Logical roles assigned to this endpoint. - pub roles: std::vec::Vec, - /// Allowed average request rate. - pub requests_per_second: u32, - /// Burst capacity for future rate-limiting. - pub burst: u32, - /// HTTP connect timeout in milliseconds. - pub connect_timeout_ms: u64, - /// HTTP request timeout in milliseconds. - pub request_timeout_ms: u64, -} - -/// WebSocket endpoint configuration. -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] -pub struct KbWsEndpointConfig { - /// Stable internal endpoint name used by the application. - pub name: std::string::String, - /// Enables or disables the endpoint. - pub enabled: bool, - /// Provider name such as `solana-public`, `helius`, or `custom`. - pub provider: std::string::String, - /// Base WebSocket RPC URL. - pub url: std::string::String, - /// Optional environment variable name used to resolve an API key later. - pub api_key_env_var: std::option::Option, - /// Logical roles assigned to this endpoint. - pub roles: std::vec::Vec, - /// Maximum number of subscriptions allowed on this endpoint. - pub max_subscriptions: u32, - /// WebSocket connect timeout in milliseconds. - pub connect_timeout_ms: u64, - /// Timeout for request/response round-trips in milliseconds. - pub request_timeout_ms: u64, - /// Timeout used during unsubscribe on disconnect in milliseconds. - pub unsubscribe_timeout_ms: u64, - /// Capacity of the future outgoing write channel. - pub write_channel_capacity: usize, - /// Capacity of the future event channel. - pub event_channel_capacity: usize, - /// Enables future automatic reconnection behavior. - pub auto_reconnect: bool, -} - impl KbConfig { /// Returns the default path of the JSON configuration file. pub fn default_path() -> std::path::PathBuf { @@ -419,6 +301,157 @@ impl KbConfig { } } +/// Generic application settings. +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct KbAppConfig { + /// Human-readable application name. + pub name: std::string::String, + /// Current environment name such as `development` or `production`. + pub environment: std::string::String, + /// Default reconnection preference used by future UI settings. + pub auto_reconnect_default: bool, +} + +/// Logging and tracing configuration. +/// +/// In version `0.0.2`, the project actively uses: +/// `level`, `console_enabled`, `console_ansi`, `file_enabled`, +/// `directory`, `file_prefix`, and `rotation`. +/// +/// The fields `message_format` and `time_format` are already stored in the +/// configuration so that the format policy is stabilized early, even though +/// their handling will be refined in later versions. +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct KbLoggingConfig { + /// Global default log level. + pub level: std::string::String, + /// Enables console logging. + pub console_enabled: bool, + /// Enables ANSI colors on console output. + pub console_ansi: bool, + /// Enables file logging. + pub file_enabled: bool, + /// Directory where log files are stored. + pub directory: std::string::String, + /// Prefix used for log file names. + pub file_prefix: std::string::String, + /// File rotation strategy such as `daily`, `hourly`, or `never`. + pub rotation: std::string::String, + /// Preferred message formatting preset. + pub message_format: std::string::String, + /// Preferred time formatting preset. + pub time_format: std::string::String, + /// Per-target log level overrides. + pub target_filters: std::collections::BTreeMap, +} + +impl KbLoggingConfig { + /// Returns the resolved logging directory path. + pub fn directory_path(&self) -> std::path::PathBuf { + kb_resolve_workspace_relative_path(&self.directory) + } +} + +/// Local data paths used by the application. +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct KbDataConfig { + /// SQLite database path. + pub sqlite_path: std::string::String, + /// Directory storing Solana wallets and related material in future versions. + pub wallets_directory: std::string::String, +} + +impl KbDataConfig { + /// Returns the resolved SQLite database path. + pub fn sqlite_path_buf(&self) -> std::path::PathBuf { + kb_resolve_workspace_relative_path(&self.sqlite_path) + } + + /// Returns the resolved wallets directory path. + pub fn wallets_directory_path(&self) -> std::path::PathBuf { + kb_resolve_workspace_relative_path(&self.wallets_directory) + } +} + +/// Solana transport configuration. +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct KbSolanaConfig { + /// Named HTTP endpoints. + pub http_endpoints: std::vec::Vec, + /// Named WebSocket endpoints. + pub ws_endpoints: std::vec::Vec, +} + +/// HTTP endpoint configuration. +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct KbHttpEndpointConfig { + /// Stable internal endpoint name used by the application. + pub name: std::string::String, + /// Enables or disables the endpoint. + pub enabled: bool, + /// Provider name such as `solana-public`, `helius`, or `custom`. + pub provider: std::string::String, + /// Base HTTP RPC URL. + pub url: std::string::String, + /// Optional environment variable name used to resolve an API key later. + pub api_key_env_var: std::option::Option, + /// Logical roles assigned to this endpoint. + pub roles: std::vec::Vec, + /// Allowed average request rate. + pub requests_per_second: u32, + /// Burst capacity for future rate-limiting. + pub burst: u32, + /// HTTP connect timeout in milliseconds. + pub connect_timeout_ms: u64, + /// HTTP request timeout in milliseconds. + pub request_timeout_ms: u64, +} + +impl KbHttpEndpointConfig { + /// Returns the resolved endpoint URL. + pub fn resolved_url(&self) -> Result { + kb_resolve_endpoint_url(&self.url, &self.api_key_env_var) + } +} + +/// WebSocket endpoint configuration. +#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] +pub struct KbWsEndpointConfig { + /// Stable internal endpoint name used by the application. + pub name: std::string::String, + /// Enables or disables the endpoint. + pub enabled: bool, + /// Provider name such as `solana-public`, `helius`, or `custom`. + pub provider: std::string::String, + /// Base WebSocket RPC URL. + pub url: std::string::String, + /// Optional environment variable name used to resolve an API key later. + pub api_key_env_var: std::option::Option, + /// Logical roles assigned to this endpoint. + pub roles: std::vec::Vec, + /// Maximum number of subscriptions allowed on this endpoint. + pub max_subscriptions: u32, + /// WebSocket connect timeout in milliseconds. + pub connect_timeout_ms: u64, + /// Timeout for request/response round-trips in milliseconds. + pub request_timeout_ms: u64, + /// Timeout used during unsubscribe on disconnect in milliseconds. + pub unsubscribe_timeout_ms: u64, + /// Capacity of the future outgoing write channel. + pub write_channel_capacity: usize, + /// Capacity of the future event channel. + pub event_channel_capacity: usize, + /// Enables future automatic reconnection behavior. + pub auto_reconnect: bool, +} + +impl KbWsEndpointConfig { + /// Returns the resolved endpoint URL. + pub fn resolved_url(&self) -> Result { + kb_resolve_endpoint_url(&self.url, &self.api_key_env_var) + } +} + fn kb_workspace_root_dir() -> std::path::PathBuf { let manifest_dir = std::path::PathBuf::from(env!("CARGO_MANIFEST_DIR")); match manifest_dir.parent() { @@ -435,21 +468,30 @@ fn kb_resolve_workspace_relative_path>(path: P) -> std kb_workspace_root_dir().join(input_path) } -impl KbLoggingConfig { - /// Returns the resolved logging directory path. - pub fn directory_path(&self) -> std::path::PathBuf { - kb_resolve_workspace_relative_path(&self.directory) - } -} - -impl KbDataConfig { - /// Returns the resolved SQLite database path. - pub fn sqlite_path_buf(&self) -> std::path::PathBuf { - kb_resolve_workspace_relative_path(&self.sqlite_path) - } - - /// Returns the resolved wallets directory path. - pub fn wallets_directory_path(&self) -> std::path::PathBuf { - kb_resolve_workspace_relative_path(&self.wallets_directory) +fn kb_resolve_endpoint_url( + url: &str, + api_key_env_var: &std::option::Option, +) -> Result { + let env_var_name_option = api_key_env_var.as_deref(); + let env_var_name = match env_var_name_option { + Some(env_var_name) => env_var_name, + None => { + return Ok(url.to_string()); + } + }; + let placeholder = format!("${{{env_var_name}}}"); + if !url.contains(&placeholder) { + return Ok(url.to_string()); } + let env_value_result = std::env::var(env_var_name); + let env_value = match env_value_result { + Ok(env_value) => env_value, + Err(error) => { + return Err(crate::KbError::Config(format!( + "environment variable '{}' is required to resolve endpoint url '{}': {error}", + env_var_name, url + ))); + } + }; + Ok(url.replace(&placeholder, &env_value)) } diff --git a/kb_lib/src/lib.rs b/kb_lib/src/lib.rs index 565bffe..1fb2a6e 100644 --- a/kb_lib/src/lib.rs +++ b/kb_lib/src/lib.rs @@ -32,3 +32,5 @@ pub use crate::tracing::KbTracingGuard; pub use crate::tracing::init_tracing; pub use crate::types::KbConnectionState; pub use crate::ws_client::WsClient; +pub use crate::ws_client::WsEvent; +pub use crate::ws_client::WsOutgoingMessage; diff --git a/kb_lib/src/ws_client.rs b/kb_lib/src/ws_client.rs index 2d38906..c1087bc 100644 --- a/kb_lib/src/ws_client.rs +++ b/kb_lib/src/ws_client.rs @@ -1,34 +1,165 @@ // file: kb_lib/src/ws_client.rs -//! Generic asynchronous WebSocket client skeleton. +//! Generic asynchronous WebSocket transport client. //! -//! This module prepares the shape of the future Solana WebSocket transport. -//! The actual transport loop, split read/write tasks, request tracking, -//! subscribe registry, and notification routing are scheduled for `0.1.x` -//! and `0.2.x` / `0.3.x`. +//! Version `0.1.x` provides a reusable transport-level client built on top of +//! `tokio-tungstenite`. +//! +//! Scope of this version: +//! - explicit connect / disconnect +//! - separate read and write tasks +//! - bounded outgoing channel +//! - broadcast event stream +//! - incremental request identifier generator +//! - graceful close with timeout and task cancellation fallback +//! +//! JSON-RPC request / response matching, subscribe / unsubscribe tracking, +//! and Solana-specific notification routing are intentionally left for later +//! versions. -/// Generic asynchronous WebSocket client placeholder. +use futures_util::SinkExt; +use futures_util::StreamExt; + +/// Outgoing transport-level WebSocket message. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum WsOutgoingMessage { + /// UTF-8 text message. + Text(std::string::String), + /// Binary message. + Binary(std::vec::Vec), + /// Ping message. + Ping(std::vec::Vec), + /// Pong message. + Pong(std::vec::Vec), + /// Close handshake initiation. + Close, +} + +/// Incoming WebSocket transport event emitted by [`crate::WsClient`]. +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum WsEvent { + /// Connection established successfully. + Connected { + /// Stable endpoint name from configuration. + endpoint_name: std::string::String, + /// Resolved endpoint URL actually used for the connection attempt. + endpoint_url: std::string::String, + }, + /// Text message received. + TextMessage { + /// Stable endpoint name from configuration. + endpoint_name: std::string::String, + /// Received text payload. + text: std::string::String, + }, + /// Binary message received. + BinaryMessage { + /// Stable endpoint name from configuration. + endpoint_name: std::string::String, + /// Received binary payload. + data: std::vec::Vec, + }, + /// Ping frame received. + Ping { + /// Stable endpoint name from configuration. + endpoint_name: std::string::String, + /// Ping payload. + data: std::vec::Vec, + }, + /// Pong frame received. + Pong { + /// Stable endpoint name from configuration. + endpoint_name: std::string::String, + /// Pong payload. + data: std::vec::Vec, + }, + /// Close frame received. + CloseReceived { + /// Stable endpoint name from configuration. + endpoint_name: std::string::String, + /// Optional close code. + code: std::option::Option, + /// Optional textual reason. + reason: std::option::Option, + }, + /// Connection lifecycle reached the disconnected state. + Disconnected { + /// Stable endpoint name from configuration. + endpoint_name: std::string::String, + }, + /// Transport-level error. + Error { + /// Stable endpoint name from configuration. + endpoint_name: std::string::String, + /// Captured error. + error: crate::KbError, + }, +} + +#[derive(Debug)] +struct WsClientRuntime { + generation: u64, + writer_tx: std::option::Option>, + shutdown_tx: std::option::Option>, + completion_notify: std::option::Option>, + read_abort_handle: std::option::Option, + write_abort_handle: std::option::Option, + supervisor_abort_handle: std::option::Option, +} + +impl WsClientRuntime { + fn new() -> Self { + Self { + generation: 0, + writer_tx: None, + shutdown_tx: None, + completion_notify: None, + read_abort_handle: None, + write_abort_handle: None, + supervisor_abort_handle: None, + } + } +} + +#[derive(Clone, Debug)] +enum WsWriteCommand { + Send(WsOutgoingMessage), +} + +/// Generic asynchronous WebSocket client. #[derive(Clone, Debug)] pub struct WsClient { endpoint: crate::KbWsEndpointConfig, + resolved_url: std::string::String, next_request_id: std::sync::Arc, state: std::sync::Arc>, + event_tx: tokio::sync::broadcast::Sender, + runtime: std::sync::Arc>, } impl WsClient { - /// Creates a new WebSocket client bound to a named endpoint configuration. + /// Creates a new WebSocket client bound to an endpoint configuration. pub fn new(endpoint: crate::KbWsEndpointConfig) -> Result { if endpoint.name.trim().is_empty() { return Err(crate::KbError::Config( "ws client endpoint name must not be empty".to_string(), )); } + let resolved_url_result = endpoint.resolved_url(); + let resolved_url = match resolved_url_result { + Ok(resolved_url) => resolved_url, + Err(error) => return Err(error), + }; + let (event_tx, _) = tokio::sync::broadcast::channel(endpoint.event_channel_capacity); Ok(Self { endpoint, + resolved_url, next_request_id: std::sync::Arc::new(std::sync::atomic::AtomicU64::new(1)), state: std::sync::Arc::new(tokio::sync::RwLock::new( crate::KbConnectionState::Disconnected, )), + event_tx, + runtime: std::sync::Arc::new(tokio::sync::Mutex::new(WsClientRuntime::new())), }) } @@ -37,9 +168,9 @@ impl WsClient { &self.endpoint.name } - /// Returns the endpoint URL of this client. + /// Returns the resolved endpoint URL of this client. pub fn endpoint_url(&self) -> &str { - &self.endpoint.url + &self.resolved_url } /// Returns the endpoint configuration of this client. @@ -47,6 +178,11 @@ impl WsClient { &self.endpoint } + /// Returns a new broadcast receiver subscribed to transport events. + pub fn subscribe_events(&self) -> tokio::sync::broadcast::Receiver { + self.event_tx.subscribe() + } + /// Returns the next request identifier and increments the internal counter. pub fn next_request_id(&self) -> u64 { self.next_request_id @@ -61,31 +197,867 @@ impl WsClient { /// Connects the client to its remote WebSocket endpoint. pub async fn connect(&self) -> Result<(), crate::KbError> { - Err(crate::KbError::NotImplemented( - "WsClient::connect is scheduled for version 0.1.x".to_string(), - )) + if !self.endpoint.enabled { + return Err(crate::KbError::InvalidState(format!( + "ws endpoint '{}' is disabled in configuration", + self.endpoint.name + ))); + } + let current_state = self.connection_state().await; + if current_state != crate::KbConnectionState::Disconnected { + return Err(crate::KbError::InvalidState(format!( + "ws client '{}' cannot connect from state {:?}", + self.endpoint.name, current_state + ))); + } + { + let mut state_guard = self.state.write().await; + *state_guard = crate::KbConnectionState::Connecting; + } + tracing::info!( + endpoint_name = %self.endpoint.name, + endpoint_url = %self.resolved_url, + "connecting websocket client" + ); + let connect_timeout = std::time::Duration::from_millis(self.endpoint.connect_timeout_ms); + let connect_future = tokio_tungstenite::connect_async(self.resolved_url.clone()); + let timeout_result = tokio::time::timeout(connect_timeout, connect_future).await; + let connect_result = match timeout_result { + Ok(connect_result) => connect_result, + Err(_) => { + let error = crate::KbError::Ws(format!( + "timeout while connecting websocket endpoint '{}'", + self.endpoint.name + )); + self.emit_event(WsEvent::Error { + endpoint_name: self.endpoint.name.clone(), + error: error.clone(), + }); + let mut state_guard = self.state.write().await; + *state_guard = crate::KbConnectionState::Disconnected; + return Err(error); + } + }; + let (ws_stream, _response) = match connect_result { + Ok(parts) => parts, + Err(error) => { + let kb_error = crate::KbError::Ws(format!( + "cannot connect websocket endpoint '{}': {error}", + self.endpoint.name + )); + self.emit_event(WsEvent::Error { + endpoint_name: self.endpoint.name.clone(), + error: kb_error.clone(), + }); + let mut state_guard = self.state.write().await; + *state_guard = crate::KbConnectionState::Disconnected; + return Err(kb_error); + } + }; + + let (write_half, read_half) = ws_stream.split(); + let (writer_tx, writer_rx) = + tokio::sync::mpsc::channel(self.endpoint.write_channel_capacity); + let (shutdown_tx, shutdown_rx) = tokio::sync::watch::channel(false); + let completion_notify = std::sync::Arc::new(tokio::sync::Notify::new()); + let read_client = self.clone(); + let read_shutdown_tx = shutdown_tx.clone(); + let read_handle = tokio::spawn(async move { + read_client.run_read_loop(read_half, read_shutdown_tx).await; + }); + let read_abort_handle = read_handle.abort_handle(); + let write_client = self.clone(); + let write_handle = tokio::spawn(async move { + write_client + .run_write_loop(write_half, writer_rx, shutdown_rx) + .await; + }); + let write_abort_handle = write_handle.abort_handle(); + let generation = { + let mut runtime_guard = self.runtime.lock().await; + if runtime_guard.writer_tx.is_some() { + let mut state_guard = self.state.write().await; + *state_guard = crate::KbConnectionState::Disconnected; + return Err(crate::KbError::InvalidState(format!( + "ws client '{}' runtime is already active", + self.endpoint.name + ))); + } + runtime_guard.generation = runtime_guard.generation.saturating_add(1); + runtime_guard.writer_tx = Some(writer_tx.clone()); + runtime_guard.shutdown_tx = Some(shutdown_tx.clone()); + runtime_guard.completion_notify = Some(completion_notify.clone()); + runtime_guard.read_abort_handle = Some(read_abort_handle); + runtime_guard.write_abort_handle = Some(write_abort_handle); + runtime_guard.generation + }; + let supervisor_client = self.clone(); + let supervisor_notify = completion_notify.clone(); + let supervisor_handle = tokio::spawn(async move { + supervisor_client + .run_supervisor(generation, read_handle, write_handle, supervisor_notify) + .await; + }); + let supervisor_abort_handle = supervisor_handle.abort_handle(); + { + let mut runtime_guard = self.runtime.lock().await; + if runtime_guard.generation == generation { + runtime_guard.supervisor_abort_handle = Some(supervisor_abort_handle); + } + } + { + let mut state_guard = self.state.write().await; + *state_guard = crate::KbConnectionState::Connected; + } + self.emit_event(WsEvent::Connected { + endpoint_name: self.endpoint.name.clone(), + endpoint_url: self.resolved_url.clone(), + }); + tracing::info!( + endpoint_name = %self.endpoint.name, + endpoint_url = %self.resolved_url, + "websocket client connected" + ); + Ok(()) } - /// Sends a text frame through the WebSocket connection. - pub async fn send_text(&self, _text: std::string::String) -> Result<(), crate::KbError> { - Err(crate::KbError::NotImplemented( - "WsClient::send_text is scheduled for version 0.1.x".to_string(), - )) + /// Sends a transport-level WebSocket message. + pub async fn send_message(&self, message: WsOutgoingMessage) -> Result<(), crate::KbError> { + let writer_tx_option = { + let runtime_guard = self.runtime.lock().await; + runtime_guard.writer_tx.clone() + }; + let writer_tx = match writer_tx_option { + Some(writer_tx) => writer_tx, + None => { + return Err(crate::KbError::NotConnected(format!( + "ws client '{}' is not connected", + self.endpoint.name + ))); + } + }; + let queue_timeout = std::time::Duration::from_millis(self.endpoint.request_timeout_ms); + let send_future = writer_tx.send(WsWriteCommand::Send(message)); + let timeout_result = tokio::time::timeout(queue_timeout, send_future).await; + match timeout_result { + Ok(send_result) => match send_result { + Ok(()) => Ok(()), + Err(error) => Err(crate::KbError::Ws(format!( + "cannot queue outgoing websocket message for endpoint '{}': {error}", + self.endpoint.name + ))), + }, + Err(_) => Err(crate::KbError::Ws(format!( + "timeout while queueing outgoing websocket message for endpoint '{}'", + self.endpoint.name + ))), + } } - /// Sends a JSON value through the WebSocket connection. - pub async fn send_json_value(&self, _value: &serde_json::Value) -> Result<(), crate::KbError> { - Err(crate::KbError::NotImplemented( - "WsClient::send_json_value is scheduled for version 0.2.x".to_string(), - )) + /// Sends a UTF-8 text message. + pub async fn send_text(&self, text: std::string::String) -> Result<(), crate::KbError> { + self.send_message(WsOutgoingMessage::Text(text)).await + } + + /// Sends a binary message. + pub async fn send_binary(&self, data: std::vec::Vec) -> Result<(), crate::KbError> { + self.send_message(WsOutgoingMessage::Binary(data)).await + } + + /// Sends a ping message. + pub async fn send_ping(&self, data: std::vec::Vec) -> Result<(), crate::KbError> { + self.send_message(WsOutgoingMessage::Ping(data)).await + } + + /// Sends a pong message. + pub async fn send_pong(&self, data: std::vec::Vec) -> Result<(), crate::KbError> { + self.send_message(WsOutgoingMessage::Pong(data)).await + } + + /// Serializes and sends a JSON value as a text message. + pub async fn send_json_value(&self, value: &serde_json::Value) -> Result<(), crate::KbError> { + let serialization_result = serde_json::to_string(value); + let text = match serialization_result { + Ok(text) => text, + Err(error) => { + return Err(crate::KbError::Json(format!( + "cannot serialize websocket json payload for endpoint '{}': {error}", + self.endpoint.name + ))); + } + }; + self.send_text(text).await + } + + /// Initiates the close handshake. + pub async fn send_close(&self) -> Result<(), crate::KbError> { + self.send_message(WsOutgoingMessage::Close).await } /// Disconnects the client from its remote endpoint. /// - /// The final implementation will unsubscribe with timeout before close. + /// This method initiates a close handshake, signals shutdown, waits for the + /// transport tasks to complete, and aborts them if the timeout is exceeded. pub async fn disconnect(&self) -> Result<(), crate::KbError> { - Err(crate::KbError::NotImplemented( - "WsClient::disconnect is scheduled for version 0.1.x / 0.3.x".to_string(), - )) + let current_state = self.connection_state().await; + if current_state == crate::KbConnectionState::Disconnected { + return Ok(()); + } + if current_state == crate::KbConnectionState::Connecting { + return Err(crate::KbError::InvalidState(format!( + "ws client '{}' cannot disconnect while still connecting", + self.endpoint.name + ))); + } + { + let mut state_guard = self.state.write().await; + *state_guard = crate::KbConnectionState::Disconnecting; + } + tracing::info!( + endpoint_name = %self.endpoint.name, + "disconnecting websocket client" + ); + let ( + generation, + writer_tx_option, + shutdown_tx_option, + completion_notify_option, + read_abort_handle_option, + write_abort_handle_option, + supervisor_abort_handle_option, + ) = { + let runtime_guard = self.runtime.lock().await; + ( + runtime_guard.generation, + runtime_guard.writer_tx.clone(), + runtime_guard.shutdown_tx.clone(), + runtime_guard.completion_notify.clone(), + runtime_guard.read_abort_handle.clone(), + runtime_guard.write_abort_handle.clone(), + runtime_guard.supervisor_abort_handle.clone(), + ) + }; + if let Some(writer_tx) = writer_tx_option { + let close_timeout = std::time::Duration::from_millis(self.endpoint.request_timeout_ms); + let close_send_future = writer_tx.send(WsWriteCommand::Send(WsOutgoingMessage::Close)); + let close_send_timeout_result = + tokio::time::timeout(close_timeout, close_send_future).await; + match close_send_timeout_result { + Ok(close_send_result) => { + if let Err(error) = close_send_result { + tracing::warn!( + endpoint_name = %self.endpoint.name, + "cannot queue close frame during disconnect: {error}" + ); + } + } + Err(_) => { + tracing::warn!( + endpoint_name = %self.endpoint.name, + "timeout while queueing close frame during disconnect" + ); + } + } + } + if let Some(shutdown_tx) = shutdown_tx_option { + let send_result = shutdown_tx.send(true); + if let Err(error) = send_result { + tracing::debug!( + endpoint_name = %self.endpoint.name, + "shutdown signal could not be delivered because receiver is already gone: {error}" + ); + } + } + let completion_notify = match completion_notify_option { + Some(completion_notify) => completion_notify, + None => { + let state_changed = self.clear_runtime_after_generation(generation).await; + if state_changed { + self.emit_event(WsEvent::Disconnected { + endpoint_name: self.endpoint.name.clone(), + }); + } + return Ok(()); + } + }; + let completion_timeout = std::time::Duration::from_millis(self.endpoint.request_timeout_ms); + let notified_future = completion_notify.notified(); + let completion_result = tokio::time::timeout(completion_timeout, notified_future).await; + match completion_result { + Ok(()) => Ok(()), + Err(_) => { + tracing::warn!( + endpoint_name = %self.endpoint.name, + "disconnect timeout reached, aborting websocket tasks" + ); + if let Some(read_abort_handle) = read_abort_handle_option { + read_abort_handle.abort(); + } + if let Some(write_abort_handle) = write_abort_handle_option { + write_abort_handle.abort(); + } + if let Some(supervisor_abort_handle) = supervisor_abort_handle_option { + supervisor_abort_handle.abort(); + } + let state_changed = self.clear_runtime_after_generation(generation).await; + if state_changed { + self.emit_event(WsEvent::Disconnected { + endpoint_name: self.endpoint.name.clone(), + }); + } + Ok(()) + } + } + } + + async fn run_supervisor( + &self, + generation: u64, + read_handle: tokio::task::JoinHandle<()>, + write_handle: tokio::task::JoinHandle<()>, + completion_notify: std::sync::Arc, + ) { + let read_join_result = read_handle.await; + if let Err(error) = read_join_result { + let kb_error = crate::KbError::Ws(format!( + "read task for endpoint '{}' ended with join error: {error}", + self.endpoint.name + )); + self.emit_event(WsEvent::Error { + endpoint_name: self.endpoint.name.clone(), + error: kb_error, + }); + } + let write_join_result = write_handle.await; + if let Err(error) = write_join_result { + let kb_error = crate::KbError::Ws(format!( + "write task for endpoint '{}' ended with join error: {error}", + self.endpoint.name + )); + self.emit_event(WsEvent::Error { + endpoint_name: self.endpoint.name.clone(), + error: kb_error, + }); + } + let state_changed = self.clear_runtime_after_generation(generation).await; + completion_notify.notify_waiters(); + if state_changed { + self.emit_event(WsEvent::Disconnected { + endpoint_name: self.endpoint.name.clone(), + }); + } + } + + async fn run_read_loop( + &self, + mut read_half: futures_util::stream::SplitStream>, + shutdown_tx: tokio::sync::watch::Sender, + ) where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + std::marker::Send + 'static, + { + loop { + let message_option = read_half.next().await; + let message_result = match message_option { + Some(message_result) => message_result, + None => { + let _ = shutdown_tx.send(true); + break; + } + }; + let message = match message_result { + Ok(message) => message, + Err(error) => { + if !kb_is_normal_close_error(&error) { + self.emit_event(WsEvent::Error { + endpoint_name: self.endpoint.name.clone(), + error: crate::KbError::Ws(format!( + "read error on endpoint '{}': {error}", + self.endpoint.name + )), + }); + } + let _ = shutdown_tx.send(true); + break; + } + }; + match message { + tokio_tungstenite::tungstenite::Message::Text(text) => { + self.emit_event(WsEvent::TextMessage { + endpoint_name: self.endpoint.name.clone(), + text: text.to_string(), + }); + } + tokio_tungstenite::tungstenite::Message::Binary(data) => { + self.emit_event(WsEvent::BinaryMessage { + endpoint_name: self.endpoint.name.clone(), + data: data.to_vec(), + }); + } + tokio_tungstenite::tungstenite::Message::Ping(data) => { + self.emit_event(WsEvent::Ping { + endpoint_name: self.endpoint.name.clone(), + data: data.to_vec(), + }); + } + tokio_tungstenite::tungstenite::Message::Pong(data) => { + self.emit_event(WsEvent::Pong { + endpoint_name: self.endpoint.name.clone(), + data: data.to_vec(), + }); + } + tokio_tungstenite::tungstenite::Message::Close(frame_option) => { + let mut code_option = None; + let mut reason_option = None; + if let Some(frame) = frame_option { + let code: u16 = frame.code.into(); + code_option = Some(code); + if !frame.reason.is_empty() { + reason_option = Some(frame.reason.to_string()); + } + } + self.emit_event(WsEvent::CloseReceived { + endpoint_name: self.endpoint.name.clone(), + code: code_option, + reason: reason_option, + }); + let _ = shutdown_tx.send(true); + break; + } + tokio_tungstenite::tungstenite::Message::Frame(_frame) => { + tracing::trace!( + endpoint_name = %self.endpoint.name, + "ignoring internal tungstenite frame variant" + ); + } + } + } + } + + async fn run_write_loop( + &self, + mut write_half: futures_util::stream::SplitSink< + tokio_tungstenite::WebSocketStream, + tokio_tungstenite::tungstenite::Message, + >, + mut writer_rx: tokio::sync::mpsc::Receiver, + mut shutdown_rx: tokio::sync::watch::Receiver, + ) where + S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + std::marker::Send + 'static, + { + let close_sent = false; + loop { + tokio::select! { + shutdown_result = shutdown_rx.changed() => { + match shutdown_result { + Ok(()) => { + if *shutdown_rx.borrow() { + if !close_sent { + let close_result = write_half.send( + tokio_tungstenite::tungstenite::Message::Close(None) + ).await; + match close_result { + Ok(()) => {}, + Err(error) => { + if !kb_is_normal_close_error(&error) { + self.emit_event(WsEvent::Error { + endpoint_name: self.endpoint.name.clone(), + error: crate::KbError::Ws(format!( + "write close error on endpoint '{}': {error}", + self.endpoint.name + )), + }); + } + }, + } + } + let flush_result = write_half.flush().await; + if let Err(error) = flush_result { + if !kb_is_normal_close_error(&error) { + self.emit_event(WsEvent::Error { + endpoint_name: self.endpoint.name.clone(), + error: crate::KbError::Ws(format!( + "write flush error on endpoint '{}': {error}", + self.endpoint.name + )), + }); + } + } + let close_sink_result = write_half.close().await; + if let Err(error) = close_sink_result { + if !kb_is_normal_close_error(&error) { + self.emit_event(WsEvent::Error { + endpoint_name: self.endpoint.name.clone(), + error: crate::KbError::Ws(format!( + "sink close error on endpoint '{}': {error}", + self.endpoint.name + )), + }); + } + } + break; + } + }, + Err(_error) => { + break; + }, + } + }, + command_option = writer_rx.recv() => { + let command = match command_option { + Some(command) => command, + None => { + let flush_result = write_half.flush().await; + if let Err(error) = flush_result { + if !kb_is_normal_close_error(&error) { + self.emit_event(WsEvent::Error { + endpoint_name: self.endpoint.name.clone(), + error: crate::KbError::Ws(format!( + "write flush error on endpoint '{}': {error}", + self.endpoint.name + )), + }); + } + } + let close_sink_result = write_half.close().await; + if let Err(error) = close_sink_result { + if !kb_is_normal_close_error(&error) { + self.emit_event(WsEvent::Error { + endpoint_name: self.endpoint.name.clone(), + error: crate::KbError::Ws(format!( + "sink close error on endpoint '{}': {error}", + self.endpoint.name + )), + }); + } + } + break; + }, + }; + match command { + WsWriteCommand::Send(message) => { + let tungstenite_message = kb_convert_outgoing_message(message.clone()); + let send_result = write_half.send(tungstenite_message).await; + match send_result { + Ok(()) => { + if message == WsOutgoingMessage::Close { + let flush_result = write_half.flush().await; + if let Err(error) = flush_result { + if !kb_is_normal_close_error(&error) { + self.emit_event(WsEvent::Error { + endpoint_name: self.endpoint.name.clone(), + error: crate::KbError::Ws(format!( + "write flush error on endpoint '{}': {error}", + self.endpoint.name + )), + }); + } + } + break; + } + }, + Err(error) => { + if !kb_is_normal_close_error(&error) { + self.emit_event(WsEvent::Error { + endpoint_name: self.endpoint.name.clone(), + error: crate::KbError::Ws(format!( + "write error on endpoint '{}': {error}", + self.endpoint.name + )), + }); + } + break; + }, + } + }, + } + }, + } + } + } + + async fn clear_runtime_after_generation(&self, generation: u64) -> bool { + let should_clear = { + let mut runtime_guard = self.runtime.lock().await; + if runtime_guard.generation != generation { + false + } else { + runtime_guard.writer_tx = None; + runtime_guard.shutdown_tx = None; + runtime_guard.completion_notify = None; + runtime_guard.read_abort_handle = None; + runtime_guard.write_abort_handle = None; + runtime_guard.supervisor_abort_handle = None; + true + } + }; + if !should_clear { + return false; + } + let mut state_guard = self.state.write().await; + if *state_guard == crate::KbConnectionState::Disconnected { + return false; + } + *state_guard = crate::KbConnectionState::Disconnected; + true + } + + fn emit_event(&self, event: WsEvent) { + let send_result = self.event_tx.send(event); + if let Err(error) = send_result { + tracing::trace!( + endpoint_name = %self.endpoint.name, + "websocket event dropped because no receiver is currently subscribed: {error}" + ); + } + } +} + +fn kb_convert_outgoing_message( + message: WsOutgoingMessage, +) -> tokio_tungstenite::tungstenite::Message { + match message { + WsOutgoingMessage::Text(text) => tokio_tungstenite::tungstenite::Message::Text(text.into()), + WsOutgoingMessage::Binary(data) => { + tokio_tungstenite::tungstenite::Message::Binary(data.into()) + } + WsOutgoingMessage::Ping(data) => tokio_tungstenite::tungstenite::Message::Ping(data.into()), + WsOutgoingMessage::Pong(data) => tokio_tungstenite::tungstenite::Message::Pong(data.into()), + WsOutgoingMessage::Close => tokio_tungstenite::tungstenite::Message::Close(None), + } +} + +fn kb_is_normal_close_error(error: &tokio_tungstenite::tungstenite::Error) -> bool { + match error { + tokio_tungstenite::tungstenite::Error::ConnectionClosed => true, + tokio_tungstenite::tungstenite::Error::AlreadyClosed => true, + _ => false, + } +} + +#[cfg(test)] +mod tests { + use futures_util::SinkExt; + use futures_util::StreamExt; + + #[derive(Debug)] + struct TestWsServer { + url: std::string::String, + shutdown_tx: std::option::Option>, + } + + impl TestWsServer { + async fn spawn() -> Self { + let bind_result = tokio::net::TcpListener::bind("127.0.0.1:0").await; + let listener = bind_result.expect("listener bind must succeed"); + let local_addr = listener.local_addr().expect("local addr must be available"); + let (shutdown_tx, mut shutdown_rx) = tokio::sync::oneshot::channel::<()>(); + tokio::spawn(async move { + loop { + tokio::select! { + _ = &mut shutdown_rx => { + break; + }, + accept_result = listener.accept() => { + let (stream, _peer_addr) = accept_result.expect("accept must succeed"); + tokio::spawn(async move { + let accept_ws_result = tokio_tungstenite::accept_async(stream).await; + let mut ws_stream = accept_ws_result.expect("websocket accept must succeed"); + loop { + let message_option = ws_stream.next().await; + let message_result = match message_option { + Some(message_result) => message_result, + None => { + break; + }, + }; + let message = message_result.expect("message read must succeed"); + match message { + tokio_tungstenite::tungstenite::Message::Text(text) => { + ws_stream.send( + tokio_tungstenite::tungstenite::Message::Text(text) + ).await.expect("text echo must succeed"); + }, + tokio_tungstenite::tungstenite::Message::Binary(data) => { + ws_stream.send( + tokio_tungstenite::tungstenite::Message::Binary(data) + ).await.expect("binary echo must succeed"); + }, + tokio_tungstenite::tungstenite::Message::Ping(data) => { + ws_stream.send( + tokio_tungstenite::tungstenite::Message::Pong(data) + ).await.expect("pong reply must succeed"); + }, + tokio_tungstenite::tungstenite::Message::Pong(data) => { + ws_stream.send( + tokio_tungstenite::tungstenite::Message::Pong(data) + ).await.expect("pong echo must succeed"); + }, + tokio_tungstenite::tungstenite::Message::Close(frame) => { + let _ = ws_stream.send( + tokio_tungstenite::tungstenite::Message::Close(frame) + ).await; + break; + }, + tokio_tungstenite::tungstenite::Message::Frame(_frame) => {}, + } + } + }); + }, + } + } + }); + Self { + url: format!("ws://{}", local_addr), + shutdown_tx: Some(shutdown_tx), + } + } + + async fn shutdown(mut self) { + if let Some(shutdown_tx) = self.shutdown_tx.take() { + let _ = shutdown_tx.send(()); + } + } + } + + fn make_ws_endpoint(url: std::string::String) -> crate::KbWsEndpointConfig { + crate::KbWsEndpointConfig { + name: "test_ws".to_string(), + enabled: true, + provider: "test".to_string(), + url, + api_key_env_var: None, + roles: vec!["test".to_string()], + max_subscriptions: 16, + connect_timeout_ms: 2000, + request_timeout_ms: 2000, + unsubscribe_timeout_ms: 1000, + write_channel_capacity: 32, + event_channel_capacity: 64, + auto_reconnect: false, + } + } + + async fn recv_event( + receiver: &mut tokio::sync::broadcast::Receiver, + ) -> crate::WsEvent { + let timeout_result = + tokio::time::timeout(std::time::Duration::from_secs(2), receiver.recv()).await; + let recv_result = timeout_result.expect("event receive timeout must not occur"); + recv_result.expect("event receive must succeed") + } + + #[tokio::test] + async fn next_request_id_is_shared_between_clones() { + let endpoint = make_ws_endpoint("ws://127.0.0.1:65535".to_string()); + let client = crate::WsClient::new(endpoint).expect("client creation must succeed"); + let cloned = client.clone(); + assert_eq!(client.next_request_id(), 1); + assert_eq!(cloned.next_request_id(), 2); + assert_eq!(client.next_request_id(), 3); + } + + #[tokio::test] + async fn connect_send_text_and_disconnect() { + let server = TestWsServer::spawn().await; + let endpoint = make_ws_endpoint(server.url.clone()); + let client = crate::WsClient::new(endpoint).expect("client creation must succeed"); + let mut receiver = client.subscribe_events(); + client.connect().await.expect("connect must succeed"); + let connected_event = recv_event(&mut receiver).await; + match connected_event { + crate::WsEvent::Connected { + endpoint_name, + endpoint_url, + } => { + assert_eq!(endpoint_name, "test_ws"); + assert_eq!(endpoint_url, server.url); + } + other => { + panic!("unexpected connected event: {other:?}"); + } + } + client + .send_text("hello".to_string()) + .await + .expect("text send must succeed"); + let text_event = recv_event(&mut receiver).await; + match text_event { + crate::WsEvent::TextMessage { + endpoint_name, + text, + } => { + assert_eq!(endpoint_name, "test_ws"); + assert_eq!(text, "hello"); + } + other => { + panic!("unexpected text event: {other:?}"); + } + } + client.disconnect().await.expect("disconnect must succeed"); + assert_eq!( + client.connection_state().await, + crate::KbConnectionState::Disconnected + ); + let mut disconnected_seen = false; + for _ in 0..4 { + let event = recv_event(&mut receiver).await; + if let crate::WsEvent::Disconnected { endpoint_name } = event { + assert_eq!(endpoint_name, "test_ws"); + disconnected_seen = true; + break; + } + } + assert!(disconnected_seen, "disconnected event must be observed"); + server.shutdown().await; + } + + #[tokio::test] + async fn connect_twice_returns_invalid_state() { + let server = TestWsServer::spawn().await; + let endpoint = make_ws_endpoint(server.url.clone()); + let client = crate::WsClient::new(endpoint).expect("client creation must succeed"); + client.connect().await.expect("first connect must succeed"); + let second_connect_result = client.connect().await; + assert!(second_connect_result.is_err()); + let error = second_connect_result.expect_err("second connect must fail"); + match error { + crate::KbError::InvalidState(message) => { + assert!(message.contains("cannot connect")); + } + other => { + panic!("unexpected error variant: {other:?}"); + } + } + client.disconnect().await.expect("disconnect must succeed"); + server.shutdown().await; + } + + #[tokio::test] + async fn send_ping_receives_pong_event() { + let server = TestWsServer::spawn().await; + let endpoint = make_ws_endpoint(server.url.clone()); + let client = crate::WsClient::new(endpoint).expect("client creation must succeed"); + let mut receiver = client.subscribe_events(); + client.connect().await.expect("connect must succeed"); + let _ = recv_event(&mut receiver).await; + client + .send_ping(vec![1, 2, 3, 4]) + .await + .expect("ping send must succeed"); + let event = recv_event(&mut receiver).await; + match event { + crate::WsEvent::Pong { + endpoint_name, + data, + } => { + assert_eq!(endpoint_name, "test_ws"); + assert_eq!(data, vec![1, 2, 3, 4]); + } + other => { + panic!("unexpected event: {other:?}"); + } + } + client.disconnect().await.expect("disconnect must succeed"); + server.shutdown().await; } }