From 1ed2c822eac2236e463fc0b47abe19a4e01f25c8 Mon Sep 17 00:00:00 2001 From: Arvid Gotthard <66034456+mellonnen@users.noreply.github.com> Date: Thu, 16 Feb 2023 14:01:13 +0100 Subject: [PATCH] F/rendezvous id bug (#36) * feat: make conn context-aware * feat: attempt to solve id allocation bug * decided to make the mailbox communication is split into two channels, as it is easier to manage * feat: expand makefile --------- Co-authored-by: Zino Kader --- Makefile | 18 +-- cmd/wasm/main.go | 5 +- internal/conn/conn.go | 52 +++---- internal/conn/conn_test.go | 15 +- internal/receiver/receive.go | 22 +-- internal/receiver/receive_wasm.go | 19 +-- internal/receiver/receiver.go | 22 +-- internal/rendezvous/handlers.go | 220 +++++++++++++++++++----------- internal/rendezvous/mailbox.go | 7 +- internal/rendezvous/routes.go | 3 +- internal/sender/sender.go | 36 ++--- internal/sender/server.go | 2 +- internal/sender/transfer.go | 15 +- internal/sender/transfer_wasm.go | 13 +- portal/portal.go | 15 +- portal/portal_test.go | 4 +- ui/receiver/receiver.go | 15 +- ui/sender/sender.go | 21 +-- 18 files changed, 288 insertions(+), 216 deletions(-) diff --git a/Makefile b/Makefile index 59eb8bf..d0c7f6d 100644 --- a/Makefile +++ b/Makefile @@ -1,25 +1,27 @@ -.PHONY: build run lint test test-e2e build-wasm +.PHONY: serve lint test test-e2e build-wasm image LINKER_FLAGS = '-s -X main.version=${PORTAL_VERSION}' lint: golangci-lint run --timeout 5m ./... -build: - go build -o portal ./cmd/portal/ +build: + go build -o portal-bin ./cmd/portal/ build-production: CGO=0 go build -ldflags=${LINKER_FLAGS} -o portal ./cmd/portal/ +image: + docker build --tag rendezvous:latest . + +serve: image + docker run -dp 8080:8080 rendezvous:latest + build-wasm: GOOS=js GOARCH=wasm go build -o portal.wasm ./cmd/wasm/main.go -run: build - ./portal -p 8080 - test: go test -v -race -covermode=atomic -coverprofile=coverage.out -failfast -short ./... -test-e2e: - docker build --tag rendezvous:latest . +test-e2e: image go test -v -race -covermode=atomic -coverprofile=coverage.out -failfast ./... diff --git a/cmd/wasm/main.go b/cmd/wasm/main.go index 828f2ce..cc01911 100644 --- a/cmd/wasm/main.go +++ b/cmd/wasm/main.go @@ -4,6 +4,7 @@ package main import ( "bytes" + "context" "syscall/js" "github.com/SpatiumPortae/portal/portal" @@ -39,7 +40,7 @@ func SendJs() js.Func { } // Top-level promise. transferHandler := promiseHandler(func(resolve, reject js.Value) { - password, err, errCh := portal.Send(payload, int64(payload.Len()), cnf) + password, err, errCh := portal.Send(context.Background(), payload, int64(payload.Len()), cnf) if err != nil { reject.Invoke(Error.New(err.Error())) return @@ -75,7 +76,7 @@ func ReceiveJs() js.Func { var buf bytes.Buffer transferHandler := promiseHandler(func(resolve, reject js.Value) { - if err := portal.Receive(&buf, password, cnf); err != nil { + if err := portal.Receive(context.Background(), &buf, password, cnf); err != nil { reject.Invoke(Error.New(err.Error())) return } diff --git a/internal/conn/conn.go b/internal/conn/conn.go index 286f934..f1b9c3c 100644 --- a/internal/conn/conn.go +++ b/internal/conn/conn.go @@ -15,8 +15,8 @@ const MESSAGE_SIZE_LIMIT_BYTES = math.MaxInt64 - 1 // Conn is an interface that wraps a network connection. type Conn interface { - Read() ([]byte, error) - Write([]byte) error + Read(context.Context) ([]byte, error) + Write(context.Context, []byte) error } // ------------------ Conn implementations ------------------ @@ -26,15 +26,15 @@ type WS struct { Conn *websocket.Conn } -func (ws *WS) Read() ([]byte, error) { +func (ws *WS) Read(ctx context.Context) ([]byte, error) { // this limit is per-message and thus needs to be set before each read ws.Conn.SetReadLimit(MESSAGE_SIZE_LIMIT_BYTES) - _, payload, err := ws.Conn.Read(context.Background()) + _, payload, err := ws.Conn.Read(ctx) return payload, err } -func (ws *WS) Write(payload []byte) error { - return ws.Conn.Write(context.Background(), websocket.MessageBinary, payload) +func (ws *WS) Write(ctx context.Context, payload []byte) error { + return ws.Conn.Write(ctx, websocket.MessageBinary, payload) } // ------------------ Rendezvous Conn ------------------------ @@ -44,24 +44,24 @@ type Rendezvous struct { Conn Conn } -// ReadBytes reads raw bytes from the underlying connection. -func (r Rendezvous) ReadBytes() ([]byte, error) { - b, err := r.Conn.Read() +// ReadRaw reads raw bytes from the underlying connection. +func (r Rendezvous) ReadRaw(ctx context.Context) ([]byte, error) { + b, err := r.Conn.Read(ctx) if err != nil { return nil, err } return b, err } -// WriteBytes writes raw bytes to the underlying connection. -func (r Rendezvous) WriteBytes(b []byte) error { - err := r.Conn.Write(b) +// WriteRaw writes raw bytes to the underlying connection. +func (r Rendezvous) WriteRaw(ctx context.Context, b []byte) error { + err := r.Conn.Write(ctx, b) return err } // ReadMsg reads a rendezvous message from the underlying connection. -func (r Rendezvous) ReadMsg(expected ...rendezvous.MsgType) (rendezvous.Msg, error) { - b, err := r.Conn.Read() +func (r Rendezvous) ReadMsg(ctx context.Context, expected ...rendezvous.MsgType) (rendezvous.Msg, error) { + b, err := r.Conn.Read(ctx) if err != nil { return rendezvous.Msg{}, err } @@ -76,12 +76,12 @@ func (r Rendezvous) ReadMsg(expected ...rendezvous.MsgType) (rendezvous.Msg, err } // WriteMsg writes a rendezvous message to the underlying connection. -func (r Rendezvous) WriteMsg(msg rendezvous.Msg) error { +func (r Rendezvous) WriteMsg(ctx context.Context, msg rendezvous.Msg) error { payload, err := json.Marshal(msg) if err != nil { return err } - return r.Conn.Write(payload) + return r.Conn.Write(ctx, payload) } // ------------------ Transfer Conn ---------------------------- @@ -114,27 +114,27 @@ func (tc Transfer) Key() []byte { return tc.crypt.Key } -// ReadEncryptedBytes reads and decrypts bytes from the underlying connection. -func (t Transfer) ReadEncryptedBytes() ([]byte, error) { - b, err := t.Conn.Read() +// ReadRaw reads and decrypts raw bytes from the underlying connection. +func (t Transfer) ReadRaw(ctx context.Context) ([]byte, error) { + b, err := t.Conn.Read(ctx) if err != nil { return nil, err } return t.crypt.Decrypt(b) } -// WriteEncryptedBytes encrypts and writes the specified bytes to the underlying connection. -func (t Transfer) WriteEncryptedBytes(b []byte) error { +// WriteRaw encrypts and writes the raw bytes to the underlying connection. +func (t Transfer) WriteRaw(ctx context.Context, b []byte) error { enc, err := t.crypt.Encrypt(b) if err != nil { return nil } - return t.Conn.Write(enc) + return t.Conn.Write(ctx, enc) } // ReadMsg reads and decrypts the specified transfer message from the underlying connection. -func (t Transfer) ReadMsg(expected ...transfer.MsgType) (transfer.Msg, error) { - dec, err := t.ReadEncryptedBytes() +func (t Transfer) ReadMsg(ctx context.Context, expected ...transfer.MsgType) (transfer.Msg, error) { + dec, err := t.ReadRaw(ctx) if err != nil { return transfer.Msg{}, err } @@ -150,10 +150,10 @@ func (t Transfer) ReadMsg(expected ...transfer.MsgType) (transfer.Msg, error) { } // WriteMsg encrypts and writes the specified transfer message to the underlying connection. -func (t Transfer) WriteMsg(msg transfer.Msg) error { +func (t Transfer) WriteMsg(ctx context.Context, msg transfer.Msg) error { b, err := json.Marshal(msg) if err != nil { return err } - return t.WriteEncryptedBytes(b) + return t.WriteRaw(ctx, b) } diff --git a/internal/conn/conn_test.go b/internal/conn/conn_test.go index ea7537c..4af8379 100644 --- a/internal/conn/conn_test.go +++ b/internal/conn/conn_test.go @@ -1,6 +1,7 @@ package conn_test import ( + "context" "crypto/rand" "testing" @@ -14,12 +15,12 @@ type mockConn struct { conn chan []byte } -func (m mockConn) Write(b []byte) error { +func (m mockConn) Write(ctx context.Context, b []byte) error { m.conn <- b return nil } -func (m mockConn) Read() ([]byte, error) { +func (m mockConn) Read(ctx context.Context) ([]byte, error) { return <-m.conn, nil } @@ -32,10 +33,11 @@ func TestConn(t *testing.T) { r1 := conn.Rendezvous{Conn: conn1} r2 := conn.Rendezvous{Conn: conn2} - err := r1.WriteMsg(rendezvous.Msg{Type: rendezvous.SenderToRendezvousEstablish}) + ctx := context.Background() + err := r1.WriteMsg(ctx, rendezvous.Msg{Type: rendezvous.SenderToRendezvousEstablish}) assert.NoError(t, err) - msg, err := r2.ReadMsg() + msg, err := r2.ReadMsg(ctx) assert.NoError(t, err) assert.Equal(t, msg.Type, rendezvous.SenderToRendezvousEstablish) }) @@ -46,13 +48,14 @@ func TestConn(t *testing.T) { _, err := rand.Read(salt) assert.NoError(t, err) + ctx := context.Background() t1 := conn.TransferFromSession(&conn1, sessionkey, salt) t2 := conn.TransferFromSession(&conn2, sessionkey, salt) - err = t1.WriteMsg(transfer.Msg{Type: transfer.ReceiverHandshake}) + err = t1.WriteMsg(ctx, transfer.Msg{Type: transfer.ReceiverHandshake}) assert.NoError(t, err) - msg, err := t2.ReadMsg() + msg, err := t2.ReadMsg(ctx) assert.NoError(t, err) assert.Equal(t, msg.Type, transfer.ReceiverHandshake) }) diff --git a/internal/receiver/receive.go b/internal/receiver/receive.go index 484ea28..ed9c608 100644 --- a/internal/receiver/receive.go +++ b/internal/receiver/receive.go @@ -17,7 +17,7 @@ import ( // doReceive performs the transfer protocol on the receiving end. // This function is built for all platforms except js -func doReceive(relay conn.Transfer, addr string, dst io.Writer, msgs ...chan interface{}) error { +func doReceive(ctx context.Context, relay conn.Transfer, addr string, dst io.Writer, msgs ...chan interface{}) error { // Retrieve a unencrypted channel to rendezvous. rc := conn.Rendezvous{Conn: relay.Conn} @@ -27,10 +27,10 @@ func doReceive(relay conn.Transfer, addr string, dst io.Writer, msgs ...chan int if err != nil { tc = relay // Communicate to the sender that we are using relay transfer. - if err := relay.WriteMsg(transfer.Msg{Type: transfer.ReceiverRelayCommunication}); err != nil { + if err := relay.WriteMsg(ctx, transfer.Msg{Type: transfer.ReceiverRelayCommunication}); err != nil { return err } - _, err := relay.ReadMsg(transfer.SenderRelayAck) + _, err := relay.ReadMsg(ctx, transfer.SenderRelayAck) if err != nil { return err } @@ -41,12 +41,12 @@ func doReceive(relay conn.Transfer, addr string, dst io.Writer, msgs ...chan int } else { tc = direct // Communicate to the sender that we are doing direct communication. - if err := relay.WriteMsg(transfer.Msg{Type: transfer.ReceiverDirectCommunication}); err != nil { + if err := relay.WriteMsg(ctx, transfer.Msg{Type: transfer.ReceiverDirectCommunication}); err != nil { return err } // Tell rendezvous server that we can close the connection. - if err := rc.WriteMsg(rendezvous.Msg{Type: rendezvous.ReceiverToRendezvousClose}); err != nil { + if err := rc.WriteMsg(ctx, rendezvous.Msg{Type: rendezvous.ReceiverToRendezvousClose}); err != nil { return err } @@ -56,30 +56,30 @@ func doReceive(relay conn.Transfer, addr string, dst io.Writer, msgs ...chan int } // Request the payload and receive it. - if tc.WriteMsg(transfer.Msg{Type: transfer.ReceiverRequestPayload}) != nil { + if tc.WriteMsg(ctx, transfer.Msg{Type: transfer.ReceiverRequestPayload}) != nil { return err } - if err := receivePayload(tc, dst, msgs...); err != nil { + if err := receivePayload(ctx, tc, dst, msgs...); err != nil { return err } // Closing handshake. - if err := tc.WriteMsg(transfer.Msg{Type: transfer.ReceiverPayloadAck}); err != nil { + if err := tc.WriteMsg(ctx, transfer.Msg{Type: transfer.ReceiverPayloadAck}); err != nil { return err } - _, err = tc.ReadMsg(transfer.SenderClosing) + _, err = tc.ReadMsg(ctx, transfer.SenderClosing) if err != nil { return err } - if err := tc.WriteMsg(transfer.Msg{Type: transfer.ReceiverClosingAck}); err != nil { + if err := tc.WriteMsg(ctx, transfer.Msg{Type: transfer.ReceiverClosingAck}); err != nil { return err } // Tell rendezvous to close connection. - if err := rc.WriteMsg(rendezvous.Msg{Type: rendezvous.ReceiverToRendezvousClose}); err != nil { + if err := rc.WriteMsg(ctx, rendezvous.Msg{Type: rendezvous.ReceiverToRendezvousClose}); err != nil { return err } return nil diff --git a/internal/receiver/receive_wasm.go b/internal/receiver/receive_wasm.go index a01d327..0069c7d 100644 --- a/internal/receiver/receive_wasm.go +++ b/internal/receiver/receive_wasm.go @@ -3,6 +3,7 @@ package receiver import ( + "context" "io" "github.com/SpatiumPortae/portal/internal/conn" @@ -12,12 +13,12 @@ import ( // doReceive performs the transfer protocol on the receiving end. // This function is only built for the js platform. -func doReceive(relayTc conn.Transfer, addr string, dst io.Writer, msgs ...chan interface{}) error { +func doReceive(ctx context.Context, relayTc conn.Transfer, addr string, dst io.Writer, msgs ...chan interface{}) error { // Communicate to the sender that we are using relay transfer. - if err := relayTc.WriteMsg(transfer.Msg{Type: transfer.ReceiverRelayCommunication}); err != nil { + if err := relayTc.WriteMsg(ctx, transfer.Msg{Type: transfer.ReceiverRelayCommunication}); err != nil { return err } - _, err := relayTc.ReadMsg(transfer.SenderRelayAck) + _, err := relayTc.ReadMsg(ctx, transfer.SenderRelayAck) if err != nil { return err } @@ -27,31 +28,31 @@ func doReceive(relayTc conn.Transfer, addr string, dst io.Writer, msgs ...chan i } // Request the payload and receive it. - if relayTc.WriteMsg(transfer.Msg{Type: transfer.ReceiverRequestPayload}) != nil { + if relayTc.WriteMsg(ctx, transfer.Msg{Type: transfer.ReceiverRequestPayload}) != nil { return err } - if err := receivePayload(relayTc, dst, msgs...); err != nil { + if err := receivePayload(ctx, relayTc, dst, msgs...); err != nil { return err } // Closing handshake. - if err := relayTc.WriteMsg(transfer.Msg{Type: transfer.ReceiverPayloadAck}); err != nil { + if err := relayTc.WriteMsg(ctx, transfer.Msg{Type: transfer.ReceiverPayloadAck}); err != nil { return err } - _, err = relayTc.ReadMsg(transfer.SenderClosing) + _, err = relayTc.ReadMsg(ctx, transfer.SenderClosing) if err != nil { return err } - if err := relayTc.WriteMsg(transfer.Msg{Type: transfer.ReceiverClosingAck}); err != nil { + if err := relayTc.WriteMsg(ctx, transfer.Msg{Type: transfer.ReceiverClosingAck}); err != nil { return err } // Retrieve a unencrypted channel to rendezvous. rc := conn.Rendezvous{Conn: relayTc.Conn} - if err := rc.WriteMsg(rendezvous.Msg{Type: rendezvous.ReceiverToRendezvousClose}); err != nil { + if err := rc.WriteMsg(ctx, rendezvous.Msg{Type: rendezvous.ReceiverToRendezvousClose}); err != nil { return err } return nil diff --git a/internal/receiver/receiver.go b/internal/receiver/receiver.go index 75733b0..340883e 100644 --- a/internal/receiver/receiver.go +++ b/internal/receiver/receiver.go @@ -24,7 +24,7 @@ func ConnectRendezvous(addr string) (conn.Rendezvous, error) { } // SecureConnection performs the cryptographic handshake to resolve a secure connection. -func SecureConnection(rc conn.Rendezvous, pass string) (conn.Transfer, error) { +func SecureConnection(ctx context.Context, rc conn.Rendezvous, pass string) (conn.Transfer, error) { // Convenience for messaging in this function. type pakeMsg struct { pake *pake.Pake @@ -38,7 +38,7 @@ func SecureConnection(rc conn.Rendezvous, pass string) (conn.Transfer, error) { pakeCh <- pakeMsg{pake: p, err: err} }() - if err := rc.WriteMsg(rendezvous.Msg{ + if err := rc.WriteMsg(ctx, rendezvous.Msg{ Type: rendezvous.ReceiverToRendezvousEstablish, Payload: rendezvous.Payload{ Password: password.Hashed(pass), @@ -47,7 +47,7 @@ func SecureConnection(rc conn.Rendezvous, pass string) (conn.Transfer, error) { return conn.Transfer{}, err } - msg, err := rc.ReadMsg(rendezvous.RendezvousToReceiverPAKE) + msg, err := rc.ReadMsg(ctx, rendezvous.RendezvousToReceiverPAKE) if err != nil { return conn.Transfer{}, err } @@ -63,7 +63,7 @@ func SecureConnection(rc conn.Rendezvous, pass string) (conn.Transfer, error) { return conn.Transfer{}, err } - if err = rc.WriteMsg(rendezvous.Msg{ + if err = rc.WriteMsg(ctx, rendezvous.Msg{ Type: rendezvous.ReceiverToRendezvousPAKE, Payload: rendezvous.Payload{ Bytes: p.Bytes(), @@ -77,7 +77,7 @@ func SecureConnection(rc conn.Rendezvous, pass string) (conn.Transfer, error) { return conn.Transfer{}, err } - msg, err = rc.ReadMsg(rendezvous.RendezvousToReceiverSalt) + msg, err = rc.ReadMsg(ctx, rendezvous.RendezvousToReceiverSalt) if err != nil { return conn.Transfer{}, err } @@ -88,12 +88,12 @@ func SecureConnection(rc conn.Rendezvous, pass string) (conn.Transfer, error) { // Receive receives the payload over the transfer connection and writes it into the provided destination. // The Transfer can either be direct or using a relay. // The msgs channel communicates information about the receiving process while running. -func Receive(tc conn.Transfer, dst io.Writer, msgs ...chan interface{}) error { - if err := tc.WriteMsg(transfer.Msg{Type: transfer.ReceiverHandshake}); err != nil { +func Receive(ctx context.Context, tc conn.Transfer, dst io.Writer, msgs ...chan interface{}) error { + if err := tc.WriteMsg(ctx, transfer.Msg{Type: transfer.ReceiverHandshake}); err != nil { return err } - msg, err := tc.ReadMsg(transfer.SenderHandshake) + msg, err := tc.ReadMsg(ctx, transfer.SenderHandshake) if err != nil { return err } @@ -101,14 +101,14 @@ func Receive(tc conn.Transfer, dst io.Writer, msgs ...chan interface{}) error { if len(msgs) > 0 { msgs[0] <- msg.Payload.PayloadSize } - return doReceive(tc, fmt.Sprintf("%s:%d", msg.Payload.IP, msg.Payload.Port), dst, msgs...) + return doReceive(ctx, tc, fmt.Sprintf("%s:%d", msg.Payload.IP, msg.Payload.Port), dst, msgs...) } // receivePayload receives the payload over the provided connection and writes it into the desired location. -func receivePayload(tc conn.Transfer, dst io.Writer, msgs ...chan interface{}) error { +func receivePayload(ctx context.Context, tc conn.Transfer, dst io.Writer, msgs ...chan interface{}) error { writtenBytes := 0 for { - b, err := tc.ReadEncryptedBytes() + b, err := tc.ReadRaw(ctx) if err != nil { return err } diff --git a/internal/rendezvous/handlers.go b/internal/rendezvous/handlers.go index 8aee961..8b9140b 100644 --- a/internal/rendezvous/handlers.go +++ b/internal/rendezvous/handlers.go @@ -2,52 +2,59 @@ package rendezvous import ( + "context" "encoding/json" + "errors" + "io" "net/http" + "sync" "time" "github.com/SpatiumPortae/portal/internal/conn" "github.com/SpatiumPortae/portal/internal/logger" "github.com/SpatiumPortae/portal/protocol/rendezvous" "go.uber.org/zap" + "nhooyr.io/websocket" ) +// ------------------------------------------------------ Handlers ----------------------------------------------------- + // handleEstablishSender returns a websocket handler that communicates with the sender. func (s *Server) handleEstablishSender() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() logger, err := logger.FromContext(ctx) if err != nil { - w.WriteHeader(http.StatusInternalServerError) + return } c, err := conn.FromContext(ctx) if err != nil { logger.Error("getting Conn from request context", zap.Error(err)) - w.WriteHeader(http.StatusInternalServerError) return } - rc := conn.Rendezvous{Conn: c} logger.Info("sender connected") - // Bind an ID to this communication and send to the sender + id := s.ids.Bind() + logger = logger.With(zap.Int("id", id)) + logger.Info("bound id") defer func() { s.ids.Delete(id) - logger.Info("freed id", zap.Int("id", id)) + logger.Info("freed id") }() - err = rc.WriteMsg(rendezvous.Msg{ + + err = rc.WriteMsg(ctx, rendezvous.Msg{ Type: rendezvous.RendezvousToSenderBind, Payload: rendezvous.Payload{ ID: id, }, }) - logger.Info("bound id", zap.Int("id", id)) if err != nil { logger.Error("binding communcation ID", zap.Error(err)) return } - msg, err := rc.ReadMsg(rendezvous.SenderToRendezvousEstablish) + msg, err := rc.ReadMsg(ctx, rendezvous.SenderToRendezvousEstablish) if err != nil { logger.Error("establishing sender", zap.Error(err)) return @@ -55,8 +62,8 @@ func (s *Server) handleEstablishSender() http.HandlerFunc { // Allocate a mailbox for this communication. mailbox := &Mailbox{ - CommunicationChannel: make(chan []byte), - Quit: make(chan bool), + Sender: make(chan []byte), + Receiver: make(chan []byte), } s.mailboxes.StoreMailbox(msg.Payload.Password, mailbox) password := msg.Payload.Password @@ -64,14 +71,20 @@ func (s *Server) handleEstablishSender() http.HandlerFunc { // wait for receiver to connect or connection timeout timeout := time.NewTimer(RECEIVER_CONNECT_TIMEOUT) select { + case <-ctx.Done(): + if ctx.Err() != nil { + logger.Error("context error while waiting for receiver", zap.Error(ctx.Err())) + } + logger.Info("closing handler") + return case <-timeout.C: - logger.Warn("waiting for receiver timeout") + logger.Warn("waiting for receiver timed out") return - case <-mailbox.CommunicationChannel: + case <-mailbox.Sender: break } - err = rc.WriteMsg(rendezvous.Msg{ + err = rc.WriteMsg(ctx, rendezvous.Msg{ Type: rendezvous.RendezvousToSenderReady, }) @@ -80,19 +93,19 @@ func (s *Server) handleEstablishSender() http.HandlerFunc { return } - msg, err = rc.ReadMsg(rendezvous.SenderToRendezvousPAKE) + msg, err = rc.ReadMsg(ctx, rendezvous.SenderToRendezvousPAKE) if err != nil { w.WriteHeader(http.StatusBadRequest) logger.Error("performing PAKE exchange", zap.Error(err)) return } // send PAKE bytes to receiver - mailbox.CommunicationChannel <- msg.Payload.Bytes + mailbox.Receiver <- msg.Payload.Bytes // respond with receiver PAKE bytes - err = rc.WriteMsg(rendezvous.Msg{ + err = rc.WriteMsg(ctx, rendezvous.Msg{ Type: rendezvous.RendezvousToSenderPAKE, Payload: rendezvous.Payload{ - Bytes: <-mailbox.CommunicationChannel, + Bytes: <-mailbox.Sender, }, }) if err != nil { @@ -100,7 +113,7 @@ func (s *Server) handleEstablishSender() http.HandlerFunc { return } - msg, err = rc.ReadMsg(rendezvous.SenderToRendezvousSalt) + msg, err = rc.ReadMsg(ctx, rendezvous.SenderToRendezvousSalt) if err != nil { w.WriteHeader(http.StatusBadRequest) logger.Error("performing salt exchange", zap.Error(err)) @@ -108,21 +121,36 @@ func (s *Server) handleEstablishSender() http.HandlerFunc { } // Send the salt to the receiver. - mailbox.CommunicationChannel <- msg.Payload.Salt - // Start the relay of messages between the sender and receiver handlers. - logger.Info("starting relay service") - startRelay(s, rc, mailbox, password, logger) + mailbox.Receiver <- msg.Payload.Salt + // Start forwarder and relay + forward := make(chan []byte) + wg := sync.WaitGroup{} + relayCtx, cancel := context.WithCancel(ctx) + + wg.Add(2) + go s.forwarder(relayCtx, &wg, rc, forward, logger) + s.relay(relayCtx, &wg, rc, forward, mailbox.Sender, mailbox.Receiver, logger) + + // We want to make sure that the both forwarder and relay have terminated + cancel() + wg.Wait() + + // Deallocate mailbox + logger.Info("deallocating mailbox") + s.mailboxes.Delete(password) + logger.Info("sender closing") } } // handleEstablishReceiver returns a websocket handler that communicates with the sender. func (s *Server) handleEstablishReceiver() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - logger, err := logger.FromContext(r.Context()) + ctx := r.Context() + logger, err := logger.FromContext(ctx) if err != nil { w.WriteHeader(http.StatusInternalServerError) } - c, err := conn.FromContext(r.Context()) + c, err := conn.FromContext(ctx) if err != nil { w.WriteHeader(http.StatusInternalServerError) logger.Error("getting Conn from request context", zap.Error(err)) @@ -132,7 +160,7 @@ func (s *Server) handleEstablishReceiver() http.HandlerFunc { logger.Info("receiver connected") // Establish receiver. - msg, err := rc.ReadMsg(rendezvous.ReceiverToRendezvousEstablish) + msg, err := rc.ReadMsg(ctx, rendezvous.ReceiverToRendezvousEstablish) if err != nil { w.WriteHeader(http.StatusBadRequest) logger.Error("establishing receiver", zap.Error(err)) @@ -153,15 +181,14 @@ func (s *Server) handleEstablishReceiver() http.HandlerFunc { // this receiver was first, reserve this mailbox for it to receive mailbox.hasReceiver = true s.mailboxes.StoreMailbox(msg.Payload.Password, mailbox) - password := msg.Payload.Password // notify sender we are connected - mailbox.CommunicationChannel <- []byte{} + mailbox.Sender <- []byte{} // send back received sender PAKE bytes - err = rc.WriteMsg(rendezvous.Msg{ + err = rc.WriteMsg(ctx, rendezvous.Msg{ Type: rendezvous.RendezvousToReceiverPAKE, Payload: rendezvous.Payload{ - Bytes: <-mailbox.CommunicationChannel, + Bytes: <-mailbox.Receiver, }, }) if err != nil { @@ -169,18 +196,18 @@ func (s *Server) handleEstablishReceiver() http.HandlerFunc { return } - msg, err = rc.ReadMsg(rendezvous.ReceiverToRendezvousPAKE) + msg, err = rc.ReadMsg(ctx, rendezvous.ReceiverToRendezvousPAKE) if err != nil { w.WriteHeader(http.StatusBadRequest) logger.Error("performing PAKE exchange", zap.Error(err)) return } - mailbox.CommunicationChannel <- msg.Payload.Bytes - err = rc.WriteMsg(rendezvous.Msg{ + mailbox.Sender <- msg.Payload.Bytes + err = rc.WriteMsg(ctx, rendezvous.Msg{ Type: rendezvous.RendezvousToReceiverSalt, Payload: rendezvous.Payload{ - Salt: <-mailbox.CommunicationChannel, + Salt: <-mailbox.Receiver, }, }) if err != nil { @@ -188,66 +215,95 @@ func (s *Server) handleEstablishReceiver() http.HandlerFunc { logger.Error("exchanging salt", zap.Error(err)) } - logger.Info("start relay service") - startRelay(s, rc, mailbox, password, logger) + // Start forwarder and relay + forward := make(chan []byte) + wg := sync.WaitGroup{} + subCtx, cancel := context.WithCancel(ctx) + + wg.Add(2) + go s.forwarder(subCtx, &wg, rc, forward, logger) + s.relay(subCtx, &wg, rc, forward, mailbox.Receiver, mailbox.Sender, logger) + cancel() + + wg.Wait() + + logger.Info("receiver closing") } } -// starts the relay service, closing it on request (if i.e. clients can communicate directly) -func startRelay(s *Server, conn conn.Rendezvous, mailbox *Mailbox, mailboxPassword string, logger *zap.Logger) { - relayForwardCh := make(chan []byte) - // listen for incoming websocket messages from currently handled client - go func() { - for { - // read raw bytes and pass them on - payload, err := conn.ReadBytes() - if err != nil { - logger.Error("listening to incoming client messages", zap.Error(err)) - mailbox.Quit <- true - return - } - relayForwardCh <- payload - } - }() +//nolint:errcheck +func (s *Server) ping() http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.Write([]byte("pong")) + } +} + +// ------------------------------------------------------ Helpers ------------------------------------------------------ +// forwarder reads from the connection and forwards the message to the provided channel. +// Transient errors are logged on the provided logger. +func (s *Server) forwarder(ctx context.Context, wg *sync.WaitGroup, rc conn.Rendezvous, forward chan<- []byte, logger *zap.Logger) { + forwardLogger := logger.With(zap.String("component", "forwarder")) + forwardLogger.Info("starting forwarder") + defer wg.Done() + defer close(forward) for { - select { - // received payload from __other client__, relay it to our currently handled client - case relayReceivePayload := <-mailbox.CommunicationChannel: - err := conn.WriteBytes(relayReceivePayload) // send raw binary data - if err != nil { - logger.Error("relaying bytes, closing relay service", zap.Error(err)) - // close the relay service if writing failed - mailbox.Quit <- true - return - } + payload, err := rc.ReadRaw(ctx) + switch { + case errors.Is(err, io.EOF): + forwardLogger.Error("connection forcefully closed", zap.Error(err)) + return - // received payload from __currently handled__ client, relay it to other client - case relayForwardPayload := <-relayForwardCh: - var msg rendezvous.Msg - err := json.Unmarshal(relayForwardPayload, &msg) - // failed to unmarshal, we are in (encrypted) relay-mode, forward message directly to client - if err != nil { - mailbox.CommunicationChannel <- relayForwardPayload - } else { - logger.Info("closing relay service") - // close the relay service if sender requested it - mailbox.Quit <- true - return - } + // TODO: Extract closure status out to the Conn implementation + // Would be better to return a custom error, so we are not + // as heavily coupled with the websocket library - // deallocate mailbox and quit - case <-mailbox.Quit: - s.mailboxes.Delete(mailboxPassword) + case websocket.CloseStatus(err) == websocket.StatusNormalClosure: + forwardLogger.Info("connection closed, closing forwarder") + return + case errors.Is(err, context.Canceled): + forwardLogger.Info("context canceled, closing forwarder") + return + case err != nil: + forwardLogger.Error("error reading from connection, closing forwarder", zap.Error(err)) return } + + var msg rendezvous.Msg + if err := json.Unmarshal(payload, &msg); err == nil { + logger.Info("received unencrypted message, closing forwarder") + return + } + forward <- payload } } -//nolint:errcheck -func (s *Server) ping() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusOK) - w.Write([]byte("pong")) +func (s *Server) relay(ctx context.Context, wg *sync.WaitGroup, rc conn.Rendezvous, forward, relayIn <-chan []byte, relayOut chan<- []byte, logger *zap.Logger) { + relayLogger := logger.With(zap.String("component", "relay")) + relayLogger.Info("starting") + defer wg.Done() + defer close(relayOut) + for { + select { + case <-ctx.Done(): + relayLogger.Info("received context done signal") + return + case forwarded, more := <-forward: + if !more { + relayLogger.Info("forwarding channel closed, closing relay") + return + } + relayOut <- forwarded + case relayed, more := <-relayIn: + if !more { + relayLogger.Info("relay channel closed, closing relay") + return + } + if err := rc.WriteRaw(ctx, relayed); err != nil { + relayLogger.Error("writing relayed message to connection") + return + } + } } } diff --git a/internal/rendezvous/mailbox.go b/internal/rendezvous/mailbox.go index 434a5d3..d683890 100644 --- a/internal/rendezvous/mailbox.go +++ b/internal/rendezvous/mailbox.go @@ -8,9 +8,10 @@ import ( // Mailbox is a data structure that links together a sender and a receiver client. type Mailbox struct { - hasReceiver bool - CommunicationChannel chan []byte - Quit chan bool + hasReceiver bool + + Receiver chan []byte // messages to Receiver + Sender chan []byte // messages to Sender } type Mailboxes struct{ *sync.Map } diff --git a/internal/rendezvous/routes.go b/internal/rendezvous/routes.go index 54249f2..40508ba 100644 --- a/internal/rendezvous/routes.go +++ b/internal/rendezvous/routes.go @@ -8,8 +8,7 @@ import ( func (s *Server) routes() { s.router.HandleFunc("/ping", s.ping()) portal := s.router.PathPrefix("").Subrouter() - portal.Use(logger.Middleware(s.logger)) - portal.Use(conn.Middleware()) + portal.Use(logger.Middleware(s.logger), conn.Middleware()) portal.HandleFunc("/establish-sender", s.handleEstablishSender()) portal.HandleFunc("/establish-receiver", s.handleEstablishReceiver()) } diff --git a/internal/sender/sender.go b/internal/sender/sender.go index 003a2a2..b762752 100644 --- a/internal/sender/sender.go +++ b/internal/sender/sender.go @@ -26,7 +26,7 @@ func Init() error { } // ConnectRendezvous creates a connection with the rendezvous server and acquires a password associated with the connection -func ConnectRendezvous(addr string) (conn.Rendezvous, string, error) { +func ConnectRendezvous(ctx context.Context, addr string) (conn.Rendezvous, string, error) { ws, _, err := websocket.Dial(context.Background(), fmt.Sprintf("ws://%s/establish-sender", addr), nil) if err != nil { return conn.Rendezvous{}, "", err @@ -34,13 +34,13 @@ func ConnectRendezvous(addr string) (conn.Rendezvous, string, error) { rc := conn.Rendezvous{Conn: &conn.WS{Conn: ws}} - msg, err := rc.ReadMsg(rendezvous.RendezvousToSenderBind) + msg, err := rc.ReadMsg(ctx, rendezvous.RendezvousToSenderBind) if err != nil { return conn.Rendezvous{}, "", err } pass := password.Generate(msg.Payload.ID) - if err := rc.WriteMsg(rendezvous.Msg{ + if err := rc.WriteMsg(ctx, rendezvous.Msg{ Type: rendezvous.SenderToRendezvousEstablish, Payload: rendezvous.Payload{ Password: password.Hashed(pass), @@ -52,20 +52,20 @@ func ConnectRendezvous(addr string) (conn.Rendezvous, string, error) { } // SecureConnection does the cryptographic handshake in order to resolve a secure channel to do file transfer over. -func SecureConnection(rc conn.Rendezvous, password string) (conn.Transfer, error) { +func SecureConnection(ctx context.Context, rc conn.Rendezvous, password string) (conn.Transfer, error) { p, err := pake.InitCurve([]byte(password), 0, "p256") if err != nil { return conn.Transfer{}, err } // Wait for for the receiver to be ready. - _, err = rc.ReadMsg(rendezvous.RendezvousToSenderReady) + _, err = rc.ReadMsg(ctx, rendezvous.RendezvousToSenderReady) if err != nil { return conn.Transfer{}, err } // Start the key exchange. - err = rc.WriteMsg(rendezvous.Msg{ + err = rc.WriteMsg(ctx, rendezvous.Msg{ Type: rendezvous.SenderToRendezvousPAKE, Payload: rendezvous.Payload{ Bytes: p.Bytes(), @@ -75,7 +75,7 @@ func SecureConnection(rc conn.Rendezvous, password string) (conn.Transfer, error return conn.Transfer{}, err } - msg, err := rc.ReadMsg() + msg, err := rc.ReadMsg(ctx) if err != nil { return conn.Transfer{}, err } @@ -95,7 +95,7 @@ func SecureConnection(rc conn.Rendezvous, password string) (conn.Transfer, error return conn.Transfer{}, err } - err = rc.WriteMsg(rendezvous.Msg{ + err = rc.WriteMsg(ctx, rendezvous.Msg{ Type: rendezvous.SenderToRendezvousSalt, Payload: rendezvous.Payload{ Salt: salt, @@ -109,13 +109,13 @@ func SecureConnection(rc conn.Rendezvous, password string) (conn.Transfer, error } // Transfer performs the file transfer, either directly or using the Rendezvous server as a relay. -func Transfer(tc conn.Transfer, payload io.Reader, payloadSize int64, msgs ...chan interface{}) error { - return doTransfer(tc, payload, payloadSize, msgs...) +func Transfer(ctx context.Context, tc conn.Transfer, payload io.Reader, payloadSize int64, msgs ...chan interface{}) error { + return doTransfer(ctx, tc, payload, payloadSize, msgs...) } // transferSequence is a helper method that actually performs the transfer sequence. -func transferSequence(tc conn.Transfer, payload io.Reader, payloadSize int64, msgs ...chan interface{}) error { - _, err := tc.ReadMsg(transfer.ReceiverRequestPayload) +func transferSequence(ctx context.Context, tc conn.Transfer, payload io.Reader, payloadSize int64, msgs ...chan interface{}) error { + _, err := tc.ReadMsg(ctx, transfer.ReceiverRequestPayload) if err != nil { return err } @@ -124,20 +124,20 @@ func transferSequence(tc conn.Transfer, payload io.Reader, payloadSize int64, ms msgs[0] <- transfer.ReceiverRequestPayload } - if err := transferPayload(tc, payload, payloadSize, msgs...); err != nil { + if err := transferPayload(ctx, tc, payload, payloadSize, msgs...); err != nil { return err } - if err := tc.WriteMsg(transfer.Msg{Type: transfer.SenderPayloadSent}); err != nil { + if err := tc.WriteMsg(ctx, transfer.Msg{Type: transfer.SenderPayloadSent}); err != nil { return err } - _, err = tc.ReadMsg(transfer.ReceiverPayloadAck) + _, err = tc.ReadMsg(ctx, transfer.ReceiverPayloadAck) if err != nil { return err } - if err := tc.WriteMsg(transfer.Msg{Type: transfer.SenderClosing}); err != nil { + if err := tc.WriteMsg(ctx, transfer.Msg{Type: transfer.SenderClosing}); err != nil { return err } @@ -145,7 +145,7 @@ func transferSequence(tc conn.Transfer, payload io.Reader, payloadSize int64, ms } // transferPayload sends the files in chunks to the sender. -func transferPayload(tc conn.Transfer, payload io.Reader, payloadSize int64, msgs ...chan interface{}) error { +func transferPayload(ctx context.Context, tc conn.Transfer, payload io.Reader, payloadSize int64, msgs ...chan interface{}) error { bufReader := bufio.NewReader(payload) buffer := make([]byte, chunkSize(payloadSize)) bytesSent := 0 @@ -158,7 +158,7 @@ func transferPayload(tc conn.Transfer, payload io.Reader, payloadSize int64, msg if err != nil { return err } - err = tc.WriteEncryptedBytes(buffer[:n]) + err = tc.WriteRaw(ctx, buffer[:n]) if err != nil { return err } diff --git a/internal/sender/server.go b/internal/sender/server.go index cf8252f..691b61e 100644 --- a/internal/sender/server.go +++ b/internal/sender/server.go @@ -84,7 +84,7 @@ func (s *server) handleTransfer(key []byte, payload io.Reader, payloadSize int64 return } tc := conn.TransferFromKey(&conn.WS{Conn: ws}, key) - if err != transferSequence(tc, payload, payloadSize, msgs...) { + if err != transferSequence(context.Background(), tc, payload, payloadSize, msgs...) { s.Err = err return } diff --git a/internal/sender/transfer.go b/internal/sender/transfer.go index 6edd676..b1314b2 100644 --- a/internal/sender/transfer.go +++ b/internal/sender/transfer.go @@ -3,6 +3,7 @@ package sender import ( + "context" "fmt" "io" "log" @@ -14,8 +15,8 @@ import ( // doTransfer performs the file transfer, either directly or using the Rendezvous server as a relay. // This version is built for other platforms other than js (wasm) -func doTransfer(tc conn.Transfer, payload io.Reader, payloadSize int64, msgs ...chan interface{}) error { - _, err := tc.ReadMsg(transfer.ReceiverHandshake) +func doTransfer(ctx context.Context, tc conn.Transfer, payload io.Reader, payloadSize int64, msgs ...chan interface{}) error { + _, err := tc.ReadMsg(ctx, transfer.ReceiverHandshake) if err != nil { return err } @@ -39,7 +40,7 @@ func doTransfer(tc conn.Transfer, payload io.Reader, payloadSize int64, msgs ... return err } - if err := tc.WriteMsg(transfer.Msg{ + if err := tc.WriteMsg(ctx, transfer.Msg{ Type: transfer.SenderHandshake, Payload: transfer.Payload{ IP: ip, @@ -50,7 +51,7 @@ func doTransfer(tc conn.Transfer, payload io.Reader, payloadSize int64, msgs ... return err } - msg, err := tc.ReadMsg() + msg, err := tc.ReadMsg(ctx) if err != nil { return err } @@ -61,7 +62,7 @@ func doTransfer(tc conn.Transfer, payload io.Reader, payloadSize int64, msgs ... if len(msgs) > 0 { msgs[0] <- transfer.Direct } - if err := tc.WriteMsg(transfer.Msg{Type: transfer.SenderDirectAck}); err != nil { + if err := tc.WriteMsg(ctx, transfer.Msg{Type: transfer.SenderDirectAck}); err != nil { return err } @@ -74,11 +75,11 @@ func doTransfer(tc conn.Transfer, payload io.Reader, payloadSize int64, msgs ... if len(msgs) > 0 { msgs[0] <- transfer.Relay } - if err := tc.WriteMsg(transfer.Msg{Type: transfer.SenderRelayAck}); err != nil { + if err := tc.WriteMsg(ctx, transfer.Msg{Type: transfer.SenderRelayAck}); err != nil { return err } - return transferSequence(tc, payload, payloadSize, msgs...) + return transferSequence(ctx, tc, payload, payloadSize, msgs...) default: return transfer.Error{ diff --git a/internal/sender/transfer_wasm.go b/internal/sender/transfer_wasm.go index dcbab62..c15c830 100644 --- a/internal/sender/transfer_wasm.go +++ b/internal/sender/transfer_wasm.go @@ -3,6 +3,7 @@ package sender import ( + "context" "io" "net" @@ -12,13 +13,13 @@ import ( // doTransfer performs the file transfer directly, no relay. This function is only built for the // js platform (wasm) -func doTransfer(tc conn.Transfer, payload io.Reader, payloadSize int64, msgs ...chan interface{}) error { - _, err := tc.ReadMsg(transfer.ReceiverHandshake) +func doTransfer(ctx context.Context, tc conn.Transfer, payload io.Reader, payloadSize int64, msgs ...chan interface{}) error { + _, err := tc.ReadMsg(ctx, transfer.ReceiverHandshake) if err != nil { return err } - if err := tc.WriteMsg(transfer.Msg{ + if err := tc.WriteMsg(ctx, transfer.Msg{ Type: transfer.SenderHandshake, Payload: transfer.Payload{ IP: net.IP{}, @@ -29,7 +30,7 @@ func doTransfer(tc conn.Transfer, payload io.Reader, payloadSize int64, msgs ... return err } - msg, err := tc.ReadMsg() + msg, err := tc.ReadMsg(ctx) if err != nil { return err } @@ -37,10 +38,10 @@ func doTransfer(tc conn.Transfer, payload io.Reader, payloadSize int64, msgs ... switch msg.Type { // Direct transfer. case transfer.ReceiverRelayCommunication: - if err := tc.WriteMsg(transfer.Msg{Type: transfer.SenderRelayAck}); err != nil { + if err := tc.WriteMsg(ctx, transfer.Msg{Type: transfer.SenderRelayAck}); err != nil { return err } - return transferSequence(tc, payload, payloadSize) + return transferSequence(ctx, tc, payload, payloadSize) default: return transfer.Error{ diff --git a/portal/portal.go b/portal/portal.go index 8b8dd02..ad260bc 100644 --- a/portal/portal.go +++ b/portal/portal.go @@ -1,6 +1,7 @@ package portal import ( + "context" "io" "github.com/SpatiumPortae/portal/internal/receiver" @@ -12,24 +13,24 @@ import ( // asynchronously. The function returns a portal password, a error from the rendezvous // intial rendezvous connection, and a channel on which errors from the transfer sequence // can be listend to. The provided config will be merged with the default config. -func Send(payload io.Reader, payloadSize int64, config *Config) (string, error, chan error) { +func Send(ctx context.Context, payload io.Reader, payloadSize int64, config *Config) (string, error, chan error) { merged := MergeConfig(defaultConfig, config) if err := sender.Init(); err != nil { return "", err, nil } errC := make(chan error, 1) // buffer channel as to not block send. - rc, password, err := sender.ConnectRendezvous(merged.RendezvousAddr) + rc, password, err := sender.ConnectRendezvous(ctx, merged.RendezvousAddr) if err != nil { return "", err, nil } go func() { defer close(errC) - tc, err := sender.SecureConnection(rc, password) + tc, err := sender.SecureConnection(ctx, rc, password) if err != nil { errC <- err return } - if err := sender.Transfer(tc, payload, payloadSize); err != nil { + if err := sender.Transfer(ctx, tc, payload, payloadSize); err != nil { errC <- err return } @@ -40,17 +41,17 @@ func Send(payload io.Reader, payloadSize int64, config *Config) (string, error, // Receive executes the portal receive sequence. The payload is written // to the provided writer. The provided config will be merged with the // default config. -func Receive(dst io.Writer, password string, config *Config) error { +func Receive(ctx context.Context, dst io.Writer, password string, config *Config) error { merged := MergeConfig(defaultConfig, config) rc, err := receiver.ConnectRendezvous(merged.RendezvousAddr) if err != nil { return err } - tc, err := receiver.SecureConnection(rc, password) + tc, err := receiver.SecureConnection(ctx, rc, password) if err != nil { return err } - if err := receiver.Receive(tc, dst); err != nil { + if err := receiver.Receive(ctx, tc, dst); err != nil { return err } return nil diff --git a/portal/portal_test.go b/portal/portal_test.go index aed5949..65b0c43 100644 --- a/portal/portal_test.go +++ b/portal/portal_test.go @@ -42,10 +42,10 @@ func TestE2E(t *testing.T) { in := bytes.NewBufferString(oracle) out := &bytes.Buffer{} - password, err, errC := portal.Send(in, int64(in.Len()), &config) + password, err, errC := portal.Send(context.Background(), in, int64(in.Len()), &config) assert.Nil(t, err) - err = portal.Receive(out, password, &config) + err = portal.Receive(context.Background(), out, password, &config) assert.Nil(t, err) assert.Nil(t, <-errC) assert.Equal(t, oracle, out.String()) diff --git a/ui/receiver/receiver.go b/ui/receiver/receiver.go index 013c1fc..045f8da 100644 --- a/ui/receiver/receiver.go +++ b/ui/receiver/receiver.go @@ -1,6 +1,7 @@ package receiver import ( + "context" "fmt" "math" "os" @@ -67,6 +68,7 @@ type model struct { password string errorMessage string + ctx context.Context msgs chan interface{} rendezvousAddr string @@ -94,6 +96,7 @@ func New(addr string, password string, opts ...Option) *tea.Program { rendezvousAddr: addr, help: help.New(), keys: ui.Keys, + ctx: context.Background(), } for _, opt := range opts { opt(&m) @@ -133,12 +136,12 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { case connectMsg: message := fmt.Sprintf("Connected to Portal server (%s)", m.rendezvousAddr) - return m, ui.TaskCmd(message, secureCmd(msg.conn, m.password)) + return m, ui.TaskCmd(message, secureCmd(m.ctx, msg.conn, m.password)) case ui.SecureMsg: message := "Established encrypted connection to sender" return m, ui.TaskCmd(message, - tea.Batch(listenReceiveCmd(m.msgs), receiveCmd(msg.Conn, m.msgs))) + tea.Batch(listenReceiveCmd(m.msgs), receiveCmd(m.ctx, msg.Conn, m.msgs))) case payloadSizeMsg: m.payloadSize = msg.size @@ -283,9 +286,9 @@ func connectCmd(addr string) tea.Cmd { } } -func secureCmd(rc conn.Rendezvous, password string) tea.Cmd { +func secureCmd(ctx context.Context, rc conn.Rendezvous, password string) tea.Cmd { return func() tea.Msg { - tc, err := receiver.SecureConnection(rc, password) + tc, err := receiver.SecureConnection(ctx, rc, password) if err != nil { return ui.ErrorMsg(err) } @@ -293,13 +296,13 @@ func secureCmd(rc conn.Rendezvous, password string) tea.Cmd { } } -func receiveCmd(tc conn.Transfer, msgs ...chan interface{}) tea.Cmd { +func receiveCmd(ctx context.Context, tc conn.Transfer, msgs ...chan interface{}) tea.Cmd { return func() tea.Msg { temp, err := os.CreateTemp(os.TempDir(), file.RECEIVE_TEMP_FILE_NAME_PREFIX) if err != nil { return ui.ErrorMsg(err) } - if err := receiver.Receive(tc, temp, msgs...); err != nil { + if err := receiver.Receive(ctx, tc, temp, msgs...); err != nil { return ui.ErrorMsg(err) } return receiveDoneMsg{temp: temp} diff --git a/ui/sender/sender.go b/ui/sender/sender.go index 05d0e2c..21eaac1 100644 --- a/ui/sender/sender.go +++ b/ui/sender/sender.go @@ -1,6 +1,7 @@ package sender import ( + "context" "fmt" "io" "os" @@ -72,6 +73,7 @@ type model struct { transferType transfer.Type // defaults to 0 (Unknown) errorMessage string readyToSend bool + ctx context.Context msgs chan interface{} @@ -104,6 +106,7 @@ func New(filenames []string, addr string, opts ...Option) *tea.Program { help: help.New(), keys: ui.Keys, copyMessageTimer: timer.NewWithInterval(ui.TEMP_UI_MESSAGE_DURATION, 100*time.Millisecond), + ctx: context.Background(), } m.keys.FileListUp.SetEnabled(true) m.keys.FileListDown.SetEnabled(true) @@ -119,7 +122,7 @@ func (m model) Init() tea.Cmd { if m.version != nil { versionCmd = ui.VersionCmd(*m.version) } - return tea.Sequence(versionCmd, tea.Batch(m.spinner.Tick, readFilesCmd(m.fileNames), connectCmd(m.rendezvousAddr))) + return tea.Sequence(versionCmd, tea.Batch(m.spinner.Tick, readFilesCmd(m.fileNames), connectCmd(m.ctx, m.rendezvousAddr))) } // ------------------------------------------------------- Update ------------------------------------------------------ @@ -170,7 +173,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { m.keys.CopyPassword.SetEnabled(true) m.password = msg.password connectMessage := fmt.Sprintf("Connected to Portal server (%s)", m.rendezvousAddr) - return m, ui.TaskCmd(connectMessage, secureCmd(msg.conn, msg.password)) + return m, ui.TaskCmd(connectMessage, secureCmd(m.ctx, msg.conn, msg.password)) case timer.TickMsg: var cmd tea.Cmd @@ -207,7 +210,7 @@ func (m model) Update(msg tea.Msg) (tea.Model, tea.Cmd) { } cmd := tea.Batch( listenTransferCmd(m.msgs), - transferCmd(msg.Conn, m.payload, m.payloadSize, m.msgs)) + transferCmd(m.ctx, msg.Conn, m.payload, m.payloadSize, m.msgs)) return m, cmd case ui.TransferStateMessage: @@ -350,9 +353,9 @@ func (m model) View() string { // ------------------------------------------------------ Commands ----------------------------------------------------- // connectCmd command that connects to the rendezvous server. -func connectCmd(addr string) tea.Cmd { +func connectCmd(ctx context.Context, addr string) tea.Cmd { return func() tea.Msg { - rc, password, err := sender.ConnectRendezvous(addr) + rc, password, err := sender.ConnectRendezvous(ctx, addr) if err != nil { return ui.ErrorMsg(err) } @@ -361,9 +364,9 @@ func connectCmd(addr string) tea.Cmd { } // secureCmd command that secures a connection for transfer. -func secureCmd(rc conn.Rendezvous, password string) tea.Cmd { +func secureCmd(ctx context.Context, rc conn.Rendezvous, password string) tea.Cmd { return func() tea.Msg { - tc, err := sender.SecureConnection(rc, password) + tc, err := sender.SecureConnection(ctx, rc, password) if err != nil { return ui.ErrorMsg(err) } @@ -373,9 +376,9 @@ func secureCmd(rc conn.Rendezvous, password string) tea.Cmd { // transferCmd command that does the transfer sequence. // The msgs channel is used to provide intermediate messages to the ui. -func transferCmd(tc conn.Transfer, payload io.Reader, payloadSize int64, msgs ...chan interface{}) tea.Cmd { +func transferCmd(ctx context.Context, tc conn.Transfer, payload io.Reader, payloadSize int64, msgs ...chan interface{}) tea.Cmd { return func() tea.Msg { - err := sender.Transfer(tc, payload, payloadSize, msgs...) + err := sender.Transfer(ctx, tc, payload, payloadSize, msgs...) if err != nil { return ui.ErrorMsg(err) }