spicetools/external/headsocket.h

2192 lines
58 KiB
C
Raw Normal View History

2024-08-28 15:10:34 +00:00
/*/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
***** HeadSocket v0.1, created by Jan Pinter **** Minimalistic header only WebSocket server implementation in C++ *****
Sources: https://github.com/P-i-N/HeadSocket, contact: Pinter.Jan@gmail.com
PUBLIC DOMAIN - no warranty implied or offered, use this at your own risk
-----------------------------------------------------------------------------------------------------------------------
Usage:
- use this as a regular header file, but in EXACTLY one of your C++ files (ie. main.cpp) you must define
HEADSOCKET_IMPLEMENTATION beforehand, like this:
#define HEADSOCKET_IMPLEMENTATION
#include <headsocket.h>
/*/////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
#ifndef __HEADSOCKET_H__
#define __HEADSOCKET_H__
#include <memory>
#include <string>
#include <map>
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
namespace headsocket {
/* Forward declarations */
class connection;
class basic_tcp_server;
class basic_tcp_client;
class tcp_client;
class async_tcp_client;
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T> using ptr = std::shared_ptr<T>;
typedef size_t id_t;
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
/* Forward declarations */
struct connection_impl;
struct basic_tcp_server_impl;
struct basic_tcp_client_impl;
struct async_tcp_client_impl;
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
static bool handshake_websocket(connection &conn);
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
struct less_comparator
{
bool operator()(const std::string &s1, const std::string &s2) const
{
return std::lexicographical_compare(s1.begin(), s1.end(), s2.begin(), s2.end(), [](char c1, char c2) -> bool
{
return tolower(c1) < tolower(c2);
});
}
};
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class connection
{
public:
connection(const detail::connection_impl &impl);
~connection();
detail::connection_impl *impl() const { return _p.get(); }
bool is_valid() const;
id_t id() const;
size_t write(const void *ptr, size_t length);
size_t write(const std::string &text) { return write(text.c_str(), text.length()); }
size_t read(void *ptr, size_t length);
bool force_write(const void *ptr, size_t length);
bool force_read(void *ptr, size_t length);
bool read_line(std::string &output);
private:
std::unique_ptr<detail::connection_impl> _p;
};
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
enum class opcode
{
continuation = 0x00,
text = 0x01,
binary = 0x02,
connection_close = 0x08,
ping = 0x09,
pong = 0x0A
};
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
struct data_block
{
opcode op;
size_t offset;
size_t length = 0;
bool is_completed = false;
data_block(opcode opc, size_t off)
: op(opc)
, offset(off)
{
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class basic_tcp_server : public std::enable_shared_from_this<basic_tcp_server>
{
public:
int port() const;
void stop();
bool is_running() const;
bool disconnect(ptr<basic_tcp_client> client);
bool disconnect(id_t id);
protected:
struct protected_tag { };
virtual void init() { }
explicit basic_tcp_server(int port);
virtual ~basic_tcp_server();
virtual bool handshake(connection &conn) = 0;
virtual ptr<basic_tcp_client> accept(connection &conn) = 0;
virtual void client_connected(ptr<basic_tcp_client> client) = 0;
virtual void client_disconnected(ptr<basic_tcp_client> client) = 0;
std::unique_ptr<detail::basic_tcp_server_impl> _p;
private:
template <typename T> friend class tcp_server;
void remove_disconnected() const;
size_t acquire_clients() const;
void release_clients() const;
ptr<basic_tcp_client> client_at(size_t index) const;
size_t num_clients() const;
void accept_thread();
void disconnect_thread();
};
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class basic_tcp_client
{
public:
enum { is_basic_tcp_client };
static const size_t invalid_operation = static_cast<size_t>(-1);
virtual ~basic_tcp_client();
bool disconnect();
bool is_connected() const;
ptr<basic_tcp_server> server() const;
id_t id() const;
protected:
struct protected_tag { };
friend class basic_tcp_server;
virtual void on_accept() { }
virtual void on_disconnect() { }
basic_tcp_client(const std::string &address, int port);
basic_tcp_client(ptr<basic_tcp_server> server, connection &conn);
std::unique_ptr<detail::basic_tcp_client_impl> _p;
};
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
struct protected_tag { };
#define HEADSOCKET_SERVER(className, baseClassName) \
protected: \
explicit className(int port): baseClassName(port) { init(); } \
public: \
typedef baseClassName base_t; \
className(const protected_tag &, int port): className(port) { } \
static headsocket::ptr<className> create(int port) { return std::make_shared<className>(protected_tag{}, port); } \
protected: \
void init()
template <typename T>
class tcp_server : public basic_tcp_server
{
HEADSOCKET_SERVER(tcp_server, basic_tcp_server) { }
public:
typedef T client_t;
typedef ptr<client_t> client_ptr;
virtual ~tcp_server()
{
base_t::stop();
}
class enumerator
{
public:
explicit enumerator(const tcp_server &server)
: _server(server)
, _count(server.acquire_clients())
{
}
~enumerator()
{
_server.release_clients();
}
const tcp_server &server() const { return _server; }
size_t size() const { return _count; }
struct iterator
{
enumerator &e;
size_t index;
iterator(enumerator &enu, size_t idx)
: e(enu)
, index(idx)
{
}
bool operator==(const iterator &iter) const { return iter.index == index && &iter.e == &e; }
bool operator!=(const iterator &iter) const { return iter.index != index || &iter.e != &e; }
ptr<T> operator*() const { return std::dynamic_pointer_cast<T>(e.server().client_at(index)); }
iterator &operator++()
{
++index;
return *this;
}
};
iterator begin() { return iterator(*this, 0); }
iterator end() { return iterator(*this, _count); }
private:
const tcp_server &_server;
size_t _count;
};
enumerator clients() const { return enumerator(*this); }
protected:
bool handshake(connection &conn) override { return true; }
virtual void client_connected(client_ptr client) { }
virtual void client_disconnected(client_ptr client) { }
private:
enum { needs_basic_tcp_client = T::is_basic_tcp_client };
ptr<basic_tcp_client> accept(connection &conn) override
{
ptr<basic_tcp_client> newClient = T::create(shared_from_this(), conn);
return newClient->is_connected() ? newClient : nullptr;
}
void client_connected(ptr<basic_tcp_client> client) override
{
client_connected(std::dynamic_pointer_cast<T>(client));
}
void client_disconnected(ptr<basic_tcp_client> client) override
{
client_disconnected(std::dynamic_pointer_cast<T>(client));
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
#define __HEADSOCKET_CLIENT_STATIC_CTORS(className) \
className(const protected_tag &, const std::string &address, int port): className(address, port) { } \
className(const protected_tag &, headsocket::ptr<headsocket::basic_tcp_server> server, headsocket::connection &conn): className(server, conn) { } \
static headsocket::ptr<className> create(const std::string &address, int port) { return std::make_shared<className>(protected_tag{}, address, port); } \
static headsocket::ptr<className> create(headsocket::ptr<headsocket::basic_tcp_server> server, headsocket::connection &conn) { return std::make_shared<className>(protected_tag{}, server, conn); }
#define HEADSOCKET_CLIENT_BASE(className) \
protected: \
className(const std::string &address, int port); \
className(ptr<basic_tcp_server> server, connection &conn); \
public: \
__HEADSOCKET_CLIENT_STATIC_CTORS(className)
#define HEADSOCKET_CLIENT(className, baseClassName) \
protected: \
className(const std::string &address, int port): baseClassName(address, port) { } \
className(headsocket::ptr<headsocket::basic_tcp_server> server, headsocket::connection &conn): baseClassName(server, conn) { } \
public: \
__HEADSOCKET_CLIENT_STATIC_CTORS(className)
class tcp_client : public basic_tcp_client
{
HEADSOCKET_CLIENT_BASE(tcp_client)
public:
typedef basic_tcp_client base_t;
enum { is_tcp_client };
virtual ~tcp_client();
virtual size_t write(const void *ptr, size_t length);
virtual size_t read(void *ptr, size_t length);
bool force_write(const void *ptr, size_t length);
bool force_read(void *ptr, size_t length);
bool read_line(std::string &output);
};
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class async_tcp_client : public basic_tcp_client
{
HEADSOCKET_CLIENT_BASE(async_tcp_client)
public:
typedef basic_tcp_client base_t;
enum { is_async_tcp_client };
virtual ~async_tcp_client();
void push(const void *ptr, size_t length);
void push(const std::string &text);
size_t peek() const;
size_t pop(void *ptr, size_t length);
protected:
void on_accept() override { init_threads(); }
void on_disconnect() override { kill_threads(); }
virtual void init_threads();
virtual size_t async_write_handler(uint8_t *ptr, size_t length);
virtual size_t async_read_handler(uint8_t *ptr, size_t length);
virtual bool async_received_data(const data_block &db, uint8_t *ptr, size_t length) { return false; }
virtual void push(const void *ptr, size_t length, opcode opcode);
void kill_threads();
std::unique_ptr<detail::async_tcp_client_impl> _ap;
private:
void write_thread();
void read_thread();
};
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class web_socket_client : public async_tcp_client
{
HEADSOCKET_CLIENT_BASE(web_socket_client)
public:
static const size_t frame_size_limit = 128 * 1024;
typedef async_tcp_client base_t;
enum { is_web_socket_client };
virtual ~web_socket_client();
size_t peek(opcode *op) const;
protected:
size_t async_write_handler(uint8_t *ptr, size_t length) override;
size_t async_read_handler(uint8_t *ptr, size_t length) override;
private:
struct frame_header
{
bool fin;
opcode op;
bool masked;
size_t payload_length;
uint32_t masking_key;
size_t write(uint8_t *ptr, size_t length) const;
size_t read(const uint8_t *ptr, size_t length);
};
size_t _payload_size = 0;
frame_header _current_header;
};
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T>
class web_socket_server : public tcp_server<T>
{
HEADSOCKET_SERVER(web_socket_server, tcp_server<T>) { }
public:
virtual ~web_socket_server()
{
base_t::stop();
}
protected:
bool handshake(connection &conn) override { return detail::handshake_websocket(conn); }
private:
enum { needs_web_socket_client = T::is_web_socket_client };
};
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
class http_server : public tcp_server<tcp_client>
{
HEADSOCKET_SERVER(http_server, tcp_server<tcp_client>) { }
public:
~http_server()
{
stop();
}
struct response
{
std::string content_type = "text/html";
std::string message = "";
};
struct parameter
{
std::string name;
std::string value;
bool boolean;
int integer;
double real;
};
typedef std::map<std::string, parameter, detail::less_comparator> parameters_t;
protected:
virtual bool request(const std::string &path, const parameters_t &params, response &resp) { return false; }
private:
bool handshake(connection &conn) final override;
ptr<basic_tcp_client> accept(connection &conn) final override { return nullptr; }
void client_connected(client_ptr client) final override { }
void client_disconnected(client_ptr client) final override { }
};
}
#endif // __HEADSOCKET_H__
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef HEADSOCKET_IMPLEMENTATION
#ifndef __HEADSOCKET_H_IMPL__
#define __HEADSOCKET_H_IMPL__
#include <thread>
#include <atomic>
#include <iomanip>
#include <vector>
#include <mutex>
#include <condition_variable>
#include <memory>
#include <sstream>
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
#ifndef HEADSOCKET_PLATFORM_OVERRIDE
#ifdef _WIN32
#define HEADSOCKET_PLATFORM_WINDOWS
#elif __ANDROID__
#define HEADSOCKET_PLATFORM_ANDROID
#define HEADSOCKET_PLATFORM_NIX
#elif __APPLE__
#include "TargetConditionals.h"
#ifdef TARGET_OS_MAC
#define HEADSOCKET_PLATFORM_MAC
#endif
#elif __linux
#define HEADSOCKET_PLATFORM_NIX
#elif __unix
#define HEADSOCKET_PLATFORM_NIX
#elif __posix
#define HEADSOCKET_PLATFORM_NIX
#else
#error Unsupported platform!
#endif
#endif
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(HEADSOCKET_PLATFORM_WINDOWS)
#pragma comment(lib, "ws2_32.lib")
#include <winsock2.h>
#include <windows.h>
#include <ws2tcpip.h>
#include <functional>
#elif defined(HEADSOCKET_PLATFORM_ANDROID) || defined(HEADSOCKET_PLATFORM_NIX)
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/ip.h>
#include <unistd.h>
#include <netdb.h>
#endif
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
#define HEADSOCKET_LOCK_SUFFIX(var, suffix) std::lock_guard<decltype(var)> __scope_lock##suffix(var);
#define HEADSOCKET_LOCK_SUFFIX2(var, suffix) HEADSOCKET_LOCK_SUFFIX(var, suffix)
#define HEADSOCKET_LOCK(var) HEADSOCKET_LOCK_SUFFIX2(var, __LINE__)
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
namespace headsocket {
namespace detail {
#if defined(HEADSOCKET_PLATFORM_WINDOWS)
typedef SOCKET socket_type;
static const int socket_error = SOCKET_ERROR;
static const SOCKET invalid_socket = INVALID_SOCKET;
void close_socket(socket_type s) { closesocket(s); }
#define HEADSOCKET_SPRINTF sprintf_s
#elif defined(HEADSOCKET_PLATFORM_ANDROID) || defined(HEADSOCKET_PLATFORM_NIX)
typedef int socket_type;
static const int socket_error = -1;
static const int invalid_socket = -1;
void close_socket(socket_type s) { close(s); }
#define HEADSOCKET_SPRINTF sprintf
#endif
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
class sha1
{
public:
typedef uint32_t digest32_t[5];
typedef uint8_t digest8_t[20];
inline static uint32_t rotate_left(uint32_t value, size_t count) { return (value << count) ^ (value >> (32 - count)); }
sha1()
{
_digest[0] = 0x67452301;
_digest[1] = 0xEFCDAB89;
_digest[2] = 0x98BADCFE;
_digest[3] = 0x10325476;
_digest[4] = 0xC3D2E1F0;
}
~sha1()
{
}
void process_byte(uint8_t octet)
{
_block[_block_byte_index++] = octet;
++_byte_count;
if (_block_byte_index == 64)
{
_block_byte_index = 0;
process_block();
}
}
void process_block(const void *start, const void *end)
{
const uint8_t *begin = static_cast<const uint8_t *>(start);
while (begin != end)
process_byte(*begin++);
}
void process_bytes(const void *data, size_t len)
{
process_block(data, static_cast<const uint8_t *>(data) + len);
}
const uint32_t *get_digest(digest32_t digest)
{
size_t bitCount = _byte_count * 8;
process_byte(0x80);
if (_block_byte_index > 56)
{
while (_block_byte_index != 0)
process_byte(0);
while (_block_byte_index < 56)
process_byte(0);
}
else
while (_block_byte_index < 56)
process_byte(0);
process_byte(0);
process_byte(0);
process_byte(0);
process_byte(0);
for (int i = 24; i >= 0; i -= 8)
process_byte(static_cast<unsigned char>((bitCount >> i) & 0xFF));
memcpy(digest, _digest, 5 * sizeof(uint32_t));
return digest;
}
const uint8_t *get_digest_bytes(digest8_t digest)
{
digest32_t d32;
get_digest(d32);
size_t s[] = { 24, 16, 8, 0 };
for (size_t i = 0, j = 0; i < 20; ++i, j = i % 4)
digest[i] = ((d32[i >> 2] >> s[j]) & 0xFF);
return digest;
}
private:
void process_block()
{
uint32_t w[80], s[] = { 24, 16, 8, 0 };
for (size_t i = 0, j = 0; i < 64; ++i, j = i % 4)
w[i / 4] = j ? (w[i / 4] | (_block[i] << s[j])) : (_block[i] << s[j]);
for (size_t i = 16; i < 80; i++)
w[i] = rotate_left((w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16]), 1);
digest32_t dig = { _digest[0], _digest[1], _digest[2], _digest[3], _digest[4] };
for (size_t f, k, i = 0; i < 80; ++i)
{
if (i < 20)
f = (dig[1] & dig[2]) | (~dig[1] & dig[3]), k = 0x5A827999;
else if (i < 40)
f = dig[1] ^ dig[2] ^ dig[3], k = 0x6ED9EBA1;
else if (i < 60)
f = (dig[1] & dig[2]) | (dig[1] & dig[3]) | (dig[2] & dig[3]), k = 0x8F1BBCDC;
else
f = dig[1] ^ dig[2] ^ dig[3], k = 0xCA62C1D6;
uint32_t temp = static_cast<uint32_t>(rotate_left(dig[0], 5) + f + dig[4] + k + w[i]);
dig[4] = dig[3];
dig[3] = dig[2];
dig[2] = rotate_left(dig[1], 30);
dig[1] = dig[0];
dig[0] = temp;
}
for (size_t i = 0; i < 5; ++i)
_digest[i] += dig[i];
}
digest32_t _digest;
uint8_t _block[64];
size_t _block_byte_index = 0;
size_t _byte_count = 0;
};
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
struct utils
{
static std::string base64_encode(const void *ptr, size_t length)
{
static const char *encoding_table = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
static size_t mod_table[] = { 0, 2, 1 };
std::string result(4 * ((length + 2) / 3), '=');
if (ptr && length)
{
const uint8_t *input = reinterpret_cast<const uint8_t *>(ptr);
for (size_t i = 0, j = 0, triplet = 0; i < length; triplet = 0)
{
for (size_t k = 0; k < 3; ++k)
triplet = (triplet << 8) | (i < length ? static_cast<uint8_t>(input[i++]) : 0);
for (size_t k = 4; k--;)
result[j++] = encoding_table[(triplet >> k * 6) & 0x3F];
}
for (size_t i = 0; i < mod_table[length % 3]; i++)
result[result.length() - 1 - i] = '=';
}
return result;
}
static size_t xor32(uint32_t key, void *ptr, size_t length)
{
uint8_t *data = reinterpret_cast<uint8_t *>(ptr);
uint8_t *mask = reinterpret_cast<uint8_t *>(&key);
for (size_t i = 0; i < length; ++i, ++data)
*data = (*data) ^ mask[i % 4];
return length;
}
static std::string url_encode(const std::string &str)
{
std::ostringstream result;
result.fill('0');
result << std::hex;
for (std::string::const_iterator i = str.begin(), n = str.end(); i != n; ++i)
{
auto c = (*i);
if (isalnum(c) || c == '-' || c == '_' || c == '.' || c == '~')
result << c;
else
result << '%' << std::setw(2) << static_cast<int>(c);
}
return result.str();
}
static std::string url_decode(const std::string &str)
{
std::ostringstream result;
for (size_t i = 0, S = str.length(); i < S; ++i)
{
auto c = str[i];
if (c == '%')
{
char hexBuff[3] = { 0, 0, 0 };
hexBuff[0] = str[++i];
hexBuff[1] = str[++i];
int value;
sscanf(hexBuff, "%x", &value);
result << static_cast<char>(value);
}
else if (c == '+')
result << ' ';
else
result << c;
}
return result.str();
}
static uint16_t swap16bits(uint16_t x) { return ((x & 0x00FF) << 8) | ((x & 0xFF00) >> 8); }
static uint32_t swap32bits(uint32_t x)
{
return ((x & 0x000000FF) << 24) | ((x & 0x0000FF00) << 8) | ((x & 0x00FF0000) >> 8) | ((x & 0xFF000000) >> 24);
}
static uint64_t swap64bits(uint64_t x)
{
return
((x & 0x00000000000000FFULL) << 56) | ((x & 0x000000000000FF00ULL) << 40) | ((x & 0x0000000000FF0000ULL) << 24) |
((x & 0x00000000FF000000ULL) << 8) | ((x & 0x000000FF00000000ULL) >> 8) | ((x & 0x0000FF0000000000ULL) >> 24) |
((x & 0x00FF000000000000ULL) >> 40) | ((x & 0xFF00000000000000ULL) >> 56);
}
static std::string trim(const std::string &str)
{
size_t trimLeft = 0, trimRight = str.length() - 1;
while (trimLeft < str.length() && isspace(str[trimLeft]))
++trimLeft;
while (trimRight < str.length() && isspace(str[trimRight]))
--trimRight;
return (trimRight >= str.length() || trimLeft >= str.length() || trimRight < trimLeft)
? std::string("")
: str.substr(trimLeft, trimRight - trimLeft + 1);
}
static std::string cut_front(std::string &str, char delimiter = ' ', bool first = true, bool hungry = true)
{
std::string result;
auto pos = first ? str.find(delimiter) : str.rfind(delimiter);
if (pos == std::string::npos)
{
if (hungry)
{
result = str;
str = "";
}
}
else
{
result = str.substr(0, pos);
str = str.substr(pos + 1);
}
return result;
}
static std::string cut_back(std::string &str, char delimiter = ' ', bool first = true, bool hungry = true)
{
std::string result;
auto pos = first ? str.rfind(delimiter) : str.find(delimiter);
if (pos == std::string::npos)
{
if (hungry)
{
result = str;
str = "";
}
}
else
{
result = str.substr(pos + 1);
str = str.substr(0, pos);
}
return result;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
struct critical_section
{
mutable std::atomic_bool consumer_lock;
critical_section()
{
consumer_lock = false;
}
void lock() const { while (consumer_lock.exchange(true)); }
void unlock() const { consumer_lock = false; }
};
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename T, typename M = critical_section>
struct lockable_value : M
{
T value;
T *operator->() { return &value; }
const T *operator->() const { return &value; }
};
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
struct semaphore
{
mutable std::atomic_size_t count;
mutable std::mutex mutex;
mutable std::condition_variable cv;
semaphore()
{
count = 0;
}
void lock(size_t minCount = 0) const
{
std::unique_lock<std::mutex> lock(mutex);
cv.wait(lock, [&]()->bool { return count > minCount; });
lock.release();
}
void unlock()
{
mutex.unlock();
}
void notify()
{
{
std::lock_guard<std::mutex> lock(mutex);
++count;
}
cv.notify_one();
}
void consume() const
{
if (count)
--count;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
struct data_block_buffer
{
std::vector<data_block> blocks;
std::vector<uint8_t> buffer;
data_block_buffer()
{
buffer.reserve(65536);
}
data_block &block_begin(opcode op)
{
blocks.emplace_back(op, buffer.size());
return blocks.back();
}
data_block &block_end()
{
blocks.back().is_completed = true;
return blocks.back();
}
void block_remove()
{
if (blocks.empty())
return;
buffer.resize(blocks.back().offset);
blocks.pop_back();
}
void write(const void *ptr, size_t length)
{
if (!length)
return;
buffer.resize(buffer.size() + length);
memcpy(buffer.data() + buffer.size() - length, reinterpret_cast<const char *>(ptr), length);
blocks.back().length += length;
}
size_t read(void *ptr, size_t length)
{
if (!ptr || blocks.empty() || !blocks.front().is_completed)
return 0;
data_block &db = blocks.front();
size_t result = db.length >= length ? length : db.length;
if (result)
{
memcpy(ptr, buffer.data() + db.offset, result);
buffer.erase(buffer.begin(), buffer.begin() + result);
}
if (!(db.length -= result))
blocks.erase(blocks.begin());
else
blocks.front().op = opcode::continuation;
if (result) for (auto &block : blocks) if (block.offset > db.offset)
block.offset -= result;
return result;
}
size_t peek(opcode *op = nullptr) const
{
if (blocks.empty() || !blocks.front().is_completed)
return 0;
if (op)
*op = blocks.front().op;
return blocks.front().length;
}
};
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
//---------------------------------------------------------------------------------------------------------------------
bool handshake_websocket(connection &conn)
{
std::string line, key;
while (conn.read_line(line))
{
if (line.empty())
break;
if (!memcmp(line.c_str(), "Sec-WebSocket-Key: ", 19))
key = line.substr(19);
}
if (key.empty())
return false;
key += "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
detail::sha1 sha;
detail::sha1::digest8_t digest;
sha.process_bytes(key.c_str(), key.length());
std::string response = "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ";
response += detail::utils::base64_encode(sha.get_digest_bytes(digest), 20);
response += "\r\n\r\n";
return conn.force_write(response.c_str(), response.length());
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef HEADSOCKET_PLATFORM_WINDOWS
void set_thread_name(const char *name)
{
}
#else
void set_thread_name(const char *name)
{
}
#endif
} // namespace detail;
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
struct connection_impl
{
detail::socket_type socket = detail::invalid_socket;
sockaddr_in from;
size_t id = 0;
void assign(const connection_impl &impl)
{
socket = impl.socket;
from = impl.from;
id = impl.id;
}
void close()
{
if (socket != detail::invalid_socket)
{
detail::close_socket(socket);
socket = detail::invalid_socket;
}
}
};
}
//---------------------------------------------------------------------------------------------------------------------
connection::connection(const detail::connection_impl &impl)
: _p(std::make_unique<detail::connection_impl>())
{
_p->assign(impl);
}
//---------------------------------------------------------------------------------------------------------------------
connection::~connection()
{
}
//---------------------------------------------------------------------------------------------------------------------
bool connection::is_valid() const { return _p->socket != detail::invalid_socket; }
//---------------------------------------------------------------------------------------------------------------------
size_t connection::id() const { return _p->id; }
//---------------------------------------------------------------------------------------------------------------------
size_t connection::write(const void *ptr, size_t length)
{
if (!is_valid())
return detail::socket_error;
if (!ptr || !length)
return 0;
int result = send(_p->socket, static_cast<const char *>(ptr), static_cast<int>(length), 0);
if (!result || result == detail::socket_error)
return 0;
return static_cast<size_t>(result);
}
//---------------------------------------------------------------------------------------------------------------------
bool connection::force_write(const void *ptr, size_t length)
{
if (!is_valid())
return false;
if (!ptr)
return true;
const char *chPtr = static_cast<const char *>(ptr);
while (length)
{
int result = send(_p->socket, chPtr, static_cast<int>(length), 0);
if (!result || result == detail::socket_error)
return false;
length -= static_cast<size_t>(result);
chPtr += result;
}
return true;
}
//---------------------------------------------------------------------------------------------------------------------
size_t connection::read(void *ptr, size_t length)
{
if (!is_valid())
return detail::socket_error;
if (!ptr || !length)
return 0;
int result = recv(_p->socket, static_cast<char *>(ptr), static_cast<int>(length), 0);
if (!result || result == detail::socket_error)
return 0;
return static_cast<size_t>(result);
}
//---------------------------------------------------------------------------------------------------------------------
bool connection::force_read(void *ptr, size_t length)
{
if (!is_valid())
return false;
if (!ptr)
return true;
char *chPtr = static_cast<char *>(ptr);
while (length)
{
int result = recv(_p->socket, chPtr, static_cast<int>(length), 0);
if (!result || result == detail::socket_error)
return false;
length -= static_cast<size_t>(result);
chPtr += result;
}
return true;
}
//---------------------------------------------------------------------------------------------------------------------
bool connection::read_line(std::string &output)
{
if (!is_valid())
return false;
output = "";
while (true)
{
char ch;
int r = recv(_p->socket, &ch, 1, 0);
if (!r || r == detail::socket_error)
return false;
if (ch == '\n')
break;
else if (ch != '\r')
output += ch;
}
return true;
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
struct basic_tcp_client_ref
{
size_t refCount = 0;
ptr<basic_tcp_client> client;
basic_tcp_client_ref(ptr<basic_tcp_client> c)
: client(c)
{
}
};
struct basic_tcp_server_impl
{
std::atomic_bool isRunning;
std::atomic_bool disconnectThreadQuit;
sockaddr_in local;
detail::lockable_value<std::vector<basic_tcp_client_ref>> connections;
detail::semaphore disconnectSemaphore;
int port = 0;
detail::socket_type serverSocket = invalid_socket;
std::unique_ptr<std::thread> acceptThread;
std::unique_ptr<std::thread> disconnectThread;
id_t nextClientID = 1;
basic_tcp_server_impl()
{
isRunning = false;
disconnectThreadQuit = false;
}
};
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
//---------------------------------------------------------------------------------------------------------------------
basic_tcp_server::basic_tcp_server(int port)
: _p(std::make_unique<detail::basic_tcp_server_impl>())
{
#ifdef HEADSOCKET_PLATFORM_WINDOWS
WSADATA wsaData;
WSAStartup(0x101, &wsaData);
#endif
_p->local.sin_family = AF_INET;
_p->local.sin_addr.s_addr = INADDR_ANY;
_p->local.sin_port = htons(static_cast<unsigned short>(port));
_p->serverSocket = socket(AF_INET, SOCK_STREAM, 0);
int opt_enable = 1;
setsockopt(_p->serverSocket, SOL_SOCKET, SO_REUSEADDR, (const char*) &opt_enable, sizeof(int));
if (bind(_p->serverSocket, reinterpret_cast<sockaddr *>(&_p->local), sizeof(_p->local)) != 0)
return;
if (listen(_p->serverSocket, 8))
return;
_p->isRunning = true;
_p->port = port;
_p->acceptThread = std::make_unique<std::thread>(std::bind(&basic_tcp_server::accept_thread, this));
_p->disconnectThread = std::make_unique<std::thread>(std::bind(&basic_tcp_server::disconnect_thread, this));
}
//---------------------------------------------------------------------------------------------------------------------
basic_tcp_server::~basic_tcp_server()
{
stop();
#ifdef HEADSOCKET_PLATFORM_WINDOWS
WSACleanup();
#endif
}
//---------------------------------------------------------------------------------------------------------------------
int basic_tcp_server::port() const { return _p->port; }
//---------------------------------------------------------------------------------------------------------------------
void basic_tcp_server::stop()
{
if (_p->isRunning.exchange(false))
{
detail::close_socket(_p->serverSocket);
{
acquire_clients();
for (size_t i = 0, S = num_clients(); i < S; ++i)
client_at(i)->disconnect();
release_clients();
}
if (_p->acceptThread)
{
_p->acceptThread->join();
_p->acceptThread = nullptr;
}
if (_p->disconnectThread)
{
_p->disconnectThreadQuit = true;
_p->disconnectSemaphore.notify();
_p->disconnectThread->join();
_p->disconnectThread = nullptr;
}
}
}
//---------------------------------------------------------------------------------------------------------------------
bool basic_tcp_server::is_running() const { return _p->isRunning; }
//---------------------------------------------------------------------------------------------------------------------
bool basic_tcp_server::disconnect(ptr<basic_tcp_client> client)
{
bool found = false;
if (client)
{
{
HEADSOCKET_LOCK(_p->connections);
for (size_t i = 0, S = _p->connections->size(); i < S; ++i)
if (_p->connections->at(i).client == client)
{
found = true;
break;
}
}
if (found && !client->disconnect())
{
client_disconnected(client);
_p->disconnectSemaphore.notify();
}
}
return found;
}
//---------------------------------------------------------------------------------------------------------------------
bool basic_tcp_server::disconnect(id_t id)
{
bool found = false;
if (id)
{
ptr<basic_tcp_client> client;
{
HEADSOCKET_LOCK(_p->connections);
for (size_t i = 0, S = _p->connections->size(); i < S; ++i)
if (_p->connections->at(i).client->id() == id)
{
client = _p->connections->at(i).client;
found = true;
break;
}
}
if (found && !client->disconnect())
{
client_disconnected(client);
_p->disconnectSemaphore.notify();
}
}
return found;
}
//---------------------------------------------------------------------------------------------------------------------
ptr<basic_tcp_client> basic_tcp_server::client_at(size_t index) const
{
HEADSOCKET_LOCK(_p->connections);
return index < _p->connections->size() ? _p->connections->at(index).client : nullptr;
}
//---------------------------------------------------------------------------------------------------------------------
size_t basic_tcp_server::num_clients() const
{
HEADSOCKET_LOCK(_p->connections);
return _p->connections->size();
}
//---------------------------------------------------------------------------------------------------------------------
size_t basic_tcp_server::acquire_clients() const
{
HEADSOCKET_LOCK(_p->connections);
for (auto &clientRef : _p->connections.value)
++clientRef.refCount;
return _p->connections->size();
}
//---------------------------------------------------------------------------------------------------------------------
void basic_tcp_server::release_clients() const
{
HEADSOCKET_LOCK(_p->connections);
for (auto &clientRef : _p->connections.value)
--clientRef.refCount;
remove_disconnected();
}
//---------------------------------------------------------------------------------------------------------------------
void basic_tcp_server::remove_disconnected() const
{
size_t i = 0;
while (i < _p->connections->size())
{
auto &clientRef = _p->connections.value[i];
if (!clientRef.client->is_connected() && clientRef.refCount == 0)
{
clientRef.client->on_disconnect();
_p->connections->erase(_p->connections->begin() + i);
}
else
++i;
}
}
//---------------------------------------------------------------------------------------------------------------------
void basic_tcp_server::accept_thread()
{
detail::set_thread_name("BaseTcpServer::acceptThread");
while (_p->isRunning)
{
detail::connection_impl conn_impl;
conn_impl.socket = ::accept(_p->serverSocket, reinterpret_cast<struct sockaddr *>(&conn_impl.from), nullptr);
conn_impl.id = _p->nextClientID++;
if (!_p->nextClientID)
++_p->nextClientID;
if (!_p->isRunning)
break;
if (conn_impl.socket != detail::invalid_socket)
{
connection conn(conn_impl);
ptr<basic_tcp_client> newClient;
bool failed = false;
if (handshake(conn))
{
if ((newClient = accept(conn)))
{
newClient->on_accept();
HEADSOCKET_LOCK(_p->connections);
_p->connections->push_back(newClient);
}
else {
failed = true;
}
}
else {
failed = true;
}
if (failed)
{
conn_impl.close();
--_p->nextClientID;
if (!_p->nextClientID)
--_p->nextClientID;
}
else
client_connected(newClient);
}
}
}
//---------------------------------------------------------------------------------------------------------------------
void basic_tcp_server::disconnect_thread()
{
detail::set_thread_name("BaseTcpServer::disconnectThread");
while (!_p->disconnectThreadQuit)
{
{
HEADSOCKET_LOCK(_p->disconnectSemaphore);
HEADSOCKET_LOCK(_p->connections);
remove_disconnected();
_p->disconnectSemaphore.consume();
}
}
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
struct basic_tcp_client_impl
{
std::atomic_int refCount;
std::atomic_bool isConnected;
std::weak_ptr<basic_tcp_server> server;
connection conn = detail::connection_impl();
std::string address = "";
int port = 0;
basic_tcp_client_impl()
{
refCount = 0;
isConnected = false;
}
};
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
//---------------------------------------------------------------------------------------------------------------------
basic_tcp_client::basic_tcp_client(const std::string &address, int port)
: _p(std::make_unique<detail::basic_tcp_client_impl>())
{
struct addrinfo *result = nullptr, *ptr = nullptr, hints;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
hints.ai_socktype = SOCK_STREAM;
hints.ai_protocol = IPPROTO_TCP;
char buff[16];
HEADSOCKET_SPRINTF(buff, "%d", port);
if (getaddrinfo(address.c_str(), buff, &hints, &result))
return;
for (ptr = result; ptr != nullptr; ptr = ptr->ai_next)
{
_p->conn.impl()->socket = socket(ptr->ai_family, ptr->ai_socktype, ptr->ai_protocol);
if (!_p->conn.is_valid())
return;
if (connect(_p->conn.impl()->socket, ptr->ai_addr, static_cast<int>(ptr->ai_addrlen)) == detail::socket_error)
{
detail::close_socket(_p->conn.impl()->socket);
_p->conn.impl()->socket = detail::invalid_socket;
continue;
}
break;
}
freeaddrinfo(result);
if (!_p->conn.is_valid())
return;
_p->address = address;
_p->port = port;
_p->isConnected = true;
}
//---------------------------------------------------------------------------------------------------------------------
basic_tcp_client::basic_tcp_client(ptr<basic_tcp_server> server, connection &conn)
: _p(std::make_unique<detail::basic_tcp_client_impl>())
{
_p->server = server;
_p->conn.impl()->assign(*(conn.impl()));
_p->isConnected = true;
}
//---------------------------------------------------------------------------------------------------------------------
basic_tcp_client::~basic_tcp_client()
{
disconnect();
}
//---------------------------------------------------------------------------------------------------------------------
bool basic_tcp_client::disconnect()
{
bool wasConnected = _p->isConnected.exchange(false);
if (wasConnected)
{
_p->conn.impl()->close();
ptr<basic_tcp_server> s = server();
if (s)
s->disconnect(_p->conn.id());
}
return wasConnected;
}
//---------------------------------------------------------------------------------------------------------------------
bool basic_tcp_client::is_connected() const { return _p->isConnected; }
//---------------------------------------------------------------------------------------------------------------------
ptr<basic_tcp_server> basic_tcp_client::server() const { return _p->server.lock(); }
//---------------------------------------------------------------------------------------------------------------------
id_t basic_tcp_client::id() const { return _p->conn.id(); }
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
//---------------------------------------------------------------------------------------------------------------------
tcp_client::tcp_client(const std::string &address, int port)
: base_t(address, port)
{
}
//---------------------------------------------------------------------------------------------------------------------
tcp_client::tcp_client(ptr<basic_tcp_server> server, connection &conn)
: base_t(server, conn)
{
}
//---------------------------------------------------------------------------------------------------------------------
tcp_client::~tcp_client()
{
}
//---------------------------------------------------------------------------------------------------------------------
size_t tcp_client::write(const void *ptr, size_t length) { return _p->conn.write(ptr, length); }
//---------------------------------------------------------------------------------------------------------------------
size_t tcp_client::read(void *ptr, size_t length) { return _p->conn.read(ptr, length); }
//---------------------------------------------------------------------------------------------------------------------
bool tcp_client::force_write(const void *ptr, size_t length) { return _p->conn.force_write(ptr, length); }
//---------------------------------------------------------------------------------------------------------------------
bool tcp_client::force_read(void *ptr, size_t length) { return _p->conn.force_read(ptr, length); }
//---------------------------------------------------------------------------------------------------------------------
bool tcp_client::read_line(std::string &output) { return _p->conn.read_line(output); }
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
struct async_tcp_client_impl
{
detail::semaphore writeSemaphore;
detail::lockable_value<detail::data_block_buffer> writeBlocks;
detail::lockable_value<detail::data_block_buffer> readBlocks;
std::unique_ptr<std::thread> writeThread;
std::unique_ptr<std::thread> readThread;
std::atomic_int threadCounter = { 0 };
};
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
//---------------------------------------------------------------------------------------------------------------------
async_tcp_client::async_tcp_client(const std::string &address, int port)
: base_t(address, port)
, _ap(std::make_unique<detail::async_tcp_client_impl>())
{
}
//---------------------------------------------------------------------------------------------------------------------
async_tcp_client::async_tcp_client(ptr<basic_tcp_server> server, connection &conn)
: base_t(server, conn)
, _ap(new detail::async_tcp_client_impl())
{
}
//---------------------------------------------------------------------------------------------------------------------
async_tcp_client::~async_tcp_client()
{
disconnect();
_ap->writeSemaphore.notify();
_ap->writeThread->join();
_ap->readThread->join();
}
//---------------------------------------------------------------------------------------------------------------------
void async_tcp_client::push(const void *ptr, size_t length, opcode opcode)
{
if (!ptr)
return;
{
HEADSOCKET_LOCK(_ap->writeBlocks);
_ap->writeBlocks->block_begin(opcode);
_ap->writeBlocks->write(ptr, length);
_ap->writeBlocks->block_end();
}
_ap->writeSemaphore.notify();
}
//---------------------------------------------------------------------------------------------------------------------
void async_tcp_client::push(const void *ptr, size_t length)
{
push(ptr, length, opcode::binary);
}
//---------------------------------------------------------------------------------------------------------------------
void async_tcp_client::push(const std::string &text)
{
push(text.c_str(), text.length(), opcode::text);
}
//---------------------------------------------------------------------------------------------------------------------
size_t async_tcp_client::peek() const
{
HEADSOCKET_LOCK(_ap->readBlocks);
return _ap->readBlocks->peek(nullptr);
}
//---------------------------------------------------------------------------------------------------------------------
size_t async_tcp_client::pop(void *ptr, size_t length)
{
if (!ptr)
return invalid_operation;
if (!length)
return 0;
HEADSOCKET_LOCK(_ap->readBlocks);
return _ap->readBlocks->read(ptr, length);
}
//---------------------------------------------------------------------------------------------------------------------
void async_tcp_client::init_threads()
{
_ap->threadCounter = 0;
_ap->writeThread = std::make_unique<std::thread>(std::bind(&async_tcp_client::write_thread, this));
_ap->readThread = std::make_unique<std::thread>(std::bind(&async_tcp_client::read_thread, this));
}
//---------------------------------------------------------------------------------------------------------------------
void async_tcp_client::write_thread()
{
++_ap->threadCounter;
detail::set_thread_name("AsyncTcpClient::writeThread");
std::vector<uint8_t> buffer(1024 * 1024);
while (_p->isConnected)
{
size_t written = 0;
{
HEADSOCKET_LOCK(_ap->writeSemaphore);
if (!_p->isConnected)
break;
written = async_write_handler(buffer.data(), buffer.size());
}
if (written == invalid_operation)
break;
if (!written)
buffer.resize(buffer.size() * 2);
else
{
const char *cursor = reinterpret_cast<const char *>(buffer.data());
while (written)
{
int result = send(_p->conn.impl()->socket, cursor, static_cast<int>(written), 0);
if (!result || result == detail::socket_error)
break;
cursor += result;
written -= static_cast<size_t>(result);
}
}
}
kill_threads();
--_ap->threadCounter;
}
//---------------------------------------------------------------------------------------------------------------------
size_t async_tcp_client::async_write_handler(uint8_t *ptr, size_t length)
{
HEADSOCKET_LOCK(_ap->writeBlocks);
size_t toWrite = _ap->writeBlocks->peek(nullptr);
size_t toConsume = length > toWrite ? toWrite : length;
_ap->writeBlocks->read(ptr, toConsume);
if (toWrite == toConsume)
_ap->writeSemaphore.consume();
return toConsume;
}
//---------------------------------------------------------------------------------------------------------------------
size_t async_tcp_client::async_read_handler(uint8_t *ptr, size_t length)
{
HEADSOCKET_LOCK(_ap->readBlocks);
_ap->readBlocks->block_begin(opcode::binary);
_ap->readBlocks->write(ptr, length);
_ap->readBlocks->block_end();
return length;
}
//---------------------------------------------------------------------------------------------------------------------
void async_tcp_client::read_thread()
{
++_ap->threadCounter;
detail::set_thread_name("AsyncTcpClient::readThread");
std::vector<uint8_t> buffer(1024 * 1024);
size_t bufferBytes = 0, consumed = 0;
while (_p->isConnected)
{
while (true)
{
int result = static_cast<int>(bufferBytes);
if (!result || !consumed)
{
result = recv(
_p->conn.impl()->socket,
reinterpret_cast<char *>(buffer.data() + bufferBytes),
static_cast<int>(buffer.size() - bufferBytes),
0);
if (!result || result == detail::socket_error)
{
consumed = invalid_operation;
break;
}
bufferBytes += static_cast<size_t>(result);
}
consumed = async_read_handler(buffer.data(), bufferBytes);
if (!consumed)
{
if (bufferBytes == buffer.size())
buffer.resize(buffer.size() * 2);
}
else
break;
}
if (consumed == invalid_operation)
break;
bufferBytes -= consumed;
if (bufferBytes)
memcpy(buffer.data(), buffer.data() + consumed, bufferBytes);
}
kill_threads();
--_ap->threadCounter;
}
//---------------------------------------------------------------------------------------------------------------------
void async_tcp_client::kill_threads()
{
if (std::this_thread::get_id() == _ap->readThread->get_id())
_ap->writeSemaphore.notify();
disconnect();
}
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
//---------------------------------------------------------------------------------------------------------------------
web_socket_client::web_socket_client(const std::string &address, int port)
: base_t(address, port)
{
}
//---------------------------------------------------------------------------------------------------------------------
web_socket_client::web_socket_client(ptr<basic_tcp_server> server, connection &conn)
: base_t(server, conn)
{
}
//---------------------------------------------------------------------------------------------------------------------
web_socket_client::~web_socket_client()
{
}
//---------------------------------------------------------------------------------------------------------------------
size_t web_socket_client::peek(opcode *op) const
{
HEADSOCKET_LOCK(_ap->readBlocks);
return _ap->readBlocks->peek(op);
}
//---------------------------------------------------------------------------------------------------------------------
size_t web_socket_client::async_write_handler(uint8_t *ptr, size_t length)
{
uint8_t *cursor = ptr;
HEADSOCKET_LOCK(_ap->writeBlocks);
while (length >= 16)
{
opcode op{};
size_t toWrite = _ap->writeBlocks->peek(&op);
size_t toConsume = (length - 15) > frame_size_limit ? frame_size_limit : (length - 15);
toConsume = toConsume > toWrite ? toWrite : toConsume;
frame_header header;
header.fin = (toWrite - toConsume) == 0;
header.op = op;
header.masked = false;
header.payload_length = toConsume;
size_t headerSize = header.write(cursor, length);
cursor += headerSize;
length -= headerSize;
_ap->writeBlocks->read(cursor, toConsume);
cursor += toConsume;
length -= toConsume;
if (header.fin)
_ap->writeSemaphore.consume();
if (!_ap->writeBlocks->peek(&op))
break;
}
return cursor - ptr;
}
//---------------------------------------------------------------------------------------------------------------------
size_t web_socket_client::async_read_handler(uint8_t *ptr, size_t length)
{
uint8_t *cursor = ptr;
HEADSOCKET_LOCK(_ap->readBlocks);
if (!_payload_size)
{
opcode prevOpcode = _current_header.op;
size_t headerSize = _current_header.read(cursor, length);
if (!headerSize)
return 0;
else if (headerSize == invalid_operation)
return invalid_operation;
_payload_size = _current_header.payload_length;
cursor += headerSize;
length -= headerSize;
if (_current_header.op != opcode::continuation)
_ap->readBlocks->block_begin(_current_header.op);
else
_current_header.op = prevOpcode;
}
if (_payload_size)
{
size_t toConsume = length >= _payload_size ? _payload_size : length;
if (toConsume)
{
_ap->readBlocks->write(cursor, toConsume);
_payload_size -= toConsume;
cursor += toConsume;
length -= toConsume;
}
}
if (!_payload_size)
{
if (_current_header.masked)
{
//data_block &db = _ap->readBlocks->blocks.back();
size_t len = _current_header.payload_length;
detail::utils::xor32(_current_header.masking_key, _ap->readBlocks->buffer.data() + _ap->readBlocks->buffer.size() - len, len);
}
if (_current_header.fin)
{
data_block &db = _ap->readBlocks->blocks.back();
switch (_current_header.op)
{
case opcode::ping:
push(_ap->readBlocks->buffer.data() + db.offset, db.length, opcode::pong);
break;
case opcode::text:
_ap->readBlocks->buffer.push_back(0);
++db.length;
break;
case opcode::connection_close:
kill_threads();
break;
default:
break;
}
if (_current_header.op == opcode::text || _current_header.op == opcode::binary)
{
_ap->readBlocks->block_end();
if (async_received_data(db, _ap->readBlocks->buffer.data() + db.offset, db.length))
_ap->readBlocks->block_remove();
}
}
}
return cursor - ptr;
}
//---------------------------------------------------------------------------------------------------------------------
#define HAVE_ENOUGH_BYTES(num) if (length < num) return 0; else length -= num;
size_t web_socket_client::frame_header::read(const uint8_t *ptr, size_t length)
{
const uint8_t *cursor = ptr;
HAVE_ENOUGH_BYTES(2);
this->fin = ((*cursor) & 0x80) != 0;
this->op = static_cast<opcode>((*cursor++) & 0x0F);
this->masked = ((*cursor) & 0x80) != 0;
uint8_t byte = (*cursor++) & 0x7F;
if (byte < 126)
this->payload_length = byte;
else if (byte == 126)
{
HAVE_ENOUGH_BYTES(2);
this->payload_length = detail::utils::swap16bits(*(reinterpret_cast<const uint16_t *>(cursor)));
cursor += 2;
}
else if (byte == 127)
{
HAVE_ENOUGH_BYTES(8);
uint64_t length64 = detail::utils::swap64bits(*(reinterpret_cast<const uint64_t *>(cursor))) & 0x7FFFFFFFFFFFFFFFULL;
this->payload_length = static_cast<size_t>(length64);
cursor += 8;
}
if (this->masked)
{
HAVE_ENOUGH_BYTES(4);
this->masking_key = *(reinterpret_cast<const uint32_t *>(cursor));
cursor += 4;
}
return cursor - ptr;
}
//---------------------------------------------------------------------------------------------------------------------
size_t web_socket_client::frame_header::write(uint8_t *ptr, size_t length) const
{
uint8_t *cursor = ptr;
HAVE_ENOUGH_BYTES(2);
*cursor = this->fin ? 0x80 : 0x00;
*cursor++ |= static_cast<uint8_t>(this->op);
*cursor = this->masked ? 0x80 : 0x00;
if (this->payload_length < 126)
*cursor++ |= static_cast<uint8_t>(this->payload_length);
else if (this->payload_length < 65536)
{
HAVE_ENOUGH_BYTES(2);
*cursor++ |= 126;
*reinterpret_cast<uint16_t *>(cursor) = detail::utils::swap16bits(static_cast<uint16_t>(this->payload_length));
cursor += 2;
}
else
{
HAVE_ENOUGH_BYTES(8);
*cursor++ |= 127;
*reinterpret_cast<uint64_t *>(cursor) = detail::utils::swap64bits(static_cast<uint64_t>(this->payload_length));
cursor += 8;
}
if (this->masked)
{
HAVE_ENOUGH_BYTES(4);
*reinterpret_cast<uint32_t *>(cursor) = this->masking_key;
cursor += 4;
}
return cursor - ptr;
}
#undef HAVE_ENOUGH_BYTES
///////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail {
}
//---------------------------------------------------------------------------------------------------------------------
bool http_server::handshake(connection &conn)
{
std::string requestLine;
if (!conn.read_line(requestLine))
return false;
std::string headerLine;
while (conn.read_line(headerLine))
{
if (headerLine.empty())
break;
}
std::string method = detail::utils::cut_front(requestLine);
std::string path = detail::utils::url_decode(detail::utils::cut_front(requestLine));
if (!path.empty() && path.front() == '/') path = path.substr(1);
if (!path.empty() && path.back() == '/') path = path.substr(0, path.length() - 1);
std::string params_get = detail::utils::cut_back(path, '?', false, false);
std::string version = detail::utils::cut_front(requestLine);
parameters_t params;
std::string param_str;
while (!(param_str = detail::utils::cut_front(params_get, '&')).empty())
{
parameter param;
param.name = detail::utils::cut_front(param_str, '=');
param.value = param_str;
param.integer = atoi(param_str.c_str());
param.real = atof(param_str.c_str());
param.boolean = (param.integer != 0) || (param_str == "true");
params[param.name] = param;
}
response resp;
if (path != "favicon.ico" && request(path, params, resp))
{
std::stringstream ss;
ss << version << " 200 OK\r\n";
ss << "Content-Type: " << resp.content_type << "\r\n";
ss << "Content-Length: " << resp.message.length() << "\r\n\r\n";
ss << resp.message;
conn.write(ss.str());
}
else
{
conn.write(version);
conn.write(" 404 Not Found\r\n");
}
return false;
}
}
#endif
#endif