Search code examples
javachecksumip-protocol

How to calculate the internet checksum from a byte[] in Java


I'm trying to figure out how to calculate the Internet Checksum in Java and its causing me no end of pain. (I'm horrible at bit manipulation.) I found a version in C# Calculate an Internet (aka IP, aka RFC791) checksum in C#. However my attempt at converting it to Java does not see to produce the correct results. Can anyone see what I'm doing wrong? I suspect a data type issue.

public long getValue() {
    byte[] buf = { (byte) 0xed, 0x2A, 0x44, 0x10, 0x03, 0x30};
    int length = buf.length;
    int i = 0;

    long sum = 0;
    long data = 0;
    while (length > 1) {
        data = 0;
        data = (((buf[i]) << 8) | ((buf[i + 1]) & 0xFF));

        sum += data;
        if ((sum & 0xFFFF0000) > 0) {
            sum = sum & 0xFFFF;
            sum += 1;
        }

        i += 2;
        length -= 2;
    }

    if (length > 0) {
        sum += (buf[i] << 8);
        // sum += buffer[i];
        if ((sum & 0xFFFF0000) > 0) {
            sum = sum & 0xFFFF;
            sum += 1;
        }
    }
    sum = ~sum;
    sum = sum & 0xFFFF;
    return sum;
}

Solution

  • Edited to apply comments from @Andy, @EJP, @RD et al and adding extra test cases just to be sure.

    I've used a combination of @Andys answer (correctly identifying the location of the problem) and updated the code to include the unit tests provided in the linked answer along with a verified message checksum additional test case.

    First the implementation

    package org.example.checksum;
    
    public class InternetChecksum {
    
      /**
       * Calculate the Internet Checksum of a buffer (RFC 1071 - http://www.faqs.org/rfcs/rfc1071.html)
       * Algorithm is
       * 1) apply a 16-bit 1's complement sum over all octets (adjacent 8-bit pairs [A,B], final odd length is [A,0])
       * 2) apply 1's complement to this final sum
       *
       * Notes:
       * 1's complement is bitwise NOT of positive value.
       * Ensure that any carry bits are added back to avoid off-by-one errors
       *
       *
       * @param buf The message
       * @return The checksum
       */
      public long calculateChecksum(byte[] buf) {
        int length = buf.length;
        int i = 0;
    
        long sum = 0;
        long data;
    
        // Handle all pairs
        while (length > 1) {
          // Corrected to include @Andy's edits and various comments on Stack Overflow
          data = (((buf[i] << 8) & 0xFF00) | ((buf[i + 1]) & 0xFF));
          sum += data;
          // 1's complement carry bit correction in 16-bits (detecting sign extension)
          if ((sum & 0xFFFF0000) > 0) {
            sum = sum & 0xFFFF;
            sum += 1;
          }
    
          i += 2;
          length -= 2;
        }
    
        // Handle remaining byte in odd length buffers
        if (length > 0) {
          // Corrected to include @Andy's edits and various comments on Stack Overflow
          sum += (buf[i] << 8 & 0xFF00);
          // 1's complement carry bit correction in 16-bits (detecting sign extension)
          if ((sum & 0xFFFF0000) > 0) {
            sum = sum & 0xFFFF;
            sum += 1;
          }
        }
    
        // Final 1's complement value correction to 16-bits
        sum = ~sum;
        sum = sum & 0xFFFF;
        return sum;
    
      }
    
    }
    

    Then the unit test in JUnit4

    package org.example.checksum;
    
    import org.junit.Test;
    
    import static junit.framework.Assert.assertEquals;
    
    public class InternetChecksumTest {
      @Test
      public void simplestValidValue() {
        InternetChecksum testObject = new InternetChecksum();
    
        byte[] buf = new byte[1]; // should work for any-length array of zeros
        long expected = 0xFFFF;
    
        long actual = testObject.calculateChecksum(buf);
    
        assertEquals(expected, actual);
      }
    
      @Test
      public void validSingleByteExtreme() {
        InternetChecksum testObject = new InternetChecksum();
    
        byte[] buf = new byte[]{(byte) 0xFF};
        long expected = 0xFF;
    
        long actual = testObject.calculateChecksum(buf);
    
        assertEquals(expected, actual);
      }
    
      @Test
      public void validMultiByteExtrema() {
        InternetChecksum testObject = new InternetChecksum();
    
        byte[] buf = new byte[]{0x00, (byte) 0xFF};
        long expected = 0xFF00;
    
        long actual = testObject.calculateChecksum(buf);
    
        assertEquals(expected, actual);
      }
    
      @Test
      public void validExampleMessage() {
        InternetChecksum testObject = new InternetChecksum();
    
        // Berkley example http://www.cs.berkeley.edu/~kfall/EE122/lec06/tsld023.htm
        // e3 4f 23 96 44 27 99 f3
        byte[] buf = {(byte) 0xe3, 0x4f, 0x23, (byte) 0x96, 0x44, 0x27, (byte) 0x99, (byte) 0xf3};
    
        long expected = 0x1aff;
    
        long actual = testObject.calculateChecksum(buf);
    
        assertEquals(expected, actual);
      }
    
      @Test
      public void validExampleEvenMessageWithCarryFromRFC1071() {
        InternetChecksum testObject = new InternetChecksum();
    
        // RFC1071 example http://www.ietf.org/rfc/rfc1071.txt
        // 00 01 f2 03 f4 f5 f6 f7
        byte[] buf = {(byte) 0x00, 0x01, (byte) 0xf2, (byte) 0x03, (byte) 0xf4, (byte) 0xf5, (byte) 0xf6, (byte) 0xf7};
    
        long expected = 0x220d;
    
        long actual = testObject.calculateChecksum(buf);
    
        assertEquals(expected, actual);
    
      }
    
    }