From a095d4269f6c944fb5aaed0e4686e81c5ff303ba Mon Sep 17 00:00:00 2001 From: mellonnen Date: Thu, 23 Feb 2023 11:41:28 +0100 Subject: [PATCH] feat: rework relay validation --- cmd/portal/main.go | 69 ------------------------------------------ cmd/portal/receive.go | 5 ++- cmd/portal/send.go | 5 +-- cmd/portal/validate.go | 43 ++++++++++++++++++++++++++ 4 files changed, 48 insertions(+), 74 deletions(-) create mode 100644 cmd/portal/validate.go diff --git a/cmd/portal/main.go b/cmd/portal/main.go index 3505440..9ae9bd2 100644 --- a/cmd/portal/main.go +++ b/cmd/portal/main.go @@ -1,14 +1,11 @@ package main import ( - "errors" "fmt" "io" "log" - "net" "os" "path/filepath" - "unicode/utf8" tea "github.com/charmbracelet/bubbletea" homedir "github.com/mitchellh/go-homedir" @@ -106,17 +103,6 @@ func initViperConfig() { } } -// validateRendezvousAddressInViper validates that the `rendezvousAddress` value in viper is a valid hostname or IP -func validateRendezvousAddressInViper() error { - rendezvouzAdress := net.ParseIP(viper.GetString("rendezvousAddress")) - err := validateHostname(viper.GetString("rendezvousAddress")) - // neither a valid IP nor a valid hostname was provided - if (rendezvouzAdress == nil) && err != nil { - return errors.New("invalid IP or hostname provided") - } - return nil -} - func setupLoggingFromViper(cmd string) (*os.File, error) { if viper.GetBool("verbose") { f, err := tea.LogToFile(fmt.Sprintf(".portal-%s.log", cmd), fmt.Sprintf("portal-%s: \n", cmd)) @@ -128,58 +114,3 @@ func setupLoggingFromViper(cmd string) (*os.File, error) { log.SetOutput(io.Discard) return nil, nil } - -// validateHostname returns an error if the domain name is not valid -// See https://tools.ietf.org/html/rfc1034#section-3.5 and -// https://tools.ietf.org/html/rfc1123#section-2. -// source: https://gist.github.com/chmike/d4126a3247a6d9a70922fc0e8b4f4013 -func validateHostname(name string) error { - switch { - case len(name) == 0: - return nil - case len(name) > 255: - return fmt.Errorf("name length is %d, can't exceed 255", len(name)) - } - var l int - for i := 0; i < len(name); i++ { - b := name[i] - if b == '.' { - // check domain labels validity - switch { - case i == l: - return fmt.Errorf("invalid character '%c' at offset %d: label can't begin with a period", b, i) - case i-l > 63: - return fmt.Errorf("byte length of label '%s' is %d, can't exceed 63", name[l:i], i-l) - case name[l] == '-': - return fmt.Errorf("label '%s' at offset %d begins with a hyphen", name[l:i], l) - case name[i-1] == '-': - return fmt.Errorf("label '%s' at offset %d ends with a hyphen", name[l:i], l) - } - l = i + 1 - continue - } - // test label character validity, note: tests are ordered by decreasing validity frequency - if !(b >= 'a' && b <= 'z' || b >= '0' && b <= '9' || b == '-' || b >= 'A' && b <= 'Z') { - // show the printable unicode character starting at byte offset i - c, _ := utf8.DecodeRuneInString(name[i:]) - if c == utf8.RuneError { - return fmt.Errorf("invalid rune at offset %d", i) - } - return fmt.Errorf("invalid character '%c' at offset %d", c, i) - } - } - // check top level domain validity - switch { - case l == len(name): - return fmt.Errorf("missing top level domain, domain can't end with a period") - case len(name)-l > 63: - return fmt.Errorf("byte length of top level domain '%s' is %d, can't exceed 63", name[l:], len(name)-l) - case name[l] == '-': - return fmt.Errorf("top level domain '%s' at offset %d begins with a hyphen", name[l:], l) - case name[len(name)-1] == '-': - return fmt.Errorf("top level domain '%s' at offset %d ends with a hyphen", name[l:], l) - case name[l] >= '0' && name[l] <= '9': - return fmt.Errorf("top level domain '%s' at offset %d begins with a digit", name[l:], l) - } - return nil -} diff --git a/cmd/portal/receive.go b/cmd/portal/receive.go index 5694b73..36d247d 100644 --- a/cmd/portal/receive.go +++ b/cmd/portal/receive.go @@ -32,7 +32,7 @@ var receiveCmd = &cobra.Command{ Args: cobra.ExactArgs(1), ValidArgsFunction: passwordCompletion, PreRunE: func(cmd *cobra.Command, args []string) error { - // Bind flags to viper + // BindvalidateRelayInViper if err := viper.BindPFlag("relay", cmd.Flags().Lookup("relay")); err != nil { return fmt.Errorf("binding relay flag: %w", err) } @@ -40,8 +40,7 @@ var receiveCmd = &cobra.Command{ }, RunE: func(cmd *cobra.Command, args []string) error { file.RemoveTemporaryFiles(file.RECEIVE_TEMP_FILE_NAME_PREFIX) - err := validateRendezvousAddressInViper() - if err != nil { + if err := validateRelayInViper(); err != nil { return err } logFile, err := setupLoggingFromViper("receive") diff --git a/cmd/portal/send.go b/cmd/portal/send.go index 83d1bbb..011c39e 100644 --- a/cmd/portal/send.go +++ b/cmd/portal/send.go @@ -15,7 +15,7 @@ import ( // Set flags. func init() { // Add subcommand flags (dummy default values as default values are handled through viper) - sendCmd.Flags().StringP("relay", "r", "", "address of the relay server") + sendCmd.Flags().StringP("relay", "r", "", "Address of the relay server ()") } // ------------------------------------------------------ Command ------------------------------------------------------ @@ -39,7 +39,8 @@ var sendCmd = &cobra.Command{ return err } file.RemoveTemporaryFiles(file.SEND_TEMP_FILE_NAME_PREFIX) - if err := validateRendezvousAddressInViper(); err != nil { + + if err := validateRelayInViper(); err != nil { return err } diff --git a/cmd/portal/validate.go b/cmd/portal/validate.go new file mode 100644 index 0000000..3b34635 --- /dev/null +++ b/cmd/portal/validate.go @@ -0,0 +1,43 @@ +package main + +import ( + "errors" + "net" + "regexp" + "strings" + + "github.com/spf13/viper" + "golang.org/x/net/idna" +) + +var ErrInvalidRelay = errors.New("invalid relay provided") + +var ipv6Rex = regexp.MustCompile(`\[(.*?)\]`) + +func stripPort(addr string) string { + split := strings.Split(addr, ":") + if len(split) == 2 { + return split[0] + } + + matches := ipv6Rex.FindStringSubmatch(addr) + if len(matches) >= 2 { + return matches[1] + } + return addr +} + +// validateRelayInViper validates that the `rendezvousAddress` value in viper is a valid hostname or IP +func validateRelayInViper() error { + relayAddr := viper.GetString("relay") + + if ip := net.ParseIP(stripPort(relayAddr)); ip != nil { + return nil + } + + if _, err := idna.Lookup.ToASCII(relayAddr); err == nil { + return nil + } + + return ErrInvalidRelay +}