Browse Source

Add an unit test, refactor the session handler

Gu1 11 years ago
parent
commit
265a66ea07
3 changed files with 45 additions and 9 deletions
  1. 1 1
      ffdnispdb/__init__.py
  2. 8 8
      ffdnispdb/sessions.py
  3. 36 0
      test_ffdnispdb.py

+ 1 - 1
ffdnispdb/__init__.py

@@ -11,7 +11,7 @@ app = Flask(__name__)
 app.config.from_object('config')
 babel = Babel(app)
 db = SQLAlchemy(app)
-app.session_interface = MySessionInterface(db.engine, db.metadata)
+app.session_interface = MySessionInterface(db)
 cache = NullCache()
 
 @event.listens_for(db.engine, "connect")

+ 8 - 8
ffdnispdb/sessions.py

@@ -45,10 +45,10 @@ class SQLSession(CallbackDict, SessionMixin):
 
 
 class MySessionInterface(SessionInterface):
-    def __init__(self, engine, metadata):
-        self.engine = engine
+    def __init__(self, db):
+        self.db = db
 
-        self.table = Table('flask_sessions', metadata,
+        self.table = Table('flask_sessions', db.metadata,
             Column('session_id', String(32), primary_key=True),
             Column('expire', DateTime, index=True),
             Column('value', LargeBinary, nullable=False)
@@ -57,18 +57,18 @@ class MySessionInterface(SessionInterface):
     def open_session(self, app, request):
         sid = request.cookies.get(app.session_cookie_name)
         if sid:
-            res=self.engine.execute(select([self.table.c.value], (self.table.c.session_id == sid) &
+            res=self.db.engine.execute(select([self.table.c.value], (self.table.c.session_id == sid) &
                                                                  (self.table.c.expire > datetime.now()))).first()
             if res:
-                return SQLSession(sid, self.engine, self.table, False, cPickle.loads(res[0]))
+                return SQLSession(sid, self.db.engine, self.table, False, cPickle.loads(res[0]))
 
         while True:
             sid=''.join(random.choice(string.ascii_letters+string.digits) for i in range(32))
-            res=self.engine.execute(select([self.table.c.value], self.table.c.session_id == sid)).first()
+            res=self.db.engine.execute(select([self.table.c.value], self.table.c.session_id == sid)).first()
             if not res:
                 break
 
-        return SQLSession(sid, self.engine, self.table, True)
+        return SQLSession(sid, self.db.engine, self.table, True)
 
     def save_session(self, app, session, response):
         if session.modified:
@@ -76,7 +76,7 @@ class MySessionInterface(SessionInterface):
 
         # remove expired sessions.. or maybe not
         if randrange(20) % 20 == 0:
-            self.engine.execute(self.table.delete(self.table.c.expire <= datetime.now()))
+            self.db.engine.execute(self.table.delete(self.table.c.expire <= datetime.now()))
 
         response.set_cookie(app.session_cookie_name, session.sid,
                             expires=self.get_expiration_time(app, session),

+ 36 - 0
test_ffdnispdb.py

@@ -0,0 +1,36 @@
+
+from ffdnispdb import app, db
+from ffdnispdb.models import ISP
+from flask import Flask
+from flask.ext.sqlalchemy import SQLAlchemy
+import unittest
+import os
+
+
+class TestCase(unittest.TestCase):
+
+    def setUp(self):
+        app.config['TESTING'] = True
+        app.config['WTF_CSRF_ENABLED'] = False
+        # Ugly, but should work in this context... ?
+        app.config['SQLALCHEMY_DATABASE_URI'] = 'sqlite://'
+        self.app = app.test_client()
+        db.create_all()
+
+    def tearDown(self):
+        db.drop_all()
+
+    def test_projectform(self):
+        resp = self.app.post('/create/form', data={
+            'name': 'Test',
+            'step': '1',
+            'covered_areas-0-name': 'Somewhere over the rainbow',
+            'covered_areas-0-technologies': 'dsl',
+            'covered_areas-0-technologies': 'ftth'
+        })
+        self.assertNotEqual(resp.location, None)
+        self.assertEqual(ISP.query.filter_by(name='Test').count(), 1)
+
+
+if __name__ == '__main__':
+    unittest.main()