diff --git a/srtp/src/session/mod.rs b/srtp/src/session/mod.rs index efee1bef8..81edefe32 100644 --- a/srtp/src/session/mod.rs +++ b/srtp/src/session/mod.rs @@ -28,7 +28,7 @@ const DEFAULT_SESSION_SRTCP_REPLAY_PROTECTION_WINDOW: usize = 64; pub struct Session { local_context: Arc>, streams_map: Arc>>>, - new_stream_rx: Arc>>>, + new_stream_rx: Arc, Option)>>>, close_stream_tx: mpsc::Sender, close_session_tx: mpsc::Sender<()>, pub(crate) udp_tx: Arc, @@ -128,7 +128,7 @@ impl Session { buf: &mut [u8], streams_map: &Arc>>>, close_stream_tx: &mpsc::Sender, - new_stream_tx: &mut mpsc::Sender>, + new_stream_tx: &mut mpsc::Sender<(Arc, Option)>, remote_context: &mut Context, is_rtp: bool, ) -> Result<()> { @@ -144,24 +144,28 @@ impl Session { }; let mut buf = &decrypted[..]; - let ssrcs = if is_rtp { - vec![rtp::header::Header::unmarshal(&mut buf)?.ssrc] + let (ssrcs, header) = if is_rtp { + let header = rtp::header::Header::unmarshal(&mut buf)?; + (vec![header.ssrc], Some(header)) } else { let pkts = rtcp::packet::unmarshal(&mut buf)?; - destination_ssrc(&pkts) + (destination_ssrc(&pkts), None) }; for ssrc in ssrcs { let (stream, is_new) = Session::get_or_create_stream(streams_map, close_stream_tx.clone(), is_rtp, ssrc) .await; + if is_new { log::trace!( "srtp session got new {} stream {}", if is_rtp { "rtp" } else { "rtcp" }, ssrc ); - new_stream_tx.send(Arc::clone(&stream)).await?; + new_stream_tx + .send((Arc::clone(&stream), header.clone())) + .await?; } match stream.buffer.write(&decrypted).await { @@ -210,14 +214,13 @@ impl Session { } /// accept returns a stream to handle RTCP for a single SSRC - pub async fn accept(&self) -> Result> { + pub async fn accept(&self) -> Result<(Arc, Option)> { let mut new_stream_rx = self.new_stream_rx.lock().await; - let result = new_stream_rx.recv().await; - if let Some(stream) = result { - Ok(stream) - } else { - Err(Error::SessionSrtpAlreadyClosed) - } + + new_stream_rx + .recv() + .await + .ok_or(Error::SessionSrtpAlreadyClosed) } pub async fn close(&self) -> Result<()> { diff --git a/srtp/src/session/session_rtcp_test.rs b/srtp/src/session/session_rtcp_test.rs index bff477c43..79a097fd9 100644 --- a/srtp/src/session/session_rtcp_test.rs +++ b/srtp/src/session/session_rtcp_test.rs @@ -84,7 +84,7 @@ async fn test_session_srtcp_accept() -> Result<()> { let test_payload = rtcp_packet.marshal()?; sa.write_rtcp(&rtcp_packet).await?; - let read_stream = sb.accept().await?; + let (read_stream, _) = sb.accept().await?; let ssrc = read_stream.get_ssrc(); assert_eq!( ssrc, TEST_SSRC, diff --git a/srtp/src/session/session_rtp_test.rs b/srtp/src/session/session_rtp_test.rs index 5764d1346..7c38d8127 100644 --- a/srtp/src/session/session_rtp_test.rs +++ b/srtp/src/session/session_rtp_test.rs @@ -86,19 +86,22 @@ async fn test_session_srtp_accept() -> Result<()> { let packet = rtp::packet::Packet { header: rtp::header::Header { + version: 2, ssrc: TEST_SSRC, + payload_type: 96, ..Default::default() }, payload: test_payload.clone(), }; sa.write_rtp(&packet).await?; - let read_stream = sb.accept().await?; + let (read_stream, header) = sb.accept().await?; let ssrc = read_stream.get_ssrc(); assert_eq!( ssrc, TEST_SSRC, "SSRC mismatch during accept exp({TEST_SSRC}) actual({ssrc})" ); + assert_eq!(header, Some(packet.header)); read_stream.read(&mut read_buffer).await?; diff --git a/webrtc/src/peer_connection/mod.rs b/webrtc/src/peer_connection/mod.rs index 0bb19cc1e..a748cec32 100644 --- a/webrtc/src/peer_connection/mod.rs +++ b/webrtc/src/peer_connection/mod.rs @@ -81,8 +81,7 @@ use crate::rtp_transceiver::rtp_receiver::RTCRtpReceiver; use crate::rtp_transceiver::rtp_sender::RTCRtpSender; use crate::rtp_transceiver::rtp_transceiver_direction::RTCRtpTransceiverDirection; use crate::rtp_transceiver::{ - find_by_mid, handle_unknown_rtp_packet, satisfy_type_and_direction, RTCRtpTransceiver, - RTCRtpTransceiverInit, SSRC, + find_by_mid, satisfy_type_and_direction, RTCRtpTransceiver, RTCRtpTransceiverInit, SSRC, }; use crate::sctp_transport::sctp_transport_capabilities::SCTPTransportCapabilities; use crate::sctp_transport::sctp_transport_state::RTCSctpTransportState; diff --git a/webrtc/src/peer_connection/peer_connection_internal.rs b/webrtc/src/peer_connection/peer_connection_internal.rs index 3931553ca..5dddb62ce 100644 --- a/webrtc/src/peer_connection/peer_connection_internal.rs +++ b/webrtc/src/peer_connection/peer_connection_internal.rs @@ -2,7 +2,7 @@ use std::collections::VecDeque; use std::sync::Weak; use super::*; -use crate::rtp_transceiver::create_stream_info; +use crate::rtp_transceiver::{create_stream_info, PayloadType}; use crate::stats::stats_collector::StatsCollector; use crate::stats::{ InboundRTPStats, OutboundRTPStats, RTCStatsType, RemoteInboundRTPStats, RemoteOutboundRTPStats, @@ -15,7 +15,6 @@ use arc_swap::ArcSwapOption; use portable_atomic::AtomicIsize; use smol_str::SmolStr; use tokio::time::Instant; -use util::Unmarshal; pub(crate) struct PeerConnectionInternal { /// a value containing the last known greater mid value @@ -309,8 +308,12 @@ impl PeerConnectionInternal { } }; - let stream = match srtp_session.accept().await { - Ok(stream) => stream, + let (stream, header) = match srtp_session.accept().await { + Ok((stream, Some(header))) => (stream, header), + Ok((_, None)) => { + log::error!("Accepting RTP session, without RTP header?"); + return; + } Err(err) => { log::warn!("Failed to accept RTP {}", err); return; @@ -338,16 +341,16 @@ impl PeerConnectionInternal { let pci = Arc::clone(&pci); tokio::spawn(async move { let ssrc = stream.get_ssrc(); - dtls_transport .store_simulcast_stream(ssrc, Arc::clone(&stream)) .await; - if let Err(err) = pci.handle_incoming_ssrc(stream, ssrc).await { + if let Err(err) = pci + .handle_incoming_rtp_stream(stream, header.payload_type) + .await + { log::warn!( - "Incoming unhandled RTP ssrc({}), on_track will not be fired. {}", - ssrc, - err + "Incoming unhandled RTP ssrc({ssrc}), on_track will not be fired. {err}" ); } @@ -370,17 +373,18 @@ impl PeerConnectionInternal { } }; - let stream = match srtcp_session.accept().await { - Ok(stream) => stream, + match srtcp_session.accept().await { + Ok((stream, _)) => { + let ssrc = stream.get_ssrc(); + log::warn!( + "Incoming unhandled RTCP ssrc({ssrc}), on_track will not be fired" + ); + } Err(err) => { - log::warn!("Failed to accept RTCP {}", err); + log::warn!("Failed to accept RTCP {err}"); return; } }; - log::warn!( - "Incoming unhandled RTCP ssrc({}), on_track will not be fired", - stream.get_ssrc() - ); } }); } @@ -1002,18 +1006,18 @@ impl PeerConnectionInternal { Ok(true) } - async fn handle_incoming_ssrc( + async fn handle_incoming_rtp_stream( self: &Arc, rtp_stream: Arc, - ssrc: SSRC, + payload_type: PayloadType, ) -> Result<()> { + let ssrc = rtp_stream.get_ssrc(); let parsed = match self.remote_description().await.and_then(|rd| rd.parsed) { Some(r) => r, None => return Err(Error::ErrPeerConnRemoteDescriptionNil), }; // If the remote SDP was only one media section the ssrc doesn't have to be explicitly declared - let handled = self.handle_undeclared_ssrc(ssrc, &parsed).await?; - if handled { + if self.handle_undeclared_ssrc(ssrc, &parsed).await? { return Ok(()); } @@ -1046,26 +1050,6 @@ impl PeerConnectionInternal { }) .await; - // Packets that we read as part of simulcast probing that we need to make available - // if we do find a track later. - let mut buffered_packets: VecDeque<(rtp::packet::Packet, Attributes)> = VecDeque::default(); - - let mut buf = vec![0u8; self.setting_engine.get_receive_mtu()]; - let n = rtp_stream.read(&mut buf).await?; - let mut b = &buf[..n]; - - let (mut mid, mut rid, mut rsid, payload_type) = handle_unknown_rtp_packet( - b, - mid_extension_id as u8, - sid_extension_id as u8, - rsid_extension_id as u8, - )?; - - let packet = rtp::packet::Packet::unmarshal(&mut b).unwrap(); - - // TODO: Can we have attributes on the first packets? - buffered_packets.push_back((packet, Attributes::new())); - let params = self .media_engine .get_rtp_parameters_by_payload_type(payload_type) @@ -1089,21 +1073,24 @@ impl PeerConnectionInternal { .streams_for_ssrc(ssrc, &stream_info, &icpr) .await?; - let a = Attributes::new(); + // Packets that we read as part of simulcast probing that we need to make available + // if we do find a track later. + let mut buffered_packets: VecDeque<(rtp::packet::Packet, Attributes)> = VecDeque::default(); + let mut buf = vec![0u8; self.setting_engine.get_receive_mtu()]; + for _ in 0..=SIMULCAST_PROBE_COUNT { + let (pkt, a) = rtp_interceptor + .read(&mut buf, &stream_info.attributes) + .await?; + let (mid, rid, rsid) = get_stream_mid_rid( + &pkt.header, + mid_extension_id as u8, + sid_extension_id as u8, + rsid_extension_id as u8, + )?; + buffered_packets.push_back((pkt, a.clone())); + if mid.is_empty() || (rid.is_empty() && rsid.is_empty()) { - let (pkt, _) = rtp_interceptor.read(&mut buf, &a).await?; - let (m, r, rs, _) = handle_unknown_rtp_packet( - &buf[..n], - mid_extension_id as u8, - sid_extension_id as u8, - rsid_extension_id as u8, - )?; - mid = m; - rid = r; - rsid = rs; - - buffered_packets.push_back((pkt, a.clone())); continue; } @@ -1544,3 +1531,34 @@ fn capitalize(s: &str) -> String { result } + +fn get_stream_mid_rid( + header: &rtp::header::Header, + mid_extension_id: u8, + sid_extension_id: u8, + rsid_extension_id: u8, +) -> Result<(String, String, String)> { + if !header.extension { + return Ok((String::new(), String::new(), String::new())); + } + + let mid = if let Some(payload) = header.get_extension(mid_extension_id) { + String::from_utf8(payload.to_vec())? + } else { + String::new() + }; + + let rid = if let Some(payload) = header.get_extension(sid_extension_id) { + String::from_utf8(payload.to_vec())? + } else { + String::new() + }; + + let srid = if let Some(payload) = header.get_extension(rsid_extension_id) { + String::from_utf8(payload.to_vec())? + } else { + String::new() + }; + + Ok((mid, rid, srid)) +} diff --git a/webrtc/src/rtp_transceiver/mod.rs b/webrtc/src/rtp_transceiver/mod.rs index d15c09886..95d3e6f97 100644 --- a/webrtc/src/rtp_transceiver/mod.rs +++ b/webrtc/src/rtp_transceiver/mod.rs @@ -14,7 +14,6 @@ use portable_atomic::{AtomicBool, AtomicU8}; use serde::{Deserialize, Serialize}; use smol_str::SmolStr; use tokio::sync::{Mutex, OnceCell}; -use util::Unmarshal; use crate::api::media_engine::MediaEngine; use crate::error::{Error, Result}; @@ -523,41 +522,3 @@ pub(crate) async fn satisfy_type_and_direction( None } - -/// handle_unknown_rtp_packet consumes a single RTP Packet and returns information that is helpful -/// for demuxing and handling an unknown SSRC (usually for Simulcast) -pub(crate) fn handle_unknown_rtp_packet( - buf: &[u8], - mid_extension_id: u8, - sid_extension_id: u8, - rsid_extension_id: u8, -) -> Result<(String, String, String, PayloadType)> { - let mut reader = buf; - let rp = rtp::packet::Packet::unmarshal(&mut reader)?; - - if !rp.header.extension { - return Ok((String::new(), String::new(), String::new(), 0)); - } - - let payload_type = rp.header.payload_type; - - let mid = if let Some(payload) = rp.header.get_extension(mid_extension_id) { - String::from_utf8(payload.to_vec())? - } else { - String::new() - }; - - let rid = if let Some(payload) = rp.header.get_extension(sid_extension_id) { - String::from_utf8(payload.to_vec())? - } else { - String::new() - }; - - let srid = if let Some(payload) = rp.header.get_extension(rsid_extension_id) { - String::from_utf8(payload.to_vec())? - } else { - String::new() - }; - - Ok((mid, rid, srid, payload_type)) -}