diff --git a/auth.go b/auth.go index 36ce19e..5125c98 100644 --- a/auth.go +++ b/auth.go @@ -43,7 +43,7 @@ func authPlain(socket io.ReadWriter, decoder *xml.Decoder, user string, password // v.Any is type of sub-element in failure, which gives a description of what failed. return errors.New("auth failure: " + v.Any.Local) default: - return errors.New("expected SASL success or failure, got " + v.Name()) + return errors.New("expected SASL success or failure, got " + v.Name().Local) } return err } @@ -60,8 +60,8 @@ type SASLSuccess struct { XMLName xml.Name `xml:"urn:ietf:params:xml:ns:xmpp-sasl success"` } -func (SASLSuccess) Name() string { - return "sasl:success" +func (s SASLSuccess) Name() xml.Name { + return s.XMLName } type saslSuccessDecoder struct{} @@ -82,8 +82,8 @@ type SASLFailure struct { Any xml.Name // error reason is a subelement } -func (SASLFailure) Name() string { - return "sasl:failure" +func (x SASLFailure) Name() xml.Name { + return x.XMLName } type saslFailureDecoder struct{} @@ -111,6 +111,10 @@ type BindBind struct { Jid string `xml:"jid,omitempty"` } +func (b BindBind) Name() xml.Name { + return b.XMLName +} + // Session is obsolete in RFC 6121. // Added for compliance with RFC 3121. // Remove when ejabberd purely conforms to RFC 6121. diff --git a/check_cert.go b/check_cert.go index 6190d96..96d3dad 100644 --- a/check_cert.go +++ b/check_cert.go @@ -73,7 +73,7 @@ func (c *ServerCheck) Check() error { case StreamError: return errors.New("open stream error: " + p.Error.Local) default: - return errors.New("expected packet received while expecting features, got " + p.Name()) + return errors.New("expected packet received while expecting features, got " + p.Name().Local) } startTLSFeature := f.StartTLS.XMLName.Space + " " + f.StartTLS.XMLName.Local diff --git a/component.go b/component.go index 67c4bfd..0486dc0 100644 --- a/component.go +++ b/component.go @@ -67,7 +67,7 @@ func (c *Component) Connect(connStr string) error { case Handshake: return nil default: - return errors.New("unexpected packet, got " + v.Name()) + return errors.New("unexpected packet, got " + v.Name().Local) } panic("unreachable") } @@ -128,8 +128,8 @@ type Handshake struct { // Value string `xml:",innerxml"` } -func (Handshake) Name() string { - return "component:handshake" +func (h Handshake) Name() xml.Name { + return h.XMLName } // Handshake decoding wrapper diff --git a/iot_control.go b/iot_control.go index 795e5e3..cdf1c15 100644 --- a/iot_control.go +++ b/iot_control.go @@ -1,4 +1,4 @@ -package xmpp // import "gosrc.io/xmpp/iot" +package xmpp // import "gosrc.io/xmpp" import ( "encoding/xml" @@ -10,6 +10,10 @@ type ControlSet struct { Fields []ControlField `xml:",any"` } +func (c ControlSet) Name() xml.Name { + return c.XMLName +} + type ControlGetForm struct { XMLName xml.Name `xml:"urn:xmpp:iot:control getForm"` } diff --git a/iq.go b/iq.go index bfc13e3..253181b 100644 --- a/iq.go +++ b/iq.go @@ -154,8 +154,8 @@ func (iq IQ) MakeError(xerror Err) IQ { return iq } -func (IQ) Name() string { - return "iq" +func (iq IQ) Name() xml.Name { + return iq.XMLName } type iqDecoder struct{} @@ -289,6 +289,10 @@ type DiscoInfo struct { Features []Feature `xml:"feature"` } +func (d DiscoInfo) Name() xml.Name { + return d.XMLName +} + type Identity struct { XMLName xml.Name `xml:"identity,omitempty"` Name string `xml:"name,attr,omitempty"` @@ -310,6 +314,10 @@ type DiscoItems struct { Items []DiscoItem `xml:"item"` } +func (d DiscoItems) Name() xml.Name { + return d.XMLName +} + type DiscoItem struct { XMLName xml.Name `xml:"item"` Name string `xml:"name,attr,omitempty"` @@ -318,8 +326,8 @@ type DiscoItem struct { } func init() { - typeRegistry.MapExtension(PKTIQ, xml.Name{"http://jabber.org/protocol/disco#info", "query"}, DiscoInfo{}) - typeRegistry.MapExtension(PKTIQ, xml.Name{"http://jabber.org/protocol/disco#items", "query"}, DiscoItems{}) - typeRegistry.MapExtension(PKTIQ, xml.Name{"urn:ietf:params:xml:ns:xmpp-bind", "bind"}, BindBind{}) - typeRegistry.MapExtension(PKTIQ, xml.Name{"urn:xmpp:iot:control", "set"}, ControlSet{}) + typeRegistry.MapExtension(PKTIQ, DiscoInfo{}) + typeRegistry.MapExtension(PKTIQ, DiscoItems{}) + typeRegistry.MapExtension(PKTIQ, BindBind{}) + typeRegistry.MapExtension(PKTIQ, ControlSet{}) } diff --git a/message.go b/message.go index 5b99348..08f7ffa 100644 --- a/message.go +++ b/message.go @@ -18,8 +18,8 @@ type Message struct { Extensions []MsgExtension `xml:",omitempty"` } -func (Message) Name() string { - return "message" +func (msg Message) Name() xml.Name { + return msg.XMLName } func NewMessage(msgtype, from, to, id, lang string) Message { diff --git a/msg_chat_markers.go b/msg_chat_markers.go index 2940c1f..c34fba7 100644 --- a/msg_chat_markers.go +++ b/msg_chat_markers.go @@ -12,13 +12,21 @@ type Markable struct { XMLName xml.Name `xml:"urn:xmpp:chat-markers:0 markable"` } +func (m Markable) Name() xml.Name { + return m.XMLName +} + type MarkReceived struct { MsgExtension XMLName xml.Name `xml:"urn:xmpp:chat-markers:0 received"` ID string } +func (m MarkReceived) Name() xml.Name { + return m.XMLName +} + func init() { - typeRegistry.MapExtension(PKTMessage, xml.Name{"urn:xmpp:chat-markers:0", "markable"}, Markable{}) - typeRegistry.MapExtension(PKTMessage, xml.Name{"urn:xmpp:chat-markers:0", "received"}, MarkReceived{}) + typeRegistry.MapExtension(PKTMessage, Markable{}) + typeRegistry.MapExtension(PKTMessage, MarkReceived{}) } diff --git a/msg_oob.go b/msg_oob.go index 4b2188e..54c0d4c 100644 --- a/msg_oob.go +++ b/msg_oob.go @@ -14,6 +14,10 @@ type OOB struct { Desc string `xml:"desc,omitempty"` } +func (o OOB) Name() xml.Name { + return o.XMLName +} + func init() { - typeRegistry.MapExtension(PKTMessage, xml.Name{"jabber:x:oob", "x"}, OOB{}) + typeRegistry.MapExtension(PKTMessage, OOB{}) } diff --git a/msg_receipts.go b/msg_receipts.go index 76958a8..183e4f4 100644 --- a/msg_receipts.go +++ b/msg_receipts.go @@ -13,13 +13,21 @@ type ReceiptRequest struct { XMLName xml.Name `xml:"urn:xmpp:receipts request"` } +func (r ReceiptRequest) Name() xml.Name { + return r.XMLName +} + type ReceiptReceived struct { MsgExtension XMLName xml.Name `xml:"urn:xmpp:receipts received"` ID string } +func (r ReceiptReceived) Name() xml.Name { + return r.XMLName +} + func init() { - typeRegistry.MapExtension(PKTMessage, xml.Name{"urn:xmpp:receipts", "request"}, ReceiptRequest{}) - typeRegistry.MapExtension(PKTMessage, xml.Name{"urn:xmpp:receipts", "received"}, ReceiptReceived{}) + typeRegistry.MapExtension(PKTMessage, ReceiptRequest{}) + typeRegistry.MapExtension(PKTMessage, ReceiptReceived{}) } diff --git a/packet.go b/packet.go index 89ec126..886fa40 100644 --- a/packet.go +++ b/packet.go @@ -1,7 +1,9 @@ package xmpp // import "gosrc.io/xmpp" +import "encoding/xml" + type Packet interface { - Name() string + Name() xml.Name } // PacketAttrs represents the common structure for base XMPP packets. diff --git a/parser.go b/parser.go index 0ad24fc..5489ad9 100644 --- a/parser.go +++ b/parser.go @@ -88,54 +88,54 @@ func next(p *xml.Decoder) (Packet, error) { func decodeStream(p *xml.Decoder, se xml.StartElement) (Packet, error) { switch se.Name.Local { - case "error": + case StreamError{}.Name().Local: return streamError.decode(p, se) - case "features": + case StreamFeatures{}.Name().Local: return streamFeatures.decode(p, se) default: - return nil, errors.New("unexpected XMPP packet " + + return nil, errors.New("unexpected stream XMPP packet " + se.Name.Space + " <" + se.Name.Local + "/>") } } func decodeSASL(p *xml.Decoder, se xml.StartElement) (Packet, error) { switch se.Name.Local { - case "success": + case SASLSuccess{}.Name().Local: return saslSuccess.decode(p, se) - case "failure": + case SASLFailure{}.Name().Local: return saslFailure.decode(p, se) default: - return nil, errors.New("unexpected XMPP packet " + + return nil, errors.New("unexpected sasl XMPP packet " + se.Name.Space + " <" + se.Name.Local + "/>") } } func decodeClient(p *xml.Decoder, se xml.StartElement) (Packet, error) { switch se.Name.Local { - case "message": + case Message{}.Name().Local: return message.decode(p, se) - case "presence": + case Presence{}.Name().Local: return presence.decode(p, se) - case "iq": + case IQ{}.Name().Local: return iq.decode(p, se) default: - return nil, errors.New("unexpected XMPP packet " + + return nil, errors.New("unexpected client XMPP packet " + se.Name.Space + " <" + se.Name.Local + "/>") } } func decodeComponent(p *xml.Decoder, se xml.StartElement) (Packet, error) { switch se.Name.Local { - case "handshake": + case Handshake{}.Name().Local: return handshake.decode(p, se) - case "message": + case Message{}.Name().Local: return message.decode(p, se) - case "presence": + case Presence{}.Name().Local: return presence.decode(p, se) - case "iq": + case IQ{}.Name().Local: return iq.decode(p, se) default: - return nil, errors.New("unexpected XMPP packet " + + return nil, errors.New("unexpected component XMPP packet " + se.Name.Space + " <" + se.Name.Local + "/>") } } diff --git a/pep.go b/pep.go index 1870ca8..6ead8ae 100644 --- a/pep.go +++ b/pep.go @@ -1,4 +1,4 @@ -package xmpp // import "gosrc.io/xmpp/pep" +package xmpp // import "gosrc.io/xmpp" import ( "encoding/xml" diff --git a/presence.go b/presence.go index 0cadeaf..6845a73 100644 --- a/presence.go +++ b/presence.go @@ -14,8 +14,8 @@ type Presence struct { Error Err `xml:"error,omitempty"` } -func (Presence) Name() string { - return "presence" +func (p Presence) Name() xml.Name { + return p.XMLName } func NewPresence(from, to, id, lang string) Presence { diff --git a/registry.go b/registry.go index 49e6178..3270130 100644 --- a/registry.go +++ b/registry.go @@ -6,7 +6,9 @@ import ( "sync" ) -type MsgExtension interface{} +type MsgExtension interface { + Name() xml.Name +} // The Registry for msg and IQ types is a global variable. // TODO: Move to the client init process to remove the dependency on a global variable. @@ -49,7 +51,8 @@ func newRegistry() *registry { // The match is done per packetType (iq, message, or presence) and XML tag name. // You can use the alias "*" as local XML name to be able to match all unknown tag name for that // packet type and namespace. -func (r *registry) MapExtension(pktType packetType, name xml.Name, extension MsgExtension) { +func (r *registry) MapExtension(pktType packetType, extension MsgExtension) { + name := extension.Name() key := registryKey{pktType, name.Space} r.msgTypesLock.RLock() store := r.msgTypes[key] diff --git a/registry_test.go b/registry_test.go index 4b3ba81..500be19 100644 --- a/registry_test.go +++ b/registry_test.go @@ -1,7 +1,6 @@ package xmpp // import "gosrc.io/xmpp" import ( - "encoding/xml" "reflect" "testing" ) @@ -11,11 +10,14 @@ func TestRegistry_RegisterMsgExt(t *testing.T) { typeRegistry := newRegistry() // Register an element - name := xml.Name{Space: "urn:xmpp:receipts", Local: "request"} - typeRegistry.MapExtension(PKTMessage, name, ReceiptRequest{}) + req := ReceiptRequest{} + res := ReceiptReceived{} + + typeRegistry.MapExtension(PKTMessage, req) + typeRegistry.MapExtension(PKTMessage, res) // Match that element - receipt := typeRegistry.GetMsgExtension(name) + receipt := typeRegistry.GetMsgExtension(req.Name()) if receipt == nil { t.Error("cannot read element type from registry") return @@ -33,12 +35,12 @@ func BenchmarkRegistryGet(b *testing.B) { typeRegistry := newRegistry() // Register an element - name := xml.Name{Space: "urn:xmpp:receipts", Local: "request"} - typeRegistry.MapExtension(PKTMessage, name, ReceiptRequest{}) + req := ReceiptRequest{} + typeRegistry.MapExtension(PKTMessage, req) for i := 0; i < b.N; i++ { // Match that element - receipt := typeRegistry.GetExtensionType(PKTMessage, name) + receipt := typeRegistry.GetExtensionType(PKTMessage, req.Name()) if receipt == nil { b.Error("cannot read element type from registry") return diff --git a/stream.go b/stream.go index 1e0df97..87b9476 100644 --- a/stream.go +++ b/stream.go @@ -17,8 +17,8 @@ type StreamFeatures struct { Any []xml.Name `xml:",any"` } -func (StreamFeatures) Name() string { - return "stream:features" +func (s StreamFeatures) Name() xml.Name { + return s.XMLName } type streamFeatureDecoder struct{} @@ -39,8 +39,8 @@ type StreamError struct { Error xml.Name `xml:",any"` } -func (StreamError) Name() string { - return "stream:error" +func (s StreamError) Name() xml.Name { + return s.XMLName } type streamErrorDecoder struct{}