Skip to content

Commit

Permalink
F/rendezvous id bug (#36)
Browse files Browse the repository at this point in the history
* feat: make conn context-aware

* feat: attempt to solve id allocation bug

 * decided to make the mailbox communication is split into two channels,
  as it is easier to manage

* feat: expand makefile

---------

Co-authored-by: Zino Kader <[email protected]>
  • Loading branch information
mellonnen and ZinoKader committed Feb 16, 2023
1 parent 458213f commit 1ed2c82
Show file tree
Hide file tree
Showing 18 changed files with 288 additions and 216 deletions.
18 changes: 10 additions & 8 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
.PHONY: build run lint test test-e2e build-wasm
.PHONY: serve lint test test-e2e build-wasm image

LINKER_FLAGS = '-s -X main.version=${PORTAL_VERSION}'

lint:
golangci-lint run --timeout 5m ./...

build:
go build -o portal ./cmd/portal/
build:
go build -o portal-bin ./cmd/portal/

build-production:
CGO=0 go build -ldflags=${LINKER_FLAGS} -o portal ./cmd/portal/

image:
docker build --tag rendezvous:latest .

serve: image
docker run -dp 8080:8080 rendezvous:latest

build-wasm:
GOOS=js GOARCH=wasm go build -o portal.wasm ./cmd/wasm/main.go

run: build
./portal -p 8080

test:
go test -v -race -covermode=atomic -coverprofile=coverage.out -failfast -short ./...

test-e2e:
docker build --tag rendezvous:latest .
test-e2e: image
go test -v -race -covermode=atomic -coverprofile=coverage.out -failfast ./...
5 changes: 3 additions & 2 deletions cmd/wasm/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ package main

import (
"bytes"
"context"
"syscall/js"

"github.com/SpatiumPortae/portal/portal"
Expand Down Expand Up @@ -39,7 +40,7 @@ func SendJs() js.Func {
}
// Top-level promise.
transferHandler := promiseHandler(func(resolve, reject js.Value) {
password, err, errCh := portal.Send(payload, int64(payload.Len()), cnf)
password, err, errCh := portal.Send(context.Background(), payload, int64(payload.Len()), cnf)
if err != nil {
reject.Invoke(Error.New(err.Error()))
return
Expand Down Expand Up @@ -75,7 +76,7 @@ func ReceiveJs() js.Func {

var buf bytes.Buffer
transferHandler := promiseHandler(func(resolve, reject js.Value) {
if err := portal.Receive(&buf, password, cnf); err != nil {
if err := portal.Receive(context.Background(), &buf, password, cnf); err != nil {
reject.Invoke(Error.New(err.Error()))
return
}
Expand Down
52 changes: 26 additions & 26 deletions internal/conn/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ const MESSAGE_SIZE_LIMIT_BYTES = math.MaxInt64 - 1

// Conn is an interface that wraps a network connection.
type Conn interface {
Read() ([]byte, error)
Write([]byte) error
Read(context.Context) ([]byte, error)
Write(context.Context, []byte) error
}

// ------------------ Conn implementations ------------------
Expand All @@ -26,15 +26,15 @@ type WS struct {
Conn *websocket.Conn
}

func (ws *WS) Read() ([]byte, error) {
func (ws *WS) Read(ctx context.Context) ([]byte, error) {
// this limit is per-message and thus needs to be set before each read
ws.Conn.SetReadLimit(MESSAGE_SIZE_LIMIT_BYTES)
_, payload, err := ws.Conn.Read(context.Background())
_, payload, err := ws.Conn.Read(ctx)
return payload, err
}

func (ws *WS) Write(payload []byte) error {
return ws.Conn.Write(context.Background(), websocket.MessageBinary, payload)
func (ws *WS) Write(ctx context.Context, payload []byte) error {
return ws.Conn.Write(ctx, websocket.MessageBinary, payload)
}

// ------------------ Rendezvous Conn ------------------------
Expand All @@ -44,24 +44,24 @@ type Rendezvous struct {
Conn Conn
}

// ReadBytes reads raw bytes from the underlying connection.
func (r Rendezvous) ReadBytes() ([]byte, error) {
b, err := r.Conn.Read()
// ReadRaw reads raw bytes from the underlying connection.
func (r Rendezvous) ReadRaw(ctx context.Context) ([]byte, error) {
b, err := r.Conn.Read(ctx)
if err != nil {
return nil, err
}
return b, err
}

// WriteBytes writes raw bytes to the underlying connection.
func (r Rendezvous) WriteBytes(b []byte) error {
err := r.Conn.Write(b)
// WriteRaw writes raw bytes to the underlying connection.
func (r Rendezvous) WriteRaw(ctx context.Context, b []byte) error {
err := r.Conn.Write(ctx, b)
return err
}

// ReadMsg reads a rendezvous message from the underlying connection.
func (r Rendezvous) ReadMsg(expected ...rendezvous.MsgType) (rendezvous.Msg, error) {
b, err := r.Conn.Read()
func (r Rendezvous) ReadMsg(ctx context.Context, expected ...rendezvous.MsgType) (rendezvous.Msg, error) {
b, err := r.Conn.Read(ctx)
if err != nil {
return rendezvous.Msg{}, err
}
Expand All @@ -76,12 +76,12 @@ func (r Rendezvous) ReadMsg(expected ...rendezvous.MsgType) (rendezvous.Msg, err
}

// WriteMsg writes a rendezvous message to the underlying connection.
func (r Rendezvous) WriteMsg(msg rendezvous.Msg) error {
func (r Rendezvous) WriteMsg(ctx context.Context, msg rendezvous.Msg) error {
payload, err := json.Marshal(msg)
if err != nil {
return err
}
return r.Conn.Write(payload)
return r.Conn.Write(ctx, payload)
}

// ------------------ Transfer Conn ----------------------------
Expand Down Expand Up @@ -114,27 +114,27 @@ func (tc Transfer) Key() []byte {
return tc.crypt.Key
}

// ReadEncryptedBytes reads and decrypts bytes from the underlying connection.
func (t Transfer) ReadEncryptedBytes() ([]byte, error) {
b, err := t.Conn.Read()
// ReadRaw reads and decrypts raw bytes from the underlying connection.
func (t Transfer) ReadRaw(ctx context.Context) ([]byte, error) {
b, err := t.Conn.Read(ctx)
if err != nil {
return nil, err
}
return t.crypt.Decrypt(b)
}

// WriteEncryptedBytes encrypts and writes the specified bytes to the underlying connection.
func (t Transfer) WriteEncryptedBytes(b []byte) error {
// WriteRaw encrypts and writes the raw bytes to the underlying connection.
func (t Transfer) WriteRaw(ctx context.Context, b []byte) error {
enc, err := t.crypt.Encrypt(b)
if err != nil {
return nil
}
return t.Conn.Write(enc)
return t.Conn.Write(ctx, enc)
}

// ReadMsg reads and decrypts the specified transfer message from the underlying connection.
func (t Transfer) ReadMsg(expected ...transfer.MsgType) (transfer.Msg, error) {
dec, err := t.ReadEncryptedBytes()
func (t Transfer) ReadMsg(ctx context.Context, expected ...transfer.MsgType) (transfer.Msg, error) {
dec, err := t.ReadRaw(ctx)
if err != nil {
return transfer.Msg{}, err
}
Expand All @@ -150,10 +150,10 @@ func (t Transfer) ReadMsg(expected ...transfer.MsgType) (transfer.Msg, error) {
}

// WriteMsg encrypts and writes the specified transfer message to the underlying connection.
func (t Transfer) WriteMsg(msg transfer.Msg) error {
func (t Transfer) WriteMsg(ctx context.Context, msg transfer.Msg) error {
b, err := json.Marshal(msg)
if err != nil {
return err
}
return t.WriteEncryptedBytes(b)
return t.WriteRaw(ctx, b)
}
15 changes: 9 additions & 6 deletions internal/conn/conn_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package conn_test

import (
"context"
"crypto/rand"
"testing"

Expand All @@ -14,12 +15,12 @@ type mockConn struct {
conn chan []byte
}

func (m mockConn) Write(b []byte) error {
func (m mockConn) Write(ctx context.Context, b []byte) error {
m.conn <- b
return nil
}

func (m mockConn) Read() ([]byte, error) {
func (m mockConn) Read(ctx context.Context) ([]byte, error) {
return <-m.conn, nil
}

Expand All @@ -32,10 +33,11 @@ func TestConn(t *testing.T) {
r1 := conn.Rendezvous{Conn: conn1}
r2 := conn.Rendezvous{Conn: conn2}

err := r1.WriteMsg(rendezvous.Msg{Type: rendezvous.SenderToRendezvousEstablish})
ctx := context.Background()
err := r1.WriteMsg(ctx, rendezvous.Msg{Type: rendezvous.SenderToRendezvousEstablish})
assert.NoError(t, err)

msg, err := r2.ReadMsg()
msg, err := r2.ReadMsg(ctx)
assert.NoError(t, err)
assert.Equal(t, msg.Type, rendezvous.SenderToRendezvousEstablish)
})
Expand All @@ -46,13 +48,14 @@ func TestConn(t *testing.T) {
_, err := rand.Read(salt)
assert.NoError(t, err)

ctx := context.Background()
t1 := conn.TransferFromSession(&conn1, sessionkey, salt)
t2 := conn.TransferFromSession(&conn2, sessionkey, salt)

err = t1.WriteMsg(transfer.Msg{Type: transfer.ReceiverHandshake})
err = t1.WriteMsg(ctx, transfer.Msg{Type: transfer.ReceiverHandshake})
assert.NoError(t, err)

msg, err := t2.ReadMsg()
msg, err := t2.ReadMsg(ctx)
assert.NoError(t, err)
assert.Equal(t, msg.Type, transfer.ReceiverHandshake)
})
Expand Down
22 changes: 11 additions & 11 deletions internal/receiver/receive.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ import (

// doReceive performs the transfer protocol on the receiving end.
// This function is built for all platforms except js
func doReceive(relay conn.Transfer, addr string, dst io.Writer, msgs ...chan interface{}) error {
func doReceive(ctx context.Context, relay conn.Transfer, addr string, dst io.Writer, msgs ...chan interface{}) error {

// Retrieve a unencrypted channel to rendezvous.
rc := conn.Rendezvous{Conn: relay.Conn}
Expand All @@ -27,10 +27,10 @@ func doReceive(relay conn.Transfer, addr string, dst io.Writer, msgs ...chan int
if err != nil {
tc = relay
// Communicate to the sender that we are using relay transfer.
if err := relay.WriteMsg(transfer.Msg{Type: transfer.ReceiverRelayCommunication}); err != nil {
if err := relay.WriteMsg(ctx, transfer.Msg{Type: transfer.ReceiverRelayCommunication}); err != nil {
return err
}
_, err := relay.ReadMsg(transfer.SenderRelayAck)
_, err := relay.ReadMsg(ctx, transfer.SenderRelayAck)
if err != nil {
return err
}
Expand All @@ -41,12 +41,12 @@ func doReceive(relay conn.Transfer, addr string, dst io.Writer, msgs ...chan int
} else {
tc = direct
// Communicate to the sender that we are doing direct communication.
if err := relay.WriteMsg(transfer.Msg{Type: transfer.ReceiverDirectCommunication}); err != nil {
if err := relay.WriteMsg(ctx, transfer.Msg{Type: transfer.ReceiverDirectCommunication}); err != nil {
return err
}

// Tell rendezvous server that we can close the connection.
if err := rc.WriteMsg(rendezvous.Msg{Type: rendezvous.ReceiverToRendezvousClose}); err != nil {
if err := rc.WriteMsg(ctx, rendezvous.Msg{Type: rendezvous.ReceiverToRendezvousClose}); err != nil {
return err
}

Expand All @@ -56,30 +56,30 @@ func doReceive(relay conn.Transfer, addr string, dst io.Writer, msgs ...chan int
}

// Request the payload and receive it.
if tc.WriteMsg(transfer.Msg{Type: transfer.ReceiverRequestPayload}) != nil {
if tc.WriteMsg(ctx, transfer.Msg{Type: transfer.ReceiverRequestPayload}) != nil {
return err
}
if err := receivePayload(tc, dst, msgs...); err != nil {
if err := receivePayload(ctx, tc, dst, msgs...); err != nil {
return err
}

// Closing handshake.
if err := tc.WriteMsg(transfer.Msg{Type: transfer.ReceiverPayloadAck}); err != nil {
if err := tc.WriteMsg(ctx, transfer.Msg{Type: transfer.ReceiverPayloadAck}); err != nil {
return err
}

_, err = tc.ReadMsg(transfer.SenderClosing)
_, err = tc.ReadMsg(ctx, transfer.SenderClosing)

if err != nil {
return err
}

if err := tc.WriteMsg(transfer.Msg{Type: transfer.ReceiverClosingAck}); err != nil {
if err := tc.WriteMsg(ctx, transfer.Msg{Type: transfer.ReceiverClosingAck}); err != nil {
return err
}

// Tell rendezvous to close connection.
if err := rc.WriteMsg(rendezvous.Msg{Type: rendezvous.ReceiverToRendezvousClose}); err != nil {
if err := rc.WriteMsg(ctx, rendezvous.Msg{Type: rendezvous.ReceiverToRendezvousClose}); err != nil {
return err
}
return nil
Expand Down
19 changes: 10 additions & 9 deletions internal/receiver/receive_wasm.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package receiver

import (
"context"
"io"

"github.com/SpatiumPortae/portal/internal/conn"
Expand All @@ -12,12 +13,12 @@ import (

// doReceive performs the transfer protocol on the receiving end.
// This function is only built for the js platform.
func doReceive(relayTc conn.Transfer, addr string, dst io.Writer, msgs ...chan interface{}) error {
func doReceive(ctx context.Context, relayTc conn.Transfer, addr string, dst io.Writer, msgs ...chan interface{}) error {
// Communicate to the sender that we are using relay transfer.
if err := relayTc.WriteMsg(transfer.Msg{Type: transfer.ReceiverRelayCommunication}); err != nil {
if err := relayTc.WriteMsg(ctx, transfer.Msg{Type: transfer.ReceiverRelayCommunication}); err != nil {
return err
}
_, err := relayTc.ReadMsg(transfer.SenderRelayAck)
_, err := relayTc.ReadMsg(ctx, transfer.SenderRelayAck)
if err != nil {
return err
}
Expand All @@ -27,31 +28,31 @@ func doReceive(relayTc conn.Transfer, addr string, dst io.Writer, msgs ...chan i
}

// Request the payload and receive it.
if relayTc.WriteMsg(transfer.Msg{Type: transfer.ReceiverRequestPayload}) != nil {
if relayTc.WriteMsg(ctx, transfer.Msg{Type: transfer.ReceiverRequestPayload}) != nil {
return err
}
if err := receivePayload(relayTc, dst, msgs...); err != nil {
if err := receivePayload(ctx, relayTc, dst, msgs...); err != nil {
return err
}

// Closing handshake.
if err := relayTc.WriteMsg(transfer.Msg{Type: transfer.ReceiverPayloadAck}); err != nil {
if err := relayTc.WriteMsg(ctx, transfer.Msg{Type: transfer.ReceiverPayloadAck}); err != nil {
return err
}

_, err = relayTc.ReadMsg(transfer.SenderClosing)
_, err = relayTc.ReadMsg(ctx, transfer.SenderClosing)

if err != nil {
return err
}

if err := relayTc.WriteMsg(transfer.Msg{Type: transfer.ReceiverClosingAck}); err != nil {
if err := relayTc.WriteMsg(ctx, transfer.Msg{Type: transfer.ReceiverClosingAck}); err != nil {
return err
}

// Retrieve a unencrypted channel to rendezvous.
rc := conn.Rendezvous{Conn: relayTc.Conn}
if err := rc.WriteMsg(rendezvous.Msg{Type: rendezvous.ReceiverToRendezvousClose}); err != nil {
if err := rc.WriteMsg(ctx, rendezvous.Msg{Type: rendezvous.ReceiverToRendezvousClose}); err != nil {
return err
}
return nil
Expand Down
Loading

0 comments on commit 1ed2c82

Please sign in to comment.