diff --git a/go.mod b/go.mod index 394b1fb0..6f6c1c98 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,12 @@ module github.com/omec-project/amf go 1.24.0 require ( - git.cs.nctu.edu.tw/calee/sctp v1.1.0 github.com/antihax/optional v1.0.0 github.com/gin-contrib/cors v1.7.6 github.com/gin-gonic/gin v1.11.0 github.com/go-viper/mapstructure/v2 v2.4.0 github.com/google/uuid v1.6.0 + github.com/ishidawataru/sctp v0.0.0-20250829011129-4b890084db30 github.com/mohae/deepcopy v0.0.0-20170929034955-c48cc78d4826 github.com/omec-project/nas v1.6.3 github.com/omec-project/ngap v1.6.1 diff --git a/go.sum b/go.sum index 9530778c..60e53890 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,3 @@ -git.cs.nctu.edu.tw/calee/sctp v1.1.0 h1:caiPJ0g2sH1QmDkC7x2yklKrc01Fuo1rqYW68Tq4mU0= -git.cs.nctu.edu.tw/calee/sctp v1.1.0/go.mod h1:NeOuBXO1iJBtldmNhkfSH8yFbnxlhI8eEJdUd7DZvws= github.com/aead/cmac v0.0.0-20160719120800-7af84192f0b1 h1:+JkXLHME8vLJafGhOH4aoV2Iu8bR55nU6iKMVfYVLjY= github.com/aead/cmac v0.0.0-20160719120800-7af84192f0b1/go.mod h1:nuudZmJhzWtx2212z+pkuy7B6nkBqa+xwNXZHL1j8cg= github.com/antihax/optional v1.0.0 h1:xK2lYat7ZLaVVcIuj82J8kIro4V6kDe0AUDFboUCwcg= @@ -64,6 +62,8 @@ github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2 h1:8Tjv8EJ+pM1xP8mK6egEbD1OgnV github.com/grpc-ecosystem/grpc-gateway/v2 v2.27.2/go.mod h1:pkJQ2tZHJ0aFOVEEot6oZmaVEZcRme73eIFmhiVuRWs= github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542 h1:2VTzZjLZBgl62/EtslCrtky5vbi9dd7HrQPQIx6wqiw= github.com/h2non/parth v0.0.0-20190131123155-b4df798d6542/go.mod h1:Ow0tF8D4Kplbc8s8sSb3V2oUCygFHVp8gC3Dn6U4MNI= +github.com/ishidawataru/sctp v0.0.0-20250829011129-4b890084db30 h1:SF8DGX8bGAXMAvxtJvFFy2KIAPwxIEDP3XpzZVhz0i4= +github.com/ishidawataru/sctp v0.0.0-20250829011129-4b890084db30/go.mod h1:co9pwDoBCm1kGxawmb4sPq0cSIOOWNPT4KnHotMP1Zg= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= diff --git a/ngap/dispatcher.go b/ngap/dispatcher.go index c4c81db4..e64736f9 100644 --- a/ngap/dispatcher.go +++ b/ngap/dispatcher.go @@ -12,9 +12,7 @@ import ( "fmt" "net" "os" - "reflect" - "git.cs.nctu.edu.tw/calee/sctp" "github.com/omec-project/amf/context" "github.com/omec-project/amf/logger" "github.com/omec-project/amf/metrics" @@ -339,52 +337,6 @@ func DispatchNgapMsg(ctx ctxt.Context, ran *context.AmfRan, pdu *ngapType.NGAPPD } } -func HandleSCTPNotification(conn net.Conn, notification sctp.Notification) { - amfSelf := context.AMF_Self() - - logger.NgapLog.Infof("Handle SCTP Notification[addr: %+v]", conn.RemoteAddr()) - - ran, ok := amfSelf.AmfRanFindByConn(conn) - if !ok { - logger.NgapLog.Warnf("RAN context has been removed[addr: %+v]", conn.RemoteAddr()) - return - } - - // Removing Stale Connections in AmfRanPool - amfSelf.AmfRanPool.Range(func(key, value interface{}) bool { - amfRan := value.(*context.AmfRan) - - conn := amfRan.Conn.(*sctp.SCTPConn) - errorConn := sctp.NewSCTPConn(-1, nil) - if reflect.DeepEqual(conn, errorConn) { - amfRan.Remove() - ran.Log.Infoln("removed stale entry in AmfRan pool") - } - return true - }) - - switch notification.Type() { - case sctp.SCTP_ASSOC_CHANGE: - ran.Log.Infoln("SCTP_ASSOC_CHANGE notification") - event := notification.(*sctp.SCTPAssocChangeEvent) - switch event.State() { - case sctp.SCTP_COMM_LOST: - ran.Log.Infoln("SCTP state is SCTP_COMM_LOST, close the connection") - ran.Remove() - case sctp.SCTP_SHUTDOWN_COMP: - ran.Log.Infoln("SCTP state is SCTP_SHUTDOWN_COMP, close the connection") - ran.Remove() - default: - ran.Log.Warnf("SCTP state[%+v] is not handled", event.State()) - } - case sctp.SCTP_SHUTDOWN_EVENT: - ran.Log.Infoln("SCTP_SHUTDOWN_EVENT notification, close the connection") - ran.Remove() - default: - ran.Log.Warnf("Non handled notification type: 0x%x", notification.Type()) - } -} - func HandleSCTPNotificationLb(gnbId string) { logger.NgapLog.Infof("Handle SCTP Notification[GnbId: %+v]", gnbId) diff --git a/ngap/service/service.go b/ngap/service/service.go index 33f9b0ba..63b30100 100644 --- a/ngap/service/service.go +++ b/ngap/service/service.go @@ -9,35 +9,26 @@ package service import ( "encoding/hex" "io" - "math/bits" "net" "sync" "syscall" - "git.cs.nctu.edu.tw/calee/sctp" + "github.com/ishidawataru/sctp" "github.com/omec-project/amf/logger" "github.com/omec-project/ngap" ) -type NGAPHandler struct { - HandleMessage func(conn net.Conn, msg []byte) - HandleNotification func(conn net.Conn, notification sctp.Notification) -} +type NGAPHandler func(conn net.Conn, msg []byte) const readBufSize uint32 = 131072 -// set default read timeout to 2 seconds -var readTimeout syscall.Timeval = syscall.Timeval{Sec: 2, Usec: 0} - var ( sctpListener *sctp.SCTPListener connections sync.Map ) var sctpConfig sctp.SocketConfig = sctp.SocketConfig{ - InitMsg: sctp.InitMsg{NumOstreams: 3, MaxInstreams: 5, MaxAttempts: 2, MaxInitTimeout: 2}, - RtoInfo: &sctp.RtoInfo{SrtoAssocID: 0, SrtoInitial: 500, SrtoMax: 1500, StroMin: 100}, - AssocInfo: &sctp.AssocInfo{AsocMaxRxt: 4}, + InitMsg: sctp.InitMsg{NumOstreams: 3, MaxInstreams: 5, MaxAttempts: 2, MaxInitTimeout: 2}, } func Run(addresses []string, port int, handler NGAPHandler) { @@ -61,14 +52,14 @@ func Run(addresses []string, port int, handler NGAPHandler) { } func listenAndServe(addr *sctp.SCTPAddr, handler NGAPHandler) { - if listener, err := sctpConfig.Listen("sctp", addr); err != nil { + listener, err := sctpConfig.Listen("sctp", addr) + if err != nil { logger.NgapLog.Errorf("failed to listen: %+v", err) return - } else { - sctpListener = listener } + sctpListener = listener - logger.NgapLog.Infof("Listen on %s", sctpListener.Addr()) + logger.NgapLog.Infof("listen on %s", sctpListener.Addr()) for { newConn, err := sctpListener.AcceptSCTP() @@ -126,16 +117,6 @@ func listenAndServe(addr *sctp.SCTPAddr, handler NGAPHandler) { logger.NgapLog.Debugf("Set read buffer to %d bytes", readBufSize) } - if err := newConn.SetReadTimeout(readTimeout); err != nil { - logger.NgapLog.Errorf("set read timeout error: %+v, accept failed", err) - if err = newConn.Close(); err != nil { - logger.NgapLog.Errorf("close error: %+v", err) - } - continue - } else { - logger.NgapLog.Debugf("set read timeout: %+v", readTimeout) - } - logger.NgapLog.Infof("[AMF] SCTP Accept from: %s", newConn.RemoteAddr().String()) connections.Store(newConn, newConn) @@ -144,13 +125,13 @@ func listenAndServe(addr *sctp.SCTPAddr, handler NGAPHandler) { } func Stop() { - logger.NgapLog.Infoln("close SCTP server...") + logger.NgapLog.Infoln("close SCTP server") if err := sctpListener.Close(); err != nil { logger.NgapLog.Error(err) - logger.NgapLog.Infof("SCTP server may not close normally.") + logger.NgapLog.Infoln("SCTP server may not close normally") } - connections.Range(func(key, value interface{}) bool { + connections.Range(func(key, value any) bool { conn := value.(net.Conn) if err := conn.Close(); err != nil { logger.NgapLog.Error(err) @@ -173,7 +154,7 @@ func handleConnection(conn *sctp.SCTPConn, bufsize uint32, handler NGAPHandler) for { buf := make([]byte, bufsize) - n, info, notification, err := conn.SCTPRead(buf) + n, info, err := conn.SCTPRead(buf) if err != nil { switch err { case io.EOF, io.ErrUnexpectedEOF: @@ -191,23 +172,15 @@ func handleConnection(conn *sctp.SCTPConn, bufsize uint32, handler NGAPHandler) } } - if notification != nil { - if handler.HandleNotification != nil { - handler.HandleNotification(conn, notification) - } else { - logger.NgapLog.Warnf("received sctp notification[type 0x%x] but not handled", notification.Type()) - } - } else { - if info == nil || info.PPID != bits.ReverseBytes32(ngap.PPID) { - logger.NgapLog.Warnln("received SCTP PPID != 60, discard this packet") - continue - } + if info == nil || info.PPID != ngap.PPID { + logger.NgapLog.Warnln("received SCTP PPID != 60, discard this packet") + continue + } - logger.NgapLog.Debugf("Read %d bytes", n) - logger.NgapLog.Debugf("Packet content: %+v", hex.Dump(buf[:n])) + logger.NgapLog.Debugf("read %d bytes", n) + logger.NgapLog.Debugf("packet content: %+v", hex.Dump(buf[:n])) - // TODO: concurrent on per-UE message - handler.HandleMessage(conn, buf[:n]) - } + // TODO: concurrent on per-UE message + handler(conn, buf[:n]) } } diff --git a/service/init.go b/service/init.go index 39986e23..1eed8f3b 100644 --- a/service/init.go +++ b/service/init.go @@ -313,10 +313,7 @@ func (amf *AMF) Start() { addr := fmt.Sprintf("%s:%d", self.BindingIPv4, self.SBIPort) - ngapHandler := ngap_service.NGAPHandler{ - HandleMessage: ngap.Dispatch, - HandleNotification: ngap.HandleSCTPNotification, - } + ngapHandler := ngap_service.NGAPHandler(ngap.Dispatch) ngap_service.Run(self.NgapIpList, self.NgapPort, ngapHandler) if self.EnableNrfCaching {