import struct
from nfnetlink import *

# message types
KZNL_MSG_START = 0
KZNL_MSG_COMMIT = 1
KZNL_MSG_FLUSH_ZONE = 2
KZNL_MSG_ADD_ZONE = 3
KZNL_MSG_ADD_ZONE_SVC_IN = 4
KZNL_MSG_ADD_ZONE_SVC_OUT = 5
KZNL_MSG_GET_ZONE = 6
KZNL_MSG_FLUSH_SERVICE = 7
KZNL_MSG_ADD_SERVICE = 8
KZNL_MSG_ADD_SERVICE_NAT_SRC = 9
KZNL_MSG_ADD_SERVICE_NAT_DST = 10
KZNL_MSG_GET_SERVICE = 11
KZNL_MSG_FLUSH_DISPATCHER = 12
KZNL_MSG_ADD_DISPATCHER = 13
KZNL_MSG_ADD_DISPATCHER_CSS = 14
KZNL_MSG_GET_DISPATCHER = 15
KZNL_MSG_QUERY = 16
KZNL_MSG_MAX = 16

# attribute types
KZA_INVALID = 0
KZA_INSTANCE_NAME = 1
KZA_TR_PARAMS = 2
KZA_ZONE_PARAMS = 3
KZA_ZONE_NAME = 4
KZA_ZONE_UNAME = 5
KZA_ZONE_PNAME = 6
KZA_ZONE_RANGE = 7
KZA_SVC_PARAMS = 8
KZA_SVC_NAME = 9
KZA_SVC_ROUTER_DST = 10
KZA_SVC_NAT_SRC = 11
KZA_SVC_NAT_DST = 12
KZA_SVC_NAT_MAP = 13
KZA_SVC_SESSION_CNT = 14
KZA_DPT_PARAMS = 15
KZA_DPT_NAME = 16
KZA_DPT_BIND_ADDR = 17
KZA_DPT_BIND_IFACE = 18
KZA_DPT_BIND_IFGROUP = 19
KZA_DPT_CSS_CZONE = 20
KZA_DPT_CSS_SZONE = 21
KZA_DPT_CSS_SERVICE = 22
KZA_QUERY_SRC = 23
KZA_QUERY_DST = 24
KZA_QUERY_IFACE = 25
KZA_MAX = 25

# name of global instance
KZ_INSTANCE_GLOBAL = ".global"

# transaction types
KZ_TR_TYPE_INVALID = 0
KZ_TR_TYPE_ZONE = 1
KZ_TR_TYPE_SERVICE = 2
KZ_TR_TYPE_DISPATCHER = 3

# zone flags
KZF_ZONE_UMBRELLA = 1

# service types
KZ_SVC_INVALID = 0
KZ_SVC_PROXY = 1
KZ_SVC_FORWARD = 2

# service flags
KZF_SVC_TRANSPARENT = 1
KZF_SVC_FORGE_ADDR = 2

# service NAT entry flags
KZ_SVC_NAT_MAP_IPS = 1
KZ_SVC_NAT_MAP_PROTO_SPECIFIC = 2

# dispatcher types
KZ_DPT_TYPE_INVALID = 0
KZ_DPT_TYPE_INET = 1
KZ_DPT_TYPE_IFACE = 2
KZ_DPT_TYPE_IFGROUP = 3

# dispatcher flags
KZF_DPT_TRANSPARENT = 1
KZF_DPT_FOLLOW_PARENT = 2

# dispatcher bind address port ranges
KZF_DPT_PORT_RANGE_SIZE = 8

###########################################################################
# helper functions to create/parse kzorp attributes
###########################################################################
def create_name_attr(type, name):
        data = "".join((struct.pack('>H', len(name)), name))
        return NfnetlinkAttribute(type, data)

def parse_name_attr(attr):
        (len,) = struct.unpack('>H', attr.get_data()[:2])
        (name,) = struct.unpack(str(len) + 's', attr.get_data()[2 : 2 + len])
        return name

def create_int8_attr(type, value):
        return NfnetlinkAttribute(type, struct.pack('B', value))

