Skip to content

Commit

Permalink
Simplify checksum calculation
Browse files Browse the repository at this point in the history
Hope I didn't mess this up...
  • Loading branch information
hack3ric committed Apr 1, 2024
1 parent 338ad73 commit 813b373
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 48 deletions.
26 changes: 13 additions & 13 deletions src/bpf/ingress.c
Original file line number Diff line number Diff line change
Expand Up @@ -190,9 +190,10 @@ int ingress_handler(struct xdp_md* xdp) {
ipv4->tot_len = new_len;
ipv4->protocol = IPPROTO_UDP;

// See RFC 1624
__u32 ipv4_csum = (__u16)~ntohs(ipv4->check);
update_csum(&ipv4_csum, -(__s32)TCP_UDP_HEADER_DIFF);
update_csum(&ipv4_csum, IPPROTO_UDP - IPPROTO_TCP);
ipv4_csum += 0xffff - TCP_UDP_HEADER_DIFF;
ipv4_csum += ~IPPROTO_UDP + IPPROTO_TCP;
ipv4->check = htons(csum_fold(ipv4_csum));
} else if (ipv6) {
ipv6_saddr = ipv6->saddr, ipv6_daddr = ipv6->daddr;
Expand All @@ -208,21 +209,20 @@ int ingress_handler(struct xdp_md* xdp) {

__u32 csum = 0;
if (ipv4) {
update_csum_ul(&csum, ntohl(ipv4_saddr));
update_csum_ul(&csum, ntohl(ipv4_daddr));
csum += u32_fold(ntohl(ipv4_saddr));
csum += u32_fold(ntohl(ipv4_daddr));
} else if (ipv6) {
for (int i = 0; i < 8; i++) {
update_csum(&csum, ntohs(ipv6_saddr.in6_u.u6_addr16[i]));
update_csum(&csum, ntohs(ipv6_daddr.in6_u.u6_addr16[i]));
csum += ntohs(ipv6_saddr.in6_u.u6_addr16[i]);
csum += ntohs(ipv6_daddr.in6_u.u6_addr16[i]);
}
}
update_csum(&csum, IPPROTO_UDP);
update_csum(&csum, udp_len);
update_csum(&csum, ntohs(udp->source));
update_csum(&csum, ntohs(udp->dest));
update_csum(&csum, udp_len);

update_csum_data(xdp, &csum, ip_end + sizeof(*udp));
csum += IPPROTO_UDP;
csum += udp_len;
csum += ntohs(udp->source);
csum += ntohs(udp->dest);
csum += udp_len;
csum += calc_csum_ctx(xdp, ip_end + sizeof(*udp));
udp->check = htons(csum_fold(csum));

return XDP_PASS;
Expand Down
22 changes: 11 additions & 11 deletions src/run.c
Original file line number Diff line number Diff line change
Expand Up @@ -182,20 +182,20 @@ static inline int send_ctrl_packet(struct send_options* s) {
__u32 local = s->conn.local.v4, remote = s->conn.remote.v4;
*(struct sockaddr_in*)&saddr = (struct sockaddr_in){.sin_family = AF_INET, .sin_addr = local, .sin_port = 0};
*(struct sockaddr_in*)&daddr = (struct sockaddr_in){.sin_family = AF_INET, .sin_addr = remote, .sin_port = 0};
update_csum_ul(&csum, ntohl(local));
update_csum_ul(&csum, ntohl(remote));
csum += u32_fold(ntohl(local));
csum += u32_fold(ntohl(remote));
} else {
*(struct sockaddr_in6*)&saddr =
(struct sockaddr_in6){.sin6_family = AF_INET6, .sin6_addr = s->conn.local.v6, .sin6_port = 0};
*(struct sockaddr_in6*)&daddr =
(struct sockaddr_in6){.sin6_family = AF_INET6, .sin6_addr = s->conn.remote.v6, .sin6_port = 0};
for (int i = 0; i < 8; i++) {
update_csum(&csum, ntohs(s->conn.local.v6.s6_addr16[i]));
update_csum(&csum, ntohs(s->conn.remote.v6.s6_addr16[i]));
csum += ntohs(s->conn.local.v6.s6_addr16[i]);
csum += ntohs(s->conn.remote.v6.s6_addr16[i]);
}
}
update_csum(&csum, IPPROTO_TCP);
update_csum(&csum, sizeof(struct tcphdr));
csum += IPPROTO_TCP;
csum += sizeof(struct tcphdr);
try(bind(sk, (struct sockaddr*)&saddr, sizeof(saddr)), _("failed to bind: %s"), strerror(-_ret));

struct tcphdr tcp = {
Expand All @@ -210,11 +210,11 @@ static inline int send_ctrl_packet(struct send_options* s) {
.window = htons(0xfff),
.urg_ptr = 0,
};
update_csum(&csum, ntohs(tcp.source));
update_csum(&csum, ntohs(tcp.dest));
update_csum_ul(&csum, s->seq);
update_csum_ul(&csum, s->ack_seq);
update_csum_ul(&csum, ntohl(tcp_flag_word(&tcp)));
csum += ntohs(tcp.source);
csum += ntohs(tcp.dest);
csum += u32_fold(s->seq);
csum += u32_fold(s->ack_seq);
csum += u32_fold(ntohl(tcp_flag_word(&tcp)));
tcp.check = htons(csum_fold(csum));

try(sendto(sk, &tcp, sizeof(tcp), 0, (struct sockaddr*)&daddr, sizeof(daddr)), _("failed to send: %s"),
Expand Down
32 changes: 8 additions & 24 deletions src/shared/checksum.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,46 +11,30 @@

#include "util.h"

static inline __u16 csum_fold(__u32 csum) {
csum = (csum & 0xffff) + (csum >> 16);
csum = (csum & 0xffff) + (csum >> 16);
return (__u16)~csum;
}

static inline void update_csum(__u32* csum, __s32 delta) {
if (delta < 0) delta += 0xffff;
*csum += delta;
}

static inline void update_csum_ul(__u32* csum, __u32 new) {
__s32 value = (new >> 16) + (new & 0xffff);
update_csum(csum, value);
}

static inline void update_csum_ul_neg(__u32* csum, __u32 new) {
__s32 value = -(new >> 16) - (new & 0xffff);
update_csum(csum, value);
}
static inline __u32 u32_fold(__u32 num) { return (num & 0xffff) + (num >> 16); }
static inline __u16 csum_fold(__u32 csum) { return ~u32_fold(u32_fold(csum)); }

#ifdef _MIMIC_BPF

// HACK: make verifier happy; otherwise it will complain "32-bit arithmetic prohibited" on
// {skb,xdp}->{data,data_end} using the signature `void update_csum_data(__u32 data, __u32 data_end,
// __u32* csum, __u32 off)`.
//
// void update_csum_data(void* ctx, __u32* csum, __u32 off)
#define update_csum_data(_x, csum, off) \
// __u32 calc_csum_ctx(void* ctx, __u32 off)
#define calc_csum_ctx(_x, off) \
({ \
__u32 csum = 0; \
__u16* data = (void*)(__u64)_x->data + off; \
int i = 0; \
for (; i < MAX_PACKET_SIZE / sizeof(__u16); i++) { \
if ((__u64)(data + i + 1) > (__u64)_x->data_end) break; \
*csum += ntohs(data[i]); \
csum += ntohs(data[i]); \
} \
__u8* remainder = (__u8*)data + i * sizeof(__u16); \
if ((__u64)(remainder + 1) <= (__u64)_x->data_end) { \
*csum += (__u16)(*remainder << 8); \
csum += (__u16)(*remainder << 8); \
} \
csum; \
})

#else
Expand Down

0 comments on commit 813b373

Please sign in to comment.