#!/usr/bin/python
# Copyright 1999-2017. Plesk International GmbH. All rights reserved.
import os
import sys
import re
import socket
import string
import tempfile
import binascii
import ConfigParser
import logging


db_login = 'admin'
db_pass = None
db_name = 'psa'
psa_conf = {}
map_file = None
ifmng_bin = None
mysql_cmd = None

mapfile_header = """# You should edit IP addresses, netmasks and interfaces to reflect your
# future settings. If you don't want the IP to be changed - leave it untouched,
# comment out it's line or remove entire line from the file.

"""

cfg_line_pattern = re.compile(r"(\w+)\s+(\S+)\s*")


def usage():
    print """Plesk reconfigurator - utility to change IP addresses used by Plesk

Usage: %s { <map_file> | --autoconfigure | --remap-ips | --help }

If <map_file> doesn't exists - template will be created, otherwise it will be used to map IP addresses.

--autoconfigure option will attempt to create and process IP mapping automatically. Any new excessive
or old unmapped IP addresses will retain their status and would need to be handled manually either by
rereading IP addresses or by passing a correct map file to this utility.

--remap-ips is an alias for --autoconfigure option.

--help option displays this help page.
""" % sys.argv[0]


def err(msg):
    print >>sys.stderr, msg


def report(*args):
    msg = ""
    for i in args:
        msg += str(i)
        msg += " "
    print msg


def init():
    global db_pass, psa_conf, ifmng_bin, mysql_cmd
    fd = open('/etc/psa/.psa.shadow')
    try:
        db_pass = fd.read().strip()
    finally:
        fd.close()

    fd = open('/etc/psa/psa.conf')
    try:
        for ln in fd.readlines():
            pm = cfg_line_pattern.match(ln)
            if pm:
                psa_conf[pm.group(1)] = pm.group(2)
    finally:
        fd.close()

    ifmng_bin = os.path.join(psa_conf["PRODUCT_ROOT_D"], 'admin/bin/ifmng')
    mysql_cmd = """%s '-u%s' '-p%s' '-D%s' """ % (os.path.join(psa_conf["MYSQL_BIN_D"], 'mysql'), db_login, db_pass, db_name)


def readCmd(cmd):
    fd = os.popen(cmd)
    try:
        return fd.readlines()
    finally:
        fd.close()


def execCmd(*args):
    cmd = ' '.join(args)
    result = os.system(cmd)
    if result > 0:
        raise Exception("%s: exited with non-zero code %s" % (cmd, result))
    elif result < 0:
        raise Exception("%s: killed by signal %s" % (cmd, result))


def ifmng(*args):
    return execCmd(ifmng_bin, *args)


def execUtil(util, *args):
    return execCmd(os.path.join(psa_conf["PRODUCT_ROOT_D"], 'admin/bin', util), *args)


def readDBAddresses():
    IPs = []
    sqlQuery = """SELECT ip.id, ip.iface, ip.ip_address, ip.mask
                FROM IP_Addresses ip
                INNER JOIN ServiceNodes sn
                    ON ip.serviceNodeId = sn.id
                WHERE sn.ipAddress = 'local'"""
    for ident, iface, ip, mask in db_query(sqlQuery):
        IP = IPAddr(ip, mask, iface)
        IP.ident = ident
        IPs.append(IP)
    return IPs


def createMapFile(fn):
    fd = open(fn, "w")
    try:
        fd.write(mapfile_header)
        for addr in readDBAddresses():
            fd.write("%s -> %s\n" % (addr, addr))
    finally:
        fd.close()


