From ad8982cafbe849707a2a07fa3f329f377b5aea33 Mon Sep 17 00:00:00 2001 From: gongzt Date: Tue, 17 Oct 2023 09:50:16 +0800 Subject: Fixed errors in 20.03-sp3, such as task progress, cve repair task, and host cve query MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apollo/database/proxy/cve.py | 2 +- apollo/database/proxy/host.py | 136 +++++++++++++--------------------- apollo/database/proxy/task.py | 2 +- 3 files changed, 54 insertions(+), 86 deletions(-) diff --git a/apollo/database/proxy/cve.py b/apollo/database/proxy/cve.py index 4125894..257083a 100644 --- a/apollo/database/proxy/cve.py +++ b/apollo/database/proxy/cve.py @@ -485,7 +485,7 @@ class CveMysqlProxy(MysqlProxy): CveHostAssociation.installed_rpm, CveHostAssociation.available_rpm, ) - .join(CveHostAssociation, Host.host_id == CveHostAssociation.host_id) + .join(Host, Host.host_id == CveHostAssociation.host_id) .filter(CveHostAssociation.cve_id.in_(cve_list)) .filter(*filters) .all() diff --git a/apollo/database/proxy/host.py b/apollo/database/proxy/host.py index 4b75f1c..b3cabb6 100644 --- a/apollo/database/proxy/host.py +++ b/apollo/database/proxy/host.py @@ -511,31 +511,27 @@ class HostProxy(HostMysqlProxy, CveEsProxy): result = {"total_count": 0, "total_page": 1, "result": []} host_id = data["host_id"] - filters = self._get_host_cve_filters(data.get("filter", {})) - host_cve_query = self._query_host_cve(data["username"], host_id, filters).all() + host_cve_query = self._query_host_cve(data["username"], host_id, data.get("filter", {})) - total_count = len(host_cve_query) + total_count = host_cve_query.count() if not total_count: return SUCCEED, result - cve_info_list, cve_packages_dict = self._preprocess_cve_list_query(host_cve_query) sort_column = data['sort'] if "sort" in data else "cve_id" direction, page, per_page = data.get('direction'), data.get('page'), data.get('per_page') - processed_cve_list, total_page = self._sort_and_page_host_cve_info( - cve_info_list, page, per_page, sort_column, direction - ) - description_dict = self._get_cve_description(list(cve_packages_dict.keys())) - result['result'] = self._add_additional_info_to_cve_list( - processed_cve_list, description_dict, cve_packages_dict - ) + host_cve_list, total_page = sort_and_page(host_cve_query, sort_column, direction, per_page, page) + + cve_id_list = [cve.cve_id for cve in host_cve_list] + description_dict = self._get_cve_description(cve_id_list) + result['result'] = self._add_additional_info_to_cve_list(host_cve_list, description_dict) result['total_page'] = total_page result['total_count'] = total_count return SUCCEED, result @staticmethod - def _get_host_cve_filters(filter_dict): + def _get_host_cve_filters(filter_dict, cve_affected_pkg_subquery): """ Generate filters to filter host's CVEs @@ -563,7 +559,7 @@ class HostProxy(HostMysqlProxy, CveEsProxy): filters.add( or_( CveHostAssociation.cve_id.like("%" + filter_dict["search_key"] + "%"), - CveAffectedPkgs.package.like("%" + filter_dict["search_key"] + "%"), + cve_affected_pkg_subquery.c.package.like("%" + filter_dict["search_key"] + "%"), ) ) if filter_dict.get("severity"): @@ -573,107 +569,79 @@ class HostProxy(HostMysqlProxy, CveEsProxy): filters.add(CveHostAssociation.affected == filter_dict["affected"]) return filters - def _query_host_cve(self, username: str, host_id: int, filters: set): + def _query_host_cve(self, username: str, host_id: int, filters_dict: dict): """ query needed host CVEs info Args: username (str): user name of the request host_id (int): host id - filters (set): filter given by user + filter_dict(dict): filter dict to filter host's CVEs, e.g. + { + "cve_id": "", + "severity": ["high", "unknown"], + "affected": True, + "package": "vim", + "fixed": False // The default is false if the field is null. + } Returns: sqlalchemy.orm.query.Query """ + cve_affected_pkg_subquery = ( + self.session.query( + CveAffectedPkgs.cve_id, + func.group_concat(func.distinct(CveAffectedPkgs.package), SEPARATOR=",").label("package"), + ) + .group_by(CveAffectedPkgs.cve_id) + .distinct() + .subquery() + ) + + filters = self._get_host_cve_filters(filters_dict, cve_affected_pkg_subquery) + host_cve_query = ( self.session.query( CveHostAssociation.cve_id, case([(Cve.publish_time == None, "")], else_=Cve.publish_time).label("publish_time"), case([(Cve.severity == None, "")], else_=Cve.severity).label("severity"), case([(Cve.cvss_score == None, "")], else_=Cve.cvss_score).label("cvss_score"), - case([(CveAffectedPkgs.package == None, "")], else_=CveAffectedPkgs.package).label("package"), + case( + [(cve_affected_pkg_subquery.c.package == None, "")], else_=cve_affected_pkg_subquery.c.package + ).label("package"), ) .select_from(CveHostAssociation) .outerjoin(Cve, CveHostAssociation.cve_id == Cve.cve_id) - .outerjoin(CveAffectedPkgs, CveAffectedPkgs.cve_id == CveHostAssociation.cve_id) + .outerjoin(cve_affected_pkg_subquery, cve_affected_pkg_subquery.c.cve_id == CveHostAssociation.cve_id) .outerjoin(Host, Host.host_id == CveHostAssociation.host_id) .filter(CveHostAssociation.host_id == host_id, Host.user == username) .filter(*filters) - ).group_by(CveHostAssociation.cve_id, CveAffectedPkgs.package) + ).group_by(CveHostAssociation.cve_id, cve_affected_pkg_subquery.c.package) return host_cve_query @staticmethod - def _preprocess_cve_list_query(cve_list_query: List[dict]) -> Tuple[List[dict], dict]: - """ - get each cve's source package set and deduplication cve info list - - Args: - cve_list_query(list): cve info list(two rows may have same cve id - with different source package) - Returns: - list: list of cve info without package and description. - dict: key is cve id, value is cve affected source package set. e.g - { "cve-xxxx-xxxx": {"vim"}, "cve-xxxx-xxxx":{"kernel"} } - """ - cve_packages_dict, cve_info_list = defaultdict(set), [] - - for row in cve_list_query: - if row.cve_id not in cve_packages_dict: - cve_info_list.append( - { - "cve_id": row.cve_id, - "publish_time": row.publish_time, - "severity": row.severity, - "cvss_score": row.cvss_score, - } - ) - cve_packages_dict[row.cve_id].add(row.package) - - return cve_info_list, cve_packages_dict - - @staticmethod - def _sort_and_page_host_cve_info(cve_info_list, page, per_page, sort_column=None, direction="desc"): - """ - sort and page cve info list - - Args: - cve_info_list (list): cve info list. not empty. - page(int) - per_page(int): number of record per page - sort_column(str): the column that sort based on - direction(str): desc or asc - - Returns: - query result, total number of pages - """ - total_page = 1 - total_count = len(cve_info_list) - - if sort_column: - cve_info_list.sort( - key=lambda cve_info: cve_info[sort_column], reverse=True if direction == "desc" else False - ) - - if page and per_page: - total_page = math.ceil(total_count / per_page) - return cve_info_list[per_page * (page - 1) : per_page * page], total_page - - return cve_info_list, total_page - - @staticmethod - def _add_additional_info_to_cve_list(cve_info_list: list, description_dict: dict, cve_packages_dict: list) -> list: + def _add_additional_info_to_cve_list(host_cve_list: list, description_dict: dict) -> list: """ add description and affected source packages for each cve Args: - cve_info_list: + host_cve_list: description_dict (dict): key is cve's id, value is cve's description - cve_pkgs_dict (dict): key is cve's id, value is cve's affected packages set Returns: list """ - for cve_info in cve_info_list: - cve_id = cve_info["cve_id"] - cve_info["description"] = description_dict[cve_id] if description_dict.get(cve_id) else "" - cve_info["package"] = ",".join(cve_packages_dict[cve_id]) - return cve_info_list + host_cve_info_list = [] + for host_cve in host_cve_list: + description = description_dict[host_cve.cve_id] if description_dict.get(host_cve.cve_id) else "" + host_cve_info_list.append( + { + "cve_id": host_cve.cve_id, + "publish_time": host_cve.publish_time, + "severity": host_cve.severity, + "cvss_score": host_cve.cvss_score, + "package": host_cve.package, + "description": description, + } + ) + + return host_cve_info_list diff --git a/apollo/database/proxy/task.py b/apollo/database/proxy/task.py index e7c3b4e..de151b2 100644 --- a/apollo/database/proxy/task.py +++ b/apollo/database/proxy/task.py @@ -1291,7 +1291,7 @@ class TaskMysqlProxy(MysqlProxy): status = TaskStatus.UNKNOWN else: status = TaskStatus.SUCCEED - result[cve_id] = {"progress": row.total - row.running - row.none, "status": status} + result[cve_id] = {"progress": int(row.total - row.running - row.none), "status": status} succeed_list = list(result.keys()) fail_list = list(set(cve_list) - set(succeed_list)) -- Gitee