package command import ( "context" "errors" "fmt" "gitea.illuad.fr/adrien/middleman" "gitea.illuad.fr/adrien/middleman/flag" "github.com/rs/zerolog/log" "github.com/urfave/cli/v3" "golang.org/x/sync/errgroup" "net" "net/http" "net/http/httputil" "net/netip" "net/url" "regexp" "strings" "time" ) // serve is the structure representing the command described below. type serve struct { group *errgroup.Group } // addrPorts holds netip.AddrPort of different HTTP servers. type addrPorts struct { srvAddrPort, healthAddrPort netip.AddrPort } // methodRegex maps HTTP methods (GET, POST, DELETE...) to a compiled regular expression. type methodRegex = map[string]*regexp.Regexp // ProxyHandler takes an incoming request and sends it to another server, proxying the response back to the client. type ProxyHandler struct { rp *httputil.ReverseProxy } // ErrHTTPMethodNotAllowed is returned if the request's HTTP method is not allowed for this container. type ErrHTTPMethodNotAllowed struct { httpMethod string } // ErrNoMatch is returned if the request's URL path is not allowed for this container. type ErrNoMatch struct { path, httpMethod string } func (e ErrHTTPMethodNotAllowed) Error() string { return fmt.Sprintf("%s is not in the list of HTTP methods allowed this container", e.httpMethod) } func (e ErrNoMatch) Error() string { return fmt.Sprintf("%s does not match any registered regular expression for this HTTP method (%s)", e.path, e.httpMethod) } // ServeCmd is the command name. const ServeCmd = "serve" const ( // listenAddrFlagName is the flag name used to set the listen address and port. listenAddrFlagName = "listen-addr" // dockerSocketPathFlagName is the flag name used to set the Docker unix socket path. dockerSocketPathFlagName = "docker-socket-path" // shutdownTimeoutFlagName is the flag name used to set the server shutdown timeout. shutdownTimeoutFlagName = "shutdown-timeout" // noDockerSocketHealthcheckFlagName is the flag name used to disable the Docker unix socket healthcheck. noDockerSocketHealthcheckFlagName = "no-docker-socket-healthcheck" // noShutdownFlagName is the flag name used to disable the application shutdown in case of Docker unix socket healthcheck failure. noShutdownFlagName = "no-shutdown" // healthcheckRetryFlagName is the flag name used to set the number of dial attempts with the Docker unix socket before the application shutdown (if enabled). healthcheckRetryFlagName = "healthcheck-retry" // healthEndpointFlagName is the flag name used to set the health endpoint. healthEndpointFlagName = "health-endpoint" // healthListenAddrFlagName is the flag name used to set the health endpoint listen address and port. healthListenAddrFlagName = "health-listen-addr" // addRequestsFlagName is the flag name used to add allowed requests. addRequestsFlagName = "add-requests" ) const ( // defaultListenAddr is the proxy default listen address and port. defaultListenAddr = "0.0.0.0:2375" // defaultDockerSocketPath is the default Docker socket path. defaultDockerSocketPath = "/var/run/docker.sock" // defaultShutdownTimeout is the default server shutdown timeout. defaultShutdownTimeout = 5 * time.Second // defaultHealthcheckRetry is the default number of healthcheck retry defaultHealthcheckRetry uint64 = 3 // defaultHealthEndpoint is the default health endpoint. defaultHealthEndpoint = "/healthz" // defaultHealthListenAddr is the proxy health endpoint default listen address and port. defaultHealthListenAddr = "127.0.0.1:5732" // defaultAllowedRequest is the proxy default allowed request. defaultAllowedRequest = "^/(version|containers/.*|events.*)$" ) var s serve var ( containerMethodRegex = map[string]methodRegex{ "*": { http.MethodGet: regexp.MustCompile(defaultAllowedRequest), }, } applicatorURLRegex = regexp.MustCompile(`^([a-zA-Z0-9][a-zA-Z0-9_.-]+)\(((?:(?:GET|HEAD|POST|PUT|PATCH|DELETE|CONNECT|TRACE|OPTIONS)(?:,(?:GET|HEAD|POST|PUT|PATCH|DELETE|CONNECT|TRACE|OPTIONS))*)?)\):(.*)$`) validHTTPMethods = map[string]struct{}{ http.MethodGet: {}, http.MethodHead: {}, http.MethodPost: {}, http.MethodPut: {}, http.MethodPatch: {}, http.MethodDelete: {}, http.MethodConnect: {}, http.MethodTrace: {}, http.MethodOptions: {}, } containerNameRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_.-]+`) httpMethodsRegex = regexp.MustCompile(`(?:GET|HEAD|POST|PUT|PATCH|DELETE|CONNECT|TRACE|OPTIONS)(?:GET|HEAD|POST|PUT|PATCH|DELETE|CONNECT|TRACE|OPTIONS)*`) ) // Serve describes the serve command. func Serve(group *errgroup.Group) *cli.Command { s.group = group return &cli.Command{ Name: ServeCmd, Aliases: middleman.PluckAlias(ServeCmd, ServeCmd), Usage: "Runs serve mode", Description: "Proxy requests to the Docker socket using defined access control", Flags: []cli.Flag{ listenAddr(), dockerSocketPath(), shutdownTimeout(), noDockerSocketHealthcheck(), noShutdown(), healthcheckRetry(), healthEndpoint(), healthListenAddr(), addRequests(), flag.LogFormat(ServeCmd), flag.LogLevel(ServeCmd), }, Before: before, Action: s.action, DisableSliceFlagSeparator: true, } } func (ph *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.Debug().Str("remote_addr", r.RemoteAddr).Str("method", r.Method).Str("path", r.URL.Path).Msg("incoming request") mr, ok := containerMethodRegex["*"] if ok { var req *regexp.Regexp req, ok = mr[r.Method] if !ok { log.Error(). Str("remote_addr", r.RemoteAddr). Str("method", r.Method). Str("path", r.URL.Path). Str("decision", "denied"). Msg("this HTTP method is not in the list of those authorized for this container") http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) return } if !req.MatchString(r.URL.Path) { log.Error(). Str("remote_addr", r.RemoteAddr). Str("method", r.Method). Str("path", r.URL.Path). Str("decision", "denied"). Msg("this path does not match any regular expression for this HTTP method") http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) return } log.Info(). Str("remote_addr", r.RemoteAddr). Str("method", r.Method). Str("path", r.URL.Path). Str("decision", "authorized"). Msg("incoming request matches a registered regular expression") return /* if err := checkMethodPath(r, mr); err != nil { handleError(w, err) log.Err(err).Send() return } */ } else { var ( containerName string host, _, _ = net.SplitHostPort(r.RemoteAddr) ip = net.ParseIP(host) ) for containerName, mr = range containerMethodRegex { resolvedIPs, err := net.LookupIP(containerName) if err != nil { // log.Warn().Err(err).Msg("this error may be transient due to the unavailability of one of the services") continue } for _, resolvedIP := range resolvedIPs { if resolvedIP.Equal(ip) { var req *regexp.Regexp req, ok = mr[r.Method] if !ok { log.Error(). Str("remote_addr", r.RemoteAddr). Str("method", r.Method). Str("path", r.URL.Path). Str("decision", "denied"). Msg("this HTTP method is not in the list of those authorized for this container") http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed) return } if !req.MatchString(r.URL.Path) { log.Error(). Str("remote_addr", r.RemoteAddr). Str("method", r.Method). Str("path", r.URL.Path). Str("decision", "denied"). Msg("this path does not match any regular expression for this HTTP method") http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) return } /* if err = checkMethodPath(r, mr); err != nil { handleError(w, err) log.Err(err).Send() return } */ log.Info(). Str("remote_addr", r.RemoteAddr). Str("method", r.Method). Str("path", r.URL.Path). Str("decision", "authorized"). Str("from", containerName). Msg("incoming request matches a registered regular expression") ph.rp.ServeHTTP(w, r) return } } } } log.Warn(). Str("remote_addr", r.RemoteAddr). Str("method", r.Method). Str("path", r.URL.Path). Str("decision", "denied"). Msg("this error may be transient due to the unavailability of one of the services") http.Error(w, http.StatusText(http.StatusServiceUnavailable), http.StatusServiceUnavailable) return } // checkMethodPath executes the regular expression on the path of the HTTP request if and only if // the latter's HTTP method is actually present in the list of authorized HTTP methods. func checkMethodPath(r *http.Request, mr methodRegex) error { req, ok := mr[r.Method] if !ok { return ErrHTTPMethodNotAllowed{httpMethod: r.Method} } if !req.MatchString(r.URL.Path) { return ErrNoMatch{path: r.URL.Path, httpMethod: r.Method} } return nil } // action is executed when the ServeCmd command is called. func (s serve) action(ctx context.Context, command *cli.Command) error { if err := ctx.Err(); err != nil { return err } ap, err := parseAddrPorts(command.String(listenAddrFlagName), command.String(healthListenAddrFlagName)) if err != nil { return err } var l net.Listener l, err = net.Listen("tcp", ap.srvAddrPort.String()) if err != nil { return err } dummyURL, _ := url.Parse("http://dummy") rp := httputil.NewSingleHostReverseProxy(dummyURL) rp.Transport = &http.Transport{ DialContext: func(_ context.Context, _ string, _ string) (net.Conn, error) { return net.Dial("unix", command.String(dockerSocketPathFlagName)) }, } srv := &http.Server{ // #nosec: G112 Handler: &ProxyHandler{rp: rp}, } s.group.Go(func() error { if err = srv.Serve(l); err != nil && !errors.Is(err, http.ErrServerClosed) { return err } return nil }) retry := command.Uint(healthcheckRetryFlagName) if !command.Bool(noDockerSocketHealthcheckFlagName) { ticker := time.Tick(2 * time.Second) s.group.Go(func() error { loop: for { if err = unixDial(command.String(dockerSocketPathFlagName)); err != nil { if !command.Bool(noShutdownFlagName) { if retry == 0 { return err } log.Err(err).Uint64("retry_remaining", retry).Send() retry-- } } select { case <-ticker: continue case <-ctx.Done(): break loop } } return ctx.Err() }) } mux := http.NewServeMux() mux.HandleFunc(command.String(healthEndpointFlagName), func(w http.ResponseWriter, r *http.Request) { log.Trace().Str("from", r.RemoteAddr).Msg("incoming healthcheck request") if r.Method != http.MethodHead { w.WriteHeader(http.StatusMethodNotAllowed) return } if err = unixDial(command.String(dockerSocketPathFlagName)); err != nil { w.WriteHeader(http.StatusServiceUnavailable) } }) healthSrv := http.Server{ Addr: ap.healthAddrPort.String(), Handler: mux, ReadTimeout: 3 * time.Second, ReadHeaderTimeout: 3 * time.Second, WriteTimeout: 3 * time.Second, } s.group.Go(func() error { if err = healthSrv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) { return err } return nil }) log.Info(). Stringer("listen_addr", ap.srvAddrPort). Str("docker_socket_path", command.String(dockerSocketPathFlagName)). Stringer("shutdown_timeout", command.Duration(shutdownTimeoutFlagName)). Bool("docker_socket_healthcheck_disabled", command.Bool(noDockerSocketHealthcheckFlagName)). Bool("shutdown_on_failure_disabled", command.Bool(noShutdownFlagName)). Uint64("number_of_healthcheck_retry", command.Uint(healthcheckRetryFlagName)). Str("health_endpoint", command.String(healthEndpointFlagName)). Stringer("health_endpoint_listen_addr", ap.healthAddrPort). Strs("requests", command.StringSlice(addRequestsFlagName)). Str("log_format", command.String(flag.LogFormatFlagName)). Str("log_level", command.String(flag.LogLevelFlagName)). Msg("middleman started") <-ctx.Done() shutdownCtx, cancel := context.WithTimeout(context.Background(), command.Duration(shutdownTimeoutFlagName)) defer cancel() if err = healthSrv.Close(); err != nil { log.Err(err).Send() } return srv.Shutdown(shutdownCtx) } func parseAddrPorts(srvAddrPort, healthAddrPort string) (addrPorts, error) { var ap addrPorts if err := parseAddrPort(srvAddrPort, &ap.srvAddrPort); err != nil { return ap, err } return ap, parseAddrPort(healthAddrPort, &ap.healthAddrPort) } func parseAddrPort(addrPortStr string, addrPort *netip.AddrPort) error { ap, err := netip.ParseAddrPort(addrPortStr) if err != nil { return err } *addrPort = ap return nil } func unixDial(socketPath string) error { conn, err := net.DialTimeout("unix", socketPath, 3*time.Second) if err != nil { return err } return conn.Close() } func listenAddr() cli.Flag { return &cli.StringFlag{ Name: listenAddrFlagName, Category: "network", Usage: "proxy listen address", Sources: middleman.PluckEnvVar(EnvVarPrefix, listenAddrFlagName), Value: defaultListenAddr, Aliases: middleman.PluckAlias(ServeCmd, listenAddrFlagName), } } func dockerSocketPath() cli.Flag { return &cli.StringFlag{ Name: dockerSocketPathFlagName, Category: "docker", Usage: "Docker unix socket path", Sources: middleman.PluckEnvVar(EnvVarPrefix, dockerSocketPathFlagName), Value: defaultDockerSocketPath, Aliases: middleman.PluckAlias(ServeCmd, dockerSocketPathFlagName), } } func shutdownTimeout() cli.Flag { return &cli.DurationFlag{ Name: shutdownTimeoutFlagName, Category: "network", Usage: "server shutdown timeout", Sources: middleman.PluckEnvVar(EnvVarPrefix, shutdownTimeoutFlagName), Value: defaultShutdownTimeout, Aliases: middleman.PluckAlias(ServeCmd, shutdownTimeoutFlagName), } } func noDockerSocketHealthcheck() cli.Flag { return &cli.BoolFlag{ Name: noDockerSocketHealthcheckFlagName, Category: "docker", HideDefault: true, Usage: "disable Docker unix socket healthcheck", Sources: middleman.PluckEnvVar(EnvVarPrefix, noDockerSocketHealthcheckFlagName), Aliases: middleman.PluckAlias(ServeCmd, noDockerSocketHealthcheckFlagName), } } func noShutdown() cli.Flag { return &cli.BoolFlag{ Name: noShutdownFlagName, Category: "internal monitoring behavior", HideDefault: true, Usage: "disable application shutdown in case of Docker unix socket healthcheck failure", Sources: middleman.PluckEnvVar(EnvVarPrefix, noShutdownFlagName), Aliases: middleman.PluckAlias(ServeCmd, noShutdownFlagName), } } func healthcheckRetry() cli.Flag { return &cli.UintFlag{ Name: healthcheckRetryFlagName, Category: "internal monitoring behavior", Usage: "number of dial attempts with the Docker unix socket before the application shutdown (if enabled)", Sources: middleman.PluckEnvVar(EnvVarPrefix, healthcheckRetryFlagName), Value: defaultHealthcheckRetry, Aliases: middleman.PluckAlias(ServeCmd, healthcheckRetryFlagName), } } func healthEndpoint() cli.Flag { return &cli.StringFlag{ Name: healthEndpointFlagName, Category: "monitoring", Usage: "health endpoint", Sources: middleman.PluckEnvVar(EnvVarPrefix, healthEndpointFlagName), Value: defaultHealthEndpoint, Aliases: middleman.PluckAlias(ServeCmd, healthEndpointFlagName), } } func healthListenAddr() cli.Flag { return &cli.StringFlag{ Name: healthListenAddrFlagName, Category: "monitoring", Usage: "proxy health endpoint listen address", Sources: middleman.PluckEnvVar(EnvVarPrefix, healthListenAddrFlagName), Value: defaultHealthListenAddr, Aliases: middleman.PluckAlias(ServeCmd, healthListenAddrFlagName), } } func addRequests() cli.Flag { return &cli.StringSliceFlag{ Name: addRequestsFlagName, Category: "network", Usage: "add requests", Sources: middleman.PluckEnvVar(EnvVarPrefix, addRequestsFlagName), // Local: true, // Required to trigger the Action when this flag is set via the environment variable, see https://github.com/urfave/cli/issues/2041. Value: []string{"*:" + defaultAllowedRequest}, Aliases: middleman.PluckAlias(ServeCmd, addRequestsFlagName), Action: func(ctx context.Context, command *cli.Command, requests []string) error { clear(containerMethodRegex) for _, request := range requests { // An applicator is a container name/HTTP method pair e.g., nginx(GET). // The following regex extracts this applicator and the associated URL regex. aur := applicatorURLRegex.FindString(request) if aur == "" { // If we are here, it means that user set a wildcard (kinda) applicator e.g., GET,POST:URL_REGEX. // In its extended form, this applicator can be read as follows: *(GET,POST):URL_REGEX. methods, urlRegex, ok := strings.Cut(request, ":") if !ok { // || methods == ""? return errors.New("HTTP method(s) must be specified before ':'") } if err := registerMethodRegex("*", urlRegex, strings.Split(methods, ",")); err != nil { return err } } applicator, urlRegex, _ := strings.Cut(aur, ":") if err := registerMethodRegex(containerNameRegex.FindString(applicator), urlRegex, httpMethodsRegex.FindAllString(applicator, -1)); err != nil { return err } } return nil }, } } func registerMethodRegex(containerName, urlRegex string, httpMethods []string) error { r, err := regexp.Compile(urlRegex) if err != nil { return err } for _, httpMethod := range httpMethods { if _, ok := validHTTPMethods[httpMethod]; !ok { return fmt.Errorf("%s is not a valid HTTP method", httpMethod) } if containerMethodRegex[containerName] == nil { containerMethodRegex[containerName] = make(methodRegex) } containerMethodRegex[containerName][httpMethod] = r } return nil } func handleError(w http.ResponseWriter, err error) { var methodNotAllowedErr ErrHTTPMethodNotAllowed var noMatchErr ErrNoMatch if errors.As(err, &methodNotAllowedErr) { http.Error(w, err.Error(), http.StatusMethodNotAllowed) } else if errors.As(err, &noMatchErr) { http.Error(w, err.Error(), http.StatusForbidden) } else { http.Error(w, err.Error(), http.StatusInternalServerError) } }