Skip to content

Commit f4d6e83

Browse files
shailend-ggvisor-bot
authored andcommitted
Allow IP_(ADD|REMOVE)_MEMBERSHIP on AF_INET6 sockets
Just like Linux does. The new test `AddV4MembershipToV6Socket` passes only if there is a routing table entry to resolve the ipv4 multicast address, and that is easiest to setup via a "external networking" test. Now it turns out that the ipv6 "external networking" test did not actually have a second non-loopback (external) interface, so I've added that. PiperOrigin-RevId: 807959283
1 parent 6448dbe commit f4d6e83

File tree

9 files changed

+185
-73
lines changed

9 files changed

+185
-73
lines changed

pkg/tcpip/tcpip.go

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,10 +49,8 @@ import (
4949

5050
// Using the header package here would cause an import cycle.
5151
const (
52-
ipv4AddressSize = 4
53-
ipv4ProtocolNumber = 0x0800
54-
ipv6AddressSize = 16
55-
ipv6ProtocolNumber = 0x86dd
52+
ipv4AddressSize = 4
53+
ipv6AddressSize = 16
5654
)
5755

5856
const (

pkg/tcpip/tests/integration/BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,6 @@ go_test(
128128
"//pkg/tcpip/tests/utils",
129129
"//pkg/tcpip/testutil",
130130
"//pkg/tcpip/transport/icmp",
131-
"//pkg/tcpip/transport/raw",
132131
"//pkg/tcpip/transport/udp",
133132
"//pkg/waiter",
134133
"@com_github_google_go_cmp//cmp:go_default_library",

pkg/tcpip/tests/integration/multicast_broadcast_test.go

Lines changed: 0 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@ import (
3333
"gvisor.dev/gvisor/pkg/tcpip/tests/utils"
3434
"gvisor.dev/gvisor/pkg/tcpip/testutil"
3535
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
36-
"gvisor.dev/gvisor/pkg/tcpip/transport/raw"
3736
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
3837
"gvisor.dev/gvisor/pkg/waiter"
3938
)
@@ -779,54 +778,3 @@ func TestAddMembershipInterfacePrecedence(t *testing.T) {
779778
t.Fatalf("ep.SetSockOpt(&%#v): %s", addOpt, err)
780779
}
781780
}
782-
783-
func TestMismatchedMulticastAddressAndProtocol(t *testing.T) {
784-
const nicID = 1
785-
// MulticastAddr is IPv4, but proto is IPv6.
786-
multicastAddr := tcpip.AddrFromSlice([]byte("\xe0\x01\x02\x03"))
787-
s := stack.New(stack.Options{
788-
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol},
789-
TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol6},
790-
RawFactory: raw.EndpointFactory{},
791-
})
792-
e := channel.New(0, defaultMTU, "")
793-
defer e.Close()
794-
if err := s.CreateNIC(nicID, e); err != nil {
795-
t.Fatalf("CreateNIC(%d, _): %s", nicID, err)
796-
}
797-
protoAddr := tcpip.ProtocolAddress{Protocol: header.IPv4ProtocolNumber, AddressWithPrefix: utils.Ipv4Addr}
798-
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
799-
t.Fatalf("AddProtocolAddress(%d, %+v, {}): %s", nicID, protoAddr, err)
800-
}
801-
802-
var wq waiter.Queue
803-
ep, err := s.NewRawEndpoint(header.ICMPv6ProtocolNumber, header.IPv6ProtocolNumber, &wq, false)
804-
if err != nil {
805-
t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, header.IPv6ProtocolNumber, err)
806-
}
807-
defer ep.Close()
808-
809-
bindAddr := tcpip.FullAddress{Port: utils.LocalPort}
810-
if err := ep.Bind(bindAddr); err != nil {
811-
t.Fatalf("ep.Bind(%#v): %s", bindAddr, err)
812-
}
813-
814-
memOpt := tcpip.MembershipOption{
815-
MulticastAddr: multicastAddr,
816-
NIC: 0,
817-
InterfaceAddr: utils.Ipv4Addr.Address,
818-
}
819-
820-
// Add/remove membership should succeed when the interface index is specified,
821-
// even if a bad interface address is specified.
822-
addOpt := tcpip.AddMembershipOption(memOpt)
823-
expErr := &tcpip.ErrInvalidOptionValue{}
824-
if err := ep.SetSockOpt(&addOpt); err != expErr {
825-
t.Fatalf("ep.SetSockOpt(&%#v): want %q, got %q", addOpt, expErr, err)
826-
}
827-
828-
removeOpt := tcpip.RemoveMembershipOption(memOpt)
829-
if err := ep.SetSockOpt(&removeOpt); err != expErr {
830-
t.Fatalf("ep.SetSockOpt(&%#v): want %q, got %q", addOpt, expErr, err)
831-
}
832-
}

