Add some basic Win32 async socket abstractions.

- Adds cppcoro::net::socket class with support for TCP/UDP
  over IPv4 or IPv6.
- Adds async accept/connect/disconnect/send/recv operations in
  both cancellable and non-cancellable variants.
- Still need to add send_to/recv_from as well as overloads for
  multi-buffer send/recv.
This commit is contained in:
Lewis Baker 2017-12-28 23:05:32 +10:30
parent 6ba0778c4e
commit 664537595e
22 changed files with 2425 additions and 3 deletions

View file

@ -58,6 +58,16 @@ namespace cppcoro
struct wsabuf
{
constexpr wsabuf() noexcept
: len(0)
, buf(nullptr)
{}
constexpr wsabuf(void* ptr, std::size_t size)
: len(size <= ulong_t(-1) ? ulong_t(size) : ulong_t(-1))
, buf(static_cast<char*>(ptr))
{}
ulong_t len;
char* buf;
};

View file

@ -18,6 +18,7 @@
#include <cstdint>
#include <atomic>
#include <utility>
#include <mutex>
#include <experimental/coroutine>
namespace cppcoro
@ -121,7 +122,7 @@ namespace cppcoro
/// Call this after a call to stop() to allow calls to process_xxx() methods
/// to process events.
///
/// After calling stop() you must ensure that all threads have returned from
/// After calling stop() you should ensure that all threads have returned from
/// calls to process_xxx() methods before calling reset().
void reset();
@ -133,6 +134,7 @@ namespace cppcoro
#if CPPCORO_OS_WINNT
detail::win32::handle_t native_iocp_handle() noexcept;
void ensure_winsock_initialised();
#endif
private:
@ -167,6 +169,9 @@ namespace cppcoro
#if CPPCORO_OS_WINNT
detail::win32::safe_handle m_iocpHandle;
std::atomic<bool> m_winsockInitialised;
std::mutex m_winsockInitialisationMutex;
#endif
// Head of a linked-list of schedule operations that are

View file

