diff options
author | Steve <nilslice@gmail.com> | 2017-01-16 16:14:00 -0800 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-01-16 16:14:00 -0800 |
commit | 2af951230eddc45ba429cff10d7566ad98fd343b (patch) | |
tree | 7543be03fae8aeeacc8eb48dbe16ab2d42fbca0b | |
parent | 3249b82b2a4f1aa0ae9e6943cd72dd7eebae8a4a (diff) |
[core] Adding toggle for CORS, GZIP in admin/cms configuration (#30)
This PR enables admins to disable/enable CORS and GZIP from within the admin CMS configuration page. Both are enabled by default.
Note: currently, the GZIP implementation is 100% on the fly, for every qualifying API endpoint request. This could add significant CPU usage, but dramatically decreases bandwidth. Will be considering other better implementations, but for now YMMV.
Possible optimizations:
- pooling gzip Writers vs. creating a new one for each response
- caching gzipped responses (in memory? on disk?)
- enforcing size threshold (only gzip content larger than N bytes)
-rw-r--r-- | README.md | 21 | ||||
-rw-r--r-- | cmd/ponzu/main.go | 31 | ||||
-rw-r--r-- | cmd/ponzu/usage.go | 20 | ||||
-rw-r--r-- | system/addon/api.go | 8 | ||||
-rw-r--r-- | system/admin/config/config.go | 20 | ||||
-rw-r--r-- | system/admin/handlers.go | 4 | ||||
-rw-r--r-- | system/admin/server.go | 2 | ||||
-rw-r--r-- | system/admin/upload/upload.go | 2 | ||||
-rw-r--r-- | system/api/external.go | 2 | ||||
-rw-r--r-- | system/api/handlers.go | 101 | ||||
-rw-r--r-- | system/api/server.go | 8 | ||||
-rw-r--r-- | system/db/addon.go | 22 | ||||
-rw-r--r-- | system/db/cache.go | 2 | ||||
-rw-r--r-- | system/db/config.go | 143 | ||||
-rw-r--r-- | system/db/content.go | 49 | ||||
-rw-r--r-- | system/db/init.go | 36 | ||||
-rw-r--r-- | system/db/user.go | 20 | ||||
-rw-r--r-- | system/tls/devcerts.go | 2 | ||||
-rw-r--r-- | system/tls/enable.go | 2 |
19 files changed, 347 insertions, 148 deletions
@@ -80,22 +80,17 @@ Errors will be reported, but successful commands return nothing. ### generate, gen, g -Generate a content type file with boilerplate code to implement -the editor.Editable interface. Must be given one (1) parameter of -the name of the type for the new content. The fields following a -type determine the field names and types of the content struct to -be generated. These must be in the following format: -fieldName:"T" +Generate boilerplate code for various Ponzu components, such as `content`. Example: ```bash - struct fields and built-in types... - | - v -$ ponzu gen review title:"string" body:"string" rating:"int" tags:"[]string" - ^ - | - struct type + generator struct fields and built-in types... + | | + v v +$ ponzu gen content review title:"string" body:"string" rating:"int" tags:"[]string" + ^ + | + struct type ``` The command above will generate the file `content/review.go` with boilerplate diff --git a/cmd/ponzu/main.go b/cmd/ponzu/main.go index 90ad613..e90d318 100644 --- a/cmd/ponzu/main.go +++ b/cmd/ponzu/main.go @@ -80,7 +80,7 @@ func main() { case "new": if len(args) < 2 { - fmt.Println(usage) + fmt.Println(usageNew) os.Exit(0) } @@ -91,15 +91,22 @@ func main() { } case "generate", "gen", "g": - if len(args) < 2 { - flag.PrintDefaults() + if len(args) < 3 { + fmt.Println(usageGenerate) os.Exit(0) } - err := generateContentType(args[1:]) - if err != nil { - fmt.Println(err) - os.Exit(1) + // check what we are asked to generate + switch args[1] { + case "content", "c": + err := generateContentType(args[2:]) + if err != nil { + fmt.Println(err) + os.Exit(1) + } + default: + msg := fmt.Sprintf("Generator '%s' is not implemented.", args[1]) + fmt.Println(msg) } case "build": @@ -163,8 +170,10 @@ func main() { for i := range services { if services[i] == "api" { api.Run() + } else if services[i] == "admin" { admin.Run() + } else { fmt.Println("To execute 'ponzu serve', you must specify which service to run.") fmt.Println("$ ponzu --help") @@ -176,7 +185,7 @@ func main() { // save the https port the system is listening on err := db.PutConfig("https_port", fmt.Sprintf("%d", httpsport)) if err != nil { - log.Fatalln("System failed to save config. Please try to run again.") + log.Fatalln("System failed to save config. Please try to run again.", err) } // cannot run production HTTPS and development HTTPS together @@ -193,16 +202,18 @@ func main() { fmt.Println("Enabling HTTPS...") go tls.Enable() - fmt.Printf("Server listening on :%s for HTTPS requests...\n", db.ConfigCache("https_port")) + fmt.Printf("Server listening on :%s for HTTPS requests...\n", db.ConfigCache("https_port").(string)) } // save the https port the system is listening on so internal system can make // HTTP api calls while in dev or production w/o adding more cli flags err = db.PutConfig("http_port", fmt.Sprintf("%d", port)) if err != nil { - log.Fatalln("System failed to save config. Please try to run again.") + log.Fatalln("System failed to save config. Please try to run again.", err) } + fmt.Printf("Server listening on :%d for HTTP requests...\n", port) + fmt.Println("\nvisit `/admin` to get started.") log.Fatalln(http.ListenAndServe(fmt.Sprintf(":%d", port), nil)) case "": diff --git a/cmd/ponzu/usage.go b/cmd/ponzu/usage.go index 2dca46d..2d7a3b8 100644 --- a/cmd/ponzu/usage.go +++ b/cmd/ponzu/usage.go @@ -10,11 +10,10 @@ var year = fmt.Sprintf("%d", time.Now().Year()) var usageHeader = ` $ ponzu [flags] command <params> -Ponzu is a powerful and efficient open-source "Content-as-a-Service" system -framework and CMS. It provides automatic, free, and secure HTTP/2 over TLS -(certificates obtained via Let's Encrypt - https://letsencrypt.org), a useful -CMS and scaffolding to generate content editors, and a fast HTTP API on which -to build modern applications. +Ponzu is a powerful and efficient open-source HTTP server framework and CMS. It +provides automatic, free, and secure HTTP/2 over TLS (certificates obtained via +[Let's Encrypt](https://letsencrypt.org)), a useful CMS and scaffolding to +generate content editors, and a fast HTTP API on which to build modern applications. Ponzu is released under the BSD-3-Clause license (see LICENSE). (c) ` + year + ` Boss Sauce Creative, LLC @@ -55,17 +54,12 @@ new <directory>: ` var usageGenerate = ` -generate, gen, g <type (,...fields)>: +generate, gen, g <generator type (,...fields)>: - Generate a content type file with boilerplate code to implement - the editor.Editable interface. Must be given one (1) parameter of - the name of the type for the new content. The fields following a - type determine the field names and types of the content struct to - be generated. These must be in the following format: - fieldName:"T" + Generate boilerplate code for various Ponzu components, such as 'content'. Example: - $ ponzu gen review title:"string" body:"string" rating:"int" tags:"[]string" + $ ponzu gen content review title:"string" body:"string" rating:"int" tags:"[]string" The command above will generate a file 'content/review.go' with boilerplate methods, as well as struct definition, and cooresponding field tags like: diff --git a/system/addon/api.go b/system/addon/api.go index 9b54d6e..cd792aa 100644 --- a/system/addon/api.go +++ b/system/addon/api.go @@ -18,8 +18,8 @@ type QueryOptions db.QueryOptions // ContentAll retrives all items from the HTTP API within the provided namespace func ContentAll(namespace string) []byte { - host := db.ConfigCache("domain") - port := db.ConfigCache("http_port") + host := db.ConfigCache("domain").(string) + port := db.ConfigCache("http_port").(string) endpoint := "http://%s:%s/api/contents?type=%s&count=-1" URL := fmt.Sprintf(endpoint, host, port, namespace) @@ -35,8 +35,8 @@ func ContentAll(namespace string) []byte { // Query retrieves a set of content from the HTTP API based on options // and returns the total number of content in the namespace and the content func Query(namespace string, opts QueryOptions) []byte { - host := db.ConfigCache("domain") - port := db.ConfigCache("http_port") + host := db.ConfigCache("domain").(string) + port := db.ConfigCache("http_port").(string) endpoint := "http://%s:%s/api/contents?type=%s&count=%d&offset=%d&order=%s" URL := fmt.Sprintf(endpoint, host, port, namespace, opts.Count, opts.Offset, opts.Order) diff --git a/system/admin/config/config.go b/system/admin/config/config.go index 7b57dc0..0d55700 100644 --- a/system/admin/config/config.go +++ b/system/admin/config/config.go @@ -16,6 +16,8 @@ type Config struct { AdminEmail string `json:"admin_email"` ClientSecret string `json:"client_secret"` Etag string `json:"etag"` + DisableCORS bool `json:"cors_disabled"` + DisableGZIP bool `json:"gzip_disabled"` CacheInvalidate []string `json:"cache"` } @@ -49,7 +51,7 @@ func (c *Config) MarshalEditor() ([]byte, error) { }, editor.Field{ View: editor.Input("AdminEmail", c, map[string]string{ - "label": "Adminstrator Email (will be notified of internal system information)", + "label": "Adminstrator Email (notified of internal system information)", }), }, editor.Field{ @@ -65,7 +67,7 @@ func (c *Config) MarshalEditor() ([]byte, error) { }, editor.Field{ View: editor.Input("Etag", c, map[string]string{ - "label": "Etag Header (used for static asset cache)", + "label": "Etag Header (used to cache resources)", "disabled": "true", }), }, @@ -75,6 +77,20 @@ func (c *Config) MarshalEditor() ([]byte, error) { }), }, editor.Field{ + View: editor.Checkbox("DisableCORS", c, map[string]string{ + "label": "Disable CORS (so only " + c.Domain + " can fetch your data)", + }, map[string]string{ + "true": "Disable CORS", + }), + }, + editor.Field{ + View: editor.Checkbox("DisableGZIP", c, map[string]string{ + "label": "Disable GZIP (will increase server speed, but also bandwidth)", + }, map[string]string{ + "true": "Disable GZIP", + }), + }, + editor.Field{ View: editor.Checkbox("CacheInvalidate", c, map[string]string{ "label": "Invalidate cache on save", }, map[string]string{ diff --git a/system/admin/handlers.go b/system/admin/handlers.go index c39fee4..2bea356 100644 --- a/system/admin/handlers.go +++ b/system/admin/handlers.go @@ -92,7 +92,7 @@ func initHandler(res http.ResponseWriter, req *http.Request) { } // set HTTP port which should be previously added to config cache - port := db.ConfigCache("http_port") + port := db.ConfigCache("http_port").(string) req.Form.Set("http_port", port) // set initial user email as admin_email and make config @@ -1533,7 +1533,7 @@ func editHandler(res http.ResponseWriter, req *http.Request) { // create a timestamp if one was not set if ts == "" { - ts := fmt.Sprintf("%d", time.Now().Unix()*1000) + ts = fmt.Sprintf("%d", int64(time.Nanosecond)*time.Now().UnixNano()/int64(time.Millisecond)) req.PostForm.Set("timestamp", ts) } diff --git a/system/admin/server.go b/system/admin/server.go index f2bf244..11bfe6f 100644 --- a/system/admin/server.go +++ b/system/admin/server.go @@ -51,5 +51,5 @@ func Run() { // even if the API server is not running. Otherwise, images/files uploaded // through the editor will not load within the admin system. uploadsDir := filepath.Join(pwd, "uploads") - http.Handle("/api/uploads/", api.Record(db.CacheControl(http.StripPrefix("/api/uploads/", http.FileServer(restrict(http.Dir(uploadsDir))))))) + http.Handle("/api/uploads/", api.Record(api.CORS(db.CacheControl(http.StripPrefix("/api/uploads/", http.FileServer(restrict(http.Dir(uploadsDir)))))))) } diff --git a/system/admin/upload/upload.go b/system/admin/upload/upload.go index 486f55c..6b99dfc 100644 --- a/system/admin/upload/upload.go +++ b/system/admin/upload/upload.go @@ -20,7 +20,7 @@ func StoreFiles(req *http.Request) (map[string]string, error) { ts := req.FormValue("timestamp") // timestamp in milliseconds since unix epoch if ts == "" { - ts = fmt.Sprintf("%d", time.Now().Unix()*1000) // Unix() returns seconds since unix epoch + ts = fmt.Sprintf("%d", int64(time.Nanosecond)*time.Now().UnixNano()/int64(time.Millisecond)) // Unix() returns seconds since unix epoch } req.Form.Set("timestamp", ts) diff --git a/system/api/external.go b/system/api/external.go index 662fc07..5d4b302 100644 --- a/system/api/external.go +++ b/system/api/external.go @@ -61,7 +61,7 @@ func externalContentHandler(res http.ResponseWriter, req *http.Request) { return } - ts := fmt.Sprintf("%d", time.Now().Unix()*1000) + ts := fmt.Sprintf("%d", int64(time.Nanosecond)*time.Now().UnixNano()/int64(time.Millisecond)) req.PostForm.Set("timestamp", ts) req.PostForm.Set("updated", ts) diff --git a/system/api/handlers.go b/system/api/handlers.go index 1bc4fbb..9292e15 100644 --- a/system/api/handlers.go +++ b/system/api/handlers.go @@ -2,9 +2,11 @@ package api import ( "bytes" + "compress/gzip" "encoding/json" "log" "net/http" + "net/url" "strconv" "strings" @@ -13,6 +15,7 @@ import ( "github.com/ponzu-cms/ponzu/system/item" ) +// deprecating from API, but going to provide code here in case someone wants it func typesHandler(res http.ResponseWriter, req *http.Request) { var types = []string{} for t, fn := range item.Types { @@ -27,7 +30,7 @@ func typesHandler(res http.ResponseWriter, req *http.Request) { return } - sendData(res, j, http.StatusOK) + sendData(res, req, j) } func contentsHandler(res http.ResponseWriter, req *http.Request) { @@ -91,7 +94,7 @@ func contentsHandler(res http.ResponseWriter, req *http.Request) { return } - sendData(res, j, http.StatusOK) + sendData(res, req, j) } func contentHandler(res http.ResponseWriter, req *http.Request) { @@ -134,7 +137,7 @@ func contentHandler(res http.ResponseWriter, req *http.Request) { return } - sendData(res, j, http.StatusOK) + sendData(res, req, j) } func contentHandlerBySlug(res http.ResponseWriter, req *http.Request) { @@ -171,7 +174,7 @@ func contentHandlerBySlug(res http.ResponseWriter, req *http.Request) { return } - sendData(res, j, http.StatusOK) + sendData(res, req, j) } func hide(it interface{}, res http.ResponseWriter, req *http.Request) bool { @@ -231,13 +234,12 @@ func toJSON(data []string) ([]byte, error) { return buf.Bytes(), nil } -// sendData() should be used any time you want to communicate +// sendData should be used any time you want to communicate // data back to a foreign client -func sendData(res http.ResponseWriter, data []byte, code int) { - res.Header().Set("Access-Control-Allow-Headers", "Accept, Authorization, Content-Type") - res.Header().Set("Access-Control-Allow-Origin", "*") +func sendData(res http.ResponseWriter, req *http.Request, data []byte) { res.Header().Set("Content-Type", "application/json") - res.WriteHeader(code) + res.Header().Set("Vary", "Accept-Encoding") + _, err := res.Write(data) if err != nil { log.Println("Error writing to response in sendData") @@ -252,9 +254,54 @@ func sendPreflight(res http.ResponseWriter) { return } -// CORS wraps a HandleFunc to respond to OPTIONS requests properly +func responseWithCORS(res http.ResponseWriter, req *http.Request) (http.ResponseWriter, bool) { + if db.ConfigCache("cors_disabled").(bool) == true { + // check origin matches config domain + domain := db.ConfigCache("domain").(string) + origin := req.Header.Get("Origin") + u, err := url.Parse(origin) + if err != nil { + log.Println("Error parsing URL from request Origin header:", origin) + return res, false + } + + // hack to get dev environments to bypass cors since u.Host (below) will + // be empty, based on Go's url.Parse function + if domain == "localhost" { + domain = "" + } + origin = u.Host + + // currently, this will check for exact match. will need feedback to + // determine if subdomains should be allowed or allow multiple domains + // in config + if origin == domain { + // apply limited CORS headers and return + res.Header().Set("Access-Control-Allow-Headers", "Accept, Authorization, Content-Type") + res.Header().Set("Access-Control-Allow-Origin", domain) + return res, true + } + + // disallow request + res.WriteHeader(http.StatusForbidden) + return res, false + } + + // apply full CORS headers and return + res.Header().Set("Access-Control-Allow-Headers", "Accept, Authorization, Content-Type") + res.Header().Set("Access-Control-Allow-Origin", "*") + + return res, true +} + +// CORS wraps a HandlerFunc to respond to OPTIONS requests properly func CORS(next http.HandlerFunc) http.HandlerFunc { return db.CacheControl(http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + res, cors := responseWithCORS(res, req) + if !cors { + return + } + if req.Method == http.MethodOptions { sendPreflight(res) return @@ -264,7 +311,7 @@ func CORS(next http.HandlerFunc) http.HandlerFunc { })) } -// Record wraps a HandleFunc to record API requests for analytical purposes +// Record wraps a HandlerFunc to record API requests for analytical purposes func Record(next http.HandlerFunc) http.HandlerFunc { return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { go analytics.Record(req) @@ -272,3 +319,35 @@ func Record(next http.HandlerFunc) http.HandlerFunc { next.ServeHTTP(res, req) }) } + +// Gzip wraps a HandlerFunc to compress responses when possible +func Gzip(next http.HandlerFunc) http.HandlerFunc { + return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + if db.ConfigCache("gzip_disabled").(bool) == true { + next.ServeHTTP(res, req) + return + } + + // check if req header content-encoding supports gzip + if strings.Contains(req.Header.Get("Accept-Encoding"), "gzip") { + // gzip response data + res.Header().Set("Content-Encoding", "gzip") + gzres := gzipResponseWriter{res, gzip.NewWriter(res)} + + next.ServeHTTP(gzres, req) + return + } + + next.ServeHTTP(res, req) + }) +} + +type gzipResponseWriter struct { + http.ResponseWriter + gw *gzip.Writer +} + +func (gzw gzipResponseWriter) Write(p []byte) (int, error) { + defer gzw.gw.Close() + return gzw.gw.Write(p) +} diff --git a/system/api/server.go b/system/api/server.go index f31a748..8e103c4 100644 --- a/system/api/server.go +++ b/system/api/server.go @@ -4,11 +4,9 @@ import "net/http" // Run adds Handlers to default http listener for API func Run() { - http.HandleFunc("/api/types", CORS(Record(typesHandler))) + http.HandleFunc("/api/contents", Record(CORS(Gzip(contentsHandler)))) - http.HandleFunc("/api/contents", CORS(Record(contentsHandler))) + http.HandleFunc("/api/content", Record(CORS(Gzip(contentHandler)))) - http.HandleFunc("/api/content", CORS(Record(contentHandler))) - - http.HandleFunc("/api/content/external", CORS(Record(externalContentHandler))) + http.HandleFunc("/api/content/external", Record(externalContentHandler)) } diff --git a/system/db/addon.go b/system/db/addon.go index f4621fa..0f63405 100644 --- a/system/db/addon.go +++ b/system/db/addon.go @@ -24,6 +24,9 @@ func Addon(key string) ([]byte, error) { err := store.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("__addons")) + if b == nil { + return bolt.ErrBucketNotFound + } val := b.Get([]byte(key)) @@ -56,12 +59,16 @@ func SetAddon(data url.Values, kind interface{}) error { v, err := json.Marshal(kind) + k := data.Get("addon_reverse_dns") + if k == "" { + name := data.Get("addon_name") + return fmt.Errorf(`Addon "%s" has no identifier to use as key.`, name) + } + err = store.Update(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("__addons")) - k := data.Get("addon_reverse_dns") - if k == "" { - name := data.Get("addon_name") - return fmt.Errorf(`Addon "%s" has no identifier to use as key.`, name) + if b == nil { + return bolt.ErrBucketNotFound } err := b.Put([]byte(k), v) @@ -84,6 +91,10 @@ func AddonAll() [][]byte { err := store.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("__addons")) + if b == nil { + return bolt.ErrBucketNotFound + } + err := b.ForEach(func(k, v []byte) error { all = append(all, v) @@ -107,6 +118,9 @@ func AddonAll() [][]byte { func DeleteAddon(key string) error { err := store.Update(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("__addons")) + if b == nil { + return bolt.ErrBucketNotFound + } if err := b.Delete([]byte(key)); err != nil { return err diff --git a/system/db/cache.go b/system/db/cache.go index 30ecf5a..0120147 100644 --- a/system/db/cache.go +++ b/system/db/cache.go @@ -11,7 +11,7 @@ import ( // CacheControl sets the default cache policy on static asset responses func CacheControl(next http.Handler) http.HandlerFunc { return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { - etag := ConfigCache("etag") + etag := ConfigCache("etag").(string) policy := fmt.Sprintf("max-age=%d, public", 60*60*24*30) res.Header().Add("ETag", etag) res.Header().Add("Cache-Control", policy) diff --git a/system/db/config.go b/system/db/config.go index 45b3952..5a93353 100644 --- a/system/db/config.go +++ b/system/db/config.go @@ -13,57 +13,62 @@ import ( "github.com/gorilla/schema" ) -var configCache url.Values +var configCache map[string]interface{} func init() { - configCache = make(url.Values) + configCache = make(map[string]interface{}) } // SetConfig sets key:value pairs in the db for configuration settings func SetConfig(data url.Values) error { - err := store.Update(func(tx *bolt.Tx) error { - b := tx.Bucket([]byte("__config")) - - // check for any multi-value fields (ex. checkbox fields) - // and correctly format for db storage. Essentially, we need - // fieldX.0: value1, fieldX.1: value2 => fieldX: []string{value1, value2} - var discardKeys []string - for k, v := range data { - if strings.Contains(k, ".") { - key := strings.Split(k, ".")[0] - - if data.Get(key) == "" { - data.Set(key, v[0]) - } else { - data.Add(key, v[0]) - } - - discardKeys = append(discardKeys, k) + var j []byte + + // check for any multi-value fields (ex. checkbox fields) + // and correctly format for db storage. Essentially, we need + // fieldX.0: value1, fieldX.1: value2 => fieldX: []string{value1, value2} + var discardKeys []string + for k, v := range data { + if strings.Contains(k, ".") { + key := strings.Split(k, ".")[0] + + if data.Get(key) == "" { + data.Set(key, v[0]) + } else { + data.Add(key, v[0]) } - } - for _, discardKey := range discardKeys { - data.Del(discardKey) + discardKeys = append(discardKeys, k) } + } - cfg := &config.Config{} - dec := schema.NewDecoder() - dec.SetAliasTag("json") // allows simpler struct tagging when creating a content type - dec.IgnoreUnknownKeys(true) // will skip over form values submitted, but not in struct - err := dec.Decode(cfg, data) - if err != nil { - return err - } + for _, discardKey := range discardKeys { + data.Del(discardKey) + } - // check for "invalidate" value to reset the Etag - if len(cfg.CacheInvalidate) > 0 && cfg.CacheInvalidate[0] == "invalidate" { - cfg.Etag = NewEtag() - cfg.CacheInvalidate = []string{} - } + cfg := &config.Config{} + dec := schema.NewDecoder() + dec.SetAliasTag("json") // allows simpler struct tagging when creating a content type + dec.IgnoreUnknownKeys(true) // will skip over form values submitted, but not in struct + err := dec.Decode(cfg, data) + if err != nil { + return err + } - j, err := json.Marshal(cfg) - if err != nil { - return err + // check for "invalidate" value to reset the Etag + if len(cfg.CacheInvalidate) > 0 && cfg.CacheInvalidate[0] == "invalidate" { + cfg.Etag = NewEtag() + cfg.CacheInvalidate = []string{} + } + + j, err = json.Marshal(cfg) + if err != nil { + return err + } + + err = store.Update(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("__config")) + if b == nil { + return bolt.ErrBucketNotFound } err = b.Put([]byte("settings"), j) @@ -77,7 +82,14 @@ func SetConfig(data url.Values) error { return err } - configCache = data + // convert json => map[string]interface{} + var kv map[string]interface{} + err = json.Unmarshal(j, &kv) + if err != nil { + return err + } + + configCache = kv return nil } @@ -108,6 +120,9 @@ func ConfigAll() ([]byte, error) { val := &bytes.Buffer{} err := store.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("__config")) + if b == nil { + return bolt.ErrBucketNotFound + } _, err := val.Write(b.Get([]byte("settings"))) if err != nil { return err @@ -131,6 +146,13 @@ func PutConfig(key string, value interface{}) error { return err } + if c == nil { + c, err = emptyConfig() + if err != nil { + return err + } + } + err = json.Unmarshal(c, &kv) if err != nil { return err @@ -166,6 +188,43 @@ func PutConfig(key string, value interface{}) error { // ConfigCache is a in-memory cache of the Configs for quicker lookups // 'key' is the JSON tag associated with the config field -func ConfigCache(key string) string { - return configCache.Get(key) +func ConfigCache(key string) interface{} { + return configCache[key] +} + +// LoadCacheConfig loads the config into a cache to be accessed by ConfigCache() +func LoadCacheConfig() error { + c, err := ConfigAll() + if err != nil { + return err + } + + if c == nil { + c, err = emptyConfig() + if err != nil { + return err + } + } + + // convert json => map[string]interface{} + var kv map[string]interface{} + err = json.Unmarshal(c, &kv) + if err != nil { + return err + } + + configCache = kv + + return nil +} + +func emptyConfig() ([]byte, error) { + cfg := &config.Config{} + + data, err := json.Marshal(cfg) + if err != nil { + return nil, err + } + + return data, nil } diff --git a/system/db/content.go b/system/db/content.go index b8d9cb8..d9096ae 100644 --- a/system/db/content.go +++ b/system/db/content.go @@ -49,17 +49,17 @@ func update(ns, id string, data url.Values) (int, error) { return 0, err } + j, err := postToJSON(ns, data) + if err != nil { + return 0, err + } + err = store.Update(func(tx *bolt.Tx) error { b, err := tx.CreateBucketIfNotExists([]byte(ns + specifier)) if err != nil { return err } - j, err := postToJSON(ns, data) - if err != nil { - return err - } - err = b.Put([]byte(fmt.Sprintf("%d", cid)), j) if err != nil { return err @@ -134,6 +134,10 @@ func insert(ns string, data url.Values) (int, error) { // store the slug,type:id in contentIndex if public content if specifier == "" { ci := tx.Bucket([]byte("__contentIndex")) + if ci == nil { + return bolt.ErrBucketNotFound + } + k := []byte(data.Get("slug")) v := []byte(fmt.Sprintf("%s:%d", ns, effectedID)) err := ci.Put(k, v) @@ -168,7 +172,12 @@ func DeleteContent(target string, data url.Values) error { ns, id := t[0], t[1] err := store.Update(func(tx *bolt.Tx) error { - err := tx.Bucket([]byte(ns)).Delete([]byte(id)) + b := tx.Bucket([]byte(ns)) + if b == nil { + return bolt.ErrBucketNotFound + } + + err := b.Delete([]byte(id)) if err != nil { return err } @@ -176,7 +185,12 @@ func DeleteContent(target string, data url.Values) error { // if content has a slug, also delete it from __contentIndex slug := data.Get("slug") if slug != "" { - err := tx.Bucket([]byte("__contentIndex")).Delete([]byte(slug)) + ci := tx.Bucket([]byte("__contentIndex")) + if ci == nil { + return bolt.ErrBucketNotFound + } + + err := ci.Delete([]byte(slug)) if err != nil { return err } @@ -212,6 +226,10 @@ func Content(target string) ([]byte, error) { val := &bytes.Buffer{} err := store.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte(ns)) + if b == nil { + return bolt.ErrBucketNotFound + } + _, err := val.Write(b.Get([]byte(id))) if err != nil { log.Println(err) @@ -235,6 +253,9 @@ func ContentBySlug(slug string) (string, []byte, error) { var t, id string err := store.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("__contentIndex")) + if b == nil { + return bolt.ErrBucketNotFound + } idx := b.Get([]byte(slug)) if idx != nil { @@ -248,6 +269,9 @@ func ContentBySlug(slug string) (string, []byte, error) { } c := tx.Bucket([]byte(t)) + if c == nil { + return bolt.ErrBucketNotFound + } _, err := val.Write(c.Get([]byte(id))) if err != nil { return err @@ -267,9 +291,8 @@ func ContentAll(namespace string) [][]byte { var posts [][]byte store.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte(namespace)) - if b == nil { - return nil + return bolt.ErrBucketNotFound } numKeys := b.Stats().KeyN @@ -313,7 +336,7 @@ func Query(namespace string, opts QueryOptions) (int, [][]byte) { store.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte(namespace)) if b == nil { - return nil + return bolt.ErrBucketNotFound } c := b.Cursor() @@ -450,13 +473,11 @@ func SortContent(namespace string) { bname := []byte(namespace + "__sorted") err := tx.DeleteBucket(bname) if err != nil && err != bolt.ErrBucketNotFound { - fmt.Println("Error in DeleteBucket") return err } b, err := tx.CreateBucketIfNotExists(bname) if err != nil { - fmt.Println("Error in CreateBucketIfNotExists") return err } @@ -465,7 +486,6 @@ func SortContent(namespace string) { cid := fmt.Sprintf("%d:%d", i, posts[i].Time()) err = b.Put([]byte(cid), bb[i]) if err != nil { - fmt.Println("Error in Put") return err } } @@ -538,6 +558,9 @@ func checkSlugForDuplicate(slug string) (string, error) { // check for existing slug in __contentIndex err := store.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("__contentIndex")) + if b == nil { + return bolt.ErrBucketNotFound + } original := slug exists := true i := 0 diff --git a/system/db/init.go b/system/db/init.go index eaf6d76..9125d3b 100644 --- a/system/db/init.go +++ b/system/db/init.go @@ -1,10 +1,8 @@ package db import ( - "encoding/json" "log" - "github.com/ponzu-cms/ponzu/system/admin/config" "github.com/ponzu-cms/ponzu/system/item" "github.com/boltdb/bolt" @@ -57,32 +55,23 @@ func Init() { } } - // seed db with configs structure if not present - b := tx.Bucket([]byte("__config")) - if b.Get([]byte("settings")) == nil { - j, err := json.Marshal(&config.Config{}) - if err != nil { - return err - } - - err = b.Put([]byte("settings"), j) - if err != nil { - return err - } - } - - clientSecret := ConfigCache("client_secret") - - if clientSecret != "" { - jwt.Secret([]byte(clientSecret)) - } - return nil }) if err != nil { log.Fatalln("Coudn't initialize db with buckets.", err) } + err = LoadCacheConfig() + if err != nil { + log.Fatalln("Failed to load config cache.", err) + } + + clientSecret := ConfigCache("client_secret").(string) + + if clientSecret != "" { + jwt.Secret([]byte(clientSecret)) + } + // invalidate cache on system start err = InvalidateCache() if err != nil { @@ -103,6 +92,9 @@ func SystemInitComplete() bool { err := store.View(func(tx *bolt.Tx) error { users := tx.Bucket([]byte("__users")) + if users == nil { + return bolt.ErrBucketNotFound + } err := users.ForEach(func(k, v []byte) error { complete = true diff --git a/system/db/user.go b/system/db/user.go index 02fda95..164ae7b 100644 --- a/system/db/user.go +++ b/system/db/user.go @@ -26,6 +26,9 @@ func SetUser(usr *user.User) (int, error) { err := store.Update(func(tx *bolt.Tx) error { email := []byte(usr.Email) users := tx.Bucket([]byte("__users")) + if users == nil { + return bolt.ErrBucketNotFound + } // check if user is found by email, fail if nil exists := users.Get(email) @@ -69,6 +72,9 @@ func UpdateUser(usr, updatedUsr *user.User) error { err := store.Update(func(tx *bolt.Tx) error { users := tx.Bucket([]byte("__users")) + if users == nil { + return bolt.ErrBucketNotFound + } // check if user is found by email, fail if nil exists := users.Get([]byte(usr.Email)) @@ -110,6 +116,10 @@ func UpdateUser(usr, updatedUsr *user.User) error { func DeleteUser(email string) error { err := store.Update(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("__users")) + if b == nil { + return bolt.ErrBucketNotFound + } + err := b.Delete([]byte(email)) if err != nil { return err @@ -129,6 +139,10 @@ func User(email string) ([]byte, error) { val := &bytes.Buffer{} err := store.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("__users")) + if b == nil { + return bolt.ErrBucketNotFound + } + usr := b.Get([]byte(email)) _, err := val.Write(usr) @@ -154,6 +168,10 @@ func UserAll() ([][]byte, error) { var users [][]byte err := store.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("__users")) + if b == nil { + return bolt.ErrBucketNotFound + } + err := b.ForEach(func(k, v []byte) error { users = append(users, v) return nil @@ -230,7 +248,7 @@ func RecoveryKey(email string) (string, error) { err := store.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("__recoveryKeys")) if b == nil { - return errors.New("No database found for checking keys.") + return bolt.ErrBucketNotFound } _, err := key.Write(b.Get([]byte(email))) diff --git a/system/tls/devcerts.go b/system/tls/devcerts.go index f4dc18f..0554aa4 100644 --- a/system/tls/devcerts.go +++ b/system/tls/devcerts.go @@ -89,7 +89,7 @@ func setupDev() { } hosts := []string{"localhost", "0.0.0.0"} - domain := db.ConfigCache("domain") + domain := db.ConfigCache("domain").(string) if domain != "" { hosts = append(hosts, domain) } diff --git a/system/tls/enable.go b/system/tls/enable.go index f9c16d8..4279b55 100644 --- a/system/tls/enable.go +++ b/system/tls/enable.go @@ -70,7 +70,7 @@ func Enable() { setup() server := &http.Server{ - Addr: fmt.Sprintf(":%s", db.ConfigCache("https_port")), + Addr: fmt.Sprintf(":%s", db.ConfigCache("https_port").(string)), TLSConfig: &tls.Config{GetCertificate: m.GetCertificate}, } |