From d682d3edd80bdebe1d8f68010ae8b95293269c9a Mon Sep 17 00:00:00 2001 From: Massimo Gengarelli Date: Thu, 13 Jun 2024 12:14:27 +0200 Subject: [PATCH] feat: use contexts instead of forwarding signals to terminate goroutines Other minor modifications include: - Build flake in GitHub CI - Makefile to simplify the commands - Better versioning - Clean flake.nix file a little --- .github/workflows/ci.yaml | 17 ++++++-- .gitignore | 1 + Makefile | 15 +++++++ cmd/protrans/main.go | 38 ++++++++++------ flake.nix | 14 ++++-- pkg/config/config.go | 1 + pkg/flow/flow.go | 26 +++++------ pkg/flow/flow_test.go | 91 +++++++++++++++++++-------------------- 8 files changed, 125 insertions(+), 78 deletions(-) create mode 100644 Makefile diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 1808ead..3e208b3 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -4,7 +4,7 @@ on: [push] jobs: build: - name: Build + name: Build raw binary runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -12,7 +12,8 @@ jobs: with: go-version: ">=1.22" - run: go version - - run: go build -v ./cmd/protrans + - run: make all + - run: ldd ./protrans test: name: Test runs-on: ubuntu-latest @@ -22,4 +23,14 @@ jobs: with: go-version: ">=1.22" - run: go version - - run: go test -v ./... + - run: make test + flake: + name: Build Flake + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: cachix/install-nix-action@v27 + with: + github_access_token: ${{ secrets.GITHUB_TOKEN }} + - run: nix build + - run: ldd ./result/bin/protrans diff --git a/.gitignore b/.gitignore index 9b42106..b8085d2 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ .direnv/ +protrans diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..bd3f4d5 --- /dev/null +++ b/Makefile @@ -0,0 +1,15 @@ +GO := $(shell which go) +VERSION = 1.1 + +.PHONY: clean + +all: protrans + +protrans: cmd/protrans/main.go + $(GO) build -ldflags "-X 'main.Version=${VERSION}'" -o $@ -v $^ + +test: + $(GO) test -v ./... + +clean: + rm -fr protrans diff --git a/cmd/protrans/main.go b/cmd/protrans/main.go index 2dc86c8..9de7386 100644 --- a/cmd/protrans/main.go +++ b/cmd/protrans/main.go @@ -1,6 +1,7 @@ package main import ( + "context" "fmt" "os" "os/signal" @@ -23,9 +24,14 @@ func dumpTransmissionConfiguration(conf *config.ProtransConfiguration) string { return fmt.Sprintf("\tHost: %s\n\tPort: %d\n\tUsername: %s\n", conf.Transmission.Host, conf.Transmission.Port, conf.Transmission.Username) } +var Version string + func main() { var configurationPath string + var wg sync.WaitGroup + defer wg.Wait() + if len(os.Args) > 1 { configurationPath = os.Args[1] logrus.Infof("Parsing configuration from '%s'", configurationPath) @@ -35,7 +41,9 @@ func main() { conf := config.NewConfiguration(configurationPath, true) logrus.SetLevel(conf.LogrusLogLevel()) - logrus.Infof("Log level: %s\n", conf.LogrusLogLevel().String()) + + logrus.Infof("Protrans version: %s", Version) + logrus.Infof("Log level: %s", conf.LogrusLogLevel().String()) logrus.Infof("NAT Configuration:\n%s", dumpNatConfiguration(conf)) logrus.Infof("Transmission Configuration:\n%s", dumpTransmissionConfiguration(conf)) @@ -47,12 +55,11 @@ func main() { logrus.Panic(err) } + ctx, cancel := context.WithCancel(context.Background()) + // Register to some signals - done := make(chan os.Signal, 126) - signal.Notify(done, syscall.SIGTERM) - signal.Notify(done, syscall.SIGINT) - signal.Notify(done, syscall.SIGABRT) - signal.Notify(done, syscall.SIGHUP) + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGTERM, syscall.SIGINT, syscall.SIGABRT, syscall.SIGHUP) // Buffered channel to make sure we're refreshing it constantly ipChan := make(chan string, 30) @@ -62,20 +69,27 @@ func main() { defer func() { close(ipChan) close(portChan) - close(done) + close(sigChan) }() - var wg sync.WaitGroup wg.Add(3) // This goroutine will constantly check the external IP address and send it to a channel - go flow.FetchExternalIP(natClient, &wg, ipChan, done) + go flow.FetchExternalIP(ctx, natClient, &wg, ipChan) // This goroutine will receive the IP address and create a port mapping which will be sent to another channel - go flow.MapPorts(natClient, int(conf.Nat.PortLifetime), &wg, ipChan, portChan, done) + go flow.MapPorts(ctx, natClient, int(conf.Nat.PortLifetime), &wg, ipChan, portChan) // This goroutine will receive the mapped port and send it to Transmission if connected - go flow.TransmissionArgSetter(transmissionClient, &wg, portChan, done) + go flow.TransmissionArgSetter(ctx, transmissionClient, &wg, portChan) + + select { + case <-ctx.Done(): + logrus.Infof("Context closed, leaving") + case s := <-sigChan: + logrus.Infof("Received signal %q, leaving", s) + cancel() + } - wg.Wait() + logrus.Info("Waiting for goroutines to finish") } diff --git a/flake.nix b/flake.nix index e1b7771..aad9dec 100644 --- a/flake.nix +++ b/flake.nix @@ -5,20 +5,26 @@ let system = "x86_64-linux"; pkgs = import nixpkgs { inherit system; }; + hardeningDisable = [ "fortify" ]; + version = "1.1"; in { devShells.${system}.default = pkgs.mkShell { + inherit hardeningDisable; packages = with pkgs; [ go ]; - hardeningDisable = [ "fortify" ]; }; packages.${system}.default = pkgs.buildGoModule { + inherit hardeningDisable version; pname = "protrans"; - version = "1.0"; - hardeningDisable = [ "fortify" ]; - src = ./.; + src = + let + noSrcs = [ ".vscode" ".git" ".github" ".gitignore" ".envrc" ]; + in + builtins.filterSource (path: _: ! builtins.elem (baseNameOf path) noSrcs) ./.; + ldflags = [ "-X 'main.Version=${version}'" ]; vendorHash = "sha256-H79018dCud68fYT0l3IGZXQvD22byhnw/GchsiYJc68="; }; diff --git a/pkg/config/config.go b/pkg/config/config.go index e335a39..b6f9709 100644 --- a/pkg/config/config.go +++ b/pkg/config/config.go @@ -33,6 +33,7 @@ type ( func overrideWithEnvironment(iface, concrete any) { types := reflect.TypeOf(iface) values := reflect.ValueOf(concrete).Elem() + for i := 0; i < types.NumField(); i++ { typeField := types.Field(i) valueField := values.FieldByName(typeField.Name) diff --git a/pkg/flow/flow.go b/pkg/flow/flow.go index c33e4e7..c55e65d 100644 --- a/pkg/flow/flow.go +++ b/pkg/flow/flow.go @@ -1,7 +1,7 @@ package flow import ( - "os" + "context" "sync" "time" @@ -10,9 +10,12 @@ import ( "github.com/sirupsen/logrus" ) -func FetchExternalIP(natClient nat.NatClientI, wg *sync.WaitGroup, ipChan chan<- string, done chan os.Signal) { +func FetchExternalIP(ctx context.Context, natClient nat.NatClientI, wg *sync.WaitGroup, ipChan chan<- string) { running := true + timer := time.NewTimer(30 * time.Second) + defer timer.Stop() + for running { ip, err := nat.GetExternalIP(natClient) if err != nil { @@ -23,26 +26,25 @@ func FetchExternalIP(natClient nat.NatClientI, wg *sync.WaitGroup, ipChan chan<- } select { - case s := <-done: + case <-ctx.Done(): logrus.Info("Gracefully stopping gateway detector") running = false - done <- s // Make sure everyone is leaving - case <-time.After(30 * time.Second): + case <-timer.C: logrus.Debug("No signals received in 30 seconds, refreshing IP...") + timer.Reset(30 * time.Second) } } wg.Done() } -func MapPorts(natClient nat.NatClientI, portLifetime int, wg *sync.WaitGroup, ipChan <-chan string, portChan chan<- int, done chan os.Signal) { +func MapPorts(ctx context.Context, natClient nat.NatClientI, portLifetime int, wg *sync.WaitGroup, ipChan <-chan string, portChan chan<- int) { running := true for running { select { case ip := <-ipChan: - var mappedTcpPort int - var mappedUdpPort int + var mappedTcpPort, mappedUdpPort int var err error logrus.Debugf("Mapping port for external IP: %s", ip) @@ -65,17 +67,16 @@ func MapPorts(natClient nat.NatClientI, portLifetime int, wg *sync.WaitGroup, ip logrus.Debugf("Sending port %d to channel", mappedTcpPort) portChan <- mappedTcpPort - case s := <-done: + case <-ctx.Done(): logrus.Info("Gracefully stopping port mapper") running = false - done <- s } } wg.Done() } -func TransmissionArgSetter(transmissionClient transmission.TransmissionClient, wg *sync.WaitGroup, portChan <-chan int, done chan os.Signal) { +func TransmissionArgSetter(ctx context.Context, transmissionClient transmission.TransmissionClient, wg *sync.WaitGroup, portChan <-chan int) { running := true for running { @@ -104,10 +105,9 @@ func TransmissionArgSetter(transmissionClient transmission.TransmissionClient, w } else { logrus.Warnf("Should set port: %d but Transmission is not connected", mappedPort) } - case s := <-done: + case <-ctx.Done(): logrus.Info("Gracefully stopping Transmission Client") running = false - done <- s } } diff --git a/pkg/flow/flow_test.go b/pkg/flow/flow_test.go index bfbc0db..b688690 100644 --- a/pkg/flow/flow_test.go +++ b/pkg/flow/flow_test.go @@ -1,8 +1,8 @@ package flow_test import ( + "context" "errors" - "os" "sync" "testing" "time" @@ -87,6 +87,10 @@ func reset() { peerPortSet = false } +func createContext() (context.Context, context.CancelFunc) { + return context.WithDeadline(context.Background(), time.Now().Add(3*time.Second)) +} + // Happy path test, everything is ok and working func Test_Flow_OK(t *testing.T) { var wg sync.WaitGroup @@ -96,22 +100,22 @@ func Test_Flow_OK(t *testing.T) { nc := &fakeNatClient{true, true} tc := &fakeTransmissionClient{true, false, true} + ctx, cancel := createContext() + defer cancel() + ipChan := make(chan string) portChan := make(chan int) - done := make(chan os.Signal, 1) defer func() { close(ipChan) close(portChan) - close(done) }() // Start the flow - go flow.FetchExternalIP(nc, &wg, ipChan, done) - go flow.MapPorts(nc, 600, &wg, ipChan, portChan, done) - go flow.TransmissionArgSetter(tc, &wg, portChan, done) + go flow.FetchExternalIP(ctx, nc, &wg, ipChan) + go flow.MapPorts(ctx, nc, 600, &wg, ipChan, portChan) + go flow.TransmissionArgSetter(ctx, tc, &wg, portChan) - <-time.After(3 * time.Second) - done <- os.Interrupt + <-ctx.Done() // Wait for everyone to finish wg.Wait() @@ -127,23 +131,21 @@ func Test_Flow_NoExternalIP(t *testing.T) { reset() wg.Add(1) + ctx, cancel := createContext() + defer cancel() + nc := &fakeNatClient{false, true} ipChan := make(chan string) - done := make(chan os.Signal, 1) - defer func() { - close(ipChan) - close(done) - }() + defer close(ipChan) - go flow.FetchExternalIP(nc, &wg, ipChan, done) + go flow.FetchExternalIP(ctx, nc, &wg, ipChan) // We should fall into the timeout here select { case ip := <-ipChan: t.Fatalf("Should not have received an IP, received: %s instead?", ip) - case <-time.After(5 * time.Second): - done <- os.Interrupt + case <-ctx.Done(): } wg.Wait() @@ -160,24 +162,23 @@ func Test_Flow_NoPortMapping(t *testing.T) { wg.Add(2) nc := &fakeNatClient{true, false} + ctx, cancel := createContext() + defer cancel() ipChan := make(chan string) portChan := make(chan int) - done := make(chan os.Signal, 1) defer func() { close(ipChan) close(portChan) - close(done) }() - go flow.FetchExternalIP(nc, &wg, ipChan, done) - go flow.MapPorts(nc, 600, &wg, ipChan, portChan, done) + go flow.FetchExternalIP(ctx, nc, &wg, ipChan) + go flow.MapPorts(ctx, nc, 600, &wg, ipChan, portChan) select { case p := <-portChan: t.Fatalf("Should not have been able to open a port, opened %d instead?", p) - case <-time.After(5 * time.Second): - done <- os.Interrupt + case <-ctx.Done(): } wg.Wait() @@ -196,23 +197,23 @@ func Test_Flow_NoTransmissionConnection(t *testing.T) { nc := &fakeNatClient{true, true} tc := &fakeTransmissionClient{false, false, false} + ctx, cancel := createContext() + defer cancel() + ipChan := make(chan string) portChan := make(chan int) - done := make(chan os.Signal, 1) defer func() { close(ipChan) close(portChan) - close(done) }() // Start the flow - go flow.FetchExternalIP(nc, &wg, ipChan, done) - go flow.MapPorts(nc, 600, &wg, ipChan, portChan, done) - go flow.TransmissionArgSetter(tc, &wg, portChan, done) + go flow.FetchExternalIP(ctx, nc, &wg, ipChan) + go flow.MapPorts(ctx, nc, 600, &wg, ipChan, portChan) + go flow.TransmissionArgSetter(ctx, tc, &wg, portChan) - // Wait for everyone to finish - <-time.After(5 * time.Second) - done <- os.Interrupt + // Wait for the context to expire + <-ctx.Done() wg.Wait() assert.True(t, externalIPRetrieved) @@ -228,24 +229,23 @@ func Test_Flow_TransmissionPortAlreadyOpen(t *testing.T) { nc := &fakeNatClient{true, true} tc := &fakeTransmissionClient{true, true, true} + ctx, cancel := createContext() + defer cancel() ipChan := make(chan string) portChan := make(chan int) - done := make(chan os.Signal, 1) defer func() { close(ipChan) close(portChan) - close(done) }() // Start the flow - go flow.FetchExternalIP(nc, &wg, ipChan, done) - go flow.MapPorts(nc, 600, &wg, ipChan, portChan, done) - go flow.TransmissionArgSetter(tc, &wg, portChan, done) + go flow.FetchExternalIP(ctx, nc, &wg, ipChan) + go flow.MapPorts(ctx, nc, 600, &wg, ipChan, portChan) + go flow.TransmissionArgSetter(ctx, tc, &wg, portChan) - // Wait for everyone to finish - <-time.After(5 * time.Second) - done <- os.Interrupt + // Wait for the context to expire + <-ctx.Done() wg.Wait() assert.True(t, externalIPRetrieved) @@ -261,24 +261,23 @@ func Test_Flow_TransmissionUnableToSet(t *testing.T) { nc := &fakeNatClient{true, true} tc := &fakeTransmissionClient{true, false, false} + ctx, cancel := createContext() + defer cancel() ipChan := make(chan string) portChan := make(chan int) - done := make(chan os.Signal, 1) defer func() { close(ipChan) close(portChan) - close(done) }() // Start the flow - go flow.FetchExternalIP(nc, &wg, ipChan, done) - go flow.MapPorts(nc, 600, &wg, ipChan, portChan, done) - go flow.TransmissionArgSetter(tc, &wg, portChan, done) + go flow.FetchExternalIP(ctx, nc, &wg, ipChan) + go flow.MapPorts(ctx, nc, 600, &wg, ipChan, portChan) + go flow.TransmissionArgSetter(ctx, tc, &wg, portChan) - // Wait for everyone to finish - <-time.After(5 * time.Second) - done <- os.Interrupt + // Wait for the context to expire + <-ctx.Done() wg.Wait() assert.True(t, externalIPRetrieved)