Compare commits

...

76 Commits

Author SHA1 Message Date
5a8725dac1 Merge pull request #13 from paullouisageneau/stop-method
Stop transports before destruction
2019-12-15 20:19:35 +00:00
cd66a3f987 Added stop method on transports to stop them before destroying them 2019-12-15 20:41:45 +01:00
fc4091a9fc Merge pull request #11 from murat-dogan/master
Fix List TRANSFORM command & address-of-packed-member option
2019-12-15 19:12:41 +00:00
7a49a0cfd8 Fix List TRANSFORM command & address-of-packed-member option 2019-12-15 20:49:57 +03:00
de5aff68e6 Fixed transport synchronization on destruction 2019-12-15 16:35:58 +01:00
5416b66116 Merge pull request #10 from paullouisageneau/big-messages
Large messages support
2019-12-14 22:11:29 +00:00
8b94f22aca Cleanup and added some comments 2019-12-14 22:41:55 +01:00
1ab81731e3 Changed buffer amount low behavior to prevent deadlock situations 2019-12-14 21:13:51 +01:00
5213f12f1a Proper handling of SCTP EOR flag at reception 2019-12-14 17:23:31 +01:00
c5e25bbdbc Implemented max message size negociation 2019-12-14 17:23:31 +01:00
58eea3fcf6 Cleanup and destruction fixes 2019-12-14 17:23:31 +01:00
59517cb0da Fixed transition to disconnected state 2019-12-14 17:23:31 +01:00
e77586fc81 Added support for deprecated PPIDs String Partial and Binary Partial 2019-12-14 17:23:23 +01:00
cafc674689 Changed SCTP to non-blocking to spare a thread and fix blocking on close 2019-12-14 12:52:11 +01:00
96f08cb3c8 Fixed DataChannel recv queue limit 2019-12-12 19:42:32 +01:00
3220bf8ae1 Fixed formatting 2019-12-12 17:31:43 +01:00
7236c06880 Merge pull request #9 from paullouisageneau/send-buffer
Enhanced DataChannel API
2019-12-12 16:26:12 +00:00
89ff113688 Changed sent callback to more generic bufferAmountLow 2019-12-12 17:18:49 +01:00
aa55aa76df Added sendBuffer() methods to DataChannel 2019-12-12 11:07:31 +01:00
2c0955fe57 Some performance tuning for SCTP transport 2019-12-11 23:16:51 +01:00
d7e417ce3f Use structured binding instead of tie 2019-12-11 19:57:04 +01:00
1df2fa559c Added rvalue push to Queue 2019-12-11 16:28:09 +01:00
2d5d2f0486 Merge pull request #8 from paullouisageneau/openssl
OpenSSL as alternative to GnuTLS
2019-12-10 15:20:14 +00:00
e75ae36ba8 Added OpenSSL as an alternative to GnuTLS 2019-12-10 16:18:03 +01:00
585450d13e Better libs decalration in Makefile 2019-12-08 23:06:31 +01:00
bc0666be05 Added Jamfile for easier integration 2019-12-08 23:04:26 +01:00
b14518238a Fixed reorder warning 2019-12-08 23:03:57 +01:00
04df12b581 Merge pull request #7 from paullouisageneau/back-pressure
Back-pressure
2019-12-06 23:04:03 +00:00
08931de03b Added proper destructor for synchronized_callback 2019-12-06 16:59:43 +01:00
e6da6a185f Make DataChannel keep a strong reference on PeerConnection 2019-12-06 16:21:55 +01:00
87bf676428 Fixed SctpTransport::process() 2019-12-06 11:12:06 +01:00
4029a9bb4a Do not receive messages in onMessage is not set 2019-12-04 12:16:39 +01:00
a1df562785 Added available and availableSize getters on DataChannel 2019-12-04 12:00:40 +01:00
5d57b4e214 Used synchronized callbacks for PeerConnection 2019-12-03 12:05:19 +01:00
abdf61e841 Added callback wrapper in Channel 2019-12-03 11:55:54 +01:00
d9bfcbd6be Added sent callback on DataChannel 2019-12-03 11:17:56 +01:00
e55d4e906b Handle state changes atomically 2019-12-02 22:18:29 +01:00
6363361e0c Send thread might now change SCTP state to disconnected 2019-12-02 21:28:37 +01:00
21f43611b6 Some cleanup to thread handling 2019-12-01 22:00:31 +01:00
b20f3b30c0 Remaned runConnect() to runConnectAndSendLoop() 2019-12-01 21:50:44 +01:00
fcb1a2571d Consider DTLS premature termination as remote closing 2019-12-01 21:46:50 +01:00
900c482146 Implemented reading back-pressure and callbacks synchronization 2019-12-01 16:03:50 +01:00
d27ed8aab0 Don't wait for resolve threads on destruction 2019-11-28 09:47:20 +01:00
84219d381d Implemented async sending in SCTP transport 2019-11-27 13:26:02 +01:00
75735fb8d8 Enforced setup:actpass on offers 2019-11-24 21:06:50 +01:00
d2b0d1e07f Added bundle line 2019-11-24 20:14:54 +01:00
6f09bc7a17 Added local and remote address accessors 2019-11-23 19:16:10 +01:00
ac6cae8fc4 Merge pull request #4 from paullouisageneau/dtls-on-connected
Do not wait for ICE Ready
2019-11-22 21:40:59 +01:00
000bef45f6 Renamed ICE transport Ready state to Completed for consistency with web API 2019-11-22 21:32:08 +01:00
c2bba83254 Start DTLS transport on ICE state Connected instead of Ready 2019-11-22 21:30:02 +01:00
5839e9d3db Enabled SCTP MTU discovery 2019-11-22 16:48:52 +01:00
2ff361ab29 Fixed DTLS transport send logic to also fail on non-fatal errors 2019-11-22 16:05:27 +01:00
4f6bdc5135 Merge pull request #3 from aaronalbers/aa_lifetime_fixes_
Fixed lifetime issues
2019-11-22 15:59:11 +01:00
65e584107c Fixed lifetime issues
- Channels are not longer immortal objects
- Fixed (or mitigated) crashes on cleanup
2019-11-21 09:03:15 -07:00
cd0f17e36d Implemented DTLS timeout handling 2019-11-19 15:45:34 +01:00
71bdc94804 Removed full-mode ICE option 2019-11-19 15:45:18 +01:00
648644895c Set DTLS timeouts 2019-11-19 14:14:11 +01:00
2e59e44a83 Fixed ICE controlling role and reduced STUN timeout for gathering 2019-11-19 14:07:35 +01:00
3afc127750 Reset end of candidates status when extracted from description 2019-11-19 12:34:40 +01:00
44cdbab8dc Changed remote description logic to resolve candidates asyncronously 2019-11-19 12:31:15 +01:00
f083815569 Prevent unresolved candidates from going through libnice 2019-11-17 14:19:31 +01:00
6f0bcbb1e6 Fixed nice_agent_parse_remote_sdp() check breaking non-trickle case 2019-11-17 14:00:39 +01:00
ae46162649 Fixed typo in gatheringState() 2019-11-14 19:29:09 +01:00
64ed232d1b Merge pull request #2 from aaronalbers/aa_fix_clang_build_
Fix clang build
2019-11-09 12:41:55 +01:00
0e2c992d1c Fix clang build
- Fix the build for the clang compiler since it doesn't support -Wno-error=format-truncation.
- Added missing `#include <string>` that caused `implicit instantiation of undefined
      template`
2019-11-08 15:53:27 -07:00
a10d47499b Excluded target datachannel-static from all 2019-10-18 20:55:48 +02:00
3528432c5c Removed useless usrsctp defines 2019-10-18 20:47:37 +02:00
dd3012ac35 Added tests to CMakeLists 2019-10-18 20:41:33 +02:00
640144e01d Fixed usrsctp compilation and include directory 2019-10-18 20:26:48 +02:00
dc7a59503a Fixed usrsctp cleanup not triggered 2019-10-09 08:56:11 +02:00
c5f7502397 Fixed compilation of usrsctp with -Wno-error=format-truncation 2019-10-07 21:08:59 +02:00
c8f2b79015 Updated Readme for CMake 2019-10-05 16:12:08 +02:00
defc230dba Added CMake support 2019-10-05 16:08:24 +02:00
81308b2095 Added links for implemented IETF drafts and RFCs 2019-10-01 09:20:47 +02:00
5842be3442 Fixed datachannel open state not toggled on incoming channels 2019-09-30 20:15:44 +02:00
bca4d89f93 Added bidirectional message exchange in test 2019-09-30 19:48:41 +02:00
31 changed files with 1841 additions and 469 deletions

1
.gitignore vendored
View File

@ -1,3 +1,4 @@
build/
*.d *.d
*.o *.o
*.a *.a

108
CMakeLists.txt Normal file
View File

