Search code examples
pythondictionaryhashtable

How can i reduce the complexity / time it takes to do this triplet count function?


currently wondering how I can reduce the complexity / time of this triplet count function?

This solution works, but it is too slow to pass on HackerRank

def count_triplets(arr, r):
    left_map = {}
    right_map = {}
    for item in arr:
        right_map[item] = arr.count(item)  # store frequency of occurences of each item.

    count = 0
    for i in range(0, len(arr)):
        mid_term = arr[i]
        c1 = 0
        c3 = 0
        right_map[mid_term] -= 1
        left_term = mid_term / r
        if left_term in left_map and mid_term % r == 0:
            c1 = left_map.get(left_term)

        right_term = mid_term * r
        if right_term in right_map:
            c3 = right_map.get(right_term)

        count += c1 * c3

        if mid_term in left_map:
            left_map[mid_term] += 1
        else: 
            left_map[mid_term] = 1

    return count


if __name__ == '__main__':
    r = 3
    arr = [1,3,9,9,9,27,81]  # returns 9 as solution
    # arr = [1,3,9,9,27,81] # returns 6 as solution
    result = count_triplets(arr, r)
    print("result = ", result)

Solution

  • This code:

        right_map = {}
        for item in arr:
            right_map[item] = arr.count(item)  # store frequency of occurences of each item.
    

    is O(n^2) complexity, because arr.count itself is O(n), that's probably your problem. You can count all the items in O(n) using a Counter, which can also help simplify your code in general:

    from collections import Counter
    
    
    def count_triplets(arr, r):
        left_map = Counter()
        right_map = Counter(arr)
    
        count = 0
        for mid_term in arr:
            right_map[mid_term] -= 1
            if mid_term % r == 0:
                left_term = mid_term / r
                c1 = left_map[left_term]
            else:
                c1 = 0
    
            right_term = mid_term * r
            c3 = right_map[right_term]
    
            count += c1 * c3
            left_map[mid_term] += 1
    
        return count
    
    
    def test():
        r = 3
        assert count_triplets([1, 3, 9, 9, 9, 27, 81], r) == 9
        assert count_triplets([1, 3, 9, 9, 27, 81], r) == 6
    
    
    if __name__ == '__main__':
        test()