#!/usr/bin/python
#
# Copyright (c) 2011 Citrix Systems, Inc.
#
# This program is free software; you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation; version 2.1 only. with the special
# exception on linking described in file LICENSE.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU Lesser General Public License for more details.
#

import fcntl
import os
import os.path
import string
import subprocess
import sys
import syslog
import time
import common
from common import send_to_syslog, doexec

path = "/sbin:/usr/sbin:/bin:/usr/bin"

command_name = sys.argv[0]
short_command_name = os.path.basename(command_name)

ebtables = "ebtables"
ebtables_lock_path = "/var/lock/ebtables.lock"

vsctl = "ovs-vsctl"
ofctl = "ovs-ofctl"

ip = "/sbin/ip"

def get_host_network_mode(domuuid, devid):
    i = common.Interface(domuuid, devid)
    v = i.get_vif()
    return v.get_mode()

def get_locking_config(domuuid, devid):
    i = common.Interface(domuuid, devid)
    v = i.get_vif()
    return v.get_locking_mode()

def ip_link_set(device, direction):
    doexec([ip, "link", "set", device, direction])

###############################################################################
# Creation of ebtables rules in the case of bridge.
###############################################################################

def get_chain_name(vif_name):
    return ("FORWARD_%s" % vif_name)

def do_chain_action(executable, action, chain_name, args=[]):
    return doexec([executable, action, chain_name] + args)

def chain_exists(executable, chain_name):
    rc, _, _ = do_chain_action(executable, "-L", chain_name)
    return (rc == 0)

def clear_bridge_rules(vif_name):
    vif_chain = get_chain_name(vif_name)
    if chain_exists(ebtables, vif_chain):
        # Stop forwarding to this VIF's chain.
        do_chain_action(ebtables, "-D", "FORWARD", ["-i", vif_name, "-j", vif_chain])
        do_chain_action(ebtables, "-D", "FORWARD", ["-o", vif_name, "-j", vif_chain])
        # Flush and delete the VIF's chain.
        do_chain_action(ebtables, "-F", vif_chain)
        do_chain_action(ebtables, "-X", vif_chain)

def create_bridge_rules(vif_name, config):
    vif_chain = get_chain_name(vif_name)
    # Forward all traffic on this VIF to a new chain, with default policy DROP.
    do_chain_action(ebtables, "-N", vif_chain)
    do_chain_action(ebtables, "-A", "FORWARD", ["-i", vif_name, "-j", vif_chain])
    do_chain_action(ebtables, "-A", "FORWARD", ["-o", vif_name, "-j", vif_chain])
    do_chain_action(ebtables, "-P", vif_chain, ["DROP"])
    # We now need to create rules to allow valid traffic.
    mac = config["mac"]
    # We only fully support IPv4 multitenancy with bridge.
    # We allow all IPv6 traffic if any IPv6 addresses are associated with the VIF.
    ipv4_allowed = config["ipv4_allowed"]
    ipv6_allowed = config["ipv6_allowed"]
    # Accept all traffic going to the VM.
    do_chain_action(ebtables, "-A", vif_chain, ["-o", vif_name, "-j", "ACCEPT"])
    # Drop everything not coming from the correct MAC.
    do_chain_action(ebtables, "-A", vif_chain, ["-s", "!", mac, "-i", vif_name, "-j", "DROP"])
    # Accept DHCP.
    do_chain_action(ebtables, "-A", vif_chain, ["-p", "IPv4", "-i", vif_name, "--ip-protocol", "UDP", "--ip-dport", "67", "-j", "ACCEPT"])
    for ipv4 in ipv4_allowed:
        # Accept ARP travelling from known IP addresses, also filtering ARP replies by MAC.
        do_chain_action(ebtables, "-A", vif_chain, ["-p", "ARP", "-i", vif_name, "--arp-opcode", "Request", "--arp-ip-src", ipv4, "-j", "ACCEPT"])
        do_chain_action(ebtables, "-A", vif_chain, ["-p", "ARP", "-i", vif_name, "--arp-opcode", "Reply", "--arp-ip-src", ipv4, "--arp-mac-src", mac, "-j", "ACCEPT"])
        # Accept IP travelling from known IP addresses.
        do_chain_action(ebtables, "-A", vif_chain, ["-p", "IPv4", "-i", vif_name, "--ip-src", ipv4, "-j", "ACCEPT"])
    if ipv6_allowed != []:
        # Accept all IPv6 traffic.
        do_chain_action(ebtables, "-A", vif_chain, ["-p", "IPv6", "-i", vif_name, "-j", "ACCEPT"])