pkg/tcpip/transport/internal/network/endpoint.go

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,11 @@ func (e *Endpoint) Close() {
181181
}
182182

183183
for mem := range e.multicastMemberships {
184-
e.stack.LeaveGroup(e.netProto, mem.nicID, mem.multicastAddr)
184+
proto, err := e.multicastNetProto(mem.multicastAddr)
185+
if err != nil {
186+
panic("non multicast address in an existing membership")
187+
}
188+
e.stack.LeaveGroup(proto, mem.nicID, mem.multicastAddr)
185189
}
186190
e.multicastMemberships = nil
187191

@@ -912,6 +916,19 @@ func (e *Endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, tcpip.Error) {
912916
}
913917
}
914918

919+
// multicastNetProto returns the network protocol of a given multicast address.
920+
// Returns an error if the address is not a multicast address.
921+
func (e *Endpoint) multicastNetProto(addr tcpip.Address) (tcpip.NetworkProtocolNumber, tcpip.Error) {
922+
switch {
923+
case header.IsV4MulticastAddress(addr):
924+
return header.IPv4ProtocolNumber, nil
925+
case header.IsV6MulticastAddress(addr):
926+
return header.IPv6ProtocolNumber, nil
927+
default:
928+
return 0, &tcpip.ErrInvalidOptionValue{}
929+
}
930+
}
931+
915932
// SetSockOpt sets the socket option.
916933
func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
917934
switch v := opt.(type) {
@@ -952,21 +969,21 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
952969
e.multicastAddr = addr
953970

954971
case *tcpip.AddMembershipOption:
955-
if !(header.IsV4MulticastAddress(v.MulticastAddr) && e.netProto == header.IPv4ProtocolNumber) && !(header.IsV6MulticastAddress(v.MulticastAddr) && e.netProto == header.IPv6ProtocolNumber) {
956-
return &tcpip.ErrInvalidOptionValue{}
972+
proto, err := e.multicastNetProto(v.MulticastAddr)
973+
if err != nil {
974+
return err
957975
}
958976

959977
nicID := v.NIC
960-
961978
if v.InterfaceAddr.Unspecified() {
962979
if nicID == 0 {
963-
if r, err := e.stack.FindRoute(0, tcpip.Address{}, v.MulticastAddr, e.netProto, false /* multicastLoop */); err == nil {
980+
if r, err := e.stack.FindRoute(0, tcpip.Address{}, v.MulticastAddr, proto, false /* multicastLoop */); err == nil {
964981
nicID = r.NICID()
965982
r.Release()
966983
}
967984
}
968985
} else {
969-
nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
986+
nicID = e.stack.CheckLocalAddress(nicID, proto, v.InterfaceAddr)
970987
}
971988
if nicID == 0 {
972989
return &tcpip.ErrUnknownDevice{}
@@ -981,27 +998,28 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
981998
return &tcpip.ErrPortInUse{}
982999
}
9831000

984-
if err := e.stack.JoinGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
1001+
if err := e.stack.JoinGroup(proto, nicID, v.MulticastAddr); err != nil {
9851002
return err
9861003
}
9871004

9881005
e.multicastMemberships[memToInsert] = struct{}{}
9891006

9901007
case *tcpip.RemoveMembershipOption:
991-
if !(header.IsV4MulticastAddress(v.MulticastAddr) && e.netProto == header.IPv4ProtocolNumber) && !(header.IsV6MulticastAddress(v.MulticastAddr) && e.netProto == header.IPv6ProtocolNumber) {
992-
return &tcpip.ErrInvalidOptionValue{}
1008+
proto, err := e.multicastNetProto(v.MulticastAddr)
1009+
if err != nil {
1010+
return err
9931011
}
9941012

9951013
nicID := v.NIC
9961014
if v.InterfaceAddr.Unspecified() {
9971015
if nicID == 0 {
998-
if r, err := e.stack.FindRoute(0, tcpip.Address{}, v.MulticastAddr, e.netProto, false /* multicastLoop */); err == nil {
1016+
if r, err := e.stack.FindRoute(0, tcpip.Address{}, v.MulticastAddr, proto, false /* multicastLoop */); err == nil {
9991017
nicID = r.NICID()
10001018
r.Release()
10011019
}
10021020
}
10031021
} else {
1004-
nicID = e.stack.CheckLocalAddress(nicID, e.netProto, v.InterfaceAddr)
1022+
nicID = e.stack.CheckLocalAddress(nicID, proto, v.InterfaceAddr)
10051023
}
10061024
if nicID == 0 {
10071025
return &tcpip.ErrUnknownDevice{}
@@ -1016,7 +1034,7 @@ func (e *Endpoint) SetSockOpt(opt tcpip.SettableSocketOption) tcpip.Error {
10161034
return &tcpip.ErrBadLocalAddress{}
10171035
}
10181036

1019-
if err := e.stack.LeaveGroup(e.netProto, nicID, v.MulticastAddr); err != nil {
1037+
if err := e.stack.LeaveGroup(proto, nicID, v.MulticastAddr); err != nil {
10201038
return err
10211039
}
10221040

pkg/tcpip/transport/internal/network/endpoint_state.go

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,12 @@ func (e *Endpoint) Resume(s *stack.Stack) error {
2929

3030
e.stack = s
3131
for m := range e.multicastMemberships {
32-
if err := e.stack.JoinGroup(e.netProto, m.nicID, m.multicastAddr); err != nil {
33-
return fmt.Errorf("e.stack.JoinGroup(%d, %d, %s): %s", e.netProto, m.nicID, m.multicastAddr, err)
32+
proto, err := e.multicastNetProto(m.multicastAddr)
33+
if err != nil {
34+
panic("non multicast address in an existing membership during Resume")
35+
}
36+
if err := e.stack.JoinGroup(proto, m.nicID, m.multicastAddr); err != nil {
37+
return fmt.Errorf("e.stack.JoinGroup(%d, %d, %s): %s", proto, m.nicID, m.multicastAddr, err)
3438
}
3539
}
3640

test/syscalls/BUILD

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -810,6 +810,12 @@ syscall_test(
810810
test = "//test/syscalls/linux:socket_ipv4_udp_unbound_external_networking_test",
811811
)
812812

813+
syscall_test(
814+
# FIXME: TestJoinLeaveMulticast fails with add_hostinet.
815+
netstack_sr = True,
816+
test = "//test/syscalls/linux:socket_ipv6_udp_unbound_external_networking_test",
817+
)
818+
813819
syscall_test(
814820
size = "large",
815821
add_hostinet = True,

test/syscalls/linux/BUILD

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2896,7 +2896,13 @@ cc_library(
28962896
"socket_ipv6_udp_unbound_external_networking.h",
28972897
],
28982898
deps = select_gtest() + [
2899+
":ip_socket_test_util",
28992900
":socket_ip_udp_unbound_external_networking",
2901+
"//test/util:file_descriptor",
2902+
"//test/util:posix_error",
2903+
"//test/util:socket_util",
2904+
"//test/util:test_util",
2905+
"@com_google_absl//absl/cleanup",
29002906
],
29012907
alwayslink = 1,
29022908
)

test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.cc

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,72 @@
1414

1515
#include "test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.h"
1616

17+
#include <net/if.h>
18+
#include <sys/socket.h>
19+
20+
#include <cerrno>
21+
#include <cstring>
22+
#include <iostream>
23+
#include <optional>
24+
#include <utility>
25+
26+
#include "gmock/gmock.h"
27+
#include "gtest/gtest.h"
28+
#include "absl/cleanup/cleanup.h"
29+
#include "test/syscalls/linux/ip_socket_test_util.h"
30+
#include "test/util/file_descriptor.h"
31+
#include "test/util/posix_error.h"
32+
#include "test/util/socket_util.h"
33+
#include "test/util/test_util.h"
34+
1735
namespace gvisor {
1836
namespace testing {
1937

38+
void IPv6UDPUnboundExternalNetworkingSocketTest::SetUp() {
39+
#ifdef ANDROID
40+
GTEST_SKIP() << "Android does not support getifaddrs in r22";
41+
#endif
42+
43+
ifaddrs* ifaddr;
44+
ASSERT_THAT(getifaddrs(&ifaddr), SyscallSucceeds());
45+
auto cleanup = absl::MakeCleanup([ifaddr] { freeifaddrs(ifaddr); });
46+
47+
for (const ifaddrs* ifa = ifaddr; ifa != nullptr; ifa = ifa->ifa_next) {
48+
ASSERT_NE(ifa->ifa_name, nullptr);
49+
ASSERT_NE(ifa->ifa_addr, nullptr);
50+
51+
if (ifa->ifa_addr->sa_family != AF_INET6) {
52+
continue;
53+
}
54+
55+
std::optional<std::pair<int, sockaddr_in6>>& if_pair = *[this, ifa]() {
56+
if (strcmp(ifa->ifa_name, "lo") == 0) {
57+
return &lo_if_;
58+
}
59+
return &eth_if_;
60+
}();
61+
62+
const int if_index =
63+
ASSERT_NO_ERRNO_AND_VALUE(InterfaceIndex(ifa->ifa_name));
64+
65+
std::cout << " name=" << ifa->ifa_name
66+
<< " addr=" << GetAddrStr(ifa->ifa_addr) << " index=" << if_index
67+
<< " has_value=" << if_pair.has_value() << std::endl;
68+
69+
if (if_pair.has_value()) {
70+
continue;
71+
}
72+
73+
if_pair = std::make_pair(
74+
if_index, *reinterpret_cast<const sockaddr_in6*>(ifa->ifa_addr));
75+
}
76+
77+
if (!(eth_if_.has_value() && lo_if_.has_value())) {
78+
GTEST_SKIP() << " eth_if_.has_value()=" << eth_if_.has_value()
79+
<< " lo_if_.has_value()=" << lo_if_.has_value();
80+
}
81+
}
82+
2083
TEST_P(IPv6UDPUnboundExternalNetworkingSocketTest, TestJoinLeaveMulticast) {
2184
auto sender = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
2285
auto receiver = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
@@ -82,5 +145,58 @@ TEST_P(IPv6UDPUnboundExternalNetworkingSocketTest, TestJoinLeaveMulticast) {
82145
SyscallFailsWithErrno(EAGAIN));
83146
}
84147

148+
// Test that an AF_INET6 socket can set the IP_ADD_MEMBERSHIP socket option.
149+
TEST_P(IPv6UDPUnboundExternalNetworkingSocketTest, AddV4MembershipToV6Socket) {
150+
TestAddress send_addr = V4Multicast();
151+
sockaddr_in* send_addr_in = reinterpret_cast<sockaddr_in*>(&send_addr.addr);
152+
153+
// recv is an AF_INET6 socket while send is an AF_INET socket.
154+
auto recv = ASSERT_NO_ERRNO_AND_VALUE(NewSocket());
155+
FileDescriptor send =
156+
ASSERT_NO_ERRNO_AND_VALUE(Socket(AF_INET, SOCK_DGRAM, IPPROTO_UDP));
157+
158+
// Make recv join the multicast group with address `send_addr`.
159+
// Note that IP_ADD_MEMBERSHIP is used instead of IPV6_ADD_MEMBERSHIP, and
160+
// the group address is an IPv4 address.
161+
struct ip_mreq mreq;
162+
mreq.imr_multiaddr.s_addr = send_addr_in->sin_addr.s_addr;
163+
mreq.imr_interface.s_addr = htonl(INADDR_ANY);
164+
ASSERT_THAT(
165+
setsockopt(recv->get(), SOL_IP, IP_ADD_MEMBERSHIP, &mreq, sizeof(mreq)),
166+
SyscallSucceeds());
167+
168+
// Bind recv to ::.
169+
auto recv_addr = V6Any();
170+
ASSERT_THAT(
171+
bind(recv->get(), AsSockAddr(&recv_addr.addr), recv_addr.addr_len),
172+
SyscallSucceeds());
173+
socklen_t recv_addr_len = recv_addr.addr_len;
174+
ASSERT_THAT(
175+
getsockname(recv->get(), AsSockAddr(&recv_addr.addr), &recv_addr_len),
176+
SyscallSucceeds());
177+
EXPECT_EQ(recv_addr_len, recv_addr.addr_len);
178+
179+
// Send a multicast packet...
180+
send_addr_in->sin_port =
181+
reinterpret_cast<sockaddr_in*>(&recv_addr.addr)->sin_port;
182+
char send_buf[200];
183+
RandomizeBuffer(send_buf, sizeof(send_buf));
184+
ASSERT_THAT(
185+
RetryEINTR(sendto)(send.get(), send_buf, sizeof(send_buf), 0,
186+
AsSockAddr(&send_addr.addr), send_addr.addr_len),
187+
SyscallSucceedsWithValue(sizeof(send_buf)));
188+
189+
// ...and check that it was received.
190+
char recv_buf[sizeof(send_buf)] = {};
191+
ASSERT_THAT(
192+
RecvTimeout(recv->get(), recv_buf, sizeof(recv_buf), 1 /*timeout*/),
193+
IsPosixErrorOkAndHolds(sizeof(recv_buf)));
194+
EXPECT_EQ(0, memcmp(send_buf, recv_buf, sizeof(send_buf)));
195+
196+
ASSERT_THAT(
197+
setsockopt(recv->get(), SOL_IP, IP_DROP_MEMBERSHIP, &mreq, sizeof(mreq)),
198+
SyscallSucceeds());
199+
}
200+
85201
} // namespace testing
86202
} // namespace gvisor

test/syscalls/linux/socket_ipv6_udp_unbound_external_networking.h

Lines changed: 19 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,32 @@
1515
#ifndef GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV6_UDP_UNBOUND_EXTERNAL_NETWORKING_H_
1616
#define GVISOR_TEST_SYSCALLS_LINUX_SOCKET_IPV6_UDP_UNBOUND_EXTERNAL_NETWORKING_H_
1717

18+
#include <optional>
19+
#include <utility>
20+
1821
#include "test/syscalls/linux/socket_ip_udp_unbound_external_networking.h"
1922

2023
namespace gvisor {
2124
namespace testing {
2225

2326
// Test fixture for tests that apply to unbound IPv6 UDP sockets in a sandbox
2427
// with external networking support.
25-
using IPv6UDPUnboundExternalNetworkingSocketTest =
26-
IPUDPUnboundExternalNetworkingSocketTest;
28+
class IPv6UDPUnboundExternalNetworkingSocketTest
29+
: public IPUDPUnboundExternalNetworkingSocketTest {
30+
protected:
31+
void SetUp() override;
32+
33+
int lo_if_idx() const { return std::get<0>(lo_if_.value()); }
34+
int eth_if_idx() const { return std::get<0>(eth_if_.value()); }
35+
36+
const sockaddr_in6& lo_if_addr() const { return std::get<1>(lo_if_.value()); }
37+
const sockaddr_in6& eth_if_addr() const {
38+
return std::get<1>(eth_if_.value());
39+
}
40+
41+
private:
42+
std::optional<std::pair<int, sockaddr_in6>> lo_if_, eth_if_;
43+
};
2744

2845
} // namespace testing
2946
} // namespace gvisor

0 commit comments

Comments
 (0)