Skip to content

Commit

Permalink
Adds protection against cross-site WebSocket hijacking using CSRF tok…
Browse files Browse the repository at this point in the history
…ens (#6)

To protect against CSWSH attacks we need a way to identify cross-site requests and prevent them from connecting to the server. The easiest way to do this is by ensuring that the request Host and Origin are the same but unfortunately, kubectl proxy modifies the request Host so we can't use this method. Another easy method is to check the Sec-Fetch-Site header but unfortunately it isn't implemented in some popular browsers (see #3) so we can't use this method either. Instead, this PR uses the old-school method of CSRF token validation to identify cross-site requests and block them. After a WebSocket connection is made, the client is required to authenticate using the CSRF token value. If the token fails validation the connection is closed, otherwise it is allowed to continue.

This PR also moves the GraphiQL playground interface to a static page accessible at /graphiql.
  • Loading branch information
amorey committed Feb 15, 2024
1 parent d095230 commit 1fc41c1
Show file tree
Hide file tree
Showing 11 changed files with 216 additions and 54 deletions.
2 changes: 0 additions & 2 deletions backend/go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,9 @@ github.com/99designs/gqlgen v0.17.43 h1:I4SYg6ahjowErAQcHFVKy5EcWuwJ3+Xw9z2fLpuF
github.com/99designs/gqlgen v0.17.43/go.mod h1:lO0Zjy8MkZgBdv4T1U91x09r0e0WFOdhVUutlQs1Rsc=
github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU=
github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym/WlBOVXweHU+Q+/VP0lqqI8lqeDx9IjBqo=
github.com/PuerkitoBio/goquery v1.8.1 h1:uQxhNlArOIdbrH1tr0UXwdVFgDcZDrZVdcpygAcwmWM=
github.com/agnivade/levenshtein v1.1.1 h1:QY8M92nrzkmr798gCo3kmMyqXFzdQVpxLlGPRBij0P8=
github.com/agnivade/levenshtein v1.1.1/go.mod h1:veldBMzWxcCG2ZvUTKD2kJNRdCk5hVbJomOvKkmgYbo=
github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 h1:bvNMNQO63//z+xNgfBlViaCIJKLlCJ6/fmUseuG0wVQ=
github.com/andybalholm/cascadia v1.3.1 h1:nhxRkql1kdYCc8Snf7D5/D3spOX+dBgjA6u8x004T2c=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0 h1:jfIu9sQUG6Ig+0+Ap1h4unLjW6YQJpKZVmUzxsD4E/Q=
github.com/arbovm/levenshtein v0.0.0-20160628152529-48b4e1c0c4d0/go.mod h1:t2tdKJDJF9BV14lnkjHmOQgcvEKgtqs5a1N3LNdJhGE=
github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM=
Expand Down
6 changes: 5 additions & 1 deletion backend/graph/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,11 @@ func NewHandler(r *Resolver, options *HandlerOptions) *handler.Server {
h.AddTransport(&transport.Websocket{
Upgrader: websocket.Upgrader{
CheckOrigin: func(r *http.Request) bool {
return (r.Header.Get("Sec-Fetch-Site") == "same-origin" || r.Header.Get("Origin") == "")
// We have to return true here because `kubectl proxy` modifies the Host header
// so requests will fail same-origin tests and unfortunately not all browsers
// have implemented `sec-fetch-site` header. Instead, we will use CSRF token
// validation to ensure requests are coming from the same site.
return true
},
ReadBufferSize: 1024,
WriteBufferSize: 1024,
Expand Down
18 changes: 10 additions & 8 deletions backend/internal/ginapp/ginapp.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import (
"os"
"path"

"github.com/99designs/gqlgen/graphql/playground"
"github.com/gin-contrib/gzip"
"github.com/gin-contrib/requestid"
"github.com/gin-contrib/secure"
Expand Down Expand Up @@ -115,9 +114,11 @@ func NewGinApp(config Config) (*GinApp, error) {
c.Next()
})

var csrfProtect func(http.Handler) http.Handler

// csrf middleware
if config.CSRF.Enabled {
dynamicRoutes.Use(adapter.Wrap(csrf.Protect(
csrfProtect = csrf.Protect(
[]byte(config.CSRF.Secret),
csrf.FieldName(config.CSRF.FieldName),
csrf.CookieName(config.CSRF.Cookie.Name),
Expand All @@ -127,8 +128,12 @@ func NewGinApp(config Config) (*GinApp, error) {
csrf.Secure(config.CSRF.Cookie.Secure),
csrf.HttpOnly(config.CSRF.Cookie.HttpOnly),
csrf.SameSite(config.CSRF.Cookie.SameSite),
)))
)

// add to gin middleware
dynamicRoutes.Use(adapter.Wrap(csrfProtect))

// token fetcher helper
dynamicRoutes.GET("/csrf-token", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"value": csrf.Token(c.Request)})
})
Expand Down Expand Up @@ -156,17 +161,13 @@ func NewGinApp(config Config) (*GinApp, error) {

// graphql handler
h := &GraphQLHandlers{app}
endpointHandler := h.EndpointHandler(k8sCfg, config.Namespace)
endpointHandler := h.EndpointHandler(k8sCfg, config.Namespace, csrfProtect)
graphql.GET("", endpointHandler)
graphql.POST("", endpointHandler)
}
}
app.dynamicroutes = dynamicRoutes // for unit tests

// graphiql
h := playground.Handler("GraphQL Playground", "/graphql")
app.GET("/graphiql", gin.WrapH(h))

// healthz
app.GET("/healthz", func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
Expand All @@ -183,6 +184,7 @@ func NewGinApp(config Config) (*GinApp, error) {

app.StaticFile("/", path.Join(websiteDir, "/index.html"))
app.StaticFile("/favicon.ico", path.Join(websiteDir, "/favicon.ico"))
app.StaticFile("/graphiql", path.Join(websiteDir, "/graphiql.html"))
app.Static("/assets", path.Join(websiteDir, "/assets"))

// use react app for unknown routes
Expand Down
50 changes: 44 additions & 6 deletions backend/internal/ginapp/graphql.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ package ginapp

import (
"context"
"errors"
"net/http"
"net/http/httptest"

"github.com/99designs/gqlgen/graphql/handler/transport"
"github.com/gin-gonic/gin"
Expand All @@ -24,30 +27,60 @@ import (
"github.com/kubetail-org/kubetail/graph"
)

type key int

const graphQLCookiesCtxKey key = iota

type GraphQLHandlers struct {
*GinApp
}

// GET|POST "/graphql": GraphQL query endpoint
func (app *GraphQLHandlers) EndpointHandler(cfg *rest.Config, namespace string) gin.HandlerFunc {
func (app *GraphQLHandlers) EndpointHandler(cfg *rest.Config, namespace string, csrfProtect func(http.Handler) http.Handler) gin.HandlerFunc {
// init resolver
r, err := graph.NewResolver(cfg, namespace)
if err != nil {
panic(err)
}

csrfTestServer := http.NewServeMux()
csrfTestServer.HandleFunc("/", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {}))

// init handler options
opts := graph.NewDefaultHandlerOptions()

// Because we had to disable same-origin checks in the CheckOrigin() handler
// we will use use CSRF token validation to ensure requests are coming from
// the same site. (See https://dev.to/pssingh21/websockets-bypassing-sop-cors-5ajm)
opts.WSInitFunc = func(ctx context.Context, initPayload transport.InitPayload) (context.Context, *transport.InitPayload, error) {
token := initPayload.Authorization()
// check if csrf protection is disabled
if csrfProtect == nil {
return ctx, &initPayload, nil
}

csrfToken := initPayload.Authorization()

// add token to context
if token != "" {
ctx = context.WithValue(ctx, graph.K8STokenCtxKey, token)
cookies, ok := ctx.Value(graphQLCookiesCtxKey).([]*http.Cookie)
if !ok {
return ctx, nil, errors.New("AUTHORIZATION_REQUIRED")
}

// make mock request
r, _ := http.NewRequest("POST", "/", nil)
for _, cookie := range cookies {
r.AddCookie(cookie)
}
r.Header.Set("X-CSRF-Token", csrfToken)

// run request through csrf protect function
rr := httptest.NewRecorder()
p := csrfProtect(csrfTestServer)
p.ServeHTTP(rr, r)

if rr.Code != 200 {
return ctx, nil, errors.New("AUTHORIZATION_REQUIRED")
}

// return
return ctx, &initPayload, nil
}

Expand All @@ -56,6 +89,11 @@ func (app *GraphQLHandlers) EndpointHandler(cfg *rest.Config, namespace string)

// return gin handler func
return func(c *gin.Context) {
// save cookies for use in WSInitFunc
ctx := context.WithValue(c.Request.Context(), graphQLCookiesCtxKey, c.Request.Cookies())
c.Request = c.Request.WithContext(ctx)

// execute
h.ServeHTTP(c.Writer, c.Request)
}
}
47 changes: 42 additions & 5 deletions backend/internal/ginapp/graphql_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,15 +98,52 @@ func (suite *GraphQLTestSuite) TestAccess() {
suite.Equal(http.StatusNotFound, resp.StatusCode)
})

suite.Run("cross-origin subscriptions aren't allowed", func() {
suite.Run("cross-origin websocket requests are allowed when csrf protection is disabled", func() {
// init websocket connection
u := "ws" + strings.TrimPrefix(suite.defaultclient.testserver.URL, "http") + "/graphql"
h := http.Header{}
h.Add("Origin", "not-the-host.com")
_, _, err := websocket.DefaultDialer.Dial(u, h)
conn, resp, err := websocket.DefaultDialer.Dial(u, h)

// check response
suite.NotNil(err)
// check that response was ok
suite.Nil(err)
suite.NotNil(conn)
suite.Equal(101, resp.StatusCode)
defer conn.Close()

// write
conn.WriteJSON(map[string]string{"type": "connection_init"})

// read
_, msg, err := conn.ReadMessage()
suite.Nil(err)
suite.Contains(string(msg), "connection_ack")
})

suite.Run("websocket requests require csrf validation when csrf protection is enabled", func() {
// init client
cfg := NewTestConfig()
cfg.CSRF.Enabled = true
client := NewWebTestClient(suite.T(), NewTestApp(cfg))
defer client.Teardown()

// init websocket connection
u := "ws" + strings.TrimPrefix(client.testserver.URL, "http") + "/graphql"
h := http.Header{}
conn, resp, err := websocket.DefaultDialer.Dial(u, h)

// check that response was ok
suite.Nil(err)
suite.NotNil(conn)
suite.Equal(101, resp.StatusCode)
defer conn.Close()

// write
conn.WriteJSON(map[string]string{"type": "connection_init"})

// read
_, msg, err := conn.ReadMessage()
suite.Nil(err)
suite.Contains(string(msg), "connection_error")
})
})
}
Expand Down
4 changes: 2 additions & 2 deletions frontend/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
"date-fns": "^3.3.1",
"distinct-colors": "^3.0.0",
"graphql": "^16.8.1",
"graphql-ws": "^5.14.3",
"graphql-ws": "^5.15.0",
"kubetail-ui": "github:kubetail-org/kubetail-ui#v0.1.0",
"lucide-react": "^0.303.0",
"react": "^18.2.0",
Expand Down Expand Up @@ -63,7 +63,7 @@
"rollup-plugin-visualizer": "^5.12.0",
"tailwindcss": "^3.4.1",
"typescript": "^5.3.3",
"vite": "^5.1.0",
"vite": "^5.1.2",
"vite-plugin-svgr": "^4.2.0",
"vite-tsconfig-paths": "^4.3.1",
"vitest": "^1.2.2"
Expand Down
Loading

0 comments on commit 1fc41c1

Please sign in to comment.