Search code examples
ebpf

Network packet counting: Failure to read packet data from a BPF socket filter


I'd like to count incoming network packets and it len in bytes for each TOS value. I created two maps, the first one with 256 entries which contains packet count of each TOS value and the second with packet bytes. So I've written the following eBPF socket filter:

struct bpf_insn prog[]{
  BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),

  //we use dgram socket, so packet starts directly from IP header
  // BPF_LD_ABS(BPF_H, offsetof(struct ethhdr, h_proto)), // r0 = header type
  // BPF_JMP_IMM(BPF_JEQ, BPF_REG_0, ETH_P_IP, 2),        // if (r0 == IPv4) skip 2
  // BPF_MOV64_IMM(BPF_REG_0, 0),                         // r0 = 0
  // BPF_EXIT_INSN(),                                     // return

  //check for IP version, we only interested in v4
  BPF_LD_ABS(BPF_B, 0),                           // R0 = ip->vers: offsetof(struct iphdr, version)
  BPF_ALU64_IMM(BPF_AND, BPF_REG_0, 0xF0),        // r0 = r0 & 0xF0
  BPF_JMP_IMM(BPF_JEQ, BPF_REG_0, 0x40, 2),       // if (r0 == 0x40) goto pc+2
  BPF_MOV64_IMM(BPF_REG_0, 0),                    // r0 = 0
  BPF_EXIT_INSN(),                                // return

  // load packet TOS value
  BPF_LD_ABS(BPF_B, offsetof(struct iphdr, tos)), // R0 = ip->tos
  BPF_STX_MEM(BPF_W, BPF_REG_10, BPF_REG_0, -4),  // *(u32 *)(fp - 4) = r0
  BPF_MOV64_REG(BPF_REG_2, BPF_REG_10),           // r2 = fp
  BPF_ALU64_IMM(BPF_ADD, BPF_REG_2, -4),          // r2 = fp - 4
  //first map with packet counters
  BPF_LD_MAP_FD(BPF_REG_1, map_cnt_fd),           // r1 = map_fd
  BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0,
               BPF_FUNC_map_lookup_elem),         // r0 = map_lookup(r1, r2)
  BPF_JMP_IMM(BPF_JEQ, BPF_REG_0, 0, 2),          // if (r0 == 0) goto pc+2
  BPF_MOV64_IMM(BPF_REG_1, 1),                    // r1 = 1
  BPF_RAW_INSN(BPF_STX | BPF_XADD | BPF_DW,
               BPF_REG_0, BPF_REG_1, 0, 0),       // xadd r0 += r1

  BPF_LD_ABS(BPF_B, offsetof(struct iphdr, tos)), // R0 = ip->tos
  BPF_STX_MEM(BPF_W, BPF_REG_10, BPF_REG_0, -4),  // *(u32 *)(fp - 4) = r0
  BPF_MOV64_REG(BPF_REG_2, BPF_REG_10),
  BPF_ALU64_IMM(BPF_ADD, BPF_REG_2, -4),          // r2 = fp - 4
  //second map with packet bytes
  BPF_LD_MAP_FD(BPF_REG_1, map_bytes_fd),         // r1 = map_fd
  BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0,
               BPF_FUNC_map_lookup_elem),         // r0 = map_lookup(r1, r2)
  BPF_JMP_IMM(BPF_JEQ, BPF_REG_0, 0, 2),          // if (r0 == 0) goto pc+2
  // FIXME big endian
  BPF_LDX_MEM(BPF_H, BPF_REG_1, BPF_REG_6,
              offsetof(struct iphdr, tot_len)),   // r1 = tot_len
  BPF_RAW_INSN(BPF_STX | BPF_XADD | BPF_DW,
               BPF_REG_0, BPF_REG_1, 0, 0),       // xadd r0 += r1

  BPF_MOV64_IMM(BPF_REG_0, 0),                    // r0 = 0
  BPF_EXIT_INSN(),
};

Maps create without errors, socket filter program creates fine too and packet counter part works as it should. But bytes counter is always 0. What is the problem with that code?

I tried to write a simple example. To compile you just need to include bpf_insn.h.


Solution

  • Problem: Reading from the Socket Buffer

    The context placed in BPF_REG_1 before the program starts is not a pointer to the beginning of the data. Instead, it is a pointer to a struct __sk_buff defined in the UAPI headers as follows:

    struct __sk_buff {
        __u32 len;
        ...
    }
    

    So when you attempt to read data from your IP header:

      BPF_LDX_MEM(BPF_H, BPF_REG_1, BPF_REG_6, offsetof(struct iphdr, tot_len)),
    

    You are in fact reading two bytes at offset 2 from the struct __sk_buff (let's call its pointer skb). Because your system is in little endian, this corresponds to the most significant bits for skb->len, which are 0 unless you have packets bigger than 2^16 bytes (unlikely).

    We have two possible solutions here.

    Solution 1: Use Absolute Load

    We can update your program to read the IP length at the correct location. I believe this is not possible with a BPF_LDX_MEM(), because socket filters do not permit direct packet access. The workaround would be to use an absolute load instead. Your program would become:

    struct bpf_insn prog[]{
      BPF_MOV64_REG(BPF_REG_6, BPF_REG_1),
    
      // ... packet number counter, skipped for brevity
    
      // Read IP length and store to r7 (preserved during helper calls)
      BPF_LD_ABS(BPF_H,
                 offsetof(struct iphdr, tot_len)),    // r0 = tot_len
      BPF_MOV64_REG(BPF_REG_7, BPF_REG_0),            // r7 = r0
    
      // No need to parse ToS a second time here, skipped
      BPF_MOV64_REG(BPF_REG_2, BPF_REG_10),
      BPF_ALU64_IMM(BPF_ADD, BPF_REG_2, -4),          // r2 = fp - 4
      //second map with packet bytes
      BPF_LD_MAP_FD(BPF_REG_1, map_bytes_fd),         // r1 = map_fd
      BPF_RAW_INSN(BPF_JMP | BPF_CALL, 0, 0, 0,
                   BPF_FUNC_map_lookup_elem),         // r0 = map_lookup(r1, r2)
      BPF_JMP_IMM(BPF_JEQ, BPF_REG_0, 0, 2),          // if (r0 == 0) goto pc+2
    
      // Now add length to the counter
      BPF_STX_XADD(BPF_DW, BPF_REG_0, BPF_REG_7, 0),  // xadd r0 += r7
    
      BPF_MOV64_IMM(BPF_REG_0, 0),                    // r0 = 0
      BPF_EXIT_INSN(),
    };
    

    Solution 2: Just use skb->len

    The other solution is to get the length from skb, since the kernel has already computed it for us. This is just a matter of fixing the offset and length of the load you had, and your BPF_STX_MEM(), BPF_XADD() would become:

      BPF_LDX_MEM(BPF_W, BPF_REG_1, BPF_REG_6,
                  offsetof(struct __sk_buff, len)),         // r1 = skb->len
      BPF_STX_XADD(BPF_DW, BPF_REG_0, BPF_REG_1, 0),        // xadd r0 += r1