Skip to content
Closed
165 changes: 18 additions & 147 deletions listeners.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ import (
"github.com/caddyserver/caddy/v2/internal"
)

// listenFdsStart is the first file descriptor number for systemd socket activation.
// File descriptors 0, 1, 2 are reserved for stdin, stdout, stderr.
const listenFdsStart = 3

// NetworkAddress represents one or more network addresses.
// It contains the individual components for a parsed network
// address of the form accepted by ParseNetworkAddress().
Expand Down Expand Up @@ -137,31 +133,29 @@ func (na NetworkAddress) ListenAll(ctx context.Context, config net.ListenConfig)
// Listen synchronizes binds to unix domain sockets to avoid race conditions
// while an existing socket is unlinked.
func (na NetworkAddress) Listen(ctx context.Context, portOffset uint, config net.ListenConfig) (any, error) {
if na.IsUnixNetwork() {
unixSocketsMu.Lock()
defer unixSocketsMu.Unlock()
}
var (
ln any
err error
)

// check to see if plugin provides listener
if ln, err := getListenerFromPlugin(ctx, na.Network, na.Host, na.port(), portOffset, config); ln != nil || err != nil {
// check to see if network provides a listener
if ln, err = getListenerFromNetwork(ctx, na.Network, na.Host, na.port(), portOffset, config); ln != nil || err != nil {
return ln, err
}

// create (or reuse) the listener ourselves
return na.listen(ctx, portOffset, config)
}

func (na NetworkAddress) listen(ctx context.Context, portOffset uint, config net.ListenConfig) (any, error) {
var (
ln any
err error
address string
unixFileMode fs.FileMode
)

// lock other unix sockets from being bound and
// split unix socket addr early so lnKey
// is independent of permissions bits
if na.IsUnixNetwork() {
unixSocketsMu.Lock()
defer unixSocketsMu.Unlock()

address, unixFileMode, err = internal.SplitUnixSocketPermissionsBits(na.Host)
if err != nil {
return nil, err
Expand All @@ -172,7 +166,7 @@ func (na NetworkAddress) listen(ctx context.Context, portOffset uint, config net
address = na.JoinHostPort(portOffset)
}

if strings.HasPrefix(na.Network, "ip") {
if na.IsIpNetwork() {
ln, err = config.ListenPacket(ctx, na.Network, address)
} else {
if na.IsUnixNetwork() {
Expand Down Expand Up @@ -220,6 +214,12 @@ func (na NetworkAddress) IsFdNetwork() bool {
return IsFdNetwork(na.Network)
}

// IsIpNetwork returns true if na.Network starts with
// ip: ip4: or ip6:
func (na NetworkAddress) IsIpNetwork() bool {
return IsIpNetwork(na.Network)
}

// JoinHostPort is like net.JoinHostPort, but where the port
// is StartPort + offset.
func (na NetworkAddress) JoinHostPort(offset uint) string {
Expand Down Expand Up @@ -299,74 +299,6 @@ func (na NetworkAddress) String() string {
return JoinNetworkAddress(na.Network, na.Host, na.port())
}

// IsUnixNetwork returns true if the netw is a unix network.
func IsUnixNetwork(netw string) bool {
return strings.HasPrefix(netw, "unix")
}

// IsFdNetwork returns true if the netw is a fd network.
func IsFdNetwork(netw string) bool {
return strings.HasPrefix(netw, "fd")
}

// getFdByName returns the file descriptor number for the given
// socket name from systemd's LISTEN_FDNAMES environment variable.
// Socket names are provided by systemd via socket activation.
//
// The name can optionally include an index to handle multiple sockets
// with the same name: "web:0" for first, "web:1" for second, etc.
// If no index is specified, defaults to index 0 (first occurrence).
func getFdByName(nameWithIndex string) (int, error) {
if nameWithIndex == "" {
return 0, fmt.Errorf("socket name cannot be empty")
}

fdNamesStr := os.Getenv("LISTEN_FDNAMES")
if fdNamesStr == "" {
return 0, fmt.Errorf("LISTEN_FDNAMES environment variable not set")
}

// Parse name and optional index
parts := strings.Split(nameWithIndex, ":")
if len(parts) > 2 {
return 0, fmt.Errorf("invalid socket name format '%s': too many colons", nameWithIndex)
}

name := parts[0]
targetIndex := 0

if len(parts) > 1 {
var err error
targetIndex, err = strconv.Atoi(parts[1])
if err != nil {
return 0, fmt.Errorf("invalid socket index '%s': %v", parts[1], err)
}
if targetIndex < 0 {
return 0, fmt.Errorf("socket index cannot be negative: %d", targetIndex)
}
}

// Parse the socket names
names := strings.Split(fdNamesStr, ":")

// Find the Nth occurrence of the requested name
matchCount := 0
for i, fdName := range names {
if fdName == name {
if matchCount == targetIndex {
return listenFdsStart + i, nil
}
matchCount++
}
}

if matchCount == 0 {
return 0, fmt.Errorf("socket name '%s' not found in LISTEN_FDNAMES", name)
}

return 0, fmt.Errorf("socket name '%s' found %d times, but index %d requested", name, matchCount, targetIndex)
}

// ParseNetworkAddress parses addr into its individual
// components. The input string is expected to be of
// the form "network/host:port-range" where any part is
Expand Down Expand Up @@ -398,27 +330,9 @@ func ParseNetworkAddressWithDefaults(addr, defaultNetwork string, defaultPort ui
}, err
}
if IsFdNetwork(network) {
fdAddr := host

// Handle named socket activation (fdname/name, fdgramname/name)
if strings.HasPrefix(network, "fdname") || strings.HasPrefix(network, "fdgramname") {
fdNum, err := getFdByName(host)
if err != nil {
return NetworkAddress{}, fmt.Errorf("named socket activation: %v", err)
}
fdAddr = strconv.Itoa(fdNum)

// Normalize network to standard fd/fdgram
if strings.HasPrefix(network, "fdname") {
network = "fd"
} else {
network = "fdgram"
}
}

return NetworkAddress{
Network: network,
Host: fdAddr,
Host: host,
}, nil
}
var start, end uint64
Expand Down Expand Up @@ -713,55 +627,12 @@ func (fcql *fakeCloseQuicListener) Close() error {
return nil
}

// RegisterNetwork registers a network type with Caddy so that if a listener is
// created for that network type, getListener will be invoked to get the listener.
// This should be called during init() and will panic if the network type is standard
// or reserved, or if it is already registered. EXPERIMENTAL and subject to change.
func RegisterNetwork(network string, getListener ListenerFunc) {
network = strings.TrimSpace(strings.ToLower(network))

if network == "tcp" || network == "tcp4" || network == "tcp6" ||
network == "udp" || network == "udp4" || network == "udp6" ||
network == "unix" || network == "unixpacket" || network == "unixgram" ||
strings.HasPrefix(network, "ip:") || strings.HasPrefix(network, "ip4:") || strings.HasPrefix(network, "ip6:") ||
network == "fd" || network == "fdgram" {
panic("network type " + network + " is reserved")
}

if _, ok := networkTypes[strings.ToLower(network)]; ok {
panic("network type " + network + " is already registered")
}

networkTypes[network] = getListener
}

var unixSocketsMu sync.Mutex

// getListenerFromPlugin returns a listener on the given network and address
// if a plugin has registered the network name. It may return (nil, nil) if
// no plugin can provide a listener.
func getListenerFromPlugin(ctx context.Context, network, host, port string, portOffset uint, config net.ListenConfig) (any, error) {
// get listener from plugin if network type is registered
if getListener, ok := networkTypes[network]; ok {
Log().Debug("getting listener from plugin", zap.String("network", network))
return getListener(ctx, network, host, port, portOffset, config)
}

return nil, nil
}

func listenerKey(network, addr string) string {
return network + "/" + addr
}

// ListenerFunc is a function that can return a listener given a network and address.
// The listeners must be capable of overlapping: with Caddy, new configs are loaded
// before old ones are unloaded, so listeners may overlap briefly if the configs
// both need the same listener. EXPERIMENTAL and subject to change.
type ListenerFunc func(ctx context.Context, network, host, portRange string, portOffset uint, cfg net.ListenConfig) (any, error)

var networkTypes = map[string]ListenerFunc{}

// ListenerWrapper is a type that wraps a listener
// so it can modify the input listener's methods.
// Modules that implement this interface are found
Expand Down
Loading
Loading