Skip to content

Commit

Permalink
improve error handling in writePacket (#1601)
Browse files Browse the repository at this point in the history
* handle error before success case.
* return io.ErrShortWrite if not all bytes were written but err is nil.
* return err instead of ErrInvalidConn.
  • Loading branch information
methane committed Jun 28, 2024
1 parent 52c1917 commit 3484db1
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 19 deletions.
6 changes: 4 additions & 2 deletions connection_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ func TestPingMarkBadConnection(t *testing.T) {
netConn: nc,
buf: newBuffer(nc),
maxAllowedPacket: defaultMaxAllowedPacket,
closech: make(chan struct{}),
cfg: NewConfig(),
}

err := mc.Ping(context.Background())
Expand All @@ -184,8 +186,8 @@ func TestPingErrInvalidConn(t *testing.T) {

err := mc.Ping(context.Background())

if err != ErrInvalidConn {
t.Errorf("expected ErrInvalidConn, got %#v", err)
if err != nc.err {
t.Errorf("expected %#v, got %#v", nc.err, err)
}
}

Expand Down
34 changes: 17 additions & 17 deletions packets.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,32 +124,32 @@ func (mc *mysqlConn) writePacket(data []byte) error {
}

n, err := mc.netConn.Write(data[:4+size])
if err == nil && n == 4+size {
mc.sequence++
if size != maxPacketSize {
return nil
}
pktLen -= size
data = data[size:]
continue
}

// Handle error
if err == nil { // n != len(data)
mc.cleanup()
mc.log(ErrMalformPkt)
} else {
if err != nil {
if cerr := mc.canceled.Value(); cerr != nil {
return cerr
}
mc.cleanup()
if n == 0 && pktLen == len(data)-4 {
// only for the first loop iteration when nothing was written yet
mc.log(err)
return errBadConnNoWrite
} else {
return err
}
}
if n != 4+size {
// io.Writer(b) must return a non-nil error if it cannot write len(b) bytes.
// The io.ErrShortWrite error is used to indicate that this rule has not been followed.
mc.cleanup()
mc.log(err)
return io.ErrShortWrite
}

mc.sequence++
if size != maxPacketSize {
return nil
}
return ErrInvalidConn
pktLen -= size
data = data[size:]
}
}

Expand Down

0 comments on commit 3484db1

Please sign in to comment.