#!/usr/bin/env python

"""
Produce a formatted report from sls hosts.
"""

import json
import logging
import pprint
import socket
import time
import urlparse
import warnings

from collections import OrderedDict
from multiprocessing.dummy import Process, Value
from optparse import OptionParser

import IPy
import jinja2
import requests
import tldextract

ACTIVE_HOSTS = 'http://ps-west.es.net/lookup/activehosts.json'


class SlsReportException(Exception):
    """Custom SlsReport exception"""
    def __init__(self, value):
        # pylint: disable=super-init-not-called
        self.value = value

    def __str__(self):
        return repr(self.value)


class SlsReportWarning(Warning):
    """Custom SlsReport warning"""
    pass


def setup_log(log_path=None):
    """
    Usage:
    _log('main.start', 'happy simple log event')
    _log('launch', 'more={0}, complex={1} log=event'.format(100, 200))
    """
    # pylint: disable=redefined-variable-type
    logger = logging.getLogger("sls_report")
    if not log_path:
        handle = logging.StreamHandler()
    else:
        # it's on you to make sure log_path is valid.
        logfile = '{0}/sls_report.log'.format(log_path)
        handle = logging.FileHandler(logfile)
    handle.setFormatter(logging.Formatter('ts=%(asctime)s %(message)s'))
    logger.addHandler(handle)
    logger.setLevel(logging.INFO)
    return logger

log = setup_log()  # pylint: disable=invalid-name


def _log(event, msg):
    log.info('event=%s id=%s %s', event, int(time.time()), msg)


def _dummy_log(event, msg):  # pylint: disable=unused-argument
    """used to completely silence default logging."""
    pass


class SlsReportBase(object):
    """
    Base class for other report classes.
    """
    def __init__(self, options, logger):
        self._options = options
        self._log = logger
        self._pp = pprint.PrettyPrinter(indent=4)

    def pretty_verbose(self, obj, level=1):
        """
        Pretty print data structures if --verbose flag is on (default)
        or above a certain -v -v -v level.
        """
        if self._options.verbose >= level:
            if isinstance(obj, SlsStatCapsule):
                out = str(obj)
            else:
                out = self._pp.pformat(obj)
            self._log('verbose.level.{level}'.format(level=level), out)

    def verbose_log(self, event, msg, level=1):
        """Log message if --verbose is on."""
        if self._options.verbose >= level:
            self._log(event, msg)

    @staticmethod
    def warn(msg):
        """raise a custom warning."""
        warnings.warn(msg, SlsReportWarning, stacklevel=2)


