From ed40455f5d726267451fbbc64508e1a869618bb3 Mon Sep 17 00:00:00 2001 From: Jordan Orelli Date: Sun, 3 May 2015 20:26:36 -0400 Subject: [PATCH] parse cli long args --- lib/args.go | 91 ++++++++++++++++++++++++++++++++++++++++++++++++ lib/args_test.go | 28 +++++++++++++++ lib/parse.go | 8 +++++ lib/req.go | 22 ++++++++++++ 4 files changed, 149 insertions(+) create mode 100644 lib/args.go create mode 100644 lib/args_test.go diff --git a/lib/args.go b/lib/args.go new file mode 100644 index 0000000..f8c8a79 --- /dev/null +++ b/lib/args.go @@ -0,0 +1,91 @@ +package moon + +import ( + "fmt" + "log" + "os" + "reflect" + "strings" +) + +func parseArgs(args []string, dest interface{}) (map[string]interface{}, error) { + reqs, err := requirements(dest) + if err != nil { + return nil, fmt.Errorf("unable to parse args: bad requirements: %s", err) + } + + out := make(map[string]interface{}) + shorts := make(map[string]req, len(reqs)) + longs := make(map[string]req, len(reqs)) + for _, req := range reqs { + if req.short != "" { + shorts[req.short] = req + } + if req.long != "" { + longs[req.long] = req + } + } + + for i := 1; i < len(args); i++ { + arg := args[i] + if arg == "help" { + showHelp(dest) + } + if arg == "--" { + break + } + if strings.HasPrefix(arg, "--") { + arg = strings.TrimPrefix(arg, "--") + + var ( + key string + val string + ) + if strings.ContainsRune(arg, '=') { + parts := strings.SplitN(arg, "=", 2) + key, val = parts[0], parts[1] + } else { + key = arg + i++ + if i >= len(args) { + return nil, fmt.Errorf("terminal arg %s is missing a value", key) + } + val = args[i] + } + + req, ok := longs[key] + if !ok { + // ignore unknown options silently? + log.Printf("no such long opt: %s", key) + continue + } + if req.t.Kind() == reflect.Bool { + out[key] = true + continue + } + + d, err := ReadString(fmt.Sprintf("%s: %s", key, val)) // :( + if err != nil { + return nil, fmt.Errorf("unable to parse cli argument %s: %s", key, err) + } + out[key] = d.items[key] + } else if strings.HasPrefix(arg, "-") { + panic("i'm not doing short args yet") + } else { + break + } + } + return out, nil +} + +func showHelp(dest interface{}) { + reqs, err := requirements(dest) + if err != nil { + panic(err) + } + + for _, req := range reqs { + req.writeHelpLine(os.Stdout) + } + os.Exit(1) +} diff --git a/lib/args_test.go b/lib/args_test.go new file mode 100644 index 0000000..35658a4 --- /dev/null +++ b/lib/args_test.go @@ -0,0 +1,28 @@ +package moon + +import ( + "testing" +) + +func TestArgs(t *testing.T) { + var one struct { + Host string `name: host; short: h; default: localhost` + Port int `name: port; short: p; required: true` + UseSSL bool `name: ssl_enabled; long: ssl-enabled` + CertPath string `name: ssl_cert; long: ssl-cert` + } + args := []string{"program", "--host=example.com", "--port", "9000"} + vals, err := parseArgs(args, &one) + if err != nil { + t.Error(err) + return + } + + if vals["host"] != "example.com" { + t.Errorf("expected host 'example.com', saw host '%s'", vals["host"]) + } + + if vals["port"] != 9000 { + t.Errorf("expected port 9000, saw port %d", vals["port"]) + } +} diff --git a/lib/parse.go b/lib/parse.go index d8d6123..722e5e2 100644 --- a/lib/parse.go +++ b/lib/parse.go @@ -31,6 +31,10 @@ func bail(status int, t string, args ...interface{}) { } func Parse(dest interface{}) { + cliArgs, err := parseArgs(os.Args, dest) + if err != nil { + bail(1, "unable to parse cli args: %s", err) + } f, err := os.Open(DefaultPath) if err != nil { bail(1, "unable to open moon config file at path %s: %s", DefaultPath, err) @@ -42,6 +46,10 @@ func Parse(dest interface{}) { bail(1, "unable to parse moon config file at path %s: %s", DefaultPath, err) } + for k, v := range cliArgs { + doc.items[k] = v + } + if err := doc.Fill(dest); err != nil { bail(1, "unable to fill moon config values: %s", err) } diff --git a/lib/req.go b/lib/req.go index 0556cc3..b3463b7 100644 --- a/lib/req.go +++ b/lib/req.go @@ -2,7 +2,9 @@ package moon import ( "fmt" + "io" "reflect" + "unicode/utf8" ) type req struct { @@ -28,9 +30,24 @@ func (r req) validate() error { if r.required && r.d_fault != nil { return fmt.Errorf("invalid requirement %s: a required value cannot have a default", r.name) } + + if utf8.RuneCountInString(r.short) > 1 { + return fmt.Errorf("invalid requirement %s: provided short flag (%s) is more than 1 rune", + r.name, r.short) + } return nil } +func (r req) writeHelpLine(w io.Writer) { + if r.short != "" { + fmt.Fprintf(w, "-%s\t%s\n\n", r.short, r.name) + fmt.Fprintf(w, "\t%s\n\n", r.help) + } else if r.long != "" { + fmt.Fprintf(w, "--%s\t%s\n\n", r.long, r.name) + fmt.Fprintf(w, "\t%s\n\n", r.help) + } +} + func field2req(field reflect.StructField) (*req, error) { doc, err := ReadString(string(field.Tag)) if err != nil { @@ -41,6 +58,7 @@ func field2req(field reflect.StructField) (*req, error) { name: field.Name, cliName: field.Name, t: field.Type, + long: field.Name, } // it's really easy to cause infinite recursion here, since this is used by // some of the higher up functions, so we're going to hack the document directly @@ -55,6 +73,10 @@ func field2req(field reflect.StructField) (*req, error) { "long": doc.Get("long", &req.long), } + if req.long == field.Name && req.name != field.Name { + req.long = req.name + } + for fname, err := range errors { if err == nil { continue