Skip to content

Commit

Permalink
ssh: add (*Client).DialContext method
Browse files Browse the repository at this point in the history
This change adds DialContext to ssh.Client, which opens a TCP-IP
connection tunneled over the SSH connection. This is useful for
proxying network connections, e.g. setting
(net/http.Transport).DialContext.

Fixes golang/go#20288.

Change-Id: I110494c00962424ea803065535ebe2209364ac27
GitHub-Last-Rev: 3176984
GitHub-Pull-Request: #260
Reviewed-on: https://go-review.googlesource.com/c/crypto/+/504735
Run-TryBot: Nicola Murino <[email protected]>
Run-TryBot: Han-Wen Nienhuys <[email protected]>
Auto-Submit: Nicola Murino <[email protected]>
Reviewed-by: Han-Wen Nienhuys <[email protected]>
Reviewed-by: Dmitri Shuralyov <[email protected]>
TryBot-Result: Gopher Robot <[email protected]>
Reviewed-by: Nicola Murino <[email protected]>
Commit-Queue: Nicola Murino <[email protected]>
  • Loading branch information
ydnar authored and gopherbot committed Nov 27, 2023
1 parent 1c17e20 commit b2d7c26
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 1 deletion.
35 changes: 35 additions & 0 deletions ssh/tcpip.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
package ssh

import (
"context"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -332,6 +333,40 @@ func (l *tcpListener) Addr() net.Addr {
return l.laddr
}

// DialContext initiates a connection to the addr from the remote host.
//
// The provided Context must be non-nil. If the context expires before the
// connection is complete, an error is returned. Once successfully connected,
// any expiration of the context will not affect the connection.
//
// See func Dial for additional information.
func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
type connErr struct {
conn net.Conn
err error
}
ch := make(chan connErr)
go func() {
conn, err := c.Dial(n, addr)
select {
case ch <- connErr{conn, err}:
case <-ctx.Done():
if conn != nil {
conn.Close()
}
}
}()
select {
case res := <-ch:
return res.conn, res.err
case <-ctx.Done():
return nil, ctx.Err()
}
}

// Dial initiates a connection to the addr from the remote host.
// The resulting connection has a zero LocalAddr() and RemoteAddr().
func (c *Client) Dial(n, addr string) (net.Conn, error) {
Expand Down
33 changes: 33 additions & 0 deletions ssh/tcpip_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
package ssh

import (
"context"
"net"
"testing"
"time"
)

func TestAutoPortListenBroken(t *testing.T) {
Expand All @@ -18,3 +21,33 @@ func TestAutoPortListenBroken(t *testing.T) {
t.Errorf("version %q marked as broken", works)
}
}

func TestClientImplementsDialContext(t *testing.T) {
type ContextDialer interface {
DialContext(context.Context, string, string) (net.Conn, error)
}
// Belt and suspenders assertion, since package net does not
// declare a ContextDialer type.
var _ ContextDialer = &net.Dialer{}
var _ ContextDialer = &Client{}
}

func TestClientDialContextWithCancel(t *testing.T) {
c := &Client{}
ctx, cancel := context.WithCancel(context.Background())
cancel()
_, err := c.DialContext(ctx, "tcp", "localhost:1000")
if err != context.Canceled {
t.Errorf("DialContext: got nil error, expected %v", context.Canceled)
}
}

func TestClientDialContextWithDeadline(t *testing.T) {
c := &Client{}
ctx, cancel := context.WithDeadline(context.Background(), time.Now())
defer cancel()
_, err := c.DialContext(ctx, "tcp", "localhost:1000")
if err != context.DeadlineExceeded {
t.Errorf("DialContext: got nil error, expected %v", context.DeadlineExceeded)
}
}
7 changes: 6 additions & 1 deletion ssh/test/dial_unix_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package test
// direct-tcpip and direct-streamlocal functional tests

import (
"context"
"fmt"
"io"
"net"
Expand Down Expand Up @@ -46,7 +47,11 @@ func testDial(t *testing.T, n, listenAddr string, x dialTester) {
}
}()

conn, err := sshConn.Dial(n, l.Addr().String())
ctx, cancel := context.WithCancel(context.Background())
conn, err := sshConn.DialContext(ctx, n, l.Addr().String())
// Canceling the context after dial should have no effect
// on the opened connection.
cancel()
if err != nil {
t.Fatalf("Dial: %v", err)
}
Expand Down

0 comments on commit b2d7c26

Please sign in to comment.