# @package portal.MessageLib
# Provides a low-level client interface to the Portal server
#
# The Portal MessageLib module contains a single API class that provides
# a low-level client interface for creating, sending, and receiving
# messages over a connection with a Portal server. The Portal message
# interface uses Google's protobuf package, and the messages (requests,
# replies, and data structures) are defined in the \*.proto files included
# in the Portal client package. This module relies on the Portal
# Connection module which sends and receives encoded messages over a
# secure ZeroMQ link.


import zmq
from . import Version_pb2
from . import SDMS_Anon_pb2 as anon
from . import SDMS_Auth_pb2 as auth
from . import Connection
from . import VERSION

from . import VersionUtils


##
# @class API
# @brief Provides a low-level messaging interface to the Portal core server.
#
# The Portal MessageLib.API class provides a low-level interface
# for creating, sending, and receiving messages to/from a Portal
# server. The Portal message interface uses Google's protobuf
# package, and the messages (requests, replies, and data structures)
# are defined in the \*.proto files included in the Portal client
# package. Basic functionality includes connectivity, authentication,
# and both synchronous ans asynchronous message send/recv methods.
#
class API:
    ##
    # @brief MessageLib.API class initialization method.
    # @param server_host The Portal core server hostname or IP address.
    # @param server_port Portal core server port number.
    # @param server_pub_key_file Portal core server public key file (full path).
    # @param server_cfg_dir Portal core server configuration directory.
    # @param client_pub_key_file Client public key file (full path).
    # @param client_priv_key_file Client private key file (full path).
    # @param client_cfg_dir Client configuration directory.
    # @param manual_auth Client intends to manually authenticate if True.
    #                    Bypasses client key loading.
    # @param kwargs Placeholder for any extra keyword arguments (ignored)
    # @exception Exception: On server key load error, timeout, or incompatible protocols.
    #
    # Attempts to create a secure connection with a specified DatFed
    # server. A server key or key file must be found, but if client
    # keys are not provided, an anonymous connection is established.
    # The keysLoaded(), keysValid(), and getAuthStatus() methods may
    # be used to assess status. Also checks client and server protocol
    # versions for compatibility.
    #
    def __init__(
        self,
        server_host=None,
        server_port=None,
        server_pub_key_file=None,
        server_pub_key=None,
        client_pub_key_file=None,
        client_pub_key=None,
        client_priv_key_file=None,
        client_priv_key=None,
        client_token=None,
        manual_auth=None,
        **kwargs,
    ):
        self._ctxt = 0
        self._auth = False
        self._nack_except = True
        self._timeout = 50000

        if not server_host:
            raise Exception("Server host is not defined")

        if server_port is None:
            raise Exception("Server port is not defined")

        if not server_pub_key and not server_pub_key_file:
            raise Exception("Server public key or key file is not defined")

        if server_pub_key and server_pub_key_file:
            raise Exception("Cannot specify both server public key and key file")

        if client_pub_key and client_pub_key_file:
            raise Exception("Cannot specify both client public key and key file")

        if client_priv_key and client_priv_key_file:
            raise Exception("Cannot specify both client private key and key file")

        _server_pub_key = None
        _client_pub_key = None
        _client_priv_key = None

        # Use or load server public key
        if server_pub_key_file is not None:
            try:
                keyf = open(server_pub_key_file, "r")
                _server_pub_key = keyf.read()
                keyf.close()
            except BaseException:
                raise Exception(
                    "Could not open server public key file: " + server_pub_key_file
                )
        else:
            _server_pub_key = server_pub_key

        # Use, load, or generate client keys
        self._keys_loaded = False
        self._keys_valid = False

        if (
            manual_auth
            or client_token
            or not (
                client_pub_key_file
                or client_pub_key
                or client_priv_key_file
                or client_priv_key
            )
        ):
            pub, priv = zmq.curve_keypair()
            _client_pub_key = pub.decode("utf-8")
            _client_priv_key = priv.decode("utf-8")
        else:
            try:
                if client_pub_key_file:
                    keyf = open(client_pub_key_file, "r")
                    _client_pub_key = keyf.read()
                    keyf.close()
                else:
                    _client_pub_key = client_pub_key

                if client_priv_key_file:
                    keyf = open(client_priv_key_file, "r")
                    _client_priv_key = keyf.read()
                    keyf.close()
                else:
                    _client_priv_key = client_priv_key

                # Check for obviously bad keys
                if len(_client_pub_key) != 40 or len(_client_priv_key) != 40:
                    pub, priv = zmq.curve_keypair()
                    _client_pub_key = pub.decode("utf-8")
                    _client_priv_key = priv.decode("utf-8")
                else:
                    self._keys_valid = True
                self._keys_loaded = True
                print
            except BaseException:
                pub, priv = zmq.curve_keypair()
                _client_pub_key = pub.decode("utf-8")
                _client_priv_key = priv.decode("utf-8")

        if not _client_pub_key:
            raise Exception("Client public key is not defined")

        if not _client_priv_key:
            raise Exception("Client private key is not defined")

        self._conn = Connection.Connection(
            server_host, server_port, _server_pub_key, _client_pub_key, _client_priv_key
        )

        self._conn.registerProtocol(anon)
        self._conn.registerProtocol(auth)

        # Make a request to pypi
        package_name = "portal"  # Replace with the package name you want to check
        latest_version_on_pypi = VersionUtils.get_latest_stable_version(package_name)

        self.new_client_avail = False
        if latest_version_on_pypi:
            pypi_major, pypi_minor, pypi_patch = latest_version_on_pypi.split(".")
            major, minor, patch_w_prerelease = VERSION.__version__.split(".")

            # Remove prerelease part from patch
            patch = VersionUtils.remove_after_prefix_with_numbers(patch_w_prerelease)

            if pypi_major > major:
                self.new_client_avail = latest_version_on_pypi
            elif pypi_major == major:
                if pypi_minor > minor:
                    self.new_client_avail = latest_version_on_pypi
                elif pypi_minor == minor:
                    if pypi_patch > patch:
                        self.new_client_avail = latest_version_on_pypi

        # Check for compatible protocol versions
        reply, mt = self.sendRecv(anon.VersionRequest(), 10000)
        if reply is None:
            raise Exception(
                "Timeout waiting for server connection. Make sure"
                "the right ports are open."
            )

        if reply.api_major != Version_pb2.PORTAL_COMMON_PROTOCOL_API_MAJOR:
            error_msg = (
                "Incompatible server api detected {}.{}.{}, you are running "
                "{}.{}.{} consider "
                "upgrading the portal python client.".format(
                    reply.api_major,
                    reply.api_minor,
                    reply.api_patch,
                    Version_pb2.PORTAL_COMMON_PROTOCOL_API_MAJOR,
                    Version_pb2.PORTAL_COMMON_PROTOCOL_API_MINOR,
                    Version_pb2.PORTAL_COMMON_PROTOCOL_API_PATCH,
                )
            )
            if self.new_client_avail:
                error_msg += (
                    "\nConsider upgrading the portal python client as"
                    f" a new version is available {latest_version_on_pypi} that"
                    " should be compatible with the API."
                )
            raise Exception(error_msg)

        if client_token:
            self.manualAuthByToken(client_token)
        else:
            # Check if server authenticated based on keys
            reply, mt = self.sendRecv(anon.GetAuthStatusRequest(), 10000)
            self._auth = reply.auth
            self._uid = reply.uid

    # @brief Determines if client security keys were loaded.
    #
    # @return True if keys were loaded; false otherwise.
    # @retval bool
    #
    def keysLoaded(self):
        return self._keys_loaded

    # @brief Determines if loaded client security keys had a valid format.
    #
    # Note that keys with valid format but invalid value will cause
    # a connection failure (exception or timeout).
    #
    # @return True if client key formats were valid; false otherwise.
    # @retval bool
    #
    def keysValid(self):
        return self._keys_valid

    # @brief Gets the client authentication status and user ID.
    #
    # @return A tuple of (bool,string) - The bool is True if client
    #    is authenticated; False otherwise. IF authenticated, the
    #    string part is the Portal user ID of the client.
    # @retval (bool,str)
    #
    def getAuthStatus(self):
        return self._auth, self._uid

    # @brief Perform manual client authentication with Portal user ID and password.
    #
    # @param uid Client's Portal user ID.
    # @param password Client's Portal password.
    # @exception Exception: On communication timeout or authentication failure.
    #
    def manualAuthByPassword(self, uid, password):
        msg = anon.AuthenticateByPasswordRequest()
        msg.uid = uid
        msg.password = password
        a, b = self.sendRecv(msg)

        # Reset connection so server can re-authenticate
        self._conn.reset()

        # Test auth status
        reply, mt = self.sendRecv(anon.GetAuthStatusRequest())
        if not reply.auth:
            raise Exception("Password authentication failed.")

        self._auth = True
        self._uid = reply.uid

    def manualAuthByToken(self, token):
        msg = anon.AuthenticateByTokenRequest()
        msg.token = token
        self.sendRecv(msg)

        # Reset connection so server can re-authenticate
        self._conn.reset()

        # Test auth status
        reply, mt = self.sendRecv(anon.GetAuthStatusRequest())

        if not reply.auth:
            raise Exception("Token authentication failed")

        self._auth = True
        self._uid = reply.uid

    def logout(self):
        self._conn.reset()
        self._auth = False
        self._uid = None

    # @brief Get NackReply exception enable state.
    #
    # @return True if Nack exceptions are enabled; False otherwise.
    # @retval bool
    #
    def getNackExceptionEnabled(self):
        return self._nack_except

    # @brief Set NackReply exception enable state.
    #
    # If NackReply exceptions are enabled, any NackReply received by
    # the recv() or SendRecv() methods will be raised as an exception
    # containing the error message from the NackReply. When disabled,
    # NackReply messages are returned like any other reply.
    #
    # @param enabled: Sets exceptions to enabled (True) or disabled (False)
    #
    def setNackExceptionEnabled(self, enabled):
        if enabled:
            self._nack_except = True
        else:
            self._nack_except = False

    def setDefaultTimeout(self, timeout):
        self._timeout = timeout

    def getDefaultTimeout(self):
        return self._timeout

    def getDailyMessage(self):
        # Get daily message, if set
        reply, mt = self.sendRecv(anon.DailyMessageRequest(), 10000)
        if reply is None:
            raise Exception("Timeout waiting for server connection.")

        return reply.message

    # @brief Synchronously send a message then receive a reply to/from Portal server.
    #
    # @param msg: Protobuf message to send to the server
    #   timeout: Timeout in milliseconds
    # @return A tuple consisting of (reply, type), where reply is
    #   the received protobuf message reply and type is the
    #   corresponding message type/name (string) of the reply.
    #   On timeout, returns (None,None)
    # @retval (obj,str)
    # @exception Exception: On message context mismatch (out of sync)
    #
    def sendRecv(self, msg, timeout=None, nack_except=None):
        self.send(msg)
        _timeout = timeout if timeout is not None else self._timeout
        reply, mt, ctxt = self.recv(_timeout, nack_except)
        if reply is None:
            raise Exception("Timeout!!!!!!!!!")
            return None, None
        if ctxt != self._ctxt:
            raise Exception(
                "Mismatched reply. Expected {} got {}".format(self._ctxt, ctxt)
            )
        return reply, mt

    # @brief Asynchronously send a protobuf message to Portal server.
    #
    # @param msg: Protobuf message to send to the server
    # @return Auto-generated message re-association context int
    #    value (match to context in subsequent reply).
    # @retval int
    #
    def send(self, msg):
        self._ctxt += 1
        self._conn.send(msg, self._ctxt)
        return self._ctxt

    # @brief Receive a protobuf message (reply) from Portal server.
    #
    # @param timeout: Timeout in milliseconds (0 = don't wait, -1 =
    #   wait forever).
    # @exception Exception: On NackReply (if Nack exceptions enabled).
    # @return Tuple of (reply, type, context) where reply is the
    #   received protobuf message, type is the corresponding
    #   message type/name (string) of the reply, and context
    #   is the reassociation value (int). On timeout, returns
    #   (None,None,None).
    # @retval (obj,str,int)
    #
    def recv(self, timeout=None, nack_except=None):
        _timeout = timeout if timeout is not None else self._timeout

        reply, msg_type, ctxt = self._conn.recv(_timeout)
        if reply is None:
            return None, None, None

        _nack_except = nack_except if nack_except is not None else self._nack_except

        if msg_type == "NackReply" and _nack_except:
            if reply.err_msg:
                raise Exception(reply.err_msg)
            else:
                raise Exception("Server error {}".format(reply.err_code))

        return reply, msg_type, ctxt
