Search code examples
algorithmscaladata-structuresip-addresstrie

What is the most efficient way to find out which subnet an IP belongs to


Suppose I have billions of records that contain IPv4 field each. And I want to find out:

  1. If each record belongs to one of the subnet I concerns
  2. Which subnet it belongs to if it satisfies requirement 1.

Each concerned subnet is defined as several masks(61.232.85.0/25, 61.232.86.0/27) and maybe some discrete IPs, and I have a total number of 1000 subnet definitions to deal with.

So, What's the most efficient algorithm and/or data structure to handle the job?

Also, I've find Trie a possible solution, any suggestions?


Solution

  • You could represent the subnets as int values, put them all in an array, sort it and then use a binary search on that. Although longs might be better since then you don't have to deal with negative numbers in the binary search. That would take a maximum of 10 steps to find a subnet. A trie might take e.g. 25 steps for a /25 subnet. Also remember that 1000 longs is only 8 kb. That will easily fit into the CPU cache, making it really fast. And of course you could then use a second array that stores which subnet each of the masks belongs to.

    Here is an example in Scala findMaskIdx finds the index of a given mask (the ip part of the subnet definition) using a binary search. If it can't find anything it returns the index of the first mask that's larger than the one it searched for. findIpIdx takes an ip address and returns the index of the subnet definition it belongs to or -1 if nothing is found.

    findIpIdx can be run around 100 to 200 million times per second. So it seems to be quite fast. There is only one problem with this approach. If two subnets of different size overlap the code might find the wrong one. But I hope that shouldn't be too difficult to fix.

    def ipStringToInt(s: String): Int = {
      var ip = 0
      for(num <- s.split("\\.")) {
        ip = ip * 256 + num.toInt
      }
      ip
    }
    
    def parseSubnet(s: String): (Long, Int) = {
      val mask_length = s.split("/")
      val length = if(mask_length.size > 1) mask_length(1).toInt else 32
      var mask = ipStringToInt(mask_length(0)) & 0xFFFFFFFFL
      (mask, length)
    }
    
    val subnetGroups = Vector(
      Vector("61.232.85.0/25", "61.232.86.0/27"),
      Vector("123.234.12.24/16", "1.2.3.4"),
      Vector("61.232.87.5", "253.2.0.0/16")
    )
    
    val subnetData = (for {
      (group, idx) <- subnetGroups.zipWithIndex
      maskString <- group
      (mask, length) = parseSubnet(maskString)
    } yield (mask, length, idx)).sortBy(_._1)
    
    val masks: Array[Long] = subnetData.map(_._1).toArray
    val maskLengths: Array[Int] = subnetData.map(_._2).toArray
    val groupNr: Array[Int] = subnetData.map(_._3).toArray
    
    def findMaskIdx(ip: Long): Int = {
      var low = 0
      var high = masks.size
      while(high > low) {
        val mid = (low + high)/2
        if(masks(mid) > ip) high = mid
        else if(masks(mid) < ip) low = mid + 1
        else return mid
      }
      low
    }
    
    def findIpIdx(ip: Int): Int = {
      val ipLong = ip & 0xFFFFFFFFL
      var idx = findMaskIdx(ipLong)
      if(idx < masks.size && masks(idx) == ipLong) return idx
      idx -= 1
      if(idx < 0) return -1
      val m = (0xFFFFFFFF00000000L >>> maskLengths(idx)) & 0xFFFFFFFFL
      if((m & masks(idx)) == (m & ipLong)) return idx
      return -1
    }
    
    
    println("subnet data (mask, bit length of mask, index of subnet group):")
    println(subnetData.map {
      case (mask, length, idx) => (mask.toHexString, length, idx)
    })
    println()
    
    println("masks = " + masks.toVector.map(_.toHexString))
    println()
    
    def testIP(ipString: String) {
      println("ipString = " + ipString)
      val ip = ipStringToInt(ipString)
      val dataIdx = findIpIdx(ip)
      println("dataIdx = " + dataIdx)
      if(dataIdx >= 0) {
        val data = subnetData(dataIdx)
        println("data = " + (subnetData(dataIdx) match {
          case (mask, length, idx) => (mask.toHexString, length, idx)
        }))
      }
      println()
    }
    
    testIP("61.232.86.12")
    testIP("253.2.100.253")
    testIP("253.3.0.0")