store stuff in sqlite

master
Jordan Orelli 10 years ago
parent a88957060e
commit e240a4f5a8

@ -6,9 +6,20 @@ import (
) )
func requestsHandler(w http.ResponseWriter, r *http.Request) { func requestsHandler(w http.ResponseWriter, r *http.Request) {
fmt.Printf("we have %d requests in history when user checked", len(requestHistory)) rows, err := db.Query("select * from requests limit 100")
for _, req := range requestHistory { if err != nil {
fmt.Fprintln(w, req.RequestURI) http.Error(w, fmt.Sprintf("unable to query db: %s", err), 500)
return
}
defer rows.Close()
for rows.Next() {
var (
id string
host string
path string
)
rows.Scan(&id, &host, &path)
fmt.Fprintf(w, "%s %s %s\n", id, host, path)
} }
} }

@ -2,12 +2,16 @@ package main
import ( import (
"bytes" "bytes"
"database/sql"
"fmt" "fmt"
_ "github.com/mattn/go-sqlite3"
"io/ioutil" "io/ioutil"
"log"
"net/http" "net/http"
) )
var lastId int var lastId int
var db *sql.DB
func freezeRequest(r *http.Request) error { func freezeRequest(r *http.Request) error {
var buf bytes.Buffer var buf bytes.Buffer
@ -22,4 +26,62 @@ func freezeRequest(r *http.Request) error {
return nil return nil
} }
var requestHistory []http.Request func openDB() error {
var dbpath string
var err error
err = conf.Get("dbpath", &dbpath)
if err != nil {
return err
}
log.Printf("opening sqlite file at %s", dbpath)
db, err = sql.Open("sqlite3", dbpath)
if err != nil {
return err
}
setupDB()
return nil
}
func setupDB() {
sql := `
create table if not exists domains (
hostname text primary key,
blocked integer not null default 0
);
create table if not exists requests (
id text primary key,
host text not null,
path text not null
);
`
res, err := db.Exec(sql)
if err != nil {
log.Printf("unable to setup db: %v", err)
} else {
log.Printf("db was set up: %v", res)
}
}
func saveHostname(hostname string) {
res, err := db.Exec(`insert or ignore into domains (hostname) values (?)`, hostname)
if err != nil {
log.Printf("unable to save hostname: %v", err)
return
}
if n, _ := res.RowsAffected(); n > 0 {
log.Printf("saved new hostname: %s", hostname)
}
}
func saveRequest(id RequestId, r *http.Request) {
_, err := db.Exec(`insert or ignore into requests (id, host, path)
values (?, ?, ?)`, id.String(), r.URL.Host, r.URL.Path)
if err != nil {
log.Printf("unable to save request: %v", err)
return
}
}

@ -5,6 +5,7 @@ import (
"fmt" "fmt"
"github.com/jordanorelli/moon/lib" "github.com/jordanorelli/moon/lib"
"io" "io"
"log"
"net/http" "net/http"
"net/http/httputil" "net/http/httputil"
"os" "os"
@ -18,6 +19,7 @@ var (
) )
func httpHandler(w http.ResponseWriter, r *http.Request) { func httpHandler(w http.ResponseWriter, r *http.Request) {
id := newRequestId()
fmt.Println("from:", r.RemoteAddr) fmt.Println("from:", r.RemoteAddr)
if err := freezeRequest(r); err != nil { if err := freezeRequest(r); err != nil {
fmt.Printf("error freezing request: %s\n", err) fmt.Printf("error freezing request: %s\n", err)
@ -53,11 +55,8 @@ func httpHandler(w http.ResponseWriter, r *http.Request) {
fmt.Printf("error copying body: %s\n", err) fmt.Printf("error copying body: %s\n", err)
} }
if requestHistory == nil {
requestHistory = make([]http.Request, 0, 100)
}
r.RequestURI = requestURI r.RequestURI = requestURI
requestHistory = append(requestHistory, *r) saveRequest(id, r)
} }
func bail(status int, t string, args ...interface{}) { func bail(status int, t string, args ...interface{}) {
@ -85,12 +84,18 @@ func main() {
flag.StringVar(&configPath, "config", "./prox_config.moon", "path to configuration file") flag.StringVar(&configPath, "config", "./prox_config.moon", "path to configuration file")
flag.Parse() flag.Parse()
log.Printf("reading config from %s", configPath)
var err error var err error
conf, err = moon.ReadFile(configPath) conf, err = moon.ReadFile(configPath)
if err != nil { if err != nil {
bail(1, "unable to read config: %s", err) bail(1, "unable to read config: %s", err)
} }
if err := openDB(); err != nil {
bail(1, "unable to open db: %s", err)
}
defer db.Close()
go appServer() go appServer()
proxyListener() proxyListener()
} }

@ -5,3 +5,5 @@ proxy_addr: ":8080"
# http address for the user app. Users navigate to this address to view their # http address for the user app. Users navigate to this address to view their
# prox history. # prox history.
app_addr: ":9000" app_addr: ":9000"
dbpath: "history.db"

@ -0,0 +1,84 @@
package main
import (
"crypto/md5"
"encoding/binary"
"fmt"
"os"
"sync/atomic"
"time"
)
// RequestId is used for tagging each incoming http request for logging
// purposes. The actual implementation is just the ObjectId implementation
// found in launchpad.net/mgo/bson.
type RequestId string
func (id RequestId) String() string {
return fmt.Sprintf("%x", string(id))
}
// Time returns the timestamp part of the id.
// It's a runtime error to call this method with an invalid id.
func (id RequestId) Time() time.Time {
secs := int64(binary.BigEndian.Uint32(id.byteSlice(0, 4)))
return time.Unix(secs, 0)
}
// byteSlice returns byte slice of id from start to end.
// Calling this function with an invalid id will cause a runtime panic.
func (id RequestId) byteSlice(start, end int) []byte {
if len(id) != 12 {
panic(fmt.Sprintf("Invalid RequestId: %q", string(id)))
}
return []byte(string(id)[start:end])
}
// requestIdCounter is atomically incremented when generating a new ObjectId
// using NewObjectId() function. It's used as a counter part of an id.
var requestIdCounter uint32 = 0
// machineId stores machine id generated once and used in subsequent calls
// to NewObjectId function.
var machineId []byte
// initMachineId generates machine id and puts it into the machineId global
// variable. If this function fails to get the hostname, it will cause
// a runtime error.
func initMachineId() {
var sum [3]byte
hostname, err := os.Hostname()
if err != nil {
panic("Failed to get hostname: " + err.Error())
}
hw := md5.New()
hw.Write([]byte(hostname))
copy(sum[:3], hw.Sum(nil))
machineId = sum[:]
}
// NewObjectId returns a new unique ObjectId.
// This function causes a runtime error if it fails to get the hostname
// of the current machine.
func newRequestId() RequestId {
b := make([]byte, 12)
// Timestamp, 4 bytes, big endian
binary.BigEndian.PutUint32(b, uint32(time.Now().Unix()))
// Machine, first 3 bytes of md5(hostname)
if machineId == nil {
initMachineId()
}
b[4] = machineId[0]
b[5] = machineId[1]
b[6] = machineId[2]
// Pid, 2 bytes, specs don't specify endianness, but we use big endian.
pid := os.Getpid()
b[7] = byte(pid >> 8)
b[8] = byte(pid)
// Increment, 3 bytes, big endian
i := atomic.AddUint32(&requestIdCounter, 1)
b[9] = byte(i >> 16)
b[10] = byte(i >> 8)
b[11] = byte(i)
return RequestId(b)
}
Loading…
Cancel
Save