Search code examples
networkingudpchecksumnatip-fragmentation

Update UDP checksum in fragmented packets


I'm building a network appliance. I need to support NAT and IP packet fragmentation. When I change the source or destination address of a UDP packet, I have to correct the UDP checksum (and the IP checksum too, but that's trivial). When the packet is fragmented, I'd have to collect all the fragments to recalculate the checksum. I know the old address and the new address. I'd like to:

  1. Un-negate the checksum
  2. Subtract the old address
  3. Add the new address
  4. Re-reduce the sum and negate

This process doesn't always work. Is there any way to update the checksum versus having to recalculate it from scratch?

I've tried:

long CalcCheckSumAdd(unsigned char *pbHeader, int iSize, long lInitial){

    long lSum = lInitial;

    while (iSize > 1){

        lSum += *((unsigned short*)pbHeader);

        pbHeader += 2;

        iSize -= 2;

    }

    if (iSize > 0) lSum += *pbHeader;

    return lSum;

}

long CalcCheckSumSubract(unsigned char *pbHeader, int iSize, long lInitial){

    long lSum = lInitial;

    while (iSize > 1){

        lSum -= *((unsigned short*)pbHeader);

        pbHeader += 2;

        iSize -= 2;

    }

    if (iSize > 0) lSum -= *pbHeader;

    return lSum;

}

unsigned short CalcCheckSumFinish(long lSum){

    while (lSum >> 16){

        lSum = (lSum & 0xFFFF) + (lSum >> 16);

    }

    return (unsigned short)(~lSum);

}

long CalcCheckSumUnfinish(unsigned short usSum){

    // Can't totally undo lossy finish logic

    return ~usSum;

}

unsigned short CalcCheckSumUpdateAddress(unsigned short usOldSum, unsigned long ulOldAddress, unsigned long ulNewAddress){

    long lSumFixed = CalcCheckSumUnfinish(usOldSum);

    lSumFixed = CalcCheckSumSubract((unsigned char*)&ulOldAddress,sizeof(ulOldAddress),lSumFixed);

    lSumFixed = CalcCheckSumAdd((unsigned char*)&ulNewAddress,sizeof(ulNewAddress),lSumFixed);

    return CalcCheckSumFinish(lSumFixed);

}

Thanks!

EDIT: Added unit test code below

#include <time.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>

long CalcCheckSumAdd(unsigned char *pbHeader, int iSize, long lInitial){

    long lSum = lInitial;

    while (iSize > 1){

        lSum += *((unsigned short*)pbHeader);

        pbHeader += 2;

        iSize -= 2;

    }

    if (iSize > 0) lSum += *pbHeader;

    return lSum;

}

unsigned short CalcCheckSumFinish(long lSum){

    while (lSum >> 16){

        lSum = (lSum & 0xFFFF) + (lSum >> 16);

    }

    return (unsigned short)(~lSum);

}

void Randomize(unsigned char *pucPacket, unsigned long ulSize){

    for (unsigned long ulByte = 0; ulByte < ulSize; ulByte++){

        pucPacket[ulByte] = (unsigned char)(255 * rand() / RAND_MAX);

    }

}

unsigned short Calc(unsigned char *pucPacket, unsigned long ulSize){

    long lSum = CalcCheckSumAdd(pucPacket,ulSize,0);

    return CalcCheckSumFinish(lSum);

}

unsigned short Fix(unsigned short usOrig, unsigned int uiOld, unsigned int uiNew){

    // TODO: Replace this with something that makes main never fail
    usOrig -= uiOld & 0xffff;
    usOrig -= uiOld >> 16 & 0xffff;
    usOrig += uiNew & 0xffff;
    usOrig += uiNew >>16 & 0xffff;

    return usOrig;

}

void Break(unsigned char *pucPacket, unsigned int *puiOld, unsigned int *puiNew){

    unsigned int *puiChange = (unsigned int*)pucPacket;

    *puiOld = *puiChange;

    Randomize((unsigned char*)puiNew,sizeof(unsigned int));

    *puiChange = *puiNew;

}

void PrintBuffer(const char *szName, unsigned char *pucBuff, unsigned int uiSize){

    printf("%s: ",szName);

    for (unsigned int uiByte = 0; uiByte < uiSize; uiByte++){

        printf("%02X",(unsigned int)pucBuff[uiByte]);

    }

    printf("\n");

}

void PrintTestCase(unsigned char *pucOrig, unsigned char *pucChanged, unsigned int uiSize, unsigned short usOrig, unsigned short usChanged, unsigned short usFixed){

    PrintBuffer("Original Buffer",pucOrig,uiSize);
    PrintBuffer("Changed Buffer ",pucChanged,uiSize);

    printf("Orig    checksum: %04X\n",(unsigned int)usOrig);
    printf("Changed checksum: %04X\n",(unsigned int)usChanged);
    printf("Fixed   checksum: %04X\n",(unsigned int)usFixed);

}

int main(){

    srand((unsigned int)time(nullptr));

    unsigned char pucDataOrig[100];
    unsigned char pucDataChanged[100];

    bool bTestFailed = false;

    while (!bTestFailed){

        Randomize(pucDataOrig,sizeof(pucDataOrig));

        memcpy(pucDataChanged,pucDataOrig,sizeof(pucDataOrig));

        unsigned short usOrig = Calc(pucDataOrig,sizeof(pucDataOrig));

        unsigned int uiOld = 0,
                     uiNew = 0;

        Break(pucDataChanged,&uiOld,&uiNew);

        unsigned short usFixed = Fix(usOrig,uiOld,uiNew);

        unsigned short usChanged = Calc(pucDataChanged,sizeof(pucDataChanged));

        if (usChanged == usFixed){

            printf(".");

        }else{

            printf("\nTest case failed\n");
            PrintTestCase(pucDataOrig,pucDataChanged,sizeof(pucDataOrig),usOrig,usChanged,usFixed);

            bTestFailed = true;

        }

    }

    return 0;

}

Solution

  • You are right, the solution above works only on some cases, but I have a new implem that works for all kind of packet (fragmented or not, UDP, TCP, IP). Here is the implem:

    /* incremental checksum update */
    static inline void
    cksum_update(uint16_t *csum, uint32_t from, uint32_t to)
    {
        uint32_t sum, csum_c, from_c, res, res2, ret, ret2;
    
        csum_c = ~((uint32_t)*csum);
        from_c = ~from;
        res = csum_c + from_c;
        ret = res + (res < from_c);
    
       res2 = ret + to;
       ret2 = res2 + (res2 < to);
    
       sum = ret2;
       sum = (sum & 0xffff) + (sum >> 16);
       sum = (sum & 0xffff) + (sum >> 16);
       *csum = (uint16_t)~sum;
    
    }
    

    You can now use this function when you translated you packet address and before sending:

    /* Update L4 checksums on all packet a part from [2nd, n] fragment */
    switch (IS_FRAG(ipv4_hdr) ? 0 : ipv4_hdr->next_proto_id) {
    case IPPROTO_TCP:
    {
        struct tcp_hdr *tcp_hdr = tcp_header(pkt);
    
        /* Compute TCP checksum using incremental update */
        cksum_update(&tcp_hdr->cksum, old_ip_addr, *address);
        break;
    }
    case IPPROTO_UDPLITE:
    case IPPROTO_UDP:
    {
        struct udp_hdr *udp_hdr = udp_header(pkt);
    
        /* Compute UDP checksum using incremental update */
        cksum_update(&udp_hdr->dgram_cksum, old_ip_addr, *address);
        break;
    }
    default:
        break;
    }