diff --git a/cmd/gean/main.go b/cmd/gean/main.go index d1b514d..aab47c6 100644 --- a/cmd/gean/main.go +++ b/cmd/gean/main.go @@ -135,6 +135,7 @@ func main() { logger.Error("failed to initialize node", "err", err) os.Exit(1) } + defer n.Close() ctx, cancel := context.WithCancel(context.Background()) defer cancel() diff --git a/node/lifecycle.go b/node/lifecycle.go index a7d72c7..37f26f1 100644 --- a/node/lifecycle.go +++ b/node/lifecycle.go @@ -299,6 +299,10 @@ func loadValidatorKeys(log *slog.Logger, cfg Config) (map[uint64]forkchoice.Sign kp, err := leansig.LoadKeypair(pkPath, skPath) if err != nil { + // Clean up previously loaded keypairs to prevent Rust memory leaks. + // Modeled after zeam's errdefer keypair.deinit() pattern + // (cli/src/node.zig:433-469). + freeLoadedKeys(keys) return nil, fmt.Errorf("failed to load keypair for validator %d: %w", idx, err) } keys[idx] = kp @@ -307,6 +311,16 @@ func loadValidatorKeys(log *slog.Logger, cfg Config) (map[uint64]forkchoice.Sign return keys, nil } +// freeLoadedKeys releases Rust-allocated XMSS keypairs from a partially +// loaded key map. Called on error during loadValidatorKeys to prevent leaks. +func freeLoadedKeys(keys map[uint64]forkchoice.Signer) { + for _, key := range keys { + if f, ok := key.(interface{ Free() }); ok { + f.Free() + } + } +} + func startMetrics(log *slog.Logger, cfg Config) { if cfg.MetricsPort <= 0 { return diff --git a/node/node.go b/node/node.go index 3140f39..3935a0b 100644 --- a/node/node.go +++ b/node/node.go @@ -104,6 +104,14 @@ func (n *Node) Close() { if n.API != nil { n.API.Stop() } + // Free Rust-allocated XMSS keypairs. + if n.Validator != nil { + for _, key := range n.Validator.Keys { + if f, ok := key.(interface{ Free() }); ok { + f.Free() + } + } + } if n.dbCloser != nil { n.dbCloser.Close() } diff --git a/xmss/leanmultisig/leanmultisig.go b/xmss/leanmultisig/leanmultisig.go index d5899e4..74b05e6 100644 --- a/xmss/leanmultisig/leanmultisig.go +++ b/xmss/leanmultisig/leanmultisig.go @@ -10,6 +10,7 @@ package leanmultisig import "C" import ( "fmt" + "runtime" "sync" "unsafe" ) @@ -82,6 +83,7 @@ func Aggregate(pubkeys, signatures [][]byte, messageHash [MessageHashLength]byte &outData, &outLen, ) + runtime.KeepAlive(messageHash) if result != ResultOK { return nil, resultError("leanmultisig_aggregate", result) } @@ -118,6 +120,8 @@ func VerifyAggregated(pubkeys [][]byte, messageHash [MessageHashLength]byte, pro C.size_t(len(proofData)), C.uint32_t(epoch), ) + runtime.KeepAlive(messageHash) + runtime.KeepAlive(proofData) if result != ResultOK { return resultError("leanmultisig_verify_aggregated", result) } diff --git a/xmss/leansig/leansig.go b/xmss/leansig/leansig.go index bcfd2f3..3bad795 100644 --- a/xmss/leansig/leansig.go +++ b/xmss/leansig/leansig.go @@ -16,6 +16,7 @@ package leansig import "C" import ( "fmt" + "runtime" "unsafe" ) @@ -75,6 +76,8 @@ func RestoreKeypair(pkBytes []byte, skBytes []byte) (*Keypair, error) { skLen := C.size_t(len(skBytes)) result := C.leansig_keypair_restore(pkPtr, pkLen, skPtr, skLen, &kpPtr) + runtime.KeepAlive(pkBytes) + runtime.KeepAlive(skBytes) if result != ResultOK { return nil, fmt.Errorf("leansig_keypair_restore failed with code %d", result) } @@ -184,6 +187,7 @@ func (kp *Keypair) Sign(epoch uint32, message [MessageLength]byte) ([]byte, erro &sigData, &sigLen, ) + runtime.KeepAlive(message) if result != ResultOK { return nil, fmt.Errorf("leansig_sign failed with code %d", result) } @@ -206,6 +210,9 @@ func Verify(pubkeyBytes []byte, epoch uint32, message [MessageLength]byte, sigBy (*C.uint8_t)(unsafe.Pointer(&sigBytes[0])), C.size_t(len(sigBytes)), ) + runtime.KeepAlive(pubkeyBytes) + runtime.KeepAlive(message) + runtime.KeepAlive(sigBytes) if result == ResultOK { return nil } @@ -231,6 +238,8 @@ func (kp *Keypair) VerifyWithKeypair(epoch uint32, message [MessageLength]byte, (*C.uint8_t)(unsafe.Pointer(&sigBytes[0])), C.size_t(len(sigBytes)), ) + runtime.KeepAlive(message) + runtime.KeepAlive(sigBytes) if result == ResultOK { return nil }