Search code examples
pythoncbit-manipulation

Determining the validity of a multi-hot encoding using bit manipulation


Suppose I have N items and a binary number that represents inclusion of these items in a result:

N = 4

# items 1 and 3 will be included in the result
vector = 0b0101

# item 2 will be included in the result
vector = 0b0010

I'm also provided a list conflicts which indicates which items cannot be included in the result at the same time:

conflicts = [
  0b0110, # any result that contains items 1 AND 2 is invalid
  0b0111, # any result that contains AT LEAST 2 items from {1, 2, 3} is invalid
]

Given this list of conflicts, we can determine the validity of the earlier vectors:

# invalid as it triggers conflict 1: [0, 1, 1, 1]
vector = 0b0101

# valid as it triggers no conflicts
vector = 0b0010

How can bit manipulation be used to determine validity of a vector or large list of vectors against a list of conflicts in this context?

The solution provided here already gets us most of the way there but I'm unsure how to adapt it to the integer use case (to avoid numpy arrays and numba entirely).


Solution

  • N = 4
    
    # items 1 and 3 will be included in the result
    vector = 0b0101
    
    # item 2 will be included in the result
    vector = 0b0010
    
    conflicts = [
      0b0110, # any result that contains items 1 AND 2 is invalid
      0b0111, # any result that contains AT LEAST 2 items from {1, 2, 3} is invalid
    ]
    
    def find_conflict(vector, conflicts):
        found_conflict = False
        for v in conflicts:
            result = vector & v # do a logical AND operation
            if result != 0: # there are common elements
                number_of_bits_set = bin(result).count("1") # count number of common elements
                if number_of_bits_set >= 2: # check common limit for detection of invalid vectors
                    found_conflict = True
                    print(f"..Conflict between {bin(vector)} and {bin(v)}: {bin(result)}")
        if found_conflict:
            print(f"Conflict found for {bin(vector)}.")
        else:
            print(f"No conflict found for {bin(vector)}.")
    
    # invalid as it triggers conflict 1: [0, 1, 1, 1]
    vector = 0b0101
    find_conflict(vector, conflicts)
    
    # valid as it triggers no conflicts
    vector = 0b0010
    find_conflict(vector, conflicts)
    
    $ python3 pythontest.py
    ..Conflict between 0b101 and 0b111: 0b101
    Conflict found for 0b101.
    No conflict found for 0b10.
    $