diff --git a/lago/prefix.py b/lago/prefix.py
index e51722a5..ae7c44be 100644
--- a/lago/prefix.py
+++ b/lago/prefix.py
@@ -390,10 +390,11 @@ def _add_dns_records(self, conf, mgmts):
LOGGER.debug('Using network %s as main DNS server', dns_mgmt)
forward = conf['nets'][dns_mgmt].get('gw')
dns_records = {}
+
for net_name, net_spec in nets.iteritems():
dns_records.update(net_spec['mapping'].copy())
if net_name not in mgmts:
- net_spec['dns_forward'] = forward
+ net_spec['dns_forwarders'] = [{'addr': forward}]
for mgmt in mgmts:
if nets[mgmt].get('dns_records'):
diff --git a/lago/providers/libvirt/network.py b/lago/providers/libvirt/network.py
index 66e024c5..079bfc06 100644
--- a/lago/providers/libvirt/network.py
+++ b/lago/providers/libvirt/network.py
@@ -20,6 +20,7 @@
from future.builtins import super
from collections import defaultdict
import functools
+import itertools
import logging
import time
from copy import deepcopy
@@ -52,6 +53,10 @@ def name(self):
def gw(self):
return self._spec.get('gw')
+ @property
+ def subnet(self):
+ return self.gw().split('.')[2]
+
def mtu(self):
if self.libvirt_con.getLibVersion() > 3001001:
return self._spec.get('mtu', '1500')
@@ -145,50 +150,189 @@ def spec(self):
return deepcopy(self._spec)
-class NATNetwork(Network):
- def _generate_dns_forward(self, forward_ip):
- dns = ET.Element('dns', forwardPlainNames='yes')
- dns.append(ET.Element('forwarder', addr=forward_ip))
- return dns
+class LibvirtDNS(object):
+ """
+ This class represents the `dns` element in Libvirt's
+ Network XML.
+
+ This class has convenient "class methods" for generating
+ specific dns elements.
+
+ For more information please refer to Libvirt's documentation:
+ https://libvirt.org/formatnetwork.html
+
+ Attributes:
+ _dns(lxml.etree.Element): The root of the `dns` element
+ """
+
+ def __init__(self, enable=True, forward_plain_names=True):
+ """
+ Args:
+ enabled(bool): If false, don't create a dns server
+ forward_plain_names(bool): If false, names that are not FQDNs
+ will not be forwarded to the host's upstream server.
+ """
+ forward_plain_name = 'yes' if forward_plain_names else 'no'
+
+ if enable:
+ self._dns = ET.Element(
+ 'dns',
+ enable='yes',
+ forwardPlainNames=forward_plain_name,
+ )
+ else:
+ self._dns = ET.Element('dns', enable='no')
+
+ @classmethod
+ def generate_dns_forward(cls, forwarders):
+ """
+ Generate a dns server that forwards request to one or more dns servers.
+
+ Args:
+ forwarders(list of dicts): Each dict represents a `forwarder`
+ that will be added to the dns server. The dict's items
+ will be added as attributes to the forwarder element.
+ """
+ dns = cls()
+ dns.add_forwarders(*forwarders)
+
+ return dns.get_xml_object()
- def _generate_dns_disable(self):
- dns = ET.Element('dns', enable='no')
- return dns
+ @classmethod
+ def generate_dns_disable(cls):
+ """
+ Generate a disbaled dns server.
+ """
+ dns = cls(enable=False)
+
+ return dns.get_xml_object()
+
+ @classmethod
+ def generate_default_dns(cls):
+ """
+ Generate a default dns server.
+ Please refer to the default values of the `__init__` method
+ in order to see the properties of a default dns server.
+ """
+ return cls().get_xml_object()
+
+ @classmethod
+ def generate_main_dns(cls, records, forwarders, forward_plain_names):
+ """
+ Generate a dns server for Lago's management network.
+
+ Args:
+ records(list of tuples): For a tuple "t", t[0] is a hostname
+ and t[1] is its IP. Each tuple will be added as a `host`
+ element to the dns server.
+ forwarders(list of dicts): List of forwarders that will be added
+ ths dns server.
+ forward_plain_name(bool): If false, names that are not
+ FQDNs will not be forwarded to the host's upstream server.
+ """
+ dns = cls(forward_plain_names=forward_plain_names)
- def _generate_main_dns(self, records, subnet, forward_plain='no'):
- dns = ET.Element('dns', forwardPlainNames=forward_plain)
reverse_records = defaultdict(list)
- ipv6_prefix = self._ipv6_prefix(subnet=subnet)
- for hostname, ip in records.iteritems():
+ for hostname, ip in records:
reverse_records[ip] = reverse_records[ip] + [hostname]
+
for ip, hostnames in reverse_records.iteritems():
- record_ipv4 = ET.Element('host', ip=ip)
- record_ipv6 = ET.Element('host', ip=ipv6_prefix + ip)
- for hostname in sorted(hostnames):
- host = ET.Element('hostname')
- host.text = hostname
- record_ipv4.append(host)
- record_ipv6.append(deepcopy(host))
- dns.append(record_ipv4)
- dns.append(record_ipv6)
-
- return dns
+ dns.add_host(ip, *hostnames)
+
+ dns.add_forwarders(*forwarders)
+
+ return dns.get_xml_object()
+
+ def add_forwarders(self, *forwarders):
+ """
+ Add `forwarder(s)` to the dns server.
+
+ Args:
+ forwarders(dicts): One or more dicts that represents a forwader.
+ Each item in the dict will be mapped to an attribute of
+ the forwarder.
+ """
+ for forwarder in forwarders:
+ self._dns.append(ET.Element('forwarder', **forwarder))
+ def add_host(self, ip, *hostnames):
+ """
+ Add `host entry(s)` to the dns server.
+
+ Args:
+ ip(str): The host's IP address.
+ hostnames(str): One or more hostnames that will be mapped to `ip`.
+ """
+ host_element = ET.Element('host', ip=ip)
+
+ for hostname in sorted(hostnames):
+ hostname_element = ET.Element('hostname')
+ hostname_element.text = hostname
+ host_element.append(hostname_element)
+
+ self._dns.append(host_element)
+
+ def get_xml_object(self):
+ """
+ Returns:
+ (lxml.etree.Element): The dns server
+ """
+ return deepcopy(self._dns)
+
+
+class NATNetwork(Network):
def _ipv6_prefix(self, subnet, const='fd8f:1391:3a82:'):
return '{0}{1}::'.format(const, subnet)
+ def get_ipv6_dns_records(self, mapping):
+ """
+ Given a mapping between host names and an IPv4 addresses,
+ return a new mapping from hostnames and their IPv6.
+ The IPv6 address is gernerate from the host's IPv4.
+
+ Args:
+ mapping(dict): A mapping between host names and their
+ IPv4 addresses.
+
+ Returns:
+ (dict): A mapping between host names and their IPv6 addresses.
+ """
+ return {
+ hostname: self.ipv6_prefix + ip
+ for hostname, ip in mapping.items()
+ }
+
+ def get_ipv4_and_ipv6_dns_records(self, mapping_name):
+ """
+ Get a chain of tuples that represent a mapping between a hostname
+ and its IP address. The chain will include tuples for IPv4 and IPv6
+ addresses.
+
+ Args:
+ mapping_name(str): From which dict of this network spec the chain
+ should be built.
+
+ Returns:
+ (itertools.chain): A chain of tuples.
+ """
+ return itertools.chain(
+ self.spec[mapping_name].iteritems(),
+ self.get_ipv6_dns_records(self.spec[mapping_name]).iteritems(),
+ )
+
+ @property
+ def ipv6_prefix(self):
+ return self._ipv6_prefix(self.subnet)
+
def _libvirt_xml(self):
net_raw_xml = libvirt_utils.get_template('net_nat_template.xml')
-
- subnet = self.gw().split('.')[2]
- ipv6_prefix = self._ipv6_prefix(subnet=subnet)
mtu = self.mtu()
replacements = {
'@NAME@': self._libvirt_name(),
'@BR_NAME@': ('%s-nic' % self._libvirt_name())[:12],
'@GW_ADDR@': self.gw(),
- '@SUBNET@': subnet
+ '@SUBNET@': self.subnet
}
for k, v in replacements.items():
net_raw_xml = net_raw_xml.replace(k, v, 1)
@@ -223,8 +367,10 @@ def make_ipv4(last):
dhcpv6.append(
ET.Element(
'range',
- start=ipv6_prefix + make_ipv4(self._spec['dhcp']['start']),
- end=ipv6_prefix + make_ipv4(self._spec['dhcp']['end']),
+ start=self.ipv6_prefix +
+ make_ipv4(self._spec['dhcp']['start']),
+ end=self.ipv6_prefix +
+ make_ipv4(self._spec['dhcp']['end']),
)
)
@@ -247,7 +393,7 @@ def make_ipv4(last):
ET.Element(
'host',
id='0:3:0:1:' + utils.ipv4_to_mac(ip4),
- ip=ipv6_prefix + ip4,
+ ip=self.ipv6_prefix + ip4,
name=hostname
)
)
@@ -260,16 +406,23 @@ def make_ipv4(last):
localOnly='yes'
)
net_xml.append(domain_xml)
+
net_xml.append(
- self._generate_main_dns(self._spec['dns_records'], subnet)
+ LibvirtDNS.generate_main_dns(
+ self.get_ipv4_and_ipv6_dns_records('dns_records'),
+ self._spec.get('dns_forwarders', []),
+ forward_plain_names=False
+ )
)
else:
if self.libvirt_con.getLibVersion() < 2002000:
net_xml.append(
- self._generate_dns_forward(self._spec['dns_forward'])
+ LibvirtDNS.generate_dns_forward(
+ self._spec['dns_forwarders']
+ )
)
else:
- net_xml.append(self._generate_dns_disable())
+ net_xml.append(LibvirtDNS.generate_dns_disable())
else:
LOGGER.debug(
'Generating network XML with compatibility prior to %s',
@@ -286,14 +439,15 @@ def make_ipv4(last):
)
net_xml.append(domain_xml)
- net_xml.append(
- self._generate_main_dns(
- self._spec['mapping'], subnet, forward_plain='yes'
- )
+ dns_element = LibvirtDNS.generate_main_dns(
+ self.get_ipv4_and_ipv6_dns_records('mapping'),
+ [],
+ forward_plain_names=True,
)
+
+ net_xml.append(dns_element)
else:
- dns = ET.Element('dns', forwardPlainNames='yes', enable='yes')
- net_xml.append(dns)
+ net_xml.append(LibvirtDNS.generate_default_dns())
LOGGER.debug(
'Generated Network XML\n {0}'.format(
diff --git a/tests/unit/lago/providers/libvirt/test_dns.py b/tests/unit/lago/providers/libvirt/test_dns.py
new file mode 100644
index 00000000..158a5b72
--- /dev/null
+++ b/tests/unit/lago/providers/libvirt/test_dns.py
@@ -0,0 +1,77 @@
+#
+# Copyright 2017 Red Hat, Inc.
+#
+# This program is free software; you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation; either version 2 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program; if not, write to the Free Software
+# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
+#
+# Refer to the README and COPYING files for full details of the license
+#
+import lxml.etree as ET
+from xmlunittest import XmlTestCase
+
+from lago.providers.libvirt.network import LibvirtDNS
+
+
+class TestDNS(XmlTestCase):
+ def test_dns_disable(self):
+ _xml = ''
+ dns = LibvirtDNS.generate_dns_disable()
+
+ self.assertXmlEquivalentOutputs(ET.tostring(dns), _xml)
+
+ def test_default_dns(self):
+ _xml = ''
+ dns = LibvirtDNS.generate_default_dns()
+
+ self.assertXmlEquivalentOutputs(ET.tostring(dns), _xml)
+
+ def test_forward_dns(self):
+ _xml = """
+
+
+
+ """
+
+ dns = LibvirtDNS.generate_dns_forward([{'addr': '8.8.8.8'}])
+
+ self.assertXmlEquivalentOutputs(ET.tostring(dns), _xml)
+
+ def test_main_dns(self):
+ _xml = """
+
+
+ myhost
+ myhostalias
+
+
+
+
+ """
+
+ records = [
+ ('myhost', '192.168.122.2'), ('myhostalias', '192.168.122.2')
+ ]
+
+ forwarders = [
+ {
+ 'addr': '8.8.8.8'
+ }, {
+ 'addr': '8.8.4.4',
+ 'domain': 'example.com'
+ }
+ ]
+
+ dns = LibvirtDNS.generate_main_dns(records, forwarders, True)
+
+ self.assertXmlEquivalentOutputs(ET.tostring(dns), _xml)