mirror of
https://codeberg.org/forgejo/forgejo.git
synced 2025-01-06 15:06:06 -05:00
1f0638747b
By default Gitea will always unlink any sockets that are provided using the `LISTEN_FDS` environment variable. This is because it uses this variable to handle passing when it is doing a graceful restart. However, this same mechanism is used by systemd - which explicitly expects that passed in sockets should not be unlinked by the receiving process. This PR adjusts Gitea's graceful restart mechanism to use an additional environment variable which tracks if a listening socket was opened by Gitea - and therefore should be unlinked on shutdown by Gitea. Fix #20490 Co-authored-by: zeripath <art27@cantab.net>
257 lines
7.3 KiB
Go
257 lines
7.3 KiB
Go
// Copyright 2019 The Gitea Authors. All rights reserved.
|
|
// Use of this source code is governed by a MIT-style
|
|
// license that can be found in the LICENSE file.
|
|
// This code is heavily inspired by the archived gofacebook/gracenet/net.go handler
|
|
|
|
//go:build !windows
|
|
|
|
package graceful
|
|
|
|
import (
|
|
"fmt"
|
|
"net"
|
|
"os"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
|
|
"code.gitea.io/gitea/modules/log"
|
|
"code.gitea.io/gitea/modules/setting"
|
|
"code.gitea.io/gitea/modules/util"
|
|
)
|
|
|
|
const (
|
|
listenFDs = "LISTEN_FDS"
|
|
startFD = 3
|
|
unlinkFDs = "GITEA_UNLINK_FDS"
|
|
)
|
|
|
|
// In order to keep the working directory the same as when we started we record
|
|
// it at startup.
|
|
var originalWD, _ = os.Getwd()
|
|
|
|
var (
|
|
once = sync.Once{}
|
|
mutex = sync.Mutex{}
|
|
|
|
providedListenersToUnlink = []bool{}
|
|
activeListenersToUnlink = []bool{}
|
|
providedListeners = []net.Listener{}
|
|
activeListeners = []net.Listener{}
|
|
)
|
|
|
|
func getProvidedFDs() (savedErr error) {
|
|
// Only inherit the provided FDS once but we will save the error so that repeated calls to this function will return the same error
|
|
once.Do(func() {
|
|
mutex.Lock()
|
|
defer mutex.Unlock()
|
|
|
|
numFDs := os.Getenv(listenFDs)
|
|
if numFDs == "" {
|
|
return
|
|
}
|
|
n, err := strconv.Atoi(numFDs)
|
|
if err != nil {
|
|
savedErr = fmt.Errorf("%s is not a number: %s. Err: %v", listenFDs, numFDs, err)
|
|
return
|
|
}
|
|
|
|
fdsToUnlinkStr := strings.Split(os.Getenv(unlinkFDs), ",")
|
|
providedListenersToUnlink = make([]bool, n)
|
|
for _, fdStr := range fdsToUnlinkStr {
|
|
i, err := strconv.Atoi(fdStr)
|
|
if err != nil || i < 0 || i >= n {
|
|
continue
|
|
}
|
|
providedListenersToUnlink[i] = true
|
|
}
|
|
|
|
for i := startFD; i < n+startFD; i++ {
|
|
file := os.NewFile(uintptr(i), fmt.Sprintf("listener_FD%d", i))
|
|
|
|
l, err := net.FileListener(file)
|
|
if err == nil {
|
|
// Close the inherited file if it's a listener
|
|
if err = file.Close(); err != nil {
|
|
savedErr = fmt.Errorf("error closing provided socket fd %d: %s", i, err)
|
|
return
|
|
}
|
|
providedListeners = append(providedListeners, l)
|
|
continue
|
|
}
|
|
|
|
// If needed we can handle packetconns here.
|
|
savedErr = fmt.Errorf("Error getting provided socket fd %d: %v", i, err)
|
|
return
|
|
}
|
|
})
|
|
return savedErr
|
|
}
|
|
|
|
// CloseProvidedListeners closes all unused provided listeners.
|
|
func CloseProvidedListeners() error {
|
|
mutex.Lock()
|
|
defer mutex.Unlock()
|
|
var returnableError error
|
|
for _, l := range providedListeners {
|
|
err := l.Close()
|
|
if err != nil {
|
|
log.Error("Error in closing unused provided listener: %v", err)
|
|
if returnableError != nil {
|
|
returnableError = fmt.Errorf("%v & %v", returnableError, err)
|
|
} else {
|
|
returnableError = err
|
|
}
|
|
}
|
|
}
|
|
providedListeners = []net.Listener{}
|
|
|
|
return returnableError
|
|
}
|
|
|
|
// GetListener obtains a listener for the local network address. The network must be
|
|
// a stream-oriented network: "tcp", "tcp4", "tcp6", "unix" or "unixpacket". It
|
|
// returns an provided net.Listener for the matching network and address, or
|
|
// creates a new one using net.Listen.
|
|
func GetListener(network, address string) (net.Listener, error) {
|
|
// Add a deferral to say that we've tried to grab a listener
|
|
defer GetManager().InformCleanup()
|
|
switch network {
|
|
case "tcp", "tcp4", "tcp6":
|
|
tcpAddr, err := net.ResolveTCPAddr(network, address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return GetListenerTCP(network, tcpAddr)
|
|
case "unix", "unixpacket":
|
|
unixAddr, err := net.ResolveUnixAddr(network, address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return GetListenerUnix(network, unixAddr)
|
|
default:
|
|
return nil, net.UnknownNetworkError(network)
|
|
}
|
|
}
|
|
|
|
// GetListenerTCP announces on the local network address. The network must be:
|
|
// "tcp", "tcp4" or "tcp6". It returns a provided net.Listener for the
|
|
// matching network and address, or creates a new one using net.ListenTCP.
|
|
func GetListenerTCP(network string, address *net.TCPAddr) (*net.TCPListener, error) {
|
|
if err := getProvidedFDs(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
mutex.Lock()
|
|
defer mutex.Unlock()
|
|
|
|
// look for a provided listener
|
|
for i, l := range providedListeners {
|
|
if isSameAddr(l.Addr(), address) {
|
|
providedListeners = append(providedListeners[:i], providedListeners[i+1:]...)
|
|
needsUnlink := providedListenersToUnlink[i]
|
|
providedListenersToUnlink = append(providedListenersToUnlink[:i], providedListenersToUnlink[i+1:]...)
|
|
|
|
activeListeners = append(activeListeners, l)
|
|
activeListenersToUnlink = append(activeListenersToUnlink, needsUnlink)
|
|
return l.(*net.TCPListener), nil
|
|
}
|
|
}
|
|
|
|
// no provided listener for this address -> make a fresh listener
|
|
l, err := net.ListenTCP(network, address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
activeListeners = append(activeListeners, l)
|
|
activeListenersToUnlink = append(activeListenersToUnlink, false)
|
|
return l, nil
|
|
}
|
|
|
|
// GetListenerUnix announces on the local network address. The network must be:
|
|
// "unix" or "unixpacket". It returns a provided net.Listener for the
|
|
// matching network and address, or creates a new one using net.ListenUnix.
|
|
func GetListenerUnix(network string, address *net.UnixAddr) (*net.UnixListener, error) {
|
|
if err := getProvidedFDs(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
mutex.Lock()
|
|
defer mutex.Unlock()
|
|
|
|
// look for a provided listener
|
|
for i, l := range providedListeners {
|
|
if isSameAddr(l.Addr(), address) {
|
|
providedListeners = append(providedListeners[:i], providedListeners[i+1:]...)
|
|
needsUnlink := providedListenersToUnlink[i]
|
|
providedListenersToUnlink = append(providedListenersToUnlink[:i], providedListenersToUnlink[i+1:]...)
|
|
|
|
activeListenersToUnlink = append(activeListenersToUnlink, needsUnlink)
|
|
activeListeners = append(activeListeners, l)
|
|
unixListener := l.(*net.UnixListener)
|
|
if needsUnlink {
|
|
unixListener.SetUnlinkOnClose(true)
|
|
}
|
|
return unixListener, nil
|
|
}
|
|
}
|
|
|
|
// make a fresh listener
|
|
if err := util.Remove(address.Name); err != nil && !os.IsNotExist(err) {
|
|
return nil, fmt.Errorf("Failed to remove unix socket %s: %v", address.Name, err)
|
|
}
|
|
|
|
l, err := net.ListenUnix(network, address)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
fileMode := os.FileMode(setting.UnixSocketPermission)
|
|
if err = os.Chmod(address.Name, fileMode); err != nil {
|
|
return nil, fmt.Errorf("Failed to set permission of unix socket to %s: %v", fileMode.String(), err)
|
|
}
|
|
|
|
activeListeners = append(activeListeners, l)
|
|
activeListenersToUnlink = append(activeListenersToUnlink, true)
|
|
return l, nil
|
|
}
|
|
|
|
func isSameAddr(a1, a2 net.Addr) bool {
|
|
// If the addresses are not on the same network fail.
|
|
if a1.Network() != a2.Network() {
|
|
return false
|
|
}
|
|
|
|
// If the two addresses have the same string representation they're equal
|
|
a1s := a1.String()
|
|
a2s := a2.String()
|
|
if a1s == a2s {
|
|
return true
|
|
}
|
|
|
|
// This allows for ipv6 vs ipv4 local addresses to compare as equal. This
|
|
// scenario is common when listening on localhost.
|
|
const ipv6prefix = "[::]"
|
|
a1s = strings.TrimPrefix(a1s, ipv6prefix)
|
|
a2s = strings.TrimPrefix(a2s, ipv6prefix)
|
|
const ipv4prefix = "0.0.0.0"
|
|
a1s = strings.TrimPrefix(a1s, ipv4prefix)
|
|
a2s = strings.TrimPrefix(a2s, ipv4prefix)
|
|
return a1s == a2s
|
|
}
|
|
|
|
func getActiveListeners() []net.Listener {
|
|
mutex.Lock()
|
|
defer mutex.Unlock()
|
|
listeners := make([]net.Listener, len(activeListeners))
|
|
copy(listeners, activeListeners)
|
|
return listeners
|
|
}
|
|
|
|
func getActiveListenersToUnlink() []bool {
|
|
mutex.Lock()
|
|
defer mutex.Unlock()
|
|
listenersToUnlink := make([]bool, len(activeListenersToUnlink))
|
|
copy(listenersToUnlink, activeListenersToUnlink)
|
|
return listenersToUnlink
|
|
}
|