def acquire_lock(path):
    lock_file = open(path, 'w')
    while True:
        send_to_syslog("attempting to acquire lock %s" % path)
        try:
            fcntl.lockf(lock_file, fcntl.LOCK_EX | fcntl.LOCK_NB)
            send_to_syslog("acquired lock %s" % path)
            return lock_file
        except IOError, e:
            send_to_syslog("failed to acquire lock %s because '%s' - waiting 1 second" % (path, e))
            time.sleep(1)

def handle_bridge(vif_name, domuuid, devid, action):
    if (action == "clear") or (action == "filter"):
        # ebtables fails if called concurrently, so acquire a lock before starting to call it.
        ebtables_lock_file = acquire_lock(ebtables_lock_path)
        ip_link_set(vif_name, "down")
        clear_bridge_rules(vif_name)
        if action == "filter":
            config = get_locking_config(domuuid, devid)
            locking_mode = config["locking_mode"]
            if locking_mode == "locked":
                create_bridge_rules(vif_name, config)
            if locking_mode in ["locked", "unlocked"]:
                ip_link_set(vif_name, "up")
        if action == "clear":
            ip_link_set(vif_name, "up")

###############################################################################
# Creation of openflow rules in the case of openvswitch.
###############################################################################

def get_vswitch_port(vif_name):
    (rc, stdout, stderr) = doexec([vsctl, "get", "interface", vif_name, "ofport"])
    return stdout.readline().strip()

def clear_vswitch_rules(bridge_name, port):
    doexec([ofctl, "del-flows", bridge_name, "in_port=%s" % port])

def add_flow(bridge_name, args):
    doexec([ofctl, "add-flow", bridge_name, args])

def create_vswitch_rules(bridge_name, port, config):
    mac = config["mac"]
    ipv4_allowed = config["ipv4_allowed"]
    ipv6_allowed = config["ipv6_allowed"]
    # Allow DHCP traffic (outgoing UDP on port 67).
    add_flow(bridge_name, "in_port=%s,priority=8000,dl_type=0x0800,nw_proto=0x11,"
                          "tp_dst=67,dl_src=%s,idle_timeout=0,action=normal" % (port, mac))
    # Filter ARP requests.
    add_flow(bridge_name, "in_port=%s,priority=7000,dl_type=0x0806,dl_src=%s,arp_sha=%s,"
                          "nw_src=0.0.0.0,idle_timeout=0,action=normal" % (port, mac, mac))
    for ipv4 in ipv4_allowed:
        # Filter ARP responses.
        add_flow(bridge_name, "in_port=%s,priority=7000,dl_type=0x0806,dl_src=%s,arp_sha=%s,"
                              "nw_src=%s,idle_timeout=0,action=normal" % (port, mac, mac, ipv4))
        # Allow traffic from specified ipv4 addresses.
        add_flow(bridge_name, "in_port=%s,priority=6000,dl_type=0x0800,nw_src=%s,"
                              "dl_src=%s,idle_timeout=0,action=normal" % (port, ipv4, mac))
    for ipv6 in ipv6_allowed:
        # Neighbour solicitation.
        add_flow(bridge_name, "in_port=%s,priority=8000,dl_src=%s,icmp6,ipv6_src=%s,"
                              "icmp_type=135,nd_sll=%s,idle_timeout=0,action=normal" % (port, mac, ipv6, mac))
        # Neighbour advertisement.
        add_flow(bridge_name, "in_port=%s,priority=8000,dl_src=%s,icmp6,ipv6_src=%s,"
                              "icmp_type=136,nd_target=%s,idle_timeout=0,action=normal" % (port, mac, ipv6, ipv6))
        # Allow traffic from specified ipv6 addresses.
        add_flow(bridge_name, "in_port=%s,priority=5000,dl_src=%s,ipv6_src=%s,icmp6,action=normal" % (port, mac, ipv6))
        add_flow(bridge_name, "in_port=%s,priority=5000,dl_src=%s,ipv6_src=%s,tcp6,action=normal" % (port, mac, ipv6))
        add_flow(bridge_name, "in_port=%s,priority=5000,dl_src=%s,ipv6_src=%s,udp6,action=normal" % (port, mac, ipv6))
    # Drop all other neighbour discovery.
    add_flow(bridge_name, "in_port=%s,priority=7000,icmp6,icmp_type=135,action=drop" % port)
    add_flow(bridge_name, "in_port=%s,priority=7000,icmp6,icmp_type=136,action=drop" % port)
    # Drop other specific ICMPv6 types.
    # Router advertisement.
    add_flow(bridge_name, "in_port=%s,priority=6000,icmp6,icmp_type=134,action=drop" % port)
    # Redirect gateway.
    add_flow(bridge_name, "in_port=%s,priority=6000,icmp6,icmp_type=137,action=drop" % port)
    # Mobile prefix solicitation.
    add_flow(bridge_name, "in_port=%s,priority=6000,icmp6,icmp_type=146,action=drop" % port)
    # Mobile prefix advertisement.
    add_flow(bridge_name, "in_port=%s,priority=6000,icmp6,icmp_type=147,action=drop" % port)
    # Multicast router advertisement.
    add_flow(bridge_name, "in_port=%s,priority=6000,icmp6,icmp_type=151,action=drop" % port)
    # Multicast router solicitation.
    add_flow(bridge_name, "in_port=%s,priority=6000,icmp6,icmp_type=152,action=drop" % port)
    # Multicast router termination.
    add_flow(bridge_name, "in_port=%s,priority=6000,icmp6,icmp_type=153,action=drop" % port)
    # Drop everything else.
    add_flow(bridge_name, "in_port=%s,priority=4000,idle_timeout=0,action=drop" % port)