@ -0,0 +1,108 @@
cmake_minimum_required (VERSION 3.7)
project (libdatachannel
DESCRIPTION "WebRTC Data Channels Library"
VERSION 0.2.1
LANGUAGES CXX)
set(CMAKE_POSITION_INDEPENDENT_CODE ON)
set(CMAKE_MODULE_PATH ${PROJECT_SOURCE_DIR}/cmake/Modules)
set(LIBDATACHANNEL_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/src/candidate.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/certificate.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/channel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/configuration.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/datachannel.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/description.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/dtlstransport.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/icetransport.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/peerconnection.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/rtc.cpp
${CMAKE_CURRENT_SOURCE_DIR}/src/sctptransport.cpp
)
set(TESTS_SOURCES
${CMAKE_CURRENT_SOURCE_DIR}/test/main.cpp
)
# Hack because usrsctp uses CMAKE_SOURCE_DIR instead of CMAKE_CURRENT_SOURCE_DIR
set(CMAKE_REQUIRED_FLAGS "-I${CMAKE_CURRENT_SOURCE_DIR}/usrsctp/usrsctplib")
add_subdirectory(usrsctp EXCLUDE_FROM_ALL)
# Set include directory and custom options to make usrsctp compile with recent g++
target_include_directories(usrsctp-static PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/usrsctp/usrsctplib)
if (CMAKE_CXX_COMPILER_ID MATCHES "Clang")
# using regular Clang or AppleClang: Needed since they don't have -Wno-error=format-truncation
target_compile_options(usrsctp-static PRIVATE -Wno-error=address-of-packed-member)
else()
if (CMAKE_CXX_COMPILER_ID MATCHES "GNU")
if(CMAKE_CXX_COMPILER_VERSION VERSION_LESS "9.0")
# GCC version below 9.0 does not support option address-of-packed-member
target_compile_options(usrsctp-static PRIVATE -Wno-error=format-truncation)
else()
target_compile_options(usrsctp-static PRIVATE -Wno-error=address-of-packed-member -Wno-error=format-truncation)
endif()
else()
# all other compilers
target_compile_options(usrsctp-static PRIVATE -Wno-error=address-of-packed-member -Wno-error=format-truncation)
endif()
endif()
option(USE_GNUTLS "Use GnuTLS instead of OpenSSL" OFF)
find_package(LibNice REQUIRED)
add_library(datachannel SHARED ${LIBDATACHANNEL_SOURCES})
set_target_properties(datachannel PROPERTIES
VERSION ${PROJECT_VERSION}
CXX_STANDARD 17)
target_include_directories(datachannel PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
target_include_directories(datachannel PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc)
target_include_directories(datachannel PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
target_link_libraries(datachannel usrsctp-static LibNice::LibNice)
add_library(datachannel-static STATIC EXCLUDE_FROM_ALL ${LIBDATACHANNEL_SOURCES})
set_target_properties(datachannel-static PROPERTIES
VERSION ${PROJECT_VERSION}
CXX_STANDARD 17)
target_include_directories(datachannel-static PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}/include)
target_include_directories(datachannel-static PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/include/rtc)
target_include_directories(datachannel-static PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/src)
target_link_libraries(datachannel-static usrsctp-static LibNice::LibNice)
if (USE_GNUTLS)
find_package(GnuTLS REQUIRED)
if(NOT TARGET GnuTLS::GnuTLS)
add_library(GnuTLS::GnuTLS UNKNOWN IMPORTED)
set_target_properties(GnuTLS::GnuTLS PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${GNUTLS_INCLUDE_DIRS}"
INTERFACE_COMPILE_DEFINITIONS "${GNUTLS_DEFINITIONS}"
IMPORTED_LINK_INTERFACE_LANGUAGES "C"
IMPORTED_LOCATION "${GNUTLS_LIBRARIES}")
endif()
target_compile_definitions(datachannel PRIVATE USE_GNUTLS=1)
target_link_libraries(datachannel GnuTLS::GnuTLS)
target_compile_definitions(datachannel-static PRIVATE USE_GNUTLS=1)
target_link_libraries(datachannel-static GnuTLS::GnuTLS)
else()
find_package(OpenSSL REQUIRED)
target_compile_definitions(datachannel PRIVATE USE_GNUTLS=0)
target_link_libraries(datachannel OpenSSL::SSL)
target_compile_definitions(datachannel-static PRIVATE USE_GNUTLS=0)
target_link_libraries(datachannel-static OpenSSL::SSL)
endif()
add_library(LibDataChannel::LibDataChannel ALIAS datachannel)
add_library(LibDataChannel::LibDataChannelStatic ALIAS datachannel-static)
add_executable(tests ${TESTS_SOURCES})
set_target_properties(tests PROPERTIES
VERSION ${PROJECT_VERSION}
CXX_STANDARD 17)
target_link_libraries(tests datachannel)

38
Jamfile Normal file
View File

@ -0,0 +1,38 @@
project libdatachannel ;
path-constant CWD : . ;
lib libdatachannel
: # sources
[ glob ./src/*.cpp ]
: # requirements
<include>./include/rtc
<define>USE_GNUTLS=0
<cxxflags>"`pkg-config --cflags openssl glib-2.0 gobject-2.0 nice`"
<library>/libdatachannel//usrsctp
: # default build
<link>static
: # usage requirements
<include>./include
<cxxflags>-pthread
<linkflags>"`pkg-config --libs openssl glib-2.0 gobject-2.0 nice`"
;
alias usrsctp
: # no sources
: # no build requirements
: # no default build
: # usage requirements
<include>./usrsctp/usrsctplib
<library>libusrsctp.a
;
make libusrsctp.a : : @make_libusrsctp ;
actions make_libusrsctp
{
(cd $(CWD)/usrsctp && \
./bootstrap && \
./configure --enable-static --disable-debug CFLAGS="-fPIC -Wno-address-of-packed-member" && \
make)
cp $(CWD)/usrsctp/usrsctplib/.libs/libusrsctp.a $(<)
}

View File

@ -7,11 +7,20 @@ RM=rm -f
CPPFLAGS=-O2 -pthread -fPIC -Wall -Wno-address-of-packed-member CPPFLAGS=-O2 -pthread -fPIC -Wall -Wno-address-of-packed-member
CXXFLAGS=-std=c++17 CXXFLAGS=-std=c++17
LDFLAGS=-pthread LDFLAGS=-pthread
LDLIBS= -lgnutls $(shell pkg-config --libs glib-2.0 gobject-2.0 nice) LIBS=glib-2.0 gobject-2.0 nice
INCLUDES=-Iinclude/rtc -I$(USRSCTP_DIR)/usrsctplib $(shell pkg-config --cflags glib-2.0 gobject-2.0 nice) USRSCTP_DIR=usrsctp
USRSCTP_DIR:=usrsctp USE_GNUTLS ?= 0
USRSCTP_DEFINES:=-DINET -DINET6 ifeq ($(USE_GNUTLS), 1)
CPPFLAGS+= -DUSE_GNUTLS=1
LIBS+= gnutls
else
CPPFLAGS+= -DUSE_GNUTLS=0
LIBS+= openssl
endif
LDLIBS= $(shell pkg-config --libs $(LIBS))
INCLUDES=-Iinclude/rtc -I$(USRSCTP_DIR)/usrsctplib $(shell pkg-config --cflags $(LIBS))
SRCS=$(shell printf "%s " src/*.cpp) SRCS=$(shell printf "%s " src/*.cpp)
OBJS=$(subst .cpp,.o,$(SRCS)) OBJS=$(subst .cpp,.o,$(SRCS))
@ -19,7 +28,7 @@ OBJS=$(subst .cpp,.o,$(SRCS))
all: $(NAME).a $(NAME).so tests all: $(NAME).a $(NAME).so tests
src/%.o: src/%.cpp src/%.o: src/%.cpp
$(CXX) $(CXXFLAGS) $(CPPFLAGS) $(INCLUDES) $(USRSCTP_DEFINES) -MMD -MP -o $@ -c $< $(CXX) $(CXXFLAGS) $(CPPFLAGS) $(INCLUDES) -MMD -MP -o $@ -c $<
test/%.o: test/%.cpp test/%.o: test/%.cpp
$(CXX) $(CXXFLAGS) $(CPPFLAGS) -Iinclude -MMD -MP -o $@ -c $< $(CXX) $(CXXFLAGS) $(CPPFLAGS) -Iinclude -MMD -MP -o $@ -c $<

View File

@ -8,12 +8,12 @@ Licensed under LGPLv2, see [LICENSE](https://github.com/paullouisageneau/libdata
## Compatibility ## Compatibility
This implementation has been tested to be compatible with Firefox and Chromium. It supports IPv6 and Multicast DNS candidates resolution provided the operating system also supports it. The library aims at fully implementing SCTP DataChannels ([draft-ietf-rtcweb-data-channel-13](https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13)) over DTLS/UDP ([RFC7350](https://tools.ietf.org/html/rfc7350) and [RFC8261](https://tools.ietf.org/html/rfc8261)) and has been tested to be compatible with Firefox and Chromium. It supports IPv6 and Multicast DNS candidates resolution ([draft-ietf-rtcweb-mdns-ice-candidates-03](https://tools.ietf.org/html/draft-ietf-rtcweb-mdns-ice-candidates-03)) provided the operating system also supports it.
## Dependencies ## Dependencies
- libnice: https://github.com/libnice/libnice - libnice: https://github.com/libnice/libnice
- GnuTLS: https://www.gnutls.org/ - GnuTLS: https://www.gnutls.org/ or OpenSSL: https://www.openssl.org/
Submodules: Submodules:
- usrsctp: https://github.com/sctplab/usrsctp - usrsctp: https://github.com/sctplab/usrsctp
@ -22,6 +22,9 @@ Submodules:
```bash ```bash
$ git submodule update --init --recursive $ git submodule update --init --recursive
$ mkdir build
$ cd build
$ cmake -DUSE_GNUTLS=1 ..
$ make $ make
``` ```
@ -47,6 +50,7 @@ pc->onLocalDescription([](const rtc::Description &sdp) {
}); });
pc->onLocalCandidate([](const rtc::Candidate &candidate) { pc->onLocalCandidate([](const rtc::Candidate &candidate) {
// Send the candidate to the remote peer
MY_SEND_CANDIDATE_TO_REMOTE(candidate.candidate(), candidate.mid()); MY_SEND_CANDIDATE_TO_REMOTE(candidate.candidate(), candidate.mid());
}); });

View File

@ -0,0 +1,123 @@
# - Try to find Glib and its components (gio, gobject etc)
# Once done, this will define
#
# GLIB_FOUND - system has Glib
# GLIB_INCLUDE_DIRS - the Glib include directories
# GLIB_LIBRARIES - link these to use Glib
#
# Optionally, the COMPONENTS keyword can be passed to find_package()
# and Glib components can be looked for. Currently, the following
# components can be used, and they define the following variables if
# found:
#
# gio: GLIB_GIO_LIBRARIES
# gobject: GLIB_GOBJECT_LIBRARIES
# gmodule: GLIB_GMODULE_LIBRARIES
# gthread: GLIB_GTHREAD_LIBRARIES
#
# Note that the respective _INCLUDE_DIR variables are not set, since
# all headers are in the same directory as GLIB_INCLUDE_DIRS.
#
# Copyright (C) 2012 Raphael Kubo da Costa <rakuco@webkit.org>
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# 1. Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# 2. Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER AND ITS CONTRIBUTORS ``AS
# IS'' AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
# THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR ITS
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
# OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY,
# WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR
# OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF
# ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
find_package(PkgConfig)
pkg_check_modules(PC_GLIB QUIET glib-2.0)
find_library(GLIB_LIBRARIES
NAMES glib-2.0
HINTS ${PC_GLIB_LIBDIR}
${PC_GLIB_LIBRARY_DIRS}
)
# Files in glib's main include path may include glibconfig.h, which,
# for some odd reason, is normally in $LIBDIR/glib-2.0/include.
get_filename_component(_GLIB_LIBRARY_DIR ${GLIB_LIBRARIES} PATH)
find_path(GLIBCONFIG_INCLUDE_DIR
NAMES glibconfig.h
HINTS ${PC_LIBDIR} ${PC_LIBRARY_DIRS} ${_GLIB_LIBRARY_DIR}
${PC_GLIB_INCLUDEDIR} ${PC_GLIB_INCLUDE_DIRS}
PATH_SUFFIXES glib-2.0/include
)
find_path(GLIB_INCLUDE_DIR
NAMES glib.h
HINTS ${PC_GLIB_INCLUDEDIR}
${PC_GLIB_INCLUDE_DIRS}
PATH_SUFFIXES glib-2.0
)
set(GLIB_INCLUDE_DIRS ${GLIB_INCLUDE_DIR} ${GLIBCONFIG_INCLUDE_DIR})
# Version detection
if (EXISTS "${GLIBCONFIG_INCLUDE_DIR}/glibconfig.h")
file(READ "${GLIBCONFIG_INCLUDE_DIR}/glibconfig.h" GLIBCONFIG_H_CONTENTS)
string(REGEX MATCH "#define GLIB_MAJOR_VERSION ([0-9]+)" _dummy "${GLIBCONFIG_H_CONTENTS}")
set(GLIB_VERSION_MAJOR "${CMAKE_MATCH_1}")
string(REGEX MATCH "#define GLIB_MINOR_VERSION ([0-9]+)" _dummy "${GLIBCONFIG_H_CONTENTS}")
set(GLIB_VERSION_MINOR "${CMAKE_MATCH_1}")
string(REGEX MATCH "#define GLIB_MICRO_VERSION ([0-9]+)" _dummy "${GLIBCONFIG_H_CONTENTS}")
set(GLIB_VERSION_MICRO "${CMAKE_MATCH_1}")
set(GLIB_VERSION "${GLIB_VERSION_MAJOR}.${GLIB_VERSION_MINOR}.${GLIB_VERSION_MICRO}")
endif ()
# Additional Glib components. We only look for libraries, as not all of them
# have corresponding headers and all headers are installed alongside the main
# glib ones.
foreach (_component ${GLIB_FIND_COMPONENTS})
if (${_component} STREQUAL "gio")
find_library(GLIB_GIO_LIBRARIES NAMES gio-2.0 HINTS ${_GLIB_LIBRARY_DIR})
set(ADDITIONAL_REQUIRED_VARS ${ADDITIONAL_REQUIRED_VARS} GLIB_GIO_LIBRARIES)
elseif (${_component} STREQUAL "gobject")
find_library(GLIB_GOBJECT_LIBRARIES NAMES gobject-2.0 HINTS ${_GLIB_LIBRARY_DIR})
set(ADDITIONAL_REQUIRED_VARS ${ADDITIONAL_REQUIRED_VARS} GLIB_GOBJECT_LIBRARIES)
elseif (${_component} STREQUAL "gmodule")
find_library(GLIB_GMODULE_LIBRARIES NAMES gmodule-2.0 HINTS ${_GLIB_LIBRARY_DIR})
set(ADDITIONAL_REQUIRED_VARS ${ADDITIONAL_REQUIRED_VARS} GLIB_GMODULE_LIBRARIES)
elseif (${_component} STREQUAL "gthread")
find_library(GLIB_GTHREAD_LIBRARIES NAMES gthread-2.0 HINTS ${_GLIB_LIBRARY_DIR})
set(ADDITIONAL_REQUIRED_VARS ${ADDITIONAL_REQUIRED_VARS} GLIB_GTHREAD_LIBRARIES)
elseif (${_component} STREQUAL "gio-unix")
# gio-unix is compiled as part of the gio library, but the include paths
# are separate from the shared glib ones. Since this is currently only used
# by WebKitGTK we don't go to extraordinary measures beyond pkg-config.
pkg_check_modules(GIO_UNIX QUIET gio-unix-2.0)
endif ()
endforeach ()
include(FindPackageHandleStandardArgs)
FIND_PACKAGE_HANDLE_STANDARD_ARGS(GLIB REQUIRED_VARS GLIB_INCLUDE_DIRS GLIB_LIBRARIES ${ADDITIONAL_REQUIRED_VARS}
VERSION_VAR GLIB_VERSION)
mark_as_advanced(
GLIBCONFIG_INCLUDE_DIR
GLIB_GIO_LIBRARIES
GLIB_GIO_UNIX_LIBRARIES
GLIB_GMODULE_LIBRARIES
GLIB_GOBJECT_LIBRARIES
GLIB_GTHREAD_LIBRARIES
GLIB_INCLUDE_DIR
GLIB_INCLUDE_DIRS
GLIB_LIBRARIES
)

View File

@ -0,0 +1,35 @@
if (NOT TARGET LibNice::LibNice)
find_package(PkgConfig)
pkg_check_modules(PC_LIBNICE nice)
set(LIBNICE_DEFINITIONS ${PC_LIBNICE_CFLAGS_OTHER})
find_path(LIBNICE_INCLUDE_DIR nice/agent.h
HINTS ${PC_LIBNICE_INCLUDEDIR} ${PC_LIBNICE_INCLUDE_DIRS}
PATH_SUFFICES libnice)
find_library(LIBNICE_LIBRARY NAMES nice libnice
HINTS ${PC_LIBNICE_LIBDIR} ${PC_LIBNICE_LIBRARY_DIRS})
include(FindPackageHandleStandardArgs)
find_package_handle_standard_args(Libnice DEFAULT_MSG
LIBNICE_LIBRARY LIBNICE_INCLUDE_DIR)
mark_as_advanced(LIBNICE_INCLUDE_DIR LIBNICE_LIBRARY)
set(LIBNICE_LIBRARIES ${LIBNICE_LIBRARY})
set(LIBNICE_INCLUDE_DIRS ${LIBNICE_INCLUDE_DIR})
find_package(GLIB REQUIRED COMPONENTS gio gobject gmodule gthread)
list(APPEND LIBNICE_INCLUDE_DIRS ${GLIB_INCLUDE_DIRS})
list(APPEND LIBNICE_LIBRARIES ${GLIB_GOBJECT_LIBRARIES} ${GLIB_LIBRARIES})
if (LIBNICE_FOUND)
add_library(LibNice::LibNice UNKNOWN IMPORTED)
set_target_properties(LibNice::LibNice PROPERTIES
IMPORTED_LOCATION "${LIBNICE_LIBRARY}"
INTERFACE_COMPILE_DEFINITIONS "_REENTRANT"
INTERFACE_INCLUDE_DIRECTORIES "${LIBNICE_INCLUDE_DIRS}"
INTERFACE_LINK_LIBRARIES "${LIBNICE_LIBRARIES}"
IMPORTED_LINK_INTERFACE_LANGUAGES "C")
endif ()
endif ()

View File

@ -29,6 +29,10 @@ class Candidate {
public: public:
Candidate(string candidate, string mid = ""); Candidate(string candidate, string mid = "");
enum class ResolveMode { Simple, Lookup };
bool resolve(ResolveMode mode = ResolveMode::Simple);
bool isResolved() const;
string candidate() const; string candidate() const;
string mid() const; string mid() const;
operator string() const; operator string() const;
@ -36,6 +40,7 @@ public:
private: private:
string mCandidate; string mCandidate;
string mMid; string mMid;
bool mIsResolved;
}; };
} // namespace rtc } // namespace rtc

View File

@ -21,6 +21,7 @@
#include "include.hpp" #include "include.hpp"
#include <atomic>
#include <functional> #include <functional>
#include <variant> #include <variant>
@ -28,32 +29,49 @@ namespace rtc {
class Channel { class Channel {
public: public:
virtual void close(void) = 0; virtual void close() = 0;
virtual void send(const std::variant<binary, string> &data) = 0; virtual bool send(const std::variant<binary, string> &data) = 0; // returns false if buffered
virtual std::optional<std::variant<binary, string>> receive() = 0; // only if onMessage unset
virtual bool isOpen(void) const = 0; virtual bool isOpen() const = 0;
virtual bool isClosed(void) const = 0; virtual bool isClosed() const = 0;
virtual size_t availableAmount() const; // total size available to receive
virtual size_t bufferedAmount() const; // total size buffered to send
void onOpen(std::function<void()> callback); void onOpen(std::function<void()> callback);
void onClosed(std::function<void()> callback); void onClosed(std::function<void()> callback);
void onError(std::function<void(const string &error)> callback); void onError(std::function<void(const string &error)> callback);
void onMessage(std::function<void(const std::variant<binary, string> &data)> callback); void onMessage(std::function<void(const std::variant<binary, string> &data)> callback);
void onMessage(std::function<void(const binary &data)> binaryCallback, void onMessage(std::function<void(const binary &data)> binaryCallback,
std::function<void(const string &data)> stringCallback); std::function<void(const string &data)> stringCallback);
void onAvailable(std::function<void()> callback);
void onBufferedAmountLow(std::function<void()> callback);
void setBufferedAmountLowThreshold(size_t amount);
protected: protected:
virtual void triggerOpen(void); virtual void triggerOpen();
virtual void triggerClosed(void); virtual void triggerClosed();
virtual void triggerError(const string &error); virtual void triggerError(const string &error);
virtual void triggerMessage(const std::variant<binary, string> &data); virtual void triggerAvailable(size_t count);
virtual void triggerBufferedAmount(size_t amount);
private: private:
std::function<void()> mOpenCallback; synchronized_callback<> mOpenCallback;
std::function<void()> mClosedCallback; synchronized_callback<> mClosedCallback;
std::function<void(const string &)> mErrorCallback; synchronized_callback<const string &> mErrorCallback;
std::function<void(const std::variant<binary, string> &)> mMessageCallback; synchronized_callback<const std::variant<binary, string> &> mMessageCallback;
synchronized_callback<> mAvailableCallback;
synchronized_callback<> mBufferedAmountLowCallback;
std::atomic<size_t> mBufferedAmount = 0;
std::atomic<size_t> mBufferedAmountLowThreshold = 0;
}; };
} // namespace rtc } // namespace rtc
#endif // RTC_CHANNEL_H #endif // RTC_CHANNEL_H

View File

@ -22,10 +22,13 @@
#include "channel.hpp" #include "channel.hpp"
#include "include.hpp" #include "include.hpp"
#include "message.hpp" #include "message.hpp"
#include "queue.hpp"
#include "reliability.hpp" #include "reliability.hpp"
#include <atomic>
#include <chrono> #include <chrono>
#include <functional> #include <functional>
#include <type_traits>
#include <variant> #include <variant>
namespace rtc { namespace rtc {
@ -35,39 +38,84 @@ class PeerConnection;
class DataChannel : public Channel { class DataChannel : public Channel {
public: public:
DataChannel(unsigned int stream_, string label_, string protocol_, Reliability reliability_); DataChannel(std::shared_ptr<PeerConnection> pc, unsigned int stream, string label,
DataChannel(unsigned int stream, std::shared_ptr<SctpTransport> sctpTransport); string protocol, Reliability reliability);
DataChannel(std::shared_ptr<PeerConnection> pc, std::shared_ptr<SctpTransport> transport,
unsigned int stream);
~DataChannel(); ~DataChannel();
void close(void); void close(void) override;
void send(const std::variant<binary, string> &data);
void send(const byte *data, size_t size); bool send(const std::variant<binary, string> &data) override;
bool send(const byte *data, size_t size);
template <typename Buffer> bool sendBuffer(const Buffer &buf);
template <typename Iterator> bool sendBuffer(Iterator first, Iterator last);
std::optional<std::variant<binary, string>> receive() override;
bool isOpen(void) const override;
bool isClosed(void) const override;
size_t availableAmount() const override;
size_t maxMessageSize() const; // maximum message size in a call to send or sendBuffer
unsigned int stream() const; unsigned int stream() const;
string label() const; string label() const;
string protocol() const; string protocol() const;
Reliability reliability() const; Reliability reliability() const;
bool isOpen(void) const;
bool isClosed(void) const;
private: private:
void open(std::shared_ptr<SctpTransport> sctpTransport); void open(std::shared_ptr<SctpTransport> sctpTransport);
bool outgoing(mutable_message_ptr message);
void incoming(message_ptr message); void incoming(message_ptr message);
void processOpenMessage(message_ptr message); void processOpenMessage(message_ptr message);
const unsigned int mStream; std::shared_ptr<PeerConnection> mPeerConnection;
std::shared_ptr<SctpTransport> mSctpTransport; std::shared_ptr<SctpTransport> mSctpTransport;
unsigned int mStream;
string mLabel; string mLabel;
string mProtocol; string mProtocol;
std::shared_ptr<Reliability> mReliability; std::shared_ptr<Reliability> mReliability;
bool mIsOpen = false; std::atomic<bool> mIsOpen = false;
bool mIsClosed = false; std::atomic<bool> mIsClosed = false;
Queue<message_ptr> mRecvQueue;
std::atomic<size_t> mRecvAmount = 0;
friend class PeerConnection; friend class PeerConnection;
}; };
template <typename Buffer> std::pair<const byte *, size_t> to_bytes(const Buffer &buf) {
using T = typename std::remove_pointer<decltype(buf.data())>::type;
using E = typename std::conditional<std::is_void<T>::value, byte, T>::type;
return std::make_pair(static_cast<const byte *>(static_cast<const void *>(buf.data())),
buf.size() * sizeof(E));
}
template <typename Buffer> bool DataChannel::sendBuffer(const Buffer &buf) {
auto [bytes, size] = to_bytes(buf);
auto message = std::make_shared<Message>(size);
std::copy(bytes, bytes + size, message->data());
return outgoing(message);
}
template <typename Iterator> bool DataChannel::sendBuffer(Iterator first, Iterator last) {
size_t size = 0;
for (Iterator it = first; it != last; ++it)
size += it->size();
auto message = std::make_shared<Message>(size);
auto pos = message->begin();
for (Iterator it = first; it != last; ++it) {
auto [bytes, size] = to_bytes(*it);
pos = std::copy(bytes, bytes + size, pos);
}
return outgoing(message);
}
} // namespace rtc } // namespace rtc
#endif #endif

View File

@ -44,11 +44,15 @@ public:
string mid() const; string mid() const;
std::optional<string> fingerprint() const; std::optional<string> fingerprint() const;
std::optional<uint16_t> sctpPort() const; std::optional<uint16_t> sctpPort() const;
std::optional<size_t> maxMessageSize() const;
void setFingerprint(string fingerprint); void setFingerprint(string fingerprint);
void setSctpPort(uint16_t port); void setSctpPort(uint16_t port);
void setMaxMessageSize(size_t size);
void addCandidate(Candidate candidate); void addCandidate(Candidate candidate);
void endCandidates(); void endCandidates();
std::vector<Candidate> extractCandidates();
operator string() const; operator string() const;
@ -60,6 +64,7 @@ private:
string mIceUfrag, mIcePwd; string mIceUfrag, mIcePwd;
std::optional<string> mFingerprint; std::optional<string> mFingerprint;
std::optional<uint16_t> mSctpPort; std::optional<uint16_t> mSctpPort;
std::optional<size_t> mMaxMessageSize;
std::vector<Candidate> mCandidates; std::vector<Candidate> mCandidates;
bool mTrickle; bool mTrickle;

View File

@ -20,8 +20,11 @@
#define RTC_INCLUDE_H #define RTC_INCLUDE_H
#include <cstddef> #include <cstddef>
#include <functional>
#include <memory> #include <memory>
#include <mutex>
#include <optional> #include <optional>
#include <string>
#include <vector> #include <vector>
namespace rtc { namespace rtc {
@ -32,6 +35,7 @@ using binary = std::vector<byte>;
using std::nullopt; using std::nullopt;
using std::size_t;
using std::uint16_t; using std::uint16_t;
using std::uint32_t; using std::uint32_t;
using std::uint64_t; using std::uint64_t;
@ -41,9 +45,35 @@ const size_t MAX_NUMERICNODE_LEN = 48; // Max IPv6 string representation length
const size_t MAX_NUMERICSERV_LEN = 6; // Max port string representation length const size_t MAX_NUMERICSERV_LEN = 6; // Max port string representation length
const uint16_t DEFAULT_SCTP_PORT = 5000; // SCTP port to use by default const uint16_t DEFAULT_SCTP_PORT = 5000; // SCTP port to use by default
const size_t DEFAULT_MAX_MESSAGE_SIZE = 65536; // Remote max message size if not specified in SDP
const size_t LOCAL_MAX_MESSAGE_SIZE = 256 * 1024; // Local max message size
template <class... Ts> struct overloaded : Ts... { using Ts::operator()...; }; template <class... Ts> struct overloaded : Ts... { using Ts::operator()...; };
template <class... Ts> overloaded(Ts...)->overloaded<Ts...>; template <class... Ts> overloaded(Ts...)->overloaded<Ts...>;
template <typename... P> class synchronized_callback {
public:
synchronized_callback() = default;
~synchronized_callback() { *this = nullptr; }
synchronized_callback &operator=(std::function<void(P...)> func) {
std::lock_guard<std::recursive_mutex> lock(mutex);
callback = func;
return *this;
}
void operator()(P... args) const {
std::lock_guard<std::recursive_mutex> lock(mutex);
if (callback)
callback(args...);
}
operator bool() const { return callback ? true : false; }
private:
std::function<void(P...)> callback;
mutable std::recursive_mutex mutex;
};
} }
#endif #endif

View File

@ -30,24 +30,33 @@ namespace rtc {
struct Message : binary { struct Message : binary {
enum Type { Binary, String, Control }; enum Type { Binary, String, Control };
Message(size_t size) : binary(size), type(Binary) {}
template <typename Iterator> template <typename Iterator>
Message(Iterator begin_, Iterator end_, Type type_ = Binary, unsigned int stream_ = 0, Message(Iterator begin_, Iterator end_, Type type_ = Binary)
std::shared_ptr<Reliability> reliability_ = nullptr) : binary(begin_, end_), type(type_) {}
: binary(begin_, end_), type(type_), stream(stream_), reliability(reliability_) {}
Type type; Type type;
unsigned int stream; unsigned int stream = 0;
std::shared_ptr<Reliability> reliability; std::shared_ptr<Reliability> reliability;
}; };
using message_ptr = std::shared_ptr<const Message>; using message_ptr = std::shared_ptr<const Message>;
using mutable_message_ptr = std::shared_ptr<Message>;
using message_callback = std::function<void(message_ptr message)>; using message_callback = std::function<void(message_ptr message)>;
constexpr auto message_size_func = [](const message_ptr &m) -> size_t {
return m->type != Message::Control ? m->size() : 0;
};
template <typename Iterator> template <typename Iterator>
message_ptr make_message(Iterator begin, Iterator end, Message::Type type = Message::Binary, message_ptr make_message(Iterator begin, Iterator end, Message::Type type = Message::Binary,
unsigned int stream = 0, unsigned int stream = 0,
std::shared_ptr<Reliability> reliability = nullptr) { std::shared_ptr<Reliability> reliability = nullptr) {
return std::make_shared<Message>(begin, end, type, stream, reliability); auto message = std::make_shared<Message>(begin, end, type);
message->stream = stream;
message->reliability = reliability;
return message;
} }
} // namespace rtc } // namespace rtc

View File

@ -30,6 +30,8 @@
#include <atomic> #include <atomic>
#include <functional> #include <functional>
#include <list>
#include <thread>
#include <unordered_map> #include <unordered_map>
namespace rtc { namespace rtc {
@ -39,7 +41,7 @@ class IceTransport;
class DtlsTransport; class DtlsTransport;
class SctpTransport; class SctpTransport;
class PeerConnection { class PeerConnection : public std::enable_shared_from_this<PeerConnection> {
public: public:
enum class State : int { enum class State : int {
New = RTC_NEW, New = RTC_NEW,
@ -65,6 +67,8 @@ public:
GatheringState gatheringState() const; GatheringState gatheringState() const;
std::optional<Description> localDescription() const; std::optional<Description> localDescription() const;
std::optional<Description> remoteDescription() const; std::optional<Description> remoteDescription() const;
std::optional<string> localAddress() const;
std::optional<string> remoteAddress() const;
void setRemoteDescription(Description description); void setRemoteDescription(Description description);
void addRemoteCandidate(Candidate candidate); void addRemoteCandidate(Candidate candidate);
@ -85,13 +89,14 @@ private:
bool checkFingerprint(const std::string &fingerprint) const; bool checkFingerprint(const std::string &fingerprint) const;
void forwardMessage(message_ptr message); void forwardMessage(message_ptr message);
void forwardBufferedAmount(uint16_t stream, size_t amount);
void iterateDataChannels(std::function<void(std::shared_ptr<DataChannel> channel)> func); void iterateDataChannels(std::function<void(std::shared_ptr<DataChannel> channel)> func);
void openDataChannels(); void openDataChannels();
void closeDataChannels(); void closeDataChannels();
void processLocalDescription(Description description); void processLocalDescription(Description description);
void processLocalCandidate(Candidate candidate); void processLocalCandidate(Candidate candidate);
void triggerDataChannel(std::shared_ptr<DataChannel> dataChannel); void triggerDataChannel(std::weak_ptr<DataChannel> weakDataChannel);
void changeState(State state); void changeState(State state);
void changeGatheringState(GatheringState state); void changeGatheringState(GatheringState state);
@ -110,11 +115,11 @@ private:
std::atomic<State> mState; std::atomic<State> mState;
std::atomic<GatheringState> mGatheringState; std::atomic<GatheringState> mGatheringState;
std::function<void(std::shared_ptr<DataChannel> dataChannel)> mDataChannelCallback; synchronized_callback<std::shared_ptr<DataChannel>> mDataChannelCallback;
std::function<void(const Description &description)> mLocalDescriptionCallback; synchronized_callback<const Description &> mLocalDescriptionCallback;
std::function<void(const Candidate &candidate)> mLocalCandidateCallback; synchronized_callback<const Candidate &> mLocalCandidateCallback;
std::function<void(State state)> mStateChangeCallback; synchronized_callback<State> mStateChangeCallback;
std::function<void(GatheringState state)> mGatheringStateChangeCallback; synchronized_callback<GatheringState> mGatheringStateChangeCallback;
}; };
} // namespace rtc } // namespace rtc

138
include/rtc/queue.hpp Normal file
View File

@ -0,0 +1,138 @@
/**
* Copyright (c) 2019 Paul-Louis Ageneau
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
#ifndef RTC_QUEUE_H
#define RTC_QUEUE_H
#include "include.hpp"
#include <atomic>
#include <chrono>
#include <condition_variable>
#include <mutex>
#include <optional>
#include <queue>
namespace rtc {
template <typename T> class Queue {
public:
using amount_function = std::function<size_t(const T &element)>;
Queue(size_t limit = 0, amount_function func = nullptr);
~Queue();
void stop();
bool empty() const;
size_t size() const; // elements
size_t amount() const; // amount
void push(const T &element);
void push(T &&element);
std::optional<T> pop();
std::optional<T> peek();
void wait();
void wait(const std::chrono::milliseconds &duration);
private:
const size_t mLimit;
size_t mAmount;
std::queue<T> mQueue;
std::condition_variable mPopCondition, mPushCondition;
amount_function mAmountFunction;
bool mStopping = false;
mutable std::mutex mMutex;
};
template <typename T>
Queue<T>::Queue(size_t limit, amount_function func) : mLimit(limit), mAmount(0) {
mAmountFunction = func ? func : [](const T &element) -> size_t { return 1; };
}
template <typename T> Queue<T>::~Queue() { stop(); }
template <typename T> void Queue<T>::stop() {
std::lock_guard<std::mutex> lock(mMutex);
mStopping = true;
mPopCondition.notify_all();
mPushCondition.notify_all();
}
template <typename T> bool Queue<T>::empty() const {
std::lock_guard<std::mutex> lock(mMutex);
return mQueue.empty();
}
template <typename T> size_t Queue<T>::size() const {
std::lock_guard<std::mutex> lock(mMutex);
return mQueue.size();
}
template <typename T> size_t Queue<T>::amount() const {
std::lock_guard<std::mutex> lock(mMutex);
return mAmount;
}
template <typename T> void Queue<T>::push(const T &element) { push(T{element}); }
template <typename T> void Queue<T>::push(T &&element) {
std::unique_lock<std::mutex> lock(mMutex);
mPushCondition.wait(lock, [this]() { return !mLimit || mQueue.size() < mLimit || mStopping; });
if (!mStopping) {
mAmount += mAmountFunction(element);
mQueue.emplace(std::move(element));
mPopCondition.notify_one();
}
}
template <typename T> std::optional<T> Queue<T>::pop() {
std::unique_lock<std::mutex> lock(mMutex);
mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
if (!mQueue.empty()) {
mAmount -= mAmountFunction(mQueue.front());
std::optional<T> element{std::move(mQueue.front())};
mQueue.pop();
return element;
} else {
return nullopt;
}
}
template <typename T> std::optional<T> Queue<T>::peek() {
std::unique_lock<std::mutex> lock(mMutex);
if (!mQueue.empty()) {
return std::optional<T>{mQueue.front()};
} else {
return nullopt;
}
}
template <typename T> void Queue<T>::wait() {
std::unique_lock<std::mutex> lock(mMutex);
mPopCondition.wait(lock, [this]() { return !mQueue.empty() || mStopping; });
}
template <typename T> void Queue<T>::wait(const std::chrono::milliseconds &duration) {
std::unique_lock<std::mutex> lock(mMutex);
mPopCondition.wait_for(lock, duration, [this]() { return !mQueue.empty() || mStopping; });
}
} // namespace rtc
#endif

View File

@ -40,7 +40,7 @@ inline bool hasprefix(const string &str, const string &prefix) {
namespace rtc { namespace rtc {
Candidate::Candidate(string candidate, string mid) { Candidate::Candidate(string candidate, string mid) : mIsResolved(false) {
const std::array prefixes{"a=", "candidate:"}; const std::array prefixes{"a=", "candidate:"};
for (string prefix : prefixes) for (string prefix : prefixes)
if (hasprefix(candidate, prefix)) if (hasprefix(candidate, prefix))
@ -48,6 +48,11 @@ Candidate::Candidate(string candidate, string mid) {
mCandidate = std::move(candidate); mCandidate = std::move(candidate);
mMid = std::move(mid); mMid = std::move(mid);
}
bool Candidate::resolve(ResolveMode mode) {
if (mIsResolved)
return true;
// See RFC 5245 for format // See RFC 5245 for format
std::stringstream ss(mCandidate); std::stringstream ss(mCandidate);
@ -64,6 +69,10 @@ Candidate::Candidate(string candidate, string mid) {
hints.ai_socktype = SOCK_DGRAM; hints.ai_socktype = SOCK_DGRAM;
hints.ai_protocol = IPPROTO_UDP; hints.ai_protocol = IPPROTO_UDP;
} }
if (mode == ResolveMode::Simple)
hints.ai_flags |= AI_NUMERICHOST;
struct addrinfo *result = nullptr; struct addrinfo *result = nullptr;
if (getaddrinfo(node.c_str(), service.c_str(), &hints, &result) == 0) { if (getaddrinfo(node.c_str(), service.c_str(), &hints, &result) == 0) {
for (auto p = result; p; p = p->ai_next) for (auto p = result; p; p = p->ai_next)
@ -83,15 +92,19 @@ Candidate::Candidate(string candidate, string mid) {
if (!left.empty()) if (!left.empty())
ss << left; ss << left;
mCandidate = ss.str(); mCandidate = ss.str();
break; return mIsResolved = true;
} }
} }
} }
freeaddrinfo(result); freeaddrinfo(result);
} }
return false;
} }
bool Candidate::isResolved() const { return mIsResolved; }
string Candidate::candidate() const { return "candidate:" + mCandidate; } string Candidate::candidate() const { return "candidate:" + mCandidate; }
string Candidate::mid() const { return mMid; } string Candidate::mid() const { return mMid; }

View File

@ -25,10 +25,13 @@
#include <sstream> #include <sstream>
#include <unordered_map> #include <unordered_map>
#include <gnutls/crypto.h>
using std::shared_ptr; using std::shared_ptr;
using std::string; using std::string;
using std::unique_ptr;
#if USE_GNUTLS
#include <gnutls/crypto.h>
namespace { namespace {
@ -117,10 +120,10 @@ Certificate::Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey)
"Unable to set certificate and key pair in credentials"); "Unable to set certificate and key pair in credentials");
} }
string Certificate::fingerprint() const { return mFingerprint; }
gnutls_certificate_credentials_t Certificate::credentials() const { return *mCredentials; } gnutls_certificate_credentials_t Certificate::credentials() const { return *mCredentials; }
string Certificate::fingerprint() const { return mFingerprint; }
string make_fingerprint(gnutls_x509_crt_t crt) { string make_fingerprint(gnutls_x509_crt_t crt) {
const size_t size = 32; const size_t size = 32;
unsigned char buffer[size]; unsigned char buffer[size];
@ -177,3 +180,120 @@ shared_ptr<Certificate> make_certificate(const string &commonName) {
} }
} // namespace rtc } // namespace rtc
#else
#include <openssl/err.h>
#include <openssl/pem.h>
#include <openssl/ssl.h>
namespace rtc {
Certificate::Certificate(string crt_pem, string key_pem) {
BIO *bio;
bio = BIO_new(BIO_s_mem());
BIO_write(bio, crt_pem.data(), crt_pem.size());
mX509 = shared_ptr<X509>(PEM_read_bio_X509(bio, nullptr, 0, 0), X509_free);
BIO_free(bio);
if (!mX509)
throw std::invalid_argument("Unable to import certificate PEM");
bio = BIO_new(BIO_s_mem());
BIO_write(bio, key_pem.data(), key_pem.size());
mPKey = shared_ptr<EVP_PKEY>(PEM_read_bio_PrivateKey(bio, nullptr, 0, 0), EVP_PKEY_free);
BIO_free(bio);
if (!mPKey)
throw std::invalid_argument("Unable to import PEM key PEM");
mFingerprint = make_fingerprint(mX509.get());
}
Certificate::Certificate(shared_ptr<X509> x509, shared_ptr<EVP_PKEY> pkey) :
mX509(std::move(x509)), mPKey(std::move(pkey))
{
mFingerprint = make_fingerprint(mX509.get());
}
string Certificate::fingerprint() const { return mFingerprint; }
std::tuple<X509 *, EVP_PKEY *> Certificate::credentials() const { return {mX509.get(), mPKey.get()}; }
string make_fingerprint(X509 *x509) {
const size_t size = 32;
unsigned char buffer[size];
unsigned int len = size;
if (!X509_digest(x509, EVP_sha256(), buffer, &len))
throw std::runtime_error("X509 fingerprint error");
std::ostringstream oss;
oss << std::hex << std::uppercase << std::setfill('0');
for (size_t i = 0; i < len; ++i) {
if (i)
oss << std::setw(1) << ':';
oss << std::setw(2) << unsigned(buffer[i]);
}
return oss.str();
}
shared_ptr<Certificate> make_certificate(const string &commonName) {
static std::unordered_map<string, shared_ptr<Certificate>> cache;
static std::mutex cacheMutex;
std::lock_guard<std::mutex> lock(cacheMutex);
if (auto it = cache.find(commonName); it != cache.end())
return it->second;
if (cache.empty()) {
// This is the first call to OpenSSL
OPENSSL_init_ssl(0, NULL);
SSL_load_error_strings();
ERR_load_crypto_strings();
}
shared_ptr<X509> x509(X509_new(), X509_free);
shared_ptr<EVP_PKEY> pkey(EVP_PKEY_new(), EVP_PKEY_free);
unique_ptr<RSA, decltype(&RSA_free)> rsa(RSA_new(), RSA_free);
unique_ptr<BIGNUM, decltype(&BN_free)> exponent(BN_new(), BN_free);
unique_ptr<BIGNUM, decltype(&BN_free)> serial_number(BN_new(), BN_free);
unique_ptr<X509_NAME, decltype(&X509_NAME_free)> name(X509_NAME_new(), X509_NAME_free);
if (!x509 || !pkey || !rsa || !exponent || !serial_number || !name)
throw std::runtime_error("Unable allocate structures for certificate generation");
const int bits = 4096;
const unsigned int e = 65537; // 2^16 + 1
if (!pkey || !rsa || !exponent || !BN_set_word(exponent.get(), e) ||
!RSA_generate_key_ex(rsa.get(), bits, exponent.get(), NULL) ||
!EVP_PKEY_assign_RSA(pkey.get(), rsa.release())) // the key will be freed when pkey is freed
throw std::runtime_error("Unable to generate key pair");
const size_t serialSize = 16;
const auto *commonNameBytes = reinterpret_cast<const unsigned char *>(commonName.c_str());
if (!X509_gmtime_adj(X509_get_notBefore(x509.get()), 3600 * -1) ||
!X509_gmtime_adj(X509_get_notAfter(x509.get()), 3600 * 24 * 365) ||
!X509_set_version(x509.get(), 1) || !X509_set_pubkey(x509.get(), pkey.get()) ||
!BN_pseudo_rand(serial_number.get(), serialSize, 0, 0) ||
!BN_to_ASN1_INTEGER(serial_number.get(), X509_get_serialNumber(x509.get())) ||
!X509_NAME_add_entry_by_NID(name.get(), NID_commonName, MBSTRING_UTF8, commonNameBytes, -1,
-1, 0) ||
!X509_set_subject_name(x509.get(), name.get()) ||
!X509_set_issuer_name(x509.get(), name.get()))
throw std::runtime_error("Unable to set certificate properties");
if (!X509_sign(x509.get(), pkey.get(), EVP_sha256()))
throw std::runtime_error("Unable to auto-sign certificate");
auto certificate = std::make_shared<Certificate>(x509, pkey);
cache.emplace(std::make_pair(commonName, certificate));
return certificate;
}
} // namespace rtc
#endif

View File

@ -21,24 +21,47 @@
#include "include.hpp" #include "include.hpp"
#include <tuple>
#if USE_GNUTLS
#include <gnutls/x509.h> #include <gnutls/x509.h>
#else
#include <openssl/x509.h>
#endif
namespace rtc { namespace rtc {
class Certificate { class Certificate {
public: public:
Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey);
Certificate(string crt_pem, string key_pem); Certificate(string crt_pem, string key_pem);
string fingerprint() const; #if USE_GNUTLS
Certificate(gnutls_x509_crt_t crt, gnutls_x509_privkey_t privkey);
gnutls_certificate_credentials_t credentials() const; gnutls_certificate_credentials_t credentials() const;
#else
Certificate(std::shared_ptr<X509> x509, std::shared_ptr<EVP_PKEY> pkey);
std::tuple<X509 *, EVP_PKEY *> credentials() const;
#endif
string fingerprint() const;
private: private:
#if USE_GNUTLS
std::shared_ptr<gnutls_certificate_credentials_t> mCredentials; std::shared_ptr<gnutls_certificate_credentials_t> mCredentials;
#else
std::shared_ptr<X509> mX509;
std::shared_ptr<EVP_PKEY> mPKey;
#endif
string mFingerprint; string mFingerprint;
}; };
#if USE_GNUTLS
string make_fingerprint(gnutls_x509_crt_t crt); string make_fingerprint(gnutls_x509_crt_t crt);
#else
string make_fingerprint(X509 *x509);
#endif
std::shared_ptr<Certificate> make_certificate(const string &commonName); std::shared_ptr<Certificate> make_certificate(const string &commonName);
} // namespace rtc } // namespace rtc

View File

@ -18,11 +18,17 @@
#include "channel.hpp" #include "channel.hpp"
namespace {}
namespace rtc { namespace rtc {
void Channel::onOpen(std::function<void()> callback) { mOpenCallback = callback; } void Channel::onOpen(std::function<void()> callback) {
mOpenCallback = callback;
}
void Channel::onClosed(std::function<void()> callback) { mClosedCallback = callback; } void Channel::onClosed(std::function<void()> callback) {
mClosedCallback = callback;
}
void Channel::onError(std::function<void(const string &error)> callback) { void Channel::onError(std::function<void(const string &error)> callback) {
mErrorCallback = callback; mErrorCallback = callback;
@ -30,6 +36,10 @@ void Channel::onError(std::function<void(const string &error)> callback) {
void Channel::onMessage(std::function<void(const std::variant<binary, string> &data)> callback) { void Channel::onMessage(std::function<void(const std::variant<binary, string> &data)> callback) {
mMessageCallback = callback; mMessageCallback = callback;
// Pass pending messages
while (auto message = receive())
mMessageCallback(*message);
} }
void Channel::onMessage(std::function<void(const binary &data)> binaryCallback, void Channel::onMessage(std::function<void(const binary &data)> binaryCallback,
@ -39,24 +49,43 @@ void Channel::onMessage(std::function<void(const binary &data)> binaryCallback,
}); });
} }
void Channel::triggerOpen(void) { void Channel::onAvailable(std::function<void()> callback) {
if (mOpenCallback) mAvailableCallback = callback;
mOpenCallback();
} }
void Channel::triggerClosed(void) { void Channel::onBufferedAmountLow(std::function<void()> callback) {
if (mClosedCallback) mBufferedAmountLowCallback = callback;
mClosedCallback();
} }
void Channel::triggerError(const string &error) { size_t Channel::availableAmount() const { return 0; }
if (mErrorCallback)
mErrorCallback(error); size_t Channel::bufferedAmount() const { return mBufferedAmount; }
void Channel::setBufferedAmountLowThreshold(size_t amount) { mBufferedAmountLowThreshold = amount; }
void Channel::triggerOpen() { mOpenCallback(); }
void Channel::triggerClosed() { mClosedCallback(); }
void Channel::triggerError(const string &error) { mErrorCallback(error); }
void Channel::triggerAvailable(size_t count) {
if (count == 1)
mAvailableCallback();
while (mMessageCallback && count--) {
auto message = receive();
if (!message)
break;
mMessageCallback(*message);
}
} }
void Channel::triggerMessage(const std::variant<binary, string> &data) { void Channel::triggerBufferedAmount(size_t amount) {
if (mMessageCallback) size_t previous = mBufferedAmount.exchange(amount);
mMessageCallback(data); size_t threshold = mBufferedAmountLowThreshold.load();
if (previous > threshold && amount <= threshold)
mBufferedAmountLowCallback();
} }
} // namespace rtc } // namespace rtc

View File

@ -17,6 +17,7 @@
*/ */
#include "datachannel.hpp" #include "datachannel.hpp"
#include "include.hpp"
#include "peerconnection.hpp" #include "peerconnection.hpp"
#include "sctptransport.hpp" #include "sctptransport.hpp"
@ -57,48 +58,88 @@ struct CloseMessage {
}; };
#pragma pack(pop) #pragma pack(pop)
DataChannel::DataChannel(unsigned int stream, string label, string protocol, const size_t RECV_QUEUE_LIMIT = 1024 * 1024; // 1 MiB
Reliability reliability)
: mStream(stream), mLabel(std::move(label)), mProtocol(std::move(protocol)),
mReliability(std::make_shared<Reliability>(std::move(reliability))) {}
DataChannel::DataChannel(unsigned int stream, shared_ptr<SctpTransport> sctpTransport) DataChannel::DataChannel(shared_ptr<PeerConnection> pc, unsigned int stream, string label,
: mStream(stream), mSctpTransport(sctpTransport), string protocol, Reliability reliability)
mReliability(std::make_shared<Reliability>()) {} : mPeerConnection(std::move(pc)), mStream(stream), mLabel(std::move(label)),
mProtocol(std::move(protocol)),
mReliability(std::make_shared<Reliability>(std::move(reliability))),
mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {}
DataChannel::DataChannel(shared_ptr<PeerConnection> pc, shared_ptr<SctpTransport> transport,
unsigned int stream)
: mPeerConnection(std::move(pc)), mSctpTransport(transport), mStream(stream),
mReliability(std::make_shared<Reliability>()),
mRecvQueue(RECV_QUEUE_LIMIT, message_size_func) {}
DataChannel::~DataChannel() { close(); } DataChannel::~DataChannel() { close(); }
void DataChannel::close() { void DataChannel::close() {
mIsOpen = false; mIsOpen = false;
if (!mIsClosed) { if (!mIsClosed.exchange(true)) {
mIsClosed = true;
if (mSctpTransport) if (mSctpTransport)
mSctpTransport->reset(mStream); mSctpTransport->reset(mStream);
} }
// Reset mSctpTransport first so SctpTransport is never alive without PeerConnection
mSctpTransport.reset();
mPeerConnection.reset();
} }
void DataChannel::send(const std::variant<binary, string> &data) { bool DataChannel::send(const std::variant<binary, string> &data) {
if (mIsClosed || !mSctpTransport) return std::visit(
return; [&](const auto &d) {
std::visit(
[this](const auto &d) {
using T = std::decay_t<decltype(d)>; using T = std::decay_t<decltype(d)>;
constexpr auto type = std::is_same_v<T, string> ? Message::String : Message::Binary; constexpr auto type = std::is_same_v<T, string> ? Message::String : Message::Binary;
auto *b = reinterpret_cast<const byte *>(d.data()); auto *b = reinterpret_cast<const byte *>(d.data());
// Before the ACK has been received on a DataChannel, all messages must be sent ordered return outgoing(std::make_shared<Message>(b, b + d.size(), type));
auto reliability = mIsOpen ? mReliability : nullptr;
mSctpTransport->send(make_message(b, b + d.size(), type, mStream, reliability));
}, },
data); data);
} }
void DataChannel::send(const byte *data, size_t size) { bool DataChannel::send(const byte *data, size_t size) {
if (mIsClosed || !mSctpTransport) return outgoing(std::make_shared<Message>(data, data + size, Message::Binary));
return; }
auto reliability = mIsOpen ? mReliability : nullptr; std::optional<std::variant<binary, string>> DataChannel::receive() {
mSctpTransport->send(make_message(data, data + size, Message::Binary, mStream, reliability)); while (!mRecvQueue.empty()) {
auto message = *mRecvQueue.pop();
switch (message->type) {
case Message::Control: {
auto raw = reinterpret_cast<const uint8_t *>(message->data());
if (raw[0] == MESSAGE_CLOSE) {
if (mIsOpen) {
close();
triggerClosed();
}
}
break;
}
case Message::String:
return std::make_optional(
string(reinterpret_cast<const char *>(message->data()), message->size()));
case Message::Binary:
return std::make_optional(std::move(*message));
}
}
return nullopt;
}
bool DataChannel::isOpen(void) const { return mIsOpen; }
bool DataChannel::isClosed(void) const { return mIsClosed; }
size_t DataChannel::availableAmount() const { return mRecvQueue.amount(); }
size_t DataChannel::maxMessageSize() const {
size_t max = DEFAULT_MAX_MESSAGE_SIZE;
if (auto description = mPeerConnection->remoteDescription())
if (auto maxMessageSize = description->maxMessageSize())
return *maxMessageSize > 0 ? *maxMessageSize : LOCAL_MAX_MESSAGE_SIZE;
return std::min(max, LOCAL_MAX_MESSAGE_SIZE);
} }
unsigned int DataChannel::stream() const { return mStream; } unsigned int DataChannel::stream() const { return mStream; }
@ -109,10 +150,6 @@ string DataChannel::protocol() const { return mProtocol; }
Reliability DataChannel::reliability() const { return *mReliability; } Reliability DataChannel::reliability() const { return *mReliability; }
bool DataChannel::isOpen(void) const { return mIsOpen; }
bool DataChannel::isClosed(void) const { return mIsClosed; }
void DataChannel::open(shared_ptr<SctpTransport> sctpTransport) { void DataChannel::open(shared_ptr<SctpTransport> sctpTransport) {
mSctpTransport = sctpTransport; mSctpTransport = sctpTransport;
@ -144,6 +181,19 @@ void DataChannel::open(shared_ptr<SctpTransport> sctpTransport) {
mSctpTransport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream)); mSctpTransport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
} }
bool DataChannel::outgoing(mutable_message_ptr message) {
if (mIsClosed || !mSctpTransport)
throw std::runtime_error("DataChannel is closed");
if (message->size() > maxMessageSize())
throw std::runtime_error("Message size exceeds limit");
// Before the ACK has been received on a DataChannel, all messages must be sent ordered
message->reliability = mIsOpen ? mReliability : nullptr;
message->stream = mStream;
return mSctpTransport->send(message);
}
void DataChannel::incoming(message_ptr message) { void DataChannel::incoming(message_ptr message) {
switch (message->type) { switch (message->type) {
case Message::Control: { case Message::Control: {
@ -153,16 +203,14 @@ void DataChannel::incoming(message_ptr message) {
processOpenMessage(message); processOpenMessage(message);
break; break;
case MESSAGE_ACK: case MESSAGE_ACK:
if (!mIsOpen) { if (!mIsOpen.exchange(true)) {
mIsOpen = true;
triggerOpen(); triggerOpen();
} }
break; break;
case MESSAGE_CLOSE: case MESSAGE_CLOSE:
if (mIsOpen) { // The close message will be processed in-order in receive()
close(); mRecvQueue.push(message);
triggerClosed(); triggerAvailable(mRecvQueue.size());
}
break; break;
default: default:
// Ignore // Ignore
@ -170,15 +218,15 @@ void DataChannel::incoming(message_ptr message) {
} }
break; break;
} }
case Message::String: { case Message::String:
triggerMessage(string(reinterpret_cast<const char *>(message->data()), message->size())); case Message::Binary:
mRecvQueue.push(message);
triggerAvailable(mRecvQueue.size());
break; break;
} default:
case Message::Binary: { // Ignore
triggerMessage(*message);
break; break;
} }
}
} }
void DataChannel::processOpenMessage(message_ptr message) { void DataChannel::processOpenMessage(message_ptr message) {
@ -220,6 +268,7 @@ void DataChannel::processOpenMessage(message_ptr message) {
mSctpTransport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream)); mSctpTransport->send(make_message(buffer.begin(), buffer.end(), Message::Control, mStream));
mIsOpen = true;
triggerOpen(); triggerOpen();
} }

