Parcourir la source

[master] Merge branch 'trac3373'

Python data.merge function now does "deep" merges
Thomas Markwalder il y a 11 ans
Parent
commit
da3b0d4f36
2 fichiers modifiés avec 37 ajouts et 20 suppressions
  1. 15 5
      src/lib/python/isc/cc/data.py
  2. 22 15
      src/lib/python/isc/cc/tests/data_test.py

+ 15 - 5
src/lib/python/isc/cc/data.py

@@ -1,4 +1,4 @@
-# Copyright (C) 2010  Internet Systems Consortium.
+# Copyright (C) 2010-2014  Internet Systems Consortium.
 #
 # Permission to use, copy, modify, and distribute this software for any
 # purpose with or without fee is hereby granted, provided that the above
@@ -51,14 +51,24 @@ def remove_identical(a, b):
         del(a[id])
 
 def merge(orig, new):
-    """Merges the contents of new into orig, think recursive update()
-       orig and new must both be dicts. If an element value is None in
-       new it will be removed in orig."""
+    """Merges the contents of one dictionary into another.
+       The merge is done element by element, in order to recursivley merge
+       any elements which are themselves dictionaries. If an element value
+       is None in new it will be removed in orig. Previously this method
+       relied on dict.update but this does not do deep merges properly.
+       Raises a DataTypeError if either argument is not a dict"""
     if type(orig) != dict or type(new) != dict:
         raise DataTypeError("Not a dict in merge()")
-    orig.update(new)
+
+    for key in new.keys():
+        if ((key in orig) and (type(orig[key]) == dict)):
+            merge(orig[key], new[key])
+        else:
+            orig[key] = new[key]
+
     remove_null_items(orig)
 
+
 def remove_null_items(d):
     """Recursively removes all (key,value) pairs from d where the
        value is None"""

+ 22 - 15
src/lib/python/isc/cc/tests/data_test.py

@@ -34,37 +34,37 @@ class TestData(unittest.TestCase):
         c = {}
         data.remove_identical(a, b)
         self.assertEqual(a, c)
-    
+
         a = { "a": 1, "b": [ 1, 2 ] }
         b = {}
         c = { "a": 1, "b": [ 1, 2 ] }
         data.remove_identical(a, b)
         self.assertEqual(a, c)
-    
+
         a = { "a": 1, "b": [ 1, 2 ] }
         b = { "a": 1, "b": [ 1, 2 ] }
         c = {}
         data.remove_identical(a, b)
         self.assertEqual(a, c)
-    
+
         a = { "a": 1, "b": [ 1, 2 ] }
         b = { "a": 1, "b": [ 1, 3 ] }
         c = { "b": [ 1, 2 ] }
         data.remove_identical(a, b)
         self.assertEqual(a, c)
-    
+
         a = { "a": { "b": "c" } }
         b = {}
         c = { "a": { "b": "c" } }
         data.remove_identical(a, b)
         self.assertEqual(a, c)
-    
+
         a = { "a": { "b": "c" } }
         b = { "a": { "b": "c" } }
         c = {}
         data.remove_identical(a, b)
         self.assertEqual(a, c)
-    
+
         a = { "a": { "b": "c" } }
         b = { "a": { "b": "d" } }
         c = { "a": { "b": "c" } }
@@ -75,7 +75,7 @@ class TestData(unittest.TestCase):
                           a, 1)
         self.assertRaises(data.DataTypeError, data.remove_identical,
                           1, b)
-        
+
     def test_merge(self):
         d1 = { 'a': 'a', 'b': 1, 'c': { 'd': 'd', 'e': 2 } }
         d2 = { 'a': None, 'c': { 'd': None, 'e': 3, 'f': [ 1 ] } }
@@ -87,6 +87,13 @@ class TestData(unittest.TestCase):
         self.assertRaises(data.DataTypeError, data.merge, 1, d2)
         self.assertRaises(data.DataTypeError, data.merge, None, None)
 
