1+ // +build ignore
2+
3+ #include <linux/bpf_common.h>
4+ #include <linux/if_ether.h>
5+ #include <linux/in.h>
6+ #include <linux/ip.h>
7+ #include <linux/udp.h>
8+ #include <linux/bpf.h>
9+ #include <bpf_helpers.h>
10+ #include <bpf_endian.h>
11+
12+ char __license [] SEC ("license" ) = "GPL" ;
13+
14+ #ifndef memcpy
15+ #define memcpy (dest , src , n ) __builtin_memcpy((dest), (src), (n))
16+ #endif
17+
18+ #define MAX_BACKENDS 128
19+ #define MAX_UDP_LENGTH 1480
20+
21+ #define UDP_PAYLOAD_SIZE (x ) (unsigned int)(((bpf_htons(x) - sizeof(struct udphdr)) * 8 ) / 4)
22+
23+ static __always_inline void ip_from_int (__u32 * buf , __be32 ip ) {
24+ buf [0 ] = (ip >> 0 ) & 0xFF ;
25+ buf [1 ] = (ip >> 8 ) & 0xFF ;
26+ buf [2 ] = (ip >> 16 ) & 0xFF ;
27+ buf [3 ] = (ip >> 24 ) & 0xFF ;
28+ }
29+
30+ static __always_inline void bpf_printk_ip (__be32 ip ) {
31+ __u32 ip_parts [4 ];
32+ ip_from_int ((__u32 * )& ip_parts , ip );
33+ bpf_printk ("%d.%d.%d.%d" , ip_parts [0 ], ip_parts [1 ], ip_parts [2 ], ip_parts [3 ]);
34+ }
35+
36+ static __always_inline __u16 csum_fold_helper (__u64 csum ) {
37+ int i ;
38+ #pragma unroll
39+ for (i = 0 ; i < 4 ; i ++ )
40+ {
41+ if (csum >> 16 )
42+ csum = (csum & 0xffff ) + (csum >> 16 );
43+ }
44+ return ~csum ;
45+ }
46+
47+ static __always_inline __u16 iph_csum (struct iphdr * iph ) {
48+ iph -> check = 0 ;
49+ unsigned long long csum = bpf_csum_diff (0 , 0 , (unsigned int * )iph , sizeof (struct iphdr ), 0 );
50+ return csum_fold_helper (csum );
51+ }
52+
53+ // static __always_inline __u16 udp_checksum(struct iphdr *ip, struct udphdr * udp, void * data_end) {
54+ // udp->check = 0;
55+
56+ // // So we can overflow a bit make this __u32
57+ // __u32 csum_total = 0;
58+ // __u16 csum;
59+ // __u16 *buf = (void *)udp;
60+
61+ // csum_total += (__u16)ip->saddr;
62+ // csum_total += (__u16)(ip->saddr >> 16);
63+ // csum_total += (__u16)ip->daddr;
64+ // csum_total += (__u16)(ip->daddr >> 16);
65+ // csum_total += (__u16)(ip->protocol << 8);
66+ // csum_total += udp->len;
67+
68+ // // The number of nibbles in the UDP header + Payload
69+ // unsigned int udp_packet_nibbles = UDP_PAYLOAD_SIZE(udp->len);
70+
71+ // // Here we only want to iterate through payload
72+ // // NOT trailing bits
73+ // for (int i = 0; i <= MAX_UDP_LENGTH; i += 2) {
74+ // if (i > udp_packet_nibbles) {
75+ // break;
76+ // }
77+
78+ // if ((void *)(buf + 1) > data_end) {
79+ // break;
80+ // }
81+ // csum_total += *buf;
82+ // buf++;
83+ // }
84+
85+ // if ((void *)buf + 1 <= data_end) {
86+ // csum_total += (*(__u8 *)buf);
87+ // }
88+
89+ // // Add any cksum overflow back into __u16
90+ // csum = (__u16)csum_total + (__u16)(csum_total >> 16);
91+
92+ // csum = ~csum;
93+ // return csum;
94+ // }
95+
96+ struct backend {
97+ __u32 saddr ;
98+ __u32 daddr ;
99+ __u16 dport ;
100+ __u16 ifindex ;
101+ // Cksum isn't required for UDP see:
102+ // https://en.wikipedia.org/wiki/User_Datagram_Protocol
103+ __u8 nocksum ;
104+ __u8 pad [3 ];
105+ };
106+
107+ struct vip_key {
108+ __u32 vip ;
109+ __u16 port ;
110+ __u8 pad [2 ];
111+ };
112+
113+ struct {
114+ __uint (type , BPF_MAP_TYPE_HASH );
115+ __uint (max_entries , MAX_BACKENDS );
116+ __type (key , struct vip_key );
117+ __type (value , struct backend );
118+ } backends SEC (".maps" );
119+
120+ SEC ("classifier" )
121+ int tc_prog_func (struct xdp_md * ctx ) {
122+ // ---------------------------------------------------------------------------
123+ // Initialize
124+ // ---------------------------------------------------------------------------
125+
126+ void * data = (void * )(long )ctx -> data ;
127+ void * data_end = (void * )(long )ctx -> data_end ;
128+
129+ struct ethhdr * eth = data ;
130+ if (data + sizeof (struct ethhdr ) > data_end ) {
131+ bpf_printk ("ABORTED: bad ethhdr!" );
132+ return XDP_ABORTED ;
133+ }
134+
135+ if (bpf_ntohs (eth -> h_proto ) != ETH_P_IP ) {
136+ bpf_printk ("PASS: not IP protocol!" );
137+ return XDP_PASS ;
138+ }
139+
140+ struct iphdr * ip = data + sizeof (struct ethhdr );
141+ if (data + sizeof (struct ethhdr ) + sizeof (struct iphdr ) > data_end ) {
142+ bpf_printk ("ABORTED: bad iphdr!" );
143+ return XDP_ABORTED ;
144+ }
145+
146+ if (ip -> protocol != IPPROTO_UDP )
147+ return XDP_PASS ;
148+
149+ struct udphdr * udp = data + sizeof (struct ethhdr ) + sizeof (struct iphdr );
150+ if (data + sizeof (struct ethhdr ) + sizeof (struct iphdr ) + sizeof (struct udphdr ) > data_end ) {
151+ bpf_printk ("ABORTED: bad udphdr!" );
152+ return XDP_ABORTED ;
153+ }
154+
155+ bpf_printk ("UDP packet received - daddr:%x, port:%d" , ip -> daddr , bpf_ntohs (udp -> dest ));
156+
157+ // ---------------------------------------------------------------------------
158+ // Routing
159+ // ---------------------------------------------------------------------------
160+
161+ struct vip_key key = {
162+ .vip = ip -> daddr ,
163+ .port = bpf_ntohs (udp -> dest )
164+ };
165+
166+ struct backend * bk ;
167+ bk = bpf_map_lookup_elem (& backends , & key );
168+ if (!bk ) {
169+ bpf_printk ("no backends for ip %x:%x" , key .vip , key .port );
170+ return XDP_PASS ;
171+ }
172+
173+ bpf_printk ("got UDP traffic, source address:" );
174+ bpf_printk_ip (ip -> saddr );
175+ bpf_printk ("destination address:" );
176+ bpf_printk_ip (ip -> daddr );
177+
178+ ip -> saddr = bk -> saddr ;
179+ ip -> daddr = bk -> daddr ;
180+
181+ bpf_printk ("updated saddr to:" );
182+ bpf_printk_ip (ip -> saddr );
183+ bpf_printk ("updated daddr to:" );
184+ bpf_printk_ip (ip -> daddr );
185+
186+ if (udp -> dest != bpf_ntohs (bk -> dport )) {
187+ udp -> dest = bpf_ntohs (bk -> dport );
188+ bpf_printk ("updated dport to: %d" , bk -> dport );
189+ }
190+
191+ // memcpy(eth->h_source, bk->shwaddr, sizeof(eth->h_source));
192+ // bpf_printk("new source hwaddr %x:%x:%x:%x:%x:%x", eth->h_source[0], eth->h_source[1], eth->h_source[2], eth->h_source[3], eth->h_source[4], eth->h_source[5]);
193+
194+ // memcpy(eth->h_dest, bk->dhwaddr, sizeof(eth->h_dest));
195+ // bpf_printk("new dest hwaddr %x:%x:%x:%x:%x:%x", eth->h_dest[0], eth->h_dest[1], eth->h_dest[2], eth->h_dest[3], eth->h_dest[4], eth->h_dest[5]);
196+
197+ ip -> check = iph_csum (ip );
198+ udp -> check = 0 ;
199+
200+ if (!bk -> nocksum ){
201+ udp -> check = udp_checksum (ip , udp , data_end );
202+ }
203+
204+ bpf_printk ("destination interface index %d" , bk -> ifindex );
205+
206+ int action = bpf_redirect (bk -> ifindex , 0 );
207+
208+ bpf_printk ("redirect action: %d" , action );
209+
210+ return action ;
211+ }
212+
213+ // SEC("xdp")
214+ // int bpf_redirect_placeholder(struct xdp_md *ctx) {
215+ // bpf_printk("received a packet on dest interface");
216+ // return XDP_PASS;
217+ // }
0 commit comments