Search code examples
clinux-kernelebpfxdp-bpf

How to write an ebpf/xdp program to drop TCP packets containing a certain byte pattern?


I'm trying force the retransmission of certain TCP packet to test some feature.

I would like to drop RX TCP packets (dst port 4420) whose payload is all zeros except for the last 4 bytes which should have at least a non-null byte.

Unfortunately I cannot get around the eBPF verifier :( I'm trying to compromise with "starts with at least 50 zero bytes followed by a non-zero byte" but still no luck.

#define KBUILD_MODNAME "xdp_nvme_drop"
#include <linux/bpf.h>
#include <linux/in.h>
#include <linux/if_ether.h>
#include <linux/if_packet.h>
#include <linux/if_vlan.h>
#include <linux/ip.h>
#include <linux/ipv6.h>
#include <stdint.h>
#include <stdbool.h>
#include <bpf/bpf_helpers.h>
#include <bpf/bpf_endian.h>
#include "../common/parsing_helpers.h"

static inline int parse_ipv4(void *data, uint64_t nh_off, void *data_end) {
        struct iphdr *iph = data + nh_off;

        if (data + nh_off + sizeof(struct iphdr) > data_end)
                return 0;
        return iph->protocol;
}

static inline int parse_ipv6(void *data, uint64_t nh_off, void *data_end) {
        struct ipv6hdr *ip6h = data + nh_off;

        if (data + nh_off + sizeof(struct ipv6hdr) > data_end)
                return 0;
        return ip6h->nexthdr;
}

//static uint64_t zeroes[1024];

SEC("xdp")
int nvme_drop(struct xdp_md *ctx) {
        void* data_end = (void*)(long)ctx->data_end;
        void* data = (void*)(long)ctx->data;
        uint64_t total = data_end-data;
        struct ethhdr *eth = data;
        uint16_t h_proto;
        uint32_t cur = 0;
        struct tcphdr *tcph;
        uint32_t i;
        int nbzeros = 0;
        bool found = 0;
        cur = sizeof(*eth);

        if (data + cur  > data_end)
                return XDP_PASS;

        h_proto = eth->h_proto;
        if (h_proto == bpf_htons(ETH_P_IP)) {
                h_proto = parse_ipv4(data, cur, data_end);
                cur += sizeof(struct iphdr);
        } else if (h_proto == bpf_htons(ETH_P_IPV6)) {
                h_proto = parse_ipv6(data, cur, data_end);
                cur += sizeof(struct ipv6hdr);
        } else {
                return XDP_PASS;
        }

        if (cur > 100)
                return XDP_PASS;

        if (h_proto != IPPROTO_TCP)
                return XDP_PASS;

        if (data + cur + sizeof(*tcph) > data_end)
                return XDP_PASS;

        tcph = data + cur;
        if (tcph->doff > 10)
                return XDP_PASS;

        if (data + cur + tcph->doff * 4 > data_end)
                return XDP_PASS;

        cur += tcph->doff * 4;
        
        if (tcph->dest != 4420)
                return XDP_PASS;

        if (cur > total || cur > 100)
                return XDP_PASS;

        nbzeros = 0;
        for (i = cur; data+i < data_end; i++) {
                if (*((uint8_t*)(data+i)) == 0 && !found) {
                        nbzeros++;
                } else {
                        found = true;
                        break;
                }

        }

        if (found && nbzeros > 50) {
                bpf_printk("found nvme pdu tail seq=%u\n", bpf_ntohs(tcph->seq));
        }
        return XDP_PASS;
}

char _license[] SEC("license") = "GPL";

Verifier output:

Validating nvme_drop() func#0...
0: R1=ctx(off=0,imm=0) R10=fp0
; void* data = (void*)(long)ctx->data;
0: (61) r2 = *(u32 *)(r1 +0)          ; R1=ctx(off=0,imm=0) R2_w=pkt(off=0,r=0,imm=0)
; void* data_end = (void*)(long)ctx->data_end;
1: (61) r1 = *(u32 *)(r1 +4)          ; R1_w=pkt_end(off=0,imm=0)
; if (data + cur  > data_end)
2: (bf) r3 = r2                       ; R2_w=pkt(off=0,r=0,imm=0) R3_w=pkt(off=0,r=0,imm=0)
3: (07) r3 += 14                      ; R3_w=pkt(off=14,r=0,imm=0)
; if (data + cur  > data_end)
4: (2d) if r3 > r1 goto pc+12         ; R1_w=pkt_end(off=0,imm=0) R3_w=pkt(off=14,r=14,imm=0)
; h_proto = eth->h_proto;
5: (71) r4 = *(u8 *)(r2 +12)          ; R2_w=pkt(off=0,r=14,imm=0) R4_w=scalar(umax=255,var_off=(0x0; 0xff))
6: (71) r3 = *(u8 *)(r2 +13)          ; R2_w=pkt(off=0,r=14,imm=0) R3_w=scalar(umax=255,var_off=(0x0; 0xff))
7: (67) r3 <<= 8                      ; R3_w=scalar(umax=65280,var_off=(0x0; 0xff00))
8: (4f) r3 |= r4                      ; R3_w=scalar() R4_w=scalar(umax=255,var_off=(0x0; 0xff))
; if (h_proto == bpf_htons(ETH_P_IP)) {
9: (15) if r3 == 0xdd86 goto pc+9     ; R3_w=scalar()
10: (55) if r3 != 0x8 goto pc+6       ; R3=8
11: (b7) r3 = 34                      ; R3_w=34
12: (b7) r4 = 23                      ; R4_w=23
; if (data + nh_off + sizeof(struct iphdr) > data_end)
13: (bf) r5 = r2                      ; R2=pkt(off=0,r=14,imm=0) R5_w=pkt(off=0,r=14,imm=0)
14: (07) r5 += 34                     ; R5_w=pkt(off=34,r=14,imm=0)
; if (data + nh_off + sizeof(struct iphdr) > data_end)
15: (2d) if r5 > r1 goto pc+1         ; R1=pkt_end(off=0,imm=0) R5_w=pkt(off=34,r=34,imm=0)
16: (05) goto pc+7
;
24: (bf) r5 = r2                      ; R2=pkt(off=0,r=34,imm=0) R5_w=pkt(off=0,r=34,imm=0)
25: (0f) r5 += r4                     ; R4_w=P23 R5_w=pkt(off=23,r=34,imm=0)
26: (71) r4 = *(u8 *)(r5 +0)          ; R4=scalar(umax=255,var_off=(0x0; 0xff)) R5=pkt(off=23,r=34,imm=0)
; if (h_proto != IPPROTO_TCP)
27: (55) if r4 != 0x6 goto pc-11      ; R4=6
; if (data + cur + sizeof(*tcph) > data_end)
28: (bf) r4 = r2                      ; R2=pkt(off=0,r=34,imm=0) R4_w=pkt(off=0,r=34,imm=0)
29: (0f) r4 += r3                     ; R3=P34 R4_w=pkt(off=34,r=34,imm=0)
; if (data + cur + sizeof(*tcph) > data_end)
30: (bf) r5 = r4                      ; R4_w=pkt(off=34,r=34,imm=0) R5_w=pkt(off=34,r=34,imm=0)
31: (07) r5 += 20                     ; R5_w=pkt(off=54,r=34,imm=0)
; if (data + cur + sizeof(*tcph) > data_end)
32: (2d) if r5 > r1 goto pc-16        ; R1=pkt_end(off=0,imm=0) R5_w=pkt(off=54,r=54,imm=0)
; if (tcph->doff > 10)
33: (69) r5 = *(u16 *)(r4 +12)        ; R4_w=pkt(off=34,r=54,imm=0) R5_w=scalar(umax=65535,var_off=(0x0; 0xffff))
34: (77) r5 >>= 4                     ; R5_w=scalar(umax=4095,var_off=(0x0; 0xfff))
35: (57) r5 &= 15                     ; R5=scalar(umax=15,var_off=(0x0; 0xf))
; if (tcph->doff > 10)
36: (25) if r5 > 0xa goto pc-20       ; R5=scalar(umax=10,var_off=(0x0; 0xf))
; if (data + cur + tcph->doff * 4 > data_end)
37: (67) r5 <<= 2                     ; R5_w=scalar(umax=40,var_off=(0x0; 0x3c))
; if (data + cur + tcph->doff * 4 > data_end)
38: (bf) r0 = r4                      ; R0_w=pkt(off=34,r=54,imm=0) R4=pkt(off=34,r=54,imm=0)
39: (0f) r0 += r5                     ; R0_w=pkt(id=1,off=34,r=0,umax=40,var_off=(0x0; 0x3c),s32_max=60,u32_max=60) R5_w=Pscalar(umax=40,var_off=(0x0; 0x3c))
; if (data + cur + tcph->doff * 4 > data_end)
40: (2d) if r0 > r1 goto pc-24        ; R0_w=pkt(id=1,off=34,r=34,umax=40,var_off=(0x0; 0x3c),s32_max=60,u32_max=60) R1=pkt_end(off=0,imm=0)
; if (tcph->dest != 4420)
41: (69) r0 = *(u16 *)(r4 +2)         ; R0_w=scalar(umax=65535,var_off=(0x0; 0xffff)) R4=pkt(off=34,r=54,imm=0)
; if (tcph->dest != 4420)
42: (55) if r0 != 0x1144 goto pc-26   ; R0_w=4420
43: (bf) r0 = r1                      ; R0_w=pkt_end(off=0,imm=0) R1=pkt_end(off=0,imm=0)
44: (1f) r0 -= r2                     ; R0_w=scalar() R2=pkt(off=0,r=54,imm=0)
45: (bf) r6 = r3                      ; R3=P34 R6_w=P34
46: (0f) r6 += r5                     ; R5=Pscalar(umax=40,var_off=(0x0; 0x3c)) R6=Pscalar(umin=34,umax=74,var_off=(0x2; 0x7c))
47: (2d) if r6 > r0 goto pc-31        ; R0=scalar() R6=Pscalar(umin=34,umax=74,var_off=(0x2; 0x7c))
48: (bf) r0 = r2                      ; R0_w=pkt(off=0,r=54,imm=0) R2=pkt(off=0,r=54,imm=0)
49: (0f) r0 += r6                     ; R0_w=pkt(id=3,off=0,r=0,umin=34,umax=74,var_off=(0x2; 0x7c),s32_min=2,s32_max=126,u32_min=2,u32_max=126) R6=Pscalar(umin=34,umax=74,var_off=(0x2; 0x7c))
50: (3d) if r0 >= r1 goto pc-34       ; R0_w=pkt(id=3,off=0,r=0,umin=34,umax=74,var_off=(0x2; 0x7c),s32_min=2,s32_max=126,u32_min=2,u32_max=126) R1=pkt_end(off=0,imm=0)
; for (i = cur; data+i < data_end; i++) {
51: (0f) r3 += r5                     ; R3_w=Pscalar(umin=34,umax=74,var_off=(0x2; 0x7c)) R5=Pscalar(umax=40,var_off=(0x0; 0x3c))
52: (b7) r5 = 0                       ; R5_w=0
53: (07) r3 += 1                      ; R3_w=Pscalar(umin=35,umax=75,var_off=(0x3; 0x7c))
; if (*((uint8_t*)(data+i)) == 0 && !found) {
54: (71) r0 = *(u8 *)(r0 +0)
invalid access to packet, off=0 size=1, R0(id=3,off=0,r=0)
R0 offset is outside of the packet
processed 48 insns (limit 1000000) max_states_per_insn 0 total_states 4 peak_states 4 mark_read 1

Solution

  • There are two issues here. First is an off by 1 error in the loop, you need to account for the width of the read in the for loop condition:

    #define KBUILD_MODNAME "xdp_nvme_drop"
    #include <linux/bpf.h>
    #include <linux/in.h>
    #include <linux/if_ether.h>
    #include <linux/if_packet.h>
    #include <linux/if_vlan.h>
    #include <linux/ip.h>
    #include <linux/ipv6.h>
    #include <stdint.h>
    #include <stdbool.h>
    #include <bpf/bpf_helpers.h>
    #include <bpf/bpf_endian.h>
    #include "../common/parsing_helpers.h"
    
    static inline int parse_ipv4(void *data, uint64_t nh_off, void *data_end) {
            struct iphdr *iph = data + nh_off;
    
            if (data + nh_off + sizeof(struct iphdr) > data_end)
                    return 0;
            return iph->protocol;
    }
    
    static inline int parse_ipv6(void *data, uint64_t nh_off, void *data_end) {
            struct ipv6hdr *ip6h = data + nh_off;
    
            if (data + nh_off + sizeof(struct ipv6hdr) > data_end)
                    return 0;
            return ip6h->nexthdr;
    }
    
    //static uint64_t zeroes[1024];
    
    SEC("xdp")
    int nvme_drop(struct xdp_md *ctx) {
            void* data_end = (void*)(long)ctx->data_end;
            void* data = (void*)(long)ctx->data;
            uint64_t total = data_end-data;
            struct ethhdr *eth = data;
            uint16_t h_proto;
            uint32_t cur = 0;
            struct tcphdr *tcph;
            uint32_t i;
            int nbzeros = 0;
            bool found = 0;
            cur = sizeof(*eth);
    
            if (data + cur  > data_end)
                    return XDP_PASS;
    
            h_proto = eth->h_proto;
            if (h_proto == bpf_htons(ETH_P_IP)) {
                    h_proto = parse_ipv4(data, cur, data_end);
                    cur += sizeof(struct iphdr);
            } else if (h_proto == bpf_htons(ETH_P_IPV6)) {
                    h_proto = parse_ipv6(data, cur, data_end);
                    cur += sizeof(struct ipv6hdr);
            } else {
                    return XDP_PASS;
            }
    
            if (cur > 100)
                    return XDP_PASS;
    
            if (h_proto != IPPROTO_TCP)
                    return XDP_PASS;
    
            if (data + cur + sizeof(*tcph) > data_end)
                    return XDP_PASS;
    
            tcph = data + cur;
            if (tcph->doff > 10)
                    return XDP_PASS;
    
            if (data + cur + tcph->doff * 4 > data_end)
                    return XDP_PASS;
    
            cur += tcph->doff * 4;
            
            if (tcph->dest != 4420)
                    return XDP_PASS;
    
            if (cur > total || cur > 100)
                    return XDP_PASS;
    
            nbzeros = 0;
            for (i = cur; data+i+sizeof(uint8_t) < data_end; i++) {
                    if (*((uint8_t*)(data+i)) == 0 && !found) {
                            nbzeros++;
                    } else {
                            found = true;
                            break;
                    }
    
            }
    
            if (found && nbzeros > 50) {
                    bpf_printk("found nvme pdu tail seq=%u\n", bpf_ntohs(tcph->seq));
            }
            return XDP_PASS;
    }
    
    char _license[] SEC("license") = "GPL";
    

    Once you compile and try the above you will run into the second issue which is the complexity of the code.

    The sequence of 8193 jumps is too complex

    That is because the verifier has to check for every possible iteration if the body of the loop is valid. But since data_end is a u32 without limits that becomes to much. What we can do to fix that is to set an upper limit for the number of iterations. like:

    for (int i = 0; i < 100; i++) {
        if (data + i + sizeof(uint8_t) >= data_end)
            break;
    
        if (*((uint8_t*)(data+i)) == 0 && !found) {
            nbzeros++;
        } else {
         found = true;
         break;
        }
    }
    

    However, due to the way the current program is written I had a hard time modifying it. The verifier tracks which variables have been offset checked and which haven't and the current data + cur way of tracking the offset generates code that confuses the verifier. So I took the liberty to rewrite it in such a way that everything passes the verifier:

    SEC("xdp")
    int nvme_drop(struct xdp_md *ctx)
    {
        void *data_end = (void *)(long)ctx->data_end;
        void *data = (void *)(long)ctx->data;
        void *head = data;
        struct ethhdr *eth;
        struct iphdr *iph;
        struct ipv6hdr *ip6h;
        struct tcphdr *tcph;
        uint16_t h_proto;
        uint8_t *tcp_data;
        int nbzeros = 0;
        int i = 0;
        bool found = false;
    
        eth = head;
        if ((void *)eth + sizeof(struct ethhdr) >= data_end)
            return XDP_PASS;
        head += sizeof(struct ethhdr);
    
        h_proto = eth->h_proto;
        switch (h_proto)
        {
        case bpf_htons(ETH_P_IP):
            iph = head;
            if ((void *)iph + sizeof(struct iphdr) >= data_end)
                return XDP_PASS;
    
            h_proto = iph->protocol;
    
            head += iph->ihl * 4;
    
            break;
    
        case bpf_htons(ETH_P_IPV6):
            ip6h = head;
            if ((void *)ip6h + sizeof(struct ipv6hdr) >= data_end)
                return XDP_PASS;
    
            h_proto = ip6h->nexthdr;
    
            head += sizeof(struct ipv6hdr);
    
            break;
    
        default:
            return XDP_PASS;
        }
    
        if (h_proto != IPPROTO_TCP)
            return XDP_PASS;
    
        tcph = head;
        if ((void *)tcph + sizeof(*tcph) > data_end)
            return XDP_PASS;
        head += sizeof(*tcph);
    
        if (head + tcph->doff * 4 > data_end)
            return XDP_PASS;
        head += tcph->doff * 4;
    
        if (tcph->dest != 4420)
            return XDP_PASS;
    
        tcp_data = head;
    
        // 1500 is the typical MTU size
        #define MAX_ITER 1500
    
        for (i = 0; i < MAX_ITER; i++)
        {
            if ((void *)tcp_data + i + 1 >= data_end)
                return XDP_PASS;
    
            if (tcp_data[i] == 0)
            {
                nbzeros++;
                continue;
            }
    
            found = true;
            break;
        }
    
        if (found && nbzeros > 50)
        {
            bpf_printk("found nvme pdu tail seq=%u\n", bpf_ntohs(tcph->seq));
        }
    
        return XDP_PASS;
    }
    

    This should be functionally the same. In this case I chose 1500 as the max iteration count since most packets will never reach that size. Though you might need to tune that.