diff --git a/error.go b/error.go new file mode 100644 index 0000000..74bac28 --- /dev/null +++ b/error.go @@ -0,0 +1,34 @@ +package main + +import ( + "net/http" +) + +type apiError int + +func (e apiError) Error() string { + return http.StatusText(int(e)) +} + +type errorNode struct { + err error + parent error +} + +func (e errorNode) Error() string { + return e.err.Error() +} + +func (e errorNode) Unwrap() error { + return e.parent +} + +func joinErrors(e, e2 error) error { + if e == nil { + return e2 + } + if e2 == nil { + return e + } + return errorNode{err: e, parent: e2} +} diff --git a/handler.go b/handler.go index 923b140..52d6bfa 100644 --- a/handler.go +++ b/handler.go @@ -3,7 +3,9 @@ package main import ( "context" "embed" + "errors" "fmt" + "io/fs" "net" "net/http" "os" @@ -183,6 +185,7 @@ func (h handler) locate(modpath string) ([]os.DirEntry, error) { } func writeError(w http.ResponseWriter, err error) { + // this sucks and is wrong if os.IsNotExist(err) { log_info.Printf("404 %v", err) w.WriteHeader(http.StatusNotFound) @@ -190,34 +193,39 @@ func writeError(w http.ResponseWriter, err error) { return } + var status apiError + if errors.As(err, &status) { + w.WriteHeader(int(status)) + log_error.Printf("%d %v", status, err) + fmt.Fprintf(w, err.Error()) + return + } + w.WriteHeader(http.StatusInternalServerError) fmt.Fprintf(w, "internal server error") log_error.Printf("500 %v", err) return } -// latest serves the @latest endpoint -func (h handler) latest(modpath string, w http.ResponseWriter, r *http.Request) { -} - -// list serves the $base/$module/@v/list endpoint -func (h handler) list(modpath string, w http.ResponseWriter, r *http.Request) { - log_info.Printf("list: %s", modpath) +// getVersions gets the list of versions available for a module +func (h handler) getVersions(modpath string) ([]string, error) { dirpath, _ := filepath.Split(modpath) - log_info.Printf("dirpath: %s", dirpath) localDir := filepath.Join(h.root, "modules", dirpath) - log_info.Printf("localDir: %s", localDir) files, err := os.ReadDir(localDir) if err != nil { - writeError(w, err) - return + if errors.Is(err, fs.ErrNotExist) { + return nil, apiError(http.StatusNotFound) + } + if errors.Is(err, fs.ErrPermission) { + return nil, apiError(http.StatusForbidden) + } + return nil, joinErrors(err, apiError(http.StatusInternalServerError)) } allVersions := make([]string, 0, len(files)) for _, f := range files { name := f.Name() if filepath.Ext(name) != ".zip" { - log_info.Printf("not a zip: %s", name) continue } parts := strings.Split(name, "@") @@ -230,13 +238,27 @@ func (h handler) list(modpath string, w http.ResponseWriter, r *http.Request) { allVersions = append(allVersions, parts[1]) } - semver.Sort(allVersions) if len(allVersions) == 0 { - w.WriteHeader(http.StatusNotFound) - fmt.Fprint(w, "not found") + return nil, apiError(http.StatusNotFound) + } + semver.Sort(allVersions) + return allVersions, nil +} + +// latest serves the @latest endpoint +func (h handler) latest(modpath string, w http.ResponseWriter, r *http.Request) { + +} + +// list serves the $base/$module/@v/list endpoint +func (h handler) list(modpath string, w http.ResponseWriter, r *http.Request) { + versions, err := h.getVersions(modpath) + if err != nil { + writeError(w, err) return } - for _, version := range allVersions { + + for _, version := range versions { fmt.Fprint(w, version) } }