Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for validating the downstream ip of the connection #108

Merged
merged 3 commits into from
Jun 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 34 additions & 0 deletions policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,21 @@ import (
// In case an error is returned the connection is denied.
type PolicyFunc func(upstream net.Addr) (Policy, error)

// ConnPolicyFunc can be used to decide whether to trust the PROXY info
// based on connection policy options. If set, the connecting addresses
// (remote and local) are passed in as argument.
//
// See below for the different policies.
//
// In case an error is returned the connection is denied.
type ConnPolicyFunc func(connPolicyOptions ConnPolicyOptions) (Policy, error)

// ConnPolicyOptions contains the remote and local addresses of a connection.
type ConnPolicyOptions struct {
Upstream net.Addr
Downstream net.Addr
}

// Policy defines how a connection with a PROXY header address is treated.
type Policy int

Expand Down Expand Up @@ -170,3 +185,22 @@ func ipFromAddr(upstream net.Addr) (net.IP, error) {

return upstreamIP, nil
}

// IgnoreProxyHeaderNotOnInterface retuns a ConnPolicyFunc which can be used to
// decide whether to use or ignore PROXY headers depending on the connection
// being made on a specific interface. This policy can be used when the server
// is bound to multiple interfaces but wants to allow on only one interface.
func IgnoreProxyHeaderNotOnInterface(allowedIP net.IP) ConnPolicyFunc {
return func(connOpts ConnPolicyOptions) (Policy, error) {
ip, err := ipFromAddr(connOpts.Downstream)
if err != nil {
return REJECT, err
}

if allowedIP.Equal(ip) {
return USE, nil
}

return IGNORE, nil
}
}
38 changes: 38 additions & 0 deletions policy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -211,3 +211,41 @@ func TestSkipProxyHeaderForCIDR(t *testing.T) {
t.Errorf("Expected a REJECT policy for the %s address", upstream)
}
}

func TestIgnoreProxyHeaderNotOnInterface(t *testing.T) {
downstream, err := net.ResolveTCPAddr("tcp", "10.0.0.3:45738")
if err != nil {
t.Fatalf("err: %v", err)
}

var cases = []struct {
name string
policy ConnPolicyFunc
downstreamAddress net.Addr
expectedPolicy Policy
expectError bool
}{
{"ignore header for requests non on interface", IgnoreProxyHeaderNotOnInterface(net.ParseIP("192.0.2.1")), downstream, IGNORE, false},
{"use headers for requests on interface", IgnoreProxyHeaderNotOnInterface(net.ParseIP("10.0.0.3")), downstream, USE, false},
{"invalid address should return error", IgnoreProxyHeaderNotOnInterface(net.ParseIP("10.0.0.3")), failingAddr{}, REJECT, true},
}

for _, tc := range cases {
t.Run(tc.name, func(t *testing.T) {
policy, err := tc.policy(ConnPolicyOptions{
Downstream: tc.downstreamAddress,
})
if !tc.expectError && err != nil {
t.Fatalf("err: %v", err)
}
if tc.expectError && err == nil {
t.Fatal("Expected error, got none")
}

if policy != tc.expectedPolicy {
t.Fatalf("Expected policy %v, got %v", tc.expectedPolicy, policy)
}
})
}

}
21 changes: 18 additions & 3 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,14 @@ var DefaultReadHeaderTimeout = 10 * time.Second
// connections in order to prevent blocking operations. If no ReadHeaderTimeout
// is set, a default of 200ms will be used. This can be disabled by setting the
// timeout to < 0.
//
// Only one of Policy or ConnPolicy should be provided. If both are provided then
// a panic would occur during accept.
type Listener struct {
Listener net.Listener
Listener net.Listener
// Deprecated: use ConnPolicyFunc instead. This will be removed in future release.
Policy PolicyFunc
ConnPolicy ConnPolicyFunc
ValidateHeader Validator
ReadHeaderTimeout time.Duration
}
Expand Down Expand Up @@ -67,8 +72,18 @@ func (p *Listener) Accept() (net.Conn, error) {
}