def parse_int8_attr(attr):
        (value,) = struct.unpack('B', attr.get_data()[0])
        return value

def create_int32_attr(type, value):
        return NfnetlinkAttribute(type, struct.pack('>I', value))

def parse_int32_attr(attr):
        (value,) = struct.unpack('>I', attr.get_data()[:4])
        return value

def create_inet_range_attr(type, address, mask):
        return NfnetlinkAttribute(type, struct.pack('>II', address, mask))

def parse_inet_range_attr(attr):
        return struct.unpack('>II', attr.get_data()[:8])

def create_nat_range_attr(type, flags, min_ip, max_ip, min_port, max_port):
        data = struct.pack('>IIIHH', flags, min_ip, max_ip, min_port, max_port)
        return NfnetlinkAttribute(type, data)

def parse_nat_range_attr(attr):
        return struct.unpack('>IIIHH', attr.get_data()[:16])

def create_address_attr(type, proto, ip, port):
        return NfnetlinkAttribute(type, struct.pack('>IHB', ip, port, proto))

def parse_address_attr(attr):
        return struct.unpack('>IHB', attr.get_data()[:7])

def create_bind_addr_attr(type, proto, ip, ports):
        if len(ports) > KZF_DPT_PORT_RANGE_SIZE:
                raise ValueError, "bind address contains too many port ranges, %s allowed" % KZF_DPT_PORT_RANGE_SIZE
        data = struct.pack('>I', ip)
        for r in ports:
                data = "".join((data, struct.pack('>HH', r[0], r[1])))
        if len(ports) < KZF_DPT_PORT_RANGE_SIZE:
                data = "".join((data, "\0" * 4 * (KZF_DPT_PORT_RANGE_SIZE - len(ports))))
        data = "".join((data, struct.pack('BB', len(ports), proto)))
        return NfnetlinkAttribute(type, data)

def parse_bind_addr_attr(attr):
        (addr,) = struct.unpack('>I', attr.get_data()[:4])
        (num_ports, proto) = struct.unpack('BB', attr.get_data()[36:38])
        ports = []
        for i in range(num_ports):
                (start, end) = struct.unpack('>HH', attr.get_data()[4 + 4 * i : 8 + 4 * i])
                ports.append((start, end))
        return (proto, addr, ports)

def create_bind_iface_attr(type, proto, iface, ports, pref_addr):
        if len(ports) > KZF_DPT_PORT_RANGE_SIZE:
                raise ValueError, "bind address contains too many port ranges, %s allowed" % KZF_DPT_PORT_RANGE_SIZE
        data = struct.pack('>I', pref_addr)
        for r in ports:
                data = "".join((data, struct.pack('>HH', r[0], r[1])))
        if len(ports) < KZF_DPT_PORT_RANGE_SIZE:
                data = "".join((data, "\0" * 4 * (KZF_DPT_PORT_RANGE_SIZE - len(ports))))

        data = "".join((data, struct.pack('BB', len(ports), proto), iface, "\0" * (16 - len(iface))))
        return NfnetlinkAttribute(type, data)

def parse_bind_iface_attr(attr):
        (pref_addr,) = struct.unpack('>I', attr.get_data()[:4])
        (num_ports, proto) = struct.unpack('BB', attr.get_data()[36:38])
        ports = []
        for i in range(num_ports):
                (start, end) = struct.unpack('>HH', attr.get_data()[4 + 4 * i : 8 + 4 * i])
                ports.append((start, end))
        iface = attr.get_data()[38:].rstrip("\0")
        return (proto, iface, ports, pref_addr)

def create_bind_ifgroup_attr(type, proto, group, mask, ports, pref_addr):
        if len(ports) > KZF_DPT_PORT_RANGE_SIZE:
                raise ValueError, "bind address contains too many port ranges, %s allowed" & KZF_DPT_PORT_RANGE_SIZE
        data = struct.pack('>III', group, mask, pref_addr)
        for r in ports:
                data = "".join((data, struct.pack('>HH', r[0], r[1])))
        if len(ports) < KZF_DPT_PORT_RANGE_SIZE:
                data = "".join((data, "\0" * 4 * (KZF_DPT_PORT_RANGE_SIZE - len(ports))))

        data = "".join((data, struct.pack('BB', len(ports), proto)))
        return NfnetlinkAttribute(type, data)

