Skip to content

Commit 01af1cf

Browse files
rootastoycos
authored andcommitted
Update UDP bit logic + cleanup
Signed-off-by: astoycos <[email protected]>
1 parent f75c48e commit 01af1cf

File tree

7 files changed

+264
-28
lines changed

7 files changed

+264
-28
lines changed

.vscode/settings.json

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
{
2+
"files.associations": {
3+
"bpf_helpers.h": "c"
4+
}
25
}

bpf/tc_udp.bpf.c

Lines changed: 217 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,217 @@
1+
// +build ignore
2+
3+
#include <linux/bpf_common.h>
4+
#include <linux/if_ether.h>
5+
#include <linux/in.h>
6+
#include <linux/ip.h>
7+
#include <linux/udp.h>
8+
#include <linux/bpf.h>
9+
#include <bpf_helpers.h>
10+
#include <bpf_endian.h>
11+
12+
char __license[] SEC("license") = "GPL";
13+
14+
#ifndef memcpy
15+
#define memcpy(dest, src, n) __builtin_memcpy((dest), (src), (n))
16+
#endif
17+
18+
#define MAX_BACKENDS 128
19+
#define MAX_UDP_LENGTH 1480
20+
21+
#define UDP_PAYLOAD_SIZE(x) (unsigned int)(((bpf_htons(x) - sizeof(struct udphdr)) * 8 ) / 4)
22+
23+
static __always_inline void ip_from_int(__u32 *buf, __be32 ip) {
24+
buf[0] = (ip >> 0 ) & 0xFF;
25+
buf[1] = (ip >> 8 ) & 0xFF;
26+
buf[2] = (ip >> 16 ) & 0xFF;
27+
buf[3] = (ip >> 24 ) & 0xFF;
28+
}
29+
30+
static __always_inline void bpf_printk_ip(__be32 ip) {
31+
__u32 ip_parts[4];
32+
ip_from_int((__u32 *)&ip_parts, ip);
33+
bpf_printk("%d.%d.%d.%d", ip_parts[0], ip_parts[1], ip_parts[2], ip_parts[3]);
34+
}
35+
36+
static __always_inline __u16 csum_fold_helper(__u64 csum) {
37+
int i;
38+
#pragma unroll
39+
for (i = 0; i < 4; i++)
40+
{
41+
if (csum >> 16)
42+
csum = (csum & 0xffff) + (csum >> 16);
43+
}
44+
return ~csum;
45+
}
46+
47+
static __always_inline __u16 iph_csum(struct iphdr *iph) {
48+
iph->check = 0;
49+
unsigned long long csum = bpf_csum_diff(0, 0, (unsigned int *)iph, sizeof(struct iphdr), 0);
50+
return csum_fold_helper(csum);
51+
}
52+
53+
// static __always_inline __u16 udp_checksum(struct iphdr *ip, struct udphdr * udp, void * data_end) {
54+
// udp->check = 0;
55+
56+
// // So we can overflow a bit make this __u32
57+
// __u32 csum_total = 0;
58+
// __u16 csum;
59+
// __u16 *buf = (void *)udp;
60+
61+
// csum_total += (__u16)ip->saddr;
62+
// csum_total += (__u16)(ip->saddr >> 16);
63+
// csum_total += (__u16)ip->daddr;
64+
// csum_total += (__u16)(ip->daddr >> 16);
65+
// csum_total += (__u16)(ip->protocol << 8);
66+
// csum_total += udp->len;
67+
68+
// // The number of nibbles in the UDP header + Payload
69+
// unsigned int udp_packet_nibbles = UDP_PAYLOAD_SIZE(udp->len);
70+
71+
// // Here we only want to iterate through payload
72+
// // NOT trailing bits
73+
// for (int i = 0; i <= MAX_UDP_LENGTH; i += 2) {
74+
// if (i > udp_packet_nibbles) {
75+
// break;
76+
// }
77+
78+
// if ((void *)(buf + 1) > data_end) {
79+
// break;
80+
// }
81+
// csum_total += *buf;
82+
// buf++;
83+
// }
84+
85+
// if ((void *)buf + 1 <= data_end) {
86+
// csum_total += (*(__u8 *)buf);
87+
// }
88+
89+
// // Add any cksum overflow back into __u16
90+
// csum = (__u16)csum_total + (__u16)(csum_total >> 16);
91+
92+
// csum = ~csum;
93+
// return csum;
94+
// }
95+
96+
struct backend {
97+
__u32 saddr;
98+
__u32 daddr;
99+
__u16 dport;
100+
__u16 ifindex;
101+
// Cksum isn't required for UDP see:
102+
// https://en.wikipedia.org/wiki/User_Datagram_Protocol
103+
__u8 nocksum;
104+
__u8 pad[3];
105+
};
106+
107+
struct vip_key {
108+
__u32 vip;
109+
__u16 port;
110+
__u8 pad[2];
111+
};
112+
113+
struct {
114+
__uint(type, BPF_MAP_TYPE_HASH);
115+
__uint(max_entries, MAX_BACKENDS);
116+
__type(key, struct vip_key);
117+
__type(value, struct backend);
118+
} backends SEC(".maps");
119+
120+
SEC("classifier")
121+
int tc_prog_func(struct xdp_md *ctx) {
122+
// ---------------------------------------------------------------------------
123+
// Initialize
124+
// ---------------------------------------------------------------------------
125+
126+
void *data = (void *)(long)ctx->data;
127+
void *data_end = (void *)(long)ctx->data_end;
128+
129+
struct ethhdr *eth = data;
130+
if (data + sizeof(struct ethhdr) > data_end) {
131+
bpf_printk("ABORTED: bad ethhdr!");
132+
return XDP_ABORTED;
133+
}
134+
135+
if (bpf_ntohs(eth->h_proto) != ETH_P_IP) {
136+
bpf_printk("PASS: not IP protocol!");
137+
return XDP_PASS;
138+
}
139+
140+
struct iphdr *ip = data + sizeof(struct ethhdr);
141+
if (data + sizeof(struct ethhdr) + sizeof(struct iphdr) > data_end) {
142+
bpf_printk("ABORTED: bad iphdr!");
143+
return XDP_ABORTED;
144+
}
145+
146+
if (ip->protocol != IPPROTO_UDP)
147+
return XDP_PASS;
148+
149+
struct udphdr *udp = data + sizeof(struct ethhdr) + sizeof(struct iphdr);
150+
if (data + sizeof(struct ethhdr) + sizeof(struct iphdr) + sizeof(struct udphdr) > data_end) {
151+
bpf_printk("ABORTED: bad udphdr!");
152+
return XDP_ABORTED;
153+
}
154+
155+
bpf_printk("UDP packet received - daddr:%x, port:%d", ip->daddr, bpf_ntohs(udp->dest));
156+
157+
// ---------------------------------------------------------------------------
158+
// Routing
159+
// ---------------------------------------------------------------------------
160+
161+
struct vip_key key = {
162+
.vip = ip->daddr,
163+
.port = bpf_ntohs(udp->dest)
164+
};
165+
166+
struct backend *bk;
167+
bk = bpf_map_lookup_elem(&backends, &key);
168+
if (!bk) {
169+
bpf_printk("no backends for ip %x:%x", key.vip, key.port);
170+
return XDP_PASS;
171+
}
172+
173+
bpf_printk("got UDP traffic, source address:");
174+
bpf_printk_ip(ip->saddr);
175+
bpf_printk("destination address:");
176+
bpf_printk_ip(ip->daddr);
177+
178+
ip->saddr = bk->saddr;
179+
ip->daddr = bk->daddr;
180+
181+
bpf_printk("updated saddr to:");
182+
bpf_printk_ip(ip->saddr);
183+
bpf_printk("updated daddr to:");
184+
bpf_printk_ip(ip->daddr);
185+
186+
if (udp->dest != bpf_ntohs(bk->dport)) {
187+
udp->dest = bpf_ntohs(bk->dport);
188+
bpf_printk("updated dport to: %d", bk->dport);
189+
}
190+
191+
// memcpy(eth->h_source, bk->shwaddr, sizeof(eth->h_source));
192+
// bpf_printk("new source hwaddr %x:%x:%x:%x:%x:%x", eth->h_source[0], eth->h_source[1], eth->h_source[2], eth->h_source[3], eth->h_source[4], eth->h_source[5]);
193+
194+
// memcpy(eth->h_dest, bk->dhwaddr, sizeof(eth->h_dest));
195+
// bpf_printk("new dest hwaddr %x:%x:%x:%x:%x:%x", eth->h_dest[0], eth->h_dest[1], eth->h_dest[2], eth->h_dest[3], eth->h_dest[4], eth->h_dest[5]);
196+
197+
ip->check = iph_csum(ip);
198+
udp->check = 0;
199+
200+
if (!bk->nocksum){
201+
udp->check = udp_checksum(ip, udp, data_end);
202+
}
203+
204+
bpf_printk("destination interface index %d", bk->ifindex);
205+
206+
int action = bpf_redirect(bk->ifindex, 0);
207+
208+
bpf_printk("redirect action: %d", action);
209+
210+
return action;
211+
}
212+
213+
// SEC("xdp")
214+
// int bpf_redirect_placeholder(struct xdp_md *ctx) {
215+
// bpf_printk("received a packet on dest interface");
216+
// return XDP_PASS;
217+
// }

