Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix webserver thread safety #330

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 10 additions & 5 deletions src/httpserver/webserver.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include <map>
#include <memory>
#include <set>
#include <shared_mutex>
#include <string>

#include "httpserver/http_utils.hpp"
Expand Down Expand Up @@ -167,17 +168,21 @@ class webserver {
const std::string file_upload_dir;
const bool generate_random_filename_on_upload;
const bool deferred_enabled;
bool single_resource;
bool tcp_nodelay;
const bool single_resource;
const bool tcp_nodelay;
pthread_mutex_t mutexwait;
pthread_cond_t mutexcond;
render_ptr not_found_resource;
render_ptr method_not_allowed_resource;
render_ptr internal_error_resource;
const render_ptr not_found_resource;
const render_ptr method_not_allowed_resource;
const render_ptr internal_error_resource;
std::shared_mutex registered_resources_mutex;
std::map<details::http_endpoint, http_resource*> registered_resources;
std::map<std::string, http_resource*> registered_resources_str;

std::shared_mutex bans_mutex;
std::set<http::ip_representation> bans;

std::shared_mutex allowances_mutex;
std::set<http::ip_representation> allowances;

struct MHD_Daemon* daemon;
Expand Down
98 changes: 56 additions & 42 deletions src/webserver.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
#include <iosfwd>
#include <cstring>
#include <memory>
#include <mutex>
#include <stdexcept>
#include <utility>
#include <vector>
Expand Down Expand Up @@ -192,6 +193,7 @@ bool webserver::register_resource(const std::string& resource, http_resource* hr

details::http_endpoint idx(resource, family, true, regex_checking);

std::unique_lock registered_resources_lock(registered_resources_mutex);
pair<map<details::http_endpoint, http_resource*>::iterator, bool> result = registered_resources.insert(map<details::http_endpoint, http_resource*>::value_type(idx, hrm));

if (!family && result.second) {
Expand Down Expand Up @@ -361,12 +363,14 @@ bool webserver::stop() {
void webserver::unregister_resource(const string& resource) {
// family does not matter - it just checks the url_normalized anyhow
details::http_endpoint he(resource, false, true, regex_checking);
std::unique_lock registered_resources_lock(registered_resources_mutex);
registered_resources.erase(he);
registered_resources.erase(he.get_url_complete());
registered_resources_str.erase(he.get_url_complete());
}

void webserver::ban_ip(const string& ip) {
std::unique_lock bans_lock(bans_mutex);
ip_representation t_ip(ip);
set<ip_representation>::iterator it = bans.find(t_ip);
if (it != bans.end() && (t_ip.weight() < (*it).weight())) {
Expand All @@ -378,6 +382,7 @@ void webserver::ban_ip(const string& ip) {
}

void webserver::allow_ip(const string& ip) {
std::unique_lock allowances_lock(allowances_mutex);
ip_representation t_ip(ip);
set<ip_representation>::iterator it = allowances.find(t_ip);
if (it != allowances.end() && (t_ip.weight() < (*it).weight())) {
Expand All @@ -389,25 +394,31 @@ void webserver::allow_ip(const string& ip) {
}

void webserver::unban_ip(const string& ip) {
std::unique_lock bans_lock(bans_mutex);
bans.erase(ip_representation(ip));
}

void webserver::disallow_ip(const string& ip) {
std::unique_lock allowances_lock(allowances_mutex);
allowances.erase(ip_representation(ip));
}

MHD_Result policy_callback(void *cls, const struct sockaddr* addr, socklen_t addrlen) {
// Parameter needed to respect MHD interface, but not needed here.
std::ignore = addrlen;

if (!(static_cast<webserver*>(cls))->ban_system_enabled) return MHD_YES;
const auto ws = static_cast<webserver*>(cls);

if ((((static_cast<webserver*>(cls))->default_policy == http_utils::ACCEPT) &&
((static_cast<webserver*>(cls))->bans.count(ip_representation(addr))) &&
(!(static_cast<webserver*>(cls))->allowances.count(ip_representation(addr)))) ||
(((static_cast<webserver*>(cls))->default_policy == http_utils::REJECT) &&
((!(static_cast<webserver*>(cls))->allowances.count(ip_representation(addr))) ||
((static_cast<webserver*>(cls))->bans.count(ip_representation(addr)))))) {
if (!ws->ban_system_enabled) return MHD_YES;

std::shared_lock bans_lock(ws->bans_mutex);
std::shared_lock allowances_lock(ws->allowances_mutex);
const bool is_banned = ws->bans.count(ip_representation(addr));
const bool is_allowed = ws->allowances.count(ip_representation(addr));

if ((ws->default_policy == http_utils::ACCEPT && is_banned && !is_allowed) ||
(ws->default_policy == http_utils::REJECT && (!is_allowed || is_banned)))
{
return MHD_NO;
}

Expand Down Expand Up @@ -626,51 +637,54 @@ MHD_Result webserver::finalize_answer(MHD_Connection* connection, struct details

bool found = false;
struct MHD_Response* raw_response;
if (!single_resource) {
const char* st_url = mr->standardized_url->c_str();
fe = registered_resources_str.find(st_url);
if (fe == registered_resources_str.end()) {
if (regex_checking) {
map<details::http_endpoint, http_resource*>::iterator found_endpoint;

details::http_endpoint endpoint(st_url, false, false, false);

map<details::http_endpoint, http_resource*>::iterator it;

size_t len = 0;
size_t tot_len = 0;
for (it = registered_resources.begin(); it != registered_resources.end(); ++it) {
size_t endpoint_pieces_len = (*it).first.get_url_pieces().size();
size_t endpoint_tot_len = (*it).first.get_url_complete().size();
if (!found || endpoint_pieces_len > len || (endpoint_pieces_len == len && endpoint_tot_len > tot_len)) {
if ((*it).first.match(endpoint)) {
found = true;
len = endpoint_pieces_len;
tot_len = endpoint_tot_len;
found_endpoint = it;
{
std::shared_lock registered_resources_lock(registered_resources_mutex);
if (!single_resource) {
const char* st_url = mr->standardized_url->c_str();
fe = registered_resources_str.find(st_url);
if (fe == registered_resources_str.end()) {
if (regex_checking) {
map<details::http_endpoint, http_resource*>::iterator found_endpoint;

details::http_endpoint endpoint(st_url, false, false, false);

map<details::http_endpoint, http_resource*>::iterator it;

size_t len = 0;
size_t tot_len = 0;
for (it = registered_resources.begin(); it != registered_resources.end(); ++it) {
size_t endpoint_pieces_len = (*it).first.get_url_pieces().size();
size_t endpoint_tot_len = (*it).first.get_url_complete().size();
if (!found || endpoint_pieces_len > len || (endpoint_pieces_len == len && endpoint_tot_len > tot_len)) {
if ((*it).first.match(endpoint)) {
found = true;
len = endpoint_pieces_len;
tot_len = endpoint_tot_len;
found_endpoint = it;
}
}
}
}

if (found) {
vector<string> url_pars = found_endpoint->first.get_url_pars();
if (found) {
vector<string> url_pars = found_endpoint->first.get_url_pars();

vector<string> url_pieces = endpoint.get_url_pieces();
vector<int> chunks = found_endpoint->first.get_chunk_positions();
for (unsigned int i = 0; i < url_pars.size(); i++) {
mr->dhr->set_arg(url_pars[i], url_pieces[chunks[i]]);
}
vector<string> url_pieces = endpoint.get_url_pieces();
vector<int> chunks = found_endpoint->first.get_chunk_positions();
for (unsigned int i = 0; i < url_pars.size(); i++) {
mr->dhr->set_arg(url_pars[i], url_pieces[chunks[i]]);
}

hrm = found_endpoint->second;
hrm = found_endpoint->second;
}
}
} else {
hrm = fe->second;
found = true;
}
} else {
hrm = fe->second;
hrm = registered_resources.begin()->second;
found = true;
}
} else {
hrm = registered_resources.begin()->second;
found = true;
}

if (found) {
Expand Down
42 changes: 42 additions & 0 deletions test/integ/basic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@
*/

#include <curl/curl.h>
#include <atomic>
#include <map>
#include <memory>
#include <numeric>
#include <sstream>
#include <string>
#include <thread>

#include "./httpserver.hpp"
#include "httpserver/string_utilities.hpp"
Expand Down Expand Up @@ -1462,6 +1464,46 @@ LT_BEGIN_AUTO_TEST(basic_suite, method_not_allowed_header)
curl_easy_cleanup(curl);
LT_END_AUTO_TEST(method_not_allowed_header)

LT_BEGIN_AUTO_TEST(basic_suite, thread_safety)
simple_resource resource;

std::atomic_bool done = false;
auto register_thread = std::thread([&]() {
int i = 0;
using namespace std::chrono;
while (!done) {
ws->register_resource(
std::string("/route") + std::to_string(++i), &resource);
}
});

auto get_thread = std::thread([&](){
while (!done) {
CURL *curl = curl_easy_init();
std::string s;
std::string url = "localhost:" PORT_STRING "/route" + std::to_string(
(int)((rand() * 10000000.0) / RAND_MAX));
curl_easy_setopt(curl, CURLOPT_URL, url.c_str());
curl_easy_setopt(curl, CURLOPT_HTTPGET, 1L);
curl_easy_setopt(curl, CURLOPT_WRITEFUNCTION, writefunc);
curl_easy_setopt(curl, CURLOPT_WRITEDATA, &s);
curl_easy_perform(curl);
curl_easy_cleanup(curl);
}
});

using namespace std::chrono_literals;
std::this_thread::sleep_for(10s);
done = true;
if (register_thread.joinable()) {
register_thread.join();
}
if (get_thread.joinable()) {
get_thread.join();
}
LT_CHECK_EQ(1, 1);
LT_END_AUTO_TEST(thread_safety)

LT_BEGIN_AUTO_TEST_ENV()
AUTORUN_TESTS()
LT_END_AUTO_TEST_ENV()