#!/usr/bin/env python

from flask import Flask
from flask import request, render_template
from flask.ext.sqlalchemy import SQLAlchemy
from flask.ext.script import Server, Manager
from flask.ext.migrate import Migrate, MigrateCommand
#from flask import session, request, url_for, redirect, render_template

import netaddr
from netaddr import IPAddress, IPNetwork, IPSet
# Hack for python3
from netaddr.strategy.ipv4 import packed_to_int as unpack_v4
from netaddr.strategy.ipv6 import packed_to_int as unpack_v6
import socket
from datetime import datetime, timedelta
from uuid import uuid4

DN42 = IPSet(['172.22.0.0/15', '172.31.0.0/16', '10.0.0.0/8'])

app = Flask(__name__)
app.config.from_pyfile('config.py')
db = SQLAlchemy(app)
migrate = Migrate(app, db)

manager = Manager(app)
manager.add_command("runserver", Server(host='0.0.0.0', port=8888))
manager.add_command("db", MigrateCommand)


def unpack(ip):
    if len(ip) == 4:
        return unpack_v4(ip)
    elif len(ip) == 16:
        return unpack_v6(ip)

def is_valid_ip(ip):
    return netaddr.valid_ipv4(ip) or netaddr.valid_ipv6(ip)

def is_forbidden_ip(ip):
    # 0.0.0.0/8 is reserved, but for some reason, is_reserved() returns false
    return ip.is_link_local() or ip.is_loopback() or ip.is_multicast() or ip.is_reserved() or (ip in IPNetwork('0.0.0.0/8'))

def resolve_name(hostname):
    try:
        return list({s[4][0] for s in socket.getaddrinfo(hostname, None)})
    except socket.gaierror:
        return []


@app.template_filter()
def ipaddress_pp(addr):
    """Pretty-print an IP address"""
    a = IPAddress(addr)
    try:
        # Handle v4-mapped addresses
        return a.ipv4()
    except netaddr.AddrConversionError:
        return a.ipv6()

@app.template_filter()
def not_dn42(addr):
    """Filter the input address if it is part of dn42"""
    a = IPAddress(addr)
    if a in DN42:
        return ""
    return a

class Target(db.Model):
    """Target IP to ping"""
    id = db.Column(db.Integer, primary_key=True)
    # Unique ID for accessing the results (privacy reasons)
    unique_id = db.Column(db.String)
    # IP addresses are encoded as their binary representation
    ip = db.Column(db.BINARY(length=16))
    # Date at which a user asked for measurements to this target
    submitted = db.Column(db.DateTime)
    public = db.Column(db.Boolean)

    def __init__(self, ip, public=False):
        self.unique_id = str(uuid4())
        self.ip = IPAddress(ip).packed
        self.submitted = datetime.now()
        self.public = public

    def get_ip(self):
        return IPAddress(unpack(self.ip))

    def is_v4(self):
        return self.get_ip().version == 4

    def is_v6(self):
        return self.get_ip().version == 6

    def __repr__(self):
        return '%r' % self.get_ip()

    def __str__(self):
        return str(self.get_ip())


# Many-to-many table to record which target has been given to which
# participant.
handled_targets = db.Table('handled_targets',
    db.Column('target_id', db.Integer, db.ForeignKey('target.id')),
    db.Column('participant_id', db.Integer, db.ForeignKey('participant.id'))
)


class Participant(db.Model):
    """Participant in the ping network"""
    id = db.Column(db.Integer, primary_key=True)
    # Used both as identification and password
    uuid = db.Column(db.String, unique=True)
    # Name of the machine
    name = db.Column(db.String)
    # Mostly free-form (nick, mail address, ...)
    contact = db.Column(db.String)
    # Optional
    country = db.Column(db.String)
    # Free-form (peering technology, DSL or fiber, etc)
    comment = db.Column(db.String)
    # Whether we accept this participant or not
    active = db.Column(db.Boolean)
    # Many-to-many relationship
    targets = db.relationship('Target',
        secondary=handled_targets,
        backref=db.backref('participants', lazy='dynamic'),
        lazy='dynamic')

    def __init__(self, name, contact, country, comment):
        self.uuid = str(uuid4())
        self.name = name
        self.contact = contact
        self.country = country
        self.comment = comment
        self.active = False

    def __str__(self):
        return "{} ({})".format(self.name, self.contact)