@ -53,6 +53,11 @@ namespace cppcoro::net
std::uint32_t(m_bytes[3]);
}
static constexpr ipv4_address loopback()
{
return ipv4_address(127, 0, 0, 1);
}
constexpr bool is_loopback() const
{
return m_bytes[0] == 127;

View file

@ -0,0 +1,251 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#ifndef CPPCORO_NET_SOCKET_HPP_INCLUDED
#define CPPCORO_NET_SOCKET_HPP_INCLUDED
#include <cppcoro/config.hpp>
#include <cppcoro/net/ip_endpoint.hpp>
#include <cppcoro/net/socket_accept_operation.hpp>
#include <cppcoro/net/socket_connect_operation.hpp>
#include <cppcoro/net/socket_disconnect_operation.hpp>
#include <cppcoro/net/socket_recv_operation.hpp>
#include <cppcoro/net/socket_send_operation.hpp>
#include <cppcoro/cancellation_token.hpp>
#if CPPCORO_OS_WINNT
# include <cppcoro/detail/win32.hpp>
#endif
namespace cppcoro
{
class io_service;
namespace net
{
class socket_send_operation;
class socket_recv_operation;
class socket
{
public:
/// Create a socket that can be used to communicate using TCP/IPv4 protocol.
///
/// \param ioSvc
/// The I/O service the socket will use for dispatching I/O completion events.
///
/// \return
/// The newly created socket.
///
/// \throws std::system_error
/// If the socket could not be created for some reason.
static socket create_tcpv4(io_service& ioSvc);
/// Create a socket that can be used to communicate using TCP/IPv6 protocol.
///
/// \param ioSvc
/// The I/O service the socket will use for dispatching I/O completion events.
///
/// \return
/// The newly created socket.
///
/// \throws std::system_error
/// If the socket could not be created for some reason.
static socket create_tcpv6(io_service& ioSvc);
/// Create a socket that can be used to communicate using UDP/IPv4 protocol.
///
/// \param ioSvc
/// The I/O service the socket will use for dispatching I/O completion events.
///
/// \return
/// The newly created socket.
///
/// \throws std::system_error
/// If the socket could not be created for some reason.
static socket create_udpv4(io_service& ioSvc);
/// Create a socket that can be used to communicate using UDP/IPv6 protocol.
///
/// \param ioSvc
/// The I/O service the socket will use for dispatching I/O completion events.
///
/// \return
/// The newly created socket.
///
/// \throws std::system_error
/// If the socket could not be created for some reason.
static socket create_udpv6(io_service& ioSvc);
socket(socket&& other) noexcept;
/// Closes the socket, releasing any associated resources.
///
/// If the socket still has an open connection then the connection will be
/// reset. The destructor will not block waiting for queueud data to be sent.
/// If you need to ensure that queued data is delivered then you must call
/// disconnect() and wait until the disconnect operation completes.
~socket();
socket& operator=(socket&& other) noexcept;
#if CPPCORO_OS_WINNT
/// Get the Win32 socket handle assocaited with this socket.
cppcoro::detail::win32::socket_t native_handle() noexcept { return m_handle; }
/// Query whether I/O operations that complete synchronously will skip posting
/// an I/O completion event to the I/O completion port.
///
/// The operation class implementations can use this to determine whether or not
/// it should immediately resume the coroutine on the current thread upon an
/// operation completing synchronously or whether it should suspend the coroutine
/// and wait until the I/O completion event is dispatched to an I/O thread.
bool skip_completion_on_success() noexcept { return m_skipCompletionOnSuccess; }
#endif
/// Get the address and port of the local end-point.
///
/// If the socket is not bound then this will be the unspecified end-point
/// of the socket's associated address-family.
const ip_endpoint& local_endpoint() const noexcept { return m_localEndPoint; }
/// Get the address and port of the remote end-point.
///
/// If the socket is not in the connected state then this will be the unspecified
/// end-point of the socket's associated address-family.
const ip_endpoint& remote_endpoint() const noexcept { return m_remoteEndPoint; }
/// Bind the local end of this socket to the specified local end-point.
///
/// \param localEndPoint
/// The end-point to bind to.
/// This can be either an unspecified address (in which case it binds to all available
/// interfaces) and/or an unspecified port (in which case a random port is allocated).
///
/// \throws std::system_error
/// If the socket could not be bound for some reason.
void bind(const ip_endpoint& localEndPoint);
/// Put the socket into a passive listening state that will start acknowledging
/// and queueing up new connections ready to be accepted by a call to 'accept()'.
///
/// The backlog of connections ready to be accepted will be set to some default
/// suitable large value, depending on the network provider. If you need more
/// control over the size of the queue then use the overload of listen()
/// that accepts a 'backlog' parameter.
///
/// \throws std::system_error
/// If the socket could not be placed into a listening mode.
void listen();
/// Put the socket into a passive listening state that will start acknowledging
/// and queueing up new connections ready to be accepted by a call to 'accept()'.
///
/// \param backlog
/// The maximum number of pending connections to allow in the queue of ready-to-accept
/// connections.
///
/// \throws std::system_error
/// If the socket could not be placed into a listening mode.
void listen(std::uint32_t backlog);
/// Connect the socket to the specified remote end-point.
///
/// The socket must be in a bound but unconnected state prior to this call.
///
/// \param remoteEndPoint
/// The IP address and port-number to connect to.
///
/// \return
/// An awaitable object that must be co_await'ed to perform the async connect
/// operation. The result of the co_await expression is type void.
[[nodiscard]]
socket_connect_operation connect(const ip_endpoint& remoteEndPoint);
/// Connect to the specified remote end-point.
///
/// \param remoteEndPoint
/// The IP address and port of the remote end-point to connect to.
///
/// \param ct
/// A cancellation token that can be used to communicate a request to
/// later cancel the operation. If the operation is successfully
/// cancelled then it will complete by throwing a cppcoro::operation_cancelled
/// exception.
///
/// \return
/// An awaitable object that will start the connect operation when co_await'ed
/// and will suspend the coroutine, resuming it when the operation completes.
/// The result of the co_await expression has type 'void'.
[[nodiscard]]
socket_connect_operation_cancellable connect(
const ip_endpoint& remoteEndPoint,
cancellation_token ct);
[[nodiscard]]
socket_accept_operation accept(socket& acceptingSocket);
[[nodiscard]]
socket_accept_operation_cancellable accept(
socket& acceptingSocket,
cancellation_token ct);
[[nodiscard]]
socket_disconnect_operation disconnect();
[[nodiscard]]
socket_disconnect_operation_cancellable disconnect(cancellation_token ct);
[[nodiscard]]
socket_send_operation send(
const void* buffer,
std::size_t size);
[[nodiscard]]
socket_send_operation_cancellable send(
const void* buffer,
std::size_t size,
cancellation_token ct);
[[nodiscard]]
socket_recv_operation recv(
void* buffer,
std::size_t size);
[[nodiscard]]
socket_recv_operation_cancellable recv(
void* buffer,
std::size_t size,
cancellation_token ct);
void close_send();
void close_recv();
private:
friend class socket_accept_operation;
friend class socket_accept_operation_cancellable;
friend class socket_connect_operation;
friend class socket_connect_operation_cancellable;
friend class socket_disconnect_operation;
friend class socket_disconnect_operation_cancellable;
#if CPPCORO_OS_WINNT
explicit socket(
cppcoro::detail::win32::socket_t handle,
bool skipCompletionOnSuccess) noexcept;
#endif
#if CPPCORO_OS_WINNT
cppcoro::detail::win32::socket_t m_handle;
bool m_skipCompletionOnSuccess;
#endif
ip_endpoint m_localEndPoint;
ip_endpoint m_remoteEndPoint;
};
}
}
#endif

View file

@ -0,0 +1,101 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#ifndef CPPCORO_NET_SOCKET_ACCEPT_OPERATION_HPP_INCLUDED
#define CPPCORO_NET_SOCKET_ACCEPT_OPERATION_HPP_INCLUDED
#include <cppcoro/config.hpp>
#include <cppcoro/cancellation_token.hpp>
#include <cppcoro/cancellation_registration.hpp>
#if CPPCORO_OS_WINNT
# include <cppcoro/detail/win32.hpp>
# include <cppcoro/detail/win32_overlapped_operation.hpp>
# include <atomic>
# include <optional>
# include <experimental/coroutine>
namespace cppcoro
{
namespace net
{
class socket;
class socket_accept_operation
: public cppcoro::detail::win32_overlapped_operation<socket_accept_operation>
{
public:
socket_accept_operation(
socket& listeningSocket,
socket& acceptingSocket) noexcept
: cppcoro::detail::win32_overlapped_operation<socket_accept_operation>()
, m_listeningSocket(listeningSocket)
, m_acceptingSocket(acceptingSocket)
{}
private:
friend class cppcoro::detail::win32_overlapped_operation<socket_accept_operation>;
bool try_start() noexcept;
void get_result();
#if CPPCORO_COMPILER_MSVC
# pragma warning(push)
# pragma warning(disable : 4324) // Structure padded due to alignment
#endif
socket& m_listeningSocket;
socket& m_acceptingSocket;
alignas(8) std::uint8_t m_addressBuffer[88];
#if CPPCORO_COMPILER_MSVC
# pragma warning(pop)
#endif
};
class socket_accept_operation_cancellable
: public cppcoro::detail::win32_overlapped_operation_cancellable<socket_accept_operation_cancellable>
{
public:
socket_accept_operation_cancellable(
socket& listeningSocket,
socket& acceptingSocket,
cancellation_token&& ct) noexcept
: cppcoro::detail::win32_overlapped_operation_cancellable<socket_accept_operation_cancellable>(std::move(ct))
, m_listeningSocket(listeningSocket)
, m_acceptingSocket(acceptingSocket)
{}
private:
friend class cppcoro::detail::win32_overlapped_operation<socket_accept_operation>;
bool try_start() noexcept;
void cancel() noexcept;
void get_result();
#if CPPCORO_COMPILER_MSVC
# pragma warning(push)
# pragma warning(disable : 4324) // Structure padded due to alignment
#endif
socket& m_listeningSocket;
socket& m_acceptingSocket;
alignas(8) std::uint8_t m_addressBuffer[88];
#if CPPCORO_COMPILER_MSVC
# pragma warning(pop)
#endif
};
}
}
#endif // CPPCORO_OS_WINNT
#endif

View file

@ -0,0 +1,78 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#ifndef CPPCORO_NET_SOCKET_CONNECT_OPERATION_HPP_INCLUDED
#define CPPCORO_NET_SOCKET_CONNECT_OPERATION_HPP_INCLUDED
#include <cppcoro/config.hpp>
#include <cppcoro/cancellation_token.hpp>
#include <cppcoro/net/ip_endpoint.hpp>
#if CPPCORO_OS_WINNT
# include <cppcoro/detail/win32.hpp>
# include <cppcoro/detail/win32_overlapped_operation.hpp>
namespace cppcoro
{
namespace net
{
class socket;
class socket_connect_operation
: public cppcoro::detail::win32_overlapped_operation<socket_connect_operation>
{
public:
socket_connect_operation(
socket& socket,
const ip_endpoint& remoteEndPoint) noexcept
: cppcoro::detail::win32_overlapped_operation<socket_connect_operation>()
, m_socket(socket)
, m_remoteEndPoint(remoteEndPoint)
{}
private:
friend class cppcoro::detail::win32_overlapped_operation<socket_connect_operation>;
bool try_start() noexcept;
void get_result();
socket& m_socket;
ip_endpoint m_remoteEndPoint;
};
class socket_connect_operation_cancellable
: public cppcoro::detail::win32_overlapped_operation_cancellable<socket_connect_operation_cancellable>
{
public:
socket_connect_operation_cancellable(
socket& socket,
const ip_endpoint& remoteEndPoint,
cancellation_token&& ct) noexcept
: cppcoro::detail::win32_overlapped_operation_cancellable<socket_connect_operation_cancellable>(std::move(ct))
, m_socket(socket)
, m_remoteEndPoint(remoteEndPoint)
{}
private:
friend class cppcoro::detail::win32_overlapped_operation<socket_connect_operation>;
bool try_start() noexcept;
void cancel() noexcept;
void get_result();
socket& m_socket;
ip_endpoint m_remoteEndPoint;
};
}
}
#endif // CPPCORO_OS_WINNT
#endif

View file

@ -0,0 +1,67 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#ifndef CPPCORO_NET_SOCKET_DISCONNECT_OPERATION_HPP_INCLUDED
#define CPPCORO_NET_SOCKET_DISCONNECT_OPERATION_HPP_INCLUDED
#include <cppcoro/config.hpp>
#include <cppcoro/cancellation_token.hpp>
#if CPPCORO_OS_WINNT
# include <cppcoro/detail/win32.hpp>
# include <cppcoro/detail/win32_overlapped_operation.hpp>
namespace cppcoro
{
namespace net
{
class socket;
class socket_disconnect_operation
: public cppcoro::detail::win32_overlapped_operation<socket_disconnect_operation>
{
public:
socket_disconnect_operation(socket& s) noexcept
: m_socket(s)
{}
private:
friend class cppcoro::detail::win32_overlapped_operation<socket_disconnect_operation>;
bool try_start() noexcept;
void get_result();
socket& m_socket;
};
class socket_disconnect_operation_cancellable
: public cppcoro::detail::win32_overlapped_operation_cancellable<socket_disconnect_operation_cancellable>
{
public:
socket_disconnect_operation_cancellable(socket& s, cancellation_token&& ct) noexcept
: cppcoro::detail::win32_overlapped_operation_cancellable<socket_disconnect_operation_cancellable>(std::move(ct))
, m_socket(s)
{}
private:
friend class cppcoro::detail::win32_overlapped_operation<socket_disconnect_operation>;
bool try_start() noexcept;
void cancel() noexcept;
void get_result();
socket& m_socket;
};
}
}
#endif // CPPCORO_OS_WINNT
#endif

View file

@ -0,0 +1,71 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#ifndef CPPCORO_NET_SOCKET_RECV_OPERATION_HPP_INCLUDED
#define CPPCORO_NET_SOCKET_RECV_OPERATION_HPP_INCLUDED
#include <cppcoro/config.hpp>
#include <cppcoro/cancellation_token.hpp>
#include <cstdint>
#if CPPCORO_OS_WINNT
# include <cppcoro/detail/win32.hpp>
# include <cppcoro/detail/win32_overlapped_operation.hpp>
namespace cppcoro::net
{
class socket;
class socket_recv_operation
: public cppcoro::detail::win32_overlapped_operation<socket_recv_operation>
{
public:
socket_recv_operation(
socket& s,
void* buffer,
std::size_t byteCount) noexcept;
private:
friend class cppcoro::detail::win32_overlapped_operation<socket_recv_operation>;
bool try_start() noexcept;
cppcoro::detail::win32::socket_t m_socketHandle;
cppcoro::detail::win32::wsabuf m_buffer;
bool m_skipCompletionOnSuccess;
};
class socket_recv_operation_cancellable
: public cppcoro::detail::win32_overlapped_operation_cancellable<socket_recv_operation_cancellable>
{
public:
socket_recv_operation_cancellable(
socket& s,
void* buffer,
std::size_t byteCount,
cancellation_token&& ct) noexcept;
private:
friend class cppcoro::detail::win32_overlapped_operation<socket_recv_operation>;
bool try_start() noexcept;
void cancel() noexcept;
cppcoro::detail::win32::socket_t m_socketHandle;
cppcoro::detail::win32::wsabuf m_buffer;
bool m_skipCompletionOnSuccess;
};
}
#endif // CPPCORO_OS_WINNT
#endif

View file

@ -0,0 +1,71 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#ifndef CPPCORO_NET_SOCKET_SEND_OPERATION_HPP_INCLUDED
#define CPPCORO_NET_SOCKET_SEND_OPERATION_HPP_INCLUDED
#include <cppcoro/config.hpp>
#include <cppcoro/cancellation_token.hpp>
#include <cstdint>
#if CPPCORO_OS_WINNT
# include <cppcoro/detail/win32.hpp>
# include <cppcoro/detail/win32_overlapped_operation.hpp>
namespace cppcoro::net
{
class socket;
class socket_send_operation
: public cppcoro::detail::win32_overlapped_operation<socket_send_operation>
{
public:
socket_send_operation(
socket& s,
const void* buffer,
std::size_t byteCount) noexcept;
private:
friend class cppcoro::detail::win32_overlapped_operation<socket_send_operation>;
bool try_start() noexcept;
cppcoro::detail::win32::socket_t m_socketHandle;
cppcoro::detail::win32::wsabuf m_buffer;
bool m_skipCompletionOnSuccess;
};
class socket_send_operation_cancellable
: public cppcoro::detail::win32_overlapped_operation_cancellable<socket_send_operation_cancellable>
{
public:
socket_send_operation_cancellable(
socket& s,
const void* buffer,
std::size_t byteCount,
cancellation_token&& ct) noexcept;
private:
friend class cppcoro::detail::win32_overlapped_operation<socket_send_operation>;
bool try_start() noexcept;
void cancel() noexcept;
cppcoro::detail::win32::socket_t m_socketHandle;
cppcoro::detail::win32::wsabuf m_buffer;
bool m_skipCompletionOnSuccess;
};
}
#endif // CPPCORO_OS_WINNT
#endif

View file

@ -54,6 +54,7 @@ netIncludes = cake.path.join(env.expand('${CPPCORO}'), 'include', 'cppcoro', 'ne
'ipv4_endpoint.hpp.',
'ipv6_address.hpp',
'ipv6_endpoint.hpp',
'socket.hpp',
])
detailIncludes = cake.path.join(env.expand('${CPPCORO}'), 'include', 'cppcoro', 'detail', [
@ -65,6 +66,7 @@ detailIncludes = cake.path.join(env.expand('${CPPCORO}'), 'include', 'cppcoro',
privateHeaders = script.cwd([
'cancellation_state.hpp',
'socket_helpers.hpp',
])
sources = script.cwd([
@ -92,8 +94,16 @@ extras = script.cwd([
if variant.platform == "windows":
detailIncludes.extend(cake.path.join(env.expand('${CPPCORO}'), 'include', 'cppcoro', 'detail', [
'win32.hpp',
'win32_overlapped_operation.hpp',
'win32_overlapped_operation.hpp',
]))
netIncludes.extend(cake.path.join(env.expand('${CPPCORO}'), 'include', 'cppcoro', 'net', [
'socket.hpp',
'socket_accept_operation.hpp',
'socket_connect_operation.hpp',
'socket_disconnect_operation.hpp',
'socket_send_operation.hpp',
'socket_recv_operation.hpp',
]))
sources.extend(script.cwd([
'win32.cpp',
'io_service.cpp',
@ -105,6 +115,13 @@ if variant.platform == "windows":
'read_write_file.cpp',
'file_read_operation.cpp',
'file_write_operation.cpp',
'socket_helpers.cpp',
'socket.cpp',
'socket_accept_operation.cpp',
'socket_connect_operation.cpp',
'socket_disconnect_operation.cpp',
'socket_send_operation.cpp',
'socket_recv_operation.cpp',
]))
buildDir = env.expand('${CPPCORO_BUILD}')

View file

@ -12,8 +12,10 @@
#include <thread>
#if CPPCORO_OS_WINNT
# define WIN32_LEAN_AND_MEAN
# define NOMINMAX
# include <WinSock2.h>
# include <WS2tcpip.h>
# include <MSWSock.h>
# include <Windows.h>
#endif
@ -326,6 +328,8 @@ cppcoro::io_service::io_service(std::uint32_t concurrencyHint)
, m_workCount(0)
#if CPPCORO_OS_WINNT
, m_iocpHandle(create_io_completion_port(concurrencyHint))
, m_winsockInitialised(false)
, m_winsockInitialisationMutex()
#endif
, m_scheduleOperations(nullptr)
, m_timerState(nullptr)
@ -338,6 +342,15 @@ cppcoro::io_service::~io_service()
assert(m_threadState.load(std::memory_order_relaxed) < active_thread_count_increment);
delete m_timerState.load(std::memory_order_relaxed);
#if CPPCORO_OS_WINNT
if (m_winsockInitialised.load(std::memory_order_relaxed))
{
// TODO: Should we be checking return-code here?
// Don't want to throw from the destructor, so perhaps just log an error?
(void)::WSACleanup();
}
#endif
}
cppcoro::io_service::schedule_operation cppcoro::io_service::schedule() noexcept
@ -458,6 +471,34 @@ cppcoro::detail::win32::handle_t cppcoro::io_service::native_iocp_handle() noexc
return m_iocpHandle.handle();
}
#if CPPCORO_OS_WINNT
void cppcoro::io_service::ensure_winsock_initialised()
{
if (!m_winsockInitialised.load(std::memory_order_acquire))
{
std::lock_guard<std::mutex> lock(m_winsockInitialisationMutex);
if (!m_winsockInitialised.load(std::memory_order_acquire))
{
const WORD requestedVersion = MAKEWORD(2, 2);
WSADATA winsockData;
const int result = ::WSAStartup(requestedVersion, &winsockData);
if (result == SOCKET_ERROR)
{
const int errorCode = ::WSAGetLastError();
throw std::system_error(
errorCode,
std::system_category(),
"Error initialsing winsock: WSAStartup");
}
m_winsockInitialised.store(true, std::memory_order_release);
}
}
}
#endif // CPPCORO_OS_WINNT
void cppcoro::io_service::schedule_impl(schedule_operation* operation) noexcept
{
#if CPPCORO_OS_WINNT

469
lib/socket.cpp Normal file
View file

@ -0,0 +1,469 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#include <cppcoro/net/socket.hpp>
#include <cppcoro/net/socket_accept_operation.hpp>
#include <cppcoro/net/socket_connect_operation.hpp>
#include <cppcoro/net/socket_disconnect_operation.hpp>
#include <cppcoro/net/socket_recv_operation.hpp>
#include <cppcoro/net/socket_send_operation.hpp>
#include <cppcoro/io_service.hpp>
#include <cppcoro/on_scope_exit.hpp>
#include "socket_helpers.hpp"
#if CPPCORO_OS_WINNT
# include <WinSock2.h>
# include <WS2tcpip.h>
# include <MSWSock.h>
# include <Windows.h>
namespace
{
namespace local
{
std::tuple<SOCKET, bool> create_socket(
int addressFamily,
int socketType,
int protocol,
HANDLE ioCompletionPort)
{
// Enumerate available protocol providers for the specified socket type.
WSAPROTOCOL_INFOW stackInfos[4];
std::unique_ptr<WSAPROTOCOL_INFOW[]> heapInfos;
WSAPROTOCOL_INFOW* selectedProtocolInfo = nullptr;
{
INT protocols[] = { protocol, 0 };
DWORD bufferSize = sizeof(stackInfos);
WSAPROTOCOL_INFOW* infos = stackInfos;
int protocolCount = ::WSAEnumProtocolsW(protocols, infos, &bufferSize);
if (protocolCount == SOCKET_ERROR)
{
int errorCode = ::WSAGetLastError();
if (errorCode == WSAENOBUFS)
{
DWORD requiredElementCount = bufferSize / sizeof(WSAPROTOCOL_INFOW);
heapInfos = std::make_unique<WSAPROTOCOL_INFOW[]>(requiredElementCount);
bufferSize = requiredElementCount * sizeof(WSAPROTOCOL_INFOW);
infos = heapInfos.get();
protocolCount = ::WSAEnumProtocolsW(protocols, infos, &bufferSize);
if (protocolCount == SOCKET_ERROR)
{
errorCode = ::WSAGetLastError();
}
}
if (protocolCount == SOCKET_ERROR)
{
throw std::system_error(
errorCode,
std::system_category(),
"Error creating socket: WSAEnumProtocolsW");
}
}
if (protocolCount == 0)
{
throw std::system_error(
std::make_error_code(std::errc::protocol_not_supported));
}
for (int i = 0; i < protocolCount; ++i)
{
auto& info = infos[i];
if (info.iAddressFamily == addressFamily && info.iProtocol == protocol && info.iSocketType == socketType)
{
selectedProtocolInfo = &info;
break;
}
}
if (selectedProtocolInfo == nullptr)
{
throw std::system_error(
std::make_error_code(std::errc::address_family_not_supported));
}
}
// WSA_FLAG_NO_HANDLE_INHERIT for SDKs earlier than Windows 7.
constexpr DWORD flagNoInherit = 0x80;
const DWORD flags = WSA_FLAG_OVERLAPPED | flagNoInherit;
const SOCKET socketHandle = ::WSASocketW(
addressFamily, socketType, protocol, selectedProtocolInfo, 0, flags);
if (socketHandle == INVALID_SOCKET)
{
const int errorCode = ::WSAGetLastError();
throw std::system_error(
errorCode,
std::system_category(),
"Error creating socket: WSASocketW");
}
auto closeSocketOnFailure = cppcoro::on_scope_failure([&]
{
::closesocket(socketHandle);
});
// This is needed on operating systems earlier than Windows 7 to prevent
// socket handles from being inherited. On Windows 7 or later this is
// redundant as the WSA_FLAG_NO_HANDLE_INHERIT flag passed to creation
// above causes the socket to be atomically created with this flag cleared.
if (!::SetHandleInformation((HANDLE)socketHandle, HANDLE_FLAG_INHERIT, 0))
{
const DWORD errorCode = ::GetLastError();
throw std::system_error(
errorCode,
std::system_category(),
"Error creating socket: SetHandleInformation");
}
// Associate the socket with the I/O completion port.
{
const HANDLE result = ::CreateIoCompletionPort(
(HANDLE)socketHandle,
ioCompletionPort,
ULONG_PTR(0),
DWORD(0));
if (result == nullptr)
{
const DWORD errorCode = ::GetLastError();
throw std::system_error(
static_cast<int>(errorCode),
std::system_category(),
"Error creating socket: CreateIoCompletionPort");
}
}
const bool skipCompletionPortOnSuccess =
(selectedProtocolInfo->dwServiceFlags1 & XP1_IFS_HANDLES) != 0;
{
UCHAR completionModeFlags = FILE_SKIP_SET_EVENT_ON_HANDLE;
if (skipCompletionPortOnSuccess)
{
completionModeFlags |= FILE_SKIP_COMPLETION_PORT_ON_SUCCESS;
}
const BOOL ok = ::SetFileCompletionNotificationModes(
(HANDLE)socketHandle,
completionModeFlags);
if (!ok)
{
const DWORD errorCode = ::GetLastError();
throw std::system_error(
static_cast<int>(errorCode),
std::system_category(),
"Error creating socket: SetFileCompletionNotificationModes");
}
}
if (socketType == SOCK_STREAM)
{
// Turn off linger so that the destructor doesn't block while closing
// the socket or silently continue to flush remaining data in the
// background after ::closesocket() is called, which could fail and
// we'd never know about it.
// We expect clients to call Disconnect() or use CloseSend() to cleanly
// shut-down connections instead.
BOOL value = TRUE;
const int result = ::setsockopt(socketHandle,
SOL_SOCKET,
SO_DONTLINGER,
reinterpret_cast<const char*>(&value),
sizeof(value));
if (result == SOCKET_ERROR)
{
const int errorCode = ::WSAGetLastError();
throw std::system_error(
errorCode,
std::system_category(),
"Error creating socket: setsockopt(SO_DONTLINGER)");
}
}
return std::make_tuple(socketHandle, skipCompletionPortOnSuccess);
}
}
}
cppcoro::net::socket cppcoro::net::socket::create_tcpv4(io_service& ioSvc)
{
ioSvc.ensure_winsock_initialised();
auto[socketHandle, skipCompletionPortOnSuccess] = local::create_socket(
AF_INET, SOCK_STREAM, IPPROTO_TCP, ioSvc.native_iocp_handle());
socket result(socketHandle, skipCompletionPortOnSuccess);
result.m_localEndPoint = ipv4_endpoint();
result.m_remoteEndPoint = ipv4_endpoint();
return result;
}
cppcoro::net::socket cppcoro::net::socket::create_tcpv6(io_service& ioSvc)
{
ioSvc.ensure_winsock_initialised();
auto[socketHandle, skipCompletionPortOnSuccess] = local::create_socket(
AF_INET6, SOCK_STREAM, IPPROTO_TCP, ioSvc.native_iocp_handle());
socket result(socketHandle, skipCompletionPortOnSuccess);
result.m_localEndPoint = ipv6_endpoint();
result.m_remoteEndPoint = ipv6_endpoint();
return result;
}
cppcoro::net::socket cppcoro::net::socket::create_udpv4(io_service& ioSvc)
{
ioSvc.ensure_winsock_initialised();
auto[socketHandle, skipCompletionPortOnSuccess] = local::create_socket(
AF_INET, SOCK_DGRAM, IPPROTO_UDP, ioSvc.native_iocp_handle());
socket result(socketHandle, skipCompletionPortOnSuccess);
result.m_localEndPoint = ipv4_endpoint();
result.m_remoteEndPoint = ipv4_endpoint();
return result;
}
cppcoro::net::socket cppcoro::net::socket::create_udpv6(io_service& ioSvc)
{
ioSvc.ensure_winsock_initialised();
auto[socketHandle, skipCompletionPortOnSuccess] = local::create_socket(
AF_INET6, SOCK_DGRAM, IPPROTO_UDP, ioSvc.native_iocp_handle());
socket result(socketHandle, skipCompletionPortOnSuccess);
result.m_localEndPoint = ipv6_endpoint();
result.m_remoteEndPoint = ipv6_endpoint();
return result;
}
cppcoro::net::socket::socket(socket&& other) noexcept
: m_handle(std::exchange(other.m_handle, INVALID_SOCKET))
, m_skipCompletionOnSuccess(other.m_skipCompletionOnSuccess)
, m_localEndPoint(std::move(other.m_localEndPoint))
, m_remoteEndPoint(std::move(other.m_remoteEndPoint))
{}
cppcoro::net::socket::~socket()
{
if (m_handle != INVALID_SOCKET)
{
::closesocket(m_handle);
}
}
cppcoro::net::socket&
cppcoro::net::socket::operator=(socket&& other) noexcept
{
auto handle = std::exchange(other.m_handle, INVALID_SOCKET);
if (m_handle != INVALID_SOCKET)
{
::closesocket(m_handle);
}
m_handle = handle;
m_skipCompletionOnSuccess = other.m_skipCompletionOnSuccess;
m_localEndPoint = other.m_localEndPoint;
m_remoteEndPoint = other.m_remoteEndPoint;
return *this;
}
void cppcoro::net::socket::bind(const ip_endpoint& localEndPoint)
{
SOCKADDR_STORAGE sockaddrStorage = { 0 };
SOCKADDR* sockaddr = reinterpret_cast<SOCKADDR*>(&sockaddrStorage);
if (localEndPoint.is_ipv4())
{
SOCKADDR_IN& ipv4Sockaddr = *reinterpret_cast<SOCKADDR_IN*>(sockaddr);
ipv4Sockaddr.sin_family = AF_INET;
std::memcpy(&ipv4Sockaddr.sin_addr, localEndPoint.to_ipv4().address().bytes(), 4);
ipv4Sockaddr.sin_port = localEndPoint.to_ipv4().port();
}
else
{
SOCKADDR_IN6& ipv6Sockaddr = *reinterpret_cast<SOCKADDR_IN6*>(sockaddr);
ipv6Sockaddr.sin6_family = AF_INET6;
std::memcpy(&ipv6Sockaddr.sin6_addr, localEndPoint.to_ipv6().address().bytes(), 16);
ipv6Sockaddr.sin6_port = localEndPoint.to_ipv6().port();
}
int result = ::bind(m_handle, sockaddr, sizeof(sockaddrStorage));
if (result != 0)
{
// WSANOTINITIALISED: WSAStartup not called
// WSAENETDOWN: network subsystem failed
// WSAEACCES: access denied
// WSAEADDRINUSE: port in use
// WSAEADDRNOTAVAIL: address is not an address that can be bound to
// WSAEFAULT: invalid pointer passed to bind()
// WSAEINPROGRESS: a callback is in progress
// WSAEINVAL: socket already bound
// WSAENOBUFS: system failed to allocate memory
// WSAENOTSOCK: socket was not a valid socket.
int errorCode = ::WSAGetLastError();
throw std::system_error(
errorCode,
std::system_category(),
"Error binding to endpoint: bind()");
}
int sockaddrLen = sizeof(sockaddrStorage);
result = ::getsockname(m_handle, sockaddr, &sockaddrLen);
if (result == 0)
{
m_localEndPoint = cppcoro::net::detail::sockaddr_to_ip_endpoint(*sockaddr);
}
else
{
m_localEndPoint = localEndPoint;
}
}
void cppcoro::net::socket::listen()
{
int result = ::listen(m_handle, SOMAXCONN);
if (result != 0)
{
int errorCode = ::WSAGetLastError();
throw std::system_error(
errorCode,
std::system_category(),
"Failed to start listening on bound endpoint: listen");
}
}
void cppcoro::net::socket::listen(std::uint32_t backlog)
{
if (backlog > 0x7FFFFFFF)
{
backlog = 0x7FFFFFFF;
}
int result = ::listen(m_handle, (int)backlog);
if (result != 0)
{
// WSANOTINITIALISED: WSAStartup not called
// WSAENETDOWN: network subsystem failed
// WSAEADDRINUSE: port in use
// WSAEINPROGRESS: a callback is in progress
// WSAEINVAL: socket not yet bound
// WSAEISCONN: socket already connected
// WSAEMFILE: no more socket descriptors available
// WSAENOBUFS: system failed to allocate memory
// WSAENOTSOCK: socket was not a valid socket.
// WSAEOPNOTSUPP: The socket does not support listening
int errorCode = ::WSAGetLastError();
throw std::system_error(
errorCode,
std::system_category(),
"Failed to start listening on bound endpoint: listen");
}
}
cppcoro::net::socket_accept_operation
cppcoro::net::socket::accept(socket& acceptingSocket)
{
return socket_accept_operation{ *this, acceptingSocket };
}
cppcoro::net::socket_accept_operation_cancellable
cppcoro::net::socket::accept(socket& acceptingSocket, cancellation_token ct)
{
return socket_accept_operation_cancellable{ *this, acceptingSocket, std::move(ct) };
}
cppcoro::net::socket_connect_operation
cppcoro::net::socket::connect(const ip_endpoint& remoteEndPoint)
{
return socket_connect_operation{ *this, remoteEndPoint };
}
cppcoro::net::socket_connect_operation_cancellable
cppcoro::net::socket::connect(const ip_endpoint& remoteEndPoint, cancellation_token ct)
{
return socket_connect_operation_cancellable{ *this, remoteEndPoint, std::move(ct) };
}
cppcoro::net::socket_disconnect_operation
cppcoro::net::socket::disconnect()
{
return socket_disconnect_operation(*this);
}
cppcoro::net::socket_disconnect_operation_cancellable
cppcoro::net::socket::disconnect(cancellation_token ct)
{
return socket_disconnect_operation_cancellable{ *this, std::move(ct) };
}
cppcoro::net::socket_send_operation
cppcoro::net::socket::send(const void* buffer, std::size_t byteCount)
{
return socket_send_operation{ *this, buffer, byteCount };
}
cppcoro::net::socket_send_operation_cancellable
cppcoro::net::socket::send(const void* buffer, std::size_t byteCount, cancellation_token ct)
{
return socket_send_operation_cancellable{ *this, buffer, byteCount, std::move(ct) };
}
cppcoro::net::socket_recv_operation
cppcoro::net::socket::recv(void* buffer, std::size_t byteCount)
{
return socket_recv_operation{ *this, buffer, byteCount };
}
cppcoro::net::socket_recv_operation_cancellable
cppcoro::net::socket::recv(void* buffer, std::size_t byteCount, cancellation_token ct)
{
return socket_recv_operation_cancellable{ *this, buffer, byteCount, std::move(ct) };
}
void cppcoro::net::socket::close_send()
{
int result = ::shutdown(m_handle, SD_SEND);
if (result == SOCKET_ERROR)
{
int errorCode = ::WSAGetLastError();
throw std::system_error(
errorCode,
std::system_category(),
"failed to close socket send stream: shutdown(SD_SEND)");
}
}
void cppcoro::net::socket::close_recv()
{
int result = ::shutdown(m_handle, SD_RECEIVE);
if (result == SOCKET_ERROR)
{
int errorCode = ::WSAGetLastError();
throw std::system_error(
errorCode,
std::system_category(),
"failed to close socket receive stream: shutdown(SD_RECEIVE)");
}
}
cppcoro::net::socket::socket(
cppcoro::detail::win32::socket_t handle,
bool skipCompletionOnSuccess) noexcept
: m_handle(handle)
, m_skipCompletionOnSuccess(skipCompletionOnSuccess)
{
}
#endif

View file

@ -0,0 +1,212 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#include <cppcoro/net/socket_accept_operation.hpp>
#include <cppcoro/net/socket.hpp>
#include "socket_helpers.hpp"
#include <system_error>
#if CPPCORO_OS_WINNT
# include <WinSock2.h>
# include <WS2tcpip.h>
# include <MSWSock.h>
# include <Windows.h>
// TODO: Eliminate duplication of implementation between socket_accept_operation
// and socket_accept_operation_cancellable.
bool cppcoro::net::socket_accept_operation::try_start() noexcept
{
static_assert(
(sizeof(m_addressBuffer) / 2) >= (16 + sizeof(SOCKADDR_IN)) &&
(sizeof(m_addressBuffer) / 2) >= (16 + sizeof(SOCKADDR_IN6)),
"AcceptEx requires address buffer to be at least 16 bytes more than largest address.");
DWORD bytesReceived = 0;
BOOL ok = ::AcceptEx(
m_listeningSocket.native_handle(),
m_acceptingSocket.native_handle(),
m_addressBuffer,
0,
sizeof(m_addressBuffer) / 2,
sizeof(m_addressBuffer) / 2,
&bytesReceived,
get_overlapped());
if (!ok)
{
int errorCode = ::WSAGetLastError();
if (errorCode != ERROR_IO_PENDING)
{
m_errorCode = static_cast<DWORD>(errorCode);
return false;
}
}
else if (m_listeningSocket.m_skipCompletionOnSuccess)
{
m_errorCode = ERROR_SUCCESS;
return false;
}
return true;
}
void cppcoro::net::socket_accept_operation::get_result()
{
if (m_errorCode != ERROR_SUCCESS)
{
throw std::system_error{
static_cast<int>(m_errorCode),
std::system_category(),
"Accepting a connection failed: AcceptEx"
};
}
sockaddr* localSockaddr = nullptr;
sockaddr* remoteSockaddr = nullptr;
INT localSockaddrLength;
INT remoteSockaddrLength;
::GetAcceptExSockaddrs(
m_addressBuffer,
0,
sizeof(m_addressBuffer) / 2,
sizeof(m_addressBuffer) / 2,
&localSockaddr,
&localSockaddrLength,
&remoteSockaddr,
&remoteSockaddrLength);
m_acceptingSocket.m_localEndPoint =
detail::sockaddr_to_ip_endpoint(*localSockaddr);
m_acceptingSocket.m_remoteEndPoint =
detail::sockaddr_to_ip_endpoint(*remoteSockaddr);
{
// Need to set SO_UPDATE_ACCEPT_CONTEXT after the accept completes
// to ensure that ::shutdown() and ::setsockopt() calls work on the
// accepted socket.
SOCKET listenSocket = m_listeningSocket.native_handle();
const int result = ::setsockopt(
m_acceptingSocket.native_handle(),
SOL_SOCKET,
SO_UPDATE_ACCEPT_CONTEXT,
(const char*)&listenSocket,
sizeof(SOCKET));
if (result == SOCKET_ERROR)
{
const int errorCode = ::WSAGetLastError();
throw std::system_error{
errorCode,
std::system_category(),
"Socket accept operation failed: setsockopt(SO_UPDATE_ACCEPT_CONTEXT)"
};
}
}
}
bool cppcoro::net::socket_accept_operation_cancellable::try_start() noexcept
{
static_assert(
(sizeof(m_addressBuffer) / 2) >= (16 + sizeof(IN6_ADDR)),
"AcceptEx requires address buffer to be at least 16 bytes more than largest address.");
DWORD bytesReceived = 0;
BOOL ok = ::AcceptEx(
m_listeningSocket.native_handle(),
m_acceptingSocket.native_handle(),
m_addressBuffer,
0,
sizeof(m_addressBuffer) / 2,
sizeof(m_addressBuffer) / 2,
&bytesReceived,
get_overlapped());
if (!ok)
{
int errorCode = ::WSAGetLastError();
if (errorCode != ERROR_IO_PENDING)
{
m_errorCode = static_cast<DWORD>(errorCode);
return false;
}
}
else if (m_listeningSocket.m_skipCompletionOnSuccess)
{
m_errorCode = ERROR_SUCCESS;
return false;
}
return true;
}
void cppcoro::net::socket_accept_operation_cancellable::cancel() noexcept
{
(void)::CancelIoEx(
reinterpret_cast<HANDLE>(m_listeningSocket.native_handle()),
get_overlapped());
}
void cppcoro::net::socket_accept_operation_cancellable::get_result()
{
if (m_errorCode != ERROR_SUCCESS)
{
throw std::system_error{
static_cast<int>(m_errorCode),
std::system_category(),
"Accepting a connection failed: AcceptEx"
};
}
sockaddr* localSockaddr = nullptr;
sockaddr* remoteSockaddr = nullptr;
INT localSockaddrLength;
INT remoteSockaddrLength;
::GetAcceptExSockaddrs(
m_addressBuffer,
0,
sizeof(m_addressBuffer) / 2,
sizeof(m_addressBuffer) / 2,
&localSockaddr,
&localSockaddrLength,
&remoteSockaddr,
&remoteSockaddrLength);
m_acceptingSocket.m_localEndPoint =
detail::sockaddr_to_ip_endpoint(*localSockaddr);
m_acceptingSocket.m_remoteEndPoint =
detail::sockaddr_to_ip_endpoint(*remoteSockaddr);
{
// Need to set SO_UPDATE_ACCEPT_CONTEXT after the accept completes
// to ensure that ::shutdown() and ::setsockopt() calls work on the
// accepted socket.
SOCKET listenSocket = m_listeningSocket.native_handle();
const int result = ::setsockopt(
m_acceptingSocket.native_handle(),
SOL_SOCKET,
SO_UPDATE_ACCEPT_CONTEXT,
(const char*)&listenSocket,
sizeof(SOCKET));
if (result == SOCKET_ERROR)
{
const int errorCode = ::WSAGetLastError();
throw std::system_error{
errorCode,
std::system_category(),
"Socket accept operation failed: setsockopt(SO_UPDATE_ACCEPT_CONTEXT)"
};
}
}
}
#endif

View file

@ -0,0 +1,307 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#include <cppcoro/net/socket_connect_operation.hpp>
#include <cppcoro/net/socket.hpp>
#include "socket_helpers.hpp"
#include <cassert>
#include <system_error>
#if CPPCORO_OS_WINNT
# include <WinSock2.h>
# include <WS2tcpip.h>
# include <MSWSock.h>
# include <Windows.h>
bool cppcoro::net::socket_connect_operation::try_start() noexcept
{
// Lookup the address of the ConnectEx function pointer for this socket.
LPFN_CONNECTEX connectExPtr;
{
GUID connectExGuid = WSAID_CONNECTEX;
DWORD byteCount = 0;
int result = ::WSAIoctl(
m_socket.native_handle(),
SIO_GET_EXTENSION_FUNCTION_POINTER,
static_cast<void*>(&connectExGuid),
sizeof(connectExGuid),
static_cast<void*>(&connectExPtr),
sizeof(connectExPtr),
&byteCount,
nullptr,
nullptr);
if (result == SOCKET_ERROR)
{
m_errorCode = ::WSAGetLastError();
return false;
}
}
SOCKADDR_STORAGE remoteSockaddrStorage;
const int sockaddrNameLength = cppcoro::net::detail::ip_endpoint_to_sockaddr(
m_remoteEndPoint,
std::ref(remoteSockaddrStorage));
DWORD bytesSent = 0;
const BOOL ok = connectExPtr(
m_socket.native_handle(),
reinterpret_cast<const SOCKADDR*>(&remoteSockaddrStorage),
sockaddrNameLength,
nullptr, // send buffer
0, // size of send buffer
&bytesSent,
get_overlapped());
if (!ok)
{
const int errorCode = ::WSAGetLastError();
if (errorCode != ERROR_IO_PENDING)
{
// Failed synchronously.
m_errorCode = static_cast<DWORD>(errorCode);
return false;
}
}
else if (m_socket.m_skipCompletionOnSuccess)
{
// Successfully completed synchronously and no completion event
// will be posted to an I/O thread so we can return without suspending.
m_errorCode = ERROR_SUCCESS;
return false;
}
return true;
}
void cppcoro::net::socket_connect_operation::get_result()
{
if (m_errorCode != ERROR_SUCCESS)
{
if (m_errorCode == ERROR_OPERATION_ABORTED)
{
throw operation_cancelled{};
}
throw std::system_error{
static_cast<int>(m_errorCode),
std::system_category(),
"Connect operation failed: ConnectEx"
};
}
// We need to call setsockopt() to update the socket state with information
// about the connection now that it has been successfully connected.
{
const int result = ::setsockopt(
m_socket.native_handle(),
SOL_SOCKET,
SO_UPDATE_CONNECT_CONTEXT,
nullptr,
0);
if (result == SOCKET_ERROR)
{
// This shouldn't fail, but just in case it does we fall back to
// setting the remote address as specified in the call to Connect().
//
// Don't really want to throw an exception here since the connection
// has actually been established.
m_socket.m_remoteEndPoint = m_remoteEndPoint;
return;
}
}
{
SOCKADDR_STORAGE localSockaddr;
int nameLength = sizeof(localSockaddr);
const int result = ::getsockname(
m_socket.native_handle(),
reinterpret_cast<SOCKADDR*>(&localSockaddr),
&nameLength);
if (result == 0)
{
m_socket.m_localEndPoint = cppcoro::net::detail::sockaddr_to_ip_endpoint(
*reinterpret_cast<const SOCKADDR*>(&localSockaddr));
}
else
{
// Failed to get the updated local-end-point
// Just leave m_localEndPoint set to whatever bind() left it as.
//
// TODO: Should we be throwing an exception here instead?
}
}
{
SOCKADDR_STORAGE remoteSockaddr;
int nameLength = sizeof(remoteSockaddr);
const int result = ::getpeername(
m_socket.native_handle(),
reinterpret_cast<SOCKADDR*>(&remoteSockaddr),
&nameLength);
if (result == 0)
{
m_socket.m_remoteEndPoint = cppcoro::net::detail::sockaddr_to_ip_endpoint(
*reinterpret_cast<const SOCKADDR*>(&remoteSockaddr));
}
else
{
// Failed to get the actual remote end-point so just fall back to
// remembering the actual end-point that was passed to connect().
//
// TODO: Should we be throwing an exception here instead?
m_socket.m_remoteEndPoint = m_remoteEndPoint;
}
}
}
bool cppcoro::net::socket_connect_operation_cancellable::try_start() noexcept
{
// Lookup the address of the ConnectEx function pointer for this socket.
LPFN_CONNECTEX connectExPtr;
{
GUID connectExGuid = WSAID_CONNECTEX;
DWORD byteCount = 0;
int result = ::WSAIoctl(
m_socket.native_handle(),
SIO_GET_EXTENSION_FUNCTION_POINTER,
static_cast<void*>(&connectExGuid),
sizeof(connectExGuid),
static_cast<void*>(&connectExPtr),
sizeof(connectExPtr),
&byteCount,
nullptr,
nullptr);
if (result == SOCKET_ERROR)
{
m_errorCode = ::WSAGetLastError();
return false;
}
}
SOCKADDR_STORAGE remoteSockaddrStorage;
const int sockaddrNameLength = cppcoro::net::detail::ip_endpoint_to_sockaddr(
m_remoteEndPoint,
std::ref(remoteSockaddrStorage));
DWORD bytesSent = 0;
const BOOL ok = connectExPtr(
m_socket.native_handle(),
reinterpret_cast<const SOCKADDR*>(&remoteSockaddrStorage),
sockaddrNameLength,
nullptr, // send buffer
0, // size of send buffer
&bytesSent,
get_overlapped());
if (!ok)
{
const int errorCode = ::WSAGetLastError();
if (errorCode != ERROR_IO_PENDING)
{
// Failed synchronously.
m_errorCode = static_cast<DWORD>(errorCode);
return false;
}
}
else if (m_socket.m_skipCompletionOnSuccess)
{
// Successfully completed synchronously and no completion event
// will be posted to an I/O thread so we can return without suspending.
m_errorCode = ERROR_SUCCESS;
return false;
}
return true;
}
void cppcoro::net::socket_connect_operation_cancellable::get_result()
{
if (m_errorCode != ERROR_SUCCESS)
{
if (m_errorCode == ERROR_OPERATION_ABORTED)
{
throw operation_cancelled{};
}
throw std::system_error{
static_cast<int>(m_errorCode),
std::system_category(),
"Connect operation failed: ConnectEx"
};
}
// We need to call setsockopt() to update the socket state with information
// about the connection now that it has been successfully connected.
{
const int result = ::setsockopt(
m_socket.native_handle(),
SOL_SOCKET,
SO_UPDATE_CONNECT_CONTEXT,
nullptr,
0);
if (result == SOCKET_ERROR)
{
// This shouldn't fail, but just in case it does we fall back to
// setting the remote address as specified in the call to Connect().
//
// Don't really want to throw an exception here since the connection
// has actually been established.
m_socket.m_remoteEndPoint = m_remoteEndPoint;
return;
}
}
{
SOCKADDR_STORAGE localSockaddr;
int nameLength = sizeof(localSockaddr);
const int result = ::getsockname(
m_socket.native_handle(),
reinterpret_cast<SOCKADDR*>(&localSockaddr),
&nameLength);
if (result == 0)
{
m_socket.m_localEndPoint = cppcoro::net::detail::sockaddr_to_ip_endpoint(
*reinterpret_cast<const SOCKADDR*>(&localSockaddr));
}
else
{
// Failed to get the updated local-end-point
// Just leave m_localEndPoint set to whatever bind() left it as.
//
// TODO: Should we be throwing an exception here instead?
}
}
{
SOCKADDR_STORAGE remoteSockaddr;
int nameLength = sizeof(remoteSockaddr);
const int result = ::getpeername(
m_socket.native_handle(),
reinterpret_cast<SOCKADDR*>(&remoteSockaddr),
&nameLength);
if (result == 0)
{
m_socket.m_remoteEndPoint = cppcoro::net::detail::sockaddr_to_ip_endpoint(
*reinterpret_cast<const SOCKADDR*>(&remoteSockaddr));
}
else
{
// Failed to get the actual remote end-point so just fall back to
// remembering the actual end-point that was passed to connect().
//
// TODO: Should we be throwing an exception here instead?
m_socket.m_remoteEndPoint = m_remoteEndPoint;
}
}
}
void cppcoro::net::socket_connect_operation_cancellable::cancel() noexcept
{
(void)::CancelIoEx(
reinterpret_cast<HANDLE>(m_socket.native_handle()),
get_overlapped());
}
#endif

View file

@ -0,0 +1,158 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#include <cppcoro/net/socket_disconnect_operation.hpp>
#include <cppcoro/net/socket.hpp>
#include "socket_helpers.hpp"
#include <system_error>
#if CPPCORO_OS_WINNT
# include <WinSock2.h>
# include <WS2tcpip.h>
# include <MSWSock.h>
# include <Windows.h>
bool cppcoro::net::socket_disconnect_operation::try_start() noexcept
{
// Lookup the address of the DisconnectEx function pointer for this socket.
LPFN_DISCONNECTEX disconnectExPtr;
{
GUID disconnectExGuid = WSAID_DISCONNECTEX;
DWORD byteCount = 0;
const int result = ::WSAIoctl(
m_socket.native_handle(),
SIO_GET_EXTENSION_FUNCTION_POINTER,
static_cast<void*>(&disconnectExGuid),
sizeof(disconnectExGuid),
static_cast<void*>(&disconnectExPtr),
sizeof(disconnectExPtr),
&byteCount,
nullptr,
nullptr);
if (result == SOCKET_ERROR)
{
m_errorCode = static_cast<DWORD>(::WSAGetLastError());
return false;
}
}
// Need to add TF_REUSE_SOCKET to these flags if we want to allow reusing
// a socket for subsequent connections once the disconnect operation
// completes.
const DWORD flags = 0;
const BOOL ok = disconnectExPtr(
m_socket.native_handle(),
get_overlapped(),
flags,
0);
if (!ok)
{
const int errorCode = ::WSAGetLastError();
if (errorCode != ERROR_IO_PENDING)
{
// Failed synchronously.
m_errorCode = static_cast<DWORD>(errorCode);
return false;
}
}
else if (m_socket.m_skipCompletionOnSuccess)
{
// Successfully completed synchronously and no completion event
// will be posted to an I/O thread so we can return without suspending.
m_errorCode = ERROR_SUCCESS;
return false;
}
return true;
}
void cppcoro::net::socket_disconnect_operation::get_result()
{
if (m_errorCode != ERROR_SUCCESS)
{
throw std::system_error(
static_cast<int>(m_errorCode),
std::system_category(),
"Disconnect operation failed: DisconnectEx");
}
}
bool cppcoro::net::socket_disconnect_operation_cancellable::try_start() noexcept
{
// Lookup the address of the DisconnectEx function pointer for this socket.
LPFN_DISCONNECTEX disconnectExPtr;
{
GUID disconnectExGuid = WSAID_DISCONNECTEX;
DWORD byteCount = 0;
const int result = ::WSAIoctl(
m_socket.native_handle(),
SIO_GET_EXTENSION_FUNCTION_POINTER,
static_cast<void*>(&disconnectExGuid),
sizeof(disconnectExGuid),
static_cast<void*>(&disconnectExPtr),
sizeof(disconnectExPtr),
&byteCount,
nullptr,
nullptr);
if (result == SOCKET_ERROR)
{
m_errorCode = static_cast<DWORD>(::WSAGetLastError());
return false;
}
}
// Need to add TF_REUSE_SOCKET to these flags if we want to allow reusing
// a socket for subsequent connections once the disconnect operation
// completes.
const DWORD flags = 0;
const BOOL ok = disconnectExPtr(
m_socket.native_handle(),
get_overlapped(),
flags,
0);
if (!ok)
{
const int errorCode = ::WSAGetLastError();
if (errorCode != ERROR_IO_PENDING)
{
// Failed synchronously.
m_errorCode = static_cast<DWORD>(errorCode);
return false;
}
}
else if (m_socket.m_skipCompletionOnSuccess)
{
// Successfully completed synchronously and no completion event
// will be posted to an I/O thread so we can return without suspending.
m_errorCode = ERROR_SUCCESS;
return false;
}
return true;
}
void cppcoro::net::socket_disconnect_operation_cancellable::cancel() noexcept
{
(void)::CancelIoEx(
reinterpret_cast<HANDLE>(m_socket.native_handle()),
get_overlapped());
}
void cppcoro::net::socket_disconnect_operation_cancellable::get_result()
{
if (m_errorCode != ERROR_SUCCESS)
{
throw std::system_error(
static_cast<int>(m_errorCode),
std::system_category(),
"Disconnect operation failed: DisconnectEx");
}
}
#endif

75
lib/socket_helpers.cpp Normal file
View file

@ -0,0 +1,75 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#include "socket_helpers.hpp"
#include <cppcoro/net/ip_endpoint.hpp>
#if CPPCORO_OS_WINNT
#include <cstring>
#include <cassert>
#include <WinSock2.h>
#include <WS2tcpip.h>
#include <MSWSock.h>
#include <Windows.h>
cppcoro::net::ip_endpoint
cppcoro::net::detail::sockaddr_to_ip_endpoint(const sockaddr& address) noexcept
{
if (address.sa_family == AF_INET)
{
const auto& ipv4Address = *reinterpret_cast<const sockaddr_in*>(&address);
std::uint8_t addressBytes[4];
std::memcpy(addressBytes, &ipv4Address.sin_addr, 4);
return ipv4_endpoint{
ipv4_address{ addressBytes },
ipv4Address.sin_port
};
}
else
{
assert(address.sa_family == AF_INET6);
const auto& ipv6Address = *reinterpret_cast<const sockaddr_in6*>(&address);
return ipv6_endpoint{
ipv6_address{ ipv6Address.sin6_addr.u.Byte },
ipv6Address.sin6_port
};
}
}
int cppcoro::net::detail::ip_endpoint_to_sockaddr(
const ip_endpoint& endPoint,
std::reference_wrapper<sockaddr_storage> address) noexcept
{
if (endPoint.is_ipv4())
{
const auto& ipv4EndPoint = endPoint.to_ipv4();
auto& ipv4Address = *reinterpret_cast<SOCKADDR_IN*>(&address.get());
ipv4Address.sin_family = AF_INET;
std::memcpy(&ipv4Address.sin_addr, ipv4EndPoint.address().bytes(), 4);
ipv4Address.sin_port = ipv4EndPoint.port();
std::memset(&ipv4Address.sin_zero, 0, sizeof(ipv4Address.sin_zero));
return sizeof(SOCKADDR_IN);
}
else
{
const auto& ipv6EndPoint = endPoint.to_ipv6();
auto& ipv6Address = *reinterpret_cast<SOCKADDR_IN6*>(&address.get());
ipv6Address.sin6_family = AF_INET6;
std::memcpy(&ipv6Address.sin6_addr, ipv6EndPoint.address().bytes(), 16);
ipv6Address.sin6_port = ipv6EndPoint.port();
ipv6Address.sin6_flowinfo = 0;
ipv6Address.sin6_scope_struct = SCOPEID_UNSPECIFIED_INIT;
return sizeof(SOCKADDR_IN6);
}
}
#endif // CPPCORO_OS_WINNT

47
lib/socket_helpers.hpp Normal file
View file

@ -0,0 +1,47 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#ifndef CPPCORO_PRIVATE_SOCKET_HELPERS_HPP_INCLUDED
#define CPPCORO_PRIVATE_SOCKET_HELPERS_HPP_INCLUDED
#include <cppcoro/config.hpp>
#if CPPCORO_OS_WINNT
# include <cppcoro/detail/win32.hpp>
struct sockaddr;
struct sockaddr_storage;
#endif
namespace cppcoro
{
namespace net
{
class ip_endpoint;
namespace detail
{
#if CPPCORO_OS_WINNT
/// Convert a sockaddr to an IP endpoint.
ip_endpoint sockaddr_to_ip_endpoint(const sockaddr& address) noexcept;
/// Converts an ip_endpoint to a sockaddr structure.
///
/// \param endPoint
/// The IP endpoint to convert to a sockaddr structure.
///
/// \param address
/// The sockaddr structure to populate.
///
/// \return
/// The length of the sockaddr structure that was populated.
int ip_endpoint_to_sockaddr(
const ip_endpoint& endPoint,
std::reference_wrapper<sockaddr_storage> address) noexcept;
#endif
}
}
}
#endif

View file

@ -0,0 +1,115 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#include <cppcoro/net/socket_recv_operation.hpp>
#include <cppcoro/net/socket.hpp>
#if CPPCORO_OS_WINNT
# include <WinSock2.h>
# include <WS2tcpip.h>
# include <MSWSock.h>
# include <Windows.h>
cppcoro::net::socket_recv_operation::socket_recv_operation(
socket& s,
void* buffer,
std::size_t byteCount) noexcept
: m_socketHandle(s.native_handle())
, m_skipCompletionOnSuccess(s.skip_completion_on_success())
, m_buffer(buffer, byteCount)
{
}
bool cppcoro::net::socket_recv_operation::try_start() noexcept
{
DWORD numberOfBytesReceived = 0;
DWORD flags = 0;
int result = ::WSARecv(
m_socketHandle,
reinterpret_cast<WSABUF*>(&m_buffer),
1, // buffer count
&numberOfBytesReceived,
&flags,
get_overlapped(),
nullptr);
if (result == SOCKET_ERROR)
{
int errorCode = ::WSAGetLastError();
if (errorCode != WSA_IO_PENDING)
{
// Failed synchronously.
m_errorCode = static_cast<DWORD>(errorCode);
m_numberOfBytesTransferred = numberOfBytesReceived;
return false;
}
}
else if (m_skipCompletionOnSuccess)
{
// Completed synchronously, no completion event will be posted to the IOCP.
m_errorCode = ERROR_SUCCESS;
m_numberOfBytesTransferred = numberOfBytesReceived;
return false;
}
// Operation will complete asynchronously.
return true;
}
cppcoro::net::socket_recv_operation_cancellable::socket_recv_operation_cancellable(
socket& s,
void* buffer,
std::size_t byteCount,
cancellation_token&& ct) noexcept
: cppcoro::detail::win32_overlapped_operation_cancellable<cppcoro::net::socket_recv_operation_cancellable>(
std::move(ct))
, m_socketHandle(s.native_handle())
, m_skipCompletionOnSuccess(s.skip_completion_on_success())
, m_buffer(buffer, byteCount)
{
}
bool cppcoro::net::socket_recv_operation_cancellable::try_start() noexcept
{
DWORD numberOfBytesReceived = 0;
DWORD flags = 0;
int result = ::WSARecv(
m_socketHandle,
reinterpret_cast<WSABUF*>(&m_buffer),
1, // buffer count
&numberOfBytesReceived,
&flags,
get_overlapped(),
nullptr);
if (result == SOCKET_ERROR)
{
int errorCode = ::WSAGetLastError();
if (errorCode != WSA_IO_PENDING)
{
// Failed synchronously.
m_errorCode = static_cast<DWORD>(errorCode);
m_numberOfBytesTransferred = numberOfBytesReceived;
return false;
}
}
else if (m_skipCompletionOnSuccess)
{
// Completed synchronously, no completion event will be posted to the IOCP.
m_errorCode = ERROR_SUCCESS;
m_numberOfBytesTransferred = numberOfBytesReceived;
return false;
}
// Operation will complete asynchronously.
return true;
}
void cppcoro::net::socket_recv_operation_cancellable::cancel() noexcept
{
(void)::CancelIoEx(
reinterpret_cast<HANDLE>(m_socketHandle),
get_overlapped());
}
#endif

View file

@ -0,0 +1,113 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#include <cppcoro/net/socket_send_operation.hpp>
#include <cppcoro/net/socket.hpp>
#if CPPCORO_OS_WINNT
# include <WinSock2.h>
# include <WS2tcpip.h>
# include <MSWSock.h>
# include <Windows.h>
cppcoro::net::socket_send_operation::socket_send_operation(
socket& s,
const void* buffer,
std::size_t byteCount) noexcept
: m_socketHandle(s.native_handle())
, m_skipCompletionOnSuccess(s.skip_completion_on_success())
, m_buffer(const_cast<void*>(buffer), byteCount)
{
}
bool cppcoro::net::socket_send_operation::try_start() noexcept
{
DWORD numberOfBytesSent = 0;
int result = ::WSASend(
m_socketHandle,
reinterpret_cast<WSABUF*>(&m_buffer),
1, // buffer count
&numberOfBytesSent,
0, // flags
get_overlapped(),
nullptr);
if (result == SOCKET_ERROR)
{
int errorCode = ::WSAGetLastError();
if (errorCode != WSA_IO_PENDING)
{
// Failed synchronously.
m_errorCode = static_cast<DWORD>(errorCode);
m_numberOfBytesTransferred = numberOfBytesSent;
return false;
}
}
else if (m_skipCompletionOnSuccess)
{
// Completed synchronously, no completion event will be posted to the IOCP.
m_errorCode = ERROR_SUCCESS;
m_numberOfBytesTransferred = numberOfBytesSent;
return false;
}
// Operation will complete asynchronously.
return true;
}
cppcoro::net::socket_send_operation_cancellable::socket_send_operation_cancellable(
socket& s,
const void* buffer,
std::size_t byteCount,
cancellation_token&& ct) noexcept
: cppcoro::detail::win32_overlapped_operation_cancellable<cppcoro::net::socket_send_operation_cancellable>(
std::move(ct))
, m_socketHandle(s.native_handle())
, m_skipCompletionOnSuccess(s.skip_completion_on_success())
, m_buffer(const_cast<void*>(buffer), byteCount)
{
}
bool cppcoro::net::socket_send_operation_cancellable::try_start() noexcept
{
DWORD numberOfBytesSent = 0;
int result = ::WSASend(
m_socketHandle,
reinterpret_cast<WSABUF*>(&m_buffer),
1, // buffer count
&numberOfBytesSent,
0, // flags
get_overlapped(),
nullptr);
if (result == SOCKET_ERROR)
{
int errorCode = ::WSAGetLastError();
if (errorCode != WSA_IO_PENDING)
{
// Failed synchronously.
m_errorCode = static_cast<DWORD>(errorCode);
m_numberOfBytesTransferred = numberOfBytesSent;
return false;
}
}
else if (m_skipCompletionOnSuccess)
{
// Completed synchronously, no completion event will be posted to the IOCP.
m_errorCode = ERROR_SUCCESS;
m_numberOfBytesTransferred = numberOfBytesSent;
return false;
}
// Operation will complete asynchronously.
return true;
}
void cppcoro::net::socket_send_operation_cancellable::cancel() noexcept
{
(void)::CancelIoEx(
reinterpret_cast<HANDLE>(m_socketHandle),
get_overlapped());
}
#endif

View file

@ -15,4 +15,6 @@ compiler.addLibrary(buildScript.getResult('library'))
if variant.platform == "windows":
compiler.addLibrary("Synchronization")
compiler.addLibrary("kernel32")
compiler.addLibrary("WS2_32")
compiler.addLibrary("Mswsock")

View file

@ -46,6 +46,7 @@ if variant.platform == 'windows':
'scheduling_operator_tests.cpp',
'io_service_tests.cpp',
'file_tests.cpp',
'socket_tests.cpp',
])
extras = script.cwd([

206
test/socket_tests.cpp Normal file
View file

@ -0,0 +1,206 @@
///////////////////////////////////////////////////////////////////////////////
// Copyright (c) Lewis Baker
// Licenced under MIT license. See LICENSE.txt for details.
///////////////////////////////////////////////////////////////////////////////
#include <cppcoro/io_service.hpp>
#include <cppcoro/net/socket.hpp>
#include <cppcoro/task.hpp>
#include <cppcoro/when_all.hpp>
#include <cppcoro/sync_wait.hpp>
#include <cppcoro/on_scope_exit.hpp>
#include "doctest/doctest.h"
using namespace cppcoro;
using namespace cppcoro::net;
TEST_SUITE_BEGIN("socket");
TEST_CASE("create TCP/IPv4")
{
io_service ioSvc;
auto socket = socket::create_tcpv4(ioSvc);
}
TEST_CASE("create TCP/IPv6")
{
io_service ioSvc;
auto socket = socket::create_tcpv6(ioSvc);
}
TEST_CASE("create UDP/IPv4")
{
io_service ioSvc;
auto socket = socket::create_udpv4(ioSvc);
}
TEST_CASE("create UDP/IPv6")
{
io_service ioSvc;
auto socket = socket::create_udpv6(ioSvc);
}
TEST_CASE("TCP/IPv4 connect/disconnect")
{
io_service ioSvc;
ip_endpoint serverAddress;
task<int> serverTask;
auto server = [&](socket listeningSocket) -> task<int>
{
auto s = socket::create_tcpv4(ioSvc);
co_await listeningSocket.accept(s);
co_await s.disconnect();
co_return 0;
};
{
auto serverSocket = socket::create_tcpv4(ioSvc);
serverSocket.bind(ipv4_endpoint{ ipv4_address::loopback(), 0 });
serverSocket.listen(3);
serverAddress = serverSocket.local_endpoint();
serverTask = server(std::move(serverSocket));
}
auto client = [&]() -> task<int>
{
auto s = socket::create_tcpv4(ioSvc);
s.bind(ipv4_endpoint{ ipv4_address::loopback(), 0 });
co_await s.connect(serverAddress);
co_await s.disconnect();
co_return 0;
};
task<int> clientTask = client();
(void)sync_wait(when_all(
[&]() -> task<int>
{
auto stopOnExit = on_scope_exit([&] { ioSvc.stop(); });
(void)co_await when_all(std::move(serverTask), std::move(clientTask));
co_return 0;
}(),
[&]() -> task<int>
{
ioSvc.process_events();
co_return 0;
}()));
}
TEST_CASE("send/recv TCP/IPv4")
{
io_service ioSvc;
auto listeningSocket = socket::create_tcpv4(ioSvc);
listeningSocket.bind(ipv4_endpoint{ ipv4_address::loopback(), 0 });
listeningSocket.listen(3);
auto echoServer = [&]() -> task<int>
{
auto acceptingSocket = socket::create_tcpv4(ioSvc);
co_await listeningSocket.accept(acceptingSocket);
std::uint8_t buffer[64];
std::size_t bytesReceived;
do
{
bytesReceived = co_await acceptingSocket.recv(buffer, sizeof(buffer));
if (bytesReceived > 0)
{
std::size_t bytesSent = 0;
do
{
bytesSent += co_await acceptingSocket.send(
buffer + bytesSent,
bytesReceived - bytesSent);
} while (bytesSent < bytesReceived);
}
} while (bytesReceived > 0);
acceptingSocket.close_send();
co_await acceptingSocket.disconnect();
co_return 0;
};
auto echoClient = [&]() -> task<int>
{
auto connectingSocket = socket::create_tcpv4(ioSvc);
connectingSocket.bind(ipv4_endpoint{});
co_await connectingSocket.connect(listeningSocket.local_endpoint());
auto receive = [&]() -> task<int>
{
std::uint8_t buffer[100];
std::uint64_t totalBytesReceived = 0;
std::size_t bytesReceived;
do
{
bytesReceived = co_await connectingSocket.recv(buffer, sizeof(buffer));
for (std::size_t i = 0; i < bytesReceived; ++i)
{
std::uint64_t byteIndex = totalBytesReceived + i;
std::uint8_t expectedByte = 'a' + (byteIndex % 26);
CHECK(buffer[i] == expectedByte);
}
totalBytesReceived += bytesReceived;
} while (bytesReceived > 0);
CHECK(totalBytesReceived == 1000);
co_return 0;
};
auto send = [&]() -> task<int>
{
std::uint8_t buffer[100];
for (std::uint64_t i = 0; i < 1000; i += sizeof(buffer))
{
for (std::size_t j = 0; j < sizeof(buffer); ++j)
{
buffer[j] = 'a' + ((i + j) % 26);
}
std::size_t bytesSent = 0;
do
{
bytesSent += co_await connectingSocket.send(buffer + bytesSent, sizeof(buffer) - bytesSent);
} while (bytesSent < sizeof(buffer));
}
connectingSocket.close_send();
co_return 0;
};
co_await when_all(send(), receive());
co_await connectingSocket.disconnect();
co_return 0;
};
(void)sync_wait(when_all(
[&]() -> task<int>
{
auto stopOnExit = on_scope_exit([&] { ioSvc.stop(); });
(void)co_await when_all(echoClient(), echoServer());
co_return 0;
}(),
[&]() -> task<int>
{
ioSvc.process_events();
co_return 0;
}()));
}
TEST_SUITE_END();