def get_bridge_name_vswitch(vif_name):
    '''return bridge vif belong to'''
    (rc, stdout, stderr) = doexec([vsctl, "iface-to-br",  vif_name ])
    temp_bridge_name = stdout.readline().strip() 
    '''get bridge parent, in case we were given a fake bridge device'''
    '''will return same name if it is already a real bridge'''
    (rc, stdout, stderr) = doexec([vsctl, "br-to-parent", temp_bridge_name ])
    return stdout.readline().strip()

def handle_vswitch(vif_name, domuuid, devid, action):
    if (action == "clear") or (action == "filter"):
        bridge_name = get_bridge_name_vswitch(vif_name)
        ip_link_set(vif_name, "down")
        port = get_vswitch_port(vif_name)
        clear_vswitch_rules(bridge_name, port)
        if action == "filter":
            config = get_locking_config(domuuid, devid)
            locking_mode = config["locking_mode"]
            if locking_mode == "locked":
                create_vswitch_rules(bridge_name, port, config)
            if locking_mode in ["locked", "unlocked"]:
                ip_link_set(vif_name, "up")
        if action == "clear":
            ip_link_set(vif_name, "up")

###############################################################################
# Executable entry point.
###############################################################################

def main(vif_name, domuuid, devid, network_mode, action):
    if network_mode == "bridge":
        handle_bridge(vif_name, domuuid, devid, action)
    elif network_mode == "openvswitch":
        handle_vswitch(vif_name, domuuid, devid, action)

def usage():
    print "Usage:"
    print "%s xenopsd_backend interface uuid devid action" % short_command_name
    print ""
    print "interface:"
    print "    The name of the network interface connecting to the guest."
    print "uuid:"
    print "    The UUID of the guest."
    print "devid:"
    print "    The number of the network to which the interface is connected."
    print "action: [clear|filter]"
    print "    Specifies whether to create filtering rules for the interface, or to clear all rules associated with the interface."

if __name__ == "__main__":
    if len(sys.argv) != 6:
        usage()
        sys.exit(1)
    else:
        xenopsd_backend, vif_name, domuuid, devid, action = sys.argv[1:6]
        common.set_xenopsd_backend(xenopsd_backend)
        network_mode = get_host_network_mode (domuuid, devid)

        send_to_syslog("Called with interface=%s, uuid=%s, devid=%s, network_mode=%s, action=%s" %
            (vif_name, domuuid, devid, network_mode, action))
        main(vif_name, domuuid, devid, network_mode, action)