def parse_bind_ifgroup_attr(attr):
        (group, mask, pref_addr) = struct.unpack('>III', attr.get_data()[:12])
        (num_ports, proto) = struct.unpack('BB', attr.get_data()[44:46])
        ports = []
        for i in range(num_ports):
                (start, end) = struct.unpack('>HH', attr.get_data()[12 + 4 * i : 16 + 4 * i])
                ports.append((start, end))
        return (proto, group, mask, ports, pref_addr)

def create_dispatcher_params_attr(type, dpt_type, dpt_flags, proxy_port):
        return NfnetlinkAttribute(type, struct.pack('>IHB', dpt_flags, proxy_port, dpt_type))

def parse_dispatcher_params_attr(attr):
        return struct.unpack('>IHB', attr.get_data()[:7])

def create_service_params_attr(type, svc_type, svc_flags):
        return NfnetlinkAttribute(type, struct.pack('>IB', svc_flags, svc_type))

def parse_service_params_attr(attr):
        return struct.unpack('>IB', attr.get_data()[:5])

###########################################################################
# helper functions to assemble kzorp messages
###########################################################################
def create_start_msg(type, name):
        m = NfnetlinkMessage(socket.AF_NETLINK, 0, 0)
        m.append_attribute(create_name_attr(KZA_INSTANCE_NAME, name))
        m.append_attribute(create_int8_attr(KZA_TR_PARAMS, type))
        return m

def create_commit_msg():
        m = NfnetlinkMessage(socket.AF_NETLINK, 0, 0)
        return m
        
def create_flush_msg():
        m = NfnetlinkMessage(socket.AF_NETLINK, 0, 0)
        return m

# service
def create_add_proxyservice_msg(name):
        m = NfnetlinkMessage(socket.AF_NETLINK, 0, 0)
        m.append_attribute(create_service_params_attr(KZA_SVC_PARAMS, KZ_SVC_PROXY, 0))
        m.append_attribute(create_name_attr(KZA_SVC_NAME, name))
        return m

def create_add_pfservice_msg(name, flags, dst_ip = None, dst_port = None):
        m = NfnetlinkMessage(socket.AF_NETLINK, 0, 0)
        m.append_attribute(create_service_params_attr(KZA_SVC_PARAMS, KZ_SVC_FORWARD, flags))
        m.append_attribute(create_name_attr(KZA_SVC_NAME, name))
        if dst_ip and dst_port:
                m.append_attribute(create_address_attr(KZA_SVC_ROUTER_DST, 0, dst_ip, dst_port))
        return m

def create_add_service_nat_msg(name, mapping):
        # mapping is a tuple: (src, dst, map)
        # elements are tuples: (flags, min_ip, max_ip, min_port, max_port)
        m = NfnetlinkMessage(socket.AF_NETLINK, 0, 0)
        m.append_attribute(create_name_attr(KZA_SVC_NAME, name))
        (src, dst, map) = mapping
        m.append_attribute(create_nat_range_attr(KZA_SVC_NAT_SRC, src[0], src[1], src[2], src[3], src[4]))
        if dst:
                m.append_attribute(create_nat_range_attr(KZA_SVC_NAT_DST, dst[0], dst[1], dst[2], dst[3], dst[4]))
        m.append_attribute(create_nat_range_attr(KZA_SVC_NAT_MAP, map[0], map[1], map[2], map[3], map[4]))
        return m

def create_get_service_msg(name):
        m = NfnetlinkMessage(socket.AF_NETLINK, 0, 0)
        if name:
                m.append_attribute(create_name_attr(KZA_SERVICE_NAME, name))
        return m


# zone
def create_add_zone_msg(name, flags, address = None, mask = None, uname = None, pname = None):
        m = NfnetlinkMessage(socket.AF_NETLINK, 0, 0)
        m.append_attribute(create_int32_attr(KZA_ZONE_PARAMS, flags))
        m.append_attribute(create_name_attr(KZA_ZONE_NAME, name))
        if uname:
                m.append_attribute(create_name_attr(KZA_ZONE_UNAME, uname))
        if pname:
                m.append_attribute(create_name_attr(KZA_ZONE_PNAME, pname))
        if address != None and mask != None:
                m.append_attribute(create_inet_range_attr(KZA_ZONE_RANGE, address, mask))
        return m

