From 5acc21eb4c588756c0e688271d642640f981a5d8 Mon Sep 17 00:00:00 2001 From: Geoffrey Daniels Date: Tue, 3 Oct 2023 16:33:07 +0100 Subject: [PATCH] Adding socket class. --- README.md | 1 + source/io/socket | 626 +++++++++++++++++++++++++++++++++++++++ tests/io/socket.test.cpp | 279 +++++++++++++++++ 3 files changed, 906 insertions(+) create mode 100644 source/io/socket create mode 100644 tests/io/socket.test.cpp diff --git a/README.md b/README.md index 7793394..033883a 100644 --- a/README.md +++ b/README.md @@ -61,6 +61,7 @@ The current classes are as described below: | [hash](source/hash) | [sha3](source/hash/sha3) | An implementation of the sha3 hashing function for 224, 256, 384, and 512 bits. | :heavy_check_mark: | | [io](source/io) | [file](source/io/file) | An RAII file handle that wraps file operation functions. | :construction: | | [io](source/io) | [paths](source/io/paths) | Collection of cross platform functions to provide useful paths. | :heavy_check_mark: | +| [io](source/io) | [socket](source/io/socket) | Cross platform socket class, supporting tcp (server and client) and udp protocols. | :construction: | | [math](source/math) | [big_integer](source/math/big_integer) | Arbitrary sized signed integers. | :heavy_check_mark: | | [math](source/math) | [big_unsigned](source/math/big_unsigned) | Arbitrary sized unsigned integers. | :heavy_check_mark: | | [math](source/math) | [symbolic](source/math/symbolic) | Compile time symbolic differentiation using template metaprogramming. | :construction: | diff --git a/source/io/socket b/source/io/socket new file mode 100644 index 0000000..e7a657a --- /dev/null +++ b/source/io/socket @@ -0,0 +1,626 @@ +/* +Copyright (C) 2018-2023 Geoffrey Daniels. https://gpdaniels.com/ + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, version 3 of the License only. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program. If not, see . +*/ + +#pragma once +#ifndef GTL_IO_SOCKET_HPP +#define GTL_IO_SOCKET_HPP + +// Summary: Cross platform socket class, supporting tcp (server and client) and udp protocols. [wip] + +#if defined(linux) || defined(__linux) || defined(__linux__) + #include + #include + #include + #include + #include + + #define SOCKET int + #define INVALID_SOCKET -1 + #define closesocket ::close + #define SD_BOTH SHUT_RDWR + #define ioctlsocket ioctl +#endif + +#if defined(_WIN32) + + #if defined(_MSC_VER) + # pragma warning(push, 0) + #endif + + #include + #include + + #pragma comment(lib, "Ws2_32.lib") + + #if defined(_MSC_VER) + # pragma warning(pop) + #endif + +#endif + +#if defined(__APPLE__) + #include + #include + #include + #include + #include + + #define SOCKET int + #define INVALID_SOCKET -1 + #define closesocket ::close + #define SD_BOTH SHUT_RDWR + #define ioctlsocket ioctl +#endif + +static_assert(INVALID_SOCKET == -1, "The invalid socket definition must be negative one."); + +namespace gtl { + class socket final { + public: + struct ip { + unsigned char segment[4]; + bool operator==(const ip& other) const { + return ((this->segment[0] == other.segment[0]) && (this->segment[1] == other.segment[1]) && (this->segment[2] == other.segment[2]) && (this->segment[3] == other.segment[3])); + } + }; + + constexpr static const ip ip_any = {0, 0, 0, 0}; + constexpr static const ip ip_loopback = {127, 0, 0, 1}; + constexpr static const ip ip_broadcast = {255, 255, 255, 255}; + + constexpr static const unsigned short port_any = 0; + constexpr static const unsigned short port_ssh = 22; + constexpr static const unsigned short port_smtp = 25; + constexpr static const unsigned short port_dns = 53; + constexpr static const unsigned short port_http = 80; + constexpr static const unsigned short port_ssl = 443; + + struct tcp_server { + ip address; + unsigned short port; + }; + + struct tcp_client { + ip address_local; + unsigned short port_local; + ip address_remote; + unsigned short port_remote; + }; + + struct udp_client { + ip address; + unsigned short port; + }; + + private: + // The socket handle. + SOCKET handle; + + public: + ~socket() { + this->close(); + #if defined(_WIN32) + WSACleanup(); + #endif + } + + socket() + : handle(INVALID_SOCKET) { + #if defined(_WIN32) + WORD VersionRequested = MAKEWORD(2, 2); + WSADATA WinSockData; + if (WSAStartup(VersionRequested, &WinSockData) != 0) { + return; + } + // Confirm that the WinSock DLL supports 2.2. + if (LOBYTE(WinSockData.wVersion) != 2 || HIBYTE(WinSockData.wVersion) != 2) { + return; + } + #endif + } + + socket(const socket&) = delete; + + socket(socket&& other) { + this->handle = other.handle; + other.handle = INVALID_SOCKET; + } + + socket& operator=(const socket&) = delete; + + socket& operator=(socket&& other) { + if (this != &other) { + this->handle = other.handle; + other.handle = INVALID_SOCKET; + } + return *this; + } + + public: + bool is_open() const { + return ((this->handle >= 0) && (this->handle != INVALID_SOCKET)); + } + + bool open(const tcp_server& connection) { + // Ensure closed. + this->close(); + + // Open a socket. + this->handle = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + + // Check the socket is valid. + if ((this->handle < 0) || (this->handle == INVALID_SOCKET)) { + return false; + } + + // Enable port and address reuse. + const int reuseaddr_value = 1; + if (setsockopt(this->handle, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&reuseaddr_value), sizeof(int)) < 0) { + closesocket(this->handle); + this->handle = INVALID_SOCKET; + return false; + } + + // Configure keepalive options for the socket. + const int keepalive_value = 1; + if (setsockopt(this->handle, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast(&keepalive_value), sizeof(int)) < 0) { + closesocket(this->handle); + this->handle = INVALID_SOCKET; + return false; + } + + // These are not supported by apple. + #if !defined(__APPLE__) + // Configure other keepalive options for the socket. + const int keepidle_value = 1; + if (setsockopt(this->handle, IPPROTO_TCP, TCP_KEEPIDLE, reinterpret_cast(&keepidle_value), sizeof(int)) < 0) { + closesocket(this->handle); + this->handle = INVALID_SOCKET; + return false; + } + const int keepintvl_value = 1; + if (setsockopt(this->handle, IPPROTO_TCP, TCP_KEEPINTVL, reinterpret_cast(&keepintvl_value), sizeof(int)) < 0) { + closesocket(this->handle); + this->handle = INVALID_SOCKET; + return false; + } + const int keepcnt_value = 10; + if (setsockopt(this->handle, IPPROTO_TCP, TCP_KEEPCNT, reinterpret_cast(&keepcnt_value), sizeof(int)) < 0) { + closesocket(this->handle); + this->handle = INVALID_SOCKET; + return false; + } + #endif + + // Configure the socket as blocking. + #if defined(_WIN32) + unsigned long int non_blocking_value = 0; + #else + int non_blocking_value = 0; + #endif + if (ioctlsocket(this->handle, FIONBIO, &non_blocking_value) < 0) { + closesocket(this->handle); + this->handle = INVALID_SOCKET; + return false; + } + + // Setup the local address and port for the bind call. + sockaddr_in address_bind = {}; + address_bind.sin_family = AF_INET; + address_bind.sin_addr.s_addr = + (static_cast(connection.address.segment[3]) << 24) | + (static_cast(connection.address.segment[2]) << 16) | + (static_cast(connection.address.segment[1]) << 8) | + (static_cast(connection.address.segment[0]) << 0); + address_bind.sin_port = htons(connection.port); + + // Try and bind to the address and port. + if (bind(this->handle, reinterpret_cast(&address_bind), sizeof(sockaddr_in)) < 0) { + closesocket(this->handle); + this->handle = INVALID_SOCKET; + return false; + } + + // Now set the socket to listen for incoming connections. + if (listen(this->handle, SOMAXCONN) < 0) { + closesocket(this->handle); + this->handle = INVALID_SOCKET; + return false; + } + + return true; + } + + bool open(const tcp_client& connection) { + // Ensure closed. + this->close(); + + // Open a socket. + this->handle = ::socket(AF_INET, SOCK_STREAM, IPPROTO_TCP); + + // Check the socket is valid. + if ((this->handle < 0) || (this->handle == INVALID_SOCKET)) { + return false; + } + + // Enable port and address reuse. + const int reuseaddr_value = 1; + if (setsockopt(this->handle, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&reuseaddr_value), sizeof(int)) < 0) { + closesocket(this->handle); + this->handle = INVALID_SOCKET; + return false; + } + + // Configure keepalive options for the socket. + const int keepalive_value = 1; + if (setsockopt(this->handle, SOL_SOCKET, SO_KEEPALIVE, reinterpret_cast(&keepalive_value), sizeof(int)) < 0) { + closesocket(this->handle); + this->handle = INVALID_SOCKET; + return false; + } + + // These are not supported by apple. + #if !defined(__APPLE__) + // Configure other keepalive options for the socket. + const int keepidle_value = 1; + if (setsockopt(this->handle, IPPROTO_TCP, TCP_KEEPIDLE, reinterpret_cast(&keepidle_value), sizeof(int)) < 0) { + closesocket(this->handle); + this->handle = INVALID_SOCKET; + return false; + } + const int keepintvl_value = 1; + if (setsockopt(this->handle, IPPROTO_TCP, TCP_KEEPINTVL, reinterpret_cast(&keepintvl_value), sizeof(int)) < 0) { + closesocket(this->handle); + this->handle = INVALID_SOCKET; + return false; + } + const int keepcnt_value = 10; + if (setsockopt(this->handle, IPPROTO_TCP, TCP_KEEPCNT, reinterpret_cast(&keepcnt_value), sizeof(int)) < 0) { + closesocket(this->handle); + this->handle = INVALID_SOCKET; + return false; + } + #endif + + // Configure the socket as blocking. + #if defined(_WIN32) + unsigned long int non_blocking_value = 0; + #else + int non_blocking_value = 0; + #endif + if (ioctlsocket(this->handle, FIONBIO, &non_blocking_value) < 0) { + closesocket(this->handle); + this->handle = INVALID_SOCKET; + return false; + } + + // Setup the local address and port for the bind call. + sockaddr_in address_bind = {}; + address_bind.sin_family = AF_INET; + address_bind.sin_addr.s_addr = + (static_cast(connection.address_local.segment[3]) << 24) | + (static_cast(connection.address_local.segment[2]) << 16) | + (static_cast(connection.address_local.segment[1]) << 8) | + (static_cast(connection.address_local.segment[0]) << 0); + address_bind.sin_port = htons(connection.port_local); + + // Try and bind to the address and port. + if (bind(this->handle, reinterpret_cast(&address_bind), sizeof(sockaddr_in)) < 0) { + closesocket(this->handle); + this->handle = INVALID_SOCKET; + return false; + } + + // Setup the address and port + sockaddr_in address_connect = {}; + address_connect.sin_family = AF_INET; + address_connect.sin_addr.s_addr = + (static_cast(connection.address_remote.segment[3]) << 24) | + (static_cast(connection.address_remote.segment[2]) << 16) | + (static_cast(connection.address_remote.segment[1]) << 8) | + (static_cast(connection.address_remote.segment[0]) << 0); + address_connect.sin_port = htons(connection.port_remote); + + // Try to connect to the remote host + if (connect(this->handle, reinterpret_cast(&address_connect), sizeof(sockaddr_in)) < 0) { + closesocket(this->handle); + this->handle = INVALID_SOCKET; + return false; + } + + return true; + } + + bool open(const udp_client& connection) { + // Ensure closed. + this->close(); + + // Open a socket. + this->handle = ::socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP); + + // Check the socket is valid. + if ((this->handle < 0) || (this->handle == INVALID_SOCKET)) { + return false; + } + + // Enable port and address reuse. + const int reuseaddr_value = 1; + if (setsockopt(this->handle, SOL_SOCKET, SO_REUSEADDR, reinterpret_cast(&reuseaddr_value), sizeof(int)) < 0) { + closesocket(this->handle); + this->handle = INVALID_SOCKET; + return false; + } + + // Configure broadcast options for the socket. + const int broadcast_value = 1; + if (setsockopt(this->handle, SOL_SOCKET, SO_BROADCAST, reinterpret_cast(&broadcast_value), sizeof(int)) < 0) { + closesocket(this->handle); + this->handle = INVALID_SOCKET; + return false; + } + + // Setup the local address and port for the bind call. + sockaddr_in address_bind = {}; + address_bind.sin_family = AF_INET; + address_bind.sin_addr.s_addr = + (static_cast(connection.address.segment[3]) << 24) | + (static_cast(connection.address.segment[2]) << 16) | + (static_cast(connection.address.segment[1]) << 8) | + (static_cast(connection.address.segment[0]) << 0); + address_bind.sin_port = htons(connection.port); + + // Try and bind to the address and port. + if (bind(this->handle, reinterpret_cast(&address_bind), sizeof(sockaddr_in)) < 0) { + closesocket(this->handle); + this->handle = INVALID_SOCKET; + return false; + } + + return true; + } + + void close() { + if (!this->is_open()) { + return; + } + + sockaddr_in address_null = {}; + address_null.sin_family = AF_UNSPEC; + + // Try to disconnect + if (connect(this->handle, reinterpret_cast(&address_null), sizeof(sockaddr_in)) < 0) { + // Failed to disconnect. + } + + if (shutdown(this->handle, SD_BOTH) < 0) { + // Failed to shutdown. + } + + if (closesocket(this->handle) < 0) { + // Failed to close. + } + + this->handle = INVALID_SOCKET; + } + + bool accept(socket& connection) const { + ip address; + unsigned short port; + return this->accept(connection, address, port); + } + + bool accept(socket& connection, ip& address, unsigned short& port) const { + if (!this->is_open()) { + return false; + } + + // Ensure the connection to be created is closed initially. + connection.close(); + + // Prepare an address to store the sender's address. + sockaddr_in address_accepted = {}; + address_accepted.sin_family = AF_INET; + socklen_t address_length = sizeof(sockaddr_in); + + // Accept a connection. + SOCKET handle_accepted = ::accept(this->handle, reinterpret_cast(&address_accepted), &address_length); + if (handle_accepted == INVALID_SOCKET) { + return false; + } + + // Get the sender. + address.segment[0] = (address_accepted.sin_addr.s_addr >> 0) & 0xFF; + address.segment[1] = (address_accepted.sin_addr.s_addr >> 8) & 0xFF; + address.segment[2] = (address_accepted.sin_addr.s_addr >> 16) & 0xFF; + address.segment[3] = (address_accepted.sin_addr.s_addr >> 24) & 0xFF; + port = ntohs(address_accepted.sin_port); + + // "Open" the connection. + connection.handle = handle_accepted; + + return true; + } + + // Returns the configuration of a socket, used after listen to get port number. + bool get_config(ip& address, unsigned short& port) { + if (!this->is_open()) { + return false; + } + sockaddr_in address_config = {}; + socklen_t address_length = sizeof(sockaddr_in); + if (getsockname(this->handle, reinterpret_cast(&address_config), &address_length) < 0) { + return false; + } + address.segment[0] = (address_config.sin_addr.s_addr >> 0) & 0xFF; + address.segment[1] = (address_config.sin_addr.s_addr >> 8) & 0xFF; + address.segment[2] = (address_config.sin_addr.s_addr >> 16) & 0xFF; + address.segment[3] = (address_config.sin_addr.s_addr >> 24) & 0xFF; + port = ntohs(address_config.sin_port); + return true; + } + + bool read(unsigned char* buffer, unsigned long long int& length) const { + ip address; + unsigned short port; + return this->read(buffer, length, address, port); + } + + bool read(unsigned char* buffer, unsigned long long int& length, ip& address, unsigned short& port) const { + if (!this->is_open()) { + return false; + } + + if (length == 0) { + return true; + } + + // The available data length. + unsigned long long int length_available = 0; + + // // Get the available data length in linux. + #if defined(linux) || defined(__linux) || defined(__linux__) + int length_available_int = 0; + if (ioctlsocket(this->handle, FIONREAD, &length_available_int) < 0) { + return false; + } + length_available = static_cast(length_available_int); + #endif + + // Get the available data length in windows. + #if defined(_WIN32) + unsigned long int length_available_int = 0; + if (ioctlsocket(this->handle, FIONREAD, &length_available_int) < 0) { + return false; + } + length_available = static_cast(length_available_int); + #endif + + // Get the available data length in macos. + #if defined(__APPLE__) + int length_available_int; + socklen_t option_length = sizeof(int); + if (getsockopt(this->handle, SOL_SOCKET, SO_NREAD, &length_available_int, &option_length) < 0) { + return false; + } + length_available = static_cast(length_available_int); + #endif + + if (length_available == 0) { + length = 0; + return true; + } + + // Limit the amount received to fit in the buffer we have. + if (length_available > length) { + length_available = length; + } + + // Prepare an address to store the sender's address. + sockaddr_in address_source = {}; + socklen_t address_length = sizeof(sockaddr_in); + + // Get data + #if defined(_WIN32) + const long long int length_received = recvfrom(this->handle, reinterpret_cast(buffer), static_cast(length_available), 0, reinterpret_cast(&address_source), &address_length); + #else + const long long int length_received = recvfrom(this->handle, buffer, length_available, 0, reinterpret_cast(&address_source), &address_length); + #endif + + // Check the read was successful. + if (length_received < 0) { + return false; + } + + // Return the length of the message has been received. + length = static_cast(length_received); + + // Get the sender. + address.segment[0] = (address_source.sin_addr.s_addr >> 0) & 0xFF; + address.segment[1] = (address_source.sin_addr.s_addr >> 8) & 0xFF; + address.segment[2] = (address_source.sin_addr.s_addr >> 16) & 0xFF; + address.segment[3] = (address_source.sin_addr.s_addr >> 24) & 0xFF; + port = ntohs(address_source.sin_port); + + return true; + } + + bool write(const unsigned char* buffer, unsigned long long int& length) const { + if (!this->is_open()) { + return false; + } + + if (length == 0) { + return true; + } + + // Write out the whole buffer as a single message. + #if defined(_WIN32) + const long long int length_written = send(this->handle, reinterpret_cast(buffer), static_cast(length), 0); + #else + const long long int length_written = send(this->handle, buffer, length, 0); + #endif + + // Check the write was successful. + if (length_written < 0) { + return false; + } + + // Return the length of the message has been sent. + length = static_cast(length_written); + + return true; + } + + bool write(const unsigned char* buffer, unsigned long long int& length, const ip address, const unsigned short port) const { + if (!this->is_open()) { + return false; + } + + if (length == 0) { + return true; + } + + // Setup the target address. + sockaddr_in address_target = {}; + address_target.sin_family = AF_INET; + address_target.sin_addr.s_addr = + (static_cast(address.segment[3]) << 24) | + (static_cast(address.segment[2]) << 16) | + (static_cast(address.segment[1]) << 8) | + (static_cast(address.segment[0]) << 0); + address_target.sin_port = htons(port); + + // Write out the whole buffer as a single message. + #if defined(_WIN32) + const long long int length_written = sendto(this->handle, reinterpret_cast(buffer), static_cast(length), 0, reinterpret_cast(&address_target), sizeof(sockaddr_in)); + #else + const long long int length_written = sendto(this->handle, buffer, length, 0, reinterpret_cast(&address_target), sizeof(sockaddr_in)); + #endif + + // Check the write was successful. + if (length_written < 0) { + return false; + } + + // Return the length of the message has been sent. + length = static_cast(length_written); + + return true; + } + }; +} + +#endif // GTL_IO_SOCKET_HPP diff --git a/tests/io/socket.test.cpp b/tests/io/socket.test.cpp new file mode 100644 index 0000000..bbc30fa --- /dev/null +++ b/tests/io/socket.test.cpp @@ -0,0 +1,279 @@ +/* +Copyright (C) 2018-2023 Geoffrey Daniels. https://gpdaniels.com/ + +This program is free software: you can redistribute it and/or modify +it under the terms of the GNU General Public License as published by +the Free Software Foundation, version 3 of the License only. + +This program is distributed in the hope that it will be useful, +but WITHOUT ANY WARRANTY; without even the implied warranty of +MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +GNU General Public License for more details. + +You should have received a copy of the GNU General Public License +along with this program. If not, see . +*/ + +#include +#include +#include +#include + +#include + +#include + +#if defined(_MSC_VER) +# pragma warning(push, 0) +#endif + +#include +#include +#include + +#if defined(_MSC_VER) +# pragma warning(pop) +#endif + +class local_server { +private: + gtl::socket::tcp_server config_server; + gtl::socket socket_server; + std::thread thread_server; + std::mutex mutex_client; + gtl::socket socket_client; + unsigned int accepted_connections = 0; +public: + ~local_server() { + this->socket_server.close(); + this->thread_server.join(); + } + local_server() { + // Create connection accepting server on a thread. + gtl::barrier barrier(2); + this->thread_server = std::thread([this, &barrier](){ + REQUIRE(this->socket_server.open(gtl::socket::tcp_server{gtl::socket::ip_any, gtl::socket::port_any})); + REQUIRE(this->socket_server.is_open()); + // Get the port number. + this->socket_server.get_config(this->config_server.address, this->config_server.port); + // Syncronise with the client thread. + barrier.sync(); + // Server connection accepting loop. + while (this->socket_server.is_open()) { + gtl::socket client; + if (this->socket_server.accept(client)) { + std::lock_guard lock(this->mutex_client); + ++this->accepted_connections; + static_cast(lock); + // Replace previous client with new client. + this->socket_client.close(); + this->socket_client = std::move(client); + } + std::this_thread::yield(); + } + }); + // Wait for the server to be up. + barrier.sync(); + // Connect. + gtl::socket client; + REQUIRE(client.is_open() == false); + for (int i = 0; i < 100; ++i) { + if (client.open(gtl::socket::tcp_client{gtl::socket::ip_any, gtl::socket::port_any, gtl::socket::ip_loopback, this->config_server.port})) { + break; + } + std::this_thread::sleep_for(std::chrono::milliseconds(10)); + } + REQUIRE(client.is_open() == true); + } + unsigned short get_port() const { + return this->config_server.port; + } + unsigned int get_accepted_connections_count() { + std::lock_guard lock(this->mutex_client); + return (this->accepted_connections - 1); + } + bool send_to_last_client(const unsigned char* message, unsigned long long int length) { + std::lock_guard lock(this->mutex_client); + static_cast(lock); + if (!this->socket_client.is_open()) { + return false; + } + unsigned long long int written_length = length; + REQUIRE(this->socket_client.write(message, written_length)); + REQUIRE(written_length == length); + return true; + } +}; + +TEST(socket, traits, standard) { + REQUIRE((std::is_pod::value == false)); + + REQUIRE((std::is_trivial::value == false)); + + REQUIRE((std::is_trivially_copyable::value == false)); + + REQUIRE((std::is_standard_layout::value == true)); +} + +TEST(socket, constructor, empty) { + gtl::socket socket; + testbench::do_not_optimise_away(socket); +} + +TEST(socket, function, is_open) { + gtl::socket socket; + REQUIRE(socket.is_open() == false); +} + +TEST(socket, function, open) { + { + gtl::socket socket; + REQUIRE(socket.open(gtl::socket::udp_client{gtl::socket::ip_any, gtl::socket::port_any}) == true); + } + { + gtl::socket socket; + REQUIRE(socket.open(gtl::socket::udp_client{gtl::socket::ip_any, gtl::socket::port_any}) == true); + } + + { + gtl::socket socket; + REQUIRE(socket.open(gtl::socket::tcp_client{gtl::socket::ip_any, gtl::socket::port_any, gtl::socket::ip_loopback, 1234}) == false); + } + { + gtl::socket socket; + REQUIRE(socket.open(gtl::socket::tcp_client{gtl::socket::ip_any, gtl::socket::port_any, gtl::socket::ip_loopback, 5678}) == false); + } + + local_server server; + + { + gtl::socket socket; + REQUIRE(socket.open(gtl::socket::tcp_client{gtl::socket::ip_any, gtl::socket::port_any, gtl::socket::ip_loopback, static_cast(server.get_port() + 1234)}) == false); + } + { + gtl::socket socket; + REQUIRE(socket.open(gtl::socket::tcp_client{gtl::socket::ip_any, gtl::socket::port_any, gtl::socket::ip_loopback, server.get_port()}) == true); + } +} + +TEST(socket, function, close) { + gtl::socket socket; + REQUIRE(socket.is_open() == false); + socket.close(); + REQUIRE(socket.is_open() == false); + REQUIRE(socket.open(gtl::socket::udp_client{gtl::socket::ip_any, gtl::socket::port_any}) == true); + REQUIRE(socket.is_open() == true); + socket.close(); + REQUIRE(socket.is_open() == false); + socket.close(); + REQUIRE(socket.is_open() == false); +} + +TEST(socket, function, accept) { + { + gtl::socket socket; + REQUIRE(socket.open(gtl::socket::udp_client{gtl::socket::ip_any, gtl::socket::port_any}) == true); + gtl::socket client; + REQUIRE(socket.accept(client) == false); + } + + { + gtl::socket socket; + REQUIRE(socket.open(gtl::socket::tcp_client{gtl::socket::ip_any, gtl::socket::port_any, gtl::socket::ip_loopback, 1234}) == false); + } + { + gtl::socket socket; + REQUIRE(socket.open(gtl::socket::tcp_client{gtl::socket::ip_any, gtl::socket::port_any, gtl::socket::ip_loopback, 5678}) == false); + } + + local_server server; + + { + gtl::socket socket; + REQUIRE(socket.open(gtl::socket::tcp_client{gtl::socket::ip_any, gtl::socket::port_any, gtl::socket::ip_loopback, static_cast(server.get_port() + 1234)}) == false); + } + { + gtl::socket socket; + REQUIRE(socket.open(gtl::socket::tcp_client{gtl::socket::ip_any, gtl::socket::port_any, gtl::socket::ip_loopback, server.get_port()}) == true); + } + { + gtl::socket socket; + REQUIRE(socket.open(gtl::socket::tcp_client{gtl::socket::ip_any, gtl::socket::port_any, gtl::socket::ip_loopback, static_cast(server.get_port() + 1234)}) == false); + } + { + gtl::socket socket; + REQUIRE(socket.open(gtl::socket::tcp_client{gtl::socket::ip_any, gtl::socket::port_any, gtl::socket::ip_loopback, server.get_port()}) == true); + } +} + +TEST(socket, function, read_write_tcp) { + local_server server; + gtl::socket socket; + REQUIRE(socket.open(gtl::socket::tcp_client{gtl::socket::ip_any, gtl::socket::port_any, gtl::socket::ip_loopback, server.get_port()}) == true); + while (server.get_accepted_connections_count() != 1) { + std::this_thread::yield(); + } + const char* sent_message = "Test message"; + const unsigned long long int sent_length = testbench::string_length(sent_message); + REQUIRE(server.send_to_last_client(reinterpret_cast(sent_message), sent_length)); + unsigned char received_message[128] = {}; + unsigned long long int received_length = 128; + do { + std::this_thread::yield(); + REQUIRE(socket.read(received_message, received_length)); + } while (received_length == 0); + received_message[127] = 0; + + REQUIRE(received_length == sent_length, "Mismatching lengths of sent (%llu) and receieved (%llu) data.", sent_length, received_length); + REQUIRE(testbench::is_memory_same(sent_message, received_message, sent_length), "Mismatching sent and receieved data. '%s' != '%s'", sent_message, received_message); +} + +TEST(socket, function, read_write_udp) { + constexpr static const unsigned short socket1_port = 1234; + constexpr static const unsigned short socket2_port = 5678; + + gtl::socket socket1; + REQUIRE(socket1.open(gtl::socket::udp_client{gtl::socket::ip_loopback, socket1_port}) == true); + gtl::socket socket2; + REQUIRE(socket2.open(gtl::socket::udp_client{gtl::socket::ip_loopback, socket2_port}) == true); + + const char* sent_message1 = "Test message 1"; + unsigned long long int sent_length1 = testbench::string_length(sent_message1); + REQUIRE(socket1.write(reinterpret_cast(sent_message1), sent_length1, gtl::socket::ip_loopback, socket2_port)); + + unsigned char received_message1[128] = {}; + unsigned long long int received_length1 = 128; + gtl::socket::ip address1; + unsigned short port1 = 0; + do { + std::this_thread::yield(); + REQUIRE(socket2.read(received_message1, received_length1, address1, port1)); + } while (received_length1 == 0); + received_message1[127] = 0; + + REQUIRE(address1 == gtl::socket::ip_loopback, "Socket 1 ip does not match loopback, got %d.%d.%d.%d.", address1.segment[0], address1.segment[1], address1.segment[2], address1.segment[3]); + REQUIRE(port1 == socket1_port, "Socket1 port does not matched expected value expected %d, got %d.", socket1_port, port1); + REQUIRE(received_length1 == sent_length1, "Mismatching lengths of sent (%llu) and receieved (%llu) data.", sent_length1, received_length1); + REQUIRE(testbench::is_memory_same(sent_message1, received_message1, sent_length1), "Mismatching sent and receieved data. '%s' != '%s'", sent_message1, received_message1); + + const char* sent_message2 = "Test message the second"; + unsigned long long int sent_length2 = testbench::string_length(sent_message2); + REQUIRE(socket2.write(reinterpret_cast(sent_message2), sent_length2, gtl::socket::ip_loopback, socket1_port)); + + unsigned char received_message2[128] = {}; + unsigned long long int received_length2 = 128; + gtl::socket::ip address2; + unsigned short port2 = 0; + do { + std::this_thread::yield(); + REQUIRE(socket1.read(received_message2, received_length2, address2, port2)); + } while (received_length2 == 0); + received_message2[127] = 0; + + REQUIRE(address2 == gtl::socket::ip_loopback, "Socket 2 ip does not match loopback, got %d.%d.%d.%d.", address2.segment[0], address2.segment[1], address2.segment[2], address2.segment[3]); + REQUIRE(port2 == socket2_port, "Socket2 port does not matched expected value expected %d, got %d.", socket2_port, port2); + REQUIRE(received_length2 == sent_length2, "Mismatching lengths of sent (%llu) and receieved (%llu) data.", sent_length2, received_length2); + REQUIRE(testbench::is_memory_same(sent_message2, received_message2, sent_length2), "Mismatching sent and receieved data. '%s' != '%s'", sent_message2, received_message2); +} + +