diff --git a/Cargo.toml b/Cargo.toml index aacc02204..47213916f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -57,6 +57,7 @@ console-subscriber = "0.1" env_logger = "0.11" futures = "0.3" futures-util = { version = "0.3", default-features = false } +flate2 = { version = "1.0", default-features = false, features = ["rust_backend"] } lazy_static = "1.4" log = "0.4" parking_lot = "0.12" diff --git a/livekit-api/src/signal_client/mod.rs b/livekit-api/src/signal_client/mod.rs index 32867d17a..a1e7ae36a 100644 --- a/livekit-api/src/signal_client/mod.rs +++ b/livekit-api/src/signal_client/mod.rs @@ -565,6 +565,7 @@ fn create_join_request_param( sdk: proto::client_info::Sdk::Rust as i32, version: options.sdk_options.sdk_version.clone().unwrap_or_default(), protocol: PROTOCOL_VERSION as i32, + client_protocol: proto::RPC_GZIP_CLIENT_PROTOCOL, os: std::env::consts::OS.to_string(), ..Default::default() }; @@ -644,13 +645,16 @@ fn get_livekit_url( create_join_request_param(options, reconnect, reconnect_reason, participant_sid); lk_url.query_pairs_mut().append_pair("join_request", &join_request_param); } else { + let client_protocol = proto::RPC_GZIP_CLIENT_PROTOCOL.to_string(); // For v0 path (dual PC mode): use URL query parameters lk_url .query_pairs_mut() .append_pair("sdk", options.sdk_options.sdk.as_str()) .append_pair("protocol", PROTOCOL_VERSION.to_string().as_str()) .append_pair("auto_subscribe", if options.auto_subscribe { "1" } else { "0" }) - .append_pair("adaptive_stream", if options.adaptive_stream { "1" } else { "0" }); + .append_pair("adaptive_stream", if options.adaptive_stream { "1" } else { "0" }) + // `client_protocol=1` indicates support for gzip RPC compression. + .append_pair("client_protocol", &client_protocol); if let Some(sdk_version) = &options.sdk_options.sdk_version { lk_url.query_pairs_mut().append_pair("version", sdk_version.as_str()); diff --git a/livekit-api/src/signal_client/signal_stream.rs b/livekit-api/src/signal_client/signal_stream.rs index 89c429fdc..4025f492f 100644 --- a/livekit-api/src/signal_client/signal_stream.rs +++ b/livekit-api/src/signal_client/signal_stream.rs @@ -63,6 +63,10 @@ enum InternalMessage { signal: proto::signal_request::Message, response_chn: oneshot::Sender>, }, + RawBytes { + data: Vec, + response_chn: oneshot::Sender>, + }, Pong { ping_data: Vec, }, @@ -348,6 +352,15 @@ impl SignalStream { recv.await.map_err(|_| SignalError::SendError)? } + /// Send raw bytes to the websocket (for JoinRequest, etc.) + /// It also waits for the message to be sent + pub async fn send_raw(&self, data: Vec) -> SignalResult<()> { + let (send, recv) = oneshot::channel(); + let msg = InternalMessage::RawBytes { data, response_chn: send }; + let _ = self.internal_tx.send(msg).await; + recv.await.map_err(|_| SignalError::SendError)? + } + /// This task is used to send messages to the websocket /// It is also responsible for closing the connection async fn write_task( @@ -366,6 +379,14 @@ impl SignalStream { let _ = response_chn.send(Ok(())); } + InternalMessage::RawBytes { data, response_chn } => { + if let Err(err) = ws_writer.send(Message::Binary(data)).await { + let _ = response_chn.send(Err(err.into())); + break; + } + + let _ = response_chn.send(Ok(())); + } InternalMessage::Pong { ping_data } => { if let Err(err) = ws_writer.send(Message::Pong(ping_data)).await { log::error!("failed to send pong message: {:?}", err); diff --git a/livekit-protocol/protocol b/livekit-protocol/protocol index aec2833df..e05f7b7a6 160000 --- a/livekit-protocol/protocol +++ b/livekit-protocol/protocol @@ -1 +1 @@ -Subproject commit aec2833dffcbc4525735f29c96238c13c10bcf64 +Subproject commit e05f7b7a61466a2e051ae367e36db361cdbe699b diff --git a/livekit-protocol/src/lib.rs b/livekit-protocol/src/lib.rs index f8ea61f9d..4f41e167b 100644 --- a/livekit-protocol/src/lib.rs +++ b/livekit-protocol/src/lib.rs @@ -20,6 +20,9 @@ pub mod enum_dispatch; pub mod observer; pub mod promise; +/// `client_protocol=1` indicates support for RPC `compressed_payload` using gzip. +pub const RPC_GZIP_CLIENT_PROTOCOL: i32 = 1; + include!("livekit.rs"); #[cfg(feature = "serde")] diff --git a/livekit-protocol/src/livekit.rs b/livekit-protocol/src/livekit.rs index 106c5c86e..62322c815 100644 --- a/livekit-protocol/src/livekit.rs +++ b/livekit-protocol/src/livekit.rs @@ -946,6 +946,9 @@ pub struct RpcRequest { pub response_timeout_ms: u32, #[prost(uint32, tag="5")] pub version: u32, + /// Compressed payload data. When set, this field is used instead of `payload`. + #[prost(bytes="vec", tag="6")] + pub compressed_payload: ::prost::alloc::vec::Vec, } #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] @@ -958,7 +961,7 @@ pub struct RpcAck { pub struct RpcResponse { #[prost(string, tag="1")] pub request_id: ::prost::alloc::string::String, - #[prost(oneof="rpc_response::Value", tags="2, 3")] + #[prost(oneof="rpc_response::Value", tags="2, 3, 4")] pub value: ::core::option::Option, } /// Nested message and enum types in `RpcResponse`. @@ -970,6 +973,9 @@ pub mod rpc_response { Payload(::prost::alloc::string::String), #[prost(message, tag="3")] Error(super::RpcError), + /// Compressed payload data. When set, this field is used instead of `payload`. + #[prost(bytes, tag="4")] + CompressedPayload(::prost::alloc::vec::Vec), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -4243,7 +4249,7 @@ pub struct JobState { #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct WorkerMessage { - #[prost(oneof="worker_message::Message", tags="1, 2, 3, 4, 5, 6, 7, 8, 9")] + #[prost(oneof="worker_message::Message", tags="1, 2, 3, 4, 5, 6, 7")] pub message: ::core::option::Option, } /// Nested message and enum types in `WorkerMessage`. @@ -4269,17 +4275,13 @@ pub mod worker_message { SimulateJob(super::SimulateJobRequest), #[prost(message, tag="7")] MigrateJob(super::MigrateJobRequest), - #[prost(message, tag="8")] - TextResponse(super::TextMessageResponse), - #[prost(message, tag="9")] - PushText(super::PushTextRequest), } } /// from Server to Worker #[allow(clippy::derive_partial_eq_without_eq)] #[derive(Clone, PartialEq, ::prost::Message)] pub struct ServerMessage { - #[prost(oneof="server_message::Message", tags="1, 2, 3, 5, 4, 6")] + #[prost(oneof="server_message::Message", tags="1, 2, 3, 5, 4")] pub message: ::core::option::Option, } /// Nested message and enum types in `ServerMessage`. @@ -4299,8 +4301,6 @@ pub mod server_message { Termination(super::JobTermination), #[prost(message, tag="4")] Pong(super::WorkerPong), - #[prost(message, tag="6")] - TextRequest(super::TextMessageRequest), } } #[allow(clippy::derive_partial_eq_without_eq)] @@ -4430,61 +4430,6 @@ pub struct JobTermination { #[prost(string, tag="1")] pub job_id: ::prost::alloc::string::String, } -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct AgentSessionState { - #[prost(uint64, tag="1")] - pub version: u64, - #[prost(oneof="agent_session_state::Data", tags="2, 3")] - pub data: ::core::option::Option, -} -/// Nested message and enum types in `AgentSessionState`. -pub mod agent_session_state { - #[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Oneof)] - pub enum Data { - #[prost(bytes, tag="2")] - Snapshot(::prost::alloc::vec::Vec), - #[prost(bytes, tag="3")] - Delta(::prost::alloc::vec::Vec), - } -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct TextMessageRequest { - #[prost(string, tag="1")] - pub message_id: ::prost::alloc::string::String, - #[prost(string, tag="2")] - pub session_id: ::prost::alloc::string::String, - #[prost(string, tag="3")] - pub agent_name: ::prost::alloc::string::String, - #[prost(string, tag="4")] - pub metadata: ::prost::alloc::string::String, - #[prost(message, optional, tag="5")] - pub session_state: ::core::option::Option, - #[prost(string, tag="6")] - pub text: ::prost::alloc::string::String, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct PushTextRequest { - /// The message_id of the TextMessageRequest that this push is for - #[prost(string, tag="1")] - pub message_id: ::prost::alloc::string::String, - #[prost(string, tag="2")] - pub content: ::prost::alloc::string::String, -} -#[allow(clippy::derive_partial_eq_without_eq)] -#[derive(Clone, PartialEq, ::prost::Message)] -pub struct TextMessageResponse { - /// Indicate the request is completed - #[prost(string, tag="1")] - pub message_id: ::prost::alloc::string::String, - #[prost(message, optional, tag="2")] - pub session_state: ::core::option::Option, - #[prost(string, tag="3")] - pub error: ::prost::alloc::string::String, -} #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum JobType { diff --git a/livekit-protocol/src/livekit.serde.rs b/livekit-protocol/src/livekit.serde.rs index 4d35e2a0d..e258fbd8a 100644 --- a/livekit-protocol/src/livekit.serde.rs +++ b/livekit-protocol/src/livekit.serde.rs @@ -1193,142 +1193,6 @@ impl<'de> serde::Deserialize<'de> for AgentDispatchState { deserializer.deserialize_struct("livekit.AgentDispatchState", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for AgentSessionState { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if self.version != 0 { - len += 1; - } - if self.data.is_some() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("livekit.AgentSessionState", len)?; - if self.version != 0 { - #[allow(clippy::needless_borrow)] - #[allow(clippy::needless_borrows_for_generic_args)] - struct_ser.serialize_field("version", ToString::to_string(&self.version).as_str())?; - } - if let Some(v) = self.data.as_ref() { - match v { - agent_session_state::Data::Snapshot(v) => { - #[allow(clippy::needless_borrow)] - #[allow(clippy::needless_borrows_for_generic_args)] - struct_ser.serialize_field("snapshot", pbjson::private::base64::encode(&v).as_str())?; - } - agent_session_state::Data::Delta(v) => { - #[allow(clippy::needless_borrow)] - #[allow(clippy::needless_borrows_for_generic_args)] - struct_ser.serialize_field("delta", pbjson::private::base64::encode(&v).as_str())?; - } - } - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for AgentSessionState { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "version", - "snapshot", - "delta", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - Version, - Snapshot, - Delta, - __SkipField__, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "version" => Ok(GeneratedField::Version), - "snapshot" => Ok(GeneratedField::Snapshot), - "delta" => Ok(GeneratedField::Delta), - _ => Ok(GeneratedField::__SkipField__), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = AgentSessionState; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct livekit.AgentSessionState") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut version__ = None; - let mut data__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::Version => { - if version__.is_some() { - return Err(serde::de::Error::duplicate_field("version")); - } - version__ = - Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) - ; - } - GeneratedField::Snapshot => { - if data__.is_some() { - return Err(serde::de::Error::duplicate_field("snapshot")); - } - data__ = map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| agent_session_state::Data::Snapshot(x.0)); - } - GeneratedField::Delta => { - if data__.is_some() { - return Err(serde::de::Error::duplicate_field("delta")); - } - data__ = map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| agent_session_state::Data::Delta(x.0)); - } - GeneratedField::__SkipField__ => { - let _ = map_.next_value::()?; - } - } - } - Ok(AgentSessionState { - version: version__.unwrap_or_default(), - data: data__, - }) - } - } - deserializer.deserialize_struct("livekit.AgentSessionState", FIELDS, GeneratedVisitor) - } -} impl serde::Serialize for AliOssUpload { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -24339,119 +24203,6 @@ impl<'de> serde::Deserialize<'de> for PublishDataTrackResponse { deserializer.deserialize_struct("livekit.PublishDataTrackResponse", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for PushTextRequest { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if !self.message_id.is_empty() { - len += 1; - } - if !self.content.is_empty() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("livekit.PushTextRequest", len)?; - if !self.message_id.is_empty() { - struct_ser.serialize_field("messageId", &self.message_id)?; - } - if !self.content.is_empty() { - struct_ser.serialize_field("content", &self.content)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for PushTextRequest { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "message_id", - "messageId", - "content", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - MessageId, - Content, - __SkipField__, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "messageId" | "message_id" => Ok(GeneratedField::MessageId), - "content" => Ok(GeneratedField::Content), - _ => Ok(GeneratedField::__SkipField__), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = PushTextRequest; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct livekit.PushTextRequest") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut message_id__ = None; - let mut content__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::MessageId => { - if message_id__.is_some() { - return Err(serde::de::Error::duplicate_field("messageId")); - } - message_id__ = Some(map_.next_value()?); - } - GeneratedField::Content => { - if content__.is_some() { - return Err(serde::de::Error::duplicate_field("content")); - } - content__ = Some(map_.next_value()?); - } - GeneratedField::__SkipField__ => { - let _ = map_.next_value::()?; - } - } - } - Ok(PushTextRequest { - message_id: message_id__.unwrap_or_default(), - content: content__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("livekit.PushTextRequest", FIELDS, GeneratedVisitor) - } -} impl serde::Serialize for RtcpSenderReportState { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -29436,6 +29187,9 @@ impl serde::Serialize for RpcRequest { if self.version != 0 { len += 1; } + if !self.compressed_payload.is_empty() { + len += 1; + } let mut struct_ser = serializer.serialize_struct("livekit.RpcRequest", len)?; if !self.id.is_empty() { struct_ser.serialize_field("id", &self.id)?; @@ -29452,6 +29206,11 @@ impl serde::Serialize for RpcRequest { if self.version != 0 { struct_ser.serialize_field("version", &self.version)?; } + if !self.compressed_payload.is_empty() { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("compressedPayload", pbjson::private::base64::encode(&self.compressed_payload).as_str())?; + } struct_ser.end() } } @@ -29468,6 +29227,8 @@ impl<'de> serde::Deserialize<'de> for RpcRequest { "response_timeout_ms", "responseTimeoutMs", "version", + "compressed_payload", + "compressedPayload", ]; #[allow(clippy::enum_variant_names)] @@ -29477,6 +29238,7 @@ impl<'de> serde::Deserialize<'de> for RpcRequest { Payload, ResponseTimeoutMs, Version, + CompressedPayload, __SkipField__, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -29504,6 +29266,7 @@ impl<'de> serde::Deserialize<'de> for RpcRequest { "payload" => Ok(GeneratedField::Payload), "responseTimeoutMs" | "response_timeout_ms" => Ok(GeneratedField::ResponseTimeoutMs), "version" => Ok(GeneratedField::Version), + "compressedPayload" | "compressed_payload" => Ok(GeneratedField::CompressedPayload), _ => Ok(GeneratedField::__SkipField__), } } @@ -29528,6 +29291,7 @@ impl<'de> serde::Deserialize<'de> for RpcRequest { let mut payload__ = None; let mut response_timeout_ms__ = None; let mut version__ = None; + let mut compressed_payload__ = None; while let Some(k) = map_.next_key()? { match k { GeneratedField::Id => { @@ -29564,6 +29328,14 @@ impl<'de> serde::Deserialize<'de> for RpcRequest { Some(map_.next_value::<::pbjson::private::NumberDeserialize<_>>()?.0) ; } + GeneratedField::CompressedPayload => { + if compressed_payload__.is_some() { + return Err(serde::de::Error::duplicate_field("compressedPayload")); + } + compressed_payload__ = + Some(map_.next_value::<::pbjson::private::BytesDeserialize<_>>()?.0) + ; + } GeneratedField::__SkipField__ => { let _ = map_.next_value::()?; } @@ -29575,6 +29347,7 @@ impl<'de> serde::Deserialize<'de> for RpcRequest { payload: payload__.unwrap_or_default(), response_timeout_ms: response_timeout_ms__.unwrap_or_default(), version: version__.unwrap_or_default(), + compressed_payload: compressed_payload__.unwrap_or_default(), }) } } @@ -29607,6 +29380,11 @@ impl serde::Serialize for RpcResponse { rpc_response::Value::Error(v) => { struct_ser.serialize_field("error", v)?; } + rpc_response::Value::CompressedPayload(v) => { + #[allow(clippy::needless_borrow)] + #[allow(clippy::needless_borrows_for_generic_args)] + struct_ser.serialize_field("compressedPayload", pbjson::private::base64::encode(&v).as_str())?; + } } } struct_ser.end() @@ -29623,6 +29401,8 @@ impl<'de> serde::Deserialize<'de> for RpcResponse { "requestId", "payload", "error", + "compressed_payload", + "compressedPayload", ]; #[allow(clippy::enum_variant_names)] @@ -29630,6 +29410,7 @@ impl<'de> serde::Deserialize<'de> for RpcResponse { RequestId, Payload, Error, + CompressedPayload, __SkipField__, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -29655,6 +29436,7 @@ impl<'de> serde::Deserialize<'de> for RpcResponse { "requestId" | "request_id" => Ok(GeneratedField::RequestId), "payload" => Ok(GeneratedField::Payload), "error" => Ok(GeneratedField::Error), + "compressedPayload" | "compressed_payload" => Ok(GeneratedField::CompressedPayload), _ => Ok(GeneratedField::__SkipField__), } } @@ -29697,6 +29479,12 @@ impl<'de> serde::Deserialize<'de> for RpcResponse { value__ = map_.next_value::<::std::option::Option<_>>()?.map(rpc_response::Value::Error) ; } + GeneratedField::CompressedPayload => { + if value__.is_some() { + return Err(serde::de::Error::duplicate_field("compressedPayload")); + } + value__ = map_.next_value::<::std::option::Option<::pbjson::private::BytesDeserialize<_>>>()?.map(|x| rpc_response::Value::CompressedPayload(x.0)); + } GeneratedField::__SkipField__ => { let _ = map_.next_value::()?; } @@ -36132,9 +35920,6 @@ impl serde::Serialize for ServerMessage { server_message::Message::Pong(v) => { struct_ser.serialize_field("pong", v)?; } - server_message::Message::TextRequest(v) => { - struct_ser.serialize_field("textRequest", v)?; - } } } struct_ser.end() @@ -36152,8 +35937,6 @@ impl<'de> serde::Deserialize<'de> for ServerMessage { "assignment", "termination", "pong", - "text_request", - "textRequest", ]; #[allow(clippy::enum_variant_names)] @@ -36163,7 +35946,6 @@ impl<'de> serde::Deserialize<'de> for ServerMessage { Assignment, Termination, Pong, - TextRequest, __SkipField__, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -36191,7 +35973,6 @@ impl<'de> serde::Deserialize<'de> for ServerMessage { "assignment" => Ok(GeneratedField::Assignment), "termination" => Ok(GeneratedField::Termination), "pong" => Ok(GeneratedField::Pong), - "textRequest" | "text_request" => Ok(GeneratedField::TextRequest), _ => Ok(GeneratedField::__SkipField__), } } @@ -36247,13 +36028,6 @@ impl<'de> serde::Deserialize<'de> for ServerMessage { return Err(serde::de::Error::duplicate_field("pong")); } message__ = map_.next_value::<::std::option::Option<_>>()?.map(server_message::Message::Pong) -; - } - GeneratedField::TextRequest => { - if message__.is_some() { - return Err(serde::de::Error::duplicate_field("textRequest")); - } - message__ = map_.next_value::<::std::option::Option<_>>()?.map(server_message::Message::TextRequest) ; } GeneratedField::__SkipField__ => { @@ -40560,321 +40334,6 @@ impl<'de> serde::Deserialize<'de> for SyncState { deserializer.deserialize_struct("livekit.SyncState", FIELDS, GeneratedVisitor) } } -impl serde::Serialize for TextMessageRequest { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if !self.message_id.is_empty() { - len += 1; - } - if !self.session_id.is_empty() { - len += 1; - } - if !self.agent_name.is_empty() { - len += 1; - } - if !self.metadata.is_empty() { - len += 1; - } - if self.session_state.is_some() { - len += 1; - } - if !self.text.is_empty() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("livekit.TextMessageRequest", len)?; - if !self.message_id.is_empty() { - struct_ser.serialize_field("messageId", &self.message_id)?; - } - if !self.session_id.is_empty() { - struct_ser.serialize_field("sessionId", &self.session_id)?; - } - if !self.agent_name.is_empty() { - struct_ser.serialize_field("agentName", &self.agent_name)?; - } - if !self.metadata.is_empty() { - struct_ser.serialize_field("metadata", &self.metadata)?; - } - if let Some(v) = self.session_state.as_ref() { - struct_ser.serialize_field("sessionState", v)?; - } - if !self.text.is_empty() { - struct_ser.serialize_field("text", &self.text)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for TextMessageRequest { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "message_id", - "messageId", - "session_id", - "sessionId", - "agent_name", - "agentName", - "metadata", - "session_state", - "sessionState", - "text", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - MessageId, - SessionId, - AgentName, - Metadata, - SessionState, - Text, - __SkipField__, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "messageId" | "message_id" => Ok(GeneratedField::MessageId), - "sessionId" | "session_id" => Ok(GeneratedField::SessionId), - "agentName" | "agent_name" => Ok(GeneratedField::AgentName), - "metadata" => Ok(GeneratedField::Metadata), - "sessionState" | "session_state" => Ok(GeneratedField::SessionState), - "text" => Ok(GeneratedField::Text), - _ => Ok(GeneratedField::__SkipField__), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = TextMessageRequest; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct livekit.TextMessageRequest") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut message_id__ = None; - let mut session_id__ = None; - let mut agent_name__ = None; - let mut metadata__ = None; - let mut session_state__ = None; - let mut text__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::MessageId => { - if message_id__.is_some() { - return Err(serde::de::Error::duplicate_field("messageId")); - } - message_id__ = Some(map_.next_value()?); - } - GeneratedField::SessionId => { - if session_id__.is_some() { - return Err(serde::de::Error::duplicate_field("sessionId")); - } - session_id__ = Some(map_.next_value()?); - } - GeneratedField::AgentName => { - if agent_name__.is_some() { - return Err(serde::de::Error::duplicate_field("agentName")); - } - agent_name__ = Some(map_.next_value()?); - } - GeneratedField::Metadata => { - if metadata__.is_some() { - return Err(serde::de::Error::duplicate_field("metadata")); - } - metadata__ = Some(map_.next_value()?); - } - GeneratedField::SessionState => { - if session_state__.is_some() { - return Err(serde::de::Error::duplicate_field("sessionState")); - } - session_state__ = map_.next_value()?; - } - GeneratedField::Text => { - if text__.is_some() { - return Err(serde::de::Error::duplicate_field("text")); - } - text__ = Some(map_.next_value()?); - } - GeneratedField::__SkipField__ => { - let _ = map_.next_value::()?; - } - } - } - Ok(TextMessageRequest { - message_id: message_id__.unwrap_or_default(), - session_id: session_id__.unwrap_or_default(), - agent_name: agent_name__.unwrap_or_default(), - metadata: metadata__.unwrap_or_default(), - session_state: session_state__, - text: text__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("livekit.TextMessageRequest", FIELDS, GeneratedVisitor) - } -} -impl serde::Serialize for TextMessageResponse { - #[allow(deprecated)] - fn serialize(&self, serializer: S) -> std::result::Result - where - S: serde::Serializer, - { - use serde::ser::SerializeStruct; - let mut len = 0; - if !self.message_id.is_empty() { - len += 1; - } - if self.session_state.is_some() { - len += 1; - } - if !self.error.is_empty() { - len += 1; - } - let mut struct_ser = serializer.serialize_struct("livekit.TextMessageResponse", len)?; - if !self.message_id.is_empty() { - struct_ser.serialize_field("messageId", &self.message_id)?; - } - if let Some(v) = self.session_state.as_ref() { - struct_ser.serialize_field("sessionState", v)?; - } - if !self.error.is_empty() { - struct_ser.serialize_field("error", &self.error)?; - } - struct_ser.end() - } -} -impl<'de> serde::Deserialize<'de> for TextMessageResponse { - #[allow(deprecated)] - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - const FIELDS: &[&str] = &[ - "message_id", - "messageId", - "session_state", - "sessionState", - "error", - ]; - - #[allow(clippy::enum_variant_names)] - enum GeneratedField { - MessageId, - SessionState, - Error, - __SkipField__, - } - impl<'de> serde::Deserialize<'de> for GeneratedField { - fn deserialize(deserializer: D) -> std::result::Result - where - D: serde::Deserializer<'de>, - { - struct GeneratedVisitor; - - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = GeneratedField; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "expected one of: {:?}", &FIELDS) - } - - #[allow(unused_variables)] - fn visit_str(self, value: &str) -> std::result::Result - where - E: serde::de::Error, - { - match value { - "messageId" | "message_id" => Ok(GeneratedField::MessageId), - "sessionState" | "session_state" => Ok(GeneratedField::SessionState), - "error" => Ok(GeneratedField::Error), - _ => Ok(GeneratedField::__SkipField__), - } - } - } - deserializer.deserialize_identifier(GeneratedVisitor) - } - } - struct GeneratedVisitor; - impl<'de> serde::de::Visitor<'de> for GeneratedVisitor { - type Value = TextMessageResponse; - - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - formatter.write_str("struct livekit.TextMessageResponse") - } - - fn visit_map(self, mut map_: V) -> std::result::Result - where - V: serde::de::MapAccess<'de>, - { - let mut message_id__ = None; - let mut session_state__ = None; - let mut error__ = None; - while let Some(k) = map_.next_key()? { - match k { - GeneratedField::MessageId => { - if message_id__.is_some() { - return Err(serde::de::Error::duplicate_field("messageId")); - } - message_id__ = Some(map_.next_value()?); - } - GeneratedField::SessionState => { - if session_state__.is_some() { - return Err(serde::de::Error::duplicate_field("sessionState")); - } - session_state__ = map_.next_value()?; - } - GeneratedField::Error => { - if error__.is_some() { - return Err(serde::de::Error::duplicate_field("error")); - } - error__ = Some(map_.next_value()?); - } - GeneratedField::__SkipField__ => { - let _ = map_.next_value::()?; - } - } - } - Ok(TextMessageResponse { - message_id: message_id__.unwrap_or_default(), - session_state: session_state__, - error: error__.unwrap_or_default(), - }) - } - } - deserializer.deserialize_struct("livekit.TextMessageResponse", FIELDS, GeneratedVisitor) - } -} impl serde::Serialize for TimeSeriesMetric { #[allow(deprecated)] fn serialize(&self, serializer: S) -> std::result::Result @@ -48357,12 +47816,6 @@ impl serde::Serialize for WorkerMessage { worker_message::Message::MigrateJob(v) => { struct_ser.serialize_field("migrateJob", v)?; } - worker_message::Message::TextResponse(v) => { - struct_ser.serialize_field("textResponse", v)?; - } - worker_message::Message::PushText(v) => { - struct_ser.serialize_field("pushText", v)?; - } } } struct_ser.end() @@ -48386,10 +47839,6 @@ impl<'de> serde::Deserialize<'de> for WorkerMessage { "simulateJob", "migrate_job", "migrateJob", - "text_response", - "textResponse", - "push_text", - "pushText", ]; #[allow(clippy::enum_variant_names)] @@ -48401,8 +47850,6 @@ impl<'de> serde::Deserialize<'de> for WorkerMessage { Ping, SimulateJob, MigrateJob, - TextResponse, - PushText, __SkipField__, } impl<'de> serde::Deserialize<'de> for GeneratedField { @@ -48432,8 +47879,6 @@ impl<'de> serde::Deserialize<'de> for WorkerMessage { "ping" => Ok(GeneratedField::Ping), "simulateJob" | "simulate_job" => Ok(GeneratedField::SimulateJob), "migrateJob" | "migrate_job" => Ok(GeneratedField::MigrateJob), - "textResponse" | "text_response" => Ok(GeneratedField::TextResponse), - "pushText" | "push_text" => Ok(GeneratedField::PushText), _ => Ok(GeneratedField::__SkipField__), } } @@ -48503,20 +47948,6 @@ impl<'de> serde::Deserialize<'de> for WorkerMessage { return Err(serde::de::Error::duplicate_field("migrateJob")); } message__ = map_.next_value::<::std::option::Option<_>>()?.map(worker_message::Message::MigrateJob) -; - } - GeneratedField::TextResponse => { - if message__.is_some() { - return Err(serde::de::Error::duplicate_field("textResponse")); - } - message__ = map_.next_value::<::std::option::Option<_>>()?.map(worker_message::Message::TextResponse) -; - } - GeneratedField::PushText => { - if message__.is_some() { - return Err(serde::de::Error::duplicate_field("pushText")); - } - message__ = map_.next_value::<::std::option::Option<_>>()?.map(worker_message::Message::PushText) ; } GeneratedField::__SkipField__ => { diff --git a/livekit/Cargo.toml b/livekit/Cargo.toml index 9881ebfa7..6eff52c73 100644 --- a/livekit/Cargo.toml +++ b/livekit/Cargo.toml @@ -45,6 +45,7 @@ libloading = { version = "0.8.6" } bytes = { workspace = true } bmrng = "0.5.2" base64 = "0.22" +flate2 = { workspace = true } [dev-dependencies] anyhow = { workspace = true } diff --git a/livekit/src/room/mod.rs b/livekit/src/room/mod.rs index ffd383e56..fc47189d1 100644 --- a/livekit/src/room/mod.rs +++ b/livekit/src/room/mod.rs @@ -521,6 +521,7 @@ impl Room { pi.attributes, e2ee_manager.encryption_type(), pi.permission, + pi.client_protocol, ); let dispatcher = Dispatcher::::default(); @@ -674,6 +675,7 @@ impl Room { pi.metadata, pi.attributes, pi.permission, + pi.client_protocol, ) }; participant.update_info(pi.clone()); @@ -1043,6 +1045,7 @@ impl RoomSession { pi.metadata, pi.attributes, pi.permission, + pi.client_protocol, ) }; @@ -1712,6 +1715,7 @@ impl RoomSession { metadata: String, attributes: HashMap, permission: Option, + client_protocol: i32, ) -> RemoteParticipant { let participant = RemoteParticipant::new( self.rtc_engine.clone(), @@ -1724,6 +1728,7 @@ impl RoomSession { attributes, self.options.auto_subscribe, permission, + client_protocol, ); participant.on_track_published({ @@ -1859,7 +1864,7 @@ impl RoomSession { self.remote_participants.read().values().find(|x| &x.sid() == sid).cloned() } - fn get_participant_by_identity( + pub(crate) fn get_participant_by_identity( &self, identity: &ParticipantIdentity, ) -> Option { diff --git a/livekit/src/room/participant/local_participant.rs b/livekit/src/room/participant/local_participant.rs index c72b5f4c4..27b70e93d 100644 --- a/livekit/src/room/participant/local_participant.rs +++ b/livekit/src/room/participant/local_participant.rs @@ -34,7 +34,10 @@ use crate::{ e2ee::EncryptionType, options::{self, compute_video_encodings, video_layers_from_encodings, TrackPublishOptions}, prelude::*, - room::participant::rpc::{RpcError, RpcErrorCode, RpcInvocationData, MAX_PAYLOAD_BYTES}, + room::participant::rpc::{ + compress_rpc_payload_bytes, RpcError, RpcErrorCode, RpcInvocationData, MAX_PAYLOAD_BYTES, + RPC_GZIP_CLIENT_PROTOCOL, + }, rtc_engine::{EngineError, RtcEngine}, ChatMessage, DataPacket, RoomSession, RpcAck, RpcRequest, RpcResponse, SipDTMF, Transcription, }; @@ -117,6 +120,7 @@ impl LocalParticipant { attributes: HashMap, encryption_type: EncryptionType, permission: Option, + client_protocol: i32, ) -> Self { Self { inner: super::new_inner( @@ -129,6 +133,7 @@ impl LocalParticipant { kind, kind_details, permission, + client_protocol, ), local: Arc::new(LocalInfo { events: LocalEvents::default(), @@ -613,15 +618,42 @@ impl LocalParticipant { .map_err(Into::into) } + /// Check if a remote participant supports RPC compression. + /// Returns true if the participant advertises the gzip compression protocol. + fn destination_supports_compression(&self, destination_identity: &str) -> bool { + let Some(session) = self.session() else { + return false; + }; + let participant_identity: ParticipantIdentity = destination_identity.to_string().into(); + let Some(participant) = session.get_participant_by_identity(&participant_identity) else { + return false; + }; + participant.client_protocol() >= RPC_GZIP_CLIENT_PROTOCOL + } + async fn publish_rpc_request(&self, rpc_request: RpcRequest) -> RoomResult<()> { + let supports_compression = + self.destination_supports_compression(&rpc_request.destination_identity); + + // Use compressed_payload field (raw bytes) when compression is beneficial + let (payload, compressed_payload) = if supports_compression { + match compress_rpc_payload_bytes(&rpc_request.payload) { + Some(compressed) => (String::new(), compressed), + None => (rpc_request.payload, Vec::new()), + } + } else { + (rpc_request.payload, Vec::new()) + }; + let destination_identities = vec![rpc_request.destination_identity]; + let rpc_request_message = proto::RpcRequest { id: rpc_request.id, method: rpc_request.method, - payload: rpc_request.payload, + payload, + compressed_payload, response_timeout_ms: rpc_request.response_timeout.as_millis() as u32, version: rpc_request.version, - ..Default::default() }; let data = proto::DataPacket { @@ -638,23 +670,38 @@ impl LocalParticipant { } async fn publish_rpc_response(&self, rpc_response: RpcResponse) -> RoomResult<()> { + let supports_compression = + self.destination_supports_compression(&rpc_response.destination_identity); let destination_identities = vec![rpc_response.destination_identity]; - let rpc_response_message = proto::RpcResponse { - request_id: rpc_response.request_id, - value: Some(match rpc_response.error { - Some(error) => proto::rpc_response::Value::Error(proto::RpcError { - code: error.code, - message: error.message, - data: error.data, - }), - None => proto::rpc_response::Value::Payload(rpc_response.payload.unwrap()), + + // Determine the response value (error, compressed payload, or plain payload) + let response_value = match rpc_response.error { + Some(error) => proto::rpc_response::Value::Error(proto::RpcError { + code: error.code, + message: error.message, + data: error.data, }), - ..Default::default() + None => { + let payload = rpc_response.payload.unwrap_or_default(); + if supports_compression { + match compress_rpc_payload_bytes(&payload) { + Some(compressed) => { + proto::rpc_response::Value::CompressedPayload(compressed) + } + None => proto::rpc_response::Value::Payload(payload), + } + } else { + proto::rpc_response::Value::Payload(payload) + } + } }; + let rpc_response_message = + proto::RpcResponse { request_id: rpc_response.request_id, value: Some(response_value) }; + let data = proto::DataPacket { value: Some(proto::data_packet::Value::RpcResponse(rpc_response_message)), - destination_identities: destination_identities.clone(), + destination_identities, ..Default::default() }; @@ -666,13 +713,12 @@ impl LocalParticipant { } async fn publish_rpc_ack(&self, rpc_ack: RpcAck) -> RoomResult<()> { - let destination_identities = vec![rpc_ack.destination_identity]; let rpc_ack_message = proto::RpcAck { request_id: rpc_ack.request_id, ..Default::default() }; let data = proto::DataPacket { value: Some(proto::data_packet::Value::RpcAck(rpc_ack_message)), - destination_identities: destination_identities.clone(), + destination_identities: vec![rpc_ack.destination_identity], ..Default::default() }; @@ -734,6 +780,10 @@ impl LocalParticipant { self.inner.info.read().attributes.clone() } + pub fn client_protocol(&self) -> i32 { + self.inner.info.read().client_protocol + } + pub fn is_speaking(&self) -> bool { self.inner.info.read().speaking } @@ -786,6 +836,13 @@ impl LocalParticipant { let min_effective_timeout = Duration::from_millis(1000); if data.payload.len() > MAX_PAYLOAD_BYTES { + log::error!( + "RPC request payload too large: {} bytes (max: {} bytes), method: {}, destination: {}", + data.payload.len(), + MAX_PAYLOAD_BYTES, + data.method, + data.destination_identity + ); return Err(RpcError::built_in(RpcErrorCode::RequestPayloadTooLarge, None)); } @@ -796,6 +853,13 @@ impl LocalParticipant { let server_version = Version::parse(&server_info.version).unwrap(); let min_required_version = Version::parse("1.8.0").unwrap(); if server_version < min_required_version { + log::error!( + "RPC error code {}: Server version {} does not support RPC (requires >= 1.8.0), method: {}, destination: {}", + RpcErrorCode::UnsupportedServer as u32, + server_info.version, + data.method, + data.destination_identity + ); return Err(RpcError::built_in(RpcErrorCode::UnsupportedServer, None)); } } @@ -839,6 +903,14 @@ impl LocalParticipant { // Wait for ack timeout match tokio::time::timeout(max_round_trip_latency, ack_rx).await { Err(_) => { + log::error!( + "RPC error code {}: Connection timeout waiting for ACK (timeout: {:?}), request_id: {}, method: {}, destination: {}", + RpcErrorCode::ConnectionTimeout as u32, + max_round_trip_latency, + id, + data.method, + data.destination_identity + ); let mut rpc_state = self.local.rpc_state.lock(); rpc_state.pending_acks.remove(&id); rpc_state.pending_responses.remove(&id); @@ -849,9 +921,17 @@ impl LocalParticipant { } } - // Wait for response timout + // Wait for response timeout let response = match tokio::time::timeout(data.response_timeout, response_rx).await { Err(_) => { + log::error!( + "RPC error code {}: Response timeout (timeout: {:?}), request_id: {}, method: {}, destination: {}", + RpcErrorCode::ResponseTimeout as u32, + data.response_timeout, + id, + data.method, + data.destination_identity + ); self.local.rpc_state.lock().pending_responses.remove(&id); return Err(RpcError::built_in(RpcErrorCode::ResponseTimeout, None)); } @@ -861,10 +941,26 @@ impl LocalParticipant { match response { Err(_) => { // Something went wrong locally + log::error!( + "RPC error code {}: Recipient disconnected, request_id: {}, method: {}, destination: {}", + RpcErrorCode::RecipientDisconnected as u32, + id, + data.method, + data.destination_identity + ); Err(RpcError::built_in(RpcErrorCode::RecipientDisconnected, None)) } Ok(Err(e)) => { // RPC error from remote, forward it + log::error!( + "RPC error code {}: {} (from remote), request_id: {}, method: {}, destination: {}, data: {:?}", + e.code, + e.message, + id, + data.method, + data.destination_identity, + e.data + ); Err(e) } Ok(Ok(payload)) => { @@ -911,10 +1007,11 @@ impl LocalParticipant { ) { let mut rpc_state = self.local.rpc_state.lock(); if let Some(tx) = rpc_state.pending_responses.remove(&request_id) { - let _ = tx.send(match error { + let result = match error { Some(e) => Err(RpcError::from_proto(e)), None => Ok(payload.unwrap_or_default()), - }); + }; + let _ = tx.send(result); } else { log::error!("Response received for unexpected RPC request: {}", request_id); } @@ -943,6 +1040,14 @@ impl LocalParticipant { let request_id_2 = request_id.clone(); let response = if version != 1 { + log::error!( + "RPC error code {}: Unsupported RPC version {}, request_id: {}, method: {}, caller: {}", + RpcErrorCode::UnsupportedVersion as u32, + version, + request_id, + method, + caller_identity + ); Err(RpcError::built_in(RpcErrorCode::UnsupportedVersion, None)) } else { let handler = self.local.rpc_state.lock().handlers.get(&method).cloned(); @@ -951,9 +1056,9 @@ impl LocalParticipant { Some(handler) => { match tokio::task::spawn(async move { handler(RpcInvocationData { - request_id: request_id.clone(), - caller_identity: caller_identity.clone(), - payload: payload.clone(), + request_id, + caller_identity, + payload, response_timeout, }) .await @@ -962,12 +1067,27 @@ impl LocalParticipant { { Ok(result) => result, Err(e) => { - log::error!("RPC method handler returned an error: {:?}", e); + log::error!( + "RPC error code {}: Method handler panicked: {:?}, request_id: {}, method: {}", + RpcErrorCode::ApplicationError as u32, + e, + request_id_2, + method + ); Err(RpcError::built_in(RpcErrorCode::ApplicationError, None)) } } } - None => Err(RpcError::built_in(RpcErrorCode::UnsupportedMethod, None)), + None => { + log::error!( + "RPC error code {}: Unsupported method '{}', request_id: {}, caller: {}", + RpcErrorCode::UnsupportedMethod as u32, + method, + request_id_2, + caller_identity_2 + ); + Err(RpcError::built_in(RpcErrorCode::UnsupportedMethod, None)) + } } }; @@ -975,7 +1095,17 @@ impl LocalParticipant { Ok(response_payload) if response_payload.len() <= MAX_PAYLOAD_BYTES => { (Some(response_payload), None) } - Ok(_) => (None, Some(RpcError::built_in(RpcErrorCode::ResponsePayloadTooLarge, None))), + Ok(response_payload) => { + log::error!( + "RPC error code {}: Response payload too large: {} bytes (max: {} bytes), request_id: {}, caller: {}", + RpcErrorCode::ResponsePayloadTooLarge as u32, + response_payload.len(), + MAX_PAYLOAD_BYTES, + request_id_2, + caller_identity_2 + ); + (None, Some(RpcError::built_in(RpcErrorCode::ResponsePayloadTooLarge, None))) + } Err(e) => (None, Some(e.into())), }; diff --git a/livekit/src/room/participant/mod.rs b/livekit/src/room/participant/mod.rs index 9c660dc3a..eb36ed828 100644 --- a/livekit/src/room/participant/mod.rs +++ b/livekit/src/room/participant/mod.rs @@ -91,6 +91,7 @@ impl Participant { pub fn name(self: &Self) -> String; pub fn metadata(self: &Self) -> String; pub fn attributes(self: &Self) -> HashMap; + pub fn client_protocol(self: &Self) -> i32; pub fn is_speaking(self: &Self) -> bool; pub fn audio_level(self: &Self) -> f32; pub fn connection_quality(self: &Self) -> ConnectionQuality; @@ -132,6 +133,8 @@ struct ParticipantInfo { pub kind_details: Vec, pub disconnect_reason: DisconnectReason, pub permission: Option, + /// Client protocol version indicating feature support (e.g., 1 = compression support) + pub client_protocol: i32, } type TrackMutedHandler = Box; @@ -180,6 +183,7 @@ pub(super) fn new_inner( kind: ParticipantKind, kind_details: Vec, permission: Option, + client_protocol: i32, ) -> Arc { Arc::new(ParticipantInner { rtc_engine, @@ -196,6 +200,7 @@ pub(super) fn new_inner( connection_quality: ConnectionQuality::Excellent, disconnect_reason: DisconnectReason::UnknownReason, permission, + client_protocol, }), track_publications: Default::default(), events: Default::default(), @@ -245,6 +250,9 @@ pub(super) fn update_info( cb(participant.clone(), new_info.permission.clone()); } } + + // Update client_protocol + info.client_protocol = new_info.client_protocol; } pub(super) fn set_speaking( diff --git a/livekit/src/room/participant/remote_participant.rs b/livekit/src/room/participant/remote_participant.rs index 9d4245ace..25da872b3 100644 --- a/livekit/src/room/participant/remote_participant.rs +++ b/livekit/src/room/participant/remote_participant.rs @@ -85,6 +85,7 @@ impl RemoteParticipant { attributes: HashMap, auto_subscribe: bool, permission: Option, + client_protocol: i32, ) -> Self { Self { inner: super::new_inner( @@ -97,6 +98,7 @@ impl RemoteParticipant { kind, kind_details, permission, + client_protocol, ), remote: Arc::new(RemoteInfo { events: Default::default(), auto_subscribe }), } @@ -518,6 +520,10 @@ impl RemoteParticipant { self.inner.info.read().attributes.clone() } + pub fn client_protocol(&self) -> i32 { + self.inner.info.read().client_protocol + } + pub fn is_speaking(&self) -> bool { self.inner.info.read().speaking } diff --git a/livekit/src/room/participant/rpc.rs b/livekit/src/room/participant/rpc.rs index b04691dda..7f3531f3a 100644 --- a/livekit/src/room/participant/rpc.rs +++ b/livekit/src/room/participant/rpc.rs @@ -146,6 +146,89 @@ impl RpcError { /// Maximum payload size in bytes pub const MAX_PAYLOAD_BYTES: usize = 15360; // 15 KB +pub use livekit_protocol::RPC_GZIP_CLIENT_PROTOCOL; + +/// Minimum payload size to trigger compression (1 KB) +pub const COMPRESSION_THRESHOLD_BYTES: usize = 1024; + +fn compress_rpc_payload_gzip(payload_bytes: &[u8]) -> Option> { + use flate2::{read::GzEncoder, Compression}; + use std::io::Read; + + // Compress the payload + let mut encoder = GzEncoder::new(payload_bytes, Compression::fast()); + let mut compressed = Vec::new(); + match encoder.read_to_end(&mut compressed) { + Ok(_) => { + // Only use compressed version if it's actually smaller + if compressed.len() < payload_bytes.len() { + return Some(compressed); + } + // Compression didn't help + None + } + Err(e) => { + log::warn!("Failed to compress RPC payload: {}", e); + None + } + } +} + +/// Compress an RPC payload to raw bytes using gzip. +/// Returns Some(compressed_bytes) if compression is beneficial, None otherwise. +/// This is used with the `compressed_payload` proto field (no base64 overhead). +pub fn compress_rpc_payload_bytes(payload: &str) -> Option> { + let payload_bytes = payload.as_bytes(); + + // Only compress if payload is large enough + if payload_bytes.len() < COMPRESSION_THRESHOLD_BYTES { + return None; + } + + compress_rpc_payload_gzip(payload_bytes) +} + +/// Decompress raw bytes RPC payload using Gzip. +/// Returns the decompressed string payload. +pub fn decompress_rpc_payload_bytes_gzip(compressed: &[u8]) -> Result { + use flate2::read::GzDecoder; + use std::io::Read; + + let mut decoder = GzDecoder::new(compressed); + let mut decompressed = Vec::new(); + let mut chunk = [0_u8; 4096]; + + loop { + let bytes_read = decoder + .read(&mut chunk) + .map_err(|e| format!("Failed to decompress RPC payload: {}", e))?; + if bytes_read == 0 { + break; + } + + if decompressed.len() + bytes_read > MAX_PAYLOAD_BYTES { + return Err(format!( + "Decompressed RPC payload exceeds max size: {} bytes", + MAX_PAYLOAD_BYTES + )); + } + + decompressed.extend_from_slice(&chunk[..bytes_read]); + } + + match String::from_utf8(decompressed) { + Ok(s) => Ok(s), + Err(e) => Err(format!("Failed to decode decompressed RPC payload as UTF-8: {}", e)), + } +} + +/// Decompress raw bytes RPC payload using Gzip. +/// Returns the decompressed string payload. +/// This is used with the `compressed_payload` proto field. +pub fn decompress_rpc_payload_bytes(compressed: &[u8]) -> Result { + decompress_rpc_payload_bytes_gzip(compressed) +} + /// Calculate the byte length of a string pub(crate) fn byte_length(s: &str) -> usize { s.as_bytes().len() diff --git a/livekit/src/rtc_engine/rtc_session.rs b/livekit/src/rtc_engine/rtc_session.rs index 8337a7f5b..7d1fce88c 100644 --- a/livekit/src/rtc_engine/rtc_session.rs +++ b/livekit/src/rtc_engine/rtc_session.rs @@ -39,6 +39,7 @@ use tokio::sync::{mpsc, oneshot, watch, Notify}; use super::{rtc_events, EngineError, EngineOptions, EngineResult, SimulateScenario}; use crate::{ id::ParticipantIdentity, + room::participant::{decompress_rpc_payload_bytes, RpcErrorCode}, utils::{ ttl_map::TtlMap, tx_queue::{TxQueue, TxQueueItem}, @@ -1295,19 +1296,66 @@ impl SessionInner { } proto::data_packet::Value::RpcRequest(rpc_request) => { let caller_identity = participant_identity; - self.emitter.send(SessionEvent::RpcRequest { - caller_identity, - request_id: rpc_request.id.clone(), - method: rpc_request.method, - payload: rpc_request.payload, - response_timeout: Duration::from_millis(rpc_request.response_timeout_ms as u64), - version: rpc_request.version, - }) + // Prefer compressed payload when present. + // If decompression fails, only fall back to plain payload when it is non-empty. + let payload = if !rpc_request.compressed_payload.is_empty() { + match decompress_rpc_payload_bytes(&rpc_request.compressed_payload) { + Ok(decompressed) => Some(decompressed), + Err(e) => { + if rpc_request.payload.is_empty() { + log::error!( + "Failed to decompress RPC request payload and plain payload is empty: {}", + e + ); + None + } else { + log::error!( + "Failed to decompress RPC request payload, falling back to plain payload: {}", + e + ); + Some(rpc_request.payload) + } + } + } + } else { + Some(rpc_request.payload) + }; + if let Some(payload) = payload { + self.emitter.send(SessionEvent::RpcRequest { + caller_identity, + request_id: rpc_request.id.clone(), + method: rpc_request.method, + payload, + response_timeout: Duration::from_millis( + rpc_request.response_timeout_ms as u64, + ), + version: rpc_request.version, + }) + } else { + Ok(()) + } } proto::data_packet::Value::RpcResponse(rpc_response) => { let (payload, error) = match rpc_response.value { None => (None, None), Some(proto::rpc_response::Value::Payload(payload)) => (Some(payload), None), + Some(proto::rpc_response::Value::CompressedPayload(compressed)) => { + match decompress_rpc_payload_bytes(&compressed) { + Ok(decompressed) => (Some(decompressed), None), + Err(e) => { + log::error!("Failed to decompress RPC response payload: {}", e); + ( + None, + Some(proto::RpcError { + code: RpcErrorCode::ApplicationError as u32, + message: "Failed to decompress RPC response payload" + .to_string(), + data: e, + }), + ) + } + } + } Some(proto::rpc_response::Value::Error(err)) => (None, Some(err)), }; self.emitter.send(SessionEvent::RpcResponse {