def create_add_zone_svc_msg(name, service):
        m = NfnetlinkMessage(socket.AF_NETLINK, 0, 0)
        m.append_attribute(create_name_attr(KZA_ZONE_UNAME, name))
        m.append_attribute(create_name_attr(KZA_SVC_NAME, service))
        return m

def create_get_zone_msg(name):
        m = NfnetlinkMessage(socket.AF_NETLINK, 0, 0)
        if name:
                m.append_attribute(create_name_attr(KZA_ZONE_NAME, name))
        return m

# dispatcher
def create_add_dispatcher_sabind_msg(name, flags, proto, proxy_port, rule_addr, rule_ports):
        m = NfnetlinkMessage(socket.AF_NETLINK, 0, 0)
        m.append_attribute(create_dispatcher_params_attr(KZA_DPT_PARAMS, KZ_DPT_TYPE_INET, flags, proxy_port))
        m.append_attribute(create_name_attr(KZA_DPT_NAME, name))
        m.append_attribute(create_bind_addr_attr(KZA_DPT_BIND_ADDR, proto, rule_addr, rule_ports))
        return m
        
def create_add_dispatcher_ifacebind_msg(name, flags, proto, proxy_port, ifname, rule_ports, pref_addr = None):
        m = NfnetlinkMessage(socket.AF_NETLINK, 0, 0)
        m.append_attribute(create_dispatcher_params_attr(KZA_DPT_PARAMS, KZ_DPT_TYPE_IFACE, flags, proxy_port))
        m.append_attribute(create_name_attr(KZA_DPT_NAME, name))
        if not pref_addr:
                pref_addr = 0
        m.append_attribute(create_bind_iface_attr(KZA_DPT_BIND_IFACE, proto, ifname, rule_ports, pref_addr))
        return m

def create_add_dispatcher_ifgroupbind_msg(name, flags, proto, proxy_port, ifgroup, ifmask, rule_ports, pref_addr = None):
        m = NfnetlinkMessage(socket.AF_NETLINK, 0, 0)
        m.append_attribute(create_dispatcher_params_attr(KZA_DPT_PARAMS, KZ_DPT_TYPE_IFGROUP, flags, proxy_port))
        m.append_attribute(create_name_attr(KZA_DPT_NAME, name))
        if not pref_addr:
                pref_addr = 0
        m.append_attribute(create_bind_ifgroup_attr(KZA_DPT_BIND_IFGROUP, proto, ifgroup, ifmask, rule_ports, pref_addr))
        return m

def create_add_dispatcher_css_msg(name, service, czone = None, szone = None):
        m = NfnetlinkMessage(socket.AF_NETLINK, 0, 0)
        m.append_attribute(create_name_attr(KZA_DPT_NAME, name))
        if czone and czone != '*':
                m.append_attribute(create_name_attr(KZA_DPT_CSS_CZONE, czone))
        if szone and szone != '*':
                m.append_attribute(create_name_attr(KZA_DPT_CSS_SZONE, szone))
        m.append_attribute(create_name_attr(KZA_DPT_CSS_SERVICE, service))
        return m

def create_get_dispatcher_msg(name):
        m = NfnetlinkMessage(socket.AF_NETLINK, 0, 0)
        if name:
                m.append_attribute(create_name_attr(KZA_DISPATCHER_NAME, name))
        return m

def create_query_msg(proto, saddr, sport, daddr, dport, iface):
        m = NfnetlinkMessage(socket.AF_NETLINK, 0, 0)
        m.append_attribute(create_address_attr(KZA_QUERY_SRC, proto, saddr, sport))
        m.append_attribute(create_address_attr(KZA_QUERY_DST, proto, daddr, dport))
        m.append_attribute(create_name_attr(KZA_QUERY_IFACE, iface))
        return m
