From a5291bcee84b56b30aac38544d85fb601fe6a25a Mon Sep 17 00:00:00 2001 From: Eric Garver Date: Tue, 17 Mar 2020 13:51:43 -0400 Subject: [PATCH] improvement: port: allow coalescing and breaking of ranges --- src/firewall/core/fw_zone.py | 106 +++++++++++++++++++++++++++---------- src/firewall/functions.py | 85 +++++++++++++++++++++++++++++ src/firewall/server/config_zone.py | 55 ++++++++++++++----- 3 files changed, 204 insertions(+), 42 deletions(-) diff --git a/src/firewall/core/fw_zone.py b/src/firewall/core/fw_zone.py index 2bc94e3..d32d7a8 100644 --- a/src/firewall/core/fw_zone.py +++ b/src/firewall/core/fw_zone.py @@ -25,7 +25,7 @@ from firewall.core.base import SHORTCUTS, DEFAULT_ZONE_TARGET, \ from firewall.core.logger import log from firewall.functions import portStr, checkIPnMask, checkIP6nMask, \ checkProtocol, enable_ip_forwarding, check_single_address, check_mac, \ - portInPortRange, get_nf_conntrack_short_name + portInPortRange, get_nf_conntrack_short_name, coalescePortRange, breakPortRange from firewall.core.rich import Rich_Rule, Rich_Accept, \ Rich_Mark, Rich_Service, Rich_Port, Rich_Protocol, \ Rich_Masquerade, Rich_ForwardPort, Rich_SourcePort, Rich_IcmpBlock, \ @@ -857,11 +857,13 @@ class FirewallZone(object): self._fw.check_panic() _obj = self._zones[_zone] - port_id = self.__port_id(port, protocol) - if port_id in _obj.settings["ports"]: - raise FirewallError(errors.ALREADY_ENABLED, - "'%s:%s' already in '%s'" % (port, protocol, - _zone)) + existing_port_ids = list(filter(lambda x: x[1] == protocol, _obj.settings["ports"])) + for port_id in existing_port_ids: + if portInPortRange(port, port_id[0]): + raise FirewallError(errors.ALREADY_ENABLED, + "'%s:%s' already in '%s'" % (port, protocol, _zone)) + + added_ranges, removed_ranges = coalescePortRange(port, [_port for (_port, _protocol) in existing_port_ids]) if use_transaction is None: transaction = self.new_transaction() @@ -869,10 +871,18 @@ class FirewallZone(object): transaction = use_transaction if _obj.applied: - self._port(True, _zone, port, protocol, transaction) - - self.__register_port(_obj, port_id, timeout, sender) - transaction.add_fail(self.__unregister_port, _obj, port_id) + for range in added_ranges: + self._port(True, _zone, portStr(range, "-"), protocol, transaction) + for range in removed_ranges: + self._port(False, _zone, portStr(range, "-"), protocol, transaction) + + for range in added_ranges: + port_id = self.__port_id(range, protocol) + self.__register_port(_obj, port_id, timeout, sender) + transaction.add_fail(self.__unregister_port, _obj, port_id) + for range in removed_ranges: + port_id = self.__port_id(range, protocol) + transaction.add_post(self.__unregister_port, _obj, port_id) if use_transaction is None: transaction.execute(True) @@ -889,20 +899,34 @@ class FirewallZone(object): self._fw.check_panic() _obj = self._zones[_zone] - port_id = self.__port_id(port, protocol) - if port_id not in _obj.settings["ports"]: + existing_port_ids = list(filter(lambda x: x[1] == protocol, _obj.settings["ports"])) + for port_id in existing_port_ids: + if portInPortRange(port, port_id[0]): + break + else: raise FirewallError(errors.NOT_ENABLED, "'%s:%s' not in '%s'" % (port, protocol, _zone)) + added_ranges, removed_ranges = breakPortRange(port, [_port for (_port, _protocol) in existing_port_ids]) + if use_transaction is None: transaction = self.new_transaction() else: transaction = use_transaction if _obj.applied: - self._port(False, _zone, port, protocol, transaction) - - transaction.add_post(self.__unregister_port, _obj, port_id) + for range in added_ranges: + self._port(True, _zone, portStr(range, "-"), protocol, transaction) + for range in removed_ranges: + self._port(False, _zone, portStr(range, "-"), protocol, transaction) + + for range in added_ranges: + port_id = self.__port_id(range, protocol) + self.__register_port(_obj, port_id, 0, None) + transaction.add_fail(self.__unregister_port, _obj, port_id) + for range in removed_ranges: + port_id = self.__port_id(range, protocol) + transaction.add_post(self.__unregister_port, _obj, port_id) if use_transaction is None: transaction.execute(True) @@ -1015,11 +1039,13 @@ class FirewallZone(object): self._fw.check_panic() _obj = self._zones[_zone] - port_id = self.__source_port_id(port, protocol) - if port_id in _obj.settings["source_ports"]: - raise FirewallError(errors.ALREADY_ENABLED, - "'%s:%s' already in '%s'" % (port, protocol, - _zone)) + existing_port_ids = list(filter(lambda x: x[1] == protocol, _obj.settings["source_ports"])) + for port_id in existing_port_ids: + if portInPortRange(port, port_id[0]): + raise FirewallError(errors.ALREADY_ENABLED, + "'%s:%s' already in '%s'" % (port, protocol, _zone)) + + added_ranges, removed_ranges = coalescePortRange(port, [_port for (_port, _protocol) in existing_port_ids]) if use_transaction is None: transaction = self.new_transaction() @@ -1027,10 +1053,18 @@ class FirewallZone(object): transaction = use_transaction if _obj.applied: - self._source_port(True, _zone, port, protocol, transaction) - - self.__register_source_port(_obj, port_id, timeout, sender) - transaction.add_fail(self.__unregister_source_port, _obj, port_id) + for range in added_ranges: + self._source_port(True, _zone, portStr(range, "-"), protocol, transaction) + for range in removed_ranges: + self._source_port(False, _zone, portStr(range, "-"), protocol, transaction) + + for range in added_ranges: + port_id = self.__source_port_id(range, protocol) + self.__register_source_port(_obj, port_id, timeout, sender) + transaction.add_fail(self.__unregister_source_port, _obj, port_id) + for range in removed_ranges: + port_id = self.__source_port_id(range, protocol) + transaction.add_post(self.__unregister_source_port, _obj, port_id) if use_transaction is None: transaction.execute(True) @@ -1047,20 +1081,34 @@ class FirewallZone(object): self._fw.check_panic() _obj = self._zones[_zone] - port_id = self.__source_port_id(port, protocol) - if port_id not in _obj.settings["source_ports"]: + existing_port_ids = list(filter(lambda x: x[1] == protocol, _obj.settings["source_ports"])) + for port_id in existing_port_ids: + if portInPortRange(port, port_id[0]): + break + else: raise FirewallError(errors.NOT_ENABLED, "'%s:%s' not in '%s'" % (port, protocol, _zone)) + added_ranges, removed_ranges = breakPortRange(port, [_port for (_port, _protocol) in existing_port_ids]) + if use_transaction is None: transaction = self.new_transaction() else: transaction = use_transaction if _obj.applied: - self._source_port(False, _zone, port, protocol, transaction) - - transaction.add_post(self.__unregister_source_port, _obj, port_id) + for range in added_ranges: + self._source_port(True, _zone, portStr(range, "-"), protocol, transaction) + for range in removed_ranges: + self._source_port(False, _zone, portStr(range, "-"), protocol, transaction) + + for range in added_ranges: + port_id = self.__source_port_id(range, protocol) + self.__register_source_port(_obj, port_id, 0, None) + transaction.add_fail(self.__unregister_source_port, _obj, port_id) + for range in removed_ranges: + port_id = self.__source_port_id(range, protocol) + transaction.add_post(self.__unregister_source_port, _obj, port_id) if use_transaction is None: transaction.execute(True) diff --git a/src/firewall/functions.py b/src/firewall/functions.py index 6af2206..6bc52d9 100644 --- a/src/firewall/functions.py +++ b/src/firewall/functions.py @@ -72,6 +72,10 @@ def getPortRange(ports): @return Array containing start and end port id for a valid range or -1 if port can not be found and -2 if port is too big for integer input or -1 for invalid ranges or None if the range is ambiguous. """ + # (port, port) or [port, port] case + if isinstance(ports, tuple) or isinstance(ports, list): + return ports + # "" case if isinstance(ports, int) or ports.isdigit(): id1 = getPortID(ports) @@ -155,6 +159,87 @@ def portInPortRange(port, range): return False +def coalescePortRange(new_range, ranges): + """ Coalesce a port range with existing list of port ranges + + @param new_range tuple/list/string + @param ranges list of tuple/list/string + @return tuple of (list of ranges added after coalescing, list of removed original ranges) + """ + + coalesced_range = getPortRange(new_range) + # normalize singleton ranges, e.g. (x,) --> (x,x) + if len(coalesced_range) == 1: + coalesced_range = (coalesced_range[0], coalesced_range[0]) + _ranges = map(getPortRange, ranges) + _ranges = sorted(map(lambda x: (x[0],x[0]) if len(x) == 1 else x, _ranges), key=lambda x: x[0]) + + removed_ranges = [] + for range in _ranges: + if coalesced_range[0] <= range[0] and coalesced_range[1] >= range[1]: + # new range covers this + removed_ranges.append(range) + elif coalesced_range[0] <= range[0] and coalesced_range[1] < range[1] and \ + coalesced_range[1] >= range[0]: + # expand beginning of range + removed_ranges.append(range) + coalesced_range = (coalesced_range[0], range[1]) + elif coalesced_range[0] > range[0] and coalesced_range[1] >= range[1] and \ + coalesced_range[0] <= range[1]: + # expand end of range + removed_ranges.append(range) + coalesced_range = (range[0], coalesced_range[1]) + + # normalize singleton ranges, e.g. (x,x) --> (x,) + removed_ranges = list(map(lambda x: (x[0],) if x[0] == x[1] else x, removed_ranges)) + if coalesced_range[0] == coalesced_range[1]: + coalesced_range = (coalesced_range[0],) + + return ([coalesced_range], removed_ranges) + +def breakPortRange(remove_range, ranges): + """ break a port range from existing list of port ranges + + @param remove_range tuple/list/string + @param ranges list of tuple/list/string + @return tuple of (list of ranges added after breaking up, list of removed original ranges) + """ + + remove_range = getPortRange(remove_range) + # normalize singleton ranges, e.g. (x,) --> (x,x) + if len(remove_range) == 1: + remove_range = (remove_range[0], remove_range[0]) + _ranges = map(getPortRange, ranges) + _ranges = sorted(map(lambda x: (x[0],x[0]) if len(x) == 1 else x, _ranges), key=lambda x: x[0]) + + removed_ranges = [] + added_ranges = [] + for range in _ranges: + if remove_range[0] <= range[0] and remove_range[1] >= range[1]: + # remove entire range + removed_ranges.append(range) + elif remove_range[0] <= range[0] and remove_range[1] < range[1] and \ + remove_range[1] >= range[0]: + # remove from beginning of range + removed_ranges.append(range) + added_ranges.append((remove_range[1] + 1, range[1])) + elif remove_range[0] > range[0] and remove_range[1] >= range[1] and \ + remove_range[0] <= range[1]: + # remove from end of range + removed_ranges.append(range) + added_ranges.append((range[0], remove_range[0] - 1)) + elif remove_range[0] > range[0] and remove_range[1] < range[1]: + # remove inside range + removed_ranges.append(range) + added_ranges.append((range[0], remove_range[0] - 1)) + added_ranges.append((remove_range[1] + 1, range[1])) + + # normalize singleton ranges, e.g. (x,x) --> (x,) + removed_ranges = list(map(lambda x: (x[0],) if x[0] == x[1] else x, removed_ranges)) + added_ranges = list(map(lambda x: (x[0],) if x[0] == x[1] else x, added_ranges)) + + return (added_ranges, removed_ranges) + def getServiceName(port, proto): """ Check and Get service name from port and proto string combination using socket.getservbyport diff --git a/src/firewall/server/config_zone.py b/src/firewall/server/config_zone.py index 1ae20ce..1c05318 100644 --- a/src/firewall/server/config_zone.py +++ b/src/firewall/server/config_zone.py @@ -41,7 +41,8 @@ from firewall.server.decorators import handle_exceptions, \ dbus_handle_exceptions, dbus_service_method from firewall import errors from firewall.errors import FirewallError -from firewall.functions import portInPortRange +from firewall.functions import portStr, portInPortRange, coalescePortRange, \ + breakPortRange ############################################################################ # @@ -455,10 +456,16 @@ class FirewallDConfigZone(slip.dbus.service.Object): protocol) self.parent.accessCheck(sender) settings = list(self.getSettings()) - if (port,protocol) in settings[6]: - raise FirewallError(errors.ALREADY_ENABLED, - "%s:%s" % (port, protocol)) - settings[6].append((port,protocol)) + existing_port_ids = list(filter(lambda x: x[1] == protocol, settings[6])) + for port_id in existing_port_ids: + if portInPortRange(port, port_id[0]): + raise FirewallError(errors.ALREADY_ENABLED, + "%s:%s" % (port, protocol)) + added_ranges, removed_ranges = coalescePortRange(port, [_port for (_port, _protocol) in existing_port_ids]) + for range in removed_ranges: + settings[6].remove((portStr(range, "-"), protocol)) + for range in added_ranges: + settings[6].append((portStr(range, "-"), protocol)) self.update(settings) @dbus_service_method(config.dbus.DBUS_INTERFACE_CONFIG_ZONE, @@ -471,9 +478,17 @@ class FirewallDConfigZone(slip.dbus.service.Object): protocol) self.parent.accessCheck(sender) settings = list(self.getSettings()) - if (port,protocol) not in settings[6]: + existing_port_ids = list(filter(lambda x: x[1] == protocol, settings[6])) + for port_id in existing_port_ids: + if portInPortRange(port, port_id[0]): + break + else: raise FirewallError(errors.NOT_ENABLED, "%s:%s" % (port, protocol)) - settings[6].remove((port,protocol)) + added_ranges, removed_ranges = breakPortRange(port, [_port for (_port, _protocol) in existing_port_ids]) + for range in removed_ranges: + settings[6].remove((portStr(range, "-"), protocol)) + for range in added_ranges: + settings[6].append((portStr(range, "-"), protocol)) self.update(settings) @dbus_service_method(config.dbus.DBUS_INTERFACE_CONFIG_ZONE, @@ -583,10 +598,16 @@ class FirewallDConfigZone(slip.dbus.service.Object): protocol) self.parent.accessCheck(sender) settings = list(self.getSettings()) - if (port,protocol) in settings[14]: - raise FirewallError(errors.ALREADY_ENABLED, - "%s:%s" % (port, protocol)) - settings[14].append((port,protocol)) + existing_port_ids = list(filter(lambda x: x[1] == protocol, settings[14])) + for port_id in existing_port_ids: + if portInPortRange(port, port_id[0]): + raise FirewallError(errors.ALREADY_ENABLED, + "%s:%s" % (port, protocol)) + added_ranges, removed_ranges = coalescePortRange(port, [_port for (_port, _protocol) in existing_port_ids]) + for range in removed_ranges: + settings[14].remove((portStr(range, "-"), protocol)) + for range in added_ranges: + settings[14].append((portStr(range, "-"), protocol)) self.update(settings) @dbus_service_method(config.dbus.DBUS_INTERFACE_CONFIG_ZONE, @@ -599,9 +620,17 @@ class FirewallDConfigZone(slip.dbus.service.Object): protocol) self.parent.accessCheck(sender) settings = list(self.getSettings()) - if (port,protocol) not in settings[14]: + existing_port_ids = list(filter(lambda x: x[1] == protocol, settings[14])) + for port_id in existing_port_ids: + if portInPortRange(port, port_id[0]): + break + else: raise FirewallError(errors.NOT_ENABLED, "%s:%s" % (port, protocol)) - settings[14].remove((port,protocol)) + added_ranges, removed_ranges = breakPortRange(port, [_port for (_port, _protocol) in existing_port_ids]) + for range in removed_ranges: + settings[14].remove((portStr(range, "-"), protocol)) + for range in added_ranges: + settings[14].append((portStr(range, "-"), protocol)) self.update(settings) @dbus_service_method(config.dbus.DBUS_INTERFACE_CONFIG_ZONE, -- 1.8.3.1