# Copyright (C) 2009  Internet Systems Consortium.
#
# Permission to use, copy, modify, and distribute this software for any
# purpose with or without fee is hereby granted, provided that the above
# copyright notice and this permission notice appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND INTERNET SYSTEMS CONSORTIUM
# DISCLAIMS ALL WARRANTIES WITH REGARD TO THIS SOFTWARE INCLUDING ALL
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL
# INTERNET SYSTEMS CONSORTIUM BE LIABLE FOR ANY SPECIAL, DIRECT,
# INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES WHATSOEVER RESULTING
# FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION OF CONTRACT,
# NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN CONNECTION
# WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

import sys
import socket
import struct
import errno
import os
import threading
import bind10_config

import isc.cc.message

class ProtocolError(Exception): pass
class NetworkError(Exception): pass
class SessionError(Exception): pass
class SessionTimeout(Exception): pass

class Session:
    MSGQ_DEFAULT_TIMEOUT = 4000
    
    def __init__(self, socket_file=None):
        self._socket = None
        self._lname = None
        self._sequence = 1
        self._closed = False
        self._queue = []
        self._lock = threading.RLock()
        self.set_timeout(self.MSGQ_DEFAULT_TIMEOUT);
        self._recv_len_size = 0
        self._recv_size = 0

        if socket_file is None:
            if "BIND10_MSGQ_SOCKET_FILE" in os.environ:
                self.socket_file = os.environ["BIND10_MSGQ_SOCKET_FILE"]
            else:
                self.socket_file = bind10_config.BIND10_MSGQ_SOCKET_FILE
        else:
            self.socket_file = socket_file

        try:
            self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
            self._socket.connect(self.socket_file)
            self.sendmsg({ "type": "getlname" })
            env, msg = self.recvmsg(False)
            if not env:
                raise ProtocolError("Could not get local name")
            self._lname = msg["lname"]
            if not self._lname:
                raise ProtocolError("Could not get local name")
        except socket.error as se:
                raise SessionError(se)

    @property
    def lname(self):
        return self._lname

    def close(self):
        self._socket.close()
        self._lname = None
        self._closed = True

    def sendmsg(self, env, msg = None):
        with self._lock:
            if self._closed:
                raise SessionError("Session has been closed.")
            if type(env) == dict:
                env = isc.cc.message.to_wire(env)
            if len(env) > 65535:
                raise ProtocolError("Envelope too large")
            if type(msg) == dict:
                msg = isc.cc.message.to_wire(msg)
            self._socket.setblocking(1)
            length = 2 + len(env);
            if msg:
                length += len(msg)
            self._socket.send(struct.pack("!I", length))
            self._socket.send(struct.pack("!H", len(env)))
            self._socket.send(env)
            if msg:
                self._socket.send(msg)

    def recvmsg(self, nonblock = True, seq = None):
        """Reads a message. If nonblock is true, and there is no
           message to read, it returns (None, None).
           If seq is not None, it should be a value as returned by
           group_sendmsg(), in which case only the response to
           that message is returned, and others will be queued until
           the next call to this method.
           If seq is None, only messages that are *not* responses
           will be returned, and responses will be queued.
           The queue is checked for relevant messages before data
           is read from the socket.
           Raises a SessionError if there is a JSON decode problem in
           the message that is read, or if the session has been closed
           prior to the call of recvmsg()"""
        with self._lock:
            if len(self._queue) > 0:
                i = 0;
                for env, msg in self._queue:
                    if seq != None and "reply" in env and seq == env["reply"]:
                        return self._queue.pop(i)
                    elif seq == None and "reply" not in env:
                        return self._queue.pop(i)
                    else:
                        i = i + 1
            if self._closed:
                raise SessionError("Session has been closed.")
            data = self._receive_full_buffer(nonblock)
            if data and len(data) > 2:
                header_length = struct.unpack('>H', data[0:2])[0]
                data_length = len(data) - 2 - header_length
                try:
                    if data_length > 0:
                        env = isc.cc.message.from_wire(data[2:header_length+2])
                        msg = isc.cc.message.from_wire(data[header_length + 2:])
                        if (seq == None and "reply" not in env) or (seq != None and "reply" in env and seq == env["reply"]):
                            return env, msg
                        else:
                            self._queue.append((env,msg))
                            return self.recvmsg(nonblock, seq)
                    else:
                        return isc.cc.message.from_wire(data[2:header_length+2]), None
                except ValueError as ve:
                    # TODO: when we have logging here, add a debug
                    # message printing the data that we were unable
                    # to parse as JSON
                    raise SessionError(ve)
            return None, None

    def _receive_bytes(self, size):
        """Try to get size bytes of data from the socket.
           Raises a ProtocolError if the size is 0.
           Raises any error from recv().
           Returns whatever data was available (if >0 bytes).
           """
        data = self._socket.recv(size)
        if len(data) == 0: # server closed connection
            raise ProtocolError("Read of 0 bytes: connection closed")
        return data
        
    def _receive_len_data(self):
        """Reads self._recv_len_size bytes of data from the socket into
           self._recv_len_data
           This is done through class variables so in the case of
           an EAGAIN we can continue on a subsequent call.
           Raises a ProtocolError, a socket.error (which may be
           timeout or eagain), or reads until we have all data we need.
           """
        while self._recv_len_size > 0:
            new_data = self._receive_bytes(self._recv_len_size)
            self._recv_len_data += new_data
            self._recv_len_size -= len(new_data)

    def _receive_data(self):
        """Reads self._recv_size bytes of data from the socket into
           self._recv_data.
           This is done through class variables so in the case of
           an EAGAIN we can continue on a subsequent call.
           Raises a ProtocolError, a socket.error (which may be
           timeout or eagain), or reads until we have all data we need.
        """
        while self._recv_size > 0:
            new_data = self._receive_bytes(self._recv_size)
            self._recv_data += new_data
            self._recv_size -= len(new_data)

    def _receive_full_buffer(self, nonblock):
        if nonblock:
            self._socket.setblocking(0)
        else:
            self._socket.setblocking(1)
            if self._socket_timeout == 0.0:
                self._socket.settimeout(None)
            else:
                self._socket.settimeout(self._socket_timeout)

        try:
            # we might be in a call following an EAGAIN, in which case
            # we simply continue. In the first case, either
            # recv_size or recv_len size are not zero
            # they may never both be non-zero (we are either starting
            # a full read, or continuing one of the reads
            assert self._recv_size == 0 or self._recv_len_size == 0
            
            if self._recv_size == 0:
                if self._recv_len_size == 0:
                    # both zero, start a new full read
                    self._recv_len_size = 4
                    self._recv_len_data = bytearray()
                self._receive_len_data()

                self._recv_size = struct.unpack('>I', self._recv_len_data)[0]
                self._recv_data = bytearray()
            self._receive_data()

            # no EAGAIN, so copy data and reset internal counters
            data = self._recv_data

            self._recv_len_size = 0
            self._recv_size = 0

            return (data)

        except socket.timeout:
            raise SessionTimeout("recv() on cc session timed out")
        except socket.error as se:
            # Only keep data in case of EAGAIN
            if se.errno == errno.EAGAIN:
                return None
            # unknown state otherwise, best to drop data
            self._recv_len_size = 0
            self._recv_size = 0
            # ctrl-c can result in EINTR, return None to prevent
            # stacktrace output
            if se.errno == errno.EINTR:
                return None
            raise se

    def _next_sequence(self):
        self._sequence += 1
        return self._sequence

    def group_subscribe(self, group, instance = "*"):
        self.sendmsg({
            "type": "subscribe",
            "group": group,
            "instance": instance,
        })

    def group_unsubscribe(self, group, instance = "*"):
        self.sendmsg({
            "type": "unsubscribe",
            "group": group,
            "instance": instance,
        })

    def group_sendmsg(self, msg, group, instance = "*", to = "*"):
        seq = self._next_sequence()
        self.sendmsg({
            "type": "send",
            "from": self._lname,
            "to": to,
            "group": group,
            "instance": instance,
            "seq": seq,
        }, isc.cc.message.to_wire(msg))
        return seq

    def has_queued_msgs(self):
        return len(self._queue) > 0

    def group_recvmsg(self, nonblock = True, seq = None):
        env, msg  = self.recvmsg(nonblock, seq)
        if env == None:
            # return none twice to match normal return value
            # (so caller won't get a type error on no data)
            return (None, None)
        return (msg, env)

    def group_reply(self, routing, msg):
        seq = self._next_sequence()
        self.sendmsg({
            "type": "send",
            "from": self._lname,
            "to": routing["from"],
            "group": routing["group"],
            "instance": routing["instance"],
            "seq": seq,
            "reply": routing["seq"],
        }, isc.cc.message.to_wire(msg))
        return seq

    def set_timeout(self, milliseconds):
        """Sets the socket timeout for blocking reads to the given
           number of milliseconds"""
        self._socket_timeout = milliseconds / 1000.0

    def get_timeout(self):
        """Returns the current timeout for blocking reads (in milliseconds)"""
        return self._socket_timeout * 1000.0

if __name__ == "__main__":
    import doctest
    doctest.testmod()