diff --git a/pkg/packetfilter/iptables/iptables.go b/pkg/packetfilter/iptables/iptables.go index 0a862872f..3fafa31e7 100644 --- a/pkg/packetfilter/iptables/iptables.go +++ b/pkg/packetfilter/iptables/iptables.go @@ -79,6 +79,18 @@ func New() (packetfilter.Driver, error) { if err != nil { return nil, errors.Wrap(err, "error creating IP tables") } + return new(ipt) +} + +func NewV6() (packetfilter.Driver, error) { + ipt, err := iptables.New(iptables.IPFamily(iptables.ProtocolIPv6), iptables.Timeout(5)) + if err != nil { + return nil, errors.Wrap(err, "error creating IP tables") + } + return new(ipt) +} + +func new(ipt *iptables.IPTables) (packetfilter.Driver, error) { ipSetIface := ipset.New() diff --git a/pkg/packetfilter/iptables/namedset.go b/pkg/packetfilter/iptables/namedset.go index 4f07f22ed..73919e6a6 100644 --- a/pkg/packetfilter/iptables/namedset.go +++ b/pkg/packetfilter/iptables/namedset.go @@ -31,6 +31,9 @@ type namedSet struct { func (p *packetFilter) NewNamedSet(set *packetfilter.SetInfo) packetfilter.NamedSet { hashFamily := ipset.ProtocolFamilyIPV4 + if set.Family == packetfilter.SetFamilyV6 { + hashFamily = ipset.ProtocolFamilyIPV6 + } return &namedSet{ ipSetIface: p.ipSetIface, diff --git a/pkg/packetfilter/packetfilter.go b/pkg/packetfilter/packetfilter.go index bfb2029a5..d56913db8 100644 --- a/pkg/packetfilter/packetfilter.go +++ b/pkg/packetfilter/packetfilter.go @@ -253,8 +253,9 @@ type ChainIPHook struct { type SetFamily uint32 const ( - // curently only IPV4 sets are supported. + // IPV4 and IPV6 sets are supported. SetFamilyV4 SetFamily = iota + SetFamilyV6 ) // named set. @@ -308,21 +309,34 @@ type Interface interface { } var newDriverFn func() (Driver, error) +var newDriverFnV6 func() (Driver, error) func SetNewDriverFn(f func() (Driver, error)) { newDriverFn = f } +func SetNewDriverFnV6(f func() (Driver, error)) { + newDriverFnV6 = f +} + type Adapter struct { Driver } func New() (Interface, error) { - if newDriverFn == nil { + return new(newDriverFn) +} + +func NewV6() (Interface, error) { + return new(newDriverFnV6) +} + +func new(f func() (Driver, error)) (Interface, error) { + if f == nil { return nil, errors.New("no driver registered") } - driver, err := newDriverFn() + driver, err := f() if err != nil { return nil, errors.Wrap(err, "error creating packet filter Driver") }