sessions.py 3.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from flask.sessions import SessionInterface, SessionMixin
  2. from werkzeug.datastructures import CallbackDict
  3. from sqlalchemy import Table, Column, String, LargeBinary, DateTime,\
  4. select
  5. from random import SystemRandom, randrange
  6. import string
  7. from datetime import datetime, timedelta
  8. import cPickle
  9. random = SystemRandom()
  10. class SQLSession(CallbackDict, SessionMixin):
  11. def __init__(self, sid, db, table, new=False, initial=None):
  12. self.sid = sid
  13. self.db = db
  14. self.table = table
  15. self.modified = False
  16. self.new = new
  17. def _on_update(self):
  18. self.modified = True
  19. super(SQLSession, self).__init__(initial, _on_update)
  20. def save(self):
  21. if self.new:
  22. self.db.execute(self.table.insert({
  23. 'session_id': self.sid,
  24. 'expire': datetime.utcnow() + timedelta(hours=1),
  25. 'value': cPickle.dumps(dict(self), -1)
  26. }))
  27. self.new = False
  28. else:
  29. self.db.execute(self.table.update(
  30. self.table.c.session_id == self.sid,
  31. {
  32. 'expire': datetime.utcnow() + timedelta(hours=1),
  33. 'value': cPickle.dumps(dict(self), -1)
  34. }
  35. ))
  36. class MySessionInterface(SessionInterface):
  37. def __init__(self, db):
  38. self.db = db
  39. self.table = Table('flask_sessions', db.metadata,
  40. Column('session_id', String(32), primary_key=True),
  41. Column('expire', DateTime, index=True),
  42. Column('value', LargeBinary, nullable=False)
  43. )
  44. def open_session(self, app, request):
  45. sid = request.cookies.get(app.session_cookie_name)
  46. if sid:
  47. res = self.db.engine.execute(select([self.table.c.value], (self.table.c.session_id == sid) &
  48. (self.table.c.expire > datetime.utcnow()))).first()
  49. if res:
  50. return SQLSession(sid, self.db.engine, self.table, False, cPickle.loads(res[0]))
  51. while True:
  52. sid = ''.join(random.choice(string.ascii_letters + string.digits) for i in range(32))
  53. res = self.db.engine.execute(select([self.table.c.value], self.table.c.session_id == sid)).first()
  54. if not res:
  55. break
  56. return SQLSession(sid, self.db.engine, self.table, True)
  57. def save_session(self, app, session, response):
  58. if not session and not session.modified:
  59. return # empty/unused session
  60. if session.modified:
  61. session.save()
  62. # remove expired sessions.. or maybe not
  63. if randrange(20) % 20 == 0:
  64. self.db.engine.execute(self.table.delete(self.table.c.expire <= datetime.utcnow()))
  65. response.set_cookie(app.session_cookie_name, session.sid,
  66. expires=self.get_expiration_time(app, session),
  67. domain=self.get_cookie_domain(app),
  68. path=self.get_cookie_path(app),
  69. secure=self.get_cookie_secure(app),
  70. httponly=self.get_cookie_httponly(app))