/* Authors: * Christian Heimes * * This program is free software; you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation; version 2 of the License. * * This program is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License along * with this program; if not, write to the Free Software Foundation, Inc., * 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA. * * Copyright (C) 2015 Red Hat, Inc. * All rights reserved. * * Custodia entrypoint for Docker */ package main import ( "bufio" "fmt" "encoding/json" "errors" "log" "net" "net/http" "os" "path" "strings" "syscall" ) const PREFIX string = "CUSTODIA_" const SECRET_PREFIX string = PREFIX + "SECRET_" const BASE_PATH string = "http+unix://localhost/secrets/" /* Unix Domain Socket transport * */ type UDSTransport struct { SocketPath string } func (uds *UDSTransport) RoundTrip(req *http.Request) (*http.Response, error) { if req.URL == nil { if req.Body != nil { req.Body.Close() } return nil, errors.New("uds: nil Request.URL") } if req.Header == nil { if req.Body != nil { req.Body.Close() } return nil, errors.New("uds: nil Request.Header") } if req.URL.Scheme != "http+unix" { if req.Body != nil { req.Body.Close() } return nil, errors.New("uds: unsupported protocol scheme " + req.URL.Scheme) } if req.URL.Host != "localhost" { if req.Body != nil { req.Body.Close() } return nil, errors.New("uds: Host must be 'localhost'") } conn, err := net.Dial("unix", uds.SocketPath) if err != nil { return nil, err } // XXX: leaks conn req.Write(conn) return http.ReadResponse(bufio.NewReader(conn), req) } func UnixClient(path string) *http.Client { transport := new(http.Transport) uds := &UDSTransport{SocketPath: path} transport.RegisterProtocol("http+unix", uds) client := &http.Client{Transport: transport} return client } /* Custodia client */ type CustodiaMessage struct { Type string `json:"type"` Value string `json:"value"` } type CustodiaSecret struct { Name string Value string Secret string Fetched bool } type CustodiaClient struct { SocketPath string BasePath string Prefix string SecretPrefix string RemoteUser string Secrets []*CustodiaSecret } func NewCustodiaClient(socketpath, remoteuser string) *CustodiaClient { return &CustodiaClient{ SocketPath: socketpath, BasePath: BASE_PATH, Prefix: PREFIX, SecretPrefix: SECRET_PREFIX, RemoteUser: remoteuser, Secrets: []*CustodiaSecret{}, } } func (cc *CustodiaClient) FindEnvs() { for _, env := range os.Environ() { if strings.HasPrefix(env, cc.SecretPrefix) { pair := strings.SplitN(env, "=", 2) sec := &CustodiaSecret{ Name: pair[0][len(cc.SecretPrefix):], Value: pair[1], Secret: "", Fetched: false, } cc.Secrets = append(cc.Secrets, sec) } } } func (cc *CustodiaClient) QueryCustodia() { if len(cc.Secrets) == 0 { return } client := UnixClient(cc.SocketPath) for _, sec := range cc.Secrets { path := cc.BasePath + sec.Value req, err := http.NewRequest( "GET", path, nil) if err != nil { panic(err) } if cc.RemoteUser != "" { req.Header.Add("REMOTE_USER", cc.RemoteUser) } resp, err := client.Do(req) if err != nil { log.Fatal(err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { log.Fatalf("%v: %v\n", path, resp.Status) } var m CustodiaMessage body := json.NewDecoder(resp.Body) err = body.Decode(&m) if err != nil { log.Fatalf("%v json error: %v", path, err) } if len(m.Value) > 0 { sec.Secret = m.Value sec.Fetched = true } } } func (cc CustodiaClient) MakeEnv() []string { environ := []string{} for _, env := range os.Environ() { if ! strings.HasPrefix(env, cc.Prefix) { environ = append(environ, env) } } for _, sec := range cc.Secrets { if sec.Fetched { env := fmt.Sprintf("%s=%s", sec.Name, sec.Secret) environ = append(environ, env) } } return environ } func (cc CustodiaClient) Debug() { fmt.Printf("%v secret(s):\n", len(cc.Secrets)) for _, sec := range cc.Secrets { fmt.Printf(" %+v\n", *sec) } fmt.Println("Environ:") environ := cc.MakeEnv() for _, env := range environ { fmt.Printf(" %v\n", env) } fmt.Println() } func main() { debug := os.Getenv("CUSTODIA_DEBUG") != "" if len(os.Args) < 2 { log.Fatalf("%s entrypoint [args]", os.Args[0]) } socketpath := os.Getenv("CUSTODIA_SOCKET") if socketpath == "" { // default socket is in the same directory as program socketpath = path.Join(path.Dir(os.Args[0]), "server_socket") } remoteuser := os.Getenv("CUSTODIA_REMOTE_USER") client := NewCustodiaClient(socketpath, remoteuser) if debug { fmt.Printf("%+v\n", client) } client.FindEnvs() client.QueryCustodia() environ := client.MakeEnv() if debug { client.Debug() } args := os.Args[1:] syscall.Exec(args[0], args, environ) }