class SlsHostList(SlsReportBase):
    """
    Query the activehost node, then the lookup hosts to generate a list
    of perfSONAR hosts, and the interface information for each host.
    """
    def __init__(self, options, logger):
        super(SlsHostList, self).__init__(options, logger)

        self._lookup_hosts = list()  # the ls hosts
        self._ls_count = dict()  # records per ls host

        self._host_map = dict()  # key: hostname/value: record object

        self._duplicates = 0  # count dupe host records

        # Get the main lookup host list or it's a no go

        self._log('SlsHostList.init.run', 'fetching activehost list')

        try:
            r = requests.get(options.active_source)
        except requests.exceptions.ConnectionError as ex:
            raise SlsReportException('activehost lookup connection error: {ex}'.format(ex=str(ex)))

        if not r.status_code == 200:
            raise SlsReportException('Request to {host} failed: status:{code} {msg}'.format(
                host=options.active_source,
                code=r.status_code,
                msg=r.content
            ))

        host_source = json.loads(r.content)

        for i in host_source.get('hosts'):
            self._lookup_hosts.append(i.get('locator'))

        self.pretty_verbose(self._lookup_hosts, level=2)

        self._log('SlsHostList.init.end',
                  'got {count} lookup hosts'.format(count=len(self._lookup_hosts)))

    def _query_lookup_host(self, look, record_type='host'):
        """
        Method to query the lookup hosts. Originally in generate_host_list
        but another pass for interfaces was added.
        """

        fail_msg = 'query to {host}?type={t} failed'.format(host=look, t=record_type)

        try:
            r = requests.get(look, params=dict(type=record_type))
        except requests.exceptions.ConnectionError as ex:
            self.warn('Unable to connect to {host} : {ex} - skipping'.format(
                host=look, ex=str(ex)))
            self._log('generate_host_list.error', fail_msg)
            return []

        if not r.status_code == 200:
            msg = 'non http 200 return code: {code} {msg}'.format(
                code=r.status_code, msg=r.content)
            self.warn(msg)
            self._log('generate_host_list.error', fail_msg)
            return []

        payload = json.loads(r.content)

        self.pretty_verbose(payload, level=4)

        return payload

    def generate_host_list(self):
        """
        Query the lookup hosts and generate a list of hostnames and also
        a count of the hosts that were reported.

        force_host can be used to pass in a specific host and short circuit
        the lookups. This is to speed develoment up.
        """

        hosts_processed = 0
        ifaces_processed = 0

        gen_start = time.time()

        for look in self._lookup_hosts:

            count_per_lookup_host = 0

            self._log('generate_host_list.run', 'querying {host}'.format(host=look))

            t_start = time.time()

            # Generate a dict of interface records, keyed by the uri element.

            iface_records = dict()

            iface_list = self._query_lookup_host(look, record_type='interface')

            for i_rec in iface_list:
                iface_records[i_rec.get('uri')] = i_rec

            # now get the list of hosts.

            payload = self._query_lookup_host(look)

            self.pretty_verbose(payload, level=4)

            # loop through list hosts, lookup the interfaces from the cached
            # interface information, build main list and count.

            for i in payload:
                if i.get('host-net-interfaces', None) is not None:

                    i['host-net-interfaces'] = self._get_interfaces(
                        i.get('host-net-interfaces', []), iface_records)

                    ifaces_processed += len(i.get('host-net-interfaces', []))

                hname = self._get_hostname(i.get('host-name', []))

                if not hname:
                    self.verbose_log('generate_host_list.error',
                                     'record missing hostname - skipping {rec}'.format(
                                         rec=self._pp.pformat(i)))
                    continue

                # check for duplicate entries
                if hname in self._host_map:
                    self.verbose_log('generate_host_list.warn',
                                     'duplicate found for {host}'.format(host=hname))
                    self._duplicates += 1
                    count_per_lookup_host -= 1  # don't count this one then

                self._host_map[hname] = i

                hosts_processed += 1
                count_per_lookup_host += 1

                if self._options.limit and count_per_lookup_host >= self._options.limit:
                    self._log('query_hosts.limit',
                              'reached host limit of {limit} - breaking'.format(
                                  limit=self._options.limit))
                    break

            self._log('generate_host_list.run', 'processed {host} in {sec} seconds'.format(
                host=look, sec=round(time.time() - t_start, 2)))

            # count how many records we got for this host for report.
            upar = urlparse.urlparse(look)
            self._ls_count[upar.netloc] = count_per_lookup_host

            if self._options.single:
                self._log('generate_host_list.break', 'got --single flag - breaking')
                break

        self.pretty_verbose(self._host_map, level=3)
        self._log('generate_host_list.end',
                  '{num} hosts and {inum} interfaces processed in {sec} seconds'.format(
                      num=hosts_processed, inum=ifaces_processed,
                      sec=round(time.time() - gen_start, 2)))

    def _get_interfaces(self, iface_uris, iface_records):
        """Generate a list of interface objects pulled from pre-cached
        interface records."""

        iface_objects = list()

        for i in iface_uris:
            iface = iface_records.get(i, None)
            iface_objects.append(iface)
            self.pretty_verbose(iface, level=4)

        return [x for x in iface_objects if x is not None]

    @staticmethod
    def _get_hostname(hlist):
        """try to extract a valid hostname from the host-name element list."""
        hname = None

        for i in hlist:
            if i:
                hname = i
                break

        return hname

    @staticmethod
    def quick_fetch(url):
        """method to run in thread to get the interfaces."""
        try:
            r = requests.get(url)
        except requests.exceptions.ConnectionError:
            return None

        if r.status_code == 200 and r.content:
            return json.loads(r.content)
        else:
            return None

    @property
    def host_list(self):
        """hosts property"""
        return self._host_map.values()

    @property
    def duplicates(self):
        """duplicates property."""
        return self._duplicates

    @property
    def ls_count(self):
        """ls_count property."""
        return self._ls_count


