diff options
37 files changed, 4671 insertions, 239 deletions
diff --git a/.gitattributes b/.gitattributes index d979c14..4d56b9a 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,2 +1,6 @@ system/admin/static/editor/sass/* linguist-vendored -cmd/ponzu/vendor/* linguist-vendored=false
\ No newline at end of file +cmd/ponzu/vendor/* linguist-vendored=false +system/admin/static/dashboard/js/* linguist-vendored +system/admin/static/editor/js/* linguist-vendored +system/admin/static/common/js/* linguist-vendored +system/admin/static/common/js/util.js linguist-vendored=false
\ No newline at end of file diff --git a/cmd/ponzu/main.go b/cmd/ponzu/main.go index ba85f36..234d3b6 100644 --- a/cmd/ponzu/main.go +++ b/cmd/ponzu/main.go @@ -11,7 +11,9 @@ import ( "github.com/bosssauce/ponzu/system/admin" "github.com/bosssauce/ponzu/system/api" + "github.com/bosssauce/ponzu/system/api/analytics" "github.com/bosssauce/ponzu/system/db" + "github.com/bosssauce/ponzu/system/tls" ) var usage = ` @@ -45,7 +47,7 @@ generate, gen, g <type>: -[[--port=8080] [--tls]] run <service(,service)>: +[[--port=8080] [--https]] run <service(,service)>: Starts the 'ponzu' HTTP server for the JSON API, Admin System, or both. The segments, separated by a comma, describe which services to start, either @@ -54,7 +56,7 @@ generate, gen, g <type>: automatically managed using Let's Encrypt (https://letsencrypt.org) Example: - $ ponzu --port=8080 --tls run admin,api + $ ponzu --port=8080 --https run admin,api (or) $ ponzu run admin (or) @@ -71,8 +73,8 @@ generate, gen, g <type>: ` var ( - port int - tls bool + port int + https bool // for ponzu internal / core development dev bool @@ -87,7 +89,7 @@ func init() { func main() { flag.IntVar(&port, "port", 8080, "port for ponzu to bind its listener") - flag.BoolVar(&tls, "tls", false, "enable automatic TLS/SSL certificate management") + flag.BoolVar(&https, "https", false, "enable automatic TLS/SSL certificate management") flag.BoolVar(&dev, "dev", false, "modify environment for Ponzu core development") flag.StringVar(&fork, "fork", "", "modify repo source for Ponzu core development") flag.Parse() @@ -132,10 +134,10 @@ func main() { case "run": var addTLS string - if tls { - addTLS = "--tls" + if https { + addTLS = "--https" } else { - addTLS = "--tls=false" + addTLS = "--https=false" } var services string @@ -168,6 +170,11 @@ func main() { case "serve", "s": db.Init() + defer db.Close() + + analytics.Init() + defer analytics.Close() + if len(args) > 1 { services := strings.Split(args[1], ",") @@ -184,10 +191,9 @@ func main() { } } - if tls { - fmt.Println("TLS through Let's Encrypt is not implemented yet.") - fmt.Println("Please run 'ponzu serve' without the --tls flag for now.") - os.Exit(1) + if https { + fmt.Println("Enabling HTTPS...") + tls.Enable() } log.Fatal(http.ListenAndServe(fmt.Sprintf(":%d", port), nil)) diff --git a/cmd/ponzu/options.go b/cmd/ponzu/options.go index cabe91a..b23ab2d 100644 --- a/cmd/ponzu/options.go +++ b/cmd/ponzu/options.go @@ -72,36 +72,16 @@ type {{ .name }} struct { editor editor.Editor // required: all maintained {{ .name }} fields must have json tags! - Title string ` + "`json:" + `"title"` + "`" + ` - Content string ` + "`json:" + `"content"` + "`" + ` - Author string ` + "`json:" + `"author"` + "`" + ` - Picture string ` + "`json:" + `"picture"` + "`" + ` - Category []string ` + "`json:" + `"category"` + "`" + ` - ThemeStyle string ` + "`json:" + `"theme"` + "`" + ` + Title string ` + "`json:" + `"title"` + "`" + ` + Content string ` + "`json:" + `"content"` + "`" + ` + Author string ` + "`json:" + `"author"` + "`" + ` + Photo string ` + "`json:" + `"photo"` + "`" + ` + Category []string ` + "`json:" + `"category"` + "`" + ` + Theme string ` + "`json:" + `"theme"` + "`" + ` } -func init() { - Types["{{ .name }}"] = func() interface{} { return new({{ .name }}) } -} - -// SetContentID partially implements editor.Editable -func ({{ .initial }} *{{ .name }}) SetContentID(id int) { {{ .initial }}.ID = id } - -// ContentID partially implements editor.Editable -func ({{ .initial }} *{{ .name }}) ContentID() int { return {{ .initial }}.ID } - -// ContentName partially implements editor.Editable -func ({{ .initial }} *{{ .name }}) ContentName() string { return {{ .initial }}.Title } - -// SetSlug partially implements editor.Editable -func ({{ .initial }} *{{ .name }}) SetSlug(slug string) { {{ .initial }}.Slug = slug } - -// Editor partially implements editor.Editable -func ({{ .initial }} *{{ .name }}) Editor() *editor.Editor { return &{{ .initial }}.editor } - // MarshalEditor writes a buffer of html to edit a {{ .name }} and partially implements editor.Editable func ({{ .initial }} *{{ .name }}) MarshalEditor() ([]byte, error) { -/* EXAMPLE CODE (from post.go, the default content type) */ view, err := editor.Form({{ .initial }}, editor.Field{ // Take careful note that the first argument to these Input-like methods @@ -127,22 +107,18 @@ func ({{ .initial }} *{{ .name }}) MarshalEditor() ([]byte, error) { }), }, editor.Field{ - View: editor.File("Picture", {{ .initial }}, map[string]string{ + View: editor.File("Photo", {{ .initial }}, map[string]string{ "label": "Author Photo", "placeholder": "Upload a profile picture for the author", }), }, editor.Field{ - View: editor.Checkbox("Category", {{ .initial }}, map[string]string{ + View: editor.Tags("Category", {{ .initial }}, map[string]string{ "label": "{{ .name }} Category", - }, map[string]string{ - "important": "Important", - "active": "Active", - "unplanned": "Unplanned", }), }, editor.Field{ - View: editor.Select("ThemeStyle", {{ .initial }}, map[string]string{ + View: editor.Select("Theme", {{ .initial }}, map[string]string{ "label": "Theme Style", }, map[string]string{ "dark": "Dark", @@ -157,6 +133,26 @@ func ({{ .initial }} *{{ .name }}) MarshalEditor() ([]byte, error) { return view, nil } + +func init() { + Types["{{ .name }}"] = func() interface{} { return new({{ .name }}) } +} + +// SetContentID partially implements editor.Editable +func ({{ .initial }} *{{ .name }}) SetContentID(id int) { {{ .initial }}.ID = id } + +// ContentID partially implements editor.Editable +func ({{ .initial }} *{{ .name }}) ContentID() int { return {{ .initial }}.ID } + +// ContentName partially implements editor.Editable +func ({{ .initial }} *{{ .name }}) ContentName() string { return {{ .initial }}.Title } + +// SetSlug partially implements editor.Editable +func ({{ .initial }} *{{ .name }}) SetSlug(slug string) { {{ .initial }}.Slug = slug } + +// Editor partially implements editor.Editable +func ({{ .initial }} *{{ .name }}) Editor() *editor.Editor { return &{{ .initial }}.editor } + ` func newProjectInDir(path string) error { diff --git a/cmd/ponzu/vendor/golang.org/x/crypto/autocert/autocert.go b/cmd/ponzu/vendor/golang.org/x/crypto/autocert/autocert.go new file mode 100644 index 0000000..12c9010 --- /dev/null +++ b/cmd/ponzu/vendor/golang.org/x/crypto/autocert/autocert.go @@ -0,0 +1,776 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package autocert provides automatic access to certificates from Let's Encrypt +// and any other ACME-based CA. +// +// This package is a work in progress and makes no API stability promises. +package autocert + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "fmt" + "io" + mathrand "math/rand" + "net/http" + "strconv" + "strings" + "sync" + "time" + + "golang.org/x/crypto/acme" + "golang.org/x/net/context" +) + +// pseudoRand is safe for concurrent use. +var pseudoRand *lockedMathRand + +func init() { + src := mathrand.NewSource(timeNow().UnixNano()) + pseudoRand = &lockedMathRand{rnd: mathrand.New(src)} +} + +// AcceptTOS always returns true to indicate the acceptance of a CA Terms of Service +// during account registration. +func AcceptTOS(tosURL string) bool { return true } + +// HostPolicy specifies which host names the Manager is allowed to respond to. +// It returns a non-nil error if the host should be rejected. +// The returned error is accessible via tls.Conn.Handshake and its callers. +// See Manager's HostPolicy field and GetCertificate method docs for more details. +type HostPolicy func(ctx context.Context, host string) error + +// HostWhitelist returns a policy where only the specified host names are allowed. +// Only exact matches are currently supported. Subdomains, regexp or wildcard +// will not match. +func HostWhitelist(hosts ...string) HostPolicy { + whitelist := make(map[string]bool, len(hosts)) + for _, h := range hosts { + whitelist[h] = true + } + return func(_ context.Context, host string) error { + if !whitelist[host] { + return errors.New("acme/autocert: host not configured") + } + return nil + } +} + +// defaultHostPolicy is used when Manager.HostPolicy is not set. +func defaultHostPolicy(context.Context, string) error { + return nil +} + +// Manager is a stateful certificate manager built on top of acme.Client. +// It obtains and refreshes certificates automatically, +// as well as providing them to a TLS server via tls.Config. +// +// A simple usage example: +// +// m := autocert.Manager{ +// Prompt: autocert.AcceptTOS, +// HostPolicy: autocert.HostWhitelist("example.org"), +// } +// s := &http.Server{ +// Addr: ":https", +// TLSConfig: &tls.Config{GetCertificate: m.GetCertificate}, +// } +// s.ListenAndServeTLS("", "") +// +// To preserve issued certificates and improve overall performance, +// use a cache implementation of Cache. For instance, DirCache. +type Manager struct { + // Prompt specifies a callback function to conditionally accept a CA's Terms of Service (TOS). + // The registration may require the caller to agree to the CA's TOS. + // If so, Manager calls Prompt with a TOS URL provided by the CA. Prompt should report + // whether the caller agrees to the terms. + // + // To always accept the terms, the callers can use AcceptTOS. + Prompt func(tosURL string) bool + + // Cache optionally stores and retrieves previously-obtained certificates. + // If nil, certs will only be cached for the lifetime of the Manager. + // + // Manager passes the Cache certificates data encoded in PEM, with private/public + // parts combined in a single Cache.Put call, private key first. + Cache Cache + + // HostPolicy controls which domains the Manager will attempt + // to retrieve new certificates for. It does not affect cached certs. + // + // If non-nil, HostPolicy is called before requesting a new cert. + // If nil, all hosts are currently allowed. This is not recommended, + // as it opens a potential attack where clients connect to a server + // by IP address and pretend to be asking for an incorrect host name. + // Manager will attempt to obtain a certificate for that host, incorrectly, + // eventually reaching the CA's rate limit for certificate requests + // and making it impossible to obtain actual certificates. + // + // See GetCertificate for more details. + HostPolicy HostPolicy + + // RenewBefore optionally specifies how early certificates should + // be renewed before they expire. + // + // If zero, they're renewed 1 week before expiration. + RenewBefore time.Duration + + // Client is used to perform low-level operations, such as account registration + // and requesting new certificates. + // If Client is nil, a zero-value acme.Client is used with acme.LetsEncryptURL + // directory endpoint and a newly-generated ECDSA P-256 key. + // + // Mutating the field after the first call of GetCertificate method will have no effect. + Client *acme.Client + + // Email optionally specifies a contact email address. + // This is used by CAs, such as Let's Encrypt, to notify about problems + // with issued certificates. + // + // If the Client's account key is already registered, Email is not used. + Email string + + clientMu sync.Mutex + client *acme.Client // initialized by acmeClient method + + stateMu sync.Mutex + state map[string]*certState // keyed by domain name + + // tokenCert is keyed by token domain name, which matches server name + // of ClientHello. Keys always have ".acme.invalid" suffix. + tokenCertMu sync.RWMutex + tokenCert map[string]*tls.Certificate + + // renewal tracks the set of domains currently running renewal timers. + // It is keyed by domain name. + renewalMu sync.Mutex + renewal map[string]*domainRenewal +} + +// GetCertificate implements the tls.Config.GetCertificate hook. +// It provides a TLS certificate for hello.ServerName host, including answering +// *.acme.invalid (TLS-SNI) challenges. All other fields of hello are ignored. +// +// If m.HostPolicy is non-nil, GetCertificate calls the policy before requesting +// a new cert. A non-nil error returned from m.HostPolicy halts TLS negotiation. +// The error is propagated back to the caller of GetCertificate and is user-visible. +// This does not affect cached certs. See HostPolicy field description for more details. +func (m *Manager) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, error) { + name := hello.ServerName + if name == "" { + return nil, errors.New("acme/autocert: missing server name") + } + + // check whether this is a token cert requested for TLS-SNI challenge + if strings.HasSuffix(name, ".acme.invalid") { + m.tokenCertMu.RLock() + defer m.tokenCertMu.RUnlock() + if cert := m.tokenCert[name]; cert != nil { + return cert, nil + } + if cert, err := m.cacheGet(name); err == nil { + return cert, nil + } + // TODO: cache error results? + return nil, fmt.Errorf("acme/autocert: no token cert for %q", name) + } + + // regular domain + cert, err := m.cert(name) + if err == nil { + return cert, nil + } + if err != ErrCacheMiss { + return nil, err + } + + // first-time + ctx := context.Background() // TODO: use a deadline? + if err := m.hostPolicy()(ctx, name); err != nil { + return nil, err + } + cert, err = m.createCert(ctx, name) + if err != nil { + return nil, err + } + m.cachePut(name, cert) + return cert, nil +} + +// cert returns an existing certificate either from m.state or cache. +// If a certificate is found in cache but not in m.state, the latter will be filled +// with the cached value. +func (m *Manager) cert(name string) (*tls.Certificate, error) { + m.stateMu.Lock() + if s, ok := m.state[name]; ok { + m.stateMu.Unlock() + s.RLock() + defer s.RUnlock() + return s.tlscert() + } + defer m.stateMu.Unlock() + cert, err := m.cacheGet(name) + if err != nil { + return nil, err + } + signer, ok := cert.PrivateKey.(crypto.Signer) + if !ok { + return nil, errors.New("acme/autocert: private key cannot sign") + } + if m.state == nil { + m.state = make(map[string]*certState) + } + s := &certState{ + key: signer, + cert: cert.Certificate, + leaf: cert.Leaf, + } + m.state[name] = s + go m.renew(name, s.key, s.leaf.NotAfter) + return cert, nil +} + +// cacheGet always returns a valid certificate, or an error otherwise. +func (m *Manager) cacheGet(domain string) (*tls.Certificate, error) { + if m.Cache == nil { + return nil, ErrCacheMiss + } + // TODO: might want to define a cache timeout on m + ctx := context.Background() + data, err := m.Cache.Get(ctx, domain) + if err != nil { + return nil, err + } + + // private + priv, pub := pem.Decode(data) + if priv == nil || !strings.Contains(priv.Type, "PRIVATE") { + return nil, errors.New("acme/autocert: no private key found in cache") + } + privKey, err := parsePrivateKey(priv.Bytes) + if err != nil { + return nil, err + } + + // public + var pubDER [][]byte + for len(pub) > 0 { + var b *pem.Block + b, pub = pem.Decode(pub) + if b == nil { + break + } + pubDER = append(pubDER, b.Bytes) + } + if len(pub) > 0 { + return nil, errors.New("acme/autocert: invalid public key") + } + + // verify and create TLS cert + leaf, err := validCert(domain, pubDER, privKey) + if err != nil { + return nil, err + } + tlscert := &tls.Certificate{ + Certificate: pubDER, + PrivateKey: privKey, + Leaf: leaf, + } + return tlscert, nil +} + +func (m *Manager) cachePut(domain string, tlscert *tls.Certificate) error { + if m.Cache == nil { + return nil + } + + // contains PEM-encoded data + var buf bytes.Buffer + + // private + switch key := tlscert.PrivateKey.(type) { + case *ecdsa.PrivateKey: + if err := encodeECDSAKey(&buf, key); err != nil { + return err + } + case *rsa.PrivateKey: + b := x509.MarshalPKCS1PrivateKey(key) + pb := &pem.Block{Type: "RSA PRIVATE KEY", Bytes: b} + if err := pem.Encode(&buf, pb); err != nil { + return err + } + default: + return errors.New("acme/autocert: unknown private key type") + } + + // public + for _, b := range tlscert.Certificate { + pb := &pem.Block{Type: "CERTIFICATE", Bytes: b} + if err := pem.Encode(&buf, pb); err != nil { + return err + } + } + + // TODO: might want to define a cache timeout on m + ctx := context.Background() + return m.Cache.Put(ctx, domain, buf.Bytes()) +} + +func encodeECDSAKey(w io.Writer, key *ecdsa.PrivateKey) error { + b, err := x509.MarshalECPrivateKey(key) + if err != nil { + return err + } + pb := &pem.Block{Type: "EC PRIVATE KEY", Bytes: b} + return pem.Encode(w, pb) +} + +// createCert starts the domain ownership verification and returns a certificate +// for that domain upon success. +// +// If the domain is already being verified, it waits for the existing verification to complete. +// Either way, createCert blocks for the duration of the whole process. +func (m *Manager) createCert(ctx context.Context, domain string) (*tls.Certificate, error) { + // TODO: maybe rewrite this whole piece using sync.Once + state, err := m.certState(domain) + if err != nil { + return nil, err + } + // state may exist if another goroutine is already working on it + // in which case just wait for it to finish + if !state.locked { + state.RLock() + defer state.RUnlock() + return state.tlscert() + } + + // We are the first; state is locked. + // Unblock the readers when domain ownership is verified + // and the we got the cert or the process failed. + defer state.Unlock() + state.locked = false + + der, leaf, err := m.authorizedCert(ctx, state.key, domain) + if err != nil { + return nil, err + } + state.cert = der + state.leaf = leaf + go m.renew(domain, state.key, state.leaf.NotAfter) + return state.tlscert() +} + +// certState returns a new or existing certState. +// If a new certState is returned, state.exist is false and the state is locked. +// The returned error is non-nil only in the case where a new state could not be created. +func (m *Manager) certState(domain string) (*certState, error) { + m.stateMu.Lock() + defer m.stateMu.Unlock() + if m.state == nil { + m.state = make(map[string]*certState) + } + // existing state + if state, ok := m.state[domain]; ok { + return state, nil + } + // new locked state + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, err + } + state := &certState{ + key: key, + locked: true, + } + state.Lock() // will be unlocked by m.certState caller + m.state[domain] = state + return state, nil +} + +// authorizedCert starts domain ownership verification process and requests a new cert upon success. +// The key argument is the certificate private key. +func (m *Manager) authorizedCert(ctx context.Context, key crypto.Signer, domain string) (der [][]byte, leaf *x509.Certificate, err error) { + // TODO: make m.verify retry or retry m.verify calls here + if err := m.verify(ctx, domain); err != nil { + return nil, nil, err + } + client, err := m.acmeClient(ctx) + if err != nil { + return nil, nil, err + } + csr, err := certRequest(key, domain) + if err != nil { + return nil, nil, err + } + der, _, err = client.CreateCert(ctx, csr, 0, true) + if err != nil { + return nil, nil, err + } + leaf, err = validCert(domain, der, key) + if err != nil { + return nil, nil, err + } + return der, leaf, nil +} + +// verify starts a new identifier (domain) authorization flow. +// It prepares a challenge response and then blocks until the authorization +// is marked as "completed" by the CA (either succeeded or failed). +// +// verify returns nil iff the verification was successful. +func (m *Manager) verify(ctx context.Context, domain string) error { + client, err := m.acmeClient(ctx) + if err != nil { + return err + } + + // start domain authorization and get the challenge + authz, err := client.Authorize(ctx, domain) + if err != nil { + return err + } + // maybe don't need to at all + if authz.Status == acme.StatusValid { + return nil + } + + // pick a challenge: prefer tls-sni-02 over tls-sni-01 + // TODO: consider authz.Combinations + var chal *acme.Challenge + for _, c := range authz.Challenges { + if c.Type == "tls-sni-02" { + chal = c + break + } + if c.Type == "tls-sni-01" { + chal = c + } + } + if chal == nil { + return errors.New("acme/autocert: no supported challenge type found") + } + + // create a token cert for the challenge response + var ( + cert tls.Certificate + name string + ) + switch chal.Type { + case "tls-sni-01": + cert, name, err = client.TLSSNI01ChallengeCert(chal.Token) + case "tls-sni-02": + cert, name, err = client.TLSSNI02ChallengeCert(chal.Token) + default: + err = fmt.Errorf("acme/autocert: unknown challenge type %q", chal.Type) + } + if err != nil { + return err + } + m.putTokenCert(name, &cert) + defer func() { + // verification has ended at this point + // don't need token cert anymore + go m.deleteTokenCert(name) + }() + + // ready to fulfill the challenge + if _, err := client.Accept(ctx, chal); err != nil { + return err + } + // wait for the CA to validate + _, err = client.WaitAuthorization(ctx, authz.URI) + return err +} + +// putTokenCert stores the cert under the named key in both m.tokenCert map +// and m.Cache. +func (m *Manager) putTokenCert(name string, cert *tls.Certificate) { + m.tokenCertMu.Lock() + defer m.tokenCertMu.Unlock() + if m.tokenCert == nil { + m.tokenCert = make(map[string]*tls.Certificate) + } + m.tokenCert[name] = cert + m.cachePut(name, cert) +} + +// deleteTokenCert removes the token certificate for the specified domain name +// from both m.tokenCert map and m.Cache. +func (m *Manager) deleteTokenCert(name string) { + m.tokenCertMu.Lock() + defer m.tokenCertMu.Unlock() + delete(m.tokenCert, name) + if m.Cache != nil { + m.Cache.Delete(context.Background(), name) + } +} + +// renew starts a cert renewal timer loop, one per domain. +// +// The loop is scheduled in two cases: +// - a cert was fetched from cache for the first time (wasn't in m.state) +// - a new cert was created by m.createCert +// +// The key argument is a certificate private key. +// The exp argument is the cert expiration time (NotAfter). +func (m *Manager) renew(domain string, key crypto.Signer, exp time.Time) { + m.renewalMu.Lock() + defer m.renewalMu.Unlock() + if m.renewal[domain] != nil { + // another goroutine is already on it + return + } + if m.renewal == nil { + m.renewal = make(map[string]*domainRenewal) + } + dr := &domainRenewal{m: m, domain: domain, key: key} + m.renewal[domain] = dr + dr.start(exp) +} + +// stopRenew stops all currently running cert renewal timers. +// The timers are not restarted during the lifetime of the Manager. +func (m *Manager) stopRenew() { + m.renewalMu.Lock() + defer m.renewalMu.Unlock() + for name, dr := range m.renewal { + delete(m.renewal, name) + dr.stop() + } +} + +func (m *Manager) accountKey(ctx context.Context) (crypto.Signer, error) { + const keyName = "acme_account.key" + + genKey := func() (*ecdsa.PrivateKey, error) { + return ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + } + + if m.Cache == nil { + return genKey() + } + + data, err := m.Cache.Get(ctx, keyName) + if err == ErrCacheMiss { + key, err := genKey() + if err != nil { + return nil, err + } + var buf bytes.Buffer + if err := encodeECDSAKey(&buf, key); err != nil { + return nil, err + } + if err := m.Cache.Put(ctx, keyName, buf.Bytes()); err != nil { + return nil, err + } + return key, nil + } + if err != nil { + return nil, err + } + + priv, _ := pem.Decode(data) + if priv == nil || !strings.Contains(priv.Type, "PRIVATE") { + return nil, errors.New("acme/autocert: invalid account key found in cache") + } + return parsePrivateKey(priv.Bytes) +} + +func (m *Manager) acmeClient(ctx context.Context) (*acme.Client, error) { + m.clientMu.Lock() + defer m.clientMu.Unlock() + if m.client != nil { + return m.client, nil + } + + client := m.Client + if client == nil { + client = &acme.Client{DirectoryURL: acme.LetsEncryptURL} + } + if client.Key == nil { + var err error + client.Key, err = m.accountKey(ctx) + if err != nil { + return nil, err + } + } + var contact []string + if m.Email != "" { + contact = []string{"mailto:" + m.Email} + } + a := &acme.Account{Contact: contact} + _, err := client.Register(ctx, a, m.Prompt) + if ae, ok := err.(*acme.Error); err == nil || ok && ae.StatusCode == http.StatusConflict { + // conflict indicates the key is already registered + m.client = client + err = nil + } + return m.client, err +} + +func (m *Manager) hostPolicy() HostPolicy { + if m.HostPolicy != nil { + return m.HostPolicy + } + return defaultHostPolicy +} + +func (m *Manager) renewBefore() time.Duration { + if m.RenewBefore > maxRandRenew { + return m.RenewBefore + } + return 7 * 24 * time.Hour // 1 week +} + +// certState is ready when its mutex is unlocked for reading. +type certState struct { + sync.RWMutex + locked bool // locked for read/write + key crypto.Signer // private key for cert + cert [][]byte // DER encoding + leaf *x509.Certificate // parsed cert[0]; always non-nil if cert != nil +} + +// tlscert creates a tls.Certificate from s.key and s.cert. +// Callers should wrap it in s.RLock() and s.RUnlock(). +func (s *certState) tlscert() (*tls.Certificate, error) { + if s.key == nil { + return nil, errors.New("acme/autocert: missing signer") + } + if len(s.cert) == 0 { + return nil, errors.New("acme/autocert: missing certificate") + } + return &tls.Certificate{ + PrivateKey: s.key, + Certificate: s.cert, + Leaf: s.leaf, + }, nil +} + +// certRequest creates a certificate request for the given common name cn +// and optional SANs. +func certRequest(key crypto.Signer, cn string, san ...string) ([]byte, error) { + req := &x509.CertificateRequest{ + Subject: pkix.Name{CommonName: cn}, + DNSNames: san, + } + return x509.CreateCertificateRequest(rand.Reader, req, key) +} + +// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates +// PKCS#1 private keys by default, while OpenSSL 1.0.0 generates PKCS#8 keys. +// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three. +// +// Inspired by parsePrivateKey in crypto/tls/tls.go. +func parsePrivateKey(der []byte) (crypto.Signer, error) { + if key, err := x509.ParsePKCS1PrivateKey(der); err == nil { + return key, nil + } + if key, err := x509.ParsePKCS8PrivateKey(der); err == nil { + switch key := key.(type) { + case *rsa.PrivateKey: + return key, nil + case *ecdsa.PrivateKey: + return key, nil + default: + return nil, errors.New("acme/autocert: unknown private key type in PKCS#8 wrapping") + } + } + if key, err := x509.ParseECPrivateKey(der); err == nil { + return key, nil + } + + return nil, errors.New("acme/autocert: failed to parse private key") +} + +// validCert parses a cert chain provided as der argument and verifies the leaf, der[0], +// corresponds to the private key, as well as the domain match and expiration dates. +// It doesn't do any revocation checking. +// +// The returned value is the verified leaf cert. +func validCert(domain string, der [][]byte, key crypto.Signer) (leaf *x509.Certificate, err error) { + // parse public part(s) + var n int + for _, b := range der { + n += len(b) + } + pub := make([]byte, n) + n = 0 + for _, b := range der { + n += copy(pub[n:], b) + } + x509Cert, err := x509.ParseCertificates(pub) + if len(x509Cert) == 0 { + return nil, errors.New("acme/autocert: no public key found") + } + // verify the leaf is not expired and matches the domain name + leaf = x509Cert[0] + now := timeNow() + if now.Before(leaf.NotBefore) { + return nil, errors.New("acme/autocert: certificate is not valid yet") + } + if now.After(leaf.NotAfter) { + return nil, errors.New("acme/autocert: expired certificate") + } + if err := leaf.VerifyHostname(domain); err != nil { + return nil, err + } + // ensure the leaf corresponds to the private key + switch pub := leaf.PublicKey.(type) { + case *rsa.PublicKey: + prv, ok := key.(*rsa.PrivateKey) + if !ok { + return nil, errors.New("acme/autocert: private key type does not match public key type") + } + if pub.N.Cmp(prv.N) != 0 { + return nil, errors.New("acme/autocert: private key does not match public key") + } + case *ecdsa.PublicKey: + prv, ok := key.(*ecdsa.PrivateKey) + if !ok { + return nil, errors.New("acme/autocert: private key type does not match public key type") + } + if pub.X.Cmp(prv.X) != 0 || pub.Y.Cmp(prv.Y) != 0 { + return nil, errors.New("acme/autocert: private key does not match public key") + } + default: + return nil, errors.New("acme/autocert: unknown public key algorithm") + } + return leaf, nil +} + +func retryAfter(v string) time.Duration { + if i, err := strconv.Atoi(v); err == nil { + return time.Duration(i) * time.Second + } + if t, err := http.ParseTime(v); err == nil { + return t.Sub(timeNow()) + } + return time.Second +} + +type lockedMathRand struct { + sync.Mutex + rnd *mathrand.Rand +} + +func (r *lockedMathRand) int63n(max int64) int64 { + r.Lock() + n := r.rnd.Int63n(max) + r.Unlock() + return n +} + +// for easier testing +var timeNow = time.Now diff --git a/cmd/ponzu/vendor/golang.org/x/crypto/autocert/autocert_test.go b/cmd/ponzu/vendor/golang.org/x/crypto/autocert/autocert_test.go new file mode 100644 index 0000000..3a9daa1 --- /dev/null +++ b/cmd/ponzu/vendor/golang.org/x/crypto/autocert/autocert_test.go @@ -0,0 +1,390 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package autocert + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/rsa" + "crypto/tls" + "crypto/x509" + "crypto/x509/pkix" + "encoding/base64" + "encoding/json" + "fmt" + "html/template" + "io" + "math/big" + "net/http" + "net/http/httptest" + "reflect" + "testing" + "time" + + "golang.org/x/crypto/acme" + "golang.org/x/net/context" +) + +var discoTmpl = template.Must(template.New("disco").Parse(`{ + "new-reg": "{{.}}/new-reg", + "new-authz": "{{.}}/new-authz", + "new-cert": "{{.}}/new-cert" +}`)) + +var authzTmpl = template.Must(template.New("authz").Parse(`{ + "status": "pending", + "challenges": [ + { + "uri": "{{.}}/challenge/1", + "type": "tls-sni-01", + "token": "token-01" + }, + { + "uri": "{{.}}/challenge/2", + "type": "tls-sni-02", + "token": "token-02" + } + ] +}`)) + +type memCache map[string][]byte + +func (m memCache) Get(ctx context.Context, key string) ([]byte, error) { + v, ok := m[key] + if !ok { + return nil, ErrCacheMiss + } + return v, nil +} + +func (m memCache) Put(ctx context.Context, key string, data []byte) error { + m[key] = data + return nil +} + +func (m memCache) Delete(ctx context.Context, key string) error { + delete(m, key) + return nil +} + +func dummyCert(pub interface{}, san ...string) ([]byte, error) { + return dateDummyCert(pub, time.Now(), time.Now().Add(90*24*time.Hour), san...) +} + +func dateDummyCert(pub interface{}, start, end time.Time, san ...string) ([]byte, error) { + // use EC key to run faster on 386 + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + return nil, err + } + t := &x509.Certificate{ + SerialNumber: big.NewInt(1), + NotBefore: start, + NotAfter: end, + BasicConstraintsValid: true, + KeyUsage: x509.KeyUsageKeyEncipherment, + DNSNames: san, + } + if pub == nil { + pub = &key.PublicKey + } + return x509.CreateCertificate(rand.Reader, t, t, pub, key) +} + +func decodePayload(v interface{}, r io.Reader) error { + var req struct{ Payload string } + if err := json.NewDecoder(r).Decode(&req); err != nil { + return err + } + payload, err := base64.RawURLEncoding.DecodeString(req.Payload) + if err != nil { + return err + } + return json.Unmarshal(payload, v) +} + +func TestGetCertificate(t *testing.T) { + const domain = "example.org" + man := &Manager{Prompt: AcceptTOS} + defer man.stopRenew() + + // echo token-02 | shasum -a 256 + // then divide result in 2 parts separated by dot + tokenCertName := "4e8eb87631187e9ff2153b56b13a4dec.13a35d002e485d60ff37354b32f665d9.token.acme.invalid" + verifyTokenCert := func() { + hello := &tls.ClientHelloInfo{ServerName: tokenCertName} + _, err := man.GetCertificate(hello) + if err != nil { + t.Errorf("verifyTokenCert: GetCertificate(%q): %v", tokenCertName, err) + return + } + } + + // ACME CA server stub + var ca *httptest.Server + ca = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("replay-nonce", "nonce") + if r.Method == "HEAD" { + // a nonce request + return + } + + switch r.URL.Path { + // discovery + case "/": + if err := discoTmpl.Execute(w, ca.URL); err != nil { + t.Fatalf("discoTmpl: %v", err) + } + // client key registration + case "/new-reg": + w.Write([]byte("{}")) + // domain authorization + case "/new-authz": + w.Header().Set("location", ca.URL+"/authz/1") + w.WriteHeader(http.StatusCreated) + if err := authzTmpl.Execute(w, ca.URL); err != nil { + t.Fatalf("authzTmpl: %v", err) + } + // accept tls-sni-02 challenge + case "/challenge/2": + verifyTokenCert() + w.Write([]byte("{}")) + // authorization status + case "/authz/1": + w.Write([]byte(`{"status": "valid"}`)) + // cert request + case "/new-cert": + var req struct { + CSR string `json:"csr"` + } + decodePayload(&req, r.Body) + b, _ := base64.RawURLEncoding.DecodeString(req.CSR) + csr, err := x509.ParseCertificateRequest(b) + if err != nil { + t.Fatalf("new-cert: CSR: %v", err) + } + der, err := dummyCert(csr.PublicKey, domain) + if err != nil { + t.Fatalf("new-cert: dummyCert: %v", err) + } + chainUp := fmt.Sprintf("<%s/ca-cert>; rel=up", ca.URL) + w.Header().Set("link", chainUp) + w.WriteHeader(http.StatusCreated) + w.Write(der) + // CA chain cert + case "/ca-cert": + der, err := dummyCert(nil, "ca") + if err != nil { + t.Fatalf("ca-cert: dummyCert: %v", err) + } + w.Write(der) + default: + t.Errorf("unrecognized r.URL.Path: %s", r.URL.Path) + } + })) + defer ca.Close() + + // use EC key to run faster on 386 + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + man.Client = &acme.Client{ + Key: key, + DirectoryURL: ca.URL, + } + + // simulate tls.Config.GetCertificate + var tlscert *tls.Certificate + done := make(chan struct{}) + go func() { + hello := &tls.ClientHelloInfo{ServerName: domain} + tlscert, err = man.GetCertificate(hello) + close(done) + }() + select { + case <-time.After(time.Minute): + t.Fatal("man.GetCertificate took too long to return") + case <-done: + } + if err != nil { + t.Fatalf("man.GetCertificate: %v", err) + } + + // verify the tlscert is the same we responded with from the CA stub + if len(tlscert.Certificate) == 0 { + t.Fatal("len(tlscert.Certificate) is 0") + } + cert, err := x509.ParseCertificate(tlscert.Certificate[0]) + if err != nil { + t.Fatalf("x509.ParseCertificate: %v", err) + } + if len(cert.DNSNames) == 0 || cert.DNSNames[0] != domain { + t.Errorf("cert.DNSNames = %v; want %q", cert.DNSNames, domain) + } + + // make sure token cert was removed + done = make(chan struct{}) + go func() { + for { + hello := &tls.ClientHelloInfo{ServerName: tokenCertName} + if _, err := man.GetCertificate(hello); err != nil { + break + } + time.Sleep(100 * time.Millisecond) + } + close(done) + }() + select { + case <-time.After(5 * time.Second): + t.Error("token cert was not removed") + case <-done: + } +} + +func TestAccountKeyCache(t *testing.T) { + cache := make(memCache) + m := Manager{Cache: cache} + ctx := context.Background() + k1, err := m.accountKey(ctx) + if err != nil { + t.Fatal(err) + } + k2, err := m.accountKey(ctx) + if err != nil { + t.Fatal(err) + } + if !reflect.DeepEqual(k1, k2) { + t.Errorf("account keys don't match: k1 = %#v; k2 = %#v", k1, k2) + } +} + +func TestCache(t *testing.T) { + privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(1), + Subject: pkix.Name{CommonName: "example.org"}, + NotAfter: time.Now().Add(time.Hour), + } + pub, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &privKey.PublicKey, privKey) + if err != nil { + t.Fatal(err) + } + tlscert := &tls.Certificate{ + Certificate: [][]byte{pub}, + PrivateKey: privKey, + } + + cache := make(memCache) + man := &Manager{Cache: cache} + defer man.stopRenew() + if err := man.cachePut("example.org", tlscert); err != nil { + t.Fatalf("man.cachePut: %v", err) + } + res, err := man.cacheGet("example.org") + if err != nil { + t.Fatalf("man.cacheGet: %v", err) + } + if res == nil { + t.Fatal("res is nil") + } +} + +func TestHostWhitelist(t *testing.T) { + policy := HostWhitelist("example.com", "example.org", "*.example.net") + tt := []struct { + host string + allow bool + }{ + {"example.com", true}, + {"example.org", true}, + {"one.example.com", false}, + {"two.example.org", false}, + {"three.example.net", false}, + {"dummy", false}, + } + for i, test := range tt { + err := policy(nil, test.host) + if err != nil && test.allow { + t.Errorf("%d: policy(%q): %v; want nil", i, test.host, err) + } + if err == nil && !test.allow { + t.Errorf("%d: policy(%q): nil; want an error", i, test.host) + } + } +} + +func TestValidCert(t *testing.T) { + key1, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + key2, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + key3, err := rsa.GenerateKey(rand.Reader, 512) + if err != nil { + t.Fatal(err) + } + cert1, err := dummyCert(key1.Public(), "example.org") + if err != nil { + t.Fatal(err) + } + cert2, err := dummyCert(key2.Public(), "example.org") + if err != nil { + t.Fatal(err) + } + cert3, err := dummyCert(key3.Public(), "example.org") + if err != nil { + t.Fatal(err) + } + now := time.Now() + early, err := dateDummyCert(key1.Public(), now.Add(time.Hour), now.Add(2*time.Hour), "example.org") + if err != nil { + t.Fatal(err) + } + expired, err := dateDummyCert(key1.Public(), now.Add(-2*time.Hour), now.Add(-time.Hour), "example.org") + if err != nil { + t.Fatal(err) + } + + tt := []struct { + domain string + key crypto.Signer + cert [][]byte + ok bool + }{ + {"example.org", key1, [][]byte{cert1}, true}, + {"example.org", key3, [][]byte{cert3}, true}, + {"example.org", key1, [][]byte{cert1, cert2, cert3}, true}, + {"example.org", key1, [][]byte{cert1, {1}}, false}, + {"example.org", key1, [][]byte{{1}}, false}, + {"example.org", key1, [][]byte{cert2}, false}, + {"example.org", key2, [][]byte{cert1}, false}, + {"example.org", key1, [][]byte{cert3}, false}, + {"example.org", key3, [][]byte{cert1}, false}, + {"example.net", key1, [][]byte{cert1}, false}, + {"example.org", key1, [][]byte{early}, false}, + {"example.org", key1, [][]byte{expired}, false}, + } + for i, test := range tt { + leaf, err := validCert(test.domain, test.cert, test.key) + if err != nil && test.ok { + t.Errorf("%d: err = %v", i, err) + } + if err == nil && !test.ok { + t.Errorf("%d: err is nil", i) + } + if err == nil && test.ok && leaf == nil { + t.Errorf("%d: leaf is nil", i) + } + } +} diff --git a/cmd/ponzu/vendor/golang.org/x/crypto/autocert/cache.go b/cmd/ponzu/vendor/golang.org/x/crypto/autocert/cache.go new file mode 100644 index 0000000..1c67f6c --- /dev/null +++ b/cmd/ponzu/vendor/golang.org/x/crypto/autocert/cache.go @@ -0,0 +1,130 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package autocert + +import ( + "errors" + "io/ioutil" + "os" + "path/filepath" + + "golang.org/x/net/context" +) + +// ErrCacheMiss is returned when a certificate is not found in cache. +var ErrCacheMiss = errors.New("acme/autocert: certificate cache miss") + +// Cache is used by Manager to store and retrieve previously obtained certificates +// as opaque data. +// +// The key argument of the methods refers to a domain name but need not be an FQDN. +// Cache implementations should not rely on the key naming pattern. +type Cache interface { + // Get returns a certificate data for the specified key. + // If there's no such key, Get returns ErrCacheMiss. + Get(ctx context.Context, key string) ([]byte, error) + + // Put stores the data in the cache under the specified key. + // Inderlying implementations may use any data storage format, + // as long as the reverse operation, Get, results in the original data. + Put(ctx context.Context, key string, data []byte) error + + // Delete removes a certificate data from the cache under the specified key. + // If there's no such key in the cache, Delete returns nil. + Delete(ctx context.Context, key string) error +} + +// DirCache implements Cache using a directory on the local filesystem. +// If the directory does not exist, it will be created with 0700 permissions. +type DirCache string + +// Get reads a certificate data from the specified file name. +func (d DirCache) Get(ctx context.Context, name string) ([]byte, error) { + name = filepath.Join(string(d), name) + var ( + data []byte + err error + done = make(chan struct{}) + ) + go func() { + data, err = ioutil.ReadFile(name) + close(done) + }() + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-done: + } + if os.IsNotExist(err) { + return nil, ErrCacheMiss + } + return data, err +} + +// Put writes the certificate data to the specified file name. +// The file will be created with 0600 permissions. +func (d DirCache) Put(ctx context.Context, name string, data []byte) error { + if err := os.MkdirAll(string(d), 0700); err != nil { + return err + } + + done := make(chan struct{}) + var err error + go func() { + defer close(done) + var tmp string + if tmp, err = d.writeTempFile(name, data); err != nil { + return + } + // prevent overwriting the file if the context was cancelled + if ctx.Err() != nil { + return // no need to set err + } + name = filepath.Join(string(d), name) + err = os.Rename(tmp, name) + }() + select { + case <-ctx.Done(): + return ctx.Err() + case <-done: + } + return err +} + +// Delete removes the specified file name. +func (d DirCache) Delete(ctx context.Context, name string) error { + name = filepath.Join(string(d), name) + var ( + err error + done = make(chan struct{}) + ) + go func() { + err = os.Remove(name) + close(done) + }() + select { + case <-ctx.Done(): + return ctx.Err() + case <-done: + } + if err != nil && !os.IsNotExist(err) { + return err + } + return nil +} + +// writeTempFile writes b to a temporary file, closes the file and returns its path. +func (d DirCache) writeTempFile(prefix string, b []byte) (string, error) { + // TempFile uses 0600 permissions + f, err := ioutil.TempFile(string(d), prefix) + if err != nil { + return "", err + } + if _, err := f.Write(b); err != nil { + f.Close() + return "", err + } + return f.Name(), f.Close() +} diff --git a/cmd/ponzu/vendor/golang.org/x/crypto/autocert/cache_test.go b/cmd/ponzu/vendor/golang.org/x/crypto/autocert/cache_test.go new file mode 100644 index 0000000..ad6d4a4 --- /dev/null +++ b/cmd/ponzu/vendor/golang.org/x/crypto/autocert/cache_test.go @@ -0,0 +1,58 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package autocert + +import ( + "io/ioutil" + "os" + "path/filepath" + "reflect" + "testing" + + "golang.org/x/net/context" +) + +// make sure DirCache satisfies Cache interface +var _ Cache = DirCache("/") + +func TestDirCache(t *testing.T) { + dir, err := ioutil.TempDir("", "autocert") + if err != nil { + t.Fatal(err) + } + dir = filepath.Join(dir, "certs") // a nonexistent dir + cache := DirCache(dir) + ctx := context.Background() + + // test cache miss + if _, err := cache.Get(ctx, "nonexistent"); err != ErrCacheMiss { + t.Errorf("get: %v; want ErrCacheMiss", err) + } + + // test put/get + b1 := []byte{1} + if err := cache.Put(ctx, "dummy", b1); err != nil { + t.Fatalf("put: %v", err) + } + b2, err := cache.Get(ctx, "dummy") + if err != nil { + t.Fatalf("get: %v", err) + } + if !reflect.DeepEqual(b1, b2) { + t.Errorf("b1 = %v; want %v", b1, b2) + } + name := filepath.Join(dir, "dummy") + if _, err := os.Stat(name); err != nil { + t.Error(err) + } + + // test delete + if err := cache.Delete(ctx, "dummy"); err != nil { + t.Fatalf("delete: %v", err) + } + if _, err := cache.Get(ctx, "dummy"); err != ErrCacheMiss { + t.Errorf("get: %v; want ErrCacheMiss", err) + } +} diff --git a/cmd/ponzu/vendor/golang.org/x/crypto/autocert/renewal.go b/cmd/ponzu/vendor/golang.org/x/crypto/autocert/renewal.go new file mode 100644 index 0000000..1a5018c --- /dev/null +++ b/cmd/ponzu/vendor/golang.org/x/crypto/autocert/renewal.go @@ -0,0 +1,125 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package autocert + +import ( + "crypto" + "sync" + "time" + + "golang.org/x/net/context" +) + +// maxRandRenew is a maximum deviation from Manager.RenewBefore. +const maxRandRenew = time.Hour + +// domainRenewal tracks the state used by the periodic timers +// renewing a single domain's cert. +type domainRenewal struct { + m *Manager + domain string + key crypto.Signer + + timerMu sync.Mutex + timer *time.Timer +} + +// start starts a cert renewal timer at the time +// defined by the certificate expiration time exp. +// +// If the timer is already started, calling start is a noop. +func (dr *domainRenewal) start(exp time.Time) { + dr.timerMu.Lock() + defer dr.timerMu.Unlock() + if dr.timer != nil { + return + } + dr.timer = time.AfterFunc(dr.next(exp), dr.renew) +} + +// stop stops the cert renewal timer. +// If the timer is already stopped, calling stop is a noop. +func (dr *domainRenewal) stop() { + dr.timerMu.Lock() + defer dr.timerMu.Unlock() + if dr.timer == nil { + return + } + dr.timer.Stop() + dr.timer = nil +} + +// renew is called periodically by a timer. +// The first renew call is kicked off by dr.start. +func (dr *domainRenewal) renew() { + dr.timerMu.Lock() + defer dr.timerMu.Unlock() + if dr.timer == nil { + return + } + + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Minute) + defer cancel() + // TODO: rotate dr.key at some point? + next, err := dr.do(ctx) + if err != nil { + next = maxRandRenew / 2 + next += time.Duration(pseudoRand.int63n(int64(next))) + } + dr.timer = time.AfterFunc(next, dr.renew) + testDidRenewLoop(next, err) +} + +// do is similar to Manager.createCert but it doesn't lock a Manager.state item. +// Instead, it requests a new certificate independently and, upon success, +// replaces dr.m.state item with a new one and updates cache for the given domain. +// +// It may return immediately if the expiration date of the currently cached cert +// is far enough in the future. +// +// The returned value is a time interval after which the renewal should occur again. +func (dr *domainRenewal) do(ctx context.Context) (time.Duration, error) { + // a race is likely unavoidable in a distributed environment + // but we try nonetheless + if tlscert, err := dr.m.cacheGet(dr.domain); err == nil { + next := dr.next(tlscert.Leaf.NotAfter) + if next > dr.m.renewBefore()+maxRandRenew { + return next, nil + } + } + + der, leaf, err := dr.m.authorizedCert(ctx, dr.key, dr.domain) + if err != nil { + return 0, err + } + state := &certState{ + key: dr.key, + cert: der, + leaf: leaf, + } + tlscert, err := state.tlscert() + if err != nil { + return 0, err + } + dr.m.cachePut(dr.domain, tlscert) + dr.m.stateMu.Lock() + defer dr.m.stateMu.Unlock() + // m.state is guaranteed to be non-nil at this point + dr.m.state[dr.domain] = state + return dr.next(leaf.NotAfter), nil +} + +func (dr *domainRenewal) next(expiry time.Time) time.Duration { + d := expiry.Sub(timeNow()) - dr.m.renewBefore() + // add a bit of randomness to renew deadline + n := pseudoRand.int63n(int64(maxRandRenew)) + d -= time.Duration(n) + if d < 0 { + return 0 + } + return d +} + +var testDidRenewLoop = func(next time.Duration, err error) {} diff --git a/cmd/ponzu/vendor/golang.org/x/crypto/autocert/renewal_test.go b/cmd/ponzu/vendor/golang.org/x/crypto/autocert/renewal_test.go new file mode 100644 index 0000000..d1ec52f --- /dev/null +++ b/cmd/ponzu/vendor/golang.org/x/crypto/autocert/renewal_test.go @@ -0,0 +1,190 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package autocert + +import ( + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rand" + "crypto/tls" + "crypto/x509" + "encoding/base64" + "fmt" + "net/http" + "net/http/httptest" + "testing" + "time" + + "golang.org/x/crypto/acme" +) + +func TestRenewalNext(t *testing.T) { + now := time.Now() + timeNow = func() time.Time { return now } + defer func() { timeNow = time.Now }() + + man := &Manager{RenewBefore: 7 * 24 * time.Hour} + defer man.stopRenew() + tt := []struct { + expiry time.Time + min, max time.Duration + }{ + {now.Add(90 * 24 * time.Hour), 83*24*time.Hour - maxRandRenew, 83 * 24 * time.Hour}, + {now.Add(time.Hour), 0, 1}, + {now, 0, 1}, + {now.Add(-time.Hour), 0, 1}, + } + + dr := &domainRenewal{m: man} + for i, test := range tt { + next := dr.next(test.expiry) + if next < test.min || test.max < next { + t.Errorf("%d: next = %v; want between %v and %v", i, next, test.min, test.max) + } + } +} + +func TestRenewFromCache(t *testing.T) { + const domain = "example.org" + + // ACME CA server stub + var ca *httptest.Server + ca = httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("replay-nonce", "nonce") + if r.Method == "HEAD" { + // a nonce request + return + } + + switch r.URL.Path { + // discovery + case "/": + if err := discoTmpl.Execute(w, ca.URL); err != nil { + t.Fatalf("discoTmpl: %v", err) + } + // client key registration + case "/new-reg": + w.Write([]byte("{}")) + // domain authorization + case "/new-authz": + w.Header().Set("location", ca.URL+"/authz/1") + w.WriteHeader(http.StatusCreated) + w.Write([]byte(`{"status": "valid"}`)) + // cert request + case "/new-cert": + var req struct { + CSR string `json:"csr"` + } + decodePayload(&req, r.Body) + b, _ := base64.RawURLEncoding.DecodeString(req.CSR) + csr, err := x509.ParseCertificateRequest(b) + if err != nil { + t.Fatalf("new-cert: CSR: %v", err) + } + der, err := dummyCert(csr.PublicKey, domain) + if err != nil { + t.Fatalf("new-cert: dummyCert: %v", err) + } + chainUp := fmt.Sprintf("<%s/ca-cert>; rel=up", ca.URL) + w.Header().Set("link", chainUp) + w.WriteHeader(http.StatusCreated) + w.Write(der) + // CA chain cert + case "/ca-cert": + der, err := dummyCert(nil, "ca") + if err != nil { + t.Fatalf("ca-cert: dummyCert: %v", err) + } + w.Write(der) + default: + t.Errorf("unrecognized r.URL.Path: %s", r.URL.Path) + } + })) + defer ca.Close() + + // use EC key to run faster on 386 + key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) + if err != nil { + t.Fatal(err) + } + man := &Manager{ + Prompt: AcceptTOS, + Cache: make(memCache), + RenewBefore: 24 * time.Hour, + Client: &acme.Client{ + Key: key, + DirectoryURL: ca.URL, + }, + } + defer man.stopRenew() + + // cache an almost expired cert + now := time.Now() + cert, err := dateDummyCert(key.Public(), now.Add(-2*time.Hour), now.Add(time.Minute), domain) + if err != nil { + t.Fatal(err) + } + tlscert := &tls.Certificate{PrivateKey: key, Certificate: [][]byte{cert}} + if err := man.cachePut(domain, tlscert); err != nil { + t.Fatal(err) + } + + // veriy the renewal happened + defer func() { + testDidRenewLoop = func(next time.Duration, err error) {} + }() + done := make(chan struct{}) + testDidRenewLoop = func(next time.Duration, err error) { + defer close(done) + if err != nil { + t.Errorf("testDidRenewLoop: %v", err) + } + // Next should be about 90 days: + // dummyCert creates 90days expiry + account for man.RenewBefore. + // Previous expiration was within 1 min. + future := 88 * 24 * time.Hour + if next < future { + t.Errorf("testDidRenewLoop: next = %v; want >= %v", next, future) + } + + // ensure the new cert is cached + after := time.Now().Add(future) + tlscert, err := man.cacheGet(domain) + if err != nil { + t.Fatalf("man.cacheGet: %v", err) + } + if !tlscert.Leaf.NotAfter.After(after) { + t.Errorf("cache leaf.NotAfter = %v; want > %v", tlscert.Leaf.NotAfter, after) + } + + // verify the old cert is also replaced in memory + man.stateMu.Lock() + defer man.stateMu.Unlock() + s := man.state[domain] + if s == nil { + t.Fatalf("m.state[%q] is nil", domain) + } + tlscert, err = s.tlscert() + if err != nil { + t.Fatalf("s.tlscert: %v", err) + } + if !tlscert.Leaf.NotAfter.After(after) { + t.Errorf("state leaf.NotAfter = %v; want > %v", tlscert.Leaf.NotAfter, after) + } + } + + // trigger renew + hello := &tls.ClientHelloInfo{ServerName: domain} + if _, err := man.GetCertificate(hello); err != nil { + t.Fatal(err) + } + + // wait for renew loop + select { + case <-time.After(10 * time.Second): + t.Fatal("renew took too long to occur") + case <-done: + } +} diff --git a/cmd/ponzu/vendor/golang.org/x/net/context/context.go b/cmd/ponzu/vendor/golang.org/x/net/context/context.go new file mode 100644 index 0000000..134654c --- /dev/null +++ b/cmd/ponzu/vendor/golang.org/x/net/context/context.go @@ -0,0 +1,156 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package context defines the Context type, which carries deadlines, +// cancelation signals, and other request-scoped values across API boundaries +// and between processes. +// +// Incoming requests to a server should create a Context, and outgoing calls to +// servers should accept a Context. The chain of function calls between must +// propagate the Context, optionally replacing it with a modified copy created +// using WithDeadline, WithTimeout, WithCancel, or WithValue. +// +// Programs that use Contexts should follow these rules to keep interfaces +// consistent across packages and enable static analysis tools to check context +// propagation: +// +// Do not store Contexts inside a struct type; instead, pass a Context +// explicitly to each function that needs it. The Context should be the first +// parameter, typically named ctx: +// +// func DoSomething(ctx context.Context, arg Arg) error { +// // ... use ctx ... +// } +// +// Do not pass a nil Context, even if a function permits it. Pass context.TODO +// if you are unsure about which Context to use. +// +// Use context Values only for request-scoped data that transits processes and +// APIs, not for passing optional parameters to functions. +// +// The same Context may be passed to functions running in different goroutines; +// Contexts are safe for simultaneous use by multiple goroutines. +// +// See http://blog.golang.org/context for example code for a server that uses +// Contexts. +package context // import "golang.org/x/net/context" + +import "time" + +// A Context carries a deadline, a cancelation signal, and other values across +// API boundaries. +// +// Context's methods may be called by multiple goroutines simultaneously. +type Context interface { + // Deadline returns the time when work done on behalf of this context + // should be canceled. Deadline returns ok==false when no deadline is + // set. Successive calls to Deadline return the same results. + Deadline() (deadline time.Time, ok bool) + + // Done returns a channel that's closed when work done on behalf of this + // context should be canceled. Done may return nil if this context can + // never be canceled. Successive calls to Done return the same value. + // + // WithCancel arranges for Done to be closed when cancel is called; + // WithDeadline arranges for Done to be closed when the deadline + // expires; WithTimeout arranges for Done to be closed when the timeout + // elapses. + // + // Done is provided for use in select statements: + // + // // Stream generates values with DoSomething and sends them to out + // // until DoSomething returns an error or ctx.Done is closed. + // func Stream(ctx context.Context, out chan<- Value) error { + // for { + // v, err := DoSomething(ctx) + // if err != nil { + // return err + // } + // select { + // case <-ctx.Done(): + // return ctx.Err() + // case out <- v: + // } + // } + // } + // + // See http://blog.golang.org/pipelines for more examples of how to use + // a Done channel for cancelation. + Done() <-chan struct{} + + // Err returns a non-nil error value after Done is closed. Err returns + // Canceled if the context was canceled or DeadlineExceeded if the + // context's deadline passed. No other values for Err are defined. + // After Done is closed, successive calls to Err return the same value. + Err() error + + // Value returns the value associated with this context for key, or nil + // if no value is associated with key. Successive calls to Value with + // the same key returns the same result. + // + // Use context values only for request-scoped data that transits + // processes and API boundaries, not for passing optional parameters to + // functions. + // + // A key identifies a specific value in a Context. Functions that wish + // to store values in Context typically allocate a key in a global + // variable then use that key as the argument to context.WithValue and + // Context.Value. A key can be any type that supports equality; + // packages should define keys as an unexported type to avoid + // collisions. + // + // Packages that define a Context key should provide type-safe accessors + // for the values stores using that key: + // + // // Package user defines a User type that's stored in Contexts. + // package user + // + // import "golang.org/x/net/context" + // + // // User is the type of value stored in the Contexts. + // type User struct {...} + // + // // key is an unexported type for keys defined in this package. + // // This prevents collisions with keys defined in other packages. + // type key int + // + // // userKey is the key for user.User values in Contexts. It is + // // unexported; clients use user.NewContext and user.FromContext + // // instead of using this key directly. + // var userKey key = 0 + // + // // NewContext returns a new Context that carries value u. + // func NewContext(ctx context.Context, u *User) context.Context { + // return context.WithValue(ctx, userKey, u) + // } + // + // // FromContext returns the User value stored in ctx, if any. + // func FromContext(ctx context.Context) (*User, bool) { + // u, ok := ctx.Value(userKey).(*User) + // return u, ok + // } + Value(key interface{}) interface{} +} + +// Background returns a non-nil, empty Context. It is never canceled, has no +// values, and has no deadline. It is typically used by the main function, +// initialization, and tests, and as the top-level Context for incoming +// requests. +func Background() Context { + return background +} + +// TODO returns a non-nil, empty Context. Code should use context.TODO when +// it's unclear which Context to use or it is not yet available (because the +// surrounding function has not yet been extended to accept a Context +// parameter). TODO is recognized by static analysis tools that determine +// whether Contexts are propagated correctly in a program. +func TODO() Context { + return todo +} + +// A CancelFunc tells an operation to abandon its work. +// A CancelFunc does not wait for the work to stop. +// After the first call, subsequent calls to a CancelFunc do nothing. +type CancelFunc func() diff --git a/cmd/ponzu/vendor/golang.org/x/net/context/context_test.go b/cmd/ponzu/vendor/golang.org/x/net/context/context_test.go new file mode 100644 index 0000000..9554dcf --- /dev/null +++ b/cmd/ponzu/vendor/golang.org/x/net/context/context_test.go @@ -0,0 +1,577 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.7 + +package context + +import ( + "fmt" + "math/rand" + "runtime" + "strings" + "sync" + "testing" + "time" +) + +// otherContext is a Context that's not one of the types defined in context.go. +// This lets us test code paths that differ based on the underlying type of the +// Context. +type otherContext struct { + Context +} + +func TestBackground(t *testing.T) { + c := Background() + if c == nil { + t.Fatalf("Background returned nil") + } + select { + case x := <-c.Done(): + t.Errorf("<-c.Done() == %v want nothing (it should block)", x) + default: + } + if got, want := fmt.Sprint(c), "context.Background"; got != want { + t.Errorf("Background().String() = %q want %q", got, want) + } +} + +func TestTODO(t *testing.T) { + c := TODO() + if c == nil { + t.Fatalf("TODO returned nil") + } + select { + case x := <-c.Done(): + t.Errorf("<-c.Done() == %v want nothing (it should block)", x) + default: + } + if got, want := fmt.Sprint(c), "context.TODO"; got != want { + t.Errorf("TODO().String() = %q want %q", got, want) + } +} + +func TestWithCancel(t *testing.T) { + c1, cancel := WithCancel(Background()) + + if got, want := fmt.Sprint(c1), "context.Background.WithCancel"; got != want { + t.Errorf("c1.String() = %q want %q", got, want) + } + + o := otherContext{c1} + c2, _ := WithCancel(o) + contexts := []Context{c1, o, c2} + + for i, c := range contexts { + if d := c.Done(); d == nil { + t.Errorf("c[%d].Done() == %v want non-nil", i, d) + } + if e := c.Err(); e != nil { + t.Errorf("c[%d].Err() == %v want nil", i, e) + } + + select { + case x := <-c.Done(): + t.Errorf("<-c.Done() == %v want nothing (it should block)", x) + default: + } + } + + cancel() + time.Sleep(100 * time.Millisecond) // let cancelation propagate + + for i, c := range contexts { + select { + case <-c.Done(): + default: + t.Errorf("<-c[%d].Done() blocked, but shouldn't have", i) + } + if e := c.Err(); e != Canceled { + t.Errorf("c[%d].Err() == %v want %v", i, e, Canceled) + } + } +} + +func TestParentFinishesChild(t *testing.T) { + // Context tree: + // parent -> cancelChild + // parent -> valueChild -> timerChild + parent, cancel := WithCancel(Background()) + cancelChild, stop := WithCancel(parent) + defer stop() + valueChild := WithValue(parent, "key", "value") + timerChild, stop := WithTimeout(valueChild, 10000*time.Hour) + defer stop() + + select { + case x := <-parent.Done(): + t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) + case x := <-cancelChild.Done(): + t.Errorf("<-cancelChild.Done() == %v want nothing (it should block)", x) + case x := <-timerChild.Done(): + t.Errorf("<-timerChild.Done() == %v want nothing (it should block)", x) + case x := <-valueChild.Done(): + t.Errorf("<-valueChild.Done() == %v want nothing (it should block)", x) + default: + } + + // The parent's children should contain the two cancelable children. + pc := parent.(*cancelCtx) + cc := cancelChild.(*cancelCtx) + tc := timerChild.(*timerCtx) + pc.mu.Lock() + if len(pc.children) != 2 || !pc.children[cc] || !pc.children[tc] { + t.Errorf("bad linkage: pc.children = %v, want %v and %v", + pc.children, cc, tc) + } + pc.mu.Unlock() + + if p, ok := parentCancelCtx(cc.Context); !ok || p != pc { + t.Errorf("bad linkage: parentCancelCtx(cancelChild.Context) = %v, %v want %v, true", p, ok, pc) + } + if p, ok := parentCancelCtx(tc.Context); !ok || p != pc { + t.Errorf("bad linkage: parentCancelCtx(timerChild.Context) = %v, %v want %v, true", p, ok, pc) + } + + cancel() + + pc.mu.Lock() + if len(pc.children) != 0 { + t.Errorf("pc.cancel didn't clear pc.children = %v", pc.children) + } + pc.mu.Unlock() + + // parent and children should all be finished. + check := func(ctx Context, name string) { + select { + case <-ctx.Done(): + default: + t.Errorf("<-%s.Done() blocked, but shouldn't have", name) + } + if e := ctx.Err(); e != Canceled { + t.Errorf("%s.Err() == %v want %v", name, e, Canceled) + } + } + check(parent, "parent") + check(cancelChild, "cancelChild") + check(valueChild, "valueChild") + check(timerChild, "timerChild") + + // WithCancel should return a canceled context on a canceled parent. + precanceledChild := WithValue(parent, "key", "value") + select { + case <-precanceledChild.Done(): + default: + t.Errorf("<-precanceledChild.Done() blocked, but shouldn't have") + } + if e := precanceledChild.Err(); e != Canceled { + t.Errorf("precanceledChild.Err() == %v want %v", e, Canceled) + } +} + +func TestChildFinishesFirst(t *testing.T) { + cancelable, stop := WithCancel(Background()) + defer stop() + for _, parent := range []Context{Background(), cancelable} { + child, cancel := WithCancel(parent) + + select { + case x := <-parent.Done(): + t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) + case x := <-child.Done(): + t.Errorf("<-child.Done() == %v want nothing (it should block)", x) + default: + } + + cc := child.(*cancelCtx) + pc, pcok := parent.(*cancelCtx) // pcok == false when parent == Background() + if p, ok := parentCancelCtx(cc.Context); ok != pcok || (ok && pc != p) { + t.Errorf("bad linkage: parentCancelCtx(cc.Context) = %v, %v want %v, %v", p, ok, pc, pcok) + } + + if pcok { + pc.mu.Lock() + if len(pc.children) != 1 || !pc.children[cc] { + t.Errorf("bad linkage: pc.children = %v, cc = %v", pc.children, cc) + } + pc.mu.Unlock() + } + + cancel() + + if pcok { + pc.mu.Lock() + if len(pc.children) != 0 { + t.Errorf("child's cancel didn't remove self from pc.children = %v", pc.children) + } + pc.mu.Unlock() + } + + // child should be finished. + select { + case <-child.Done(): + default: + t.Errorf("<-child.Done() blocked, but shouldn't have") + } + if e := child.Err(); e != Canceled { + t.Errorf("child.Err() == %v want %v", e, Canceled) + } + + // parent should not be finished. + select { + case x := <-parent.Done(): + t.Errorf("<-parent.Done() == %v want nothing (it should block)", x) + default: + } + if e := parent.Err(); e != nil { + t.Errorf("parent.Err() == %v want nil", e) + } + } +} + +func testDeadline(c Context, wait time.Duration, t *testing.T) { + select { + case <-time.After(wait): + t.Fatalf("context should have timed out") + case <-c.Done(): + } + if e := c.Err(); e != DeadlineExceeded { + t.Errorf("c.Err() == %v want %v", e, DeadlineExceeded) + } +} + +func TestDeadline(t *testing.T) { + c, _ := WithDeadline(Background(), time.Now().Add(100*time.Millisecond)) + if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) { + t.Errorf("c.String() = %q want prefix %q", got, prefix) + } + testDeadline(c, 200*time.Millisecond, t) + + c, _ = WithDeadline(Background(), time.Now().Add(100*time.Millisecond)) + o := otherContext{c} + testDeadline(o, 200*time.Millisecond, t) + + c, _ = WithDeadline(Background(), time.Now().Add(100*time.Millisecond)) + o = otherContext{c} + c, _ = WithDeadline(o, time.Now().Add(300*time.Millisecond)) + testDeadline(c, 200*time.Millisecond, t) +} + +func TestTimeout(t *testing.T) { + c, _ := WithTimeout(Background(), 100*time.Millisecond) + if got, prefix := fmt.Sprint(c), "context.Background.WithDeadline("; !strings.HasPrefix(got, prefix) { + t.Errorf("c.String() = %q want prefix %q", got, prefix) + } + testDeadline(c, 200*time.Millisecond, t) + + c, _ = WithTimeout(Background(), 100*time.Millisecond) + o := otherContext{c} + testDeadline(o, 200*time.Millisecond, t) + + c, _ = WithTimeout(Background(), 100*time.Millisecond) + o = otherContext{c} + c, _ = WithTimeout(o, 300*time.Millisecond) + testDeadline(c, 200*time.Millisecond, t) +} + +func TestCanceledTimeout(t *testing.T) { + c, _ := WithTimeout(Background(), 200*time.Millisecond) + o := otherContext{c} + c, cancel := WithTimeout(o, 400*time.Millisecond) + cancel() + time.Sleep(100 * time.Millisecond) // let cancelation propagate + select { + case <-c.Done(): + default: + t.Errorf("<-c.Done() blocked, but shouldn't have") + } + if e := c.Err(); e != Canceled { + t.Errorf("c.Err() == %v want %v", e, Canceled) + } +} + +type key1 int +type key2 int + +var k1 = key1(1) +var k2 = key2(1) // same int as k1, different type +var k3 = key2(3) // same type as k2, different int + +func TestValues(t *testing.T) { + check := func(c Context, nm, v1, v2, v3 string) { + if v, ok := c.Value(k1).(string); ok == (len(v1) == 0) || v != v1 { + t.Errorf(`%s.Value(k1).(string) = %q, %t want %q, %t`, nm, v, ok, v1, len(v1) != 0) + } + if v, ok := c.Value(k2).(string); ok == (len(v2) == 0) || v != v2 { + t.Errorf(`%s.Value(k2).(string) = %q, %t want %q, %t`, nm, v, ok, v2, len(v2) != 0) + } + if v, ok := c.Value(k3).(string); ok == (len(v3) == 0) || v != v3 { + t.Errorf(`%s.Value(k3).(string) = %q, %t want %q, %t`, nm, v, ok, v3, len(v3) != 0) + } + } + + c0 := Background() + check(c0, "c0", "", "", "") + + c1 := WithValue(Background(), k1, "c1k1") + check(c1, "c1", "c1k1", "", "") + + if got, want := fmt.Sprint(c1), `context.Background.WithValue(1, "c1k1")`; got != want { + t.Errorf("c.String() = %q want %q", got, want) + } + + c2 := WithValue(c1, k2, "c2k2") + check(c2, "c2", "c1k1", "c2k2", "") + + c3 := WithValue(c2, k3, "c3k3") + check(c3, "c2", "c1k1", "c2k2", "c3k3") + + c4 := WithValue(c3, k1, nil) + check(c4, "c4", "", "c2k2", "c3k3") + + o0 := otherContext{Background()} + check(o0, "o0", "", "", "") + + o1 := otherContext{WithValue(Background(), k1, "c1k1")} + check(o1, "o1", "c1k1", "", "") + + o2 := WithValue(o1, k2, "o2k2") + check(o2, "o2", "c1k1", "o2k2", "") + + o3 := otherContext{c4} + check(o3, "o3", "", "c2k2", "c3k3") + + o4 := WithValue(o3, k3, nil) + check(o4, "o4", "", "c2k2", "") +} + +func TestAllocs(t *testing.T) { + bg := Background() + for _, test := range []struct { + desc string + f func() + limit float64 + gccgoLimit float64 + }{ + { + desc: "Background()", + f: func() { Background() }, + limit: 0, + gccgoLimit: 0, + }, + { + desc: fmt.Sprintf("WithValue(bg, %v, nil)", k1), + f: func() { + c := WithValue(bg, k1, nil) + c.Value(k1) + }, + limit: 3, + gccgoLimit: 3, + }, + { + desc: "WithTimeout(bg, 15*time.Millisecond)", + f: func() { + c, _ := WithTimeout(bg, 15*time.Millisecond) + <-c.Done() + }, + limit: 8, + gccgoLimit: 16, + }, + { + desc: "WithCancel(bg)", + f: func() { + c, cancel := WithCancel(bg) + cancel() + <-c.Done() + }, + limit: 5, + gccgoLimit: 8, + }, + { + desc: "WithTimeout(bg, 100*time.Millisecond)", + f: func() { + c, cancel := WithTimeout(bg, 100*time.Millisecond) + cancel() + <-c.Done() + }, + limit: 8, + gccgoLimit: 25, + }, + } { + limit := test.limit + if runtime.Compiler == "gccgo" { + // gccgo does not yet do escape analysis. + // TODO(iant): Remove this when gccgo does do escape analysis. + limit = test.gccgoLimit + } + if n := testing.AllocsPerRun(100, test.f); n > limit { + t.Errorf("%s allocs = %f want %d", test.desc, n, int(limit)) + } + } +} + +func TestSimultaneousCancels(t *testing.T) { + root, cancel := WithCancel(Background()) + m := map[Context]CancelFunc{root: cancel} + q := []Context{root} + // Create a tree of contexts. + for len(q) != 0 && len(m) < 100 { + parent := q[0] + q = q[1:] + for i := 0; i < 4; i++ { + ctx, cancel := WithCancel(parent) + m[ctx] = cancel + q = append(q, ctx) + } + } + // Start all the cancels in a random order. + var wg sync.WaitGroup + wg.Add(len(m)) + for _, cancel := range m { + go func(cancel CancelFunc) { + cancel() + wg.Done() + }(cancel) + } + // Wait on all the contexts in a random order. + for ctx := range m { + select { + case <-ctx.Done(): + case <-time.After(1 * time.Second): + buf := make([]byte, 10<<10) + n := runtime.Stack(buf, true) + t.Fatalf("timed out waiting for <-ctx.Done(); stacks:\n%s", buf[:n]) + } + } + // Wait for all the cancel functions to return. + done := make(chan struct{}) + go func() { + wg.Wait() + close(done) + }() + select { + case <-done: + case <-time.After(1 * time.Second): + buf := make([]byte, 10<<10) + n := runtime.Stack(buf, true) + t.Fatalf("timed out waiting for cancel functions; stacks:\n%s", buf[:n]) + } +} + +func TestInterlockedCancels(t *testing.T) { + parent, cancelParent := WithCancel(Background()) + child, cancelChild := WithCancel(parent) + go func() { + parent.Done() + cancelChild() + }() + cancelParent() + select { + case <-child.Done(): + case <-time.After(1 * time.Second): + buf := make([]byte, 10<<10) + n := runtime.Stack(buf, true) + t.Fatalf("timed out waiting for child.Done(); stacks:\n%s", buf[:n]) + } +} + +func TestLayersCancel(t *testing.T) { + testLayers(t, time.Now().UnixNano(), false) +} + +func TestLayersTimeout(t *testing.T) { + testLayers(t, time.Now().UnixNano(), true) +} + +func testLayers(t *testing.T, seed int64, testTimeout bool) { + rand.Seed(seed) + errorf := func(format string, a ...interface{}) { + t.Errorf(fmt.Sprintf("seed=%d: %s", seed, format), a...) + } + const ( + timeout = 200 * time.Millisecond + minLayers = 30 + ) + type value int + var ( + vals []*value + cancels []CancelFunc + numTimers int + ctx = Background() + ) + for i := 0; i < minLayers || numTimers == 0 || len(cancels) == 0 || len(vals) == 0; i++ { + switch rand.Intn(3) { + case 0: + v := new(value) + ctx = WithValue(ctx, v, v) + vals = append(vals, v) + case 1: + var cancel CancelFunc + ctx, cancel = WithCancel(ctx) + cancels = append(cancels, cancel) + case 2: + var cancel CancelFunc + ctx, cancel = WithTimeout(ctx, timeout) + cancels = append(cancels, cancel) + numTimers++ + } + } + checkValues := func(when string) { + for _, key := range vals { + if val := ctx.Value(key).(*value); key != val { + errorf("%s: ctx.Value(%p) = %p want %p", when, key, val, key) + } + } + } + select { + case <-ctx.Done(): + errorf("ctx should not be canceled yet") + default: + } + if s, prefix := fmt.Sprint(ctx), "context.Background."; !strings.HasPrefix(s, prefix) { + t.Errorf("ctx.String() = %q want prefix %q", s, prefix) + } + t.Log(ctx) + checkValues("before cancel") + if testTimeout { + select { + case <-ctx.Done(): + case <-time.After(timeout + 100*time.Millisecond): + errorf("ctx should have timed out") + } + checkValues("after timeout") + } else { + cancel := cancels[rand.Intn(len(cancels))] + cancel() + select { + case <-ctx.Done(): + default: + errorf("ctx should be canceled") + } + checkValues("after cancel") + } +} + +func TestCancelRemoves(t *testing.T) { + checkChildren := func(when string, ctx Context, want int) { + if got := len(ctx.(*cancelCtx).children); got != want { + t.Errorf("%s: context has %d children, want %d", when, got, want) + } + } + + ctx, _ := WithCancel(Background()) + checkChildren("after creation", ctx, 0) + _, cancel := WithCancel(ctx) + checkChildren("with WithCancel child ", ctx, 1) + cancel() + checkChildren("after cancelling WithCancel child", ctx, 0) + + ctx, _ = WithCancel(Background()) + checkChildren("after creation", ctx, 0) + _, cancel = WithTimeout(ctx, 60*time.Minute) + checkChildren("with WithTimeout child ", ctx, 1) + cancel() + checkChildren("after cancelling WithTimeout child", ctx, 0) +} diff --git a/cmd/ponzu/vendor/golang.org/x/net/context/ctxhttp/ctxhttp.go b/cmd/ponzu/vendor/golang.org/x/net/context/ctxhttp/ctxhttp.go new file mode 100644 index 0000000..606cf1f --- /dev/null +++ b/cmd/ponzu/vendor/golang.org/x/net/context/ctxhttp/ctxhttp.go @@ -0,0 +1,74 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.7 + +// Package ctxhttp provides helper functions for performing context-aware HTTP requests. +package ctxhttp // import "golang.org/x/net/context/ctxhttp" + +import ( + "io" + "net/http" + "net/url" + "strings" + + "golang.org/x/net/context" +) + +// Do sends an HTTP request with the provided http.Client and returns +// an HTTP response. +// +// If the client is nil, http.DefaultClient is used. +// +// The provided ctx must be non-nil. If it is canceled or times out, +// ctx.Err() will be returned. +func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) { + if client == nil { + client = http.DefaultClient + } + resp, err := client.Do(req.WithContext(ctx)) + // If we got an error, and the context has been canceled, + // the context's error is probably more useful. + if err != nil { + select { + case <-ctx.Done(): + err = ctx.Err() + default: + } + } + return resp, err +} + +// Get issues a GET request via the Do function. +func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + return Do(ctx, client, req) +} + +// Head issues a HEAD request via the Do function. +func Head(ctx context.Context, client *http.Client, url string) (*http.Response, error) { + req, err := http.NewRequest("HEAD", url, nil) + if err != nil { + return nil, err + } + return Do(ctx, client, req) +} + +// Post issues a POST request via the Do function. +func Post(ctx context.Context, client *http.Client, url string, bodyType string, body io.Reader) (*http.Response, error) { + req, err := http.NewRequest("POST", url, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", bodyType) + return Do(ctx, client, req) +} + +// PostForm issues a POST request via the Do function. +func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) { + return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) +} diff --git a/cmd/ponzu/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_17_test.go b/cmd/ponzu/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_17_test.go new file mode 100644 index 0000000..9f0f90f --- /dev/null +++ b/cmd/ponzu/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_17_test.go @@ -0,0 +1,28 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !plan9,go1.7 + +package ctxhttp + +import ( + "io" + "net/http" + "net/http/httptest" + "testing" + + "context" +) + +func TestGo17Context(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + io.WriteString(w, "ok") + })) + ctx := context.Background() + resp, err := Get(ctx, http.DefaultClient, ts.URL) + if resp == nil || err != nil { + t.Fatalf("error received from client: %v %v", err, resp) + } + resp.Body.Close() +} diff --git a/cmd/ponzu/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_pre17.go b/cmd/ponzu/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_pre17.go new file mode 100644 index 0000000..926870c --- /dev/null +++ b/cmd/ponzu/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_pre17.go @@ -0,0 +1,147 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.7 + +package ctxhttp // import "golang.org/x/net/context/ctxhttp" + +import ( + "io" + "net/http" + "net/url" + "strings" + + "golang.org/x/net/context" +) + +func nop() {} + +var ( + testHookContextDoneBeforeHeaders = nop + testHookDoReturned = nop + testHookDidBodyClose = nop +) + +// Do sends an HTTP request with the provided http.Client and returns an HTTP response. +// If the client is nil, http.DefaultClient is used. +// If the context is canceled or times out, ctx.Err() will be returned. +func Do(ctx context.Context, client *http.Client, req *http.Request) (*http.Response, error) { + if client == nil { + client = http.DefaultClient + } + + // TODO(djd): Respect any existing value of req.Cancel. + cancel := make(chan struct{}) + req.Cancel = cancel + + type responseAndError struct { + resp *http.Response + err error + } + result := make(chan responseAndError, 1) + + // Make local copies of test hooks closed over by goroutines below. + // Prevents data races in tests. + testHookDoReturned := testHookDoReturned + testHookDidBodyClose := testHookDidBodyClose + + go func() { + resp, err := client.Do(req) + testHookDoReturned() + result <- responseAndError{resp, err} + }() + + var resp *http.Response + + select { + case <-ctx.Done(): + testHookContextDoneBeforeHeaders() + close(cancel) + // Clean up after the goroutine calling client.Do: + go func() { + if r := <-result; r.resp != nil { + testHookDidBodyClose() + r.resp.Body.Close() + } + }() + return nil, ctx.Err() + case r := <-result: + var err error + resp, err = r.resp, r.err + if err != nil { + return resp, err + } + } + + c := make(chan struct{}) + go func() { + select { + case <-ctx.Done(): + close(cancel) + case <-c: + // The response's Body is closed. + } + }() + resp.Body = ¬ifyingReader{resp.Body, c} + + return resp, nil +} + +// Get issues a GET request via the Do function. +func Get(ctx context.Context, client *http.Client, url string) (*http.Response, error) { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + return Do(ctx, client, req) +} + +// Head issues a HEAD request via the Do function. +func Head(ctx context.Context, client *http.Client, url string) (*http.Response, error) { + req, err := http.NewRequest("HEAD", url, nil) + if err != nil { + return nil, err + } + return Do(ctx, client, req) +} + +// Post issues a POST request via the Do function. +func Post(ctx context.Context, client *http.Client, url string, bodyType string, body io.Reader) (*http.Response, error) { + req, err := http.NewRequest("POST", url, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", bodyType) + return Do(ctx, client, req) +} + +// PostForm issues a POST request via the Do function. +func PostForm(ctx context.Context, client *http.Client, url string, data url.Values) (*http.Response, error) { + return Post(ctx, client, url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) +} + +// notifyingReader is an io.ReadCloser that closes the notify channel after +// Close is called or a Read fails on the underlying ReadCloser. +type notifyingReader struct { + io.ReadCloser + notify chan<- struct{} +} + +func (r *notifyingReader) Read(p []byte) (int, error) { + n, err := r.ReadCloser.Read(p) + if err != nil && r.notify != nil { + close(r.notify) + r.notify = nil + } + return n, err +} + +func (r *notifyingReader) Close() error { + err := r.ReadCloser.Close() + if r.notify != nil { + close(r.notify) + r.notify = nil + } + return err +} diff --git a/cmd/ponzu/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_pre17_test.go b/cmd/ponzu/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_pre17_test.go new file mode 100644 index 0000000..9159cf0 --- /dev/null +++ b/cmd/ponzu/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_pre17_test.go @@ -0,0 +1,79 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !plan9,!go1.7 + +package ctxhttp + +import ( + "net" + "net/http" + "net/http/httptest" + "sync" + "testing" + "time" + + "golang.org/x/net/context" +) + +// golang.org/issue/14065 +func TestClosesResponseBodyOnCancel(t *testing.T) { + defer func() { testHookContextDoneBeforeHeaders = nop }() + defer func() { testHookDoReturned = nop }() + defer func() { testHookDidBodyClose = nop }() + + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})) + defer ts.Close() + + ctx, cancel := context.WithCancel(context.Background()) + + // closed when Do enters select case <-ctx.Done() + enteredDonePath := make(chan struct{}) + + testHookContextDoneBeforeHeaders = func() { + close(enteredDonePath) + } + + testHookDoReturned = func() { + // We now have the result (the Flush'd headers) at least, + // so we can cancel the request. + cancel() + + // But block the client.Do goroutine from sending + // until Do enters into the <-ctx.Done() path, since + // otherwise if both channels are readable, select + // picks a random one. + <-enteredDonePath + } + + sawBodyClose := make(chan struct{}) + testHookDidBodyClose = func() { close(sawBodyClose) } + + tr := &http.Transport{} + defer tr.CloseIdleConnections() + c := &http.Client{Transport: tr} + req, _ := http.NewRequest("GET", ts.URL, nil) + _, doErr := Do(ctx, c, req) + + select { + case <-sawBodyClose: + case <-time.After(5 * time.Second): + t.Fatal("timeout waiting for body to close") + } + + if doErr != ctx.Err() { + t.Errorf("Do error = %v; want %v", doErr, ctx.Err()) + } +} + +type noteCloseConn struct { + net.Conn + onceClose sync.Once + closefn func() +} + +func (c *noteCloseConn) Close() error { + c.onceClose.Do(c.closefn) + return c.Conn.Close() +} diff --git a/cmd/ponzu/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_test.go b/cmd/ponzu/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_test.go new file mode 100644 index 0000000..1e41551 --- /dev/null +++ b/cmd/ponzu/vendor/golang.org/x/net/context/ctxhttp/ctxhttp_test.go @@ -0,0 +1,105 @@ +// Copyright 2015 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !plan9 + +package ctxhttp + +import ( + "io" + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + "time" + + "golang.org/x/net/context" +) + +const ( + requestDuration = 100 * time.Millisecond + requestBody = "ok" +) + +func okHandler(w http.ResponseWriter, r *http.Request) { + time.Sleep(requestDuration) + io.WriteString(w, requestBody) +} + +func TestNoTimeout(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(okHandler)) + defer ts.Close() + + ctx := context.Background() + res, err := Get(ctx, nil, ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(slurp) != requestBody { + t.Errorf("body = %q; want %q", slurp, requestBody) + } +} + +func TestCancelBeforeHeaders(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + blockServer := make(chan struct{}) + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + cancel() + <-blockServer + io.WriteString(w, requestBody) + })) + defer ts.Close() + defer close(blockServer) + + res, err := Get(ctx, nil, ts.URL) + if err == nil { + res.Body.Close() + t.Fatal("Get returned unexpected nil error") + } + if err != context.Canceled { + t.Errorf("err = %v; want %v", err, context.Canceled) + } +} + +func TestCancelAfterHangingRequest(t *testing.T) { + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + w.(http.Flusher).Flush() + <-w.(http.CloseNotifier).CloseNotify() + })) + defer ts.Close() + + ctx, cancel := context.WithCancel(context.Background()) + resp, err := Get(ctx, nil, ts.URL) + if err != nil { + t.Fatalf("unexpected error in Get: %v", err) + } + + // Cancel befer reading the body. + // Reading Request.Body should fail, since the request was + // canceled before anything was written. + cancel() + + done := make(chan struct{}) + + go func() { + b, err := ioutil.ReadAll(resp.Body) + if len(b) != 0 || err == nil { + t.Errorf(`Read got (%q, %v); want ("", error)`, b, err) + } + close(done) + }() + + select { + case <-time.After(1 * time.Second): + t.Errorf("Test timed out") + case <-done: + } +} diff --git a/cmd/ponzu/vendor/golang.org/x/net/context/go17.go b/cmd/ponzu/vendor/golang.org/x/net/context/go17.go new file mode 100644 index 0000000..f8cda19 --- /dev/null +++ b/cmd/ponzu/vendor/golang.org/x/net/context/go17.go @@ -0,0 +1,72 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build go1.7 + +package context + +import ( + "context" // standard library's context, as of Go 1.7 + "time" +) + +var ( + todo = context.TODO() + background = context.Background() +) + +// Canceled is the error returned by Context.Err when the context is canceled. +var Canceled = context.Canceled + +// DeadlineExceeded is the error returned by Context.Err when the context's +// deadline passes. +var DeadlineExceeded = context.DeadlineExceeded + +// WithCancel returns a copy of parent with a new Done channel. The returned +// context's Done channel is closed when the returned cancel function is called +// or when the parent context's Done channel is closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { + ctx, f := context.WithCancel(parent) + return ctx, CancelFunc(f) +} + +// WithDeadline returns a copy of the parent context with the deadline adjusted +// to be no later than d. If the parent's deadline is already earlier than d, +// WithDeadline(parent, d) is semantically equivalent to parent. The returned +// context's Done channel is closed when the deadline expires, when the returned +// cancel function is called, or when the parent context's Done channel is +// closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { + ctx, f := context.WithDeadline(parent, deadline) + return ctx, CancelFunc(f) +} + +// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)). +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete: +// +// func slowOperationWithTimeout(ctx context.Context) (Result, error) { +// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) +// defer cancel() // releases resources if slowOperation completes before timeout elapses +// return slowOperation(ctx) +// } +func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { + return WithDeadline(parent, time.Now().Add(timeout)) +} + +// WithValue returns a copy of parent in which the value associated with key is +// val. +// +// Use context Values only for request-scoped data that transits processes and +// APIs, not for passing optional parameters to functions. +func WithValue(parent Context, key interface{}, val interface{}) Context { + return context.WithValue(parent, key, val) +} diff --git a/cmd/ponzu/vendor/golang.org/x/net/context/pre_go17.go b/cmd/ponzu/vendor/golang.org/x/net/context/pre_go17.go new file mode 100644 index 0000000..5a30aca --- /dev/null +++ b/cmd/ponzu/vendor/golang.org/x/net/context/pre_go17.go @@ -0,0 +1,300 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build !go1.7 + +package context + +import ( + "errors" + "fmt" + "sync" + "time" +) + +// An emptyCtx is never canceled, has no values, and has no deadline. It is not +// struct{}, since vars of this type must have distinct addresses. +type emptyCtx int + +func (*emptyCtx) Deadline() (deadline time.Time, ok bool) { + return +} + +func (*emptyCtx) Done() <-chan struct{} { + return nil +} + +func (*emptyCtx) Err() error { + return nil +} + +func (*emptyCtx) Value(key interface{}) interface{} { + return nil +} + +func (e *emptyCtx) String() string { + switch e { + case background: + return "context.Background" + case todo: + return "context.TODO" + } + return "unknown empty Context" +} + +var ( + background = new(emptyCtx) + todo = new(emptyCtx) +) + +// Canceled is the error returned by Context.Err when the context is canceled. +var Canceled = errors.New("context canceled") + +// DeadlineExceeded is the error returned by Context.Err when the context's +// deadline passes. +var DeadlineExceeded = errors.New("context deadline exceeded") + +// WithCancel returns a copy of parent with a new Done channel. The returned +// context's Done channel is closed when the returned cancel function is called +// or when the parent context's Done channel is closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { + c := newCancelCtx(parent) + propagateCancel(parent, c) + return c, func() { c.cancel(true, Canceled) } +} + +// newCancelCtx returns an initialized cancelCtx. +func newCancelCtx(parent Context) *cancelCtx { + return &cancelCtx{ + Context: parent, + done: make(chan struct{}), + } +} + +// propagateCancel arranges for child to be canceled when parent is. +func propagateCancel(parent Context, child canceler) { + if parent.Done() == nil { + return // parent is never canceled + } + if p, ok := parentCancelCtx(parent); ok { + p.mu.Lock() + if p.err != nil { + // parent has already been canceled + child.cancel(false, p.err) + } else { + if p.children == nil { + p.children = make(map[canceler]bool) + } + p.children[child] = true + } + p.mu.Unlock() + } else { + go func() { + select { + case <-parent.Done(): + child.cancel(false, parent.Err()) + case <-child.Done(): + } + }() + } +} + +// parentCancelCtx follows a chain of parent references until it finds a +// *cancelCtx. This function understands how each of the concrete types in this +// package represents its parent. +func parentCancelCtx(parent Context) (*cancelCtx, bool) { + for { + switch c := parent.(type) { + case *cancelCtx: + return c, true + case *timerCtx: + return c.cancelCtx, true + case *valueCtx: + parent = c.Context + default: + return nil, false + } + } +} + +// removeChild removes a context from its parent. +func removeChild(parent Context, child canceler) { + p, ok := parentCancelCtx(parent) + if !ok { + return + } + p.mu.Lock() + if p.children != nil { + delete(p.children, child) + } + p.mu.Unlock() +} + +// A canceler is a context type that can be canceled directly. The +// implementations are *cancelCtx and *timerCtx. +type canceler interface { + cancel(removeFromParent bool, err error) + Done() <-chan struct{} +} + +// A cancelCtx can be canceled. When canceled, it also cancels any children +// that implement canceler. +type cancelCtx struct { + Context + + done chan struct{} // closed by the first cancel call. + + mu sync.Mutex + children map[canceler]bool // set to nil by the first cancel call + err error // set to non-nil by the first cancel call +} + +func (c *cancelCtx) Done() <-chan struct{} { + return c.done +} + +func (c *cancelCtx) Err() error { + c.mu.Lock() + defer c.mu.Unlock() + return c.err +} + +func (c *cancelCtx) String() string { + return fmt.Sprintf("%v.WithCancel", c.Context) +} + +// cancel closes c.done, cancels each of c's children, and, if +// removeFromParent is true, removes c from its parent's children. +func (c *cancelCtx) cancel(removeFromParent bool, err error) { + if err == nil { + panic("context: internal error: missing cancel error") + } + c.mu.Lock() + if c.err != nil { + c.mu.Unlock() + return // already canceled + } + c.err = err + close(c.done) + for child := range c.children { + // NOTE: acquiring the child's lock while holding parent's lock. + child.cancel(false, err) + } + c.children = nil + c.mu.Unlock() + + if removeFromParent { + removeChild(c.Context, c) + } +} + +// WithDeadline returns a copy of the parent context with the deadline adjusted +// to be no later than d. If the parent's deadline is already earlier than d, +// WithDeadline(parent, d) is semantically equivalent to parent. The returned +// context's Done channel is closed when the deadline expires, when the returned +// cancel function is called, or when the parent context's Done channel is +// closed, whichever happens first. +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete. +func WithDeadline(parent Context, deadline time.Time) (Context, CancelFunc) { + if cur, ok := parent.Deadline(); ok && cur.Before(deadline) { + // The current deadline is already sooner than the new one. + return WithCancel(parent) + } + c := &timerCtx{ + cancelCtx: newCancelCtx(parent), + deadline: deadline, + } + propagateCancel(parent, c) + d := deadline.Sub(time.Now()) + if d <= 0 { + c.cancel(true, DeadlineExceeded) // deadline has already passed + return c, func() { c.cancel(true, Canceled) } + } + c.mu.Lock() + defer c.mu.Unlock() + if c.err == nil { + c.timer = time.AfterFunc(d, func() { + c.cancel(true, DeadlineExceeded) + }) + } + return c, func() { c.cancel(true, Canceled) } +} + +// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to +// implement Done and Err. It implements cancel by stopping its timer then +// delegating to cancelCtx.cancel. +type timerCtx struct { + *cancelCtx + timer *time.Timer // Under cancelCtx.mu. + + deadline time.Time +} + +func (c *timerCtx) Deadline() (deadline time.Time, ok bool) { + return c.deadline, true +} + +func (c *timerCtx) String() string { + return fmt.Sprintf("%v.WithDeadline(%s [%s])", c.cancelCtx.Context, c.deadline, c.deadline.Sub(time.Now())) +} + +func (c *timerCtx) cancel(removeFromParent bool, err error) { + c.cancelCtx.cancel(false, err) + if removeFromParent { + // Remove this timerCtx from its parent cancelCtx's children. + removeChild(c.cancelCtx.Context, c) + } + c.mu.Lock() + if c.timer != nil { + c.timer.Stop() + c.timer = nil + } + c.mu.Unlock() +} + +// WithTimeout returns WithDeadline(parent, time.Now().Add(timeout)). +// +// Canceling this context releases resources associated with it, so code should +// call cancel as soon as the operations running in this Context complete: +// +// func slowOperationWithTimeout(ctx context.Context) (Result, error) { +// ctx, cancel := context.WithTimeout(ctx, 100*time.Millisecond) +// defer cancel() // releases resources if slowOperation completes before timeout elapses +// return slowOperation(ctx) +// } +func WithTimeout(parent Context, timeout time.Duration) (Context, CancelFunc) { + return WithDeadline(parent, time.Now().Add(timeout)) +} + +// WithValue returns a copy of parent in which the value associated with key is +// val. +// +// Use context Values only for request-scoped data that transits processes and +// APIs, not for passing optional parameters to functions. +func WithValue(parent Context, key interface{}, val interface{}) Context { + return &valueCtx{parent, key, val} +} + +// A valueCtx carries a key-value pair. It implements Value for that key and +// delegates all other calls to the embedded Context. +type valueCtx struct { + Context + key, val interface{} +} + +func (c *valueCtx) String() string { + return fmt.Sprintf("%v.WithValue(%#v, %#v)", c.Context, c.key, c.val) +} + +func (c *valueCtx) Value(key interface{}) interface{} { + if c.key == key { + return c.val + } + return c.Context.Value(key) +} diff --git a/cmd/ponzu/vendor/golang.org/x/net/context/withtimeout_test.go b/cmd/ponzu/vendor/golang.org/x/net/context/withtimeout_test.go new file mode 100644 index 0000000..a6754dc --- /dev/null +++ b/cmd/ponzu/vendor/golang.org/x/net/context/withtimeout_test.go @@ -0,0 +1,26 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package context_test + +import ( + "fmt" + "time" + + "golang.org/x/net/context" +) + +func ExampleWithTimeout() { + // Pass a context with a timeout to tell a blocking function that it + // should abandon its work after the timeout elapses. + ctx, _ := context.WithTimeout(context.Background(), 100*time.Millisecond) + select { + case <-time.After(200 * time.Millisecond): + fmt.Println("overslept") + case <-ctx.Done(): + fmt.Println(ctx.Err()) // prints "context deadline exceeded" + } + // Output: + // context deadline exceeded +} diff --git a/content/post.go b/content/post.go index 7dc99c1..dcdfeff 100644 --- a/content/post.go +++ b/content/post.go @@ -11,33 +11,14 @@ type Post struct { Item editor editor.Editor - Title string `json:"title"` - Content string `json:"content"` - Photo string `json:"photo"` - Author string `json:"author"` - Category []string `json:"category"` - ThemeStyle string `json:"theme"` + Title string `json:"title"` + Content string `json:"content"` + Photo string `json:"photo"` + Author string `json:"author"` + Category []string `json:"category"` + Theme string `json:"theme"` } -func init() { - Types["Post"] = func() interface{} { return new(Post) } -} - -// SetContentID partially implements editor.Editable -func (p *Post) SetContentID(id int) { p.ID = id } - -// ContentID partially implements editor.Editable -func (p *Post) ContentID() int { return p.ID } - -// ContentName partially implements editor.Editable -func (p *Post) ContentName() string { return p.Title } - -// SetSlug partially implements editor.Editable -func (p *Post) SetSlug(slug string) { p.Slug = slug } - -// Editor partially implements editor.Editable -func (p *Post) Editor() *editor.Editor { return &p.editor } - // MarshalEditor writes a buffer of html to edit a Post and partially implements editor.Editable func (p *Post) MarshalEditor() ([]byte, error) { view, err := editor.Form(p, @@ -55,7 +36,7 @@ func (p *Post) MarshalEditor() ([]byte, error) { }), }, editor.Field{ - View: editor.File("Picture", p, map[string]string{ + View: editor.File("Photo", p, map[string]string{ "label": "Author Photo", "placeholder": "Upload a profile picture for the author", }), @@ -68,16 +49,12 @@ func (p *Post) MarshalEditor() ([]byte, error) { }), }, editor.Field{ - View: editor.Checkbox("Category", p, map[string]string{ - "label": "Post Category", - }, map[string]string{ - "important": "Important", - "active": "Active", - "unplanned": "Unplanned", + View: editor.Tags("Category", p, map[string]string{ + "label": "Post Categories", }), }, editor.Field{ - View: editor.Select("ThemeStyle", p, map[string]string{ + View: editor.Select("Theme", p, map[string]string{ "label": "Theme Style", }, map[string]string{ "dark": "Dark", @@ -92,3 +69,22 @@ func (p *Post) MarshalEditor() ([]byte, error) { return view, nil } + +func init() { + Types["Post"] = func() interface{} { return new(Post) } +} + +// SetContentID partially implements editor.Editable +func (p *Post) SetContentID(id int) { p.ID = id } + +// ContentID partially implements editor.Editable +func (p *Post) ContentID() int { return p.ID } + +// ContentName partially implements editor.Editable +func (p *Post) ContentName() string { return p.Title } + +// SetSlug partially implements editor.Editable +func (p *Post) SetSlug(slug string) { p.Slug = slug } + +// Editor partially implements editor.Editable +func (p *Post) Editor() *editor.Editor { return &p.editor } diff --git a/management/editor/editor.go b/management/editor/editor.go index 3b26adb..68b787b 100644 --- a/management/editor/editor.go +++ b/management/editor/editor.go @@ -110,17 +110,31 @@ func Form(post Editable, fields ...Field) ([]byte, error) { <button class="right waves-effect waves-light btn red delete-post" type="submit">Delete</button> </div> +<div class="row external post-controls"> + <div class="col s12 input-field"> + <button class="right waves-effect waves-light btn blue approve-post" type="submit">Approve</button> + </div> + <label class="approve-details right-align col s12">This content is pending approval. By clicking 'Approve', it will be immediately published.</label> +</div> + <script> $(function() { var form = $('form'), del = form.find('button.delete-post'), + approve = form.find('.post-controls.external'), id = form.find('input[name=id]'); - // hide delete button if this is a new post, or a non-post editor page + // hide if this is a new post, or a non-post editor page if (id.val() === '-1' || form.attr('action') !== '/admin/edit') { del.hide(); + approve.hide(); } + // hide approval if not on a pending content item + if (getParam("status") !== "pending") { + approve.hide(); + } + del.on('click', function(e) { e.preventDefault(); var action = form.attr('action'); @@ -131,6 +145,15 @@ func Form(post Editable, fields ...Field) ([]byte, error) { form.submit(); } }); + + approve.find('button').on('click', function(e) { + e.preventDefault(); + var action = form.attr('action'); + action = action + '/approve'; + form.attr('action', action); + + form.submit(); + }); }); </script> ` diff --git a/management/editor/elements.go b/management/editor/elements.go index 4d829ad..2326358 100644 --- a/management/editor/elements.go +++ b/management/editor/elements.go @@ -332,6 +332,94 @@ func Checkbox(fieldName string, p interface{}, attrs, options map[string]string) return domElementWithChildrenCheckbox(div, opts) } +// Tags returns the []byte of a tag input (in the style of Materialze 'Chips') with a label. +// IMPORTANT: +// The `fieldName` argument will cause a panic if it is not exactly the string +// form of the struct field that this editor input is representing +func Tags(fieldName string, p interface{}, attrs map[string]string) []byte { + name := tagNameFromStructField(fieldName, p) + + // get the saved tags if this is already an existing post + values := valueFromStructField(fieldName, p) // returns refelct.Value + tags := values.Slice(0, values.Len()).Interface().([]string) // casts reflect.Value to []string + + html := ` + <div class="col s12 tags ` + name + `"> + <label class="active">` + attrs["label"] + ` (Type and press "Enter")</label> + <div class="chips ` + name + `"></div> + ` + + var initial []string + i := 0 + for _, tag := range tags { + tagName := tagNameFromStructFieldMulti(fieldName, i, p) + html += `<input type="hidden" class="tag ` + tag + `" name=` + tagName + ` value="` + tag + `"/>` + initial = append(initial, `{tag: '`+tag+`'}`) + i++ + } + + script := ` + <script> + $(function() { + var tags = $('.tags.` + name + `'); + $('.chips.` + name + `').material_chip({ + data: [` + strings.Join(initial, ",") + `], + secondaryPlaceholder: '+` + name + `' + }); + + // handle events specific to tags + var chips = tags.find('.chips'); + + chips.on('chip.add', function(e, chip) { + chips.parent().find('.empty-tag').remove(); + + var input = $('<input>'); + input.attr({ + class: 'tag '+chip.tag, + name: '` + name + `.'+String(tags.find('input[type=hidden]').length), + value: chip.tag, + type: 'hidden' + }); + + tags.append(input); + }); + + chips.on('chip.delete', function(e, chip) { + // convert tag string to class-like selector "some tag" -> ".some.tag" + var sel = '.tag.'+chip.tag.split(' ').join('.'); + console.log(sel); + console.log(chips.parent().find(sel)); + chips.parent().find(sel).remove(); + + // iterate through all hidden tag inputs to re-name them with the correct ` + name + `.index + var hidden = chips.parent().find('input[type=hidden]'); + + // if there are no tags, set a blank + if (hidden.length === 0) { + var input = $('<input>'); + input.attr({ + class: 'empty-tag', + name: '` + name + `', + type: 'hidden' + }); + + tags.append(input); + return; + } + + for (var i = 0; i < hidden.length; i++) { + $(hidden[i]).attr('name', '` + name + `.'+String(i)); + } + }); + }); + </script> + ` + + html += `</div>` + + return []byte(html + script) +} + // domElementSelfClose is a special DOM element which is parsed as a // self-closing tag and thus needs to be created differently func domElementSelfClose(e *element) []byte { diff --git a/system/admin/admin.go b/system/admin/admin.go index b375eb6..9ee86c3 100644 --- a/system/admin/admin.go +++ b/system/admin/admin.go @@ -4,9 +4,12 @@ package admin import ( "bytes" + "encoding/json" "html/template" + "net/http" "github.com/bosssauce/ponzu/content" + "github.com/bosssauce/ponzu/system/admin/user" "github.com/bosssauce/ponzu/system/db" ) @@ -241,6 +244,129 @@ func Login() ([]byte, error) { return buf.Bytes(), nil } +// UsersList ... +func UsersList(req *http.Request) ([]byte, error) { + html := ` + <div class="card user-management"> + <div class="card-title">Edit your account:</div> + <form class="row" enctype="multipart/form-data" action="/admin/configure/users/edit" method="post"> + <div class="input-feild col s9"> + <label class="active">Email Address</label> + <input type="email" name="email" value="{{ .User.Email }}"/> + </div> + + <div class="input-feild col s9"> + <div>To approve changes, enter your password:</div> + + <label class="active">Current Password</label> + <input type="password" name="password"/> + </div> + + <div class="input-feild col s9"> + <label class="active">New Password: (leave blank if no password change needed)</label> + <input name="new_password" type="password"/> + </div> + + <div class="input-feild col s9"> + <button class="btn waves-effect waves-light green right" type="submit">Save</button> + </div> + </form> + + <div class="card-title">Add a new user:</div> + <form class="row" enctype="multipart/form-data" action="/admin/configure/users" method="post"> + <div class="input-feild col s9"> + <label class="active">Email Address</label> + <input type="email" name="email" value=""/> + </div> + + <div class="input-feild col s9"> + <label class="active">Password</label> + <input type="password" name="password"/> + </div> + + <div class="input-feild col s9"> + <button class="btn waves-effect waves-light green right" type="submit">Add User</button> + </div> + </form> + + <div class="card-title">Remove Admin Users</div> + <ul class="users row"> + {{ range .Users }} + <li class="col s9"> + {{ .Email }} + <form enctype="multipart/form-data" class="delete-user __ponzu right" action="/admin/configure/users/delete" method="post"> + <span>Delete</span> + <input type="hidden" name="email" value="{{ .Email }}"/> + <input type="hidden" name="id" value="{{ .ID }}"/> + </form> + </li> + {{ end }} + </ul> + </div> + ` + script := ` + <script> + $(function() { + var del = $('.delete-user.__ponzu span'); + del.on('click', function(e) { + if (confirm("[Ponzu] Please confirm:\n\nAre you sure you want to delete this user?\nThis cannot be undone.")) { + $(e.target).parent().submit(); + } + }); + }); + </script> + ` + // get current user out to pass as data to execute template + j, err := db.CurrentUser(req) + if err != nil { + return nil, err + } + + var usr user.User + err = json.Unmarshal(j, &usr) + if err != nil { + return nil, err + } + + // get all users to list + jj, err := db.UserAll() + if err != nil { + return nil, err + } + + var usrs []user.User + for i := range jj { + var u user.User + err = json.Unmarshal(jj[i], &u) + if err != nil { + return nil, err + } + if u.Email != usr.Email { + usrs = append(usrs, u) + } + } + + // make buffer to execute html into then pass buffer's bytes to Admin + buf := &bytes.Buffer{} + tmpl := template.Must(template.New("users").Parse(html + script)) + data := map[string]interface{}{ + "User": usr, + "Users": usrs, + } + + err = tmpl.Execute(buf, data) + if err != nil { + return nil, err + } + + view, err := Admin(buf.Bytes()) + if err != nil { + return nil, err + } + + return view, nil +} + var err400HTML = ` <div class="error-page e400 col s6"> <div class="card"> @@ -288,7 +414,7 @@ var err405HTML = ` <div class="card"> <div class="card-content"> <div class="card-title"><b>405</b> Error: Method Not Allowed</div> - <blockquote>Sorry, the page you requested could not be found.</blockquote> + <blockquote>Sorry, the method of your request is not allowed.</blockquote> </div> </div> </div> diff --git a/system/admin/config/config.go b/system/admin/config/config.go index c83c311..66f767d 100644 --- a/system/admin/config/config.go +++ b/system/admin/config/config.go @@ -12,6 +12,7 @@ type Config struct { Name string `json:"name"` Domain string `json:"domain"` + AdminEmail string `json:"admin_email"` ClientSecret string `json:"client_secret"` Etag string `json:"etag"` CacheInvalidate []string `json:"-"` @@ -48,8 +49,13 @@ func (c *Config) MarshalEditor() ([]byte, error) { }), }, editor.Field{ + View: editor.Input("AdminEmail", c, map[string]string{ + "label": "Adminstrator Email (will be notified of internal system information)", + }), + }, + editor.Field{ View: editor.Input("ClientSecret", c, map[string]string{ - "label": "Client Secret (used to validate requests)", + "label": "Client Secret (used to validate requests, DO NOT SHARE)", "disabled": "true", }), }, diff --git a/system/admin/handlers.go b/system/admin/handlers.go index de340ae..742b898 100644 --- a/system/admin/handlers.go +++ b/system/admin/handlers.go @@ -14,16 +14,19 @@ import ( "github.com/bosssauce/ponzu/management/editor" "github.com/bosssauce/ponzu/management/manager" "github.com/bosssauce/ponzu/system/admin/config" + "github.com/bosssauce/ponzu/system/admin/upload" "github.com/bosssauce/ponzu/system/admin/user" + "github.com/bosssauce/ponzu/system/api" "github.com/bosssauce/ponzu/system/db" + "github.com/gorilla/schema" "github.com/nilslice/jwt" ) func adminHandler(res http.ResponseWriter, req *http.Request) { view, err := Admin(nil) if err != nil { - fmt.Println(err) + log.Println(err) res.WriteHeader(http.StatusInternalServerError) return } @@ -42,7 +45,7 @@ func initHandler(res http.ResponseWriter, req *http.Request) { case http.MethodGet: view, err := Init() if err != nil { - fmt.Println(err) + log.Println(err) res.WriteHeader(http.StatusInternalServerError) return } @@ -52,7 +55,7 @@ func initHandler(res http.ResponseWriter, req *http.Request) { case http.MethodPost: err := req.ParseForm() if err != nil { - fmt.Println(err) + log.Println(err) res.WriteHeader(http.StatusInternalServerError) return } @@ -66,20 +69,22 @@ func initHandler(res http.ResponseWriter, req *http.Request) { etag := db.NewEtag() req.Form.Set("etag", etag) - err = db.SetConfig(req.Form) - if err != nil { - fmt.Println(err) - res.WriteHeader(http.StatusInternalServerError) - return - } - email := strings.ToLower(req.FormValue("email")) password := req.FormValue("password") usr := user.NewUser(email, password) _, err = db.SetUser(usr) if err != nil { - fmt.Println(err) + log.Println(err) + res.WriteHeader(http.StatusInternalServerError) + return + } + + // set initial user email as admin_email and make config + req.Form.Set("admin_email", email) + err = db.SetConfig(req.Form) + if err != nil { + log.Println(err) res.WriteHeader(http.StatusInternalServerError) return } @@ -98,6 +103,7 @@ func initHandler(res http.ResponseWriter, req *http.Request) { Name: "_token", Value: token, Expires: week, + Path: "/", }) redir := strings.TrimSuffix(req.URL.String(), "/init") @@ -113,7 +119,7 @@ func configHandler(res http.ResponseWriter, req *http.Request) { case http.MethodGet: data, err := db.ConfigAll() if err != nil { - fmt.Println(err) + log.Println(err) res.WriteHeader(http.StatusInternalServerError) return } @@ -122,21 +128,21 @@ func configHandler(res http.ResponseWriter, req *http.Request) { err = json.Unmarshal(data, c) if err != nil { - fmt.Println(err) + log.Println(err) res.WriteHeader(http.StatusInternalServerError) return } cfg, err := c.MarshalEditor() if err != nil { - fmt.Println(err) + log.Println(err) res.WriteHeader(http.StatusInternalServerError) return } adminView, err := Admin(cfg) if err != nil { - fmt.Println(err) + log.Println(err) res.WriteHeader(http.StatusInternalServerError) return } @@ -147,14 +153,14 @@ func configHandler(res http.ResponseWriter, req *http.Request) { case http.MethodPost: err := req.ParseForm() if err != nil { - fmt.Println(err) + log.Println(err) res.WriteHeader(http.StatusInternalServerError) return } err = db.SetConfig(req.Form) if err != nil { - fmt.Println(err) + log.Println(err) res.WriteHeader(http.StatusInternalServerError) return } @@ -170,10 +176,246 @@ func configHandler(res http.ResponseWriter, req *http.Request) { func configUsersHandler(res http.ResponseWriter, req *http.Request) { switch req.Method { case http.MethodGet: - // list all users and delete buttons + view, err := UsersList(req) + if err != nil { + log.Println(err) + res.WriteHeader(http.StatusInternalServerError) + errView, err := Error500() + if err != nil { + return + } + + res.Write(errView) + return + } + + res.Write(view) case http.MethodPost: // create new user + err := req.ParseMultipartForm(1024 * 1024 * 4) // maxMemory 4MB + if err != nil { + log.Println(err) + res.WriteHeader(http.StatusInternalServerError) + errView, err := Error500() + if err != nil { + return + } + + res.Write(errView) + return + } + + email := strings.ToLower(req.FormValue("email")) + password := req.PostFormValue("password") + + if email == "" || password == "" { + res.WriteHeader(http.StatusInternalServerError) + errView, err := Error500() + if err != nil { + return + } + + res.Write(errView) + return + } + + usr := user.NewUser(email, password) + + _, err = db.SetUser(usr) + if err != nil { + log.Println(err) + res.WriteHeader(http.StatusInternalServerError) + errView, err := Error500() + if err != nil { + return + } + + res.Write(errView) + return + } + + http.Redirect(res, req, req.URL.String(), http.StatusFound) + + default: + res.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func configUsersEditHandler(res http.ResponseWriter, req *http.Request) { + switch req.Method { + case http.MethodPost: + err := req.ParseMultipartForm(1024 * 1024 * 4) // maxMemory 4MB + + // check if user to be edited is current user + j, err := db.CurrentUser(req) + if err != nil { + log.Println(err) + res.WriteHeader(http.StatusInternalServerError) + errView, err := Error500() + if err != nil { + return + } + + res.Write(errView) + return + } + + usr := &user.User{} + err = json.Unmarshal(j, usr) + if err != nil { + log.Println(err) + res.WriteHeader(http.StatusInternalServerError) + errView, err := Error500() + if err != nil { + return + } + + res.Write(errView) + return + } + + // check if password matches + password := req.PostFormValue("password") + + if !user.IsUser(usr, password) { + log.Println("Unexpected user/password combination for", usr.Email) + res.WriteHeader(http.StatusBadRequest) + errView, err := Error405() + if err != nil { + return + } + + res.Write(errView) + return + } + + email := strings.ToLower(req.PostFormValue("email")) + newPassword := req.PostFormValue("new_password") + var updatedUser *user.User + if newPassword != "" { + updatedUser = user.NewUser(email, newPassword) + } else { + updatedUser = user.NewUser(email, password) + } + + // set the ID to the same ID as current user + updatedUser.ID = usr.ID + + // set user in db + err = db.UpdateUser(usr, updatedUser) + if err != nil { + log.Println(err) + res.WriteHeader(http.StatusInternalServerError) + errView, err := Error500() + if err != nil { + return + } + + res.Write(errView) + return + } + + // create new token + week := time.Now().Add(time.Hour * 24 * 7) + claims := map[string]interface{}{ + "exp": week, + "user": updatedUser.Email, + } + token, err := jwt.New(claims) + if err != nil { + log.Println(err) + res.WriteHeader(http.StatusInternalServerError) + errView, err := Error500() + if err != nil { + return + } + + res.Write(errView) + return + } + + // add token to cookie +1 week expiration + cookie := &http.Cookie{ + Name: "_token", + Value: token, + Expires: week, + Path: "/", + } + http.SetCookie(res, cookie) + + // add new token cookie to the request + req.AddCookie(cookie) + + http.Redirect(res, req, strings.TrimSuffix(req.URL.String(), "/edit"), http.StatusFound) + + default: + res.WriteHeader(http.StatusMethodNotAllowed) + } +} + +func configUsersDeleteHandler(res http.ResponseWriter, req *http.Request) { + switch req.Method { + case http.MethodPost: + err := req.ParseMultipartForm(1024 * 1024 * 4) // maxMemory 4MB + + // do not allow current user to delete themselves + j, err := db.CurrentUser(req) + if err != nil { + log.Println(err) + res.WriteHeader(http.StatusInternalServerError) + errView, err := Error500() + if err != nil { + return + } + + res.Write(errView) + return + } + + usr := &user.User{} + err = json.Unmarshal(j, &usr) + if err != nil { + log.Println(err) + res.WriteHeader(http.StatusInternalServerError) + errView, err := Error500() + if err != nil { + return + } + + res.Write(errView) + return + } + + email := strings.ToLower(req.PostFormValue("email")) + + if usr.Email == email { + log.Println(err) + res.WriteHeader(http.StatusBadRequest) + errView, err := Error405() + if err != nil { + return + } + + res.Write(errView) + return + } + + // delete existing user + err = db.DeleteUser(email) + if err != nil { + log.Println(err) + res.WriteHeader(http.StatusInternalServerError) + errView, err := Error500() + if err != nil { + return + } + + res.Write(errView) + return + } + + http.Redirect(res, req, strings.TrimSuffix(req.URL.String(), "/delete"), http.StatusFound) default: res.WriteHeader(http.StatusMethodNotAllowed) @@ -196,7 +438,7 @@ func loginHandler(res http.ResponseWriter, req *http.Request) { view, err := Login() if err != nil { - fmt.Println(err) + log.Println(err) res.WriteHeader(http.StatusInternalServerError) return } @@ -212,7 +454,7 @@ func loginHandler(res http.ResponseWriter, req *http.Request) { err := req.ParseForm() if err != nil { - fmt.Println(err) + log.Println(err) http.Redirect(res, req, req.URL.String(), http.StatusFound) return } @@ -220,7 +462,7 @@ func loginHandler(res http.ResponseWriter, req *http.Request) { // check email & password j, err := db.User(strings.ToLower(req.FormValue("email"))) if err != nil { - fmt.Println(err) + log.Println(err) http.Redirect(res, req, req.URL.String(), http.StatusFound) return } @@ -233,7 +475,7 @@ func loginHandler(res http.ResponseWriter, req *http.Request) { usr := &user.User{} err = json.Unmarshal(j, usr) if err != nil { - fmt.Println(err) + log.Println(err) http.Redirect(res, req, req.URL.String(), http.StatusFound) return } @@ -250,7 +492,7 @@ func loginHandler(res http.ResponseWriter, req *http.Request) { } token, err := jwt.New(claims) if err != nil { - fmt.Println(err) + log.Println(err) http.Redirect(res, req, req.URL.String(), http.StatusFound) return } @@ -260,6 +502,7 @@ func loginHandler(res http.ResponseWriter, req *http.Request) { Name: "_token", Value: token, Expires: week, + Path: "/", }) http.Redirect(res, req, strings.TrimSuffix(req.URL.String(), "/login"), http.StatusFound) @@ -271,6 +514,7 @@ func logoutHandler(res http.ResponseWriter, req *http.Request) { Name: "_token", Expires: time.Unix(0, 0), Value: "", + Path: "/", }) http.Redirect(res, req, req.URL.Scheme+req.URL.Host+"/admin/login", http.StatusFound) @@ -291,10 +535,25 @@ func postsHandler(res http.ResponseWriter, req *http.Request) { } order := strings.ToLower(q.Get("order")) + status := q.Get("status") posts := db.ContentAll(t + "_sorted") b := &bytes.Buffer{} - p, ok := content.Types[t]().(editor.Editable) + + if _, ok := content.Types[t]; !ok { + res.WriteHeader(http.StatusBadRequest) + errView, err := Error405() + if err != nil { + return + } + + res.Write(errView) + return + } + + pt := content.Types[t]() + + p, ok := pt.(editor.Editable) if !ok { res.WriteHeader(http.StatusInternalServerError) errView, err := Error500() @@ -306,6 +565,12 @@ func postsHandler(res http.ResponseWriter, req *http.Request) { return } + var hasExt bool + _, ok = pt.(api.Externalable) + if ok { + hasExt = true + } + html := `<div class="col s9 card"> <div class="card-content"> <div class="row"> @@ -321,21 +586,6 @@ func postsHandler(res http.ResponseWriter, req *http.Request) { </div> <script> $(function() { - var getParam = function(param) { - var qs = window.location.search.substring(1); - var qp = qs.split('&'); - var t = ''; - - for (var i = 0; i < qp.length; i++) { - var p = qp[i].split('=') - if (p[0] === param) { - t = p[1]; - } - } - - return t; - } - var sort = $('select.__ponzu.sort-order'); sort.on('change', function() { @@ -363,39 +613,107 @@ func postsHandler(res http.ResponseWriter, req *http.Request) { <input type="hidden" name="type" value="` + t + `" /> </div> </form> - </div> - <ul class="posts row">` + </div>` + if hasExt { + if status == "" { + q.Add("status", "public") + } - if order == "desc" || order == "" { - // keep natural order of posts slice, as returned from sorted bucket - for i := range posts { - err := json.Unmarshal(posts[i], &p) - if err != nil { - log.Println("Error unmarshal json into", t, err, posts[i]) + q.Set("status", "public") + publicURL := req.URL.Path + "?" + q.Encode() + + q.Set("status", "pending") + pendingURL := req.URL.Path + "?" + q.Encode() + + switch status { + case "public", "": + html += `<div class="row externalable"> + <span class="description">Status:</span> + <span class="active">Public</span> + | + <a href="` + pendingURL + `">Pending</a> + </div>` + + case "pending": + // get _pending posts of type t from the db + posts = db.ContentAll(t + "_pending") + + html += `<div class="row externalable"> + <span class="description">Status:</span> + <a href="` + publicURL + `">Public</a> + | + <span class="active">Pending</span> + </div>` + } + + } + html += `<ul class="posts row">` + + switch order { + case "desc", "": + if hasExt { + // reverse the order of posts slice + for i := len(posts) - 1; i >= 0; i-- { + err := json.Unmarshal(posts[i], &p) + if err != nil { + log.Println("Error unmarshal json into", t, err, posts[i]) + + post := `<li class="col s12">Error decoding data. Possible file corruption.</li>` + b.Write([]byte(post)) + continue + } - post := `<li class="col s12">Error decoding data. Possible file corruption.</li>` - b.Write([]byte(post)) - continue + post := adminPostListItem(p, t, status) + b.Write(post) } + } else { + // keep natural order of posts slice, as returned from sorted bucket + for i := range posts { + err := json.Unmarshal(posts[i], &p) + if err != nil { + log.Println("Error unmarshal json into", t, err, posts[i]) + + post := `<li class="col s12">Error decoding data. Possible file corruption.</li>` + b.Write([]byte(post)) + continue + } - post := adminPostListItem(p, t) - b.Write(post) + post := adminPostListItem(p, t, status) + b.Write(post) + } } - } else if order == "asc" { - // reverse the order of posts slice - for i := len(posts) - 1; i >= 0; i-- { - err := json.Unmarshal(posts[i], &p) - if err != nil { - log.Println("Error unmarshal json into", t, err, posts[i]) + case "asc": + if hasExt { + // keep natural order of posts slice, as returned from sorted bucket + for i := range posts { + err := json.Unmarshal(posts[i], &p) + if err != nil { + log.Println("Error unmarshal json into", t, err, posts[i]) + + post := `<li class="col s12">Error decoding data. Possible file corruption.</li>` + b.Write([]byte(post)) + continue + } - post := `<li class="col s12">Error decoding data. Possible file corruption.</li>` - b.Write([]byte(post)) - continue + post := adminPostListItem(p, t, status) + b.Write(post) } + } else { + // reverse the order of posts slice + for i := len(posts) - 1; i >= 0; i-- { + err := json.Unmarshal(posts[i], &p) + if err != nil { + log.Println("Error unmarshal json into", t, err, posts[i]) - post := adminPostListItem(p, t) - b.Write(post) + post := `<li class="col s12">Error decoding data. Possible file corruption.</li>` + b.Write([]byte(post)) + continue + } + + post := adminPostListItem(p, t, status) + b.Write(post) + } } } @@ -419,7 +737,7 @@ func postsHandler(res http.ResponseWriter, req *http.Request) { adminView, err := Admin([]byte(html)) if err != nil { - fmt.Println(err) + log.Println(err) res.WriteHeader(http.StatusInternalServerError) return } @@ -430,7 +748,8 @@ func postsHandler(res http.ResponseWriter, req *http.Request) { // adminPostListItem is a helper to create the li containing a post. // p is the asserted post as an Editable, t is the Type of the post. -func adminPostListItem(p editor.Editable, t string) []byte { +// specifier is passed to append a name to a namespace like _pending +func adminPostListItem(p editor.Editable, t, status string) []byte { s, ok := p.(editor.Sortable) if !ok { log.Println("Content type", t, "doesn't implement editor.Sortable") @@ -446,28 +765,127 @@ func adminPostListItem(p editor.Editable, t string) []byte { cid := fmt.Sprintf("%d", p.ContentID()) + if status == "public" { + status = "" + } else { + status = "_" + status + } + post := ` <li class="col s12"> - <a href="/admin/edit?type=` + t + `&id=` + cid + `">` + p.ContentName() + `</a> + <a href="/admin/edit?type=` + t + `&status=` + strings.TrimPrefix(status, "_") + `&id=` + cid + `">` + p.ContentName() + `</a> <span class="post-detail">Updated: ` + updatedTime + `</span> <span class="publish-date right">` + publishTime + `</span> <form enctype="multipart/form-data" class="quick-delete-post __ponzu right" action="/admin/edit/delete" method="post"> <span>Delete</span> <input type="hidden" name="id" value="` + cid + `" /> - <input type="hidden" name="type" value="` + t + `" /> + <input type="hidden" name="type" value="` + t + status + `" /> </form> </li>` return []byte(post) } +func approvePostHandler(res http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + res.WriteHeader(http.StatusMethodNotAllowed) + errView, err := Error405() + if err != nil { + return + } + + res.Write(errView) + return + } + + err := req.ParseMultipartForm(1024 * 1024 * 4) // maxMemory 4MB + if err != nil { + res.WriteHeader(http.StatusInternalServerError) + errView, err := Error500() + if err != nil { + return + } + + res.Write(errView) + return + } + + t := req.FormValue("type") + if strings.Contains(t, "_") { + t = strings.Split(t, "_")[0] + } + + post := content.Types[t]() + + // check if we have a Mergeable + m, ok := post.(api.Mergeable) + if !ok { + res.WriteHeader(http.StatusBadRequest) + errView, err := Error400() + if err != nil { + return + } + + res.Write(errView) + return + } + + dec := schema.NewDecoder() + dec.IgnoreUnknownKeys(true) + dec.SetAliasTag("json") + err = dec.Decode(post, req.Form) + if err != nil { + res.WriteHeader(http.StatusInternalServerError) + errView, err := Error500() + if err != nil { + return + } + + res.Write(errView) + return + } + + // call its Approve method + err = m.Approve(req) + if err != nil { + res.WriteHeader(http.StatusInternalServerError) + errView, err := Error500() + if err != nil { + return + } + + res.Write(errView) + return + } + + // Store the content in the bucket t + id, err := db.SetContent(t+":-1", req.Form) + if err != nil { + res.WriteHeader(http.StatusInternalServerError) + errView, err := Error500() + if err != nil { + return + } + + res.Write(errView) + return + } + + // redirect to the new approved content's editor + redir := req.URL.Scheme + req.URL.Host + strings.TrimSuffix(req.URL.Path, "/approve") + redir += fmt.Sprintf("?type=%s&id=%d", t, id) + http.Redirect(res, req, redir, http.StatusFound) +} + func editHandler(res http.ResponseWriter, req *http.Request) { switch req.Method { case http.MethodGet: q := req.URL.Query() i := q.Get("id") t := q.Get("type") + status := q.Get("status") + contentType, ok := content.Types[t] if !ok { fmt.Fprintf(res, content.ErrTypeNotRegistered, t) @@ -476,9 +894,13 @@ func editHandler(res http.ResponseWriter, req *http.Request) { post := contentType() if i != "" { + if status == "pending" { + t = t + "_pending" + } + data, err := db.Content(t + ":" + i) if err != nil { - fmt.Println(err) + log.Println(err) res.WriteHeader(http.StatusInternalServerError) errView, err := Error500() if err != nil { @@ -490,6 +912,7 @@ func editHandler(res http.ResponseWriter, req *http.Request) { } if len(data) < 1 || data == nil { + fmt.Println(string(data)) res.WriteHeader(http.StatusNotFound) errView, err := Error404() if err != nil { @@ -502,7 +925,7 @@ func editHandler(res http.ResponseWriter, req *http.Request) { err = json.Unmarshal(data, post) if err != nil { - fmt.Println(err) + log.Println(err) res.WriteHeader(http.StatusInternalServerError) errView, err := Error500() if err != nil { @@ -518,7 +941,7 @@ func editHandler(res http.ResponseWriter, req *http.Request) { m, err := manager.Manage(post.(editor.Editable), t) if err != nil { - fmt.Println(err) + log.Println(err) res.WriteHeader(http.StatusInternalServerError) errView, err := Error500() if err != nil { @@ -531,7 +954,7 @@ func editHandler(res http.ResponseWriter, req *http.Request) { adminView, err := Admin(m) if err != nil { - fmt.Println(err) + log.Println(err) res.WriteHeader(http.StatusInternalServerError) return } @@ -542,7 +965,7 @@ func editHandler(res http.ResponseWriter, req *http.Request) { case http.MethodPost: err := req.ParseMultipartForm(1024 * 1024 * 4) // maxMemory 4MB if err != nil { - fmt.Println(err) + log.Println(err) res.WriteHeader(http.StatusBadRequest) errView, err := Error405() if err != nil { @@ -568,9 +991,9 @@ func editHandler(res http.ResponseWriter, req *http.Request) { req.PostForm.Set("updated", ts) } - urlPaths, err := storeFileUploads(req) + urlPaths, err := upload.StoreFiles(req) if err != nil { - fmt.Println(err) + log.Println(err) res.WriteHeader(http.StatusInternalServerError) errView, err := Error500() if err != nil { @@ -608,7 +1031,7 @@ func editHandler(res http.ResponseWriter, req *http.Request) { id, err := db.SetContent(t+":"+cid, req.PostForm) if err != nil { - fmt.Println(err) + log.Println(err) res.WriteHeader(http.StatusInternalServerError) errView, err := Error500() if err != nil { @@ -639,7 +1062,7 @@ func deleteHandler(res http.ResponseWriter, req *http.Request) { err := req.ParseMultipartForm(1024 * 1024 * 4) // maxMemory 4MB if err != nil { - fmt.Println("req.ParseMPF") + log.Println(err) res.WriteHeader(http.StatusInternalServerError) return } @@ -654,11 +1077,17 @@ func deleteHandler(res http.ResponseWriter, req *http.Request) { err = db.DeleteContent(t + ":" + id) if err != nil { - fmt.Println("db.DeleteContent") + log.Println(err) res.WriteHeader(http.StatusInternalServerError) return } + // catch specifier suffix from delete form value + if strings.Contains(t, "_") { + spec := strings.Split(t, "_") + t = spec[0] + } + redir := strings.TrimSuffix(req.URL.Scheme+req.URL.Host+req.URL.Path, "/edit/delete") redir = redir + "/posts?type=" + t http.Redirect(res, req, redir, http.StatusFound) @@ -670,9 +1099,9 @@ func editUploadHandler(res http.ResponseWriter, req *http.Request) { return } - urlPaths, err := storeFileUploads(req) + urlPaths, err := upload.StoreFiles(req) if err != nil { - fmt.Println("Couldn't store file uploads.", err) + log.Println("Couldn't store file uploads.", err) res.WriteHeader(http.StatusInternalServerError) return } @@ -726,7 +1155,7 @@ func searchHandler(res http.ResponseWriter, req *http.Request) { continue } - post := adminPostListItem(p, t) + post := adminPostListItem(p, t, "") b.Write([]byte(post)) } @@ -737,7 +1166,7 @@ func searchHandler(res http.ResponseWriter, req *http.Request) { adminView, err := Admin([]byte(html)) if err != nil { - fmt.Println(err) + log.Println(err) res.WriteHeader(http.StatusInternalServerError) return } diff --git a/system/admin/server.go b/system/admin/server.go index b3b128d..5d93d84 100644 --- a/system/admin/server.go +++ b/system/admin/server.go @@ -20,12 +20,15 @@ func Run() { http.HandleFunc("/admin/configure", user.Auth(configHandler)) http.HandleFunc("/admin/configure/users", user.Auth(configUsersHandler)) + http.HandleFunc("/admin/configure/users/edit", user.Auth(configUsersEditHandler)) + http.HandleFunc("/admin/configure/users/delete", user.Auth(configUsersDeleteHandler)) http.HandleFunc("/admin/posts", user.Auth(postsHandler)) http.HandleFunc("/admin/posts/search", user.Auth(searchHandler)) http.HandleFunc("/admin/edit", user.Auth(editHandler)) http.HandleFunc("/admin/edit/delete", user.Auth(deleteHandler)) + http.HandleFunc("/admin/edit/approve", user.Auth(approvePostHandler)) http.HandleFunc("/admin/edit/upload", user.Auth(editUploadHandler)) pwd, err := os.Getwd() diff --git a/system/admin/static/common/js/util.js b/system/admin/static/common/js/util.js index 7f4c8ab..8d5e74b 100644 --- a/system/admin/static/common/js/util.js +++ b/system/admin/static/common/js/util.js @@ -67,4 +67,20 @@ function getPartialDate(unix) { d.dd = day; return d; +} + +// Returns a part of the window URL 'search' string +function getParam(param) { + var qs = window.location.search.substring(1); + var qp = qs.split('&'); + var t = ''; + + for (var i = 0; i < qp.length; i++) { + var p = qp[i].split('=') + if (p[0] === param) { + t = p[1]; + } + } + + return t; }
\ No newline at end of file diff --git a/system/admin/static/dashboard/css/admin.css b/system/admin/static/dashboard/css/admin.css index e23658e..3cffc5d 100644 --- a/system/admin/static/dashboard/css/admin.css +++ b/system/admin/static/dashboard/css/admin.css @@ -44,7 +44,7 @@ padding: 0px !important; } -ul.posts li { +ul.posts li, ul.users li { display: block; margin: 0 0 20px 0; padding: 0 0 20px 0 !important; @@ -164,15 +164,15 @@ span.post-detail { font-style: italic; } -.quick-delete-post { +.quick-delete-post, .delete-user { display: none; } -li:hover .quick-delete-post { +li:hover .quick-delete-post, li:hover .delete-user { display: inline-block; } -.quick-delete-post span { +.quick-delete-post span, .delete-user span { cursor: pointer; color: #F44336; text-transform: uppercase; @@ -181,6 +181,10 @@ li:hover .quick-delete-post { margin-right: 20px; } +.user-management { + padding: 20px; +} + /* OVERRIDE Bootstrap + Materialize conflicts */ .iso-texteditor.input-field label { @@ -199,3 +203,17 @@ li:hover .quick-delete-post { -o-transform: translateY(-140%); transform: translateY(-140%); } + +.chips { + margin-top: 10px; +} + +.external.post-controls .col.input-field { + margin-top: 40px; + padding: 0; +} + +.approve-details { + text-align: right; + padding: 10px 0 !important; +}
\ No newline at end of file diff --git a/system/admin/upload.go b/system/admin/upload/upload.go index 7f2a4fa..169bffe 100644 --- a/system/admin/upload.go +++ b/system/admin/upload/upload.go @@ -1,4 +1,4 @@ -package admin +package upload import ( "fmt" @@ -10,7 +10,8 @@ import ( "time" ) -func storeFileUploads(req *http.Request) (map[string]string, error) { +// StoreFiles stores file uploads at paths like /YYYY/MM/filename.ext +func StoreFiles(req *http.Request) (map[string]string, error) { err := req.ParseMultipartForm(1024 * 1024 * 4) // maxMemory 4MB if err != nil { return nil, fmt.Errorf("%s", err) diff --git a/system/api/analytics/init.go b/system/api/analytics/init.go new file mode 100644 index 0000000..c351bed --- /dev/null +++ b/system/api/analytics/init.go @@ -0,0 +1,101 @@ +// Package analytics provides the methods to run an analytics reporting system +// for API requests which may be useful to users for measuring access and +// possibly identifying bad actors abusing requests. +package analytics + +import ( + "log" + "net/http" + "strings" + "time" + + "github.com/boltdb/bolt" +) + +type apiRequest struct { + URL string `json:"url"` + Method string `json:"http_method"` + RemoteAddr string `json:"ip_address"` + Timestamp int64 `json:"timestamp"` + External bool `json:"external"` +} + +var ( + store *bolt.DB + recordChan chan apiRequest +) + +// Record queues an apiRequest for metrics +func Record(req *http.Request) { + external := strings.Contains(req.URL.Path, "/external/") + + r := apiRequest{ + URL: req.URL.String(), + Method: req.Method, + RemoteAddr: req.RemoteAddr, + Timestamp: time.Now().Unix() * 1000, + External: external, + } + + // put r on buffered recordChan to take advantage of batch insertion in DB + recordChan <- r + +} + +// Close exports the abillity to close our db file. Should be called with defer +// after call to Init() from the same place. +func Close() { + err := store.Close() + if err != nil { + log.Println(err) + } +} + +// Init creates a db connection, should run an initial prune of old data, and +// sets up the queue/batching channel +func Init() { + var err error + store, err = bolt.Open("analytics.db", 0666, nil) + if err != nil { + log.Fatalln(err) + } + + recordChan = make(chan apiRequest, 1024*128) + + go serve() + + err = store.Update(func(tx *bolt.Tx) error { + + return nil + }) + if err != nil { + log.Fatalln(err) + } +} + +func serve() { + // make timer to notify select to batch request insert from recordChan + // interval: 1 minute + apiRequestTimer := time.NewTicker(time.Minute * 1) + + // make timer to notify select to remove old analytics + // interval: 2 weeks + // TODO: enable analytics backup service to cloud + pruneDBTimer := time.NewTicker(time.Hour * 24 * 14) + + for { + select { + case <-apiRequestTimer.C: + var reqs []apiRequest + batchSize := len(recordChan) + + for i := 0; i < batchSize; i++ { + reqs = append(reqs, <-recordChan) + } + + case <-pruneDBTimer.C: + + default: + } + } +} diff --git a/system/api/external.go b/system/api/external.go new file mode 100644 index 0000000..4a42eb5 --- /dev/null +++ b/system/api/external.go @@ -0,0 +1,109 @@ +package api + +import ( + "fmt" + "log" + "net/http" + "strings" + "time" + + "github.com/bosssauce/ponzu/content" + "github.com/bosssauce/ponzu/system/admin/upload" + "github.com/bosssauce/ponzu/system/db" +) + +// Externalable accepts or rejects external POST requests to endpoints such as: +// /external/posts?type=Review +type Externalable interface { + // Accepts determines whether a type will allow external submissions + Accepts() bool +} + +// Mergeable allows external post content to be approved and published through +// the public-facing API +type Mergeable interface { + // Approve copies an external post to the internal collection and triggers + // a re-sort of its content type posts + Approve(req *http.Request) error +} + +func externalPostsHandler(res http.ResponseWriter, req *http.Request) { + if req.Method != http.MethodPost { + res.WriteHeader(http.StatusMethodNotAllowed) + return + } + + err := req.ParseMultipartForm(1024 * 1024 * 4) // maxMemory 4MB + if err != nil { + log.Println("[External] error:", err) + res.WriteHeader(http.StatusInternalServerError) + return + } + + t := req.URL.Query().Get("type") + if t == "" { + res.WriteHeader(http.StatusBadRequest) + return + } + + p, found := content.Types[t] + if !found { + log.Println("[External] attempt to submit unknown type:", t, "from:", req.RemoteAddr) + res.WriteHeader(http.StatusNotFound) + return + } + + post := p() + + ext, ok := post.(Externalable) + if !ok { + log.Println("[External] rejected non-externalable type:", t, "from:", req.RemoteAddr) + res.WriteHeader(http.StatusBadRequest) + return + } + + if ext.Accepts() { + ts := fmt.Sprintf("%d", time.Now().Unix()*1000) + req.PostForm.Set("timestamp", ts) + req.PostForm.Set("updated", ts) + + urlPaths, err := upload.StoreFiles(req) + if err != nil { + log.Println(err) + res.WriteHeader(http.StatusInternalServerError) + return + } + + for name, urlPath := range urlPaths { + req.PostForm.Add(name, urlPath) + } + + // check for any multi-value fields (ex. checkbox fields) + // and correctly format for db storage. Essentially, we need + // fieldX.0: value1, fieldX.1: value2 => fieldX: []string{value1, value2} + var discardKeys []string + for k, v := range req.PostForm { + if strings.Contains(k, ".") { + key := strings.Split(k, ".")[0] + + if req.PostForm.Get(key) == "" { + req.PostForm.Set(key, v[0]) + discardKeys = append(discardKeys, k) + } else { + req.PostForm.Add(key, v[0]) + } + } + } + + for _, discardKey := range discardKeys { + req.PostForm.Del(discardKey) + } + + _, err = db.SetContent(t+"_pending:-1", req.PostForm) + if err != nil { + log.Println("[External] error:", err) + res.WriteHeader(http.StatusInternalServerError) + return + } + } +} diff --git a/system/api/handlers.go b/system/api/handlers.go index 0c9139f..8356683 100644 --- a/system/api/handlers.go +++ b/system/api/handlers.go @@ -5,8 +5,11 @@ import ( "encoding/json" "log" "net/http" + "strconv" + "strings" "github.com/bosssauce/ponzu/content" + "github.com/bosssauce/ponzu/system/api/analytics" "github.com/bosssauce/ponzu/system/db" ) @@ -28,24 +31,72 @@ func typesHandler(res http.ResponseWriter, req *http.Request) { func postsHandler(res http.ResponseWriter, req *http.Request) { q := req.URL.Query() t := q.Get("type") - // TODO: implement pagination - // num := q.Get("num") - // page := q.Get("page") - - // TODO: inplement time-based ?after=time.Time, ?before=time.Time between=time.Time|time.Time - if t == "" { res.WriteHeader(http.StatusBadRequest) return } - posts := db.ContentAll(t) + count, err := strconv.Atoi(q.Get("count")) // int: determines number of posts to return (10 default, -1 is all) + if err != nil { + if q.Get("count") == "" { + count = 10 + } else { + res.WriteHeader(http.StatusInternalServerError) + return + } + } + + offset, err := strconv.Atoi(q.Get("offset")) // int: multiplier of count for pagination (0 default) + if err != nil { + if q.Get("offset") == "" { + offset = 0 + } else { + res.WriteHeader(http.StatusInternalServerError) + return + } + } + + order := strings.ToLower(q.Get("order")) // string: sort order of posts by timestamp ASC / DESC (DESC default) + if order != "asc" || order == "" { + order = "desc" + } + + // TODO: time-based ?after=time.Time, ?before=time.Time between=time.Time|time.Time + + posts := db.ContentAll(t + "_sorted") var all = []json.RawMessage{} for _, post := range posts { all = append(all, post) } - j, err := fmtJSON(all...) + var start, end int + switch count { + case -1: + start = 0 + end = len(posts) + + default: + start = count * offset + end = start + count + } + + // bounds check on posts given the start & end count + if start > len(posts) { + start = len(posts) - count + } + if end > len(posts) { + end = len(posts) + } + + // reverse the sorted order if ASC + if order == "asc" { + all = []json.RawMessage{} + for i := len(posts) - 1; i >= 0; i-- { + all = append(all, posts[i]) + } + } + + j, err := fmtJSON(all[start:end]...) if err != nil { res.WriteHeader(http.StatusInternalServerError) return @@ -150,6 +201,7 @@ func SendJSON(res http.ResponseWriter, j map[string]interface{}) { data, err = json.Marshal(j) if err != nil { + log.Println(err) data, _ = json.Marshal(map[string]interface{}{ "status": "fail", "message": err.Error(), @@ -159,9 +211,6 @@ func SendJSON(res http.ResponseWriter, j map[string]interface{}) { sendData(res, data, 200) } -// ResponseFunc ... -type ResponseFunc func(http.ResponseWriter, *http.Request) - // CORS wraps a HandleFunc to response to OPTIONS requests properly func CORS(next http.HandlerFunc) http.HandlerFunc { return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { @@ -173,3 +222,12 @@ func CORS(next http.HandlerFunc) http.HandlerFunc { next.ServeHTTP(res, req) }) } + +// Record wraps a HandleFunc to record API requests for analytical purposes +func Record(next http.HandlerFunc) http.HandlerFunc { + return http.HandlerFunc(func(res http.ResponseWriter, req *http.Request) { + go analytics.Record(req) + + next.ServeHTTP(res, req) + }) +} diff --git a/system/api/server.go b/system/api/server.go index da73382..816bc21 100644 --- a/system/api/server.go +++ b/system/api/server.go @@ -9,4 +9,6 @@ func Run() { http.HandleFunc("/api/posts", CORS(postsHandler)) http.HandleFunc("/api/post", CORS(postHandler)) + + http.HandleFunc("/api/external/posts", CORS(externalPostsHandler)) } diff --git a/system/db/content.go b/system/db/content.go index 2d6e8ea..bd1ee4b 100644 --- a/system/db/content.go +++ b/system/db/content.go @@ -38,13 +38,20 @@ func SetContent(target string, data url.Values) (int, error) { } func update(ns, id string, data url.Values) (int, error) { + var specifier string // i.e. _pending, _sorted, etc. + if strings.Contains(ns, "_") { + spec := strings.Split(ns, "_") + ns = spec[0] + specifier = "_" + spec[1] + } + cid, err := strconv.Atoi(id) if err != nil { return 0, err } err = store.Update(func(tx *bolt.Tx) error { - b, err := tx.CreateBucketIfNotExists([]byte(ns)) + b, err := tx.CreateBucketIfNotExists([]byte(ns + specifier)) if err != nil { return err } @@ -65,15 +72,24 @@ func update(ns, id string, data url.Values) (int, error) { return 0, nil } - go SortContent(ns) + if specifier == "" { + go SortContent(ns) + } return cid, nil } func insert(ns string, data url.Values) (int, error) { var effectedID int + var specifier string // i.e. _pending, _sorted, etc. + if strings.Contains(ns, "_") { + spec := strings.Split(ns, "_") + ns = spec[0] + specifier = "_" + spec[1] + } + err := store.Update(func(tx *bolt.Tx) error { - b, err := tx.CreateBucketIfNotExists([]byte(ns)) + b, err := tx.CreateBucketIfNotExists([]byte(ns + specifier)) if err != nil { return err } @@ -89,7 +105,7 @@ func insert(ns string, data url.Values) (int, error) { if err != nil { return err } - data.Add("id", cid) + data.Set("id", cid) j, err := postToJSON(ns, data) if err != nil { @@ -107,40 +123,11 @@ func insert(ns string, data url.Values) (int, error) { return 0, err } - go SortContent(ns) - - return effectedID, nil -} - -func postToJSON(ns string, data url.Values) ([]byte, error) { - // find the content type and decode values into it - t, ok := content.Types[ns] - if !ok { - return nil, fmt.Errorf(content.ErrTypeNotRegistered, ns) - } - post := t() - - dec := schema.NewDecoder() - dec.SetAliasTag("json") // allows simpler struct tagging when creating a content type - dec.IgnoreUnknownKeys(true) // will skip over form values submitted, but not in struct - err := dec.Decode(post, data) - if err != nil { - return nil, err - } - - slug, err := manager.Slug(post.(editor.Editable)) - if err != nil { - return nil, err - } - post.(editor.Editable).SetSlug(slug) - - // marshall content struct to json for db storage - j, err := json.Marshal(post) - if err != nil { - return nil, err + if specifier == "" { + go SortContent(ns) } - return j, nil + return effectedID, nil } // DeleteContent removes an item from the database. Deleting a non-existent item @@ -153,7 +140,6 @@ func DeleteContent(target string) error { tx.Bucket([]byte(ns)).Delete([]byte(id)) return nil }) - if err != nil { return err } @@ -178,7 +164,7 @@ func Content(target string) ([]byte, error) { b := tx.Bucket([]byte(ns)) _, err := val.Write(b.Get([]byte(id))) if err != nil { - fmt.Println(err) + log.Println(err) return err } @@ -197,8 +183,12 @@ func ContentAll(namespace string) [][]byte { store.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte(namespace)) - len := b.Stats().KeyN - posts = make([][]byte, 0, len) + if b == nil { + return nil + } + + numKeys := b.Stats().KeyN + posts = make([][]byte, 0, numKeys) b.ForEach(func(k, v []byte) error { posts = append(posts, v) @@ -216,10 +206,15 @@ func ContentAll(namespace string) [][]byte { // in descending order, from most recent to least recent // Should be called from a goroutine after SetContent is successful func SortContent(namespace string) { + // only sort main content types i.e. Post + if strings.Contains(namespace, "_") { + return + } + all := ContentAll(namespace) var posts sortablePosts - // decode each (json) into Editable + // decode each (json) into type to then sort for i := range all { j := all[i] post := content.Types[namespace]() @@ -238,18 +233,14 @@ func SortContent(namespace string) { // store in <namespace>_sorted bucket, first delete existing err := store.Update(func(tx *bolt.Tx) error { - err := tx.DeleteBucket([]byte(namespace + "_sorted")) + bname := []byte(namespace + "_sorted") + err := tx.DeleteBucket(bname) if err != nil { return err } - b, err := tx.CreateBucket([]byte(namespace + "_sorted")) + b, err := tx.CreateBucketIfNotExists(bname) if err != nil { - err := tx.Rollback() - if err != nil { - return err - } - return err } @@ -263,11 +254,6 @@ func SortContent(namespace string) { cid := fmt.Sprintf("%d:%d", i, posts[i].Time()) err = b.Put([]byte(cid), j) if err != nil { - err := tx.Rollback() - if err != nil { - return err - } - return err } } @@ -293,3 +279,35 @@ func (s sortablePosts) Less(i, j int) bool { func (s sortablePosts) Swap(i, j int) { s[i], s[j] = s[j], s[i] } + +func postToJSON(ns string, data url.Values) ([]byte, error) { + // find the content type and decode values into it + ns = strings.TrimSuffix(ns, "_external") + t, ok := content.Types[ns] + if !ok { + return nil, fmt.Errorf(content.ErrTypeNotRegistered, ns) + } + post := t() + + dec := schema.NewDecoder() + dec.SetAliasTag("json") // allows simpler struct tagging when creating a content type + dec.IgnoreUnknownKeys(true) // will skip over form values submitted, but not in struct + err := dec.Decode(post, data) + if err != nil { + return nil, err + } + + slug, err := manager.Slug(post.(editor.Editable)) + if err != nil { + return nil, err + } + post.(editor.Editable).SetSlug(slug) + + // marshall content struct to json for db storage + j, err := json.Marshal(post) + if err != nil { + return nil, err + } + + return j, nil +} diff --git a/system/db/init.go b/system/db/init.go index 1a5ed25..63804e1 100644 --- a/system/db/init.go +++ b/system/db/init.go @@ -13,12 +13,21 @@ import ( var store *bolt.DB +// Close exports the abillity to close our db file. Should be called with defer +// after call to Init() from the same place. +func Close() { + err := store.Close() + if err != nil { + log.Println(err) + } +} + // Init creates a db connection, initializes db with required info, sets secrets func Init() { var err error - store, err = bolt.Open("store.db", 0666, nil) + store, err = bolt.Open("system.db", 0666, nil) if err != nil { - log.Fatal(err) + log.Fatalln(err) } err = store.Update(func(tx *bolt.Tx) error { @@ -67,10 +76,9 @@ func Init() { return nil }) if err != nil { - log.Fatal("Coudn't initialize db with buckets.", err) + log.Fatalln("Coudn't initialize db with buckets.", err) } - // sort all content into type_sorted buckets go func() { for t := range content.Types { SortContent(t) @@ -99,7 +107,7 @@ func SystemInitComplete() bool { }) if err != nil { complete = false - log.Fatal(err) + log.Fatalln(err) } return complete diff --git a/system/db/user.go b/system/db/user.go index a3a0be3..d2dc3a9 100644 --- a/system/db/user.go +++ b/system/db/user.go @@ -4,15 +4,21 @@ import ( "bytes" "encoding/json" "errors" + "fmt" + "net/http" "github.com/bosssauce/ponzu/system/admin/user" "github.com/boltdb/bolt" + "github.com/nilslice/jwt" ) // ErrUserExists is used for the db to report to admin user of existing user var ErrUserExists = errors.New("Error. User exists.") +// ErrNoUserExists is used for the db to report to admin user of non-existing user +var ErrNoUserExists = errors.New("Error. No user exists.") + // SetUser sets key:value pairs in the db for user settings func SetUser(usr *user.User) (int, error) { err := store.Update(func(tx *bolt.Tx) error { @@ -38,7 +44,7 @@ func SetUser(usr *user.User) (int, error) { return err } - err = users.Put([]byte(usr.Email), j) + err = users.Put(email, j) if err != nil { return err } @@ -52,6 +58,65 @@ func SetUser(usr *user.User) (int, error) { return usr.ID, nil } +// UpdateUser sets key:value pairs in the db for existing user settings +func UpdateUser(usr, updatedUsr *user.User) error { + err := store.Update(func(tx *bolt.Tx) error { + users := tx.Bucket([]byte("_users")) + + // check if user is found by email, fail if nil + exists := users.Get([]byte(usr.Email)) + if exists == nil { + return ErrNoUserExists + } + + // marshal User to json and put into bucket + j, err := json.Marshal(updatedUsr) + if err != nil { + return err + } + + err = users.Put([]byte(updatedUsr.Email), j) + if err != nil { + return err + } + + // if email address was changed, delete the old record of former + // user with original email address + if usr.Email != updatedUsr.Email { + err = users.Delete([]byte(usr.Email)) + if err != nil { + return err + } + + } + + return nil + }) + if err != nil { + return err + } + + return nil +} + +// DeleteUser deletes a user from the db by email +func DeleteUser(email string) error { + err := store.Update(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("_users")) + err := b.Delete([]byte(email)) + if err != nil { + return err + } + + return nil + }) + if err != nil { + return err + } + + return nil +} + // User gets the user by email from the db func User(email string) ([]byte, error) { val := &bytes.Buffer{} @@ -72,3 +137,50 @@ func User(email string) ([]byte, error) { return val.Bytes(), nil } + +// UserAll returns all users from the db +func UserAll() ([][]byte, error) { + var users [][]byte + err := store.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("_users")) + err := b.ForEach(func(k, v []byte) error { + users = append(users, v) + return nil + }) + if err != nil { + return err + } + + return nil + }) + if err != nil { + return nil, err + } + + return users, nil +} + +// CurrentUser extracts the user from the request data and returns the current user from the db +func CurrentUser(req *http.Request) ([]byte, error) { + if !user.IsValid(req) { + return nil, fmt.Errorf("Error. Invalid User.") + } + + token, err := req.Cookie("_token") + if err != nil { + return nil, err + } + + claims := jwt.GetClaims(token.Value) + email, ok := claims["user"] + if !ok { + return nil, fmt.Errorf("Error. No user data found in request token.") + } + + usr, err := User(email.(string)) + if err != nil { + return nil, err + } + + return usr, nil +} diff --git a/system/tls/enable.go b/system/tls/enable.go new file mode 100644 index 0000000..c53fac6 --- /dev/null +++ b/system/tls/enable.go @@ -0,0 +1,79 @@ +package tls + +import ( + "crypto/tls" + "fmt" + "log" + "net/http" + "os" + "path/filepath" + "time" + + "github.com/bosssauce/ponzu/system/db" + + "golang.org/x/crypto/acme/autocert" +) + +var m autocert.Manager + +// setup attempts to locate or create the cert cache directory and the certs for TLS encryption +func setup() { + pwd, err := os.Getwd() + if err != nil { + log.Fatalln("Couldn't find working directory to locate or save certificates.") + } + + cache := autocert.DirCache(filepath.Join(pwd, "system", "tls", "certs")) + if _, err := os.Stat(string(cache)); os.IsNotExist(err) { + err := os.MkdirAll(string(cache), os.ModePerm|os.ModeDir) + if err != nil { + log.Fatalln("Couldn't create cert directory at", cache) + } + } + + // get host/domain and email from Config to use for TLS request to Let's encryption. + // we will fail fatally if either are not found since Let's Encrypt will rate-limit + // and sending incomplete requests is wasteful and guarenteed to fail its check + host, err := db.Config("domain") + if err != nil { + log.Fatalln("Error identifying host/domain during TLS set-up.", err) + } + + if host == nil { + log.Fatalln("No 'domain' field set in Configuration. Please add a domain before attempting to make certificates.") + } + fmt.Println("Using", host, "as host/domain for certificate...") + fmt.Println("NOTE: if the host/domain is not configured properly or is unreachable, HTTPS set-up will fail.") + + email, err := db.Config("admin_email") + if err != nil { + log.Fatalln("Error identifying admin email during TLS set-up.", err) + } + + if email == nil { + log.Fatalln("No 'admin_email' field set in Configuration. Please add an admin email before attempting to make certificates.") + } + fmt.Println("Using", email, "as contact email for certificate...") + + m = autocert.Manager{ + Prompt: autocert.AcceptTOS, + Cache: cache, + HostPolicy: autocert.HostWhitelist(string(host)), + RenewBefore: time.Hour * 24 * 30, + Email: string(email), + } + +} + +// Enable runs the setup for creating or locating certificates and starts the TLS server +func Enable() { + setup() + + server := &http.Server{ + Addr: ":443", + TLSConfig: &tls.Config{GetCertificate: m.GetCertificate}, + } + + go log.Fatalln(server.ListenAndServeTLS("", "")) + fmt.Println("Server listening for HTTPS requests...") +} |