class IPAddr:
    def __init__(self, addr, netmask=None, iface=None):
        self.netmask = netmask
        self.iface = iface
        try:
            self.binaddr = socket.inet_pton(socket.AF_INET, addr)
            self.af = socket.AF_INET
        except:
            try:
                self.binaddr = socket.inet_pton(socket.AF_INET6, addr)
                self.af = socket.AF_INET6
            except Exception as e:
                raise Exception("%s: is neither ipv6 nor ipv4 address (%s)" % (addr, e))
        self.addr = socket.inet_ntop(self.af, self.binaddr)

    def __eq__(self, rhs):
        return self.binaddr.__eq__(rhs.binaddr)

    def __hash__(self):
        return self.binaddr.__hash__()

    def __str__(self):
        return "%s %s %s" % (self.iface, self.addr, self.netmask)

    def __repr__(self):
        return "<%s>" % self.addr

    def fullAddr(self):
        if (self.af == socket.AF_INET6):
            hexstr = binascii.hexlify(self.binaddr)
            return ':'.join(hexstr[i:i+4] for i in range(0, len(hexstr), 4))
        else:
            return self.addr

    def isIPv4(self):
        return self.af == socket.AF_INET

    def isIPv6(self):
        return self.af == socket.AF_INET6


def readSystemAddresses():
    rv = []
    for ln in readCmd("%s -l" % ifmng_bin):
        addr, mask, iface, main = ln.split()
        rv.append(IPAddr(addr, mask, iface))

    return rv


def makeNotBlacklistedIpAddrPred():
    config = ConfigParser.RawConfigParser()
    config.read(os.path.join(psa_conf["PRODUCT_ROOT_D"], 'admin/conf/panel.ini'))

    def get_list(config, section, option, default=None):
        try:
            value = config.get(section, option)
            return map(str.strip, value.lower().split(','))
        except ConfigParser.Error:
            return default or []

    blacklisted_ips = get_list(config, 'ip', 'blacklist')
    blacklisted_ifaces = get_list(config, 'networkInterfaces', 'blacklist', default=['docker'])

    def predicate(ip):
        if ip.addr in blacklisted_ips:
            return False
        if ip.iface.lower().rstrip(string.digits) in blacklisted_ifaces:
            return False
        return True

    return predicate


empty_pattern = re.compile(r"\s*(#.*)?")
map_pattern = re.compile(r"([^#]+)->([^#]+)(#.*)?")

addr_pattern = re.compile(r"\s*(\w+):?\s*(\S+)\s+([0-9.]+)")


def parseAddr(addr):
    am = addr_pattern.match(addr)
    if am is None:
        raise Exception("%s: cannot parse", addr)
    return am.group(1), am.group(2), am.group(3)


def readMapping(fn):

    fd = open(fn)
    errors = 0
    rv = {}
    try:
        for ln in fd.readlines():
            pmatch = map_pattern.match(ln)
            if not pmatch:
                if not empty_pattern.match(ln):
                    err("%s: cannot parse" % ln)
                    errors += 1
            else:
                try:
                    iface_f, addr_f, mask_f = parseAddr(pmatch.group(1))
                    iface_t, addr_t, mask_t = parseAddr(pmatch.group(2))
                    ipf = IPAddr(addr_f, mask_f, iface_f)
                    ipt = IPAddr(addr_t, mask_t, iface_t)
                    if ipf in rv:
                        err("%s: is already mapped to %s" % (ipf.addr, rv[ipf].addr))
                        errors += 1
                    if ipt in rv.values():
                        err("%s: is already used as target address" % (ipt.addr))
                        errors += 1

                    rv[ipf] = IPAddr(addr_t, mask_t, iface_t)
                except:
                    err("%s: cannot parse" % ln.strip())
                    errors += 1

    finally:
        fd.close()
    for addr in rv:
        if addr.af != rv[addr].af:
            err("%s=>%s: it is not allowed to change IPv6 address to IPv4 and vice versa" % (addr.addr, rv[addr].addr))
            errors += 1
    if errors:
        raise Exception("%s: cannot parse, %s errors found." % (fn, errors))
    return rv


def db_query(stmt):
    cmd = """%s -s -N -e "%s" """ % (mysql_cmd, stmt)
    return [s.split() for s in readCmd(cmd)]


