Browse Source

Avoid duplicate code for fetching jobs

Baptiste Jonglez 10 years ago
parent
commit
8f5494dc5e
1 changed files with 13 additions and 17 deletions
  1. 13 17
      peerfinder.py

+ 13 - 17
peerfinder.py

@@ -237,25 +237,21 @@ def get_script():
     r = render_template('run.sh', peerfinder=app.config["PEERFINDER_DN42"])
     r = render_template('run.sh', peerfinder=app.config["PEERFINDER_DN42"])
     return r, 200, {'Content-Type': 'text/x-shellscript'}
     return r, 200, {'Content-Type': 'text/x-shellscript'}
 
 
+@app.route('/target/<uuid>/<family>')
 @app.route('/target/<uuid>')
 @app.route('/target/<uuid>')
-def get_next_target(uuid):
-    """"Returns the next target to ping for the given participant"""
-    target = get_targets(uuid).first()
-    if target is not None:
-        return "{} {}".format(target.id, target)
+def get_next_target(uuid, family="any"):
+    """"Returns the next target to ping for the given participant and family
+    ("any", "ipv4", or "ipv6")"""
+    if family not in ("ipv4", "ipv6", "any"):
+        return "Invalid family, should be 'any', 'ipv4' or 'ipv6'\n"
+    if family == "any":
+        targets = get_targets(uuid).all()
     else:
     else:
-        return ""
-
-@app.route('/target/<uuid>/<family>')
-def get_next_target_family(uuid, family):
-    """Same as above, but for a specific family ("ipv4" or "ipv6")"""
-    if family not in ("ipv4", "ipv6"):
-        return "Invalid family, should be ipv4 or ipv6\n"
-    predicate = lambda t: t.is_v4() if family == "ipv4" else t.is_v6()
-    targets = [t for t in get_targets(uuid).all() if predicate(t)]
-    if not targets:
-        return ""
-    return "{} {}".format(targets[0].id, targets[0])
+        predicate = lambda t: t.is_v4() if family == "ipv4" else t.is_v6()
+        targets = [t for t in get_targets(uuid).all() if predicate(t)]
+    if targets:
+        return "{} {}".format(targets[0].id, targets[0])
+    return ""
 
 
 @app.route('/result/report/<uuid>', methods=['POST'])
 @app.route('/result/report/<uuid>', methods=['POST'])
 def report_result(uuid):
 def report_result(uuid):