Search code examples
pythondictionarysortingsimilarity

How to group redundant entries of a dictionary


I have a dictionary of entries like this:

{
    'A': {
        'HUE_SAT': 1,
        'GROUP_INPUT': 1,
        'GROUP_OUTPUT': 1
    },
    'D': {
        'HUE_SAT': 1,
        'GROUP_INPUT': 1,
        'GROUP_OUTPUT': 1
    },
    'T': {
        'HUE_SAT': 1,
        'GROUP_INPUT': 1,
        'GROUP_OUTPUT': 1
    },
    'O': {
        'GROUP_INPUT': 3,
        'MAPPING': 2,
        'TEX_NOISE': 2,
        'UVMAP': 2,
        'VALTORGB': 3,
        'GROUP_OUTPUT': 1,
        'AMBIENT_OCCLUSION': 1,
        'MIX': 4,
        'REROUTE': 1,
        'NEW_GEOMETRY': 1,
        'VECT_MATH': 1
    },

Each item is compared to one another, and given a similarity score. A dictionary is then created with a tuple of the two compared entries as keys and their similarity score as values. The issue is that I'm getting a lot of redundant entries like this:

{
    ('A', 'D'): 1.0,
    ('A', 'C'): 1.0,
    ('D', 'A'): 1.0,
    ('D', 'C'): 1.0,
    ('C', 'A'): 1.0,
    ('C', 'D'): 1.0,

I want to group keys if they all give the same similarity score when compared to each other, like so:

{
    ('A', 'D', 'C'): 1.0,
    ('O', 'L', 'S', 'N', 'P'): 0.412

Code:

# Cosine similarity function from here:
# https://stackoverflow.com/a/35981085/22855942

def square_root(x):
    return round(sqrt(sum([a * a for a in x])), 3)

def cosine_similarity(a, b):

    input1 = {}
    input2 = {}

    vector1 = []
    vector2 = []

    if len(a) > len(b):
        input1 = a
        input2 = b
    else:
        input1 = b
        input2 = a

    vector1 = list(input1.values())

    for k in input1.keys():
        if k in input2:
            vector2.append(float(input2[k]))
        else:
            vector2.append(float(0))

    numerator = sum(a * b for a, b in zip(vector2, vector1))
    denominator = square_root(vector1) * square_root(vector2)
    return round(numerator / float(denominator), 3)


keys = tuple(my_dict.keys())
results = {}

for k in keys:
    for l in keys:
        if l != k:
            results[(k, l)] = results.get((l, k), cosine_similarity(my_dict[k], my_dict[l]))

results = {key: value for key, value in sorted(results.items(), key=lambda item: item[1], reverse=True)}

I tried to create a sort of 'buffer' of compared pairs, comparing the current item the last item of the buffer list, appending depending on the similarity score – but that quickly spiralled into a mess of nested for loops, conditionals and sublists, when I feel like there must be a much more elegant solution?


Solution

  • You problem can be formulated as a clique problem if you build one graph per distinct similarity value.

    Using networkx:

    from collections import defaultdict
    import networkx as nx
    
    # MAKING TOY DATA - USE YOUR DATA INSTEAD
    points = {c: f'{i:04b}' for i,c in enumerate('ABCDEFGHIJ')} # 10 of the 16 vertices of a 4-dimensional unit hypercube
    similarities = {(p,q): sum(x==y for x,y in zip(s,t))**0.5 for (p,s),(q,t) in combinations(points.items(),2)}
    
    # BUILDING ONE GRAPH PER DISTINCT SIMILARITY VALUE
    graphs = defaultdict(nx.Graph)
    for (p,q),s in similarities.items():
        graphs[s].add_edge(p, q)   # optionally round s to avoid floating-point issues: graphs[int(1000*s+0.5)].add_edge(p, q)
    
    # FIND CLIQUES
    cliques = {}
    for s,G in graphs.items():
        cliques.update({tuple(clique): s for clique in nx.find_cliques(G)})
    
    print(cliques)
    # {('H', 'G'): 1.7320508075688772, ('H', 'F'): 1.7320508075688772, ('H', 'D'): 1.7320508075688772, ('B', 'A'): 1.7320508075688772, ('B', 'J'): 1.7320508075688772, ('B', 'F'): 1.7320508075688772, ('B', 'D'): 1.7320508075688772, ('I', 'A'): 1.7320508075688772, ('I', 'J'): 1.7320508075688772, ('E', 'A'): 1.7320508075688772, ('E', 'F'): 1.7320508075688772, ('E', 'G'): 1.7320508075688772, ('G', 'C'): 1.7320508075688772, ('C', 'A'): 1.7320508075688772, ('C', 'D'): 1.7320508075688772,
    #  ('J', 'E'): 1.0, ('J', 'C'): 1.0, ('J', 'H'): 1.0, ('B', 'G'): 1.0, ('I', 'G'): 1.0, ('I', 'F'): 1.0, ('I', 'D'): 1.0, ('D', 'E'): 1.0, ('A', 'H'): 1.0, ('F', 'C'): 1.0,
    # ('H', 'I'): 0.0, ('G', 'J'): 0.0
    # ('J', 'A', 'F', 'D'): 1.4142135623730951, ('B', 'E', 'C', 'I'): 1.4142135623730951, ('B', 'E', 'C', 'H'): 1.4142135623730951, ('D', 'A', 'G', 'F'): 1.4142135623730951,}