class SlsStatCapsule(SlsReportBase):
    """
    Encapsulation class and accessors manage the statistics
    from SlsStatstics.
    """
    def __init__(self, options, logger):
        super(SlsStatCapsule, self).__init__(options, logger)
        self._host_count = 0
        self._counter_stats = dict()

    def tick_host_count(self):
        """increment host count."""
        self._host_count += 1

    # stats to count numbers of things.

    def _counter(self, dname, val):
        """Interface to the underlying dict of dicts meant to count instances of things."""

        if dname not in self._counter_stats:
            self._counter_stats[dname] = dict()

        if val not in self._counter_stats.get(dname):
            self._counter_stats[dname][val] = 0

        self._counter_stats[dname][val] += 1

    def mtu(self, mtu):
        """store mtu values."""
        if mtu:
            for val in mtu:
                self._counter('mtu', val)

    def nic_speed(self, speed):
        """store nic speed values."""

        xlate = {
            '40000000000': u'40 Gbps',
            '20000000000': u'20 Gbps',
            '10000000000': u'10 Gbps',
            '1000000000': u'1 Gbps',
            '2000000000': u'2 Gbps',
            '100000000': u'100 Mbps',
            '10000000': u'10 Mbps',
            '0': u'undefined',
        }

        if speed:
            for val in speed:
                self._counter('speed', xlate.get(val))

    def interface_summary(self, address, skip_reverse_dns):  # pylint: disable=too-many-branches
        """
        Gather the information from the interface objects.
        Public/private, ipv4/6, domain and country information.

        Sorry pylint, it's gotta be this way.
        """

        # if --domain flag not given, do a quick and dirty domain
        # and country report from the group-domains element
        # from the interface object.

        if not self._options.domain:
            c_tmp = dict()

            for i in address.get('group-domains', []):
                # domains
                if not i:
                    continue
                self._counter('domains', i)
                c_tmp[i.lower()] = 1

            # parse out the country
            for i in c_tmp.keys():
                extract = tldextract.extract(i)
                if extract.suffix.strip():
                    self._counter('country', extract.suffix)
                else:
                    self._counter('country', 'unknown')

        # get information on the addresses on the interface and
        # do the "live" dns/country resolution if that was requested.

        lookup_hang = skip_reverse_dns

        for addy in address.get('interface-addresses', []):

            t_start = time.time()

            try:
                ip = IPy.IP(addy)
            except ValueError:
                # fqdn -> ip
                try:
                    ip = IPy.IP(socket.gethostbyname(addy))
                except socket.gaierror:
                    # time to give up
                    continue

            self._counter('iface_type', ip.version())

            if ip.iptype() == 'PRIVATE':
                self._counter('iface_pub', 'private')
            else:
                self._counter('iface_pub', 'public')

            # do the live lookup if that's what they requested.
            if self._options.domain:

                # this little song and dance here is becasue gethostbyaddr()
                # does not have a timeout, nor does it honor when you set
                # socket.setdefaulttimeout() - and some of these lookups
                # take a good long while to timeout. So the theory here
                # is to do the lookup in a single thread and call join() to
                # enforce the desired timeout.
                #
                # Yes, the thread will keep running unti it completes, but
                # I don't think that'll be the end of the world vs. the
                # performance increase. It's only a few lookups that cause
                # things to hang, most bad ones raise the exception almost
                # immediately.
                #
                # Also, if a reverse lookup times out for an address on one
                # interface, presume any additional addresses on this interface
                # or other interfaces on the same host will do the same so bail.

                if lookup_hang:
                    self._log('interface_summary.skip', 'skipping lookup for {addy}'.format(
                        addy=addy))
                    continue

                def quick_ns(val, ip):
                    """wrapper to run in thread to deal with timeouts."""
                    try:
                        val.value = socket.gethostbyaddr(ip)[0]
                    except socket.herror:
                        pass

                val = Value('i', 0)  # a shared memory object
                lookup = Process(target=quick_ns, args=(val, ip.strNormal()))
                lookup.start()
                lookup.join(8)  # sorry buddy, time's up

                if val.value:
                    extract = tldextract.extract(val.value)
                    self._counter('domains', extract.registered_domain)
                    if extract.suffix.strip():
                        self._counter('country', extract.suffix)
                    else:
                        self._counter('country', 'unknown')

                # notify if the lookup takes a while
                duration = round(time.time() - t_start, 2)
                if duration > 5:
                    self._log('ip_summary.notice', 'took {sec} seconds resolving {addr}'.format(
                        sec=duration, addr=addy))
                    # stop additional iterations on this interface and subsequent
                    # interfaces on this host since they will likely hang as well.
                    lookup_hang = True

        return lookup_hang

    def ps_toolkit_version(self, version):
        """store the ps toolkit version."""
        if version:
            for ver in version:
                self._counter('ps_version', ver)

    def os_and_architecture(self, os, version, kernel):  # pylint: disable=invalid-name
        """handle storing OS/etc information."""

        for i in os, version, kernel:
            if not isinstance(i, list) or not len(i):
                return

        os_info = '{os} {version}'.format(
            os=os[0], version=version[0]
        )

        self._counter('os', os_info)

        arch = kernel[0].split('.')[-1]

        self._counter('arch', arch)

    # formatting, etc.

    def _dict_to_stats(self, dname):
        """convert a 'counter' sub-dict to a sorted dict of percentages."""
        sdict = self._counter_stats.get(dname, {})
        total = sum(sdict.values())

        stat_dict = OrderedDict()

        # sort by the values, with the highest value first
        for order in sorted(sdict, key=sdict.get, reverse=True):

            val = round(
                (float(sdict.get(order)) / total) * 100,
                2
            )

            # make everything strings for the templating library
            stat_dict.update({str(order): (str(sdict.get(order)), str(val),)})

        return stat_dict

    def to_dict(self):
        """Return a sorted dict of the statistics. The order of the tuples
        in the OrderedDict defines how the report is ordered."""
        doc = OrderedDict(
            [
                ('ls_count', self._dict_to_stats('ls_count')),
                ('ps_version', self._dict_to_stats('ps_version')),
                ('iface_speed', self._dict_to_stats('speed')),
                ('architecture', self._dict_to_stats('arch')),
                ('os_info', self._dict_to_stats('os')),
                ('ipv4/v6', self._dict_to_stats('iface_type')),
                ('mtu_summary', self._dict_to_stats('mtu')),
                ('public_private', self._dict_to_stats('iface_pub')),
                ('top_level_domain/country', self._dict_to_stats('country')),
                ('domains', self._dict_to_stats('domains')),
            ]
        )

        return doc

    def add_counter_dict(self, dname, cdict):
        """
        Allow the calling stats class to add a new dict to the
        _counter_stats dict o' dicts for the report rather than
        using _counter()
        """
        self._counter_stats[dname] = cdict

    @property
    def host_count(self):
        """get the host count"""
        return self._host_count

    def __str__(self):
        """str repr"""
        return json.dumps(self.to_dict())