bpf/xdp_udp.bpf.c

Lines changed: 29 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@ char __license[] SEC("license") = "GPL";
1818
#define MAX_BACKENDS 128
1919
#define MAX_UDP_LENGTH 1480
2020

21-
#define UDP_PAYLOAD_SIZE(x) (unsigned int)(((bpf_htons(x) - sizeof(struct udphdr)) * 8 ) / 4)
22-
2321
static __always_inline void ip_from_int(__u32 *buf, __be32 ip) {
2422
buf[0] = (ip >> 0 ) & 0xFF;
2523
buf[1] = (ip >> 8 ) & 0xFF;
@@ -50,12 +48,16 @@ static __always_inline __u16 iph_csum(struct iphdr *iph) {
5048
return csum_fold_helper(csum);
5149
}
5250

53-
static __always_inline __u16 udp_checksum(struct iphdr *ip, struct udphdr * udp, void * data_end) {
51+
static __always_inline __u16 udp_csum_diff(struct udphdr *udp) {
5452
udp->check = 0;
53+
unsigned long long csum = bpf_csum_diff(0, 0, (unsigned int *)udp, sizeof(struct udphdr), 0);
54+
return csum_fold_helper(csum);
55+
}
5556

57+
static __always_inline __u16 udp_checksum(struct iphdr *ip, struct udphdr * udp, void * data_end) {
5658
// So we can overflow a bit make this __u32
57-
__u32 csum_total = 0;
58-
__u16 csum;
59+
__u64 csum_total = 0;
60+
5961
__u16 *buf = (void *)udp;
6062

6163
csum_total += (__u16)ip->saddr;
@@ -65,32 +67,33 @@ static __always_inline __u16 udp_checksum(struct iphdr *ip, struct udphdr * udp,
6567
csum_total += (__u16)(ip->protocol << 8);
6668
csum_total += udp->len;
6769

68-
// The number of nibbles in the UDP header + Payload
69-
unsigned int udp_packet_nibbles = UDP_PAYLOAD_SIZE(udp->len);
70+
int udp_len = bpf_ntohs(udp->len);
71+
72+
// Verifier fails without this check
73+
if (udp_len >= MAX_UDP_LENGTH) {
74+
return 1;
75+
}
7076

7177
// Here we only want to iterate through payload
7278
// NOT trailing bits
73-
for (int i = 0; i <= MAX_UDP_LENGTH; i += 2) {
74-
if (i > udp_packet_nibbles) {
75-
break;
76-
}
77-
79+
for (int i = 0; i < udp_len; i += 2) {
80+
// Verifier Fails without this check
7881
if ((void *)(buf + 1) > data_end) {
7982
break;
8083
}
81-
csum_total += *buf;
82-
buf++;
83-
}
8484

85-
if ((void *)buf + 1 <= data_end) {
86-
csum_total += (*(__u8 *)buf);
85+
// Last byte
86+
if (i + 1 == udp_len) {
87+
csum_total += (*(__u8 *)buf);
88+
// Verifier fails without this print statement, I have no Idea why :/
89+
bpf_printk("Adding last byte %X to csum", (*(__u8 *)buf));
90+
} else {
91+
csum_total += *buf;
92+
}
93+
buf+=1;
8794
}
8895

89-
// Add any cksum overflow back into __u16
90-
csum = (__u16)csum_total + (__u16)(csum_total >> 16);
91-
92-
csum = ~csum;
93-
return csum;
96+
return csum_fold_helper(csum_total);
9497
}
9598

9699
struct backend {
@@ -198,10 +201,12 @@ int xdp_prog_func(struct xdp_md *ctx) {
198201
bpf_printk("new dest hwaddr %x:%x:%x:%x:%x:%x", eth->h_dest[0], eth->h_dest[1], eth->h_dest[2], eth->h_dest[3], eth->h_dest[4], eth->h_dest[5]);
199202

200203
ip->check = iph_csum(ip);
204+
201205
udp->check = 0;
202-
203206
if (!bk->nocksum){
204-
udp->check = udp_checksum(ip, udp, data_end);
207+
int tmp_check = udp_checksum(ip, udp, data_end);
208+
bpf_printk("Manual Cksum: %X Diff Cksum %X", tmp_check, udp_csum_diff(udp));
209+
udp->check = tmp_check;
205210
}
206211

207212
bpf_printk("destination interface index %d", bk->ifindex);

userspace-go/bpf_bpfeb.o

1.8 KB
Binary file not shown.

userspace-go/bpf_bpfel.o

1.88 KB
Binary file not shown.

userspace-go/userspace-go

34.1 KB
Binary file not shown.

userspace-go/xdp_udp.go

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"bytes"
55
"encoding/binary"
66
"encoding/hex"
7+
"errors"
78
"fmt"
89
"log"
910
"net"
@@ -32,10 +33,15 @@ func main() {
3233
if err != nil {
3334
log.Fatalf("lookup network iface %s: %s", ifaceName, err)
3435
}
35-
36+
37+
var ve *ebpf.VerifierError
3638
objs := bpfObjects{}
3739
if err := loadBpfObjects(&objs, nil); err != nil {
38-
log.Fatalf("loading objects: %s", err)
40+
if errors.As(err, &ve) {
41+
// Using %+v will print the whole verifier error, not just the last
42+
// few lines.
43+
fmt.Printf("Verifier error: %+v\n", ve)
44+
}
3945
}
4046
defer objs.Close()
4147

@@ -61,20 +67,25 @@ func main() {
6167
log.Printf("Press Ctrl-C to exit and remove the program")
6268

6369
b := bpfBackend{
70+
// Hardcoded Src IP (main Nic)
6471
Saddr: ip2int("10.8.125.12"),
72+
// Hardcoded Dst IP (container)
6573
Daddr: ip2int("192.168.10.2"),
74+
// Hardcoded Dst Port (UDP echo server)
6675
Dport: 9875,
6776
// Host-Side Veth Mac
6877
Shwaddr: hwaddr2bytes("06:56:87:ec:fd:1f"),
6978
// Container-Side Veth Mac
7079
Dhwaddr: hwaddr2bytes("86:ad:33:29:ff:5e"),
71-
Nocksum: 1,
80+
Nocksum: 0,
81+
// Hardcoded Host side Veth index
7282
Ifindex: 8,
7383
}
7484

7585
key := bpfVipKey{
86+
// Hardcoded main NIC IP
7687
Vip: ip2int("10.8.125.12"),
77-
//Vip: ip2int("192.168.10.1"),
88+
// Hardcoded main NIC port
7889
Port: 8888,
7990
}
8091

0 commit comments

Comments
 (0)