From 175856ffd3c8377db2e631b99ef7a7c996fdae77 Mon Sep 17 00:00:00 2001 From: Oliver Tan Date: Tue, 12 Apr 2022 14:26:13 +1000 Subject: [PATCH] add GSS authentication to pgproto3 --- authentication_gss.go | 58 +++++++++++++++++++++++++++++ authentication_gss_continue.go | 67 ++++++++++++++++++++++++++++++++++ backend.go | 11 +++--- frontend.go | 6 ++- gss_response.go | 48 ++++++++++++++++++++++++ json_test.go | 39 ++++++++++++++++++++ 6 files changed, 222 insertions(+), 7 deletions(-) create mode 100644 authentication_gss.go create mode 100644 authentication_gss_continue.go create mode 100644 gss_response.go diff --git a/authentication_gss.go b/authentication_gss.go new file mode 100644 index 0000000..5a3f3b1 --- /dev/null +++ b/authentication_gss.go @@ -0,0 +1,58 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + "github.com/jackc/pgio" +) + +type AuthenticationGSS struct{} + +func (a *AuthenticationGSS) Backend() {} + +func (a *AuthenticationGSS) AuthenticationResponse() {} + +func (a *AuthenticationGSS) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeGSS { + return errors.New("bad auth type") + } + return nil +} + +func (a *AuthenticationGSS) Encode(dst []byte) []byte { + dst = append(dst, 'R') + dst = pgio.AppendInt32(dst, 4) + dst = pgio.AppendUint32(dst, AuthTypeGSS) + return dst +} + +func (a *AuthenticationGSS) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data []byte + }{ + Type: "AuthenticationGSS", + }) +} + +func (a *AuthenticationGSS) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + return nil +} diff --git a/authentication_gss_continue.go b/authentication_gss_continue.go new file mode 100644 index 0000000..cf8b183 --- /dev/null +++ b/authentication_gss_continue.go @@ -0,0 +1,67 @@ +package pgproto3 + +import ( + "encoding/binary" + "encoding/json" + "errors" + "github.com/jackc/pgio" +) + +type AuthenticationGSSContinue struct { + Data []byte +} + +func (a *AuthenticationGSSContinue) Backend() {} + +func (a *AuthenticationGSSContinue) AuthenticationResponse() {} + +func (a *AuthenticationGSSContinue) Decode(src []byte) error { + if len(src) < 4 { + return errors.New("authentication message too short") + } + + authType := binary.BigEndian.Uint32(src) + + if authType != AuthTypeGSSCont { + return errors.New("bad auth type") + } + + a.Data = src[4:] + return nil +} + +func (a *AuthenticationGSSContinue) Encode(dst []byte) []byte { + dst = append(dst, 'R') + dst = pgio.AppendInt32(dst, int32(len(a.Data))+8) + dst = pgio.AppendUint32(dst, AuthTypeGSSCont) + dst = append(dst, a.Data...) + return dst +} + +func (a *AuthenticationGSSContinue) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data []byte + }{ + Type: "AuthenticationGSSContinue", + Data: a.Data, + }) +} + +func (a *AuthenticationGSSContinue) UnmarshalJSON(data []byte) error { + // Ignore null, like in the main JSON package. + if string(data) == "null" { + return nil + } + + var msg struct { + Type string + Data []byte + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + + a.Data = msg.Data + return nil +} diff --git a/backend.go b/backend.go index 9c42ad0..a48b66f 100644 --- a/backend.go +++ b/backend.go @@ -30,11 +30,10 @@ type Backend struct { sync Sync terminate Terminate - bodyLen int - msgType byte - partialMsg bool - authType uint32 - + bodyLen int + msgType byte + partialMsg bool + authType uint32 } const ( @@ -147,6 +146,8 @@ func (b *Backend) Receive() (FrontendMessage, error) { msg = &SASLResponse{} case AuthTypeSASLFinal: msg = &SASLResponse{} + case AuthTypeGSS, AuthTypeGSSCont: + msg = &GSSResponse{} case AuthTypeCleartextPassword, AuthTypeMD5Password: fallthrough default: diff --git a/frontend.go b/frontend.go index c33dfb0..f15a3e0 100644 --- a/frontend.go +++ b/frontend.go @@ -16,6 +16,8 @@ type Frontend struct { authenticationOk AuthenticationOk authenticationCleartextPassword AuthenticationCleartextPassword authenticationMD5Password AuthenticationMD5Password + authenticationGSS AuthenticationGSS + authenticationGSSContinue AuthenticationGSSContinue authenticationSASL AuthenticationSASL authenticationSASLContinue AuthenticationSASLContinue authenticationSASLFinal AuthenticationSASLFinal @@ -178,9 +180,9 @@ func (f *Frontend) findAuthenticationMessageType(src []byte) (BackendMessage, er case AuthTypeSCMCreds: return nil, errors.New("AuthTypeSCMCreds is unimplemented") case AuthTypeGSS: - return nil, errors.New("AuthTypeGSS is unimplemented") + return &f.authenticationGSS, nil case AuthTypeGSSCont: - return nil, errors.New("AuthTypeGSSCont is unimplemented") + return &f.authenticationGSSContinue, nil case AuthTypeSSPI: return nil, errors.New("AuthTypeSSPI is unimplemented") case AuthTypeSASL: diff --git a/gss_response.go b/gss_response.go new file mode 100644 index 0000000..62da99c --- /dev/null +++ b/gss_response.go @@ -0,0 +1,48 @@ +package pgproto3 + +import ( + "encoding/json" + "github.com/jackc/pgio" +) + +type GSSResponse struct { + Data []byte +} + +// Frontend identifies this message as sendable by a PostgreSQL frontend. +func (g *GSSResponse) Frontend() {} + +func (g *GSSResponse) Decode(data []byte) error { + g.Data = data + return nil +} + +func (g *GSSResponse) Encode(dst []byte) []byte { + dst = append(dst, 'p') + dst = pgio.AppendInt32(dst, int32(4+len(g.Data))) + dst = append(dst, g.Data...) + return dst +} + +// MarshalJSON implements encoding/json.Marshaler. +func (g *GSSResponse) MarshalJSON() ([]byte, error) { + return json.Marshal(struct { + Type string + Data []byte + }{ + Type: "GSSResponse", + Data: g.Data, + }) +} + +// UnmarshalJSON implements encoding/json.Unmarshaler. +func (g *GSSResponse) UnmarshalJSON(data []byte) error { + var msg struct { + Data []byte + } + if err := json.Unmarshal(data, &msg); err != nil { + return err + } + g.Data = msg.Data + return nil +} diff --git a/json_test.go b/json_test.go index eab2625..8fad4f8 100644 --- a/json_test.go +++ b/json_test.go @@ -37,6 +37,32 @@ func TestJSONUnmarshalAuthenticationSASL(t *testing.T) { } } +func TestJSONUnmarshalAuthenticationGSS(t *testing.T) { + data := []byte(`{"Type":"AuthenticationGSS"}`) + want := AuthenticationGSS{} + + var got AuthenticationGSS + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationGSS struct doesn't match expected value") + } +} + +func TestJSONUnmarshalAuthenticationGSSContinue(t *testing.T) { + data := []byte(`{"Type":"AuthenticationGSSContinue","Data":[1,2,3,4]}`) + want := AuthenticationGSSContinue{Data: []byte{1, 2, 3, 4}} + + var got AuthenticationGSSContinue + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled AuthenticationGSSContinue struct doesn't match expected value") + } +} + func TestJSONUnmarshalAuthenticationSASLContinue(t *testing.T) { data := []byte(`{"Type":"AuthenticationSASLContinue", "Data":"1"}`) want := AuthenticationSASLContinue{ @@ -551,6 +577,19 @@ func TestAuthenticationMD5Password(t *testing.T) { } } +func TestJSONUnmarshalGSSResponse(t *testing.T) { + data := []byte(`{"Type":"GSSResponse","Data":[10,20,30,40]}`) + want := GSSResponse{Data: []byte{10, 20, 30, 40}} + + var got GSSResponse + if err := json.Unmarshal(data, &got); err != nil { + t.Errorf("cannot JSON unmarshal %v", err) + } + if !reflect.DeepEqual(got, want) { + t.Error("unmarshaled GSSResponse struct doesn't match expected value") + } +} + func TestErrorResponse(t *testing.T) { data := []byte(`{"Type":"ErrorResponse","UnknownFields":{"112":"foo"},"Code": "Fail","Position":1,"Message":"this is an error"}`) want := ErrorResponse{