proxyHeaderPolicy := USE
if p.Policy != nil {
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
if p.Policy != nil && p.ConnPolicy != nil {
panic("only one of policy or connpolicy must be provided.")
}
if p.Policy != nil || p.ConnPolicy != nil {
if p.Policy != nil {
proxyHeaderPolicy, err = p.Policy(conn.RemoteAddr())
} else {
proxyHeaderPolicy, err = p.ConnPolicy(ConnPolicyOptions{
Upstream: conn.RemoteAddr(),
Downstream: conn.LocalAddr(),
})
}
if err != nil {
// can't decide the policy, we can't accept the connection
conn.Close()
Expand Down
133 changes: 133 additions & 0 deletions protocol_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,82 @@ func TestAcceptReturnsErrorWhenPolicyFuncErrors(t *testing.T) {
}
}

func TestPanicIfPolicyAndConnPolicySet(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}

connPolicyFunc := func(connopts ConnPolicyOptions) (Policy, error) { return USE, nil }
policyFunc := func(upstream net.Addr) (Policy, error) { return USE, nil }

pl := &Listener{Listener: l, ConnPolicy: connPolicyFunc, Policy: policyFunc}

cliResult := make(chan error)
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
cliResult <- err
return
}
defer conn.Close()

close(cliResult)
}()

defer func() {
if r := recover(); r != nil {
fmt.Printf("accept did panic as expected with error, %v", r)
}
}()
conn, err := pl.Accept()
if err != nil {
t.Fatalf("Expected the accept to panic but did not and error is returned, got %v", err)
}

if conn != nil {
t.Fatalf("xpected the accept to panic but did not, got %v", conn)
}
t.Fatalf("expected the accept to panic but did not")
}

func TestAcceptReturnsErrorWhenConnPolicyFuncErrors(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}

expectedErr := fmt.Errorf("failure")
connPolicyFunc := func(connopts ConnPolicyOptions) (Policy, error) { return USE, expectedErr }

pl := &Listener{Listener: l, ConnPolicy: connPolicyFunc}

cliResult := make(chan error)
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
cliResult <- err
return
}
defer conn.Close()

close(cliResult)
}()

conn, err := pl.Accept()
if err != expectedErr {
t.Fatalf("Expected error %v, got %v", expectedErr, err)
}

if conn != nil {
t.Fatalf("Expected no connection, got %v", conn)
}
err = <-cliResult
if err != nil {
t.Fatalf("client error: %v", err)
}
}

func TestReadingIsRefusedWhenProxyHeaderRequiredButMissing(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
Expand Down Expand Up @@ -979,6 +1055,63 @@ func TestSkipProxyProtocolPolicy(t *testing.T) {
t.Fatalf("err: %v", err)
}

connPolicyFunc := func(connopts ConnPolicyOptions) (Policy, error) { return SKIP, nil }

pl := &Listener{
Listener: l,
ConnPolicy: connPolicyFunc,
}

cliResult := make(chan error)
ping := []byte("ping")
go func() {
conn, err := net.Dial("tcp", pl.Addr().String())
if err != nil {
cliResult <- err
return
}
defer conn.Close()

if _, err := conn.Write(ping); err != nil {
cliResult <- err
return
}

close(cliResult)
}()

conn, err := pl.Accept()
if err != nil {
t.Fatalf("err: %v", err)
}
defer conn.Close()

_, ok := conn.(*net.TCPConn)
if !ok {
t.Fatal("err: should be a tcp connection")
}
_ = conn.LocalAddr()
recv := make([]byte, 4)
if _, err = conn.Read(recv); err != nil {
t.Fatalf("Unexpected read error: %v", err)
}

if !bytes.Equal(ping, recv) {
t.Fatalf("Unexpected %s data while expected %s", recv, ping)
}

err = <-cliResult
if err != nil {
t.Fatalf("client error: %v", err)
}
}

func TestSkipProxyProtocolConnPolicy(t *testing.T) {
l, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("err: %v", err)
}

policyFunc := func(upstream net.Addr) (Policy, error) { return SKIP, nil }

pl := &Listener{
Expand Down
Loading