def generateUpdateSQL(mapping):
    fd, fn = tempfile.mkstemp('.sql')
    fd = os.fdopen(fd, "w")
    fd.write("BEGIN;\n")

    for addr in mapping:
        naddr = mapping[addr]
        fd.write("UPDATE IP_Addresses SET ip_address='%s', iface='%s', mask='%s' WHERE id='%s';\n" % (naddr.addr, naddr.iface, naddr.netmask, addr.ident))

    domains = set()

    query_tpl = r"""SELECT rec.id, z.name, rec.type, rec.val FROM dns_recs rec JOIN dns_zone z ON (z.id = rec.dns_zone_id) WHERE rec.type IN (%s);"""

    for rec_id, domain_name, rec_type, rec_val in db_query(query_tpl % ("'A', 'AAAA'")):
        rec_ip = IPAddr(rec_val)
        if rec_ip in mapping:
            new_addr = mapping[rec_ip]
            if rec_type == 'A' and new_addr.af == socket.AF_INET6:
                rec_type = 'AAAA'
            elif rec_type == 'AAAA' and new_addr.af == socket.AF_INET:
                rec_type = 'A'

            domains.add(domain_name)  # this includes domains and domain aliases
            fd.write("UPDATE dns_recs SET type='%s', val='%s', displayVal='%s' WHERE id = '%s';\n" % (rec_type, new_addr.addr, new_addr.addr, rec_id))

    query_ptr = r"""SELECT rec.id, z.name, rec.type, rec.host FROM dns_recs rec JOIN dns_zone z ON (z.id = rec.dns_zone_id) WHERE rec.type IN ('PTR');"""

    for rec_id, domain_name, rec_type, rec_host in db_query(query_ptr):
        rec_ip = IPAddr(rec_host)
        if rec_ip in mapping:
            new_addr = mapping[rec_ip]
            fd.write("UPDATE dns_recs SET host='%s', displayHost='%s' WHERE id = '%s';\n" % (new_addr.fullAddr(), new_addr.addr, rec_id))

            domains.add(domain_name)

    for rec_id, rec_type, rec_host in db_query(r"SELECT rec.id, rec.type, rec.host FROM dns_recs rec WHERE dns_zone_id IS NULL"):
        rec_ip = IPAddr(rec_host)
        if rec_ip in mapping:
            new_addr = mapping[rec_ip]
            if rec_type == 'PTR':
                fd.write("UPDATE dns_recs SET host='%s', displayHost='%s' WHERE id = '%s';\n" % (new_addr.fullAddr(), new_addr.addr, rec_id))
            elif rec_type != 'none':
                fd.write("UPDATE dns_recs SET host='%s', displayHost='%s' WHERE id = '%s';\n" % (new_addr.addr, new_addr.addr, rec_id))

    fd.write("COMMIT;\n")
    fd.close()
    return fn, domains


def reconfigure(mapping):
    report("Generating DB update script... ")
    sqlfile, affected_domains = generateUpdateSQL(mapping)
    report("ok")

    report("Updating database... ")
    execCmd(mysql_cmd, " < ", sqlfile)
    os.unlink(sqlfile)
    report("ok")
    if affected_domains:
        report("Reconfiguring DNS:")
        for domain in affected_domains:
            report("domain %s..." % domain)
            execUtil('dnsmng', '--update', domain)
            report("ok")
        report("Restarting DNS service...")
        execUtil('dnsmng', '--restart')
        report("ok")

# As we got affected domains from DNS, we cannot rely on it in case DNS not installed
    report("Reconfiguring Apache...")
    execUtil('httpdmng', '--reconfigure-all')
    report("ok")

    report("Reconfiguring Proftpd...")
    execUtil('ftpmng', '--reconfigure-all')
    report("ok")

    wd = os.path.join(psa_conf["PRODUCT_ROOT_D"], 'admin/bin/modules/watchdog/wd')
    if os.path.exists(wd):
        report("Reconfiguring Watchdog module...")
        execCmd(wd, "--adapt")
        report("ok")

    transport_restore = '/usr/lib/plesk-9.0/mail_postfix_transport_restore'
    if os.path.exists(transport_restore):
        report("Rebuilding Postfix transport map...")
        execCmd(transport_restore)
        report('ok')

    report("Refresh trusted IPs for site preview")
    refresh_trusted_ips = os.path.join(
        psa_conf["PRODUCT_ROOT_D"],
        'admin/plib/scripts/refresh-trusted-ips.php')
    execUtil('php', refresh_trusted_ips)
    report('ok')


