iSulad/0009-fix-small-probability-of-coredump-in-CRI-streaming-s.patch
Li Feng 2c873f3fa9 iSulad: sync with upstream
Signed-off-by: Li Feng <lifeng2221dd1@zoho.com.cn>
(cherry picked from commit 230d9a529755704c38f8051b8fe35f4543e5b806)
2021-01-18 20:19:07 +08:00

1079 lines
39 KiB
Diff

From 0295f347d6394294cb2c81741ece78548d4cafc6 Mon Sep 17 00:00:00 2001
From: wujing <jing.woo@outlook.com>
Date: Thu, 14 Jan 2021 10:53:07 +0800
Subject: [PATCH 9/9] fix small probability of coredump in CRI streaming
services in high concurrency scenarios
Signed-off-by: wujing <wujing50@huawei.com>
---
.../cri/cri_container_manager_service_impl.cc | 16 +-
src/daemon/entry/cri/request_cache.cc | 74 ++++++---
src/daemon/entry/cri/request_cache.h | 29 +++-
.../cri/websocket/service/attach_serve.cc | 60 ++++---
.../cri/websocket/service/attach_serve.h | 3 +-
.../entry/cri/websocket/service/exec_serve.cc | 71 +++++----
.../entry/cri/websocket/service/exec_serve.h | 3 +-
.../entry/cri/websocket/service/ws_server.cc | 148 ++++++++++--------
.../entry/cri/websocket/service/ws_server.h | 20 ++-
src/utils/cpputils/read_write_lock.cc | 59 +++++++
src/utils/cpputils/read_write_lock.h | 90 +++++++++++
src/utils/cpputils/stoppable_thread.cc | 4 -
12 files changed, 392 insertions(+), 185 deletions(-)
create mode 100644 src/utils/cpputils/read_write_lock.cc
create mode 100644 src/utils/cpputils/read_write_lock.h
diff --git a/src/daemon/entry/cri/cri_container_manager_service_impl.cc b/src/daemon/entry/cri/cri_container_manager_service_impl.cc
index 45ecf9f2..812469ee 100644
--- a/src/daemon/entry/cri/cri_container_manager_service_impl.cc
+++ b/src/daemon/entry/cri/cri_container_manager_service_impl.cc
@@ -1251,15 +1251,9 @@ void ContainerManagerServiceImpl::Exec(const runtime::v1alpha2::ExecRequest &req
return;
}
RequestCache *cache = RequestCache::GetInstance();
- runtime::v1alpha2::ExecRequest *execReq = new (std::nothrow) runtime::v1alpha2::ExecRequest(req);
- if (execReq == nullptr) {
- error.SetError("Out of memory");
- return;
- }
- std::string token = cache->Insert(const_cast<runtime::v1alpha2::ExecRequest *>(execReq));
+ std::string token = cache->InsertExecRequest(req);
if (token.empty()) {
error.SetError("failed to get a unique token!");
- delete execReq;
return;
}
std::string url = BuildURL("exec", token);
@@ -1303,15 +1297,9 @@ void ContainerManagerServiceImpl::Attach(const runtime::v1alpha2::AttachRequest
return;
}
RequestCache *cache = RequestCache::GetInstance();
- runtime::v1alpha2::AttachRequest *attachReq = new (std::nothrow) runtime::v1alpha2::AttachRequest(req);
- if (attachReq == nullptr) {
- error.SetError("Out of memory");
- return;
- }
- std::string token = cache->Insert(const_cast<runtime::v1alpha2::AttachRequest *>(attachReq));
+ std::string token = cache->InsertAttachRequest(req);
if (token.empty()) {
error.SetError("failed to get a unique token!");
- delete attachReq;
return;
}
std::string url = BuildURL("attach", token);
diff --git a/src/daemon/entry/cri/request_cache.cc b/src/daemon/entry/cri/request_cache.cc
index a3cb3771..b502715a 100644
--- a/src/daemon/entry/cri/request_cache.cc
+++ b/src/daemon/entry/cri/request_cache.cc
@@ -41,12 +41,26 @@ RequestCache *RequestCache::GetInstance() noexcept
return cache;
}
-std::string RequestCache::Insert(::google::protobuf::Message *req)
+std::string RequestCache::InsertExecRequest(const runtime::v1alpha2::ExecRequest &req)
{
- if (req == nullptr) {
- ERROR("invalid request");
+ std::lock_guard<std::mutex> lock(m_mutex);
+ // Remove expired entries.
+ GarbageCollection();
+ // If the cache is full, reject the request.
+ if (m_ll.size() == MaxInFlight) {
+ ERROR("too many cache in flight!");
return "";
}
+ auto token = UniqueToken();
+ CacheEntry tmp;
+ tmp.SetValue(token, &req, nullptr, std::chrono::system_clock::now() + std::chrono::minutes(1));
+ m_ll.push_front(tmp);
+ m_tokens.insert(std::make_pair(token, tmp));
+ return token;
+}
+
+std::string RequestCache::InsertAttachRequest(const runtime::v1alpha2::AttachRequest &req)
+{
std::lock_guard<std::mutex> lock(m_mutex);
// Remove expired entries.
GarbageCollection();
@@ -56,7 +70,8 @@ std::string RequestCache::Insert(::google::protobuf::Message *req)
return "";
}
auto token = UniqueToken();
- CacheEntry tmp { token, req, std::chrono::system_clock::now() + std::chrono::minutes(1) };
+ CacheEntry tmp;
+ tmp.SetValue(token, nullptr, &req, std::chrono::system_clock::now() + std::chrono::minutes(1));
m_ll.push_front(tmp);
m_tokens.insert(std::make_pair(token, tmp));
return token;
@@ -64,16 +79,12 @@ std::string RequestCache::Insert(::google::protobuf::Message *req)
void RequestCache::GarbageCollection()
{
- std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
+ auto now = std::chrono::system_clock::now();
while (!m_ll.empty()) {
CacheEntry oldest = m_ll.back();
if (now < oldest.expireTime) {
return;
}
- if (oldest.req != nullptr) {
- delete oldest.req;
- oldest.req = nullptr;
- }
m_ll.pop_back();
m_tokens.erase(oldest.token);
}
@@ -124,34 +135,59 @@ std::string RequestCache::UniqueToken()
ERROR("create unique token failed!");
return "";
}
+
bool RequestCache::IsValidToken(const std::string &token)
{
+ std::lock_guard<std::mutex> lock(m_mutex);
+
return static_cast<bool>(m_tokens.count(token));
}
// Consume the token (remove it from the cache) and return the cached request, if found.
-::google::protobuf::Message *RequestCache::Consume(const std::string &token, bool &found)
+runtime::v1alpha2::ExecRequest RequestCache::ConsumeExecRequest(const std::string &token)
{
std::lock_guard<std::mutex> lock(m_mutex);
- found = false;
- if (!IsValidToken(token)) {
+ if (m_tokens.count(token) == 0 || m_tokens[token].execRequest.size() == 0) {
ERROR("Invalid token");
- return nullptr;
+ return runtime::v1alpha2::ExecRequest();
}
CacheEntry ele = m_tokens[token];
for (auto it = m_ll.begin(); it != m_ll.end(); it++) {
- if (it->token == ele.token) {
+ if (it->token == token) {
m_ll.erase(it);
break;
}
}
m_tokens.erase(token);
- std::chrono::system_clock::time_point now = std::chrono::system_clock::now();
- if (now > ele.expireTime) {
- return nullptr;
+ if (std::chrono::system_clock::now() > ele.expireTime) {
+ return runtime::v1alpha2::ExecRequest();
}
- found = true;
- return ele.req;
+
+ return ele.execRequest.at(0);
}
+
+runtime::v1alpha2::AttachRequest RequestCache::ConsumeAttachRequest(const std::string &token)
+{
+ std::lock_guard<std::mutex> lock(m_mutex);
+
+ if (m_tokens.count(token) == 0 || m_tokens[token].attachRequest.size() == 0) {
+ ERROR("Invalid token");
+ return runtime::v1alpha2::AttachRequest();
+ }
+
+ CacheEntry ele = m_tokens[token];
+ for (auto it = m_ll.begin(); it != m_ll.end(); it++) {
+ if (it->token == token) {
+ m_ll.erase(it);
+ break;
+ }
+ }
+ m_tokens.erase(token);
+ if (std::chrono::system_clock::now() > ele.expireTime) {
+ return runtime::v1alpha2::AttachRequest();
+ }
+
+ return ele.attachRequest.at(0);
+}
\ No newline at end of file
diff --git a/src/daemon/entry/cri/request_cache.h b/src/daemon/entry/cri/request_cache.h
index 024f3ba7..0f86a85e 100644
--- a/src/daemon/entry/cri/request_cache.h
+++ b/src/daemon/entry/cri/request_cache.h
@@ -21,19 +21,38 @@
#include <mutex>
#include <unordered_map>
#include <chrono>
+#include <typeinfo>
#include <google/protobuf/message.h>
+#include "api.pb.h"
-typedef struct sCacheEntry {
+struct CacheEntry {
std::string token;
- ::google::protobuf::Message *req;
+ std::vector<runtime::v1alpha2::ExecRequest> execRequest;
+ std::vector<runtime::v1alpha2::AttachRequest> attachRequest;
std::chrono::system_clock::time_point expireTime;
-} CacheEntry, *pCacheEntry;
+
+ void SetValue(const std::string &t,
+ const runtime::v1alpha2::ExecRequest *execReq,
+ const runtime::v1alpha2::AttachRequest *attachReq,
+ std::chrono::system_clock::time_point et)
+ {
+ token = t;
+ if (execReq != nullptr) {
+ execRequest.push_back(*execReq);
+ } else if (attachReq != nullptr) {
+ attachRequest.push_back(*attachReq);
+ }
+ expireTime = et;
+ }
+};
class RequestCache {
public:
static RequestCache *GetInstance() noexcept;
- std::string Insert(::google::protobuf::Message *req);
- ::google::protobuf::Message *Consume(const std::string &token, bool &found);
+ std::string InsertExecRequest(const runtime::v1alpha2::ExecRequest &req);
+ std::string InsertAttachRequest(const runtime::v1alpha2::AttachRequest &req);
+ runtime::v1alpha2::ExecRequest ConsumeExecRequest(const std::string &token);
+ runtime::v1alpha2::AttachRequest ConsumeAttachRequest(const std::string &token);
bool IsValidToken(const std::string &token);
private:
diff --git a/src/daemon/entry/cri/websocket/service/attach_serve.cc b/src/daemon/entry/cri/websocket/service/attach_serve.cc
index caf02c74..01c6b9cf 100644
--- a/src/daemon/entry/cri/websocket/service/attach_serve.cc
+++ b/src/daemon/entry/cri/websocket/service/attach_serve.cc
@@ -18,54 +18,50 @@
int AttachServe::Execute(struct lws *wsi, const std::string &token, int read_pipe_fd)
{
- RequestCache *cache = RequestCache::GetInstance();
- bool found = false;
- auto cachedRequest = cache->Consume(token, found);
- if (!found) {
- ERROR("invalid token :%s", token.c_str());
- return -1;
- }
- runtime::v1alpha2::AttachRequest *request = dynamic_cast<runtime::v1alpha2::AttachRequest *>(cachedRequest);
- if (request == nullptr) {
- ERROR("failed to get exec request!");
- return -1;
- }
-
- container_attach_request *container_req = nullptr;
- container_attach_response *container_res = nullptr;
-
service_executor_t *cb = get_service_executor();
if (cb == nullptr || cb->container.attach == nullptr) {
return -1;
}
- int tret = 0;
- tret = RequestFromCri(request, &container_req);
- if (tret != 0) {
- ERROR("Failed to transform grpc request!");
+
+ container_attach_request *container_req = nullptr;
+ if (GetContainerRequest(token, &container_req) != 0) {
+ ERROR("Failed to get contaner request");
return -1;
}
+
struct io_write_wrapper stringWriter = { 0 };
stringWriter.context = (void *)wsi;
stringWriter.write_func = WsWriteStdoutToClient;
stringWriter.close_func = closeWsConnect;
container_req->attach_stderr = false;
+
+ container_attach_response *container_res = nullptr;
int ret = cb->container.attach(container_req, &container_res, container_req->attach_stdin ? read_pipe_fd : -1,
&stringWriter, nullptr);
+ if (ret != 0) {
+ ERROR("Failed to attach container: %s", container_req->container_id);
+ }
+
free_container_attach_request(container_req);
free_container_attach_response(container_res);
- if (request != nullptr) {
- delete request;
- request = nullptr;
- }
- if (tret != 0) {
- ERROR("Failed to translate response to grpc, operation is %s", ret ? "failed" : "success");
+ return ret;
+}
+
+int AttachServe::GetContainerRequest(const std::string &token, container_attach_request **container_req)
+{
+ RequestCache *cache = RequestCache::GetInstance();
+ auto request = cache->ConsumeAttachRequest(token);
+
+ int ret = RequestFromCri(request, container_req);
+ if (ret != 0) {
+ ERROR("Failed to transform grpc request!");
}
return ret;
}
-int AttachServe::RequestFromCri(const runtime::v1alpha2::AttachRequest *grequest, container_attach_request **request)
+int AttachServe::RequestFromCri(const runtime::v1alpha2::AttachRequest &grequest, container_attach_request **request)
{
container_attach_request *tmpreq = nullptr;
@@ -75,12 +71,12 @@ int AttachServe::RequestFromCri(const runtime::v1alpha2::AttachRequest *grequest
return -1;
}
- if (!grequest->container_id().empty()) {
- tmpreq->container_id = util_strdup_s(grequest->container_id().c_str());
+ if (!grequest.container_id().empty()) {
+ tmpreq->container_id = util_strdup_s(grequest.container_id().c_str());
}
- tmpreq->attach_stdin = grequest->stdin();
- tmpreq->attach_stdout = grequest->stdout();
- tmpreq->attach_stderr = grequest->stderr();
+ tmpreq->attach_stdin = grequest.stdin();
+ tmpreq->attach_stdout = grequest.stdout();
+ tmpreq->attach_stderr = grequest.stderr();
*request = tmpreq;
diff --git a/src/daemon/entry/cri/websocket/service/attach_serve.h b/src/daemon/entry/cri/websocket/service/attach_serve.h
index 7d57b9a3..00e2b34e 100644
--- a/src/daemon/entry/cri/websocket/service/attach_serve.h
+++ b/src/daemon/entry/cri/websocket/service/attach_serve.h
@@ -35,8 +35,9 @@ public:
virtual ~AttachServe() = default;
int Execute(struct lws *wsi, const std::string &token, int read_pipe_fd) override;
private:
- int RequestFromCri(const runtime::v1alpha2::AttachRequest *grequest,
+ int RequestFromCri(const runtime::v1alpha2::AttachRequest &grequest,
container_attach_request **request);
+ int GetContainerRequest(const std::string &token, container_attach_request **container_req);
};
#endif // DAEMON_ENTRY_CRI_WEBSOCKET_SERVICE_ATTACH_SERVE_H
diff --git a/src/daemon/entry/cri/websocket/service/exec_serve.cc b/src/daemon/entry/cri/websocket/service/exec_serve.cc
index b1a3759d..855d28b8 100644
--- a/src/daemon/entry/cri/websocket/service/exec_serve.cc
+++ b/src/daemon/entry/cri/websocket/service/exec_serve.cc
@@ -19,37 +19,25 @@
int ExecServe::Execute(struct lws *wsi, const std::string &token, int read_pipe_fd)
{
- RequestCache *cache = RequestCache::GetInstance();
- bool found = false;
- auto cachedRequest = cache->Consume(token, found);
- if (!found) {
- ERROR("invalid token :%s", token.c_str());
- return -1;
- }
- runtime::v1alpha2::ExecRequest *request = dynamic_cast<runtime::v1alpha2::ExecRequest *>(cachedRequest);
- if (request == nullptr) {
- ERROR("failed to get exec request!");
- return -1;
- }
-
- container_exec_request *container_req = nullptr;
- container_exec_response *container_res = nullptr;
-
service_executor_t *cb = get_service_executor();
if (cb == nullptr || cb->container.exec == nullptr) {
return -1;
}
- int tret = RequestFromCri(request, &container_req);
- if (tret != 0) {
- ERROR("Failed to transform grpc request!");
+
+ container_exec_request *container_req = nullptr;
+ if (GetContainerRequest(token, &container_req) != 0) {
+ ERROR("Failed to get contaner request");
return -1;
}
+
struct io_write_wrapper StdoutstringWriter = { 0 };
StdoutstringWriter.context = (void *)wsi;
StdoutstringWriter.write_func = WsWriteStdoutToClient;
struct io_write_wrapper StderrstringWriter = { 0 };
StderrstringWriter.context = (void *)wsi;
StderrstringWriter.write_func = WsWriteStderrToClient;
+
+ container_exec_response *container_res = nullptr;
int ret = cb->container.exec(container_req, &container_res, container_req->attach_stdin ? read_pipe_fd : -1,
container_req->attach_stdout ? &StdoutstringWriter : nullptr,
container_req->attach_stderr ? &StderrstringWriter : nullptr);
@@ -66,19 +54,29 @@ int ExecServe::Execute(struct lws *wsi, const std::string &token, int read_pipe_
std::string exit_info = "Exit code :" + std::to_string((int)container_res->exit_code) + "\n";
WsWriteStdoutToClient(wsi, exit_info.c_str(), exit_info.length());
}
+
free_container_exec_request(container_req);
free_container_exec_response(container_res);
- if (request != nullptr) {
- delete request;
- request = nullptr;
- }
(void)closeWsConnect((void*)wsi, nullptr);
return ret;
}
-int ExecServe::RequestFromCri(const runtime::v1alpha2::ExecRequest *grequest, container_exec_request **request)
+int ExecServe::GetContainerRequest(const std::string &token, container_exec_request **container_req)
+{
+ RequestCache *cache = RequestCache::GetInstance();
+ auto request = cache->ConsumeExecRequest(token);
+
+ int ret = RequestFromCri(request, container_req);
+ if (ret != 0) {
+ ERROR("Failed to transform grpc request!");
+ }
+
+ return ret;
+}
+
+int ExecServe::RequestFromCri(const runtime::v1alpha2::ExecRequest &grequest, container_exec_request **request)
{
container_exec_request *tmpreq = nullptr;
@@ -88,32 +86,33 @@ int ExecServe::RequestFromCri(const runtime::v1alpha2::ExecRequest *grequest, co
return -1;
}
- tmpreq->tty = grequest->tty();
- tmpreq->attach_stdin = grequest->stdin();
- tmpreq->attach_stdout = grequest->stdout();
- tmpreq->attach_stderr = grequest->stderr();
+ tmpreq->tty = grequest.tty();
+ tmpreq->attach_stdin = grequest.stdin();
+ tmpreq->attach_stdout = grequest.stdout();
+ tmpreq->attach_stderr = grequest.stderr();
- if (!grequest->container_id().empty()) {
- tmpreq->container_id = util_strdup_s(grequest->container_id().c_str());
+ if (!grequest.container_id().empty()) {
+ tmpreq->container_id = util_strdup_s(grequest.container_id().c_str());
}
- if (grequest->cmd_size() > 0) {
- if ((size_t)grequest->cmd_size() > SIZE_MAX / sizeof(char *)) {
+ if (grequest.cmd_size() > 0) {
+ if ((size_t)grequest.cmd_size() > SIZE_MAX / sizeof(char *)) {
ERROR("Too many arguments!");
free_container_exec_request(tmpreq);
return -1;
}
- tmpreq->argv = (char **)util_common_calloc_s(sizeof(char *) * grequest->cmd_size());
+ tmpreq->argv = (char **)util_common_calloc_s(sizeof(char *) * grequest.cmd_size());
if (tmpreq->argv == nullptr) {
ERROR("Out of memory!");
free_container_exec_request(tmpreq);
return -1;
}
- for (int i = 0; i < grequest->cmd_size(); i++) {
- tmpreq->argv[i] = util_strdup_s(grequest->cmd(i).c_str());
+ for (int i = 0; i < grequest.cmd_size(); i++) {
+ tmpreq->argv[i] = util_strdup_s(grequest.cmd(i).c_str());
}
- tmpreq->argv_len = (size_t)grequest->cmd_size();
+ tmpreq->argv_len = (size_t)grequest.cmd_size();
}
+
*request = tmpreq;
return 0;
}
diff --git a/src/daemon/entry/cri/websocket/service/exec_serve.h b/src/daemon/entry/cri/websocket/service/exec_serve.h
index ef474018..b29c3e1e 100644
--- a/src/daemon/entry/cri/websocket/service/exec_serve.h
+++ b/src/daemon/entry/cri/websocket/service/exec_serve.h
@@ -40,6 +40,7 @@ public:
int Execute(struct lws *wsi, const std::string &token, int read_pipe_fd) override;
private:
- int RequestFromCri(const runtime::v1alpha2::ExecRequest *grequest, container_exec_request **request);
+ int RequestFromCri(const runtime::v1alpha2::ExecRequest &grequest, container_exec_request **request);
+ int GetContainerRequest(const std::string &token, container_exec_request **request);
};
#endif // DAEMON_ENTRY_CRI_WEBSOCKET_SERVICE_EXEC_SERVE_H
diff --git a/src/daemon/entry/cri/websocket/service/ws_server.cc b/src/daemon/entry/cri/websocket/service/ws_server.cc
index c7e1b538..795d2c1e 100644
--- a/src/daemon/entry/cri/websocket/service/ws_server.cc
+++ b/src/daemon/entry/cri/websocket/service/ws_server.cc
@@ -28,22 +28,19 @@
struct lws_context *WebsocketServer::m_context = nullptr;
std::atomic<WebsocketServer *> WebsocketServer::m_instance;
-std::mutex WebsocketServer::m_mutex;
-std::unordered_map<struct lws *, session_data> WebsocketServer::m_wsis;
+RWMutex WebsocketServer::m_mutex;
+std::unordered_map<int, session_data> WebsocketServer::m_wsis;
+std::unordered_set<struct lws *> WebsocketServer::m_activeSession;
+
WebsocketServer *WebsocketServer::GetInstance() noexcept
{
- WebsocketServer *server = m_instance.load(std::memory_order_relaxed);
- std::atomic_thread_fence(std::memory_order_acquire);
- if (server == nullptr) {
- std::lock_guard<std::mutex> lock(m_mutex);
- server = m_instance.load(std::memory_order_relaxed);
- if (server == nullptr) {
- server = new WebsocketServer;
- std::atomic_thread_fence(std::memory_order_release);
- m_instance.store(server, std::memory_order_relaxed);
- }
- }
- return server;
+ static std::once_flag flag;
+
+ std::call_once(flag, [] {
+ m_instance = new WebsocketServer;
+ });
+
+ return m_instance;
}
WebsocketServer::WebsocketServer()
@@ -62,14 +59,14 @@ url::URLDatum WebsocketServer::GetWebsocketUrl()
return m_url;
}
-std::unordered_map<struct lws *, session_data> &WebsocketServer::GetWsisData()
+std::unordered_map<int, session_data> &WebsocketServer::GetWsisData()
{
return m_wsis;
}
-void WebsocketServer::LockAllWsSession()
+void WebsocketServer::ReadLockAllWsSession()
{
- m_mutex.lock();
+ m_mutex.rdlock();
}
void WebsocketServer::UnlockAllWsSession()
@@ -160,7 +157,7 @@ void WebsocketServer::RegisterCallback(const std::string &path,
void WebsocketServer::CloseAllWsSession()
{
- std::lock_guard<std::mutex> lock(m_mutex);
+ WriteGuard<RWMutex> lock(m_mutex);
for (auto it = m_wsis.begin(); it != m_wsis.end(); ++it) {
free(it->second.buf);
close(it->second.pipes.at(0));
@@ -172,15 +169,10 @@ void WebsocketServer::CloseAllWsSession()
m_wsis.clear();
}
-void WebsocketServer::CloseWsSession(struct lws *wsi)
+void WebsocketServer::CloseWsSession(int socketID)
{
- const int WAIT_PERIOD_MS = 50;
-
- auto it = m_wsis.find(wsi);
+ auto it = m_wsis.find(socketID);
if (it != m_wsis.end()) {
- while (it->second.GetProcessingStatus()) {
- std::this_thread::sleep_for(std::chrono::milliseconds(WAIT_PERIOD_MS));
- }
free(it->second.buf);
close(it->second.pipes.at(0));
close(it->second.pipes.at(1));
@@ -191,6 +183,21 @@ void WebsocketServer::CloseWsSession(struct lws *wsi)
}
}
+void WebsocketServer::RecordSession(struct lws *wsi)
+{
+ m_activeSession.insert(wsi);
+}
+
+void WebsocketServer::RemoveSession(struct lws *wsi)
+{
+ m_activeSession.erase(wsi);
+}
+
+bool WebsocketServer::IsValidSession(struct lws *wsi)
+{
+ return m_activeSession.count(wsi) != 0;
+}
+
int WebsocketServer::DumpHandshakeInfo(struct lws *wsi) noexcept
{
int read_pipe_fd[PIPE_FD_NUM];
@@ -200,15 +207,17 @@ int WebsocketServer::DumpHandshakeInfo(struct lws *wsi) noexcept
session_data session;
session.pipes = std::array<int, MAX_ARRAY_LEN> { read_pipe_fd[0], read_pipe_fd[1] };
- m_wsis.insert(std::make_pair(wsi, session));
- m_wsis[wsi].buf = (unsigned char *)util_common_calloc_s(LWS_PRE + MAX_MSG_BUFFER_SIZE + 1);
- if (m_wsis[wsi].buf == nullptr) {
+
+ int socketID = lws_get_socket_fd(wsi);
+ m_wsis.insert(std::make_pair(socketID, std::move(session)));
+ m_wsis[socketID].buf = (unsigned char *)util_common_calloc_s(LWS_PRE + MAX_MSG_BUFFER_SIZE + 1);
+ if (m_wsis[socketID].buf == nullptr) {
ERROR("Out of memory");
return -1;
}
- m_wsis[wsi].buf_mutex = new std::mutex;
- m_wsis[wsi].sended_mutex = new std::mutex;
- m_wsis[wsi].SetProcessingStatus(false);
+ m_wsis[socketID].buf_mutex = new std::mutex;
+ m_wsis[socketID].sended_mutex = new std::mutex;
+ m_wsis[socketID].SetProcessingStatus(false);
int len;
char buf[MAX_BUF_LEN] { 0 };
@@ -216,7 +225,7 @@ int WebsocketServer::DumpHandshakeInfo(struct lws *wsi) noexcept
lws_hdr_copy(wsi, buf, sizeof(buf), WSI_TOKEN_GET_URI);
if (strlen(buf) == 0) {
ERROR("invalid url");
- CloseWsSession(wsi);
+ CloseWsSession(socketID);
return -1;
}
@@ -228,14 +237,15 @@ int WebsocketServer::DumpHandshakeInfo(struct lws *wsi) noexcept
!m_handler.IsValidMethod(vec.at(1)) ||
!cache->IsValidToken(vec.at(2))) {
ERROR("invalid url(%s): incorrect format!", buf);
- CloseWsSession(wsi);
+ CloseWsSession(socketID);
return -1;
}
std::thread streamTh([ = ]() {
- StreamTask(&m_handler, wsi, vec.at(1), vec.at(2), m_wsis[wsi].pipes.at(0)).Run();
+ StreamTask(&m_handler, wsi, vec.at(1), vec.at(2), m_wsis[socketID].pipes.at(0)).Run();
});
streamTh.detach();
+ RecordSession(wsi);
int n = 0;
const unsigned char *c = nullptr;
do {
@@ -260,7 +270,7 @@ int WebsocketServer::DumpHandshakeInfo(struct lws *wsi) noexcept
int WebsocketServer::Wswrite(struct lws *wsi, void *in, size_t len)
{
- auto it = m_wsis.find(wsi);
+ auto it = m_wsis.find(lws_get_socket_fd(wsi));
if (it != m_wsis.end()) {
if (it->second.close) {
DEBUG("websocket session disconnected");
@@ -286,9 +296,9 @@ int WebsocketServer::Wswrite(struct lws *wsi, void *in, size_t len)
return 0;
}
-void WebsocketServer::Receive(struct lws *wsi, void *in, size_t len)
+void WebsocketServer::Receive(int socketID, void *in, size_t len)
{
- if (m_wsis.find(wsi) == m_wsis.end()) {
+ if (m_wsis.find(socketID) == m_wsis.end()) {
ERROR("invailed websocket session!");
return;
}
@@ -298,20 +308,20 @@ void WebsocketServer::Receive(struct lws *wsi, void *in, size_t len)
return;
}
- if (write(m_wsis[wsi].pipes.at(1), (void *)((char *)in + 1), len - 1) < 0) {
+ if (write(m_wsis[socketID].pipes.at(1), (void *)((char *)in + 1), len - 1) < 0) {
ERROR("sub write over!");
return;
}
}
-void WebsocketServer::SetLwsSendedFlag(struct lws *wsi, bool sended)
+void WebsocketServer::SetLwsSendedFlag(int socketID, bool sended)
{
- auto it = m_wsis.find(wsi);
- if (it != m_wsis.end()) {
- it->second.sended_mutex->lock();
- it->second.sended = sended;
- it->second.sended_mutex->unlock();
+ if (m_wsis.count(socketID) == 0) {
+ return;
}
+ m_wsis[socketID].sended_mutex->lock();
+ m_wsis[socketID].sended = sended;
+ m_wsis[socketID].sended_mutex->unlock();
}
int WebsocketServer::Callback(struct lws *wsi, enum lws_callback_reasons reason,
@@ -323,7 +333,7 @@ int WebsocketServer::Callback(struct lws *wsi, enum lws_callback_reasons reason,
// asking to upgrade the connection to a websocket one.
return -1;
case LWS_CALLBACK_FILTER_PROTOCOL_CONNECTION: {
- std::lock_guard<std::mutex> lock(m_mutex);
+ WriteGuard<RWMutex> lock(m_mutex);
if (WebsocketServer::GetInstance()->DumpHandshakeInfo(wsi)) {
// return non-zero here and kill the connection
return -1;
@@ -335,22 +345,27 @@ int WebsocketServer::Callback(struct lws *wsi, enum lws_callback_reasons reason,
}
break;
case LWS_CALLBACK_SERVER_WRITEABLE: {
- std::lock_guard<std::mutex> lock(m_mutex);
+ ReadGuard<RWMutex> lock(m_mutex);
+ int socketID = lws_get_socket_fd(wsi);
if (WebsocketServer::GetInstance()->Wswrite(wsi, in, len)) {
- WebsocketServer::GetInstance()->SetLwsSendedFlag(wsi, true);
+ WebsocketServer::GetInstance()->SetLwsSendedFlag(socketID, true);
+ // return nonzero from the user callback to close the connection
+ // and callback with the reason of LWS_CALLBACK_CLOSED
return -1;
}
- WebsocketServer::GetInstance()->SetLwsSendedFlag(wsi, true);
+ WebsocketServer::GetInstance()->SetLwsSendedFlag(socketID, true);
}
break;
case LWS_CALLBACK_RECEIVE: {
- std::lock_guard<std::mutex> lock(m_mutex);
- WebsocketServer::GetInstance()->Receive(wsi, (char *)in, len);
+ ReadGuard<RWMutex> lock(m_mutex);
+ WebsocketServer::GetInstance()->Receive(lws_get_socket_fd(wsi), (char *)in, len);
}
break;
case LWS_CALLBACK_CLOSED: {
- std::lock_guard<std::mutex> lock(m_mutex);
- WebsocketServer::GetInstance()->CloseWsSession(wsi);
+ WriteGuard<RWMutex> lock(m_mutex);
+ DEBUG("connection has been closed");
+ WebsocketServer::GetInstance()->RemoveSession(wsi);
+ WebsocketServer::GetInstance()->CloseWsSession(lws_get_socket_fd(wsi));
}
break;
default:
@@ -363,8 +378,7 @@ void WebsocketServer::ServiceWorkThread(int threadid)
{
int n = 0;
while (n >= 0 && !m_force_exit) {
- n = lws_service(m_context, 50);
- std::this_thread::sleep_for(std::chrono::milliseconds(1));
+ n = lws_service(m_context, 0);
}
}
@@ -396,20 +410,19 @@ void WebsocketServer::Wait()
}
namespace {
-auto PrepareWsiSession(struct lws *wsi) -> session_data *
+auto PrepareWsiSession(int socketID) -> session_data *
{
WebsocketServer *server = WebsocketServer::GetInstance();
- server->LockAllWsSession();
+ server->ReadLockAllWsSession();
- auto itor = server->GetWsisData().find(wsi);
+ auto itor = server->GetWsisData().find(socketID);
if (itor == server->GetWsisData().end()) {
ERROR("invalid session!");
server->UnlockAllWsSession();
return nullptr;
}
- itor->second.SetProcessingStatus(true);
+ server->SetLwsSendedFlag(socketID, false);
server->UnlockAllWsSession();
- server->SetLwsSendedFlag(wsi, false);
return &itor->second;
}
@@ -450,15 +463,13 @@ void EnsureWrited(struct lws *wsi, session_data *session)
}
std::this_thread::sleep_for(std::chrono::milliseconds(TRIGGER_PERIOD_MS));
}
-
- session->SetProcessingStatus(false);
}
ssize_t WsWriteToClient(void *context, const void *data, size_t len, WebsocketChannel channel)
{
struct lws *wsi = static_cast<struct lws *>(context);
- session_data *session = PrepareWsiSession(wsi);
+ session_data *session = PrepareWsiSession(lws_get_socket_fd(wsi));
if (session == nullptr) {
return 0;
}
@@ -487,15 +498,20 @@ int closeWsConnect(void *context, char **err)
struct lws *wsi = static_cast<struct lws *>(context);
WebsocketServer *server = WebsocketServer::GetInstance();
- auto it = server->GetWsisData().find(wsi);
+ server->ReadLockAllWsSession();
+ auto it = server->GetWsisData().find(lws_get_socket_fd(wsi));
if (it == server->GetWsisData().end()) {
+ server->UnlockAllWsSession();
ERROR("websocket session not exist");
return -1;
}
+
it->second.close = true;
// close websocket session
- lws_callback_on_writable(wsi);
+ if (server->IsValidSession(wsi)) {
+ lws_callback_on_writable(wsi);
+ }
+ server->UnlockAllWsSession();
+
return 0;
}
-
-
diff --git a/src/daemon/entry/cri/websocket/service/ws_server.h b/src/daemon/entry/cri/websocket/service/ws_server.h
index 1370c552..cb431f7f 100644
--- a/src/daemon/entry/cri/websocket/service/ws_server.h
+++ b/src/daemon/entry/cri/websocket/service/ws_server.h
@@ -17,6 +17,7 @@
#define DAEMON_ENTRY_CRI_WEBSOCKET_SERVICE_WS_SERVER_H
#include <vector>
#include <unordered_map>
+#include <unordered_set>
#include <string>
#include <mutex>
#include <atomic>
@@ -26,6 +27,7 @@
#include "route_callback_register.h"
#include "url.h"
#include "errors.h"
+#include "read_write_lock.h"
#define MAX_ECHO_PAYLOAD 4096
#define MAX_ARRAY_LEN 2
@@ -71,10 +73,11 @@ public:
void Shutdown();
void RegisterCallback(const std::string &path, std::shared_ptr<StreamingServeInterface> callback);
url::URLDatum GetWebsocketUrl();
- std::unordered_map<struct lws *, session_data> &GetWsisData();
- void SetLwsSendedFlag(struct lws *wsi, bool sended);
- void LockAllWsSession();
+ std::unordered_map<int, session_data> &GetWsisData();
+ void SetLwsSendedFlag(int socketID, bool sended);
+ void ReadLockAllWsSession();
void UnlockAllWsSession();
+ bool IsValidSession(struct lws *wsi);
private:
WebsocketServer();
@@ -85,17 +88,19 @@ private:
std::vector<std::string> split(std::string str, char r);
static void EmitLog(int level, const char *line);
int CreateContext();
- inline void Receive(struct lws *client, void *in, size_t len);
+ inline void Receive(int socketID, void *in, size_t len);
int Wswrite(struct lws *wsi, void *in, size_t len);
inline int DumpHandshakeInfo(struct lws *wsi) noexcept;
static int Callback(struct lws *wsi, enum lws_callback_reasons reason,
void *user, void *in, size_t len);
void ServiceWorkThread(int threadid);
- void CloseWsSession(struct lws *wsi);
+ void CloseWsSession(int socketID);
void CloseAllWsSession();
+ void RecordSession(struct lws *wsi);
+ void RemoveSession(struct lws *wsi);
private:
- static std::mutex m_mutex;
+ static RWMutex m_mutex;
static struct lws_context *m_context;
volatile int m_force_exit = 0;
std::thread m_pthread_service;
@@ -104,7 +109,8 @@ private:
{ NULL, NULL, 0, 0 }
};
RouteCallbackRegister m_handler;
- static std::unordered_map<struct lws *, session_data> m_wsis;
+ static std::unordered_map<int, session_data> m_wsis;
+ static std::unordered_set<struct lws *> m_activeSession;
url::URLDatum m_url;
int m_listenPort;
};
diff --git a/src/utils/cpputils/read_write_lock.cc b/src/utils/cpputils/read_write_lock.cc
new file mode 100644
index 00000000..c9f94dc8
--- /dev/null
+++ b/src/utils/cpputils/read_write_lock.cc
@@ -0,0 +1,59 @@
+/******************************************************************************
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved.
+ * iSulad licensed under the Mulan PSL v2.
+ * You can use this software according to the terms and conditions of the Mulan PSL v2.
+ * You may obtain a copy of Mulan PSL v2 at:
+ * http://license.coscl.org.cn/MulanPSL2
+ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR
+ * PURPOSE.
+ * See the Mulan PSL v2 for more details.
+ * Author: wujing
+ * Create: 2021-01-18
+ * Description: provide read write lock implementation
+ *********************************************************************************/
+
+#include "read_write_lock.h"
+
+void RWMutex::rdlock()
+{
+ std::unique_lock<std::mutex> autoLock(m_mutex);
+ ++m_waiting_readers;
+ m_read_cond.wait(autoLock, [&]() {
+ return m_waiting_writers == 0 && m_status >= 0;
+ });
+ --m_waiting_readers;
+ ++m_status;
+}
+
+void RWMutex::wrlock()
+{
+ std::unique_lock<std::mutex> autoLock(m_mutex);
+ ++m_waiting_writers;
+ m_write_cond.wait(autoLock, [&]() {
+ return m_status == 0;
+ });
+ --m_waiting_writers;
+ --m_status;
+}
+
+void RWMutex::unlock()
+{
+ std::unique_lock<std::mutex> autoLock(m_mutex);
+
+ if (m_status == -1) { // one writer
+ m_status = 0;
+ } else if (m_status > 0) { // one or multiple readers
+ --m_status;
+ } else { // neither readers nor writers
+ return;
+ }
+
+ if (m_waiting_writers > 0) {
+ if (m_status == 0) {
+ m_write_cond.notify_one();
+ }
+ } else {
+ m_read_cond.notify_all();
+ }
+}
diff --git a/src/utils/cpputils/read_write_lock.h b/src/utils/cpputils/read_write_lock.h
new file mode 100644
index 00000000..0149e3a5
--- /dev/null
+++ b/src/utils/cpputils/read_write_lock.h
@@ -0,0 +1,90 @@
+/******************************************************************************
+ * Copyright (c) Huawei Technologies Co., Ltd. 2021. All rights reserved.
+ * iSulad licensed under the Mulan PSL v2.
+ * You can use this software according to the terms and conditions of the Mulan PSL v2.
+ * You may obtain a copy of Mulan PSL v2 at:
+ * http://license.coscl.org.cn/MulanPSL2
+ * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR
+ * IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY OR FIT FOR A PARTICULAR
+ * PURPOSE.
+ * See the Mulan PSL v2 for more details.
+ * Author: wujing
+ * Create: 2021-01-18
+ * Description: provide read write lock definition
+ *********************************************************************************/
+#ifndef UTILS_CPPUTILS_READ_WRITE_LOCK_H
+#define UTILS_CPPUTILS_READ_WRITE_LOCK_H
+
+#include <iostream>
+#include <mutex>
+#include <condition_variable>
+#include <thread>
+
+class RWMutex {
+public:
+ RWMutex() = default;
+ ~RWMutex() = default;
+ RWMutex(const RWMutex &) = delete;
+ RWMutex(RWMutex &&) = delete;
+ RWMutex &operator = (const RWMutex &) = delete;
+ RWMutex &operator = (RWMutex &&) = delete;
+
+ void rdlock();
+ void wrlock();
+ void unlock();
+
+private:
+ volatile long m_status {0};
+ volatile long m_waiting_readers {0};
+ volatile long m_waiting_writers {0};
+ std::mutex m_mutex;
+ std::condition_variable m_read_cond;
+ std::condition_variable m_write_cond;
+};
+
+template<typename RWMutexType>
+class ReadGuard {
+public:
+ explicit ReadGuard(RWMutexType &lock) : m_lock(lock)
+ {
+ m_lock.rdlock();
+ }
+ virtual ~ReadGuard()
+ {
+ m_lock.unlock();
+ }
+
+ ReadGuard() = delete;
+ ReadGuard(const ReadGuard &) = delete;
+ ReadGuard &operator=(const ReadGuard &) = delete;
+ ReadGuard(const ReadGuard &&) = delete;
+ ReadGuard &operator = (const ReadGuard &&) = delete;
+
+private:
+ RWMutexType &m_lock;
+};
+
+
+template<typename RWMutexType>
+class WriteGuard {
+public:
+ explicit WriteGuard(RWMutexType &lock) : m_lock(lock)
+ {
+ m_lock.wrlock();
+ }
+ virtual ~WriteGuard()
+ {
+ m_lock.unlock();
+ }
+
+ WriteGuard() = delete;
+ WriteGuard(const WriteGuard &) = delete;
+ WriteGuard &operator=(const WriteGuard &) = delete;
+ WriteGuard(const WriteGuard &&) = delete;
+ WriteGuard &operator = (const WriteGuard &&) = delete;
+
+private:
+ RWMutexType &m_lock;
+};
+
+#endif // UTILS_CPPUTILS_READ_WRITE_LOCK_H
diff --git a/src/utils/cpputils/stoppable_thread.cc b/src/utils/cpputils/stoppable_thread.cc
index 0d15aa01..68f6d9b2 100644
--- a/src/utils/cpputils/stoppable_thread.cc
+++ b/src/utils/cpputils/stoppable_thread.cc
@@ -22,7 +22,6 @@ StoppableThread &StoppableThread::operator=(StoppableThread &&obj)
return *this;
}
-
bool StoppableThread::stopRequested()
{
if (m_future_obj.wait_for(std::chrono::milliseconds(0)) == std::future_status::timeout) {
@@ -35,6 +34,3 @@ void StoppableThread::stop()
{
m_exit_signal.set_value();
}
-
-
-
--
2.25.1