spicetools/avs/ssl.cpp

1364 lines
42 KiB
C++
Raw Normal View History

2024-08-28 15:10:34 +00:00
#define SECURITY_WIN32
#include <cstring>
#include <iostream>
#include <vector>
#include <windows.h>
#include <schnlsp.h>
#include <sspi.h>
#include "util/logging.h"
#include "ssl.h"
#include "core.h"
#ifndef SP_PROT_TLS1_1_CLIENT
#define SP_PROT_TLS1_1_CLIENT 0x00000200
#endif
#ifndef SP_PROT_TLS1_2_CLIENT
#define SP_PROT_TLS1_2_CLIENT 0x00000800
#endif
#define IO_BUF_INITIAL_CAPACITY 524288
namespace avs {
// These structures are initialized by AVS and our functions are handed pointers to
// them, do not use C++ objects in them!
struct io_buf {
uint8_t *bytes;
size_t capacity;
size_t pos;
size_t limit;
};
struct core::avs_net_proto_desc_work {
void *proto_desc_work;
uint32_t padding[16];
};
struct core::avs_net_sock_desc_work {
void *proto_desc_work;
int sock_fd;
uint32_t address;
avs_net_port_t port;
const char *hostname;
uint32_t send_timeout;
uint32_t recv_timeout;
bool non_blocking;
CtxtHandle security_context;
CredHandle credentials_handle;
SecPkgContext_StreamSizes sizes;
struct io_buf send_buf;
struct io_buf recv_buf;
uint8_t *recv_plaintext_tail;
size_t recv_plaintext_count;
uint8_t *recv_ciphertext_tail;
size_t recv_ciphertext_count;
bool got_shutdown;
uint32_t padding[16];
};
namespace ssl {
using avs::core::avs_iovec;
using avs::core::avs_net_port_t;
using avs::core::avs_net_poll_fd;
using avs::core::avs_net_poll_fd_opaque;
using avs::core::avs_net_pollfds_size_t;
using avs::core::avs_net_proto_desc_work;
using avs::core::avs_net_size_t;
using avs::core::avs_net_sock_desc_work;
using avs::core::avs_net_timeout_t;
using avs::core::AVS_ERROR_CLASS_NET;
using avs::core::AVS_ERROR_SUBCLASS_NET_TIMEOUT;
using avs::core::AVS_ERROR_SUBCLASS_SC_BADMSG;
using avs::core::AVS_ERROR_SUBCLASS_SC_INVAL;
using avs::core::AVS_NET_POLL_POLLIN;
using avs::core::AVS_NET_PROTOCOL_SSL_TLS_V1_1;
using avs::core::AVS_SO_SNDTIMEO;
using avs::core::AVS_SO_RCVTIMEO;
using avs::core::AVS_SO_NONBLOCK;
using avs::core::AVS_SO_SSL_PROTOCOL;
using avs::core::AVS_SO_SSL_VERIFY_CN;
using avs::core::T_NET_PROTO_ID_DEFAULT;
enum tls_recv_payload {
TLS_RECV_PAYLOAD_NONE,
TLS_RECV_PAYLOAD_DATA,
TLS_RECV_PAYLOAD_SHUTDOWN_TOKEN,
};
static constexpr size_t alignment = 16;
static int io_buf_init(struct io_buf *buf, size_t initial_capacity) {
uint8_t *tmp_bytes = nullptr;
int err = 0;
if (buf == nullptr) {
err = -1;
goto arg_fail;
}
tmp_bytes = reinterpret_cast<uint8_t *>(_aligned_malloc(initial_capacity, alignment));
if (tmp_bytes == nullptr) {
err = -1;
goto alloc_fail;
}
buf->bytes = tmp_bytes;
buf->capacity = initial_capacity;
buf->pos = 0;
buf->limit = initial_capacity;
return 0;
alloc_fail:
memset(buf, 0, sizeof(*buf));
arg_fail:
return err;
}
static inline int io_buf_validate(const struct io_buf *buf) {
if (buf == nullptr || buf->bytes == nullptr) {
return -1;
}
if (buf->pos > buf->limit || buf->limit > buf->capacity) {
return -1;
}
return 0;
}
static int io_buf_grow_to(struct io_buf *buf, size_t min_capacity) {
auto err = io_buf_validate(buf);
if (err != 0) {
return err;
}
auto tmp_capacity = buf->capacity;
while (tmp_capacity < min_capacity) {
tmp_capacity *= 2;
}
if (tmp_capacity <= buf->capacity) {
return -1;
}
auto tmp_bytes = reinterpret_cast<uint8_t *>(_aligned_malloc(tmp_capacity, alignment));
if (tmp_bytes == NULL) {
return -1;
}
memcpy(tmp_bytes, buf->bytes, buf->pos);
_aligned_free(buf->bytes);
buf->bytes = tmp_bytes;
buf->capacity = tmp_capacity;
return 0;
}
static int io_buf_grow(struct io_buf *buf) {
auto err = io_buf_validate(buf);
if (err != 0) {
return err;
}
return io_buf_grow_to(buf, buf->capacity * 2);
}
static int io_buf_append(struct io_buf *buf, const void *src, size_t *nbytes) {
auto err = io_buf_validate(buf);
if (err != 0) {
return err;
}
if (src == nullptr || nbytes == nullptr) {
return -1;
}
if (*nbytes > buf->limit - buf->pos) {
*nbytes = buf->limit - buf->pos;
}
memcpy(&buf->bytes[buf->pos], src, *nbytes);
buf->pos += *nbytes;
return 0;
}
static int io_buf_flip(struct io_buf *buf) {
auto err = io_buf_validate(buf);
if (err != 0) {
return err;
}
buf->limit = buf->pos;
buf->pos = 0;
return 0;
}
static void io_buf_finish(struct io_buf *buf) {
if (buf != nullptr && buf->bytes != nullptr) {
_aligned_free(buf->bytes);
buf->bytes = nullptr;
buf->capacity = 0;
buf->pos = 0;
buf->limit = 0;
}
}
static int impl_socket_recv(struct avs_net_sock_desc_work *work, struct io_buf *buf) {
int result = 0;
int err = 0;
if (work == nullptr || work->sock_fd < 0) {
return -1;
}
err = io_buf_validate(buf);
if (err != 0) {
return -1;
}
if (buf->pos == buf->limit) {
return -1;
}
result = core::avs_net_recv(work->sock_fd, &buf->bytes[buf->pos], buf->limit - buf->pos);
if (result < 0) {
log_warning("avs::ssl", "avs_net_recv failed: 0x{:08x}", result);
return -1;
}
if (result == 0) {
log_misc("avs::ssl", "connection closed");
return -2;
}
buf->pos += result;
return 0;
}
static int impl_socket_recv_all(
struct avs_net_sock_desc_work *work,
struct io_buf *buf,
uint32_t recv_timeout)
{
uint8_t old_non_blocking_value = 0;
uint8_t non_blocking = 0;
avs_net_size_t old_non_blocking_value_size = 0;
struct avs_net_poll_fd poll_fds[1] {};
int ret = 0;
if (work == nullptr || work->sock_fd < 0) {
return -1;
}
auto err = io_buf_validate(buf);
if (err != 0) {
return -1;
}
if (buf->pos == buf->limit) {
return -1;
}
poll_fds[0].socket = work->sock_fd;
poll_fds[0].events = AVS_NET_POLL_POLLIN;
non_blocking = 1;
old_non_blocking_value_size = sizeof(old_non_blocking_value);
core::avs_net_getsockopt(work->sock_fd, AVS_SO_NONBLOCK, &old_non_blocking_value, &old_non_blocking_value_size);
core::avs_net_setsockopt(work->sock_fd, AVS_SO_NONBLOCK, &non_blocking, sizeof(non_blocking));
while (true) {
auto result = core::avs_net_poll(poll_fds, std::size(poll_fds), recv_timeout);
if (result < 0) {
log_warning("avs::ssl", "avs_net_poll failed: 0x{:08x}", result);
ret = -1;
goto out;
}
if (!result) {
#if 0
log_warning("avs::ssl", "socket timeout, no data received after {} milliseconds", recv_timeout);
#endif
goto poll_succeeded_or_empty;
}
if (poll_fds[0].r_events & AVS_NET_POLL_POLLIN) {
result = core::avs_net_recv(work->sock_fd, &buf->bytes[buf->pos], buf->limit - buf->pos);
if (result < 0) {
break;
}
if (result == 0) {
log_misc("avs::ssl", "connection closed");
ret = -1;
goto out;
}
#if 0
log_warning("avs::ssl", "socket({}) got {} bytes during handshake", work->sock_fd, result);
#endif
buf->pos += result;
}
}
poll_succeeded_or_empty:
ret = 0;
out:
core::avs_net_setsockopt(work->sock_fd, AVS_SO_NONBLOCK, &old_non_blocking_value, sizeof(old_non_blocking_value));
return ret;
}
static int impl_socket_send(struct avs_net_sock_desc_work *work, struct io_buf *buf) {
if (work == nullptr || work->sock_fd < 0) {
return -1;
}
auto err = io_buf_validate(buf);
if (err != 0) {
return -1;
}
auto result = core::avs_net_send(work->sock_fd, &buf->bytes[buf->pos], buf->limit - buf->pos);
#if 0
log_warning("avs::ssl", "socket({}) sending {} bytes", work->sock_fd, buf->limit - buf->pos);
#endif
if (result != static_cast<int>(buf->limit - buf->pos)) {
log_warning("avs::ssl", "avs_net_send failed: 0x{:08x}", result);
return -1;
}
buf->pos = buf->limit;
return 0;
}
static int tls_begin_buffer(struct avs_net_sock_desc_work *work) {
auto status = QueryContextAttributes(&work->security_context, SECPKG_ATTR_STREAM_SIZES, &work->sizes);
if (status != SEC_E_OK) {
log_warning("avs::ssl", "QueryContextAttributes failed: {}", FMT_HRESULT(status));
return -1;
}
auto total = work->sizes.cbHeader + work->sizes.cbTrailer + work->sizes.cbMaximumMessage;
auto err = io_buf_grow_to(&work->send_buf, total);
if (err != 0) {
return -1;
}
return 0;
}
static int tls_send_chunk(struct avs_net_sock_desc_work *work, const uint8_t *bytes, size_t nbytes) {
int err = 0;
SecBuffer send_bufs[4];
SecBufferDesc send_vec;
SECURITY_STATUS status = 0;
send_vec.ulVersion = SECBUFFER_VERSION;
send_vec.pBuffers = send_bufs;
send_vec.cBuffers = std::size(send_bufs);
send_bufs[0].BufferType = SECBUFFER_STREAM_HEADER;
send_bufs[0].pvBuffer = &work->send_buf.bytes[0];
send_bufs[0].cbBuffer = work->sizes.cbHeader;
send_bufs[1].BufferType = SECBUFFER_DATA;
send_bufs[1].pvBuffer = &work->send_buf.bytes[work->sizes.cbHeader];
send_bufs[1].cbBuffer = nbytes;
send_bufs[2].BufferType = SECBUFFER_STREAM_TRAILER;
send_bufs[2].pvBuffer = &work->send_buf.bytes[work->sizes.cbHeader + nbytes];
send_bufs[2].cbBuffer = work->sizes.cbTrailer;
send_bufs[3].BufferType = SECBUFFER_EMPTY;
send_bufs[3].pvBuffer = nullptr;
send_bufs[3].cbBuffer = 0;
memcpy(send_bufs[1].pvBuffer, bytes, nbytes);
status = EncryptMessage(&work->security_context, 0, &send_vec, 0);
if (status != SEC_E_OK) {
log_warning("avs::ssl", "EncryptMessage failed: {}", FMT_HRESULT(status));
return -1;
}
work->send_buf.pos = 0;
work->send_buf.limit = send_bufs[0].cbBuffer + send_bufs[1].cbBuffer + send_bufs[2].cbBuffer;
err = impl_socket_send(work, &work->send_buf);
if (err != 0) {
log_warning("avs::ssl", "impl_socket_send failed: {}", err);
}
return 0;
}
static int tls_send(struct avs_net_sock_desc_work *work, struct io_buf *buf) {
size_t chunk_size = 0;
int err = 0;
err = io_buf_validate(buf);
if (err != 0) {
return err;
}
if (buf->pos == buf->limit) {
return -1;
}
while (buf->pos < buf->limit) {
chunk_size = buf->limit - buf->pos;
if (chunk_size > work->sizes.cbMaximumMessage) {
chunk_size = work->sizes.cbMaximumMessage;
}
err = tls_send_chunk(work, &buf->bytes[buf->pos], chunk_size);
if (err != 0) {
return err;
}
buf->pos += chunk_size;
}
return err;
}
static int tls_recv_dequeue_plaintext(struct avs_net_sock_desc_work *work, struct io_buf *buf) {
size_t tail_nbytes = 0;
int err = 0;
tail_nbytes = work->recv_plaintext_count;
err = io_buf_append(buf, work->recv_plaintext_tail, &tail_nbytes);
if (err != 0) {
return err;
}
SecureZeroMemory(work->recv_plaintext_tail, tail_nbytes);
work->recv_plaintext_count -= tail_nbytes;
work->recv_plaintext_tail += tail_nbytes;
return 0;
}
static int tls_recv_common(struct avs_net_sock_desc_work *work, enum tls_recv_payload *payload) {
int err = 0;
SecBuffer recv_bufs[4];
SecBufferDesc recv_vec;
SECURITY_STATUS status = SEC_E_OK;
// Consolidate any leftover ciphertext at the start of the buffer
memmove(work->recv_buf.bytes, work->recv_ciphertext_tail, work->recv_ciphertext_count);
work->recv_buf.pos = work->recv_ciphertext_count;
work->recv_buf.limit = work->recv_buf.capacity;
work->recv_ciphertext_tail = nullptr;
work->recv_ciphertext_count = 0;
recv_vec.ulVersion = SECBUFFER_VERSION;
recv_vec.cBuffers = std::size(recv_bufs);
recv_vec.pBuffers = recv_bufs;
while (true) {
recv_bufs[0].BufferType = SECBUFFER_DATA;
recv_bufs[0].pvBuffer = work->recv_buf.bytes;
recv_bufs[0].cbBuffer = work->recv_buf.pos;
for (size_t i = 1; i < std::size(recv_bufs); i++) {
recv_bufs[i].BufferType = SECBUFFER_EMPTY;
recv_bufs[i].pvBuffer = nullptr;
recv_bufs[i].cbBuffer = 0;
}
status = DecryptMessage(&work->security_context, &recv_vec, 0, nullptr);
if (status != SEC_E_INCOMPLETE_MESSAGE) {
break;
}
err = impl_socket_recv(work, &work->recv_buf);
if (err != 0) {
if (err != -2) {
log_warning("avs::ssl", "impl_socket_recv failed: {}", err);
}
return err;
}
}
// Deal with whatever it is we received
switch (status) {
case SEC_E_OK:
// Walk buffers and mark up the plaintext and ciphertext span within our
// own io_buf as appropriate
for (size_t i = 0; i < std::size(recv_bufs); i++) {
switch (recv_bufs[i].BufferType) {
case SECBUFFER_DATA:
work->recv_plaintext_tail = reinterpret_cast<uint8_t *>(recv_bufs[i].pvBuffer);
work->recv_plaintext_count = recv_bufs[i].cbBuffer;
break;
case SECBUFFER_EXTRA:
work->recv_ciphertext_tail = reinterpret_cast<uint8_t *>(recv_bufs[i].pvBuffer);
work->recv_ciphertext_count = recv_bufs[i].cbBuffer;
break;
default:
break;
}
}
*payload = TLS_RECV_PAYLOAD_DATA;
return 0;
case SEC_I_CONTEXT_EXPIRED:
*payload = TLS_RECV_PAYLOAD_SHUTDOWN_TOKEN;
return 0;
default:
log_warning("avs::ssl", "DecryptMessage failed: {}", FMT_HRESULT(status));
work->recv_buf.pos = 0;
work->recv_buf.limit = 0;
return -1;
}
}
static int tls_recv(struct avs_net_sock_desc_work *work, struct io_buf *buf) {
enum tls_recv_payload payload = TLS_RECV_PAYLOAD_NONE;
int err = 0;
err = io_buf_validate(buf);
if (err != 0) {
return err;
}
if (buf->pos == buf->limit) {
return -1;
}
// Try to drain any leftover plaintext in the receive buffer
if (work->recv_plaintext_count > 0) {
return tls_recv_dequeue_plaintext(work, buf);
}
err = tls_recv_common(work, &payload);
if (err != 0) {
return err;
}
switch (payload) {
case TLS_RECV_PAYLOAD_DATA:
return tls_recv_dequeue_plaintext(work, buf);
case TLS_RECV_PAYLOAD_SHUTDOWN_TOKEN:
work->got_shutdown = true;
return -2;
default:
return -1;
}
}
static int tls_recv_shutdown(struct avs_net_sock_desc_work *work) {
enum tls_recv_payload payload = TLS_RECV_PAYLOAD_NONE;
if (work->recv_plaintext_count > 0) {
return -1;
}
if (work->got_shutdown) {
return 0;
}
auto err = tls_recv_common(work, &payload);
if (err != 0) {
return err;
}
if (payload != TLS_RECV_PAYLOAD_SHUTDOWN_TOKEN) {
return -1;
}
return 0;
}
static int tls_send_shutdown(struct avs_net_sock_desc_work *work) {
static uint32_t tls_shutdown_token = SCHANNEL_SHUTDOWN;
ULONG attrs = 0;
SecBuffer cmd_buf;
SecBufferDesc cmd_vec;
SECURITY_STATUS status = SEC_E_OK;
cmd_vec.ulVersion = SECBUFFER_VERSION;
cmd_vec.pBuffers = &cmd_buf;
cmd_vec.cBuffers = 1;
cmd_buf.BufferType = SECBUFFER_TOKEN;
cmd_buf.pvBuffer = static_cast<void *>(&tls_shutdown_token);
cmd_buf.cbBuffer = sizeof(tls_shutdown_token);
status = ApplyControlToken(&work->security_context, &cmd_vec);
if (status != SEC_E_OK) {
log_warning("avs::ssl", "{}: ApplyControlToken failed: {}", __func__, FMT_HRESULT(status));
return -1;
}
cmd_vec.ulVersion = SECBUFFER_VERSION;
cmd_vec.pBuffers = &cmd_buf;
cmd_vec.cBuffers = 1;
cmd_buf.BufferType = SECBUFFER_TOKEN;
cmd_buf.pvBuffer = work->send_buf.bytes;
cmd_buf.cbBuffer = work->send_buf.capacity;
log_info("avs::ssl", "calling InitializeSecurityContextA to generate token");
status = InitializeSecurityContextA(
&work->credentials_handle,
&work->security_context,
const_cast<SEC_CHAR *>(work->hostname),
0,
0,
0,
nullptr,
0,
nullptr,
&cmd_vec,
&attrs,
nullptr);
if (status != SEC_E_OK) {
log_warning("avs::ssl", "{}: InitializeSecurityContextA failed: {}", __func__, FMT_HRESULT(status));
return -1;
}
work->send_buf.pos = 0;
work->send_buf.limit = cmd_buf.cbBuffer;
auto err = impl_socket_send(work, &work->send_buf);
if (err != 0) {
log_warning("avs::ssl", "impl_socket_send failed: {}", err);
}
return err;
}
static int ssl_protocol_initialize(struct avs_net_proto_desc_work *work) {
return 0;
}
static int ssl_protocol_finalize(struct avs_net_proto_desc_work *work) {
return 0;
}
static int ssl_allocate_socket(struct avs_net_sock_desc_work *work) {
SCHANNEL_CRED credentials;
CredHandle credentials_handle;
memset(work, 0, sizeof(*work));
memset(&credentials, 0, sizeof(credentials));
credentials.dwVersion = SCHANNEL_CRED_VERSION;
credentials.cCreds = 0;
credentials.paCred = nullptr;
credentials.dwFlags = SCH_CRED_AUTO_CRED_VALIDATION | SCH_CRED_REVOCATION_CHECK_CHAIN_EXCLUDE_ROOT | SCH_CRED_IGNORE_REVOCATION_OFFLINE;
credentials.grbitEnabledProtocols = SP_PROT_TLS1_CLIENT | SP_PROT_TLS1_1_CLIENT | SP_PROT_TLS1_2_CLIENT;
auto status = AcquireCredentialsHandleA(
nullptr,
const_cast<SEC_CHAR *>(UNISP_NAME),
SECPKG_CRED_OUTBOUND,
nullptr,
&credentials,
nullptr,
nullptr,
&credentials_handle,
nullptr);
if (status != SEC_E_OK) {
log_warning("avs::ssl", "AcquireCredentialsHandleA failed: {}", status);
return core::avs_error_make(AVS_ERROR_CLASS_NET, AVS_ERROR_SUBCLASS_SC_INVAL);
}
work->credentials_handle = credentials_handle;
return 0;
}
static void ssl_free_socket(struct avs_net_sock_desc_work *work) {
if (work != nullptr) {
if (work->hostname != nullptr) {
free(const_cast<void *>(reinterpret_cast<const void *>(work->hostname)));
work->hostname = nullptr;
}
FreeCredentialsHandle(&work->credentials_handle);
}
}
static int ssl_initialize_socket(struct avs_net_sock_desc_work *work) {
auto sock_fd = core::avs_net_socket(T_NET_PROTO_ID_DEFAULT);
if (sock_fd > 0) {
work->sock_fd = sock_fd;
return 1;
}
return 0;
}
static void ssl_finalize_socket(struct avs_net_sock_desc_work *work) {
}
static int ssl_setsockopt(
struct avs_net_sock_desc_work *work,
unsigned int option_name,
const void *option_value,
avs_net_size_t option_len)
{
// SSL specific options here
switch (option_name) {
case AVS_SO_SSL_PROTOCOL:
log_info("avs::ssl", "AVS_SO_SSL_PROTOCOL = {}", *reinterpret_cast<const int *>(option_value));
return 0;
case AVS_SO_SSL_VERIFY_CN:
log_info("avs::ssl", "AVS_SO_SSL_VERIFY_CN = {}", *reinterpret_cast<const int *>(option_value));
return 0;
default:
break;
}
// Generic network options here
auto result = core::avs_net_setsockopt(work->sock_fd, option_name, option_value, option_len);
if (result < 0) {
return result;
}
switch (option_name) {
case AVS_SO_SNDTIMEO:
work->send_timeout = *reinterpret_cast<const avs_net_timeout_t *>(option_value);
break;
case AVS_SO_RCVTIMEO:
work->recv_timeout = *reinterpret_cast<const avs_net_timeout_t *>(option_value);
break;
case AVS_SO_NONBLOCK:
work->non_blocking = *reinterpret_cast<const uint8_t *>(option_value);
break;
default:
break;
}
return 0;
}
static int ssl_socket_getsockopt(
struct avs_net_sock_desc_work *work,
unsigned int option_name,
void *option_value,
avs_net_size_t *option_len)
{
switch (option_name) {
case AVS_SO_SNDTIMEO:
*reinterpret_cast<avs_net_timeout_t *>(option_value) = work->send_timeout;
*option_len = sizeof(avs_net_timeout_t);
break;
case AVS_SO_RCVTIMEO:
*reinterpret_cast<avs_net_timeout_t *>(option_value) = work->recv_timeout;
*option_len = sizeof(avs_net_timeout_t);
break;
case AVS_SO_NONBLOCK:
*reinterpret_cast<uint8_t *>(option_value) = work->non_blocking;
*option_len = sizeof(uint8_t);
break;
case AVS_SO_SSL_PROTOCOL:
*reinterpret_cast<uint32_t *>(option_value) = AVS_NET_PROTOCOL_SSL_TLS_V1_1;
*option_len = sizeof(uint32_t);
break;
case AVS_SO_SSL_VERIFY_CN:
*reinterpret_cast<uint8_t *>(option_value) = 0;
*option_len = sizeof(uint8_t);
break;
default:
return core::avs_net_getsockopt(work->sock_fd, option_name, option_value, option_len);
}
return 0;
}
static int ssl_socket_bind(
struct avs_net_sock_desc_work *work,
uint32_t address,
avs_net_port_t port)
{
auto result = core::avs_net_bind(work->sock_fd, address, port);
if (result > 0) {
work->address = address;
work->port = port;
}
return result;
}
static int ssl_socket_connect(
struct avs_net_sock_desc_work *work,
uint32_t address,
avs_net_port_t port)
{
constexpr uint32_t security_context_flags = ISC_REQ_CONFIDENTIALITY |
ISC_REQ_INTEGRITY |
ISC_REQ_STREAM |
ISC_REQ_SEQUENCE_DETECT |
ISC_REQ_REPLAY_DETECT;
SecBuffer send_bufs[2] {};
SecBufferDesc send_vec {};
SecBuffer recv_bufs[2] {};
SecBufferDesc recv_vec {};
CtxtHandle security_context;
ULONG attributes = 0;
SECURITY_STATUS status = SEC_E_OK;
int result = 0;
int err = 0;
char hostname[256];
memset(&security_context, 0, sizeof(security_context));
result = core::avs_net_addrinfobyaddr(address, hostname, sizeof(hostname), 1);
if (result < 0) {
log_warning("avs::ssl", "avs_net_addrinfobyaddr failed: {}", FMT_HRESULT(result));
return result;
}
work->hostname = strdup(hostname);
err = io_buf_init(&work->send_buf, IO_BUF_INITIAL_CAPACITY);
if (err != 0) {
goto send_buf_fail;
}
err = io_buf_init(&work->recv_buf, IO_BUF_INITIAL_CAPACITY);
if (err != 0) {
goto recv_buf_fail;
}
result = core::avs_net_connect(work->sock_fd, address, port);
if (result < 0) {
log_warning("avs::ssl", "avs_net_connect failed: {}", FMT_HRESULT(result));
err = result;
goto connect_fail;
}
send_vec.ulVersion = SECBUFFER_VERSION;
send_vec.pBuffers = send_bufs;
send_vec.cBuffers = std::size(send_bufs);
recv_vec.ulVersion = SECBUFFER_VERSION;
recv_vec.pBuffers = recv_bufs;
recv_vec.cBuffers = std::size(recv_bufs);
send_bufs[0].BufferType = SECBUFFER_TOKEN;
send_bufs[0].pvBuffer = work->send_buf.bytes;
send_bufs[0].cbBuffer = work->send_buf.capacity;
send_bufs[1].BufferType = SECBUFFER_EMPTY;
send_bufs[1].pvBuffer = nullptr;
send_bufs[1].cbBuffer = 0;
status = InitializeSecurityContextA(
&work->credentials_handle,
nullptr,
const_cast<SEC_CHAR *>(work->hostname),
security_context_flags,
0,
0,
nullptr,
0,
&security_context,
&send_vec,
&attributes,
nullptr);
if (status != SEC_I_CONTINUE_NEEDED) {
log_warning("ssl", "{}: InitializeSecurityContextA failed: {}", __func__, FMT_HRESULT(status));
err = core::avs_error_make(AVS_ERROR_CLASS_NET, AVS_ERROR_SUBCLASS_SC_BADMSG);
goto first_isc_fail;
}
while (status != SEC_E_OK) {
switch (status) {
case SEC_I_CONTINUE_NEEDED:
work->send_buf.pos = 0;
work->send_buf.limit = send_bufs[0].cbBuffer;
// Only send data if we need to
if (send_bufs[0].cbBuffer > 0) {
err = impl_socket_send(work, &work->send_buf);
if (err != 0) {
goto loop_fail;
}
}
work->recv_buf.pos = 0;
work->recv_buf.limit = work->recv_buf.capacity;
err = impl_socket_recv_all(work, &work->recv_buf, 100);
if (err != 0) {
goto loop_fail;
}
break;
case SEC_E_INCOMPLETE_MESSAGE:
if (recv_bufs[1].BufferType != SECBUFFER_MISSING) {
err = core::avs_error_make(AVS_ERROR_CLASS_NET, AVS_ERROR_SUBCLASS_SC_BADMSG);
goto loop_fail;
}
err = io_buf_grow_to(&work->recv_buf, work->recv_buf.pos + recv_bufs[1].cbBuffer);
if (err != 0) {
goto loop_fail;
}
work->recv_buf.limit = work->recv_buf.capacity;
err = impl_socket_recv(work, &work->recv_buf);
if (err != 0) {
goto loop_fail;
}
break;
case SEC_E_BUFFER_TOO_SMALL:
err = io_buf_grow(&work->send_buf);
if (err != 0) {
goto loop_fail;
}
break;
case SEC_E_WRONG_PRINCIPAL:
case SEC_E_CERT_EXPIRED:
case SEC_E_UNTRUSTED_ROOT:
case SEC_E_ALGORITHM_MISMATCH:
case SEC_E_INCOMPLETE_CREDENTIALS:
log_warning("avs::ssl", "unable to verify server/client certificate: {}", FMT_HRESULT(status));
err = core::avs_error_make(AVS_ERROR_CLASS_NET, AVS_ERROR_SUBCLASS_SC_INVAL);
goto loop_fail;
default:
log_warning("avs::ssl", "TLS handshake failed with status: {}", FMT_HRESULT(status));
err = core::avs_error_make(AVS_ERROR_CLASS_NET, AVS_ERROR_SUBCLASS_SC_BADMSG);
goto loop_fail;
}
send_bufs[0].BufferType = SECBUFFER_TOKEN;
send_bufs[0].pvBuffer = work->send_buf.bytes;
send_bufs[0].cbBuffer = work->send_buf.capacity;
send_bufs[1].BufferType = SECBUFFER_EMPTY;
send_bufs[1].pvBuffer = nullptr;
send_bufs[1].cbBuffer = 0;
recv_bufs[0].BufferType = SECBUFFER_TOKEN;
recv_bufs[0].pvBuffer = work->recv_buf.bytes;
recv_bufs[0].cbBuffer = work->recv_buf.pos;
recv_bufs[1].BufferType = SECBUFFER_EMPTY;
recv_bufs[1].pvBuffer = nullptr;
recv_bufs[1].cbBuffer = 0;
status = InitializeSecurityContextA(
&work->credentials_handle,
&security_context,
reinterpret_cast<SEC_CHAR *>(hostname),
security_context_flags,
0,
0,
&recv_vec,
0,
nullptr,
&send_vec,
&attributes,
nullptr);
}
log_misc("avs::ssl", "TLS handshake complete");
work->security_context = security_context;
tls_begin_buffer(work);
return 0;
loop_fail:
DeleteSecurityContext(&security_context);
connect_fail:
first_isc_fail:
io_buf_finish(&work->recv_buf);
recv_buf_fail:
io_buf_finish(&work->send_buf);
send_buf_fail:
return err;
}
static int ssl_socket_listen(struct avs_net_sock_desc_work *work, int backlog) {
return -1;
}
static int ssl_socket_accept(
struct avs_net_sock_desc_work *work,
void *new_sock,
uint32_t *address,
avs_net_port_t *port)
{
return -1;
}
static int ssl_socket_close(struct avs_net_sock_desc_work *work) {
return core::avs_net_close(work->sock_fd);
}
static int ssl_socket_shutdown(struct avs_net_sock_desc_work *work, int how) {
auto err = tls_send_shutdown(work);
if (err != 0) {
goto fail;
}
err = tls_recv_shutdown(work);
if (err != 0) {
goto fail;
}
fail:
io_buf_finish(&work->recv_buf);
return core::avs_net_shutdown(work->sock_fd, how);
}
static int ssl_socket_sendtov(
struct avs_net_sock_desc_work *work,
const struct avs_iovec *iovec,
int iov_count,
uint32_t address,
avs_net_port_t port)
{
struct io_buf send_io_buf;
int err = 0;
int result = -1;
int bytes_sent = 0;
for (int i = 0; i < iov_count; i++) {
auto iovp = &iovec[i];
auto iov_len = iovp->iov_len;
err = io_buf_init(&send_io_buf, IO_BUF_INITIAL_CAPACITY);
if (err != 0) {
goto fail;
}
err = io_buf_append(&send_io_buf, iovp->iov_base, &iov_len);
if (err != 0) {
goto fail;
}
err = io_buf_flip(&send_io_buf);
if (err != 0) {
goto fail;
}
result = tls_send(work, &send_io_buf);
if (result < 0) {
log_warning("avs::ssl", "tls_send failed: {}", result);
return result;
}
// Use the original length
bytes_sent += iovp->iov_len;
io_buf_finish(&send_io_buf);
}
result = bytes_sent;
fail:
return result;
}
static int ssl_socket_recvfromv(
struct avs_net_sock_desc_work *work,
struct avs_iovec *iovec,
int iov_count,
uint32_t *address,
avs_net_port_t *port)
{
struct io_buf recv_io_buf;
int result = -1;
int bytes_received = 0;
for (int i = 0; i < iov_count; i++) {
auto iovp = &iovec[i];
auto iov_len = iovp->iov_len;
if (!iov_len) {
continue;
}
recv_io_buf.bytes = reinterpret_cast<uint8_t *>(iovp->iov_base);
recv_io_buf.pos = 0;
recv_io_buf.limit = iov_len;
recv_io_buf.capacity = iov_len;
result = tls_recv(work, &recv_io_buf);
if (result < 0) {
// connection closed returns -2, convert to -1
if (result == -2) {
result = -1;
} else {
log_warning("avs::ssl", "{}: tls_recv failed: {}", __func__, result);
}
return result;
}
bytes_received += recv_io_buf.pos;
}
if (address != nullptr) {
*address = work->address;
}
if (port != nullptr) {
*port = work->port;
}
#if 0
log_warning("avs::ssl", "socket({}) received {} bytes", work->sock_fd, bytes_received);
#endif
return bytes_received;
}
static int ssl_socket_pollfds_add(
struct avs_net_sock_desc_work *work,
struct avs_net_poll_fd_opaque *fds,
avs_net_pollfds_size_t fds_size,
struct avs_net_poll_fd *events)
{
if (work->sock_fd < 0) {
return core::avs_error_make(AVS_ERROR_CLASS_NET, AVS_ERROR_SUBCLASS_NET_TIMEOUT);
}
return core::avs_net_pollfds_add(work->sock_fd, fds, fds_size, events);
}
static int ssl_socket_pollfds_get(
struct avs_net_sock_desc_work *work,
struct avs_net_poll_fd *events,
struct avs_net_poll_fd_opaque *fds)
{
if (work->sock_fd < 0) {
return core::avs_error_make(AVS_ERROR_CLASS_NET, AVS_ERROR_SUBCLASS_NET_TIMEOUT);
}
return core::avs_net_pollfds_get(work->sock_fd, events, fds);
}
static int ssl_socket_sockpeer(
struct avs_net_sock_desc_work *work,
bool peer_name,
uint32_t *address,
avs_net_port_t *port)
{
if (peer_name) {
return core::avs_net_get_peername(work->sock_fd, address, port);
}
return core::avs_net_get_sockname(work->sock_fd, address, port);
}
static struct core::avs_net_protocol_ops ssl_protocol_ops {
.protocol_initialize = ssl_protocol_initialize,
.protocol_finalize = ssl_protocol_finalize,
.allocate_socket = ssl_allocate_socket,
.free_socket = ssl_free_socket,
.initialize_socket = ssl_initialize_socket,
.finalize_socket = ssl_finalize_socket,
.setsockopt = ssl_setsockopt,
.getsockopt = ssl_socket_getsockopt,
.bind = ssl_socket_bind,
.connect = ssl_socket_connect,
.listen = ssl_socket_listen,
.accept = ssl_socket_accept,
.close = ssl_socket_close,
.shutdown = ssl_socket_shutdown,
.sendtov = ssl_socket_sendtov,
.recvfromv = ssl_socket_recvfromv,
.pollfds_add = ssl_socket_pollfds_add,
.pollfds_get = ssl_socket_pollfds_get,
.sockpeer = ssl_socket_sockpeer
};
static struct core::avs_net_protocol ssl_protocol {
.ops = &ssl_protocol_ops,
.magic = core::AVS_NET_PROTOCOL_MAGIC,
.protocol_id = SSL_PROTOCOL_ID,
.proto_work_size = sizeof(struct avs_net_proto_desc_work),
.sock_work_size = sizeof(struct avs_net_sock_desc_work),
};
static struct core::avs_net_protocol_legacy ssl_protocol_legacy {
.ops = &ssl_protocol_ops,
.protocol_id = SSL_PROTOCOL_ID,
.mystery = 0,
.sz_work = sizeof(struct avs_net_sock_desc_work),
};
void init() {
log_info("ssl", "initializing");
if (!core::avs_net_add_protocol) {
log_warning("ssl", "missing optional avs imports which are required for this module to work");
return;
}
core::avs_net_del_protocol(SSL_PROTOCOL_ID);
int regist_res = 0;
if (core::VERSION == core::AVSLEGACY) {
core::avs_net_add_protocol_legacy(&ssl_protocol_legacy);
} else {
core::avs_net_add_protocol(&ssl_protocol);
}
if (regist_res) {
log_fatal("ssl", "failed to register protocol");
}
}
}
}