class SlsStatistics(SlsReportBase):
    """
    Take an SlsHostList object and use the sls_client library to get
    detailed information about the hosts.

    Then used to generate reports.
    """
    def __init__(self, sls_hosts, options, logger):
        super(SlsStatistics, self).__init__(options, logger)

        self._sls_hosts = sls_hosts

        self._query_count = 0
        self._good_record = False

        self._capsule = SlsStatCapsule(options, logger)
        self._capsule.add_counter_dict('ls_count', self._sls_hosts.ls_count)

    def process_hosts(self):
        """
        Process the hosts that were gathered up from the lookup hosts.
        """

        self._log('process_hosts.begin', 'start')

        for res_host in self._sls_hosts.host_list:

            self.verbose_log('process_hosts.run', 'processing {host} [{i}/{len}]'.format(
                host=res_host.get('host-name')[0],
                i=self._capsule.host_count + 1,
                len=len(self._sls_hosts.host_list)), level=2)

            self.pretty_verbose(res_host, level=3)

            self._capsule.tick_host_count()

            self._capsule.ps_toolkit_version(res_host.get('pshost-toolkitversion'))

            self._capsule.os_and_architecture(
                res_host.get('host-os-name'),
                res_host.get('host-os-version'),
                res_host.get('host-os-kernel'),
            )

            # reset this when processing a new host
            skip_reverse_dns = False

            for iface in res_host.get('host-net-interfaces', []):
                if not isinstance(iface, dict):
                    continue
                self._capsule.mtu(iface.get('interface-mtu'))
                self._capsule.nic_speed(iface.get('interface-capacity'))
                # if there was a reverse dns lookup hang in an interface,
                # note that state and don't do the lookup in subsequent
                # interfaces on this host.
                # skip_reverse_dns = self._capsule.ip_summary(
                #     iface.get('interface-addresses'), skip_reverse_dns)
                skip_reverse_dns = self._capsule.interface_summary(
                    iface, skip_reverse_dns)

            self.pretty_verbose(self._capsule, level=2)

        self._log('process_hosts.end',
                  'client lookup on {num} hosts'.format(num=self._capsule.host_count))

    @property
    def get_report(self):
        """Get the raw statistics dict o' dicts."""
        return self._capsule.to_dict()

    @property
    def host_count(self):
        """Get the underlying number of hosts from the capsule."""
        return self._capsule.host_count

    @property
    def duplicates(self):
        """Get the duplicates from the underlying host list object."""
        return self._sls_hosts.duplicates


