parse cli long args

master
Jordan Orelli 10 years ago
parent 36b292c9a0
commit ed40455f5d

@ -0,0 +1,91 @@
package moon
import (
"fmt"
"log"
"os"
"reflect"
"strings"
)
func parseArgs(args []string, dest interface{}) (map[string]interface{}, error) {
reqs, err := requirements(dest)
if err != nil {
return nil, fmt.Errorf("unable to parse args: bad requirements: %s", err)
}
out := make(map[string]interface{})
shorts := make(map[string]req, len(reqs))
longs := make(map[string]req, len(reqs))
for _, req := range reqs {
if req.short != "" {
shorts[req.short] = req
}
if req.long != "" {
longs[req.long] = req
}
}
for i := 1; i < len(args); i++ {
arg := args[i]
if arg == "help" {
showHelp(dest)
}
if arg == "--" {
break
}
if strings.HasPrefix(arg, "--") {
arg = strings.TrimPrefix(arg, "--")
var (
key string
val string
)
if strings.ContainsRune(arg, '=') {
parts := strings.SplitN(arg, "=", 2)
key, val = parts[0], parts[1]
} else {
key = arg
i++
if i >= len(args) {
return nil, fmt.Errorf("terminal arg %s is missing a value", key)
}
val = args[i]
}
req, ok := longs[key]
if !ok {
// ignore unknown options silently?
log.Printf("no such long opt: %s", key)
continue
}
if req.t.Kind() == reflect.Bool {
out[key] = true
continue
}
d, err := ReadString(fmt.Sprintf("%s: %s", key, val)) // :(
if err != nil {
return nil, fmt.Errorf("unable to parse cli argument %s: %s", key, err)
}
out[key] = d.items[key]
} else if strings.HasPrefix(arg, "-") {
panic("i'm not doing short args yet")
} else {
break
}
}
return out, nil
}
func showHelp(dest interface{}) {
reqs, err := requirements(dest)
if err != nil {
panic(err)
}
for _, req := range reqs {
req.writeHelpLine(os.Stdout)
}
os.Exit(1)
}

@ -0,0 +1,28 @@
package moon
import (
"testing"
)
func TestArgs(t *testing.T) {
var one struct {
Host string `name: host; short: h; default: localhost`
Port int `name: port; short: p; required: true`
UseSSL bool `name: ssl_enabled; long: ssl-enabled`
CertPath string `name: ssl_cert; long: ssl-cert`
}
args := []string{"program", "--host=example.com", "--port", "9000"}
vals, err := parseArgs(args, &one)
if err != nil {
t.Error(err)
return
}
if vals["host"] != "example.com" {
t.Errorf("expected host 'example.com', saw host '%s'", vals["host"])
}
if vals["port"] != 9000 {
t.Errorf("expected port 9000, saw port %d", vals["port"])
}
}

@ -31,6 +31,10 @@ func bail(status int, t string, args ...interface{}) {
}
func Parse(dest interface{}) {
cliArgs, err := parseArgs(os.Args, dest)
if err != nil {
bail(1, "unable to parse cli args: %s", err)
}
f, err := os.Open(DefaultPath)
if err != nil {
bail(1, "unable to open moon config file at path %s: %s", DefaultPath, err)
@ -42,6 +46,10 @@ func Parse(dest interface{}) {
bail(1, "unable to parse moon config file at path %s: %s", DefaultPath, err)
}
for k, v := range cliArgs {
doc.items[k] = v
}
if err := doc.Fill(dest); err != nil {
bail(1, "unable to fill moon config values: %s", err)
}

@ -2,7 +2,9 @@ package moon
import (
"fmt"
"io"
"reflect"
"unicode/utf8"
)
type req struct {
@ -28,9 +30,24 @@ func (r req) validate() error {
if r.required && r.d_fault != nil {
return fmt.Errorf("invalid requirement %s: a required value cannot have a default", r.name)
}
if utf8.RuneCountInString(r.short) > 1 {
return fmt.Errorf("invalid requirement %s: provided short flag (%s) is more than 1 rune",
r.name, r.short)
}
return nil
}
func (r req) writeHelpLine(w io.Writer) {
if r.short != "" {
fmt.Fprintf(w, "-%s\t%s\n\n", r.short, r.name)
fmt.Fprintf(w, "\t%s\n\n", r.help)
} else if r.long != "" {
fmt.Fprintf(w, "--%s\t%s\n\n", r.long, r.name)
fmt.Fprintf(w, "\t%s\n\n", r.help)
}
}
func field2req(field reflect.StructField) (*req, error) {
doc, err := ReadString(string(field.Tag))
if err != nil {
@ -41,6 +58,7 @@ func field2req(field reflect.StructField) (*req, error) {
name: field.Name,
cliName: field.Name,
t: field.Type,
long: field.Name,
}
// it's really easy to cause infinite recursion here, since this is used by
// some of the higher up functions, so we're going to hack the document directly
@ -55,6 +73,10 @@ func field2req(field reflect.StructField) (*req, error) {
"long": doc.Get("long", &req.long),
}
if req.long == field.Name && req.name != field.Name {
req.long = req.name
}
for fname, err := range errors {
if err == nil {
continue

Loading…
Cancel
Save