def fileMapping(mapfile):
    system_addresses = filter(makeNotBlacklistedIpAddrPred(), readSystemAddresses())
    mapping = readMapping(mapfile)
    db_addresses = readDBAddresses()

    # finding out DB ids
    for ip in mapping:
        if ip in db_addresses:
            dbIP = db_addresses[db_addresses.index(ip)]
            ip.ident = dbIP.ident

    errors = 0
    for addr in mapping:
        if addr not in db_addresses:
            err("%s: address is not used by Plesk" % addr.addr)
            errors += 1
    for addr in db_addresses:
        if addr not in mapping and addr in mapping.values():
            err("%s: address is already used by Plesk" % addr.addr)
            errors += 1
    if errors:
        raise Exception("%s: %d conflicts found" % (mapfile, errors))

    clean_mapping = {}
    for addr in mapping:
        if mapping[addr] != addr:
            clean_mapping[addr] = mapping[addr]
    mapping = clean_mapping

    for addr in mapping.values():
        exists = False
        for eaddr in system_addresses:
            if eaddr == addr:
                if eaddr.iface != addr.iface or eaddr.netmask != addr.netmask:
                    ifmng("--del", eaddr.addr)
                else:
                    exists = True
        if not exists:
            report("Adding %s..." % addr)
            ifmng("--add", addr.iface, addr.addr, addr.netmask)
            report("ok")

    return mapping


def autoMapping():
    dbData = readDBAddresses()
    sysData = filter(makeNotBlacklistedIpAddrPred(), readSystemAddresses())

    report("Database:", dbData)
    report("Actual:", sysData)

    removedIPv4s = []
    removedIPv6s = []
    addedIPv4s = []
    addedIPv6s = []
    mapping = {}

    # searching for removed IPs
    for dbIP in dbData:
        if dbIP not in sysData:
            if dbIP.isIPv4():
                removedIPv4s.append(dbIP)
            elif dbIP.isIPv6():
                removedIPv6s.append(dbIP)
            else:
                assert(False)

    # searching for added IPs
    for sysIP in sysData:
        if sysIP not in dbData:
            if sysIP.isIPv4():
                addedIPv4s.append(sysIP)
            elif sysIP.isIPv6():
                addedIPv6s.append(sysIP)
            else:
                assert(False)

    report("Removed IPs:", removedIPv4s, removedIPv6s)
    report("Added IPs:", addedIPv4s, addedIPv6s)

    for ipFrom in removedIPv4s[:]:
        try:
            ipTo = addedIPv4s.pop(0)
            removedIPv4s.remove(ipFrom)
            mapping[ipFrom] = ipTo
        except:
            break
    for ipFrom in removedIPv6s[:]:
        try:
            ipTo = addedIPv6s.pop(0)
            removedIPv6s.remove(ipFrom)
            mapping[ipFrom] = ipTo
        except:
            break

    report("Mapping:", mapping)
    report("Old not remapped:", removedIPv4s + removedIPv6s)
    report("New not used:", addedIPv4s + addedIPv6s)
    return mapping


def main():
    if len(sys.argv) != 2 or sys.argv[1] in ("--help", "-h"):
        usage()
        sys.exit(1)
    init()

    mapping = None

    if sys.argv[1] in ("--autoconfigure", "--remap-ips"):
        mapping = autoMapping()
    else:
        map_file = sys.argv[1]
        if not os.path.exists(map_file):
            createMapFile(map_file)
            report("""IP map file template '%s' is successfully created.
Edit it to declare desired configuration, and start reconfigurator again with --file '%s'.
""" % (map_file, map_file))
        else:
            mapping = fileMapping(map_file)

    if mapping:
        reconfigure(mapping)
        report("IP addresses are successfully changed.")
    else:
        report("Nothing to do.")


if __name__ == "__main__":
    try:
        main()
    except Exception:
        logging.exception("Exception occurred")
        sys.exit(1)

# vim: ft=python ts=4 sts=4 sw=4 et :