class SlsReportFormatter(object):  # pylint: disable=too-few-public-methods
    """Wrapper class to produce the formatted output."""
    def __init__(self, report, count, duplicates):
        self._report = report
        self._count = count
        self._duplicates = duplicates

        # jinja template strings
        # default name/count/percent columns
        self._columns = """{{section}} ({{count}}):
                                                     count    percent
{% for key, val in _dict.iteritems() %}  {{ key.ljust(52) }}{{ val[0].rjust(4) }}{{ val[1].rjust(9) }} %
{% endfor %}"""

        self._column_template = jinja2.Template(self._columns)

    def get_console_report(self):
        """Returns a formatted string of template output."""
        fmt = '\nLookup data for {count} hosts ({dupe} duplicates found):\n\n'.format(
            count=self._count, dupe=self._duplicates)

        if self._duplicates == 1:
            fmt = fmt.replace('duplicates', 'duplicate')

        for k, v in self._report.items():
            fmt += self._column_template.render(
                section=k.replace('_', ' ').title(), _dict=v, count=len(v.keys()))
            fmt += '\n'

        return fmt

    def __str__(self):
        """wrap this around the formatter."""
        return self.get_console_report()


def main():
    """Execute the report."""
    usage = '%prog [ -h ACTIVE_HOSTS | -v ]'
    parser = OptionParser(usage=usage)
    parser.add_option('-a', '--active-hosts', metavar='ACTIVE_SOURCE',
                      type='string', dest='active_source', default=ACTIVE_HOSTS,
                      help='Source of the seed active hosts file (default: %default) .')
    parser.add_option('-s', '--single',
                      dest='single', action='store_true', default=False,
                      help='Query a single lookup host and then break. Primarily for development.')
    parser.add_option('-l', '--limit', metavar='NUM',
                      type='int', dest='limit', default=0,
                      help='Limit to NUM hosts from each lookup host.')
    parser.add_option('-d', '--domain',
                      dest='domain', action='store_true', default=False,
                      help='Live domain resolution - greatly increases accuracy and processing time.')  # pylint: disable=line-too-long
    parser.add_option('-v', '--verbose',
                      dest='verbose', action='count', default=0,
                      help='Verbose output. Can take multiple flags for more output.')
    parser.add_option('-q', '--quiet',
                      dest='quiet', action='store_true', default=False,
                      help='Silence all default log output - meant for cron.')
    options, _ = parser.parse_args()

    t_start = time.time()

    # silence even default logging
    if options.quiet:
        _logger = _dummy_log
    else:
        _logger = _log

    try:
        lookup_hosts = SlsHostList(options, _logger)
    except SlsReportException as ex:
        print str(ex)
        return -1

    lookup_hosts.generate_host_list()

    # get detailed information on the hosts the host list
    report = SlsStatistics(lookup_hosts, options, _logger)
    report.process_hosts()

    _logger('main.end', 'completed in {sec} seconds'.format(sec=round(time.time() - t_start, 2)))

    formatter = SlsReportFormatter(report.get_report, report.host_count, report.duplicates)
    print formatter

if __name__ == '__main__':
    main()
