467 lines
15 KiB
Go
467 lines
15 KiB
Go
package command
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"gitea.illuad.fr/adrien/middleman"
|
|
"gitea.illuad.fr/adrien/middleman/flag"
|
|
"github.com/rs/zerolog/log"
|
|
"github.com/urfave/cli/v3"
|
|
"golang.org/x/sync/errgroup"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"net/netip"
|
|
"net/url"
|
|
"regexp"
|
|
"strings"
|
|
"time"
|
|
)
|
|
|
|
// serve is the structure representing the command described below.
|
|
type serve struct {
|
|
group *errgroup.Group
|
|
}
|
|
|
|
// addrPorts holds netip.AddrPort of different HTTP servers.
|
|
type addrPorts struct {
|
|
srvAddrPort, healthAddrPort netip.AddrPort
|
|
}
|
|
|
|
// methodRegex maps HTTP methods (GET, POST, DELETE...) to a compiled regular expression.
|
|
type methodRegex = map[string]*regexp.Regexp
|
|
|
|
// ProxyHandler takes an incoming request and sends it to another server, proxying the response back to the client.
|
|
type ProxyHandler struct {
|
|
rp *httputil.ReverseProxy
|
|
}
|
|
|
|
// ErrHTTPMethodNotAllowed is returned if the request's HTTP method is not allowed for this container.
|
|
type ErrHTTPMethodNotAllowed struct {
|
|
httpMethod string
|
|
}
|
|
|
|
// ErrNoMatch is returned if the request's URL path is not allowed for this container.
|
|
type ErrNoMatch struct {
|
|
path, httpMethod string
|
|
}
|
|
|
|
func (e ErrHTTPMethodNotAllowed) Error() string {
|
|
return fmt.Sprintf("%s is not in the list of HTTP methods allowed this container", e.httpMethod)
|
|
}
|
|
|
|
func (e ErrNoMatch) Error() string {
|
|
return fmt.Sprintf("%s does not match any registered regular expression for this HTTP method (%s)", e.path, e.httpMethod)
|
|
}
|
|
|
|
// ServeCmd is the command name.
|
|
const ServeCmd = "serve"
|
|
|
|
const (
|
|
// listenAddrFlagName is the flag name used to set the listen address and port.
|
|
listenAddrFlagName = "listen-addr"
|
|
// dockerSocketPathFlagName is the flag name used to set the Docker unix socket path.
|
|
dockerSocketPathFlagName = "docker-socket-path"
|
|
// shutdownTimeoutFlagName is the flag name used to set the server shutdown timeout.
|
|
shutdownTimeoutFlagName = "shutdown-timeout"
|
|
// noDockerSocketHealthcheckFlagName is the flag name used to disable the Docker unix socket healthcheck.
|
|
noDockerSocketHealthcheckFlagName = "no-docker-socket-healthcheck"
|
|
// noShutdownFlagName is the flag name used to disable the application shutdown in case of Docker unix socket healthcheck failure.
|
|
noShutdownFlagName = "no-shutdown"
|
|
// healthcheckRetryFlagName is the flag name used to set the number of dial attempts with the Docker unix socket before the application shutdown (if enabled).
|
|
healthcheckRetryFlagName = "healthcheck-retry"
|
|
// healthEndpointFlagName is the flag name used to set the health endpoint.
|
|
healthEndpointFlagName = "health-endpoint"
|
|
// healthListenAddrFlagName is the flag name used to set the health endpoint listen address and port.
|
|
healthListenAddrFlagName = "health-listen-addr"
|
|
// addRequestsFlagName is the flag name used to add allowed requests.
|
|
addRequestsFlagName = "add-requests"
|
|
)
|
|
|
|
const (
|
|
// defaultListenAddr is the proxy default listen address and port.
|
|
defaultListenAddr = "0.0.0.0:2375"
|
|
// defaultDockerSocketPath is the default Docker socket path.
|
|
defaultDockerSocketPath = "/var/run/docker.sock"
|
|
// defaultShutdownTimeout is the default server shutdown timeout.
|
|
defaultShutdownTimeout = 5 * time.Second
|
|
// defaultHealthcheckRetry is the default number of healthcheck retry
|
|
defaultHealthcheckRetry uint64 = 3
|
|
// defaultHealthEndpoint is the default health endpoint.
|
|
defaultHealthEndpoint = "/healthz"
|
|
// defaultHealthListenAddr is the proxy health endpoint default listen address and port.
|
|
defaultHealthListenAddr = "127.0.0.1:5732"
|
|
// defaultAllowedRequest is the proxy default allowed request.
|
|
defaultAllowedRequest = "^/(version|containers/.*|events.*)$"
|
|
)
|
|
|
|
var s serve
|
|
|
|
var (
|
|
containerMethodRegex = map[string]methodRegex{
|
|
"*": {
|
|
http.MethodGet: regexp.MustCompile(defaultAllowedRequest),
|
|
},
|
|
}
|
|
applicatorURLRegex = regexp.MustCompile(`^([a-zA-Z0-9][a-zA-Z0-9_.-]+)\(((?:(?:GET|HEAD|POST|PUT|PATCH|DELETE|CONNECT|TRACE|OPTIONS)(?:,(?:GET|HEAD|POST|PUT|PATCH|DELETE|CONNECT|TRACE|OPTIONS))*)?)\):(.*)$`)
|
|
validHTTPMethods = map[string]struct{}{
|
|
http.MethodGet: {},
|
|
http.MethodHead: {},
|
|
http.MethodPost: {},
|
|
http.MethodPut: {},
|
|
http.MethodPatch: {},
|
|
http.MethodDelete: {},
|
|
http.MethodConnect: {},
|
|
http.MethodTrace: {},
|
|
http.MethodOptions: {},
|
|
}
|
|
containerNameRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_.-]+`)
|
|
httpMethodsRegex = regexp.MustCompile(`(?:GET|HEAD|POST|PUT|PATCH|DELETE|CONNECT|TRACE|OPTIONS)(?:GET|HEAD|POST|PUT|PATCH|DELETE|CONNECT|TRACE|OPTIONS)*`)
|
|
)
|
|
|
|
// Serve describes the serve command.
|
|
func Serve(group *errgroup.Group) *cli.Command {
|
|
s.group = group
|
|
return &cli.Command{
|
|
Name: ServeCmd,
|
|
Aliases: middleman.PluckAlias(ServeCmd, ServeCmd),
|
|
Usage: "Runs serve mode",
|
|
Description: "Proxy requests to the Docker socket using defined access control",
|
|
Flags: []cli.Flag{
|
|
listenAddr(),
|
|
dockerSocketPath(),
|
|
shutdownTimeout(),
|
|
noDockerSocketHealthcheck(),
|
|
noShutdown(),
|
|
healthcheckRetry(),
|
|
healthEndpoint(),
|
|
healthListenAddr(),
|
|
addRequests(),
|
|
flag.LogFormat(ServeCmd),
|
|
flag.LogLevel(ServeCmd),
|
|
},
|
|
Before: before,
|
|
Action: s.action,
|
|
DisableSliceFlagSeparator: true,
|
|
}
|
|
}
|
|
|
|
func (ph *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
|
log.Debug().Str("http_method", r.Method).Str("path", r.URL.Path).Msg("incoming request")
|
|
mr, ok := containerMethodRegex["*"]
|
|
if ok {
|
|
if err := checkMethodPath(w, r, mr); err != nil {
|
|
log.Err(err).Send()
|
|
return
|
|
}
|
|
} else {
|
|
var (
|
|
containerName string
|
|
host, _, _ = net.SplitHostPort(r.RemoteAddr)
|
|
ip = net.ParseIP(host)
|
|
)
|
|
for containerName, mr = range containerMethodRegex {
|
|
resolvedIPs, err := net.LookupIP(containerName)
|
|
if err != nil {
|
|
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
|
return
|
|
}
|
|
for _, resolvedIP := range resolvedIPs {
|
|
if resolvedIP.Equal(ip) {
|
|
if err = checkMethodPath(w, r, mr); err != nil {
|
|
log.Err(err).Send()
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
ph.rp.ServeHTTP(w, r)
|
|
}
|
|
|
|
// checkMethodPath executes the regular expression on the path of the HTTP request if and only if
|
|
// the latter's HTTP method is actually present in the list of authorized HTTP methods.
|
|
func checkMethodPath(w http.ResponseWriter, r *http.Request, mr methodRegex) error {
|
|
req, ok := mr[r.Method]
|
|
if !ok {
|
|
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
|
|
return ErrHTTPMethodNotAllowed{httpMethod: r.Method}
|
|
}
|
|
if !req.MatchString(r.URL.Path) {
|
|
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
|
|
return ErrNoMatch{
|
|
path: r.URL.Path,
|
|
httpMethod: r.Method,
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// action is executed when the ServeCmd command is called.
|
|
func (s serve) action(ctx context.Context, command *cli.Command) error {
|
|
if err := ctx.Err(); err != nil {
|
|
return err
|
|
}
|
|
ap, err := parseAddrPorts(command.String(listenAddrFlagName), command.String(healthListenAddrFlagName))
|
|
if err != nil {
|
|
return err
|
|
}
|
|
var l net.Listener
|
|
l, err = net.Listen("tcp", ap.srvAddrPort.String())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
dummyURL, _ := url.Parse("http://dummy")
|
|
rp := httputil.NewSingleHostReverseProxy(dummyURL)
|
|
rp.Transport = &http.Transport{
|
|
DialContext: func(_ context.Context, _ string, _ string) (net.Conn, error) {
|
|
return net.Dial("unix", command.String(dockerSocketPathFlagName))
|
|
},
|
|
}
|
|
srv := &http.Server{ // #nosec: G112
|
|
Handler: &ProxyHandler{rp: rp},
|
|
}
|
|
s.group.Go(func() error {
|
|
if err = srv.Serve(l); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
retry := command.Uint(healthcheckRetryFlagName)
|
|
if !command.Bool(noDockerSocketHealthcheckFlagName) {
|
|
ticker := time.Tick(2 * time.Second)
|
|
s.group.Go(func() error {
|
|
loop:
|
|
for {
|
|
if err = unixDial(command.String(dockerSocketPathFlagName)); err != nil {
|
|
if !command.Bool(noShutdownFlagName) {
|
|
if retry == 0 {
|
|
return err
|
|
}
|
|
log.Err(err).Uint64("retry_remaining", retry).Send()
|
|
retry--
|
|
}
|
|
}
|
|
select {
|
|
case <-ticker:
|
|
continue
|
|
case <-ctx.Done():
|
|
break loop
|
|
}
|
|
}
|
|
return ctx.Err()
|
|
})
|
|
}
|
|
mux := http.NewServeMux()
|
|
mux.HandleFunc(command.String(healthEndpointFlagName), func(w http.ResponseWriter, r *http.Request) {
|
|
log.Trace().Str("from", r.RemoteAddr).Msg("incoming healthcheck request")
|
|
if r.Method != http.MethodHead {
|
|
w.WriteHeader(http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
if err = unixDial(command.String(dockerSocketPathFlagName)); err != nil {
|
|
w.WriteHeader(http.StatusServiceUnavailable)
|
|
}
|
|
})
|
|
healthSrv := http.Server{
|
|
Addr: ap.healthAddrPort.String(),
|
|
Handler: mux,
|
|
ReadTimeout: 3 * time.Second,
|
|
ReadHeaderTimeout: 3 * time.Second,
|
|
WriteTimeout: 3 * time.Second,
|
|
}
|
|
s.group.Go(func() error {
|
|
if err = healthSrv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
|
return err
|
|
}
|
|
return nil
|
|
})
|
|
log.Info().
|
|
Stringer("listen_addr", ap.srvAddrPort).
|
|
Str("docker_socket_path", command.String(dockerSocketPathFlagName)).
|
|
Stringer("shutdown_timeout", command.Duration(shutdownTimeoutFlagName)).
|
|
Bool("docker_socket_healthcheck_disabled", command.Bool(noDockerSocketHealthcheckFlagName)).
|
|
Bool("shutdown_on_failure_disabled", command.Bool(noShutdownFlagName)).
|
|
Uint64("number_of_healthcheck_retry", command.Uint(healthcheckRetryFlagName)).
|
|
Str("health_endpoint", command.String(healthEndpointFlagName)).
|
|
Stringer("health_endpoint_listen_addr", ap.healthAddrPort).
|
|
Strs("requests", command.StringSlice(addRequestsFlagName)).
|
|
Str("log_format", command.String(flag.LogFormatFlagName)).
|
|
Str("log_level", command.String(flag.LogLevelFlagName)).
|
|
Msg("middleman started")
|
|
<-ctx.Done()
|
|
shutdownCtx, cancel := context.WithTimeout(context.Background(), command.Duration(shutdownTimeoutFlagName))
|
|
defer cancel()
|
|
if err = healthSrv.Close(); err != nil {
|
|
log.Err(err).Send()
|
|
}
|
|
return srv.Shutdown(shutdownCtx)
|
|
}
|
|
|
|
func parseAddrPorts(srvAddrPort, healthAddrPort string) (addrPorts, error) {
|
|
var ap addrPorts
|
|
if err := parseAddrPort(srvAddrPort, &ap.srvAddrPort); err != nil {
|
|
return ap, err
|
|
}
|
|
return ap, parseAddrPort(healthAddrPort, &ap.healthAddrPort)
|
|
}
|
|
|
|
func parseAddrPort(addrPortStr string, addrPort *netip.AddrPort) error {
|
|
ap, err := netip.ParseAddrPort(addrPortStr)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
*addrPort = ap
|
|
return nil
|
|
}
|
|
|
|
func unixDial(socketPath string) error {
|
|
conn, err := net.DialTimeout("unix", socketPath, 3*time.Second)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
return conn.Close()
|
|
}
|
|
|
|
func listenAddr() cli.Flag {
|
|
return &cli.StringFlag{
|
|
Name: listenAddrFlagName,
|
|
Category: "network",
|
|
Usage: "proxy listen address",
|
|
Sources: middleman.PluckEnvVar(EnvVarPrefix, listenAddrFlagName),
|
|
Value: defaultListenAddr,
|
|
Aliases: middleman.PluckAlias(ServeCmd, listenAddrFlagName),
|
|
}
|
|
}
|
|
|
|
func dockerSocketPath() cli.Flag {
|
|
return &cli.StringFlag{
|
|
Name: dockerSocketPathFlagName,
|
|
Category: "docker",
|
|
Usage: "Docker unix socket path",
|
|
Sources: middleman.PluckEnvVar(EnvVarPrefix, dockerSocketPathFlagName),
|
|
Value: defaultDockerSocketPath,
|
|
Aliases: middleman.PluckAlias(ServeCmd, dockerSocketPathFlagName),
|
|
}
|
|
}
|
|
|
|
func shutdownTimeout() cli.Flag {
|
|
return &cli.DurationFlag{
|
|
Name: shutdownTimeoutFlagName,
|
|
Category: "network",
|
|
Usage: "server shutdown timeout",
|
|
Sources: middleman.PluckEnvVar(EnvVarPrefix, shutdownTimeoutFlagName),
|
|
Value: defaultShutdownTimeout,
|
|
Aliases: middleman.PluckAlias(ServeCmd, shutdownTimeoutFlagName),
|
|
}
|
|
}
|
|
|
|
func noDockerSocketHealthcheck() cli.Flag {
|
|
return &cli.BoolFlag{
|
|
Name: noDockerSocketHealthcheckFlagName,
|
|
Category: "docker",
|
|
HideDefault: true,
|
|
Usage: "disable Docker unix socket healthcheck",
|
|
Sources: middleman.PluckEnvVar(EnvVarPrefix, noDockerSocketHealthcheckFlagName),
|
|
Aliases: middleman.PluckAlias(ServeCmd, noDockerSocketHealthcheckFlagName),
|
|
}
|
|
}
|
|
|
|
func noShutdown() cli.Flag {
|
|
return &cli.BoolFlag{
|
|
Name: noShutdownFlagName,
|
|
Category: "internal monitoring behavior",
|
|
HideDefault: true,
|
|
Usage: "disable application shutdown in case of Docker unix socket healthcheck failure",
|
|
Sources: middleman.PluckEnvVar(EnvVarPrefix, noShutdownFlagName),
|
|
Aliases: middleman.PluckAlias(ServeCmd, noShutdownFlagName),
|
|
}
|
|
}
|
|
|
|
func healthcheckRetry() cli.Flag {
|
|
return &cli.UintFlag{
|
|
Name: healthcheckRetryFlagName,
|
|
Category: "internal monitoring behavior",
|
|
Usage: "number of dial attempts with the Docker unix socket before the application shutdown (if enabled)",
|
|
Sources: middleman.PluckEnvVar(EnvVarPrefix, healthcheckRetryFlagName),
|
|
Value: defaultHealthcheckRetry,
|
|
Aliases: middleman.PluckAlias(ServeCmd, healthcheckRetryFlagName),
|
|
}
|
|
}
|
|
|
|
func healthEndpoint() cli.Flag {
|
|
return &cli.StringFlag{
|
|
Name: healthEndpointFlagName,
|
|
Category: "monitoring",
|
|
Usage: "health endpoint",
|
|
Sources: middleman.PluckEnvVar(EnvVarPrefix, healthEndpointFlagName),
|
|
Value: defaultHealthEndpoint,
|
|
Aliases: middleman.PluckAlias(ServeCmd, healthEndpointFlagName),
|
|
}
|
|
}
|
|
|
|
func healthListenAddr() cli.Flag {
|
|
return &cli.StringFlag{
|
|
Name: healthListenAddrFlagName,
|
|
Category: "monitoring",
|
|
Usage: "proxy health endpoint listen address",
|
|
Sources: middleman.PluckEnvVar(EnvVarPrefix, healthListenAddrFlagName),
|
|
Value: defaultHealthListenAddr,
|
|
Aliases: middleman.PluckAlias(ServeCmd, healthListenAddrFlagName),
|
|
}
|
|
}
|
|
|
|
func addRequests() cli.Flag {
|
|
return &cli.StringSliceFlag{
|
|
Name: addRequestsFlagName,
|
|
Category: "network",
|
|
Usage: "add requests",
|
|
Sources: middleman.PluckEnvVar(EnvVarPrefix, addRequestsFlagName),
|
|
Local: true, // Required to trigger the Action when this flag is set via the environment variable, see https://github.com/urfave/cli/issues/2041.
|
|
Value: []string{"*:" + defaultAllowedRequest},
|
|
Aliases: middleman.PluckAlias(ServeCmd, addRequestsFlagName),
|
|
Action: func(ctx context.Context, command *cli.Command, requests []string) error {
|
|
clear(containerMethodRegex)
|
|
for _, request := range requests {
|
|
// An applicator is a container name/HTTP method pair e.g., nginx(GET).
|
|
// The following regex extracts this applicator and the associated URL regex.
|
|
aur := applicatorURLRegex.FindString(request)
|
|
if aur == "" {
|
|
// If we are here, it means that user set a wildcard (kinda) applicator e.g., GET,POST:URL_REGEX.
|
|
// In its extended form, this applicator can be read as follows: *(GET,POST):URL_REGEX.
|
|
methods, urlRegex, ok := strings.Cut(request, ":")
|
|
if !ok { // || methods == ""?
|
|
return errors.New("HTTP method(s) must be specified before ':'")
|
|
}
|
|
if err := registerMethodRegex("*", urlRegex, strings.Split(methods, ",")); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
applicator, urlRegex, _ := strings.Cut(aur, ":")
|
|
if err := registerMethodRegex(containerNameRegex.FindString(applicator), urlRegex, httpMethodsRegex.FindAllString(applicator, -1)); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
return nil
|
|
},
|
|
}
|
|
}
|
|
|
|
func registerMethodRegex(containerName, urlRegex string, httpMethods []string) error {
|
|
r, err := regexp.Compile(urlRegex)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
for _, httpMethod := range httpMethods {
|
|
if _, ok := validHTTPMethods[httpMethod]; !ok {
|
|
return fmt.Errorf("%s is not a valid HTTP method", httpMethod)
|
|
}
|
|
if containerMethodRegex[containerName] == nil {
|
|
containerMethodRegex[containerName] = make(methodRegex)
|
|
}
|
|
containerMethodRegex[containerName][httpMethod] = r
|
|
}
|
|
return nil
|
|
}
|