I'm writing some cross-platform networking code, and have come across some inconsistent behavior in getnameinfo() on Windows and Linux (WSL).
The code below does the following:
Get an address using getaddrinfo().
Calls getnameinfo() on the address with:
NI_NAMEREQD set and not set.
NI_NUMERICHOST set and not set.
.
// INCLUDES
#if defined(PLATFORM_WINDOWS)
#include <winsock2.h>
#include <ws2tcpip.h>
#else
#include <cerrno>
#include <netdb.h>
#include <signal.h>
#include <sys/types.h>
#include <sys/socket.h>
#endif
#include <algorithm>
#include <cstring>
#include <functional>
#include <iostream>
#include <memory>
#include <string>
#include <vector>
// DEBUG UTILS
namespace debug
{
#if defined(PLATFORM_WINDOWS)
void die()
{
__debugbreak();
}
#else
void die()
{
raise(SIGTRAP);
}
#endif
void die_if(bool condition)
{
if (condition)
die();
}
} // debug
// NET CODE
enum class error_code
{
no_error,
host_not_found,
try_again,
out_of_memory,
buffer_overflow,
unrecoverable_error,
system_error,
};
char const* get_error_string(error_code ec)
{
switch (ec)
{
case error_code::no_error: return "no_error";
case error_code::host_not_found: return "host_not_found";
case error_code::try_again: return "try_again";
case error_code::out_of_memory: return "out_of_memory";
case error_code::buffer_overflow: return "buffer_overflow";
case error_code::unrecoverable_error: return "unrecoverable_error";
case error_code::system_error: return "system_error";
}
debug::die();
return nullptr;
}
namespace ip
{
enum class address_family
{
v4, v6, unspecified,
};
enum class protocol
{
tcp, udp,
};
} // ip
class platform_context
{
public:
#if defined (PLATFORM_WINDOWS)
platform_context()
{
auto data = WSADATA();
auto const result = WSAStartup(MAKEWORD(2, 2), &data);
debug::die_if(result != 0);
debug::die_if(LOBYTE(data.wVersion) != 2 || HIBYTE(data.wVersion) != 2);
}
~platform_context()
{
auto const result = WSACleanup();
debug::die_if(result != 0);
}
#else
platform_context() { }
~platform_context() { }
#endif
platform_context(platform_context const&) = delete;
platform_context operator=(platform_context const&) = delete;
platform_context(platform_context&&) = delete;
platform_context operator=(platform_context&&) = delete;
};
using addrinfo_ptr = std::unique_ptr<addrinfo, std::function<void(addrinfo*)>>;
int get_ai_family(ip::address_family family)
{
switch (family)
{
case ip::address_family::v4: return AF_INET;
case ip::address_family::v6: return AF_INET6;
case ip::address_family::unspecified: return AF_UNSPEC;
}
debug::die();
return AF_UNSPEC;
}
int get_ai_socktype(ip::protocol protocol)
{
switch (protocol)
{
case ip::protocol::tcp: return SOCK_STREAM;
case ip::protocol::udp: return SOCK_DGRAM;
}
debug::die();
return SOCK_STREAM;
}
int get_ai_protocol(ip::protocol protocol)
{
switch (protocol)
{
case ip::protocol::tcp: return IPPROTO_TCP;
case ip::protocol::udp: return IPPROTO_UDP;
}
debug::die();
return IPPROTO_TCP;
}
ip::address_family get_ip_address_family(int ai_family)
{
switch (ai_family)
{
case AF_INET: return ip::address_family::v4;
case AF_INET6: return ip::address_family::v6;
case AF_UNSPEC: return ip::address_family::unspecified;
}
debug::die();
return ip::address_family::unspecified;
}
struct end_point
{
explicit end_point(addrinfo const& info):
address_length(0),
address{ 0 }
{
debug::die_if(info.ai_addrlen < 0);
debug::die_if(info.ai_addrlen > sizeof(sockaddr_storage));
address_length = static_cast<std::size_t>(info.ai_addrlen);
std::memcpy(&address, info.ai_addr, address_length);
}
ip::address_family get_address_family() const
{
return get_ip_address_family(address.ss_family);
}
std::size_t address_length;
sockaddr_storage address;
};
std::vector<end_point> get_end_points(addrinfo_ptr const& info)
{
if (!info)
return {};
auto result = std::vector<end_point>();
auto ptr = info.get();
while (ptr)
{
result.emplace_back(*ptr);
ptr = ptr->ai_next;
}
return result;
}
addrinfo_ptr get_address(error_code&, char const* node, char const* service, ip::address_family family, ip::protocol protocol, int flags)
{
debug::die_if(!node && !service);
auto hints = addrinfo();
std::memset(&hints, 0, sizeof(hints));
hints.ai_family = get_ai_family(family);
hints.ai_socktype = get_ai_socktype(protocol);
hints.ai_protocol = get_ai_protocol(protocol);
hints.ai_flags = flags;
auto out = (addrinfo*) nullptr;
auto const result = ::getaddrinfo(node, service, &hints, &out);
// error handling ignored for this example
// (make sure you have internet for testing remote end points)
debug::die_if(result != 0);
debug::die_if(out == nullptr);
return addrinfo_ptr(out, std::bind(::freeaddrinfo, std::placeholders::_1));
}
std::vector<end_point> get_wildcard_address(error_code& ec, ip::address_family family, ip::protocol protocol)
{
return get_end_points(get_address(ec, nullptr, "0", family, protocol, AI_PASSIVE));
}
std::vector<end_point> get_loopback_address(error_code& ec, ip::address_family family, ip::protocol protocol)
{
return get_end_points(get_address(ec, nullptr, "0", family, protocol, 0));
}
std::vector<end_point> get_address(error_code& ec, std::string const& node, std::string const& service, ip::address_family family, ip::protocol protocol)
{
return get_end_points(get_address(ec, node.c_str(), service.c_str(), family, protocol, 0));
};
enum class name_type
{
numeric,
name,
};
#if defined(PLATFORM_WINDOWS)
error_code get_getnameinfo_error(int result)
{
debug::die_if(result == 0);
auto const error = WSAGetLastError();
debug::die_if(error == WSANOTINITIALISED);
debug::die_if(error == WSAEAFNOSUPPORT);
debug::die_if(error == WSAEINVAL);
debug::die_if(error == WSAEFAULT);
switch (error)
{
case WSAHOST_NOT_FOUND: return error_code::host_not_found;
case WSATRY_AGAIN: return error_code::try_again;
case WSA_NOT_ENOUGH_MEMORY: return error_code::out_of_memory;
case WSANO_RECOVERY: return error_code::unrecoverable_error;
}
debug::die();
return error_code::no_error;
}
std::size_t get_cstr_len(char const* string, std::size_t max)
{
return strnlen_s(string, max);
}
#else
error_code get_getnameinfo_error(int result)
{
debug::die_if(result == 0);
auto const error = result;
debug::die_if(error == EAI_FAMILY);
debug::die_if(error == EAI_BADFLAGS);
switch (error)
{
case EAI_NONAME: return error_code::host_not_found;
case EAI_AGAIN: return error_code::try_again;
case EAI_MEMORY: return error_code::out_of_memory;
case EAI_OVERFLOW: return error_code::buffer_overflow;
case EAI_FAIL: return error_code::unrecoverable_error;
case EAI_SYSTEM: return error_code::system_error;
}
debug::die();
return error_code::no_error;
}
std::size_t get_cstr_len(char const* string, std::size_t max)
{
return strnlen(string, max);
}
#endif
//////////////
bool get_node_name(error_code& ec, std::string& node, name_type node_type, end_point const& end_point, bool require_name)
{
auto const numeric_flag = (node_type == name_type::numeric ? NI_NUMERICHOST : 0);
auto const require_flag = (require_name ? NI_NAMEREQD : 0);
char node_buffer[NI_MAXHOST] = { 0 };
auto const result = ::getnameinfo((sockaddr const*)&end_point.address, (socklen_t)end_point.address_length, node_buffer, NI_MAXHOST, nullptr, 0, numeric_flag | require_flag);
if (result != 0)
{
ec = get_getnameinfo_error(result);
return false;
}
node.resize(get_cstr_len(node_buffer, NI_MAXHOST));
std::copy_n(node_buffer, node.size(), node.begin());
return true;
}
//////////////
// TEST CODE
void test_get_node_name(end_point const& e, name_type node_type, bool name_required)
{
auto ec = error_code::no_error;
auto node = std::string();
auto result = get_node_name(ec, node, node_type, e, name_required);
std::cout << "\t"
<< (name_required ? "required - " : "not required - ")
<< (node_type == name_type::numeric ? "numeric - " : "");
if (result)
std::cout << "success (node name: '" << node << "')";
else
std::cout << "failed! (error: " << get_error_string(ec) << ")";
std::cout << "\n";
}
int main()
{
platform_context context;
std::cout << "wildcard address:" << std::endl;
{
auto ec = error_code::no_error;
auto end_points = get_wildcard_address(ec, ip::address_family::unspecified, ip::protocol::tcp);
debug::die_if(end_points.empty());
test_get_node_name(end_points.front(), name_type::name, true);
test_get_node_name(end_points.front(), name_type::name, false);
test_get_node_name(end_points.front(), name_type::numeric, true);
test_get_node_name(end_points.front(), name_type::numeric, false);
}
std::cout << "loopback address:" << std::endl;
{
auto ec = error_code::no_error;
auto end_points = get_loopback_address(ec, ip::address_family::unspecified, ip::protocol::tcp);
debug::die_if(end_points.empty());
test_get_node_name(end_points.front(), name_type::name, true);
test_get_node_name(end_points.front(), name_type::name, false);
test_get_node_name(end_points.front(), name_type::numeric, true);
test_get_node_name(end_points.front(), name_type::numeric, false);
}
std::cout << "remote address:" << std::endl;
{
auto ec = error_code::no_error;
auto end_points = get_address(ec, "www.google.com", "443", ip::address_family::unspecified, ip::protocol::tcp);
debug::die_if(end_points.empty());
test_get_node_name(end_points.front(), name_type::name, true);
test_get_node_name(end_points.front(), name_type::name, false);
test_get_node_name(end_points.front(), name_type::numeric, true);
test_get_node_name(end_points.front(), name_type::numeric, false);
}
}
This can be compiled with cl main.cpp /DPLATFORM_WINDOWS /nologo /EHsc /W4 /WX ws2_32.lib on Windows, and g++ -Wall -Werror -std=c++17 -o main main.cpp on WSL.
I get the following output on my system:
Windows:
wildcard address:
required - success (node name: 'ComputerName')
not required - success (node name: 'ComputerName')
required - numeric - success (node name: '::')
not required - numeric - success (node name: '::')
loopback address:
required - success (node name: 'ComputerName')
not required - success (node name: 'ComputerName')
required - numeric - success (node name: '::1')
not required - numeric - success (node name: '::1')
remote address:
required - success (node name: 'lhr25s12-in-f4.1e100.net')
not required - success (node name: 'lhr25s12-in-f4.1e100.net')
required - numeric - success (node name: '216.58.204.36')
not required - numeric - success (node name: '216.58.204.36')
WSL:
wildcard address:
required - failed! (error: host_not_found)
not required - success (node name: '0.0.0.0')
required - numeric - failed! (error: host_not_found)
not required - numeric - success (node name: '0.0.0.0')
loopback address:
required - success (node name: 'ip6-localhost')
not required - success (node name: 'ip6-localhost')
required - numeric - failed! (error: host_not_found)
not required - numeric - success (node name: '::1')
remote address:
required - success (node name: 'lhr25s12-in-x04.1e100.net')
not required - success (node name: 'lhr25s12-in-x04.1e100.net')
required - numeric - failed! (error: host_not_found)
not required - numeric - success (node name: '2a00:1450:4009:80d::2004')
So the getnameinfo() behavioral differences are:
Non-numeric wildcard addresses work on Windows, but fail on WSL.
Numeric address lookups fail on WSL when NI_NAMEREQD is set.
Are these differences simply alternative interpretations of the specs? Is it reasonable for the Windows version to return the ComputerName as the host name?
I'm not yet sure why the wildcard lookups fail, but after digging around in glibc, it seems that NI_NUMERICHOST and NI_NAMEREQD simply don't work together:
/* Convert AF_INET or AF_INET6 socket address, host part. */
static int
gni_host_inet (struct scratch_buffer *tmpbuf,
const struct sockaddr *sa, socklen_t addrlen,
char *host, socklen_t hostlen, int flags)
{
if (!(flags & NI_NUMERICHOST))
{
int result = gni_host_inet_name
(tmpbuf, sa, addrlen, host, hostlen, flags);
if (result != EAI_NONAME)
return result;
}
if (flags & NI_NAMEREQD)
return EAI_NONAME;
else
return gni_host_inet_numeric
(tmpbuf, sa, addrlen, host, hostlen, flags);
}
Related
My goal is to be able to do load multiple PEM files in boost so it can do a correct SSL handshake with a client depending on the SNI sent in TLS which means a single ip address can host multiple https sites. I am looking through the boost document SSL section, but nothing there about SNI. Because my project depends on boost::asio and I don’t know if it is possible and how to refer the example code, s_server.c, in openssl directory ( https://github.com/openssl/openssl/blob/master/apps/s_server.c#L127 ) THANK YOU for any hints.
And my https server code is here:
template<class service_pool_policy = io_service_pool>
class httpTLS_server_ : private noncopyable {
public:
template<class... Args>
explicit httpTLS_server_(Args&&... args) : io_service_pool_(std::forward<Args>(args)...)
, ctx_(boost::asio::ssl::context::tls_server)
{
http_cache::get().set_cache_max_age(86400);
init_conn_callback();
}
void enable_http_cache(bool b) {
http_cache::get().enable_cache(b);
}
template<typename F>
void init_ssl_context(bool ssl_enable_v3, F&& f, std::string certificate_chain_file,
std::string private_key_file, std::string tmp_dh_file) {
unsigned long ssl_options = boost::asio::ssl::context::default_workarounds
| boost::asio::ssl::context::no_sslv2
| boost::asio::ssl::context::no_sslv3
| boost::asio::ssl::context::no_tlsv1
| boost::asio::ssl::context::single_dh_use;
ctx_.set_options(ssl_options);
ctx_.set_password_callback(std::forward<F>(f));
ctx_.use_certificate_chain_file(std::move(certificate_chain_file));
ctx_.use_private_key_file(std::move(private_key_file), boost::asio::ssl::context::pem);
}
//address :
// "0.0.0.0" : ipv4. use 'https://localhost/' to visit
// "::1" : ipv6. use 'https://[::1]/' to visit
// "" : ipv4 & ipv6.
bool listen(std::string_view address, std::string_view port) {
boost::asio::ip::tcp::resolver::query query(address.data(), port.data());
return listen(query);
}
//support ipv6 & ipv4
bool listen(std::string_view port) {
boost::asio::ip::tcp::resolver::query query(port.data());
return listen(query);
}
bool listen(const boost::asio::ip::tcp::resolver::query & query) {
boost::asio::ip::tcp::resolver resolver(io_service_pool_.get_io_service());
boost::asio::ip::tcp::resolver::iterator endpoints = resolver.resolve(query);
bool r = false;
for (; endpoints != boost::asio::ip::tcp::resolver::iterator(); ++endpoints) {
boost::asio::ip::tcp::endpoint endpoint = *endpoints;
auto acceptor = std::make_shared<boost::asio::ip::tcp::acceptor>(io_service_pool_.get_io_service());
acceptor->open(endpoint.protocol());
acceptor->set_option(boost::asio::ip::tcp::acceptor::reuse_address(true));
try {
acceptor->bind(endpoint);
acceptor->listen();
start_accept(acceptor);
r = true;
}
catch (const std::exception& ex) {
std::cout << ex.what() << "\n";
//LOG_INFO << e.what();
}
}
return r;
}
void stop() {
io_service_pool_.stop();
}
void run() {
if (!fs::exists(public_root_path_.data())) {
fs::create_directories(public_root_path_.data());
}
if (!fs::exists(static_dir_.data())) {
fs::create_directories(static_dir_.data());
}
io_service_pool_.run();
}
intptr_t run_one() {
return io_service_pool_.run_one();
}
intptr_t poll() {
return io_service_pool_.poll();
}
intptr_t poll_one() {
return io_service_pool_.poll_one();
}
void set_static_dir(std::string&& path) {
static_dir_ = public_root_path_+std::move(path)+"/";
}
const std::string& static_dir() const {
return static_dir_;
}
//xM
void set_max_req_buf_size(std::size_t max_buf_size) {
max_req_buf_size_ = max_buf_size;
}
void set_keep_alive_timeout(long seconds) {
keep_alive_timeout_ = seconds;
}
template<typename T>
bool need_cache(T&& t) {
if constexpr(std::is_same_v<T, enable_cache<bool>>) {
return t.value;
}
else {
return false;
}
}
//set http handlers
template<http_method... Is, typename Function, typename... AP>
void set_http_handler(std::string_view name, Function&& f, AP&&... ap) {
if constexpr(has_type<enable_cache<bool>, std::tuple<std::decay_t<AP>...>>::value) {//for cache
bool b = false;
((!b&&(b = need_cache(std::forward<AP>(ap)))),...);
if (!b) {
http_cache::get().add_skip(name);
}else{
http_cache::get().add_single_cache(name);
}
auto tp = filter<enable_cache<bool>>(std::forward<AP>(ap)...);
auto lm = [this, name, f = std::move(f)](auto... ap) {
https_router_.register_handler<Is...>(name, std::move(f), std::move(ap)...);
};
std::apply(lm, std::move(tp));
}
else {
https_router_.register_handler<Is...>(name, std::forward<Function>(f), std::forward<AP>(ap)...);
}
}
void set_base_path(const std::string& key,const std::string& path)
{
base_path_[0] = std::move(key);
base_path_[1] = std::move(path);
}
void set_res_cache_max_age(std::time_t seconds)
{
static_res_cache_max_age_ = seconds;
}
std::time_t get_res_cache_max_age()
{
return static_res_cache_max_age_;
}
void set_cache_max_age(std::time_t seconds)
{
http_cache::get().set_cache_max_age(seconds);
}
std::time_t get_cache_max_age()
{
return http_cache::get().get_cache_max_age();
}
//don't begin with "./" or "/", not absolutely path
void set_public_root_directory(const std::string& name)
{
if(!name.empty()){
public_root_path_ = "./"+name+"/";
}
else {
public_root_path_ = "./";
}
}
std::string get_public_root_directory()
{
return public_root_path_;
}
void set_download_check(std::function<bool(request_ssl& req, response& res)> checker) {
download_check_ = std::move(checker);
}
//should be called before listen
void set_upload_check(std::function<bool(request_ssl& req, response& res)> checker) {
upload_check_ = std::move(checker);
}
void mapping_to_root_path(std::string relate_path) {
relate_paths_.emplace_back("."+std::move(relate_path));
}
private:
void start_accept(std::shared_ptr<boost::asio::ip::tcp::acceptor> const& acceptor) {
auto new_conn = std::make_shared<connection_ssl<Socket_ssl>>(
io_service_pool_.get_io_service(), max_req_buf_size_, keep_alive_timeout_, https_handler_, static_dir_,
upload_check_?&upload_check_ : nullptr
, ctx_
);
acceptor->async_accept(new_conn->socket(), [this, new_conn, acceptor](const boost::system::error_code& e) {
if (!e) {
new_conn->socket().set_option(boost::asio::ip::tcp::no_delay(true));
new_conn->start();
}
else {
std::cout << "server::handle_accept: " << e.message();
//LOG_INFO << "server::handle_accept: " << e.message();
}
start_accept(acceptor);
});
}
void init_conn_callback() {
set_static_res_handler();
https_handler_ = [this](request_ssl& req, response& res) {
res.set_base_path(this->base_path_[0],this->base_path_[1]);
res.set_url(req.get_url());
try {
bool success = https_router_.route(req.get_method(), req.get_url(), req, res);
if (!success) {
//updated by neo
//res.set_status_and_content(status_type::bad_request, "the url is not right");
res.redirect("/");
//updated end
}
}
catch (const std::exception& ex) {
res.set_status_and_content(status_type::internal_server_error, ex.what()+std::string(" exception in business function"));
}
catch (...) {
res.set_status_and_content(status_type::internal_server_error, "unknown exception in business function");
}
};
}
service_pool_policy io_service_pool_;
std::size_t max_req_buf_size_ = 3 * 1024 * 1024; //max request buffer size 3M
long keep_alive_timeout_ = 60; //max request timeout 60s
https_router https_router_;
std::string static_dir_ = "./public/static/"; //default
std::string base_path_[2] = {"base_path","/"};
std::time_t static_res_cache_max_age_ = 0;
std::string public_root_path_ = "./";
boost::asio::ssl::context ctx_;
//SSL_CTX ctx_ ;
//SSL_CTX *ctx2 = NULL;
https_handler https_handler_ = nullptr;
std::function<bool(request_ssl& req, response& res)> download_check_;
std::vector<std::string> relate_paths_;
std::function<bool(request_ssl& req, response& res)> upload_check_ = nullptr;
};
using httpTLS_server = httpTLS_server_<io_service_pool>;
Closed. This question needs to be more focused. It is not currently accepting answers.
Want to improve this question? Update the question so it focuses on one problem only by editing this post.
Closed 1 year ago.
Improve this question
i have some c++ project after a release of support c++20, i want to upgrade my makefile std support 17 to 20 after that point my compiler (gcc10.2) give me a error like this ;
Error
In file included from /usr/local/lib/gcc10/include/c++/bits/node_handle.h:39,
from /usr/local/lib/gcc10/include/c++/bits/stl_tree.h:72,
from /usr/local/lib/gcc10/include/c++/map:60,
from AsyncSQL.h:10,
from AsyncSQL.cpp:4:
/usr/local/lib/gcc10/include/c++/optional: In function 'constexpr std::strong_ordering std::operator<=>(const std::optional<_Tp>&, std::nullopt_t)':
/usr/local/lib/gcc10/include/c++/optional:1052:24: error: invalid operands of types 'bool' and 'int' to binary 'operator<=>'
1052 | { return bool(__x) <=> false; }
| ~~~~~~~~~ ^~~
| |
| bool
gmake[2]: *** [Makefile:23: AsyncSQL.o] Error 1
This is my AsyncSQL.cpp ;
#include <sys/time.h>
#include <cstdlib>
#include <cstring>
#include "AsyncSQL.h"
#define MUTEX_LOCK(mtx) pthread_mutex_lock(mtx)
#define MUTEX_UNLOCK(mtx) pthread_mutex_unlock(mtx)
CAsyncSQL::CAsyncSQL(): m_stHost (""), m_stUser (""), m_stPassword (""), m_stDB (""), m_stLocale (""), m_iMsgCount (0), m_iPort (0), m_bEnd (false), m_hThread (0), m_mtxQuery (NULL), m_mtxResult (NULL), m_iQueryFinished (0), m_ulThreadID (0), m_bConnected (false), m_iCopiedQuery (0)
{
memset (&m_hDB, 0, sizeof (m_hDB));
m_aiPipe[0] = 0;
m_aiPipe[1] = 0;
}
CAsyncSQL::~CAsyncSQL()
{
Quit();
Destroy();
}
void CAsyncSQL::Destroy()
{
if (m_hDB.host)
{
sys_log (0, "AsyncSQL: closing mysql connection.");
mysql_close (&m_hDB);
m_hDB.host = NULL;
}
if (m_mtxQuery)
{
pthread_mutex_destroy (m_mtxQuery);
delete m_mtxQuery;
m_mtxQuery = NULL;
}
if (m_mtxResult)
{
pthread_mutex_destroy (m_mtxResult);
delete m_mtxResult;
m_mtxQuery = NULL;
}
}
void* AsyncSQLThread (void* arg)
{
CAsyncSQL* pSQL = ((CAsyncSQL*) arg);
if (!pSQL->Connect())
{
return NULL;
}
pSQL->ChildLoop();
return NULL;
}
bool CAsyncSQL::QueryLocaleSet()
{
if (0 == m_stLocale.length())
{
sys_err ("m_stLocale == 0");
return true;
}
if (mysql_set_character_set (&m_hDB, m_stLocale.c_str()))
{
sys_err ("cannot set locale %s by 'mysql_set_character_set', errno %u %s", m_stLocale.c_str(), mysql_errno (&m_hDB) , mysql_error (&m_hDB));
return false;
}
sys_log (0, "\t--mysql_set_character_set(%s)", m_stLocale.c_str());
return true;
}
bool CAsyncSQL::Connect()
{
if (0 == mysql_init (&m_hDB))
{
fprintf (stderr, "mysql_init failed\n");
return false;
}
if (!m_stLocale.empty())
{
if (mysql_options (&m_hDB, MYSQL_SET_CHARSET_NAME, m_stLocale.c_str()) != 0)
{
fprintf (stderr, "mysql_option failed : MYSQL_SET_CHARSET_NAME %s ", mysql_error(&m_hDB));
}
}
if (!mysql_real_connect (&m_hDB, m_stHost.c_str(), m_stUser.c_str(), m_stPassword.c_str(), m_stDB.c_str(), m_iPort, NULL, CLIENT_MULTI_STATEMENTS))
{
fprintf (stderr, "mysql_real_connect: %s\n", mysql_error(&m_hDB));
return false;
}
my_bool reconnect = true;
if (0 != mysql_options (&m_hDB, MYSQL_OPT_RECONNECT, &reconnect))
{
fprintf (stderr, "mysql_option: %s\n", mysql_error(&m_hDB));
}
m_ulThreadID = mysql_thread_id (&m_hDB);
m_bConnected = true;
return true;
}
bool CAsyncSQL::Setup (CAsyncSQL* sql, bool bNoThread)
{
return Setup (sql->m_stHost.c_str(), sql->m_stUser.c_str(), sql->m_stPassword.c_str(), sql->m_stDB.c_str(), sql->m_stLocale.c_str(), bNoThread, sql->m_iPort);
}
bool CAsyncSQL::Setup (const char* c_pszHost, const char* c_pszUser, const char* c_pszPassword, const char* c_pszDB, const char* c_pszLocale, bool bNoThread, int iPort)
{
m_stHost = c_pszHost;
m_stUser = c_pszUser;
m_stPassword = c_pszPassword;
m_stDB = c_pszDB;
m_iPort = iPort;
if (c_pszLocale)
{
m_stLocale = c_pszLocale;
sys_log (0, "AsyncSQL: locale %s", m_stLocale.c_str());
}
if (!bNoThread)
{
m_mtxQuery = new pthread_mutex_t;
m_mtxResult = new pthread_mutex_t;
if (0 != pthread_mutex_init (m_mtxQuery, NULL))
{
perror ("pthread_mutex_init");
exit (0);
}
if (0 != pthread_mutex_init (m_mtxResult, NULL))
{
perror ("pthread_mutex_init");
exit (0);
}
pthread_create (&m_hThread, NULL, AsyncSQLThread, this);
return true;
}
else
{
return Connect();
}
}
void CAsyncSQL::Quit()
{
m_bEnd = true;
m_sem.Release();
if (m_hThread)
{
pthread_join (m_hThread, NULL);
m_hThread = NULL;
}
}
SQLMsg* CAsyncSQL::DirectQuery (const char* c_pszQuery)
{
if (m_ulThreadID != mysql_thread_id (&m_hDB))
{
sys_log (0, "MySQL connection was reconnected. querying locale set");
while (!QueryLocaleSet());
m_ulThreadID = mysql_thread_id (&m_hDB);
}
SQLMsg* p = new SQLMsg;
p->m_pkSQL = &m_hDB;
p->iID = ++m_iMsgCount;
p->stQuery = c_pszQuery;
if (mysql_real_query (&m_hDB, p->stQuery.c_str(), p->stQuery.length()))
{
char buf[1024];
snprintf (buf, sizeof(buf), "AsyncSQL::DirectQuery : mysql_query error: %s\nquery: %s", mysql_error (&m_hDB), p->stQuery.c_str());
sys_err (buf);
p->uiSQLErrno = mysql_errno (&m_hDB);
}
p->Store();
return p;
}
void CAsyncSQL::AsyncQuery (const char* c_pszQuery)
{
auto p = new SQLMsg;
p->m_pkSQL = &m_hDB;
p->iID = ++m_iMsgCount;
p->stQuery = c_pszQuery;
PushQuery (p);
}
void CAsyncSQL::ReturnQuery (const char* c_pszQuery, void* pvUserData)
{
auto p = new SQLMsg;
p->m_pkSQL = &m_hDB;
p->iID = ++m_iMsgCount;
p->stQuery = c_pszQuery;
p->bReturn = true;
p->pvUserData = pvUserData;
PushQuery (p);
}
void CAsyncSQL::PushResult (SQLMsg* p)
{
MUTEX_LOCK (m_mtxResult);
m_queue_result.push (p);
MUTEX_UNLOCK (m_mtxResult);
}
bool CAsyncSQL::PopResult(SQLMsg** pp)
{
MUTEX_LOCK (m_mtxResult);
if (m_queue_result.empty())
{
MUTEX_UNLOCK (m_mtxResult);
return false;
}
*pp = m_queue_result.front();
m_queue_result.pop();
MUTEX_UNLOCK (m_mtxResult);
return true;
}
void CAsyncSQL::PushQuery (SQLMsg* p)
{
MUTEX_LOCK (m_mtxQuery);
m_queue_query.push (p);
m_sem.Release();
MUTEX_UNLOCK (m_mtxQuery);
}
bool CAsyncSQL::PeekQuery (SQLMsg** pp)
{
MUTEX_LOCK (m_mtxQuery);
if (m_queue_query.empty())
{
MUTEX_UNLOCK (m_mtxQuery);
return false;
}
*pp = m_queue_query.front();
MUTEX_UNLOCK (m_mtxQuery);
return true;
}
bool CAsyncSQL::PopQuery (int iID)
{
MUTEX_LOCK (m_mtxQuery);
if (m_queue_query.empty())
{
MUTEX_UNLOCK (m_mtxQuery);
return false;
}
m_queue_query.pop();
MUTEX_UNLOCK (m_mtxQuery);
return true;
}
bool CAsyncSQL::PeekQueryFromCopyQueue (SQLMsg** pp)
{
if (m_queue_query_copy.empty())
{
return false;
}
*pp = m_queue_query_copy.front();
return true;
}
int CAsyncSQL::CopyQuery()
{
MUTEX_LOCK (m_mtxQuery);
if (m_queue_query.empty())
{
MUTEX_UNLOCK (m_mtxQuery);
return -1;
}
while (!m_queue_query.empty())
{
SQLMsg* p = m_queue_query.front();
m_queue_query_copy.push (p);
m_queue_query.pop();
}
int count = m_queue_query_copy.size();
MUTEX_UNLOCK (m_mtxQuery);
return count;
}
bool CAsyncSQL::PopQueryFromCopyQueue()
{
if (m_queue_query_copy.empty())
{
return false;
}
m_queue_query_copy.pop();
return true;
}
int CAsyncSQL::GetCopiedQueryCount()
{
return m_iCopiedQuery;
}
void CAsyncSQL::ResetCopiedQueryCount()
{
m_iCopiedQuery = 0;
}
void CAsyncSQL::AddCopiedQueryCount (int iCopiedQuery)
{
m_iCopiedQuery += iCopiedQuery;
}
DWORD CAsyncSQL::CountQuery()
{
return m_queue_query.size();
}
DWORD CAsyncSQL::CountResult()
{
return m_queue_result.size();
}
void __timediff (struct timeval* a, struct timeval* b, struct timeval* rslt)
{
if (a->tv_sec < b->tv_sec)
{
rslt->tv_sec = rslt->tv_usec = 0;
}
else if (a->tv_sec == b->tv_sec)
{
if (a->tv_usec < b->tv_usec)
{
rslt->tv_sec = rslt->tv_usec = 0;
}
else
{
rslt->tv_sec = 0;
rslt->tv_usec = a->tv_usec - b->tv_usec;
}
}
else
{
rslt->tv_sec = a->tv_sec - b->tv_sec;
if (a->tv_usec < b->tv_usec)
{
rslt->tv_usec = a->tv_usec + 1000000 - b->tv_usec;
rslt->tv_sec--;
}
else
{
rslt->tv_usec = a->tv_usec - b->tv_usec;
}
}
}
class cProfiler
{
public:
cProfiler()
{
m_nInterval = 0 ;
memset (&prev, 0, sizeof (prev));
memset (&now, 0, sizeof (now));
memset (&interval, 0, sizeof (interval));
Start();
}
cProfiler (int nInterval = 100000)
{
m_nInterval = nInterval;
memset (&prev, 0, sizeof (prev));
memset (&now, 0, sizeof (now));
memset (&interval, 0, sizeof (interval));
Start();
}
void Start()
{
gettimeofday (&prev , (struct timezone*) 0);
}
void Stop()
{
gettimeofday (&now, (struct timezone*) 0);
__timediff (&now, &prev, &interval);
}
bool IsOk()
{
if (interval.tv_sec > (m_nInterval / 1000000))
{
return false;
}
if (interval.tv_usec > m_nInterval)
{
return false;
}
return true;
}
struct timeval* GetResult()
{
return &interval;
}
long GetResultSec()
{
return interval.tv_sec;
}
long GetResultUSec()
{
return interval.tv_usec;
}
private:
int m_nInterval;
struct timeval prev;
struct timeval now;
struct timeval interval;
};
void CAsyncSQL::ChildLoop()
{
cProfiler profiler(500000);
while (!m_bEnd)
{
m_sem.Wait();
int count = CopyQuery();
if (count <= 0)
{
continue;
}
AddCopiedQueryCount (count);
SQLMsg* p;
while (count--)
{
profiler.Start();
if (!PeekQueryFromCopyQueue (&p))
{
continue;
}
if (m_ulThreadID != mysql_thread_id (&m_hDB))
{
sys_log (0, "MySQL connection was reconnected. querying locale set");
while (!QueryLocaleSet());
m_ulThreadID = mysql_thread_id (&m_hDB);
}
if (mysql_real_query (&m_hDB, p->stQuery.c_str(), p->stQuery.length()))
{
p->uiSQLErrno = mysql_errno (&m_hDB);
sys_err ("AsyncSQL: query failed: %s (query: %s errno: %d)", mysql_error (&m_hDB), p->stQuery.c_str(), p->uiSQLErrno);
switch (p->uiSQLErrno)
{
case CR_SOCKET_CREATE_ERROR:
case CR_CONNECTION_ERROR:
case CR_IPSOCK_ERROR:
case CR_UNKNOWN_HOST:
case CR_SERVER_GONE_ERROR:
case CR_CONN_HOST_ERROR:
case ER_NOT_KEYFILE:
case ER_CRASHED_ON_USAGE:
case ER_CANT_OPEN_FILE:
case ER_HOST_NOT_PRIVILEGED:
case ER_HOST_IS_BLOCKED:
case ER_PASSWORD_NOT_ALLOWED:
case ER_PASSWORD_NO_MATCH:
case ER_CANT_CREATE_THREAD:
case ER_INVALID_USE_OF_NULL:
m_sem.Release();
sys_err ("AsyncSQL: retrying");
continue;
}
}
profiler.Stop();
if (!profiler.IsOk())
{
sys_log (0, "[QUERY : LONG INTERVAL(OverSec %ld.%ld)] : %s", profiler.GetResultSec(), profiler.GetResultUSec(), p->stQuery.c_str());
}
PopQueryFromCopyQueue();
if (p->bReturn)
{
p->Store();
PushResult (p);
}
else
{
delete p;
}
++m_iQueryFinished;
}
}
SQLMsg* p;
while (PeekQuery (&p))
{
if (m_ulThreadID != mysql_thread_id (&m_hDB))
{
sys_log (0, "MySQL connection was reconnected. querying locale set");
while (!QueryLocaleSet());
m_ulThreadID = mysql_thread_id (&m_hDB);
}
if (mysql_real_query (&m_hDB, p->stQuery.c_str(), p->stQuery.length()))
{
p->uiSQLErrno = mysql_errno (&m_hDB);
sys_err ("AsyncSQL::ChildLoop : mysql_query error: %s:\nquery: %s", mysql_error (&m_hDB), p->stQuery.c_str());
switch (p->uiSQLErrno)
{
case CR_SOCKET_CREATE_ERROR:
case CR_CONNECTION_ERROR:
case CR_IPSOCK_ERROR:
case CR_UNKNOWN_HOST:
case CR_SERVER_GONE_ERROR:
case CR_CONN_HOST_ERROR:
case ER_NOT_KEYFILE:
case ER_CRASHED_ON_USAGE:
case ER_CANT_OPEN_FILE:
case ER_HOST_NOT_PRIVILEGED:
case ER_HOST_IS_BLOCKED:
case ER_PASSWORD_NOT_ALLOWED:
case ER_PASSWORD_NO_MATCH:
case ER_CANT_CREATE_THREAD:
case ER_INVALID_USE_OF_NULL:
continue;
}
}
sys_log (0, "QUERY_FLUSH: %s", p->stQuery.c_str());
PopQuery (p->iID);
if (p->bReturn)
{
p->Store();
PushResult (p);
}
else
{
delete p;
}
++m_iQueryFinished;
}
}
int CAsyncSQL::CountQueryFinished()
{
return m_iQueryFinished;
}
void CAsyncSQL::ResetQueryFinished()
{
m_iQueryFinished = 0;
}
MYSQL* CAsyncSQL::GetSQLHandle()
{
return &m_hDB;
}
size_t CAsyncSQL::EscapeString (char* dst, size_t dstSize, const char* src, size_t srcSize)
{
if (0 == srcSize)
{
memset (dst, 0, dstSize);
return 0;
}
if (0 == dstSize)
{
return 0;
}
if (dstSize < srcSize * 2 + 1)
{
char tmp[256];
size_t tmpLen = sizeof (tmp) > srcSize ? srcSize : sizeof (tmp);
strlcpy (tmp, src, tmpLen);
sys_err ("FATAL ERROR!! not enough buffer size (dstSize %u srcSize %u src%s: %s)", dstSize, srcSize, tmpLen != srcSize ? "(trimmed to 255 characters)" : "", tmp);
dst[0] = '\0';
return 0;
}
return mysql_real_escape_string (GetSQLHandle(), dst, src, srcSize);
}
void CAsyncSQL2::SetLocale (const std::string & stLocale)
{
m_stLocale = stLocale;
QueryLocaleSet();
}
This is my AsyncSQL.h
#ifndef __INC_METIN_II_ASYNCSQL_H__
#define __INC_METIN_II_ASYNCSQL_H__
#include "../../libthecore/src/stdafx.h"
#include "../../libthecore/src/log.h"
#include "../../Ayarlar.h"
#include <string>
#include <queue>
#include <vector>
#include <map>
#include <mysql/server/mysql.h>
#include <mysql/server/errmsg.h>
#include <mysql/server/mysqld_error.h>
#include "Semaphore.h"
typedef struct _SQLResult
{
_SQLResult(): pSQLResult (NULL), uiNumRows (0), uiAffectedRows (0), uiInsertID (0) {}
~_SQLResult()
{
if (pSQLResult)
{
mysql_free_result (pSQLResult);
pSQLResult = NULL;
}
}
MYSQL_RES* pSQLResult;
uint32_t uiNumRows;
uint32_t uiAffectedRows;
uint32_t uiInsertID;
} SQLResult;
typedef struct _SQLMsg
{
_SQLMsg() : m_pkSQL (NULL), iID (0), uiResultPos (0), pvUserData (NULL), bReturn (false), uiSQLErrno (0) {}
~_SQLMsg()
{
auto first = vec_pkResult.begin();
auto past = vec_pkResult.end();
while (first != past)
{
delete * (first++);
}
vec_pkResult.clear();
}
void Store()
{
do
{
SQLResult* pRes = new SQLResult;
pRes->pSQLResult = mysql_store_result (m_pkSQL);
pRes->uiInsertID = mysql_insert_id (m_pkSQL);
pRes->uiAffectedRows = mysql_affected_rows (m_pkSQL);
if (pRes->pSQLResult)
{
pRes->uiNumRows = mysql_num_rows (pRes->pSQLResult);
}
else
{
pRes->uiNumRows = 0;
}
vec_pkResult.push_back (pRes);
}
while (!mysql_next_result (m_pkSQL));
}
SQLResult* Get()
{
if (uiResultPos >= vec_pkResult.size())
{
return NULL;
}
return vec_pkResult[uiResultPos];
}
bool Next()
{
if (uiResultPos + 1 >= vec_pkResult.size())
{
return false;
}
++uiResultPos;
return true;
}
MYSQL* m_pkSQL;
int iID;
std::string stQuery;
std::vector<SQLResult *> vec_pkResult;
unsigned int uiResultPos;
void* pvUserData;
bool bReturn;
unsigned int uiSQLErrno;
} SQLMsg;
class CAsyncSQL
{
public:
CAsyncSQL();
virtual ~CAsyncSQL();
void Quit();
bool Setup (const char* c_pszHost, const char* c_pszUser, const char* c_pszPassword, const char* c_pszDB, const char* c_pszLocale, bool bNoThread = false, int iPort = 0);
bool Setup (CAsyncSQL* sql, bool bNoThread = false);
bool Connect();
bool IsConnected()
{
return m_bConnected;
}
bool QueryLocaleSet();
void AsyncQuery (const char* c_pszQuery);
void ReturnQuery (const char* c_pszQuery, void* pvUserData);
SQLMsg* DirectQuery (const char* c_pszQuery);
DWORD CountQuery();
DWORD CountResult();
void PushResult (SQLMsg* p);
bool PopResult (SQLMsg** pp);
void ChildLoop();
MYSQL* GetSQLHandle();
int CountQueryFinished();
void ResetQueryFinished();
size_t EscapeString (char* dst, size_t dstSize, const char* src, size_t srcSize);
protected:
void Destroy();
void PushQuery (SQLMsg* p);
bool PeekQuery (SQLMsg** pp);
bool PopQuery (int iID);
bool PeekQueryFromCopyQueue (SQLMsg** pp );
INT CopyQuery();
bool PopQueryFromCopyQueue();
public:
int GetCopiedQueryCount();
void ResetCopiedQueryCount();
void AddCopiedQueryCount (int iCopiedQuery);
protected:
MYSQL m_hDB;
std::string m_stHost;
std::string m_stUser;
std::string m_stPassword;
std::string m_stDB;
std::string m_stLocale;
int m_iMsgCount;
int m_aiPipe[2];
int m_iPort;
std::queue<SQLMsg*> m_queue_query;
std::queue<SQLMsg*> m_queue_query_copy;
std::queue<SQLMsg*> m_queue_result;
volatile bool m_bEnd;
pthread_t m_hThread;
pthread_mutex_t* m_mtxQuery;
pthread_mutex_t* m_mtxResult;
CSemaphore m_sem;
int m_iQueryFinished;
unsigned long m_ulThreadID;
bool m_bConnected;
int m_iCopiedQuery;
};
class CAsyncSQL2 : public CAsyncSQL
{
public:
void SetLocale (const std::string & stLocale);
};
#endif
And this is the function the reason of the error ;
optional:1052 ;
#ifdef __cpp_lib_three_way_comparison
template<typename _Tp>
constexpr strong_ordering
operator<=>(const optional<_Tp>& __x, nullopt_t) noexcept
{ return bool(__x) <=> false; }
#else
After a see a document the microsoft release i'm gonna try <= > false; like this and take a error again..
Best Regards.
I ve no idea why it looks is getting bool(__x) <=> false as an bool and int comparison.
I would think you got some strange macro in your files included before to include the header that is going to break the standard code.
I would suggest you try to move above the standard headers and below them your 'user defined' headers.
#include <string>
#include <queue>
#include <vector>
#include <map>
#include <mysql/server/mysql.h>
#include <mysql/server/errmsg.h>
#include <mysql/server/mysqld_error.h>
#include "../../libthecore/src/stdafx.h"
#include "../../libthecore/src/log.h"
#include "../../Ayarlar.h"
#include "Semaphore.h"
EDIT:
i ve found the cause of the problem.
a macro defined in "libthrecore/stdafx.h" (i own the files that is using the author, they are public).
#ifndef false
#define false 0
#define true (!false)
#endif
it is causing false to be read as a int and is causing the spaceship operator to fails with the error shown by the author. Move up the standard headers or remove the macro to solve the error.
I write a C++ dome of tcp server with the libuv. When I check the cpu performance, I found the dome is a single thread running, how can I implement it with multi-thread?
Currently, the dome can hanlde 100,000+ tcp request per second, it can only eat 1 CPU.
Code:
#include <iostream>
#include <atomic>
#include "uv.h"
#include <thread>
#include <mutex>
#include <map>
using namespace std;
auto loop = uv_default_loop();
struct sockaddr_in addr;
typedef struct {
uv_write_t req;
uv_buf_t buf;
} write_req_t;
typedef struct {
uv_stream_t* client;
uv_alloc_cb alloc_cb;
uv_read_cb read_cb;
} begin_read_req;
void alloc_buffer(uv_handle_t *handle, size_t suggested_size, uv_buf_t *buf) {
buf->base = (char*)malloc(suggested_size);
buf->len = suggested_size;
}
void free_write_req(uv_write_t *req) {
write_req_t *wr = (write_req_t*)req;
free(wr->buf.base);
free(wr);
}
void echo_write(uv_write_t *req, int status) {
if (status) {
fprintf(stderr, "Write error %s\n", uv_strerror(status));
}
free_write_req(req);
}
void echo_read(uv_stream_t *client, ssize_t nread, const uv_buf_t *buf) {
if (nread > 0) {
auto req = (write_req_t*)malloc(sizeof(write_req_t));
auto *aaa = (char*)malloc(5);
aaa[0] = '+';
aaa[1] = 'O';
aaa[2] = 'K';
aaa[3] = '\r';
aaa[4] = '\n';
req->buf = uv_buf_init(aaa, 5);
uv_write((uv_write_t*)req, client, &req->buf, 1, echo_write);
}
if (nread < 0) {
if (nread != UV_EOF)
fprintf(stderr, "Read error %s\n", uv_err_name(static_cast<unsigned int>(nread)));
uv_close((uv_handle_t*)client, nullptr);
}
free(buf->base);
}
void acceptClientRead(uv_work_t *req) {
begin_read_req *data = (begin_read_req *)req->data;
uv_read_start(data->client, data->alloc_cb, data->read_cb);
}
void on_new_connection(uv_stream_t *server, int status) {
if (status < 0) {
cout << "New connection error:" << uv_strerror(status);
return;
}
uv_tcp_t *client = (uv_tcp_t *)malloc(sizeof(uv_tcp_t));
uv_tcp_init(loop, client);
uv_work_t *req = (uv_work_t *)malloc(sizeof(uv_work_t));
begin_read_req *read_req = (begin_read_req *)malloc(sizeof(begin_read_req));
read_req->client = (uv_stream_t *)client;
read_req->read_cb = echo_read;
read_req->alloc_cb = alloc_buffer;
req->data = read_req;
if (uv_accept(server, (uv_stream_t *)client) == 0) {
uv_read_start((uv_stream_t *)client, alloc_buffer, echo_read);
// uv_queue_work(workloop[0], req, acceptClientRead, nullptr);
}
else {
uv_close((uv_handle_t *)client, nullptr);
}
}
void timer_callback(uv_timer_t* handle) {
cout << std::this_thread::get_id() << "---------" << "hello" << endl;
}
int main() {
uv_tcp_t server{};
uv_tcp_init(loop, &server);
uv_ip4_addr("0.0.0.0", 8790, &addr);
uv_tcp_bind(&server, (const struct sockaddr *) &addr, 0);
uv_listen((uv_stream_t *)&server, 511, on_new_connection);
uv_run(loop, UV_RUN_DEFAULT);
return 0;
}
Of course, I can make the write step asynchronous in the method "echo_read", but I didn't do anything before the write, can I make the demo multi-thread in another way to improve the throughput?
Closed. This question needs debugging details. It is not currently accepting answers.
Edit the question to include desired behavior, a specific problem or error, and the shortest code necessary to reproduce the problem. This will help others answer the question.
Closed 5 years ago.
Improve this question
I am converting our code to use IOCP and I got the communication relatively stable, but the memory usage of the application is increasing. Looks like I am getting back (on completion function calls) much fewer objects of OverlappedEx than I create. My code is below. What am I doing wrong?
#ifndef NETWORK_DATA
#define NETWORK_DATA
#include <afxwin.h>
#include <vector>
#include <string>
#include "CriticalSectionLocker.h"
using namespace std;
DWORD NetworkManager::NetworkThread(void* param)
{
bool bRun = true;
while (bRun)
{
DWORD wait = ::WaitForSingleObject(CCommunicationManager::s_hShutdownEvent, 0);
if (WAIT_OBJECT_0 == wait)
{
bRun = false;
DEBUG_LOG0("Shutdown event was signalled thread");
}
else
{
DWORD dwBytesTransfered = 0;
void* lpContext = nullptr;
OVERLAPPED* pOverlapped = nullptr;
BOOL bReturn = GetQueuedCompletionStatus(s_IOCompletionPort,
&dwBytesTransfered,
(LPDWORD)&lpContext,
&pOverlapped,
INFINITE);
if (nullptr == lpContext)
{
DEBUG_LOG0("invalid context");
/*continue;*/
}
else
{
if (bReturn && dwBytesTransfered > 0)
{
OverlappedEx* data = reinterpret_cast<OverlappedEx*>(pOverlapped);
ServerData* networkData = reinterpret_cast<ServerData*>(lpContext);
if (networkData && data)
{
switch(data->m_opType)
{
case OverlappedEx::OP_READ:
/*DEBUG_LOG4("device name: %s bytes received: %d socket: %d handle: %d",
networkData->Name().c_str(), dwBytesTransfered, networkData->Socket(), networkData->Handle());*/
networkData->CompleteReceive(dwBytesTransfered, data);
break;
case OverlappedEx::OP_WRITE:
/*DEBUG_LOG4("device name: %s bytes sent: %d socket: %d handle: %d",
networkData->Name().c_str(), dwBytesTransfered, networkData->Socket(), networkData->Handle());*/
networkData->CompleteSend(dwBytesTransfered, data);
break;
}
}
}
else
{
/*DEBUG_LOG2("GetQueuedCompletionStatus failed: bReturn: %d dwBytesTransferred: %u", bReturn, dwBytesTransfered);*/
}
}
}
}
return 0;
}
enum NetworkType
{
UDP,
TCP
};
struct OverlappedEx : public OVERLAPPED
{
enum OperationType
{
OP_READ,
OP_WRITE
};
const static int MAX_PACKET_SIZE = 2048;
WSABUF m_wBuf;
char m_buffer[MAX_PACKET_SIZE];
OperationType m_opType;
OverlappedEx()
{
Clear();
m_refCount = 1;
}
void AddRef()
{
::InterlockedIncrement(&m_refCount);
}
void Release()
{
::InterlockedDecrement(&m_refCount);
}
int Refcount() const
{
return InterlockedExchangeAdd((unsigned long*)&m_refCount, 0UL);
}
~OverlappedEx()
{
Clear();
}
void Clear()
{
memset(m_buffer, 0, MAX_PACKET_SIZE);
m_wBuf.buf = m_buffer;
m_wBuf.len = MAX_PACKET_SIZE;
Internal = 0;
InternalHigh = 0;
Offset = 0;
OffsetHigh = 0;
hEvent = nullptr;
m_opType = OP_READ;
}
private:
volatile LONG m_refCount;
};
class ServerData
{
public:
const static int MAX_REVEIVE_QUEUE_SIZE = 100;
const static int MAX_PACKET_SIZE = 2048;
const static int MAX_SEND_QUEUE_SIZE = 10;
const static int MAX_RECEIVE_QUEUE_SIZE = 100;
const static int MAX_OVERLAPPED_STRUCTS = 20;
ServerData(NetworkType netType, const string& sName, CCommunicationManager::CommHandle handle,
SOCKET sock, HANDLE IOPort) :
m_sName(sName)
{
InitializeCriticalSection(&m_receiveQueLock);
InitializeCriticalSection(&m_objectLock);
m_Handle = handle;
m_Socket = sock;
m_nIPAddress = 0;
m_netType = netType;
m_bEnabled = true;
m_ovlpIndex = 0;
for (int i = 0; i < MAX_OVERLAPPED_STRUCTS; ++i)
{
m_olps.push_back(new OverlappedEx);
}
/* Associate socket with completion handle */
if (m_Socket != 0)
{
CreateIoCompletionPort( reinterpret_cast<HANDLE>(m_Socket), IOPort, reinterpret_cast<ULONG_PTR>(this), 0 );
}
}
~ServerData()
{
CriticalSectionLocker lock(&m_receiveQueLock);
DeleteCriticalSection(&m_receiveQueLock);
DeleteCriticalSection(&m_objectLock);
closesocket(m_Socket);
}
const string& Name() const { return m_sName; }
bool Enabled() const { return m_bEnabled; }
void SetEnabled(bool bEnabled)
{
m_bEnabled = bEnabled;
}
int Handle() const { return m_Handle; }
void SetHandle(int handle)
{
m_Handle = handle;
}
unsigned long IPAddress() const { return m_nIPAddress; }
SOCKET Socket() const
{
return m_Socket;
}
void SetSocket(SOCKET sock)
{
m_Socket = sock;
}
void SetIPAddress(unsigned long nIP)
{
m_nIPAddress = nIP;
}
bool ValidTelegram(const vector<char>& telegram) const
{
return false;
}
OverlappedEx* GetBuffer()
{
OverlappedEx* ret = nullptr;
if (!m_olps.empty())
{
ret = m_olps.front();
m_olps.pop_front();
}
return ret;
}
void CompleteReceive(size_t numBytes, OverlappedEx* data)
{
//DEBUG_LOG1("%d buffers are available", AvailableBufferCount());
if (numBytes > 0)
{
vector<char> v(data->m_buffer, data->m_buffer + numBytes);
ReceivedData rd;
rd.SetData(v);
EnqueReceiveMessage(rd);
}
data->Release();
{
CriticalSectionLocker lock(&m_objectLock);
m_olps.push_back(data);
// DEBUG_LOG1("Queue size: %d", m_olps.size());
}
StartReceiving();
}
void CompleteSend(size_t numBytes, OverlappedEx* data)
{
data->Release();
{
CriticalSectionLocker lock(&m_objectLock);
m_olps.push_back(data);
//DEBUG_LOG1("Queue size: %d", m_olps.size());
}
//DEBUG_LOG2("Object: %s num sent: %d", Name().c_str(), numBytes);
}
void StartReceiving()
{
DWORD bytesRecv = 0;
sockaddr_in senderAddr;
DWORD flags = 0;
int senderAddrSize = sizeof(senderAddr);
int rc = 0;
CriticalSectionLocker lock(&m_objectLock);
auto olp = GetBuffer();
if (!olp)
{
if (...)
{
m_olps.push_back(new OverlappedEx);
olp = GetBuffer();
}
else
{
if (...)
{
DEBUG_LOG1("Name: %s ************* NO AVAILABLE BUFFERS - bailing ***************", Name().c_str());
}
return;
}
}
olp->Clear();
olp->m_opType = OverlappedEx::OP_READ;
olp->AddRef();
switch(GetNetworkType())
{
case UDP:
{
rc = WSARecvFrom(Socket(),
&olp->m_wBuf,
1,
&bytesRecv,
&flags,
(SOCKADDR *)&senderAddr,
&senderAddrSize, (OVERLAPPED*)olp, NULL);
}
break;
case TCP:
{
rc = WSARecv(Socket(),
&olp->m_wBuf,
1,
&bytesRecv,
&flags,
(OVERLAPPED*)olp, NULL);
}
break;
}
if (SOCKET_ERROR == rc)
{
DWORD err = WSAGetLastError();
if (err != WSA_IO_PENDING)
{
olp->Release();
m_olps.push_back(olp);
}
}
}
void SetWriteBuf(const SendData& msg, OverlappedEx* data)
{
int len = min(msg.Data().size(), MAX_PACKET_SIZE);
memcpy(data->m_buffer, &msg.Data()[0], len);
data->m_wBuf.buf = data->m_buffer;
data->m_wBuf.len = len;
}
void StartSending(const SendData& msg)
{
DEBUG_LOG1("device name: %s", Name().c_str());
int rc = 0;
DWORD bytesSent = 0;
DWORD flags = 0;
SOCKET sock = Socket();
int addrSize = sizeof(sockaddr_in);
CriticalSectionLocker lock(&m_objectLock);
//UpdateOverlapped(OverlappedEx::OP_WRITE);
auto olp = GetBuffer();
if (!olp)
{
if (...)
{
m_olps.push_back(new OverlappedEx);
olp = GetBuffer();
DEBUG_LOG2("name: %s ************* NO AVAILABLE BUFFERS new size: %d ***************", Name().c_str(), m_olps.size());
}
else
{
if (...)
{
DEBUG_LOG1("Name: %s ************* NO AVAILABLE BUFFERS - bailing ***************", Name().c_str());
}
return;
}
}
olp->Clear();
olp->m_opType = OverlappedEx::OP_WRITE;
olp->AddRef();
SetWriteBuf(msg, olp);
switch(GetNetworkType())
{
case UDP:
rc = WSASendTo(Socket(), &olp->m_wBuf, 1,
&bytesSent, flags, (sockaddr*)&msg.SendAddress(),
addrSize, (OVERLAPPED*)olp, NULL);
break;
case TCP:
rc = WSASend(Socket(), &olp->m_wBuf, 1,
&bytesSent, flags, (OVERLAPPED*)olp, NULL);
break;
}
if (SOCKET_ERROR == rc)
{
DWORD err = WSAGetLastError();
if (err != WSA_IO_PENDING)
{
olp->Release();
m_olps.push_back(olp);
}
}
}
size_t ReceiveQueueSize()
{
CriticalSectionLocker lock(&m_receiveQueLock);
return m_receiveDataQueue.size();
}
void GetAllData(vector <ReceivedData> & data)
{
CriticalSectionLocker lock(&m_receiveQueLock);
while (m_receiveDataQueue.size() > 0)
{
data.push_back(m_receiveDataQueue.front());
m_receiveDataQueue.pop_front();
}
}
void DequeReceiveMessage(ReceivedData& msg)
{
CriticalSectionLocker lock(&m_receiveQueLock);
if (m_receiveDataQueue.size() > 0)
{
msg = m_receiveDataQueue.front();
m_receiveDataQueue.pop_front();
}
}
template <class T>
void EnqueReceiveMessage(T&& data)
{
CriticalSectionLocker lock(&m_receiveQueLock);
if (m_receiveDataQueue.size() <= MAX_RECEIVE_QUEUE_SIZE)
{
m_receiveDataQueue.push_back(data);
}
else
{
static int s_nLogCount = 0;
if (s_nLogCount % 100 == 0)
{
DEBUG_LOG2("Max queue size was reached handle id: %d in %s", Handle(), Name().c_str());
}
s_nLogCount++;
}
}
NetworkType GetNetworkType() const
{
return m_netType;
}
private:
ServerData(const ServerData&);
ServerData& operator=(const ServerData&);
private:
bool m_bEnabled; //!< This member flags if this reciever is enabled for receiving incoming connections.
int m_Handle; //!< This member holds the handle for this receiver.
SOCKET m_Socket; //!< This member holds the socket information for this receiver.
unsigned long m_nIPAddress; //!< This member holds an IP address the socket is bound to.
deque < ReceivedData > m_receiveDataQueue;
CRITICAL_SECTION m_receiveQueLock;
CRITICAL_SECTION m_objectLock;
string m_sName;
NetworkType m_netType;
deque<OverlappedEx*> m_olps;
size_t m_ovlpIndex;
};
#endif
your implementation of void Release() have no sense - you decrement m_refCount and so what ? must be
void Release()
{
if (!InterlockedDecrement(&m_refCount)) delete this;
}
as result you never free OverlappedEx* data - this what i just view and this give memory leak.
also can advice - use WaitForSingleObject(CCommunicationManager::s_hShutdownEvent, 0); this is bad idea for detect shutdown. call only GetQueuedCompletionStatus and for shutdown call PostQueuedCompletionStatus(s_IOCompletionPort, 0, 0, 0) several times(number or threads listen on s_IOCompletionPort) and if thread view pOverlapped==0 - just exit.
use
OverlappedEx* data = static_cast<OverlappedEx*>(pOverlapped);
instead of reinterpret_cast
make ~OverlappedEx() private - it must not be direct called, only via Release
olp->Release();
m_olps.push_back(olp);
after you call Release() on object you must not it more access here, so or olp->Release() or m_olps.push_back(olp); but not both. this kill all logic of Release may be you need overwrite operator delete of OverlappedEx and inside it call m_olps.push_back(olp); and of course overwrite operator new too
again (OVERLAPPED*)olp - for what reinterpret_cast here ? because you inherit own struct from OVERLAPPED compiler auto do type cast here
I am writing a C++ wrapper for sockets on Linux.
I can connect read/write to a http server my poll function works perfectly for reading but for some reason it won't work with writing. I have tried using gdb and it appears poll sets fd.revents to 0 when fd.events is POLLOUT
poll code:
/**
*#brief assigns request.reply based on fd.revents
*/
static void translate_request(mizaru::PollRequest &request, pollfd &fd)
{
assert( (fd.revents & POLLHUP) == 0);
assert( (fd.revents & POLLERR) == 0);
assert( (fd.revents & POLLNVAL) == 0);
switch(fd.revents)
{
case (POLLIN | POLLOUT) :
request.reply = mizaru::POLL_REPLY_RW;
break;
case POLLIN :
request.reply = mizaru::POLL_REPLY_READ;
break;
case POLLOUT :
request.reply = mizaru::POLL_REPLY_WRITE;
default :
request.reply = 0;
}
}
/**
* #fills in fd.events based on request.request
* and fd.fd based on request.sock
*/
static void prep_request(mizaru::PollRequest &request, pollfd &fd)
{
fd.fd = request.sock.get_handle();
switch(request.request)
{
case mizaru::PollType::POLL_READ :
fd.events = POLLIN;
break;
case mizaru::PollType::POLL_WRITE :
fd.events = POLLOUT;
break;
default :
fd.events = POLLIN | POLLOUT;
}
}
void mizaru::trans::poll(mizaru::PollRequest &request,const std::chrono::milliseconds& wait_time) noexcept
{
pollfd fd;
prep_request(request, fd);
poll(&fd, 1, wait_time.count());
translate_request(request, fd);
}
mizaru::PollRequest struct :
struct PollRequest
{
PollRequest(PollType request, const SyncSocket &sock) : sock(sock), request(request), reply(POLL_REPLY_FAIL) {}
const SyncSocket & sock;
PollType request;
uint8_t reply;
};
SyncSocket.get_handle() just returns the fd returned by socket(int,int,int) in sys/socket.h
test function :
bool test_poll()
{
mizaru::IPv4 ip ( "54.225.138.124" );
mizaru::SyncSocketTCP sock ( ip, 80 ,true);
mizaru::PollRequest request{mizaru::PollType::POLL_READ, sock};
std::chrono::milliseconds time(1000);
mizaru::poll(request, time);
if(request.reply == mizaru::POLL_REPLY_READ)
{
std::cout << "fail test_poll first read" <<std::endl;
return false;
}
request.request = mizaru::PollType::POLL_WRITE;
mizaru::poll(request, time);
if(request.reply != mizaru::POLL_REPLY_WRITE)
{
std::cout << "fail test_poll first write" << std::endl;
return false;
}
std::string toWrite ( "GET / http/1.1\nHost: httpbin.org\n\n" );
mizaru::byte_buffer write_buff;
for ( char c : toWrite )
{
write_buff.push_back ( c );
}
unsigned int r_value = sock.write ( write_buff );
if (r_value != toWrite.size())
{
std::cout << "fail test_poll r_value" << std::endl;
return false;
}
request.request = mizaru::PollType::POLL_READ;
mizaru::poll(request, time);
if(request.reply != mizaru::POLL_REPLY_READ)
{
std::cout << "fail test_poll second read" << std::endl;
return false;
}
request.request = mizaru::PollType::POLL_WRITE;
mizaru::poll(request, time);
if(request.reply != mizaru::POLL_REPLY_WRITE)
{
std::cout << "fail test_poll second write " << std::endl;
return false;
}
return true;
}
PollType enum :
enum PollType {POLL_READ, POLL_WRITE, POLL_RW};
POLL_REPLY_* constants :
constexpr uint8_t POLL_REPLY_READ = 0x01;
constexpr uint8_t POLL_REPLY_WRITE = 0x02;
constexpr uint8_t POLL_REPLY_RW = POLL_REPLY_READ | POLL_REPLY_WRITE;
constexpr uint8_t POLL_REPLY_FAIL = 0;
I am sorry that the sample code is not directly compilable I was trying to make it short,
Since I get a proper HTTP 200 OK reply you can assume connecting and read/write are handled properly. The test fails when polling for write
In translate_request():
switch(fd.revents)
{
case (POLLIN | POLLOUT) :
request.reply = mizaru::POLL_REPLY_RW;
break;
case POLLIN :
request.reply = mizaru::POLL_REPLY_READ;
break;
case POLLOUT :
request.reply = mizaru::POLL_REPLY_WRITE;
default :
request.reply = 0;
}
You are missing a break in both the POLLOUT and default cases. POLLOUT falls through to default, erasing the evidence that POLLOUT occurred.