diff --git a/confuse.py b/confuse.py index 63efa61..976b79c 100644 --- a/confuse.py +++ b/confuse.py @@ -861,30 +861,34 @@ def restore_yaml_comments(data, default_data): Only works with comments that are on one or more own lines, i.e. not next to a yaml mapping. """ + + def has_comment(line): + if not line: + return True + elif re.match(r'^\s*#.*$', line): + return True + else: + return False + comment_map = dict() + comment = "" default_lines = iter(default_data.splitlines()) for line in default_lines: - if not line: - comment = "\n" - elif line.startswith("#"): - comment = "{0}\n".format(line) - else: - continue - while True: - line = next(default_lines) - if line and not line.startswith("#"): - break + if has_comment(line): comment += "{0}\n".format(line) - key = line.split(':')[0].strip() - comment_map[key] = comment - out_lines = iter(data.splitlines()) + else: + key = line.split(':')[0].strip() + if comment != "": + comment_map[key] = comment + comment = "" out_data = "" + out_lines = iter(data.splitlines()) for line in out_lines: key = line.split(':')[0].strip() if key in comment_map: - out_data += comment_map[key] + out_data += comment_map.pop(key, None) out_data += "{0}\n".format(line) - return out_data + return out_data + comment # Main interface. diff --git a/test/test_dump.py b/test/test_dump.py index e40add7..4cd4362 100644 --- a/test/test_dump.py +++ b/test/test_dump.py @@ -59,6 +59,33 @@ def test_dump_sans_defaults(self): yaml = config.dump(full=False).strip() self.assertEqual(yaml, "baz: qux") + def test_restore_yaml_comments(self): + odict = confuse.OrderedDict() + odict['foo'] = 'bar' + odict['bar'] = 'baz' + + config = confuse.Configuration('myapp', read=False) + config.add({'key1': odict}) + config.add({'key2': odict}) + data = config.dump() + default_data = textwrap.dedent(""" + # Comment 1 + key1: + # Comment 2 + foo: bar + bar: baz + + key2: + foo: bar + bar: baz + + # TODO: add more keys + """) + self.assertEqual( + default_data, + confuse.restore_yaml_comments(data, default_data) + ) + class RedactTest(unittest.TestCase): def test_no_redaction(self):