diff --git a/src/openvpn/forward.c b/src/openvpn/forward.c index 0798337c19d..0f2ec072530 100644 --- a/src/openvpn/forward.c +++ b/src/openvpn/forward.c @@ -1379,8 +1379,6 @@ drop_if_recursive_routing(struct context *c, struct buffer *buf) if (proto_ver == 4) { - const struct openvpn_iphdr *pip; - /* make sure we got whole IP header */ if (BLEN(buf) < ((int) sizeof(struct openvpn_iphdr) + ip_hdr_offset)) { @@ -1393,18 +1391,16 @@ drop_if_recursive_routing(struct context *c, struct buffer *buf) return; } - pip = (struct openvpn_iphdr *) (BPTR(buf) + ip_hdr_offset); + struct openvpn_iphdr *pip = (struct openvpn_iphdr *) (BPTR(buf) + ip_hdr_offset); /* drop packets with same dest addr as gateway */ - if (tun_sa.addr.in4.sin_addr.s_addr == pip->daddr) + if (memcmp(&tun_sa.addr.in4.sin_addr.s_addr, &pip->daddr, sizeof(pip->daddr)) == 0) { drop = true; } } else if (proto_ver == 6) { - const struct openvpn_ipv6hdr *pip6; - /* make sure we got whole IPv6 header */ if (BLEN(buf) < ((int) sizeof(struct openvpn_ipv6hdr) + ip_hdr_offset)) { @@ -1417,9 +1413,10 @@ drop_if_recursive_routing(struct context *c, struct buffer *buf) return; } + struct openvpn_ipv6hdr *pip6 = (struct openvpn_ipv6hdr *) (BPTR(buf) + ip_hdr_offset); + /* drop packets with same dest addr as gateway */ - pip6 = (struct openvpn_ipv6hdr *) (BPTR(buf) + ip_hdr_offset); - if (IN6_ARE_ADDR_EQUAL(&tun_sa.addr.in6.sin6_addr, &pip6->daddr)) + if (OPENVPN_IN6_ARE_ADDR_EQUAL(&tun_sa.addr.in6.sin6_addr, &pip6->daddr)) { drop = true; } diff --git a/src/openvpn/proto.h b/src/openvpn/proto.h index 4b6d6d6efca..a21fc99cf71 100644 --- a/src/openvpn/proto.h +++ b/src/openvpn/proto.h @@ -103,6 +103,12 @@ struct openvpn_arp { in_addr_t ip_dest; }; +/** Version of IN6_ARE_ADDR_EQUAL that is guaranteed to work for + * unaligned access. E.g. Linux uses 32bit compares which are + * not safe if the struct is unaligned. */ +#define OPENVPN_IN6_ARE_ADDR_EQUAL(a, b) \ + (memcmp(a, b, sizeof(struct in6_addr)) == 0) + struct openvpn_iphdr { #define OPENVPN_IPH_GET_VER(v) (((v) >> 4) & 0x0F) #define OPENVPN_IPH_GET_LEN(v) (((v) & 0x0F) << 2)