You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

264 lines
6.5 KiB
Go

// package dox implements utilities for performing field-wise encryption of
// structured documents.
package dox
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"reflect"
)
// a Doc represents an encrypted document. Document keys are left in
// plaintext, while document fields are encrypted or hashed as specified by
// their struct tags.
type Doc struct {
Key []byte `json:"key"`
Fields map[string]interface{} `json:"fields"`
Blob []byte `json:"blob"`
}
func (d *Doc) Decrypt(key *rsa.PrivateKey, v interface{}) error {
aesKey, err := rsa.DecryptPKCS1v15(rand.Reader, key, d.Key)
if err != nil {
return fmt.Errorf("failed to decrypt aes key: %v", err)
}
var blob []byte
var blobvals map[string]interface{}
if d.Blob != nil {
blob, err = aesDecrypt(aesKey, d.Blob)
if err != nil {
return fmt.Errorf("failed to decrypt blob: %v", err)
}
if err := json.Unmarshal(blob, &blobvals); err != nil {
return fmt.Errorf("unable to unmarshal blobvals: %v", err)
}
}
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
return fmt.Errorf("cannot Decrypt into non-pointer value of kind %v", rv.Kind())
}
rv = rv.Elem() // de-reference our pointer
rt := rv.Type()
for i := 0; i < rv.NumField(); i++ {
fv := rv.Field(i)
f := rt.Field(i)
tag := f.Tag.Get("dox")
switch tag {
case "":
val, ok := blobvals[f.Name]
if !ok {
return fmt.Errorf("doc blob is missing field %s", f.Name)
}
if !fv.CanSet() {
return fmt.Errorf("cannot set field value %s", f.Name)
}
fv.Set(reflect.ValueOf(val))
case "plaintext":
val, ok := d.Fields[f.Name]
if !ok {
return fmt.Errorf("doc fields is missing field %s", f.Name)
}
if !fv.CanSet() {
return fmt.Errorf("cannot set field value %s", f.Name)
}
fv.Set(reflect.ValueOf(val))
case "aes":
val, ok := d.Fields[f.Name]
if !ok {
return fmt.Errorf("doc fields is missing field %s", f.Name)
}
if !fv.CanSet() {
return fmt.Errorf("cannot set field value %s", f.Name)
}
b, ok := val.([]byte)
if !ok {
s, ok := val.(string)
if !ok {
return fmt.Errorf("doc is corrupt")
}
b = make([]byte, len(s))
n, err := base64.StdEncoding.Decode(b, []byte(s))
if err != nil {
return fmt.Errorf("couldn't base64 decode wtf %v", err)
}
b = b[:n]
}
rawVal, err := aesDecrypt(aesKey, b)
if err != nil {
return fmt.Errorf("couldn't decrypt field %s: %v", f.Name, err)
}
switch fv.Kind() {
case reflect.String:
fv.Set(reflect.ValueOf(string(rawVal)))
case reflect.Slice:
fv.Set(reflect.ValueOf(rawVal))
default:
panic("wtf")
}
default:
return fmt.Errorf("not there yet stop it")
}
}
return nil
}
func (d *Doc) setField(name string, v interface{}) {
if d.Fields == nil {
d.Fields = make(map[string]interface{}, 4)
}
d.Fields[name] = v
}
func DecryptJSON(key *rsa.PrivateKey, b []byte, v interface{}) error {
var doc Doc
if err := json.Unmarshal(b, &doc); err != nil {
return err
}
return doc.Decrypt(key, v)
}
func EncryptJSON(key *rsa.PublicKey, v interface{}) ([]byte, error) {
doc, err := Encrypt(key, v)
if err != nil {
return nil, err
}
b, err := json.Marshal(doc)
if err != nil {
return nil, fmt.Errorf("dox.Encrypt unable to marshal doc: %v", err)
}
return b, nil
}
func Encrypt(key *rsa.PublicKey, v interface{}) (*Doc, error) {
aesKey, err := randKey()
if err != nil {
return nil, fmt.Errorf("dox.Encrypt unable to generate document key: %v", err)
}
ckey, err := rsa.EncryptPKCS1v15(rand.Reader, key, aesKey)
if err != nil {
return nil, fmt.Errorf("dox.Encrypt unable to encrypt aes key: %v", err)
}
doc := &Doc{Key: ckey}
rv := reflect.ValueOf(v)
if rv.Kind() != reflect.Ptr {
return nil, fmt.Errorf("dox.Encrypt received non-pointer value")
}
rv = rv.Elem() // dereference our pointer
rt := rv.Type()
// blobvals stores the values of the struct to be collected into a single
// opaque blob.
blobvals := make(map[string]interface{})
for i := 0; i < rv.NumField(); i++ {
fv := rv.Field(i)
f := rt.Field(i)
tag := f.Tag.Get("dox")
switch tag {
case "":
blobvals[f.Name] = fv.Interface()
case "plaintext":
doc.setField(f.Name, fv.Interface())
case "aes":
switch value := fv.Interface().(type) {
case string:
cval, err := aesEncrypt(aesKey, []byte(value))
if err != nil {
return nil, fmt.Errorf("dox.Encrypt couldn't aes encrypt a field: %v", err)
}
doc.setField(f.Name, cval)
case []byte:
cval, err := aesEncrypt(aesKey, value)
if err != nil {
return nil, fmt.Errorf("dox.Encrypt couldn't aes encrypt a field: %v", err)
}
doc.setField(f.Name, cval)
default:
return nil, fmt.Errorf("dox.Encrypt can only aes encrypt fields of type string or []byte")
}
}
}
if len(blobvals) > 0 {
blob, err := json.Marshal(blobvals)
if err != nil {
return nil, fmt.Errorf("dox.Encrypt unable to marshal blob fields: %v", err)
}
cipherblob, err := aesEncrypt(aesKey, blob)
if err != nil {
return nil, fmt.Errorf("dox.Encrypt failed to encrypt blob: %v", err)
}
doc.Blob = cipherblob
}
return doc, nil
}
func aesEncrypt(key []byte, ptxt []byte) ([]byte, error) {
ptxt = append(ptxt, '|')
if len(ptxt)%aes.BlockSize != 0 {
pad := aes.BlockSize - len(ptxt)%aes.BlockSize
for i := 0; i < pad; i++ {
ptxt = append(ptxt, ' ')
}
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("couldn't aes encrypt: failed to make aes cipher: %v", err)
}
ctxt := make([]byte, aes.BlockSize+len(ptxt))
iv := ctxt[:aes.BlockSize]
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return nil, fmt.Errorf("couldn't encrypt note: failed to make aes iv: %v", err)
}
mode := cipher.NewCBCEncrypter(block, iv)
mode.CryptBlocks(ctxt[aes.BlockSize:], ptxt)
return ctxt, nil
}
func aesDecrypt(key []byte, ctxt []byte) ([]byte, error) {
block, err := aes.NewCipher(key)
if err != nil {
return nil, fmt.Errorf("unable to create aes cipher: %v", err)
}
iv := ctxt[:aes.BlockSize]
ptxt := make([]byte, len(ctxt)-aes.BlockSize)
mode := cipher.NewCBCDecrypter(block, iv)
mode.CryptBlocks(ptxt, ctxt[aes.BlockSize:])
for i := len(ptxt) - 1; i >= 0; i-- {
if ptxt[i] == '|' {
return ptxt[:i], nil
}
}
return ptxt, fmt.Errorf("unable to strip padding: %v", err)
}
func randslice(n int) ([]byte, error) {
b := make([]byte, n)
_, err := rand.Read(b)
if err != nil {
return nil, err
}
return b, nil
}
func randKey() ([]byte, error) {
return randslice(aes.BlockSize)
}