From 0295f347d6394294cb2c81741ece78548d4cafc6 Mon Sep 17 00:00:00 2001 From: wujing Date: Thu, 14 Jan 2021 10:53:07 +0800 Subject: [PATCH 09/53] fix small probability of coredump in CRI streaming services in high concurrency scenarios Signed-off-by: wujing --- .../cri/cri_container_manager_service_impl.cc | 16 +- src/daemon/entry/cri/request_cache.cc | 74 ++++++--- src/daemon/entry/cri/request_cache.h | 29 +++- .../cri/websocket/service/attach_serve.cc | 60 ++++--- .../cri/websocket/service/attach_serve.h | 3 +- .../entry/cri/websocket/service/exec_serve.cc | 71 +++++---- .../entry/cri/websocket/service/exec_serve.h | 3 +- .../entry/cri/websocket/service/ws_server.cc | 148 ++++++++++-------- .../entry/cri/websocket/service/ws_server.h | 20 ++- src/utils/cpputils/read_write_lock.cc | 59 +++++++ src/utils/cpputils/read_write_lock.h | 90 +++++++++++ src/utils/cpputils/stoppable_thread.cc | 4 - 12 files changed, 392 insertions(+), 185 deletions(-) create mode 100644 src/utils/cpputils/read_write_lock.cc create mode 100644 src/utils/cpputils/read_write_lock.h diff --git a/src/daemon/entry/cri/cri_container_manager_service_impl.cc b/src/daemon/entry/cri/cri_container_manager_service_impl.cc index 45ecf9f2..812469ee 100644 --- a/src/daemon/entry/cri/cri_container_manager_service_impl.cc +++ b/src/daemon/entry/cri/cri_container_manager_service_impl.cc @@ -1251,15 +1251,9 @@ void ContainerManagerServiceImpl::Exec(const runtime::v1alpha2::ExecRequest &req return; } RequestCache *cache = RequestCache::GetInstance(); - runtime::v1alpha2::ExecRequest *execReq = new (std::nothrow) runtime::v1alpha2::ExecRequest(req); - if (execReq == nullptr) { - error.SetError("Out of memory"); - return; - } - std::string token = cache->Insert(const_cast(execReq)); + std::string token = cache->InsertExecRequest(req); if (token.empty()) { error.SetError("failed to get a unique token!"); - delete execReq; return; } std::string url = BuildURL("exec", token); @@ -1303,15 +1297,9 @@ void ContainerManagerServiceImpl::Attach(const runtime::v1alpha2::AttachRequest return; } RequestCache *cache = RequestCache::GetInstance(); - runtime::v1alpha2::AttachRequest *attachReq = new (std::nothrow) runtime::v1alpha2::AttachRequest(req); - if (attachReq == nullptr) { - error.SetError("Out of memory"); - return; - } - std::string token = cache->Insert(const_cast(attachReq)); + std::string token = cache->InsertAttachRequest(req); if (token.empty()) { error.SetError("failed to get a unique token!"); - delete attachReq; return; } std::string url = BuildURL("attach", token); diff --git a/src/daemon/entry/cri/request_cache.cc b/src/daemon/entry/cri/request_cache.cc index a3cb3771..b502715a 100644 --- a/src/daemon/entry/cri/request_cache.cc +++ b/src/daemon/entry/cri/request_cache.cc @@ -41,12 +41,26 @@ RequestCache *RequestCache::GetInstance() noexcept return cache; } -std::string RequestCache::Insert(::google::protobuf::Message *req) +std::string RequestCache::InsertExecRequest(const runtime::v1alpha2::ExecRequest &req) { - if (req == nullptr) { - ERROR("invalid request"); + std::lock_guard lock(m_mutex); + // Remove expired entries. + GarbageCollection(); + // If the cache is full, reject the request. + if (m_ll.size() == MaxInFlight) { + ERROR("too many cache in flight!"); return ""; } + auto token = UniqueToken(); + CacheEntry tmp; + tmp.SetValue(token, &req, nullptr, std::chrono::system_clock::now() + std::chrono::minutes(1)); + m_ll.push_front(tmp); + m_tokens.insert(std::make_pair(token, tmp)); + return token; +} + +std::string RequestCache::InsertAttachRequest(const runtime::v1alpha2::AttachRequest &req) +{ std::lock_guard lock(m_mutex); // Remove expired entries. GarbageCollection(); @@ -56,7 +70,8 @@ std::string RequestCache::Insert(::google::protobuf::Message *req) return ""; } auto token = UniqueToken(); - CacheEntry tmp { token, req, std::chrono::system_clock::now() + std::chrono::minutes(1) }; + CacheEntry tmp; + tmp.SetValue(token, nullptr, &req, std::chrono::system_clock::now() + std::chrono::minutes(1)); m_ll.push_front(tmp); m_tokens.insert(std::make_pair(token, tmp)); return token; @@ -64,16 +79,12 @@ std::string RequestCache::Insert(::google::protobuf::Message *req) void RequestCache::GarbageCollection() { - std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); + auto now = std::chrono::system_clock::now(); while (!m_ll.empty()) { CacheEntry oldest = m_ll.back(); if (now < oldest.expireTime) { return; } - if (oldest.req != nullptr) { - delete oldest.req; - oldest.req = nullptr; - } m_ll.pop_back(); m_tokens.erase(oldest.token); } @@ -124,34 +135,59 @@ std::string RequestCache::UniqueToken() ERROR("create unique token failed!"); return ""; } + bool RequestCache::IsValidToken(const std::string &token) { + std::lock_guard lock(m_mutex); + return static_cast(m_tokens.count(token)); } // Consume the token (remove it from the cache) and return the cached request, if found. -::google::protobuf::Message *RequestCache::Consume(const std::string &token, bool &found) +runtime::v1alpha2::ExecRequest RequestCache::ConsumeExecRequest(const std::string &token) { std::lock_guard lock(m_mutex); - found = false; - if (!IsValidToken(token)) { + if (m_tokens.count(token) == 0 || m_tokens[token].execRequest.size() == 0) { ERROR("Invalid token"); - return nullptr; + return runtime::v1alpha2::ExecRequest(); } CacheEntry ele = m_tokens[token]; for (auto it = m_ll.begin(); it != m_ll.end(); it++) { - if (it->token == ele.token) { + if (it->token == token) { m_ll.erase(it); break; } } m_tokens.erase(token); - std::chrono::system_clock::time_point now = std::chrono::system_clock::now(); - if (now > ele.expireTime) { - return nullptr; + if (std::chrono::system_clock::now() > ele.expireTime) { + return runtime::v1alpha2::ExecRequest(); } - found = true; - return ele.req; + + return ele.execRequest.at(0); } + +runtime::v1alpha2::AttachRequest RequestCache::ConsumeAttachRequest(const std::string &token) +{ + std::lock_guard lock(m_mutex); + + if (m_tokens.count(token) == 0 || m_tokens[token].attachRequest.size() == 0) { + ERROR("Invalid token"); + return runtime::v1alpha2::AttachRequest(); + } + + CacheEntry ele = m_tokens[token]; + for (auto it = m_ll.begin(); it != m_ll.end(); it++) { + if (it->token == token) { + m_ll.erase(it); + break; + } + } + m_tokens.erase(token); + if (std::chrono::system_clock::now() > ele.expireTime) { + return runtime::v1alpha2::AttachRequest(); + } + + return ele.attachRequest.at(0); +} \ No newline at end of file diff --git a/src/daemon/entry/cri/request_cache.h b/src/daemon/entry/cri/request_cache.h index 024f3ba7..0f86a85e 100644 --- a/src/daemon/entry/cri/request_cache.h +++ b/src/daemon/entry/cri/request_cache.h @@ -21,19 +21,38 @@ #include #include #include +#include #include +#include "api.pb.h" -typedef struct sCacheEntry { +struct CacheEntry { std::string token; - ::google::protobuf::Message *req; + std::vector execRequest; + std::vector attachRequest; std::chrono::system_clock::time_point expireTime; -} CacheEntry, *pCacheEntry; + + void SetValue(const std::string &t, + const runtime::v1alpha2::ExecRequest *execReq, + const runtime::v1alpha2::AttachRequest *attachReq, + std::chrono::system_clock::time_point et) + { + token = t; + if (execReq != nullptr) { + execRequest.push_back(*execReq); + } else if (attachReq != nullptr) { + attachRequest.push_back(*attachReq); + } + expireTime = et; + } +}; class RequestCache { public: static RequestCache *GetInstance() noexcept; - std::string Insert(::google::protobuf::Message *req); - ::google::protobuf::Message *Consume(const std::string &token, bool &found); + std::string InsertExecRequest(const runtime::v1alpha2::ExecRequest &req); + std::string InsertAttachRequest(const runtime::v1alpha2::AttachRequest &req); + runtime::v1alpha2::ExecRequest ConsumeExecRequest(const std::string &token); + runtime::v1alpha2::AttachRequest ConsumeAttachRequest(const std::string &token); bool IsValidToken(const std::string &token); private: diff --git a/src/daemon/entry/cri/websocket/service/attach_serve.cc b/src/daemon/entry/cri/websocket/service/attach_serve.cc index caf02c74..01c6b9cf 100644 --- a/src/daemon/entry/cri/websocket/service/attach_serve.cc +++ b/src/daemon/entry/cri/websocket/service/attach_serve.cc @@ -18,54 +18,50 @@ int AttachServe::Execute(struct lws *wsi, const std::string &token, int read_pipe_fd) { - RequestCache *cache = RequestCache::GetInstance(); - bool found = false; - auto cachedRequest = cache->Consume(token, found); - if (!found) { - ERROR("invalid token :%s", token.c_str()); - return -1; - } - runtime::v1alpha2::AttachRequest *request = dynamic_cast(cachedRequest); - if (request == nullptr) { - ERROR("failed to get exec request!"); - return -1; - } - - container_attach_request *container_req = nullptr; - container_attach_response *container_res = nullptr; - service_executor_t *cb = get_service_executor(); if (cb == nullptr || cb->container.attach == nullptr) { return -1; } - int tret = 0; - tret = RequestFromCri(request, &container_req); - if (tret != 0) { - ERROR("Failed to transform grpc request!"); + + container_attach_request *container_req = nullptr; + if (GetContainerRequest(token, &container_req) != 0) { + ERROR("Failed to get contaner request"); return -1; } + struct io_write_wrapper stringWriter = { 0 }; stringWriter.context = (void *)wsi; stringWriter.write_func = WsWriteStdoutToClient; stringWriter.close_func = closeWsConnect; container_req->attach_stderr = false; + + container_attach_response *container_res = nullptr; int ret = cb->container.attach(container_req, &container_res, container_req->attach_stdin ? read_pipe_fd : -1, &stringWriter, nullptr); + if (ret != 0) { + ERROR("Failed to attach container: %s", container_req->container_id); + } + free_container_attach_request(container_req); free_container_attach_response(container_res); - if (request != nullptr) { - delete request; - request = nullptr; - } - if (tret != 0) { - ERROR("Failed to translate response to grpc, operation is %s", ret ? "failed" : "success"); + return ret; +} + +int AttachServe::GetContainerRequest(const std::string &token, container_attach_request **container_req) +{ + RequestCache *cache = RequestCache::GetInstance(); + auto request = cache->ConsumeAttachRequest(token); + + int ret = RequestFromCri(request, container_req); + if (ret != 0) { + ERROR("Failed to transform grpc request!"); } return ret; } -int AttachServe::RequestFromCri(const runtime::v1alpha2::AttachRequest *grequest, container_attach_request **request) +int AttachServe::RequestFromCri(const runtime::v1alpha2::AttachRequest &grequest, container_attach_request **request) { container_attach_request *tmpreq = nullptr; @@ -75,12 +71,12 @@ int AttachServe::RequestFromCri(const runtime::v1alpha2::AttachRequest *grequest return -1; } - if (!grequest->container_id().empty()) { - tmpreq->container_id = util_strdup_s(grequest->container_id().c_str()); + if (!grequest.container_id().empty()) { + tmpreq->container_id = util_strdup_s(grequest.container_id().c_str()); } - tmpreq->attach_stdin = grequest->stdin(); - tmpreq->attach_stdout = grequest->stdout(); - tmpreq->attach_stderr = grequest->stderr(); + tmpreq->attach_stdin = grequest.stdin(); + tmpreq->attach_stdout = grequest.stdout(); + tmpreq->attach_stderr = grequest.stderr(); *request = tmpreq; diff --git a/src/daemon/entry/cri/websocket/service/attach_serve.h b/src/daemon/entry/cri/websocket/service/attach_serve.h index 7d57b9a3..00e2b34e 100644 --- a/src/daemon/entry/cri/websocket/service/attach_serve.h +++ b/src/daemon/entry/cri/websocket/service/attach_serve.h @@ -35,8 +35,9 @@ public: virtual ~AttachServe() = default; int Execute(struct lws *wsi, const std::string &token, int read_pipe_fd) override; private: - int RequestFromCri(const runtime::v1alpha2::AttachRequest *grequest, + int RequestFromCri(const runtime::v1alpha2::AttachRequest &grequest, container_attach_request **request); + int GetContainerRequest(const std::string &token, container_attach_request **container_req); }; #endif // DAEMON_ENTRY_CRI_WEBSOCKET_SERVICE_ATTACH_SERVE_H diff --git a/src/daemon/entry/cri/websocket/service/exec_serve.cc b/src/daemon/entry/cri/websocket/service/exec_serve.cc index b1a3759d..855d28b8 100644 --- a/src/daemon/entry/cri/websocket/service/exec_serve.cc +++ b/src/daemon/entry/cri/websocket/service/exec_serve.cc @@ -19,37 +19,25 @@ int ExecServe::Execute(struct lws *wsi, const std::string &token, int read_pipe_fd) { - RequestCache *cache = RequestCache::GetInstance(); - bool found = false; - auto cachedRequest = cache->Consume(token, found); - if (!found) { - ERROR("invalid token :%s", token.c_str()); - return -1; - } - runtime::v1alpha2::ExecRequest *request = dynamic_cast(cachedRequest); - if (request == nullptr) { - ERROR("failed to get exec request!"); - return -1; - } - - container_exec_request *container_req = nullptr; - container_exec_response *container_res = nullptr; - service_executor_t *cb = get_service_executor(); if (cb == nullptr || cb->container.exec == nullptr) { return -1; } - int tret = RequestFromCri(request, &container_req); - if (tret != 0) { - ERROR("Failed to transform grpc request!"); + + container_exec_request *container_req = nullptr; + if (GetContainerRequest(token, &container_req) != 0) { + ERROR("Failed to get contaner request"); return -1; } + struct io_write_wrapper StdoutstringWriter = { 0 }; StdoutstringWriter.context = (void *)wsi; StdoutstringWriter.write_func = WsWriteStdoutToClient; struct io_write_wrapper StderrstringWriter = { 0 }; StderrstringWriter.context = (void *)wsi; StderrstringWriter.write_func = WsWriteStderrToClient; + + container_exec_response *container_res = nullptr; int ret = cb->container.exec(container_req, &container_res, container_req->attach_stdin ? read_pipe_fd : -1, container_req->attach_stdout ? &StdoutstringWriter : nullptr, container_req->attach_stderr ? &StderrstringWriter : nullptr); @@ -66,19 +54,29 @@ int ExecServe::Execute(struct lws *wsi, const std::string &token, int read_pipe_ std::string exit_info = "Exit code :" + std::to_string((int)container_res->exit_code) + "\n"; WsWriteStdoutToClient(wsi, exit_info.c_str(), exit_info.length()); } + free_container_exec_request(container_req); free_container_exec_response(container_res); - if (request != nullptr) { - delete request; - request = nullptr; - } (void)closeWsConnect((void*)wsi, nullptr); return ret; } -int ExecServe::RequestFromCri(const runtime::v1alpha2::ExecRequest *grequest, container_exec_request **request) +int ExecServe::GetContainerRequest(const std::string &token, container_exec_request **container_req) +{ + RequestCache *cache = RequestCache::GetInstance(); + auto request = cache->ConsumeExecRequest(token); + + int ret = RequestFromCri(request, container_req); + if (ret != 0) { + ERROR("Failed to transform grpc request!"); + } + + return ret; +} + +int ExecServe::RequestFromCri(const runtime::v1alpha2::ExecRequest &grequest, container_exec_request **request) { container_exec_request *tmpreq = nullptr; @@ -88,32 +86,33 @@ int ExecServe::RequestFromCri(const runtime::v1alpha2::ExecRequest *grequest, co return -1; } - tmpreq->tty = grequest->tty(); - tmpreq->attach_stdin = grequest->stdin(); - tmpreq->attach_stdout = grequest->stdout(); - tmpreq->attach_stderr = grequest->stderr(); + tmpreq->tty = grequest.tty(); + tmpreq->attach_stdin = grequest.stdin(); + tmpreq->attach_stdout = grequest.stdout(); + tmpreq->attach_stderr = grequest.stderr(); - if (!grequest->container_id().empty()) { - tmpreq->container_id = util_strdup_s(grequest->container_id().c_str()); + if (!grequest.container_id().empty()) { + tmpreq->container_id = util_strdup_s(grequest.container_id().c_str()); } - if (grequest->cmd_size() > 0) { - if ((size_t)grequest->cmd_size() > SIZE_MAX / sizeof(char *)) { + if (grequest.cmd_size() > 0) { + if ((size_t)grequest.cmd_size() > SIZE_MAX / sizeof(char *)) { ERROR("Too many arguments!"); free_container_exec_request(tmpreq); return -1; } - tmpreq->argv = (char **)util_common_calloc_s(sizeof(char *) * grequest->cmd_size()); + tmpreq->argv = (char **)util_common_calloc_s(sizeof(char *) * grequest.cmd_size()); if (tmpreq->argv == nullptr) { ERROR("Out of memory!"); free_container_exec_request(tmpreq); return -1; } - for (int i = 0; i < grequest->cmd_size(); i++) { - tmpreq->argv[i] = util_strdup_s(grequest->cmd(i).c_str()); + for (int i = 0; i < grequest.cmd_size(); i++) { + tmpreq->argv[i] = util_strdup_s(grequest.cmd(i).c_str()); } - tmpreq->argv_len = (size_t)grequest->cmd_size(); + tmpreq->argv_len = (size_t)grequest.cmd_size(); } + *request = tmpreq; return 0; } diff --git a/src/daemon/entry/cri/websocket/service/exec_serve.h b/src/daemon/entry/cri/websocket/service/exec_serve.h index ef474018..b29c3e1e 100644 --- a/src/daemon/entry/cri/websocket/service/exec_serve.h +++ b/src/daemon/entry/cri/websocket/service/exec_serve.h @@ -40,6 +40,7 @@ public: int Execute(struct lws *wsi, const std::string &token, int read_pipe_fd) override; private: - int RequestFromCri(const runtime::v1alpha2::ExecRequest *grequest, container_exec_request **request); + int RequestFromCri(const runtime::v1alpha2::ExecRequest &grequest, container_exec_request **request); + int GetContainerRequest(const std::string &token, container_exec_request **request); }; #endif // DAEMON_ENTRY_CRI_WEBSOCKET_SERVICE_EXEC_SERVE_H diff --git a/src/daemon/entry/cri/websocket/service/ws_server.cc b/src/daemon/entry/cri/websocket/service/ws_server.cc index c7e1b538..795d2c1e 100644 --- a/src/daemon/entry/cri/websocket/service/ws_server.cc +++ b/src/daemon/entry/cri/websocket/service/ws_server.cc @@ -28,22 +28,19 @@ struct lws_context *WebsocketServer::m_context = nullptr; std::atomic WebsocketServer::m_instance; -std::mutex WebsocketServer::m_mutex; -std::unordered_map WebsocketServer::m_wsis; +RWMutex WebsocketServer::m_mutex; +std::unordered_map WebsocketServer::m_wsis; +std::unordered_set WebsocketServer::m_activeSession; + WebsocketServer *WebsocketServer::GetInstance() noexcept { - WebsocketServer *server = m_instance.load(std::memory_order_relaxed); - std::atomic_thread_fence(std::memory_order_acquire); - if (server == nullptr) { - std::lock_guard lock(m_mutex); - server = m_instance.load(std::memory_order_relaxed); - if (server == nullptr) { - server = new WebsocketServer; - std::atomic_thread_fence(std::memory_order_release); - m_instance.store(server, std::memory_order_relaxed); - } - } - return server; + static std::once_flag flag; + + std::call_once(flag, [] { + m_instance = new WebsocketServer; + }); + + return m_instance; } WebsocketServer::WebsocketServer() @@ -62,14 +59,14 @@ url::URLDatum WebsocketServer::GetWebsocketUrl() return m_url; } -std::unordered_map &WebsocketServer::GetWsisData() +std::unordered_map &WebsocketServer::GetWsisData() { return m_wsis; } -void WebsocketServer::LockAllWsSession() +void WebsocketServer::ReadLockAllWsSession() { - m_mutex.lock(); + m_mutex.rdlock(); } void WebsocketServer::UnlockAllWsSession() @@ -160,7 +157,7 @@ void WebsocketServer::RegisterCallback(const std::string &path, void WebsocketServer::CloseAllWsSession() { - std::lock_guard lock(m_mutex); + WriteGuard lock(m_mutex); for (auto it = m_wsis.begin(); it != m_wsis.end(); ++it) { free(it->second.buf); close(it->second.pipes.at(0)); @@ -172,15 +169,10 @@ void WebsocketServer::CloseAllWsSession() m_wsis.clear(); } -void WebsocketServer::CloseWsSession(struct lws *wsi) +void WebsocketServer::CloseWsSession(int socketID) { - const int WAIT_PERIOD_MS = 50; - - auto it = m_wsis.find(wsi); + auto it = m_wsis.find(socketID); if (it != m_wsis.end()) { - while (it->second.GetProcessingStatus()) { - std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_PERIOD_MS)); - } free(it->second.buf); close(it->second.pipes.at(0)); close(it->second.pipes.at(1)); @@ -191,6 +183,21 @@ void WebsocketServer::CloseWsSession(struct lws *wsi) } } +void WebsocketServer::RecordSession(struct lws *wsi) +{ + m_activeSession.insert(wsi); +} + +void WebsocketServer::RemoveSession(struct lws *wsi) +{ + m_activeSession.erase(wsi); +} + +bool WebsocketServer::IsValidSession(struct lws *wsi) +{ + return m_activeSession.count(wsi) != 0; +} + int WebsocketServer::DumpHandshakeInfo(struct lws *wsi) noexcept { int read_pipe_fd[PIPE_FD_NUM]; @@ -200,15 +207,17 @@ int WebsocketServer::DumpHandshakeInfo(struct lws *wsi) noexcept session_data session; session.pipes = std::array { read_pipe_fd[0], read_pipe_fd[1] }; - m_wsis.insert(std::make_pair(wsi, session)); - m_wsis[wsi].buf = (unsigned char *)util_common_calloc_s(LWS_PRE + MAX_MSG_BUFFER_SIZE + 1); - if (m_wsis[wsi].buf == nullptr) { + + int socketID = lws_get_socket_fd(wsi); + m_wsis.insert(std::make_pair(socketID, std::move(session))); + m_wsis[socketID].buf = (unsigned char *)util_common_calloc_s(LWS_PRE + MAX_MSG_BUFFER_SIZE + 1); + if (m_wsis[socketID].buf == nullptr) { ERROR("Out of memory"); return -1; } - m_wsis[wsi].buf_mutex = new std::mutex; - m_wsis[wsi].sended_mutex = new std::mutex; - m_wsis[wsi].SetProcessingStatus(false); + m_wsis[socketID].buf_mutex = new std::mutex; + m_wsis[socketID].sended_mutex = new std::mutex; + m_wsis[socketID].SetProcessingStatus(false); int len; char buf[MAX_BUF_LEN] { 0 }; @@ -216,7 +225,7 @@ int WebsocketServer::DumpHandshakeInfo(struct lws *wsi) noexcept lws_hdr_copy(wsi, buf, sizeof(buf), WSI_TOKEN_GET_URI); if (strlen(buf) == 0) { ERROR("invalid url"); - CloseWsSession(wsi); + CloseWsSession(socketID); return -1; } @@ -228,14 +237,15 @@ int WebsocketServer::DumpHandshakeInfo(struct lws *wsi) noexcept !m_handler.IsValidMethod(vec.at(1)) || !cache->IsValidToken(vec.at(2))) { ERROR("invalid url(%s): incorrect format!", buf); - CloseWsSession(wsi); + CloseWsSession(socketID); return -1; } std::thread streamTh([ = ]() { - StreamTask(&m_handler, wsi, vec.at(1), vec.at(2), m_wsis[wsi].pipes.at(0)).Run(); + StreamTask(&m_handler, wsi, vec.at(1), vec.at(2), m_wsis[socketID].pipes.at(0)).Run(); }); streamTh.detach(); + RecordSession(wsi); int n = 0; const unsigned char *c = nullptr; do { @@ -260,7 +270,7 @@ int WebsocketServer::DumpHandshakeInfo(struct lws *wsi) noexcept int WebsocketServer::Wswrite(struct lws *wsi, void *in, size_t len) { - auto it = m_wsis.find(wsi); + auto it = m_wsis.find(lws_get_socket_fd(wsi)); if (it != m_wsis.end()) { if (it->second.close) { DEBUG("websocket session disconnected"); @@ -286,9 +296,9 @@ int WebsocketServer::Wswrite(struct lws *wsi, void *in, size_t len) return 0; } -void WebsocketServer::Receive(struct lws *wsi, void *in, size_t len) +void WebsocketServer::Receive(int socketID, void *in, size_t len) { - if (m_wsis.find(wsi) == m_wsis.end()) { + if (m_wsis.find(socketID) == m_wsis.end()) { ERROR("invailed websocket session!"); return; } @@ -298,20 +308,20 @@ void WebsocketServer::Receive(struct lws *wsi, void *in, size_t len) return; } - if (write(m_wsis[wsi].pipes.at(1), (void *)((char *)in + 1), len - 1) < 0) { + if (write(m_wsis[socketID].pipes.at(1), (void *)((char *)in + 1), len - 1) < 0) { ERROR("sub write over!"); return; } } -void WebsocketServer::SetLwsSendedFlag(struct lws *wsi, bool sended) +void WebsocketServer::SetLwsSendedFlag(int socketID, bool sended) { - auto it = m_wsis.find(wsi); - if (it != m_wsis.end()) { - it->second.sended_mutex->lock(); - it->second.sended = sended; - it->second.sended_mutex->unlock(); + if (m_wsis.count(socketID) == 0) { + return; } + m_wsis[socketID].sended_mutex->lock(); + m_wsis[socketID].sended = sended; + m_wsis[socketID].sended_mutex->unlock(); } int WebsocketServer::Callback(struct lws *wsi, enum lws_callback_reasons reason, @@ -323,7 +333,7 @@ int WebsocketServer::Callback(struct lws *wsi, enum lws_callback_reasons reason, // asking to upgrade the connection to a websocket one. return -1; case LWS_CALLBACK_FILTER_PROTOCOL_CONNECTION: { - std::lock_guard lock(m_mutex); + WriteGuard lock(m_mutex); if (WebsocketServer::GetInstance()->DumpHandshakeInfo(wsi)) { // return non-zero here and kill the connection return -1; @@ -335,22 +345,27 @@ int WebsocketServer::Callback(struct lws *wsi, enum lws_callback_reasons reason, } break; case LWS_CALLBACK_SERVER_WRITEABLE: { - std::lock_guard lock(m_mutex); + ReadGuard lock(m_mutex); + int socketID = lws_get_socket_fd(wsi); if (WebsocketServer::GetInstance()->Wswrite(wsi, in, len)) { - WebsocketServer::GetInstance()->SetLwsSendedFlag(wsi, true); + WebsocketServer::GetInstance()->SetLwsSendedFlag(socketID, true); + // return nonzero from the user callback to close the connection + // and callback with the reason of LWS_CALLBACK_CLOSED return -1; } - WebsocketServer::GetInstance()->SetLwsSendedFlag(wsi, true); + WebsocketServer::GetInstance()->SetLwsSendedFlag(socketID, true); } break; case LWS_CALLBACK_RECEIVE: { - std::lock_guard lock(m_mutex); - WebsocketServer::GetInstance()->Receive(wsi, (char *)in, len); + ReadGuard lock(m_mutex); + WebsocketServer::GetInstance()->Receive(lws_get_socket_fd(wsi), (char *)in, len); } break; case LWS_CALLBACK_CLOSED: { - std::lock_guard lock(m_mutex); - WebsocketServer::GetInstance()->CloseWsSession(wsi); + WriteGuard lock(m_mutex); + DEBUG("connection has been closed"); + WebsocketServer::GetInstance()->RemoveSession(wsi); + WebsocketServer::GetInstance()->CloseWsSession(lws_get_socket_fd(wsi)); } break; default: @@ -363,8 +378,7 @@ void WebsocketServer::ServiceWorkThread(int threadid) { int n = 0; while (n >= 0 && !m_force_exit) { - n = lws_service(m_context, 50); - std::this_thread::sleep_for(std::chrono::milliseconds(1)); + n = lws_service(m_context, 0); } } @@ -396,20 +410,19 @@ void WebsocketServer::Wait() } namespace { -auto PrepareWsiSession(struct lws *wsi) -> session_data * +auto PrepareWsiSession(int socketID) -> session_data * { WebsocketServer *server = WebsocketServer::GetInstance(); - server->LockAllWsSession(); + server->ReadLockAllWsSession(); - auto itor = server->GetWsisData().find(wsi); + auto itor = server->GetWsisData().find(socketID); if (itor == server->GetWsisData().end()) { ERROR("invalid session!"); server->UnlockAllWsSession(); return nullptr; } - itor->second.SetProcessingStatus(true); + server->SetLwsSendedFlag(socketID, false); server->UnlockAllWsSession(); - server->SetLwsSendedFlag(wsi, false); return &itor->second; } @@ -450,15 +463,13 @@ void EnsureWrited(struct lws *wsi, session_data *session) } std::this_thread::sleep_for(std::chrono::milliseconds(TRIGGER_PERIOD_MS)); } - - session->SetProcessingStatus(false); } ssize_t WsWriteToClient(void *context, const void *data, size_t len, WebsocketChannel channel) { struct lws *wsi = static_cast(context); - session_data *session = PrepareWsiSession(wsi); + session_data *session = PrepareWsiSession(lws_get_socket_fd(wsi)); if (session == nullptr) { return 0; } @@ -487,15 +498,20 @@ int closeWsConnect(void *context, char **err) struct lws *wsi = static_cast(context); WebsocketServer *server = WebsocketServer::GetInstance(); - auto it = server->GetWsisData().find(wsi); + server->ReadLockAllWsSession(); + auto it = server->GetWsisData().find(lws_get_socket_fd(wsi)); if (it == server->GetWsisData().end()) { + server->UnlockAllWsSession(); ERROR("websocket session not exist"); return -1; } + it->second.close = true; // close websocket session - lws_callback_on_writable(wsi); + if (server->IsValidSession(wsi)) { + lws_callback_on_writable(wsi); + } + server->UnlockAllWsSession(); + return 0; } - - diff --git a/src/daemon/entry/cri/websocket/service/ws_server.h b/src/daemon/entry/cri/websocket/service/ws_server.h index 1370c552..cb431f7f 100644 --- a/src/daemon/entry/cri/websocket/service/ws_server.h +++ b/src/daemon/entry/cri/websocket/service/ws_server.h @@ -17,6 +17,7 @@ #define DAEMON_ENTRY_CRI_WEBSOCKET_SERVICE_WS_SERVER_H #include #include +#include #include #include #include @@ -26,6 +27,7 @@ #include "route_callback_register.h" #include "url.h" #include "errors.h" +#include "read_write_lock.h" #define MAX_ECHO_PAYLOAD 4096 #define MAX_ARRAY_LEN 2 @@ -71,10 +73,11 @@ public: void Shutdown(); void RegisterCallback(const std::string &path, std::shared_ptr callback); url::URLDatum GetWebsocketUrl(); - std::unordered_map &GetWsisData(); - void SetLwsSendedFlag(struct lws *wsi, bool sended); - void LockAllWsSession(); + std::unordered_map &GetWsisData(); + void SetLwsSendedFlag(int socketID, bool sended); + void ReadLockAllWsSession(); void UnlockAllWsSession(); + bool IsValidSession(struct lws *wsi); private: WebsocketServer(); @@ -85,17 +88,19 @@ private: std::vector split(std::string str, char r); static void EmitLog(int level, const char *line); int CreateContext(); - inline void Receive(struct lws *client, void *in, size_t len); + inline void Receive(int socketID, void *in, size_t len); int Wswrite(struct lws *wsi, void *in, size_t len); inline int DumpHandshakeInfo(struct lws *wsi) noexcept; static int Callback(struct lws *wsi, enum lws_callback_reasons reason, void *user, void *in, size_t len); void ServiceWorkThread(int threadid); - void CloseWsSession(struct lws *wsi); + void CloseWsSession(int socketID); void CloseAllWsSession(); + void RecordSession(struct lws *wsi); + void RemoveSession(struct lws *wsi); private: - static std::mutex m_mutex; + static RWMutex m_mutex; static struct lws_context *m_context; volatile int m_force_exit = 0; std::thread m_pthread_service; @@ -104,7 +109,8 @@ private: { NULL, NULL, 0, 0 } }; RouteCallbackRegister m_handler; - static std::unordered_map m_wsis; + static std::unordered_map m_wsis; + static std::unordered_set m_activeSession; url::URLDatum m_url; int m_listenPort; }; diff --git a/src/utils/cpputils/read_write_lock.cc b/src/utils/cpputils/read_write_lock.cc new file mode 100644 index 00000000..c9f94dc8 --- /dev/null +++ b/src/utils/cpputils/read_write_lock.cc @@ -0,0 +1,59 @@ +/****************************************************************************** + * Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved. + * iSulad licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR + * PURPOSE. + * See the Mulan PSL v2 for more details. + * Author: wujing + * Create: 2021-01-18 + * Description: provide read write lock implementation + *********************************************************************************/ + +#include "read_write_lock.h" + +void RWMutex::rdlock() +{ + std::unique_lock autoLock(m_mutex); + ++m_waiting_readers; + m_read_cond.wait(autoLock, [&]() { + return m_waiting_writers == 0 && m_status >= 0; + }); + --m_waiting_readers; + ++m_status; +} + +void RWMutex::wrlock() +{ + std::unique_lock autoLock(m_mutex); + ++m_waiting_writers; + m_write_cond.wait(autoLock, [&]() { + return m_status == 0; + }); + --m_waiting_writers; + --m_status; +} + +void RWMutex::unlock() +{ + std::unique_lock autoLock(m_mutex); + + if (m_status == -1) { // one writer + m_status = 0; + } else if (m_status > 0) { // one or multiple readers + --m_status; + } else { // neither readers nor writers + return; + } + + if (m_waiting_writers > 0) { + if (m_status == 0) { + m_write_cond.notify_one(); + } + } else { + m_read_cond.notify_all(); + } +} diff --git a/src/utils/cpputils/read_write_lock.h b/src/utils/cpputils/read_write_lock.h new file mode 100644 index 00000000..0149e3a5 --- /dev/null +++ b/src/utils/cpputils/read_write_lock.h @@ -0,0 +1,90 @@ +/****************************************************************************** + * Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved. + * iSulad licensed under the Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR + * PURPOSE. + * See the Mulan PSL v2 for more details. + * Author: wujing + * Create: 2021-01-18 + * Description: provide read write lock definition + *********************************************************************************/ +#ifndef UTILS_CPPUTILS_READ_WRITE_LOCK_H +#define UTILS_CPPUTILS_READ_WRITE_LOCK_H + +#include +#include +#include +#include + +class RWMutex { +public: + RWMutex() = default; + ~RWMutex() = default; + RWMutex(const RWMutex &) = delete; + RWMutex(RWMutex &&) = delete; + RWMutex &operator = (const RWMutex &) = delete; + RWMutex &operator = (RWMutex &&) = delete; + + void rdlock(); + void wrlock(); + void unlock(); + +private: + volatile long m_status {0}; + volatile long m_waiting_readers {0}; + volatile long m_waiting_writers {0}; + std::mutex m_mutex; + std::condition_variable m_read_cond; + std::condition_variable m_write_cond; +}; + +template +class ReadGuard { +public: + explicit ReadGuard(RWMutexType &lock) : m_lock(lock) + { + m_lock.rdlock(); + } + virtual ~ReadGuard() + { + m_lock.unlock(); + } + + ReadGuard() = delete; + ReadGuard(const ReadGuard &) = delete; + ReadGuard &operator=(const ReadGuard &) = delete; + ReadGuard(const ReadGuard &&) = delete; + ReadGuard &operator = (const ReadGuard &&) = delete; + +private: + RWMutexType &m_lock; +}; + + +template +class WriteGuard { +public: + explicit WriteGuard(RWMutexType &lock) : m_lock(lock) + { + m_lock.wrlock(); + } + virtual ~WriteGuard() + { + m_lock.unlock(); + } + + WriteGuard() = delete; + WriteGuard(const WriteGuard &) = delete; + WriteGuard &operator=(const WriteGuard &) = delete; + WriteGuard(const WriteGuard &&) = delete; + WriteGuard &operator = (const WriteGuard &&) = delete; + +private: + RWMutexType &m_lock; +}; + +#endif // UTILS_CPPUTILS_READ_WRITE_LOCK_H diff --git a/src/utils/cpputils/stoppable_thread.cc b/src/utils/cpputils/stoppable_thread.cc index 0d15aa01..68f6d9b2 100644 --- a/src/utils/cpputils/stoppable_thread.cc +++ b/src/utils/cpputils/stoppable_thread.cc @@ -22,7 +22,6 @@ StoppableThread &StoppableThread::operator=(StoppableThread &&obj) return *this; } - bool StoppableThread::stopRequested() { if (m_future_obj.wait_for(std::chrono::milliseconds(0)) == std::future_status::timeout) { @@ -35,6 +34,3 @@ void StoppableThread::stop() { m_exit_signal.set_value(); } - - - -- 2.25.1