Skip to content

Commit

Permalink
fix: sort ports and merge adjacent ones in the nft rule
Browse files Browse the repository at this point in the history
Fixes #9009

When building a port interval set, sort the ports and merge adjacent
ranges to prevent mismatch on the nftables side.

With address sets, this was already the case due to the way IPRange
builder works, but ports need a manual implementation.

Signed-off-by: Andrey Smirnov <[email protected]>
(cherry picked from commit f14c479)
  • Loading branch information
smira committed Aug 6, 2024
1 parent d692ab1 commit 44827e4
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 9 deletions.
27 changes: 25 additions & 2 deletions internal/app/machined/pkg/adapters/network/nftables_rule.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package network

import (
"cmp"
"fmt"
"net/netip"
"os"
Expand Down Expand Up @@ -109,9 +110,11 @@ func (set NfTablesSet) SetElements() []nftables.SetElement {

return elements
case SetKindPort:
elements := make([]nftables.SetElement, 0, len(set.Ports))
ports := mergeAdjacentPorts(set.Ports)

for _, p := range set.Ports {
elements := make([]nftables.SetElement, 0, len(ports))

for _, p := range ports {
from := binaryutil.BigEndian.PutUint16(p[0])
to := binaryutil.BigEndian.PutUint16(p[1] + 1)

Expand Down Expand Up @@ -157,6 +160,26 @@ func (set NfTablesSet) SetElements() []nftables.SetElement {
}
}

func mergeAdjacentPorts(in [][2]uint16) [][2]uint16 {
ports := slices.Clone(in)

slices.SortFunc(ports, func(a, b [2]uint16) int {
// sort by the lower bound of the range, assume no overlap
return cmp.Compare(a[0], b[0])
})

for i := 0; i < len(ports)-1; {
if ports[i][1]+1 >= ports[i+1][0] {
ports[i][1] = ports[i+1][1]
ports = append(ports[:i+1], ports[i+2:]...)
} else {
i++
}
}

return ports
}

// NfTablesCompiled is a compiled representation of the rule.
type NfTablesCompiled struct {
Rules [][]expr.Any
Expand Down
55 changes: 50 additions & 5 deletions internal/app/machined/pkg/adapters/network/nftables_rule_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -526,14 +526,14 @@ func TestNfTablesRuleCompile(t *testing.T) { //nolint:tparallel
Protocol: nethelpers.ProtocolTCP,
MatchSourcePort: &networkres.NfTablesPortMatch{
Ranges: []networkres.PortRange{
{
Lo: 1000,
Hi: 1025,
},
{
Lo: 2000,
Hi: 2000,
},
{
Lo: 1000,
Hi: 1025,
},
},
},
},
Expand Down Expand Up @@ -562,8 +562,8 @@ func TestNfTablesRuleCompile(t *testing.T) { //nolint:tparallel
{
Kind: network.SetKindPort,
Ports: [][2]uint16{
{1000, 1025},
{2000, 2000},
{1000, 1025},
},
},
},
Expand Down Expand Up @@ -713,3 +713,48 @@ func TestNfTablesRuleCompile(t *testing.T) { //nolint:tparallel
})
}
}

func TestNftablesSet(t *testing.T) { //nolint:tparallel
t.Parallel()

for _, test := range []struct {
name string

set network.NfTablesSet

expectedKeyType nftables.SetDatatype
expectedInterval bool
expectedData []nftables.SetElement
}{
{
name: "ports",

set: network.NfTablesSet{
Kind: network.SetKindPort,
Ports: [][2]uint16{
{443, 443},
{80, 81},
{5000, 5000},
{5001, 5001},
},
},

expectedKeyType: nftables.TypeInetService,
expectedInterval: true,
expectedData: []nftables.SetElement{ // network byte order
{Key: []uint8{0x0, 80}, IntervalEnd: false}, // 80 - 81
{Key: []uint8{0x0, 82}, IntervalEnd: true},
{Key: []uint8{0x1, 0xbb}, IntervalEnd: false}, // 443-443
{Key: []uint8{0x1, 0xbc}, IntervalEnd: true},
{Key: []uint8{0x13, 0x88}, IntervalEnd: false}, // 5000-5001
{Key: []uint8{0x13, 0x8a}, IntervalEnd: true},
},
},
} {
t.Run(test.name, func(t *testing.T) {
assert.Equal(t, test.expectedKeyType, test.set.KeyType())
assert.Equal(t, test.expectedInterval, test.set.IsInterval())
assert.Equal(t, test.expectedData, test.set.SetElements())
})
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -442,8 +442,60 @@ func (s *NfTablesChainSuite) TestL4Match2() {
s.checkNftOutput(`table inet talos-test {
chain test-tcp {
type filter hook input priority filter; policy accept;
ip saddr != { 10.0.0.0/8 } tcp dport { 1023, 1024 } drop
meta nfproto ipv6 tcp dport { 1023, 1024 } drop
ip saddr != { 10.0.0.0/8 } tcp dport { 1023-1024 } drop
meta nfproto ipv6 tcp dport { 1023-1024 } drop
}
}`)
}

func (s *NfTablesChainSuite) TestL4MatchAdjacentPorts() {
chain := network.NewNfTablesChain(network.NamespaceName, "test-tcp")
chain.TypedSpec().Type = nethelpers.ChainTypeFilter
chain.TypedSpec().Hook = nethelpers.ChainHookInput
chain.TypedSpec().Priority = nethelpers.ChainPriorityFilter
chain.TypedSpec().Policy = nethelpers.VerdictAccept
chain.TypedSpec().Rules = []network.NfTablesRule{
{
MatchSourceAddress: &network.NfTablesAddressMatch{
IncludeSubnets: []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
},
Invert: true,
},
MatchLayer4: &network.NfTablesLayer4Match{
Protocol: nethelpers.ProtocolTCP,
MatchDestinationPort: &network.NfTablesPortMatch{
Ranges: []network.PortRange{
{
Lo: 5000,
Hi: 5000,
},
{
Lo: 5001,
Hi: 5001,
},
{
Lo: 10250,
Hi: 10250,
},
{
Lo: 4240,
Hi: 4240,
},
},
},
},
Verdict: pointer.To(nethelpers.VerdictDrop),
},
}

s.Require().NoError(s.State().Create(s.Ctx(), chain))

s.checkNftOutput(`table inet talos-test {
chain test-tcp {
type filter hook input priority filter; policy accept;
ip saddr != { 10.0.0.0/8 } tcp dport { 4240, 5000-5001, 10250 } drop
meta nfproto ipv6 tcp dport { 4240, 5000-5001, 10250 } drop
}
}`)
}
Expand Down

0 comments on commit 44827e4

Please sign in to comment.