models.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297
  1. # -*- coding: utf-8 -*-
  2. from decimal import Decimal
  3. import json
  4. import os
  5. import itertools
  6. from datetime import datetime
  7. import pytz
  8. from . import db
  9. from .utils import dict_to_geojson, utcnow
  10. from flask import current_app
  11. import flask_sqlalchemy
  12. from sqlalchemy.types import TypeDecorator, VARCHAR, DateTime
  13. from sqlalchemy.ext.mutable import MutableDict
  14. from sqlalchemy import event
  15. import geoalchemy as geo
  16. import whoosh
  17. from whoosh import fields, index, qparser
  18. class fakefloat(float):
  19. def __init__(self, value):
  20. self._value = value
  21. def __repr__(self):
  22. return str(self._value)
  23. def defaultencode(o):
  24. if isinstance(o, Decimal):
  25. # Subclass float with custom repr?
  26. return fakefloat(o)
  27. raise TypeError(repr(o) + " is not JSON serializable")
  28. class JSONEncodedDict(TypeDecorator):
  29. "Represents an immutable structure as a json-encoded string."
  30. impl = VARCHAR
  31. def process_bind_param(self, value, dialect):
  32. if value is not None:
  33. value = json.dumps(value, default=defaultencode)
  34. return value
  35. def process_result_value(self, value, dialect):
  36. if value is not None:
  37. value = json.loads(value)
  38. return value
  39. class UTCDateTime(TypeDecorator):
  40. impl = DateTime
  41. def process_bind_param(self, value, engine):
  42. if value is not None:
  43. return value.astimezone(pytz.utc)
  44. def process_result_value(self, value, engine):
  45. if value is not None:
  46. return datetime(value.year, value.month, value.day,
  47. value.hour, value.minute, value.second,
  48. value.microsecond, tzinfo=pytz.utc)
  49. class ISP(db.Model):
  50. __tablename__ = 'isp'
  51. id = db.Column(db.Integer, primary_key=True)
  52. name = db.Column(db.String, nullable=False, index=True, unique=True)
  53. shortname = db.Column(db.String(12), index=True, unique=True)
  54. is_ffdn_member = db.Column(db.Boolean, default=False)
  55. is_disabled = db.Column(db.Boolean, default=False) # True = ISP will not appear
  56. json_url = db.Column(db.String)
  57. date_added = db.Column(UTCDateTime, default=utcnow)
  58. last_update_success = db.Column(UTCDateTime)
  59. last_update_attempt = db.Column(UTCDateTime)
  60. update_error_strike = db.Column(db.Integer, default=0) # if >= 3; then updates are disabled
  61. next_update = db.Column(UTCDateTime, default=utcnow)
  62. tech_email = db.Column(db.String)
  63. cache_info = db.Column(MutableDict.as_mutable(JSONEncodedDict))
  64. json = db.Column(MutableDict.as_mutable(JSONEncodedDict))
  65. covered_areas = db.relationship('CoveredArea', backref='isp')
  66. # covered_areas_query = db.relationship('CoveredArea', lazy='dynamic')
  67. registered_office = db.relationship('RegisteredOffice', uselist=False, backref='isp')
  68. def __init__(self, *args, **kwargs):
  69. super(ISP, self).__init__(*args, **kwargs)
  70. self.json = {}
  71. def pre_save(self, *args):
  72. if db.inspect(self).attrs.json.history.has_changes():
  73. self._sync_covered_areas()
  74. def _sync_covered_areas(self):
  75. """
  76. Called to synchronise between json['coveredAreas'] and the
  77. covered_areas table, when json was modified.
  78. """
  79. # delete current covered areas & registered office
  80. CoveredArea.query.filter_by(isp_id=self.id).delete()
  81. RegisteredOffice.query.filter_by(isp_id=self.id).delete()
  82. for ca_js in self.json.get('coveredAreas', []):
  83. ca = CoveredArea()
  84. ca.name = ca_js['name']
  85. area = ca_js.get('area')
  86. ca.area = db.func.CastToMultiPolygon(
  87. db.func.GeomFromGeoJSON(dict_to_geojson(area))
  88. ) if area else None
  89. self.covered_areas.append(ca)
  90. coords = self.json.get('coordinates')
  91. if coords:
  92. self.registered_office = RegisteredOffice(
  93. point=db.func.MakePoint(coords['longitude'], coords['latitude'], 4326)
  94. )
  95. def covered_areas_names(self):
  96. return [c['name'] for c in self.json.get('coveredAreas', [])]
  97. @property
  98. def is_local(self):
  99. return self.json_url is None
  100. @property
  101. def complete_name(self):
  102. if 'shortname' in self.json:
  103. return u'%s (%s)' % (self.json['shortname'], self.json['name'])
  104. else:
  105. return u'%s' % self.json['name']
  106. @staticmethod
  107. def str2date(_str):
  108. d = None
  109. try:
  110. d = datetime.strptime(_str, '%Y-%m-%d')
  111. except ValueError:
  112. pass
  113. if d is None:
  114. try:
  115. d = datetime.strptime(_str, '%Y-%m')
  116. except ValueError:
  117. pass
  118. return d
  119. def has_technology(self, technology):
  120. for i in self.json["coveredAreas"]:
  121. if technology in i["technologies"]:
  122. return True
  123. return False
  124. def __repr__(self):
  125. return u'<ISP %r>' % (self.shortname if self.shortname else self.name,)
  126. class CoveredArea(db.Model):
  127. __tablename__ = 'covered_areas'
  128. id = db.Column(db.Integer, primary_key=True)
  129. isp_id = db.Column(db.Integer, db.ForeignKey('isp.id'))
  130. name = db.Column(db.String)
  131. area = geo.GeometryColumn(geo.MultiPolygon(2))
  132. area_geojson = db.column_property(db.func.AsGeoJSON(db.literal_column('area')), deferred=True)
  133. @classmethod
  134. def containing(cls, coords):
  135. """
  136. Return CoveredAreas containing point (lat,lon)
  137. """
  138. return cls.query.filter(
  139. cls.area != None,
  140. cls.area.gcontains(db.func.MakePoint(coords[1], coords[0])) == 1
  141. )
  142. def __repr__(self):
  143. return u'<CoveredArea %r>' % (self.name,)
  144. geo.GeometryDDL(CoveredArea.__table__)
  145. class RegisteredOffice(db.Model):
  146. __tablename__ = 'registered_offices'
  147. id = db.Column(db.Integer, primary_key=True)
  148. isp_id = db.Column(db.Integer, db.ForeignKey('isp.id'))
  149. point = geo.GeometryColumn(geo.Point(0))
  150. geo.GeometryDDL(RegisteredOffice.__table__)
  151. @event.listens_for(db.metadata, 'before_create')
  152. def init_spatialite_metadata(target, conn, **kwargs):
  153. conn.execute('SELECT InitSpatialMetaData(1)')
  154. def pre_save_hook(sess):
  155. for v in itertools.chain(sess.new, sess.dirty):
  156. if hasattr(v, 'pre_save') and hasattr(v.pre_save, '__call__'):
  157. v.pre_save(sess)
  158. class ISPWhoosh(object):
  159. """
  160. Helper class to index the ISP model with Whoosh to allow full-text search
  161. """
  162. schema = fields.Schema(
  163. id=fields.ID(unique=True, stored=True),
  164. is_ffdn_member=fields.BOOLEAN(),
  165. is_disabled=fields.BOOLEAN(),
  166. name=fields.TEXT(),
  167. shortname=fields.TEXT(),
  168. description=fields.TEXT(),
  169. covered_areas=fields.KEYWORD(scorable=True, commas=True, lowercase=True),
  170. step=fields.NUMERIC(signed=False),
  171. )
  172. primary_key = schema._fields['id']
  173. @staticmethod
  174. def get_index_dir():
  175. return current_app.config.get('WHOOSH_INDEX_DIR', 'whoosh')
  176. @classmethod
  177. def get_index(cls):
  178. idxdir = cls.get_index_dir()
  179. if index.exists_in(idxdir):
  180. idx = index.open_dir(idxdir)
  181. else:
  182. if not os.path.exists(idxdir):
  183. os.makedirs(idxdir)
  184. idx = index.create_in(idxdir, cls.schema)
  185. return idx
  186. @classmethod
  187. def _search(cls, s, terms):
  188. return s.search(qparser.MultifieldParser([
  189. 'name', 'shortname', 'description', 'covered_areas'
  190. ], schema=cls.schema).parse(terms),
  191. mask=whoosh.query.Term('is_disabled', True))
  192. @classmethod
  193. def search(cls, terms):
  194. with ISPWhoosh.get_index().searcher() as s:
  195. sres = cls._search(s, terms)
  196. ranks = {}
  197. for rank, r in enumerate(sres):
  198. ranks[r['id']] = rank
  199. if not len(ranks):
  200. return []
  201. _res = ISP.query.filter(ISP.id.in_(ranks.keys()))
  202. return sorted(_res, key=lambda r: ranks[r.id])
  203. @classmethod
  204. def update_document(cls, writer, model):
  205. kw = {
  206. 'id': unicode(model.id),
  207. '_stored_id': model.id,
  208. 'is_ffdn_member': model.is_ffdn_member,
  209. 'is_disabled': model.is_disabled,
  210. 'name': model.name,
  211. 'shortname': model.shortname,
  212. 'description': model.json.get('description'),
  213. 'covered_areas': unicode(','.join(model.covered_areas_names())),
  214. 'step': model.json.get('progressStatus')
  215. }
  216. writer.update_document(**kw)
  217. @classmethod
  218. def _after_flush(cls, app, changes):
  219. try:
  220. isp_changes = []
  221. for change in changes:
  222. if change[0].__class__ == ISP:
  223. update = change[1] in ('update', 'insert')
  224. isp_changes.append((update, change[0]))
  225. if not len(changes):
  226. return
  227. idx = cls.get_index()
  228. with idx.writer() as writer:
  229. for update, model in isp_changes:
  230. if update:
  231. cls.update_document(writer, model)
  232. else:
  233. writer.delete_by_term(cls.primary_key, model.id)
  234. except Exception as e:
  235. print("Error while updating woosh db. Cause: {}".format(e))
  236. flask_sqlalchemy.models_committed.connect(ISPWhoosh._after_flush)
  237. event.listen(flask_sqlalchemy.Session, 'before_commit', pre_save_hook)