Skip to content

Commit

Permalink
add GSS authentication to pgproto3
Browse files Browse the repository at this point in the history
  • Loading branch information
otan authored and jackc committed Apr 12, 2022
1 parent c6ccb4b commit 175856f
Show file tree
Hide file tree
Showing 6 changed files with 222 additions and 7 deletions.
58 changes: 58 additions & 0 deletions authentication_gss.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package pgproto3

import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)

type AuthenticationGSS struct{}

func (a *AuthenticationGSS) Backend() {}

func (a *AuthenticationGSS) AuthenticationResponse() {}

func (a *AuthenticationGSS) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}

authType := binary.BigEndian.Uint32(src)

if authType != AuthTypeGSS {
return errors.New("bad auth type")
}
return nil
}

func (a *AuthenticationGSS) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, 4)
dst = pgio.AppendUint32(dst, AuthTypeGSS)
return dst
}

func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data []byte
}{
Type: "AuthenticationGSS",
})
}

func (a *AuthenticationGSS) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}

var msg struct {
Type string
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
return nil
}
67 changes: 67 additions & 0 deletions authentication_gss_continue.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
package pgproto3

import (
"encoding/binary"
"encoding/json"
"errors"
"github.com/jackc/pgio"
)

type AuthenticationGSSContinue struct {
Data []byte
}

func (a *AuthenticationGSSContinue) Backend() {}

func (a *AuthenticationGSSContinue) AuthenticationResponse() {}

func (a *AuthenticationGSSContinue) Decode(src []byte) error {
if len(src) < 4 {
return errors.New("authentication message too short")
}

authType := binary.BigEndian.Uint32(src)

if authType != AuthTypeGSSCont {
return errors.New("bad auth type")
}

a.Data = src[4:]
return nil
}

func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte {
dst = append(dst, 'R')
dst = pgio.AppendInt32(dst, int32(len(a.Data))+8)
dst = pgio.AppendUint32(dst, AuthTypeGSSCont)
dst = append(dst, a.Data...)
return dst
}

func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data []byte
}{
Type: "AuthenticationGSSContinue",
Data: a.Data,
})
}

func (a *AuthenticationGSSContinue) UnmarshalJSON(data []byte) error {
// Ignore null, like in the main JSON package.
if string(data) == "null" {
return nil
}

var msg struct {
Type string
Data []byte
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}

a.Data = msg.Data
return nil
}
11 changes: 6 additions & 5 deletions backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,10 @@ type Backend struct {
sync Sync
terminate Terminate

bodyLen int
msgType byte
partialMsg bool
authType uint32

bodyLen int
msgType byte
partialMsg bool
authType uint32
}

const (
Expand Down Expand Up @@ -147,6 +146,8 @@ func (b *Backend) Receive() (FrontendMessage, error) {
msg = &SASLResponse{}
case AuthTypeSASLFinal:
msg = &SASLResponse{}
case AuthTypeGSS, AuthTypeGSSCont:
msg = &GSSResponse{}
case AuthTypeCleartextPassword, AuthTypeMD5Password:
fallthrough
default:
Expand Down
6 changes: 4 additions & 2 deletions frontend.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ type Frontend struct {
authenticationOk AuthenticationOk
authenticationCleartextPassword AuthenticationCleartextPassword
authenticationMD5Password AuthenticationMD5Password
authenticationGSS AuthenticationGSS
authenticationGSSContinue AuthenticationGSSContinue
authenticationSASL AuthenticationSASL
authenticationSASLContinue AuthenticationSASLContinue
authenticationSASLFinal AuthenticationSASLFinal
Expand Down Expand Up @@ -178,9 +180,9 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er
case AuthTypeSCMCreds:
return nil, errors.New("AuthTypeSCMCreds is unimplemented")
case AuthTypeGSS:
return nil, errors.New("AuthTypeGSS is unimplemented")
return &f.authenticationGSS, nil
case AuthTypeGSSCont:
return nil, errors.New("AuthTypeGSSCont is unimplemented")
return &f.authenticationGSSContinue, nil
case AuthTypeSSPI:
return nil, errors.New("AuthTypeSSPI is unimplemented")
case AuthTypeSASL:
Expand Down
48 changes: 48 additions & 0 deletions gss_response.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package pgproto3

import (
"encoding/json"
"github.com/jackc/pgio"
)

type GSSResponse struct {
Data []byte
}

// Frontend identifies this message as sendable by a PostgreSQL frontend.
func (g *GSSResponse) Frontend() {}

func (g *GSSResponse) Decode(data []byte) error {
g.Data = data
return nil
}

func (g *GSSResponse) Encode(dst []byte) []byte {
dst = append(dst, 'p')
dst = pgio.AppendInt32(dst, int32(4+len(g.Data)))
dst = append(dst, g.Data...)
return dst
}

// MarshalJSON implements encoding/json.Marshaler.
func (g *GSSResponse) MarshalJSON() ([]byte, error) {
return json.Marshal(struct {
Type string
Data []byte
}{
Type: "GSSResponse",
Data: g.Data,
})
}

// UnmarshalJSON implements encoding/json.Unmarshaler.
func (g *GSSResponse) UnmarshalJSON(data []byte) error {
var msg struct {
Data []byte
}
if err := json.Unmarshal(data, &msg); err != nil {
return err
}
g.Data = msg.Data
return nil
}
39 changes: 39 additions & 0 deletions json_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,32 @@ func TestJSONUnmarshalAuthenticationSASL(t *testing.T) {
}
}

func TestJSONUnmarshalAuthenticationGSS(t *testing.T) {
data := []byte(`{"Type":"AuthenticationGSS"}`)
want := AuthenticationGSS{}

var got AuthenticationGSS
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled AuthenticationGSS struct doesn't match expected value")
}
}

func TestJSONUnmarshalAuthenticationGSSContinue(t *testing.T) {
data := []byte(`{"Type":"AuthenticationGSSContinue","Data":[1,2,3,4]}`)
want := AuthenticationGSSContinue{Data: []byte{1, 2, 3, 4}}

var got AuthenticationGSSContinue
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled AuthenticationGSSContinue struct doesn't match expected value")
}
}

func TestJSONUnmarshalAuthenticationSASLContinue(t *testing.T) {
data := []byte(`{"Type":"AuthenticationSASLContinue", "Data":"1"}`)
want := AuthenticationSASLContinue{
Expand Down Expand Up @@ -551,6 +577,19 @@ func TestAuthenticationMD5Password(t *testing.T) {
}
}

func TestJSONUnmarshalGSSResponse(t *testing.T) {
data := []byte(`{"Type":"GSSResponse","Data":[10,20,30,40]}`)
want := GSSResponse{Data: []byte{10, 20, 30, 40}}

var got GSSResponse
if err := json.Unmarshal(data, &got); err != nil {
t.Errorf("cannot JSON unmarshal %v", err)
}
if !reflect.DeepEqual(got, want) {
t.Error("unmarshaled GSSResponse struct doesn't match expected value")
}
}

func TestErrorResponse(t *testing.T) {
data := []byte(`{"Type":"ErrorResponse","UnknownFields":{"112":"foo"},"Code": "Fail","Position":1,"Message":"this is an error"}`)
want := ErrorResponse{
Expand Down

0 comments on commit 175856f

Please sign in to comment.