View File

@ -81,8 +81,10 @@ Description::Description(const string &sdp, Type type, Role role)
mIcePwd = line.substr(line.find(':') + 1); mIcePwd = line.substr(line.find(':') + 1);
} else if (hasprefix(line, "a=sctp-port")) { } else if (hasprefix(line, "a=sctp-port")) {
mSctpPort = uint16_t(std::stoul(line.substr(line.find(':') + 1))); mSctpPort = uint16_t(std::stoul(line.substr(line.find(':') + 1)));
} else if (hasprefix(line, "a=max-message-size")) {
mMaxMessageSize = size_t(std::stoul(line.substr(line.find(':') + 1)));
} else if (hasprefix(line, "a=candidate")) { } else if (hasprefix(line, "a=candidate")) {
mCandidates.emplace_back(Candidate(line.substr(2), mMid)); addCandidate(Candidate(line.substr(2), mMid));
} else if (hasprefix(line, "a=end-of-candidates")) { } else if (hasprefix(line, "a=end-of-candidates")) {
mTrickle = false; mTrickle = false;
} }
@ -103,27 +105,39 @@ std::optional<string> Description::fingerprint() const { return mFingerprint; }
std::optional<uint16_t> Description::sctpPort() const { return mSctpPort; } std::optional<uint16_t> Description::sctpPort() const { return mSctpPort; }
std::optional<size_t> Description::maxMessageSize() const { return mMaxMessageSize; }
void Description::setFingerprint(string fingerprint) { void Description::setFingerprint(string fingerprint) {
mFingerprint.emplace(std::move(fingerprint)); mFingerprint.emplace(std::move(fingerprint));
} }
void Description::setSctpPort(uint16_t port) { mSctpPort.emplace(port); } void Description::setSctpPort(uint16_t port) { mSctpPort.emplace(port); }
void Description::setMaxMessageSize(size_t size) { mMaxMessageSize.emplace(size); }
void Description::addCandidate(Candidate candidate) { void Description::addCandidate(Candidate candidate) {
mCandidates.emplace_back(std::move(candidate)); mCandidates.emplace_back(std::move(candidate));
} }
void Description::endCandidates() { mTrickle = false; } void Description::endCandidates() { mTrickle = false; }
std::vector<Candidate> Description::extractCandidates() {
std::vector<Candidate> result;
std::swap(mCandidates, result);
mTrickle = true;
return result;
}
Description::operator string() const { Description::operator string() const {
if (!mFingerprint) if (!mFingerprint)
throw std::logic_error("Fingerprint must be set to generate a SDP"); throw std::logic_error("Fingerprint must be set to generate a SDP");
std::ostringstream sdp; std::ostringstream sdp;
sdp << "v=0\n"; sdp << "v=0\n";
sdp << "o=- " << mSessionId << " 0 IN IP4 0.0.0.0\n"; sdp << "o=- " << mSessionId << " 0 IN IP4 127.0.0.1\n";
sdp << "s=-\n"; sdp << "s=-\n";
sdp << "t=0 0\n"; sdp << "t=0 0\n";
sdp << "a=group:BUNDLE 0\n";
sdp << "m=application 9 UDP/DTLS/SCTP webrtc-datachannel\n"; sdp << "m=application 9 UDP/DTLS/SCTP webrtc-datachannel\n";
sdp << "c=IN IP4 0.0.0.0\n"; sdp << "c=IN IP4 0.0.0.0\n";
sdp << "a=ice-ufrag:" << mIceUfrag << "\n"; sdp << "a=ice-ufrag:" << mIceUfrag << "\n";
@ -137,7 +151,8 @@ Description::operator string() const {
sdp << "a=fingerprint:sha-256 " << *mFingerprint << "\n"; sdp << "a=fingerprint:sha-256 " << *mFingerprint << "\n";
if (mSctpPort) if (mSctpPort)
sdp << "a=sctp-port:" << *mSctpPort << "\n"; sdp << "a=sctp-port:" << *mSctpPort << "\n";
if (mMaxMessageSize)
sdp << "a=max-message-size:" << *mMaxMessageSize << "\n";
for (const auto &candidate : mCandidates) { for (const auto &candidate : mCandidates) {
sdp << string(candidate) << "\n"; sdp << string(candidate) << "\n";
} }

View File

@ -24,10 +24,14 @@
#include <exception> #include <exception>
#include <iostream> #include <iostream>
#include <gnutls/dtls.h>
using std::shared_ptr; using std::shared_ptr;
using std::string; using std::string;
using std::unique_ptr;
using std::weak_ptr;
#if USE_GNUTLS
#include <gnutls/dtls.h>
namespace { namespace {
@ -44,8 +48,6 @@ static bool check_gnutls(int ret, const string &message = "GnuTLS error") {
namespace rtc { namespace rtc {
using std::shared_ptr;
DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate, DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
verifier_callback verifierCallback, verifier_callback verifierCallback,
state_callback stateChangeCallback) state_callback stateChangeCallback)
@ -58,73 +60,110 @@ DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certific
unsigned int flags = GNUTLS_DATAGRAM | (active ? GNUTLS_CLIENT : GNUTLS_SERVER); unsigned int flags = GNUTLS_DATAGRAM | (active ? GNUTLS_CLIENT : GNUTLS_SERVER);
check_gnutls(gnutls_init(&mSession, flags)); check_gnutls(gnutls_init(&mSession, flags));
const char *priorities = "SECURE128:-VERS-SSL3.0:-VERS-TLS1.0:-ARCFOUR-128"; // RFC 8261: SCTP performs segmentation and reassembly based on the path MTU.
// Therefore, the DTLS layer MUST NOT use any compression algorithm.
// See https://tools.ietf.org/html/rfc8261#section-5
const char *priorities = "SECURE128:-VERS-SSL3.0:-ARCFOUR-128:-COMP-ALL";
const char *err_pos = NULL; const char *err_pos = NULL;
check_gnutls(gnutls_priority_set_direct(mSession, priorities, &err_pos), check_gnutls(gnutls_priority_set_direct(mSession, priorities, &err_pos),
"Unable to set TLS priorities"); "Unable to set TLS priorities");
check_gnutls(
gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE, mCertificate->credentials()));
gnutls_dtls_set_mtu(mSession, 1280 - 40 - 8); // min MTU over UDP/IPv6 (only for handshake)
gnutls_dtls_set_timeouts(mSession, 400, 60000);
gnutls_handshake_set_timeout(mSession, 60000);
gnutls_session_set_ptr(mSession, this); gnutls_session_set_ptr(mSession, this);
gnutls_transport_set_ptr(mSession, this); gnutls_transport_set_ptr(mSession, this);
gnutls_transport_set_push_function(mSession, WriteCallback); gnutls_transport_set_push_function(mSession, WriteCallback);
gnutls_transport_set_pull_function(mSession, ReadCallback); gnutls_transport_set_pull_function(mSession, ReadCallback);
gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback); gnutls_transport_set_pull_timeout_function(mSession, TimeoutCallback);
check_gnutls(
gnutls_credentials_set(mSession, GNUTLS_CRD_CERTIFICATE, mCertificate->credentials()));
mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this); mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
} }
DtlsTransport::~DtlsTransport() { DtlsTransport::~DtlsTransport() {
mIncomingQueue.stop();
if (mRecvThread.joinable())
mRecvThread.join();
gnutls_bye(mSession, GNUTLS_SHUT_RDWR); gnutls_bye(mSession, GNUTLS_SHUT_RDWR);
gnutls_deinit(mSession); gnutls_deinit(mSession);
} }
DtlsTransport::State DtlsTransport::state() const { return mState; } DtlsTransport::State DtlsTransport::state() const { return mState; }
void DtlsTransport::stop() {
Transport::stop();
mIncomingQueue.stop();
mRecvThread.join();
}
bool DtlsTransport::send(message_ptr message) { bool DtlsTransport::send(message_ptr message) {
if (!message) if (!message || mState != State::Connected)
return false; return false;
while (true) { ssize_t ret;
ssize_t ret = gnutls_record_send(mSession, message->data(), message->size()); do {
if (check_gnutls(ret)) { ret = gnutls_record_send(mSession, message->data(), message->size());
return ret > 0; } while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
}
} if (ret == GNUTLS_E_LARGE_PACKET)
return false;
return check_gnutls(ret);
} }
void DtlsTransport::incoming(message_ptr message) { mIncomingQueue.push(message); } void DtlsTransport::incoming(message_ptr message) { mIncomingQueue.push(message); }
void DtlsTransport::changeState(State state) { void DtlsTransport::changeState(State state) {
mState = state; if (mState.exchange(state) != state)
mStateChangeCallback(state); mStateChangeCallback(state);
} }
void DtlsTransport::runRecvLoop() { void DtlsTransport::runRecvLoop() {
const size_t maxMtu = 4096;
// Handshake loop
try { try {
changeState(State::Connecting); changeState(State::Connecting);
while (!check_gnutls(gnutls_handshake(mSession), "TLS handshake failed")) { int ret;
} do {
ret = gnutls_handshake(mSession);
if (ret == GNUTLS_E_LARGE_PACKET)
throw std::runtime_error("MTU is too low");
} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN ||
!check_gnutls(ret, "TLS handshake failed"));
// RFC 8261: DTLS MUST support sending messages larger than the current path MTU
// See https://tools.ietf.org/html/rfc8261#section-5
gnutls_dtls_set_mtu(mSession, maxMtu + 1);
} catch (const std::exception &e) { } catch (const std::exception &e) {
std::cerr << "DTLS handshake: " << e.what() << std::endl; std::cerr << "DTLS handshake: " << e.what() << std::endl;
changeState(State::Failed); changeState(State::Failed);
return; return;
} }
// Receive loop
try { try {
changeState(State::Connected); changeState(State::Connected);
const size_t bufferSize = 2048; const size_t bufferSize = maxMtu;
char buffer[bufferSize]; char buffer[bufferSize];
while (true) { while (true) {
ssize_t ret = gnutls_record_recv(mSession, buffer, bufferSize); ssize_t ret;
do {
ret = gnutls_record_recv(mSession, buffer, bufferSize);
} while (ret == GNUTLS_E_INTERRUPTED || ret == GNUTLS_E_AGAIN);
// Consider premature termination as remote closing
if (ret == GNUTLS_E_PREMATURE_TERMINATION)
break;
if (check_gnutls(ret)) { if (check_gnutls(ret)) {
if (ret == 0) { if (ret == 0) {
// Closed // Closed
@ -198,7 +237,245 @@ ssize_t DtlsTransport::ReadCallback(gnutls_transport_ptr_t ptr, void *data, size
} }
int DtlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) { int DtlsTransport::TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms) {
return 1; // So ReadCallback is called DtlsTransport *t = static_cast<DtlsTransport *>(ptr);
if (ms != GNUTLS_INDEFINITE_TIMEOUT)
t->mIncomingQueue.wait(std::chrono::milliseconds(ms));
else
t->mIncomingQueue.wait();
return !t->mIncomingQueue.empty() ? 1 : 0;
} }
} // namespace rtc } // namespace rtc
#else
#include <openssl/bio.h>
#include <openssl/ec.h>
#include <openssl/err.h>
#include <openssl/ssl.h>
namespace {
const int BIO_EOF = -1;
string openssl_error_string(unsigned long err) {
const size_t bufferSize = 256;
char buffer[bufferSize];
ERR_error_string_n(err, buffer, bufferSize);
return string(buffer);
}
bool check_openssl(int success, const string &message = "OpenSSL error") {
if (success)
return true;
else
throw std::runtime_error(message + ": " + openssl_error_string(ERR_get_error()));
}
bool check_openssl_ret(SSL *ssl, int ret, const string &message = "OpenSSL error") {
if (ret == BIO_EOF)
return true;
unsigned long err = SSL_get_error(ssl, ret);
if (err == SSL_ERROR_NONE || err == SSL_ERROR_WANT_READ || err == SSL_ERROR_WANT_WRITE)
return true;
else if (err == SSL_ERROR_ZERO_RETURN)
return false;
else
throw std::runtime_error(message + ": " + openssl_error_string(err));
}
} // namespace
namespace rtc {
int DtlsTransport::TransportExIndex = -1;
std::mutex DtlsTransport::GlobalMutex;
void DtlsTransport::GlobalInit() {
std::lock_guard<std::mutex> lock(GlobalMutex);
if (TransportExIndex < 0) {
TransportExIndex = SSL_get_ex_new_index(0, NULL, NULL, NULL, NULL);
}
}
DtlsTransport::DtlsTransport(shared_ptr<IceTransport> lower, shared_ptr<Certificate> certificate,
verifier_callback verifierCallback, state_callback stateChangeCallback)
: Transport(lower), mCertificate(certificate), mState(State::Disconnected),
mVerifierCallback(std::move(verifierCallback)),
mStateChangeCallback(std::move(stateChangeCallback)) {
GlobalInit();
if (!(mCtx = SSL_CTX_new(DTLS_method())))
throw std::runtime_error("Unable to create SSL context");
check_openssl(SSL_CTX_set_cipher_list(mCtx, "ALL:!LOW:!EXP:!RC4:!MD5:@STRENGTH"),
"Unable to set SSL priorities");
// RFC 8261: SCTP performs segmentation and reassembly based on the path MTU.
// Therefore, the DTLS layer MUST NOT use any compression algorithm.
// See https://tools.ietf.org/html/rfc8261#section-5
SSL_CTX_set_options(mCtx, SSL_OP_NO_SSLv3 | SSL_OP_NO_COMPRESSION);
SSL_CTX_set_min_proto_version(mCtx, DTLS1_VERSION);
SSL_CTX_set_read_ahead(mCtx, 1);
SSL_CTX_set_quiet_shutdown(mCtx, 1);
SSL_CTX_set_info_callback(mCtx, InfoCallback);
SSL_CTX_set_verify(mCtx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
CertificateCallback);
SSL_CTX_set_verify_depth(mCtx, 1);
auto [x509, pkey] = mCertificate->credentials();
SSL_CTX_use_certificate(mCtx, x509);
SSL_CTX_use_PrivateKey(mCtx, pkey);
check_openssl(SSL_CTX_check_private_key(mCtx), "SSL local private key check failed");
if (!(mSsl = SSL_new(mCtx)))
throw std::runtime_error("Unable to create SSL instance");
SSL_set_ex_data(mSsl, TransportExIndex, this);
SSL_set_mtu(mSsl, 1280 - 40 - 8); // min MTU over UDP/IPv6
if (lower->role() == Description::Role::Active)
SSL_set_connect_state(mSsl);
else
SSL_set_accept_state(mSsl);
if (!(mInBio = BIO_new(BIO_s_mem())) || !(mOutBio = BIO_new(BIO_s_mem())))
throw std::runtime_error("Unable to create BIO");
BIO_set_mem_eof_return(mInBio, BIO_EOF);
BIO_set_mem_eof_return(mOutBio, BIO_EOF);
SSL_set_bio(mSsl, mInBio, mOutBio);
auto ecdh = unique_ptr<EC_KEY, decltype(&EC_KEY_free)>(
EC_KEY_new_by_curve_name(NID_X9_62_prime256v1), EC_KEY_free);
SSL_set_options(mSsl, SSL_OP_SINGLE_ECDH_USE);
SSL_set_tmp_ecdh(mSsl, ecdh.get());
mRecvThread = std::thread(&DtlsTransport::runRecvLoop, this);
}
DtlsTransport::~DtlsTransport() {
SSL_shutdown(mSsl);
SSL_free(mSsl);
SSL_CTX_free(mCtx);
}
void DtlsTransport::stop() {
Transport::stop();
mIncomingQueue.stop();
mRecvThread.join();
}
DtlsTransport::State DtlsTransport::state() const { return mState; }
bool DtlsTransport::send(message_ptr message) {
const size_t bufferSize = 4096;
byte buffer[bufferSize];
if (!message || mState != State::Connected)
return false;
int ret = SSL_write(mSsl, message->data(), message->size());
if (!check_openssl_ret(mSsl, ret)) {
return false;
}
while (BIO_ctrl_pending(mOutBio) > 0) {
int ret = BIO_read(mOutBio, buffer, bufferSize);
if (check_openssl_ret(mSsl, ret) && ret > 0)
outgoing(make_message(buffer, buffer + ret));
}
return true;
}
void DtlsTransport::incoming(message_ptr message) { mIncomingQueue.push(message); }
void DtlsTransport::changeState(State state) {
if (mState.exchange(state) != state)
mStateChangeCallback(state);
}
void DtlsTransport::runRecvLoop() {
const size_t bufferSize = 4096;
byte buffer[bufferSize];
try {
changeState(State::Connecting);
SSL_do_handshake(mSsl);
while (BIO_ctrl_pending(mOutBio) > 0) {
int ret = BIO_read(mOutBio, buffer, bufferSize);
if (check_openssl_ret(mSsl, ret) && ret > 0)
outgoing(make_message(buffer, buffer + ret));
}
while (auto next = mIncomingQueue.pop()) {
auto message = *next;
BIO_write(mInBio, message->data(), message->size());
int ret = SSL_read(mSsl, buffer, bufferSize);
if (!check_openssl_ret(mSsl, ret))
break;
auto decrypted = ret > 0 ? make_message(buffer, buffer + ret) : nullptr;
if (mState == State::Connecting) {
if (unsigned long err = ERR_get_error())
throw std::runtime_error("handshake failed: " + openssl_error_string(err));
while (BIO_ctrl_pending(mOutBio) > 0) {
ret = BIO_read(mOutBio, buffer, bufferSize);
if (check_openssl_ret(mSsl, ret) && ret > 0)
outgoing(make_message(buffer, buffer + ret));
}
if (SSL_is_init_finished(mSsl))
changeState(State::Connected);
}
if (decrypted)
recv(decrypted);
}
} catch (const std::exception &e) {
std::cerr << "DTLS recv: " << e.what() << std::endl;
}
if (mState == State::Connected) {
changeState(State::Disconnected);
recv(nullptr);
} else {
changeState(State::Failed);
}
}
int DtlsTransport::CertificateCallback(int preverify_ok, X509_STORE_CTX *ctx) {
SSL *ssl =
static_cast<SSL *>(X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx()));
DtlsTransport *t =
static_cast<DtlsTransport *>(SSL_get_ex_data(ssl, DtlsTransport::TransportExIndex));
X509 *crt = X509_STORE_CTX_get_current_cert(ctx);
std::string fingerprint = make_fingerprint(crt);
return t->mVerifierCallback(fingerprint) ? 1 : 0;
}
void DtlsTransport::InfoCallback(const SSL *ssl, int where, int ret) {
DtlsTransport *t =
static_cast<DtlsTransport *>(SSL_get_ex_data(ssl, DtlsTransport::TransportExIndex));
if (where & SSL_CB_ALERT) {
if (ret != 256) // Close Notify
std::cerr << "DTLS alert: " << SSL_alert_desc_string_long(ret) << std::endl;
t->mIncomingQueue.stop(); // Close the connection
}
}
} // namespace rtc
#endif

View File

@ -28,9 +28,14 @@
#include <atomic> #include <atomic>
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <mutex>
#include <thread> #include <thread>
#if USE_GNUTLS
#include <gnutls/gnutls.h> #include <gnutls/gnutls.h>
#else
#include <openssl/ssl.h>
#endif
namespace rtc { namespace rtc {
@ -49,7 +54,8 @@ public:
State state() const; State state() const;
bool send(message_ptr message); void stop() override;
bool send(message_ptr message); // false if dropped
private: private:
void incoming(message_ptr message); void incoming(message_ptr message);
@ -58,7 +64,6 @@ private:
const std::shared_ptr<Certificate> mCertificate; const std::shared_ptr<Certificate> mCertificate;
gnutls_session_t mSession;
Queue<message_ptr> mIncomingQueue; Queue<message_ptr> mIncomingQueue;
std::atomic<State> mState; std::atomic<State> mState;
std::thread mRecvThread; std::thread mRecvThread;
@ -66,10 +71,25 @@ private:
verifier_callback mVerifierCallback; verifier_callback mVerifierCallback;
state_callback mStateChangeCallback; state_callback mStateChangeCallback;
#if USE_GNUTLS
gnutls_session_t mSession;
static int CertificateCallback(gnutls_session_t session); static int CertificateCallback(gnutls_session_t session);
static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len); static ssize_t WriteCallback(gnutls_transport_ptr_t ptr, const void *data, size_t len);
static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen); static ssize_t ReadCallback(gnutls_transport_ptr_t ptr, void *data, size_t maxlen);
static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms); static int TimeoutCallback(gnutls_transport_ptr_t ptr, unsigned int ms);
#else
SSL_CTX *mCtx;
SSL *mSsl;
BIO *mInBio, *mOutBio;
static int TransportExIndex;
static std::mutex GlobalMutex;
static void GlobalInit();
static int CertificateCallback(int preverify_ok, X509_STORE_CTX *ctx);
static void InfoCallback(const SSL *ssl, int where, int ret);
#endif
}; };
} // namespace rtc } // namespace rtc

View File

@ -59,10 +59,15 @@ IceTransport::IceTransport(const Configuration &config, Description::Role role,
throw std::runtime_error("Failed to create the nice agent"); throw std::runtime_error("Failed to create the nice agent");
mMainLoopThread = std::thread(g_main_loop_run, mMainLoop.get()); mMainLoopThread = std::thread(g_main_loop_run, mMainLoop.get());
g_object_set(G_OBJECT(mNiceAgent.get()), "upnp", FALSE, nullptr);
g_object_set(G_OBJECT(mNiceAgent.get()), "controlling-mode", FALSE, nullptr); g_object_set(G_OBJECT(mNiceAgent.get()), "controlling-mode", TRUE, nullptr);
g_object_set(G_OBJECT(mNiceAgent.get()), "ice-udp", TRUE, nullptr); g_object_set(G_OBJECT(mNiceAgent.get()), "ice-udp", TRUE, nullptr);
g_object_set(G_OBJECT(mNiceAgent.get()), "ice-tcp", FALSE, nullptr); g_object_set(G_OBJECT(mNiceAgent.get()), "ice-tcp", FALSE, nullptr);
g_object_set(G_OBJECT(mNiceAgent.get()), "stun-initial-timeout", 200, nullptr);
g_object_set(G_OBJECT(mNiceAgent.get()), "stun-max-retransmissions", 3, nullptr);
g_object_set(G_OBJECT(mNiceAgent.get()), "stun-pacing-timer", 20, nullptr);
g_object_set(G_OBJECT(mNiceAgent.get()), "upnp", FALSE, nullptr);
g_object_set(G_OBJECT(mNiceAgent.get()), "upnp-timeout", 200, nullptr);
std::vector<IceServer> servers = config.iceServers; std::vector<IceServer> servers = config.iceServers;
unsigned seed = std::chrono::system_clock::now().time_since_epoch().count(); unsigned seed = std::chrono::system_clock::now().time_since_epoch().count();
@ -89,6 +94,7 @@ IceTransport::IceTransport(const Configuration &config, Description::Role role,
char nodebuffer[MAX_NUMERICNODE_LEN]; char nodebuffer[MAX_NUMERICNODE_LEN];
char servbuffer[MAX_NUMERICSERV_LEN]; char servbuffer[MAX_NUMERICSERV_LEN];
if (getnameinfo(p->ai_addr, p->ai_addrlen, nodebuffer, MAX_NUMERICNODE_LEN, if (getnameinfo(p->ai_addr, p->ai_addrlen, nodebuffer, MAX_NUMERICNODE_LEN,
servbuffer, MAX_NUMERICNODE_LEN, servbuffer, MAX_NUMERICNODE_LEN,
NI_NUMERICHOST | NI_NUMERICSERV) == 0) { NI_NUMERICHOST | NI_NUMERICSERV) == 0) {
g_object_set(G_OBJECT(mNiceAgent.get()), "stun-server", nodebuffer, nullptr); g_object_set(G_OBJECT(mNiceAgent.get()), "stun-server", nodebuffer, nullptr);
@ -124,10 +130,11 @@ IceTransport::IceTransport(const Configuration &config, Description::Role role,
RecvCallback, this); RecvCallback, this);
} }
IceTransport::~IceTransport() { IceTransport::~IceTransport() {}
void IceTransport::stop() {
g_main_loop_quit(mMainLoop.get()); g_main_loop_quit(mMainLoop.get());
if (mMainLoopThread.joinable()) mMainLoopThread.join();
mMainLoopThread.join();
} }
Description::Role IceTransport::role() const { return mRole; } Description::Role IceTransport::role() const { return mRole; }
@ -135,6 +142,11 @@ Description::Role IceTransport::role() const { return mRole; }
IceTransport::State IceTransport::state() const { return mState; } IceTransport::State IceTransport::state() const { return mState; }
Description IceTransport::getLocalDescription(Description::Type type) const { Description IceTransport::getLocalDescription(Description::Type type) const {
// RFC 5245: The agent that generated the offer which started the ICE processing MUST take the
// controlling role, and the other MUST take the controlled role.
g_object_set(G_OBJECT(mNiceAgent.get()), "controlling-mode",
type == Description::Type::Offer ? TRUE : FALSE, nullptr);
std::unique_ptr<gchar[], void (*)(void *)> sdp(nice_agent_generate_local_sdp(mNiceAgent.get()), std::unique_ptr<gchar[], void (*)(void *)> sdp(nice_agent_generate_local_sdp(mNiceAgent.get()),
g_free); g_free);
return Description(string(sdp.get()), type, mRole); return Description(string(sdp.get()), type, mRole);
@ -145,20 +157,16 @@ void IceTransport::setRemoteDescription(const Description &description) {
: Description::Role::Active; : Description::Role::Active;
mMid = description.mid(); mMid = description.mid();
if (nice_agent_parse_remote_sdp(mNiceAgent.get(), string(description).c_str())) if (nice_agent_parse_remote_sdp(mNiceAgent.get(), string(description).c_str()) < 0)
throw std::runtime_error("Failed to parse remote SDP"); throw std::runtime_error("Failed to parse remote SDP");
} }
void IceTransport::gatherLocalCandidates() {
// Change state now as candidates calls can be synchronous
changeGatheringState(GatheringState::InProgress);
if (!nice_agent_gather_candidates(mNiceAgent.get(), mStreamId)) {
throw std::runtime_error("Failed to gather local ICE candidates");
}
}
bool IceTransport::addRemoteCandidate(const Candidate &candidate) { bool IceTransport::addRemoteCandidate(const Candidate &candidate) {
// Don't try to pass unresolved candidates to libnice for more safety
if (!candidate.isResolved())
return false;
// Warning: the candidate string must start with "a=candidate:" and it must not end with a // Warning: the candidate string must start with "a=candidate:" and it must not end with a
// newline, else libnice will reject it. // newline, else libnice will reject it.
string sdp(candidate); string sdp(candidate);
@ -174,6 +182,32 @@ bool IceTransport::addRemoteCandidate(const Candidate &candidate) {
return ret > 0; return ret > 0;
} }
void IceTransport::gatherLocalCandidates() {
// Change state now as candidates calls can be synchronous
changeGatheringState(GatheringState::InProgress);
if (!nice_agent_gather_candidates(mNiceAgent.get(), mStreamId)) {
throw std::runtime_error("Failed to gather local ICE candidates");
}
}
std::optional<string> IceTransport::getLocalAddress() const {
NiceCandidate *local = nullptr;
NiceCandidate *remote = nullptr;
if (nice_agent_get_selected_pair(mNiceAgent.get(), mStreamId, 1, &local, &remote)) {
return std::make_optional(AddressToString(local->addr));
}
return nullopt;
}
std::optional<string> IceTransport::getRemoteAddress() const {
NiceCandidate *local = nullptr;
NiceCandidate *remote = nullptr;
if (nice_agent_get_selected_pair(mNiceAgent.get(), mStreamId, 1, &local, &remote)) {
return std::make_optional(AddressToString(remote->addr));
}
return nullopt;
}
bool IceTransport::send(message_ptr message) { bool IceTransport::send(message_ptr message) {
if (!message || !mStreamId) if (!message || !mStreamId)
return false; return false;
@ -194,8 +228,8 @@ void IceTransport::outgoing(message_ptr message) {
} }
void IceTransport::changeState(State state) { void IceTransport::changeState(State state) {
mState = state; if (mState.exchange(state) != state)
mStateChangeCallback(mState); mStateChangeCallback(mState);
} }
void IceTransport::changeGatheringState(GatheringState state) { void IceTransport::changeGatheringState(GatheringState state) {
@ -214,6 +248,15 @@ void IceTransport::processStateChange(uint32_t state) {
changeState(static_cast<State>(state)); changeState(static_cast<State>(state));
} }
string IceTransport::AddressToString(const NiceAddress &addr) {
char buffer[NICE_ADDRESS_STRING_LEN];
nice_address_to_string(&addr, buffer);
unsigned int port = nice_address_get_port(&addr);
std::ostringstream ss;
ss << buffer << ":" << port;
return ss.str();
}
void IceTransport::CandidateCallback(NiceAgent *agent, NiceCandidate *candidate, void IceTransport::CandidateCallback(NiceAgent *agent, NiceCandidate *candidate,
gpointer userData) { gpointer userData) {
auto iceTransport = static_cast<rtc::IceTransport *>(userData); auto iceTransport = static_cast<rtc::IceTransport *>(userData);

View File

@ -41,7 +41,7 @@ public:
Disconnected = NICE_COMPONENT_STATE_DISCONNECTED, Disconnected = NICE_COMPONENT_STATE_DISCONNECTED,
Connecting = NICE_COMPONENT_STATE_CONNECTING, Connecting = NICE_COMPONENT_STATE_CONNECTING,
Connected = NICE_COMPONENT_STATE_CONNECTED, Connected = NICE_COMPONENT_STATE_CONNECTED,
Ready = NICE_COMPONENT_STATE_READY, Completed = NICE_COMPONENT_STATE_READY,
Failed = NICE_COMPONENT_STATE_FAILED Failed = NICE_COMPONENT_STATE_FAILED
}; };
@ -58,13 +58,17 @@ public:
Description::Role role() const; Description::Role role() const;
State state() const; State state() const;
GatheringState gyyatheringState() const; GatheringState gatheringState() const;
Description getLocalDescription(Description::Type type) const; Description getLocalDescription(Description::Type type) const;
void setRemoteDescription(const Description &description); void setRemoteDescription(const Description &description);
void gatherLocalCandidates();
bool addRemoteCandidate(const Candidate &candidate); bool addRemoteCandidate(const Candidate &candidate);
void gatherLocalCandidates();
bool send(message_ptr message); std::optional<string> getLocalAddress() const;
std::optional<string> getRemoteAddress() const;
void stop() override;
bool send(message_ptr message) override; // false if dropped
private: private:
void incoming(message_ptr message); void incoming(message_ptr message);
@ -92,6 +96,8 @@ private:
state_callback mStateChangeCallback; state_callback mStateChangeCallback;
gathering_state_callback mGatheringStateChangeCallback; gathering_state_callback mGatheringStateChangeCallback;
static string AddressToString(const NiceAddress &addr);
static void CandidateCallback(NiceAgent *agent, NiceCandidate *candidate, gpointer userData); static void CandidateCallback(NiceAgent *agent, NiceCandidate *candidate, gpointer userData);
static void GatheringDoneCallback(NiceAgent *agent, guint streamId, gpointer userData); static void GatheringDoneCallback(NiceAgent *agent, guint streamId, gpointer userData);
static void StateChangeCallback(NiceAgent *agent, guint streamId, guint componentId, static void StateChangeCallback(NiceAgent *agent, guint streamId, guint componentId,

View File

@ -28,15 +28,26 @@ namespace rtc {
using namespace std::placeholders; using namespace std::placeholders;
using std::function;
using std::shared_ptr; using std::shared_ptr;
using std::weak_ptr;
PeerConnection::PeerConnection() : PeerConnection(Configuration()) {} PeerConnection::PeerConnection() : PeerConnection(Configuration()) {}
PeerConnection::PeerConnection(const Configuration &config) PeerConnection::PeerConnection(const Configuration &config)
: mConfig(config), mCertificate(make_certificate("libdatachannel")), mState(State::New) {} : mConfig(config), mCertificate(make_certificate("libdatachannel")), mState(State::New) {}
PeerConnection::~PeerConnection() {} PeerConnection::~PeerConnection() {
if (mIceTransport)
mIceTransport->stop();
if (mDtlsTransport)
mDtlsTransport->stop();
if (mSctpTransport)
mSctpTransport->stop();
mSctpTransport.reset();
mDtlsTransport.reset();
mIceTransport.reset();
}
const Configuration *PeerConnection::config() const { return &mConfig; } const Configuration *PeerConnection::config() const { return &mConfig; }
@ -49,33 +60,77 @@ std::optional<Description> PeerConnection::localDescription() const { return mLo
std::optional<Description> PeerConnection::remoteDescription() const { return mRemoteDescription; } std::optional<Description> PeerConnection::remoteDescription() const { return mRemoteDescription; }
void PeerConnection::setRemoteDescription(Description description) { void PeerConnection::setRemoteDescription(Description description) {
if (!mIceTransport) { auto remoteCandidates = description.extractCandidates();
mRemoteDescription.emplace(std::move(description));
if (!mIceTransport)
initIceTransport(Description::Role::ActPass); initIceTransport(Description::Role::ActPass);
mIceTransport->setRemoteDescription(description);
mIceTransport->setRemoteDescription(*mRemoteDescription);
if (mRemoteDescription->type() == Description::Type::Offer) {
// This is an offer and we are the answerer.
processLocalDescription(mIceTransport->getLocalDescription(Description::Type::Answer)); processLocalDescription(mIceTransport->getLocalDescription(Description::Type::Answer));
mIceTransport->gatherLocalCandidates(); mIceTransport->gatherLocalCandidates();
} else { } else {
mIceTransport->setRemoteDescription(description); // This is an answer and we are the offerer.
if (!mSctpTransport && mIceTransport->role() == Description::Role::Active) {
// Since we assumed passive role during DataChannel creation, we need to shift the
// stream numbers by one to shift them from odd to even.
decltype(mDataChannels) newDataChannels;
iterateDataChannels([&](shared_ptr<DataChannel> channel) {
if (channel->stream() % 2 == 1)
channel->mStream -= 1;
newDataChannels.emplace(channel->stream(), channel);
});
std::swap(mDataChannels, newDataChannels);
}
} }
mRemoteDescription.emplace(std::move(description)); for (const auto &candidate : remoteCandidates)
addRemoteCandidate(candidate);
} }
void PeerConnection::addRemoteCandidate(Candidate candidate) { void PeerConnection::addRemoteCandidate(Candidate candidate) {
if (!mRemoteDescription || !mIceTransport) if (!mRemoteDescription || !mIceTransport)
throw std::logic_error("Remote candidate set without remote description"); throw std::logic_error("Remote candidate set without remote description");
if (mIceTransport->addRemoteCandidate(candidate)) mRemoteDescription->addCandidate(candidate);
mRemoteDescription->addCandidate(std::move(candidate));
if (candidate.resolve(Candidate::ResolveMode::Simple)) {
mIceTransport->addRemoteCandidate(candidate);
} else {
// OK, we might need a lookup, do it asynchronously
weak_ptr<IceTransport> weakIceTransport{mIceTransport};
std::thread t([weakIceTransport, candidate]() mutable {
if (candidate.resolve(Candidate::ResolveMode::Lookup))
if (auto iceTransport = weakIceTransport.lock())
iceTransport->addRemoteCandidate(candidate);
});
t.detach();
}
}
std::optional<string> PeerConnection::localAddress() const {
return mIceTransport ? mIceTransport->getLocalAddress() : nullopt;
}
std::optional<string> PeerConnection::remoteAddress() const {
return mIceTransport ? mIceTransport->getRemoteAddress() : nullopt;
} }
shared_ptr<DataChannel> PeerConnection::createDataChannel(const string &label, shared_ptr<DataChannel> PeerConnection::createDataChannel(const string &label,
const string &protocol, const string &protocol,
const Reliability &reliability) { const Reliability &reliability) {
// RFC 5763: The answerer MUST use either a setup attribute value of setup:active or
// setup:passive. [...] Thus, setup:active is RECOMMENDED.
// See https://tools.ietf.org/html/rfc5763#section-5
// Therefore, we assume passive role when we are the offerer.
auto role = mIceTransport ? mIceTransport->role() : Description::Role::Passive;
// The active side must use streams with even identifiers, whereas the passive side must use // The active side must use streams with even identifiers, whereas the passive side must use
// streams with odd identifiers. // streams with odd identifiers.
// See https://tools.ietf.org/html/draft-ietf-rtcweb-data-protocol-09#section-6 // See https://tools.ietf.org/html/draft-ietf-rtcweb-data-protocol-09#section-6
auto role = mIceTransport ? mIceTransport->role() : Description::Role::Active;
unsigned int stream = (role == Description::Role::Active) ? 0 : 1; unsigned int stream = (role == Description::Role::Active) ? 0 : 1;
while (mDataChannels.find(stream) != mDataChannels.end()) { while (mDataChannels.find(stream) != mDataChannels.end()) {
stream += 2; stream += 2;
@ -83,11 +138,15 @@ shared_ptr<DataChannel> PeerConnection::createDataChannel(const string &label,
throw std::runtime_error("Too many DataChannels"); throw std::runtime_error("Too many DataChannels");
} }
auto channel = std::make_shared<DataChannel>(stream, label, protocol, reliability); auto channel =
std::make_shared<DataChannel>(shared_from_this(), stream, label, protocol, reliability);
mDataChannels.insert(std::make_pair(stream, channel)); mDataChannels.insert(std::make_pair(stream, channel));
if (!mIceTransport) { if (!mIceTransport) {
initIceTransport(Description::Role::Active); // RFC 5763: The endpoint that is the offerer MUST use the setup attribute value of
// setup:actpass.
// See https://tools.ietf.org/html/rfc5763#section-5
initIceTransport(Description::Role::ActPass);
processLocalDescription(mIceTransport->getLocalDescription(Description::Type::Offer)); processLocalDescription(mIceTransport->getLocalDescription(Description::Type::Offer));
mIceTransport->gatherLocalCandidates(); mIceTransport->gatherLocalCandidates();
} else if (mSctpTransport && mSctpTransport->state() == SctpTransport::State::Connected) { } else if (mSctpTransport && mSctpTransport->state() == SctpTransport::State::Connected) {
@ -97,7 +156,7 @@ shared_ptr<DataChannel> PeerConnection::createDataChannel(const string &label,
} }
void PeerConnection::onDataChannel( void PeerConnection::onDataChannel(
std::function<void(std::shared_ptr<DataChannel> dataChannel)> callback) { std::function<void(shared_ptr<DataChannel> dataChannel)> callback) {
mDataChannelCallback = callback; mDataChannelCallback = callback;
} }
@ -129,7 +188,7 @@ void PeerConnection::initIceTransport(Description::Role role) {
case IceTransport::State::Failed: case IceTransport::State::Failed:
changeState(State::Failed); changeState(State::Failed);
break; break;
case IceTransport::State::Ready: case IceTransport::State::Connected:
initDtlsTransport(); initDtlsTransport();
break; break;
default: default:
@ -176,6 +235,7 @@ void PeerConnection::initSctpTransport() {
uint16_t sctpPort = mRemoteDescription->sctpPort().value_or(DEFAULT_SCTP_PORT); uint16_t sctpPort = mRemoteDescription->sctpPort().value_or(DEFAULT_SCTP_PORT);
mSctpTransport = std::make_shared<SctpTransport>( mSctpTransport = std::make_shared<SctpTransport>(
mDtlsTransport, sctpPort, std::bind(&PeerConnection::forwardMessage, this, _1), mDtlsTransport, sctpPort, std::bind(&PeerConnection::forwardMessage, this, _1),
std::bind(&PeerConnection::forwardBufferedAmount, this, _1, _2),
[this](SctpTransport::State state) { [this](SctpTransport::State state) {
switch (state) { switch (state) {
case SctpTransport::State::Connected: case SctpTransport::State::Connected:
@ -226,8 +286,10 @@ void PeerConnection::forwardMessage(message_ptr message) {
unsigned int remoteParity = (mIceTransport->role() == Description::Role::Active) ? 1 : 0; unsigned int remoteParity = (mIceTransport->role() == Description::Role::Active) ? 1 : 0;
if (message->type == Message::Control && *message->data() == dataChannelOpenMessage && if (message->type == Message::Control && *message->data() == dataChannelOpenMessage &&
message->stream % 2 == remoteParity) { message->stream % 2 == remoteParity) {
channel = std::make_shared<DataChannel>(message->stream, mSctpTransport); channel =
channel->onOpen(std::bind(&PeerConnection::triggerDataChannel, this, channel)); std::make_shared<DataChannel>(shared_from_this(), mSctpTransport, message->stream);
channel->onOpen(std::bind(&PeerConnection::triggerDataChannel, this,
weak_ptr<DataChannel>{channel}));
mDataChannels.insert(std::make_pair(message->stream, channel)); mDataChannels.insert(std::make_pair(message->stream, channel));
} else { } else {
// Invalid, close the DataChannel by resetting the stream // Invalid, close the DataChannel by resetting the stream
@ -239,6 +301,20 @@ void PeerConnection::forwardMessage(message_ptr message) {
channel->incoming(message); channel->incoming(message);
} }
void PeerConnection::forwardBufferedAmount(uint16_t stream, size_t amount) {
shared_ptr<DataChannel> channel;
if (auto it = mDataChannels.find(stream); it != mDataChannels.end()) {
channel = it->second.lock();
if (!channel || channel->isClosed()) {
mDataChannels.erase(it);
channel = nullptr;
}
}
if (channel)
channel->triggerBufferedAmount(amount);
}
void PeerConnection::iterateDataChannels( void PeerConnection::iterateDataChannels(
std::function<void(shared_ptr<DataChannel> channel)> func) { std::function<void(shared_ptr<DataChannel> channel)> func) {
auto it = mDataChannels.begin(); auto it = mDataChannels.begin();
@ -264,12 +340,12 @@ void PeerConnection::closeDataChannels() {
void PeerConnection::processLocalDescription(Description description) { void PeerConnection::processLocalDescription(Description description) {
auto remoteSctpPort = mRemoteDescription ? mRemoteDescription->sctpPort() : nullopt; auto remoteSctpPort = mRemoteDescription ? mRemoteDescription->sctpPort() : nullopt;
description.setFingerprint(mCertificate->fingerprint());
description.setSctpPort(remoteSctpPort.value_or(DEFAULT_SCTP_PORT));
mLocalDescription.emplace(std::move(description)); mLocalDescription.emplace(std::move(description));
mLocalDescription->setFingerprint(mCertificate->fingerprint());
mLocalDescription->setSctpPort(remoteSctpPort.value_or(DEFAULT_SCTP_PORT));
mLocalDescription->setMaxMessageSize(LOCAL_MAX_MESSAGE_SIZE);
if (mLocalDescriptionCallback) mLocalDescriptionCallback(*mLocalDescription);
mLocalDescriptionCallback(*mLocalDescription);
} }
void PeerConnection::processLocalCandidate(Candidate candidate) { void PeerConnection::processLocalCandidate(Candidate candidate) {
@ -278,24 +354,24 @@ void PeerConnection::processLocalCandidate(Candidate candidate) {
mLocalDescription->addCandidate(candidate); mLocalDescription->addCandidate(candidate);
if (mLocalCandidateCallback) mLocalCandidateCallback(candidate);
mLocalCandidateCallback(candidate);
} }
void PeerConnection::triggerDataChannel(std::shared_ptr<DataChannel> dataChannel) { void PeerConnection::triggerDataChannel(weak_ptr<DataChannel> weakDataChannel) {
if (mDataChannelCallback) auto dataChannel = weakDataChannel.lock();
mDataChannelCallback(dataChannel); if (!dataChannel)
return;
mDataChannelCallback(dataChannel);
} }
void PeerConnection::changeState(State state) { void PeerConnection::changeState(State state) {
mState = state; if (mState.exchange(state) != state)
if (mStateChangeCallback)
mStateChangeCallback(state); mStateChangeCallback(state);
} }
void PeerConnection::changeGatheringState(GatheringState state) { void PeerConnection::changeGatheringState(GatheringState state) {
mGatheringState = state; if (mGatheringState.exchange(state) != state)
if (mGatheringStateChangeCallback)
mGatheringStateChangeCallback(state); mGatheringStateChangeCallback(state);
} }

View File

@ -1,90 +0,0 @@
/**
* Copyright (c) 2019 Paul-Louis Ageneau
*
* This library is free software; you can redistribute it and/or
* modify it under the terms of the GNU Lesser General Public
* License as published by the Free Software Foundation; either
* version 2.1 of the License, or (at your option) any later version.
*
* This library is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
* Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public
* License along with this library; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*/
#ifndef RTC_QUEUE_H
#define RTC_QUEUE_H
#include "include.hpp"
#include <atomic>
#include <condition_variable>
#include <mutex>
#include <optional>
#include <queue>
namespace rtc {
template <typename T> class Queue {
public:
Queue();
~Queue();
void stop();
void push(const T &element);
std::optional<T> pop();
bool empty() const;
private:
std::queue<T> mQueue;
std::condition_variable mCondition;
std::atomic<bool> mStopping;
mutable std::mutex mMutex;
};
template <typename T> Queue<T>::Queue() : mStopping(false) {}
template <typename T> Queue<T>::~Queue() { stop(); }
template <typename T> void Queue<T>::stop() {
std::lock_guard<std::mutex> lock(mMutex);
mStopping = true;
mCondition.notify_all();
}
template <typename T> void Queue<T>::push(const T &element) {
std::lock_guard<std::mutex> lock(mMutex);
if (mStopping)
return;
mQueue.push(element);
mCondition.notify_one();
}
template <typename T> std::optional<T> Queue<T>::pop() {
std::unique_lock<std::mutex> lock(mMutex);
while (mQueue.empty()) {
if (mStopping)
return nullopt;
mCondition.wait(lock);
}
std::optional<T> element = mQueue.front();
mQueue.pop();
return element;
}
template <typename T> bool Queue<T>::empty() const {
std::lock_guard<std::mutex> lock(mMutex);
return mQueue.empty();
}
} // namespace rtc
#endif

View File

@ -42,25 +42,31 @@ void SctpTransport::GlobalInit() {
void SctpTransport::GlobalCleanup() { void SctpTransport::GlobalCleanup() {
std::unique_lock<std::mutex> lock(GlobalMutex); std::unique_lock<std::mutex> lock(GlobalMutex);
if (InstancesCount-- == 0) { if (--InstancesCount == 0) {
usrsctp_finish(); usrsctp_finish();
} }
} }
SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recv, SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port,
message_callback recvCallback, amount_callback bufferedAmountCallback,
state_callback stateChangeCallback) state_callback stateChangeCallback)
: Transport(lower), mPort(port), mState(State::Disconnected), : Transport(lower), mPort(port), mSendQueue(0, message_size_func),
mStateChangeCallback(std::move(stateChangeCallback)) { mBufferedAmountCallback(std::move(bufferedAmountCallback)),
mStateChangeCallback(std::move(stateChangeCallback)), mState(State::Disconnected) {
onRecv(recv); onRecv(recvCallback);
GlobalInit(); GlobalInit();
usrsctp_register_address(this);
mSock = usrsctp_socket(AF_CONN, SOCK_STREAM, IPPROTO_SCTP, &SctpTransport::ReadCallback,
nullptr, 0, this);
if (!mSock)
throw std::runtime_error("Could not create usrsctp socket, errno=" + std::to_string(errno));
usrsctp_register_address(this);
mSock = usrsctp_socket(AF_CONN, SOCK_STREAM, IPPROTO_SCTP, &SctpTransport::RecvCallback,
&SctpTransport::SendCallback, 0, this);
if (!mSock)
throw std::runtime_error("Could not create SCTP socket, errno=" + std::to_string(errno));
if (usrsctp_set_non_blocking(mSock, 1))
throw std::runtime_error("Unable to set non-blocking mode, errno=" + std::to_string(errno));
// SCTP must stop sending after the lower layer is shut down, so disable linger
struct linger sol = {}; struct linger sol = {};
sol.l_onoff = 1; sol.l_onoff = 1;
sol.l_linger = 0; sol.l_linger = 0;
@ -68,14 +74,6 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, me
throw std::runtime_error("Could not set socket option SO_LINGER, errno=" + throw std::runtime_error("Could not set socket option SO_LINGER, errno=" +
std::to_string(errno)); std::to_string(errno));
struct sctp_paddrparams spp = {};
spp.spp_flags = SPP_PMTUD_DISABLE;
spp.spp_pathmtu = 1200; // Max safe value recommended by RFC 8261
// See https://tools.ietf.org/html/rfc8261#section-5
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_PEER_ADDR_PARAMS, &spp, sizeof(spp)))
throw std::runtime_error("Could not set socket option SCTP_PEER_ADDR_PARAMS, errno=" +
std::to_string(errno));
struct sctp_assoc_value av = {}; struct sctp_assoc_value av = {};
av.assoc_id = SCTP_ALL_ASSOC; av.assoc_id = SCTP_ALL_ASSOC;
av.assoc_value = 1; av.assoc_value = 1;
@ -83,17 +81,42 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, me
throw std::runtime_error("Could not set socket option SCTP_ENABLE_STREAM_RESET, errno=" + throw std::runtime_error("Could not set socket option SCTP_ENABLE_STREAM_RESET, errno=" +
std::to_string(errno)); std::to_string(errno));
uint32_t nodelay = 1; struct sctp_event se = {};
se.se_assoc_id = SCTP_ALL_ASSOC;
se.se_on = 1;
se.se_type = SCTP_ASSOC_CHANGE;
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_EVENT, &se, sizeof(se)))
throw std::runtime_error("Could not subscribe to event SCTP_ASSOC_CHANGE, errno=" +
std::to_string(errno));
se.se_type = SCTP_SENDER_DRY_EVENT;
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_EVENT, &se, sizeof(se)))
throw std::runtime_error("Could not subscribe to event SCTP_SENDER_DRY_EVENT, errno=" +
std::to_string(errno));
se.se_type = SCTP_STREAM_RESET_EVENT;
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_EVENT, &se, sizeof(se)))
throw std::runtime_error("Could not subscribe to event SCTP_STREAM_RESET_EVENT, errno=" +
std::to_string(errno));
// The sender SHOULD disable the Nagle algorithm (see RFC1122) to minimize the latency.
// See https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.6
int nodelay = 1;
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_NODELAY, &nodelay, sizeof(nodelay))) if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_NODELAY, &nodelay, sizeof(nodelay)))
throw std::runtime_error("Could not set socket option SCTP_NODELAY, errno=" + throw std::runtime_error("Could not set socket option SCTP_NODELAY, errno=" +
std::to_string(errno)); std::to_string(errno));
struct sctp_event se = {}; struct sctp_paddrparams spp = {};
se.se_assoc_id = SCTP_ALL_ASSOC; #ifdef __linux__
se.se_on = 1; // Linux UDP does path MTU discovery by default (setting DF and returning EMSGSIZE).
se.se_type = SCTP_STREAM_RESET_EVENT; // It should be safe to enable discovery for SCTP.
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_EVENT, &se, sizeof(se))) spp.spp_flags = SPP_PMTUD_ENABLE;
throw std::runtime_error("Could not set socket option SCTP_EVENT, errno=" + #else
// Otherwise, fall back to a safe MTU value.
spp.spp_flags = SPP_PMTUD_DISABLE;
spp.spp_pathmtu = 1200; // Max safe value recommended by RFC 8261
// See https://tools.ietf.org/html/rfc8261#section-5
#endif
if (usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_PEER_ADDR_PARAMS, &spp, sizeof(spp)))
throw std::runtime_error("Could not set socket option SCTP_PEER_ADDR_PARAMS, errno=" +
std::to_string(errno)); std::to_string(errno));
// The IETF draft recommends the number of streams negotiated during SCTP association to be // The IETF draft recommends the number of streams negotiated during SCTP association to be
@ -105,6 +128,48 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, me
throw std::runtime_error("Could not set socket option SCTP_INITMSG, errno=" + throw std::runtime_error("Could not set socket option SCTP_INITMSG, errno=" +
std::to_string(errno)); std::to_string(errno));
// The default send and receive window size of usrsctp is 256KiB, which is too small for
// realistic RTTs, therefore we increase it to 1MiB for better performance.
// See https://bugzilla.mozilla.org/show_bug.cgi?id=1051685
int bufSize = 1024 * 1024;
if (usrsctp_setsockopt(mSock, SOL_SOCKET, SO_RCVBUF, &bufSize, sizeof(bufSize)))
throw std::runtime_error("Could not set SCTP recv buffer size, errno=" +
std::to_string(errno));
if (usrsctp_setsockopt(mSock, SOL_SOCKET, SO_SNDBUF, &bufSize, sizeof(bufSize)))
throw std::runtime_error("Could not set SCTP send buffer size, errno=" +
std::to_string(errno));
connect();
}
SctpTransport::~SctpTransport() {
if (mSock) {
usrsctp_shutdown(mSock, SHUT_RDWR);
usrsctp_close(mSock);
}
usrsctp_deregister_address(this);
GlobalCleanup();
}
SctpTransport::State SctpTransport::state() const { return mState; }
void SctpTransport::stop() {
Transport::stop();
mSendQueue.stop();
// Unblock incoming
if (!mConnectDataSent) {
std::unique_lock<std::mutex> lock(mConnectMutex);
mConnectDataSent = true;
mConnectCondition.notify_all();
}
}
void SctpTransport::connect() {
changeState(State::Connecting);
struct sockaddr_conn sconn = {}; struct sockaddr_conn sconn = {};
sconn.sconn_family = AF_CONN; sconn.sconn_family = AF_CONN;
sconn.sconn_port = htons(mPort); sconn.sconn_port = htons(mPort);
@ -116,34 +181,83 @@ SctpTransport::SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, me
if (usrsctp_bind(mSock, reinterpret_cast<struct sockaddr *>(&sconn), sizeof(sconn))) if (usrsctp_bind(mSock, reinterpret_cast<struct sockaddr *>(&sconn), sizeof(sconn)))
throw std::runtime_error("Could not bind usrsctp socket, errno=" + std::to_string(errno)); throw std::runtime_error("Could not bind usrsctp socket, errno=" + std::to_string(errno));
mConnectThread = std::thread(&SctpTransport::runConnect, this); // According to the IETF draft, both endpoints must initiate the SCTP association, in a
// simultaneous-open manner, irrelevent to the SDP setup role.
// See https://tools.ietf.org/html/draft-ietf-mmusic-sctp-sdp-26#section-9.3
int ret = usrsctp_connect(mSock, reinterpret_cast<struct sockaddr *>(&sconn), sizeof(sconn));
if (ret && errno != EINPROGRESS)
throw std::runtime_error("Connection attempt failed, errno=" + std::to_string(errno));
} }
SctpTransport::~SctpTransport() {
mStopping = true;
mConnectCondition.notify_all();
if (mConnectThread.joinable())
mConnectThread.join();
if (mSock) {
usrsctp_shutdown(mSock, SHUT_RDWR);
usrsctp_close(mSock);
}
usrsctp_deregister_address(this);
GlobalCleanup();
}
SctpTransport::State SctpTransport::state() const { return mState; }
bool SctpTransport::send(message_ptr message) { bool SctpTransport::send(message_ptr message) {
std::lock_guard<std::mutex> lock(mSendMutex);
if (!message) if (!message)
return false; return mSendQueue.empty();
// If nothing is pending, try to send directly
if (mSendQueue.empty() && trySendMessage(message))
return true;
mSendQueue.push(message);
updateBufferedAmount(message->stream, message_size_func(message));
return false;
}
void SctpTransport::reset(unsigned int stream) {
using srs_t = struct sctp_reset_streams;
const size_t len = sizeof(srs_t) + sizeof(uint16_t);
byte buffer[len] = {};
srs_t &srs = *reinterpret_cast<srs_t *>(buffer);
srs.srs_flags = SCTP_STREAM_RESET_OUTGOING;
srs.srs_number_streams = 1;
srs.srs_stream_list[0] = uint16_t(stream);
usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_RESET_STREAMS, &srs, len);
}
void SctpTransport::incoming(message_ptr message) {
if (!message) {
changeState(State::Disconnected);
recv(nullptr);
return;
}
// There could be a race condition here where we receive the remote INIT before the local one is
// sent, which would result in the connection being aborted. Therefore, we need to wait for data
// to be sent on our side (i.e. the local INIT) before proceeding.
if (!mConnectDataSent) {
std::unique_lock<std::mutex> lock(mConnectMutex);
mConnectCondition.wait(lock, [this]() -> bool { return mConnectDataSent; });
}
usrsctp_conninput(this, message->data(), message->size(), 0);
}
void SctpTransport::changeState(State state) {
if (mState.exchange(state) != state)
mStateChangeCallback(state);
}
bool SctpTransport::trySendQueue() {
// Requires mSendMutex to be locked
while (auto next = mSendQueue.peek()) {
auto message = *next;
if (!trySendMessage(message))
return false;
mSendQueue.pop();
updateBufferedAmount(message->stream, -message_size_func(message));
}
return true;
}
bool SctpTransport::trySendMessage(message_ptr message) {
// Requires mSendMutex to be locked
//
// TODO: Implement SCTP ndata specification draft when supported everywhere
// See https://tools.ietf.org/html/draft-ietf-tsvwg-sctp-ndata-08
const Reliability reliability = message->reliability ? *message->reliability : Reliability(); const Reliability reliability = message->reliability ? *message->reliability : Reliability();
struct sctp_sendv_spa spa = {};
uint32_t ppid; uint32_t ppid;
switch (message->type) { switch (message->type) {
case Message::String: case Message::String:
@ -157,11 +271,13 @@ bool SctpTransport::send(message_ptr message) {
break; break;
} }
struct sctp_sendv_spa spa = {};
// set sndinfo // set sndinfo
spa.sendv_flags |= SCTP_SEND_SNDINFO_VALID; spa.sendv_flags |= SCTP_SEND_SNDINFO_VALID;
spa.sendv_sndinfo.snd_sid = uint16_t(message->stream); spa.sendv_sndinfo.snd_sid = uint16_t(message->stream);
spa.sendv_sndinfo.snd_ppid = htonl(ppid); spa.sendv_sndinfo.snd_ppid = htonl(ppid);
spa.sendv_sndinfo.snd_flags |= SCTP_EOR; spa.sendv_sndinfo.snd_flags |= SCTP_EOR; // implicit here
// set prinfo // set prinfo
spa.sendv_flags |= SCTP_SEND_PRINFO_VALID; spa.sendv_flags |= SCTP_SEND_PRINFO_VALID;
@ -185,131 +301,149 @@ bool SctpTransport::send(message_ptr message) {
break; break;
} }
ssize_t ret;
if (!message->empty()) { if (!message->empty()) {
return usrsctp_sendv(mSock, message->data(), message->size(), nullptr, 0, &spa, sizeof(spa), ret = usrsctp_sendv(mSock, message->data(), message->size(), nullptr, 0, &spa, sizeof(spa),
SCTP_SENDV_SPA, 0) > 0; SCTP_SENDV_SPA, 0);
} else { } else {
const char zero = 0; const char zero = 0;
return usrsctp_sendv(mSock, &zero, 1, nullptr, 0, &spa, sizeof(spa), SCTP_SENDV_SPA, 0) > 0; ret = usrsctp_sendv(mSock, &zero, 1, nullptr, 0, &spa, sizeof(spa), SCTP_SENDV_SPA, 0);
}
}
void SctpTransport::reset(unsigned int stream) {
using srs_t = struct sctp_reset_streams;
const size_t len = sizeof(srs_t) + sizeof(uint16_t);
byte buffer[len] = {};
srs_t &srs = *reinterpret_cast<srs_t *>(buffer);
srs.srs_flags = SCTP_STREAM_RESET_OUTGOING;
srs.srs_number_streams = 1;
srs.srs_stream_list[0] = uint16_t(stream);
usrsctp_setsockopt(mSock, IPPROTO_SCTP, SCTP_RESET_STREAMS, &srs, len);
}
void SctpTransport::incoming(message_ptr message) {
if (!message) {
changeState(State::Disconnected);
recv(nullptr);
return;
} }
// There could be a race condition here where we receive the remote INIT before the thread in if (ret >= 0)
// usrsctp_connect sends the local one, which would result in the connection being aborted. return true;
// Therefore, we need to wait for data to be sent on our side (i.e. the local INIT) before else if (errno == EWOULDBLOCK && errno == EAGAIN)
// proceeding. return false;
if (!mConnectDataSent) { else
std::unique_lock<std::mutex> lock(mConnectMutex); throw std::runtime_error("Sending failed, errno=" + std::to_string(errno));
mConnectCondition.wait(lock, [this] { return mConnectDataSent || mStopping; });
}
if (!mStopping)
usrsctp_conninput(this, message->data(), message->size(), 0);
} }
void SctpTransport::changeState(State state) { void SctpTransport::updateBufferedAmount(uint16_t streamId, long delta) {
mState = state; // Requires mSendMutex to be locked
mStateChangeCallback(state); auto it = mBufferedAmount.insert(std::make_pair(streamId, 0)).first;
size_t amount = it->second;
amount = size_t(std::max(long(amount) + delta, long(0)));
if (amount == 0)
mBufferedAmount.erase(it);
mBufferedAmountCallback(streamId, amount);
} }
void SctpTransport::runConnect() { int SctpTransport::handleRecv(struct socket *sock, union sctp_sockstore addr, const byte *data,
size_t len, struct sctp_rcvinfo info, int flags) {
try { try {
changeState(State::Connecting); if (!data) {
recv(nullptr);
struct sockaddr_conn sconn = {}; return 0;
sconn.sconn_family = AF_CONN;
sconn.sconn_port = htons(mPort);
sconn.sconn_addr = this;
#ifdef HAVE_SCONN_LEN
sconn.sconn_len = sizeof(sconn);
#endif
// According to the IETF draft, both endpoints must initiate the SCTP association, in a
// simultaneous-open manner, irrelevent to the SDP setup role.
// See https://tools.ietf.org/html/draft-ietf-mmusic-sctp-sdp-26#section-9.3
if (usrsctp_connect(mSock, reinterpret_cast<struct sockaddr *>(&sconn), sizeof(sconn)) !=
0) {
std::cerr << "SCTP connection failed, errno=" << errno << std::endl;
changeState(State::Failed);
mStopping = true;
return;
} }
if (flags & MSG_EOR) {
if (!mPartialRecv.empty()) {
mPartialRecv.insert(mPartialRecv.end(), data, data + len);
data = mPartialRecv.data();
len = mPartialRecv.size();
}
// Message is complete, process it
if (flags & MSG_NOTIFICATION)
processNotification(reinterpret_cast<const union sctp_notification *>(data), len);
else
processData(data, len, info.rcv_sid, PayloadId(htonl(info.rcv_ppid)));
if (!mStopping) mPartialRecv.clear();
changeState(State::Connected); } else {
// Message is not complete
mPartialRecv.insert(mPartialRecv.end(), data, data + len);
}
} catch (const std::exception &e) { } catch (const std::exception &e) {
std::cerr << "SCTP connect: " << e.what() << std::endl; std::cerr << "SCTP recv: " << e.what() << std::endl;
} return -1;
}
int SctpTransport::handleWrite(void *data, size_t len, uint8_t tos, uint8_t set_df) {
byte *b = reinterpret_cast<byte *>(data);
outgoing(make_message(b, b + len));
if (!mConnectDataSent) {
std::unique_lock<std::mutex> lock(mConnectMutex);
mConnectDataSent = true;
mConnectCondition.notify_all();
} }
return 0; // success return 0; // success
} }
int SctpTransport::process(struct socket *sock, union sctp_sockstore addr, void *data, size_t len, int SctpTransport::handleSend(size_t free) {
struct sctp_rcvinfo info, int flags) { try {
if (flags & MSG_NOTIFICATION) { std::lock_guard<std::mutex> lock(mSendMutex);
processNotification((union sctp_notification *)data, len); trySendQueue();
} else { } catch (const std::exception &e) {
processData((const byte *)data, len, info.rcv_sid, PayloadId(htonl(info.rcv_ppid))); std::cerr << "SCTP send: " << e.what() << std::endl;
return -1;
} }
free(data); return 0; // success
return 0; }
int SctpTransport::handleWrite(byte *data, size_t len, uint8_t tos, uint8_t set_df) {
try {
outgoing(make_message(data, data + len));
if (!mConnectDataSent) {
std::unique_lock<std::mutex> lock(mConnectMutex);
mConnectDataSent = true;
mConnectCondition.notify_all();
}
} catch (const std::exception &e) {
std::cerr << "SCTP write: " << e.what() << std::endl;
return -1;
}
return 0; // success
} }
void SctpTransport::processData(const byte *data, size_t len, uint16_t sid, PayloadId ppid) { void SctpTransport::processData(const byte *data, size_t len, uint16_t sid, PayloadId ppid) {
Message::Type type; // The usage of the PPIDs "WebRTC String Partial" and "WebRTC Binary Partial" is deprecated.
// See https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-6.6
// We handle them at reception for compatibility reasons but should never send them.
switch (ppid) { switch (ppid) {
case PPID_STRING:
type = Message::String;
break;
case PPID_STRING_EMPTY:
type = Message::String;
len = 0;
break;
case PPID_BINARY:
type = Message::Binary;
break;
case PPID_BINARY_EMPTY:
type = Message::Binary;
len = 0;
break;
case PPID_CONTROL: case PPID_CONTROL:
type = Message::Control; recv(make_message(data, data + len, Message::Control, sid));
break; break;
case PPID_STRING_PARTIAL: // deprecated
mPartialStringData.insert(mPartialStringData.end(), data, data + len);
break;
case PPID_STRING:
if (mPartialStringData.empty()) {
recv(make_message(data, data + len, Message::String, sid));
} else {
mPartialStringData.insert(mPartialStringData.end(), data, data + len);
recv(make_message(mPartialStringData.begin(), mPartialStringData.end(), Message::String,
sid));
mPartialStringData.clear();
}
break;
case PPID_STRING_EMPTY:
// This only accounts for when the partial data is empty
recv(make_message(mPartialStringData.begin(), mPartialStringData.end(), Message::String,
sid));
mPartialStringData.clear();
break;
case PPID_BINARY_PARTIAL: // deprecated
mPartialBinaryData.insert(mPartialBinaryData.end(), data, data + len);
break;
case PPID_BINARY:
if (mPartialBinaryData.empty()) {
recv(make_message(data, data + len, Message::Binary, sid));
} else {
mPartialBinaryData.insert(mPartialBinaryData.end(), data, data + len);
recv(make_message(mPartialBinaryData.begin(), mPartialBinaryData.end(), Message::Binary,
sid));
mPartialBinaryData.clear();
}
break;
case PPID_BINARY_EMPTY:
// This only accounts for when the partial data is empty
recv(make_message(mPartialBinaryData.begin(), mPartialBinaryData.end(), Message::Binary,
sid));
mPartialBinaryData.clear();
break;
default: default:
// Unknown // Unknown
std::cerr << "Unknown PPID: " << uint32_t(ppid) << std::endl; std::cerr << "Unknown PPID: " << uint32_t(ppid) << std::endl;
return; return;
} }
recv(make_message(data, data + len, type, sid));
} }
void SctpTransport::processNotification(const union sctp_notification *notify, size_t len) { void SctpTransport::processNotification(const union sctp_notification *notify, size_t len) {
@ -317,21 +451,41 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
return; return;
switch (notify->sn_header.sn_type) { switch (notify->sn_header.sn_type) {
case SCTP_ASSOC_CHANGE: {
const struct sctp_assoc_change &assoc_change = notify->sn_assoc_change;
std::unique_lock<std::mutex> lock(mConnectMutex);
if (assoc_change.sac_state == SCTP_COMM_UP) {
changeState(State::Connected);
} else {
if (mState == State::Connecting) {
std::cerr << "SCTP connection failed" << std::endl;
changeState(State::Failed);
} else {
changeState(State::Disconnected);
}
}
}
case SCTP_SENDER_DRY_EVENT: {
// It not should be necessary since the send callback should have been called already,
// but to be sure, let's try to send now.
std::lock_guard<std::mutex> lock(mSendMutex);
trySendQueue();
}
case SCTP_STREAM_RESET_EVENT: { case SCTP_STREAM_RESET_EVENT: {
const struct sctp_stream_reset_event *reset_event = &notify->sn_strreset_event; const struct sctp_stream_reset_event &reset_event = notify->sn_strreset_event;
const int count = (reset_event->strreset_length - sizeof(*reset_event)) / sizeof(uint16_t); const int count = (reset_event.strreset_length - sizeof(reset_event)) / sizeof(uint16_t);
if (reset_event->strreset_flags & SCTP_STREAM_RESET_INCOMING_SSN) { if (reset_event.strreset_flags & SCTP_STREAM_RESET_INCOMING_SSN) {
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
uint16_t streamId = reset_event->strreset_stream_list[i]; uint16_t streamId = reset_event.strreset_stream_list[i];
reset(streamId); reset(streamId);
} }
} }
if (reset_event->strreset_flags & SCTP_STREAM_RESET_OUTGOING_SSN) { if (reset_event.strreset_flags & SCTP_STREAM_RESET_OUTGOING_SSN) {
const byte dataChannelCloseMessage{0x04}; const byte dataChannelCloseMessage{0x04};
for (int i = 0; i < count; ++i) { for (int i = 0; i < count; ++i) {
uint16_t streamId = reset_event->strreset_stream_list[i]; uint16_t streamId = reset_event.strreset_stream_list[i];
recv(make_message(&dataChannelCloseMessage, &dataChannelCloseMessage + 1, recv(make_message(&dataChannelCloseMessage, &dataChannelCloseMessage + 1,
Message::Control, streamId)); Message::Control, streamId));
} }
@ -344,16 +498,29 @@ void SctpTransport::processNotification(const union sctp_notification *notify, s
break; break;
} }
} }
int SctpTransport::WriteCallback(void *sctp_ptr, void *data, size_t len, uint8_t tos,
uint8_t set_df) { int SctpTransport::RecvCallback(struct socket *sock, union sctp_sockstore addr, void *data,
return static_cast<SctpTransport *>(sctp_ptr)->handleWrite(data, len, tos, set_df); size_t len, struct sctp_rcvinfo recv_info, int flags, void *ptr) {
int ret = static_cast<SctpTransport *>(ptr)->handleRecv(
sock, addr, static_cast<const byte *>(data), len, recv_info, flags);
free(data);
return ret;
} }
int SctpTransport::ReadCallback(struct socket *sock, union sctp_sockstore addr, void *data, int SctpTransport::SendCallback(struct socket *sock, uint32_t sb_free) {
size_t len, struct sctp_rcvinfo recv_info, int flags, struct sctp_paddrinfo paddrinfo = {};
void *user_data) { socklen_t len = sizeof(paddrinfo);
return static_cast<SctpTransport *>(user_data)->process(sock, addr, data, len, recv_info, if (usrsctp_getsockopt(sock, IPPROTO_SCTP, SCTP_GET_PEER_ADDR_INFO, &paddrinfo, &len))
flags); return -1;
auto sconn = reinterpret_cast<struct sockaddr_conn *>(&paddrinfo.spinfo_address);
void *ptr = sconn->sconn_addr;
return static_cast<SctpTransport *>(ptr)->handleSend(size_t(sb_free));
}
int SctpTransport::WriteCallback(void *ptr, void *data, size_t len, uint8_t tos, uint8_t set_df) {
return static_cast<SctpTransport *>(ptr)->handleWrite(static_cast<byte *>(data), len, tos,
set_df);
} }
} // namespace rtc } // namespace rtc

View File

@ -21,16 +21,19 @@
#include "include.hpp" #include "include.hpp"
#include "peerconnection.hpp" #include "peerconnection.hpp"
#include "queue.hpp"
#include "transport.hpp" #include "transport.hpp"
#include <condition_variable> #include <condition_variable>
#include <functional> #include <functional>
#include <map>
#include <mutex> #include <mutex>
#include <thread> #include <thread>
#include <sys/socket.h> #include <sys/socket.h>
#include <sys/types.h> #include <sys/types.h>
#include <usrsctp.h>
#include "usrsctp.h"
namespace rtc { namespace rtc {
@ -38,54 +41,70 @@ class SctpTransport : public Transport {
public: public:
enum class State { Disconnected, Connecting, Connected, Failed }; enum class State { Disconnected, Connecting, Connected, Failed };
using amount_callback = std::function<void(uint16_t streamId, size_t amount)>;
using state_callback = std::function<void(State state)>; using state_callback = std::function<void(State state)>;
SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recv, SctpTransport(std::shared_ptr<Transport> lower, uint16_t port, message_callback recvCallback,
state_callback stateChangeCallback); amount_callback bufferedAmountCallback, state_callback stateChangeCallback);
~SctpTransport(); ~SctpTransport();
State state() const; State state() const;
bool send(message_ptr message); void stop() override;
bool send(message_ptr message) override; // false if buffered
void reset(unsigned int stream); void reset(unsigned int stream);
private: private:
// Order seems wrong but these are the actual values
// See https://tools.ietf.org/html/draft-ietf-rtcweb-data-channel-13#section-8
enum PayloadId : uint32_t { enum PayloadId : uint32_t {
PPID_CONTROL = 50, PPID_CONTROL = 50,
PPID_STRING = 51, PPID_STRING = 51,
PPID_BINARY_PARTIAL = 52,
PPID_BINARY = 53, PPID_BINARY = 53,
PPID_STRING_PARTIAL = 54,
PPID_STRING_EMPTY = 56, PPID_STRING_EMPTY = 56,
PPID_BINARY_EMPTY = 57 PPID_BINARY_EMPTY = 57
}; };
void connect();
void incoming(message_ptr message); void incoming(message_ptr message);
void changeState(State state); void changeState(State state);
void runConnect();
int handleWrite(void *data, size_t len, uint8_t tos, uint8_t set_df); bool trySendQueue();
bool trySendMessage(message_ptr message);
void updateBufferedAmount(uint16_t streamId, long delta);
int process(struct socket *sock, union sctp_sockstore addr, void *data, size_t len, int handleRecv(struct socket *sock, union sctp_sockstore addr, const byte *data, size_t len,
struct sctp_rcvinfo recv_info, int flags); struct sctp_rcvinfo recv_info, int flags);
int handleSend(size_t free);
int handleWrite(byte *data, size_t len, uint8_t tos, uint8_t set_df);
void processData(const byte *data, size_t len, uint16_t streamId, PayloadId ppid); void processData(const byte *data, size_t len, uint16_t streamId, PayloadId ppid);
void processNotification(const union sctp_notification *notify, size_t len); void processNotification(const union sctp_notification *notify, size_t len);
const uint16_t mPort;
struct socket *mSock; struct socket *mSock;
uint16_t mPort;
std::thread mConnectThread; std::mutex mSendMutex;
Queue<message_ptr> mSendQueue;
std::map<uint16_t, size_t> mBufferedAmount;
amount_callback mBufferedAmountCallback;
std::mutex mConnectMutex; std::mutex mConnectMutex;
std::condition_variable mConnectCondition; std::condition_variable mConnectCondition;
std::atomic<bool> mConnectDataSent = false; std::atomic<bool> mConnectDataSent = false;
std::atomic<bool> mStopping = false; std::atomic<bool> mStopping = false;
state_callback mStateChangeCallback;
std::atomic<State> mState; std::atomic<State> mState;
state_callback mStateChangeCallback; binary mPartialRecv, mPartialStringData, mPartialBinaryData;
static int WriteCallback(void *sctp_ptr, void *data, size_t len, uint8_t tos, uint8_t set_df); static int RecvCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
static int ReadCallback(struct socket *sock, union sctp_sockstore addr, void *data, size_t len,
struct sctp_rcvinfo recv_info, int flags, void *user_data); struct sctp_rcvinfo recv_info, int flags, void *user_data);
static int SendCallback(struct socket *sock, uint32_t sb_free);
static int WriteCallback(void *sctp_ptr, void *data, size_t len, uint8_t tos, uint8_t set_df);
void GlobalInit(); void GlobalInit();
void GlobalCleanup(); void GlobalCleanup();

View File

@ -22,6 +22,7 @@
#include "include.hpp" #include "include.hpp"
#include "message.hpp" #include "message.hpp"
#include <atomic>
#include <functional> #include <functional>
#include <memory> #include <memory>
@ -31,36 +32,34 @@ using namespace std::placeholders;
class Transport { class Transport {
public: public:
Transport(std::shared_ptr<Transport> lower = nullptr) : mLower(lower) { init(); } Transport(std::shared_ptr<Transport> lower = nullptr) : mLower(std::move(lower)) {
if (auto lower = std::atomic_load(&mLower))
lower->onRecv(std::bind(&Transport::incoming, this, _1));
}
virtual ~Transport() {} virtual ~Transport() {}
virtual void stop() { resetLower(); }
virtual bool send(message_ptr message) = 0; virtual bool send(message_ptr message) = 0;
void onRecv(message_callback callback) { mRecvCallback = std::move(callback); } void onRecv(message_callback callback) { mRecvCallback = std::move(callback); }
protected: protected:
void recv(message_ptr message) { void recv(message_ptr message) { mRecvCallback(message); }
if (mRecvCallback)
mRecvCallback(message); void resetLower() {
if (auto lower = std::atomic_exchange(&mLower, std::shared_ptr<Transport>(nullptr)))
lower->onRecv(nullptr);
} }
virtual void incoming(message_ptr message) = 0; virtual void incoming(message_ptr message) = 0;
virtual void outgoing(message_ptr message) { getLower()->send(message); } virtual void outgoing(message_ptr message) {
if (auto lower = std::atomic_load(&mLower))
lower->send(message);
}
private: private:
void init() {
if (mLower)
mLower->onRecv(std::bind(&Transport::incoming, this, _1));
}
std::shared_ptr<Transport> getLower() {
if (mLower)
return mLower;
else
throw std::logic_error("No lower transport to call");
}
std::shared_ptr<Transport> mLower; std::shared_ptr<Transport> mLower;
message_callback mRecvCallback; synchronized_callback<message_ptr> mRecvCallback;
}; };
} // namespace rtc } // namespace rtc

View File

@ -26,6 +26,9 @@
using namespace rtc; using namespace rtc;
using namespace std; using namespace std;
template <class T>
weak_ptr<T> make_weak_ptr(shared_ptr<T> ptr) { return ptr; }
int main(int argc, char **argv) { int main(int argc, char **argv) {
rtc::Configuration config; rtc::Configuration config;
// config.iceServers.emplace_back("stun.l.google.com:19302"); // config.iceServers.emplace_back("stun.l.google.com:19302");
@ -33,12 +36,16 @@ int main(int argc, char **argv) {
auto pc1 = std::make_shared<PeerConnection>(config); auto pc1 = std::make_shared<PeerConnection>(config);
auto pc2 = std::make_shared<PeerConnection>(config); auto pc2 = std::make_shared<PeerConnection>(config);
pc1->onLocalDescription([pc2](const Description &sdp) { pc1->onLocalDescription([wpc2 = make_weak_ptr(pc2)](const Description &sdp) {
auto pc2 = wpc2.lock();
if (!pc2) return;
cout << "Description 1: " << sdp << endl; cout << "Description 1: " << sdp << endl;
pc2->setRemoteDescription(sdp); pc2->setRemoteDescription(sdp);
}); });
pc1->onLocalCandidate([pc2](const Candidate &candidate) { pc1->onLocalCandidate([wpc2 = make_weak_ptr(pc2)](const Candidate &candidate) {
auto pc2 = wpc2.lock();
if (!pc2) return;
cout << "Candidate 1: " << candidate << endl; cout << "Candidate 1: " << candidate << endl;
pc2->addRemoteCandidate(candidate); pc2->addRemoteCandidate(candidate);
}); });
@ -48,12 +55,16 @@ int main(int argc, char **argv) {
cout << "Gathering state 1: " << state << endl; cout << "Gathering state 1: " << state << endl;
}); });
pc2->onLocalDescription([pc1](const Description &sdp) { pc2->onLocalDescription([wpc1 = make_weak_ptr(pc1)](const Description &sdp) {
auto pc1 = wpc1.lock();
if (!pc1) return;
cout << "Description 2: " << sdp << endl; cout << "Description 2: " << sdp << endl;
pc1->setRemoteDescription(sdp); pc1->setRemoteDescription(sdp);
}); });
pc2->onLocalCandidate([pc1](const Candidate &candidate) { pc2->onLocalCandidate([wpc1 = make_weak_ptr(pc1)](const Candidate &candidate) {
auto pc1 = wpc1.lock();
if (!pc1) return;
cout << "Candidate 2: " << candidate << endl; cout << "Candidate 2: " << candidate << endl;
pc1->addRemoteCandidate(candidate); pc1->addRemoteCandidate(candidate);
}); });
@ -67,19 +78,38 @@ int main(int argc, char **argv) {
pc2->onDataChannel([&dc2](shared_ptr<DataChannel> dc) { pc2->onDataChannel([&dc2](shared_ptr<DataChannel> dc) {
cout << "Got a DataChannel with label: " << dc->label() << endl; cout << "Got a DataChannel with label: " << dc->label() << endl;
dc2 = dc; dc2 = dc;
dc2->send("Hello world!"); dc2->onMessage([](const variant<binary, string> &message) {
if (holds_alternative<string>(message)) {
cout << "Received 2: " << get<string>(message) << endl;
}
});
dc2->send("Hello from 2");
}); });
auto dc1 = pc1->createDataChannel("test"); auto dc1 = pc1->createDataChannel("test");
dc1->onOpen([dc1]() { dc1->onOpen([wdc1 = make_weak_ptr(dc1)]() {
auto dc1 = wdc1.lock();
if (!dc1) return;
cout << "DataChannel open: " << dc1->label() << endl; cout << "DataChannel open: " << dc1->label() << endl;
dc1->send("Hello from 1");
}); });
dc1->onMessage([](const variant<binary, string> &message) { dc1->onMessage([](const variant<binary, string> &message) {
if (holds_alternative<string>(message)) { if (holds_alternative<string>(message)) {
cout << "Received: " << get<string>(message) << endl; cout << "Received 1: " << get<string>(message) << endl;
} }
}); });
this_thread::sleep_for(10s); this_thread::sleep_for(3s);
if (dc1->isOpen() && dc2->isOpen()) {
dc1->close();
dc2->close();
cout << "Success" << endl;
return 0;
} else {
cout << "Failure" << endl;
return 1;
}
} }