+        # An example that failed when merge was relying on dict.update.
+        tnew = {'d2': {'port': 54000}}
+        torig = {'ifaces': ['p8p1'], 'db': {'type': 'memfile'}, 'd2': {'ip': '127.0.0.1', 'enable': True}}
+        tchk = {'ifaces': ['p8p1'], 'db': {'type': 'memfile'}, 'd2': {'ip': '127.0.0.1', 'enable': True, 'port': 54000}}
+        tmrg = torig
+        data.merge(tmrg, tnew)
+        self.assertEqual(tmrg, tchk)
 
     def test_split_identifier_list_indices(self):
         id, indices = data.split_identifier_list_indices('a')
@@ -103,15 +110,15 @@ class TestData(unittest.TestCase):
         id, indices = data.split_identifier_list_indices('a/b/c')
         self.assertEqual(id, 'a/b/c')
         self.assertEqual(indices, None)
-        
+
         id, indices = data.split_identifier_list_indices('a/b/c[1]')
         self.assertEqual(id, 'a/b/c')
         self.assertEqual(indices, [1])
-       
+
         id, indices = data.split_identifier_list_indices('a/b/c[1][2][3]')
         self.assertEqual(id, 'a/b/c')
         self.assertEqual(indices, [1, 2, 3])
-        
+
         id, indices = data.split_identifier_list_indices('a[0]/b[1]/c[2]')
         self.assertEqual(id, 'a[0]/b[1]/c')
         self.assertEqual(indices, [2])
@@ -124,7 +131,7 @@ class TestData(unittest.TestCase):
         self.assertRaises(data.DataTypeError, data.split_identifier_list_indices, 'a[0]a[1]')
 
         self.assertRaises(data.DataTypeError, data.split_identifier_list_indices, 1)
-        
+
 
     def test_find(self):
         d1 = { 'a': 'a', 'b': 1, 'c': { 'd': 'd', 'e': 2, 'more': { 'data': 'here' } } }
@@ -151,7 +158,7 @@ class TestData(unittest.TestCase):
         d3 = { 'a': [ { 'b': [ {}, { 'c': 'd' } ] } ] }
         self.assertEqual(data.find(d3, 'a[0]/b[1]/c'), 'd')
         self.assertRaises(data.DataNotFoundError, data.find, d3, 'a[1]/b[1]/c')
-        
+
     def test_set(self):
         d1 = { 'a': 'a', 'b': 1, 'c': { 'd': 'd', 'e': 2 } }
         d12 = { 'b': 1, 'c': { 'e': 3, 'f': [ 1 ] } }
@@ -170,7 +177,7 @@ class TestData(unittest.TestCase):
         self.assertEqual(d1, d14)
         data.set(d1, 'c/f[0]/g[1]', 3)
         self.assertEqual(d1, d15)
-        
+
         self.assertRaises(data.DataTypeError, data.set, d1, 1, 2)
         self.assertRaises(data.DataTypeError, data.set, 1, "", 2)
         self.assertRaises(data.DataTypeError, data.set, d1, 'c[1]', 2)
@@ -205,7 +212,7 @@ class TestData(unittest.TestCase):
         self.assertEqual(d3, { 'a': [ [ 1, 3 ] ] })
         data.unset(d3, 'a[0][1]')
         self.assertEqual(d3, { 'a': [ [ 1 ] ] })
-        
+
     def test_find_no_exc(self):
         d1 = { 'a': 'a', 'b': 1, 'c': { 'd': 'd', 'e': 2, 'more': { 'data': 'here' } } }
         self.assertEqual(data.find_no_exc(d1, ''), d1)
@@ -220,7 +227,7 @@ class TestData(unittest.TestCase):
         self.assertEqual(data.find_no_exc(None, 1), None)
         self.assertEqual(data.find_no_exc("123", ""), "123")
         self.assertEqual(data.find_no_exc("123", ""), "123")
-        
+
     def test_parse_value_str(self):
         self.assertEqual(data.parse_value_str("1"), 1)
         self.assertEqual(data.parse_value_str("true"), True)