models.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293
  1. # -*- coding: utf-8 -*-
  2. from decimal import Decimal
  3. import json
  4. import os
  5. import itertools
  6. from datetime import datetime
  7. import pytz
  8. from . import db
  9. from .utils import dict_to_geojson, utcnow
  10. from flask import current_app
  11. import flask_sqlalchemy
  12. from sqlalchemy.types import TypeDecorator, VARCHAR, DateTime
  13. from sqlalchemy.ext.mutable import MutableDict
  14. from sqlalchemy import event
  15. import geoalchemy as geo
  16. import whoosh
  17. from whoosh import fields, index, qparser
  18. class fakefloat(float):
  19. def __init__(self, value):
  20. self._value = value
  21. def __repr__(self):
  22. return str(self._value)
  23. def defaultencode(o):
  24. if isinstance(o, Decimal):
  25. # Subclass float with custom repr?
  26. return fakefloat(o)
  27. raise TypeError(repr(o) + " is not JSON serializable")
  28. class JSONEncodedDict(TypeDecorator):
  29. "Represents an immutable structure as a json-encoded string."
  30. impl = VARCHAR
  31. def process_bind_param(self, value, dialect):
  32. if value is not None:
  33. value = json.dumps(value, default=defaultencode)
  34. return value
  35. def process_result_value(self, value, dialect):
  36. if value is not None:
  37. value = json.loads(value)
  38. return value
  39. class UTCDateTime(TypeDecorator):
  40. impl = DateTime
  41. def process_bind_param(self, value, engine):
  42. if value is not None:
  43. return value.astimezone(pytz.utc)
  44. def process_result_value(self, value, engine):
  45. if value is not None:
  46. return datetime(value.year, value.month, value.day,
  47. value.hour, value.minute, value.second,
  48. value.microsecond, tzinfo=pytz.utc)
  49. class ISP(db.Model):
  50. __tablename__ = 'isp'
  51. id = db.Column(db.Integer, primary_key=True)
  52. name = db.Column(db.String, nullable=False, index=True, unique=True)
  53. shortname = db.Column(db.String(12), index=True, unique=True)
  54. is_ffdn_member = db.Column(db.Boolean, default=False)
  55. is_disabled = db.Column(db.Boolean, default=False) # True = ISP will not appear
  56. json_url = db.Column(db.String)
  57. date_added = db.Column(UTCDateTime, default=utcnow)
  58. last_update_success = db.Column(UTCDateTime)
  59. last_update_attempt = db.Column(UTCDateTime)
  60. update_error_strike = db.Column(db.Integer, default=0) # if >= 3; then updates are disabled
  61. next_update = db.Column(UTCDateTime, default=utcnow)
  62. tech_email = db.Column(db.String)
  63. cache_info = db.Column(MutableDict.as_mutable(JSONEncodedDict))
  64. json = db.Column(MutableDict.as_mutable(JSONEncodedDict))
  65. covered_areas = db.relationship('CoveredArea', backref='isp')
  66. # covered_areas_query = db.relationship('CoveredArea', lazy='dynamic')
  67. registered_office = db.relationship('RegisteredOffice', uselist=False, backref='isp')
  68. def __init__(self, *args, **kwargs):
  69. super(ISP, self).__init__(*args, **kwargs)
  70. self.json = {}
  71. def pre_save(self, *args):
  72. if 'name' in self.json:
  73. assert self.name == self.json['name']
  74. if 'shortname' in self.json:
  75. assert self.shortname == self.json['shortname']
  76. if db.inspect(self).attrs.json.history.has_changes():
  77. self._sync_covered_areas()
  78. def _sync_covered_areas(self):
  79. """
  80. Called to synchronise between json['coveredAreas'] and the
  81. covered_areas table, when json was modified.
  82. """
  83. # delete current covered areas & registered office
  84. CoveredArea.query.filter_by(isp_id=self.id).delete()
  85. RegisteredOffice.query.filter_by(isp_id=self.id).delete()
  86. for ca_js in self.json.get('coveredAreas', []):
  87. ca = CoveredArea()
  88. ca.name = ca_js['name']
  89. area = ca_js.get('area')
  90. ca.area = db.func.CastToMultiPolygon(
  91. db.func.GeomFromGeoJSON(dict_to_geojson(area))
  92. ) if area else None
  93. self.covered_areas.append(ca)
  94. coords = self.json.get('coordinates')
  95. if coords:
  96. self.registered_office = RegisteredOffice(
  97. point=db.func.MakePoint(coords['longitude'], coords['latitude'], 4326)
  98. )
  99. def covered_areas_names(self):
  100. return [c['name'] for c in self.json.get('coveredAreas', [])]
  101. @property
  102. def is_local(self):
  103. return self.json_url is None
  104. @property
  105. def complete_name(self):
  106. if 'shortname' in self.json:
  107. return u'%s (%s)' % (self.json['shortname'], self.json['name'])
  108. else:
  109. return u'%s' % self.json['name']
  110. @staticmethod
  111. def str2date(_str):
  112. d = None
  113. try:
  114. d = datetime.strptime(_str, '%Y-%m-%d')
  115. except ValueError:
  116. pass
  117. if d is None:
  118. try:
  119. d = datetime.strptime(_str, '%Y-%m')
  120. except ValueError:
  121. pass
  122. return d
  123. def __repr__(self):
  124. return u'<ISP %r>' % (self.shortname if self.shortname else self.name,)
  125. class CoveredArea(db.Model):
  126. __tablename__ = 'covered_areas'
  127. id = db.Column(db.Integer, primary_key=True)
  128. isp_id = db.Column(db.Integer, db.ForeignKey('isp.id'))
  129. name = db.Column(db.String)
  130. area = geo.GeometryColumn(geo.MultiPolygon(2))
  131. area_geojson = db.column_property(db.func.AsGeoJSON(db.literal_column('area')), deferred=True)
  132. @classmethod
  133. def containing(cls, coords):
  134. """
  135. Return CoveredAreas containing point (lat,lon)
  136. """
  137. return cls.query.filter(
  138. cls.area != None,
  139. cls.area.gcontains(db.func.MakePoint(coords[1], coords[0])) == 1
  140. )
  141. def __repr__(self):
  142. return u'<CoveredArea %r>' % (self.name,)
  143. geo.GeometryDDL(CoveredArea.__table__)
  144. class RegisteredOffice(db.Model):
  145. __tablename__ = 'registered_offices'
  146. id = db.Column(db.Integer, primary_key=True)
  147. isp_id = db.Column(db.Integer, db.ForeignKey('isp.id'))
  148. point = geo.GeometryColumn(geo.Point(0))
  149. geo.GeometryDDL(RegisteredOffice.__table__)
  150. @event.listens_for(db.metadata, 'before_create')
  151. def init_spatialite_metadata(target, conn, **kwargs):
  152. conn.execute('SELECT InitSpatialMetaData(1)')
  153. def pre_save_hook(sess):
  154. for v in itertools.chain(sess.new, sess.dirty):
  155. if hasattr(v, 'pre_save') and hasattr(v.pre_save, '__call__'):
  156. v.pre_save(sess)
  157. class ISPWhoosh(object):
  158. """
  159. Helper class to index the ISP model with Whoosh to allow full-text search
  160. """
  161. schema = fields.Schema(
  162. id=fields.ID(unique=True, stored=True),
  163. is_ffdn_member=fields.BOOLEAN(),
  164. is_disabled=fields.BOOLEAN(),
  165. name=fields.TEXT(),
  166. shortname=fields.TEXT(),
  167. description=fields.TEXT(),
  168. covered_areas=fields.KEYWORD(scorable=True, commas=True, lowercase=True),
  169. step=fields.NUMERIC(signed=False),
  170. )
  171. primary_key = schema._fields['id']
  172. @staticmethod
  173. def get_index_dir():
  174. return current_app.config.get('WHOOSH_INDEX_DIR', 'whoosh')
  175. @classmethod
  176. def get_index(cls):
  177. idxdir = cls.get_index_dir()
  178. if index.exists_in(idxdir):
  179. idx = index.open_dir(idxdir)
  180. else:
  181. if not os.path.exists(idxdir):
  182. os.makedirs(idxdir)
  183. idx = index.create_in(idxdir, cls.schema)
  184. return idx
  185. @classmethod
  186. def _search(cls, s, terms):
  187. return s.search(qparser.MultifieldParser([
  188. 'name', 'shortname', 'description', 'covered_areas'
  189. ], schema=cls.schema).parse(terms),
  190. mask=whoosh.query.Term('is_disabled', True))
  191. @classmethod
  192. def search(cls, terms):
  193. with ISPWhoosh.get_index().searcher() as s:
  194. sres = cls._search(s, terms)
  195. ranks = {}
  196. for rank, r in enumerate(sres):
  197. ranks[r['id']] = rank
  198. if not len(ranks):
  199. return []
  200. _res = ISP.query.filter(ISP.id.in_(ranks.keys()))
  201. return sorted(_res, key=lambda r: ranks[r.id])
  202. @classmethod
  203. def update_document(cls, writer, model):
  204. kw = {
  205. 'id': unicode(model.id),
  206. '_stored_id': model.id,
  207. 'is_ffdn_member': model.is_ffdn_member,
  208. 'is_disabled': model.is_disabled,
  209. 'name': model.name,
  210. 'shortname': model.shortname,
  211. 'description': model.json.get('description'),
  212. 'covered_areas': ','.join(model.covered_areas_names()),
  213. 'step': model.json.get('progressStatus')
  214. }
  215. writer.update_document(**kw)
  216. @classmethod
  217. def _after_flush(cls, app, changes):
  218. isp_changes = []
  219. for change in changes:
  220. if change[0].__class__ == ISP:
  221. update = change[1] in ('update', 'insert')
  222. isp_changes.append((update, change[0]))
  223. if not len(changes):
  224. return
  225. idx = cls.get_index()
  226. with idx.writer() as writer:
  227. for update, model in isp_changes:
  228. if update:
  229. cls.update_document(writer, model)
  230. else:
  231. writer.delete_by_term(cls.primary_key, model.id)
  232. flask_sqlalchemy.models_committed.connect(ISPWhoosh._after_flush)
  233. event.listen(flask_sqlalchemy.Session, 'before_commit', pre_save_hook)