diff --git a/common/CMakeLists.txt b/common/CMakeLists.txt index e2e5c60a..920f9f31 100644 --- a/common/CMakeLists.txt +++ b/common/CMakeLists.txt @@ -71,6 +71,7 @@ add_library(${TARGET} STATIC train.cpp log.cpp log.h + http.h ngram-cache.cpp ngram-cache.h ngram-map.cpp diff --git a/common/common.cpp b/common/common.cpp index ac2e3226..4261865a 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -2397,6 +2397,10 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa params.webui = common_webui_from_name(std::string(argv[i])); return true; } + if (arg == "--webui-mcp-proxy" || arg == "--ui-mcp-proxy") { + params.webui_mcp_proxy = true; + return true; + } if (arg == "--api-key") { CHECK_ARG params.api_keys.push_back(argv[i]); @@ -3234,6 +3238,7 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param "- auto: default webui \n" "- llamacpp: llamacpp webui \n" "(default: auto)", }); + options.push_back({ "server", " --ui-mcp-proxy, --webui-mcp-proxy", "experimental: whether to enable MCP CORS proxy - do not enable in untrusted environments (default: disabled)" }); options.push_back({ "server", " --api-key KEY", "API key to use for authentication (default: none)" }); options.push_back({ "server", " --api-key-file FNAME", "path to file containing API keys (default: none)" }); options.push_back({ "server", " --ssl-key-file FNAME", "path to file a PEM-encoded SSL private key" }); diff --git a/common/common.h b/common/common.h index bc68ca0f..1ec0f095 100644 --- a/common/common.h +++ b/common/common.h @@ -501,6 +501,7 @@ struct gpt_params { // "advanced" endpoints are disabled by default for better security common_webui webui = COMMON_WEBUI_AUTO; + bool webui_mcp_proxy = false; bool endpoint_slots = true; bool endpoint_props = false; // only control POST requests, not GET bool endpoint_metrics = false; diff --git a/common/http.h b/common/http.h new file mode 100644 index 00000000..d3daccd6 --- /dev/null +++ b/common/http.h @@ -0,0 +1,99 @@ +#pragma once + +#include + +struct common_http_url { + std::string scheme; + std::string user; + std::string password; + std::string host; + int port; + std::string path; +}; + +static common_http_url common_http_parse_url(const std::string & url) { + common_http_url parts; + auto scheme_end = url.find("://"); + + if (scheme_end == std::string::npos) { + throw std::runtime_error("invalid URL: no scheme"); + } + parts.scheme = url.substr(0, scheme_end); + + if (parts.scheme != "http" && parts.scheme != "https") { + throw std::runtime_error("unsupported URL scheme: " + parts.scheme); + } + + auto rest = url.substr(scheme_end + 3); + auto at_pos = rest.find('@'); + + if (at_pos != std::string::npos) { + auto auth = rest.substr(0, at_pos); + auto colon_pos = auth.find(':'); + if (colon_pos != std::string::npos) { + parts.user = auth.substr(0, colon_pos); + parts.password = auth.substr(colon_pos + 1); + } else { + parts.user = auth; + } + rest = rest.substr(at_pos + 1); + } + + auto slash_pos = rest.find('/'); + + if (slash_pos != std::string::npos) { + parts.host = rest.substr(0, slash_pos); + parts.path = rest.substr(slash_pos); + } else { + parts.host = rest; + parts.path = "/"; + } + + auto colon_pos = parts.host.find(':'); + + if (colon_pos != std::string::npos) { + parts.port = std::stoi(parts.host.substr(colon_pos + 1)); + parts.host = parts.host.substr(0, colon_pos); + } else if (parts.scheme == "http") { + parts.port = 80; + } else if (parts.scheme == "https") { + parts.port = 443; + } else { + throw std::runtime_error("unsupported URL scheme: " + parts.scheme); + } + + return parts; +} + +static std::pair common_http_client(const std::string & url) { + common_http_url parts = common_http_parse_url(url); + + if (parts.host.empty()) { + throw std::runtime_error("error: invalid URL format"); + } + +#ifndef CPPHTTPLIB_OPENSSL_SUPPORT + if (parts.scheme == "https") { + throw std::runtime_error( + "HTTPS is not supported. Please rebuild with one of:\n" + " -DLLAMA_BUILD_BORINGSSL=ON\n" + " -DLLAMA_BUILD_LIBRESSL=ON\n" + " -DLLAMA_OPENSSL=ON (default, requires OpenSSL dev files installed)" + ); + } +#endif + + httplib::Client cli(parts.scheme + "://" + parts.host + ":" + std::to_string(parts.port)); + + if (!parts.user.empty()) { + cli.set_basic_auth(parts.user, parts.password); + } + + cli.set_follow_location(true); + + return { std::move(cli), std::move(parts) }; +} + +static std::string common_http_show_masked_url(const common_http_url & parts) { + return parts.scheme + "://" + (parts.user.empty() ? "" : "****:****@") + parts.host + parts.path; +} diff --git a/examples/server/server-cors-proxy.h b/examples/server/server-cors-proxy.h new file mode 100644 index 00000000..4deb3f3a --- /dev/null +++ b/examples/server/server-cors-proxy.h @@ -0,0 +1,170 @@ +#pragma once + +#include "common.h" +#include "http.h" +#include +#include +#include +#include + +static std::string to_lower_copy(const std::string & value) { + std::string lowered(value.size(), '\0'); + std::transform(value.begin(), value.end(), lowered.begin(), [](unsigned char c) { return std::tolower(c); }); + return lowered; +} + +static httplib::Request prepare_proxy_req_header(const std::string & method, + const std::string & scheme, + const std::string & host, + int port, + const std::string & path, + const std::map & headers, + const std::string & body, + const httplib::FormFiles & files) { + httplib::Request req; + bool has_files = !files.empty(); + req.form.files = files; + std::string effective_body = body; + std::string override_content_type; + req.method = method; + req.path = path; + for (const auto & [key, value] : headers) { + const auto lowered = to_lower_copy(key); + if (lowered == "accept-encoding") { + // disable Accept-Encoding to avoid compressed responses + continue; + } + if (lowered == "transfer-encoding") { + // the body is already decoded + continue; + } + if (lowered == "content-length") { + // let httplib calculate Content-Length from the actual body + continue; + } + if (lowered == "content-type") { + if (has_files) { + // we set our own Content-Type with the new boundary + continue; + } + // when no files but the original request was multipart, + // the body is now JSON, so correct the Content-Type + if (value.find("multipart/form-data") != std::string::npos) { + override_content_type = "application/json; charset=utf-8"; + continue; + } + } + if (lowered == "host") { + bool is_default_port = (scheme == "https" && port == 443) || (scheme == "http" && port == 80); + req.set_header(key, is_default_port ? host : host + ":" + std::to_string(port)); + } else { + req.set_header(key, value); + } + } + req.body = effective_body; + if (!override_content_type.empty()) { + req.set_header("Content-Type", override_content_type); + } + //req.response_handler = response_handler; + //req.content_receiver = content_receiver; + + return req; +} + +static std::string get_param(httplib::Params params,const std::string & key, const std::string & def = "") { + auto it = params.find("url"); + if (it != params.end()) { + return it->second; + } + return def; +} + +static void proxy_request(const httplib::Request & req, + httplib::Response & res, + const std::string & method) { + std::string target_url = get_param(req.params, "url"); + common_http_url parsed_url = common_http_parse_url(target_url); + if (parsed_url.host.empty()) { + throw std::runtime_error("invalid target URL: missing host"); + } + + if (parsed_url.path.empty()) { + parsed_url.path = "/"; + } + + if (!parsed_url.password.empty()) { + throw std::runtime_error("authentication in target URL is not supported"); + } + + if (parsed_url.scheme != "http" && parsed_url.scheme != "https") { + throw std::runtime_error("unsupported URL scheme in target URL: " + parsed_url.scheme); + } + + SRV_INF("proxying %s request to %s://%s:%i%s\n", method.c_str(), parsed_url.scheme.c_str(), parsed_url.host.c_str(), parsed_url.port, parsed_url.path.c_str()); + std::map headers; + for (auto [key, value] : req.headers) { + auto new_key = key; + if (string_starts_with(new_key, "x-proxy-header-")) { + string_replace_all(new_key, "x-proxy-header-", ""); + } + headers[new_key] = value; + } + + httplib::Request proxy_req = prepare_proxy_req_header(method, + parsed_url.scheme, + parsed_url.host, + parsed_url.port, + parsed_url.path, + headers, + req.body, + req.form.files); + + // Make the proxied request + httplib::Result proxy_res; + + if (parsed_url.scheme == "https") { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + httplib::SSLClient cli(parsed_url.host, parsed_url.port); + // set timeouts, follow redirects as needed + cli.set_connection_timeout(600); + cli.set_read_timeout(600); + cli.set_write_timeout(600); + cli.set_follow_location(true); + proxy_res = cli.send(proxy_req); +#else + res.status = 501; + res.set_content("HTTPS not supported (build with OpenSSL)", "text/plain"); + return; +#endif + } else { + httplib::Client cli(parsed_url.host, parsed_url.port); + cli.set_connection_timeout(600); + cli.set_read_timeout(600); + cli.set_write_timeout(600); + proxy_res = cli.send(std::move(proxy_req)); + } + + if (!proxy_res) { + std::string error_data = "Proxy failed: " + httplib::to_string(proxy_res.error()); + json final_response{ {"error", error_data} }; + res.set_content(safe_json_to_str(final_response), "application/json; charset=utf-8"); + res.status = json_value(error_data, "code", 500); + return; + } + + res.status = proxy_res->status; + res.set_content(proxy_res->body, proxy_res->get_header_value("Content-Type")); + for (const auto & h : proxy_res->headers) { + // skip hop-by-hop headers + if (h.first != "Transfer-Encoding" && h.first != "Connection") + res.set_header(h.first, h.second); + } +} + +static void proxy_handler_get(const httplib::Request & req, httplib::Response & res) { + proxy_request(req, res, "GET"); +} + +static void proxy_handler_post(const httplib::Request & req, httplib::Response & res) { + proxy_request(req, res, "POST"); +} diff --git a/examples/server/server.cpp b/examples/server/server.cpp index e7e55634..df2557f8 100644 --- a/examples/server/server.cpp +++ b/examples/server/server.cpp @@ -2,6 +2,7 @@ #include "server-context.h" #include "server-common.h" #include "server-chat.h" +#include "server-cors-proxy.h" #include "chat.h" #include "common.h" @@ -1020,7 +1021,8 @@ int main(int argc, char ** argv) { {"vision", ctx_server.chat_params.allow_image}, {"audio", ctx_server.chat_params.allow_audio}, } }, - { "n_ctx", ctx_server.n_ctx } + { "n_ctx", ctx_server.n_ctx }, + { "cors_proxy_enabled", ctx_server.params_base.webui_mcp_proxy}, }; @@ -2108,6 +2110,16 @@ int main(int argc, char ** argv) { } #endif } + + // CORS proxy (EXPERIMENTAL, only used by the Web UI for MCP) + if (params.webui_mcp_proxy) { + SRV_WRN("%s", "-----------------\n"); + SRV_WRN("%s", "CORS proxy is enabled, do not expose server to untrusted environments\n"); + SRV_WRN("%s", "This feature is EXPERIMENTAL and may be removed or changed in future versions\n"); + SRV_WRN("%s", "-----------------\n"); + svr->Get("/cors-proxy", proxy_handler_get); + svr->Post("/cors-proxy", proxy_handler_post); + } // // Start the server // diff --git a/vendor/cpp-httplib/CMakeLists.txt b/vendor/cpp-httplib/CMakeLists.txt index 3b42fc8c..739a3e36 100644 --- a/vendor/cpp-httplib/CMakeLists.txt +++ b/vendor/cpp-httplib/CMakeLists.txt @@ -22,7 +22,93 @@ target_compile_definitions(${TARGET} PRIVATE CPPHTTPLIB_TCP_NODELAY=1 ) -if (LLAMA_OPENSSL) +set(OPENSSL_NO_ASM ON CACHE BOOL "Disable OpenSSL ASM code when building BoringSSL or LibreSSL") + +if (LLAMA_BUILD_BORINGSSL) + set(FIPS OFF CACHE BOOL "Enable FIPS (BoringSSL)") + + set(BORINGSSL_GIT "https://boringssl.googlesource.com/boringssl" CACHE STRING "BoringSSL git repository") + set(BORINGSSL_VERSION "0.20260508.0" CACHE STRING "BoringSSL version") + + message(STATUS "Fetching BoringSSL version ${BORINGSSL_VERSION}") + + set(BORINGSSL_ARGS + GIT_REPOSITORY ${BORINGSSL_GIT} + GIT_TAG ${BORINGSSL_VERSION} + ) + if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28) + list(APPEND BORINGSSL_ARGS EXCLUDE_FROM_ALL) + endif() + + include(FetchContent) + FetchContent_Declare(boringssl ${BORINGSSL_ARGS}) + + set(SAVED_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) + set(SAVED_BUILD_TESTING ${BUILD_TESTING}) + + set(BUILD_SHARED_LIBS OFF) + set(BUILD_TESTING OFF) + + if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28) + FetchContent_MakeAvailable(boringssl) + else() + FetchContent_GetProperties(boringssl) + if(NOT boringssl_POPULATED) + FetchContent_Populate(boringssl) + add_subdirectory(${boringssl_SOURCE_DIR} ${boringssl_BINARY_DIR} EXCLUDE_FROM_ALL) + endif() + endif() + + set(BUILD_SHARED_LIBS ${SAVED_BUILD_SHARED_LIBS}) + set(BUILD_TESTING ${SAVED_BUILD_TESTING}) + + + set(CPPHTTPLIB_OPENSSL_SUPPORT TRUE) + target_link_libraries(${TARGET} PUBLIC ssl crypto) + +elseif (LLAMA_BUILD_LIBRESSL) + set(LIBRESSL_VERSION "4.3.1" CACHE STRING "LibreSSL version") + + message(STATUS "Fetching LibreSSL version ${LIBRESSL_VERSION}") + + set(LIBRESSL_ARGS + URL "https://cdn.openbsd.org/pub/OpenBSD/LibreSSL/libressl-${LIBRESSL_VERSION}.tar.gz" + ) + if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.24) + list(APPEND LIBRESSL_ARGS DOWNLOAD_EXTRACT_TIMESTAMP TRUE) + endif() + + if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28) + list(APPEND LIBRESSL_ARGS EXCLUDE_FROM_ALL) + endif() + + include(FetchContent) + FetchContent_Declare(libressl ${LIBRESSL_ARGS}) + + set(SAVED_BUILD_SHARED_LIBS ${BUILD_SHARED_LIBS}) + set(SAVED_BUILD_TESTING ${BUILD_TESTING}) + + set(BUILD_SHARED_LIBS OFF) + set(BUILD_TESTING OFF) + + if(CMAKE_VERSION VERSION_GREATER_EQUAL 3.28) + FetchContent_MakeAvailable(libressl) + else() + FetchContent_GetProperties(libressl) + if(NOT libressl_POPULATED) + FetchContent_Populate(libressl) + add_subdirectory(${libressl_SOURCE_DIR} ${libressl_BINARY_DIR} EXCLUDE_FROM_ALL) + endif() + endif() + + set(BUILD_SHARED_LIBS ${SAVED_BUILD_SHARED_LIBS}) + set(BUILD_TESTING ${SAVED_BUILD_TESTING}) + + + set(CPPHTTPLIB_OPENSSL_SUPPORT TRUE) + target_link_libraries(${TARGET} PUBLIC ssl crypto) + +elseif (LLAMA_OPENSSL) find_package(OpenSSL) if (OpenSSL_FOUND) include(CheckCSourceCompiles) @@ -44,17 +130,51 @@ if (LLAMA_OPENSSL) set(CMAKE_REQUIRED_INCLUDES ${SAVED_CMAKE_REQUIRED_INCLUDES}) if (OPENSSL_VERSION_SUPPORTED) message(STATUS "OpenSSL found: ${OPENSSL_VERSION}") - target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_OPENSSL_SUPPORT) + set(CPPHTTPLIB_OPENSSL_SUPPORT TRUE) target_link_libraries(${TARGET} PUBLIC OpenSSL::SSL OpenSSL::Crypto) - if (APPLE AND CMAKE_SYSTEM_NAME STREQUAL "Darwin") - target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) - find_library(CORE_FOUNDATION_FRAMEWORK CoreFoundation REQUIRED) - find_library(SECURITY_FRAMEWORK Security REQUIRED) - target_link_libraries(${TARGET} PUBLIC ${CORE_FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK}) - endif() endif() else() - message(STATUS "OpenSSL not found, SSL support disabled") + message(WARNING "OpenSSL not found, HTTPS support disabled") endif() endif() +# disable warnings in 3rd party code +if(LLAMA_BUILD_BORINGSSL OR LLAMA_BUILD_LIBRESSL) + if (CMAKE_CXX_COMPILER_ID STREQUAL "MSVC") + target_compile_options(ssl PRIVATE /w) + target_compile_options(crypto PRIVATE /w) + if(LLAMA_BUILD_BORINGSSL) + target_compile_options(fipsmodule PRIVATE /w) + endif() + if(LLAMA_BUILD_LIBRESSL) + target_compile_options(ssl_obj PRIVATE /w) + target_compile_options(bs_obj PRIVATE /w) + target_compile_options(compat_obj PRIVATE /w) + target_compile_options(crypto_obj PRIVATE /w) + endif() + else() + target_compile_options(ssl PRIVATE -w) + target_compile_options(crypto PRIVATE -w) + if(LLAMA_BUILD_BORINGSSL) + target_compile_options(fipsmodule PRIVATE -w) + endif() + if(LLAMA_BUILD_LIBRESSL) + target_compile_options(ssl_obj PRIVATE -w) + target_compile_options(bs_obj PRIVATE -w) + target_compile_options(compat_obj PRIVATE -w) + target_compile_options(crypto_obj PRIVATE -w) + endif() + endif() +endif() + +if (CPPHTTPLIB_OPENSSL_SUPPORT) + target_compile_definitions(${TARGET} PUBLIC CPPHTTPLIB_OPENSSL_SUPPORT) # used in server.cpp + if (APPLE AND CMAKE_SYSTEM_NAME STREQUAL "Darwin") + find_library(CORE_FOUNDATION_FRAMEWORK CoreFoundation REQUIRED) + find_library(SECURITY_FRAMEWORK Security REQUIRED) + target_link_libraries(${TARGET} PUBLIC ${CORE_FOUNDATION_FRAMEWORK} ${SECURITY_FRAMEWORK}) + endif() + if (WIN32 AND NOT MSVC) + target_link_libraries(${TARGET} PUBLIC crypt32) + endif() +endif() diff --git a/vendor/cpp-httplib/httplib.cpp b/vendor/cpp-httplib/httplib.cpp index 5432db69..f3555f2d 100644 --- a/vendor/cpp-httplib/httplib.cpp +++ b/vendor/cpp-httplib/httplib.cpp @@ -1,15 +1,484 @@ #include "httplib.h" namespace httplib { - /* * Implementation that will be part of the .cc file if split into .h + .cc. */ +namespace stream { + +// stream::Result implementations +Result::Result() : chunk_size_(8192) {} + +Result::Result(ClientImpl::StreamHandle &&handle, size_t chunk_size) + : handle_(std::move(handle)), chunk_size_(chunk_size) {} + +Result::Result(Result &&other) noexcept + : handle_(std::move(other.handle_)), buffer_(std::move(other.buffer_)), + current_size_(other.current_size_), chunk_size_(other.chunk_size_), + finished_(other.finished_) { + other.current_size_ = 0; + other.finished_ = true; +} + +Result &Result::operator=(Result &&other) noexcept { + if (this != &other) { + handle_ = std::move(other.handle_); + buffer_ = std::move(other.buffer_); + current_size_ = other.current_size_; + chunk_size_ = other.chunk_size_; + finished_ = other.finished_; + other.current_size_ = 0; + other.finished_ = true; + } + return *this; +} + +bool Result::is_valid() const { return handle_.is_valid(); } +Result::operator bool() const { return is_valid(); } + +int Result::status() const { + return handle_.response ? handle_.response->status : -1; +} + +const Headers &Result::headers() const { + static const Headers empty_headers; + return handle_.response ? handle_.response->headers : empty_headers; +} + +std::string Result::get_header_value(const std::string &key, + const char *def) const { + return handle_.response ? handle_.response->get_header_value(key, def) : def; +} + +bool Result::has_header(const std::string &key) const { + return handle_.response ? handle_.response->has_header(key) : false; +} + +Error Result::error() const { return handle_.error; } +Error Result::read_error() const { return handle_.get_read_error(); } +bool Result::has_read_error() const { return handle_.has_read_error(); } + +bool Result::next() { + if (!handle_.is_valid() || finished_) { return false; } + + if (buffer_.size() < chunk_size_) { buffer_.resize(chunk_size_); } + + ssize_t n = handle_.read(&buffer_[0], chunk_size_); + if (n > 0) { + current_size_ = static_cast(n); + return true; + } + + current_size_ = 0; + finished_ = true; + return false; +} + +const char *Result::data() const { return buffer_.data(); } +size_t Result::size() const { return current_size_; } + +std::string Result::read_all() { + std::string result; + while (next()) { + result.append(data(), size()); + } + return result; +} + +} // namespace stream + +namespace sse { + +// SSEMessage implementations +SSEMessage::SSEMessage() : event("message") {} + +void SSEMessage::clear() { + event = "message"; + data.clear(); + id.clear(); +} + +// SSEClient implementations +SSEClient::SSEClient(Client &client, const std::string &path) + : client_(client), path_(path) {} + +SSEClient::SSEClient(Client &client, const std::string &path, + const Headers &headers) + : client_(client), path_(path), headers_(headers) {} + +SSEClient::~SSEClient() { stop(); } + +SSEClient &SSEClient::on_message(MessageHandler handler) { + on_message_ = std::move(handler); + return *this; +} + +SSEClient &SSEClient::on_event(const std::string &type, + MessageHandler handler) { + event_handlers_[type] = std::move(handler); + return *this; +} + +SSEClient &SSEClient::on_open(OpenHandler handler) { + on_open_ = std::move(handler); + return *this; +} + +SSEClient &SSEClient::on_error(ErrorHandler handler) { + on_error_ = std::move(handler); + return *this; +} + +SSEClient &SSEClient::set_reconnect_interval(int ms) { + reconnect_interval_ms_ = ms; + return *this; +} + +SSEClient &SSEClient::set_max_reconnect_attempts(int n) { + max_reconnect_attempts_ = n; + return *this; +} + +SSEClient &SSEClient::set_headers(const Headers &headers) { + std::lock_guard lock(headers_mutex_); + headers_ = headers; + return *this; +} + +bool SSEClient::is_connected() const { return connected_.load(); } + +const std::string &SSEClient::last_event_id() const { + return last_event_id_; +} + +void SSEClient::start() { + running_.store(true); + run_event_loop(); +} + +void SSEClient::start_async() { + running_.store(true); + async_thread_ = std::thread([this]() { run_event_loop(); }); +} + +void SSEClient::stop() { + running_.store(false); + client_.stop(); // Cancel any pending operations + if (async_thread_.joinable()) { async_thread_.join(); } +} + +bool SSEClient::parse_sse_line(const std::string &line, SSEMessage &msg, + int &retry_ms) { + // Blank line signals end of event + if (line.empty() || line == "\r") { return true; } + + // Lines starting with ':' are comments (ignored) + if (!line.empty() && line[0] == ':') { return false; } + + // Find the colon separator + auto colon_pos = line.find(':'); + if (colon_pos == std::string::npos) { + // Line with no colon is treated as field name with empty value + return false; + } + + auto field = line.substr(0, colon_pos); + std::string value; + + // Value starts after colon, skip optional single space + if (colon_pos + 1 < line.size()) { + auto value_start = colon_pos + 1; + if (line[value_start] == ' ') { value_start++; } + value = line.substr(value_start); + // Remove trailing \r if present + if (!value.empty() && value.back() == '\r') { value.pop_back(); } + } + + // Handle known fields + if (field == "event") { + msg.event = value; + } else if (field == "data") { + // Multiple data lines are concatenated with newlines + if (!msg.data.empty()) { msg.data += "\n"; } + msg.data += value; + } else if (field == "id") { + // Empty id is valid (clears the last event ID) + msg.id = value; + } else if (field == "retry") { + // Parse retry interval in milliseconds + { + int v = 0; + auto res = + detail::from_chars(value.data(), value.data() + value.size(), v); + if (res.ec == std::errc{}) { retry_ms = v; } + } + } + // Unknown fields are ignored per SSE spec + + return false; +} + +void SSEClient::run_event_loop() { + auto reconnect_count = 0; + + while (running_.load()) { + // Build headers, including Last-Event-ID if we have one + Headers request_headers; + { + std::lock_guard lock(headers_mutex_); + request_headers = headers_; + } + if (!last_event_id_.empty()) { + request_headers.emplace("Last-Event-ID", last_event_id_); + } + + // Open streaming connection + auto result = stream::Get(client_, path_, request_headers); + + // Connection error handling + if (!result) { + connected_.store(false); + if (on_error_) { on_error_(result.error()); } + + if (!should_reconnect(reconnect_count)) { break; } + wait_for_reconnect(); + reconnect_count++; + continue; + } + + if (result.status() != StatusCode::OK_200) { + connected_.store(false); + if (on_error_) { on_error_(Error::Connection); } + + // For certain errors, don't reconnect. + // Note: 401 is intentionally absent so that handlers can refresh + // credentials via set_headers() and let the client reconnect. + if (result.status() == StatusCode::NoContent_204 || + result.status() == StatusCode::NotFound_404 || + result.status() == StatusCode::Forbidden_403) { + break; + } + + if (!should_reconnect(reconnect_count)) { break; } + wait_for_reconnect(); + reconnect_count++; + continue; + } + + // Connection successful + connected_.store(true); + reconnect_count = 0; + if (on_open_) { on_open_(); } + + // Event receiving loop + std::string buffer; + SSEMessage current_msg; + + while (running_.load() && result.next()) { + buffer.append(result.data(), result.size()); + + // Process complete lines in the buffer + size_t line_start = 0; + size_t newline_pos; + + while ((newline_pos = buffer.find('\n', line_start)) != + std::string::npos) { + auto line = buffer.substr(line_start, newline_pos - line_start); + line_start = newline_pos + 1; + + // Parse the line and check if event is complete + auto event_complete = + parse_sse_line(line, current_msg, reconnect_interval_ms_); + + if (event_complete && !current_msg.data.empty()) { + // Update last_event_id for reconnection + if (!current_msg.id.empty()) { last_event_id_ = current_msg.id; } + + // Dispatch event to appropriate handler + dispatch_event(current_msg); + + current_msg.clear(); + } + } + + // Keep unprocessed data in buffer + buffer.erase(0, line_start); + } + + // Connection ended + connected_.store(false); + + if (!running_.load()) { break; } + + // Check for read errors + if (result.has_read_error()) { + if (on_error_) { on_error_(result.read_error()); } + } + + if (!should_reconnect(reconnect_count)) { break; } + wait_for_reconnect(); + reconnect_count++; + } + + connected_.store(false); +} + +void SSEClient::dispatch_event(const SSEMessage &msg) { + // Check for specific event type handler first + auto it = event_handlers_.find(msg.event); + if (it != event_handlers_.end()) { + it->second(msg); + return; + } + + // Fall back to generic message handler + if (on_message_) { on_message_(msg); } +} + +bool SSEClient::should_reconnect(int count) const { + if (!running_.load()) { return false; } + if (max_reconnect_attempts_ == 0) { return true; } // unlimited + return count < max_reconnect_attempts_; +} + +void SSEClient::wait_for_reconnect() { + // Use small increments to check running_ flag frequently + auto waited = 0; + while (running_.load() && waited < reconnect_interval_ms_) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + waited += 100; + } +} + +} // namespace sse + +#ifdef CPPHTTPLIB_SSL_ENABLED +/* + * TLS abstraction layer - internal function declarations + * These are implementation details and not part of the public API. + */ +namespace tls { + +// Client context +ctx_t create_client_context(); +void free_context(ctx_t ctx); +bool set_min_version(ctx_t ctx, Version version); +bool load_ca_pem(ctx_t ctx, const char *pem, size_t len); +bool load_ca_file(ctx_t ctx, const char *file_path); +bool load_ca_dir(ctx_t ctx, const char *dir_path); +bool load_system_certs(ctx_t ctx); +bool set_client_cert_pem(ctx_t ctx, const char *cert, const char *key, + const char *password); +bool set_client_cert_file(ctx_t ctx, const char *cert_path, + const char *key_path, const char *password); + +// Server context +ctx_t create_server_context(); +bool set_server_cert_pem(ctx_t ctx, const char *cert, const char *key, + const char *password); +bool set_server_cert_file(ctx_t ctx, const char *cert_path, + const char *key_path, const char *password); +bool set_client_ca_file(ctx_t ctx, const char *ca_file, const char *ca_dir); +void set_verify_client(ctx_t ctx, bool require); + +// Session management +session_t create_session(ctx_t ctx, socket_t sock); +void free_session(session_t session); +bool set_sni(session_t session, const char *hostname); +bool set_hostname(session_t session, const char *hostname); + +// Handshake (non-blocking capable) +TlsError connect(session_t session); +TlsError accept(session_t session); + +// Handshake with timeout (blocking until timeout) +bool connect_nonblocking(session_t session, socket_t sock, time_t timeout_sec, + time_t timeout_usec, TlsError *err); +bool accept_nonblocking(session_t session, socket_t sock, time_t timeout_sec, + time_t timeout_usec, TlsError *err); + +// I/O (non-blocking capable) +ssize_t read(session_t session, void *buf, size_t len, TlsError &err); +ssize_t write(session_t session, const void *buf, size_t len, TlsError &err); +int pending(const_session_t session); +void shutdown(session_t session, bool graceful); + +// Connection state +bool is_peer_closed(session_t session, socket_t sock); + +// Certificate verification +cert_t get_peer_cert(const_session_t session); +void free_cert(cert_t cert); +bool verify_hostname(cert_t cert, const char *hostname); +uint64_t hostname_mismatch_code(); +long get_verify_result(const_session_t session); + +// Certificate introspection +std::string get_cert_subject_cn(cert_t cert); +std::string get_cert_issuer_name(cert_t cert); +bool get_cert_sans(cert_t cert, std::vector &sans); +bool get_cert_validity(cert_t cert, time_t ¬_before, time_t ¬_after); +std::string get_cert_serial(cert_t cert); +bool get_cert_der(cert_t cert, std::vector &der); +const char *get_sni(const_session_t session); + +// CA store management +ca_store_t create_ca_store(const char *pem, size_t len); +void free_ca_store(ca_store_t store); +bool set_ca_store(ctx_t ctx, ca_store_t store); +size_t get_ca_certs(ctx_t ctx, std::vector &certs); +std::vector get_ca_names(ctx_t ctx); + +// Dynamic certificate update (for servers) +bool update_server_cert(ctx_t ctx, const char *cert_pem, const char *key_pem, + const char *password); +bool update_server_client_ca(ctx_t ctx, const char *ca_pem); + +// Certificate verification callback +bool set_verify_callback(ctx_t ctx, VerifyCallback callback); +long get_verify_error(const_session_t session); +std::string verify_error_string(long error_code); + +// TlsError information +uint64_t peek_error(); +uint64_t get_error(); +std::string error_string(uint64_t code); + +} // namespace tls +#endif // CPPHTTPLIB_SSL_ENABLED + +/* + * Group 1: detail namespace - Non-SSL utilities + */ + namespace detail { +bool set_socket_opt_impl(socket_t sock, int level, int optname, + const void *optval, socklen_t optlen) { + return setsockopt(sock, level, optname, +#ifdef _WIN32 + reinterpret_cast(optval), +#else + optval, +#endif + optlen) == 0; +} + +bool set_socket_opt_time(socket_t sock, int level, int optname, + time_t sec, time_t usec) { +#ifdef _WIN32 + auto timeout = static_cast(sec * 1000 + usec / 1000); +#else + timeval timeout; + timeout.tv_sec = static_cast(sec); + timeout.tv_usec = static_cast(usec); +#endif + return set_socket_opt_impl(sock, level, optname, &timeout, sizeof(timeout)); +} + bool is_hex(char c, int &v) { - if (0x20 <= c && isdigit(c)) { + if (isdigit(c)) { v = c - '0'; return true; } else if ('A' <= c && c <= 'F') { @@ -49,6 +518,92 @@ std::string from_i_to_hex(size_t n) { return ret; } +std::string compute_etag(const FileStat &fs) { + if (!fs.is_file()) { return std::string(); } + + // If mtime cannot be determined (negative value indicates an error + // or sentinel), do not generate an ETag. Returning a neutral / fixed + // value like 0 could collide with a real file that legitimately has + // mtime == 0 (epoch) and lead to misleading validators. + auto mtime_raw = fs.mtime(); + if (mtime_raw < 0) { return std::string(); } + + auto mtime = static_cast(mtime_raw); + auto size = fs.size(); + + return std::string("W/\"") + from_i_to_hex(mtime) + "-" + + from_i_to_hex(size) + "\""; +} + +// Format time_t as HTTP-date (RFC 9110 Section 5.6.7): "Sun, 06 Nov 1994 +// 08:49:37 GMT" This implementation is defensive: it validates `mtime`, checks +// return values from `gmtime_r`/`gmtime_s`, and ensures `strftime` succeeds. +std::string file_mtime_to_http_date(time_t mtime) { + if (mtime < 0) { return std::string(); } + + struct tm tm_buf; +#ifdef _WIN32 + if (gmtime_s(&tm_buf, &mtime) != 0) { return std::string(); } +#else + if (gmtime_r(&mtime, &tm_buf) == nullptr) { return std::string(); } +#endif + char buf[64]; + if (strftime(buf, sizeof(buf), "%a, %d %b %Y %H:%M:%S GMT", &tm_buf) == 0) { + return std::string(); + } + + return std::string(buf); +} + +// Parse HTTP-date (RFC 9110 Section 5.6.7) to time_t. Returns -1 on failure. +time_t parse_http_date(const std::string &date_str) { + struct tm tm_buf; + + // Create a classic locale object once for all parsing attempts + const std::locale classic_locale = std::locale::classic(); + + // Try to parse using std::get_time (C++11, cross-platform) + auto try_parse = [&](const char *fmt) -> bool { + std::istringstream ss(date_str); + ss.imbue(classic_locale); + + memset(&tm_buf, 0, sizeof(tm_buf)); + ss >> std::get_time(&tm_buf, fmt); + + return !ss.fail(); + }; + + // RFC 9110 preferred format (HTTP-date): "Sun, 06 Nov 1994 08:49:37 GMT" + if (!try_parse("%a, %d %b %Y %H:%M:%S")) { + // RFC 850 format: "Sunday, 06-Nov-94 08:49:37 GMT" + if (!try_parse("%A, %d-%b-%y %H:%M:%S")) { + // asctime format: "Sun Nov 6 08:49:37 1994" + if (!try_parse("%a %b %d %H:%M:%S %Y")) { + return static_cast(-1); + } + } + } + +#ifdef _WIN32 + return _mkgmtime(&tm_buf); +#elif defined _AIX + return mktime(&tm_buf); +#else + return timegm(&tm_buf); +#endif +} + +bool is_weak_etag(const std::string &s) { + // Check if the string is a weak ETag (starts with 'W/"') + return s.size() > 3 && s[0] == 'W' && s[1] == '/' && s[2] == '"'; +} + +bool is_strong_etag(const std::string &s) { + // Check if the string is a strong ETag (starts and ends with '"', at least 2 + // chars) + return s.size() >= 2 && s[0] == '"' && s.back() == '"'; +} + size_t to_utf8(int code, char *buff) { if (code < 0x0080) { buff[0] = static_cast(code & 0x7F); @@ -81,6 +636,56 @@ size_t to_utf8(int code, char *buff) { return 0; } +} // namespace detail + +namespace ws { +namespace impl { + +bool is_valid_utf8(const std::string &s) { + size_t i = 0; + auto n = s.size(); + while (i < n) { + auto c = static_cast(s[i]); + size_t len; + uint32_t cp; + if (c < 0x80) { + i++; + continue; + } else if ((c & 0xE0) == 0xC0) { + len = 2; + cp = c & 0x1F; + } else if ((c & 0xF0) == 0xE0) { + len = 3; + cp = c & 0x0F; + } else if ((c & 0xF8) == 0xF0) { + len = 4; + cp = c & 0x07; + } else { + return false; + } + if (i + len > n) { return false; } + for (size_t j = 1; j < len; j++) { + auto b = static_cast(s[i + j]); + if ((b & 0xC0) != 0x80) { return false; } + cp = (cp << 6) | (b & 0x3F); + } + // Overlong encoding check + if (len == 2 && cp < 0x80) { return false; } + if (len == 3 && cp < 0x800) { return false; } + if (len == 4 && cp < 0x10000) { return false; } + // Surrogate halves (U+D800..U+DFFF) and beyond U+10FFFF are invalid + if (cp >= 0xD800 && cp <= 0xDFFF) { return false; } + if (cp > 0x10FFFF) { return false; } + i += len; + } + return true; +} + +} // namespace impl +} // namespace ws + +namespace detail { + // NOTE: This code came up with the following stackoverflow post: // https://stackoverflow.com/questions/180947/base64-decode-snippet-in-c std::string base64_encode(const std::string &in) { @@ -111,6 +716,278 @@ std::string base64_encode(const std::string &in) { return out; } +std::string sha1(const std::string &input) { + // RFC 3174 SHA-1 implementation + auto left_rotate = [](uint32_t x, uint32_t n) -> uint32_t { + return (x << n) | (x >> (32 - n)); + }; + + uint32_t h0 = 0x67452301; + uint32_t h1 = 0xEFCDAB89; + uint32_t h2 = 0x98BADCFE; + uint32_t h3 = 0x10325476; + uint32_t h4 = 0xC3D2E1F0; + + // Pre-processing: adding padding bits + std::string msg = input; + uint64_t original_bit_len = static_cast(msg.size()) * 8; + msg.push_back(static_cast(0x80)); + while (msg.size() % 64 != 56) { + msg.push_back(0); + } + + // Append original length in bits as 64-bit big-endian + for (int i = 56; i >= 0; i -= 8) { + msg.push_back(static_cast((original_bit_len >> i) & 0xFF)); + } + + // Process each 512-bit chunk + for (size_t offset = 0; offset < msg.size(); offset += 64) { + uint32_t w[80]; + + for (size_t i = 0; i < 16; i++) { + w[i] = + (static_cast(static_cast(msg[offset + i * 4])) + << 24) | + (static_cast(static_cast(msg[offset + i * 4 + 1])) + << 16) | + (static_cast(static_cast(msg[offset + i * 4 + 2])) + << 8) | + (static_cast( + static_cast(msg[offset + i * 4 + 3]))); + } + + for (int i = 16; i < 80; i++) { + w[i] = left_rotate(w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16], 1); + } + + uint32_t a = h0, b = h1, c = h2, d = h3, e = h4; + + for (int i = 0; i < 80; i++) { + uint32_t f, k; + if (i < 20) { + f = (b & c) | ((~b) & d); + k = 0x5A827999; + } else if (i < 40) { + f = b ^ c ^ d; + k = 0x6ED9EBA1; + } else if (i < 60) { + f = (b & c) | (b & d) | (c & d); + k = 0x8F1BBCDC; + } else { + f = b ^ c ^ d; + k = 0xCA62C1D6; + } + + uint32_t temp = left_rotate(a, 5) + f + e + k + w[i]; + e = d; + d = c; + c = left_rotate(b, 30); + b = a; + a = temp; + } + + h0 += a; + h1 += b; + h2 += c; + h3 += d; + h4 += e; + } + + // Produce the final hash as a 20-byte binary string + std::string hash(20, '\0'); + for (size_t i = 0; i < 4; i++) { + hash[i] = static_cast((h0 >> (24 - i * 8)) & 0xFF); + hash[4 + i] = static_cast((h1 >> (24 - i * 8)) & 0xFF); + hash[8 + i] = static_cast((h2 >> (24 - i * 8)) & 0xFF); + hash[12 + i] = static_cast((h3 >> (24 - i * 8)) & 0xFF); + hash[16 + i] = static_cast((h4 >> (24 - i * 8)) & 0xFF); + } + return hash; +} + +std::string websocket_accept_key(const std::string &client_key) { + const std::string magic = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + return base64_encode(sha1(client_key + magic)); +} + +bool is_websocket_upgrade(const Request &req) { + if (req.method != "GET") { return false; } + + // Check Upgrade: websocket (case-insensitive) + auto upgrade_it = req.headers.find("Upgrade"); + if (upgrade_it == req.headers.end()) { return false; } + auto upgrade_val = case_ignore::to_lower(upgrade_it->second); + if (upgrade_val != "websocket") { return false; } + + // Check Connection header contains "Upgrade" + auto connection_it = req.headers.find("Connection"); + if (connection_it == req.headers.end()) { return false; } + auto connection_val = case_ignore::to_lower(connection_it->second); + if (connection_val.find("upgrade") == std::string::npos) { return false; } + + // Check Sec-WebSocket-Key is a valid base64-encoded 16-byte value (24 chars) + // RFC 6455 Section 4.2.1 + auto ws_key = req.get_header_value("Sec-WebSocket-Key"); + if (ws_key.size() != 24 || ws_key[22] != '=' || ws_key[23] != '=') { + return false; + } + static const std::string b64chars = + "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; + for (size_t i = 0; i < 22; i++) { + if (b64chars.find(ws_key[i]) == std::string::npos) { return false; } + } + + // Check Sec-WebSocket-Version: 13 + auto version = req.get_header_value("Sec-WebSocket-Version"); + if (version != "13") { return false; } + + return true; +} + +bool write_websocket_frame(Stream &strm, ws::Opcode opcode, + const char *data, size_t len, bool fin, + bool mask) { + // First byte: FIN + opcode + uint8_t header[2]; + header[0] = static_cast((fin ? 0x80 : 0x00) | + (static_cast(opcode) & 0x0F)); + + // Second byte: MASK + payload length + if (len < 126) { + header[1] = static_cast(len); + if (mask) { header[1] |= 0x80; } + if (strm.write(reinterpret_cast(header), 2) < 0) { return false; } + } else if (len <= 0xFFFF) { + header[1] = 126; + if (mask) { header[1] |= 0x80; } + if (strm.write(reinterpret_cast(header), 2) < 0) { return false; } + uint8_t ext[2]; + ext[0] = static_cast((len >> 8) & 0xFF); + ext[1] = static_cast(len & 0xFF); + if (strm.write(reinterpret_cast(ext), 2) < 0) { return false; } + } else { + header[1] = 127; + if (mask) { header[1] |= 0x80; } + if (strm.write(reinterpret_cast(header), 2) < 0) { return false; } + uint8_t ext[8]; + for (int i = 7; i >= 0; i--) { + ext[7 - i] = + static_cast((static_cast(len) >> (i * 8)) & 0xFF); + } + if (strm.write(reinterpret_cast(ext), 8) < 0) { return false; } + } + + if (mask) { + // Generate random mask key + thread_local std::mt19937 rng(std::random_device{}()); + uint8_t mask_key[4]; + auto r = rng(); + std::memcpy(mask_key, &r, 4); + if (strm.write(reinterpret_cast(mask_key), 4) < 0) { return false; } + + // Write masked payload in chunks + const size_t chunk_size = 4096; + std::vector buf((std::min)(len, chunk_size)); + for (size_t offset = 0; offset < len; offset += chunk_size) { + size_t n = (std::min)(chunk_size, len - offset); + for (size_t i = 0; i < n; i++) { + buf[i] = + data[offset + i] ^ static_cast(mask_key[(offset + i) % 4]); + } + if (strm.write(buf.data(), n) < 0) { return false; } + } + } else { + if (len > 0) { + if (strm.write(data, len) < 0) { return false; } + } + } + + return true; +} + +} // namespace detail + +namespace ws { +namespace impl { + +bool read_websocket_frame(Stream &strm, Opcode &opcode, + std::string &payload, bool &fin, + bool expect_masked, size_t max_len) { + // Read first 2 bytes + uint8_t header[2]; + if (strm.read(reinterpret_cast(header), 2) != 2) { return false; } + + fin = (header[0] & 0x80) != 0; + + // RSV1, RSV2, RSV3 must be 0 when no extension is negotiated + if (header[0] & 0x70) { return false; } + + opcode = static_cast(header[0] & 0x0F); + bool masked = (header[1] & 0x80) != 0; + uint64_t payload_len = header[1] & 0x7F; + + // RFC 6455 Section 5.5: control frames MUST NOT be fragmented and + // MUST have a payload length of 125 bytes or less + bool is_control = (static_cast(opcode) & 0x08) != 0; + if (is_control) { + if (!fin) { return false; } + if (payload_len > 125) { return false; } + } + + if (masked != expect_masked) { return false; } + + // Extended payload length + if (payload_len == 126) { + uint8_t ext[2]; + if (strm.read(reinterpret_cast(ext), 2) != 2) { return false; } + payload_len = (static_cast(ext[0]) << 8) | ext[1]; + } else if (payload_len == 127) { + uint8_t ext[8]; + if (strm.read(reinterpret_cast(ext), 8) != 8) { return false; } + // RFC 6455 Section 5.2: the most significant bit MUST be 0 + if (ext[0] & 0x80) { return false; } + payload_len = 0; + for (int i = 0; i < 8; i++) { + payload_len = (payload_len << 8) | ext[i]; + } + } + + if (payload_len > max_len) { return false; } + + // Read mask key if present + uint8_t mask_key[4] = {0}; + if (masked) { + if (strm.read(reinterpret_cast(mask_key), 4) != 4) { return false; } + } + + // Read payload + payload.resize(static_cast(payload_len)); + if (payload_len > 0) { + size_t total_read = 0; + while (total_read < payload_len) { + auto n = strm.read(&payload[total_read], + static_cast(payload_len - total_read)); + if (n <= 0) { return false; } + total_read += static_cast(n); + } + } + + // Unmask if needed + if (masked) { + for (size_t i = 0; i < payload.size(); i++) { + payload[i] ^= static_cast(mask_key[i % 4]); + } + } + + return true; +} + +} // namespace impl +} // namespace ws + +namespace detail { + bool is_valid_path(const std::string &path) { size_t level = 0; size_t i = 0; @@ -153,6 +1030,35 @@ bool is_valid_path(const std::string &path) { return true; } +bool canonicalize_path(const char *path, std::string &resolved) { +#if defined(_WIN32) + char buf[_MAX_PATH]; + if (_fullpath(buf, path, _MAX_PATH) == nullptr) { return false; } + resolved = buf; +#elif defined(PATH_MAX) + char buf[PATH_MAX]; + if (realpath(path, buf) == nullptr) { return false; } + resolved = buf; +#else + auto buf = realpath(path, nullptr); + auto guard = scope_exit([&]() { std::free(buf); }); + if (buf == nullptr) { return false; } + resolved = buf; +#endif + return true; +} + +bool is_path_within_base(const std::string &resolved_path, + const std::string &resolved_base) { +#if defined(_WIN32) + return _strnicmp(resolved_path.c_str(), resolved_base.c_str(), + resolved_base.size()) == 0; +#else + return strncmp(resolved_path.c_str(), resolved_base.c_str(), + resolved_base.size()) == 0; +#endif +} + FileStat::FileStat(const std::string &path) { #if defined(_WIN32) auto wpath = u8string_to_wstring(path.c_str()); @@ -168,6 +1074,15 @@ bool FileStat::is_dir() const { return ret_ >= 0 && S_ISDIR(st_.st_mode); } +time_t FileStat::mtime() const { + return ret_ >= 0 ? static_cast(st_.st_mtime) + : static_cast(-1); +} + +size_t FileStat::size() const { + return ret_ >= 0 ? static_cast(st_.st_size) : 0; +} + std::string encode_path(const std::string &s) { std::string result; result.reserve(s.size()); @@ -209,6 +1124,148 @@ std::string file_extension(const std::string &path) { bool is_space_or_tab(char c) { return c == ' ' || c == '\t'; } +template +bool parse_header(const char *beg, const char *end, T fn); + +template +bool parse_header(const char *beg, const char *end, T fn) { + // Skip trailing spaces and tabs. + while (beg < end && is_space_or_tab(end[-1])) { + end--; + } + + auto p = beg; + while (p < end && *p != ':') { + p++; + } + + auto name = std::string(beg, p); + if (!detail::fields::is_field_name(name)) { return false; } + + if (p == end) { return false; } + + auto key_end = p; + + if (*p++ != ':') { return false; } + + while (p < end && is_space_or_tab(*p)) { + p++; + } + + if (p <= end) { + auto key_len = key_end - beg; + if (!key_len) { return false; } + + auto key = std::string(beg, key_end); + auto val = std::string(p, end); + + if (!detail::fields::is_field_value(val)) { return false; } + + // RFC 9110 ยง5.5: header field values are opaque octets and MUST NOT be + // percent-decoded by the recipient. Applications that need to interpret a + // value as a URI component should call httplib::decode_uri_component() + // (or decode_path_component()) explicitly. + fn(key, val); + + return true; + } + + return false; +} + +bool parse_trailers(stream_line_reader &line_reader, Headers &dest, + const Headers &src_headers) { + // NOTE: In RFC 9112, '7.1 Chunked Transfer Coding' mentions "The chunked + // transfer coding is complete when a chunk with a chunk-size of zero is + // received, possibly followed by a trailer section, and finally terminated by + // an empty line". https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1 + // + // In '7.1.3. Decoding Chunked', however, the pseudo-code in the section + // doesn't care for the existence of the final CRLF. In other words, it seems + // to be ok whether the final CRLF exists or not in the chunked data. + // https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1.3 + // + // According to the reference code in RFC 9112, cpp-httplib now allows + // chunked transfer coding data without the final CRLF. + + // RFC 7230 Section 4.1.2 - Headers prohibited in trailers + thread_local case_ignore::unordered_set prohibited_trailers = { + "transfer-encoding", + "content-length", + "host", + "authorization", + "www-authenticate", + "proxy-authenticate", + "proxy-authorization", + "cookie", + "set-cookie", + "cache-control", + "expect", + "max-forwards", + "pragma", + "range", + "te", + "age", + "expires", + "date", + "location", + "retry-after", + "vary", + "warning", + "content-encoding", + "content-type", + "content-range", + "trailer"}; + + case_ignore::unordered_set declared_trailers; + auto trailer_header = get_header_value(src_headers, "Trailer", "", 0); + if (trailer_header && std::strlen(trailer_header)) { + auto len = std::strlen(trailer_header); + split(trailer_header, trailer_header + len, ',', + [&](const char *b, const char *e) { + const char *kbeg = b; + const char *kend = e; + while (kbeg < kend && (*kbeg == ' ' || *kbeg == '\t')) { + ++kbeg; + } + while (kend > kbeg && (kend[-1] == ' ' || kend[-1] == '\t')) { + --kend; + } + std::string key(kbeg, static_cast(kend - kbeg)); + if (!key.empty() && + prohibited_trailers.find(key) == prohibited_trailers.end()) { + declared_trailers.insert(key); + } + }); + } + + size_t trailer_header_count = 0; + while (strcmp(line_reader.ptr(), "\r\n") != 0) { + if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; } + if (trailer_header_count >= CPPHTTPLIB_HEADER_MAX_COUNT) { return false; } + + constexpr auto line_terminator_len = 2; + auto line_beg = line_reader.ptr(); + auto line_end = + line_reader.ptr() + line_reader.size() - line_terminator_len; + + if (!parse_header(line_beg, line_end, + [&](const std::string &key, const std::string &val) { + if (declared_trailers.find(key) != + declared_trailers.end()) { + dest.emplace(key, val); + trailer_header_count++; + } + })) { + return false; + } + + if (!line_reader.getline()) { return false; } + } + + return true; +} + std::pair trim(const char *b, const char *e, size_t left, size_t right) { while (b + left < e && is_space_or_tab(b[left])) { @@ -280,6 +1337,42 @@ void split(const char *b, const char *e, char d, size_t m, } } +bool split_find(const char *b, const char *e, char d, size_t m, + std::function fn) { + size_t i = 0; + size_t beg = 0; + size_t count = 1; + + while (e ? (b + i < e) : (b[i] != '\0')) { + if (b[i] == d && count < m) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { + auto found = fn(&b[r.first], &b[r.second]); + if (found) { return true; } + } + beg = i + 1; + count++; + } + i++; + } + + if (i) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { + auto found = fn(&b[r.first], &b[r.second]); + if (found) { return true; } + } + } + + return false; +} + +bool split_find(const char *b, const char *e, char d, + std::function fn) { + return split_find(b, e, d, (std::numeric_limits::max)(), + std::move(fn)); +} + stream_line_reader::stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size) : strm_(strm), fixed_buffer_(fixed_buffer), @@ -370,8 +1463,9 @@ bool mmap::open(const char *path) { auto wpath = u8string_to_wstring(path); if (wpath.empty()) { return false; } - hFile_ = ::CreateFile2(wpath.c_str(), GENERIC_READ, FILE_SHARE_READ, - OPEN_EXISTING, NULL); + hFile_ = + ::CreateFile2(wpath.c_str(), GENERIC_READ, + FILE_SHARE_READ | FILE_SHARE_WRITE, OPEN_EXISTING, NULL); if (hFile_ == INVALID_HANDLE_VALUE) { return false; } @@ -473,7 +1567,7 @@ void mmap::close() { #endif size_ = 0; } -int close_socket(socket_t sock) { +int close_socket(socket_t sock) noexcept { #ifdef _WIN32 return closesocket(sock); #else @@ -481,7 +1575,7 @@ int close_socket(socket_t sock) { #endif } -template inline ssize_t handle_EINTR(T fn) { +template ssize_t handle_EINTR(T fn) { ssize_t res = 0; while (true) { res = fn(); @@ -527,78 +1621,32 @@ int poll_wrapper(struct pollfd *fds, nfds_t nfds, int timeout) { #endif } -template -ssize_t select_impl(socket_t sock, time_t sec, time_t usec) { -#ifdef __APPLE__ - if (sock >= FD_SETSIZE) { return -1; } - - fd_set fds, *rfds, *wfds; - FD_ZERO(&fds); - FD_SET(sock, &fds); - rfds = (Read ? &fds : nullptr); - wfds = (Read ? nullptr : &fds); - - timeval tv; - tv.tv_sec = static_cast(sec); - tv.tv_usec = static_cast(usec); - - return handle_EINTR([&]() { - return select(static_cast(sock + 1), rfds, wfds, nullptr, &tv); - }); -#else +ssize_t select_impl(socket_t sock, short events, time_t sec, + time_t usec) { struct pollfd pfd; pfd.fd = sock; - pfd.events = (Read ? POLLIN : POLLOUT); + pfd.events = events; + pfd.revents = 0; auto timeout = static_cast(sec * 1000 + usec / 1000); return handle_EINTR([&]() { return poll_wrapper(&pfd, 1, timeout); }); -#endif } ssize_t select_read(socket_t sock, time_t sec, time_t usec) { - return select_impl(sock, sec, usec); + return select_impl(sock, POLLIN, sec, usec); } ssize_t select_write(socket_t sock, time_t sec, time_t usec) { - return select_impl(sock, sec, usec); + return select_impl(sock, POLLOUT, sec, usec); } Error wait_until_socket_is_ready(socket_t sock, time_t sec, time_t usec) { -#ifdef __APPLE__ - if (sock >= FD_SETSIZE) { return Error::Connection; } - - fd_set fdsr, fdsw; - FD_ZERO(&fdsr); - FD_ZERO(&fdsw); - FD_SET(sock, &fdsr); - FD_SET(sock, &fdsw); - - timeval tv; - tv.tv_sec = static_cast(sec); - tv.tv_usec = static_cast(usec); - - auto ret = handle_EINTR([&]() { - return select(static_cast(sock + 1), &fdsr, &fdsw, nullptr, &tv); - }); - - if (ret == 0) { return Error::ConnectionTimeout; } - - if (ret > 0 && (FD_ISSET(sock, &fdsr) || FD_ISSET(sock, &fdsw))) { - auto error = 0; - socklen_t len = sizeof(error); - auto res = getsockopt(sock, SOL_SOCKET, SO_ERROR, - reinterpret_cast(&error), &len); - auto successful = res >= 0 && !error; - return successful ? Error::Success : Error::Connection; - } - - return Error::Connection; -#else struct pollfd pfd_read; pfd_read.fd = sock; pfd_read.events = POLLIN | POLLOUT; + pfd_read.revents = 0; auto timeout = static_cast(sec * 1000 + usec / 1000); @@ -617,7 +1665,6 @@ Error wait_until_socket_is_ready(socket_t sock, time_t sec, } return Error::Connection; -#endif } bool is_socket_alive(socket_t sock) { @@ -643,12 +1690,14 @@ public: bool is_readable() const override; bool wait_readable() const override; bool wait_writable() const override; + bool is_peer_alive() const override; ssize_t read(char *ptr, size_t size) override; ssize_t write(const char *ptr, size_t size) override; void get_remote_ip_and_port(std::string &ip, int &port) const override; void get_local_ip_and_port(std::string &ip, int &port) const override; socket_t socket() const override; time_t duration() const override; + void set_read_timeout(time_t sec, time_t usec = 0) override; private: socket_t sock_; @@ -666,39 +1715,6 @@ private: static const size_t read_buff_size_ = 1024l * 4; }; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -class SSLSocketStream final : public Stream { -public: - SSLSocketStream( - socket_t sock, SSL *ssl, time_t read_timeout_sec, - time_t read_timeout_usec, time_t write_timeout_sec, - time_t write_timeout_usec, time_t max_timeout_msec = 0, - std::chrono::time_point start_time = - (std::chrono::steady_clock::time_point::min)()); - ~SSLSocketStream() override; - - bool is_readable() const override; - bool wait_readable() const override; - bool wait_writable() const override; - ssize_t read(char *ptr, size_t size) override; - ssize_t write(const char *ptr, size_t size) override; - void get_remote_ip_and_port(std::string &ip, int &port) const override; - void get_local_ip_and_port(std::string &ip, int &port) const override; - socket_t socket() const override; - time_t duration() const override; - -private: - socket_t sock_; - SSL *ssl_; - time_t read_timeout_sec_; - time_t read_timeout_usec_; - time_t write_timeout_sec_; - time_t write_timeout_usec_; - time_t max_timeout_msec_; - const std::chrono::time_point start_time_; -}; -#endif - bool keep_alive(const std::atomic &svr_sock, socket_t sock, time_t keep_alive_timeout_sec) { using namespace std::chrono; @@ -778,7 +1794,7 @@ bool process_client_socket( return callback(strm); } -int shutdown_socket(socket_t sock) { +int shutdown_socket(socket_t sock) noexcept { #ifdef _WIN32 return shutdown(sock, SD_BOTH); #else @@ -865,7 +1881,8 @@ int getaddrinfo_with_timeout(const char *node, const char *service, } return ret; -#elif TARGET_OS_MAC +#elif TARGET_OS_MAC && defined(__clang__) + if (!node) { return EAI_NONAME; } // macOS implementation using CFHost API for asynchronous DNS resolution CFStringRef hostname_ref = CFStringCreateWithCString( kCFAllocatorDefault, node, kCFStringEncodingUTF8); @@ -1012,9 +2029,9 @@ int getaddrinfo_with_timeout(const char *node, const char *service, memcpy((*current)->ai_addr, sockaddr_ptr, sockaddr_len); // Set port if service is specified - if (service && strlen(service) > 0) { - int port = atoi(service); - if (port > 0) { + if (service && *service) { + int port = 0; + if (parse_port(service, strlen(service), port)) { if (sockaddr_ptr->sa_family == AF_INET) { reinterpret_cast((*current)->ai_addr) ->sin_port = htons(static_cast(port)); @@ -1035,74 +2052,76 @@ int getaddrinfo_with_timeout(const char *node, const char *service, return 0; #elif defined(_GNU_SOURCE) && defined(__GLIBC__) && \ (__GLIBC__ > 2 || (__GLIBC__ == 2 && __GLIBC_MINOR__ >= 2)) - // Linux implementation using getaddrinfo_a for asynchronous DNS resolution - struct gaicb request; + // #2431: gai_cancel() is non-blocking and may return EAI_NOTCANCELED while + // the resolver worker still references the stack-local gaicb. The cancel + // path therefore waits (gai_suspend with no timeout) for the worker to + // actually finish before letting the stack frame go. The trade-off is that + // a wedged DNS server can hold this thread for the system resolver timeout + // (~30s by default) past the caller's connection timeout. + struct gaicb request {}; struct gaicb *requests[1] = {&request}; - struct sigevent sevp; - struct timespec timeout; + struct sigevent sevp {}; + struct timespec timeout { + timeout_sec, 0 + }; - // Initialize the request structure - memset(&request, 0, sizeof(request)); request.ar_name = node; request.ar_service = service; request.ar_request = hints; - - // Set up timeout - timeout.tv_sec = timeout_sec; - timeout.tv_nsec = 0; - - // Initialize sigevent structure (not used, but required) - memset(&sevp, 0, sizeof(sevp)); sevp.sigev_notify = SIGEV_NONE; - // Start asynchronous resolution - int start_result = getaddrinfo_a(GAI_NOWAIT, requests, 1, &sevp); - if (start_result != 0) { return start_result; } + int rc = getaddrinfo_a(GAI_NOWAIT, requests, 1, &sevp); + if (rc != 0) { return rc; } - // Wait for completion with timeout - int wait_result = - gai_suspend((const struct gaicb *const *)requests, 1, &timeout); + auto cleanup = scope_exit([&] { + if (request.ar_result) { freeaddrinfo(request.ar_result); } + }); + + int wait_result = gai_suspend(requests, 1, &timeout); if (wait_result == 0 || wait_result == EAI_ALLDONE) { - // Completed successfully, get the result int gai_result = gai_error(&request); if (gai_result == 0) { *res = request.ar_result; + request.ar_result = nullptr; return 0; - } else { - // Clean up on error - if (request.ar_result) { freeaddrinfo(request.ar_result); } - return gai_result; } - } else if (wait_result == EAI_AGAIN) { - // Timeout occurred, cancel the request - gai_cancel(&request); - return EAI_AGAIN; - } else { - // Other error occurred - gai_cancel(&request); - return wait_result; + return gai_result; } + + gai_cancel(&request); + while (gai_error(&request) == EAI_INPROGRESS) { + gai_suspend(requests, 1, nullptr); + } + return wait_result; #else - // Fallback implementation using thread-based timeout for other Unix systems + // Fallback implementation using thread-based timeout for other Unix systems. struct GetAddrInfoState { + ~GetAddrInfoState() { + if (info) { freeaddrinfo(info); } + } + std::mutex mutex; std::condition_variable result_cv; bool completed = false; int result = EAI_SYSTEM; - std::string node = node; - std::string service = service; - struct addrinfo hints = hints; + std::string node; + std::string service; + struct addrinfo hints; struct addrinfo *info = nullptr; }; // Allocate on the heap, so the resolver thread can keep using the data. auto state = std::make_shared(); + if (node) { state->node = node; } + state->service = service; + state->hints = *hints; - std::thread resolve_thread([=]() { - auto thread_result = getaddrinfo( - state->node.c_str(), state->service.c_str(), hints, &state->info); + std::thread resolve_thread([state]() { + auto thread_result = + getaddrinfo(state->node.c_str(), state->service.c_str(), &state->hints, + &state->info); std::lock_guard lock(state->mutex); state->result = thread_result; @@ -1120,6 +2139,7 @@ int getaddrinfo_with_timeout(const char *node, const char *service, // Operation completed within timeout resolve_thread.join(); *res = state->info; + state->info = nullptr; // Pass ownership to caller return state->result; } else { // Timeout occurred @@ -1192,7 +2212,7 @@ socket_t create_socket(const std::string &host, const std::string &ip, int port, #ifdef _WIN32 // Setting SO_REUSEADDR seems not to work well with AF_UNIX on windows, so // remove the option. - detail::set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 0); + set_socket_opt(sock, SOL_SOCKET, SO_REUSEADDR, 0); #endif bool dummy; @@ -1485,6 +2505,10 @@ void get_remote_ip_and_port(socket_t sock, std::string &ip, int &port) { } } +// Recursive form retained so operator""_t below can compute hashes for +// switch-case labels at compile time (C++11 constexpr forbids loops). Do not +// call from runtime paths with arbitrary-length inputs โ€” use str2tag() +// instead, which is iterative and stack-safe. constexpr unsigned int str2tag_core(const char *s, size_t l, unsigned int h) { return (l == 0) @@ -1498,7 +2522,16 @@ constexpr unsigned int str2tag_core(const char *s, size_t l, } unsigned int str2tag(const std::string &s) { - return str2tag_core(s.data(), s.size(), 0); + // Iterative form of str2tag_core: the recursive constexpr version is kept + // for compile-time UDL evaluation of short string literals, but at runtime + // we may receive arbitrarily long inputs (e.g. fuzzed Content-Type) that + // would blow the stack with one frame per character. + unsigned int h = 0; + for (auto c : s) { + h = (((std::numeric_limits::max)() >> 6) & h * 33) ^ + static_cast(c); + } + return h; } namespace udl { @@ -1575,52 +2608,203 @@ find_content_type(const std::string &path, } } +std::string +extract_media_type(const std::string &content_type, + std::map *params = nullptr) { + // Extract type/subtype from Content-Type value (RFC 2045) + // e.g. "application/json; charset=utf-8" -> "application/json" + auto media_type = content_type; + auto semicolon_pos = media_type.find(';'); + if (semicolon_pos != std::string::npos) { + auto param_str = media_type.substr(semicolon_pos + 1); + media_type = media_type.substr(0, semicolon_pos); + + if (params) { + // Parse parameters: key=value pairs separated by ';' + split(param_str.data(), param_str.data() + param_str.size(), ';', + [&](const char *b, const char *e) { + std::string key; + std::string val; + split(b, e, '=', [&](const char *b2, const char *e2) { + if (key.empty()) { + key.assign(b2, e2); + } else { + val.assign(b2, e2); + } + }); + if (!key.empty()) { + params->emplace(trim_copy(key), trim_double_quotes_copy(val)); + } + }); + } + } + + // Trim whitespace from media type + return trim_copy(media_type); +} + bool can_compress_content_type(const std::string &content_type) { using udl::operator""_t; - auto tag = str2tag(content_type); + auto mime_type = extract_media_type(content_type); + auto tag = str2tag(mime_type); switch (tag) { case "image/svg+xml"_t: case "application/javascript"_t: + case "application/x-javascript"_t: case "application/json"_t: + case "application/ld+json"_t: case "application/xml"_t: - case "application/protobuf"_t: - case "application/xhtml+xml"_t: return true; + case "application/xhtml+xml"_t: + case "application/rss+xml"_t: + case "application/atom+xml"_t: + case "application/xslt+xml"_t: + case "application/protobuf"_t: return true; case "text/event-stream"_t: return false; - default: return !content_type.rfind("text/", 0); + default: return !mime_type.rfind("text/", 0); } } +bool parse_quality(const char *b, const char *e, std::string &token, + double &quality) { + quality = 1.0; + token.clear(); + + // Split on first ';': left = token name, right = parameters + const char *params_b = nullptr; + std::size_t params_len = 0; + + divide( + b, static_cast(e - b), ';', + [&](const char *lb, std::size_t llen, const char *rb, std::size_t rlen) { + auto r = trim(lb, lb + llen, 0, llen); + if (r.first < r.second) { token.assign(lb + r.first, lb + r.second); } + params_b = rb; + params_len = rlen; + }); + + if (token.empty()) { return false; } + if (params_len == 0) { return true; } + + // Scan parameters for q= (stops on first match) + bool invalid = false; + split_find(params_b, params_b + params_len, ';', + (std::numeric_limits::max)(), + [&](const char *pb, const char *pe) -> bool { + // Match exactly "q=" or "Q=" (not "query=" etc.) + auto len = static_cast(pe - pb); + if (len < 2) { return false; } + if ((pb[0] != 'q' && pb[0] != 'Q') || pb[1] != '=') { + return false; + } + + // Trim the value portion + auto r = trim(pb, pe, 2, len); + if (r.first >= r.second) { + invalid = true; + return true; + } + + double v = 0.0; + auto res = from_chars(pb + r.first, pb + r.second, v); + if (res.ec != std::errc{} || v < 0.0 || v > 1.0) { + invalid = true; + return true; + } + quality = v; + return true; + }); + + return !invalid; +} + EncodingType encoding_type(const Request &req, const Response &res) { - auto ret = - detail::can_compress_content_type(res.get_header_value("Content-Type")); - if (!ret) { return EncodingType::None; } + if (!can_compress_content_type(res.get_header_value("Content-Type"))) { + return EncodingType::None; + } const auto &s = req.get_header_value("Accept-Encoding"); - (void)(s); + if (s.empty()) { return EncodingType::None; } + // Single-pass: iterate tokens and track the best supported encoding. + // Server preference breaks ties (br > gzip > zstd). + EncodingType best = EncodingType::None; + double best_q = 0.0; // q=0 means "not acceptable" + + // Server preference: Brotli > Gzip > Zstd (lower = more preferred) + auto priority = [](EncodingType t) -> int { + switch (t) { + case EncodingType::Brotli: return 0; + case EncodingType::Gzip: return 1; + case EncodingType::Zstd: return 2; + default: return 3; + } + }; + + std::string name; + split(s.data(), s.data() + s.size(), ',', [&](const char *b, const char *e) { + double quality = 1.0; + if (!parse_quality(b, e, name, quality)) { return; } + if (quality <= 0.0) { return; } + + EncodingType type = EncodingType::None; #ifdef CPPHTTPLIB_BROTLI_SUPPORT - // TODO: 'Accept-Encoding' has br, not br;q=0 - ret = s.find("br") != std::string::npos; - if (ret) { return EncodingType::Brotli; } + if (case_ignore::equal(name, "br")) { type = EncodingType::Brotli; } #endif - #ifdef CPPHTTPLIB_ZLIB_SUPPORT - // TODO: 'Accept-Encoding' has gzip, not gzip;q=0 - ret = s.find("gzip") != std::string::npos; - if (ret) { return EncodingType::Gzip; } + if (type == EncodingType::None && case_ignore::equal(name, "gzip")) { + type = EncodingType::Gzip; + } #endif - #ifdef CPPHTTPLIB_ZSTD_SUPPORT - // TODO: 'Accept-Encoding' has zstd, not zstd;q=0 - ret = s.find("zstd") != std::string::npos; - if (ret) { return EncodingType::Zstd; } + if (type == EncodingType::None && case_ignore::equal(name, "zstd")) { + type = EncodingType::Zstd; + } #endif - return EncodingType::None; + if (type == EncodingType::None) { return; } + + // Higher q-value wins; for equal q, server preference breaks ties + if (quality > best_q || + (quality == best_q && priority(type) < priority(best))) { + best_q = quality; + best = type; + } + }); + + return best; +} + +std::unique_ptr make_compressor(EncodingType type) { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (type == EncodingType::Gzip) { + return detail::make_unique(); + } +#endif +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + if (type == EncodingType::Brotli) { + return detail::make_unique(); + } +#endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + if (type == EncodingType::Zstd) { + return detail::make_unique(); + } +#endif + (void)type; + return nullptr; +} + +const char *encoding_name(EncodingType type) { + switch (type) { + case EncodingType::Gzip: return "gzip"; + case EncodingType::Brotli: return "br"; + case EncodingType::Zstd: return "zstd"; + default: return ""; + } } bool nocompressor::compress(const char *data, size_t data_length, @@ -1883,6 +3067,42 @@ bool zstd_decompressor::decompress(const char *data, size_t data_length, } #endif +std::unique_ptr +create_decompressor(const std::string &encoding) { + std::unique_ptr decompressor; + + if (encoding == "gzip" || encoding == "deflate") { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + decompressor = detail::make_unique(); +#endif + } else if (encoding.find("br") != std::string::npos) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + decompressor = detail::make_unique(); +#endif + } else if (encoding == "zstd" || encoding.find("zstd") != std::string::npos) { +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + decompressor = detail::make_unique(); +#endif + } + + return decompressor; +} + +// Returns the best available compressor and its Content-Encoding name. +// Priority: Brotli > Gzip > Zstd (matches server-side preference). +std::pair, const char *> +create_compressor() { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + return {detail::make_unique(), "br"}; +#elif defined(CPPHTTPLIB_ZLIB_SUPPORT) + return {detail::make_unique(), "gzip"}; +#elif defined(CPPHTTPLIB_ZSTD_SUPPORT) + return {detail::make_unique(), "zstd"}; +#else + return {nullptr, nullptr}; +#endif +} + bool is_prohibited_header_name(const std::string &name) { using udl::operator""_t; @@ -1919,51 +3139,27 @@ const char *get_header_value(const Headers &headers, return def; } -template -bool parse_header(const char *beg, const char *end, T fn) { - // Skip trailing spaces and tabs. - while (beg < end && is_space_or_tab(end[-1])) { - end--; +size_t get_header_value_count(const Headers &headers, + const std::string &key) { + auto r = headers.equal_range(key); + return static_cast(std::distance(r.first, r.second)); +} + +template +typename Map::mapped_type +get_multimap_value(const Map &m, const std::string &key, size_t id) { + auto rng = m.equal_range(key); + auto it = rng.first; + std::advance(it, static_cast(id)); + if (it != rng.second) { return it->second; } + return typename Map::mapped_type(); +} + +void set_header(Headers &headers, const std::string &key, + const std::string &val) { + if (fields::is_field_name(key) && fields::is_field_value(val)) { + headers.emplace(key, val); } - - auto p = beg; - while (p < end && *p != ':') { - p++; - } - - auto name = std::string(beg, p); - if (!detail::fields::is_field_name(name)) { return false; } - - if (p == end) { return false; } - - auto key_end = p; - - if (*p++ != ':') { return false; } - - while (p < end && is_space_or_tab(*p)) { - p++; - } - - if (p <= end) { - auto key_len = key_end - beg; - if (!key_len) { return false; } - - auto key = std::string(beg, key_end); - auto val = std::string(p, end); - - if (!detail::fields::is_field_value(val)) { return false; } - - if (case_ignore::equal(key, "Location") || - case_ignore::equal(key, "Referer")) { - fn(key, val); - } else { - fn(key, decode_path_component(val)); - } - - return true; - } - - return false; } bool read_headers(Stream &strm, Headers &headers) { @@ -2009,40 +3205,57 @@ bool read_headers(Stream &strm, Headers &headers) { header_count++; } - return true; -} - -bool read_content_with_length(Stream &strm, size_t len, - DownloadProgress progress, - ContentReceiverWithProgress out) { - char buf[CPPHTTPLIB_RECV_BUFSIZ]; - - size_t r = 0; - while (r < len) { - auto read_len = static_cast(len - r); - auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); - if (n <= 0) { return false; } - - if (!out(buf, static_cast(n), r, len)) { return false; } - r += static_cast(n); - - if (progress) { - if (!progress(r, len)) { return false; } + // RFC 9110 Section 8.6: Reject requests with multiple Content-Length + // headers that have different values to prevent request smuggling. + auto cl_range = headers.equal_range("Content-Length"); + if (cl_range.first != cl_range.second) { + const auto &first_val = cl_range.first->second; + for (auto it = std::next(cl_range.first); it != cl_range.second; ++it) { + if (it->second != first_val) { return false; } } } return true; } -void skip_content_with_length(Stream &strm, size_t len) { - char buf[CPPHTTPLIB_RECV_BUFSIZ]; - size_t r = 0; - while (r < len) { - auto read_len = static_cast(len - r); - auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); - if (n <= 0) { return; } - r += static_cast(n); - } +bool read_websocket_upgrade_response(Stream &strm, + const std::string &expected_accept, + std::string &selected_subprotocol) { + // Read status line + const auto bufsiz = 2048; + char buf[bufsiz]; + stream_line_reader line_reader(strm, buf, bufsiz); + if (!line_reader.getline()) { return false; } + + // Check for "HTTP/1.1 101" + auto line = std::string(line_reader.ptr(), line_reader.size()); + if (line.find("HTTP/1.1 101") == std::string::npos) { return false; } + + // Parse headers using existing read_headers + Headers headers; + if (!read_headers(strm, headers)) { return false; } + + // Verify Upgrade: websocket (case-insensitive) + auto upgrade_it = headers.find("Upgrade"); + if (upgrade_it == headers.end()) { return false; } + auto upgrade_val = case_ignore::to_lower(upgrade_it->second); + if (upgrade_val != "websocket") { return false; } + + // Verify Connection header contains "Upgrade" (case-insensitive) + auto connection_it = headers.find("Connection"); + if (connection_it == headers.end()) { return false; } + auto connection_val = case_ignore::to_lower(connection_it->second); + if (connection_val.find("upgrade") == std::string::npos) { return false; } + + // Verify Sec-WebSocket-Accept header value + auto it = headers.find("Sec-WebSocket-Accept"); + if (it == headers.end() || it->second != expected_accept) { return false; } + + // Extract negotiated subprotocol + auto proto_it = headers.find("Sec-WebSocket-Protocol"); + if (proto_it != headers.end()) { selected_subprotocol = proto_it->second; } + + return true; } enum class ReadContentResult { @@ -2051,6 +3264,47 @@ enum class ReadContentResult { Error // An error occurred while reading the content }; +ReadContentResult read_content_with_length( + Stream &strm, size_t len, DownloadProgress progress, + ContentReceiverWithProgress out, + size_t payload_max_length = (std::numeric_limits::max)()) { + char buf[CPPHTTPLIB_RECV_BUFSIZ]; + + detail::BodyReader br; + br.stream = &strm; + br.has_content_length = true; + br.content_length = len; + br.payload_max_length = payload_max_length; + br.chunked = false; + br.bytes_read = 0; + br.last_error = Error::Success; + + size_t r = 0; + while (r < len) { + auto read_len = static_cast(len - r); + auto to_read = (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ); + auto n = detail::read_body_content(&strm, br, buf, to_read); + if (n <= 0) { + // Check if it was a payload size error + if (br.last_error == Error::ExceedMaxPayloadSize) { + return ReadContentResult::PayloadTooLarge; + } + return ReadContentResult::Error; + } + + if (!out(buf, static_cast(n), r, len)) { + return ReadContentResult::Error; + } + r += static_cast(n); + + if (progress) { + if (!progress(r, len)) { return ReadContentResult::Error; } + } + } + + return ReadContentResult::Success; +} + ReadContentResult read_content_without_length(Stream &strm, size_t payload_max_length, ContentReceiverWithProgress out) { @@ -2080,125 +3334,35 @@ template ReadContentResult read_content_chunked(Stream &strm, T &x, size_t payload_max_length, ContentReceiverWithProgress out) { - const auto bufsiz = 16; - char buf[bufsiz]; + detail::ChunkedDecoder dec(strm); - stream_line_reader line_reader(strm, buf, bufsiz); - - if (!line_reader.getline()) { return ReadContentResult::Error; } - - unsigned long chunk_len; + char buf[CPPHTTPLIB_RECV_BUFSIZ]; size_t total_len = 0; - while (true) { - char *end_ptr; - chunk_len = std::strtoul(line_reader.ptr(), &end_ptr, 16); + for (;;) { + size_t chunk_offset = 0; + size_t chunk_total = 0; + auto n = dec.read_payload(buf, sizeof(buf), chunk_offset, chunk_total); + if (n < 0) { return ReadContentResult::Error; } - if (end_ptr == line_reader.ptr()) { return ReadContentResult::Error; } - if (chunk_len == ULONG_MAX) { return ReadContentResult::Error; } + if (n == 0) { + if (!dec.parse_trailers_into(x.trailers, x.headers)) { + return ReadContentResult::Error; + } + return ReadContentResult::Success; + } - if (chunk_len == 0) { break; } - - // Check if adding this chunk would exceed the payload limit if (total_len > payload_max_length || - payload_max_length - total_len < chunk_len) { + payload_max_length - total_len < static_cast(n)) { return ReadContentResult::PayloadTooLarge; } - total_len += chunk_len; - - if (!read_content_with_length(strm, chunk_len, nullptr, out)) { + if (!out(buf, static_cast(n), chunk_offset, chunk_total)) { return ReadContentResult::Error; } - if (!line_reader.getline()) { return ReadContentResult::Error; } - - if (strcmp(line_reader.ptr(), "\r\n") != 0) { - return ReadContentResult::Error; - } - - if (!line_reader.getline()) { return ReadContentResult::Error; } + total_len += static_cast(n); } - - assert(chunk_len == 0); - - // NOTE: In RFC 9112, '7.1 Chunked Transfer Coding' mentions "The chunked - // transfer coding is complete when a chunk with a chunk-size of zero is - // received, possibly followed by a trailer section, and finally terminated by - // an empty line". https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1 - // - // In '7.1.3. Decoding Chunked', however, the pseudo-code in the section - // does't care for the existence of the final CRLF. In other words, it seems - // to be ok whether the final CRLF exists or not in the chunked data. - // https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1.3 - // - // According to the reference code in RFC 9112, cpp-httplib now allows - // chunked transfer coding data without the final CRLF. - if (!line_reader.getline()) { return ReadContentResult::Success; } - - // RFC 7230 Section 4.1.2 - Headers prohibited in trailers - thread_local case_ignore::unordered_set prohibited_trailers = { - // Message framing - "transfer-encoding", "content-length", - - // Routing - "host", - - // Authentication - "authorization", "www-authenticate", "proxy-authenticate", - "proxy-authorization", "cookie", "set-cookie", - - // Request modifiers - "cache-control", "expect", "max-forwards", "pragma", "range", "te", - - // Response control - "age", "expires", "date", "location", "retry-after", "vary", "warning", - - // Payload processing - "content-encoding", "content-type", "content-range", "trailer"}; - - // Parse declared trailer headers once for performance - case_ignore::unordered_set declared_trailers; - if (has_header(x.headers, "Trailer")) { - auto trailer_header = get_header_value(x.headers, "Trailer", "", 0); - auto len = std::strlen(trailer_header); - - split(trailer_header, trailer_header + len, ',', - [&](const char *b, const char *e) { - std::string key(b, e); - if (prohibited_trailers.find(key) == prohibited_trailers.end()) { - declared_trailers.insert(key); - } - }); - } - - size_t trailer_header_count = 0; - while (strcmp(line_reader.ptr(), "\r\n") != 0) { - if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { - return ReadContentResult::Error; - } - - // Check trailer header count limit - if (trailer_header_count >= CPPHTTPLIB_HEADER_MAX_COUNT) { - return ReadContentResult::Error; - } - - // Exclude line terminator - constexpr auto line_terminator_len = 2; - auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; - - parse_header(line_reader.ptr(), end, - [&](const std::string &key, const std::string &val) { - if (declared_trailers.find(key) != declared_trailers.end()) { - x.trailers.emplace(key, val); - trailer_header_count++; - } - }); - - if (!line_reader.getline()) { return ReadContentResult::Error; } - } - - return ReadContentResult::Success; } bool is_chunked_transfer_encoding(const Headers &headers) { @@ -2209,42 +3373,39 @@ bool is_chunked_transfer_encoding(const Headers &headers) { template bool prepare_content_receiver(T &x, int &status, ContentReceiverWithProgress receiver, - bool decompress, U callback) { + bool decompress, size_t payload_max_length, + bool &exceed_payload_max_length, U callback) { if (decompress) { std::string encoding = x.get_header_value("Content-Encoding"); std::unique_ptr decompressor; - if (encoding == "gzip" || encoding == "deflate") { -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - decompressor = detail::make_unique(); -#else - status = StatusCode::UnsupportedMediaType_415; - return false; -#endif - } else if (encoding.find("br") != std::string::npos) { -#ifdef CPPHTTPLIB_BROTLI_SUPPORT - decompressor = detail::make_unique(); -#else - status = StatusCode::UnsupportedMediaType_415; - return false; -#endif - } else if (encoding == "zstd") { -#ifdef CPPHTTPLIB_ZSTD_SUPPORT - decompressor = detail::make_unique(); -#else - status = StatusCode::UnsupportedMediaType_415; - return false; -#endif + if (!encoding.empty()) { + decompressor = detail::create_decompressor(encoding); + if (!decompressor) { + // Unsupported encoding or no support compiled in + status = StatusCode::UnsupportedMediaType_415; + return false; + } } if (decompressor) { if (decompressor->is_valid()) { + size_t decompressed_size = 0; ContentReceiverWithProgress out = [&](const char *buf, size_t n, size_t off, size_t len) { - return decompressor->decompress(buf, n, - [&](const char *buf2, size_t n2) { - return receiver(buf2, n2, off, len); - }); + return decompressor->decompress( + buf, n, [&](const char *buf2, size_t n2) { + // Guard against zip-bomb: check + // decompressed size against limit. + if (payload_max_length > 0 && + (decompressed_size >= payload_max_length || + n2 > payload_max_length - decompressed_size)) { + exceed_payload_max_length = true; + return false; + } + decompressed_size += n2; + return receiver(buf2, n2, off, len); + }); }; return callback(std::move(out)); } else { @@ -2265,11 +3426,14 @@ template bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, DownloadProgress progress, ContentReceiverWithProgress receiver, bool decompress) { + bool exceed_payload_max_length = false; return prepare_content_receiver( - x, status, std::move(receiver), decompress, - [&](const ContentReceiverWithProgress &out) { + x, status, std::move(receiver), decompress, payload_max_length, + exceed_payload_max_length, [&](const ContentReceiverWithProgress &out) { auto ret = true; - auto exceed_payload_max_length = false; + // Note: exceed_payload_max_length may also be set by the decompressor + // wrapper in prepare_content_receiver when the decompressed payload + // size exceeds the limit. if (is_chunked_transfer_encoding(x.headers)) { auto result = read_content_chunked(strm, x, payload_max_length, out); @@ -2300,12 +3464,13 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, if (is_invalid_value) { ret = false; - } else if (len > payload_max_length) { - exceed_payload_max_length = true; - skip_content_with_length(strm, len); - ret = false; } else if (len > 0) { - ret = read_content_with_length(strm, len, std::move(progress), out); + auto result = read_content_with_length( + strm, len, std::move(progress), out, payload_max_length); + ret = (result == ReadContentResult::Success); + if (result == ReadContentResult::PayloadTooLarge) { + exceed_payload_max_length = true; + } } } @@ -2320,7 +3485,7 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, ssize_t write_request_line(Stream &strm, const std::string &method, const std::string &path) { std::string s = method; - s += " "; + s += ' '; s += path; s += " HTTP/1.1\r\n"; return strm.write(s.data(), s.size()); @@ -2329,7 +3494,7 @@ ssize_t write_request_line(Stream &strm, const std::string &method, ssize_t write_response_line(Stream &strm, int status) { std::string s = "HTTP/1.1 "; s += std::to_string(status); - s += " "; + s += ' '; s += httplib::status_message(status); s += "\r\n"; return strm.write(s.data(), s.size()); @@ -2395,10 +3560,10 @@ bool write_content_with_progress(Stream &strm, return ok; }; - data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; + data_sink.is_writable = [&]() -> bool { return strm.is_peer_alive(); }; while (offset < end_offset && !is_shutting_down()) { - if (!strm.wait_writable()) { + if (!strm.wait_writable() || !strm.is_peer_alive()) { error = Error::Write; return false; } else if (!content_provider(offset, end_offset - offset, data_sink)) { @@ -2410,6 +3575,11 @@ bool write_content_with_progress(Stream &strm, } } + if (offset < end_offset) { // exited due to is_shutting_down(), not completion + error = Error::Write; + return false; + } + error = Error::Success; return true; } @@ -2449,12 +3619,12 @@ write_content_without_length(Stream &strm, return ok; }; - data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; + data_sink.is_writable = [&]() -> bool { return strm.is_peer_alive(); }; data_sink.done = [&](void) { data_available = false; }; while (data_available && !is_shutting_down()) { - if (!strm.wait_writable()) { + if (!strm.wait_writable() || !strm.is_peer_alive()) { return false; } else if (!content_provider(offset, 0, data_sink)) { return false; @@ -2462,7 +3632,8 @@ write_content_without_length(Stream &strm, return false; } } - return true; + return !data_available; // true only if done() was called, false if shutting + // down } template @@ -2498,7 +3669,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, return ok; }; - data_sink.is_writable = [&]() -> bool { return strm.wait_writable(); }; + data_sink.is_writable = [&]() -> bool { return strm.is_peer_alive(); }; auto done_with_trailer = [&](const Headers *trailer) { if (!ok) { return; } @@ -2548,7 +3719,7 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, }; while (data_available && !is_shutting_down()) { - if (!strm.wait_writable()) { + if (!strm.wait_writable() || !strm.is_peer_alive()) { error = Error::Write; return false; } else if (!content_provider(offset, 0, data_sink)) { @@ -2560,6 +3731,11 @@ write_content_chunked(Stream &strm, const ContentProvider &content_provider, } } + if (data_available) { // exited due to is_shutting_down(), not done() + error = Error::Write; + return false; + } + error = Error::Success; return true; } @@ -2592,8 +3768,8 @@ bool redirect(T &cli, Request &req, Response &res, auto ret = cli.send(new_req, new_res, error); if (ret) { - req = new_req; - res = new_res; + req = std::move(new_req); + res = std::move(new_res); if (res.location.empty()) { res.location = location; } } @@ -2604,9 +3780,9 @@ std::string params_to_query_str(const Params ¶ms) { std::string query; for (auto it = params.begin(); it != params.end(); ++it) { - if (it != params.begin()) { query += "&"; } + if (it != params.begin()) { query += '&'; } query += encode_query_component(it->first); - query += "="; + query += '='; query += encode_query_component(it->second); } return query; @@ -2639,14 +3815,45 @@ void parse_query_text(const std::string &s, Params ¶ms) { parse_query_text(s.data(), s.size(), params); } +// Normalize a query string by decoding and re-encoding each key/value pair +// while preserving the original parameter order. This avoids double-encoding +// and ensures consistent encoding without reordering (unlike Params which +// uses std::multimap and sorts keys). +std::string normalize_query_string(const std::string &query) { + std::string result; + split(query.data(), query.data() + query.size(), '&', + [&](const char *b, const char *e) { + std::string key; + std::string val; + divide(b, static_cast(e - b), '=', + [&](const char *lhs_data, std::size_t lhs_size, + const char *rhs_data, std::size_t rhs_size) { + key.assign(lhs_data, lhs_size); + val.assign(rhs_data, rhs_size); + }); + + if (!key.empty()) { + auto dec_key = decode_query_component(key); + auto dec_val = decode_query_component(val); + + if (!result.empty()) { result += '&'; } + result += encode_query_component(dec_key); + if (!val.empty() || std::find(b, e, '=') != e) { + result += '='; + result += encode_query_component(dec_val); + } + } + }); + return result; +} + bool parse_multipart_boundary(const std::string &content_type, std::string &boundary) { - auto boundary_keyword = "boundary="; - auto pos = content_type.find(boundary_keyword); - if (pos == std::string::npos) { return false; } - auto end = content_type.find(';', pos); - auto beg = pos + strlen(boundary_keyword); - boundary = trim_double_quotes_copy(content_type.substr(beg, end - beg)); + std::map params; + extract_media_type(content_type, ¶ms); + auto it = params.find("boundary"); + if (it == params.end()) { return false; } + boundary = it->second; return !boundary.empty(); } @@ -2704,10 +3911,20 @@ bool parse_range_header(const std::string &s, Ranges &ranges) try { return; } - const auto first = - static_cast(lhs.empty() ? -1 : std::stoll(lhs)); - const auto last = - static_cast(rhs.empty() ? -1 : std::stoll(rhs)); + ssize_t first = -1; + if (!lhs.empty()) { + ssize_t v; + auto res = detail::from_chars(lhs.data(), lhs.data() + lhs.size(), v); + if (res.ec == std::errc{}) { first = v; } + } + + ssize_t last = -1; + if (!rhs.empty()) { + ssize_t v; + auto res = detail::from_chars(rhs.data(), rhs.data() + rhs.size(), v); + if (res.ec == std::errc{}) { last = v; } + } + if ((first == -1 && last == -1) || (first != -1 && last != -1 && first > last)) { all_valid_ranges = false; @@ -2741,7 +3958,7 @@ bool parse_accept_header(const std::string &s, struct AcceptEntry { std::string media_type; double quality; - int order; // Original order in header + int order; }; std::vector entries; @@ -2759,64 +3976,16 @@ bool parse_accept_header(const std::string &s, } AcceptEntry accept_entry; - accept_entry.quality = 1.0; // Default quality accept_entry.order = order++; - // Find q= parameter - auto q_pos = entry.find(";q="); - if (q_pos == std::string::npos) { q_pos = entry.find("; q="); } - - if (q_pos != std::string::npos) { - // Extract media type (before q parameter) - accept_entry.media_type = trim_copy(entry.substr(0, q_pos)); - - // Extract quality value - auto q_start = entry.find('=', q_pos) + 1; - auto q_end = entry.find(';', q_start); - if (q_end == std::string::npos) { q_end = entry.length(); } - - std::string quality_str = - trim_copy(entry.substr(q_start, q_end - q_start)); - if (quality_str.empty()) { - has_invalid_entry = true; - return; - } - -#ifdef CPPHTTPLIB_NO_EXCEPTIONS - { - std::istringstream iss(quality_str); - iss >> accept_entry.quality; - - // Check if conversion was successful and entire string was consumed - if (iss.fail() || !iss.eof()) { - has_invalid_entry = true; - return; - } - } -#else - try { - accept_entry.quality = std::stod(quality_str); - } catch (...) { - has_invalid_entry = true; - return; - } -#endif - // Check if quality is in valid range [0.0, 1.0] - if (accept_entry.quality < 0.0 || accept_entry.quality > 1.0) { - has_invalid_entry = true; - return; - } - } else { - // No quality parameter, use entire entry as media type - accept_entry.media_type = entry; + if (!parse_quality(entry.data(), entry.data() + entry.size(), + accept_entry.media_type, accept_entry.quality)) { + has_invalid_entry = true; + return; } // Remove additional parameters from media type - auto param_pos = accept_entry.media_type.find(';'); - if (param_pos != std::string::npos) { - accept_entry.media_type = - trim_copy(accept_entry.media_type.substr(0, param_pos)); - } + accept_entry.media_type = extract_media_type(accept_entry.media_type); // Basic validation of media type format if (accept_entry.media_type.empty()) { @@ -2831,7 +4000,7 @@ bool parse_accept_header(const std::string &s, return; } - entries.push_back(accept_entry); + entries.push_back(std::move(accept_entry)); }); // Return false if any invalid entry was found @@ -2848,8 +4017,8 @@ bool parse_accept_header(const std::string &s, // Extract sorted media types content_types.reserve(entries.size()); - for (const auto &entry : entries) { - content_types.push_back(entry.media_type); + for (auto &entry : entries) { + content_types.push_back(std::move(entry.media_type)); } return true; @@ -2860,7 +4029,7 @@ public: FormDataParser() = default; void set_boundary(std::string &&boundary) { - boundary_ = boundary; + boundary_ = std::move(boundary); dash_boundary_crlf_ = dash_ + boundary_ + crlf_; crlf_dash_boundary_ = crlf_ + dash_ + boundary_; } @@ -2925,14 +4094,10 @@ public: file_.content_type = trim_copy(header.substr(str_len(header_content_type))); } else { - thread_local const std::regex re_content_disposition( - R"~(^Content-Disposition:\s*form-data;\s*(.*)$)~", - std::regex_constants::icase); - - std::smatch m; - if (std::regex_match(header, m, re_content_disposition)) { + std::string disposition_params; + if (parse_content_disposition(header, disposition_params)) { Params params; - parse_disposition_params(m[1], params); + parse_disposition_params(disposition_params, params); auto it = params.find("name"); if (it != params.end()) { @@ -2947,13 +4112,14 @@ public: it = params.find("filename*"); if (it != params.end()) { - // Only allow UTF-8 encoding... - thread_local const std::regex re_rfc5987_encoding( - R"~(^UTF-8''(.+?)$)~", std::regex_constants::icase); - - std::smatch m2; - if (std::regex_match(it->second, m2, re_rfc5987_encoding)) { - file_.filename = decode_path_component(m2[1]); // override... + // RFC 5987: only UTF-8 encoding is allowed + const auto &val = it->second; + constexpr const char utf8_prefix[] = "UTF-8''"; + constexpr size_t prefix_len = str_len(utf8_prefix); + if (val.size() > prefix_len && + start_with_case_ignore(val, utf8_prefix)) { + file_.filename = decode_path_component( + val.substr(prefix_len)); // override... } else { is_valid_ = false; return false; @@ -3021,17 +4187,48 @@ private: file_.headers.clear(); } - bool start_with_case_ignore(const std::string &a, const char *b) const { + bool start_with_case_ignore(const std::string &a, const char *b, + size_t offset = 0) const { const auto b_len = strlen(b); - if (a.size() < b_len) { return false; } + if (a.size() < offset + b_len) { return false; } for (size_t i = 0; i < b_len; i++) { - if (case_ignore::to_lower(a[i]) != case_ignore::to_lower(b[i])) { + if (case_ignore::to_lower(a[offset + i]) != case_ignore::to_lower(b[i])) { return false; } } return true; } + // Parses "Content-Disposition: form-data; " without std::regex. + // Returns true if header matches, with the params portion in `params_out`. + bool parse_content_disposition(const std::string &header, + std::string ¶ms_out) const { + constexpr const char prefix[] = "Content-Disposition:"; + constexpr size_t prefix_len = str_len(prefix); + + if (!start_with_case_ignore(header, prefix)) { return false; } + + // Skip whitespace after "Content-Disposition:" + auto pos = prefix_len; + while (pos < header.size() && (header[pos] == ' ' || header[pos] == '\t')) { + pos++; + } + + // Match "form-data;" (case-insensitive) + constexpr const char form_data[] = "form-data;"; + constexpr size_t form_data_len = str_len(form_data); + if (!start_with_case_ignore(header, form_data, pos)) { return false; } + pos += form_data_len; + + // Skip whitespace after "form-data;" + while (pos < header.size() && (header[pos] == ' ' || header[pos] == '\t')) { + pos++; + } + + params_out = header.substr(pos); + return true; + } + const std::string dash_ = "--"; const std::string crlf_ = "\r\n"; std::string boundary_; @@ -3192,6 +4389,104 @@ serialize_multipart_formdata(const UploadFormDataItems &items, return body; } +size_t get_multipart_content_length(const UploadFormDataItems &items, + const std::string &boundary) { + size_t total = 0; + for (const auto &item : items) { + total += serialize_multipart_formdata_item_begin(item, boundary).size(); + total += item.content.size(); + total += serialize_multipart_formdata_item_end().size(); + } + total += serialize_multipart_formdata_finish(boundary).size(); + return total; +} + +struct MultipartSegment { + const char *data; + size_t size; +}; + +// NOTE: items must outlive the returned ContentProvider +// (safe for synchronous use inside Post/Put/Patch) +ContentProvider +make_multipart_content_provider(const UploadFormDataItems &items, + const std::string &boundary) { + // Own the per-item header strings and the finish string + std::vector owned; + owned.reserve(items.size() + 1); + for (const auto &item : items) + owned.push_back(serialize_multipart_formdata_item_begin(item, boundary)); + owned.push_back(serialize_multipart_formdata_finish(boundary)); + + // Flat segment list: [header, content, "\r\n"] * N + [finish] + std::vector segs; + segs.reserve(items.size() * 3 + 1); + static const char crlf[] = "\r\n"; + for (size_t i = 0; i < items.size(); i++) { + segs.push_back({owned[i].data(), owned[i].size()}); + segs.push_back({items[i].content.data(), items[i].content.size()}); + segs.push_back({crlf, 2}); + } + segs.push_back({owned.back().data(), owned.back().size()}); + + struct MultipartState { + std::vector owned; + std::vector segs; + std::vector buf = std::vector(CPPHTTPLIB_SEND_BUFSIZ); + }; + auto state = std::make_shared(); + state->owned = std::move(owned); + // `segs` holds raw pointers into owned strings; std::string move preserves + // the data pointer, so these pointers remain valid after the move above. + state->segs = std::move(segs); + + return [state](size_t offset, size_t length, DataSink &sink) -> bool { + // Buffer multiple small segments into fewer, larger writes to avoid + // excessive TCP packets when there are many form data items (#2410) + auto &buf = state->buf; + auto buf_size = buf.size(); + size_t buf_len = 0; + size_t remaining = length; + + // Find the first segment containing 'offset' + size_t pos = 0; + size_t seg_idx = 0; + for (; seg_idx < state->segs.size(); seg_idx++) { + const auto &seg = state->segs[seg_idx]; + if (seg.size > 0 && offset - pos < seg.size) { break; } + pos += seg.size; + } + + size_t seg_offset = (seg_idx < state->segs.size()) ? offset - pos : 0; + + for (; seg_idx < state->segs.size() && remaining > 0; seg_idx++) { + const auto &seg = state->segs[seg_idx]; + size_t available = seg.size - seg_offset; + size_t to_copy = (std::min)(available, remaining); + const char *src = seg.data + seg_offset; + seg_offset = 0; // only the first segment has a non-zero offset + + while (to_copy > 0) { + size_t space = buf_size - buf_len; + size_t chunk = (std::min)(to_copy, space); + std::memcpy(buf.data() + buf_len, src, chunk); + buf_len += chunk; + src += chunk; + to_copy -= chunk; + remaining -= chunk; + + if (buf_len == buf_size) { + if (!sink.write(buf.data(), buf_len)) { return false; } + buf_len = 0; + } + } + } + + if (buf_len > 0) { return sink.write(buf.data(), buf_len); } + return true; + }; +} + void coalesce_ranges(Ranges &ranges, size_t content_length) { if (ranges.size() <= 1) return; @@ -3323,7 +4618,8 @@ get_range_offset_and_length(Range r, size_t content_length) { assert(r.first <= r.second && r.second < static_cast(content_length)); (void)(content_length); - return std::make_pair(r.first, static_cast(r.second - r.first) + 1); + return std::make_pair(static_cast(r.first), + static_cast(r.second - r.first) + 1); } std::string make_content_range_header_field( @@ -3333,9 +4629,9 @@ std::string make_content_range_header_field( std::string field = "bytes "; field += std::to_string(st); - field += "-"; + field += '-'; field += std::to_string(ed); - field += "/"; + field += '/'; field += std::to_string(content_length); return field; } @@ -3427,28 +4723,233 @@ write_multipart_ranges_data(Stream &strm, const Request &req, Response &res, }); } +bool has_framed_body(const Request &req) { + return is_chunked_transfer_encoding(req.headers) || + req.get_header_value_u64("Content-Length") > 0; +} + +bool is_connection_persistent(const Request &req) { + auto conn = req.get_header_value("Connection"); + if (conn == "close") { return false; } + if (req.version == "HTTP/1.0" && conn != "Keep-Alive") { return false; } + return true; +} + bool expect_content(const Request &req) { if (req.method == "POST" || req.method == "PUT" || req.method == "PATCH" || req.method == "DELETE") { return true; } - if (req.has_header("Content-Length") && - req.get_header_value_u64("Content-Length") > 0) { - return true; + return has_framed_body(req); +} + +#ifdef _WIN32 +class WSInit { +public: + WSInit() { + WSADATA wsaData; + if (WSAStartup(0x0002, &wsaData) == 0) is_valid_ = true; + } + + ~WSInit() { + if (is_valid_) WSACleanup(); + } + + bool is_valid_ = false; +}; + +static WSInit wsinit_; +#endif + +bool parse_www_authenticate(const Response &res, + std::map &auth, + bool is_proxy) { + auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; + if (res.has_header(auth_key)) { + thread_local auto re = + std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); + auto s = res.get_header_value(auth_key); + auto pos = s.find(' '); + if (pos != std::string::npos) { + auto type = s.substr(0, pos); + if (type == "Basic") { + return false; + } else if (type == "Digest") { + s = s.substr(pos + 1); + auto beg = std::sregex_iterator(s.begin(), s.end(), re); + for (auto i = beg; i != std::sregex_iterator(); ++i) { + const auto &m = *i; + auto key = s.substr(static_cast(m.position(1)), + static_cast(m.length(1))); + auto val = m.length(2) > 0 + ? s.substr(static_cast(m.position(2)), + static_cast(m.length(2))) + : s.substr(static_cast(m.position(3)), + static_cast(m.length(3))); + auth[std::move(key)] = std::move(val); + } + return true; + } + } } - if (is_chunked_transfer_encoding(req.headers)) { return true; } return false; } -bool has_crlf(const std::string &s) { - auto p = s.c_str(); - while (*p) { - if (*p == '\r' || *p == '\n') { return true; } - p++; +class ContentProviderAdapter { +public: + explicit ContentProviderAdapter( + ContentProviderWithoutLength &&content_provider) + : content_provider_(std::move(content_provider)) {} + + bool operator()(size_t offset, size_t, DataSink &sink) { + return content_provider_(offset, sink); } - return false; + +private: + ContentProviderWithoutLength content_provider_; +}; + +// NOTE: https://www.rfc-editor.org/rfc/rfc9110#section-5 +namespace fields { + +bool is_token_char(char c) { + return std::isalnum(c) || c == '!' || c == '#' || c == '$' || c == '%' || + c == '&' || c == '\'' || c == '*' || c == '+' || c == '-' || + c == '.' || c == '^' || c == '_' || c == '`' || c == '|' || c == '~'; } +bool is_token(const std::string &s) { + if (s.empty()) { return false; } + for (auto c : s) { + if (!is_token_char(c)) { return false; } + } + return true; +} + +bool is_field_name(const std::string &s) { return is_token(s); } + +bool is_vchar(char c) { return c >= 33 && c <= 126; } + +bool is_obs_text(char c) { return 128 <= static_cast(c); } + +bool is_field_vchar(char c) { return is_vchar(c) || is_obs_text(c); } + +bool is_field_content(const std::string &s) { + if (s.empty()) { return true; } + + if (s.size() == 1) { + return is_field_vchar(s[0]); + } else if (s.size() == 2) { + return is_field_vchar(s[0]) && is_field_vchar(s[1]); + } else { + size_t i = 0; + + if (!is_field_vchar(s[i])) { return false; } + i++; + + while (i < s.size() - 1) { + auto c = s[i++]; + if (c == ' ' || c == '\t' || is_field_vchar(c)) { + } else { + return false; + } + } + + return is_field_vchar(s[i]); + } +} + +bool is_field_value(const std::string &s) { return is_field_content(s); } + +} // namespace fields + +bool perform_websocket_handshake(Stream &strm, const std::string &host, + int port, const std::string &path, + const Headers &headers, + std::string &selected_subprotocol) { + // Validate path and host + if (!fields::is_field_value(path) || !fields::is_field_value(host)) { + return false; + } + + // Validate user-provided headers + for (const auto &h : headers) { + if (!fields::is_field_name(h.first) || !fields::is_field_value(h.second)) { + return false; + } + } + + // Generate random Sec-WebSocket-Key + thread_local std::mt19937 rng(std::random_device{}()); + std::string key_bytes(16, '\0'); + for (size_t i = 0; i < 16; i += 4) { + auto r = rng(); + std::memcpy(&key_bytes[i], &r, (std::min)(size_t(4), size_t(16 - i))); + } + auto client_key = base64_encode(key_bytes); + + // Build upgrade request + std::string req_str = "GET " + path + " HTTP/1.1\r\n"; + req_str += "Host: " + host + ":" + std::to_string(port) + "\r\n"; + req_str += "Upgrade: websocket\r\n"; + req_str += "Connection: Upgrade\r\n"; + req_str += "Sec-WebSocket-Key: " + client_key + "\r\n"; + req_str += "Sec-WebSocket-Version: 13\r\n"; + for (const auto &h : headers) { + req_str += h.first + ": " + h.second + "\r\n"; + } + req_str += "\r\n"; + + if (strm.write(req_str.data(), req_str.size()) < 0) { return false; } + + // Verify 101 response and Sec-WebSocket-Accept header + auto expected_accept = websocket_accept_key(client_key); + return read_websocket_upgrade_response(strm, expected_accept, + selected_subprotocol); +} + +} // namespace detail + +/* + * Group 2: detail namespace - SSL common utilities + */ + +#ifdef CPPHTTPLIB_SSL_ENABLED +namespace detail { + +class SSLSocketStream final : public Stream { +public: + SSLSocketStream( + socket_t sock, tls::session_t session, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, time_t max_timeout_msec = 0, + std::chrono::time_point start_time = + (std::chrono::steady_clock::time_point::min)()); + ~SSLSocketStream() override; + + bool is_readable() const override; + bool wait_readable() const override; + bool wait_writable() const override; + bool is_peer_alive() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + time_t duration() const override; + void set_read_timeout(time_t sec, time_t usec = 0) override; + +private: + socket_t sock_; + tls::session_t session_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; + time_t max_timeout_msec_; + const std::chrono::time_point start_time_; +}; + #ifdef CPPHTTPLIB_OPENSSL_SUPPORT std::string message_digest(const std::string &s, const EVP_MD *algo) { auto context = std::unique_ptr( @@ -3481,6 +4982,122 @@ std::string SHA_256(const std::string &s) { std::string SHA_512(const std::string &s) { return message_digest(s, EVP_sha512()); } +#elif defined(CPPHTTPLIB_MBEDTLS_SUPPORT) +namespace { +template +std::string hash_to_hex(const unsigned char (&hash)[N]) { + std::stringstream ss; + for (size_t i = 0; i < N; ++i) { + ss << std::hex << std::setw(2) << std::setfill('0') + << static_cast(hash[i]); + } + return ss.str(); +} +} // namespace + +std::string MD5(const std::string &s) { + unsigned char hash[16]; +#ifdef CPPHTTPLIB_MBEDTLS_V3 + mbedtls_md5(reinterpret_cast(s.c_str()), s.size(), + hash); +#else + mbedtls_md5_ret(reinterpret_cast(s.c_str()), s.size(), + hash); +#endif + return hash_to_hex(hash); +} + +std::string SHA_256(const std::string &s) { + unsigned char hash[32]; +#ifdef CPPHTTPLIB_MBEDTLS_V3 + mbedtls_sha256(reinterpret_cast(s.c_str()), s.size(), + hash, 0); +#else + mbedtls_sha256_ret(reinterpret_cast(s.c_str()), + s.size(), hash, 0); +#endif + return hash_to_hex(hash); +} + +std::string SHA_512(const std::string &s) { + unsigned char hash[64]; +#ifdef CPPHTTPLIB_MBEDTLS_V3 + mbedtls_sha512(reinterpret_cast(s.c_str()), s.size(), + hash, 0); +#else + mbedtls_sha512_ret(reinterpret_cast(s.c_str()), + s.size(), hash, 0); +#endif + return hash_to_hex(hash); +} +#elif defined(CPPHTTPLIB_WOLFSSL_SUPPORT) +namespace { +template +std::string hash_to_hex(const unsigned char (&hash)[N]) { + std::stringstream ss; + for (size_t i = 0; i < N; ++i) { + ss << std::hex << std::setw(2) << std::setfill('0') + << static_cast(hash[i]); + } + return ss.str(); +} +} // namespace + +std::string MD5(const std::string &s) { + unsigned char hash[WC_MD5_DIGEST_SIZE]; + wc_Md5Hash(reinterpret_cast(s.c_str()), + static_cast(s.size()), hash); + return hash_to_hex(hash); +} + +std::string SHA_256(const std::string &s) { + unsigned char hash[WC_SHA256_DIGEST_SIZE]; + wc_Sha256Hash(reinterpret_cast(s.c_str()), + static_cast(s.size()), hash); + return hash_to_hex(hash); +} + +std::string SHA_512(const std::string &s) { + unsigned char hash[WC_SHA512_DIGEST_SIZE]; + wc_Sha512Hash(reinterpret_cast(s.c_str()), + static_cast(s.size()), hash); + return hash_to_hex(hash); +} +#endif + +bool is_ip_address(const std::string &host) { + struct in_addr addr4; + struct in6_addr addr6; + return inet_pton(AF_INET, host.c_str(), &addr4) == 1 || + inet_pton(AF_INET6, host.c_str(), &addr6) == 1; +} + +template +bool process_server_socket_ssl( + const std::atomic &svr_sock, tls::session_t session, + socket_t sock, size_t keep_alive_max_count, time_t keep_alive_timeout_sec, + time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, T callback) { + return process_server_socket_core( + svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SSLSocketStream strm(sock, session, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); +} + +template +bool process_client_socket_ssl( + tls::session_t session, socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, time_t max_timeout_msec, + std::chrono::time_point start_time, T callback) { + SSLSocketStream strm(sock, session, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec, max_timeout_msec, + start_time); + return callback(strm); +} std::pair make_digest_authentication_header( const Request &req, const std::map &auth, @@ -3542,200 +5159,357 @@ std::pair make_digest_authentication_header( return std::make_pair(key, field); } -bool is_ssl_peer_could_be_closed(SSL *ssl, socket_t sock) { - detail::set_nonblocking(sock, true); - auto se = detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); +bool match_hostname(const std::string &pattern, + const std::string &hostname) { + // Exact match (case-insensitive) + if (detail::case_ignore::equal(hostname, pattern)) { return true; } - char buf[1]; - return !SSL_peek(ssl, buf, 1) && - SSL_get_error(ssl, 0) == SSL_ERROR_ZERO_RETURN; -} + // Split both pattern and hostname into components by '.' + std::vector pattern_components; + if (!pattern.empty()) { + split(pattern.data(), pattern.data() + pattern.size(), '.', + [&](const char *b, const char *e) { + pattern_components.emplace_back(b, e); + }); + } -#ifdef _WIN32 -// NOTE: This code came up with the following stackoverflow post: -// https://stackoverflow.com/questions/9507184/can-openssl-on-windows-use-the-system-certificate-store -bool load_system_certs_on_windows(X509_STORE *store) { - auto hStore = CertOpenSystemStoreW((HCRYPTPROV_LEGACY)NULL, L"ROOT"); - if (!hStore) { return false; } + std::vector host_components; + if (!hostname.empty()) { + split(hostname.data(), hostname.data() + hostname.size(), '.', + [&](const char *b, const char *e) { + host_components.emplace_back(b, e); + }); + } - auto result = false; - PCCERT_CONTEXT pContext = NULL; - while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != - nullptr) { - auto encoded_cert = - static_cast(pContext->pbCertEncoded); + // Component count must match + if (host_components.size() != pattern_components.size()) { return false; } - auto x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded); - if (x509) { - X509_STORE_add_cert(store, x509); - X509_free(x509); - result = true; + // Compare each component with wildcard support + // Supports: "*" (full wildcard), "prefix*" (partial wildcard) + // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 + auto itr = pattern_components.begin(); + for (const auto &h : host_components) { + auto &p = *itr; + if (!detail::case_ignore::equal(p, h) && p != "*") { + bool partial_match = false; + if (!p.empty() && p[p.size() - 1] == '*') { + const auto prefix_length = p.size() - 1; + if (prefix_length == 0) { + partial_match = true; + } else if (h.size() >= prefix_length) { + partial_match = + std::equal(p.begin(), + p.begin() + static_cast( + prefix_length), + h.begin(), [](const char ca, const char cb) { + return detail::case_ignore::to_lower(ca) == + detail::case_ignore::to_lower(cb); + }); + } + } + if (!partial_match) { return false; } } + ++itr; } - CertFreeCertificateContext(pContext); - CertCloseStore(hStore, 0); - - return result; -} -#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && TARGET_OS_MAC -template -using CFObjectPtr = - std::unique_ptr::type, void (*)(CFTypeRef)>; - -void cf_object_ptr_deleter(CFTypeRef obj) { - if (obj) { CFRelease(obj); } -} - -bool retrieve_certs_from_keychain(CFObjectPtr &certs) { - CFStringRef keys[] = {kSecClass, kSecMatchLimit, kSecReturnRef}; - CFTypeRef values[] = {kSecClassCertificate, kSecMatchLimitAll, - kCFBooleanTrue}; - - CFObjectPtr query( - CFDictionaryCreate(nullptr, reinterpret_cast(keys), values, - sizeof(keys) / sizeof(keys[0]), - &kCFTypeDictionaryKeyCallBacks, - &kCFTypeDictionaryValueCallBacks), - cf_object_ptr_deleter); - - if (!query) { return false; } - - CFTypeRef security_items = nullptr; - if (SecItemCopyMatching(query.get(), &security_items) != errSecSuccess || - CFArrayGetTypeID() != CFGetTypeID(security_items)) { - return false; - } - - certs.reset(reinterpret_cast(security_items)); return true; } -bool retrieve_root_certs_from_keychain(CFObjectPtr &certs) { - CFArrayRef root_security_items = nullptr; - if (SecTrustCopyAnchorCertificates(&root_security_items) != errSecSuccess) { +#ifdef _WIN32 +// Verify certificate using Windows CertGetCertificateChain API. +// This provides real-time certificate validation with Windows Update +// integration, independent of the TLS backend (OpenSSL or MbedTLS). +bool +verify_cert_with_windows_schannel(const std::vector &der_cert, + const std::string &hostname, + bool verify_hostname, uint64_t &out_error) { + if (der_cert.empty()) { return false; } + + out_error = 0; + + // Create Windows certificate context from DER data + auto cert_context = CertCreateCertificateContext( + X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, der_cert.data(), + static_cast(der_cert.size())); + + if (!cert_context) { + out_error = GetLastError(); return false; } - certs.reset(root_security_items); - return true; -} + auto cert_guard = + scope_exit([&] { CertFreeCertificateContext(cert_context); }); -bool add_certs_to_x509_store(CFArrayRef certs, X509_STORE *store) { - auto result = false; - for (auto i = 0; i < CFArrayGetCount(certs); ++i) { - const auto cert = reinterpret_cast( - CFArrayGetValueAtIndex(certs, i)); + // Setup chain parameters + CERT_CHAIN_PARA chain_para = {}; + chain_para.cbSize = sizeof(chain_para); - if (SecCertificateGetTypeID() != CFGetTypeID(cert)) { continue; } + // Build certificate chain with revocation checking + PCCERT_CHAIN_CONTEXT chain_context = nullptr; + auto chain_result = CertGetCertificateChain( + nullptr, cert_context, nullptr, cert_context->hCertStore, &chain_para, + CERT_CHAIN_CACHE_END_CERT | CERT_CHAIN_REVOCATION_CHECK_END_CERT | + CERT_CHAIN_REVOCATION_ACCUMULATIVE_TIMEOUT, + nullptr, &chain_context); - CFDataRef cert_data = nullptr; - if (SecItemExport(cert, kSecFormatX509Cert, 0, nullptr, &cert_data) != - errSecSuccess) { - continue; - } - - CFObjectPtr cert_data_ptr(cert_data, cf_object_ptr_deleter); - - auto encoded_cert = static_cast( - CFDataGetBytePtr(cert_data_ptr.get())); - - auto x509 = - d2i_X509(NULL, &encoded_cert, CFDataGetLength(cert_data_ptr.get())); - - if (x509) { - X509_STORE_add_cert(store, x509); - X509_free(x509); - result = true; - } + if (!chain_result || !chain_context) { + out_error = GetLastError(); + return false; } - return result; -} + auto chain_guard = + scope_exit([&] { CertFreeCertificateChain(chain_context); }); -bool load_system_certs_on_macos(X509_STORE *store) { - auto result = false; - CFObjectPtr certs(nullptr, cf_object_ptr_deleter); - if (retrieve_certs_from_keychain(certs) && certs) { - result = add_certs_to_x509_store(certs.get(), store); + // Check if chain has errors + if (chain_context->TrustStatus.dwErrorStatus != CERT_TRUST_NO_ERROR) { + out_error = chain_context->TrustStatus.dwErrorStatus; + return false; } - if (retrieve_root_certs_from_keychain(certs) && certs) { - result = add_certs_to_x509_store(certs.get(), store) || result; - } - - return result; -} -#endif // _WIN32 -#endif // CPPHTTPLIB_OPENSSL_SUPPORT - -#ifdef _WIN32 -class WSInit { -public: - WSInit() { - WSADATA wsaData; - if (WSAStartup(0x0002, &wsaData) == 0) is_valid_ = true; - } - - ~WSInit() { - if (is_valid_) WSACleanup(); - } - - bool is_valid_ = false; -}; - -static WSInit wsinit_; + // Verify SSL policy + SSL_EXTRA_CERT_CHAIN_POLICY_PARA extra_policy_para = {}; + extra_policy_para.cbSize = sizeof(extra_policy_para); +#ifdef AUTHTYPE_SERVER + extra_policy_para.dwAuthType = AUTHTYPE_SERVER; #endif -bool parse_www_authenticate(const Response &res, - std::map &auth, - bool is_proxy) { - auto auth_key = is_proxy ? "Proxy-Authenticate" : "WWW-Authenticate"; - if (res.has_header(auth_key)) { - thread_local auto re = - std::regex(R"~((?:(?:,\s*)?(.+?)=(?:"(.*?)"|([^,]*))))~"); - auto s = res.get_header_value(auth_key); - auto pos = s.find(' '); - if (pos != std::string::npos) { - auto type = s.substr(0, pos); - if (type == "Basic") { - return false; - } else if (type == "Digest") { - s = s.substr(pos + 1); - auto beg = std::sregex_iterator(s.begin(), s.end(), re); - for (auto i = beg; i != std::sregex_iterator(); ++i) { - const auto &m = *i; - auto key = s.substr(static_cast(m.position(1)), - static_cast(m.length(1))); - auto val = m.length(2) > 0 - ? s.substr(static_cast(m.position(2)), - static_cast(m.length(2))) - : s.substr(static_cast(m.position(3)), - static_cast(m.length(3))); - auth[key] = val; - } - return true; - } - } + std::wstring whost; + if (verify_hostname) { + whost = u8string_to_wstring(hostname.c_str()); + extra_policy_para.pwszServerName = const_cast(whost.c_str()); } - return false; + + CERT_CHAIN_POLICY_PARA policy_para = {}; + policy_para.cbSize = sizeof(policy_para); +#ifdef CERT_CHAIN_POLICY_IGNORE_ALL_REV_UNKNOWN_FLAGS + policy_para.dwFlags = CERT_CHAIN_POLICY_IGNORE_ALL_REV_UNKNOWN_FLAGS; +#else + policy_para.dwFlags = 0; +#endif + policy_para.pvExtraPolicyPara = &extra_policy_para; + + CERT_CHAIN_POLICY_STATUS policy_status = {}; + policy_status.cbSize = sizeof(policy_status); + + if (!CertVerifyCertificateChainPolicy(CERT_CHAIN_POLICY_SSL, chain_context, + &policy_para, &policy_status)) { + out_error = GetLastError(); + return false; + } + + if (policy_status.dwError != 0) { + out_error = policy_status.dwError; + return false; + } + + return true; +} +#endif // _WIN32 + +bool setup_client_tls_session(const std::string &host, tls::ctx_t &ctx, + tls::session_t &session, socket_t sock, + bool server_certificate_verification, + const std::string &ca_cert_file_path, + tls::ca_store_t ca_cert_store, + time_t timeout_sec, time_t timeout_usec) { + using namespace tls; + + ctx = create_client_context(); + if (!ctx) { return false; } + + if (server_certificate_verification) { + if (!ca_cert_file_path.empty()) { + load_ca_file(ctx, ca_cert_file_path.c_str()); + } + if (ca_cert_store) { set_ca_store(ctx, ca_cert_store); } + load_system_certs(ctx); + } + + bool is_ip = is_ip_address(host); + +#ifdef CPPHTTPLIB_MBEDTLS_SUPPORT + if (is_ip && server_certificate_verification) { + set_verify_client(ctx, false); + } else { + set_verify_client(ctx, server_certificate_verification); + } +#endif + + session = create_session(ctx, sock); + if (!session) { return false; } + + // RFC 6066: SNI must not be set for IP addresses + if (!is_ip) { set_sni(session, host.c_str()); } + if (server_certificate_verification) { set_hostname(session, host.c_str()); } + + if (!connect_nonblocking(session, sock, timeout_sec, timeout_usec, nullptr)) { + return false; + } + + if (server_certificate_verification) { + if (get_verify_result(session) != 0) { return false; } + } + + return true; } -class ContentProviderAdapter { -public: - explicit ContentProviderAdapter( - ContentProviderWithoutLength &&content_provider) - : content_provider_(content_provider) {} +} // namespace detail +#endif // CPPHTTPLIB_SSL_ENABLED - bool operator()(size_t offset, size_t, DataSink &sink) { - return content_provider_(offset, sink); +/* + * Group 3: httplib namespace - Non-SSL public API implementations + */ + +void default_socket_options(socket_t sock) { + set_socket_opt(sock, SOL_SOCKET, +#ifdef SO_REUSEPORT + SO_REUSEPORT, +#else + SO_REUSEADDR, +#endif + 1); +} + +bool set_socket_opt(socket_t sock, int level, int optname, int optval) { + return detail::set_socket_opt_impl(sock, level, optname, &optval, + sizeof(optval)); +} + +std::string get_bearer_token_auth(const Request &req) { + if (req.has_header("Authorization")) { + constexpr auto bearer_header_prefix_len = detail::str_len("Bearer "); + return req.get_header_value("Authorization") + .substr(bearer_header_prefix_len); + } + return ""; +} + +const char *status_message(int status) { + switch (status) { + case StatusCode::Continue_100: return "Continue"; + case StatusCode::SwitchingProtocol_101: return "Switching Protocol"; + case StatusCode::Processing_102: return "Processing"; + case StatusCode::EarlyHints_103: return "Early Hints"; + case StatusCode::OK_200: return "OK"; + case StatusCode::Created_201: return "Created"; + case StatusCode::Accepted_202: return "Accepted"; + case StatusCode::NonAuthoritativeInformation_203: + return "Non-Authoritative Information"; + case StatusCode::NoContent_204: return "No Content"; + case StatusCode::ResetContent_205: return "Reset Content"; + case StatusCode::PartialContent_206: return "Partial Content"; + case StatusCode::MultiStatus_207: return "Multi-Status"; + case StatusCode::AlreadyReported_208: return "Already Reported"; + case StatusCode::IMUsed_226: return "IM Used"; + case StatusCode::MultipleChoices_300: return "Multiple Choices"; + case StatusCode::MovedPermanently_301: return "Moved Permanently"; + case StatusCode::Found_302: return "Found"; + case StatusCode::SeeOther_303: return "See Other"; + case StatusCode::NotModified_304: return "Not Modified"; + case StatusCode::UseProxy_305: return "Use Proxy"; + case StatusCode::unused_306: return "unused"; + case StatusCode::TemporaryRedirect_307: return "Temporary Redirect"; + case StatusCode::PermanentRedirect_308: return "Permanent Redirect"; + case StatusCode::BadRequest_400: return "Bad Request"; + case StatusCode::Unauthorized_401: return "Unauthorized"; + case StatusCode::PaymentRequired_402: return "Payment Required"; + case StatusCode::Forbidden_403: return "Forbidden"; + case StatusCode::NotFound_404: return "Not Found"; + case StatusCode::MethodNotAllowed_405: return "Method Not Allowed"; + case StatusCode::NotAcceptable_406: return "Not Acceptable"; + case StatusCode::ProxyAuthenticationRequired_407: + return "Proxy Authentication Required"; + case StatusCode::RequestTimeout_408: return "Request Timeout"; + case StatusCode::Conflict_409: return "Conflict"; + case StatusCode::Gone_410: return "Gone"; + case StatusCode::LengthRequired_411: return "Length Required"; + case StatusCode::PreconditionFailed_412: return "Precondition Failed"; + case StatusCode::PayloadTooLarge_413: return "Payload Too Large"; + case StatusCode::UriTooLong_414: return "URI Too Long"; + case StatusCode::UnsupportedMediaType_415: return "Unsupported Media Type"; + case StatusCode::RangeNotSatisfiable_416: return "Range Not Satisfiable"; + case StatusCode::ExpectationFailed_417: return "Expectation Failed"; + case StatusCode::ImATeapot_418: return "I'm a teapot"; + case StatusCode::MisdirectedRequest_421: return "Misdirected Request"; + case StatusCode::UnprocessableContent_422: return "Unprocessable Content"; + case StatusCode::Locked_423: return "Locked"; + case StatusCode::FailedDependency_424: return "Failed Dependency"; + case StatusCode::TooEarly_425: return "Too Early"; + case StatusCode::UpgradeRequired_426: return "Upgrade Required"; + case StatusCode::PreconditionRequired_428: return "Precondition Required"; + case StatusCode::TooManyRequests_429: return "Too Many Requests"; + case StatusCode::RequestHeaderFieldsTooLarge_431: + return "Request Header Fields Too Large"; + case StatusCode::UnavailableForLegalReasons_451: + return "Unavailable For Legal Reasons"; + case StatusCode::NotImplemented_501: return "Not Implemented"; + case StatusCode::BadGateway_502: return "Bad Gateway"; + case StatusCode::ServiceUnavailable_503: return "Service Unavailable"; + case StatusCode::GatewayTimeout_504: return "Gateway Timeout"; + case StatusCode::HttpVersionNotSupported_505: + return "HTTP Version Not Supported"; + case StatusCode::VariantAlsoNegotiates_506: return "Variant Also Negotiates"; + case StatusCode::InsufficientStorage_507: return "Insufficient Storage"; + case StatusCode::LoopDetected_508: return "Loop Detected"; + case StatusCode::NotExtended_510: return "Not Extended"; + case StatusCode::NetworkAuthenticationRequired_511: + return "Network Authentication Required"; + + default: + case StatusCode::InternalServerError_500: return "Internal Server Error"; + } +} + +std::string to_string(const Error error) { + switch (error) { + case Error::Success: return "Success (no error)"; + case Error::Unknown: return "Unknown"; + case Error::Connection: return "Could not establish connection"; + case Error::BindIPAddress: return "Failed to bind IP address"; + case Error::Read: return "Failed to read connection"; + case Error::Write: return "Failed to write connection"; + case Error::ExceedRedirectCount: return "Maximum redirect count exceeded"; + case Error::Canceled: return "Connection handling canceled"; + case Error::SSLConnection: return "SSL connection failed"; + case Error::SSLLoadingCerts: return "SSL certificate loading failed"; + case Error::SSLServerVerification: return "SSL server verification failed"; + case Error::SSLServerHostnameVerification: + return "SSL server hostname verification failed"; + case Error::UnsupportedMultipartBoundaryChars: + return "Unsupported HTTP multipart boundary characters"; + case Error::Compression: return "Compression failed"; + case Error::ConnectionTimeout: return "Connection timed out"; + case Error::ProxyConnection: return "Proxy connection failed"; + case Error::ConnectionClosed: return "Connection closed by server"; + case Error::Timeout: return "Read timeout"; + case Error::ResourceExhaustion: return "Resource exhaustion"; + case Error::TooManyFormDataFiles: return "Too many form data files"; + case Error::ExceedMaxPayloadSize: return "Exceeded maximum payload size"; + case Error::ExceedUriMaxLength: return "Exceeded maximum URI length"; + case Error::ExceedMaxSocketDescriptorCount: + return "Exceeded maximum socket descriptor count"; + case Error::InvalidRequestLine: return "Invalid request line"; + case Error::InvalidHTTPMethod: return "Invalid HTTP method"; + case Error::InvalidHTTPVersion: return "Invalid HTTP version"; + case Error::InvalidHeaders: return "Invalid headers"; + case Error::MultipartParsing: return "Multipart parsing failed"; + case Error::OpenFile: return "Failed to open file"; + case Error::Listen: return "Failed to listen on socket"; + case Error::GetSockName: return "Failed to get socket name"; + case Error::UnsupportedAddressFamily: return "Unsupported address family"; + case Error::HTTPParsing: return "HTTP parsing failed"; + case Error::InvalidRangeHeader: return "Invalid Range header"; + default: break; } -private: - ContentProviderWithoutLength content_provider_; -}; + return "Invalid"; +} -} // namespace detail +std::ostream &operator<<(std::ostream &os, const Error &obj) { + os << to_string(obj); + os << " (" << static_cast::type>(obj) << ')'; + return os; +} std::string hosted_at(const std::string &hostname) { std::vector addrs; @@ -3770,7 +5544,7 @@ void hosted_at(const std::string &hostname, auto dummy = -1; if (detail::get_ip_and_port(addr, sizeof(struct sockaddr_storage), ip, dummy)) { - addrs.push_back(ip); + addrs.emplace_back(std::move(ip)); } } } @@ -3903,7 +5677,8 @@ std::string decode_path_component(const std::string &component) { // Unicode %uXXXX encoding auto val = 0; if (detail::from_hex_to_i(component, i + 2, 4, val)) { - // 4 digits Unicode codes + // 4 digits Unicode codes: val is 0x0000-0xFFFF (from 4 hex digits), + // so to_utf8 writes at most 3 bytes. buff[4] is safe. char buff[4]; size_t len = detail::to_utf8(val, buff); if (len > 0) { result.append(buff, len); } @@ -4008,6 +5783,30 @@ std::string decode_query_component(const std::string &component, return result; } +std::string sanitize_filename(const std::string &filename) { + // Extract basename: find the last path separator (/ or \) + auto pos = filename.find_last_of("/\\"); + auto result = + (pos != std::string::npos) ? filename.substr(pos + 1) : filename; + + // Strip null bytes + result.erase(std::remove(result.begin(), result.end(), '\0'), result.end()); + + // Trim whitespace + { + auto start = result.find_first_not_of(" \t"); + auto end = result.find_last_not_of(" \t"); + result = (start == std::string::npos) + ? "" + : result.substr(start, end - start + 1); + } + + // Reject . and .. + if (result == "." || result == "..") { return ""; } + + return result; +} + std::string append_query_params(const std::string &path, const Params ¶ms) { std::string path_with_query = path; @@ -4049,6 +5848,11 @@ make_bearer_token_authentication_header(const std::string &token, } // Request implementation +size_t Request::get_header_value_u64(const std::string &key, size_t def, + size_t id) const { + return detail::get_header_value_u64(headers, key, def, id); +} + bool Request::has_header(const std::string &key) const { return detail::has_header(headers, key); } @@ -4059,16 +5863,12 @@ std::string Request::get_header_value(const std::string &key, } size_t Request::get_header_value_count(const std::string &key) const { - auto r = headers.equal_range(key); - return static_cast(std::distance(r.first, r.second)); + return detail::get_header_value_count(headers, key); } void Request::set_header(const std::string &key, const std::string &val) { - if (detail::fields::is_field_name(key) && - detail::fields::is_field_value(val)) { - headers.emplace(key, val); - } + detail::set_header(headers, key, val); } bool Request::has_trailer(const std::string &key) const { @@ -4077,11 +5877,7 @@ bool Request::has_trailer(const std::string &key) const { std::string Request::get_trailer_value(const std::string &key, size_t id) const { - auto rng = trailers.equal_range(key); - auto it = rng.first; - std::advance(it, static_cast(id)); - if (it != rng.second) { return it->second; } - return std::string(); + return detail::get_multimap_value(trailers, key, id); } size_t Request::get_trailer_value_count(const std::string &key) const { @@ -4095,11 +5891,18 @@ bool Request::has_param(const std::string &key) const { std::string Request::get_param_value(const std::string &key, size_t id) const { + return detail::get_multimap_value(params, key, id); +} + +std::vector +Request::get_param_values(const std::string &key) const { auto rng = params.equal_range(key); - auto it = rng.first; - std::advance(it, static_cast(id)); - if (it != rng.second) { return it->second; } - return std::string(); + std::vector values; + values.reserve(static_cast(std::distance(rng.first, rng.second))); + for (auto it = rng.first; it != rng.second; ++it) { + values.push_back(it->second); + } + return values; } size_t Request::get_param_value_count(const std::string &key) const { @@ -4109,7 +5912,7 @@ size_t Request::get_param_value_count(const std::string &key) const { bool Request::is_multipart_form_data() const { const auto &content_type = get_header_value("Content-Type"); - return !content_type.rfind("multipart/form-data", 0); + return detail::extract_media_type(content_type) == "multipart/form-data"; } // Multipart FormData implementation @@ -4143,11 +5946,7 @@ size_t MultipartFormData::get_field_count(const std::string &key) const { FormData MultipartFormData::get_file(const std::string &key, size_t id) const { - auto rng = files.equal_range(key); - auto it = rng.first; - std::advance(it, static_cast(id)); - if (it != rng.second) { return it->second; } - return FormData(); + return detail::get_multimap_value(files, key, id); } std::vector @@ -4170,6 +5969,11 @@ size_t MultipartFormData::get_file_count(const std::string &key) const { } // Response implementation +size_t Response::get_header_value_u64(const std::string &key, size_t def, + size_t id) const { + return detail::get_header_value_u64(headers, key, def, id); +} + bool Response::has_header(const std::string &key) const { return headers.find(key) != headers.end(); } @@ -4181,16 +5985,12 @@ std::string Response::get_header_value(const std::string &key, } size_t Response::get_header_value_count(const std::string &key) const { - auto r = headers.equal_range(key); - return static_cast(std::distance(r.first, r.second)); + return detail::get_header_value_count(headers, key); } void Response::set_header(const std::string &key, const std::string &val) { - if (detail::fields::is_field_name(key) && - detail::fields::is_field_value(val)) { - headers.emplace(key, val); - } + detail::set_header(headers, key, val); } bool Response::has_trailer(const std::string &key) const { return trailers.find(key) != trailers.end(); @@ -4198,11 +5998,7 @@ bool Response::has_trailer(const std::string &key) const { std::string Response::get_trailer_value(const std::string &key, size_t id) const { - auto rng = trailers.equal_range(key); - auto it = rng.first; - std::advance(it, static_cast(id)); - if (it != rng.second) { return it->second; } - return std::string(); + return detail::get_multimap_value(trailers, key, id); } size_t Response::get_trailer_value_count(const std::string &key) const { @@ -4285,6 +6081,12 @@ void Response::set_file_content(const std::string &path) { } // Result implementation +size_t Result::get_request_header_value_u64(const std::string &key, + size_t def, + size_t id) const { + return detail::get_header_value_u64(request_headers_, key, def, id); +} + bool Result::has_request_header(const std::string &key) const { return request_headers_.find(key) != request_headers_.end(); } @@ -4310,6 +6112,227 @@ ssize_t Stream::write(const std::string &s) { return write(s.data(), s.size()); } +// BodyReader implementation +ssize_t detail::BodyReader::read(char *buf, size_t len) { + if (!stream) { + last_error = Error::Connection; + return -1; + } + if (eof) { return 0; } + + if (!chunked) { + // Content-Length based reading + if (has_content_length && bytes_read >= content_length) { + eof = true; + return 0; + } + + auto to_read = len; + if (has_content_length) { + auto remaining = content_length - bytes_read; + to_read = (std::min)(len, remaining); + } + auto n = stream->read(buf, to_read); + + if (n < 0) { + last_error = stream->get_error(); + if (last_error == Error::Success) { last_error = Error::Read; } + eof = true; + return n; + } + if (n == 0) { + // Unexpected EOF before content_length + last_error = stream->get_error(); + if (last_error == Error::Success) { last_error = Error::Read; } + eof = true; + return 0; + } + + bytes_read += static_cast(n); + if (has_content_length && bytes_read >= content_length) { eof = true; } + if (payload_max_length > 0 && bytes_read > payload_max_length) { + last_error = Error::ExceedMaxPayloadSize; + eof = true; + return -1; + } + return n; + } + + // Chunked transfer encoding: delegate to shared decoder instance. + if (!chunked_decoder) { chunked_decoder.reset(new ChunkedDecoder(*stream)); } + + size_t chunk_offset = 0; + size_t chunk_total = 0; + auto n = chunked_decoder->read_payload(buf, len, chunk_offset, chunk_total); + if (n < 0) { + last_error = stream->get_error(); + if (last_error == Error::Success) { last_error = Error::Read; } + eof = true; + return n; + } + + if (n == 0) { + // Final chunk observed. Leave trailer parsing to the caller (StreamHandle). + eof = true; + return 0; + } + + bytes_read += static_cast(n); + if (payload_max_length > 0 && bytes_read > payload_max_length) { + last_error = Error::ExceedMaxPayloadSize; + eof = true; + return -1; + } + return n; +} + +// ThreadPool implementation +ThreadPool::ThreadPool(size_t n, size_t max_n, size_t mqr) + : base_thread_count_(n), max_queued_requests_(mqr), idle_thread_count_(0), + shutdown_(false) { +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + if (max_n != 0 && max_n < n) { + std::string msg = "max_threads must be >= base_threads"; + throw std::invalid_argument(msg); + } +#endif + max_thread_count_ = max_n == 0 ? n : max_n; + threads_.reserve(base_thread_count_); +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + try { +#endif + for (size_t i = 0; i < base_thread_count_; i++) { + threads_.emplace_back(std::thread([this]() { worker(false); })); + } +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + } catch (...) { + // If thread creation fails partway (e.g., pthread_create returns EAGAIN), + // signal the workers we already spawned to exit and join them so the + // vector destructor does not see joinable threads (which would call + // std::terminate). Then rethrow so the caller learns of the failure. + { + std::unique_lock lock(mutex_); + shutdown_ = true; + } + cond_.notify_all(); + for (auto &t : threads_) { + if (t.joinable()) { t.join(); } + } + throw; + } +#endif +} + +bool ThreadPool::enqueue(std::function fn) { + { + std::unique_lock lock(mutex_); + if (shutdown_) { return false; } + if (max_queued_requests_ > 0 && jobs_.size() >= max_queued_requests_) { + return false; + } + jobs_.push_back(std::move(fn)); + + // Spawn a dynamic thread if no idle threads and under max + if (idle_thread_count_ == 0 && + threads_.size() + dynamic_threads_.size() < max_thread_count_) { + cleanup_finished_threads(); + dynamic_threads_.emplace_back(std::thread([this]() { worker(true); })); + } + } + + cond_.notify_one(); + return true; +} + +void ThreadPool::shutdown() { + { + std::unique_lock lock(mutex_); + shutdown_ = true; + } + + cond_.notify_all(); + + for (auto &t : threads_) { + if (t.joinable()) { t.join(); } + } + + // Move dynamic_threads_ to a local list under the lock to avoid racing + // with worker threads that call move_to_finished() concurrently. + std::list remaining_dynamic; + { + std::unique_lock lock(mutex_); + remaining_dynamic = std::move(dynamic_threads_); + } + for (auto &t : remaining_dynamic) { + if (t.joinable()) { t.join(); } + } + + std::unique_lock lock(mutex_); + cleanup_finished_threads(); +} + +void ThreadPool::move_to_finished(std::thread::id id) { + // Must be called with mutex_ held + for (auto it = dynamic_threads_.begin(); it != dynamic_threads_.end(); ++it) { + if (it->get_id() == id) { + finished_threads_.push_back(std::move(*it)); + dynamic_threads_.erase(it); + return; + } + } +} + +void ThreadPool::cleanup_finished_threads() { + // Must be called with mutex_ held + for (auto &t : finished_threads_) { + if (t.joinable()) { t.join(); } + } + finished_threads_.clear(); +} + +void ThreadPool::worker(bool is_dynamic) { + for (;;) { + std::function fn; + { + std::unique_lock lock(mutex_); + idle_thread_count_++; + + if (is_dynamic) { + auto has_work = cond_.wait_for( + lock, std::chrono::seconds(CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT), + [&] { return !jobs_.empty() || shutdown_; }); + if (!has_work) { + // Timed out with no work - exit this dynamic thread + idle_thread_count_--; + move_to_finished(std::this_thread::get_id()); + break; + } + } else { + cond_.wait(lock, [&] { return !jobs_.empty() || shutdown_; }); + } + + idle_thread_count_--; + + if (shutdown_ && jobs_.empty()) { break; } + + fn = std::move(jobs_.front()); + jobs_.pop_front(); + } + + assert(true == static_cast(fn)); + fn(); + } + +#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(OPENSSL_IS_BORINGSSL) && \ + !defined(LIBRESSL_VERSION_NUMBER) + OPENSSL_thread_stop(); +#endif +} + +/* + * Group 1 (continued): detail namespace - Stream implementations + */ + namespace detail { void calc_actual_timeout(time_t max_timeout_msec, time_t duration_msec, @@ -4360,8 +6383,11 @@ bool SocketStream::wait_readable() const { } bool SocketStream::wait_writable() const { - return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && - is_socket_alive(sock_); + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0; +} + +bool SocketStream::is_peer_alive() const { + return detail::is_socket_alive(sock_); } ssize_t SocketStream::read(char *ptr, size_t size) { @@ -4386,7 +6412,10 @@ ssize_t SocketStream::read(char *ptr, size_t size) { } } - if (!wait_readable()) { return -1; } + if (!wait_readable()) { + error_ = Error::Timeout; + return -1; + } read_buff_off_ = 0; read_buff_content_size_ = 0; @@ -4395,6 +6424,11 @@ ssize_t SocketStream::read(char *ptr, size_t size) { auto n = read_socket(sock_, read_buff_.data(), read_buff_size_, CPPHTTPLIB_RECV_FLAGS); if (n <= 0) { + if (n == 0) { + error_ = Error::ConnectionClosed; + } else { + error_ = Error::Read; + } return n; } else if (n <= static_cast(size)) { memcpy(ptr, read_buff_.data(), static_cast(n)); @@ -4406,7 +6440,15 @@ ssize_t SocketStream::read(char *ptr, size_t size) { return static_cast(size); } } else { - return read_socket(sock_, ptr, size, CPPHTTPLIB_RECV_FLAGS); + auto n = read_socket(sock_, ptr, size, CPPHTTPLIB_RECV_FLAGS); + if (n <= 0) { + if (n == 0) { + error_ = Error::ConnectionClosed; + } else { + error_ = Error::Read; + } + } + return n; } } @@ -4439,6 +6481,11 @@ time_t SocketStream::duration() const { .count(); } +void SocketStream::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + // Buffer stream implementation bool BufferStream::is_readable() const { return true; } @@ -4570,19 +6617,22 @@ bool RegexMatcher::match(Request &request) const { return std::regex_match(request.path, request.matches, regex_); } -std::string make_host_and_port_string(const std::string &host, int port, - bool is_ssl) { - std::string result; - +// Enclose IPv6 address in brackets if needed +std::string prepare_host_string(const std::string &host) { // Enclose IPv6 address in brackets (but not if already enclosed) if (host.find(':') == std::string::npos || (!host.empty() && host[0] == '[')) { // IPv4, hostname, or already bracketed IPv6 - result = host; + return host; } else { // IPv6 address without brackets - result = "[" + host + "]"; + return "[" + host + "]"; } +} + +std::string make_host_and_port_string(const std::string &host, int port, + bool is_ssl) { + auto result = prepare_host_string(host); // Append port if not default if ((!is_ssl && port == 80) || (is_ssl && port == 443)) { @@ -4594,12 +6644,365 @@ std::string make_host_and_port_string(const std::string &host, int port, return result; } +// Create "host:port" string always including port number (for CONNECT method) +std::string +make_host_and_port_string_always_port(const std::string &host, int port) { + return prepare_host_string(host) + ":" + std::to_string(port); +} + +bool parse_no_proxy_entry(const std::string &token, NoProxyEntry &out); +NormalizedTarget normalize_target(const std::string &host); +bool ip_in_cidr(const IPBytes &ip, const IPBytes &net, int prefix_bits); +bool host_matches_no_proxy(const NormalizedTarget &target, + const std::vector &entries); + +bool ip_in_cidr(const IPBytes &ip, const IPBytes &net, int prefix_bits) { + if (prefix_bits < 0 || prefix_bits > 128) { return false; } + if (prefix_bits == 0) { return true; } + int full_bytes = prefix_bits / 8; + int rem_bits = prefix_bits % 8; + if (full_bytes > 0 && std::memcmp(ip.data(), net.data(), + static_cast(full_bytes)) != 0) { + return false; + } + if (rem_bits == 0) { return true; } + auto i = static_cast(full_bytes); + auto mask = static_cast(0xFFu << (8 - rem_bits)); + return (ip[i] & mask) == (net[i] & mask); +} + +bool parse_no_proxy_entry(const std::string &token, NoProxyEntry &out) { + if (token.empty()) { return false; } + + if (token == "*") { + out.kind = NoProxyKind::Wildcard; + return true; + } + + auto slash = token.find('/'); + std::string addr_part = + (slash == std::string::npos) ? token : token.substr(0, slash); + std::string prefix_part = + (slash == std::string::npos) ? std::string() : token.substr(slash + 1); + + // A bare slash or trailing-slash CIDR like "10.0.0.0/" is malformed; + // don't silently treat it as a /32 (or /128). + if (slash != std::string::npos && prefix_part.empty()) { return false; } + + // Accept the bracketed IPv6 form ("[::1]", "[fe80::]/10") as well as the + // bare form. Brackets have no meaning for IPv4, so skip the IPv4 attempt + // when brackets are present. + bool bracketed = addr_part.size() >= 2 && addr_part.front() == '[' && + addr_part.back() == ']'; + if (bracketed) { addr_part = addr_part.substr(1, addr_part.size() - 2); } + + if (!bracketed) { + struct in_addr v4; + if (inet_pton(AF_INET, addr_part.c_str(), &v4) == 1) { + int prefix = 32; + if (!prefix_part.empty()) { + auto r = from_chars(prefix_part.data(), + prefix_part.data() + prefix_part.size(), prefix); + if (r.ec != std::errc{} || + r.ptr != prefix_part.data() + prefix_part.size()) { + return false; + } + if (prefix < 0 || prefix > 32) { return false; } + } + out.kind = NoProxyKind::IPv4Cidr; + std::memcpy(out.net.data(), &v4, sizeof(v4)); + out.prefix_bits = prefix; + return true; + } + } + + struct in6_addr v6; + if (inet_pton(AF_INET6, addr_part.c_str(), &v6) == 1) { + int prefix = 128; + if (!prefix_part.empty()) { + auto r = from_chars(prefix_part.data(), + prefix_part.data() + prefix_part.size(), prefix); + if (r.ec != std::errc{} || + r.ptr != prefix_part.data() + prefix_part.size()) { + return false; + } + if (prefix < 0 || prefix > 128) { return false; } + } + out.kind = NoProxyKind::IPv6Cidr; + std::memcpy(out.net.data(), &v6, sizeof(v6)); + out.prefix_bits = prefix; + return true; + } + + // Bracketed entries can only be IPv6. If the IPv6 parse above failed, + // the entry is malformed โ€” don't fall through to the hostname branch. + if (bracketed) { return false; } + + // A '/' on a non-IP token means a CIDR prefix without an address. Reject. + if (slash != std::string::npos) { return false; } + // Port-specific entries (host:port) are not supported. + if (token.find(':') != std::string::npos) { return false; } + + std::string hostname = case_ignore::to_lower(token); + while (!hostname.empty() && hostname.front() == '.') { + hostname.erase(hostname.begin()); + } + while (!hostname.empty() && hostname.back() == '.') { + hostname.pop_back(); + } + if (hostname.empty()) { return false; } + + out.kind = NoProxyKind::HostnameSuffix; + out.hostname_pattern = std::move(hostname); + return true; +} + +NormalizedTarget normalize_target(const std::string &host) { + NormalizedTarget t; + std::string h = host; + + if (h.size() >= 2 && h.front() == '[' && h.back() == ']') { + h = h.substr(1, h.size() - 2); + } + + // Strip a single trailing dot so "example.com." canonicalizes to + // "example.com". + if (!h.empty() && h.back() == '.') { h.pop_back(); } + + t.hostname = case_ignore::to_lower(h); + + if (!t.hostname.empty()) { + struct in_addr v4; + struct in6_addr v6; + if (inet_pton(AF_INET, t.hostname.c_str(), &v4) == 1) { + t.is_ipv4 = true; + std::memcpy(t.ip.data(), &v4, sizeof(v4)); + } else if (inet_pton(AF_INET6, t.hostname.c_str(), &v6) == 1) { + t.is_ipv6 = true; + std::memcpy(t.ip.data(), &v6, sizeof(v6)); + } + } + return t; +} + +bool host_matches_no_proxy(const NormalizedTarget &target, + const std::vector &entries) { + if (target.hostname.empty()) { return false; } + for (const auto &e : entries) { + switch (e.kind) { + case NoProxyKind::Wildcard: return true; + case NoProxyKind::IPv4Cidr: + if (target.is_ipv4 && ip_in_cidr(target.ip, e.net, e.prefix_bits)) { + return true; + } + break; + case NoProxyKind::IPv6Cidr: + if (target.is_ipv6 && ip_in_cidr(target.ip, e.net, e.prefix_bits)) { + return true; + } + break; + case NoProxyKind::HostnameSuffix: + if (target.is_ipv4 || target.is_ipv6) { break; } + if (target.hostname == e.hostname_pattern) { return true; } + // Dot-boundary suffix match: prevents "evilexample.com" from matching + // an entry of "example.com". + if (target.hostname.size() > e.hostname_pattern.size() + 1) { + auto offset = target.hostname.size() - e.hostname_pattern.size(); + if (target.hostname[offset - 1] == '.' && + target.hostname.compare(offset, e.hostname_pattern.size(), + e.hostname_pattern) == 0) { + return true; + } + } + break; + } + } + return false; +} + +template +bool check_and_write_headers(Stream &strm, Headers &headers, + T header_writer, Error &error) { + for (const auto &h : headers) { + if (!detail::fields::is_field_name(h.first) || + !detail::fields::is_field_value(h.second)) { + error = Error::InvalidHeaders; + return false; + } + } + if (header_writer(strm, headers) <= 0) { + error = Error::Write; + return false; + } + return true; +} + } // namespace detail +/* + * Group 2 (continued): detail namespace - SSLSocketStream implementation + */ + +#ifdef CPPHTTPLIB_SSL_ENABLED +namespace detail { + +// SSL socket stream implementation +SSLSocketStream::SSLSocketStream( + socket_t sock, tls::session_t session, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, time_t max_timeout_msec, + std::chrono::time_point start_time) + : sock_(sock), session_(session), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec), + write_timeout_sec_(write_timeout_sec), + write_timeout_usec_(write_timeout_usec), + max_timeout_msec_(max_timeout_msec), start_time_(start_time) { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + // Clear AUTO_RETRY for proper non-blocking I/O timeout handling + // Note: create_session() also clears this, but SSLClient currently + // uses ssl_new() which does not. Until full TLS API migration is complete, + // we need to ensure AUTO_RETRY is cleared here regardless of how the + // SSL session was created. + SSL_clear_mode(static_cast(session), SSL_MODE_AUTO_RETRY); +#endif +} + +SSLSocketStream::~SSLSocketStream() = default; + +bool SSLSocketStream::is_readable() const { + return tls::pending(session_) > 0; +} + +bool SSLSocketStream::wait_readable() const { + if (max_timeout_msec_ <= 0) { + return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + } + + time_t read_timeout_sec; + time_t read_timeout_usec; + calc_actual_timeout(max_timeout_msec_, duration(), read_timeout_sec_, + read_timeout_usec_, read_timeout_sec, read_timeout_usec); + + return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; +} + +bool SSLSocketStream::wait_writable() const { + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && + !tls::is_peer_closed(session_, sock_); +} + +bool SSLSocketStream::is_peer_alive() const { + return !tls::is_peer_closed(session_, sock_); +} + +ssize_t SSLSocketStream::read(char *ptr, size_t size) { + if (tls::pending(session_) > 0) { + tls::TlsError err; + auto ret = tls::read(session_, ptr, size, err); + if (ret == 0 || err.code == tls::ErrorCode::PeerClosed) { + error_ = Error::ConnectionClosed; + } + return ret; + } else if (wait_readable()) { + tls::TlsError err; + auto ret = tls::read(session_, ptr, size, err); + if (ret < 0) { + auto n = 1000; +#ifdef _WIN32 + while (--n >= 0 && (err.code == tls::ErrorCode::WantRead || + (err.code == tls::ErrorCode::SyscallError && + WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err.code == tls::ErrorCode::WantRead) { +#endif + if (tls::pending(session_) > 0) { + return tls::read(session_, ptr, size, err); + } else if (wait_readable()) { + std::this_thread::sleep_for(std::chrono::microseconds{10}); + ret = tls::read(session_, ptr, size, err); + if (ret >= 0) { return ret; } + } else { + break; + } + } + assert(ret < 0); + } else if (ret == 0 || err.code == tls::ErrorCode::PeerClosed) { + error_ = Error::ConnectionClosed; + } + return ret; + } else { + error_ = Error::Timeout; + return -1; + } +} + +ssize_t SSLSocketStream::write(const char *ptr, size_t size) { + if (wait_writable()) { + auto handle_size = + std::min(size, (std::numeric_limits::max)()); + + tls::TlsError err; + auto ret = tls::write(session_, ptr, handle_size, err); + if (ret < 0) { + auto n = 1000; +#ifdef _WIN32 + while (--n >= 0 && (err.code == tls::ErrorCode::WantWrite || + (err.code == tls::ErrorCode::SyscallError && + WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err.code == tls::ErrorCode::WantWrite) { +#endif + if (wait_writable()) { + std::this_thread::sleep_for(std::chrono::microseconds{10}); + ret = tls::write(session_, ptr, handle_size, err); + if (ret >= 0) { return ret; } + } else { + break; + } + } + assert(ret < 0); + } + return ret; + } + return -1; +} + +void SSLSocketStream::get_remote_ip_and_port(std::string &ip, + int &port) const { + detail::get_remote_ip_and_port(sock_, ip, port); +} + +void SSLSocketStream::get_local_ip_and_port(std::string &ip, + int &port) const { + detail::get_local_ip_and_port(sock_, ip, port); +} + +socket_t SSLSocketStream::socket() const { return sock_; } + +time_t SSLSocketStream::duration() const { + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - start_time_) + .count(); +} + +void SSLSocketStream::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + +} // namespace detail +#endif // CPPHTTPLIB_SSL_ENABLED + +/* + * Group 4: Server implementation + */ + // HTTP server implementation Server::Server() - : new_task_queue( - [] { return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT); }) { + : new_task_queue([] { + return new ThreadPool(CPPHTTPLIB_THREAD_POOL_COUNT, + CPPHTTPLIB_THREAD_POOL_MAX_COUNT); + }) { #ifndef _WIN32 signal(SIGPIPE, SIG_IGN); #endif @@ -4617,60 +7020,65 @@ Server::make_matcher(const std::string &pattern) { } Server &Server::Get(const std::string &pattern, Handler handler) { - get_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); - return *this; + return add_handler(get_handlers_, pattern, std::move(handler)); } Server &Server::Post(const std::string &pattern, Handler handler) { - post_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); - return *this; + return add_handler(post_handlers_, pattern, std::move(handler)); } Server &Server::Post(const std::string &pattern, HandlerWithContentReader handler) { - post_handlers_for_content_reader_.emplace_back(make_matcher(pattern), - std::move(handler)); - return *this; + return add_handler(post_handlers_for_content_reader_, pattern, + std::move(handler)); } Server &Server::Put(const std::string &pattern, Handler handler) { - put_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); - return *this; + return add_handler(put_handlers_, pattern, std::move(handler)); } Server &Server::Put(const std::string &pattern, HandlerWithContentReader handler) { - put_handlers_for_content_reader_.emplace_back(make_matcher(pattern), - std::move(handler)); - return *this; + return add_handler(put_handlers_for_content_reader_, pattern, + std::move(handler)); } Server &Server::Patch(const std::string &pattern, Handler handler) { - patch_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); - return *this; + return add_handler(patch_handlers_, pattern, std::move(handler)); } Server &Server::Patch(const std::string &pattern, HandlerWithContentReader handler) { - patch_handlers_for_content_reader_.emplace_back(make_matcher(pattern), - std::move(handler)); - return *this; + return add_handler(patch_handlers_for_content_reader_, pattern, + std::move(handler)); } Server &Server::Delete(const std::string &pattern, Handler handler) { - delete_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); - return *this; + return add_handler(delete_handlers_, pattern, std::move(handler)); } Server &Server::Delete(const std::string &pattern, HandlerWithContentReader handler) { - delete_handlers_for_content_reader_.emplace_back(make_matcher(pattern), - std::move(handler)); - return *this; + return add_handler(delete_handlers_for_content_reader_, pattern, + std::move(handler)); } Server &Server::Options(const std::string &pattern, Handler handler) { - options_handlers_.emplace_back(make_matcher(pattern), std::move(handler)); + return add_handler(options_handlers_, pattern, std::move(handler)); +} + +Server &Server::WebSocket(const std::string &pattern, + WebSocketHandler handler) { + websocket_handlers_.push_back( + {make_matcher(pattern), std::move(handler), nullptr}); + return *this; +} + +Server &Server::WebSocket(const std::string &pattern, + WebSocketHandler handler, + SubProtocolSelector sub_protocol_selector) { + websocket_handlers_.push_back({make_matcher(pattern), std::move(handler), + std::move(sub_protocol_selector)}); return *this; } @@ -4685,7 +7093,18 @@ bool Server::set_mount_point(const std::string &mount_point, if (stat.is_dir()) { std::string mnt = !mount_point.empty() ? mount_point : "/"; if (!mnt.empty() && mnt[0] == '/') { - base_dirs_.push_back({mnt, dir, std::move(headers)}); + std::string resolved_base; + if (detail::canonicalize_path(dir.c_str(), resolved_base)) { +#if defined(_WIN32) + if (resolved_base.back() != '\\' && resolved_base.back() != '/') { + resolved_base += '\\'; + } +#else + if (resolved_base.back() != '/') { resolved_base += '/'; } +#endif + } + base_dirs_.push_back( + {std::move(mnt), dir, std::move(resolved_base), std::move(headers)}); return true; } } @@ -4822,6 +7241,15 @@ Server &Server::set_keep_alive_timeout(time_t sec) { return *this; } +template +Server &Server::set_keep_alive_timeout( + const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t /*usec*/) { + set_keep_alive_timeout(sec); + }); + return *this; +} + Server &Server::set_read_timeout(time_t sec, time_t usec) { read_timeout_sec_ = sec; read_timeout_usec_ = usec; @@ -4845,6 +7273,25 @@ Server &Server::set_payload_max_length(size_t length) { return *this; } +Server &Server::set_websocket_max_missed_pongs(int count) { + websocket_max_missed_pongs_ = count; + return *this; +} + +Server &Server::set_websocket_ping_interval(time_t sec) { + websocket_ping_interval_sec_ = sec; + return *this; +} + +template +Server &Server::set_websocket_ping_interval( + const std::chrono::duration &duration) { + detail::duration_to_sec_and_usec(duration, [&](time_t sec, time_t /*usec*/) { + set_websocket_ping_interval(sec); + }); + return *this; +} + bool Server::bind_to_port(const std::string &host, int port, int socket_flags) { auto ret = bind_internal(host, port, socket_flags); @@ -4872,7 +7319,7 @@ void Server::wait_until_ready() const { } } -void Server::stop() { +void Server::stop() noexcept { if (is_running_) { assert(svr_sock_ != INVALID_SOCKET); std::atomic sock(svr_sock_.exchange(INVALID_SOCKET)); @@ -4970,7 +7417,8 @@ bool Server::write_response_core(Stream &strm, bool close_connection, if (need_apply_ranges) { apply_ranges(req, res, content_type, boundary); } // Prepare additional headers - if (close_connection || req.get_header_value("Connection") == "close") { + if (close_connection || req.get_header_value("Connection") == "close" || + 400 <= res.status) { // Don't leave connections open after errors res.set_header("Connection", "close"); } else { std::string s = "timeout="; @@ -4997,35 +7445,33 @@ bool Server::write_response_core(Stream &strm, bool close_connection, if (post_routing_handler_) { post_routing_handler_(req, res); } // Response line and headers - { - detail::BufferStream bstrm; - if (!detail::write_response_line(bstrm, res.status)) { return false; } - if (!header_writer_(bstrm, res.headers)) { return false; } + detail::BufferStream bstrm; + if (!detail::write_response_line(bstrm, res.status)) { return false; } + if (header_writer_(bstrm, res.headers) <= 0) { return false; } - // Flush buffer - auto &data = bstrm.get_buffer(); - detail::write_data(strm, data.data(), data.size()); + // Combine small body with headers to reduce write syscalls + if (req.method != "HEAD" && !res.body.empty() && !res.content_provider_) { + bstrm.write(res.body.data(), res.body.size()); } - // Body + // Log before writing to avoid race condition with client-side code that + // accesses logger-captured data immediately after receiving the response. + output_log(req, res); + + // Flush buffer + auto &data = bstrm.get_buffer(); + if (!detail::write_data(strm, data.data(), data.size())) { return false; } + + // Streaming body auto ret = true; - if (req.method != "HEAD") { - if (!res.body.empty()) { - if (!detail::write_data(strm, res.body.data(), res.body.size())) { - ret = false; - } - } else if (res.content_provider_) { - if (write_content_with_provider(strm, req, res, boundary, content_type)) { - res.content_provider_success_ = true; - } else { - ret = false; - } + if (req.method != "HEAD" && res.content_provider_) { + if (write_content_with_provider(strm, req, res, boundary, content_type)) { + res.content_provider_success_ = true; + } else { + ret = false; } } - // Log - output_log(req, res); - return ret; } @@ -5057,23 +7503,10 @@ Server::write_content_with_provider(Stream &strm, const Request &req, if (res.is_chunked_content_provider_) { auto type = detail::encoding_type(req, res); - std::unique_ptr compressor; - if (type == detail::EncodingType::Gzip) { -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - compressor = detail::make_unique(); -#endif - } else if (type == detail::EncodingType::Brotli) { -#ifdef CPPHTTPLIB_BROTLI_SUPPORT - compressor = detail::make_unique(); -#endif - } else if (type == detail::EncodingType::Zstd) { -#ifdef CPPHTTPLIB_ZSTD_SUPPORT - compressor = detail::make_unique(); -#endif - } else { + auto compressor = detail::make_compressor(type); + if (!compressor) { compressor = detail::make_unique(); } - assert(compressor != nullptr); return detail::write_content_chunked(strm, res.content_provider_, is_shutting_down, *compressor); @@ -5093,7 +7526,26 @@ bool Server::read_content(Stream &strm, Request &req, Response &res) { strm, req, res, // Regular [&](const char *buf, size_t n) { - if (req.body.size() + n > req.body.max_size()) { return false; } + // Prevent arithmetic overflow when checking sizes. + // Avoid computing (req.body.size() + n) directly because + // adding two unsigned `size_t` values can wrap around and + // produce a small result instead of indicating overflow. + // Instead, check using subtraction: ensure `n` does not + // exceed the remaining capacity `max_size() - size()`. + if (req.body.size() >= req.body.max_size() || + n > req.body.max_size() - req.body.size()) { + return false; + } + + // Limit decompressed body size to payload_max_length_ to protect + // against "zip bomb" attacks where a small compressed payload + // decompresses to a massive size. + if (payload_max_length_ > 0 && + (req.body.size() >= payload_max_length_ || + n > payload_max_length_ - req.body.size())) { + return false; + } + req.body.append(buf, n); return true; }, @@ -5127,7 +7579,8 @@ bool Server::read_content(Stream &strm, Request &req, Response &res) { return true; })) { const auto &content_type = req.get_header_value("Content-Type"); - if (!content_type.find("application/x-www-form-urlencoded")) { + if (detail::extract_media_type(content_type) == + "application/x-www-form-urlencoded") { if (req.body.size() > CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH) { res.status = StatusCode::PayloadTooLarge_413; // NOTE: should be 414? output_error_log(Error::ExceedMaxPayloadSize, &req); @@ -5173,15 +7626,50 @@ bool Server::read_content_core( size_t /*len*/) { return receiver(buf, n); }; } - if (req.method == "DELETE" && !req.has_header("Content-Length")) { + // RFC 9112 ยง6: no Transfer-Encoding and no Content-Length means no body. + // For non-SSL builds we still scan non-persistent connections for stray + // body bytes so the payload limit is enforced (413). On keep-alive, + // pending bytes may be the next request (issue #2450), so skip. +#if !defined(CPPHTTPLIB_SSL_ENABLED) + if (!req.has_header("Content-Length") && + !detail::is_chunked_transfer_encoding(req.headers)) { + if (!detail::is_connection_persistent(req) && payload_max_length_ > 0 && + payload_max_length_ < (std::numeric_limits::max)()) { + auto has_data = strm.is_readable(); + if (!has_data) { + auto s = strm.socket(); + if (s != INVALID_SOCKET) { + has_data = detail::select_read(s, 0, 0) > 0; + } + } + if (has_data) { + auto result = + detail::read_content_without_length(strm, payload_max_length_, out); + if (result == detail::ReadContentResult::PayloadTooLarge) { + res.status = StatusCode::PayloadTooLarge_413; + return false; + } else if (result != detail::ReadContentResult::Success) { + return false; + } + return true; + } + } return true; } +#else + if (!req.has_header("Content-Length") && + !detail::is_chunked_transfer_encoding(req.headers)) { + return true; + } +#endif if (!detail::read_content(strm, req, payload_max_length_, res.status, nullptr, out, true)) { return false; } + req.body_consumed_ = true; + if (req.is_multipart_form_data()) { if (!multipart_form_data_parser.is_valid()) { res.status = StatusCode::BadRequest_400; @@ -5193,7 +7681,7 @@ bool Server::read_content_core( return true; } -bool Server::handle_file_request(const Request &req, Response &res) { +bool Server::handle_file_request(Request &req, Response &res) { for (const auto &entry : base_dirs_) { // Prefix match if (!req.path.compare(0, entry.mount_point.size(), entry.mount_point)) { @@ -5202,6 +7690,18 @@ bool Server::handle_file_request(const Request &req, Response &res) { auto path = entry.base_dir + sub_path; if (path.back() == '/') { path += "index.html"; } + // Defense-in-depth: is_valid_path blocks ".." traversal in the URL, + // but symlinks/junctions can still escape the base directory. + if (!entry.resolved_base_dir.empty()) { + std::string resolved_path; + if (detail::canonicalize_path(path.c_str(), resolved_path) && + !detail::is_path_within_base(resolved_path, + entry.resolved_base_dir)) { + res.status = StatusCode::Forbidden_403; + return true; + } + } + detail::FileStat stat(path); if (stat.is_dir()) { @@ -5214,6 +7714,20 @@ bool Server::handle_file_request(const Request &req, Response &res) { res.set_header(kv.first, kv.second); } + auto etag = detail::compute_etag(stat); + if (!etag.empty()) { res.set_header("ETag", etag); } + + auto mtime = stat.mtime(); + + auto last_modified = detail::file_mtime_to_http_date(mtime); + if (!last_modified.empty()) { + res.set_header("Last-Modified", last_modified); + } + + if (check_if_not_modified(req, res, etag, mtime)) { return true; } + + check_if_range(req, etag, mtime); + auto mm = std::make_shared(path.c_str()); if (!mm->is_open()) { output_error_log(Error::OpenFile, &req); @@ -5243,6 +7757,81 @@ bool Server::handle_file_request(const Request &req, Response &res) { return false; } +bool Server::check_if_not_modified(const Request &req, Response &res, + const std::string &etag, + time_t mtime) const { + // Handle conditional GET: + // 1. If-None-Match takes precedence (RFC 9110 Section 13.1.2) + // 2. If-Modified-Since is checked only when If-None-Match is absent + if (req.has_header("If-None-Match")) { + if (!etag.empty()) { + auto val = req.get_header_value("If-None-Match"); + + // NOTE: We use exact string matching here. This works correctly + // because our server always generates weak ETags (W/"..."), and + // clients typically send back the same ETag they received. + // RFC 9110 Section 8.8.3.2 allows weak comparison for + // If-None-Match, where W/"x" and "x" would match, but this + // simplified implementation requires exact matches. + auto ret = detail::split_find(val.data(), val.data() + val.size(), ',', + [&](const char *b, const char *e) { + auto seg_len = static_cast(e - b); + return (seg_len == 1 && *b == '*') || + (seg_len == etag.size() && + std::equal(b, e, etag.begin())); + }); + + if (ret) { + res.status = StatusCode::NotModified_304; + return true; + } + } + } else if (req.has_header("If-Modified-Since")) { + auto val = req.get_header_value("If-Modified-Since"); + auto t = detail::parse_http_date(val); + + if (t != static_cast(-1) && mtime <= t) { + res.status = StatusCode::NotModified_304; + return true; + } + } + return false; +} + +bool Server::check_if_range(Request &req, const std::string &etag, + time_t mtime) const { + // Handle If-Range for partial content requests (RFC 9110 + // Section 13.1.5). If-Range is only evaluated when Range header is + // present. If the validator matches, serve partial content; otherwise + // serve full content. + if (!req.ranges.empty() && req.has_header("If-Range")) { + auto val = req.get_header_value("If-Range"); + + auto is_valid_range = [&]() { + if (detail::is_strong_etag(val)) { + // RFC 9110 Section 13.1.5: If-Range requires strong ETag + // comparison. + return (!etag.empty() && val == etag); + } else if (detail::is_weak_etag(val)) { + // Weak ETags are not valid for If-Range (RFC 9110 Section 13.1.5) + return false; + } else { + // HTTP-date comparison + auto t = detail::parse_http_date(val); + return (t != static_cast(-1) && mtime <= t); + } + }; + + if (!is_valid_range()) { + // Validator doesn't match: ignore Range and serve full content + req.ranges.clear(); + return false; + } + } + + return true; +} + socket_t Server::create_server_socket(const std::string &host, int port, int socket_flags, @@ -5351,6 +7940,8 @@ bool Server::listen_internal() { detail::set_socket_opt_time(sock, SOL_SOCKET, SO_SNDTIMEO, write_timeout_sec_, write_timeout_usec_); + if (tcp_nodelay_) { set_socket_opt(sock, IPPROTO_TCP, TCP_NODELAY, 1); } + if (!task_queue->enqueue( [this, sock]() { process_and_close_socket(sock); })) { output_error_log(Error::ResourceExhaustion, nullptr); @@ -5381,45 +7972,63 @@ bool Server::routing(Request &req, Response &res, Stream &strm) { if (detail::expect_content(req)) { // Content reader handler { + // Track whether the ContentReader was aborted due to the decompressed + // payload exceeding `payload_max_length_`. + // The user handler runs after the lambda returns, so we must restore the + // 413 status if the handler overwrites it. + bool content_reader_payload_too_large = false; + ContentReader reader( [&](ContentReceiver receiver) { auto result = read_content_with_content_receiver( strm, req, res, std::move(receiver), nullptr, nullptr); - if (!result) { output_error_log(Error::Read, &req); } + if (!result) { + output_error_log(Error::Read, &req); + if (res.status == StatusCode::PayloadTooLarge_413) { + content_reader_payload_too_large = true; + } + } return result; }, [&](FormDataHeader header, ContentReceiver receiver) { auto result = read_content_with_content_receiver( strm, req, res, nullptr, std::move(header), std::move(receiver)); - if (!result) { output_error_log(Error::Read, &req); } + if (!result) { + output_error_log(Error::Read, &req); + if (res.status == StatusCode::PayloadTooLarge_413) { + content_reader_payload_too_large = true; + } + } return result; }); + bool dispatched = false; if (req.method == "POST") { - if (dispatch_request_for_content_reader( - req, res, std::move(reader), - post_handlers_for_content_reader_)) { - return true; - } + dispatched = dispatch_request_for_content_reader( + req, res, std::move(reader), post_handlers_for_content_reader_); } else if (req.method == "PUT") { - if (dispatch_request_for_content_reader( - req, res, std::move(reader), - put_handlers_for_content_reader_)) { - return true; - } + dispatched = dispatch_request_for_content_reader( + req, res, std::move(reader), put_handlers_for_content_reader_); } else if (req.method == "PATCH") { - if (dispatch_request_for_content_reader( - req, res, std::move(reader), - patch_handlers_for_content_reader_)) { - return true; - } + dispatched = dispatch_request_for_content_reader( + req, res, std::move(reader), patch_handlers_for_content_reader_); } else if (req.method == "DELETE") { - if (dispatch_request_for_content_reader( - req, res, std::move(reader), - delete_handlers_for_content_reader_)) { - return true; + dispatched = dispatch_request_for_content_reader( + req, res, std::move(reader), delete_handlers_for_content_reader_); + } + + if (dispatched) { + if (content_reader_payload_too_large) { + // Enforce the limit: override any status the handler may have set + // and return false so the error path sends a plain 413 response. + res.status = StatusCode::PayloadTooLarge_413; + res.body.clear(); + res.content_length_ = 0; + res.content_provider_ = nullptr; + return false; } + return true; } } @@ -5508,12 +8117,9 @@ void Server::apply_ranges(const Request &req, Response &res, if (res.content_provider_) { if (res.is_chunked_content_provider_) { res.set_header("Transfer-Encoding", "chunked"); - if (type == detail::EncodingType::Gzip) { - res.set_header("Content-Encoding", "gzip"); - } else if (type == detail::EncodingType::Brotli) { - res.set_header("Content-Encoding", "br"); - } else if (type == detail::EncodingType::Zstd) { - res.set_header("Content-Encoding", "zstd"); + if (type != detail::EncodingType::None) { + res.set_header("Content-Encoding", detail::encoding_name(type)); + res.set_header("Vary", "Accept-Encoding"); } } } @@ -5543,27 +8149,7 @@ void Server::apply_ranges(const Request &req, Response &res, if (type != detail::EncodingType::None) { output_pre_compression_log(req, res); - std::unique_ptr compressor; - std::string content_encoding; - - if (type == detail::EncodingType::Gzip) { -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - compressor = detail::make_unique(); - content_encoding = "gzip"; -#endif - } else if (type == detail::EncodingType::Brotli) { -#ifdef CPPHTTPLIB_BROTLI_SUPPORT - compressor = detail::make_unique(); - content_encoding = "br"; -#endif - } else if (type == detail::EncodingType::Zstd) { -#ifdef CPPHTTPLIB_ZSTD_SUPPORT - compressor = detail::make_unique(); - content_encoding = "zstd"; -#endif - } - - if (compressor) { + if (auto compressor = detail::make_compressor(type)) { std::string compressed; if (compressor->compress(res.body.data(), res.body.size(), true, [&](const char *data, size_t data_len) { @@ -5571,7 +8157,8 @@ void Server::apply_ranges(const Request &req, Response &res, return true; })) { res.body.swap(compressed); - res.set_header("Content-Encoding", content_encoding); + res.set_header("Content-Encoding", detail::encoding_name(type)); + res.set_header("Vary", "Accept-Encoding"); } } } @@ -5612,6 +8199,11 @@ get_client_ip(const std::string &x_forwarded_for, ip_list.emplace_back(std::string(b + r.first, b + r.second)); }); + // A malformed X-Forwarded-For (empty, comma-only, whitespace-only) yields + // no segments. Signal "no client IP derived" with an empty string so the + // caller can fall back to the connection-level remote address. + if (ip_list.empty()) { return std::string(); } + for (size_t i = 0; i < ip_list.size(); ++i) { auto ip = ip_list[i]; @@ -5639,7 +8231,8 @@ Server::process_request(Stream &strm, const std::string &remote_addr, int remote_port, const std::string &local_addr, int local_port, bool close_connection, bool &connection_closed, - const std::function &setup_request) { + const std::function &setup_request, + bool *websocket_upgraded) { std::array buf{}; detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); @@ -5649,22 +8242,15 @@ Server::process_request(Stream &strm, const std::string &remote_addr, Request req; req.start_time_ = std::chrono::steady_clock::now(); + req.remote_addr = remote_addr; + req.remote_port = remote_port; + req.local_addr = local_addr; + req.local_port = local_port; Response res; res.version = "HTTP/1.1"; res.headers = default_headers_; -#ifdef __APPLE__ - // Socket file descriptor exceeded FD_SETSIZE... - if (strm.socket() >= FD_SETSIZE) { - Headers dummy; - detail::read_headers(strm, dummy); - res.status = StatusCode::InternalServerError_500; - output_error_log(Error::ExceedMaxSocketDescriptorCount, &req); - return write_response(strm, close_connection, req, res); - } -#endif - // Request line and headers if (!parse_request_line(line_reader.ptr(), req)) { res.status = StatusCode::BadRequest_400; @@ -5679,10 +8265,19 @@ Server::process_request(Stream &strm, const std::string &remote_addr, return write_response(strm, close_connection, req, res); } + // RFC 9112 ยง6.3: Reject requests with both a non-zero Content-Length and + // any Transfer-Encoding to prevent request smuggling. Content-Length: 0 is + // tolerated for compatibility with existing clients. + if (req.get_header_value_u64("Content-Length") > 0 && + req.has_header("Transfer-Encoding")) { + connection_closed = true; + res.status = StatusCode::BadRequest_400; + return write_response(strm, close_connection, req, res); + } + // Check if the request URI doesn't exceed the limit if (req.target.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { - Headers dummy; - detail::read_headers(strm, dummy); + connection_closed = true; res.status = StatusCode::UriTooLong_414; output_error_log(Error::ExceedUriMaxLength, &req); return write_response(strm, close_connection, req, res); @@ -5699,7 +8294,8 @@ Server::process_request(Stream &strm, const std::string &remote_addr, if (!trusted_proxies_.empty() && req.has_header("X-Forwarded-For")) { auto x_forwarded_for = req.get_header_value("X-Forwarded-For"); - req.remote_addr = get_client_ip(x_forwarded_for, trusted_proxies_); + auto derived = get_client_ip(x_forwarded_for, trusted_proxies_); + req.remote_addr = derived.empty() ? remote_addr : derived; } else { req.remote_addr = remote_addr; } @@ -5711,6 +8307,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr, if (req.has_header("Accept")) { const auto &accept_header = req.get_header_value("Accept"); if (!detail::parse_accept_header(accept_header, req.accept_content_types)) { + connection_closed = true; res.status = StatusCode::BadRequest_400; output_error_log(Error::HTTPParsing, &req); return write_response(strm, close_connection, req, res); @@ -5720,6 +8317,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr, if (req.has_header("Range")) { const auto &range_header_value = req.get_header_value("Range"); if (!detail::parse_range_header(range_header_value, req.ranges)) { + connection_closed = true; res.status = StatusCode::RangeNotSatisfiable_416; output_error_log(Error::InvalidRangeHeader, &req); return write_response(strm, close_connection, req, res); @@ -5751,6 +8349,78 @@ Server::process_request(Stream &strm, const std::string &remote_addr, return !detail::is_socket_alive(sock); }; + // WebSocket upgrade + // Check pre_routing_handler_ before upgrading so that authentication + // and other middleware can reject the request with an HTTP response + // (e.g., 401) before the protocol switches. + if (detail::is_websocket_upgrade(req)) { + if (pre_routing_handler_ && + pre_routing_handler_(req, res) == HandlerResponse::Handled) { + if (res.status == -1) { res.status = StatusCode::OK_200; } + return write_response(strm, close_connection, req, res); + } + // Find matching WebSocket handler + for (const auto &entry : websocket_handlers_) { + if (entry.matcher->match(req)) { + // Compute accept key + auto client_key = req.get_header_value("Sec-WebSocket-Key"); + auto accept_key = detail::websocket_accept_key(client_key); + + // Negotiate subprotocol + std::string selected_subprotocol; + if (entry.sub_protocol_selector) { + auto protocol_header = req.get_header_value("Sec-WebSocket-Protocol"); + if (!protocol_header.empty()) { + std::vector protocols; + std::istringstream iss(protocol_header); + std::string token; + while (std::getline(iss, token, ',')) { + // Trim whitespace + auto start = token.find_first_not_of(' '); + auto end = token.find_last_not_of(' '); + if (start != std::string::npos) { + protocols.push_back(token.substr(start, end - start + 1)); + } + } + selected_subprotocol = entry.sub_protocol_selector(protocols); + } + } + + // Send 101 Switching Protocols + std::string handshake_response = "HTTP/1.1 101 Switching Protocols\r\n" + "Upgrade: websocket\r\n" + "Connection: Upgrade\r\n" + "Sec-WebSocket-Accept: " + + accept_key + "\r\n"; + if (!selected_subprotocol.empty()) { + if (!detail::fields::is_field_value(selected_subprotocol)) { + return false; + } + handshake_response += + "Sec-WebSocket-Protocol: " + selected_subprotocol + "\r\n"; + } + handshake_response += "\r\n"; + if (strm.write(handshake_response.data(), handshake_response.size()) < + 0) { + return false; + } + + connection_closed = true; + if (websocket_upgraded) { *websocket_upgraded = true; } + + { + // Use WebSocket-specific read timeout instead of HTTP timeout + strm.set_read_timeout(CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND, 0); + ws::WebSocket ws(strm, req, true, websocket_ping_interval_sec_, + websocket_max_missed_pongs_); + entry.handler(req, ws); + } + return true; + } + } + // No matching handler - fall through to 404 + } + // Routing auto routed = false; #ifdef CPPHTTPLIB_NO_EXCEPTIONS @@ -5758,23 +8428,13 @@ Server::process_request(Stream &strm, const std::string &remote_addr, #else try { routed = routing(req, res, strm); - } catch (std::exception &e) { + } catch (std::exception &) { if (exception_handler_) { auto ep = std::current_exception(); exception_handler_(req, res, ep); routed = true; } else { res.status = StatusCode::InternalServerError_500; - std::string val; - auto s = e.what(); - for (size_t i = 0; s[i]; i++) { - switch (s[i]) { - case '\r': val += "\\r"; break; - case '\n': val += "\\n"; break; - default: val += s[i]; break; - } - } - res.set_header("EXCEPTION_WHAT", val); } } catch (...) { if (exception_handler_) { @@ -5783,10 +8443,10 @@ Server::process_request(Stream &strm, const std::string &remote_addr, routed = true; } else { res.status = StatusCode::InternalServerError_500; - res.set_header("EXCEPTION_WHAT", "UNKNOWN"); } } #endif + auto ret = false; if (routed) { if (res.status == -1) { res.status = req.ranges.empty() ? StatusCode::OK_200 @@ -5794,6 +8454,7 @@ Server::process_request(Stream &strm, const std::string &remote_addr, } // Serve file content by using a content provider + auto file_open_error = false; if (!res.file_content_path_.empty()) { const auto &path = res.file_content_path_; auto mm = std::make_shared(path.c_str()); @@ -5803,37 +8464,52 @@ Server::process_request(Stream &strm, const std::string &remote_addr, res.content_provider_ = nullptr; res.status = StatusCode::NotFound_404; output_error_log(Error::OpenFile, &req); - return write_response(strm, close_connection, req, res); - } + file_open_error = true; + } else { + auto content_type = res.file_content_content_type_; + if (content_type.empty()) { + content_type = detail::find_content_type( + path, file_extension_and_mimetype_map_, default_file_mimetype_); + } - auto content_type = res.file_content_content_type_; - if (content_type.empty()) { - content_type = detail::find_content_type( - path, file_extension_and_mimetype_map_, default_file_mimetype_); + res.set_content_provider( + mm->size(), content_type, + [mm](size_t offset, size_t length, DataSink &sink) -> bool { + sink.write(mm->data() + offset, length); + return true; + }); } - - res.set_content_provider( - mm->size(), content_type, - [mm](size_t offset, size_t length, DataSink &sink) -> bool { - sink.write(mm->data() + offset, length); - return true; - }); } - if (detail::range_error(req, res)) { + if (file_open_error) { + ret = write_response(strm, close_connection, req, res); + } else if (detail::range_error(req, res)) { res.body.clear(); res.content_length_ = 0; res.content_provider_ = nullptr; res.status = StatusCode::RangeNotSatisfiable_416; - return write_response(strm, close_connection, req, res); + ret = write_response(strm, close_connection, req, res); + } else { + ret = write_response_with_content(strm, close_connection, req, res); } - - return write_response_with_content(strm, close_connection, req, res); } else { if (res.status == -1) { res.status = StatusCode::NotFound_404; } - - return write_response(strm, close_connection, req, res); + ret = write_response(strm, close_connection, req, res); } + + // Drain any unconsumed framed body to prevent request smuggling on + // keep-alive. Without framing there is no body to drain โ€” reading would + // consume the next request (issue #2450). + if (!req.body_consumed_ && detail::has_framed_body(req)) { + int dummy_status; + if (!detail::read_content( + strm, req, payload_max_length_, dummy_status, nullptr, + [](const char *, size_t, size_t, size_t) { return true; }, false)) { + connection_closed = true; + } + } + + return ret; } bool Server::is_valid() const { return true; } @@ -5847,6 +8523,7 @@ bool Server::process_and_close_socket(socket_t sock) { int local_port = 0; detail::get_local_ip_and_port(sock, local_addr, local_port); + bool websocket_upgraded = false; auto ret = detail::process_server_socket( svr_sock_, sock, keep_alive_max_count_, keep_alive_timeout_sec_, read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, @@ -5854,7 +8531,7 @@ bool Server::process_and_close_socket(socket_t sock) { [&](Stream &strm, bool close_connection, bool &connection_closed) { return process_request(strm, remote_addr, remote_port, local_addr, local_port, close_connection, connection_closed, - nullptr); + nullptr, &websocket_upgraded); }); detail::shutdown_socket(sock); @@ -5885,6 +8562,9 @@ void Server::output_error_log(const Error &err, } } +/* + * Group 5: ClientImpl and Client (Universal) implementation + */ // HTTP client implementation ClientImpl::ClientImpl(const std::string &host) : ClientImpl(host, 80, std::string(), std::string()) {} @@ -5896,7 +8576,6 @@ ClientImpl::ClientImpl(const std::string &host, int port, const std::string &client_cert_path, const std::string &client_key_path) : host_(detail::escape_abstract_namespace_unix_domain(host)), port_(port), - host_and_port_(detail::make_host_and_port_string(host_, port, is_ssl())), client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} ClientImpl::~ClientImpl() { @@ -5929,10 +8608,6 @@ void ClientImpl::copy_settings(const ClientImpl &rhs) { basic_auth_username_ = rhs.basic_auth_username_; basic_auth_password_ = rhs.basic_auth_password_; bearer_token_auth_token_ = rhs.bearer_token_auth_token_; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - digest_auth_username_ = rhs.digest_auth_username_; - digest_auth_password_ = rhs.digest_auth_password_; -#endif keep_alive_ = rhs.keep_alive_; follow_location_ = rhs.follow_location_; path_encode_ = rhs.path_encode_; @@ -5942,32 +8617,49 @@ void ClientImpl::copy_settings(const ClientImpl &rhs) { socket_options_ = rhs.socket_options_; compress_ = rhs.compress_; decompress_ = rhs.decompress_; + payload_max_length_ = rhs.payload_max_length_; + has_payload_max_length_ = rhs.has_payload_max_length_; interface_ = rhs.interface_; proxy_host_ = rhs.proxy_host_; proxy_port_ = rhs.proxy_port_; proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_; proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_; proxy_bearer_token_auth_token_ = rhs.proxy_bearer_token_auth_token_; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; - proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; -#endif -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - ca_cert_file_path_ = rhs.ca_cert_file_path_; - ca_cert_dir_path_ = rhs.ca_cert_dir_path_; - ca_cert_store_ = rhs.ca_cert_store_; -#endif -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - server_certificate_verification_ = rhs.server_certificate_verification_; - server_hostname_verification_ = rhs.server_hostname_verification_; - server_certificate_verifier_ = rhs.server_certificate_verifier_; -#endif + no_proxy_entries_ = rhs.no_proxy_entries_; logger_ = rhs.logger_; error_logger_ = rhs.error_logger_; + +#ifdef CPPHTTPLIB_SSL_ENABLED + digest_auth_username_ = rhs.digest_auth_username_; + digest_auth_password_ = rhs.digest_auth_password_; + proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; + proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; + ca_cert_file_path_ = rhs.ca_cert_file_path_; + ca_cert_dir_path_ = rhs.ca_cert_dir_path_; + server_certificate_verification_ = rhs.server_certificate_verification_; + server_hostname_verification_ = rhs.server_hostname_verification_; +#endif +} + +bool +ClientImpl::is_proxy_enabled_for_host(const std::string &host) const { + if (proxy_host_.empty() || proxy_port_ == -1) { return false; } + if (no_proxy_entries_.empty()) { return true; } + // host_ is const so its normalized form is invariant; cache it. The + // cross-host path (setup_redirect_client passing next_host) re-normalizes. + if (host == host_) { + if (!host_normalized_valid_) { + host_normalized_ = detail::normalize_target(host_); + host_normalized_valid_ = true; + } + return !detail::host_matches_no_proxy(host_normalized_, no_proxy_entries_); + } + auto target = detail::normalize_target(host); + return !detail::host_matches_no_proxy(target, no_proxy_entries_); } socket_t ClientImpl::create_client_socket(Error &error) const { - if (!proxy_host_.empty() && proxy_port_ != -1) { + if (is_proxy_enabled_for_host(host_)) { return detail::create_client_socket( proxy_host_, std::string(), proxy_port_, address_family_, tcp_nodelay_, ipv6_v6only_, socket_options_, connection_timeout_sec_, @@ -5995,6 +8687,17 @@ bool ClientImpl::create_and_connect_socket(Socket &socket, return true; } +bool ClientImpl::ensure_socket_connection(Socket &socket, Error &error) { + return create_and_connect_socket(socket, error); +} + +bool ClientImpl::setup_proxy_connection( + Socket & /*socket*/, + std::chrono::time_point /*start_time*/, + Response & /*res*/, bool & /*success*/, Error & /*error*/) { + return true; +} + void ClientImpl::shutdown_ssl(Socket & /*socket*/, bool /*shutdown_gracefully*/) { // If there are any requests in flight from threads other than us, then it's @@ -6019,16 +8722,24 @@ void ClientImpl::close_socket(Socket &socket) { socket_requests_are_from_thread_ == std::this_thread::get_id()); // It is also a bug if this happens while SSL is still active -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED assert(socket.ssl == nullptr); #endif + if (socket.sock == INVALID_SOCKET) { return; } detail::close_socket(socket.sock); socket.sock = INVALID_SOCKET; } +void ClientImpl::disconnect(bool gracefully) { + shutdown_ssl(socket_, gracefully); + shutdown_socket(socket_); + close_socket(socket_); +} + bool ClientImpl::read_response_line(Stream &strm, const Request &req, - Response &res) const { + Response &res, + bool skip_100_continue) const { std::array buf{}; detail::stream_line_reader line_reader(strm, buf.data(), buf.size()); @@ -6049,8 +8760,8 @@ bool ClientImpl::read_response_line(Stream &strm, const Request &req, res.status = std::stoi(std::string(m[2])); res.reason = std::string(m[3]); - // Ignore '100 Continue' - while (res.status == StatusCode::Continue_100) { + // Ignore '100 Continue' (only when not using Expect: 100-continue explicitly) + while (skip_100_continue && res.status == StatusCode::Continue_100) { if (!line_reader.getline()) { return false; } // CRLF if (!line_reader.getline()) { return false; } // next response line @@ -6069,6 +8780,8 @@ bool ClientImpl::send(Request &req, Response &res, Error &error) { if (error == Error::SSLPeerCouldBeClosed_) { assert(!ret); ret = send_(req, res, error); + // If still failing with SSLPeerCouldBeClosed_, convert to Read error + if (error == Error::SSLPeerCouldBeClosed_) { error = Error::Read; } } return ret; } @@ -6086,51 +8799,34 @@ bool ClientImpl::send_(Request &req, Response &res, Error &error) { if (socket_.is_open()) { is_alive = detail::is_socket_alive(socket_.sock); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED if (is_alive && is_ssl()) { - if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { + if (tls::is_peer_closed(socket_.ssl, socket_.sock)) { is_alive = false; } } #endif if (!is_alive) { - // Attempt to avoid sigpipe by shutting down non-gracefully if it - // seems like the other side has already closed the connection Also, - // there cannot be any requests in flight from other threads since we - // locked request_mutex_, so safe to close everything immediately - const bool shutdown_gracefully = false; - shutdown_ssl(socket_, shutdown_gracefully); - shutdown_socket(socket_); - close_socket(socket_); + // Peer seems gone โ€” non-graceful shutdown to avoid SIGPIPE. + disconnect(/*gracefully=*/false); } } if (!is_alive) { - if (!create_and_connect_socket(socket_, error)) { + if (!ensure_socket_connection(socket_, error)) { output_error_log(error, &req); return false; } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - // TODO: refactoring - if (is_ssl()) { - auto &scli = static_cast(*this); - if (!proxy_host_.empty() && proxy_port_ != -1) { - auto success = false; - if (!scli.connect_with_proxy(socket_, req.start_time_, res, success, - error)) { - if (!success) { output_error_log(error, &req); } - return success; - } - } - - if (!scli.initialize_ssl(socket_, error)) { - output_error_log(error, &req); - return false; + { + auto success = true; + if (!setup_proxy_connection(socket_, req.start_time_, res, success, + error)) { + if (!success) { output_error_log(error, &req); } + return success; } } -#endif } // Mark the current socket as being in use so that it cannot be closed by @@ -6163,9 +8859,7 @@ bool ClientImpl::send_(Request &req, Response &res, Error &error) { if (socket_should_be_closed_when_request_is_done_ || close_connection || !ret) { - shutdown_ssl(socket_, true); - shutdown_socket(socket_); - close_socket(socket_); + disconnect(/*gracefully=*/true); } }); @@ -6192,14 +8886,371 @@ Result ClientImpl::send_(Request &&req) { auto res = detail::make_unique(); auto error = Error::Success; auto ret = send(req, *res, error); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED return Result{ret ? std::move(res) : nullptr, error, std::move(req.headers), - last_ssl_error_, last_openssl_error_}; + last_ssl_error_, last_backend_error_}; #else return Result{ret ? std::move(res) : nullptr, error, std::move(req.headers)}; #endif } +void ClientImpl::prepare_default_headers(Request &r, bool for_stream, + const std::string &ct) { + (void)for_stream; + for (const auto &header : default_headers_) { + if (!r.has_header(header.first)) { r.headers.insert(header); } + } + + if (!r.has_header("Host")) { + if (address_family_ == AF_UNIX) { + r.headers.emplace("Host", "localhost"); + } else { + r.headers.emplace( + "Host", detail::make_host_and_port_string(host_, port_, is_ssl())); + } + } + + if (!r.has_header("Accept")) { r.headers.emplace("Accept", "*/*"); } + + if (!r.content_receiver) { + if (!r.has_header("Accept-Encoding")) { + std::string accept_encoding; +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + accept_encoding = "br"; +#endif +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (!accept_encoding.empty()) { accept_encoding += ", "; } + accept_encoding += "gzip, deflate"; +#endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + if (!accept_encoding.empty()) { accept_encoding += ", "; } + accept_encoding += "zstd"; +#endif + r.set_header("Accept-Encoding", accept_encoding); + } + +#ifndef CPPHTTPLIB_NO_DEFAULT_USER_AGENT + if (!r.has_header("User-Agent")) { + auto agent = std::string("cpp-httplib/") + CPPHTTPLIB_VERSION; + r.set_header("User-Agent", agent); + } +#endif + } + + if (!r.body.empty()) { + if (!ct.empty() && !r.has_header("Content-Type")) { + r.headers.emplace("Content-Type", ct); + } + if (!r.has_header("Content-Length")) { + r.headers.emplace("Content-Length", std::to_string(r.body.size())); + } + } +} + +ClientImpl::StreamHandle +ClientImpl::open_stream(const std::string &method, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, + const std::string &content_type) { + StreamHandle handle; + handle.response = detail::make_unique(); + handle.error = Error::Success; + + auto query_path = params.empty() ? path : append_query_params(path, params); + handle.connection_ = detail::make_unique(); + + { + std::lock_guard guard(socket_mutex_); + + auto is_alive = false; + if (socket_.is_open()) { + is_alive = detail::is_socket_alive(socket_.sock); +#ifdef CPPHTTPLIB_SSL_ENABLED + if (is_alive && is_ssl()) { + if (tls::is_peer_closed(socket_.ssl, socket_.sock)) { + is_alive = false; + } + } +#endif + if (!is_alive) { disconnect(/*gracefully=*/false); } + } + + if (!is_alive) { + if (!ensure_socket_connection(socket_, handle.error)) { + handle.response.reset(); + return handle; + } + + { + auto success = true; + auto start_time = std::chrono::steady_clock::now(); + if (!setup_proxy_connection(socket_, start_time, *handle.response, + success, handle.error)) { + if (!success) { handle.response.reset(); } + return handle; + } + } + } + + transfer_socket_ownership_to_handle(handle); + } + +#ifdef CPPHTTPLIB_SSL_ENABLED + if (is_ssl() && handle.connection_->session) { + handle.socket_stream_ = detail::make_unique( + handle.connection_->sock, handle.connection_->session, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_); + } else { + handle.socket_stream_ = detail::make_unique( + handle.connection_->sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_); + } +#else + handle.socket_stream_ = detail::make_unique( + handle.connection_->sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_); +#endif + handle.stream_ = handle.socket_stream_.get(); + + Request req; + req.method = method; + req.path = query_path; + req.headers = headers; + req.body = body; + + prepare_default_headers(req, true, content_type); + + auto &strm = *handle.stream_; + if (detail::write_request_line(strm, req.method, req.path) < 0) { + handle.error = Error::Write; + handle.response.reset(); + return handle; + } + + if (!detail::check_and_write_headers(strm, req.headers, header_writer_, + handle.error)) { + handle.response.reset(); + return handle; + } + + if (!body.empty()) { + if (strm.write(body.data(), body.size()) < 0) { + handle.error = Error::Write; + handle.response.reset(); + return handle; + } + } + + if (!read_response_line(strm, req, *handle.response) || + !detail::read_headers(strm, handle.response->headers)) { + handle.error = Error::Read; + handle.response.reset(); + return handle; + } + + handle.body_reader_.stream = handle.stream_; + handle.body_reader_.payload_max_length = payload_max_length_; + + if (handle.response->has_header("Content-Length")) { + bool is_invalid = false; + auto content_length = detail::get_header_value_u64( + handle.response->headers, "Content-Length", 0, 0, is_invalid); + if (is_invalid) { + handle.error = Error::Read; + handle.response.reset(); + return handle; + } + handle.body_reader_.has_content_length = true; + handle.body_reader_.content_length = content_length; + } + + auto transfer_encoding = + handle.response->get_header_value("Transfer-Encoding"); + handle.body_reader_.chunked = (transfer_encoding == "chunked"); + + auto content_encoding = handle.response->get_header_value("Content-Encoding"); + if (!content_encoding.empty()) { + handle.decompressor_ = detail::create_decompressor(content_encoding); + } + + return handle; +} + +ssize_t ClientImpl::StreamHandle::read(char *buf, size_t len) { + if (!is_valid() || !response) { return -1; } + + if (decompressor_) { return read_with_decompression(buf, len); } + auto n = detail::read_body_content(stream_, body_reader_, buf, len); + + if (n <= 0 && body_reader_.chunked && !trailers_parsed_ && stream_) { + trailers_parsed_ = true; + if (body_reader_.chunked_decoder) { + if (!body_reader_.chunked_decoder->parse_trailers_into( + response->trailers, response->headers)) { + return n; + } + } else { + detail::ChunkedDecoder dec(*stream_); + if (!dec.parse_trailers_into(response->trailers, response->headers)) { + return n; + } + } + } + + return n; +} + +ssize_t ClientImpl::StreamHandle::read_with_decompression(char *buf, + size_t len) { + if (decompress_offset_ < decompress_buffer_.size()) { + auto available = decompress_buffer_.size() - decompress_offset_; + auto to_copy = (std::min)(len, available); + std::memcpy(buf, decompress_buffer_.data() + decompress_offset_, to_copy); + decompress_offset_ += to_copy; + decompressed_bytes_read_ += to_copy; + return static_cast(to_copy); + } + + decompress_buffer_.clear(); + decompress_offset_ = 0; + + constexpr size_t kDecompressionBufferSize = 8192; + char compressed_buf[kDecompressionBufferSize]; + + while (true) { + auto n = detail::read_body_content(stream_, body_reader_, compressed_buf, + sizeof(compressed_buf)); + + if (n <= 0) { return n; } + + bool decompress_ok = decompressor_->decompress( + compressed_buf, static_cast(n), + [this](const char *data, size_t data_len) { + decompress_buffer_.append(data, data_len); + auto limit = body_reader_.payload_max_length; + if (decompressed_bytes_read_ + decompress_buffer_.size() > limit) { + return false; + } + return true; + }); + + if (!decompress_ok) { + body_reader_.last_error = Error::Read; + return -1; + } + + if (!decompress_buffer_.empty()) { break; } + } + + auto to_copy = (std::min)(len, decompress_buffer_.size()); + std::memcpy(buf, decompress_buffer_.data(), to_copy); + decompress_offset_ = to_copy; + decompressed_bytes_read_ += to_copy; + return static_cast(to_copy); +} + +void ClientImpl::StreamHandle::parse_trailers_if_needed() { + if (!response || !stream_ || !body_reader_.chunked || trailers_parsed_) { + return; + } + + trailers_parsed_ = true; + + const auto bufsiz = 128; + char line_buf[bufsiz]; + detail::stream_line_reader line_reader(*stream_, line_buf, bufsiz); + + if (!line_reader.getline()) { return; } + + if (!detail::parse_trailers(line_reader, response->trailers, + response->headers)) { + return; + } +} + +namespace detail { + +ChunkedDecoder::ChunkedDecoder(Stream &s) : strm(s) {} + +ssize_t ChunkedDecoder::read_payload(char *buf, size_t len, + size_t &out_chunk_offset, + size_t &out_chunk_total) { + if (finished) { return 0; } + + if (chunk_remaining == 0) { + stream_line_reader lr(strm, line_buf, sizeof(line_buf)); + if (!lr.getline()) { return -1; } + + // RFC 9112 ยง7.1: chunk-size = 1*HEXDIG + const char *p = lr.ptr(); + int v = 0; + if (!is_hex(*p, v)) { return -1; } + + size_t chunk_len = 0; + constexpr size_t chunk_len_max = (std::numeric_limits::max)(); + for (; is_hex(*p, v); ++p) { + if (chunk_len > (chunk_len_max >> 4)) { return -1; } + chunk_len = (chunk_len << 4) | static_cast(v); + } + + while (is_space_or_tab(*p)) { + ++p; + } + if (*p != '\0' && *p != ';' && *p != '\r' && *p != '\n') { return -1; } + + if (chunk_len == 0) { + chunk_remaining = 0; + finished = true; + out_chunk_offset = 0; + out_chunk_total = 0; + return 0; + } + + chunk_remaining = chunk_len; + last_chunk_total = chunk_remaining; + last_chunk_offset = 0; + } + + auto to_read = (std::min)(chunk_remaining, len); + auto n = strm.read(buf, to_read); + if (n <= 0) { return -1; } + + auto offset_before = last_chunk_offset; + last_chunk_offset += static_cast(n); + chunk_remaining -= static_cast(n); + + out_chunk_offset = offset_before; + out_chunk_total = last_chunk_total; + + if (chunk_remaining == 0) { + stream_line_reader lr(strm, line_buf, sizeof(line_buf)); + if (!lr.getline()) { return -1; } + if (std::strcmp(lr.ptr(), "\r\n") != 0) { return -1; } + } + + return n; +} + +bool ChunkedDecoder::parse_trailers_into(Headers &dest, + const Headers &src_headers) { + stream_line_reader lr(strm, line_buf, sizeof(line_buf)); + if (!lr.getline()) { return false; } + return parse_trailers(lr, dest, src_headers); +} + +} // namespace detail + +void +ClientImpl::transfer_socket_ownership_to_handle(StreamHandle &handle) { + handle.connection_->sock = socket_.sock; +#ifdef CPPHTTPLIB_SSL_ENABLED + handle.connection_->session = socket_.ssl; + socket_.ssl = nullptr; +#endif + socket_.sock = INVALID_SOCKET; +} + bool ClientImpl::handle_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error) { @@ -6213,11 +9264,13 @@ bool ClientImpl::handle_request(Stream &strm, Request &req, bool ret; - if (!is_ssl() && !proxy_host_.empty() && proxy_port_ != -1) { + if (!is_ssl() && is_proxy_enabled_for_host(host_)) { auto req2 = req; - req2.path = "http://" + host_and_port_ + req.path; + req2.path = "http://" + + detail::make_host_and_port_string(host_, port_, false) + + req.path; ret = process_request(strm, req2, res, close_connection, error); - req = req2; + req = std::move(req2); req.path = req_save.path; } else { ret = process_request(strm, req, res, close_connection, error); @@ -6227,7 +9280,7 @@ bool ClientImpl::handle_request(Stream &strm, Request &req, if (res.get_header_value("Connection") == "close" || (res.version == "HTTP/1.0" && res.reason != "Connection established")) { - // TODO this requires a not-entirely-obvious chain of calls to be correct + // NOTE: this requires a not-entirely-obvious chain of calls to be correct // for this to be safe. // This is safe to call because handle_request is only called by send_ @@ -6235,21 +9288,27 @@ bool ClientImpl::handle_request(Stream &strm, Request &req, // to call it from a different thread since it's a thread-safety issue // to do these things to the socket if another thread is using the socket. std::lock_guard guard(socket_mutex_); - shutdown_ssl(socket_, true); - shutdown_socket(socket_); - close_socket(socket_); + disconnect(/*gracefully=*/true); } if (300 < res.status && res.status < 400 && follow_location_) { - req = req_save; + req = std::move(req_save); ret = redirect(req, res, error); } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED if ((res.status == StatusCode::Unauthorized_401 || res.status == StatusCode::ProxyAuthenticationRequired_407) && req.authorization_count_ < 5) { auto is_proxy = res.status == StatusCode::ProxyAuthenticationRequired_407; + + // Only retry when the 407 actually came from a proxy hop: plain HTTP + // through an enabled proxy. HTTPS via CONNECT tunnels the 407 from the + // origin (#2457); direct/bypassed origins have no proxy hop at all. + if (is_proxy && !(!is_ssl() && is_proxy_enabled_for_host(host_))) { + return ret; + } + const auto &username = is_proxy ? proxy_digest_auth_username_ : digest_auth_username_; const auto &password = @@ -6269,7 +9328,7 @@ bool ClientImpl::handle_request(Stream &strm, Request &req, Response new_res; ret = send(new_req, new_res, error); - if (ret) { res = new_res; } + if (ret) { res = std::move(new_res); } } } } @@ -6288,24 +9347,25 @@ bool ClientImpl::redirect(Request &req, Response &res, Error &error) { auto location = res.get_header_value("location"); if (location.empty()) { return false; } - thread_local const std::regex re( - R"((?:(https?):)?(?://(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)?([^?#]*)(\?[^#]*)?(?:#.*)?)"); + detail::UrlComponents uc; + if (!detail::parse_url(location, uc)) { return false; } - std::smatch m; - if (!std::regex_match(location, m, re)) { return false; } + // Only follow http/https redirects + if (!uc.scheme.empty() && uc.scheme != "http" && uc.scheme != "https") { + return false; + } auto scheme = is_ssl() ? "https" : "http"; - auto next_scheme = m[1].str(); - auto next_host = m[2].str(); - if (next_host.empty()) { next_host = m[3].str(); } - auto port_str = m[4].str(); - auto next_path = m[5].str(); - auto next_query = m[6].str(); + auto next_scheme = std::move(uc.scheme); + auto next_host = std::move(uc.host); + auto port_str = std::move(uc.port); + auto next_path = std::move(uc.path); + auto next_query = std::move(uc.query); auto next_port = port_; if (!port_str.empty()) { - next_port = std::stoi(port_str); + if (!detail::parse_port(port_str, next_port)) { return false; } } else if (!next_scheme.empty()) { next_port = next_scheme == "https" ? 443 : 80; } @@ -6314,7 +9374,7 @@ bool ClientImpl::redirect(Request &req, Response &res, Error &error) { if (next_host.empty()) { next_host = host_; } if (next_path.empty()) { next_path = "/"; } - auto path = decode_query_component(next_path, true) + next_query; + auto path = decode_path_component(next_path) + next_query; // Same host redirect - use current client if (next_scheme == scheme && next_host == host_ && next_port == port_) { @@ -6349,29 +9409,22 @@ bool ClientImpl::create_redirect_client( // Create appropriate client type and handle redirect if (need_ssl) { -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED // Create SSL client for HTTPS redirect SSLClient redirect_client(host, port); // Setup basic client configuration first setup_redirect_client(redirect_client); - // SSL-specific configuration for proxy environments - if (!proxy_host_.empty() && proxy_port_ != -1) { - // Critical: Disable SSL verification for proxy environments - redirect_client.enable_server_certificate_verification(false); - redirect_client.enable_server_hostname_verification(false); - } else { - // For direct SSL connections, copy SSL verification settings - redirect_client.enable_server_certificate_verification( - server_certificate_verification_); - redirect_client.enable_server_hostname_verification( - server_hostname_verification_); - } + redirect_client.enable_server_certificate_verification( + server_certificate_verification_); + redirect_client.enable_server_hostname_verification( + server_hostname_verification_); - // Handle CA certificate store and paths if available - if (ca_cert_store_ && X509_STORE_up_ref(ca_cert_store_)) { - redirect_client.set_ca_cert_store(ca_cert_store_); + // Transfer CA certificate to redirect client + if (!ca_cert_pem_.empty()) { + redirect_client.load_ca_cert_store(ca_cert_pem_.c_str(), + ca_cert_pem_.size()); } if (!ca_cert_file_path_.empty()) { redirect_client.set_ca_cert_path(ca_cert_file_path_, ca_cert_dir_path_); @@ -6417,26 +9470,19 @@ void ClientImpl::setup_redirect_client(ClientType &client) { client.set_compress(compress_); client.set_decompress(decompress_); - // Copy authentication settings BEFORE proxy setup - if (!basic_auth_username_.empty()) { - client.set_basic_auth(basic_auth_username_, basic_auth_password_); - } - if (!bearer_token_auth_token_.empty()) { - client.set_bearer_token_auth(bearer_token_auth_token_); - } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (!digest_auth_username_.empty()) { - client.set_digest_auth(digest_auth_username_, digest_auth_password_); - } -#endif + // NOTE: Authentication credentials (basic auth, bearer token, digest auth) + // are intentionally NOT copied to the redirect client. Per RFC 9110 Section + // 15.4, credentials must not be forwarded when redirecting to a different + // host. This function is only called for cross-host redirects; same-host + // redirects are handled directly in ClientImpl::redirect(). - // Setup proxy configuration (CRITICAL ORDER - proxy must be set - // before proxy auth) + // Copy the proxy configuration unconditionally; the per-target bypass is + // re-evaluated at send time, so a later hop to a non-bypassed host can + // still use the proxy. + client.no_proxy_entries_ = no_proxy_entries_; if (!proxy_host_.empty() && proxy_port_ != -1) { - // First set proxy host and port client.set_proxy(proxy_host_, proxy_port_); - // Then set proxy authentication (order matters!) if (!proxy_basic_auth_username_.empty()) { client.set_proxy_basic_auth(proxy_basic_auth_username_, proxy_basic_auth_password_); @@ -6444,7 +9490,7 @@ void ClientImpl::setup_redirect_client(ClientType &client) { if (!proxy_bearer_token_auth_token_.empty()) { client.set_proxy_bearer_token_auth(proxy_bearer_token_auth_token_); } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED if (!proxy_digest_auth_username_.empty()) { client.set_proxy_digest_auth(proxy_digest_auth_username_, proxy_digest_auth_password_); @@ -6473,14 +9519,9 @@ bool ClientImpl::write_content_with_provider(Stream &strm, auto is_shutting_down = []() { return false; }; if (req.is_chunked_content_provider_) { - // TODO: Brotli support - std::unique_ptr compressor; -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - if (compress_) { - compressor = detail::make_unique(); - } else -#endif - { + auto compressor = compress_ ? detail::create_compressor().first + : std::unique_ptr(); + if (!compressor) { compressor = detail::make_unique(); } @@ -6494,7 +9535,8 @@ bool ClientImpl::write_content_with_provider(Stream &strm, } bool ClientImpl::write_request(Stream &strm, Request &req, - bool close_connection, Error &error) { + bool close_connection, Error &error, + bool skip_body) { // Prepare additional headers if (close_connection) { if (!req.has_header("Connection")) { @@ -6502,42 +9544,11 @@ bool ClientImpl::write_request(Stream &strm, Request &req, } } - if (!req.has_header("Host")) { - // For Unix socket connections, use "localhost" as Host header (similar to - // curl behavior) - if (address_family_ == AF_UNIX) { - req.set_header("Host", "localhost"); - } else { - req.set_header("Host", host_and_port_); - } + std::string ct_for_defaults; + if (!req.has_header("Content-Type") && !req.body.empty()) { + ct_for_defaults = "text/plain"; } - - if (!req.has_header("Accept")) { req.set_header("Accept", "*/*"); } - - if (!req.content_receiver) { - if (!req.has_header("Accept-Encoding")) { - std::string accept_encoding; -#ifdef CPPHTTPLIB_BROTLI_SUPPORT - accept_encoding = "br"; -#endif -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - if (!accept_encoding.empty()) { accept_encoding += ", "; } - accept_encoding += "gzip, deflate"; -#endif -#ifdef CPPHTTPLIB_ZSTD_SUPPORT - if (!accept_encoding.empty()) { accept_encoding += ", "; } - accept_encoding += "zstd"; -#endif - req.set_header("Accept-Encoding", accept_encoding); - } - -#ifndef CPPHTTPLIB_NO_DEFAULT_USER_AGENT - if (!req.has_header("User-Agent")) { - auto agent = std::string("cpp-httplib/") + CPPHTTPLIB_VERSION; - req.set_header("User-Agent", agent); - } -#endif - }; + prepare_default_headers(req, false, ct_for_defaults); if (req.body.empty()) { if (req.content_provider_) { @@ -6553,15 +9564,6 @@ bool ClientImpl::write_request(Stream &strm, Request &req, req.set_header("Content-Length", "0"); } } - } else { - if (!req.has_header("Content-Type")) { - req.set_header("Content-Type", "text/plain"); - } - - if (!req.has_header("Content-Length")) { - auto length = std::to_string(req.body.size()); - req.set_header("Content-Length", length); - } } if (!basic_auth_password_.empty() || !basic_auth_username_.empty()) { @@ -6571,14 +9573,6 @@ bool ClientImpl::write_request(Stream &strm, Request &req, } } - if (!proxy_basic_auth_username_.empty() && - !proxy_basic_auth_password_.empty()) { - if (!req.has_header("Proxy-Authorization")) { - req.headers.insert(make_basic_authentication_header( - proxy_basic_auth_username_, proxy_basic_auth_password_, true)); - } - } - if (!bearer_token_auth_token_.empty()) { if (!req.has_header("Authorization")) { req.headers.insert(make_bearer_token_authentication_header( @@ -6586,8 +9580,18 @@ bool ClientImpl::write_request(Stream &strm, Request &req, } } - if (!proxy_bearer_token_auth_token_.empty()) { - if (!req.has_header("Proxy-Authorization")) { + // Proxy-Authorization is only sent when the proxy is actually used for + // this target โ€” otherwise NO_PROXY-matched requests would leak proxy + // credentials directly to the destination server. + if (is_proxy_enabled_for_host(host_)) { + if (!proxy_basic_auth_username_.empty() && + !proxy_basic_auth_password_.empty() && + !req.has_header("Proxy-Authorization")) { + req.headers.insert(make_basic_authentication_header( + proxy_basic_auth_username_, proxy_basic_auth_password_, true)); + } + if (!proxy_bearer_token_auth_token_.empty() && + !req.has_header("Proxy-Authorization")) { req.headers.insert(make_bearer_token_authentication_header( proxy_bearer_token_auth_token_, true)); } @@ -6608,18 +9612,41 @@ bool ClientImpl::write_request(Stream &strm, Request &req, query_part = ""; } - // Encode path and query + // Encode path part. If the original `req.path` already contained a + // query component, preserve its raw query string (including parameter + // order) instead of reparsing and reassembling it which may reorder + // parameters due to container ordering (e.g. `Params` uses + // `std::multimap`). When there is no query in `req.path`, fall back to + // building a query from `req.params` so existing callers that pass + // `Params` continue to work. auto path_with_query = path_encode_ ? detail::encode_path(path_part) : path_part; - detail::parse_query_text(query_part, req.params); - if (!req.params.empty()) { - path_with_query = append_query_params(path_with_query, req.params); + if (!query_part.empty()) { + // Normalize the query string (decode then re-encode) while preserving + // the original parameter order. + auto normalized = detail::normalize_query_string(query_part); + if (!normalized.empty()) { path_with_query += '?' + normalized; } + + // Still populate req.params for handlers/users who read them. + detail::parse_query_text(query_part, req.params); + } else { + // No query in path; parse any query_part (empty) and append params + // from `req.params` when present (preserves prior behavior for + // callers who provide Params separately). + detail::parse_query_text(query_part, req.params); + if (!req.params.empty()) { + path_with_query = append_query_params(path_with_query, req.params); + } } // Write request line and headers detail::write_request_line(bstrm, req.method, path_with_query); - header_writer_(bstrm, req.headers); + if (!detail::check_and_write_headers(bstrm, req.headers, header_writer_, + error)) { + output_error_log(error, &req); + return false; + } // Flush buffer auto &data = bstrm.get_buffer(); @@ -6630,7 +9657,59 @@ bool ClientImpl::write_request(Stream &strm, Request &req, } } + // After sending request line and headers, wait briefly for an early server + // response (e.g. 4xx) and avoid sending a potentially large request body + // unnecessarily. This workaround is only enabled on Windows because Unix + // platforms surface write errors (EPIPE) earlier; on Windows kernel send + // buffering can accept large writes even when the peer already responded. + // Check the stream first (which covers SSL via `is_readable()`), then + // fall back to select on the socket. Only perform the wait for very large + // request bodies to avoid interfering with normal small requests and + // reduce side-effects. Poll briefly (up to 50ms as default) for an early + // response. Skip this check when using Expect: 100-continue, as the protocol + // handles early responses properly. +#if defined(_WIN32) + if (!skip_body && + req.body.size() > CPPHTTPLIB_WAIT_EARLY_SERVER_RESPONSE_THRESHOLD && + req.path.size() > CPPHTTPLIB_REQUEST_URI_MAX_LENGTH) { + auto start = std::chrono::high_resolution_clock::now(); + + for (;;) { + // Prefer socket-level readiness to avoid SSL_pending() false-positives + // from SSL internals. If the underlying socket is readable, assume an + // early response may be present. + auto sock = strm.socket(); + if (sock != INVALID_SOCKET && detail::select_read(sock, 0, 0) > 0) { + return false; + } + + // Fallback to stream-level check for non-socket streams or when the + // socket isn't reporting readable. Avoid using `is_readable()` for + // SSL, since `SSL_pending()` may report buffered records that do not + // indicate a complete application-level response yet. + if (!is_ssl() && strm.is_readable()) { return false; } + + auto now = std::chrono::high_resolution_clock::now(); + auto elapsed = + std::chrono::duration_cast(now - start) + .count(); + if (elapsed >= CPPHTTPLIB_WAIT_EARLY_SERVER_RESPONSE_TIMEOUT_MSECOND) { + break; + } + + std::this_thread::sleep_for(std::chrono::milliseconds(1)); + } + } +#endif + // Body + if (skip_body) { return true; } + + return write_request_body(strm, req, error); +} + +bool ClientImpl::write_request_body(Stream &strm, Request &req, + Error &error) { if (req.body.empty()) { return write_content_with_provider(strm, req, error); } @@ -6666,21 +9745,24 @@ bool ClientImpl::write_request(Stream &strm, Request &req, return true; } -std::unique_ptr ClientImpl::send_with_content_provider( +std::unique_ptr +ClientImpl::send_with_content_provider_and_receiver( Request &req, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type, Error &error) { + const std::string &content_type, ContentReceiver content_receiver, + Error &error) { if (!content_type.empty()) { req.set_header("Content-Type", content_type); } -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - if (compress_) { req.set_header("Content-Encoding", "gzip"); } -#endif + auto enc = compress_ + ? detail::create_compressor() + : std::pair, const char *>( + nullptr, nullptr); -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - if (compress_ && !content_provider_without_length) { - // TODO: Brotli support - detail::gzip_compressor compressor; + if (enc.second) { req.set_header("Content-Encoding", enc.second); } + + if (enc.first && !content_provider_without_length) { + auto &compressor = enc.first; if (content_provider) { auto ok = true; @@ -6691,7 +9773,7 @@ std::unique_ptr ClientImpl::send_with_content_provider( if (ok) { auto last = offset + data_len == content_length; - auto ret = compressor.compress( + auto ret = compressor->compress( data, data_len, last, [&](const char *compressed_data, size_t compressed_data_len) { req.body.append(compressed_data, compressed_data_len); @@ -6715,19 +9797,17 @@ std::unique_ptr ClientImpl::send_with_content_provider( } } } else { - if (!compressor.compress(body, content_length, true, - [&](const char *data, size_t data_len) { - req.body.append(data, data_len); - return true; - })) { + if (!compressor->compress(body, content_length, true, + [&](const char *data, size_t data_len) { + req.body.append(data, data_len); + return true; + })) { error = Error::Compression; output_error_log(error, &req); return nullptr; } } - } else -#endif - { + } else { if (content_provider) { req.content_length_ = content_length; req.content_provider_ = std::move(content_provider); @@ -6743,15 +9823,24 @@ std::unique_ptr ClientImpl::send_with_content_provider( } } + if (content_receiver) { + req.content_receiver = + [content_receiver](const char *data, size_t data_length, + size_t /*offset*/, size_t /*total_length*/) { + return content_receiver(data, data_length); + }; + } + auto res = detail::make_unique(); return send(req, *res, error) ? std::move(res) : nullptr; } -Result ClientImpl::send_with_content_provider( +Result ClientImpl::send_with_content_provider_and_receiver( const std::string &method, const std::string &path, const Headers &headers, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type, UploadProgress progress) { + const std::string &content_type, ContentReceiver content_receiver, + UploadProgress progress) { Request req; req.method = method; req.headers = headers; @@ -6763,13 +9852,14 @@ Result ClientImpl::send_with_content_provider( auto error = Error::Success; - auto res = send_with_content_provider( + auto res = send_with_content_provider_and_receiver( req, body, content_length, std::move(content_provider), - std::move(content_provider_without_length), content_type, error); + std::move(content_provider_without_length), content_type, + std::move(content_receiver), error); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED return Result{std::move(res), error, std::move(req.headers), last_ssl_error_, - last_openssl_error_}; + last_backend_error_}; #else return Result{std::move(res), error, std::move(req.headers)}; #endif @@ -6794,14 +9884,26 @@ void ClientImpl::output_error_log(const Error &err, bool ClientImpl::process_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error) { - // Send request - if (!write_request(strm, req, close_connection, error)) { return false; } + // Auto-add Expect: 100-continue for large bodies + if (CPPHTTPLIB_EXPECT_100_THRESHOLD > 0 && !req.has_header("Expect")) { + auto body_size = req.body.empty() ? req.content_length_ : req.body.size(); + if (body_size >= CPPHTTPLIB_EXPECT_100_THRESHOLD) { + req.set_header("Expect", "100-continue"); + } + } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (is_ssl()) { - auto is_proxy_enabled = !proxy_host_.empty() && proxy_port_ != -1; + // Check for Expect: 100-continue + auto expect_100_continue = req.get_header_value("Expect") == "100-continue"; + + // Send request (skip body if using Expect: 100-continue) + auto write_request_success = + write_request(strm, req, close_connection, error, expect_100_continue); + +#ifdef CPPHTTPLIB_SSL_ENABLED + if (is_ssl() && !expect_100_continue) { + auto is_proxy_enabled = is_proxy_enabled_for_host(host_); if (!is_proxy_enabled) { - if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { + if (tls::is_peer_closed(socket_.ssl, socket_.sock)) { error = Error::SSLPeerCouldBeClosed_; output_error_log(error, &req); return false; @@ -6810,14 +9912,48 @@ bool ClientImpl::process_request(Stream &strm, Request &req, } #endif + // Handle Expect: 100-continue with timeout + if (expect_100_continue && CPPHTTPLIB_EXPECT_100_TIMEOUT_MSECOND > 0) { + time_t sec = CPPHTTPLIB_EXPECT_100_TIMEOUT_MSECOND / 1000; + time_t usec = (CPPHTTPLIB_EXPECT_100_TIMEOUT_MSECOND % 1000) * 1000; + auto ret = detail::select_read(strm.socket(), sec, usec); + if (ret <= 0) { + // Timeout or error: send body anyway (server didn't respond in time) + if (!write_request_body(strm, req, error)) { return false; } + expect_100_continue = false; // Switch to normal response handling + } + } + // Receive response and headers - if (!read_response_line(strm, req, res) || + // When using Expect: 100-continue, don't auto-skip `100 Continue` response + if (!read_response_line(strm, req, res, !expect_100_continue) || !detail::read_headers(strm, res.headers)) { - error = Error::Read; + if (write_request_success) { error = Error::Read; } output_error_log(error, &req); return false; } + if (!write_request_success) { return false; } + + // Handle Expect: 100-continue response + if (expect_100_continue) { + if (res.status == StatusCode::Continue_100) { + // Server accepted, send the body + if (!write_request_body(strm, req, error)) { return false; } + + // Read the actual response + res.headers.clear(); + res.body.clear(); + if (!read_response_line(strm, req, res) || + !detail::read_headers(strm, res.headers)) { + error = Error::Read; + output_error_log(error, &req); + return false; + } + } + // If not 100 Continue, server returned an error; proceed with that response + } + // Body if ((res.status != StatusCode::NoContent_204) && req.method != "HEAD" && req.method != "CONNECT") { @@ -6849,6 +9985,11 @@ bool ClientImpl::process_request(Stream &strm, Request &req, [&](const char *buf, size_t n, size_t /*off*/, size_t /*len*/) { assert(res.body.size() + n <= res.body.max_size()); + if (payload_max_length_ > 0 && + (res.body.size() >= payload_max_length_ || + n > payload_max_length_ - res.body.size())) { + return false; + } res.body.append(buf, n); return true; }); @@ -6871,15 +10012,26 @@ bool ClientImpl::process_request(Stream &strm, Request &req, output_error_log(error, &req); return false; } - res.body.reserve(static_cast(len)); + // Cap the reservation by payload_max_length_ to avoid OOM when a + // hostile or malformed server sends an enormous Content-Length. + // The actual body read below is bounded by payload_max_length_, + // so reserving more than that is never useful. + auto reserve_len = static_cast(len); + if (payload_max_length_ > 0 && reserve_len > payload_max_length_) { + reserve_len = payload_max_length_; + } + res.body.reserve(reserve_len); } } if (res.status != StatusCode::NotModified_304) { int dummy_status; - if (!detail::read_content(strm, res, (std::numeric_limits::max)(), - dummy_status, std::move(progress), - std::move(out), decompress_)) { + auto max_length = (!has_payload_max_length_ && req.content_receiver) + ? (std::numeric_limits::max)() + : payload_max_length_; + if (!detail::read_content(strm, res, max_length, dummy_status, + std::move(progress), std::move(out), + decompress_)) { if (error != Error::Canceled) { error = Error::Read; } output_error_log(error, &req); return false; @@ -7094,6 +10246,15 @@ Result ClientImpl::Post(const std::string &path, size_t content_length, content_type, progress); } +Result ClientImpl::Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Post(path, Headers(), content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} + Result ClientImpl::Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -7102,6 +10263,15 @@ Result ClientImpl::Post(const std::string &path, progress); } +Result ClientImpl::Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Post(path, Headers(), std::move(content_provider), content_type, + std::move(content_receiver), progress); +} + Result ClientImpl::Post(const std::string &path, const Headers &headers, const Params ¶ms) { auto query = detail::params_to_query_str(params); @@ -7120,8 +10290,10 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - const auto &body = detail::serialize_multipart_formdata(items, boundary); - return Post(path, headers, body, content_type, progress); + auto content_length = detail::get_multipart_content_length(items, boundary); + return Post(path, headers, content_length, + detail::make_multipart_content_provider(items, boundary), + content_type, progress); } Result ClientImpl::Post(const std::string &path, const Headers &headers, @@ -7134,25 +10306,28 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers, const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - const auto &body = detail::serialize_multipart_formdata(items, boundary); - return Post(path, headers, body, content_type, progress); + auto content_length = detail::get_multipart_content_length(items, boundary); + return Post(path, headers, content_length, + detail::make_multipart_content_provider(items, boundary), + content_type, progress); } Result ClientImpl::Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("POST", path, headers, body, content_length, - nullptr, nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "POST", path, headers, body, content_length, nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("POST", path, headers, body.data(), - body.size(), nullptr, nullptr, content_type, - progress); + return send_with_content_provider_and_receiver( + "POST", path, headers, body.data(), body.size(), nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Post(const std::string &path, const Headers &headers, @@ -7160,18 +10335,40 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers, ContentProvider content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("POST", path, headers, nullptr, - content_length, std::move(content_provider), - nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "POST", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, nullptr, progress); +} + +Result ClientImpl::Post(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return send_with_content_provider_and_receiver( + "POST", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, + std::move(content_receiver), std::move(progress)); } Result ClientImpl::Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("POST", path, headers, nullptr, 0, nullptr, - std::move(content_provider), content_type, - progress); + return send_with_content_provider_and_receiver( + "POST", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, nullptr, progress); +} + +Result ClientImpl::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return send_with_content_provider_and_receiver( + "POST", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, std::move(content_receiver), std::move(progress)); } Result ClientImpl::Post(const std::string &path, const Headers &headers, @@ -7181,10 +10378,10 @@ Result ClientImpl::Post(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - return send_with_content_provider( + return send_with_content_provider_and_receiver( "POST", path, headers, nullptr, 0, nullptr, get_multipart_content_provider(boundary, items, provider_items), - content_type, progress); + content_type, nullptr, progress); } Result ClientImpl::Post(const std::string &path, const Headers &headers, @@ -7246,6 +10443,15 @@ Result ClientImpl::Put(const std::string &path, size_t content_length, content_type, progress); } +Result ClientImpl::Put(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Put(path, Headers(), content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} + Result ClientImpl::Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -7254,6 +10460,15 @@ Result ClientImpl::Put(const std::string &path, progress); } +Result ClientImpl::Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Put(path, Headers(), std::move(content_provider), content_type, + std::move(content_receiver), progress); +} + Result ClientImpl::Put(const std::string &path, const Headers &headers, const Params ¶ms) { auto query = detail::params_to_query_str(params); @@ -7272,8 +10487,10 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - const auto &body = detail::serialize_multipart_formdata(items, boundary); - return Put(path, headers, body, content_type, progress); + auto content_length = detail::get_multipart_content_length(items, boundary); + return Put(path, headers, content_length, + detail::make_multipart_content_provider(items, boundary), + content_type, progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, @@ -7286,25 +10503,28 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers, const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - const auto &body = detail::serialize_multipart_formdata(items, boundary); - return Put(path, headers, body, content_type, progress); + auto content_length = detail::get_multipart_content_length(items, boundary); + return Put(path, headers, content_length, + detail::make_multipart_content_provider(items, boundary), + content_type, progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PUT", path, headers, body, content_length, - nullptr, nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "PUT", path, headers, body, content_length, nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PUT", path, headers, body.data(), - body.size(), nullptr, nullptr, content_type, - progress); + return send_with_content_provider_and_receiver( + "PUT", path, headers, body.data(), body.size(), nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, @@ -7312,18 +10532,40 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers, ContentProvider content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PUT", path, headers, nullptr, - content_length, std::move(content_provider), - nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "PUT", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, nullptr, progress); +} + +Result ClientImpl::Put(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return send_with_content_provider_and_receiver( + "PUT", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, + std::move(content_receiver), progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PUT", path, headers, nullptr, 0, nullptr, - std::move(content_provider), content_type, - progress); + return send_with_content_provider_and_receiver( + "PUT", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, nullptr, progress); +} + +Result ClientImpl::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return send_with_content_provider_and_receiver( + "PUT", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, std::move(content_receiver), progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, @@ -7333,10 +10575,10 @@ Result ClientImpl::Put(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - return send_with_content_provider( + return send_with_content_provider_and_receiver( "PUT", path, headers, nullptr, 0, nullptr, get_multipart_content_provider(boundary, items, provider_items), - content_type, progress); + content_type, nullptr, progress); } Result ClientImpl::Put(const std::string &path, const Headers &headers, @@ -7400,6 +10642,15 @@ Result ClientImpl::Patch(const std::string &path, size_t content_length, content_type, progress); } +Result ClientImpl::Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Patch(path, Headers(), content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} + Result ClientImpl::Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -7408,6 +10659,15 @@ Result ClientImpl::Patch(const std::string &path, progress); } +Result ClientImpl::Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return Patch(path, Headers(), std::move(content_provider), content_type, + std::move(content_receiver), progress); +} + Result ClientImpl::Patch(const std::string &path, const Headers &headers, const Params ¶ms) { auto query = detail::params_to_query_str(params); @@ -7426,8 +10686,10 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - const auto &body = detail::serialize_multipart_formdata(items, boundary); - return Patch(path, headers, body, content_type, progress); + auto content_length = detail::get_multipart_content_length(items, boundary); + return Patch(path, headers, content_length, + detail::make_multipart_content_provider(items, boundary), + content_type, progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, @@ -7440,26 +10702,28 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers, const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - const auto &body = detail::serialize_multipart_formdata(items, boundary); - return Patch(path, headers, body, content_type, progress); + auto content_length = detail::get_multipart_content_length(items, boundary); + return Patch(path, headers, content_length, + detail::make_multipart_content_provider(items, boundary), + content_type, progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PATCH", path, headers, body, - content_length, nullptr, nullptr, - content_type, progress); + return send_with_content_provider_and_receiver( + "PATCH", path, headers, body, content_length, nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PATCH", path, headers, body.data(), - body.size(), nullptr, nullptr, content_type, - progress); + return send_with_content_provider_and_receiver( + "PATCH", path, headers, body.data(), body.size(), nullptr, nullptr, + content_type, nullptr, progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, @@ -7467,18 +10731,40 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers, ContentProvider content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PATCH", path, headers, nullptr, - content_length, std::move(content_provider), - nullptr, content_type, progress); + return send_with_content_provider_and_receiver( + "PATCH", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, nullptr, progress); +} + +Result ClientImpl::Patch(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return send_with_content_provider_and_receiver( + "PATCH", path, headers, nullptr, content_length, + std::move(content_provider), nullptr, content_type, + std::move(content_receiver), progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { - return send_with_content_provider("PATCH", path, headers, nullptr, 0, nullptr, - std::move(content_provider), content_type, - progress); + return send_with_content_provider_and_receiver( + "PATCH", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, nullptr, progress); +} + +Result ClientImpl::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return send_with_content_provider_and_receiver( + "PATCH", path, headers, nullptr, 0, nullptr, std::move(content_provider), + content_type, std::move(content_receiver), progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, @@ -7488,10 +10774,10 @@ Result ClientImpl::Patch(const std::string &path, const Headers &headers, const auto &boundary = detail::make_multipart_data_boundary(); const auto &content_type = detail::serialize_multipart_formdata_get_content_type(boundary); - return send_with_content_provider( + return send_with_content_provider_and_receiver( "PATCH", path, headers, nullptr, 0, nullptr, get_multipart_content_provider(boundary, items, provider_items), - content_type, progress); + content_type, nullptr, progress); } Result ClientImpl::Patch(const std::string &path, const Headers &headers, @@ -7622,10 +10908,7 @@ void ClientImpl::stop() { return; } - // Otherwise, still holding the mutex, we can shut everything down ourselves - shutdown_ssl(socket_, true); - shutdown_socket(socket_); - close_socket(socket_); + disconnect(/*gracefully=*/true); } std::string ClientImpl::host() const { return host_; } @@ -7668,14 +10951,6 @@ void ClientImpl::set_bearer_token_auth(const std::string &token) { bearer_token_auth_token_ = token; } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -void ClientImpl::set_digest_auth(const std::string &username, - const std::string &password) { - digest_auth_username_ = username; - digest_auth_password_ = password; -} -#endif - void ClientImpl::set_keep_alive(bool on) { keep_alive_ = on; } void ClientImpl::set_follow_location(bool on) { follow_location_ = on; } @@ -7712,6 +10987,11 @@ void ClientImpl::set_compress(bool on) { compress_ = on; } void ClientImpl::set_decompress(bool on) { decompress_ = on; } +void ClientImpl::set_payload_max_length(size_t length) { + payload_max_length_ = length; + has_payload_max_length_ = true; +} + void ClientImpl::set_interface(const std::string &intf) { interface_ = intf; } @@ -7719,6 +10999,8 @@ void ClientImpl::set_interface(const std::string &intf) { void ClientImpl::set_proxy(const std::string &host, int port) { proxy_host_ = host; proxy_port_ = port; + std::lock_guard guard(socket_mutex_); + disconnect(/*gracefully=*/true); } void ClientImpl::set_proxy_basic_auth(const std::string &username, @@ -7731,11 +11013,27 @@ void ClientImpl::set_proxy_bearer_token_auth(const std::string &token) { proxy_bearer_token_auth_token_ = token; } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -void ClientImpl::set_proxy_digest_auth(const std::string &username, - const std::string &password) { - proxy_digest_auth_username_ = username; - proxy_digest_auth_password_ = password; +void ClientImpl::set_no_proxy(const std::vector &patterns) { + std::vector parsed; + parsed.reserve(patterns.size()); + for (const auto &p : patterns) { + auto trimmed = detail::trim_copy(p); + if (trimmed.empty()) { continue; } + detail::NoProxyEntry entry; + if (detail::parse_no_proxy_entry(trimmed, entry)) { + parsed.push_back(std::move(entry)); + } + } + no_proxy_entries_ = std::move(parsed); + std::lock_guard guard(socket_mutex_); + disconnect(/*gracefully=*/true); +} + +#ifdef CPPHTTPLIB_SSL_ENABLED +void ClientImpl::set_digest_auth(const std::string &username, + const std::string &password) { + digest_auth_username_ = username; + digest_auth_password_ = password; } void ClientImpl::set_ca_cert_path(const std::string &ca_cert_file_path, @@ -7744,34 +11042,10 @@ void ClientImpl::set_ca_cert_path(const std::string &ca_cert_file_path, ca_cert_dir_path_ = ca_cert_dir_path; } -void ClientImpl::set_ca_cert_store(X509_STORE *ca_cert_store) { - if (ca_cert_store && ca_cert_store != ca_cert_store_) { - ca_cert_store_ = ca_cert_store; - } -} - -X509_STORE *ClientImpl::create_ca_cert_store(const char *ca_cert, - std::size_t size) const { - auto mem = BIO_new_mem_buf(ca_cert, static_cast(size)); - auto se = detail::scope_exit([&] { BIO_free_all(mem); }); - if (!mem) { return nullptr; } - - auto inf = PEM_X509_INFO_read_bio(mem, nullptr, nullptr, nullptr); - if (!inf) { return nullptr; } - - auto cts = X509_STORE_new(); - if (cts) { - for (auto i = 0; i < static_cast(sk_X509_INFO_num(inf)); i++) { - auto itmp = sk_X509_INFO_value(inf, i); - if (!itmp) { continue; } - - if (itmp->x509) { X509_STORE_add_cert(cts, itmp->x509); } - if (itmp->crl) { X509_STORE_add_crl(cts, itmp->crl); } - } - } - - sk_X509_INFO_pop_free(inf, X509_INFO_free); - return cts; +void ClientImpl::set_proxy_digest_auth(const std::string &username, + const std::string &password) { + proxy_digest_auth_username_ = username; + proxy_digest_auth_password_ = password; } void ClientImpl::enable_server_certificate_verification(bool enabled) { @@ -7781,11 +11055,6 @@ void ClientImpl::enable_server_certificate_verification(bool enabled) { void ClientImpl::enable_server_hostname_verification(bool enabled) { server_hostname_verification_ = enabled; } - -void ClientImpl::set_server_certificate_verifier( - std::function verifier) { - server_certificate_verifier_ = verifier; -} #endif void ClientImpl::set_logger(Logger logger) { @@ -7797,933 +11066,24 @@ void ClientImpl::set_error_logger(ErrorLogger error_logger) { } /* - * SSL Implementation + * SSL/TLS Common Implementation */ -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -namespace detail { -bool is_ip_address(const std::string &host) { - struct in_addr addr4; - struct in6_addr addr6; - return inet_pton(AF_INET, host.c_str(), &addr4) == 1 || - inet_pton(AF_INET6, host.c_str(), &addr6) == 1; -} - -template -SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex, - U SSL_connect_or_accept, V setup) { - SSL *ssl = nullptr; - { - std::lock_guard guard(ctx_mutex); - ssl = SSL_new(ctx); - } - - if (ssl) { - set_nonblocking(sock, true); - auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); - BIO_set_nbio(bio, 1); - SSL_set_bio(ssl, bio, bio); - - if (!setup(ssl) || SSL_connect_or_accept(ssl) != 1) { - SSL_shutdown(ssl); - { - std::lock_guard guard(ctx_mutex); - SSL_free(ssl); - } - set_nonblocking(sock, false); - return nullptr; - } - BIO_set_nbio(bio, 0); - set_nonblocking(sock, false); - } - - return ssl; -} - -void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, socket_t sock, - bool shutdown_gracefully) { - // sometimes we may want to skip this to try to avoid SIGPIPE if we know - // the remote has closed the network connection - // Note that it is not always possible to avoid SIGPIPE, this is merely a - // best-efforts. - if (shutdown_gracefully) { - (void)(sock); - // SSL_shutdown() returns 0 on first call (indicating close_notify alert - // sent) and 1 on subsequent call (indicating close_notify alert received) - if (SSL_shutdown(ssl) == 0) { - // Expected to return 1, but even if it doesn't, we free ssl - SSL_shutdown(ssl); - } - } - - std::lock_guard guard(ctx_mutex); - SSL_free(ssl); -} - -template -bool ssl_connect_or_accept_nonblocking(socket_t sock, SSL *ssl, - U ssl_connect_or_accept, - time_t timeout_sec, time_t timeout_usec, - int *ssl_error) { - auto res = 0; - while ((res = ssl_connect_or_accept(ssl)) != 1) { - auto err = SSL_get_error(ssl, res); - switch (err) { - case SSL_ERROR_WANT_READ: - if (select_read(sock, timeout_sec, timeout_usec) > 0) { continue; } - break; - case SSL_ERROR_WANT_WRITE: - if (select_write(sock, timeout_sec, timeout_usec) > 0) { continue; } - break; - default: break; - } - if (ssl_error) { *ssl_error = err; } - return false; - } - return true; -} - -template -bool process_server_socket_ssl( - const std::atomic &svr_sock, SSL *ssl, socket_t sock, - size_t keep_alive_max_count, time_t keep_alive_timeout_sec, - time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, - time_t write_timeout_usec, T callback) { - return process_server_socket_core( - svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec, - [&](bool close_connection, bool &connection_closed) { - SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec); - return callback(strm, close_connection, connection_closed); - }); -} - -template -bool process_client_socket_ssl( - SSL *ssl, socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, - time_t write_timeout_sec, time_t write_timeout_usec, - time_t max_timeout_msec, - std::chrono::time_point start_time, T callback) { - SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec, max_timeout_msec, - start_time); - return callback(strm); -} - -// SSL socket stream implementation -SSLSocketStream::SSLSocketStream( - socket_t sock, SSL *ssl, time_t read_timeout_sec, time_t read_timeout_usec, - time_t write_timeout_sec, time_t write_timeout_usec, - time_t max_timeout_msec, - std::chrono::time_point start_time) - : sock_(sock), ssl_(ssl), read_timeout_sec_(read_timeout_sec), - read_timeout_usec_(read_timeout_usec), - write_timeout_sec_(write_timeout_sec), - write_timeout_usec_(write_timeout_usec), - max_timeout_msec_(max_timeout_msec), start_time_(start_time) { - SSL_clear_mode(ssl, SSL_MODE_AUTO_RETRY); -} - -SSLSocketStream::~SSLSocketStream() = default; - -bool SSLSocketStream::is_readable() const { - return SSL_pending(ssl_) > 0; -} - -bool SSLSocketStream::wait_readable() const { - if (max_timeout_msec_ <= 0) { - return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; - } - - time_t read_timeout_sec; - time_t read_timeout_usec; - calc_actual_timeout(max_timeout_msec_, duration(), read_timeout_sec_, - read_timeout_usec_, read_timeout_sec, read_timeout_usec); - - return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; -} - -bool SSLSocketStream::wait_writable() const { - return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && - is_socket_alive(sock_) && !is_ssl_peer_could_be_closed(ssl_, sock_); -} - -ssize_t SSLSocketStream::read(char *ptr, size_t size) { - if (SSL_pending(ssl_) > 0) { - return SSL_read(ssl_, ptr, static_cast(size)); - } else if (wait_readable()) { - auto ret = SSL_read(ssl_, ptr, static_cast(size)); - if (ret < 0) { - auto err = SSL_get_error(ssl_, ret); - auto n = 1000; -#ifdef _WIN32 - while (--n >= 0 && (err == SSL_ERROR_WANT_READ || - (err == SSL_ERROR_SYSCALL && - WSAGetLastError() == WSAETIMEDOUT))) { -#else - while (--n >= 0 && err == SSL_ERROR_WANT_READ) { -#endif - if (SSL_pending(ssl_) > 0) { - return SSL_read(ssl_, ptr, static_cast(size)); - } else if (wait_readable()) { - std::this_thread::sleep_for(std::chrono::microseconds{10}); - ret = SSL_read(ssl_, ptr, static_cast(size)); - if (ret >= 0) { return ret; } - err = SSL_get_error(ssl_, ret); - } else { - break; - } - } - assert(ret < 0); - } - return ret; - } else { - return -1; - } -} - -ssize_t SSLSocketStream::write(const char *ptr, size_t size) { - if (wait_writable()) { - auto handle_size = static_cast( - std::min(size, (std::numeric_limits::max)())); - - auto ret = SSL_write(ssl_, ptr, static_cast(handle_size)); - if (ret < 0) { - auto err = SSL_get_error(ssl_, ret); - auto n = 1000; -#ifdef _WIN32 - while (--n >= 0 && (err == SSL_ERROR_WANT_WRITE || - (err == SSL_ERROR_SYSCALL && - WSAGetLastError() == WSAETIMEDOUT))) { -#else - while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) { -#endif - if (wait_writable()) { - std::this_thread::sleep_for(std::chrono::microseconds{10}); - ret = SSL_write(ssl_, ptr, static_cast(handle_size)); - if (ret >= 0) { return ret; } - err = SSL_get_error(ssl_, ret); - } else { - break; - } - } - assert(ret < 0); - } - return ret; - } - return -1; -} - -void SSLSocketStream::get_remote_ip_and_port(std::string &ip, - int &port) const { - detail::get_remote_ip_and_port(sock_, ip, port); -} - -void SSLSocketStream::get_local_ip_and_port(std::string &ip, - int &port) const { - detail::get_local_ip_and_port(sock_, ip, port); -} - -socket_t SSLSocketStream::socket() const { return sock_; } - -time_t SSLSocketStream::duration() const { - return std::chrono::duration_cast( - std::chrono::steady_clock::now() - start_time_) - .count(); -} - -} // namespace detail - -// SSL HTTP server implementation -SSLServer::SSLServer(const char *cert_path, const char *private_key_path, - const char *client_ca_cert_file_path, - const char *client_ca_cert_dir_path, - const char *private_key_password) { - ctx_ = SSL_CTX_new(TLS_server_method()); - - if (ctx_) { - SSL_CTX_set_options(ctx_, - SSL_OP_NO_COMPRESSION | - SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); - - SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); - - if (private_key_password != nullptr && (private_key_password[0] != '\0')) { - SSL_CTX_set_default_passwd_cb_userdata( - ctx_, - reinterpret_cast(const_cast(private_key_password))); - } - - if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || - SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != - 1 || - SSL_CTX_check_private_key(ctx_) != 1) { - last_ssl_error_ = static_cast(ERR_get_error()); - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { - SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, - client_ca_cert_dir_path); - - // Set client CA list to be sent to clients during TLS handshake - if (client_ca_cert_file_path) { - auto ca_list = SSL_load_client_CA_file(client_ca_cert_file_path); - if (ca_list != nullptr) { - SSL_CTX_set_client_CA_list(ctx_, ca_list); - } else { - // Failed to load client CA list, but we continue since - // SSL_CTX_load_verify_locations already succeeded and - // certificate verification will still work - last_ssl_error_ = static_cast(ERR_get_error()); - } - } - - SSL_CTX_set_verify( - ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); - } - } -} - -SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key, - X509_STORE *client_ca_cert_store) { - ctx_ = SSL_CTX_new(TLS_server_method()); - - if (ctx_) { - SSL_CTX_set_options(ctx_, - SSL_OP_NO_COMPRESSION | - SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); - - SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); - - if (SSL_CTX_use_certificate(ctx_, cert) != 1 || - SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) { - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } else if (client_ca_cert_store) { - SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); - - // Extract CA names from the store and set them as the client CA list - auto ca_list = extract_ca_names_from_x509_store(client_ca_cert_store); - if (ca_list) { - SSL_CTX_set_client_CA_list(ctx_, ca_list); - } else { - // Failed to extract CA names, record the error - last_ssl_error_ = static_cast(ERR_get_error()); - } - - SSL_CTX_set_verify( - ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); - } - } -} - -SSLServer::SSLServer( - const std::function &setup_ssl_ctx_callback) { - ctx_ = SSL_CTX_new(TLS_method()); - if (ctx_) { - if (!setup_ssl_ctx_callback(*ctx_)) { - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } - } -} - -SSLServer::~SSLServer() { - if (ctx_) { SSL_CTX_free(ctx_); } -} - -bool SSLServer::is_valid() const { return ctx_; } - -SSL_CTX *SSLServer::ssl_context() const { return ctx_; } - -void SSLServer::update_certs(X509 *cert, EVP_PKEY *private_key, - X509_STORE *client_ca_cert_store) { - - std::lock_guard guard(ctx_mutex_); - - SSL_CTX_use_certificate(ctx_, cert); - SSL_CTX_use_PrivateKey(ctx_, private_key); - - if (client_ca_cert_store != nullptr) { - SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); - } -} - -bool SSLServer::process_and_close_socket(socket_t sock) { - auto ssl = detail::ssl_new( - sock, ctx_, ctx_mutex_, - [&](SSL *ssl2) { - return detail::ssl_connect_or_accept_nonblocking( - sock, ssl2, SSL_accept, read_timeout_sec_, read_timeout_usec_, - &last_ssl_error_); - }, - [](SSL * /*ssl2*/) { return true; }); - - auto ret = false; - if (ssl) { - std::string remote_addr; - int remote_port = 0; - detail::get_remote_ip_and_port(sock, remote_addr, remote_port); - - std::string local_addr; - int local_port = 0; - detail::get_local_ip_and_port(sock, local_addr, local_port); - - ret = detail::process_server_socket_ssl( - svr_sock_, ssl, sock, keep_alive_max_count_, keep_alive_timeout_sec_, - read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, - write_timeout_usec_, - [&](Stream &strm, bool close_connection, bool &connection_closed) { - return process_request(strm, remote_addr, remote_port, local_addr, - local_port, close_connection, - connection_closed, - [&](Request &req) { req.ssl = ssl; }); - }); - - // Shutdown gracefully if the result seemed successful, non-gracefully if - // the connection appeared to be closed. - const bool shutdown_gracefully = ret; - detail::ssl_delete(ctx_mutex_, ssl, sock, shutdown_gracefully); - } - - detail::shutdown_socket(sock); - detail::close_socket(sock); - return ret; -} - -STACK_OF(X509_NAME) * SSLServer::extract_ca_names_from_x509_store( - X509_STORE *store) { - if (!store) { return nullptr; } - - auto ca_list = sk_X509_NAME_new_null(); - if (!ca_list) { return nullptr; } - - // Get all objects from the store - auto objs = X509_STORE_get0_objects(store); - if (!objs) { - sk_X509_NAME_free(ca_list); - return nullptr; - } - - // Iterate through objects and extract certificate subject names - for (int i = 0; i < sk_X509_OBJECT_num(objs); i++) { - auto obj = sk_X509_OBJECT_value(objs, i); - if (X509_OBJECT_get_type(obj) == X509_LU_X509) { - auto cert = X509_OBJECT_get0_X509(obj); - if (cert) { - auto subject = X509_get_subject_name(cert); - if (subject) { - auto name_dup = X509_NAME_dup(subject); - if (name_dup) { sk_X509_NAME_push(ca_list, name_dup); } - } - } - } - } - - // If no names were extracted, free the list and return nullptr - if (sk_X509_NAME_num(ca_list) == 0) { - sk_X509_NAME_free(ca_list); - return nullptr; - } - - return ca_list; -} - -// SSL HTTP client implementation -SSLClient::SSLClient(const std::string &host) - : SSLClient(host, 443, std::string(), std::string()) {} - -SSLClient::SSLClient(const std::string &host, int port) - : SSLClient(host, port, std::string(), std::string()) {} - -SSLClient::SSLClient(const std::string &host, int port, - const std::string &client_cert_path, - const std::string &client_key_path, - const std::string &private_key_password) - : ClientImpl(host, port, client_cert_path, client_key_path) { - ctx_ = SSL_CTX_new(TLS_client_method()); - - SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); - - detail::split(&host_[0], &host_[host_.size()], '.', - [&](const char *b, const char *e) { - host_components_.emplace_back(b, e); - }); - - if (!client_cert_path.empty() && !client_key_path.empty()) { - if (!private_key_password.empty()) { - SSL_CTX_set_default_passwd_cb_userdata( - ctx_, reinterpret_cast( - const_cast(private_key_password.c_str()))); - } - - if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), - SSL_FILETYPE_PEM) != 1 || - SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), - SSL_FILETYPE_PEM) != 1) { - last_openssl_error_ = ERR_get_error(); - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } - } -} - -SSLClient::SSLClient(const std::string &host, int port, - X509 *client_cert, EVP_PKEY *client_key, - const std::string &private_key_password) - : ClientImpl(host, port) { - ctx_ = SSL_CTX_new(TLS_client_method()); - - detail::split(&host_[0], &host_[host_.size()], '.', - [&](const char *b, const char *e) { - host_components_.emplace_back(b, e); - }); - - if (client_cert != nullptr && client_key != nullptr) { - if (!private_key_password.empty()) { - SSL_CTX_set_default_passwd_cb_userdata( - ctx_, reinterpret_cast( - const_cast(private_key_password.c_str()))); - } - - if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 || - SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) { - last_openssl_error_ = ERR_get_error(); - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } - } -} - -SSLClient::~SSLClient() { - if (ctx_) { SSL_CTX_free(ctx_); } - // Make sure to shut down SSL since shutdown_ssl will resolve to the - // base function rather than the derived function once we get to the - // base class destructor, and won't free the SSL (causing a leak). - shutdown_ssl_impl(socket_, true); -} - -bool SSLClient::is_valid() const { return ctx_; } - -void SSLClient::set_ca_cert_store(X509_STORE *ca_cert_store) { - if (ca_cert_store) { - if (ctx_) { - if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store) { - // Free memory allocated for old cert and use new store - // `ca_cert_store` - SSL_CTX_set_cert_store(ctx_, ca_cert_store); - ca_cert_store_ = ca_cert_store; - } - } else { - X509_STORE_free(ca_cert_store); - } - } -} - -void SSLClient::load_ca_cert_store(const char *ca_cert, - std::size_t size) { - set_ca_cert_store(ClientImpl::create_ca_cert_store(ca_cert, size)); -} - -long SSLClient::get_openssl_verify_result() const { - return verify_result_; -} - -SSL_CTX *SSLClient::ssl_context() const { return ctx_; } - -bool SSLClient::create_and_connect_socket(Socket &socket, Error &error) { - if (!is_valid()) { - error = Error::SSLConnection; - return false; - } - return ClientImpl::create_and_connect_socket(socket, error); -} - -// Assumes that socket_mutex_ is locked and that there are no requests in -// flight -bool SSLClient::connect_with_proxy( - Socket &socket, - std::chrono::time_point start_time, - Response &res, bool &success, Error &error) { - success = true; - Response proxy_res; - if (!detail::process_client_socket( - socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, - start_time, [&](Stream &strm) { - Request req2; - req2.method = "CONNECT"; - req2.path = host_and_port_; - if (max_timeout_msec_ > 0) { - req2.start_time_ = std::chrono::steady_clock::now(); - } - return process_request(strm, req2, proxy_res, false, error); - })) { - // Thread-safe to close everything because we are assuming there are no - // requests in flight - shutdown_ssl(socket, true); - shutdown_socket(socket); - close_socket(socket); - success = false; - return false; - } - - if (proxy_res.status == StatusCode::ProxyAuthenticationRequired_407) { - if (!proxy_digest_auth_username_.empty() && - !proxy_digest_auth_password_.empty()) { - std::map auth; - if (detail::parse_www_authenticate(proxy_res, auth, true)) { - // Close the current socket and create a new one for the authenticated - // request - shutdown_ssl(socket, true); - shutdown_socket(socket); - close_socket(socket); - - // Create a new socket for the authenticated CONNECT request - if (!create_and_connect_socket(socket, error)) { - success = false; - output_error_log(error, nullptr); - return false; - } - - proxy_res = Response(); - if (!detail::process_client_socket( - socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, - start_time, [&](Stream &strm) { - Request req3; - req3.method = "CONNECT"; - req3.path = host_and_port_; - req3.headers.insert(detail::make_digest_authentication_header( - req3, auth, 1, detail::random_string(10), - proxy_digest_auth_username_, proxy_digest_auth_password_, - true)); - if (max_timeout_msec_ > 0) { - req3.start_time_ = std::chrono::steady_clock::now(); - } - return process_request(strm, req3, proxy_res, false, error); - })) { - // Thread-safe to close everything because we are assuming there are - // no requests in flight - shutdown_ssl(socket, true); - shutdown_socket(socket); - close_socket(socket); - success = false; - return false; - } - } - } - } - - // If status code is not 200, proxy request is failed. - // Set error to ProxyConnection and return proxy response - // as the response of the request - if (proxy_res.status != StatusCode::OK_200) { - error = Error::ProxyConnection; - output_error_log(error, nullptr); - res = std::move(proxy_res); - // Thread-safe to close everything because we are assuming there are - // no requests in flight - shutdown_ssl(socket, true); - shutdown_socket(socket); - close_socket(socket); - return false; - } - - return true; -} - -bool SSLClient::load_certs() { - auto ret = true; - - std::call_once(initialize_cert_, [&]() { - std::lock_guard guard(ctx_mutex_); - if (!ca_cert_file_path_.empty()) { - if (!SSL_CTX_load_verify_locations(ctx_, ca_cert_file_path_.c_str(), - nullptr)) { - last_openssl_error_ = ERR_get_error(); - ret = false; - } - } else if (!ca_cert_dir_path_.empty()) { - if (!SSL_CTX_load_verify_locations(ctx_, nullptr, - ca_cert_dir_path_.c_str())) { - last_openssl_error_ = ERR_get_error(); - ret = false; - } - } else { - auto loaded = false; -#ifdef _WIN32 - loaded = - detail::load_system_certs_on_windows(SSL_CTX_get_cert_store(ctx_)); -#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && TARGET_OS_MAC - loaded = detail::load_system_certs_on_macos(SSL_CTX_get_cert_store(ctx_)); -#endif // _WIN32 - if (!loaded) { SSL_CTX_set_default_verify_paths(ctx_); } - } - }); - - return ret; -} - -bool SSLClient::initialize_ssl(Socket &socket, Error &error) { - auto ssl = detail::ssl_new( - socket.sock, ctx_, ctx_mutex_, - [&](SSL *ssl2) { - if (server_certificate_verification_) { - if (!load_certs()) { - error = Error::SSLLoadingCerts; - output_error_log(error, nullptr); - return false; - } - SSL_set_verify(ssl2, SSL_VERIFY_NONE, nullptr); - } - - if (!detail::ssl_connect_or_accept_nonblocking( - socket.sock, ssl2, SSL_connect, connection_timeout_sec_, - connection_timeout_usec_, &last_ssl_error_)) { - error = Error::SSLConnection; - output_error_log(error, nullptr); - return false; - } - - if (server_certificate_verification_) { - auto verification_status = SSLVerifierResponse::NoDecisionMade; - - if (server_certificate_verifier_) { - verification_status = server_certificate_verifier_(ssl2); - } - - if (verification_status == SSLVerifierResponse::CertificateRejected) { - last_openssl_error_ = ERR_get_error(); - error = Error::SSLServerVerification; - output_error_log(error, nullptr); - return false; - } - - if (verification_status == SSLVerifierResponse::NoDecisionMade) { - verify_result_ = SSL_get_verify_result(ssl2); - - if (verify_result_ != X509_V_OK) { - last_openssl_error_ = static_cast(verify_result_); - error = Error::SSLServerVerification; - output_error_log(error, nullptr); - return false; - } - - auto server_cert = SSL_get1_peer_certificate(ssl2); - auto se = detail::scope_exit([&] { X509_free(server_cert); }); - - if (server_cert == nullptr) { - last_openssl_error_ = ERR_get_error(); - error = Error::SSLServerVerification; - output_error_log(error, nullptr); - return false; - } - - if (server_hostname_verification_) { - if (!verify_host(server_cert)) { - last_openssl_error_ = X509_V_ERR_HOSTNAME_MISMATCH; - error = Error::SSLServerHostnameVerification; - output_error_log(error, nullptr); - return false; - } - } - } - } - - return true; - }, - [&](SSL *ssl2) { - // Set SNI only if host is not IP address - if (!detail::is_ip_address(host_)) { -#if defined(OPENSSL_IS_BORINGSSL) - SSL_set_tlsext_host_name(ssl2, host_.c_str()); -#else - // NOTE: Direct call instead of using the OpenSSL macro to suppress - // -Wold-style-cast warning - SSL_ctrl(ssl2, SSL_CTRL_SET_TLSEXT_HOSTNAME, - TLSEXT_NAMETYPE_host_name, - static_cast(const_cast(host_.c_str()))); -#endif - } - return true; - }); - - if (ssl) { - socket.ssl = ssl; - return true; - } - - if (ctx_ == nullptr) { - error = Error::SSLConnection; - last_openssl_error_ = ERR_get_error(); - } - - shutdown_socket(socket); - close_socket(socket); - return false; -} - -void SSLClient::shutdown_ssl(Socket &socket, bool shutdown_gracefully) { - shutdown_ssl_impl(socket, shutdown_gracefully); -} - -void SSLClient::shutdown_ssl_impl(Socket &socket, - bool shutdown_gracefully) { - if (socket.sock == INVALID_SOCKET) { - assert(socket.ssl == nullptr); - return; - } - if (socket.ssl) { - detail::ssl_delete(ctx_mutex_, socket.ssl, socket.sock, - shutdown_gracefully); - socket.ssl = nullptr; - } - assert(socket.ssl == nullptr); -} - -bool SSLClient::process_socket( - const Socket &socket, - std::chrono::time_point start_time, - std::function callback) { - assert(socket.ssl); - return detail::process_client_socket_ssl( - socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, start_time, - std::move(callback)); -} - -bool SSLClient::is_ssl() const { return true; } - -bool SSLClient::verify_host(X509 *server_cert) const { - /* Quote from RFC2818 section 3.1 "Server Identity" - - If a subjectAltName extension of type dNSName is present, that MUST - be used as the identity. Otherwise, the (most specific) Common Name - field in the Subject field of the certificate MUST be used. Although - the use of the Common Name is existing practice, it is deprecated and - Certification Authorities are encouraged to use the dNSName instead. - - Matching is performed using the matching rules specified by - [RFC2459]. If more than one identity of a given type is present in - the certificate (e.g., more than one dNSName name, a match in any one - of the set is considered acceptable.) Names may contain the wildcard - character * which is considered to match any single domain name - component or component fragment. E.g., *.a.com matches foo.a.com but - not bar.foo.a.com. f*.com matches foo.com but not bar.com. - - In some cases, the URI is specified as an IP address rather than a - hostname. In this case, the iPAddress subjectAltName must be present - in the certificate and must exactly match the IP in the URI. - - */ - return verify_host_with_subject_alt_name(server_cert) || - verify_host_with_common_name(server_cert); -} - -bool -SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { - auto ret = false; - - auto type = GEN_DNS; - - struct in6_addr addr6 = {}; - struct in_addr addr = {}; - size_t addr_len = 0; - -#ifndef __MINGW32__ - if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { - type = GEN_IPADD; - addr_len = sizeof(struct in6_addr); - } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { - type = GEN_IPADD; - addr_len = sizeof(struct in_addr); +ClientConnection::~ClientConnection() { +#ifdef CPPHTTPLIB_SSL_ENABLED + if (session) { + tls::shutdown(session, true); + tls::free_session(session); + session = nullptr; } #endif - auto alt_names = static_cast( - X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); - - if (alt_names) { - auto dsn_matched = false; - auto ip_matched = false; - - auto count = sk_GENERAL_NAME_num(alt_names); - - for (decltype(count) i = 0; i < count && !dsn_matched; i++) { - auto val = sk_GENERAL_NAME_value(alt_names, i); - if (!val || val->type != type) { continue; } - - auto name = - reinterpret_cast(ASN1_STRING_get0_data(val->d.ia5)); - if (name == nullptr) { continue; } - - auto name_len = static_cast(ASN1_STRING_length(val->d.ia5)); - - switch (type) { - case GEN_DNS: dsn_matched = check_host_name(name, name_len); break; - - case GEN_IPADD: - if (!memcmp(&addr6, name, addr_len) || !memcmp(&addr, name, addr_len)) { - ip_matched = true; - } - break; - } - } - - if (dsn_matched || ip_matched) { ret = true; } + if (sock != INVALID_SOCKET) { + detail::close_socket(sock); + sock = INVALID_SOCKET; } - - GENERAL_NAMES_free(const_cast( - reinterpret_cast(alt_names))); - return ret; } -bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { - const auto subject_name = X509_get_subject_name(server_cert); - - if (subject_name != nullptr) { - char name[BUFSIZ]; - auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, - name, sizeof(name)); - - if (name_len != -1) { - return check_host_name(name, static_cast(name_len)); - } - } - - return false; -} - -bool SSLClient::check_host_name(const char *pattern, - size_t pattern_len) const { - if (host_.size() == pattern_len && host_ == pattern) { return true; } - - // Wildcard match - // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 - std::vector pattern_components; - detail::split(&pattern[0], &pattern[pattern_len], '.', - [&](const char *b, const char *e) { - pattern_components.emplace_back(b, e); - }); - - if (host_components_.size() != pattern_components.size()) { return false; } - - auto itr = pattern_components.begin(); - for (const auto &h : host_components_) { - auto &p = *itr; - if (p != h && p != "*") { - auto partial_match = (p.size() > 0 && p[p.size() - 1] == '*' && - !p.compare(0, p.size() - 1, h)); - if (!partial_match) { return false; } - } - ++itr; - } - - return true; -} -#endif - // Universal client implementation Client::Client(const std::string &scheme_host_port) : Client(scheme_host_port, std::string(), std::string()) {} @@ -8731,14 +11091,11 @@ Client::Client(const std::string &scheme_host_port) Client::Client(const std::string &scheme_host_port, const std::string &client_cert_path, const std::string &client_key_path) { - const static std::regex re( - R"((?:([a-z]+):\/\/)?(?:\[([a-fA-F\d:]+)\]|([^:/?#]+))(?::(\d+))?)"); + detail::UrlComponents uc; + if (detail::parse_url(scheme_host_port, uc) && !uc.host.empty()) { + auto &scheme = uc.scheme; - std::smatch m; - if (std::regex_match(scheme_host_port, m, re)) { - auto scheme = m[1].str(); - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED if (!scheme.empty() && (scheme != "http" && scheme != "https")) { #else if (!scheme.empty() && scheme != "http") { @@ -8752,14 +11109,13 @@ Client::Client(const std::string &scheme_host_port, auto is_ssl = scheme == "https"; - auto host = m[2].str(); - if (host.empty()) { host = m[3].str(); } + auto host = std::move(uc.host); - auto port_str = m[4].str(); - auto port = !port_str.empty() ? std::stoi(port_str) : (is_ssl ? 443 : 80); + auto port = is_ssl ? 443 : 80; + if (!uc.port.empty() && !detail::parse_port(uc.port, port)) { return; } if (is_ssl) { -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED cli_ = detail::make_unique(host, port, client_cert_path, client_key_path); is_ssl_ = is_ssl; @@ -8774,10 +11130,10 @@ Client::Client(const std::string &scheme_host_port, cli_ = detail::make_unique(scheme_host_port, 80, client_cert_path, client_key_path); } -} // namespace detail +} Client::Client(const std::string &host, int port) - : cli_(detail::make_unique(host, port)) {} + : Client(host, port, std::string(), std::string()) {} Client::Client(const std::string &host, int port, const std::string &client_cert_path, @@ -8883,12 +11239,28 @@ Result Client::Post(const std::string &path, size_t content_length, return cli_->Post(path, content_length, std::move(content_provider), content_type, progress); } +Result Client::Post(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Post(path, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { return cli_->Post(path, std::move(content_provider), content_type, progress); } +Result Client::Post(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Post(path, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, @@ -8897,6 +11269,15 @@ Result Client::Post(const std::string &path, const Headers &headers, return cli_->Post(path, headers, content_length, std::move(content_provider), content_type, progress); } +Result Client::Post(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Post(path, headers, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -8904,6 +11285,14 @@ Result Client::Post(const std::string &path, const Headers &headers, return cli_->Post(path, headers, std::move(content_provider), content_type, progress); } +Result Client::Post(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + DownloadProgress progress) { + return cli_->Post(path, headers, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Post(const std::string &path, const Params ¶ms) { return cli_->Post(path, params); } @@ -8938,8 +11327,8 @@ Result Client::Post(const std::string &path, const Headers &headers, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress) { - return cli_->Post(path, headers, body, content_type, content_receiver, - progress); + return cli_->Post(path, headers, body, content_type, + std::move(content_receiver), progress); } Result Client::Put(const std::string &path) { return cli_->Put(path); } @@ -8976,12 +11365,28 @@ Result Client::Put(const std::string &path, size_t content_length, return cli_->Put(path, content_length, std::move(content_provider), content_type, progress); } +Result Client::Put(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Put(path, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { return cli_->Put(path, std::move(content_provider), content_type, progress); } +Result Client::Put(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Put(path, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, @@ -8990,6 +11395,15 @@ Result Client::Put(const std::string &path, const Headers &headers, return cli_->Put(path, headers, content_length, std::move(content_provider), content_type, progress); } +Result Client::Put(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Put(path, headers, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -8997,6 +11411,14 @@ Result Client::Put(const std::string &path, const Headers &headers, return cli_->Put(path, headers, std::move(content_provider), content_type, progress); } +Result Client::Put(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Put(path, headers, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Put(const std::string &path, const Params ¶ms) { return cli_->Put(path, params); } @@ -9072,12 +11494,28 @@ Result Client::Patch(const std::string &path, size_t content_length, return cli_->Patch(path, content_length, std::move(content_provider), content_type, progress); } +Result Client::Patch(const std::string &path, size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Patch(path, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress) { return cli_->Patch(path, std::move(content_provider), content_type, progress); } +Result Client::Patch(const std::string &path, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Patch(path, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, @@ -9086,6 +11524,15 @@ Result Client::Patch(const std::string &path, const Headers &headers, return cli_->Patch(path, headers, content_length, std::move(content_provider), content_type, progress); } +Result Client::Patch(const std::string &path, const Headers &headers, + size_t content_length, + ContentProvider content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Patch(path, headers, content_length, std::move(content_provider), + content_type, std::move(content_receiver), progress); +} Result Client::Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, @@ -9093,6 +11540,14 @@ Result Client::Patch(const std::string &path, const Headers &headers, return cli_->Patch(path, headers, std::move(content_provider), content_type, progress); } +Result Client::Patch(const std::string &path, const Headers &headers, + ContentProviderWithoutLength content_provider, + const std::string &content_type, + ContentReceiver content_receiver, + UploadProgress progress) { + return cli_->Patch(path, headers, std::move(content_provider), content_type, + std::move(content_receiver), progress); +} Result Client::Patch(const std::string &path, const Params ¶ms) { return cli_->Patch(path, params); } @@ -9179,6 +11634,13 @@ Result Client::Options(const std::string &path, const Headers &headers) { return cli_->Options(path, headers); } +ClientImpl::StreamHandle +Client::open_stream(const std::string &method, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, const std::string &content_type) { + return cli_->open_stream(method, path, params, headers, body, content_type); +} + bool Client::send(Request &req, Response &res, Error &error) { return cli_->send(req, res, error); } @@ -9238,12 +11700,6 @@ void Client::set_basic_auth(const std::string &username, void Client::set_bearer_token_auth(const std::string &token) { cli_->set_bearer_token_auth(token); } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -void Client::set_digest_auth(const std::string &username, - const std::string &password) { - cli_->set_digest_auth(username, password); -} -#endif void Client::set_keep_alive(bool on) { cli_->set_keep_alive(on); } void Client::set_follow_location(bool on) { @@ -9252,15 +11708,14 @@ void Client::set_follow_location(bool on) { void Client::set_path_encode(bool on) { cli_->set_path_encode(on); } -[[deprecated("Use set_path_encode instead")]] -void Client::set_url_encode(bool on) { - cli_->set_path_encode(on); -} - void Client::set_compress(bool on) { cli_->set_compress(on); } void Client::set_decompress(bool on) { cli_->set_decompress(on); } +void Client::set_payload_max_length(size_t length) { + cli_->set_payload_max_length(length); +} + void Client::set_interface(const std::string &intf) { cli_->set_interface(intf); } @@ -9275,27 +11730,9 @@ void Client::set_proxy_basic_auth(const std::string &username, void Client::set_proxy_bearer_token_auth(const std::string &token) { cli_->set_proxy_bearer_token_auth(token); } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -void Client::set_proxy_digest_auth(const std::string &username, - const std::string &password) { - cli_->set_proxy_digest_auth(username, password); +void Client::set_no_proxy(const std::vector &patterns) { + cli_->set_no_proxy(patterns); } -#endif - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -void Client::enable_server_certificate_verification(bool enabled) { - cli_->enable_server_certificate_verification(enabled); -} - -void Client::enable_server_hostname_verification(bool enabled) { - cli_->enable_server_hostname_verification(enabled); -} - -void Client::set_server_certificate_verifier( - std::function verifier) { - cli_->set_server_certificate_verifier(verifier); -} -#endif void Client::set_logger(Logger logger) { cli_->set_logger(std::move(logger)); @@ -9305,35 +11742,4769 @@ void Client::set_error_logger(ErrorLogger error_logger) { cli_->set_error_logger(std::move(error_logger)); } +/* + * Group 6: SSL Server and Client implementation + */ + +#ifdef CPPHTTPLIB_SSL_ENABLED + +// SSL HTTP server implementation +SSLServer::SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path, + const char *client_ca_cert_dir_path, + const char *private_key_password) { + using namespace tls; + + ctx_ = create_server_context(); + if (!ctx_) { return; } + + // Load server certificate and private key + if (!set_server_cert_file(ctx_, cert_path, private_key_path, + private_key_password)) { + last_ssl_error_ = static_cast(get_error()); + free_context(ctx_); + ctx_ = nullptr; + return; + } + + // Load client CA certificates for client authentication + if (client_ca_cert_file_path || client_ca_cert_dir_path) { + if (!set_client_ca_file(ctx_, client_ca_cert_file_path, + client_ca_cert_dir_path)) { + last_ssl_error_ = static_cast(get_error()); + free_context(ctx_); + ctx_ = nullptr; + return; + } + // Enable client certificate verification + set_verify_client(ctx_, true); + } +} + +SSLServer::SSLServer(const PemMemory &pem) { + using namespace tls; + ctx_ = create_server_context(); + if (ctx_) { + if (!set_server_cert_pem(ctx_, pem.cert_pem, pem.key_pem, + pem.private_key_password)) { + last_ssl_error_ = static_cast(get_error()); + free_context(ctx_); + ctx_ = nullptr; + } else if (pem.client_ca_pem && pem.client_ca_pem_len > 0) { + if (!load_ca_pem(ctx_, pem.client_ca_pem, pem.client_ca_pem_len)) { + last_ssl_error_ = static_cast(get_error()); + free_context(ctx_); + ctx_ = nullptr; + } else { + set_verify_client(ctx_, true); + } + } + } +} + +SSLServer::SSLServer(const tls::ContextSetupCallback &setup_callback) { + using namespace tls; + ctx_ = create_server_context(); + if (ctx_) { + if (!setup_callback(ctx_)) { + free_context(ctx_); + ctx_ = nullptr; + } + } +} + +SSLServer::~SSLServer() { + if (ctx_) { tls::free_context(ctx_); } +} + +bool SSLServer::is_valid() const { return ctx_ != nullptr; } + +bool SSLServer::process_and_close_socket(socket_t sock) { + using namespace tls; + + // Create TLS session with mutex protection + session_t session = nullptr; + { + std::lock_guard guard(ctx_mutex_); + session = create_session(static_cast(ctx_), sock); + } + + if (!session) { + last_ssl_error_ = static_cast(get_error()); + detail::shutdown_socket(sock); + detail::close_socket(sock); + return false; + } + + // Use scope_exit to ensure cleanup on all paths (including exceptions) + bool handshake_done = false; + bool ret = false; + bool websocket_upgraded = false; + auto cleanup = detail::scope_exit([&] { + if (handshake_done) { shutdown(session, !websocket_upgraded && ret); } + free_session(session); + detail::shutdown_socket(sock); + detail::close_socket(sock); + }); + + // Perform TLS accept handshake with timeout + TlsError tls_err; + if (!accept_nonblocking(session, sock, read_timeout_sec_, read_timeout_usec_, + &tls_err)) { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT + // Map TlsError to legacy ssl_error for backward compatibility + if (tls_err.code == ErrorCode::WantRead) { + last_ssl_error_ = SSL_ERROR_WANT_READ; + } else if (tls_err.code == ErrorCode::WantWrite) { + last_ssl_error_ = SSL_ERROR_WANT_WRITE; + } else { + last_ssl_error_ = SSL_ERROR_SSL; + } +#else + last_ssl_error_ = static_cast(get_error()); +#endif + return false; + } + + handshake_done = true; + + std::string remote_addr; + int remote_port = 0; + detail::get_remote_ip_and_port(sock, remote_addr, remote_port); + + std::string local_addr; + int local_port = 0; + detail::get_local_ip_and_port(sock, local_addr, local_port); + + ret = detail::process_server_socket_ssl( + svr_sock_, session, sock, keep_alive_max_count_, keep_alive_timeout_sec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, + [&](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request( + strm, remote_addr, remote_port, local_addr, local_port, + close_connection, connection_closed, + [&](Request &req) { req.ssl = session; }, &websocket_upgraded); + }); + + return ret; +} + +bool SSLServer::update_certs_pem(const char *cert_pem, + const char *key_pem, + const char *client_ca_pem, + const char *password) { + if (!ctx_) { return false; } + std::lock_guard guard(ctx_mutex_); + if (!tls::update_server_cert(ctx_, cert_pem, key_pem, password)) { + return false; + } + if (client_ca_pem) { + return tls::update_server_client_ca(ctx_, client_ca_pem); + } + return true; +} + +// SSL HTTP client implementation +SSLClient::~SSLClient() { + if (ctx_) { tls::free_context(ctx_); } + // Make sure to shut down SSL since shutdown_ssl will resolve to the + // base function rather than the derived function once we get to the + // base class destructor, and won't free the SSL (causing a leak). + shutdown_ssl_impl(socket_, true); +} + +bool SSLClient::is_valid() const { return ctx_ != nullptr; } + +void SSLClient::shutdown_ssl(Socket &socket, bool shutdown_gracefully) { + shutdown_ssl_impl(socket, shutdown_gracefully); +} + +void SSLClient::shutdown_ssl_impl(Socket &socket, + bool shutdown_gracefully) { + if (socket.sock == INVALID_SOCKET) { + assert(socket.ssl == nullptr); + return; + } + if (socket.ssl) { + tls::shutdown(socket.ssl, shutdown_gracefully); + { + std::lock_guard guard(ctx_mutex_); + tls::free_session(socket.ssl); + } + socket.ssl = nullptr; + } + assert(socket.ssl == nullptr); +} + +bool SSLClient::process_socket( + const Socket &socket, + std::chrono::time_point start_time, + std::function callback) { + assert(socket.ssl); + return detail::process_client_socket_ssl( + socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, start_time, + std::move(callback)); +} + +bool SSLClient::is_ssl() const { return true; } + +bool SSLClient::create_and_connect_socket(Socket &socket, Error &error) { + if (!is_valid()) { + error = Error::SSLConnection; + return false; + } + return ClientImpl::create_and_connect_socket(socket, error); +} + +bool SSLClient::setup_proxy_connection( + Socket &socket, + std::chrono::time_point start_time, + Response &res, bool &success, Error &error) { + if (!is_proxy_enabled_for_host(host_)) { return true; } + + if (!connect_with_proxy(socket, start_time, res, success, error)) { + return false; + } + + if (!initialize_ssl(socket, error)) { + success = false; + return false; + } + + return true; +} + +// Assumes that socket_mutex_ is locked and that there are no requests in +// flight +bool SSLClient::connect_with_proxy( + Socket &socket, + std::chrono::time_point start_time, + Response &res, bool &success, Error &error) { + success = true; + Response proxy_res; + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, + start_time, [&](Stream &strm) { + Request req2; + req2.method = "CONNECT"; + req2.path = + detail::make_host_and_port_string_always_port(host_, port_); + if (max_timeout_msec_ > 0) { + req2.start_time_ = std::chrono::steady_clock::now(); + } + return process_request(strm, req2, proxy_res, false, error); + })) { + // Thread-safe to close everything because we are assuming there are no + // requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + success = false; + return false; + } + + if (proxy_res.status == StatusCode::ProxyAuthenticationRequired_407) { + if (!proxy_digest_auth_username_.empty() && + !proxy_digest_auth_password_.empty()) { + std::map auth; + if (detail::parse_www_authenticate(proxy_res, auth, true)) { + // Close the current socket and create a new one for the authenticated + // request + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + + // Create a new socket for the authenticated CONNECT request + if (!ensure_socket_connection(socket, error)) { + success = false; + output_error_log(error, nullptr); + return false; + } + + proxy_res = Response(); + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, + start_time, [&](Stream &strm) { + Request req3; + req3.method = "CONNECT"; + req3.path = detail::make_host_and_port_string_always_port( + host_, port_); + req3.headers.insert(detail::make_digest_authentication_header( + req3, auth, 1, detail::random_string(10), + proxy_digest_auth_username_, proxy_digest_auth_password_, + true)); + if (max_timeout_msec_ > 0) { + req3.start_time_ = std::chrono::steady_clock::now(); + } + return process_request(strm, req3, proxy_res, false, error); + })) { + // Thread-safe to close everything because we are assuming there are + // no requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + success = false; + return false; + } + } + } + } + + // If status code is not 200, proxy request is failed. + // Set error to ProxyConnection and return proxy response + // as the response of the request + if (proxy_res.status != StatusCode::OK_200) { + error = Error::ProxyConnection; + output_error_log(error, nullptr); + res = std::move(proxy_res); + // Thread-safe to close everything because we are assuming there are + // no requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + return false; + } + + return true; +} + +bool SSLClient::ensure_socket_connection(Socket &socket, Error &error) { + if (!ClientImpl::ensure_socket_connection(socket, error)) { return false; } + + if (is_proxy_enabled_for_host(host_)) { return true; } + + if (!initialize_ssl(socket, error)) { + shutdown_socket(socket); + close_socket(socket); + return false; + } + + return true; +} + +// SSL HTTP client implementation +SSLClient::SSLClient(const std::string &host) + : SSLClient(host, 443, std::string(), std::string()) {} + +SSLClient::SSLClient(const std::string &host, int port) + : SSLClient(host, port, std::string(), std::string()) {} + +void SSLClient::init_ctx() { + ctx_ = tls::create_client_context(); + if (ctx_) { tls::set_min_version(ctx_, tls::Version::TLS1_2); } +} + +void SSLClient::reset_ctx_on_error() { + last_backend_error_ = tls::get_error(); + tls::free_context(ctx_); + ctx_ = nullptr; +} + +SSLClient::SSLClient(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path, + const std::string &private_key_password) + : ClientImpl(host, port, client_cert_path, client_key_path) { + init_ctx(); + if (!ctx_) { return; } + + if (!client_cert_path.empty() && !client_key_path.empty()) { + const char *password = + private_key_password.empty() ? nullptr : private_key_password.c_str(); + if (!tls::set_client_cert_file(ctx_, client_cert_path.c_str(), + client_key_path.c_str(), password)) { + reset_ctx_on_error(); + } + } +} + +SSLClient::SSLClient(const std::string &host, int port, + const PemMemory &pem) + : ClientImpl(host, port) { + init_ctx(); + if (!ctx_) { return; } + + if (pem.cert_pem && pem.key_pem) { + if (!tls::set_client_cert_pem(ctx_, pem.cert_pem, pem.key_pem, + pem.private_key_password)) { + reset_ctx_on_error(); + } + } +} + +void SSLClient::set_ca_cert_store(tls::ca_store_t ca_cert_store) { + if (ca_cert_store && ctx_) { + // set_ca_store takes ownership of ca_cert_store + tls::set_ca_store(ctx_, ca_cert_store); + } else if (ca_cert_store) { + tls::free_ca_store(ca_cert_store); + } +} + +void +SSLClient::set_server_certificate_verifier(tls::VerifyCallback verifier) { + if (!ctx_) { return; } + tls::set_verify_callback(ctx_, verifier); +} + +void SSLClient::set_session_verifier( + std::function verifier) { + session_verifier_ = std::move(verifier); +} + +#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE +void SSLClient::enable_windows_certificate_verification(bool enabled) { + enable_windows_cert_verification_ = enabled; +} +#endif + +void SSLClient::load_ca_cert_store(const char *ca_cert, + std::size_t size) { + if (ctx_ && ca_cert && size > 0) { + ca_cert_pem_.assign(ca_cert, size); // Store for redirect transfer + tls::load_ca_pem(ctx_, ca_cert, size); + } +} + +bool SSLClient::load_certs() { + auto ret = true; + + std::call_once(initialize_cert_, [&]() { + std::lock_guard guard(ctx_mutex_); + + if (!ca_cert_file_path_.empty()) { + if (!tls::load_ca_file(ctx_, ca_cert_file_path_.c_str())) { + last_backend_error_ = tls::get_error(); + ret = false; + } + } else if (!ca_cert_dir_path_.empty()) { + if (!tls::load_ca_dir(ctx_, ca_cert_dir_path_.c_str())) { + last_backend_error_ = tls::get_error(); + ret = false; + } + } else if (ca_cert_pem_.empty()) { + if (!tls::load_system_certs(ctx_)) { + last_backend_error_ = tls::get_error(); + } + } + }); + + return ret; +} + +bool SSLClient::initialize_ssl(Socket &socket, Error &error) { + using namespace tls; + + // Load CA certificates if server verification is enabled + if (server_certificate_verification_) { + if (!load_certs()) { + error = Error::SSLLoadingCerts; + output_error_log(error, nullptr); + return false; + } + } + + bool is_ip = detail::is_ip_address(host_); + +#if defined(CPPHTTPLIB_MBEDTLS_SUPPORT) || defined(CPPHTTPLIB_WOLFSSL_SUPPORT) + // MbedTLS/wolfSSL need explicit verification mode (OpenSSL uses + // SSL_VERIFY_NONE by default and performs all verification post-handshake). + // For IP addresses with verification enabled, use OPTIONAL mode since + // these backends require hostname for strict verification. + if (is_ip && server_certificate_verification_) { + set_verify_client(ctx_, false); + } else { + set_verify_client(ctx_, server_certificate_verification_); + } +#endif + + // Create TLS session + session_t session = nullptr; + { + std::lock_guard guard(ctx_mutex_); + session = create_session(ctx_, socket.sock); + } + + if (!session) { + error = Error::SSLConnection; + last_backend_error_ = get_error(); + return false; + } + + // Use scope_exit to ensure session is freed on error paths + bool success = false; + auto session_guard = detail::scope_exit([&] { + if (!success) { free_session(session); } + }); + + // Set SNI extension (skip for IP addresses per RFC 6066). + // On MbedTLS, set_sni also enables hostname verification internally. + // On OpenSSL, set_sni only sets SNI; verification is done post-handshake. + if (!is_ip) { + if (!set_sni(session, host_.c_str())) { + error = Error::SSLConnection; + last_backend_error_ = get_error(); + return false; + } + } + + // Perform non-blocking TLS handshake with timeout + TlsError tls_err; + if (!connect_nonblocking(session, socket.sock, connection_timeout_sec_, + connection_timeout_usec_, &tls_err)) { + last_ssl_error_ = static_cast(tls_err.code); + last_backend_error_ = tls_err.backend_code; + if (tls_err.code == ErrorCode::CertVerifyFailed) { + error = Error::SSLServerVerification; + } else if (tls_err.code == ErrorCode::HostnameMismatch) { + error = Error::SSLServerHostnameVerification; + } else { + error = Error::SSLConnection; + } + output_error_log(error, nullptr); + return false; + } + + // Post-handshake session verifier callback + auto verification_status = SSLVerifierResponse::NoDecisionMade; + if (session_verifier_) { verification_status = session_verifier_(session); } + + if (verification_status == SSLVerifierResponse::CertificateRejected) { + last_backend_error_ = get_error(); + error = Error::SSLServerVerification; + output_error_log(error, nullptr); + return false; + } + + // Default server certificate verification + if (verification_status == SSLVerifierResponse::NoDecisionMade && + server_certificate_verification_) { + verify_result_ = tls::get_verify_result(session); + if (verify_result_ != 0) { + last_backend_error_ = static_cast(verify_result_); + error = Error::SSLServerVerification; + output_error_log(error, nullptr); + return false; + } + + auto server_cert = get_peer_cert(session); + if (!server_cert) { + last_backend_error_ = get_error(); + error = Error::SSLServerVerification; + output_error_log(error, nullptr); + return false; + } + auto cert_guard = detail::scope_exit([&] { free_cert(server_cert); }); + + // Hostname verification (post-handshake for all cases). + // On OpenSSL, verification is always post-handshake (SSL_VERIFY_NONE). + // On MbedTLS, set_sni already enabled hostname verification during + // handshake for non-IP hosts, but this check is still needed for IP + // addresses where SNI is not set. + if (server_hostname_verification_) { + if (!verify_hostname(server_cert, host_.c_str())) { + last_backend_error_ = hostname_mismatch_code(); + error = Error::SSLServerHostnameVerification; + output_error_log(error, nullptr); + return false; + } + } + +#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE + // Additional Windows Schannel verification. + // This provides real-time certificate validation with Windows Update + // integration, working with both OpenSSL and MbedTLS backends. + // Skip when a custom CA cert is specified, as the Windows certificate + // store would not know about user-provided CA certificates. + if (enable_windows_cert_verification_ && ca_cert_file_path_.empty() && + ca_cert_dir_path_.empty() && ca_cert_pem_.empty()) { + std::vector der; + if (get_cert_der(server_cert, der)) { + uint64_t wincrypt_error = 0; + if (!detail::verify_cert_with_windows_schannel( + der, host_, server_hostname_verification_, wincrypt_error)) { + last_backend_error_ = wincrypt_error; + error = Error::SSLServerVerification; + output_error_log(error, nullptr); + return false; + } + } + } +#endif + } + + success = true; + socket.ssl = session; + return true; +} + +void Client::set_digest_auth(const std::string &username, + const std::string &password) { + cli_->set_digest_auth(username, password); +} + +void Client::set_proxy_digest_auth(const std::string &username, + const std::string &password) { + cli_->set_proxy_digest_auth(username, password); +} + +void Client::enable_server_certificate_verification(bool enabled) { + cli_->enable_server_certificate_verification(enabled); +} + +void Client::enable_server_hostname_verification(bool enabled) { + cli_->enable_server_hostname_verification(enabled); +} + +#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE +void Client::enable_windows_certificate_verification(bool enabled) { + if (is_ssl_) { + static_cast(*cli_).enable_windows_certificate_verification( + enabled); + } +} +#endif + void Client::set_ca_cert_path(const std::string &ca_cert_file_path, const std::string &ca_cert_dir_path) { cli_->set_ca_cert_path(ca_cert_file_path, ca_cert_dir_path); } -void Client::set_ca_cert_store(X509_STORE *ca_cert_store) { +void Client::set_ca_cert_store(tls::ca_store_t ca_cert_store) { if (is_ssl_) { static_cast(*cli_).set_ca_cert_store(ca_cert_store); - } else { - cli_->set_ca_cert_store(ca_cert_store); + } else if (ca_cert_store) { + tls::free_ca_store(ca_cert_store); } } void Client::load_ca_cert_store(const char *ca_cert, std::size_t size) { - set_ca_cert_store(cli_->create_ca_cert_store(ca_cert, size)); + set_ca_cert_store(tls::create_ca_store(ca_cert, size)); } -long Client::get_openssl_verify_result() const { +void +Client::set_server_certificate_verifier(tls::VerifyCallback verifier) { if (is_ssl_) { - return static_cast(*cli_).get_openssl_verify_result(); + static_cast(*cli_).set_server_certificate_verifier( + std::move(verifier)); } - return -1; // NOTE: -1 doesn't match any of X509_V_ERR_??? } -SSL_CTX *Client::ssl_context() const { - if (is_ssl_) { return static_cast(*cli_).ssl_context(); } +void Client::set_session_verifier( + std::function verifier) { + if (is_ssl_) { + static_cast(*cli_).set_session_verifier(std::move(verifier)); + } +} + +tls::ctx_t Client::tls_context() const { + if (is_ssl_) { return static_cast(*cli_).tls_context(); } return nullptr; } + +#endif // CPPHTTPLIB_SSL_ENABLED + +/* + * Group 7: TLS abstraction layer - Common API + */ + +#ifdef CPPHTTPLIB_SSL_ENABLED + +namespace tls { + +// Helper for PeerCert construction +PeerCert get_peer_cert_from_session(const_session_t session) { + return PeerCert(get_peer_cert(session)); +} + +namespace impl { + +VerifyCallback &get_verify_callback() { + static thread_local VerifyCallback callback; + return callback; +} + +VerifyCallback &get_mbedtls_verify_callback() { + static thread_local VerifyCallback callback; + return callback; +} + +// Check if a string is an IPv4 address +bool is_ipv4_address(const std::string &str) { + int dots = 0; + for (char c : str) { + if (c == '.') { + dots++; + } else if (!isdigit(static_cast(c))) { + return false; + } + } + return dots == 3; +} + +// Parse IPv4 address string to bytes +bool parse_ipv4(const std::string &str, unsigned char *out) { + const char *p = str.c_str(); + for (int i = 0; i < 4; i++) { + if (i > 0) { + if (*p != '.') { return false; } + p++; + } + int val = 0; + int digits = 0; + while (*p >= '0' && *p <= '9') { + val = val * 10 + (*p - '0'); + if (val > 255) { return false; } + p++; + digits++; + } + if (digits == 0) { return false; } + // Reject leading zeros (e.g., "01.002.03.04") to prevent ambiguity + if (digits > 1 && *(p - digits) == '0') { return false; } + out[i] = static_cast(val); + } + return *p == '\0'; +} + +#ifdef _WIN32 +// Enumerate Windows system certificates and call callback with DER data +template +bool enumerate_windows_system_certs(Callback cb) { + bool loaded = false; + static const wchar_t *store_names[] = {L"ROOT", L"CA"}; + for (auto store_name : store_names) { + HCERTSTORE hStore = CertOpenSystemStoreW(0, store_name); + if (hStore) { + PCCERT_CONTEXT pContext = nullptr; + while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != + nullptr) { + if (cb(pContext->pbCertEncoded, pContext->cbCertEncoded)) { + loaded = true; + } + } + CertCloseStore(hStore, 0); + } + } + return loaded; +} #endif +#ifdef CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN +// Enumerate macOS Keychain certificates and call callback with DER data +template +bool enumerate_macos_keychain_certs(Callback cb) { + bool loaded = false; + const SecTrustSettingsDomain domains[] = { + kSecTrustSettingsDomainSystem, + kSecTrustSettingsDomainAdmin, + kSecTrustSettingsDomainUser, + }; + for (auto domain : domains) { + CFArrayRef certs = nullptr; + OSStatus status = SecTrustSettingsCopyCertificates(domain, &certs); + if (status != errSecSuccess || !certs) { + if (certs) CFRelease(certs); + continue; + } + CFIndex count = CFArrayGetCount(certs); + for (CFIndex i = 0; i < count; i++) { + SecCertificateRef cert = + (SecCertificateRef)CFArrayGetValueAtIndex(certs, i); + CFDataRef data = SecCertificateCopyData(cert); + if (data) { + if (cb(CFDataGetBytePtr(data), + static_cast(CFDataGetLength(data)))) { + loaded = true; + } + CFRelease(data); + } + } + CFRelease(certs); + } + return loaded; +} +#endif + +#if !defined(_WIN32) && !(defined(__APPLE__) && \ + defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN)) +// Common CA certificate file paths on Linux/Unix +const char **system_ca_paths() { + static const char *paths[] = { + "/etc/ssl/certs/ca-certificates.crt", // Debian/Ubuntu + "/etc/pki/tls/certs/ca-bundle.crt", // RHEL/CentOS + "/etc/ssl/ca-bundle.pem", // OpenSUSE + "/etc/pki/tls/cacert.pem", // OpenELEC + "/etc/ssl/cert.pem", // Alpine, FreeBSD + nullptr}; + return paths; +} + +// Common CA certificate directory paths on Linux/Unix +const char **system_ca_dirs() { + static const char *dirs[] = {"/etc/ssl/certs", // Debian/Ubuntu + "/etc/pki/tls/certs", // RHEL/CentOS + "/usr/share/ca-certificates", // Other + nullptr}; + return dirs; +} +#endif + +} // namespace impl + +bool set_client_ca_file(ctx_t ctx, const char *ca_file, + const char *ca_dir) { + if (!ctx) { return false; } + + bool success = true; + if (ca_file && *ca_file) { + if (!load_ca_file(ctx, ca_file)) { success = false; } + } + if (ca_dir && *ca_dir) { + if (!load_ca_dir(ctx, ca_dir)) { success = false; } + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + // Set CA list for client certificate request (CertificateRequest message) + if (ca_file && *ca_file) { + auto list = SSL_load_client_CA_file(ca_file); + if (list) { SSL_CTX_set_client_CA_list(static_cast(ctx), list); } + } +#endif + + return success; +} + +bool set_server_cert_pem(ctx_t ctx, const char *cert, const char *key, + const char *password) { + return set_client_cert_pem(ctx, cert, key, password); +} + +bool set_server_cert_file(ctx_t ctx, const char *cert_path, + const char *key_path, const char *password) { + return set_client_cert_file(ctx, cert_path, key_path, password); +} + +// PeerCert implementation +PeerCert::PeerCert() = default; + +PeerCert::PeerCert(cert_t cert) : cert_(cert) {} + +PeerCert::PeerCert(PeerCert &&other) noexcept : cert_(other.cert_) { + other.cert_ = nullptr; +} + +PeerCert &PeerCert::operator=(PeerCert &&other) noexcept { + if (this != &other) { + if (cert_) { free_cert(cert_); } + cert_ = other.cert_; + other.cert_ = nullptr; + } + return *this; +} + +PeerCert::~PeerCert() { + if (cert_) { free_cert(cert_); } +} + +PeerCert::operator bool() const { return cert_ != nullptr; } + +std::string PeerCert::subject_cn() const { + return cert_ ? get_cert_subject_cn(cert_) : std::string(); +} + +std::string PeerCert::issuer_name() const { + return cert_ ? get_cert_issuer_name(cert_) : std::string(); +} + +bool PeerCert::check_hostname(const char *hostname) const { + return cert_ ? verify_hostname(cert_, hostname) : false; +} + +std::vector PeerCert::sans() const { + std::vector result; + if (cert_) { get_cert_sans(cert_, result); } + return result; +} + +bool PeerCert::validity(time_t ¬_before, time_t ¬_after) const { + return cert_ ? get_cert_validity(cert_, not_before, not_after) : false; +} + +std::string PeerCert::serial() const { + return cert_ ? get_cert_serial(cert_) : std::string(); +} + +// VerifyContext method implementations +std::string VerifyContext::subject_cn() const { + return cert ? get_cert_subject_cn(cert) : std::string(); +} + +std::string VerifyContext::issuer_name() const { + return cert ? get_cert_issuer_name(cert) : std::string(); +} + +bool VerifyContext::check_hostname(const char *hostname) const { + return cert ? verify_hostname(cert, hostname) : false; +} + +std::vector VerifyContext::sans() const { + std::vector result; + if (cert) { get_cert_sans(cert, result); } + return result; +} + +bool VerifyContext::validity(time_t ¬_before, + time_t ¬_after) const { + return cert ? get_cert_validity(cert, not_before, not_after) : false; +} + +std::string VerifyContext::serial() const { + return cert ? get_cert_serial(cert) : std::string(); +} + +// TlsError static method implementation +std::string TlsError::verify_error_to_string(long error_code) { + return verify_error_string(error_code); +} + +} // namespace tls + +// Request::peer_cert() implementation +tls::PeerCert Request::peer_cert() const { + return tls::get_peer_cert_from_session(ssl); +} + +// Request::sni() implementation +std::string Request::sni() const { + if (!ssl) { return std::string(); } + const char *s = tls::get_sni(ssl); + return s ? std::string(s) : std::string(); +} + +#endif // CPPHTTPLIB_SSL_ENABLED + +/* + * Group 8: TLS abstraction layer - OpenSSL backend + */ + +/* + * OpenSSL Backend Implementation + */ + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +namespace tls { + +namespace impl { + +// Helper to map OpenSSL SSL_get_error to ErrorCode +ErrorCode map_ssl_error(int ssl_error, int &out_errno) { + switch (ssl_error) { + case SSL_ERROR_NONE: return ErrorCode::Success; + case SSL_ERROR_WANT_READ: return ErrorCode::WantRead; + case SSL_ERROR_WANT_WRITE: return ErrorCode::WantWrite; + case SSL_ERROR_ZERO_RETURN: return ErrorCode::PeerClosed; + case SSL_ERROR_SYSCALL: out_errno = errno; return ErrorCode::SyscallError; + case SSL_ERROR_SSL: + default: return ErrorCode::Fatal; + } +} + +// Helper: Create client CA list from PEM string +// Returns a new STACK_OF(X509_NAME)* or nullptr on failure +// Caller takes ownership of returned list +STACK_OF(X509_NAME) * + create_client_ca_list_from_pem(const char *ca_pem) { + if (!ca_pem) { return nullptr; } + + auto ca_list = sk_X509_NAME_new_null(); + if (!ca_list) { return nullptr; } + + BIO *bio = BIO_new_mem_buf(ca_pem, -1); + if (!bio) { + sk_X509_NAME_pop_free(ca_list, X509_NAME_free); + return nullptr; + } + + X509 *cert = nullptr; + while ((cert = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr)) != + nullptr) { + const X509_NAME *name = X509_get_subject_name(cert); + if (name) { + sk_X509_NAME_push(ca_list, X509_NAME_dup(const_cast(name))); + } + X509_free(cert); + } + BIO_free(bio); + + return ca_list; +} + +// OpenSSL verify callback wrapper +int openssl_verify_callback(int preverify_ok, X509_STORE_CTX *ctx) { + auto &callback = get_verify_callback(); + if (!callback) { return preverify_ok; } + + // Get SSL object from X509_STORE_CTX + auto ssl = static_cast( + X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx())); + if (!ssl) { return preverify_ok; } + + // Get current certificate and depth + auto cert = X509_STORE_CTX_get_current_cert(ctx); + int depth = X509_STORE_CTX_get_error_depth(ctx); + int error = X509_STORE_CTX_get_error(ctx); + + // Build context + VerifyContext verify_ctx; + verify_ctx.session = static_cast(ssl); + verify_ctx.cert = static_cast(cert); + verify_ctx.depth = depth; + verify_ctx.preverify_ok = (preverify_ok != 0); + verify_ctx.error_code = error; + verify_ctx.error_string = + (error != X509_V_OK) ? X509_verify_cert_error_string(error) : nullptr; + + return callback(verify_ctx) ? 1 : 0; +} + +} // namespace impl + +ctx_t create_client_context() { + SSL_CTX *ctx = SSL_CTX_new(TLS_client_method()); + if (ctx) { + // Disable auto-retry to properly handle non-blocking I/O + SSL_CTX_clear_mode(ctx, SSL_MODE_AUTO_RETRY); + // Set minimum TLS version + SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION); + } + return static_cast(ctx); +} + +void free_context(ctx_t ctx) { + if (ctx) { SSL_CTX_free(static_cast(ctx)); } +} + +bool set_min_version(ctx_t ctx, Version version) { + if (!ctx) return false; + return SSL_CTX_set_min_proto_version(static_cast(ctx), + static_cast(version)) == 1; +} + +bool load_ca_pem(ctx_t ctx, const char *pem, size_t len) { + if (!ctx || !pem || len == 0) return false; + + auto ssl_ctx = static_cast(ctx); + auto store = SSL_CTX_get_cert_store(ssl_ctx); + if (!store) return false; + + auto bio = BIO_new_mem_buf(pem, static_cast(len)); + if (!bio) return false; + + bool ok = true; + X509 *cert = nullptr; + while ((cert = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr)) != + nullptr) { + if (X509_STORE_add_cert(store, cert) != 1) { + // Ignore duplicate errors + auto err = ERR_peek_last_error(); + if (ERR_GET_REASON(err) != X509_R_CERT_ALREADY_IN_HASH_TABLE) { + ok = false; + } + } + X509_free(cert); + if (!ok) break; + } + BIO_free(bio); + + // Clear any "no more certificates" errors + ERR_clear_error(); + return ok; +} + +bool load_ca_file(ctx_t ctx, const char *file_path) { + if (!ctx || !file_path) return false; + return SSL_CTX_load_verify_locations(static_cast(ctx), file_path, + nullptr) == 1; +} + +bool load_ca_dir(ctx_t ctx, const char *dir_path) { + if (!ctx || !dir_path) return false; + return SSL_CTX_load_verify_locations(static_cast(ctx), nullptr, + dir_path) == 1; +} + +bool load_system_certs(ctx_t ctx) { + if (!ctx) return false; + auto ssl_ctx = static_cast(ctx); + +#ifdef _WIN32 + // Windows: Load from system certificate store (ROOT and CA) + auto store = SSL_CTX_get_cert_store(ssl_ctx); + if (!store) return false; + + bool loaded_any = false; + static const wchar_t *store_names[] = {L"ROOT", L"CA"}; + for (auto store_name : store_names) { + auto hStore = CertOpenSystemStoreW(NULL, store_name); + if (!hStore) continue; + + PCCERT_CONTEXT pContext = nullptr; + while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != + nullptr) { + const unsigned char *data = pContext->pbCertEncoded; + auto x509 = d2i_X509(nullptr, &data, pContext->cbCertEncoded); + if (x509) { + if (X509_STORE_add_cert(store, x509) == 1) { loaded_any = true; } + X509_free(x509); + } + } + CertCloseStore(hStore, 0); + } + return loaded_any; + +#elif defined(__APPLE__) +#ifdef CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN + // macOS: Load from Keychain + auto store = SSL_CTX_get_cert_store(ssl_ctx); + if (!store) return false; + + bool loaded_any = false; + const SecTrustSettingsDomain domains[] = { + kSecTrustSettingsDomainSystem, + kSecTrustSettingsDomainAdmin, + kSecTrustSettingsDomainUser, + }; + for (auto domain : domains) { + CFArrayRef certs = nullptr; + if (SecTrustSettingsCopyCertificates(domain, &certs) != errSecSuccess || + !certs) { + if (certs) CFRelease(certs); + continue; + } + auto count = CFArrayGetCount(certs); + for (CFIndex i = 0; i < count; i++) { + auto cert = reinterpret_cast( + const_cast(CFArrayGetValueAtIndex(certs, i))); + CFDataRef der = SecCertificateCopyData(cert); + if (der) { + const unsigned char *data = CFDataGetBytePtr(der); + auto x509 = d2i_X509(nullptr, &data, CFDataGetLength(der)); + if (x509) { + if (X509_STORE_add_cert(store, x509) == 1) { loaded_any = true; } + X509_free(x509); + } + CFRelease(der); + } + } + CFRelease(certs); + } + return loaded_any || SSL_CTX_set_default_verify_paths(ssl_ctx) == 1; +#else + return SSL_CTX_set_default_verify_paths(ssl_ctx) == 1; +#endif + +#else + // Other Unix: use default verify paths + return SSL_CTX_set_default_verify_paths(ssl_ctx) == 1; +#endif +} + +bool set_client_cert_pem(ctx_t ctx, const char *cert, const char *key, + const char *password) { + if (!ctx || !cert || !key) return false; + + auto ssl_ctx = static_cast(ctx); + + // Load certificate + auto cert_bio = BIO_new_mem_buf(cert, -1); + if (!cert_bio) return false; + + auto x509 = PEM_read_bio_X509(cert_bio, nullptr, nullptr, nullptr); + BIO_free(cert_bio); + if (!x509) return false; + + auto cert_ok = SSL_CTX_use_certificate(ssl_ctx, x509) == 1; + X509_free(x509); + if (!cert_ok) return false; + + // Load private key + auto key_bio = BIO_new_mem_buf(key, -1); + if (!key_bio) return false; + + auto pkey = PEM_read_bio_PrivateKey(key_bio, nullptr, nullptr, + password ? const_cast(password) + : nullptr); + BIO_free(key_bio); + if (!pkey) return false; + + auto key_ok = SSL_CTX_use_PrivateKey(ssl_ctx, pkey) == 1; + EVP_PKEY_free(pkey); + + return key_ok && SSL_CTX_check_private_key(ssl_ctx) == 1; +} + +bool set_client_cert_file(ctx_t ctx, const char *cert_path, + const char *key_path, const char *password) { + if (!ctx || !cert_path || !key_path) return false; + + auto ssl_ctx = static_cast(ctx); + + if (password && password[0] != '\0') { + SSL_CTX_set_default_passwd_cb_userdata( + ssl_ctx, reinterpret_cast(const_cast(password))); + } + + return SSL_CTX_use_certificate_chain_file(ssl_ctx, cert_path) == 1 && + SSL_CTX_use_PrivateKey_file(ssl_ctx, key_path, SSL_FILETYPE_PEM) == 1; +} + +ctx_t create_server_context() { + SSL_CTX *ctx = SSL_CTX_new(TLS_server_method()); + if (ctx) { + SSL_CTX_set_options(ctx, SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION); + } + return static_cast(ctx); +} + +void set_verify_client(ctx_t ctx, bool require) { + if (!ctx) return; + SSL_CTX_set_verify(static_cast(ctx), + require + ? (SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT) + : SSL_VERIFY_NONE, + nullptr); +} + +session_t create_session(ctx_t ctx, socket_t sock) { + if (!ctx || sock == INVALID_SOCKET) return nullptr; + + auto ssl_ctx = static_cast(ctx); + SSL *ssl = SSL_new(ssl_ctx); + if (!ssl) return nullptr; + + // Disable auto-retry for proper non-blocking I/O handling + SSL_clear_mode(ssl, SSL_MODE_AUTO_RETRY); + + auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); + if (!bio) { + SSL_free(ssl); + return nullptr; + } + + SSL_set_bio(ssl, bio, bio); + return static_cast(ssl); +} + +void free_session(session_t session) { + if (session) { SSL_free(static_cast(session)); } +} + +bool set_sni(session_t session, const char *hostname) { + if (!session || !hostname) return false; + + auto ssl = static_cast(session); + + // Set SNI (Server Name Indication) only - does not enable verification +#if defined(OPENSSL_IS_BORINGSSL) + return SSL_set_tlsext_host_name(ssl, hostname) == 1; +#else + // Direct call instead of macro to suppress -Wold-style-cast warning + return SSL_ctrl(ssl, SSL_CTRL_SET_TLSEXT_HOSTNAME, TLSEXT_NAMETYPE_host_name, + static_cast(const_cast(hostname))) == 1; +#endif +} + +bool set_hostname(session_t session, const char *hostname) { + if (!session || !hostname) return false; + + auto ssl = static_cast(session); + + // Set SNI (Server Name Indication) + if (!set_sni(session, hostname)) { return false; } + + // Enable hostname verification + auto param = SSL_get0_param(ssl); + if (!param) return false; + + X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS); + if (X509_VERIFY_PARAM_set1_host(param, hostname, 0) != 1) { return false; } + + SSL_set_verify(ssl, SSL_VERIFY_PEER, nullptr); + return true; +} + +TlsError connect(session_t session) { + if (!session) { return TlsError(); } + + auto ssl = static_cast(session); + auto ret = SSL_connect(ssl); + + TlsError err; + if (ret == 1) { + err.code = ErrorCode::Success; + } else { + auto ssl_err = SSL_get_error(ssl, ret); + err.code = impl::map_ssl_error(ssl_err, err.sys_errno); + err.backend_code = ERR_get_error(); + } + return err; +} + +TlsError accept(session_t session) { + if (!session) { return TlsError(); } + + auto ssl = static_cast(session); + auto ret = SSL_accept(ssl); + + TlsError err; + if (ret == 1) { + err.code = ErrorCode::Success; + } else { + auto ssl_err = SSL_get_error(ssl, ret); + err.code = impl::map_ssl_error(ssl_err, err.sys_errno); + err.backend_code = ERR_get_error(); + } + return err; +} + +bool connect_nonblocking(session_t session, socket_t sock, + time_t timeout_sec, time_t timeout_usec, + TlsError *err) { + if (!session) { + if (err) { err->code = ErrorCode::Fatal; } + return false; + } + + auto ssl = static_cast(session); + auto bio = SSL_get_rbio(ssl); + + // Set non-blocking mode for handshake + detail::set_nonblocking(sock, true); + if (bio) { BIO_set_nbio(bio, 1); } + + auto cleanup = detail::scope_exit([&]() { + // Restore blocking mode after handshake + if (bio) { BIO_set_nbio(bio, 0); } + detail::set_nonblocking(sock, false); + }); + + auto res = 0; + while ((res = SSL_connect(ssl)) != 1) { + auto ssl_err = SSL_get_error(ssl, res); + switch (ssl_err) { + case SSL_ERROR_WANT_READ: + if (detail::select_read(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + break; + case SSL_ERROR_WANT_WRITE: + if (detail::select_write(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + break; + default: break; + } + if (err) { + err->code = impl::map_ssl_error(ssl_err, err->sys_errno); + err->backend_code = ERR_get_error(); + } + return false; + } + if (err) { err->code = ErrorCode::Success; } + return true; +} + +bool accept_nonblocking(session_t session, socket_t sock, + time_t timeout_sec, time_t timeout_usec, + TlsError *err) { + if (!session) { + if (err) { err->code = ErrorCode::Fatal; } + return false; + } + + auto ssl = static_cast(session); + auto bio = SSL_get_rbio(ssl); + + // Set non-blocking mode for handshake + detail::set_nonblocking(sock, true); + if (bio) { BIO_set_nbio(bio, 1); } + + auto cleanup = detail::scope_exit([&]() { + // Restore blocking mode after handshake + if (bio) { BIO_set_nbio(bio, 0); } + detail::set_nonblocking(sock, false); + }); + + auto res = 0; + while ((res = SSL_accept(ssl)) != 1) { + auto ssl_err = SSL_get_error(ssl, res); + switch (ssl_err) { + case SSL_ERROR_WANT_READ: + if (detail::select_read(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + break; + case SSL_ERROR_WANT_WRITE: + if (detail::select_write(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + break; + default: break; + } + if (err) { + err->code = impl::map_ssl_error(ssl_err, err->sys_errno); + err->backend_code = ERR_get_error(); + } + return false; + } + if (err) { err->code = ErrorCode::Success; } + return true; +} + +ssize_t read(session_t session, void *buf, size_t len, TlsError &err) { + if (!session || !buf) { + err.code = ErrorCode::Fatal; + return -1; + } + + auto ssl = static_cast(session); + constexpr auto max_len = + static_cast((std::numeric_limits::max)()); + if (len > max_len) { len = max_len; } + auto ret = SSL_read(ssl, buf, static_cast(len)); + + if (ret > 0) { + err.code = ErrorCode::Success; + return ret; + } + + auto ssl_err = SSL_get_error(ssl, ret); + err.code = impl::map_ssl_error(ssl_err, err.sys_errno); + if (err.code == ErrorCode::PeerClosed) { + return 0; + } // Gracefully handle the peer closed state. + if (err.code == ErrorCode::Fatal) { err.backend_code = ERR_get_error(); } + return -1; +} + +ssize_t write(session_t session, const void *buf, size_t len, + TlsError &err) { + if (!session || !buf) { + err.code = ErrorCode::Fatal; + return -1; + } + + auto ssl = static_cast(session); + auto ret = SSL_write(ssl, buf, static_cast(len)); + + if (ret > 0) { + err.code = ErrorCode::Success; + return ret; + } + + auto ssl_err = SSL_get_error(ssl, ret); + err.code = impl::map_ssl_error(ssl_err, err.sys_errno); + if (err.code == ErrorCode::Fatal) { err.backend_code = ERR_get_error(); } + return -1; +} + +int pending(const_session_t session) { + if (!session) return 0; + return SSL_pending(static_cast(const_cast(session))); +} + +void shutdown(session_t session, bool graceful) { + if (!session) return; + + auto ssl = static_cast(session); + if (graceful) { + // First call sends close_notify + if (SSL_shutdown(ssl) == 0) { + // Second call waits for peer's close_notify + SSL_shutdown(ssl); + } + } +} + +bool is_peer_closed(session_t session, socket_t sock) { + if (!session) return true; + + // Temporarily set socket to non-blocking to avoid blocking on SSL_peek + detail::set_nonblocking(sock, true); + auto se = detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); + + auto ssl = static_cast(session); + char buf; + auto ret = SSL_peek(ssl, &buf, 1); + if (ret > 0) return false; + + auto err = SSL_get_error(ssl, ret); + return err == SSL_ERROR_ZERO_RETURN; +} + +cert_t get_peer_cert(const_session_t session) { + if (!session) return nullptr; + return static_cast(SSL_get1_peer_certificate( + static_cast(const_cast(session)))); +} + +void free_cert(cert_t cert) { + if (cert) { X509_free(static_cast(cert)); } +} + +bool verify_hostname(cert_t cert, const char *hostname) { + if (!cert || !hostname) return false; + + auto x509 = static_cast(cert); + + // Use X509_check_ip_asc for IP addresses, X509_check_host for DNS names + if (detail::is_ip_address(hostname)) { + return X509_check_ip_asc(x509, hostname, 0) == 1; + } + return X509_check_host(x509, hostname, strlen(hostname), 0, nullptr) == 1; +} + +uint64_t hostname_mismatch_code() { + return static_cast(X509_V_ERR_HOSTNAME_MISMATCH); +} + +long get_verify_result(const_session_t session) { + if (!session) return X509_V_ERR_UNSPECIFIED; + return SSL_get_verify_result(static_cast(const_cast(session))); +} + +std::string get_cert_subject_cn(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + auto subject_name = X509_get_subject_name(x509); + if (!subject_name) return ""; + + char buf[256]; + auto len = + X509_NAME_get_text_by_NID(subject_name, NID_commonName, buf, sizeof(buf)); + if (len < 0) return ""; + return std::string(buf, static_cast(len)); +} + +std::string get_cert_issuer_name(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + auto issuer_name = X509_get_issuer_name(x509); + if (!issuer_name) return ""; + + char buf[256]; + X509_NAME_oneline(issuer_name, buf, sizeof(buf)); + return std::string(buf); +} + +bool get_cert_sans(cert_t cert, std::vector &sans) { + sans.clear(); + if (!cert) return false; + auto x509 = static_cast(cert); + + auto names = static_cast( + X509_get_ext_d2i(x509, NID_subject_alt_name, nullptr, nullptr)); + if (!names) return true; // No SANs is valid + + auto count = sk_GENERAL_NAME_num(names); + for (decltype(count) i = 0; i < count; i++) { + auto gen = sk_GENERAL_NAME_value(names, i); + if (!gen) continue; + + SanEntry entry; + switch (gen->type) { + case GEN_DNS: + entry.type = SanType::DNS; + if (gen->d.dNSName) { + entry.value = std::string( + reinterpret_cast( + ASN1_STRING_get0_data(gen->d.dNSName)), + static_cast(ASN1_STRING_length(gen->d.dNSName))); + } + break; + case GEN_IPADD: + entry.type = SanType::IP; + if (gen->d.iPAddress) { + auto data = ASN1_STRING_get0_data(gen->d.iPAddress); + auto len = ASN1_STRING_length(gen->d.iPAddress); + if (len == 4) { + // IPv4 + char buf[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, data, buf, sizeof(buf)); + entry.value = buf; + } else if (len == 16) { + // IPv6 + char buf[INET6_ADDRSTRLEN]; + inet_ntop(AF_INET6, data, buf, sizeof(buf)); + entry.value = buf; + } + } + break; + case GEN_EMAIL: + entry.type = SanType::EMAIL; + if (gen->d.rfc822Name) { + entry.value = std::string( + reinterpret_cast( + ASN1_STRING_get0_data(gen->d.rfc822Name)), + static_cast(ASN1_STRING_length(gen->d.rfc822Name))); + } + break; + case GEN_URI: + entry.type = SanType::URI; + if (gen->d.uniformResourceIdentifier) { + entry.value = std::string( + reinterpret_cast( + ASN1_STRING_get0_data(gen->d.uniformResourceIdentifier)), + static_cast( + ASN1_STRING_length(gen->d.uniformResourceIdentifier))); + } + break; + default: entry.type = SanType::OTHER; break; + } + + if (!entry.value.empty()) { sans.push_back(std::move(entry)); } + } + + GENERAL_NAMES_free(names); + return true; +} + +bool get_cert_validity(cert_t cert, time_t ¬_before, + time_t ¬_after) { + if (!cert) return false; + auto x509 = static_cast(cert); + + auto nb = X509_get0_notBefore(x509); + auto na = X509_get0_notAfter(x509); + if (!nb || !na) return false; + + ASN1_TIME *epoch = ASN1_TIME_new(); + if (!epoch) return false; + auto se = detail::scope_exit([&] { ASN1_TIME_free(epoch); }); + + if (!ASN1_TIME_set(epoch, 0)) return false; + + int pday, psec; + + if (!ASN1_TIME_diff(&pday, &psec, epoch, nb)) return false; + not_before = 86400 * (time_t)pday + psec; + + if (!ASN1_TIME_diff(&pday, &psec, epoch, na)) return false; + not_after = 86400 * (time_t)pday + psec; + + return true; +} + +std::string get_cert_serial(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + + auto serial = X509_get_serialNumber(x509); + if (!serial) return ""; + + auto bn = ASN1_INTEGER_to_BN(serial, nullptr); + if (!bn) return ""; + + auto hex = BN_bn2hex(bn); + BN_free(bn); + if (!hex) return ""; + + std::string result(hex); + OPENSSL_free(hex); + return result; +} + +bool get_cert_der(cert_t cert, std::vector &der) { + if (!cert) return false; + auto x509 = static_cast(cert); + auto len = i2d_X509(x509, nullptr); + if (len < 0) return false; + der.resize(static_cast(len)); + auto p = der.data(); + i2d_X509(x509, &p); + return true; +} + +const char *get_sni(const_session_t session) { + if (!session) return nullptr; + auto ssl = static_cast(const_cast(session)); + return SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); +} + +uint64_t peek_error() { return ERR_peek_last_error(); } + +uint64_t get_error() { return ERR_get_error(); } + +std::string error_string(uint64_t code) { + char buf[256]; + ERR_error_string_n(static_cast(code), buf, sizeof(buf)); + return std::string(buf); +} + +ca_store_t create_ca_store(const char *pem, size_t len) { + auto mem = BIO_new_mem_buf(pem, static_cast(len)); + if (!mem) { return nullptr; } + auto mem_guard = detail::scope_exit([&] { BIO_free_all(mem); }); + + auto inf = PEM_X509_INFO_read_bio(mem, nullptr, nullptr, nullptr); + if (!inf) { return nullptr; } + + auto store = X509_STORE_new(); + if (store) { + for (auto i = 0; i < static_cast(sk_X509_INFO_num(inf)); i++) { + auto itmp = sk_X509_INFO_value(inf, i); + if (!itmp) { continue; } + if (itmp->x509) { X509_STORE_add_cert(store, itmp->x509); } + if (itmp->crl) { X509_STORE_add_crl(store, itmp->crl); } + } + } + + sk_X509_INFO_pop_free(inf, X509_INFO_free); + return static_cast(store); +} + +void free_ca_store(ca_store_t store) { + if (store) { X509_STORE_free(static_cast(store)); } +} + +bool set_ca_store(ctx_t ctx, ca_store_t store) { + if (!ctx || !store) { return false; } + auto ssl_ctx = static_cast(ctx); + auto x509_store = static_cast(store); + + // Check if same store is already set + if (SSL_CTX_get_cert_store(ssl_ctx) == x509_store) { return true; } + + // SSL_CTX_set_cert_store takes ownership and frees the old store + SSL_CTX_set_cert_store(ssl_ctx, x509_store); + return true; +} + +size_t get_ca_certs(ctx_t ctx, std::vector &certs) { + certs.clear(); + if (!ctx) { return 0; } + auto ssl_ctx = static_cast(ctx); + + auto store = SSL_CTX_get_cert_store(ssl_ctx); + if (!store) { return 0; } + + auto objs = X509_STORE_get0_objects(store); + if (!objs) { return 0; } + + auto count = sk_X509_OBJECT_num(objs); + for (decltype(count) i = 0; i < count; i++) { + auto obj = sk_X509_OBJECT_value(objs, i); + if (!obj) { continue; } + if (X509_OBJECT_get_type(obj) == X509_LU_X509) { + auto x509 = X509_OBJECT_get0_X509(obj); + if (x509) { + // Increment reference count so caller can free it + X509_up_ref(x509); + certs.push_back(static_cast(x509)); + } + } + } + return certs.size(); +} + +std::vector get_ca_names(ctx_t ctx) { + std::vector names; + if (!ctx) { return names; } + auto ssl_ctx = static_cast(ctx); + + auto store = SSL_CTX_get_cert_store(ssl_ctx); + if (!store) { return names; } + + auto objs = X509_STORE_get0_objects(store); + if (!objs) { return names; } + + auto count = sk_X509_OBJECT_num(objs); + for (decltype(count) i = 0; i < count; i++) { + auto obj = sk_X509_OBJECT_value(objs, i); + if (!obj) { continue; } + if (X509_OBJECT_get_type(obj) == X509_LU_X509) { + auto x509 = X509_OBJECT_get0_X509(obj); + if (x509) { + auto subject = X509_get_subject_name(x509); + if (subject) { + char buf[512]; + X509_NAME_oneline(subject, buf, sizeof(buf)); + names.push_back(buf); + } + } + } + } + return names; +} + +bool update_server_cert(ctx_t ctx, const char *cert_pem, + const char *key_pem, const char *password) { + if (!ctx || !cert_pem || !key_pem) { return false; } + auto ssl_ctx = static_cast(ctx); + + // Load certificate from PEM + auto cert_bio = BIO_new_mem_buf(cert_pem, -1); + if (!cert_bio) { return false; } + auto cert = PEM_read_bio_X509(cert_bio, nullptr, nullptr, nullptr); + BIO_free(cert_bio); + if (!cert) { return false; } + + // Load private key from PEM + auto key_bio = BIO_new_mem_buf(key_pem, -1); + if (!key_bio) { + X509_free(cert); + return false; + } + auto key = PEM_read_bio_PrivateKey(key_bio, nullptr, nullptr, + password ? const_cast(password) + : nullptr); + BIO_free(key_bio); + if (!key) { + X509_free(cert); + return false; + } + + // Update certificate and key + auto ret = SSL_CTX_use_certificate(ssl_ctx, cert) == 1 && + SSL_CTX_use_PrivateKey(ssl_ctx, key) == 1; + + X509_free(cert); + EVP_PKEY_free(key); + return ret; +} + +bool update_server_client_ca(ctx_t ctx, const char *ca_pem) { + if (!ctx || !ca_pem) { return false; } + auto ssl_ctx = static_cast(ctx); + + // Create new X509_STORE from PEM + auto store = create_ca_store(ca_pem, strlen(ca_pem)); + if (!store) { return false; } + + // SSL_CTX_set_cert_store takes ownership + SSL_CTX_set_cert_store(ssl_ctx, static_cast(store)); + + // Set client CA list for client certificate request + auto ca_list = impl::create_client_ca_list_from_pem(ca_pem); + if (ca_list) { + // SSL_CTX_set_client_CA_list takes ownership of ca_list + SSL_CTX_set_client_CA_list(ssl_ctx, ca_list); + } + + return true; +} + +bool set_verify_callback(ctx_t ctx, VerifyCallback callback) { + if (!ctx) { return false; } + auto ssl_ctx = static_cast(ctx); + + impl::get_verify_callback() = std::move(callback); + + if (impl::get_verify_callback()) { + SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, impl::openssl_verify_callback); + } else { + SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, nullptr); + } + return true; +} + +long get_verify_error(const_session_t session) { + if (!session) { return -1; } + auto ssl = static_cast(const_cast(session)); + return SSL_get_verify_result(ssl); +} + +std::string verify_error_string(long error_code) { + if (error_code == X509_V_OK) { return ""; } + const char *str = X509_verify_cert_error_string(static_cast(error_code)); + return str ? str : "unknown error"; +} + +} // namespace tls + +bool SSLClient::verify_host(X509 *server_cert) const { + /* Quote from RFC2818 section 3.1 "Server Identity" + + If a subjectAltName extension of type dNSName is present, that MUST + be used as the identity. Otherwise, the (most specific) Common Name + field in the Subject field of the certificate MUST be used. Although + the use of the Common Name is existing practice, it is deprecated and + Certification Authorities are encouraged to use the dNSName instead. + + Matching is performed using the matching rules specified by + [RFC2459]. If more than one identity of a given type is present in + the certificate (e.g., more than one dNSName name, a match in any one + of the set is considered acceptable.) Names may contain the wildcard + character * which is considered to match any single domain name + component or component fragment. E.g., *.a.com matches foo.a.com but + not bar.foo.a.com. f*.com matches foo.com but not bar.com. + + In some cases, the URI is specified as an IP address rather than a + hostname. In this case, the iPAddress subjectAltName must be present + in the certificate and must exactly match the IP in the URI. + + */ + return verify_host_with_subject_alt_name(server_cert) || + verify_host_with_common_name(server_cert); +} + +bool +SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { + auto ret = false; + + auto type = GEN_DNS; + + struct in6_addr addr6 = {}; + struct in_addr addr = {}; + size_t addr_len = 0; + +#ifndef __MINGW32__ + if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { + type = GEN_IPADD; + addr_len = sizeof(struct in6_addr); + } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { + type = GEN_IPADD; + addr_len = sizeof(struct in_addr); + } +#endif + + auto alt_names = static_cast( + X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); + + if (alt_names) { + auto dsn_matched = false; + auto ip_matched = false; + + auto count = sk_GENERAL_NAME_num(alt_names); + + for (decltype(count) i = 0; i < count && !dsn_matched; i++) { + auto val = sk_GENERAL_NAME_value(alt_names, i); + if (!val || val->type != type) { continue; } + + auto name = + reinterpret_cast(ASN1_STRING_get0_data(val->d.ia5)); + if (name == nullptr) { continue; } + + auto name_len = static_cast(ASN1_STRING_length(val->d.ia5)); + + switch (type) { + case GEN_DNS: + dsn_matched = + detail::match_hostname(std::string(name, name_len), host_); + break; + + case GEN_IPADD: + if (!memcmp(&addr6, name, addr_len) || !memcmp(&addr, name, addr_len)) { + ip_matched = true; + } + break; + } + } + + if (dsn_matched || ip_matched) { ret = true; } + } + + GENERAL_NAMES_free(const_cast( + reinterpret_cast(alt_names))); + return ret; +} + +bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { + const auto subject_name = X509_get_subject_name(server_cert); + + if (subject_name != nullptr) { + char name[BUFSIZ]; + auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, + name, sizeof(name)); + + if (name_len != -1) { + return detail::match_hostname( + std::string(name, static_cast(name_len)), host_); + } + } + + return false; +} + +#endif // CPPHTTPLIB_OPENSSL_SUPPORT + +/* + * Group 9: TLS abstraction layer - Mbed TLS backend + */ + +/* + * Mbed TLS Backend Implementation + */ + +#ifdef CPPHTTPLIB_MBEDTLS_SUPPORT +namespace tls { + +namespace impl { + +// Mbed TLS session wrapper +struct MbedTlsSession { + mbedtls_ssl_context ssl; + socket_t sock = INVALID_SOCKET; + std::string hostname; // For client: set via set_sni + std::string sni_hostname; // For server: received from client via SNI callback + + MbedTlsSession() { mbedtls_ssl_init(&ssl); } + + ~MbedTlsSession() { mbedtls_ssl_free(&ssl); } + + MbedTlsSession(const MbedTlsSession &) = delete; + MbedTlsSession &operator=(const MbedTlsSession &) = delete; +}; + +// Thread-local error code accessor for Mbed TLS (since it doesn't have an error +// queue) +int &mbedtls_last_error() { + static thread_local int err = 0; + return err; +} + +// Helper to map Mbed TLS error to ErrorCode +ErrorCode map_mbedtls_error(int ret, int &out_errno) { + if (ret == 0) { return ErrorCode::Success; } + if (ret == MBEDTLS_ERR_SSL_WANT_READ) { return ErrorCode::WantRead; } + if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) { return ErrorCode::WantWrite; } + if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { + return ErrorCode::PeerClosed; + } + if (ret == MBEDTLS_ERR_NET_CONN_RESET || ret == MBEDTLS_ERR_NET_SEND_FAILED || + ret == MBEDTLS_ERR_NET_RECV_FAILED) { + out_errno = errno; + return ErrorCode::SyscallError; + } + if (ret == MBEDTLS_ERR_X509_CERT_VERIFY_FAILED) { + return ErrorCode::CertVerifyFailed; + } + return ErrorCode::Fatal; +} + +// BIO-like send callback for Mbed TLS +int mbedtls_net_send_cb(void *ctx, const unsigned char *buf, + size_t len) { + auto sock = *static_cast(ctx); +#ifdef _WIN32 + auto ret = + send(sock, reinterpret_cast(buf), static_cast(len), 0); + if (ret == SOCKET_ERROR) { + int err = WSAGetLastError(); + if (err == WSAEWOULDBLOCK) { return MBEDTLS_ERR_SSL_WANT_WRITE; } + return MBEDTLS_ERR_NET_SEND_FAILED; + } +#else + auto ret = send(sock, buf, len, 0); + if (ret < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return MBEDTLS_ERR_SSL_WANT_WRITE; + } + return MBEDTLS_ERR_NET_SEND_FAILED; + } +#endif + return static_cast(ret); +} + +// BIO-like recv callback for Mbed TLS +int mbedtls_net_recv_cb(void *ctx, unsigned char *buf, size_t len) { + auto sock = *static_cast(ctx); +#ifdef _WIN32 + auto ret = + recv(sock, reinterpret_cast(buf), static_cast(len), 0); + if (ret == SOCKET_ERROR) { + int err = WSAGetLastError(); + if (err == WSAEWOULDBLOCK) { return MBEDTLS_ERR_SSL_WANT_READ; } + return MBEDTLS_ERR_NET_RECV_FAILED; + } +#else + auto ret = recv(sock, buf, len, 0); + if (ret < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return MBEDTLS_ERR_SSL_WANT_READ; + } + return MBEDTLS_ERR_NET_RECV_FAILED; + } +#endif + if (ret == 0) { return MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY; } + return static_cast(ret); +} + +// MbedTlsContext constructor/destructor implementations +MbedTlsContext::MbedTlsContext() { + mbedtls_ssl_config_init(&conf); + mbedtls_entropy_init(&entropy); + mbedtls_ctr_drbg_init(&ctr_drbg); + mbedtls_x509_crt_init(&ca_chain); + mbedtls_x509_crt_init(&own_cert); + mbedtls_pk_init(&own_key); +} + +MbedTlsContext::~MbedTlsContext() { + mbedtls_pk_free(&own_key); + mbedtls_x509_crt_free(&own_cert); + mbedtls_x509_crt_free(&ca_chain); + mbedtls_ctr_drbg_free(&ctr_drbg); + mbedtls_entropy_free(&entropy); + mbedtls_ssl_config_free(&conf); +} + +// Thread-local storage for SNI captured during handshake +// This is needed because the SNI callback doesn't have a way to pass +// session-specific data before the session is fully set up +std::string &mbedpending_sni() { + static thread_local std::string sni; + return sni; +} + +// SNI callback for Mbed TLS server to capture client's SNI hostname +int mbedtls_sni_callback(void *p_ctx, mbedtls_ssl_context *ssl, + const unsigned char *name, size_t name_len) { + (void)p_ctx; + (void)ssl; + + // Store SNI name in thread-local storage + // It will be retrieved and stored in the session after handshake + if (name && name_len > 0) { + mbedpending_sni().assign(reinterpret_cast(name), name_len); + } else { + mbedpending_sni().clear(); + } + return 0; // Accept any SNI +} + +int mbedtls_verify_callback(void *data, mbedtls_x509_crt *crt, + int cert_depth, uint32_t *flags); + +// MbedTLS verify callback wrapper +int mbedtls_verify_callback(void *data, mbedtls_x509_crt *crt, + int cert_depth, uint32_t *flags) { + auto &callback = get_verify_callback(); + if (!callback) { return 0; } // Continue with default verification + + // data points to the MbedTlsSession + auto *session = static_cast(data); + + // Build context + VerifyContext verify_ctx; + verify_ctx.session = static_cast(session); + verify_ctx.cert = static_cast(crt); + verify_ctx.depth = cert_depth; + verify_ctx.preverify_ok = (*flags == 0); + verify_ctx.error_code = static_cast(*flags); + + // Convert Mbed TLS flags to error string + static thread_local char error_buf[256]; + if (*flags != 0) { + mbedtls_x509_crt_verify_info(error_buf, sizeof(error_buf), "", *flags); + verify_ctx.error_string = error_buf; + } else { + verify_ctx.error_string = nullptr; + } + + bool accepted = callback(verify_ctx); + + if (accepted) { + *flags = 0; // Clear all error flags + return 0; + } + return MBEDTLS_ERR_X509_CERT_VERIFY_FAILED; +} + +} // namespace impl + +ctx_t create_client_context() { + auto ctx = new (std::nothrow) impl::MbedTlsContext(); + if (!ctx) { return nullptr; } + + ctx->is_server = false; + + // Seed the random number generator + const char *pers = "httplib_client"; + int ret = mbedtls_ctr_drbg_seed( + &ctx->ctr_drbg, mbedtls_entropy_func, &ctx->entropy, + reinterpret_cast(pers), strlen(pers)); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + delete ctx; + return nullptr; + } + + // Set up SSL config for client + ret = mbedtls_ssl_config_defaults(&ctx->conf, MBEDTLS_SSL_IS_CLIENT, + MBEDTLS_SSL_TRANSPORT_STREAM, + MBEDTLS_SSL_PRESET_DEFAULT); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + delete ctx; + return nullptr; + } + + // Set random number generator + mbedtls_ssl_conf_rng(&ctx->conf, mbedtls_ctr_drbg_random, &ctx->ctr_drbg); + + // Default: verify peer certificate + mbedtls_ssl_conf_authmode(&ctx->conf, MBEDTLS_SSL_VERIFY_REQUIRED); + + // Set minimum TLS version to 1.2 +#ifdef CPPHTTPLIB_MBEDTLS_V3 + mbedtls_ssl_conf_min_tls_version(&ctx->conf, MBEDTLS_SSL_VERSION_TLS1_2); +#else + mbedtls_ssl_conf_min_version(&ctx->conf, MBEDTLS_SSL_MAJOR_VERSION_3, + MBEDTLS_SSL_MINOR_VERSION_3); +#endif + + return static_cast(ctx); +} + +ctx_t create_server_context() { + auto ctx = new (std::nothrow) impl::MbedTlsContext(); + if (!ctx) { return nullptr; } + + ctx->is_server = true; + + // Seed the random number generator + const char *pers = "httplib_server"; + int ret = mbedtls_ctr_drbg_seed( + &ctx->ctr_drbg, mbedtls_entropy_func, &ctx->entropy, + reinterpret_cast(pers), strlen(pers)); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + delete ctx; + return nullptr; + } + + // Set up SSL config for server + ret = mbedtls_ssl_config_defaults(&ctx->conf, MBEDTLS_SSL_IS_SERVER, + MBEDTLS_SSL_TRANSPORT_STREAM, + MBEDTLS_SSL_PRESET_DEFAULT); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + delete ctx; + return nullptr; + } + + // Set random number generator + mbedtls_ssl_conf_rng(&ctx->conf, mbedtls_ctr_drbg_random, &ctx->ctr_drbg); + + // Default: don't verify client + mbedtls_ssl_conf_authmode(&ctx->conf, MBEDTLS_SSL_VERIFY_NONE); + + // Set minimum TLS version to 1.2 +#ifdef CPPHTTPLIB_MBEDTLS_V3 + mbedtls_ssl_conf_min_tls_version(&ctx->conf, MBEDTLS_SSL_VERSION_TLS1_2); +#else + mbedtls_ssl_conf_min_version(&ctx->conf, MBEDTLS_SSL_MAJOR_VERSION_3, + MBEDTLS_SSL_MINOR_VERSION_3); +#endif + + // Set SNI callback to capture client's SNI hostname + mbedtls_ssl_conf_sni(&ctx->conf, impl::mbedtls_sni_callback, nullptr); + + return static_cast(ctx); +} + +void free_context(ctx_t ctx) { + if (ctx) { delete static_cast(ctx); } +} + +bool set_min_version(ctx_t ctx, Version version) { + if (!ctx) { return false; } + auto mctx = static_cast(ctx); + +#ifdef CPPHTTPLIB_MBEDTLS_V3 + // Mbed TLS 3.x uses mbedtls_ssl_protocol_version enum + mbedtls_ssl_protocol_version min_ver = MBEDTLS_SSL_VERSION_TLS1_2; + if (version >= Version::TLS1_3) { +#if defined(MBEDTLS_SSL_PROTO_TLS1_3) + min_ver = MBEDTLS_SSL_VERSION_TLS1_3; +#endif + } + mbedtls_ssl_conf_min_tls_version(&mctx->conf, min_ver); +#else + // Mbed TLS 2.x uses major/minor version numbers + int major = MBEDTLS_SSL_MAJOR_VERSION_3; + int minor = MBEDTLS_SSL_MINOR_VERSION_3; // TLS 1.2 + if (version >= Version::TLS1_3) { +#if defined(MBEDTLS_SSL_PROTO_TLS1_3) + minor = MBEDTLS_SSL_MINOR_VERSION_4; // TLS 1.3 +#else + minor = MBEDTLS_SSL_MINOR_VERSION_3; // Fall back to TLS 1.2 +#endif + } + mbedtls_ssl_conf_min_version(&mctx->conf, major, minor); +#endif + return true; +} + +bool load_ca_pem(ctx_t ctx, const char *pem, size_t len) { + if (!ctx || !pem) { return false; } + auto mctx = static_cast(ctx); + + // mbedtls_x509_crt_parse expects null-terminated string for PEM + // Add null terminator if not present + std::string pem_str(pem, len); + int ret = mbedtls_x509_crt_parse( + &mctx->ca_chain, reinterpret_cast(pem_str.c_str()), + pem_str.size() + 1); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + mbedtls_ssl_conf_ca_chain(&mctx->conf, &mctx->ca_chain, nullptr); + return true; +} + +bool load_ca_file(ctx_t ctx, const char *file_path) { + if (!ctx || !file_path) { return false; } + auto mctx = static_cast(ctx); + + int ret = mbedtls_x509_crt_parse_file(&mctx->ca_chain, file_path); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + mbedtls_ssl_conf_ca_chain(&mctx->conf, &mctx->ca_chain, nullptr); + return true; +} + +bool load_ca_dir(ctx_t ctx, const char *dir_path) { + if (!ctx || !dir_path) { return false; } + auto mctx = static_cast(ctx); + + int ret = mbedtls_x509_crt_parse_path(&mctx->ca_chain, dir_path); + if (ret < 0) { // Returns number of certs on success, negative on error + impl::mbedtls_last_error() = ret; + return false; + } + + mbedtls_ssl_conf_ca_chain(&mctx->conf, &mctx->ca_chain, nullptr); + return true; +} + +bool load_system_certs(ctx_t ctx) { + if (!ctx) { return false; } + auto mctx = static_cast(ctx); + bool loaded = false; + +#ifdef _WIN32 + loaded = impl::enumerate_windows_system_certs( + [&](const unsigned char *data, size_t len) { + return mbedtls_x509_crt_parse_der(&mctx->ca_chain, data, len) == 0; + }); +#elif defined(__APPLE__) && defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) + loaded = impl::enumerate_macos_keychain_certs( + [&](const unsigned char *data, size_t len) { + return mbedtls_x509_crt_parse_der(&mctx->ca_chain, data, len) == 0; + }); +#else + for (auto path = impl::system_ca_paths(); *path; ++path) { + if (mbedtls_x509_crt_parse_file(&mctx->ca_chain, *path) >= 0) { + loaded = true; + break; + } + } + + if (!loaded) { + for (auto dir = impl::system_ca_dirs(); *dir; ++dir) { + if (mbedtls_x509_crt_parse_path(&mctx->ca_chain, *dir) >= 0) { + loaded = true; + break; + } + } + } +#endif + + if (loaded) { + mbedtls_ssl_conf_ca_chain(&mctx->conf, &mctx->ca_chain, nullptr); + } + return loaded; +} + +bool set_client_cert_pem(ctx_t ctx, const char *cert, const char *key, + const char *password) { + if (!ctx || !cert || !key) { return false; } + auto mctx = static_cast(ctx); + + // Parse certificate + std::string cert_str(cert); + int ret = mbedtls_x509_crt_parse( + &mctx->own_cert, + reinterpret_cast(cert_str.c_str()), + cert_str.size() + 1); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + // Parse private key + std::string key_str(key); + const unsigned char *pwd = + password ? reinterpret_cast(password) : nullptr; + size_t pwd_len = password ? strlen(password) : 0; + +#ifdef CPPHTTPLIB_MBEDTLS_V3 + ret = mbedtls_pk_parse_key( + &mctx->own_key, reinterpret_cast(key_str.c_str()), + key_str.size() + 1, pwd, pwd_len, mbedtls_ctr_drbg_random, + &mctx->ctr_drbg); +#else + ret = mbedtls_pk_parse_key( + &mctx->own_key, reinterpret_cast(key_str.c_str()), + key_str.size() + 1, pwd, pwd_len); +#endif + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + // Verify that the certificate and private key match +#ifdef CPPHTTPLIB_MBEDTLS_V3 + ret = mbedtls_pk_check_pair(&mctx->own_cert.pk, &mctx->own_key, + mbedtls_ctr_drbg_random, &mctx->ctr_drbg); +#else + ret = mbedtls_pk_check_pair(&mctx->own_cert.pk, &mctx->own_key); +#endif + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + ret = mbedtls_ssl_conf_own_cert(&mctx->conf, &mctx->own_cert, &mctx->own_key); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + return true; +} + +bool set_client_cert_file(ctx_t ctx, const char *cert_path, + const char *key_path, const char *password) { + if (!ctx || !cert_path || !key_path) { return false; } + auto mctx = static_cast(ctx); + + // Parse certificate file + int ret = mbedtls_x509_crt_parse_file(&mctx->own_cert, cert_path); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + // Parse private key file +#ifdef CPPHTTPLIB_MBEDTLS_V3 + ret = mbedtls_pk_parse_keyfile(&mctx->own_key, key_path, password, + mbedtls_ctr_drbg_random, &mctx->ctr_drbg); +#else + ret = mbedtls_pk_parse_keyfile(&mctx->own_key, key_path, password); +#endif + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + // Verify that the certificate and private key match +#ifdef CPPHTTPLIB_MBEDTLS_V3 + ret = mbedtls_pk_check_pair(&mctx->own_cert.pk, &mctx->own_key, + mbedtls_ctr_drbg_random, &mctx->ctr_drbg); +#else + ret = mbedtls_pk_check_pair(&mctx->own_cert.pk, &mctx->own_key); +#endif + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + ret = mbedtls_ssl_conf_own_cert(&mctx->conf, &mctx->own_cert, &mctx->own_key); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + return true; +} + +void set_verify_client(ctx_t ctx, bool require) { + if (!ctx) { return; } + auto mctx = static_cast(ctx); + mctx->verify_client = require; + if (require) { + mbedtls_ssl_conf_authmode(&mctx->conf, MBEDTLS_SSL_VERIFY_REQUIRED); + } else { + // If a verify callback is set, use OPTIONAL mode to ensure the callback + // is called (matching OpenSSL behavior). Otherwise use NONE. + mbedtls_ssl_conf_authmode(&mctx->conf, mctx->has_verify_callback + ? MBEDTLS_SSL_VERIFY_OPTIONAL + : MBEDTLS_SSL_VERIFY_NONE); + } +} + +session_t create_session(ctx_t ctx, socket_t sock) { + if (!ctx || sock == INVALID_SOCKET) { return nullptr; } + auto mctx = static_cast(ctx); + + auto session = new (std::nothrow) impl::MbedTlsSession(); + if (!session) { return nullptr; } + + session->sock = sock; + + int ret = mbedtls_ssl_setup(&session->ssl, &mctx->conf); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + delete session; + return nullptr; + } + + // Set BIO callbacks + mbedtls_ssl_set_bio(&session->ssl, &session->sock, impl::mbedtls_net_send_cb, + impl::mbedtls_net_recv_cb, nullptr); + + // Set per-session verify callback with session pointer if callback is + // registered + if (mctx->has_verify_callback) { + mbedtls_ssl_set_verify(&session->ssl, impl::mbedtls_verify_callback, + session); + } + + return static_cast(session); +} + +void free_session(session_t session) { + if (session) { delete static_cast(session); } +} + +bool set_sni(session_t session, const char *hostname) { + if (!session || !hostname) { return false; } + auto msession = static_cast(session); + + int ret = mbedtls_ssl_set_hostname(&msession->ssl, hostname); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + msession->hostname = hostname; + return true; +} + +bool set_hostname(session_t session, const char *hostname) { + // In Mbed TLS, set_hostname also sets up hostname verification + return set_sni(session, hostname); +} + +TlsError connect(session_t session) { + TlsError err; + if (!session) { + err.code = ErrorCode::Fatal; + return err; + } + + auto msession = static_cast(session); + int ret = mbedtls_ssl_handshake(&msession->ssl); + + if (ret == 0) { + err.code = ErrorCode::Success; + } else { + err.code = impl::map_mbedtls_error(ret, err.sys_errno); + err.backend_code = static_cast(-ret); + impl::mbedtls_last_error() = ret; + } + + return err; +} + +TlsError accept(session_t session) { + // Same as connect for Mbed TLS - handshake works for both client and server + auto result = connect(session); + + // After successful handshake, capture SNI from thread-local storage + if (result.code == ErrorCode::Success && session) { + auto msession = static_cast(session); + msession->sni_hostname = std::move(impl::mbedpending_sni()); + impl::mbedpending_sni().clear(); + } + + return result; +} + +bool connect_nonblocking(session_t session, socket_t sock, + time_t timeout_sec, time_t timeout_usec, + TlsError *err) { + if (!session) { + if (err) { err->code = ErrorCode::Fatal; } + return false; + } + + auto msession = static_cast(session); + + // Set socket to non-blocking mode + detail::set_nonblocking(sock, true); + auto cleanup = + detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); + + int ret; + while ((ret = mbedtls_ssl_handshake(&msession->ssl)) != 0) { + if (ret == MBEDTLS_ERR_SSL_WANT_READ) { + if (detail::select_read(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + } else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + if (detail::select_write(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + } + + // TlsError or timeout + if (err) { + err->code = impl::map_mbedtls_error(ret, err->sys_errno); + err->backend_code = static_cast(-ret); + } + impl::mbedtls_last_error() = ret; + return false; + } + + if (err) { err->code = ErrorCode::Success; } + return true; +} + +bool accept_nonblocking(session_t session, socket_t sock, + time_t timeout_sec, time_t timeout_usec, + TlsError *err) { + // Same implementation as connect for Mbed TLS + bool result = + connect_nonblocking(session, sock, timeout_sec, timeout_usec, err); + + // After successful handshake, capture SNI from thread-local storage + if (result && session) { + auto msession = static_cast(session); + msession->sni_hostname = std::move(impl::mbedpending_sni()); + impl::mbedpending_sni().clear(); + } + + return result; +} + +ssize_t read(session_t session, void *buf, size_t len, TlsError &err) { + if (!session || !buf) { + err.code = ErrorCode::Fatal; + return -1; + } + + auto msession = static_cast(session); + int ret = + mbedtls_ssl_read(&msession->ssl, static_cast(buf), len); + + if (ret > 0) { + err.code = ErrorCode::Success; + return static_cast(ret); + } + + if (ret == 0) { + err.code = ErrorCode::PeerClosed; + return 0; + } + + err.code = impl::map_mbedtls_error(ret, err.sys_errno); + err.backend_code = static_cast(-ret); + impl::mbedtls_last_error() = ret; + // mbedTLS signals a clean close_notify via a negative error code rather + // than 0; surface it as a clean EOF the way OpenSSL/wolfSSL do. + if (err.code == ErrorCode::PeerClosed) { return 0; } + return -1; +} + +ssize_t write(session_t session, const void *buf, size_t len, + TlsError &err) { + if (!session || !buf) { + err.code = ErrorCode::Fatal; + return -1; + } + + auto msession = static_cast(session); + int ret = mbedtls_ssl_write(&msession->ssl, + static_cast(buf), len); + + if (ret > 0) { + err.code = ErrorCode::Success; + return static_cast(ret); + } + + if (ret == 0) { + err.code = ErrorCode::PeerClosed; + return 0; + } + + err.code = impl::map_mbedtls_error(ret, err.sys_errno); + err.backend_code = static_cast(-ret); + impl::mbedtls_last_error() = ret; + return -1; +} + +int pending(const_session_t session) { + if (!session) { return 0; } + auto msession = + static_cast(const_cast(session)); + return static_cast(mbedtls_ssl_get_bytes_avail(&msession->ssl)); +} + +void shutdown(session_t session, bool graceful) { + if (!session) { return; } + auto msession = static_cast(session); + + if (graceful) { + // Try to send close_notify, but don't block forever + int ret; + int attempts = 0; + while ((ret = mbedtls_ssl_close_notify(&msession->ssl)) != 0 && + attempts < 3) { + if (ret != MBEDTLS_ERR_SSL_WANT_READ && + ret != MBEDTLS_ERR_SSL_WANT_WRITE) { + break; + } + attempts++; + } + } +} + +bool is_peer_closed(session_t session, socket_t sock) { + if (!session || sock == INVALID_SOCKET) { return true; } + auto msession = static_cast(session); + + // Check if there's already decrypted data available in the TLS buffer + // If so, the connection is definitely alive + if (mbedtls_ssl_get_bytes_avail(&msession->ssl) > 0) { return false; } + + // Set socket to non-blocking to avoid blocking on read + detail::set_nonblocking(sock, true); + auto cleanup = + detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); + + // Try a 1-byte read to check connection status + // Note: This will consume the byte if data is available, but for the + // purpose of checking if peer is closed, this should be acceptable + // since we're only called when we expect the connection might be closing + unsigned char buf; + int ret = mbedtls_ssl_read(&msession->ssl, &buf, 1); + + // If we got data or WANT_READ (would block), connection is alive + if (ret > 0 || ret == MBEDTLS_ERR_SSL_WANT_READ) { return false; } + + // If we get a peer close notify or a connection reset, the peer is closed + return ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY || + ret == MBEDTLS_ERR_NET_CONN_RESET || ret == 0; +} + +cert_t get_peer_cert(const_session_t session) { + if (!session) { return nullptr; } + auto msession = + static_cast(const_cast(session)); + + // Mbed TLS returns a pointer to the internal peer cert chain. + // WARNING: This pointer is only valid while the session is active. + // Do not use the certificate after calling free_session(). + const mbedtls_x509_crt *cert = mbedtls_ssl_get_peer_cert(&msession->ssl); + return const_cast(cert); +} + +void free_cert(cert_t cert) { + // Mbed TLS: peer certificate is owned by the SSL context. + // No-op here, but callers should still call this for cross-backend + // portability. + (void)cert; +} + +bool verify_hostname(cert_t cert, const char *hostname) { + if (!cert || !hostname) { return false; } + auto mcert = static_cast(cert); + std::string host_str(hostname); + + // Check if hostname is an IP address + bool is_ip = impl::is_ipv4_address(host_str); + unsigned char ip_bytes[4]; + if (is_ip) { impl::parse_ipv4(host_str, ip_bytes); } + + // Check Subject Alternative Names (SAN) + // In Mbed TLS 3.x, subject_alt_names contains raw values without ASN.1 tags + // - DNS names: raw string bytes + // - IP addresses: raw IP bytes (4 for IPv4, 16 for IPv6) + const mbedtls_x509_sequence *san = &mcert->subject_alt_names; + while (san != nullptr && san->buf.p != nullptr && san->buf.len > 0) { + const unsigned char *p = san->buf.p; + size_t len = san->buf.len; + + if (is_ip) { + // Check if this SAN is an IPv4 address (4 bytes) + if (len == 4 && memcmp(p, ip_bytes, 4) == 0) { return true; } + // Check if this SAN is an IPv6 address (16 bytes) - skip for now + } else { + // Check if this SAN is a DNS name (printable ASCII string) + bool is_dns = len > 0; + for (size_t i = 0; i < len && is_dns; i++) { + if (p[i] < 32 || p[i] > 126) { is_dns = false; } + } + if (is_dns) { + std::string san_name(reinterpret_cast(p), len); + if (detail::match_hostname(san_name, host_str)) { return true; } + } + } + san = san->next; + } + + // Fallback: Check Common Name (CN) in subject + char cn[256]; + int ret = mbedtls_x509_dn_gets(cn, sizeof(cn), &mcert->subject); + if (ret > 0) { + std::string cn_str(cn); + + // Look for "CN=" in the DN string + size_t cn_pos = cn_str.find("CN="); + if (cn_pos != std::string::npos) { + size_t start = cn_pos + 3; + size_t end = cn_str.find(',', start); + std::string cn_value = + cn_str.substr(start, end == std::string::npos ? end : end - start); + + if (detail::match_hostname(cn_value, host_str)) { return true; } + } + } + + return false; +} + +uint64_t hostname_mismatch_code() { + return static_cast(MBEDTLS_X509_BADCERT_CN_MISMATCH); +} + +long get_verify_result(const_session_t session) { + if (!session) { return -1; } + auto msession = + static_cast(const_cast(session)); + uint32_t flags = mbedtls_ssl_get_verify_result(&msession->ssl); + // Return 0 (X509_V_OK equivalent) if verification passed + return flags == 0 ? 0 : static_cast(flags); +} + +std::string get_cert_subject_cn(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + + // Find the CN in the subject + const mbedtls_x509_name *name = &x509->subject; + while (name != nullptr) { + if (MBEDTLS_OID_CMP(MBEDTLS_OID_AT_CN, &name->oid) == 0) { + return std::string(reinterpret_cast(name->val.p), + name->val.len); + } + name = name->next; + } + return ""; +} + +std::string get_cert_issuer_name(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + + // Build a human-readable issuer name string + char buf[512]; + int ret = mbedtls_x509_dn_gets(buf, sizeof(buf), &x509->issuer); + if (ret < 0) return ""; + return std::string(buf); +} + +bool get_cert_sans(cert_t cert, std::vector &sans) { + sans.clear(); + if (!cert) return false; + auto x509 = static_cast(cert); + + // Parse the Subject Alternative Name extension + const mbedtls_x509_sequence *cur = &x509->subject_alt_names; + while (cur != nullptr) { + if (cur->buf.len > 0) { + // Mbed TLS stores SAN as ASN.1 sequences + // The tag byte indicates the type + const unsigned char *p = cur->buf.p; + size_t len = cur->buf.len; + + // First byte is the tag + unsigned char tag = *p; + p++; + len--; + + // Parse length (simple single-byte length assumed) + if (len > 0 && *p < 0x80) { + size_t value_len = *p; + p++; + len--; + + if (value_len <= len) { + SanEntry entry; + // ASN.1 context tags for GeneralName + switch (tag & 0x1F) { + case 2: // dNSName + entry.type = SanType::DNS; + entry.value = + std::string(reinterpret_cast(p), value_len); + break; + case 7: // iPAddress + entry.type = SanType::IP; + if (value_len == 4) { + // IPv4 + char buf[16]; + snprintf(buf, sizeof(buf), "%d.%d.%d.%d", p[0], p[1], p[2], p[3]); + entry.value = buf; + } else if (value_len == 16) { + // IPv6 + char buf[64]; + snprintf(buf, sizeof(buf), + "%02x%02x:%02x%02x:%02x%02x:%02x%02x:" + "%02x%02x:%02x%02x:%02x%02x:%02x%02x", + p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], p[8], + p[9], p[10], p[11], p[12], p[13], p[14], p[15]); + entry.value = buf; + } + break; + case 1: // rfc822Name (email) + entry.type = SanType::EMAIL; + entry.value = + std::string(reinterpret_cast(p), value_len); + break; + case 6: // uniformResourceIdentifier + entry.type = SanType::URI; + entry.value = + std::string(reinterpret_cast(p), value_len); + break; + default: entry.type = SanType::OTHER; break; + } + + if (!entry.value.empty()) { sans.push_back(std::move(entry)); } + } + } + } + cur = cur->next; + } + return true; +} + +bool get_cert_validity(cert_t cert, time_t ¬_before, + time_t ¬_after) { + if (!cert) return false; + auto x509 = static_cast(cert); + + // Convert mbedtls_x509_time to time_t + auto to_time_t = [](const mbedtls_x509_time &t) -> time_t { + struct tm tm_time = {}; + tm_time.tm_year = t.year - 1900; + tm_time.tm_mon = t.mon - 1; + tm_time.tm_mday = t.day; + tm_time.tm_hour = t.hour; + tm_time.tm_min = t.min; + tm_time.tm_sec = t.sec; +#ifdef _WIN32 + return _mkgmtime(&tm_time); +#else + return timegm(&tm_time); +#endif + }; + + not_before = to_time_t(x509->valid_from); + not_after = to_time_t(x509->valid_to); + return true; +} + +std::string get_cert_serial(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + + // Convert serial number to hex string + std::string result; + result.reserve(x509->serial.len * 2); + for (size_t i = 0; i < x509->serial.len; i++) { + char hex[3]; + snprintf(hex, sizeof(hex), "%02X", x509->serial.p[i]); + result += hex; + } + return result; +} + +bool get_cert_der(cert_t cert, std::vector &der) { + if (!cert) return false; + auto crt = static_cast(cert); + if (!crt->raw.p || crt->raw.len == 0) return false; + der.assign(crt->raw.p, crt->raw.p + crt->raw.len); + return true; +} + +const char *get_sni(const_session_t session) { + if (!session) return nullptr; + auto msession = static_cast(session); + + // For server: return SNI received from client during handshake + if (!msession->sni_hostname.empty()) { + return msession->sni_hostname.c_str(); + } + + // For client: return the hostname set via set_sni + if (!msession->hostname.empty()) { return msession->hostname.c_str(); } + + return nullptr; +} + +uint64_t peek_error() { + // Mbed TLS doesn't have an error queue, return the last error + return static_cast(-impl::mbedtls_last_error()); +} + +uint64_t get_error() { + // Mbed TLS doesn't have an error queue, return and clear the last error + uint64_t err = static_cast(-impl::mbedtls_last_error()); + impl::mbedtls_last_error() = 0; + return err; +} + +std::string error_string(uint64_t code) { + char buf[256]; + mbedtls_strerror(-static_cast(code), buf, sizeof(buf)); + return std::string(buf); +} + +ca_store_t create_ca_store(const char *pem, size_t len) { + auto *ca_chain = new (std::nothrow) mbedtls_x509_crt; + if (!ca_chain) { return nullptr; } + + mbedtls_x509_crt_init(ca_chain); + + // mbedtls_x509_crt_parse expects null-terminated PEM + int ret = mbedtls_x509_crt_parse(ca_chain, + reinterpret_cast(pem), + len + 1); // +1 for null terminator + if (ret != 0) { + // Try without +1 in case PEM is already null-terminated + ret = mbedtls_x509_crt_parse( + ca_chain, reinterpret_cast(pem), len); + if (ret != 0) { + mbedtls_x509_crt_free(ca_chain); + delete ca_chain; + return nullptr; + } + } + + return static_cast(ca_chain); +} + +void free_ca_store(ca_store_t store) { + if (store) { + auto *ca_chain = static_cast(store); + mbedtls_x509_crt_free(ca_chain); + delete ca_chain; + } +} + +bool set_ca_store(ctx_t ctx, ca_store_t store) { + if (!ctx || !store) { return false; } + auto *mbed_ctx = static_cast(ctx); + auto *ca_chain = static_cast(store); + + // Free existing CA chain + mbedtls_x509_crt_free(&mbed_ctx->ca_chain); + mbedtls_x509_crt_init(&mbed_ctx->ca_chain); + + // Copy the CA chain (deep copy) + // Parse from the raw data of the source cert + mbedtls_x509_crt *src = ca_chain; + while (src != nullptr) { + int ret = mbedtls_x509_crt_parse_der(&mbed_ctx->ca_chain, src->raw.p, + src->raw.len); + if (ret != 0) { return false; } + src = src->next; + } + + // Update the SSL config to use the new CA chain + mbedtls_ssl_conf_ca_chain(&mbed_ctx->conf, &mbed_ctx->ca_chain, nullptr); + return true; +} + +size_t get_ca_certs(ctx_t ctx, std::vector &certs) { + certs.clear(); + if (!ctx) { return 0; } + auto *mbed_ctx = static_cast(ctx); + + // Iterate through the CA chain + mbedtls_x509_crt *cert = &mbed_ctx->ca_chain; + while (cert != nullptr && cert->raw.len > 0) { + // Create a copy of the certificate for the caller + auto *copy = new mbedtls_x509_crt; + mbedtls_x509_crt_init(copy); + int ret = mbedtls_x509_crt_parse_der(copy, cert->raw.p, cert->raw.len); + if (ret == 0) { + certs.push_back(static_cast(copy)); + } else { + mbedtls_x509_crt_free(copy); + delete copy; + } + cert = cert->next; + } + return certs.size(); +} + +std::vector get_ca_names(ctx_t ctx) { + std::vector names; + if (!ctx) { return names; } + auto *mbed_ctx = static_cast(ctx); + + // Iterate through the CA chain + mbedtls_x509_crt *cert = &mbed_ctx->ca_chain; + while (cert != nullptr && cert->raw.len > 0) { + char buf[512]; + int ret = mbedtls_x509_dn_gets(buf, sizeof(buf), &cert->subject); + if (ret > 0) { names.push_back(buf); } + cert = cert->next; + } + return names; +} + +bool update_server_cert(ctx_t ctx, const char *cert_pem, + const char *key_pem, const char *password) { + if (!ctx || !cert_pem || !key_pem) { return false; } + auto *mbed_ctx = static_cast(ctx); + + // Free existing certificate and key + mbedtls_x509_crt_free(&mbed_ctx->own_cert); + mbedtls_pk_free(&mbed_ctx->own_key); + mbedtls_x509_crt_init(&mbed_ctx->own_cert); + mbedtls_pk_init(&mbed_ctx->own_key); + + // Parse certificate PEM + int ret = mbedtls_x509_crt_parse( + &mbed_ctx->own_cert, reinterpret_cast(cert_pem), + strlen(cert_pem) + 1); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + // Parse private key PEM +#ifdef CPPHTTPLIB_MBEDTLS_V3 + ret = mbedtls_pk_parse_key( + &mbed_ctx->own_key, reinterpret_cast(key_pem), + strlen(key_pem) + 1, + password ? reinterpret_cast(password) : nullptr, + password ? strlen(password) : 0, mbedtls_ctr_drbg_random, + &mbed_ctx->ctr_drbg); +#else + ret = mbedtls_pk_parse_key( + &mbed_ctx->own_key, reinterpret_cast(key_pem), + strlen(key_pem) + 1, + password ? reinterpret_cast(password) : nullptr, + password ? strlen(password) : 0); +#endif + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + // Configure SSL to use the new certificate and key + ret = mbedtls_ssl_conf_own_cert(&mbed_ctx->conf, &mbed_ctx->own_cert, + &mbed_ctx->own_key); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + return true; +} + +bool update_server_client_ca(ctx_t ctx, const char *ca_pem) { + if (!ctx || !ca_pem) { return false; } + auto *mbed_ctx = static_cast(ctx); + + // Free existing CA chain + mbedtls_x509_crt_free(&mbed_ctx->ca_chain); + mbedtls_x509_crt_init(&mbed_ctx->ca_chain); + + // Parse CA PEM + int ret = mbedtls_x509_crt_parse( + &mbed_ctx->ca_chain, reinterpret_cast(ca_pem), + strlen(ca_pem) + 1); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + // Update SSL config to use new CA chain + mbedtls_ssl_conf_ca_chain(&mbed_ctx->conf, &mbed_ctx->ca_chain, nullptr); + return true; +} + +bool set_verify_callback(ctx_t ctx, VerifyCallback callback) { + if (!ctx) { return false; } + auto *mbed_ctx = static_cast(ctx); + + impl::get_verify_callback() = std::move(callback); + mbed_ctx->has_verify_callback = + static_cast(impl::get_verify_callback()); + + if (mbed_ctx->has_verify_callback) { + // Set OPTIONAL mode to ensure callback is called even when verification + // is disabled (matching OpenSSL behavior where SSL_VERIFY_PEER is set) + mbedtls_ssl_conf_authmode(&mbed_ctx->conf, MBEDTLS_SSL_VERIFY_OPTIONAL); + mbedtls_ssl_conf_verify(&mbed_ctx->conf, impl::mbedtls_verify_callback, + nullptr); + } else { + mbedtls_ssl_conf_verify(&mbed_ctx->conf, nullptr, nullptr); + } + return true; +} + +long get_verify_error(const_session_t session) { + if (!session) { return -1; } + auto *msession = + static_cast(const_cast(session)); + return static_cast(mbedtls_ssl_get_verify_result(&msession->ssl)); +} + +std::string verify_error_string(long error_code) { + if (error_code == 0) { return ""; } + char buf[256]; + mbedtls_x509_crt_verify_info(buf, sizeof(buf), "", + static_cast(error_code)); + // Remove trailing newline if present + std::string result(buf); + while (!result.empty() && (result.back() == '\n' || result.back() == ' ')) { + result.pop_back(); + } + return result; +} + +} // namespace tls + +#endif // CPPHTTPLIB_MBEDTLS_SUPPORT + +/* + * Group 10: TLS abstraction layer - wolfSSL backend + */ + +/* + * wolfSSL Backend Implementation + */ + +#ifdef CPPHTTPLIB_WOLFSSL_SUPPORT +namespace tls { + +namespace impl { + +// wolfSSL session wrapper +struct WolfSSLSession { + WOLFSSL *ssl = nullptr; + socket_t sock = INVALID_SOCKET; + std::string hostname; // For client: set via set_sni + std::string sni_hostname; // For server: received from client via SNI callback + + WolfSSLSession() = default; + + ~WolfSSLSession() { + if (ssl) { wolfSSL_free(ssl); } + } + + WolfSSLSession(const WolfSSLSession &) = delete; + WolfSSLSession &operator=(const WolfSSLSession &) = delete; +}; + +// Thread-local error code accessor for wolfSSL +uint64_t &wolfssl_last_error() { + static thread_local uint64_t err = 0; + return err; +} + +// Helper to map wolfSSL error to ErrorCode. +// ssl_error is the value from wolfSSL_get_error(). +// raw_ret is the raw return value from the wolfSSL call (for low-level error). +ErrorCode map_wolfssl_error(WOLFSSL *ssl, int ssl_error, + int &out_errno) { + switch (ssl_error) { + case SSL_ERROR_NONE: return ErrorCode::Success; + case SSL_ERROR_WANT_READ: return ErrorCode::WantRead; + case SSL_ERROR_WANT_WRITE: return ErrorCode::WantWrite; + case SSL_ERROR_ZERO_RETURN: return ErrorCode::PeerClosed; + case SSL_ERROR_SYSCALL: out_errno = errno; return ErrorCode::SyscallError; + default: + if (ssl) { + // wolfSSL stores the low-level error code as a negative value. + // DOMAIN_NAME_MISMATCH (-322) indicates hostname verification failure. + int low_err = ssl_error; // wolfSSL_get_error returns the low-level code + if (low_err == DOMAIN_NAME_MISMATCH) { + return ErrorCode::HostnameMismatch; + } + // Check verify result to distinguish cert verification from generic SSL + // errors. + long vr = wolfSSL_get_verify_result(ssl); + if (vr != 0) { return ErrorCode::CertVerifyFailed; } + } + return ErrorCode::Fatal; + } +} + +// WolfSSLContext constructor/destructor implementations +WolfSSLContext::WolfSSLContext() { wolfSSL_Init(); } + +WolfSSLContext::~WolfSSLContext() { + if (ctx) { wolfSSL_CTX_free(ctx); } +} + +// Thread-local storage for SNI captured during handshake +std::string &wolfssl_pending_sni() { + static thread_local std::string sni; + return sni; +} + +// SNI callback for wolfSSL server to capture client's SNI hostname +int wolfssl_sni_callback(WOLFSSL *ssl, int *ret, void *exArg) { + (void)ret; + (void)exArg; + + void *name_data = nullptr; + unsigned short name_len = + wolfSSL_SNI_GetRequest(ssl, WOLFSSL_SNI_HOST_NAME, &name_data); + + if (name_data && name_len > 0) { + wolfssl_pending_sni().assign(static_cast(name_data), + name_len); + } else { + wolfssl_pending_sni().clear(); + } + return 0; // Continue regardless +} + +// wolfSSL verify callback wrapper +int wolfssl_verify_callback(int preverify_ok, + WOLFSSL_X509_STORE_CTX *x509_ctx) { + auto &callback = get_verify_callback(); + if (!callback) { return preverify_ok; } + + WOLFSSL_X509 *cert = wolfSSL_X509_STORE_CTX_get_current_cert(x509_ctx); + int depth = wolfSSL_X509_STORE_CTX_get_error_depth(x509_ctx); + int err = wolfSSL_X509_STORE_CTX_get_error(x509_ctx); + + // Get the WOLFSSL object from the X509_STORE_CTX + WOLFSSL *ssl = static_cast(wolfSSL_X509_STORE_CTX_get_ex_data( + x509_ctx, wolfSSL_get_ex_data_X509_STORE_CTX_idx())); + + VerifyContext verify_ctx; + verify_ctx.session = static_cast(ssl); + verify_ctx.cert = static_cast(cert); + verify_ctx.depth = depth; + verify_ctx.preverify_ok = (preverify_ok != 0); + verify_ctx.error_code = static_cast(err); + + if (err != 0) { + verify_ctx.error_string = wolfSSL_X509_verify_cert_error_string(err); + } else { + verify_ctx.error_string = nullptr; + } + + bool accepted = callback(verify_ctx); + return accepted ? 1 : 0; +} + +void set_wolfssl_password_cb(WOLFSSL_CTX *ctx, const char *password) { + wolfSSL_CTX_set_default_passwd_cb_userdata(ctx, const_cast(password)); + wolfSSL_CTX_set_default_passwd_cb( + ctx, [](char *buf, int size, int /*rwflag*/, void *userdata) -> int { + auto *pwd = static_cast(userdata); + if (!pwd) return 0; + auto len = static_cast(strlen(pwd)); + if (len > size) len = size; + memcpy(buf, pwd, static_cast(len)); + return len; + }); +} + +} // namespace impl + +ctx_t create_client_context() { + auto ctx = new (std::nothrow) impl::WolfSSLContext(); + if (!ctx) { return nullptr; } + + ctx->is_server = false; + + WOLFSSL_METHOD *method = wolfTLSv1_2_client_method(); + if (!method) { + delete ctx; + return nullptr; + } + + ctx->ctx = wolfSSL_CTX_new(method); + if (!ctx->ctx) { + delete ctx; + return nullptr; + } + + // Default: verify peer certificate + wolfSSL_CTX_set_verify(ctx->ctx, SSL_VERIFY_PEER, nullptr); + + return static_cast(ctx); +} + +ctx_t create_server_context() { + auto ctx = new (std::nothrow) impl::WolfSSLContext(); + if (!ctx) { return nullptr; } + + ctx->is_server = true; + + WOLFSSL_METHOD *method = wolfTLSv1_2_server_method(); + if (!method) { + delete ctx; + return nullptr; + } + + ctx->ctx = wolfSSL_CTX_new(method); + if (!ctx->ctx) { + delete ctx; + return nullptr; + } + + // Default: don't verify client + wolfSSL_CTX_set_verify(ctx->ctx, SSL_VERIFY_NONE, nullptr); + + // Enable SNI on server + wolfSSL_CTX_SNI_SetOptions(ctx->ctx, WOLFSSL_SNI_HOST_NAME, + WOLFSSL_SNI_CONTINUE_ON_MISMATCH); + wolfSSL_CTX_set_servername_callback(ctx->ctx, impl::wolfssl_sni_callback); + + return static_cast(ctx); +} + +void free_context(ctx_t ctx) { + if (ctx) { delete static_cast(ctx); } +} + +bool set_min_version(ctx_t ctx, Version version) { + if (!ctx) { return false; } + auto wctx = static_cast(ctx); + + int min_ver = WOLFSSL_TLSV1_2; + if (version >= Version::TLS1_3) { min_ver = WOLFSSL_TLSV1_3; } + + return wolfSSL_CTX_SetMinVersion(wctx->ctx, min_ver) == WOLFSSL_SUCCESS; +} + +bool load_ca_pem(ctx_t ctx, const char *pem, size_t len) { + if (!ctx || !pem) { return false; } + auto wctx = static_cast(ctx); + + int ret = wolfSSL_CTX_load_verify_buffer( + wctx->ctx, reinterpret_cast(pem), + static_cast(len), SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + wctx->ca_pem_data_.append(pem, len); + return true; +} + +bool load_ca_file(ctx_t ctx, const char *file_path) { + if (!ctx || !file_path) { return false; } + auto wctx = static_cast(ctx); + + int ret = wolfSSL_CTX_load_verify_locations(wctx->ctx, file_path, nullptr); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + return true; +} + +bool load_ca_dir(ctx_t ctx, const char *dir_path) { + if (!ctx || !dir_path) { return false; } + auto wctx = static_cast(ctx); + + int ret = wolfSSL_CTX_load_verify_locations(wctx->ctx, nullptr, dir_path); + // wolfSSL may fail if the directory doesn't contain properly hashed certs. + // Unlike OpenSSL which lazily loads certs from directories, wolfSSL scans + // immediately. Return true even on failure since the CA file may have + // already been loaded, matching OpenSSL's lenient behavior. + (void)ret; + return true; +} + +bool load_system_certs(ctx_t ctx) { + if (!ctx) { return false; } + auto wctx = static_cast(ctx); + bool loaded = false; + +#ifdef _WIN32 + loaded = impl::enumerate_windows_system_certs( + [&](const unsigned char *data, size_t len) { + return wolfSSL_CTX_load_verify_buffer(wctx->ctx, data, + static_cast(len), + SSL_FILETYPE_ASN1) == SSL_SUCCESS; + }); +#elif defined(__APPLE__) && defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) + loaded = impl::enumerate_macos_keychain_certs( + [&](const unsigned char *data, size_t len) { + return wolfSSL_CTX_load_verify_buffer(wctx->ctx, data, + static_cast(len), + SSL_FILETYPE_ASN1) == SSL_SUCCESS; + }); +#else + for (auto path = impl::system_ca_paths(); *path; ++path) { + if (wolfSSL_CTX_load_verify_locations(wctx->ctx, *path, nullptr) == + SSL_SUCCESS) { + loaded = true; + break; + } + } + + if (!loaded) { + for (auto dir = impl::system_ca_dirs(); *dir; ++dir) { + if (wolfSSL_CTX_load_verify_locations(wctx->ctx, nullptr, *dir) == + SSL_SUCCESS) { + loaded = true; + break; + } + } + } +#endif + + return loaded; +} + +bool set_client_cert_pem(ctx_t ctx, const char *cert, const char *key, + const char *password) { + if (!ctx || !cert || !key) { return false; } + auto wctx = static_cast(ctx); + + // Load certificate + int ret = wolfSSL_CTX_use_certificate_buffer( + wctx->ctx, reinterpret_cast(cert), + static_cast(strlen(cert)), SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + + // Set password callback if password is provided + if (password) { impl::set_wolfssl_password_cb(wctx->ctx, password); } + + // Load private key + ret = wolfSSL_CTX_use_PrivateKey_buffer( + wctx->ctx, reinterpret_cast(key), + static_cast(strlen(key)), SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + + // Verify that the certificate and private key match + return wolfSSL_CTX_check_private_key(wctx->ctx) == SSL_SUCCESS; +} + +bool set_client_cert_file(ctx_t ctx, const char *cert_path, + const char *key_path, const char *password) { + if (!ctx || !cert_path || !key_path) { return false; } + auto wctx = static_cast(ctx); + + // Load certificate file + int ret = + wolfSSL_CTX_use_certificate_file(wctx->ctx, cert_path, SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + + // Set password callback if password is provided + if (password) { impl::set_wolfssl_password_cb(wctx->ctx, password); } + + // Load private key file + ret = wolfSSL_CTX_use_PrivateKey_file(wctx->ctx, key_path, SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + + // Verify that the certificate and private key match + return wolfSSL_CTX_check_private_key(wctx->ctx) == SSL_SUCCESS; +} + +void set_verify_client(ctx_t ctx, bool require) { + if (!ctx) { return; } + auto wctx = static_cast(ctx); + wctx->verify_client = require; + if (require) { + wolfSSL_CTX_set_verify( + wctx->ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, + wctx->has_verify_callback ? impl::wolfssl_verify_callback : nullptr); + } else { + if (wctx->has_verify_callback) { + wolfSSL_CTX_set_verify(wctx->ctx, SSL_VERIFY_PEER, + impl::wolfssl_verify_callback); + } else { + wolfSSL_CTX_set_verify(wctx->ctx, SSL_VERIFY_NONE, nullptr); + } + } +} + +session_t create_session(ctx_t ctx, socket_t sock) { + if (!ctx || sock == INVALID_SOCKET) { return nullptr; } + auto wctx = static_cast(ctx); + + auto session = new (std::nothrow) impl::WolfSSLSession(); + if (!session) { return nullptr; } + + session->sock = sock; + session->ssl = wolfSSL_new(wctx->ctx); + if (!session->ssl) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + delete session; + return nullptr; + } + + wolfSSL_set_fd(session->ssl, static_cast(sock)); + + return static_cast(session); +} + +void free_session(session_t session) { + if (session) { delete static_cast(session); } +} + +bool set_sni(session_t session, const char *hostname) { + if (!session || !hostname) { return false; } + auto wsession = static_cast(session); + + int ret = wolfSSL_UseSNI(wsession->ssl, WOLFSSL_SNI_HOST_NAME, hostname, + static_cast(strlen(hostname))); + if (ret != WOLFSSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + + // Also set hostname for verification + wolfSSL_check_domain_name(wsession->ssl, hostname); + + wsession->hostname = hostname; + return true; +} + +bool set_hostname(session_t session, const char *hostname) { + // In wolfSSL, set_hostname also sets up hostname verification + return set_sni(session, hostname); +} + +TlsError connect(session_t session) { + TlsError err; + if (!session) { + err.code = ErrorCode::Fatal; + return err; + } + + auto wsession = static_cast(session); + int ret = wolfSSL_connect(wsession->ssl); + + if (ret == SSL_SUCCESS) { + err.code = ErrorCode::Success; + } else { + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + err.code = impl::map_wolfssl_error(wsession->ssl, ssl_error, err.sys_errno); + err.backend_code = static_cast(ssl_error); + impl::wolfssl_last_error() = err.backend_code; + } + + return err; +} + +TlsError accept(session_t session) { + TlsError err; + if (!session) { + err.code = ErrorCode::Fatal; + return err; + } + + auto wsession = static_cast(session); + int ret = wolfSSL_accept(wsession->ssl); + + if (ret == SSL_SUCCESS) { + err.code = ErrorCode::Success; + // Capture SNI from thread-local storage after successful handshake + wsession->sni_hostname = std::move(impl::wolfssl_pending_sni()); + impl::wolfssl_pending_sni().clear(); + } else { + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + err.code = impl::map_wolfssl_error(wsession->ssl, ssl_error, err.sys_errno); + err.backend_code = static_cast(ssl_error); + impl::wolfssl_last_error() = err.backend_code; + } + + return err; +} + +bool connect_nonblocking(session_t session, socket_t sock, + time_t timeout_sec, time_t timeout_usec, + TlsError *err) { + if (!session) { + if (err) { err->code = ErrorCode::Fatal; } + return false; + } + + auto wsession = static_cast(session); + + // Set socket to non-blocking mode + detail::set_nonblocking(sock, true); + auto cleanup = + detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); + + int ret; + while ((ret = wolfSSL_connect(wsession->ssl)) != SSL_SUCCESS) { + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + if (ssl_error == SSL_ERROR_WANT_READ) { + if (detail::select_read(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + } else if (ssl_error == SSL_ERROR_WANT_WRITE) { + if (detail::select_write(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + } + + // Error or timeout + if (err) { + err->code = + impl::map_wolfssl_error(wsession->ssl, ssl_error, err->sys_errno); + err->backend_code = static_cast(ssl_error); + } + impl::wolfssl_last_error() = static_cast(ssl_error); + return false; + } + + if (err) { err->code = ErrorCode::Success; } + return true; +} + +bool accept_nonblocking(session_t session, socket_t sock, + time_t timeout_sec, time_t timeout_usec, + TlsError *err) { + if (!session) { + if (err) { err->code = ErrorCode::Fatal; } + return false; + } + + auto wsession = static_cast(session); + + // Set socket to non-blocking mode + detail::set_nonblocking(sock, true); + auto cleanup = + detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); + + int ret; + while ((ret = wolfSSL_accept(wsession->ssl)) != SSL_SUCCESS) { + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + if (ssl_error == SSL_ERROR_WANT_READ) { + if (detail::select_read(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + } else if (ssl_error == SSL_ERROR_WANT_WRITE) { + if (detail::select_write(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + } + + // Error or timeout + if (err) { + err->code = + impl::map_wolfssl_error(wsession->ssl, ssl_error, err->sys_errno); + err->backend_code = static_cast(ssl_error); + } + impl::wolfssl_last_error() = static_cast(ssl_error); + return false; + } + + if (err) { err->code = ErrorCode::Success; } + + // Capture SNI from thread-local storage after successful handshake + wsession->sni_hostname = std::move(impl::wolfssl_pending_sni()); + impl::wolfssl_pending_sni().clear(); + + return true; +} + +ssize_t read(session_t session, void *buf, size_t len, TlsError &err) { + if (!session || !buf) { + err.code = ErrorCode::Fatal; + return -1; + } + + auto wsession = static_cast(session); + int ret = wolfSSL_read(wsession->ssl, buf, static_cast(len)); + + if (ret > 0) { + err.code = ErrorCode::Success; + return static_cast(ret); + } + + if (ret == 0) { + err.code = ErrorCode::PeerClosed; + return 0; + } + + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + err.code = impl::map_wolfssl_error(wsession->ssl, ssl_error, err.sys_errno); + err.backend_code = static_cast(ssl_error); + impl::wolfssl_last_error() = err.backend_code; + return -1; +} + +ssize_t write(session_t session, const void *buf, size_t len, + TlsError &err) { + if (!session || !buf) { + err.code = ErrorCode::Fatal; + return -1; + } + + auto wsession = static_cast(session); + int ret = wolfSSL_write(wsession->ssl, buf, static_cast(len)); + + if (ret > 0) { + err.code = ErrorCode::Success; + return static_cast(ret); + } + + // wolfSSL_write returns 0 when the peer has sent a close_notify. + // Treat this as an error (return -1) so callers don't spin in a + // write loop adding zero to the offset. + if (ret == 0) { + err.code = ErrorCode::PeerClosed; + return -1; + } + + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + err.code = impl::map_wolfssl_error(wsession->ssl, ssl_error, err.sys_errno); + err.backend_code = static_cast(ssl_error); + impl::wolfssl_last_error() = err.backend_code; + return -1; +} + +int pending(const_session_t session) { + if (!session) { return 0; } + auto wsession = + static_cast(const_cast(session)); + return wolfSSL_pending(wsession->ssl); +} + +void shutdown(session_t session, bool graceful) { + if (!session) { return; } + auto wsession = static_cast(session); + + if (graceful) { + int ret; + int attempts = 0; + while ((ret = wolfSSL_shutdown(wsession->ssl)) != SSL_SUCCESS && + attempts < 3) { + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + if (ssl_error != SSL_ERROR_WANT_READ && + ssl_error != SSL_ERROR_WANT_WRITE) { + break; + } + attempts++; + } + } else { + wolfSSL_shutdown(wsession->ssl); + } +} + +bool is_peer_closed(session_t session, socket_t sock) { + if (!session || sock == INVALID_SOCKET) { return true; } + auto wsession = static_cast(session); + + // Check if there's already decrypted data available + if (wolfSSL_pending(wsession->ssl) > 0) { return false; } + + // Set socket to non-blocking to avoid blocking on read + detail::set_nonblocking(sock, true); + auto cleanup = + detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); + + // Peek 1 byte to check connection status without consuming data + unsigned char buf; + int ret = wolfSSL_peek(wsession->ssl, &buf, 1); + + // If we got data or WANT_READ (would block), connection is alive + if (ret > 0) { return false; } + + int ssl_error = wolfSSL_get_error(wsession->ssl, ret); + if (ssl_error == SSL_ERROR_WANT_READ) { return false; } + + return ssl_error == SSL_ERROR_ZERO_RETURN || ssl_error == SSL_ERROR_SYSCALL || + ret == 0; +} + +cert_t get_peer_cert(const_session_t session) { + if (!session) { return nullptr; } + auto wsession = + static_cast(const_cast(session)); + + WOLFSSL_X509 *cert = wolfSSL_get_peer_certificate(wsession->ssl); + return static_cast(cert); +} + +void free_cert(cert_t cert) { + if (cert) { wolfSSL_X509_free(static_cast(cert)); } +} + +bool verify_hostname(cert_t cert, const char *hostname) { + if (!cert || !hostname) { return false; } + auto x509 = static_cast(cert); + std::string host_str(hostname); + + // Check if hostname is an IP address + bool is_ip = impl::is_ipv4_address(host_str); + unsigned char ip_bytes[4]; + if (is_ip) { impl::parse_ipv4(host_str, ip_bytes); } + + // Check Subject Alternative Names + auto *san_names = static_cast( + wolfSSL_X509_get_ext_d2i(x509, NID_subject_alt_name, nullptr, nullptr)); + + if (san_names) { + int san_count = wolfSSL_sk_num(san_names); + for (int i = 0; i < san_count; i++) { + auto *names = + static_cast(wolfSSL_sk_value(san_names, i)); + if (!names) continue; + + if (!is_ip && names->type == WOLFSSL_GEN_DNS) { + // DNS name + unsigned char *dns_name = nullptr; + int dns_len = wolfSSL_ASN1_STRING_to_UTF8(&dns_name, names->d.dNSName); + if (dns_name && dns_len > 0) { + std::string san_name(reinterpret_cast(dns_name), + static_cast(dns_len)); + XFREE(dns_name, nullptr, DYNAMIC_TYPE_OPENSSL); + if (detail::match_hostname(san_name, host_str)) { + wolfSSL_sk_free(san_names); + return true; + } + } + } else if (is_ip && names->type == WOLFSSL_GEN_IPADD) { + // IP address + unsigned char *ip_data = wolfSSL_ASN1_STRING_data(names->d.iPAddress); + int ip_len = wolfSSL_ASN1_STRING_length(names->d.iPAddress); + if (ip_data && ip_len == 4 && memcmp(ip_data, ip_bytes, 4) == 0) { + wolfSSL_sk_free(san_names); + return true; + } + } + } + wolfSSL_sk_free(san_names); + } + + // Fallback: Check Common Name (CN) in subject + WOLFSSL_X509_NAME *subject = wolfSSL_X509_get_subject_name(x509); + if (subject) { + char cn[256] = {}; + int cn_len = wolfSSL_X509_NAME_get_text_by_NID(subject, NID_commonName, cn, + sizeof(cn)); + if (cn_len > 0) { + std::string cn_str(cn, static_cast(cn_len)); + if (detail::match_hostname(cn_str, host_str)) { return true; } + } + } + + return false; +} + +uint64_t hostname_mismatch_code() { + return static_cast(DOMAIN_NAME_MISMATCH); +} + +long get_verify_result(const_session_t session) { + if (!session) { return -1; } + auto wsession = + static_cast(const_cast(session)); + long result = wolfSSL_get_verify_result(wsession->ssl); + return result; +} + +std::string get_cert_subject_cn(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + + WOLFSSL_X509_NAME *subject = wolfSSL_X509_get_subject_name(x509); + if (!subject) return ""; + + char cn[256] = {}; + int cn_len = wolfSSL_X509_NAME_get_text_by_NID(subject, NID_commonName, cn, + sizeof(cn)); + if (cn_len <= 0) return ""; + return std::string(cn, static_cast(cn_len)); +} + +std::string get_cert_issuer_name(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + + WOLFSSL_X509_NAME *issuer = wolfSSL_X509_get_issuer_name(x509); + if (!issuer) return ""; + + char *name_str = wolfSSL_X509_NAME_oneline(issuer, nullptr, 0); + if (!name_str) return ""; + + std::string result(name_str); + XFREE(name_str, nullptr, DYNAMIC_TYPE_OPENSSL); + return result; +} + +bool get_cert_sans(cert_t cert, std::vector &sans) { + sans.clear(); + if (!cert) return false; + auto x509 = static_cast(cert); + + auto *san_names = static_cast( + wolfSSL_X509_get_ext_d2i(x509, NID_subject_alt_name, nullptr, nullptr)); + if (!san_names) return true; // No SANs is not an error + + int count = wolfSSL_sk_num(san_names); + for (int i = 0; i < count; i++) { + auto *name = + static_cast(wolfSSL_sk_value(san_names, i)); + if (!name) continue; + + SanEntry entry; + switch (name->type) { + case WOLFSSL_GEN_DNS: { + entry.type = SanType::DNS; + unsigned char *dns_name = nullptr; + int dns_len = wolfSSL_ASN1_STRING_to_UTF8(&dns_name, name->d.dNSName); + if (dns_name && dns_len > 0) { + entry.value = std::string(reinterpret_cast(dns_name), + static_cast(dns_len)); + XFREE(dns_name, nullptr, DYNAMIC_TYPE_OPENSSL); + } + break; + } + case WOLFSSL_GEN_IPADD: { + entry.type = SanType::IP; + unsigned char *ip_data = wolfSSL_ASN1_STRING_data(name->d.iPAddress); + int ip_len = wolfSSL_ASN1_STRING_length(name->d.iPAddress); + if (ip_data && ip_len == 4) { + char buf[16]; + snprintf(buf, sizeof(buf), "%d.%d.%d.%d", ip_data[0], ip_data[1], + ip_data[2], ip_data[3]); + entry.value = buf; + } else if (ip_data && ip_len == 16) { + char buf[64]; + snprintf(buf, sizeof(buf), + "%02x%02x:%02x%02x:%02x%02x:%02x%02x:" + "%02x%02x:%02x%02x:%02x%02x:%02x%02x", + ip_data[0], ip_data[1], ip_data[2], ip_data[3], ip_data[4], + ip_data[5], ip_data[6], ip_data[7], ip_data[8], ip_data[9], + ip_data[10], ip_data[11], ip_data[12], ip_data[13], + ip_data[14], ip_data[15]); + entry.value = buf; + } + break; + } + case WOLFSSL_GEN_EMAIL: + entry.type = SanType::EMAIL; + { + unsigned char *email = nullptr; + int email_len = wolfSSL_ASN1_STRING_to_UTF8(&email, name->d.rfc822Name); + if (email && email_len > 0) { + entry.value = std::string(reinterpret_cast(email), + static_cast(email_len)); + XFREE(email, nullptr, DYNAMIC_TYPE_OPENSSL); + } + } + break; + case WOLFSSL_GEN_URI: + entry.type = SanType::URI; + { + unsigned char *uri = nullptr; + int uri_len = wolfSSL_ASN1_STRING_to_UTF8( + &uri, name->d.uniformResourceIdentifier); + if (uri && uri_len > 0) { + entry.value = std::string(reinterpret_cast(uri), + static_cast(uri_len)); + XFREE(uri, nullptr, DYNAMIC_TYPE_OPENSSL); + } + } + break; + default: entry.type = SanType::OTHER; break; + } + + if (!entry.value.empty()) { sans.push_back(std::move(entry)); } + } + wolfSSL_sk_free(san_names); + return true; +} + +bool get_cert_validity(cert_t cert, time_t ¬_before, + time_t ¬_after) { + if (!cert) return false; + auto x509 = static_cast(cert); + + const WOLFSSL_ASN1_TIME *nb = wolfSSL_X509_get_notBefore(x509); + const WOLFSSL_ASN1_TIME *na = wolfSSL_X509_get_notAfter(x509); + + if (!nb || !na) return false; + + // wolfSSL_ASN1_TIME_to_tm is available + struct tm tm_nb = {}, tm_na = {}; + if (wolfSSL_ASN1_TIME_to_tm(nb, &tm_nb) != WOLFSSL_SUCCESS) return false; + if (wolfSSL_ASN1_TIME_to_tm(na, &tm_na) != WOLFSSL_SUCCESS) return false; + +#ifdef _WIN32 + not_before = _mkgmtime(&tm_nb); + not_after = _mkgmtime(&tm_na); +#else + not_before = timegm(&tm_nb); + not_after = timegm(&tm_na); +#endif + return true; +} + +std::string get_cert_serial(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + + WOLFSSL_ASN1_INTEGER *serial_asn1 = wolfSSL_X509_get_serialNumber(x509); + if (!serial_asn1) return ""; + + // Get the serial number data + int len = serial_asn1->length; + unsigned char *data = serial_asn1->data; + if (!data || len <= 0) return ""; + + std::string result; + result.reserve(static_cast(len) * 2); + for (int i = 0; i < len; i++) { + char hex[3]; + snprintf(hex, sizeof(hex), "%02X", data[i]); + result += hex; + } + return result; +} + +bool get_cert_der(cert_t cert, std::vector &der) { + if (!cert) return false; + auto x509 = static_cast(cert); + + int der_len = 0; + const unsigned char *der_data = wolfSSL_X509_get_der(x509, &der_len); + if (!der_data || der_len <= 0) return false; + + der.assign(der_data, der_data + der_len); + return true; +} + +const char *get_sni(const_session_t session) { + if (!session) return nullptr; + auto wsession = static_cast(session); + + // For server: return SNI received from client during handshake + if (!wsession->sni_hostname.empty()) { + return wsession->sni_hostname.c_str(); + } + + // For client: return the hostname set via set_sni + if (!wsession->hostname.empty()) { return wsession->hostname.c_str(); } + + return nullptr; +} + +uint64_t peek_error() { + return static_cast(wolfSSL_ERR_peek_last_error()); +} + +uint64_t get_error() { + uint64_t err = impl::wolfssl_last_error(); + impl::wolfssl_last_error() = 0; + return err; +} + +std::string error_string(uint64_t code) { + char buf[256]; + wolfSSL_ERR_error_string(static_cast(code), buf); + return std::string(buf); +} + +ca_store_t create_ca_store(const char *pem, size_t len) { + if (!pem || len == 0) { return nullptr; } + // Validate by attempting to load into a temporary ctx + WOLFSSL_CTX *tmp_ctx = wolfSSL_CTX_new(wolfTLSv1_2_client_method()); + if (!tmp_ctx) { return nullptr; } + int ret = wolfSSL_CTX_load_verify_buffer( + tmp_ctx, reinterpret_cast(pem), + static_cast(len), SSL_FILETYPE_PEM); + wolfSSL_CTX_free(tmp_ctx); + if (ret != SSL_SUCCESS) { return nullptr; } + return static_cast( + new impl::WolfSSLCAStore{std::string(pem, len)}); +} + +void free_ca_store(ca_store_t store) { + delete static_cast(store); +} + +bool set_ca_store(ctx_t ctx, ca_store_t store) { + if (!ctx || !store) { return false; } + auto *wctx = static_cast(ctx); + auto *ca = static_cast(store); + int ret = wolfSSL_CTX_load_verify_buffer( + wctx->ctx, reinterpret_cast(ca->pem_data.data()), + static_cast(ca->pem_data.size()), SSL_FILETYPE_PEM); + if (ret == SSL_SUCCESS) { wctx->ca_pem_data_ += ca->pem_data; } + return ret == SSL_SUCCESS; +} + +size_t get_ca_certs(ctx_t ctx, std::vector &certs) { + certs.clear(); + if (!ctx) { return 0; } + auto *wctx = static_cast(ctx); + if (wctx->ca_pem_data_.empty()) { return 0; } + + const std::string &pem = wctx->ca_pem_data_; + const std::string begin_marker = "-----BEGIN CERTIFICATE-----"; + const std::string end_marker = "-----END CERTIFICATE-----"; + size_t pos = 0; + while ((pos = pem.find(begin_marker, pos)) != std::string::npos) { + size_t end_pos = pem.find(end_marker, pos); + if (end_pos == std::string::npos) { break; } + end_pos += end_marker.size(); + std::string cert_pem = pem.substr(pos, end_pos - pos); + WOLFSSL_X509 *x509 = wolfSSL_X509_load_certificate_buffer( + reinterpret_cast(cert_pem.data()), + static_cast(cert_pem.size()), WOLFSSL_FILETYPE_PEM); + if (x509) { certs.push_back(static_cast(x509)); } + pos = end_pos; + } + return certs.size(); +} + +std::vector get_ca_names(ctx_t ctx) { + std::vector names; + if (!ctx) { return names; } + auto *wctx = static_cast(ctx); + if (wctx->ca_pem_data_.empty()) { return names; } + + const std::string &pem = wctx->ca_pem_data_; + const std::string begin_marker = "-----BEGIN CERTIFICATE-----"; + const std::string end_marker = "-----END CERTIFICATE-----"; + size_t pos = 0; + while ((pos = pem.find(begin_marker, pos)) != std::string::npos) { + size_t end_pos = pem.find(end_marker, pos); + if (end_pos == std::string::npos) { break; } + end_pos += end_marker.size(); + std::string cert_pem = pem.substr(pos, end_pos - pos); + WOLFSSL_X509 *x509 = wolfSSL_X509_load_certificate_buffer( + reinterpret_cast(cert_pem.data()), + static_cast(cert_pem.size()), WOLFSSL_FILETYPE_PEM); + if (x509) { + WOLFSSL_X509_NAME *subject = wolfSSL_X509_get_subject_name(x509); + if (subject) { + char *name_str = wolfSSL_X509_NAME_oneline(subject, nullptr, 0); + if (name_str) { + names.push_back(name_str); + XFREE(name_str, nullptr, DYNAMIC_TYPE_OPENSSL); + } + } + wolfSSL_X509_free(x509); + } + pos = end_pos; + } + return names; +} + +bool update_server_cert(ctx_t ctx, const char *cert_pem, + const char *key_pem, const char *password) { + if (!ctx || !cert_pem || !key_pem) { return false; } + auto *wctx = static_cast(ctx); + + // Load new certificate + int ret = wolfSSL_CTX_use_certificate_buffer( + wctx->ctx, reinterpret_cast(cert_pem), + static_cast(strlen(cert_pem)), SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + + // Set password if provided + if (password) { impl::set_wolfssl_password_cb(wctx->ctx, password); } + + // Load new private key + ret = wolfSSL_CTX_use_PrivateKey_buffer( + wctx->ctx, reinterpret_cast(key_pem), + static_cast(strlen(key_pem)), SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + + return true; +} + +bool update_server_client_ca(ctx_t ctx, const char *ca_pem) { + if (!ctx || !ca_pem) { return false; } + auto *wctx = static_cast(ctx); + + int ret = wolfSSL_CTX_load_verify_buffer( + wctx->ctx, reinterpret_cast(ca_pem), + static_cast(strlen(ca_pem)), SSL_FILETYPE_PEM); + if (ret != SSL_SUCCESS) { + impl::wolfssl_last_error() = + static_cast(wolfSSL_ERR_peek_last_error()); + return false; + } + return true; +} + +bool set_verify_callback(ctx_t ctx, VerifyCallback callback) { + if (!ctx) { return false; } + auto *wctx = static_cast(ctx); + + impl::get_verify_callback() = std::move(callback); + wctx->has_verify_callback = static_cast(impl::get_verify_callback()); + + if (wctx->has_verify_callback) { + wolfSSL_CTX_set_verify(wctx->ctx, SSL_VERIFY_PEER, + impl::wolfssl_verify_callback); + } else { + wolfSSL_CTX_set_verify( + wctx->ctx, + wctx->verify_client + ? (SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT) + : SSL_VERIFY_NONE, + nullptr); + } + return true; +} + +long get_verify_error(const_session_t session) { + if (!session) { return -1; } + auto *wsession = + static_cast(const_cast(session)); + return wolfSSL_get_verify_result(wsession->ssl); +} + +std::string verify_error_string(long error_code) { + if (error_code == 0) { return ""; } + const char *str = + wolfSSL_X509_verify_cert_error_string(static_cast(error_code)); + return str ? std::string(str) : std::string(); +} + +} // namespace tls + +#endif // CPPHTTPLIB_WOLFSSL_SUPPORT + +// WebSocket implementation +namespace ws { + +bool WebSocket::send_frame(Opcode op, const char *data, size_t len, + bool fin) { + std::lock_guard lock(write_mutex_); + if (closed_) { return false; } + return detail::write_websocket_frame(strm_, op, data, len, fin, !is_server_); +} + +ReadResult WebSocket::read(std::string &msg) { + while (!closed_) { + Opcode opcode; + std::string payload; + bool fin; + + if (!impl::read_websocket_frame(strm_, opcode, payload, fin, is_server_, + CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH)) { + closed_ = true; + return Fail; + } + + switch (opcode) { + case Opcode::Ping: { + std::lock_guard lock(write_mutex_); + detail::write_websocket_frame(strm_, Opcode::Pong, payload.data(), + payload.size(), true, !is_server_); + continue; + } + case Opcode::Pong: { + std::lock_guard lock(ping_mutex_); + unacked_pings_ = 0; + continue; + } + case Opcode::Close: { + if (!closed_.exchange(true)) { + // Echo close frame back + std::lock_guard lock(write_mutex_); + detail::write_websocket_frame(strm_, Opcode::Close, payload.data(), + payload.size(), true, !is_server_); + } + return Fail; + } + case Opcode::Text: + case Opcode::Binary: { + auto result = opcode == Opcode::Text ? Text : Binary; + msg = std::move(payload); + + // Handle fragmentation + if (!fin) { + while (true) { + Opcode cont_opcode; + std::string cont_payload; + bool cont_fin; + if (!impl::read_websocket_frame( + strm_, cont_opcode, cont_payload, cont_fin, is_server_, + CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH)) { + closed_ = true; + return Fail; + } + if (cont_opcode == Opcode::Ping) { + std::lock_guard lock(write_mutex_); + detail::write_websocket_frame( + strm_, Opcode::Pong, cont_payload.data(), cont_payload.size(), + true, !is_server_); + continue; + } + if (cont_opcode == Opcode::Pong) { + std::lock_guard lock(ping_mutex_); + unacked_pings_ = 0; + continue; + } + if (cont_opcode == Opcode::Close) { + if (!closed_.exchange(true)) { + std::lock_guard lock(write_mutex_); + detail::write_websocket_frame( + strm_, Opcode::Close, cont_payload.data(), + cont_payload.size(), true, !is_server_); + } + return Fail; + } + // RFC 6455: continuation frames must use opcode 0x0 + if (cont_opcode != Opcode::Continuation) { + closed_ = true; + return Fail; + } + msg += cont_payload; + if (msg.size() > CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH) { + closed_ = true; + return Fail; + } + if (cont_fin) { break; } + } + } + // RFC 6455 Section 5.6: text frames must contain valid UTF-8 + if (result == Text && !impl::is_valid_utf8(msg)) { + close(CloseStatus::InvalidPayload, "invalid UTF-8"); + return Fail; + } + return result; + } + default: closed_ = true; return Fail; + } + } + return Fail; +} + +bool WebSocket::send(const std::string &data) { + return send_frame(Opcode::Text, data.data(), data.size()); +} + +bool WebSocket::send(const char *data, size_t len) { + return send_frame(Opcode::Binary, data, len); +} + +void WebSocket::close(CloseStatus status, const std::string &reason) { + if (closed_.exchange(true)) { return; } + ping_cv_.notify_all(); + std::string payload; + auto code = static_cast(status); + payload.push_back(static_cast((code >> 8) & 0xFF)); + payload.push_back(static_cast(code & 0xFF)); + // RFC 6455 Section 5.5: control frame payload must not exceed 125 bytes + // Close frame has 2-byte status code, so reason is limited to 123 bytes + payload += reason.substr(0, 123); + { + std::lock_guard lock(write_mutex_); + detail::write_websocket_frame(strm_, Opcode::Close, payload.data(), + payload.size(), true, !is_server_); + } + + // RFC 6455 Section 7.1.1: after sending a Close frame, wait for the peer's + // Close response before closing the TCP connection. Use a short timeout to + // avoid hanging if the peer doesn't respond. + strm_.set_read_timeout(CPPHTTPLIB_WEBSOCKET_CLOSE_TIMEOUT_SECOND, 0); + Opcode op; + std::string resp; + bool fin; + while (impl::read_websocket_frame(strm_, op, resp, fin, is_server_, 125)) { + if (op == Opcode::Close) { break; } + } +} + +WebSocket::~WebSocket() { + { + std::lock_guard lock(ping_mutex_); + closed_ = true; + } + ping_cv_.notify_all(); + if (ping_thread_.joinable()) { ping_thread_.join(); } +} + +void WebSocket::start_heartbeat() { + if (ping_interval_sec_ == 0) { return; } + ping_thread_ = std::thread([this]() { + std::unique_lock lock(ping_mutex_); + while (!closed_) { + ping_cv_.wait_for(lock, std::chrono::seconds(ping_interval_sec_)); + if (closed_) { break; } + // If the peer has failed to respond to the previous pings, give up. + // RFC 6455 does not define a pong-timeout mechanism; this is an + // opt-in liveness check controlled by max_missed_pongs_. + if (max_missed_pongs_ > 0 && unacked_pings_ >= max_missed_pongs_) { + lock.unlock(); + close(CloseStatus::GoingAway, "pong timeout"); + return; + } + lock.unlock(); + if (!send_frame(Opcode::Ping, nullptr, 0)) { + lock.lock(); + closed_ = true; + break; + } + lock.lock(); + unacked_pings_++; + } + }); +} + +const Request &WebSocket::request() const { return req_; } + +bool WebSocket::is_open() const { return !closed_; } + +// WebSocketClient implementation +WebSocketClient::WebSocketClient( + const std::string &scheme_host_port_path, const Headers &headers) + : headers_(headers) { + detail::UrlComponents uc; + if (detail::parse_url(scheme_host_port_path, uc) && !uc.scheme.empty() && + !uc.host.empty() && !uc.path.empty()) { + auto &scheme = uc.scheme; + +#ifdef CPPHTTPLIB_SSL_ENABLED + if (scheme != "ws" && scheme != "wss") { +#else + if (scheme != "ws") { +#endif +#ifndef CPPHTTPLIB_NO_EXCEPTIONS + std::string msg = "'" + scheme + "' scheme is not supported."; + throw std::invalid_argument(msg); +#endif + return; + } + + auto is_ssl = scheme == "wss"; + + host_ = std::move(uc.host); + + port_ = is_ssl ? 443 : 80; + if (!uc.port.empty() && !detail::parse_port(uc.port, port_)) { return; } + + path_ = std::move(uc.path); + +#ifdef CPPHTTPLIB_SSL_ENABLED + is_ssl_ = is_ssl; +#else + if (is_ssl) { return; } +#endif + + is_valid_ = true; + } +} + +WebSocketClient::~WebSocketClient() { shutdown_and_close(); } + +bool WebSocketClient::is_valid() const { return is_valid_; } + +void WebSocketClient::shutdown_and_close() { +#ifdef CPPHTTPLIB_SSL_ENABLED + if (is_ssl_) { + if (tls_session_) { + tls::shutdown(tls_session_, true); + tls::free_session(tls_session_); + tls_session_ = nullptr; + } + if (tls_ctx_) { + tls::free_context(tls_ctx_); + tls_ctx_ = nullptr; + } + } +#endif + if (ws_ && ws_->is_open()) { ws_->close(); } + ws_.reset(); + if (sock_ != INVALID_SOCKET) { + detail::shutdown_socket(sock_); + detail::close_socket(sock_); + sock_ = INVALID_SOCKET; + } +} + +bool WebSocketClient::create_stream(std::unique_ptr &strm) { +#ifdef CPPHTTPLIB_SSL_ENABLED + if (is_ssl_) { + if (!detail::setup_client_tls_session( + host_, tls_ctx_, tls_session_, sock_, + server_certificate_verification_, ca_cert_file_path_, + ca_cert_store_, read_timeout_sec_, read_timeout_usec_)) { + return false; + } + + strm = std::unique_ptr(new detail::SSLSocketStream( + sock_, tls_session_, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_)); + return true; + } +#endif + strm = std::unique_ptr( + new detail::SocketStream(sock_, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_)); + return true; +} + +bool WebSocketClient::connect() { + if (!is_valid_) { return false; } + shutdown_and_close(); + + Error error; + sock_ = detail::create_client_socket( + host_, std::string(), port_, address_family_, tcp_nodelay_, ipv6_v6only_, + socket_options_, connection_timeout_sec_, connection_timeout_usec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, interface_, error); + + if (sock_ == INVALID_SOCKET) { return false; } + + std::unique_ptr strm; + if (!create_stream(strm)) { + shutdown_and_close(); + return false; + } + + std::string selected_subprotocol; + if (!detail::perform_websocket_handshake(*strm, host_, port_, path_, headers_, + selected_subprotocol)) { + shutdown_and_close(); + return false; + } + subprotocol_ = std::move(selected_subprotocol); + + Request req; + req.method = "GET"; + req.path = path_; + ws_ = std::unique_ptr(new WebSocket(std::move(strm), req, false, + websocket_ping_interval_sec_, + websocket_max_missed_pongs_)); + return true; +} + +ReadResult WebSocketClient::read(std::string &msg) { + if (!ws_) { return Fail; } + return ws_->read(msg); +} + +bool WebSocketClient::send(const std::string &data) { + if (!ws_) { return false; } + return ws_->send(data); +} + +bool WebSocketClient::send(const char *data, size_t len) { + if (!ws_) { return false; } + return ws_->send(data, len); +} + +void WebSocketClient::close(CloseStatus status, + const std::string &reason) { + if (ws_) { ws_->close(status, reason); } +} + +bool WebSocketClient::is_open() const { return ws_ && ws_->is_open(); } + +const std::string &WebSocketClient::subprotocol() const { + return subprotocol_; +} + +void WebSocketClient::set_read_timeout(time_t sec, time_t usec) { + read_timeout_sec_ = sec; + read_timeout_usec_ = usec; +} + +void WebSocketClient::set_write_timeout(time_t sec, time_t usec) { + write_timeout_sec_ = sec; + write_timeout_usec_ = usec; +} + +void WebSocketClient::set_websocket_ping_interval(time_t sec) { + websocket_ping_interval_sec_ = sec; +} + +void WebSocketClient::set_websocket_max_missed_pongs(int count) { + websocket_max_missed_pongs_ = count; +} + +void WebSocketClient::set_tcp_nodelay(bool on) { tcp_nodelay_ = on; } + +void WebSocketClient::set_address_family(int family) { + address_family_ = family; +} + +void WebSocketClient::set_ipv6_v6only(bool on) { ipv6_v6only_ = on; } + +void WebSocketClient::set_socket_options(SocketOptions socket_options) { + socket_options_ = std::move(socket_options); +} + +void WebSocketClient::set_connection_timeout(time_t sec, time_t usec) { + connection_timeout_sec_ = sec; + connection_timeout_usec_ = usec; +} + +void WebSocketClient::set_interface(const std::string &intf) { + interface_ = intf; +} + +#ifdef CPPHTTPLIB_SSL_ENABLED + +void WebSocketClient::set_ca_cert_path(const std::string &path) { + ca_cert_file_path_ = path; +} + +void WebSocketClient::set_ca_cert_store(tls::ca_store_t store) { + ca_cert_store_ = store; +} + +void +WebSocketClient::enable_server_certificate_verification(bool enabled) { + server_certificate_verification_ = enabled; +} + +#endif // CPPHTTPLIB_SSL_ENABLED + +} // namespace ws + } // namespace httplib diff --git a/vendor/cpp-httplib/httplib.h b/vendor/cpp-httplib/httplib.h index 083f7950..af856dd6 100644 --- a/vendor/cpp-httplib/httplib.h +++ b/vendor/cpp-httplib/httplib.h @@ -1,35 +1,15 @@ // // httplib.h // -// Copyright (c) 2025 Yuji Hirose. All rights reserved. +// Copyright (c) 2026 Yuji Hirose. All rights reserved. // MIT License // #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.27.0" -#define CPPHTTPLIB_VERSION_NUM "0x001B00" - -/* - * Platform compatibility check - */ - -#if defined(_WIN32) && !defined(_WIN64) -#if defined(_MSC_VER) -#pragma message( \ - "cpp-httplib doesn't support 32-bit Windows. Please use a 64-bit compiler.") -#else -#warning \ - "cpp-httplib doesn't support 32-bit Windows. Please use a 64-bit compiler." -#endif -#elif defined(__SIZEOF_POINTER__) && __SIZEOF_POINTER__ < 8 -#warning \ - "cpp-httplib doesn't support 32-bit platforms. Please use a 64-bit compiler." -#elif defined(__SIZEOF_SIZE_T__) && __SIZEOF_SIZE_T__ < 8 -#warning \ - "cpp-httplib doesn't support platforms where size_t is less than 64 bits." -#endif +#define CPPHTTPLIB_VERSION "0.46.0" +#define CPPHTTPLIB_VERSION_NUM "0x002e00" #ifdef _WIN32 #if defined(_WIN32_WINNT) && _WIN32_WINNT < 0x0A00 @@ -98,6 +78,22 @@ #define CPPHTTPLIB_CLIENT_MAX_TIMEOUT_MSECOND 0 #endif +#ifndef CPPHTTPLIB_EXPECT_100_THRESHOLD +#define CPPHTTPLIB_EXPECT_100_THRESHOLD 1024 +#endif + +#ifndef CPPHTTPLIB_EXPECT_100_TIMEOUT_MSECOND +#define CPPHTTPLIB_EXPECT_100_TIMEOUT_MSECOND 1000 +#endif + +#ifndef CPPHTTPLIB_WAIT_EARLY_SERVER_RESPONSE_THRESHOLD +#define CPPHTTPLIB_WAIT_EARLY_SERVER_RESPONSE_THRESHOLD (1024 * 1024) +#endif + +#ifndef CPPHTTPLIB_WAIT_EARLY_SERVER_RESPONSE_TIMEOUT_MSECOND +#define CPPHTTPLIB_WAIT_EARLY_SERVER_RESPONSE_TIMEOUT_MSECOND 50 +#endif + #ifndef CPPHTTPLIB_IDLE_INTERVAL_SECOND #define CPPHTTPLIB_IDLE_INTERVAL_SECOND 0 #endif @@ -131,7 +127,7 @@ #endif #ifndef CPPHTTPLIB_PAYLOAD_MAX_LENGTH -#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH ((std::numeric_limits::max)()) +#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH (100 * 1024 * 1024) // 100MB #endif #ifndef CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH @@ -169,6 +165,14 @@ : 0)) #endif +#ifndef CPPHTTPLIB_THREAD_POOL_MAX_COUNT +#define CPPHTTPLIB_THREAD_POOL_MAX_COUNT (CPPHTTPLIB_THREAD_POOL_COUNT * 4) +#endif + +#ifndef CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT +#define CPPHTTPLIB_THREAD_POOL_IDLE_TIMEOUT 3 // seconds +#endif + #ifndef CPPHTTPLIB_RECV_FLAGS #define CPPHTTPLIB_RECV_FLAGS 0 #endif @@ -185,6 +189,26 @@ #define CPPHTTPLIB_MAX_LINE_LENGTH 32768 #endif +#ifndef CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH +#define CPPHTTPLIB_WEBSOCKET_MAX_PAYLOAD_LENGTH 16777216 +#endif + +#ifndef CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND +#define CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND 300 +#endif + +#ifndef CPPHTTPLIB_WEBSOCKET_CLOSE_TIMEOUT_SECOND +#define CPPHTTPLIB_WEBSOCKET_CLOSE_TIMEOUT_SECOND 5 +#endif + +#ifndef CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND +#define CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND 30 +#endif + +#ifndef CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS +#define CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS 0 +#endif + /* * Headers */ @@ -205,7 +229,10 @@ #pragma comment(lib, "ws2_32.lib") +#ifndef _SSIZE_T_DEFINED using ssize_t = __int64; +#define _SSIZE_T_DEFINED +#endif #endif // _MSC_VER #ifndef S_ISREG @@ -257,6 +284,7 @@ using socklen_t = int; #include #ifdef __linux__ #include +#undef _res // Undefine _res macro to avoid conflicts with user code (#2278) #endif #include #include @@ -282,12 +310,15 @@ using socket_t = int; #include #include #include +#include #include #include +#include #include #include #include #include +#include #include #include #include @@ -301,19 +332,48 @@ using socket_t = int; #include #include #include +#include #include #include #include #include +// On macOS with a TLS backend, enable Keychain root certificates by default +// unless the user explicitly opts out. Not enabled on iOS/tvOS/watchOS since +// the SecTrustSettings APIs used to enumerate anchor certificates are macOS +// only; on those platforms the user must provide a CA bundle explicitly. +#if defined(__APPLE__) && defined(__clang__) && \ + !defined(CPPHTTPLIB_DISABLE_MACOSX_AUTOMATIC_ROOT_CERTIFICATES) && \ + (defined(CPPHTTPLIB_OPENSSL_SUPPORT) || \ + defined(CPPHTTPLIB_MBEDTLS_SUPPORT) || \ + defined(CPPHTTPLIB_WOLFSSL_SUPPORT)) +#if TARGET_OS_OSX +#ifndef CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN +#define CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN +#endif +#endif +#endif + +#if defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && \ + defined(__APPLE__) && !TARGET_OS_OSX +#error \ + "CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN is only supported on macOS. On iOS/tvOS/watchOS, supply a CA bundle via set_ca_cert_path()." +#endif + +// On Windows, enable Schannel certificate verification by default +// unless the user explicitly opts out. +#if defined(_WIN32) && \ + !defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE) +#define CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE +#endif + #if defined(CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO) || \ defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) -#if TARGET_OS_MAC +#if TARGET_OS_MAC && defined(__clang__) #include #include #endif -#endif // CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO or - // CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN +#endif #ifdef CPPHTTPLIB_OPENSSL_SUPPORT #ifdef _WIN32 @@ -331,11 +391,11 @@ using socket_t = int; #endif #endif // _WIN32 -#if defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) -#if TARGET_OS_MAC +#ifdef CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN +#if TARGET_OS_OSX #include #endif -#endif // CPPHTTPLIB_USE_NON_BLOCKING_GETADDRINFO +#endif #include #include @@ -360,6 +420,81 @@ using socket_t = int; #endif // CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_MBEDTLS_SUPPORT +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef _WIN32 +#include +#ifdef _MSC_VER +#pragma comment(lib, "crypt32.lib") +#endif +#endif // _WIN32 +#ifdef CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN +#if TARGET_OS_OSX +#include +#endif +#endif + +// Mbed TLS 3.x API compatibility +#if MBEDTLS_VERSION_MAJOR >= 3 +#define CPPHTTPLIB_MBEDTLS_V3 +#endif + +#endif // CPPHTTPLIB_MBEDTLS_SUPPORT + +#ifdef CPPHTTPLIB_WOLFSSL_SUPPORT +#include + +#include + +// Fallback definitions for older wolfSSL versions (e.g., 5.6.6) +#ifndef WOLFSSL_GEN_EMAIL +#define WOLFSSL_GEN_EMAIL 1 +#endif +#ifndef WOLFSSL_GEN_DNS +#define WOLFSSL_GEN_DNS 2 +#endif +#ifndef WOLFSSL_GEN_URI +#define WOLFSSL_GEN_URI 6 +#endif +#ifndef WOLFSSL_GEN_IPADD +#define WOLFSSL_GEN_IPADD 7 +#endif + +#include +#include +#include +#include +#include +#ifdef _WIN32 +#include +#ifdef _MSC_VER +#pragma comment(lib, "crypt32.lib") +#endif +#endif // _WIN32 +#ifdef CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN +#if TARGET_OS_OSX +#include +#endif +#endif +#endif // CPPHTTPLIB_WOLFSSL_SUPPORT + +// Define CPPHTTPLIB_SSL_ENABLED if any SSL backend is available +#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) || \ + defined(CPPHTTPLIB_MBEDTLS_SUPPORT) || defined(CPPHTTPLIB_WOLFSSL_SUPPORT) +#define CPPHTTPLIB_SSL_ENABLED +#endif + #ifdef CPPHTTPLIB_ZLIB_SUPPORT #include #endif @@ -378,6 +513,10 @@ using socket_t = int; */ namespace httplib { +namespace ws { +class WebSocket; +} // namespace ws + namespace detail { /* @@ -427,6 +566,14 @@ inline unsigned char to_lower(int c) { return table[(unsigned char)(char)c]; } +inline std::string to_lower(const std::string &s) { + std::string result = s; + std::transform( + result.begin(), result.end(), result.begin(), + [](unsigned char c) { return static_cast(to_lower(c)); }); + return result; +} + inline bool equal(const std::string &a, const std::string &b) { return a.size() == b.size() && std::equal(a.begin(), a.end(), b.begin(), [](char ca, char cb) { @@ -490,9 +637,171 @@ private: bool execute_on_destruction; }; +// Simple from_chars implementation for integer and double types (C++17 +// substitute) +template struct from_chars_result { + const char *ptr; + std::errc ec; +}; + +template +inline from_chars_result from_chars(const char *first, const char *last, + T &value, int base = 10) { + value = 0; + const char *p = first; + bool negative = false; + + if (p != last && *p == '-') { + negative = true; + ++p; + } + if (p == last) { return {first, std::errc::invalid_argument}; } + + T result = 0; + for (; p != last; ++p) { + char c = *p; + int digit = -1; + if ('0' <= c && c <= '9') { + digit = c - '0'; + } else if ('a' <= c && c <= 'z') { + digit = c - 'a' + 10; + } else if ('A' <= c && c <= 'Z') { + digit = c - 'A' + 10; + } else { + break; + } + + if (digit < 0 || digit >= base) { break; } + if (result > ((std::numeric_limits::max)() - digit) / base) { + return {p, std::errc::result_out_of_range}; + } + result = result * base + digit; + } + + if (p == first || (negative && p == first + 1)) { + return {first, std::errc::invalid_argument}; + } + + value = negative ? -result : result; + return {p, std::errc{}}; +} + +// from_chars for double (simple wrapper for strtod) +inline from_chars_result from_chars(const char *first, const char *last, + double &value) { + std::string s(first, last); + char *endptr = nullptr; + errno = 0; + value = std::strtod(s.c_str(), &endptr); + if (endptr == s.c_str()) { return {first, std::errc::invalid_argument}; } + if (errno == ERANGE) { + return {first + (endptr - s.c_str()), std::errc::result_out_of_range}; + } + return {first + (endptr - s.c_str()), std::errc{}}; +} + +inline bool parse_port(const char *s, size_t len, int &port) { + int val = 0; + auto r = from_chars(s, s + len, val); + if (r.ec != std::errc{} || val < 1 || val > 65535) { return false; } + port = val; + return true; +} + +inline bool parse_port(const std::string &s, int &port) { + return parse_port(s.data(), s.size(), port); +} + +struct UrlComponents { + std::string scheme; + std::string host; + std::string port; + std::string path; + std::string query; +}; + +inline bool parse_url(const std::string &url, UrlComponents &uc) { + uc = {}; + size_t pos = 0; + + auto sep = url.find("://"); + if (sep != std::string::npos) { + uc.scheme = url.substr(0, sep); + + // Scheme must be [a-z]+ only + if (uc.scheme.empty()) { return false; } + for (auto c : uc.scheme) { + if (c < 'a' || c > 'z') { return false; } + } + + pos = sep + 3; + } else if (url.compare(0, 2, "//") == 0) { + pos = 2; + } + + auto has_authority_prefix = pos > 0; + auto has_authority = has_authority_prefix || (!url.empty() && url[0] != '/' && + url[0] != '?' && url[0] != '#'); + if (has_authority) { + if (pos < url.size() && url[pos] == '[') { + auto close = url.find(']', pos); + if (close == std::string::npos) { return false; } + uc.host = url.substr(pos + 1, close - pos - 1); + + // IPv6 host must be [a-fA-F0-9:]+ only + if (uc.host.empty()) { return false; } + for (auto c : uc.host) { + if (!((c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F') || + (c >= '0' && c <= '9') || c == ':')) { + return false; + } + } + + pos = close + 1; + } else { + auto end = url.find_first_of(":/?#", pos); + if (end == std::string::npos) { end = url.size(); } + uc.host = url.substr(pos, end - pos); + pos = end; + } + + if (pos < url.size() && url[pos] == ':') { + ++pos; + auto end = url.find_first_of("/?#", pos); + if (end == std::string::npos) { end = url.size(); } + uc.port = url.substr(pos, end - pos); + pos = end; + } + + // Without :// or //, the entire input must be consumed as host[:port]. + // If there is leftover (path, query, etc.), this is not a valid + // host[:port] string โ€” clear and reparse as a plain path. + if (!has_authority_prefix && pos < url.size()) { + uc.host.clear(); + uc.port.clear(); + pos = 0; + } + } + + if (pos < url.size() && url[pos] != '?' && url[pos] != '#') { + auto end = url.find_first_of("?#", pos); + if (end == std::string::npos) { end = url.size(); } + uc.path = url.substr(pos, end - pos); + pos = end; + } + + if (pos < url.size() && url[pos] == '?') { + auto end = url.find('#', pos); + if (end == std::string::npos) { end = url.size(); } + uc.query = url.substr(pos, end - pos); + } + + return true; +} + } // namespace detail -enum SSLVerifierResponse { +enum class SSLVerifierResponse { // no decision has been made, use the built-in certificate verifier NoDecisionMade, // connection certificate is verified and accepted @@ -586,6 +895,91 @@ using Match = std::smatch; using DownloadProgress = std::function; using UploadProgress = std::function; +/* + * detail: type-erased storage used by UserData. + * ABI-stable regardless of C++ standard โ€” always uses this custom + * implementation instead of std::any. + */ +namespace detail { + +using any_type_id = const void *; + +template any_type_id any_typeid() noexcept { + static const char id = 0; + return &id; +} + +struct any_storage { + virtual ~any_storage() = default; + virtual std::unique_ptr clone() const = 0; + virtual any_type_id type_id() const noexcept = 0; +}; + +template struct any_value final : any_storage { + T value; + template explicit any_value(U &&v) : value(std::forward(v)) {} + std::unique_ptr clone() const override { + return std::unique_ptr(new any_value(value)); + } + any_type_id type_id() const noexcept override { return any_typeid(); } +}; + +} // namespace detail + +class UserData { +public: + UserData() = default; + UserData(UserData &&) noexcept = default; + UserData &operator=(UserData &&) noexcept = default; + + UserData(const UserData &o) { + for (const auto &e : o.entries_) { + if (e.second) { entries_[e.first] = e.second->clone(); } + } + } + + UserData &operator=(const UserData &o) { + if (this != &o) { + entries_.clear(); + for (const auto &e : o.entries_) { + if (e.second) { entries_[e.first] = e.second->clone(); } + } + } + return *this; + } + + template void set(const std::string &key, T &&value) { + using D = typename std::decay::type; + entries_[key].reset(new detail::any_value(std::forward(value))); + } + + template T *get(const std::string &key) noexcept { + auto it = entries_.find(key); + if (it == entries_.end() || !it->second) { return nullptr; } + if (it->second->type_id() != detail::any_typeid()) { return nullptr; } + return &static_cast *>(it->second.get())->value; + } + + template const T *get(const std::string &key) const noexcept { + auto it = entries_.find(key); + if (it == entries_.end() || !it->second) { return nullptr; } + if (it->second->type_id() != detail::any_typeid()) { return nullptr; } + return &static_cast *>(it->second.get())->value; + } + + bool has(const std::string &key) const noexcept { + return entries_.find(key) != entries_.end(); + } + + void erase(const std::string &key) { entries_.erase(key); } + + void clear() noexcept { entries_.clear(); } + +private: + std::unordered_map> + entries_; +}; + struct Response; using ResponseHandler = std::function; @@ -653,8 +1047,8 @@ private: protected: std::streamsize xsputn(const char *s, std::streamsize n) override { - sink_.write(s, static_cast(n)); - return n; + if (sink_.write(s, static_cast(n))) { return n; } + return 0; } private: @@ -680,6 +1074,63 @@ struct FormDataProvider { }; using FormDataProviderItems = std::vector; +inline FormDataProvider +make_file_provider(const std::string &name, const std::string &filepath, + const std::string &filename = std::string(), + const std::string &content_type = std::string()) { + FormDataProvider fdp; + fdp.name = name; + fdp.filename = filename.empty() ? filepath : filename; + fdp.content_type = content_type; + fdp.provider = [filepath](size_t offset, DataSink &sink) -> bool { + std::ifstream f(filepath, std::ios::binary); + if (!f) { return false; } + if (offset > 0) { + f.seekg(static_cast(offset)); + if (!f.good()) { + sink.done(); + return true; + } + } + char buf[8192]; + f.read(buf, sizeof(buf)); + auto n = static_cast(f.gcount()); + if (n > 0) { return sink.write(buf, n); } + sink.done(); // EOF + return true; + }; + return fdp; +} + +inline std::pair +make_file_body(const std::string &filepath) { + size_t size = 0; + { + std::ifstream f(filepath, std::ios::binary | std::ios::ate); + if (!f) { return {0, ContentProvider{}}; } + size = static_cast(f.tellg()); + } + + ContentProvider provider = [filepath](size_t offset, size_t length, + DataSink &sink) -> bool { + std::ifstream f(filepath, std::ios::binary); + if (!f) { return false; } + f.seekg(static_cast(offset)); + if (!f.good()) { return false; } + char buf[8192]; + while (length > 0) { + auto to_read = (std::min)(sizeof(buf), length); + f.read(buf, static_cast(to_read)); + auto n = static_cast(f.gcount()); + if (n == 0) { break; } + if (!sink.write(buf, n)) { return false; } + length -= n; + } + return true; + }; + return {size, std::move(provider)}; +} + using ContentReceiverWithProgress = std::function; @@ -713,6 +1164,105 @@ public: using Range = std::pair; using Ranges = std::vector; +#ifdef CPPHTTPLIB_SSL_ENABLED +// TLS abstraction layer - public type definitions and API +namespace tls { + +// Opaque handles (defined as void* for abstraction) +using ctx_t = void *; +using session_t = void *; +using const_session_t = const void *; // For read-only session access +using cert_t = void *; +using ca_store_t = void *; + +// TLS versions +enum class Version { + TLS1_2 = 0x0303, + TLS1_3 = 0x0304, +}; + +// Subject Alternative Names (SAN) entry types +enum class SanType { DNS, IP, EMAIL, URI, OTHER }; + +// SAN entry structure +struct SanEntry { + SanType type; + std::string value; +}; + +// Verification context for certificate verification callback +struct VerifyContext { + session_t session; // TLS session handle + cert_t cert; // Current certificate being verified + int depth; // Certificate chain depth (0 = leaf) + bool preverify_ok; // OpenSSL/Mbed TLS pre-verification result + long error_code; // Backend-specific error code (0 = no error) + const char *error_string; // Human-readable error description + + // Certificate introspection methods + std::string subject_cn() const; + std::string issuer_name() const; + bool check_hostname(const char *hostname) const; + std::vector sans() const; + bool validity(time_t ¬_before, time_t ¬_after) const; + std::string serial() const; +}; + +using VerifyCallback = std::function; + +// TlsError codes for TLS operations (backend-independent) +enum class ErrorCode : int { + Success = 0, + WantRead, // Non-blocking: need to wait for read + WantWrite, // Non-blocking: need to wait for write + PeerClosed, // Peer closed the connection + Fatal, // Unrecoverable error + SyscallError, // System call error (check sys_errno) + CertVerifyFailed, // Certificate verification failed + HostnameMismatch, // Hostname verification failed +}; + +// TLS error information +struct TlsError { + ErrorCode code = ErrorCode::Fatal; + uint64_t backend_code = 0; // OpenSSL: ERR_get_error(), mbedTLS: return value + int sys_errno = 0; // errno when SyscallError + + // Convert verification error code to human-readable string + static std::string verify_error_to_string(long error_code); +}; + +// RAII wrapper for peer certificate +class PeerCert { +public: + PeerCert(); + PeerCert(PeerCert &&other) noexcept; + PeerCert &operator=(PeerCert &&other) noexcept; + ~PeerCert(); + + PeerCert(const PeerCert &) = delete; + PeerCert &operator=(const PeerCert &) = delete; + + explicit operator bool() const; + std::string subject_cn() const; + std::string issuer_name() const; + bool check_hostname(const char *hostname) const; + std::vector sans() const; + bool validity(time_t ¬_before, time_t ¬_after) const; + std::string serial() const; + +private: + explicit PeerCert(cert_t cert); + cert_t cert_ = nullptr; + friend PeerCert get_peer_cert_from_session(const_session_t session); +}; + +// Callback for TLS context setup (used by SSLServer constructor) +using ContextSetupCallback = std::function; + +} // namespace tls +#endif + struct Request { std::string method; std::string path; @@ -742,9 +1292,6 @@ struct Request { ContentReceiverWithProgress content_receiver; DownloadProgress download_progress; UploadProgress upload_progress; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - const SSL *ssl = nullptr; -#endif bool has_header(const std::string &key) const; std::string get_header_value(const std::string &key, const char *def = "", @@ -760,11 +1307,13 @@ struct Request { bool has_param(const std::string &key) const; std::string get_param_value(const std::string &key, size_t id = 0) const; + std::vector get_param_values(const std::string &key) const; size_t get_param_value_count(const std::string &key) const; bool is_multipart_form_data() const; // private members... + bool body_consumed_ = false; size_t redirect_count_ = CPPHTTPLIB_REDIRECT_MAX_COUNT; size_t content_length_ = 0; ContentProvider content_provider_; @@ -772,6 +1321,12 @@ struct Request { size_t authorization_count_ = 0; std::chrono::time_point start_time_ = (std::chrono::steady_clock::time_point::min)(); + +#ifdef CPPHTTPLIB_SSL_ENABLED + tls::const_session_t ssl = nullptr; + tls::PeerCert peer_cert() const; + std::string sni() const; +#endif }; struct Response { @@ -783,6 +1338,10 @@ struct Response { std::string body; std::string location; // Redirect location + // User-defined context โ€” set by pre-routing/pre-request handlers and read + // by route handlers to pass arbitrary data (e.g. decoded auth tokens). + UserData user_data; + bool has_header(const std::string &key) const; std::string get_header_value(const std::string &key, const char *def = "", size_t id = 0) const; @@ -837,6 +1396,50 @@ struct Response { std::string file_content_content_type_; }; +enum class Error { + Success = 0, + Unknown, + Connection, + BindIPAddress, + Read, + Write, + ExceedRedirectCount, + Canceled, + SSLConnection, + SSLLoadingCerts, + SSLServerVerification, + SSLServerHostnameVerification, + UnsupportedMultipartBoundaryChars, + Compression, + ConnectionTimeout, + ProxyConnection, + ConnectionClosed, + Timeout, + ResourceExhaustion, + TooManyFormDataFiles, + ExceedMaxPayloadSize, + ExceedUriMaxLength, + ExceedMaxSocketDescriptorCount, + InvalidRequestLine, + InvalidHTTPMethod, + InvalidHTTPVersion, + InvalidHeaders, + MultipartParsing, + OpenFile, + Listen, + GetSockName, + UnsupportedAddressFamily, + HTTPParsing, + InvalidRangeHeader, + + // For internal use only + SSLPeerCouldBeClosed_, +}; + +std::string to_string(Error error); + +std::ostream &operator<<(std::ostream &os, const Error &obj); + class Stream { public: virtual ~Stream() = default; @@ -844,6 +1447,7 @@ public: virtual bool is_readable() const = 0; virtual bool wait_readable() const = 0; virtual bool wait_writable() const = 0; + virtual bool is_peer_alive() const { return wait_writable(); } virtual ssize_t read(char *ptr, size_t size) = 0; virtual ssize_t write(const char *ptr, size_t size) = 0; @@ -853,8 +1457,18 @@ public: virtual time_t duration() const = 0; + virtual void set_read_timeout(time_t sec, time_t usec = 0) { + (void)sec; + (void)usec; + } + ssize_t write(const char *ptr); ssize_t write(const std::string &s); + + Error get_error() const { return error_; } + +protected: + Error error_ = Error::Success; }; class TaskQueue { @@ -870,83 +1484,30 @@ public: class ThreadPool final : public TaskQueue { public: - explicit ThreadPool(size_t n, size_t mqr = 0) - : shutdown_(false), max_queued_requests_(mqr) { - while (n) { - threads_.emplace_back(worker(*this)); - n--; - } - } - + explicit ThreadPool(size_t n, size_t max_n = 0, size_t mqr = 0); ThreadPool(const ThreadPool &) = delete; ~ThreadPool() override = default; - bool enqueue(std::function fn) override { - { - std::unique_lock lock(mutex_); - if (max_queued_requests_ > 0 && jobs_.size() >= max_queued_requests_) { - return false; - } - jobs_.push_back(std::move(fn)); - } - - cond_.notify_one(); - return true; - } - - void shutdown() override { - // Stop all worker threads... - { - std::unique_lock lock(mutex_); - shutdown_ = true; - } - - cond_.notify_all(); - - // Join... - for (auto &t : threads_) { - t.join(); - } - } + bool enqueue(std::function fn) override; + void shutdown() override; private: - struct worker { - explicit worker(ThreadPool &pool) : pool_(pool) {} + void worker(bool is_dynamic); + void move_to_finished(std::thread::id id); + void cleanup_finished_threads(); - void operator()() { - for (;;) { - std::function fn; - { - std::unique_lock lock(pool_.mutex_); - - pool_.cond_.wait( - lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); - - if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } - - fn = pool_.jobs_.front(); - pool_.jobs_.pop_front(); - } - - assert(true == static_cast(fn)); - fn(); - } - -#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(OPENSSL_IS_BORINGSSL) && \ - !defined(LIBRESSL_VERSION_NUMBER) - OPENSSL_thread_stop(); -#endif - } - - ThreadPool &pool_; - }; - friend struct worker; - - std::vector threads_; - std::list> jobs_; + size_t base_thread_count_; + size_t max_thread_count_; + size_t max_queued_requests_; + size_t idle_thread_count_; bool shutdown_; - size_t max_queued_requests_ = 0; + + std::list> jobs_; + std::vector threads_; // base threads + std::list dynamic_threads_; // dynamic threads + std::vector + finished_threads_; // exited dynamic threads awaiting join std::condition_variable cond_; std::mutex mutex_; @@ -960,27 +1521,23 @@ using ErrorLogger = std::function; using SocketOptions = std::function; -namespace detail { - -bool set_socket_opt_impl(socket_t sock, int level, int optname, - const void *optval, socklen_t optlen); -bool set_socket_opt(socket_t sock, int level, int optname, int opt); -bool set_socket_opt_time(socket_t sock, int level, int optname, time_t sec, - time_t usec); - -} // namespace detail - void default_socket_options(socket_t sock); +bool set_socket_opt(socket_t sock, int level, int optname, int optval); + const char *status_message(int status); +std::string to_string(Error error); + +std::ostream &operator<<(std::ostream &os, const Error &obj); + std::string get_bearer_token_auth(const Request &req); namespace detail { class MatcherBase { public: - MatcherBase(std::string pattern) : pattern_(pattern) {} + MatcherBase(std::string pattern) : pattern_(std::move(pattern)) {} virtual ~MatcherBase() = default; const std::string &pattern() const { return pattern_; } @@ -1050,10 +1607,19 @@ private: std::regex regex_; }; +int close_socket(socket_t sock) noexcept; + ssize_t write_headers(Stream &strm, const Headers &headers); -std::string make_host_and_port_string(const std::string &host, int port, - bool is_ssl); +bool set_socket_opt_time(socket_t sock, int level, int optname, time_t sec, + time_t usec); + +size_t get_multipart_content_length(const UploadFormDataItems &items, + const std::string &boundary); + +ContentProvider +make_multipart_content_provider(const UploadFormDataItems &items, + const std::string &boundary); } // namespace detail @@ -1077,6 +1643,11 @@ public: using Expect100ContinueHandler = std::function; + using WebSocketHandler = + std::function; + using SubProtocolSelector = + std::function &protocols)>; + Server(); virtual ~Server(); @@ -1094,6 +1665,10 @@ public: Server &Delete(const std::string &pattern, HandlerWithContentReader handler); Server &Options(const std::string &pattern, Handler handler); + Server &WebSocket(const std::string &pattern, WebSocketHandler handler); + Server &WebSocket(const std::string &pattern, WebSocketHandler handler, + SubProtocolSelector sub_protocol_selector); + bool set_base_dir(const std::string &dir, const std::string &mount_point = std::string()); bool set_mount_point(const std::string &mount_point, const std::string &dir, @@ -1136,6 +1711,9 @@ public: Server &set_keep_alive_max_count(size_t count); Server &set_keep_alive_timeout(time_t sec); + template + Server & + set_keep_alive_timeout(const std::chrono::duration &duration); Server &set_read_timeout(time_t sec, time_t usec = 0); template @@ -1151,6 +1729,13 @@ public: Server &set_payload_max_length(size_t length); + Server &set_websocket_ping_interval(time_t sec); + template + Server &set_websocket_ping_interval( + const std::chrono::duration &duration); + + Server &set_websocket_max_missed_pongs(int count); + bool bind_to_port(const std::string &host, int port, int socket_flags = 0); int bind_to_any_port(const std::string &host, int socket_flags = 0); bool listen_after_bind(); @@ -1159,7 +1744,7 @@ public: bool is_running() const; void wait_until_ready() const; - void stop(); + void stop() noexcept; void decommission(); std::function new_task_queue; @@ -1169,7 +1754,8 @@ protected: int remote_port, const std::string &local_addr, int local_port, bool close_connection, bool &connection_closed, - const std::function &setup_request); + const std::function &setup_request, + bool *websocket_upgraded = nullptr); std::atomic svr_sock_{INVALID_SOCKET}; @@ -1184,6 +1770,9 @@ protected: time_t idle_interval_sec_ = CPPHTTPLIB_IDLE_INTERVAL_SECOND; time_t idle_interval_usec_ = CPPHTTPLIB_IDLE_INTERVAL_USECOND; size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; + time_t websocket_ping_interval_sec_ = + CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND; + int websocket_max_missed_pongs_ = CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS; private: using Handlers = @@ -1195,6 +1784,14 @@ private: static std::unique_ptr make_matcher(const std::string &pattern); + template + Server &add_handler( + std::vector, H>> &handlers, + const std::string &pattern, H handler) { + handlers.emplace_back(make_matcher(pattern), std::move(handler)); + return *this; + } + Server &set_error_handler_core(HandlerWithResponse handler, std::true_type); Server &set_error_handler_core(Handler handler, std::false_type); @@ -1205,7 +1802,11 @@ private: bool listen_internal(); bool routing(Request &req, Response &res, Stream &strm); - bool handle_file_request(const Request &req, Response &res); + bool handle_file_request(Request &req, Response &res); + bool check_if_not_modified(const Request &req, Response &res, + const std::string &etag, time_t mtime) const; + bool check_if_range(Request &req, const std::string &etag, + time_t mtime) const; bool dispatch_request(Request &req, Response &res, const Handlers &handlers) const; bool dispatch_request_for_content_reader( @@ -1249,6 +1850,7 @@ private: struct MountPointEntry { std::string mount_point; std::string base_dir; + std::string resolved_base_dir; Headers headers; }; std::vector base_dirs_; @@ -1267,6 +1869,14 @@ private: HandlersForContentReader delete_handlers_for_content_reader_; Handlers options_handlers_; + struct WebSocketHandlerEntry { + std::unique_ptr matcher; + WebSocketHandler handler; + SubProtocolSelector sub_protocol_selector; + }; + using WebSocketHandlers = std::vector; + WebSocketHandlers websocket_handlers_; + HandlerWithResponse error_handler_; ExceptionHandler exception_handler_; HandlerWithResponse pre_routing_handler_; @@ -1289,48 +1899,6 @@ private: detail::write_headers; }; -enum class Error { - Success = 0, - Unknown, - Connection, - BindIPAddress, - Read, - Write, - ExceedRedirectCount, - Canceled, - SSLConnection, - SSLLoadingCerts, - SSLServerVerification, - SSLServerHostnameVerification, - UnsupportedMultipartBoundaryChars, - Compression, - ConnectionTimeout, - ProxyConnection, - ResourceExhaustion, - TooManyFormDataFiles, - ExceedMaxPayloadSize, - ExceedUriMaxLength, - ExceedMaxSocketDescriptorCount, - InvalidRequestLine, - InvalidHTTPMethod, - InvalidHTTPVersion, - InvalidHeaders, - MultipartParsing, - OpenFile, - Listen, - GetSockName, - UnsupportedAddressFamily, - HTTPParsing, - InvalidRangeHeader, - - // For internal use only - SSLPeerCouldBeClosed_, -}; - -std::string to_string(Error error); - -std::ostream &operator<<(std::ostream &os, const Error &obj); - class Result { public: Result() = default; @@ -1338,17 +1906,6 @@ public: Headers &&request_headers = Headers{}) : res_(std::move(res)), err_(err), request_headers_(std::move(request_headers)) {} -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - Result(std::unique_ptr &&res, Error err, Headers &&request_headers, - int ssl_error) - : res_(std::move(res)), err_(err), - request_headers_(std::move(request_headers)), ssl_error_(ssl_error) {} - Result(std::unique_ptr &&res, Error err, Headers &&request_headers, - int ssl_error, unsigned long ssl_openssl_error) - : res_(std::move(res)), err_(err), - request_headers_(std::move(request_headers)), ssl_error_(ssl_error), - ssl_openssl_error_(ssl_openssl_error) {} -#endif // Response operator bool() const { return res_ != nullptr; } bool operator==(std::nullptr_t) const { return res_ == nullptr; } @@ -1363,13 +1920,6 @@ public: // Error Error error() const { return err_; } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - // SSL Error - int ssl_error() const { return ssl_error_; } - // OpenSSL Error - unsigned long ssl_openssl_error() const { return ssl_openssl_error_; } -#endif - // Request Headers bool has_request_header(const std::string &key) const; std::string get_request_header_value(const std::string &key, @@ -1383,12 +1933,124 @@ private: std::unique_ptr res_; Error err_ = Error::Unknown; Headers request_headers_; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + +#ifdef CPPHTTPLIB_SSL_ENABLED +public: + Result(std::unique_ptr &&res, Error err, Headers &&request_headers, + int ssl_error) + : res_(std::move(res)), err_(err), + request_headers_(std::move(request_headers)), ssl_error_(ssl_error) {} + Result(std::unique_ptr &&res, Error err, Headers &&request_headers, + int ssl_error, uint64_t ssl_backend_error) + : res_(std::move(res)), err_(err), + request_headers_(std::move(request_headers)), ssl_error_(ssl_error), + ssl_backend_error_(ssl_backend_error) {} + + int ssl_error() const { return ssl_error_; } + uint64_t ssl_backend_error() const { return ssl_backend_error_; } + +private: int ssl_error_ = 0; - unsigned long ssl_openssl_error_ = 0; + uint64_t ssl_backend_error_ = 0; #endif }; +struct ClientConnection { + socket_t sock = INVALID_SOCKET; + + bool is_open() const { return sock != INVALID_SOCKET; } + + ClientConnection() = default; + + ~ClientConnection(); + + ClientConnection(const ClientConnection &) = delete; + ClientConnection &operator=(const ClientConnection &) = delete; + + ClientConnection(ClientConnection &&other) noexcept + : sock(other.sock) +#ifdef CPPHTTPLIB_SSL_ENABLED + , + session(other.session) +#endif + { + other.sock = INVALID_SOCKET; +#ifdef CPPHTTPLIB_SSL_ENABLED + other.session = nullptr; +#endif + } + + ClientConnection &operator=(ClientConnection &&other) noexcept { + if (this != &other) { + sock = other.sock; + other.sock = INVALID_SOCKET; +#ifdef CPPHTTPLIB_SSL_ENABLED + session = other.session; + other.session = nullptr; +#endif + } + return *this; + } + +#ifdef CPPHTTPLIB_SSL_ENABLED + tls::session_t session = nullptr; +#endif +}; + +namespace detail { + +struct ChunkedDecoder; + +struct BodyReader { + Stream *stream = nullptr; + bool has_content_length = false; + size_t content_length = 0; + size_t payload_max_length = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; + size_t bytes_read = 0; + bool chunked = false; + bool eof = false; + std::unique_ptr chunked_decoder; + Error last_error = Error::Success; + + ssize_t read(char *buf, size_t len); + bool has_error() const { return last_error != Error::Success; } +}; + +inline ssize_t read_body_content(Stream *stream, BodyReader &br, char *buf, + size_t len) { + (void)stream; + return br.read(buf, len); +} + +class decompressor; + +enum class NoProxyKind { + Wildcard, // "*" + HostnameSuffix, // "example.com" or ".example.com" + IPv4Cidr, // "10.0.0.0/8" (or single IP, treated as /32) + IPv6Cidr, // "fe80::/10" (or single IP, treated as /128) +}; + +// Unified 16-byte buffer holding either a v4 (first 4 bytes) or v6 address. +// Lets one CIDR matcher cover both families. +using IPBytes = std::array; + +struct NoProxyEntry { + NoProxyKind kind = NoProxyKind::Wildcard; + std::string hostname_pattern; // lowercased, leading/trailing dot stripped + IPBytes net{}; + int prefix_bits = 0; +}; + +struct NormalizedTarget { + std::string hostname; // lowercase; brackets and trailing dot removed + bool is_ipv4 = false; + bool is_ipv6 = false; + IPBytes ip{}; +}; + +} // namespace detail + class ClientImpl { public: explicit ClientImpl(const std::string &host); @@ -1403,6 +2065,44 @@ public: virtual bool is_valid() const; + struct StreamHandle { + std::unique_ptr response; + Error error = Error::Success; + + StreamHandle() = default; + StreamHandle(const StreamHandle &) = delete; + StreamHandle &operator=(const StreamHandle &) = delete; + StreamHandle(StreamHandle &&) = default; + StreamHandle &operator=(StreamHandle &&) = default; + ~StreamHandle() = default; + + bool is_valid() const { + return response != nullptr && error == Error::Success; + } + + ssize_t read(char *buf, size_t len); + void parse_trailers_if_needed(); + Error get_read_error() const { return body_reader_.last_error; } + bool has_read_error() const { return body_reader_.has_error(); } + + bool trailers_parsed_ = false; + + private: + friend class ClientImpl; + + ssize_t read_with_decompression(char *buf, size_t len); + + std::unique_ptr connection_; + std::unique_ptr socket_stream_; + Stream *stream_ = nullptr; + detail::BodyReader body_reader_; + + std::unique_ptr decompressor_; + std::string decompress_buffer_; + size_t decompress_offset_ = 0; + size_t decompressed_bytes_read_ = 0; + }; + // clang-format off Result Get(const std::string &path, DownloadProgress progress = nullptr); Result Get(const std::string &path, ContentReceiver content_receiver, DownloadProgress progress = nullptr); @@ -1421,14 +2121,18 @@ public: Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Post(const std::string &path, const Params ¶ms); Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers); Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const Params ¶ms); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1439,14 +2143,18 @@ public: Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Params ¶ms); Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers); Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const Params ¶ms); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1457,14 +2165,18 @@ public: Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Params ¶ms); Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const Params ¶ms); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1484,6 +2196,15 @@ public: Result Options(const std::string &path, const Headers &headers); // clang-format on + // Streaming API: Open a stream for reading response body incrementally + // Socket ownership is transferred to StreamHandle for true streaming + // Supports all HTTP methods (GET, POST, PUT, PATCH, DELETE, etc.) + StreamHandle open_stream(const std::string &method, const std::string &path, + const Params ¶ms = {}, + const Headers &headers = {}, + const std::string &body = {}, + const std::string &content_type = {}); + bool send(Request &req, Response &res, Error &error); Result send(const Request &req); @@ -1526,10 +2247,6 @@ public: void set_basic_auth(const std::string &username, const std::string &password); void set_bearer_token_auth(const std::string &token); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_digest_auth(const std::string &username, - const std::string &password); -#endif void set_keep_alive(bool on); void set_follow_location(bool on); @@ -1540,30 +2257,15 @@ public: void set_decompress(bool on); + void set_payload_max_length(size_t length); + void set_interface(const std::string &intf); void set_proxy(const std::string &host, int port); void set_proxy_basic_auth(const std::string &username, const std::string &password); void set_proxy_bearer_token_auth(const std::string &token); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_proxy_digest_auth(const std::string &username, - const std::string &password); -#endif - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_ca_cert_path(const std::string &ca_cert_file_path, - const std::string &ca_cert_dir_path = std::string()); - void set_ca_cert_store(X509_STORE *ca_cert_store); - X509_STORE *create_ca_cert_store(const char *ca_cert, std::size_t size) const; -#endif - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void enable_server_certificate_verification(bool enabled); - void enable_server_hostname_verification(bool enabled); - void set_server_certificate_verifier( - std::function verifier); -#endif + void set_no_proxy(const std::vector &patterns); void set_logger(Logger logger); void set_error_logger(ErrorLogger error_logger); @@ -1571,25 +2273,37 @@ public: protected: struct Socket { socket_t sock = INVALID_SOCKET; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - SSL *ssl = nullptr; -#endif + + // For Mbed TLS compatibility: start_time for request timeout tracking + std::chrono::time_point start_time_; bool is_open() const { return sock != INVALID_SOCKET; } + +#ifdef CPPHTTPLIB_SSL_ENABLED + tls::session_t ssl = nullptr; +#endif }; virtual bool create_and_connect_socket(Socket &socket, Error &error); + virtual bool ensure_socket_connection(Socket &socket, Error &error); + virtual bool setup_proxy_connection( + Socket &socket, + std::chrono::time_point start_time, + Response &res, bool &success, Error &error); + + bool is_proxy_enabled_for_host(const std::string &host) const; // All of: // shutdown_ssl // shutdown_socket // close_socket - // should ONLY be called when socket_mutex_ is locked. - // Also, shutdown_ssl and close_socket should also NOT be called concurrently - // with a DIFFERENT thread sending requests using that socket. + // disconnect + // should ONLY be called when socket_mutex_ is locked, and only when + // no other thread is using the socket. virtual void shutdown_ssl(Socket &socket, bool shutdown_gracefully); void shutdown_socket(Socket &socket) const; void close_socket(Socket &socket); + void disconnect(bool gracefully); bool process_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error); @@ -1605,7 +2319,6 @@ protected: // Socket endpoint information const std::string host_; const int port_; - const std::string host_and_port_; // Current open socket Socket socket_; @@ -1642,10 +2355,6 @@ protected: std::string basic_auth_username_; std::string basic_auth_password_; std::string bearer_token_auth_token_; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - std::string digest_auth_username_; - std::string digest_auth_password_; -#endif bool keep_alive_ = false; bool follow_location_ = false; @@ -1660,6 +2369,9 @@ protected: bool compress_ = false; bool decompress_ = true; + size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; + bool has_payload_max_length_ = false; + std::string interface_; std::string proxy_host_; @@ -1668,42 +2380,28 @@ protected: std::string proxy_basic_auth_username_; std::string proxy_basic_auth_password_; std::string proxy_bearer_token_auth_token_; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - std::string proxy_digest_auth_username_; - std::string proxy_digest_auth_password_; -#endif -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - std::string ca_cert_file_path_; - std::string ca_cert_dir_path_; + std::vector no_proxy_entries_; - X509_STORE *ca_cert_store_ = nullptr; -#endif - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - bool server_certificate_verification_ = true; - bool server_hostname_verification_ = true; - std::function server_certificate_verifier_; -#endif + mutable detail::NormalizedTarget host_normalized_; + mutable bool host_normalized_valid_ = false; mutable std::mutex logger_mutex_; Logger logger_; ErrorLogger error_logger_; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - int last_ssl_error_ = 0; - unsigned long last_openssl_error_ = 0; -#endif - private: bool send_(Request &req, Response &res, Error &error); Result send_(Request &&req); socket_t create_client_socket(Error &error) const; - bool read_response_line(Stream &strm, const Request &req, - Response &res) const; + bool read_response_line(Stream &strm, const Request &req, Response &res, + bool skip_100_continue = true) const; bool write_request(Stream &strm, Request &req, bool close_connection, - Error &error); + Error &error, bool skip_body = false); + bool write_request_body(Stream &strm, Request &req, Error &error); + void prepare_default_headers(Request &r, bool for_stream, + const std::string &ct); bool redirect(Request &req, Response &res, Error &error); bool create_redirect_client(const std::string &scheme, const std::string &host, int port, Request &req, @@ -1712,17 +2410,19 @@ private: template void setup_redirect_client(ClientType &client); bool handle_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error); - std::unique_ptr send_with_content_provider( + std::unique_ptr send_with_content_provider_and_receiver( Request &req, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type, Error &error); - Result send_with_content_provider( + const std::string &content_type, ContentReceiver content_receiver, + Error &error); + Result send_with_content_provider_and_receiver( const std::string &method, const std::string &path, const Headers &headers, const char *body, size_t content_length, ContentProvider content_provider, ContentProviderWithoutLength content_provider_without_length, - const std::string &content_type, UploadProgress progress); + const std::string &content_type, ContentReceiver content_receiver, + UploadProgress progress); ContentProviderWithoutLength get_multipart_content_provider( const std::string &boundary, const UploadFormDataItems &items, const FormDataProviderItems &provider_items) const; @@ -1732,6 +2432,33 @@ private: std::chrono::time_point start_time, std::function callback); virtual bool is_ssl() const; + + void transfer_socket_ownership_to_handle(StreamHandle &handle); + +#ifdef CPPHTTPLIB_SSL_ENABLED +public: + void set_digest_auth(const std::string &username, + const std::string &password); + void set_proxy_digest_auth(const std::string &username, + const std::string &password); + void set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path = std::string()); + void enable_server_certificate_verification(bool enabled); + void enable_server_hostname_verification(bool enabled); + +protected: + std::string digest_auth_username_; + std::string digest_auth_password_; + std::string proxy_digest_auth_username_; + std::string proxy_digest_auth_password_; + std::string ca_cert_file_path_; + std::string ca_cert_dir_path_; + bool server_certificate_verification_ = true; + bool server_hostname_verification_ = true; + std::string ca_cert_pem_; // Store CA cert PEM for redirect transfer + int last_ssl_error_ = 0; + uint64_t last_backend_error_ = 0; +#endif }; class Client { @@ -1775,14 +2502,18 @@ public: Result Post(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Post(const std::string &path, const Params ¶ms); Result Post(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers); Result Post(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Post(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, DownloadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const Params ¶ms); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Post(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1793,14 +2524,18 @@ public: Result Put(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Params ¶ms); Result Put(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers); Result Put(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Put(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const Params ¶ms); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Put(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1811,14 +2546,18 @@ public: Result Patch(const std::string &path, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Params ¶ms); Result Patch(const std::string &path, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers); Result Patch(const std::string &path, const Headers &headers, const char *body, size_t content_length, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const std::string &body, const std::string &content_type, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, size_t content_length, ContentProvider content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, UploadProgress progress = nullptr); + Result Patch(const std::string &path, const Headers &headers, ContentProviderWithoutLength content_provider, const std::string &content_type, ContentReceiver content_receiver, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const Params ¶ms); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, UploadProgress progress = nullptr); Result Patch(const std::string &path, const Headers &headers, const UploadFormDataItems &items, const std::string &boundary, UploadProgress progress = nullptr); @@ -1838,6 +2577,16 @@ public: Result Options(const std::string &path, const Headers &headers); // clang-format on + // Streaming API: Open a stream for reading response body incrementally + // Socket ownership is transferred to StreamHandle for true streaming + // Supports all HTTP methods (GET, POST, PUT, PATCH, DELETE, etc.) + ClientImpl::StreamHandle open_stream(const std::string &method, + const std::string &path, + const Params ¶ms = {}, + const Headers &headers = {}, + const std::string &body = {}, + const std::string &content_type = {}); + bool send(Request &req, Response &res, Error &error); Result send(const Request &req); @@ -1879,64 +2628,62 @@ public: void set_basic_auth(const std::string &username, const std::string &password); void set_bearer_token_auth(const std::string &token); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_digest_auth(const std::string &username, - const std::string &password); -#endif void set_keep_alive(bool on); void set_follow_location(bool on); void set_path_encode(bool on); - void set_url_encode(bool on); void set_compress(bool on); void set_decompress(bool on); + void set_payload_max_length(size_t length); + void set_interface(const std::string &intf); void set_proxy(const std::string &host, int port); void set_proxy_basic_auth(const std::string &username, const std::string &password); void set_proxy_bearer_token_auth(const std::string &token); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_proxy_digest_auth(const std::string &username, - const std::string &password); -#endif - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void enable_server_certificate_verification(bool enabled); - void enable_server_hostname_verification(bool enabled); - void set_server_certificate_verifier( - std::function verifier); -#endif - + void set_no_proxy(const std::vector &patterns); void set_logger(Logger logger); void set_error_logger(ErrorLogger error_logger); - // SSL -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_ca_cert_path(const std::string &ca_cert_file_path, - const std::string &ca_cert_dir_path = std::string()); - - void set_ca_cert_store(X509_STORE *ca_cert_store); - void load_ca_cert_store(const char *ca_cert, std::size_t size); - - long get_openssl_verify_result() const; - - SSL_CTX *ssl_context() const; -#endif - private: std::unique_ptr cli_; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED +public: + void set_digest_auth(const std::string &username, + const std::string &password); + void set_proxy_digest_auth(const std::string &username, + const std::string &password); + void enable_server_certificate_verification(bool enabled); + void enable_server_hostname_verification(bool enabled); + void set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path = std::string()); + + void set_ca_cert_store(tls::ca_store_t ca_cert_store); + void load_ca_cert_store(const char *ca_cert, std::size_t size); + + void set_server_certificate_verifier(tls::VerifyCallback verifier); + + void set_session_verifier( + std::function verifier); + + tls::ctx_t tls_context() const; + +#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE + void enable_windows_certificate_verification(bool enabled); +#endif + +private: bool is_ssl_ = false; #endif }; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED class SSLServer : public Server { public: SSLServer(const char *cert_path, const char *private_key_path, @@ -1944,29 +2691,38 @@ public: const char *client_ca_cert_dir_path = nullptr, const char *private_key_password = nullptr); - SSLServer(X509 *cert, EVP_PKEY *private_key, - X509_STORE *client_ca_cert_store = nullptr); + struct PemMemory { + const char *cert_pem; + size_t cert_pem_len; + const char *key_pem; + size_t key_pem_len; + const char *client_ca_pem; + size_t client_ca_pem_len; + const char *private_key_password; + }; + explicit SSLServer(const PemMemory &pem); - SSLServer( - const std::function &setup_ssl_ctx_callback); + // The callback receives the ctx_t handle which can be cast to the + // appropriate backend type (SSL_CTX* for OpenSSL, + // tls::impl::MbedTlsContext* for Mbed TLS) + explicit SSLServer(const tls::ContextSetupCallback &setup_callback); ~SSLServer() override; bool is_valid() const override; - SSL_CTX *ssl_context() const; + bool update_certs_pem(const char *cert_pem, const char *key_pem, + const char *client_ca_pem = nullptr, + const char *password = nullptr); - void update_certs(X509 *cert, EVP_PKEY *private_key, - X509_STORE *client_ca_cert_store = nullptr); + tls::ctx_t tls_context() const { return ctx_; } int ssl_last_error() const { return last_ssl_error_; } private: bool process_and_close_socket(socket_t sock) override; - STACK_OF(X509_NAME) * extract_ca_names_from_x509_store(X509_STORE *store); - - SSL_CTX *ctx_; + tls::ctx_t ctx_ = nullptr; std::mutex ctx_mutex_; int last_ssl_error_ = 0; @@ -1983,23 +2739,37 @@ public: const std::string &client_key_path, const std::string &private_key_password = std::string()); - explicit SSLClient(const std::string &host, int port, X509 *client_cert, - EVP_PKEY *client_key, - const std::string &private_key_password = std::string()); + struct PemMemory { + const char *cert_pem; + size_t cert_pem_len; + const char *key_pem; + size_t key_pem_len; + const char *private_key_password; + }; + explicit SSLClient(const std::string &host, int port, const PemMemory &pem); ~SSLClient() override; bool is_valid() const override; - void set_ca_cert_store(X509_STORE *ca_cert_store); + void set_ca_cert_store(tls::ca_store_t ca_cert_store); void load_ca_cert_store(const char *ca_cert, std::size_t size); - long get_openssl_verify_result() const; + void set_server_certificate_verifier(tls::VerifyCallback verifier); - SSL_CTX *ssl_context() const; + // Post-handshake session verifier (backend-independent) + void set_session_verifier( + std::function verifier); + + tls::ctx_t tls_context() const { return ctx_; } + +#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE + void enable_windows_certificate_verification(bool enabled); +#endif private: bool create_and_connect_socket(Socket &socket, Error &error) override; + bool ensure_socket_connection(Socket &socket, Error &error) override; void shutdown_ssl(Socket &socket, bool shutdown_gracefully) override; void shutdown_ssl_impl(Socket &socket, bool shutdown_gracefully); @@ -2009,34 +2779,43 @@ private: std::function callback) override; bool is_ssl() const override; + bool setup_proxy_connection( + Socket &socket, + std::chrono::time_point start_time, + Response &res, bool &success, Error &error) override; bool connect_with_proxy( Socket &sock, std::chrono::time_point start_time, Response &res, bool &success, Error &error); bool initialize_ssl(Socket &socket, Error &error); + void init_ctx(); + void reset_ctx_on_error(); + bool load_certs(); - bool verify_host(X509 *server_cert) const; - bool verify_host_with_subject_alt_name(X509 *server_cert) const; - bool verify_host_with_common_name(X509 *server_cert) const; - bool check_host_name(const char *pattern, size_t pattern_len) const; - - SSL_CTX *ctx_; + tls::ctx_t ctx_ = nullptr; std::mutex ctx_mutex_; std::once_flag initialize_cert_; - std::vector host_components_; - long verify_result_ = 0; - friend class ClientImpl; -}; + std::function session_verifier_; + +#ifdef CPPHTTPLIB_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE + bool enable_windows_cert_verification_ = true; #endif -/* - * Implementation of template methods. - */ + friend class ClientImpl; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +private: + bool verify_host(X509 *server_cert) const; + bool verify_host_with_subject_alt_name(X509 *server_cert) const; + bool verify_host_with_common_name(X509 *server_cert) const; +#endif +}; +#endif // CPPHTTPLIB_SSL_ENABLED namespace detail { @@ -2068,7 +2847,7 @@ inline size_t get_header_value_u64(const Headers &headers, std::advance(it, static_cast(id)); if (it != rng.second) { if (is_numeric(it->second)) { - return std::strtoull(it->second.data(), nullptr, 10); + return static_cast(std::strtoull(it->second.data(), nullptr, 10)); } else { is_invalid_value = true; } @@ -2085,142 +2864,6 @@ inline size_t get_header_value_u64(const Headers &headers, } // namespace detail -inline size_t Request::get_header_value_u64(const std::string &key, size_t def, - size_t id) const { - return detail::get_header_value_u64(headers, key, def, id); -} - -inline size_t Response::get_header_value_u64(const std::string &key, size_t def, - size_t id) const { - return detail::get_header_value_u64(headers, key, def, id); -} - -namespace detail { - -inline bool set_socket_opt_impl(socket_t sock, int level, int optname, - const void *optval, socklen_t optlen) { - return setsockopt(sock, level, optname, -#ifdef _WIN32 - reinterpret_cast(optval), -#else - optval, -#endif - optlen) == 0; -} - -inline bool set_socket_opt(socket_t sock, int level, int optname, int optval) { - return set_socket_opt_impl(sock, level, optname, &optval, sizeof(optval)); -} - -inline bool set_socket_opt_time(socket_t sock, int level, int optname, - time_t sec, time_t usec) { -#ifdef _WIN32 - auto timeout = static_cast(sec * 1000 + usec / 1000); -#else - timeval timeout; - timeout.tv_sec = static_cast(sec); - timeout.tv_usec = static_cast(usec); -#endif - return set_socket_opt_impl(sock, level, optname, &timeout, sizeof(timeout)); -} - -} // namespace detail - -inline void default_socket_options(socket_t sock) { - detail::set_socket_opt(sock, SOL_SOCKET, -#ifdef SO_REUSEPORT - SO_REUSEPORT, -#else - SO_REUSEADDR, -#endif - 1); -} - -inline const char *status_message(int status) { - switch (status) { - case StatusCode::Continue_100: return "Continue"; - case StatusCode::SwitchingProtocol_101: return "Switching Protocol"; - case StatusCode::Processing_102: return "Processing"; - case StatusCode::EarlyHints_103: return "Early Hints"; - case StatusCode::OK_200: return "OK"; - case StatusCode::Created_201: return "Created"; - case StatusCode::Accepted_202: return "Accepted"; - case StatusCode::NonAuthoritativeInformation_203: - return "Non-Authoritative Information"; - case StatusCode::NoContent_204: return "No Content"; - case StatusCode::ResetContent_205: return "Reset Content"; - case StatusCode::PartialContent_206: return "Partial Content"; - case StatusCode::MultiStatus_207: return "Multi-Status"; - case StatusCode::AlreadyReported_208: return "Already Reported"; - case StatusCode::IMUsed_226: return "IM Used"; - case StatusCode::MultipleChoices_300: return "Multiple Choices"; - case StatusCode::MovedPermanently_301: return "Moved Permanently"; - case StatusCode::Found_302: return "Found"; - case StatusCode::SeeOther_303: return "See Other"; - case StatusCode::NotModified_304: return "Not Modified"; - case StatusCode::UseProxy_305: return "Use Proxy"; - case StatusCode::unused_306: return "unused"; - case StatusCode::TemporaryRedirect_307: return "Temporary Redirect"; - case StatusCode::PermanentRedirect_308: return "Permanent Redirect"; - case StatusCode::BadRequest_400: return "Bad Request"; - case StatusCode::Unauthorized_401: return "Unauthorized"; - case StatusCode::PaymentRequired_402: return "Payment Required"; - case StatusCode::Forbidden_403: return "Forbidden"; - case StatusCode::NotFound_404: return "Not Found"; - case StatusCode::MethodNotAllowed_405: return "Method Not Allowed"; - case StatusCode::NotAcceptable_406: return "Not Acceptable"; - case StatusCode::ProxyAuthenticationRequired_407: - return "Proxy Authentication Required"; - case StatusCode::RequestTimeout_408: return "Request Timeout"; - case StatusCode::Conflict_409: return "Conflict"; - case StatusCode::Gone_410: return "Gone"; - case StatusCode::LengthRequired_411: return "Length Required"; - case StatusCode::PreconditionFailed_412: return "Precondition Failed"; - case StatusCode::PayloadTooLarge_413: return "Payload Too Large"; - case StatusCode::UriTooLong_414: return "URI Too Long"; - case StatusCode::UnsupportedMediaType_415: return "Unsupported Media Type"; - case StatusCode::RangeNotSatisfiable_416: return "Range Not Satisfiable"; - case StatusCode::ExpectationFailed_417: return "Expectation Failed"; - case StatusCode::ImATeapot_418: return "I'm a teapot"; - case StatusCode::MisdirectedRequest_421: return "Misdirected Request"; - case StatusCode::UnprocessableContent_422: return "Unprocessable Content"; - case StatusCode::Locked_423: return "Locked"; - case StatusCode::FailedDependency_424: return "Failed Dependency"; - case StatusCode::TooEarly_425: return "Too Early"; - case StatusCode::UpgradeRequired_426: return "Upgrade Required"; - case StatusCode::PreconditionRequired_428: return "Precondition Required"; - case StatusCode::TooManyRequests_429: return "Too Many Requests"; - case StatusCode::RequestHeaderFieldsTooLarge_431: - return "Request Header Fields Too Large"; - case StatusCode::UnavailableForLegalReasons_451: - return "Unavailable For Legal Reasons"; - case StatusCode::NotImplemented_501: return "Not Implemented"; - case StatusCode::BadGateway_502: return "Bad Gateway"; - case StatusCode::ServiceUnavailable_503: return "Service Unavailable"; - case StatusCode::GatewayTimeout_504: return "Gateway Timeout"; - case StatusCode::HttpVersionNotSupported_505: - return "HTTP Version Not Supported"; - case StatusCode::VariantAlsoNegotiates_506: return "Variant Also Negotiates"; - case StatusCode::InsufficientStorage_507: return "Insufficient Storage"; - case StatusCode::LoopDetected_508: return "Loop Detected"; - case StatusCode::NotExtended_510: return "Not Extended"; - case StatusCode::NetworkAuthenticationRequired_511: - return "Network Authentication Required"; - - default: - case StatusCode::InternalServerError_500: return "Internal Server Error"; - } -} - -inline std::string get_bearer_token_auth(const Request &req) { - if (req.has_header("Authorization")) { - constexpr auto bearer_header_prefix_len = detail::str_len("Bearer "); - return req.get_header_value("Authorization") - .substr(bearer_header_prefix_len); - } - return ""; -} - template inline Server & Server::set_read_timeout(const std::chrono::duration &duration) { @@ -2245,61 +2888,6 @@ Server::set_idle_interval(const std::chrono::duration &duration) { return *this; } -inline std::string to_string(const Error error) { - switch (error) { - case Error::Success: return "Success (no error)"; - case Error::Unknown: return "Unknown"; - case Error::Connection: return "Could not establish connection"; - case Error::BindIPAddress: return "Failed to bind IP address"; - case Error::Read: return "Failed to read connection"; - case Error::Write: return "Failed to write connection"; - case Error::ExceedRedirectCount: return "Maximum redirect count exceeded"; - case Error::Canceled: return "Connection handling canceled"; - case Error::SSLConnection: return "SSL connection failed"; - case Error::SSLLoadingCerts: return "SSL certificate loading failed"; - case Error::SSLServerVerification: return "SSL server verification failed"; - case Error::SSLServerHostnameVerification: - return "SSL server hostname verification failed"; - case Error::UnsupportedMultipartBoundaryChars: - return "Unsupported HTTP multipart boundary characters"; - case Error::Compression: return "Compression failed"; - case Error::ConnectionTimeout: return "Connection timed out"; - case Error::ProxyConnection: return "Proxy connection failed"; - case Error::ResourceExhaustion: return "Resource exhaustion"; - case Error::TooManyFormDataFiles: return "Too many form data files"; - case Error::ExceedMaxPayloadSize: return "Exceeded maximum payload size"; - case Error::ExceedUriMaxLength: return "Exceeded maximum URI length"; - case Error::ExceedMaxSocketDescriptorCount: - return "Exceeded maximum socket descriptor count"; - case Error::InvalidRequestLine: return "Invalid request line"; - case Error::InvalidHTTPMethod: return "Invalid HTTP method"; - case Error::InvalidHTTPVersion: return "Invalid HTTP version"; - case Error::InvalidHeaders: return "Invalid headers"; - case Error::MultipartParsing: return "Multipart parsing failed"; - case Error::OpenFile: return "Failed to open file"; - case Error::Listen: return "Failed to listen on socket"; - case Error::GetSockName: return "Failed to get socket name"; - case Error::UnsupportedAddressFamily: return "Unsupported address family"; - case Error::HTTPParsing: return "HTTP parsing failed"; - case Error::InvalidRangeHeader: return "Invalid Range header"; - default: break; - } - - return "Invalid"; -} - -inline std::ostream &operator<<(std::ostream &os, const Error &obj) { - os << to_string(obj); - os << " (" << static_cast::type>(obj) << ')'; - return os; -} - -inline size_t Result::get_request_header_value_u64(const std::string &key, - size_t def, - size_t id) const { - return detail::get_header_value_u64(request_headers_, key, def, id); -} - template inline void ClientImpl::set_connection_timeout( const std::chrono::duration &duration) { @@ -2381,6 +2969,8 @@ std::string encode_query_component(const std::string &component, std::string decode_query_component(const std::string &component, bool plus_as_space = true); +std::string sanitize_filename(const std::string &filename); + std::string append_query_params(const std::string &path, const Params ¶ms); std::pair make_range_header(const Ranges &ranges); @@ -2394,16 +2984,20 @@ namespace detail { #if defined(_WIN32) inline std::wstring u8string_to_wstring(const char *s) { - std::wstring ws; + if (!s) { return std::wstring(); } + auto len = static_cast(strlen(s)); + if (!len) { return std::wstring(); } + auto wlen = ::MultiByteToWideChar(CP_UTF8, 0, s, len, nullptr, 0); - if (wlen > 0) { - ws.resize(wlen); - wlen = ::MultiByteToWideChar( - CP_UTF8, 0, s, len, - const_cast(reinterpret_cast(ws.data())), wlen); - if (wlen != static_cast(ws.size())) { ws.clear(); } - } + if (!wlen) { return std::wstring(); } + + std::wstring ws; + ws.resize(wlen); + wlen = ::MultiByteToWideChar( + CP_UTF8, 0, s, len, + const_cast(reinterpret_cast(ws.data())), wlen); + if (wlen != static_cast(ws.size())) { ws.clear(); } return ws; } #endif @@ -2412,6 +3006,8 @@ struct FileStat { FileStat(const std::string &path); bool is_file() const; bool is_dir() const; + time_t mtime() const; + size_t size() const; private: #if defined(_WIN32) @@ -2422,6 +3018,9 @@ private: int ret_ = -1; }; +std::string make_host_and_port_string(const std::string &host, int port, + bool is_ssl); + std::string trim_copy(const std::string &s); void divide( @@ -2474,8 +3073,6 @@ bool parse_range_header(const std::string &s, Ranges &ranges); bool parse_accept_header(const std::string &s, std::vector &content_types); -int close_socket(socket_t sock); - ssize_t send_socket(socket_t sock, const void *ptr, size_t size, int flags); ssize_t read_socket(socket_t sock, void *ptr, size_t size, int flags); @@ -2642,6 +3239,25 @@ private: std::string growable_buffer_; }; +bool parse_trailers(stream_line_reader &line_reader, Headers &dest, + const Headers &src_headers); + +struct ChunkedDecoder { + Stream &strm; + size_t chunk_remaining = 0; + bool finished = false; + char line_buf[64]; + size_t last_chunk_total = 0; + size_t last_chunk_offset = 0; + + explicit ChunkedDecoder(Stream &s); + + ssize_t read_payload(char *buf, size_t len, size_t &out_chunk_offset, + size_t &out_chunk_total); + + bool parse_trailers_into(Headers &dest, const Headers &src_headers); +}; + class mmap { public: mmap(const char *path); @@ -2669,59 +3285,617 @@ private: // NOTE: https://www.rfc-editor.org/rfc/rfc9110#section-5 namespace fields { -inline bool is_token_char(char c) { - return std::isalnum(c) || c == '!' || c == '#' || c == '$' || c == '%' || - c == '&' || c == '\'' || c == '*' || c == '+' || c == '-' || - c == '.' || c == '^' || c == '_' || c == '`' || c == '|' || c == '~'; -} - -inline bool is_token(const std::string &s) { - if (s.empty()) { return false; } - for (auto c : s) { - if (!is_token_char(c)) { return false; } - } - return true; -} - -inline bool is_field_name(const std::string &s) { return is_token(s); } - -inline bool is_vchar(char c) { return c >= 33 && c <= 126; } - -inline bool is_obs_text(char c) { return 128 <= static_cast(c); } - -inline bool is_field_vchar(char c) { return is_vchar(c) || is_obs_text(c); } - -inline bool is_field_content(const std::string &s) { - if (s.empty()) { return true; } - - if (s.size() == 1) { - return is_field_vchar(s[0]); - } else if (s.size() == 2) { - return is_field_vchar(s[0]) && is_field_vchar(s[1]); - } else { - size_t i = 0; - - if (!is_field_vchar(s[i])) { return false; } - i++; - - while (i < s.size() - 1) { - auto c = s[i++]; - if (c == ' ' || c == '\t' || is_field_vchar(c)) { - } else { - return false; - } - } - - return is_field_vchar(s[i]); - } -} - -inline bool is_field_value(const std::string &s) { return is_field_content(s); } +bool is_token_char(char c); +bool is_token(const std::string &s); +bool is_field_name(const std::string &s); +bool is_vchar(char c); +bool is_obs_text(char c); +bool is_field_vchar(char c); +bool is_field_content(const std::string &s); +bool is_field_value(const std::string &s); } // namespace fields - } // namespace detail +/* + * TLS Abstraction Layer Declarations + */ + +#ifdef CPPHTTPLIB_SSL_ENABLED +// TLS abstraction layer - backend-specific type declarations +#ifdef CPPHTTPLIB_MBEDTLS_SUPPORT +namespace tls { +namespace impl { + +// Mbed TLS context wrapper (holds config, entropy, DRBG, CA chain, own +// cert/key). This struct is accessible via tls::impl for use in SSL context +// setup callbacks (cast ctx_t to tls::impl::MbedTlsContext*). +struct MbedTlsContext { + mbedtls_ssl_config conf; + mbedtls_entropy_context entropy; + mbedtls_ctr_drbg_context ctr_drbg; + mbedtls_x509_crt ca_chain; + mbedtls_x509_crt own_cert; + mbedtls_pk_context own_key; + bool is_server = false; + bool verify_client = false; + bool has_verify_callback = false; + + MbedTlsContext(); + ~MbedTlsContext(); + + MbedTlsContext(const MbedTlsContext &) = delete; + MbedTlsContext &operator=(const MbedTlsContext &) = delete; +}; + +} // namespace impl +} // namespace tls +#endif + +#ifdef CPPHTTPLIB_WOLFSSL_SUPPORT +namespace tls { +namespace impl { + +// wolfSSL context wrapper (holds WOLFSSL_CTX and related state). +// This struct is accessible via tls::impl for use in SSL context +// setup callbacks (cast ctx_t to tls::impl::WolfSSLContext*). +struct WolfSSLContext { + WOLFSSL_CTX *ctx = nullptr; + bool is_server = false; + bool verify_client = false; + bool has_verify_callback = false; + std::string ca_pem_data_; // accumulated PEM for get_ca_names/get_ca_certs + + WolfSSLContext(); + ~WolfSSLContext(); + + WolfSSLContext(const WolfSSLContext &) = delete; + WolfSSLContext &operator=(const WolfSSLContext &) = delete; +}; + +// CA store for wolfSSL: holds raw PEM bytes to allow reloading into any ctx +struct WolfSSLCAStore { + std::string pem_data; +}; + +} // namespace impl +} // namespace tls +#endif + +#endif // CPPHTTPLIB_SSL_ENABLED + +namespace stream { + +class Result { +public: + Result(); + explicit Result(ClientImpl::StreamHandle &&handle, size_t chunk_size = 8192); + Result(Result &&other) noexcept; + Result &operator=(Result &&other) noexcept; + Result(const Result &) = delete; + Result &operator=(const Result &) = delete; + + // Response info + bool is_valid() const; + explicit operator bool() const; + int status() const; + const Headers &headers() const; + std::string get_header_value(const std::string &key, + const char *def = "") const; + bool has_header(const std::string &key) const; + Error error() const; + Error read_error() const; + bool has_read_error() const; + + // Stream reading + bool next(); + const char *data() const; + size_t size() const; + std::string read_all(); + +private: + ClientImpl::StreamHandle handle_; + std::string buffer_; + size_t current_size_ = 0; + size_t chunk_size_; + bool finished_ = false; +}; + +// GET +template +inline Result Get(ClientType &cli, const std::string &path, + size_t chunk_size = 8192) { + return Result{cli.open_stream("GET", path), chunk_size}; +} + +template +inline Result Get(ClientType &cli, const std::string &path, + const Headers &headers, size_t chunk_size = 8192) { + return Result{cli.open_stream("GET", path, {}, headers), chunk_size}; +} + +template +inline Result Get(ClientType &cli, const std::string &path, + const Params ¶ms, size_t chunk_size = 8192) { + return Result{cli.open_stream("GET", path, params), chunk_size}; +} + +template +inline Result Get(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + size_t chunk_size = 8192) { + return Result{cli.open_stream("GET", path, params, headers), chunk_size}; +} + +// POST +template +inline Result Post(ClientType &cli, const std::string &path, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{cli.open_stream("POST", path, {}, {}, body, content_type), + chunk_size}; +} + +template +inline Result Post(ClientType &cli, const std::string &path, + const Headers &headers, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("POST", path, {}, headers, body, content_type), + chunk_size}; +} + +template +inline Result Post(ClientType &cli, const std::string &path, + const Params ¶ms, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("POST", path, params, {}, body, content_type), + chunk_size}; +} + +template +inline Result Post(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{ + cli.open_stream("POST", path, params, headers, body, content_type), + chunk_size}; +} + +// PUT +template +inline Result Put(ClientType &cli, const std::string &path, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{cli.open_stream("PUT", path, {}, {}, body, content_type), + chunk_size}; +} + +template +inline Result Put(ClientType &cli, const std::string &path, + const Headers &headers, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("PUT", path, {}, headers, body, content_type), + chunk_size}; +} + +template +inline Result Put(ClientType &cli, const std::string &path, + const Params ¶ms, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("PUT", path, params, {}, body, content_type), + chunk_size}; +} + +template +inline Result Put(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{ + cli.open_stream("PUT", path, params, headers, body, content_type), + chunk_size}; +} + +// PATCH +template +inline Result Patch(ClientType &cli, const std::string &path, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{cli.open_stream("PATCH", path, {}, {}, body, content_type), + chunk_size}; +} + +template +inline Result Patch(ClientType &cli, const std::string &path, + const Headers &headers, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("PATCH", path, {}, headers, body, content_type), + chunk_size}; +} + +template +inline Result Patch(ClientType &cli, const std::string &path, + const Params ¶ms, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("PATCH", path, params, {}, body, content_type), + chunk_size}; +} + +template +inline Result Patch(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{ + cli.open_stream("PATCH", path, params, headers, body, content_type), + chunk_size}; +} + +// DELETE +template +inline Result Delete(ClientType &cli, const std::string &path, + size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path), chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Headers &headers, size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path, {}, headers), chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path, {}, {}, body, content_type), + chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Headers &headers, const std::string &body, + const std::string &content_type, + size_t chunk_size = 8192) { + return Result{ + cli.open_stream("DELETE", path, {}, headers, body, content_type), + chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Params ¶ms, size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path, params), chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path, params, headers), chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Params ¶ms, const std::string &body, + const std::string &content_type, + size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path, params, {}, body, content_type), + chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{ + cli.open_stream("DELETE", path, params, headers, body, content_type), + chunk_size}; +} + +// HEAD +template +inline Result Head(ClientType &cli, const std::string &path, + size_t chunk_size = 8192) { + return Result{cli.open_stream("HEAD", path), chunk_size}; +} + +template +inline Result Head(ClientType &cli, const std::string &path, + const Headers &headers, size_t chunk_size = 8192) { + return Result{cli.open_stream("HEAD", path, {}, headers), chunk_size}; +} + +template +inline Result Head(ClientType &cli, const std::string &path, + const Params ¶ms, size_t chunk_size = 8192) { + return Result{cli.open_stream("HEAD", path, params), chunk_size}; +} + +template +inline Result Head(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + size_t chunk_size = 8192) { + return Result{cli.open_stream("HEAD", path, params, headers), chunk_size}; +} + +// OPTIONS +template +inline Result Options(ClientType &cli, const std::string &path, + size_t chunk_size = 8192) { + return Result{cli.open_stream("OPTIONS", path), chunk_size}; +} + +template +inline Result Options(ClientType &cli, const std::string &path, + const Headers &headers, size_t chunk_size = 8192) { + return Result{cli.open_stream("OPTIONS", path, {}, headers), chunk_size}; +} + +template +inline Result Options(ClientType &cli, const std::string &path, + const Params ¶ms, size_t chunk_size = 8192) { + return Result{cli.open_stream("OPTIONS", path, params), chunk_size}; +} + +template +inline Result Options(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + size_t chunk_size = 8192) { + return Result{cli.open_stream("OPTIONS", path, params, headers), chunk_size}; +} + +} // namespace stream + +namespace sse { + +struct SSEMessage { + std::string event; // Event type (default: "message") + std::string data; // Event payload + std::string id; // Event ID for Last-Event-ID header + + SSEMessage(); + void clear(); +}; + +class SSEClient { +public: + using MessageHandler = std::function; + using ErrorHandler = std::function; + using OpenHandler = std::function; + + SSEClient(Client &client, const std::string &path); + SSEClient(Client &client, const std::string &path, const Headers &headers); + ~SSEClient(); + + SSEClient(const SSEClient &) = delete; + SSEClient &operator=(const SSEClient &) = delete; + + // Event handlers + SSEClient &on_message(MessageHandler handler); + SSEClient &on_event(const std::string &type, MessageHandler handler); + SSEClient &on_open(OpenHandler handler); + SSEClient &on_error(ErrorHandler handler); + SSEClient &set_reconnect_interval(int ms); + SSEClient &set_max_reconnect_attempts(int n); + + // Update headers (thread-safe) + SSEClient &set_headers(const Headers &headers); + + // State accessors + bool is_connected() const; + const std::string &last_event_id() const; + + // Blocking start - runs event loop with auto-reconnect + void start(); + + // Non-blocking start - runs in background thread + void start_async(); + + // Stop the client (thread-safe) + void stop(); + +private: + bool parse_sse_line(const std::string &line, SSEMessage &msg, int &retry_ms); + void run_event_loop(); + void dispatch_event(const SSEMessage &msg); + bool should_reconnect(int count) const; + void wait_for_reconnect(); + + // Client and path + Client &client_; + std::string path_; + Headers headers_; + mutable std::mutex headers_mutex_; + + // Callbacks + MessageHandler on_message_; + std::map event_handlers_; + OpenHandler on_open_; + ErrorHandler on_error_; + + // Configuration + int reconnect_interval_ms_ = 3000; + int max_reconnect_attempts_ = 0; // 0 = unlimited + + // State + std::atomic running_{false}; + std::atomic connected_{false}; + std::string last_event_id_; + + // Async support + std::thread async_thread_; +}; + +} // namespace sse + +namespace ws { + +enum class Opcode : uint8_t { + Continuation = 0x0, + Text = 0x1, + Binary = 0x2, + Close = 0x8, + Ping = 0x9, + Pong = 0xA, +}; + +enum class CloseStatus : uint16_t { + Normal = 1000, + GoingAway = 1001, + ProtocolError = 1002, + UnsupportedData = 1003, + NoStatus = 1005, + Abnormal = 1006, + InvalidPayload = 1007, + PolicyViolation = 1008, + MessageTooBig = 1009, + MandatoryExtension = 1010, + InternalError = 1011, +}; + +enum ReadResult : int { Fail = 0, Text = 1, Binary = 2 }; + +class WebSocket { +public: + WebSocket(const WebSocket &) = delete; + WebSocket &operator=(const WebSocket &) = delete; + ~WebSocket(); + + ReadResult read(std::string &msg); + bool send(const std::string &data); + bool send(const char *data, size_t len); + void close(CloseStatus status = CloseStatus::Normal, + const std::string &reason = ""); + const Request &request() const; + bool is_open() const; + +private: + friend class httplib::Server; + friend class WebSocketClient; + + WebSocket( + Stream &strm, const Request &req, bool is_server, + time_t ping_interval_sec = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND, + int max_missed_pongs = CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS) + : strm_(strm), req_(req), is_server_(is_server), + ping_interval_sec_(ping_interval_sec), + max_missed_pongs_(max_missed_pongs) { + start_heartbeat(); + } + + WebSocket( + std::unique_ptr &&owned_strm, const Request &req, bool is_server, + time_t ping_interval_sec = CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND, + int max_missed_pongs = CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS) + : strm_(*owned_strm), owned_strm_(std::move(owned_strm)), req_(req), + is_server_(is_server), ping_interval_sec_(ping_interval_sec), + max_missed_pongs_(max_missed_pongs) { + start_heartbeat(); + } + + void start_heartbeat(); + bool send_frame(Opcode op, const char *data, size_t len, bool fin = true); + + Stream &strm_; + std::unique_ptr owned_strm_; + Request req_; + bool is_server_; + time_t ping_interval_sec_; + int max_missed_pongs_; + int unacked_pings_ = 0; + std::atomic closed_{false}; + std::mutex write_mutex_; + std::thread ping_thread_; + std::mutex ping_mutex_; + std::condition_variable ping_cv_; +}; + +class WebSocketClient { +public: + explicit WebSocketClient(const std::string &scheme_host_port_path, + const Headers &headers = {}); + + ~WebSocketClient(); + WebSocketClient(const WebSocketClient &) = delete; + WebSocketClient &operator=(const WebSocketClient &) = delete; + + bool is_valid() const; + + bool connect(); + ReadResult read(std::string &msg); + bool send(const std::string &data); + bool send(const char *data, size_t len); + void close(CloseStatus status = CloseStatus::Normal, + const std::string &reason = ""); + bool is_open() const; + const std::string &subprotocol() const; + void set_read_timeout(time_t sec, time_t usec = 0); + void set_write_timeout(time_t sec, time_t usec = 0); + void set_websocket_ping_interval(time_t sec); + void set_websocket_max_missed_pongs(int count); + void set_tcp_nodelay(bool on); + void set_address_family(int family); + void set_ipv6_v6only(bool on); + void set_socket_options(SocketOptions socket_options); + void set_connection_timeout(time_t sec, time_t usec = 0); + void set_interface(const std::string &intf); + +#ifdef CPPHTTPLIB_SSL_ENABLED + void set_ca_cert_path(const std::string &path); + void set_ca_cert_store(tls::ca_store_t store); + void enable_server_certificate_verification(bool enabled); +#endif + +private: + void shutdown_and_close(); + bool create_stream(std::unique_ptr &strm); + + std::string host_; + int port_; + std::string path_; + Headers headers_; + std::string subprotocol_; + bool is_valid_ = false; + socket_t sock_ = INVALID_SOCKET; + std::unique_ptr ws_; + time_t read_timeout_sec_ = CPPHTTPLIB_WEBSOCKET_READ_TIMEOUT_SECOND; + time_t read_timeout_usec_ = 0; + time_t write_timeout_sec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_SECOND; + time_t write_timeout_usec_ = CPPHTTPLIB_CLIENT_WRITE_TIMEOUT_USECOND; + time_t websocket_ping_interval_sec_ = + CPPHTTPLIB_WEBSOCKET_PING_INTERVAL_SECOND; + int websocket_max_missed_pongs_ = CPPHTTPLIB_WEBSOCKET_MAX_MISSED_PONGS; + int address_family_ = AF_UNSPEC; + bool tcp_nodelay_ = CPPHTTPLIB_TCP_NODELAY; + bool ipv6_v6only_ = CPPHTTPLIB_IPV6_V6ONLY; + SocketOptions socket_options_ = nullptr; + time_t connection_timeout_sec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_SECOND; + time_t connection_timeout_usec_ = CPPHTTPLIB_CONNECTION_TIMEOUT_USECOND; + std::string interface_; + +#ifdef CPPHTTPLIB_SSL_ENABLED + bool is_ssl_ = false; + tls::ctx_t tls_ctx_ = nullptr; + tls::session_t tls_session_ = nullptr; + std::string ca_cert_file_path_; + tls::ca_store_t ca_cert_store_ = nullptr; + bool server_certificate_verification_ = true; +#endif +}; + +namespace impl { + +bool is_valid_utf8(const std::string &s); + +bool read_websocket_frame(Stream &strm, Opcode &opcode, std::string &payload, + bool &fin, bool expect_masked, size_t max_len); + +} // namespace impl + +} // namespace ws } // namespace httplib