models.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273
  1. # -*- coding: utf-8 -*-
  2. from decimal import Decimal
  3. import json
  4. import os
  5. import itertools
  6. from datetime import datetime
  7. from . import db, app
  8. from .utils import dict_to_geojson
  9. import flask_sqlalchemy
  10. from sqlalchemy.types import TypeDecorator, VARCHAR
  11. from sqlalchemy.ext.mutable import MutableDict
  12. from sqlalchemy import event
  13. import geoalchemy as geo
  14. import whoosh
  15. from whoosh import fields, index, qparser
  16. class fakefloat(float):
  17. def __init__(self, value):
  18. self._value = value
  19. def __repr__(self):
  20. return str(self._value)
  21. def defaultencode(o):
  22. if isinstance(o, Decimal):
  23. # Subclass float with custom repr?
  24. return fakefloat(o)
  25. raise TypeError(repr(o) + " is not JSON serializable")
  26. class JSONEncodedDict(TypeDecorator):
  27. "Represents an immutable structure as a json-encoded string."
  28. impl = VARCHAR
  29. def process_bind_param(self, value, dialect):
  30. if value is not None:
  31. value = json.dumps(value, default=defaultencode)
  32. return value
  33. def process_result_value(self, value, dialect):
  34. if value is not None:
  35. value = json.loads(value)
  36. return value
  37. class ISP(db.Model):
  38. __tablename__ = 'isp'
  39. id = db.Column(db.Integer, primary_key=True)
  40. name = db.Column(db.String, nullable=False, index=True, unique=True)
  41. shortname = db.Column(db.String(12), index=True, unique=True)
  42. is_ffdn_member = db.Column(db.Boolean, default=False)
  43. is_disabled = db.Column(db.Boolean, default=False) # True = ISP will not appear
  44. json_url = db.Column(db.String)
  45. last_update_success = db.Column(db.DateTime)
  46. last_update_attempt = db.Column(db.DateTime)
  47. update_error_strike = db.Column(db.Integer, default=0) # if >= 3; then updates are disabled
  48. next_update = db.Column(db.DateTime, default=datetime.now())
  49. tech_email = db.Column(db.String)
  50. cache_info = db.Column(MutableDict.as_mutable(JSONEncodedDict))
  51. json = db.Column(MutableDict.as_mutable(JSONEncodedDict))
  52. covered_areas = db.relationship('CoveredArea', backref='isp')
  53. # covered_areas_query = db.relationship('CoveredArea', lazy='dynamic')
  54. registered_office = db.relationship('RegisteredOffice', uselist=False, backref='isp')
  55. def __init__(self, *args, **kwargs):
  56. super(ISP, self).__init__(*args, **kwargs)
  57. self.json={}
  58. def pre_save(self, *args):
  59. if 'name' in self.json:
  60. assert self.name == self.json['name']
  61. if 'shortname' in self.json:
  62. assert self.shortname == self.json['shortname']
  63. if db.inspect(self).attrs.json.history.has_changes():
  64. self._sync_covered_areas()
  65. def _sync_covered_areas(self):
  66. """
  67. Called to synchronise between json['coveredAreas'] and the
  68. covered_areas table, when json was modified.
  69. """
  70. # delete current covered areas
  71. self.covered_areas.delete()
  72. RegisteredOffice.query.filter_by(isp_id=self.id).delete()
  73. for ca_js in self.json.get('coveredAreas', []):
  74. ca=CoveredArea()
  75. ca.name=ca_js['name']
  76. area=ca_js.get('area')
  77. ca.area=db.func.CastToMultiPolygon(
  78. db.func.GeomFromGeoJSON(dict_to_geojson(area))
  79. ) if area else None
  80. self.covered_areas.append(ca)
  81. coords=self.json.get('coordinates')
  82. if coords:
  83. self.registered_office=RegisteredOffice(
  84. point=db.func.MakePoint(coords['longitude'], coords['latitude'], 4326)
  85. )
  86. def covered_areas_names(self):
  87. return [c['name'] for c in self.json.get('coveredAreas', [])]
  88. @property
  89. def is_local(self):
  90. return self.json_url is None
  91. @property
  92. def complete_name(self):
  93. if 'shortname' in self.json:
  94. return u'%s (%s)'%(self.json['shortname'], self.json['name'])
  95. else:
  96. return u'%s'%self.json['name']
  97. @staticmethod
  98. def str2date(_str):
  99. d=None
  100. try:
  101. d=datetime.strptime(_str, '%Y-%m-%d')
  102. except ValueError:
  103. pass
  104. if d is None:
  105. try:
  106. d=datetime.strptime(_str, '%Y-%m')
  107. except ValueError:
  108. pass
  109. return d
  110. def __repr__(self):
  111. return u'<ISP %r>' % (self.shortname if self.shortname else self.name,)
  112. class CoveredArea(db.Model):
  113. __tablename__ = 'covered_areas'
  114. id = db.Column(db.Integer, primary_key=True)
  115. isp_id = db.Column(db.Integer, db.ForeignKey('isp.id'))
  116. name = db.Column(db.String)
  117. area = geo.GeometryColumn(geo.MultiPolygon(2))
  118. area_geojson = db.column_property(db.func.AsGeoJSON(db.literal_column('area')), deferred=True)
  119. @classmethod
  120. def containing(cls, coords):
  121. """
  122. Return CoveredAreas containing point (lat,lon)
  123. """
  124. return cls.query.filter(
  125. cls.area != None,
  126. cls.area.gcontains(db.func.MakePoint(coords[1], coords[0])) == 1
  127. )
  128. def __repr__(self):
  129. return u'<CoveredArea %r>' % (self.name,)
  130. geo.GeometryDDL(CoveredArea.__table__)
  131. class RegisteredOffice(db.Model):
  132. __tablename__ = 'registered_offices'
  133. id = db.Column(db.Integer, primary_key=True)
  134. isp_id = db.Column(db.Integer, db.ForeignKey('isp.id'))
  135. point = geo.GeometryColumn(geo.Point(0))
  136. geo.GeometryDDL(RegisteredOffice.__table__)
  137. @event.listens_for(db.metadata, 'before_create')
  138. def init_spatialite_metadata(target, conn, **kwargs):
  139. conn.execute('SELECT InitSpatialMetaData(1)')
  140. def pre_save_hook(sess):
  141. for v in itertools.chain(sess.new, sess.dirty):
  142. if hasattr(v, 'pre_save') and hasattr(v.pre_save, '__call__'):
  143. v.pre_save(sess)
  144. class ISPWhoosh(object):
  145. """
  146. Helper class to index the ISP model with Whoosh to allow full-text search
  147. """
  148. schema = fields.Schema(
  149. id=fields.ID(unique=True, stored=True),
  150. is_ffdn_member=fields.BOOLEAN(),
  151. is_disabled=fields.BOOLEAN(),
  152. name=fields.TEXT(),
  153. shortname=fields.TEXT(),
  154. description=fields.TEXT(),
  155. covered_areas=fields.KEYWORD(scorable=True, commas=True, lowercase=True),
  156. step=fields.NUMERIC(signed=False),
  157. )
  158. primary_key=schema._fields['id']
  159. @staticmethod
  160. def get_index_dir():
  161. return app.config.get('WHOOSH_INDEX_DIR', 'whoosh')
  162. @classmethod
  163. def get_index(cls):
  164. idxdir=cls.get_index_dir()
  165. if index.exists_in(idxdir):
  166. idx = index.open_dir(idxdir)
  167. else:
  168. if not os.path.exists(idxdir):
  169. os.makedirs(idxdir)
  170. idx = index.create_in(idxdir, cls.schema)
  171. return idx
  172. @classmethod
  173. def _search(cls, s, terms):
  174. return s.search(qparser.MultifieldParser([
  175. 'name', 'shortname', 'description', 'covered_areas'
  176. ], schema=cls.schema).parse(terms),
  177. mask=whoosh.query.Term('is_disabled', True))
  178. @classmethod
  179. def search(cls, terms):
  180. with ISPWhoosh.get_index().searcher() as s:
  181. sres=cls._search(s, terms)
  182. ranks={}
  183. for rank, r in enumerate(sres):
  184. ranks[r['id']]=rank
  185. if not len(ranks):
  186. return []
  187. _res=ISP.query.filter(ISP.id.in_(ranks.keys()))
  188. return sorted(_res, key=lambda r: ranks[r.id])
  189. @classmethod
  190. def update_document(cls, writer, model):
  191. kw={
  192. 'id': unicode(model.id),
  193. '_stored_id': model.id,
  194. 'is_ffdn_member': model.is_ffdn_member,
  195. 'is_disabled': model.is_disabled,
  196. 'name': model.name,
  197. 'shortname': model.shortname,
  198. 'description': model.json.get('description'),
  199. 'covered_areas': model.covered_areas_names(),
  200. 'step': model.json.get('progressStatus')
  201. }
  202. writer.update_document(**kw)
  203. @classmethod
  204. def _after_flush(cls, app, changes):
  205. isp_changes = []
  206. for change in changes:
  207. if change[0].__class__ == ISP:
  208. update = change[1] in ('update', 'insert')
  209. isp_changes.append((update, change[0]))
  210. if not len(changes):
  211. return
  212. idx=cls.get_index()
  213. with idx.writer() as writer:
  214. for update, model in isp_changes:
  215. if update:
  216. cls.update_document(writer, model)
  217. else:
  218. writer.delete_by_term(cls.primary_key, model.id)
  219. flask_sqlalchemy.models_committed.connect(ISPWhoosh._after_flush)
  220. event.listen(flask_sqlalchemy.Session, 'before_commit', pre_save_hook)