Skip to content

Commit

Permalink
Declare ExpiresOn as *time.Time in SessionState
Browse files Browse the repository at this point in the history
This eliminates time.Time zero value "0001-01-01T00:00:00Z" embedded in
encoded session state strings.

timePtr/timeVal: helper functions to convert time.Time values and
pointers each other.
  • Loading branch information
yaegashi committed Feb 22, 2019
1 parent 119420d commit ef8054f
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 37 deletions.
6 changes: 3 additions & 3 deletions providers/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err
s = &SessionState{
AccessToken: jsonResponse.AccessToken,
IDToken: jsonResponse.IDToken,
ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second),
ExpiresOn: timePtr(time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second)),
RefreshToken: jsonResponse.RefreshToken,
Email: email,
}
Expand Down Expand Up @@ -256,7 +256,7 @@ func (p *GoogleProvider) ValidateGroup(email string) bool {
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required
func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
if s == nil || timeVal(s.ExpiresOn).After(time.Now()) || s.RefreshToken == "" {
return false, nil
}

Expand All @@ -272,7 +272,7 @@ func (p *GoogleProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {

origExpiration := s.ExpiresOn
s.AccessToken = newToken
s.ExpiresOn = time.Now().Add(duration).Truncate(time.Second)
s.ExpiresOn = timePtr(time.Now().Add(duration).Truncate(time.Second))
log.Printf("refreshed access token %s (expired on %s)", s, origExpiration)
return true, nil
}
Expand Down
4 changes: 2 additions & 2 deletions providers/oidc.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er
// RefreshSessionIfNeeded checks if the session has expired and uses the
// RefreshToken to fetch a new ID token if required
func (p *OIDCProvider) RefreshSessionIfNeeded(s *SessionState) (bool, error) {
if s == nil || s.ExpiresOn.After(time.Now()) || s.RefreshToken == "" {
if s == nil || timeVal(s.ExpiresOn).After(time.Now()) || s.RefreshToken == "" {
return false, nil
}

Expand Down Expand Up @@ -124,7 +124,7 @@ func (p *OIDCProvider) createSessionState(ctx context.Context, token *oauth2.Tok
AccessToken: token.AccessToken,
IDToken: rawIDToken,
RefreshToken: token.RefreshToken,
ExpiresOn: token.Expiry,
ExpiresOn: &token.Expiry,
Email: claims.Email,
}, nil
}
Expand Down
2 changes: 1 addition & 1 deletion providers/provider_default_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
func TestRefresh(t *testing.T) {
p := &ProviderData{}
refreshed, err := p.RefreshSessionIfNeeded(&SessionState{
ExpiresOn: time.Now().Add(time.Duration(-11) * time.Minute),
ExpiresOn: timePtr(time.Now().Add(time.Duration(-11) * time.Minute)),
})
assert.Equal(t, false, refreshed)
assert.Equal(t, nil, err)
Expand Down
34 changes: 24 additions & 10 deletions providers/session_state.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,31 @@ import (

// SessionState is used to store information about the currently authenticated user session
type SessionState struct {
AccessToken string `json:",omitempty"`
IDToken string `json:",omitempty"`
ExpiresOn time.Time `json:",omitempty"`
RefreshToken string `json:",omitempty"`
Email string `json:",omitempty"`
User string `json:",omitempty"`
AccessToken string `json:",omitempty"`
IDToken string `json:",omitempty"`
ExpiresOn *time.Time `json:",omitempty"`
RefreshToken string `json:",omitempty"`
Email string `json:",omitempty"`
User string `json:",omitempty"`
}

// timePtr is a helper to convert time.Time value to pointer.
func timePtr(t time.Time) *time.Time {
return &t
}

// timeVal is a helper to convert time.Time pointer to value.
// It regards nil as a zero value.
func timeVal(t *time.Time) time.Time {
if t == nil {
return time.Time{}
}
return *t
}

// IsExpired checks whether the session has expired
func (s *SessionState) IsExpired() bool {
if !s.ExpiresOn.IsZero() && s.ExpiresOn.Before(time.Now()) {
if !timeVal(s.ExpiresOn).IsZero() && s.ExpiresOn.Before(time.Now()) {
return true
}
return false
Expand All @@ -37,7 +51,7 @@ func (s *SessionState) String() string {
if s.IDToken != "" {
o += " id_token:true"
}
if !s.ExpiresOn.IsZero() {
if !timeVal(s.ExpiresOn).IsZero() {
o += fmt.Sprintf(" expires:%s", s.ExpiresOn)
}
if s.RefreshToken != "" {
Expand Down Expand Up @@ -126,9 +140,9 @@ func legacyDecodeSessionState(v string, c *cookie.Cipher) (*SessionState, error)
i++
ts, err := strconv.Atoi(chunks[i])
if err != nil {
return nil, fmt.Errorf("invalid session state (legacy: wrong expiration time: %s)", err)
return nil, err
}
ss.ExpiresOn = time.Unix(int64(ts), 0)
ss.ExpiresOn = timePtr(time.Unix(int64(ts), 0))

i++
ss.RefreshToken = chunks[i]
Expand Down
46 changes: 25 additions & 21 deletions providers/session_state_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ import (
const secret = "0123456789abcdefghijklmnopqrstuv"
const altSecret = "0000000000abcdefghijklmnopqrstuv"

// timeUnix safely converts *time.Time to a Unix time.
func timeUnix(p *time.Time) int64 {
return timeVal(p).Unix()
}

func TestSessionStateSerialization(t *testing.T) {
c, err := cookie.NewCipher([]byte(secret))
assert.Equal(t, nil, err)
Expand All @@ -21,7 +26,7 @@ func TestSessionStateSerialization(t *testing.T) {
Email: "[email protected]",
AccessToken: "token1234",
IDToken: "rawtoken1234",
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)),
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(c)
Expand All @@ -34,7 +39,7 @@ func TestSessionStateSerialization(t *testing.T) {
assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.AccessToken, ss.AccessToken)
assert.Equal(t, s.IDToken, ss.IDToken)
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.Equal(t, timeUnix(s.ExpiresOn), timeUnix(ss.ExpiresOn))
assert.Equal(t, s.RefreshToken, ss.RefreshToken)

// ensure a different cipher can't decode properly (ie: it gets gibberish)
Expand All @@ -43,7 +48,7 @@ func TestSessionStateSerialization(t *testing.T) {
assert.Equal(t, nil, err)
assert.Equal(t, "user", ss.User)
assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.Equal(t, timeUnix(s.ExpiresOn), timeUnix(ss.ExpiresOn))
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
assert.NotEqual(t, s.IDToken, ss.IDToken)
assert.NotEqual(t, s.RefreshToken, ss.RefreshToken)
Expand All @@ -58,7 +63,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
User: "just-user",
Email: "[email protected]",
AccessToken: "token1234",
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)),
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(c)
Expand All @@ -70,7 +75,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
assert.Equal(t, s.User, ss.User)
assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.AccessToken, ss.AccessToken)
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.Equal(t, timeUnix(s.ExpiresOn), timeUnix(ss.ExpiresOn))
assert.Equal(t, s.RefreshToken, ss.RefreshToken)

// ensure a different cipher can't decode properly (ie: it gets gibberish)
Expand All @@ -79,7 +84,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
assert.Equal(t, nil, err)
assert.Equal(t, s.User, ss.User)
assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.Equal(t, timeUnix(s.ExpiresOn), timeUnix(ss.ExpiresOn))
assert.NotEqual(t, s.AccessToken, ss.AccessToken)
assert.NotEqual(t, s.RefreshToken, ss.RefreshToken)
}
Expand All @@ -88,7 +93,7 @@ func TestSessionStateSerializationNoCipher(t *testing.T) {
s := &SessionState{
Email: "[email protected]",
AccessToken: "token1234",
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)),
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(nil)
Expand All @@ -108,7 +113,7 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
User: "just-user",
Email: "[email protected]",
AccessToken: "token1234",
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Hour)),
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(nil)
Expand All @@ -124,10 +129,10 @@ func TestSessionStateSerializationNoCipherWithUser(t *testing.T) {
}

func TestExpired(t *testing.T) {
s := &SessionState{ExpiresOn: time.Now().Add(time.Duration(-1) * time.Minute)}
s := &SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(-1) * time.Minute))}
assert.Equal(t, true, s.IsExpired())

s = &SessionState{ExpiresOn: time.Now().Add(time.Duration(1) * time.Minute)}
s = &SessionState{ExpiresOn: timePtr(time.Now().Add(time.Duration(1) * time.Minute))}
assert.Equal(t, false, s.IsExpired())

s = &SessionState{}
Expand All @@ -145,18 +150,16 @@ type testCase struct {
//
// - Currently only tests without cipher here because we have no way to mock
// the random generator used in EncodeSessionState.
// - The zero value of time.Time is encoded to "0001-01-01T00:00:00Z"
// (`json:",omitempty"` is not effective for time.Time).
func TestEncodeSessionState(t *testing.T) {
e := time.Now().Add(time.Duration(1) * time.Hour)
e := timePtr(time.Now().Add(time.Duration(1) * time.Hour))

testCases := []testCase{
{
SessionState: SessionState{
Email: "[email protected]",
User: "just-user",
},
Encoded: `{"Email":"[email protected]","User":"just-user","ExpiresOn":"0001-01-01T00:00:00Z"}`,
Encoded: `{"Email":"[email protected]","User":"just-user"}`,
},
{
SessionState: SessionState{
Expand All @@ -167,7 +170,7 @@ func TestEncodeSessionState(t *testing.T) {
ExpiresOn: e,
RefreshToken: "refresh4321",
},
Encoded: `{"Email":"[email protected]","User":"just-user","ExpiresOn":"0001-01-01T00:00:00Z"}`,
Encoded: `{"Email":"[email protected]","User":"just-user"}`,
},
}

Expand All @@ -186,7 +189,7 @@ func TestEncodeSessionState(t *testing.T) {

// TestDecodeSessionState tests DecodeSessionState with the test vector
func TestDecodeSessionState(t *testing.T) {
e := time.Now().Add(time.Duration(1) * time.Hour)
e := timePtr(time.Now().Add(time.Duration(1) * time.Hour))
eJSON, _ := e.MarshalJSON()
eString := string(eJSON)
eUnix := e.Unix()
Expand All @@ -200,7 +203,7 @@ func TestDecodeSessionState(t *testing.T) {
Email: "[email protected]",
User: "just-user",
},
Encoded: `{"Email":"[email protected]","User":"just-user","ExpiresOn":"0001-01-01T00:00:00Z"}`,
Encoded: `{"Email":"[email protected]","User":"just-user"}`,
},
{
SessionState: SessionState{
Expand Down Expand Up @@ -239,16 +242,16 @@ func TestDecodeSessionState(t *testing.T) {
Email: "[email protected]",
User: "just-user",
},
Encoded: `{"Email":"[email protected]","User":"just-user","ExpiresOn":"0001-01-01T00:00:00Z"}`,
Encoded: `{"Email":"[email protected]","User":"just-user"}`,
Cipher: c,
},
{
Encoded: `{"Email":"[email protected]","User":"just-user","ExpiresOn":"0001-01-01T00:00:00Z","AccessToken":"X"}`,
Encoded: `{"Email":"[email protected]","User":"just-user","AccessToken":"X"}`,
Cipher: c,
Error: true,
},
{
Encoded: `{"Email":"[email protected]","User":"just-user","ExpiresOn":"0001-01-01T00:00:00Z","IDToken":"XXXX"}`,
Encoded: `{"Email":"[email protected]","User":"just-user","IDToken":"XXXX"}`,
Cipher: c,
Error: true,
},
Expand All @@ -273,6 +276,7 @@ func TestDecodeSessionState(t *testing.T) {
Cipher: c,
Error: true,
},

{
SessionState: SessionState{
Email: "[email protected]",
Expand Down Expand Up @@ -313,7 +317,7 @@ func TestDecodeSessionState(t *testing.T) {
assert.Equal(t, tc.AccessToken, ss.AccessToken)
assert.Equal(t, tc.RefreshToken, ss.RefreshToken)
assert.Equal(t, tc.IDToken, ss.IDToken)
assert.Equal(t, tc.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.Equal(t, timeUnix(tc.ExpiresOn), timeUnix(ss.ExpiresOn))
}
}
}

0 comments on commit ef8054f

Please sign in to comment.