middleman/command/serve.go
2025-04-08 09:03:00 +02:00

467 lines
15 KiB
Go

package command
import (
"context"
"errors"
"fmt"
"gitea.illuad.fr/adrien/middleman"
"gitea.illuad.fr/adrien/middleman/flag"
"github.com/rs/zerolog/log"
"github.com/urfave/cli/v3"
"golang.org/x/sync/errgroup"
"net"
"net/http"
"net/http/httputil"
"net/netip"
"net/url"
"regexp"
"strings"
"time"
)
// serve is the structure representing the command described below.
type serve struct {
group *errgroup.Group
}
// addrPorts holds netip.AddrPort of different HTTP servers.
type addrPorts struct {
srvAddrPort, healthAddrPort netip.AddrPort
}
// methodRegex maps HTTP methods (GET, POST, DELETE...) to a compiled regular expression.
type methodRegex = map[string]*regexp.Regexp
// ProxyHandler takes an incoming request and sends it to another server, proxying the response back to the client.
type ProxyHandler struct {
rp *httputil.ReverseProxy
}
// ErrHTTPMethodNotAllowed is returned if the request's HTTP method is not allowed for this container.
type ErrHTTPMethodNotAllowed struct {
httpMethod string
}
// ErrNoMatch is returned if the request's URL path is not allowed for this container.
type ErrNoMatch struct {
path, httpMethod string
}
func (e ErrHTTPMethodNotAllowed) Error() string {
return fmt.Sprintf("%s is not in the list of HTTP methods allowed this container", e.httpMethod)
}
func (e ErrNoMatch) Error() string {
return fmt.Sprintf("%s does not match any registered regular expression for this HTTP method (%s)", e.path, e.httpMethod)
}
// ServeCmd is the command name.
const ServeCmd = "serve"
const (
// listenAddrFlagName is the flag name used to set the listen address and port.
listenAddrFlagName = "listen-addr"
// dockerSocketPathFlagName is the flag name used to set the Docker unix socket path.
dockerSocketPathFlagName = "docker-socket-path"
// shutdownTimeoutFlagName is the flag name used to set the server shutdown timeout.
shutdownTimeoutFlagName = "shutdown-timeout"
// noDockerSocketHealthcheckFlagName is the flag name used to disable the Docker unix socket healthcheck.
noDockerSocketHealthcheckFlagName = "no-docker-socket-healthcheck"
// noShutdownFlagName is the flag name used to disable the application shutdown in case of Docker unix socket healthcheck failure.
noShutdownFlagName = "no-shutdown"
// healthcheckRetryFlagName is the flag name used to set the number of dial attempts with the Docker unix socket before the application shutdown (if enabled).
healthcheckRetryFlagName = "healthcheck-retry"
// healthEndpointFlagName is the flag name used to set the health endpoint.
healthEndpointFlagName = "health-endpoint"
// healthListenAddrFlagName is the flag name used to set the health endpoint listen address and port.
healthListenAddrFlagName = "health-listen-addr"
// addRequestsFlagName is the flag name used to add allowed requests.
addRequestsFlagName = "add-requests"
)
const (
// defaultListenAddr is the proxy default listen address and port.
defaultListenAddr = "0.0.0.0:2375"
// defaultDockerSocketPath is the default Docker socket path.
defaultDockerSocketPath = "/var/run/docker.sock"
// defaultShutdownTimeout is the default server shutdown timeout.
defaultShutdownTimeout = 5 * time.Second
// defaultHealthcheckRetry is the default number of healthcheck retry
defaultHealthcheckRetry uint64 = 3
// defaultHealthEndpoint is the default health endpoint.
defaultHealthEndpoint = "/healthz"
// defaultHealthListenAddr is the proxy health endpoint default listen address and port.
defaultHealthListenAddr = "127.0.0.1:5732"
// defaultAllowedRequest is the proxy default allowed request.
defaultAllowedRequest = "^/(version|containers/.*|events.*)$"
)
var s serve
var (
containerMethodRegex = map[string]methodRegex{
"*": {
http.MethodGet: regexp.MustCompile(defaultAllowedRequest),
},
}
applicatorURLRegex = regexp.MustCompile(`^([a-zA-Z0-9][a-zA-Z0-9_.-]+)\(((?:(?:GET|HEAD|POST|PUT|PATCH|DELETE|CONNECT|TRACE|OPTIONS)(?:,(?:GET|HEAD|POST|PUT|PATCH|DELETE|CONNECT|TRACE|OPTIONS))*)?)\):(.*)$`)
validHTTPMethods = map[string]struct{}{
http.MethodGet: {},
http.MethodHead: {},
http.MethodPost: {},
http.MethodPut: {},
http.MethodPatch: {},
http.MethodDelete: {},
http.MethodConnect: {},
http.MethodTrace: {},
http.MethodOptions: {},
}
containerNameRegex = regexp.MustCompile(`^[a-zA-Z0-9][a-zA-Z0-9_.-]+`)
httpMethodsRegex = regexp.MustCompile(`(?:GET|HEAD|POST|PUT|PATCH|DELETE|CONNECT|TRACE|OPTIONS)(?:GET|HEAD|POST|PUT|PATCH|DELETE|CONNECT|TRACE|OPTIONS)*`)
)
// Serve describes the serve command.
func Serve(group *errgroup.Group) *cli.Command {
s.group = group
return &cli.Command{
Name: ServeCmd,
Aliases: middleman.PluckAlias(ServeCmd, ServeCmd),
Usage: "Runs serve mode",
Description: "Proxy requests to the Docker socket using defined access control",
Flags: []cli.Flag{
listenAddr(),
dockerSocketPath(),
shutdownTimeout(),
noDockerSocketHealthcheck(),
noShutdown(),
healthcheckRetry(),
healthEndpoint(),
healthListenAddr(),
addRequests(),
flag.LogFormat(ServeCmd),
flag.LogLevel(ServeCmd),
},
Before: before,
Action: s.action,
DisableSliceFlagSeparator: true,
}
}
func (ph *ProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Debug().Str("http_method", r.Method).Str("path", r.URL.Path).Msg("incoming request")
mr, ok := containerMethodRegex["*"]
if ok {
if err := checkMethodPath(w, r, mr); err != nil {
log.Err(err).Send()
return
}
} else {
var (
containerName string
host, _, _ = net.SplitHostPort(r.RemoteAddr)
ip = net.ParseIP(host)
)
for containerName, mr = range containerMethodRegex {
resolvedIPs, err := net.LookupIP(containerName)
if err != nil {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
for _, resolvedIP := range resolvedIPs {
if resolvedIP.Equal(ip) {
if err = checkMethodPath(w, r, mr); err != nil {
log.Err(err).Send()
return
}
}
}
}
}
ph.rp.ServeHTTP(w, r)
}
// checkMethodPath executes the regular expression on the path of the HTTP request if and only if
// the latter's HTTP method is actually present in the list of authorized HTTP methods.
func checkMethodPath(w http.ResponseWriter, r *http.Request, mr methodRegex) error {
req, ok := mr[r.Method]
if !ok {
http.Error(w, http.StatusText(http.StatusMethodNotAllowed), http.StatusMethodNotAllowed)
return ErrHTTPMethodNotAllowed{httpMethod: r.Method}
}
if !req.MatchString(r.URL.Path) {
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return ErrNoMatch{
path: r.URL.Path,
httpMethod: r.Method,
}
}
return nil
}
// action is executed when the ServeCmd command is called.
func (s serve) action(ctx context.Context, command *cli.Command) error {
if err := ctx.Err(); err != nil {
return err
}
ap, err := parseAddrPorts(command.String(listenAddrFlagName), command.String(healthListenAddrFlagName))
if err != nil {
return err
}
var l net.Listener
l, err = net.Listen("tcp", ap.srvAddrPort.String())
if err != nil {
return err
}
dummyURL, _ := url.Parse("http://dummy")
rp := httputil.NewSingleHostReverseProxy(dummyURL)
rp.Transport = &http.Transport{
DialContext: func(_ context.Context, _ string, _ string) (net.Conn, error) {
return net.Dial("unix", command.String(dockerSocketPathFlagName))
},
}
srv := &http.Server{ // #nosec: G112
Handler: &ProxyHandler{rp: rp},
}
s.group.Go(func() error {
if err = srv.Serve(l); err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
})
retry := command.Uint(healthcheckRetryFlagName)
if !command.Bool(noDockerSocketHealthcheckFlagName) {
ticker := time.Tick(2 * time.Second)
s.group.Go(func() error {
loop:
for {
if err = unixDial(command.String(dockerSocketPathFlagName)); err != nil {
if !command.Bool(noShutdownFlagName) {
if retry == 0 {
return err
}
log.Err(err).Uint64("retry_remaining", retry).Send()
retry--
}
}
select {
case <-ticker:
continue
case <-ctx.Done():
break loop
}
}
return ctx.Err()
})
}
mux := http.NewServeMux()
mux.HandleFunc(command.String(healthEndpointFlagName), func(w http.ResponseWriter, r *http.Request) {
log.Trace().Str("from", r.RemoteAddr).Msg("incoming healthcheck request")
if r.Method != http.MethodHead {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
if err = unixDial(command.String(dockerSocketPathFlagName)); err != nil {
w.WriteHeader(http.StatusServiceUnavailable)
}
})
healthSrv := http.Server{
Addr: ap.healthAddrPort.String(),
Handler: mux,
ReadTimeout: 3 * time.Second,
ReadHeaderTimeout: 3 * time.Second,
WriteTimeout: 3 * time.Second,
}
s.group.Go(func() error {
if err = healthSrv.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
return err
}
return nil
})
log.Info().
Stringer("listen_addr", ap.srvAddrPort).
Str("docker_socket_path", command.String(dockerSocketPathFlagName)).
Stringer("shutdown_timeout", command.Duration(shutdownTimeoutFlagName)).
Bool("docker_socket_healthcheck_disabled", command.Bool(noDockerSocketHealthcheckFlagName)).
Bool("shutdown_on_failure_disabled", command.Bool(noShutdownFlagName)).
Uint64("number_of_healthcheck_retry", command.Uint(healthcheckRetryFlagName)).
Str("health_endpoint", command.String(healthEndpointFlagName)).
Stringer("health_endpoint_listen_addr", ap.healthAddrPort).
Strs("requests", command.StringSlice(addRequestsFlagName)).
Str("log_format", command.String(flag.LogFormatFlagName)).
Str("log_level", command.String(flag.LogLevelFlagName)).
Msg("middleman started")
<-ctx.Done()
shutdownCtx, cancel := context.WithTimeout(context.Background(), command.Duration(shutdownTimeoutFlagName))
defer cancel()
if err = healthSrv.Close(); err != nil {
log.Err(err).Send()
}
return srv.Shutdown(shutdownCtx)
}
func parseAddrPorts(srvAddrPort, healthAddrPort string) (addrPorts, error) {
var ap addrPorts
if err := parseAddrPort(srvAddrPort, &ap.srvAddrPort); err != nil {
return ap, err
}
return ap, parseAddrPort(healthAddrPort, &ap.healthAddrPort)
}
func parseAddrPort(addrPortStr string, addrPort *netip.AddrPort) error {
ap, err := netip.ParseAddrPort(addrPortStr)
if err != nil {
return err
}
*addrPort = ap
return nil
}
func unixDial(socketPath string) error {
conn, err := net.DialTimeout("unix", socketPath, 3*time.Second)
if err != nil {
return err
}
return conn.Close()
}
func listenAddr() cli.Flag {
return &cli.StringFlag{
Name: listenAddrFlagName,
Category: "network",
Usage: "proxy listen address",
Sources: middleman.PluckEnvVar(EnvVarPrefix, listenAddrFlagName),
Value: defaultListenAddr,
Aliases: middleman.PluckAlias(ServeCmd, listenAddrFlagName),
}
}
func dockerSocketPath() cli.Flag {
return &cli.StringFlag{
Name: dockerSocketPathFlagName,
Category: "docker",
Usage: "Docker unix socket path",
Sources: middleman.PluckEnvVar(EnvVarPrefix, dockerSocketPathFlagName),
Value: defaultDockerSocketPath,
Aliases: middleman.PluckAlias(ServeCmd, dockerSocketPathFlagName),
}
}
func shutdownTimeout() cli.Flag {
return &cli.DurationFlag{
Name: shutdownTimeoutFlagName,
Category: "network",
Usage: "server shutdown timeout",
Sources: middleman.PluckEnvVar(EnvVarPrefix, shutdownTimeoutFlagName),
Value: defaultShutdownTimeout,
Aliases: middleman.PluckAlias(ServeCmd, shutdownTimeoutFlagName),
}
}
func noDockerSocketHealthcheck() cli.Flag {
return &cli.BoolFlag{
Name: noDockerSocketHealthcheckFlagName,
Category: "docker",
HideDefault: true,
Usage: "disable Docker unix socket healthcheck",
Sources: middleman.PluckEnvVar(EnvVarPrefix, noDockerSocketHealthcheckFlagName),
Aliases: middleman.PluckAlias(ServeCmd, noDockerSocketHealthcheckFlagName),
}
}
func noShutdown() cli.Flag {
return &cli.BoolFlag{
Name: noShutdownFlagName,
Category: "internal monitoring behavior",
HideDefault: true,
Usage: "disable application shutdown in case of Docker unix socket healthcheck failure",
Sources: middleman.PluckEnvVar(EnvVarPrefix, noShutdownFlagName),
Aliases: middleman.PluckAlias(ServeCmd, noShutdownFlagName),
}
}
func healthcheckRetry() cli.Flag {
return &cli.UintFlag{
Name: healthcheckRetryFlagName,
Category: "internal monitoring behavior",
Usage: "number of dial attempts with the Docker unix socket before the application shutdown (if enabled)",
Sources: middleman.PluckEnvVar(EnvVarPrefix, healthcheckRetryFlagName),
Value: defaultHealthcheckRetry,
Aliases: middleman.PluckAlias(ServeCmd, healthcheckRetryFlagName),
}
}
func healthEndpoint() cli.Flag {
return &cli.StringFlag{
Name: healthEndpointFlagName,
Category: "monitoring",
Usage: "health endpoint",
Sources: middleman.PluckEnvVar(EnvVarPrefix, healthEndpointFlagName),
Value: defaultHealthEndpoint,
Aliases: middleman.PluckAlias(ServeCmd, healthEndpointFlagName),
}
}
func healthListenAddr() cli.Flag {
return &cli.StringFlag{
Name: healthListenAddrFlagName,
Category: "monitoring",
Usage: "proxy health endpoint listen address",
Sources: middleman.PluckEnvVar(EnvVarPrefix, healthListenAddrFlagName),
Value: defaultHealthListenAddr,
Aliases: middleman.PluckAlias(ServeCmd, healthListenAddrFlagName),
}
}
func addRequests() cli.Flag {
return &cli.StringSliceFlag{
Name: addRequestsFlagName,
Category: "network",
Usage: "add requests",
Sources: middleman.PluckEnvVar(EnvVarPrefix, addRequestsFlagName),
Local: true, // Required to trigger the Action when this flag is set via the environment variable, see https://github.com/urfave/cli/issues/2041.
Value: []string{"*:" + defaultAllowedRequest},
Aliases: middleman.PluckAlias(ServeCmd, addRequestsFlagName),
Action: func(ctx context.Context, command *cli.Command, requests []string) error {
clear(containerMethodRegex)
for _, request := range requests {
// An applicator is a container name/HTTP method pair e.g., nginx(GET).
// The following regex extracts this applicator and the associated URL regex.
aur := applicatorURLRegex.FindString(request)
if aur == "" {
// If we are here, it means that user set a wildcard (kinda) applicator e.g., GET,POST:URL_REGEX.
// In its extended form, this applicator can be read as follows: *(GET,POST):URL_REGEX.
methods, urlRegex, ok := strings.Cut(request, ":")
if !ok { // || methods == ""?
return errors.New("HTTP method(s) must be specified before ':'")
}
if err := registerMethodRegex("*", urlRegex, strings.Split(methods, ",")); err != nil {
return err
}
}
applicator, urlRegex, _ := strings.Cut(aur, ":")
if err := registerMethodRegex(containerNameRegex.FindString(applicator), urlRegex, httpMethodsRegex.FindAllString(applicator, -1)); err != nil {
return err
}
}
return nil
},
}
}
func registerMethodRegex(containerName, urlRegex string, httpMethods []string) error {
r, err := regexp.Compile(urlRegex)
if err != nil {
return err
}
for _, httpMethod := range httpMethods {
if _, ok := validHTTPMethods[httpMethod]; !ok {
return fmt.Errorf("%s is not a valid HTTP method", httpMethod)
}
if containerMethodRegex[containerName] == nil {
containerMethodRegex[containerName] = make(methodRegex)
}
containerMethodRegex[containerName][httpMethod] = r
}
return nil
}