class Result(db.Model):
    """Result of a ping measurement"""
    id = db.Column(db.Integer, primary_key=True)
    target_id = db.Column(db.Integer, db.ForeignKey('target.id'))
    target = db.relationship('Target',
                             backref=db.backref('results', lazy='dynamic'))
    participant_id = db.Column(db.Integer, db.ForeignKey('participant.id'))
    participant = db.relationship('Participant',
                                  backref=db.backref('results', lazy='dynamic'))
    # Date at which the result was reported back to us
    date = db.Column(db.DateTime)
    # In milliseconds
    avgrtt = db.Column(db.Float)
    # All these are optional
    minrtt = db.Column(db.Float)
    maxrtt = db.Column(db.Float)
    jitter = db.Column(db.Float)
    # Number of ping requests
    probes_sent = db.Column(db.Integer)
    # Number of successful probes
    probes_received = db.Column(db.Integer)

    def __init__(self, target_id, participant_uuid, avgrtt, minrtt, maxrtt,
                 jitter, probes_sent, probes_received):
        target = Target.query.get_or_404(int(target_id))
        participant = Participant.query.filter_by(uuid=participant_uuid,
                                                  active=True).first_or_404()
        self.target = target
        self.participant = participant
        self.date = datetime.now()
        self.avgrtt = float(avgrtt)
        self.minrtt = float(minrtt) if minrtt is not None else None
        self.maxrtt = float(maxrtt) if maxrtt is not None else None
        self.jitter = float(jitter) if jitter is not None else None
        self.probes_sent = int(probes_sent) if probes_sent is not None else None
        self.probes_received = int(probes_received) if probes_received is not None else None


def init_db():
    db.create_all()


def get_targets(uuid):
    """Returns the queryset of potential targets for the given participant
    UUID, that is, targets that have not already been handed out to this
    participant.
    """
    participant = Participant.query.filter_by(uuid=uuid, active=True).first_or_404()
    # We want to get all targets that do not have a relationship with the
    # given participant.  Note that the following lines manipulate SQL
    # queries, which are only executed at the very end.
    # This gives all targets that have already been sent to the given
    # participant.
    already_done = Target.query.join(handled_targets).filter_by(participant_id=participant.id).with_entities(Target.id)
    # This takes the negation of the previous set.
    new_tasks = Target.query.filter(~Target.id.in_(already_done))
    max_age = app.config.get('MAX_AGE', 0)
    if max_age == 0:
        return new_tasks
    else:
        limit = datetime.now() - timedelta(seconds=max_age)
        return new_tasks.filter(Target.submitted >= limit)


@app.route('/')
def homepage():
    public_targets = Target.query.filter_by(public=True).order_by("submitted DESC").all()
    return render_template('home.html', targets=public_targets)

@app.route('/about')
def about():
    return render_template('about.html')

@app.route('/participate')
def participate():
    return render_template('participate.html')

@app.route('/privacy')
def privacy():
    return render_template('privacy.html')

@app.route('/dev')
def dev():
    return render_template('dev.html')

@app.route('/static/<path:path>')
def static_proxy(path):
    # send_static_file will guess the correct MIME type
    return app.send_static_file(path)

@app.route('/robots.txt')
def robots():
    return app.send_static_file("robots.txt")

