models.py 8.5 KB


  1. # -*- coding: utf-8 -*-
  2. from decimal import Decimal
  3. import json
  4. import os
  5. import itertools
  6. from datetime import datetime
  7. from . import db
  8. from .utils import dict_to_geojson
  9. from flask import current_app
  10. import flask_sqlalchemy
  11. from sqlalchemy.types import TypeDecorator, VARCHAR
  12. from sqlalchemy.ext.mutable import MutableDict
  13. from sqlalchemy import event
  14. import geoalchemy as geo
  15. import whoosh
  16. from whoosh import fields, index, qparser
  17. class fakefloat(float):
  18. def __init__(self, value):
  19. self._value = value
  20. def __repr__(self):
  21. return str(self._value)
  22. def defaultencode(o):
  23. if isinstance(o, Decimal):
  24. # Subclass float with custom repr?
  25. return fakefloat(o)
  26. raise TypeError(repr(o) + " is not JSON serializable")
  27. class JSONEncodedDict(TypeDecorator):
  28. "Represents an immutable structure as a json-encoded string."
  29. impl = VARCHAR
  30. def process_bind_param(self, value, dialect):
  31. if value is not None:
  32. value = json.dumps(value, default=defaultencode)
  33. return value
  34. def process_result_value(self, value, dialect):
  35. if value is not None:
  36. value = json.loads(value)
  37. return value
  38. class ISP(db.Model):
  39. __tablename__ = 'isp'
  40. id = db.Column(db.Integer, primary_key=True)
  41. name = db.Column(db.String, nullable=False, index=True, unique=True)
  42. shortname = db.Column(db.String(12), index=True, unique=True)
  43. is_ffdn_member = db.Column(db.Boolean, default=False)
  44. is_disabled = db.Column(db.Boolean, default=False) # True = ISP will not appear
  45. json_url = db.Column(db.String)
  46. last_update_success = db.Column(db.DateTime)
  47. last_update_attempt = db.Column(db.DateTime)
  48. update_error_strike = db.Column(db.Integer, default=0) # if >= 3; then updates are disabled
  49. next_update = db.Column(db.DateTime, default=datetime.now())
  50. tech_email = db.Column(db.String)
  51. cache_info = db.Column(MutableDict.as_mutable(JSONEncodedDict))
  52. json = db.Column(MutableDict.as_mutable(JSONEncodedDict))
  53. covered_areas = db.relationship('CoveredArea', backref='isp')
  54. # covered_areas_query = db.relationship('CoveredArea', lazy='dynamic')
  55. registered_office = db.relationship('RegisteredOffice', uselist=False, backref='isp')
  56. def __init__(self, *args, **kwargs):
  57. super(ISP, self).__init__(*args, **kwargs)
  58. self.json={}
  59. def pre_save(self, *args):
  60. if 'name' in self.json:
  61. assert self.name == self.json['name']
  62. if 'shortname' in self.json:
  63. assert self.shortname == self.json['shortname']
  64. if db.inspect(self).attrs.json.history.has_changes():
  65. self._sync_covered_areas()
  66. def _sync_covered_areas(self):
  67. """
  68. Called to synchronise between json['coveredAreas'] and the
  69. covered_areas table, when json was modified.
  70. """
  71. # delete current covered areas & registered office
  72. CoveredArea.query.filter_by(isp_id=self.id).delete()
  73. RegisteredOffice.query.filter_by(isp_id=self.id).delete()
  74. for ca_js in self.json.get('coveredAreas', []):
  75. ca=CoveredArea()
  76. ca.name=ca_js['name']
  77. area=ca_js.get('area')
  78. ca.area=db.func.CastToMultiPolygon(
  79. db.func.GeomFromGeoJSON(dict_to_geojson(area))
  80. ) if area else None
  81. self.covered_areas.append(ca)
  82. coords=self.json.get('coordinates')
  83. if coords:
  84. self.registered_office=RegisteredOffice(
  85. point=db.func.MakePoint(coords['longitude'], coords['latitude'], 4326)
  86. )
  87. def covered_areas_names(self):
  88. return [c['name'] for c in self.json.get('coveredAreas', [])]
  89. @property
  90. def is_local(self):
  91. return self.json_url is None
  92. @property
  93. def complete_name(self):
  94. if 'shortname' in self.json:
  95. return u'%s (%s)'%(self.json['shortname'], self.json['name'])
  96. else:
  97. return u'%s'%self.json['name']
  98. @staticmethod
  99. def str2date(_str):
  100. d=None
  101. try:
  102. d=datetime.strptime(_str, '%Y-%m-%d')
  103. except ValueError:
  104. pass
  105. if d is None:
  106. try:
  107. d=datetime.strptime(_str, '%Y-%m')
  108. except ValueError:
  109. pass
  110. return d
  111. def __repr__(self):
  112. return u'<ISP %r>' % (self.shortname if self.shortname else self.name,)
  113. class CoveredArea(db.Model):
  114. __tablename__ = 'covered_areas'
  115. id = db.Column(db.Integer, primary_key=True)
  116. isp_id = db.Column(db.Integer, db.ForeignKey('isp.id'))
  117. name = db.Column(db.String)
  118. area = geo.GeometryColumn(geo.MultiPolygon(2))
  119. area_geojson = db.column_property(db.func.AsGeoJSON(db.literal_column('area')), deferred=True)
  120. @classmethod
  121. def containing(cls, coords):
  122. """
  123. Return CoveredAreas containing point (lat,lon)
  124. """
  125. return cls.query.filter(
  126. cls.area != None,
  127. cls.area.gcontains(db.func.MakePoint(coords[1], coords[0])) == 1
  128. )
  129. def __repr__(self):
  130. return u'<CoveredArea %r>' % (self.name,)
  131. geo.GeometryDDL(CoveredArea.__table__)
  132. class RegisteredOffice(db.Model):
  133. __tablename__ = 'registered_offices'
  134. id = db.Column(db.Integer, primary_key=True)
  135. isp_id = db.Column(db.Integer, db.ForeignKey('isp.id'))
  136. point = geo.GeometryColumn(geo.Point(0))
  137. geo.GeometryDDL(RegisteredOffice.__table__)
  138. @event.listens_for(db.metadata, 'before_create')
  139. def init_spatialite_metadata(target, conn, **kwargs):
  140. conn.execute('SELECT InitSpatialMetaData(1)')
  141. def pre_save_hook(sess):
  142. for v in itertools.chain(sess.new, sess.dirty):
  143. if hasattr(v, 'pre_save') and hasattr(v.pre_save, '__call__'):
  144. v.pre_save(sess)
  145. class ISPWhoosh(object):
  146. """
  147. Helper class to index the ISP model with Whoosh to allow full-text search
  148. """
  149. schema = fields.Schema(
  150. id=fields.ID(unique=True, stored=True),
  151. is_ffdn_member=fields.BOOLEAN(),
  152. is_disabled=fields.BOOLEAN(),
  153. name=fields.TEXT(),
  154. shortname=fields.TEXT(),
  155. description=fields.TEXT(),
  156. covered_areas=fields.KEYWORD(scorable=True, commas=True, lowercase=True),
  157. step=fields.NUMERIC(signed=False),
  158. )
  159. primary_key=schema._fields['id']
  160. @staticmethod
  161. def get_index_dir():
  162. return current_app.config.get('WHOOSH_INDEX_DIR', 'whoosh')
  163. @classmethod
  164. def get_index(cls):
  165. idxdir=cls.get_index_dir()
  166. if index.exists_in(idxdir):
  167. idx = index.open_dir(idxdir)
  168. else:
  169. if not os.path.exists(idxdir):
  170. os.makedirs(idxdir)
  171. idx = index.create_in(idxdir, cls.schema)
  172. return idx
  173. @classmethod
  174. def _search(cls, s, terms):
  175. return s.search(qparser.MultifieldParser([
  176. 'name', 'shortname', 'description', 'covered_areas'
  177. ], schema=cls.schema).parse(terms),
  178. mask=whoosh.query.Term('is_disabled', True))
  179. @classmethod
  180. def search(cls, terms):
  181. with ISPWhoosh.get_index().searcher() as s:
  182. sres=cls._search(s, terms)
  183. ranks={}
  184. for rank, r in enumerate(sres):
  185. ranks[r['id']]=rank
  186. if not len(ranks):
  187. return []
  188. _res=ISP.query.filter(ISP.id.in_(ranks.keys()))
  189. return sorted(_res, key=lambda r: ranks[r.id])
  190. @classmethod
  191. def update_document(cls, writer, model):
  192. kw={
  193. 'id': unicode(model.id),
  194. '_stored_id': model.id,
  195. 'is_ffdn_member': model.is_ffdn_member,
  196. 'is_disabled': model.is_disabled,
  197. 'name': model.name,
  198. 'shortname': model.shortname,
  199. 'description': model.json.get('description'),
  200. 'covered_areas': model.covered_areas_names(),
  201. 'step': model.json.get('progressStatus')
  202. }
  203. writer.update_document(**kw)
  204. @classmethod
  205. def _after_flush(cls, app, changes):
  206. isp_changes = []
  207. for change in changes:
  208. if change[0].__class__ == ISP:
  209. update = change[1] in ('update', 'insert')
  210. isp_changes.append((update, change[0]))
  211. if not len(changes):
  212. return
  213. idx=cls.get_index()
  214. with idx.writer() as writer:
  215. for update, model in isp_changes:
  216. if update:
  217. cls.update_document(writer, model)
  218. else:
  219. writer.delete_by_term(cls.primary_key, model.id)
  220. flask_sqlalchemy.models_committed.connect(ISPWhoosh._after_flush)
  221. event.listen(flask_sqlalchemy.Session, 'before_commit', pre_save_hook)