from flask.sessions import SessionInterface, SessionMixin
from werkzeug.datastructures import CallbackDict

from sqlalchemy import Table, Column, String, LargeBinary, DateTime,\
    select

from random import SystemRandom, randrange
import string
from datetime import datetime, timedelta
import cPickle

random = SystemRandom()


class SQLSession(CallbackDict, SessionMixin):

    def __init__(self, sid, db, table, new=False, initial=None):
        self.sid = sid
        self.db = db
        self.table = table
        self.modified = False
        self.new = new

        def _on_update(self):
            self.modified = True
        super(SQLSession, self).__init__(initial, _on_update)

    def save(self):
        if self.new:
            self.db.execute(self.table.insert({
                'session_id': self.sid,
                'expire': datetime.utcnow() + timedelta(hours=1),
                'value': cPickle.dumps(dict(self), -1)
            }))
            self.new = False
        else:
            self.db.execute(self.table.update(
                self.table.c.session_id == self.sid,
                {
                    'expire': datetime.utcnow() + timedelta(hours=1),
                    'value': cPickle.dumps(dict(self), -1)
                }
            ))


class MySessionInterface(SessionInterface):

    def __init__(self, db):
        self.db = db

        self.table = Table('flask_sessions', db.metadata,
            Column('session_id', String(32), primary_key=True),
            Column('expire', DateTime, index=True),
            Column('value', LargeBinary, nullable=False)
        )

    def open_session(self, app, request):
        sid = request.cookies.get(app.session_cookie_name)
        if sid:
            res = self.db.engine.execute(select([self.table.c.value], (self.table.c.session_id == sid) &
                                                (self.table.c.expire > datetime.utcnow()))).first()
            if res:
                return SQLSession(sid, self.db.engine, self.table, False, cPickle.loads(res[0]))

        while True:
            sid = ''.join(random.choice(string.ascii_letters + string.digits) for i in range(32))
            res = self.db.engine.execute(select([self.table.c.value], self.table.c.session_id == sid)).first()
            if not res:
                break

        return SQLSession(sid, self.db.engine, self.table, True)

    def save_session(self, app, session, response):
        if not session and not session.modified:
            return  # empty/unused session

        if session.modified:
            session.save()

        # remove expired sessions.. or maybe not
        if randrange(20) % 20 == 0:
            self.db.engine.execute(self.table.delete(self.table.c.expire <= datetime.utcnow()))

        response.set_cookie(app.session_cookie_name, session.sid,
                            expires=self.get_expiration_time(app, session),
                            domain=self.get_cookie_domain(app),
                            path=self.get_cookie_path(app),
                            secure=self.get_cookie_secure(app),
                            httponly=self.get_cookie_httponly(app))