Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom header support during reload: closes #60 #64

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 96 additions & 87 deletions apisprout.go
Original file line number Diff line number Diff line change
Expand Up @@ -597,6 +597,48 @@ var handler = func(rr *RefreshableRouter) http.Handler {
})
}

//
func loadSwaggerFromUri(uri string) (data []byte, err error) {
if strings.HasPrefix(uri, "http") {
req, httpErr := http.NewRequest("GET", uri, nil)
if httpErr != nil {
err = httpErr
return
}
if customHeader := viper.GetString("header"); customHeader != "" {
header := strings.Split(customHeader, ":")
if len(header) != 2 {
err = errors.New("Header format is invalid")
} else {
req.Header.Add(strings.TrimSpace(header[0]), strings.TrimSpace(header[1]))
}
}
if err != nil {
return
}

client := &http.Client{}
resp, httpErr := client.Do(req)
if httpErr != nil {
err = httpErr
return
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
err = fmt.Errorf("Server at %s reported %d status code", uri, resp.StatusCode)
return
}
data, err = ioutil.ReadAll(resp.Body)
if err != nil {
return
}
} else {
data, err = ioutil.ReadFile(uri)
}

return data, err
}

// server loads an OpenAPI file and runs a mock server using the paths and
// examples defined in the file.
func server(cmd *cobra.Command, args []string) {
Expand All @@ -611,83 +653,58 @@ func server(cmd *cobra.Command, args []string) {

// Load either from an HTTP URL or from a local file depending on the passed
// in value.
if strings.HasPrefix(uri, "http") {
req, err := http.NewRequest("GET", uri, nil)
if err != nil {
log.Fatal(err)
}
if customHeader := viper.GetString("header"); customHeader != "" {
header := strings.Split(customHeader, ":")
if len(header) != 2 {
log.Fatal("Header format is invalid.")
}
req.Header.Add(strings.TrimSpace(header[0]), strings.TrimSpace(header[1]))
}
client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
log.Fatal(err)
}
data, err = loadSwaggerFromUri(uri)
if err != nil {
log.Fatal(err)
}

data, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
log.Fatal(err)
if viper.GetBool("watch") {
if strings.HasPrefix(uri, "http") {
log.Fatal(errors.New("Watching a URL is not supported."))
}

if viper.GetBool("watch") {
log.Fatal("Watching a URL is not supported.")
}
} else {
data, err = ioutil.ReadFile(uri)
// Set up a new filesystem watcher and reload the router every time
// the file has changed on disk.
watcher, err := fsnotify.NewWatcher()
if err != nil {
log.Fatal(err)
}

if viper.GetBool("watch") {
// Set up a new filesystem watcher and reload the router every time
// the file has changed on disk.
watcher, err := fsnotify.NewWatcher()
if err != nil {
log.Fatal(err)
}
defer watcher.Close()

go func() {
// Since waiting for events or errors is blocking, we do this in a
// goroutine. It loops forever here but will exit when the process
// is finished, e.g. when you `ctrl+c` to exit.
for {
select {
case event, ok := <-watcher.Events:
if !ok {
return
defer watcher.Close()

go func() {
// Since waiting for events or errors is blocking, we do this in a
// goroutine. It loops forever here but will exit when the process
// is finished, e.g. when you `ctrl+c` to exit.
for {
select {
case event, ok := <-watcher.Events:
if !ok {
return
}
if event.Op&fsnotify.Write == fsnotify.Write {
fmt.Printf("🌙 Reloading %s\n", uri)
data, err = loadSwaggerFromUri(uri)
if err != nil {
log.Printf("ERROR: %s", err)
}
if event.Op&fsnotify.Write == fsnotify.Write {
fmt.Printf("🌙 Reloading %s\n", uri)
data, err = ioutil.ReadFile(uri)
if err != nil {
log.Fatal(err)
}

if s, r, err := load(uri, data); err == nil {
swagger = s
rr.Set(r)
} else {
log.Printf("ERROR: Unable to load OpenAPI document: %s", err)
}
}
case err, ok := <-watcher.Errors:
if !ok {
return
if s, r, err := load(uri, data); err == nil {
swagger = s
rr.Set(r)
} else {
log.Printf("ERROR: Unable to load OpenAPI document: %s", err)
}
fmt.Println("error:", err)
}
case err, ok := <-watcher.Errors:
if !ok {
return
}
fmt.Println("error:", err)
}
}()
}
}()

watcher.Add(uri)
}
watcher.Add(uri)
}

swagger, router, err := load(uri, data)
Expand All @@ -699,31 +716,23 @@ func server(cmd *cobra.Command, args []string) {

if strings.HasPrefix(uri, "http") {
http.HandleFunc("/__reload", func(w http.ResponseWriter, r *http.Request) {
resp, err := http.Get(uri)
if err != nil {
log.Printf("ERROR: %v", err)
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("error while reloading"))
return
log.Printf("🌙 Reloading %s\n", uri)
data, err = loadSwaggerFromUri(uri)
if err == nil {
if s, r, err := load(uri, data); err == nil {
swagger = s
rr.Set(r)
}
}

data, err = ioutil.ReadAll(resp.Body)
resp.Body.Close()
if err != nil {
log.Printf("ERROR: %v", err)
if err == nil {
log.Printf("Reloaded from %s", uri)
w.WriteHeader(200)
w.Write([]byte("reloaded"))
} else {
log.Printf("ERROR: %s", err)
w.WriteHeader(http.StatusBadRequest)
w.Write([]byte("error while parsing"))
return
}

if s, r, err := load(uri, data); err == nil {
swagger = s
rr.Set(r)
w.Write([]byte("error while reloading"))
}

w.WriteHeader(200)
w.Write([]byte("reloaded"))
log.Printf("Reloaded from %s", uri)
})
}

Expand Down