@app.route('/submit', methods=['POST'])
def submit_job():
    if 'target' in request.form:
        target = request.form['target'].strip()
        public = bool(request.form.get('public'))
        if is_valid_ip(target):
            # Explicit IP
            targets = [Target(target, public)]
        else:
            # DNS name, might give multiple IP
            ip_addresses = resolve_name(target)
            try:
                # We might still fail to recognise some addresses (e.g. "ff02::1%eth0")
                targets = [Target(ip, public) for ip in ip_addresses]
            except netaddr.core.AddrFormatError:
                return render_template('submit_error.html', target=request.form['target'])
        if targets == []:
            return render_template('submit_error.html', target=request.form['target'])
        # Check for forbidden targets
        for target in targets:
            if is_forbidden_ip(target.get_ip()):
                return render_template('submit_error_forbidden.html', ip=target.get_ip())
        for t in targets:
            db.session.add(t)
        db.session.commit()
        return render_template('submit.html', targets=targets)
    else:
        return "Invalid arguments"

@app.route('/create/participant', methods=['POST'])
def create_participant():
    fields = ['name', 'contact', 'country', 'comment']
    if set(fields).issubset(request.form) and request.form['name']:
        participant = Participant(*(request.form[f] for f in fields))
        db.session.add(participant)
        db.session.commit()
        return render_template('participant.html', participant=participant,
                               uuid=participant.uuid,
                               peerfinder=app.config["PEERFINDER_DN42"])
    else:
        return "Invalid arguments"

@app.route('/script.sh')
def get_script():
    r = render_template('run.sh', peerfinder=app.config["PEERFINDER_DN42"])
    return r, 200, {'Content-Type': 'text/x-shellscript'}

@app.route('/cron.sh')
def get_cron():
    r = render_template('cron.sh', peerfinder=app.config["PEERFINDER_DN42"])
    return r, 200, {'Content-Type': 'text/x-shellscript'}

@app.route('/target/<uuid>/<family>')
@app.route('/target/<uuid>')
def get_next_target(uuid, family="any"):
    """"Returns the next target to ping for the given participant and family
    ("any", "ipv4", or "ipv6")"""
    if family not in ("ipv4", "ipv6", "any"):
        return "Invalid family, should be 'any', 'ipv4' or 'ipv6'\n"
    if family == "any":
        targets = get_targets(uuid).all()
    else:
        predicate = lambda t: t.is_v4() if family == "ipv4" else t.is_v6()
        targets = [t for t in get_targets(uuid).all() if predicate(t)]
    if targets:
        return "{} {}".format(targets[0].id, targets[0])
    return ""

@app.route('/result/report/<uuid>', methods=['POST'])
def report_result(uuid):
    if {'avgrtt', 'target'}.issubset(request.form):
        target_id = request.form['target']
        avgrtt = request.form['avgrtt']
        optional_args = [request.form.get(f) for f in
                         ('minrtt', 'maxrtt', 'jitter', 'probes_sent',
                          'probes_received')]
        result = Result(target_id, uuid, avgrtt, *optional_args)
        db.session.add(result)
        # Record that the participant has returned a result
        participant = result.participant
        participant.targets.append(result.target)
        db.session.commit()
        return "OK\n"
    else:
        return "Invalid arguments\n"

@app.route('/result/show/<target_uniqueid>')
def show_results(target_uniqueid):
    target = Target.query.filter_by(unique_id=target_uniqueid).first_or_404()
    results = target.results.order_by('avgrtt').all()
    return render_template('results.html', target=target, results=results)


if __name__ == '__main__':
    if not app.debug:
        import logging
        from logging.handlers import SMTPHandler
        smtp_server = app.config.get('SMTP_SERVER', "127.0.0.1")
        from_address = app.config.get('FROM_ADDRESS', "peerfinder@example.com")
        admins = app.config.get('ADMINS', [])
        if admins:
            mail_handler = SMTPHandler(smtp_server,
                                       from_address,
                                       admins, 'Peerfinder error')
            mail_handler.setLevel(logging.ERROR)
            app.logger.addHandler(mail_handler)
    init_db()
    manager.run()