Search code examples
gremlinjanusgraph

Sort cosine similarity scores before adding edges to graph


I was trying to get the example for jaccard similarity found here to work for cosine similarity but wanted to limit the number of created links to the top 10 scores.

I reviewed https://gist.github.com/dkuppitz/79e0b009f0c9ae87db5a but couldn't figure out how to skip over the edge creation piece to sort it before and get the same results as from the link above.

Based on the jaccard example above this is what I have come up with so far:

g.V().
    match(
     __.as('v1').outE('RECOMMENDS').values('amount').fold().as('v1rec'),
     __.as('v1').V().as('v2'),
     __.as('v2').outE('RECOMMENDS').values('amount').fold().as('v2rec'),
     __.as('v1').out().dedup().fold().as('v1n'),
     __.as('v2').out().dedup().fold().as('v2n')
    ).
    where('v1',lt('v2')).
         by(id).
    where('v1',neq('v2').and(without('v1n'))).
    where('v2',without('v1n')).
    project('v1','v2','n','d1','d2').
     by(select('v1')).
     by(select('v2')).
     by(
         select('v1rec','v2rec') <-- this does not work, can't get dot product from this
     ).
     by(coalesce(
            select('v1rec').
                unfold().
                math('_ ^ 2').
                sum(),
             constant(0))).
     by(coalesce(
            select('v2rec').
                unfold().
                math('_ ^ 2').
                sum(),
            constant(0))).
    filter(select('d1').is(gt(0))).
    filter(select('d2').is(gt(0))).
    project('v1','v2','cosine').
         by(select('v1')).
         by(select('v2')).
         by(math('n/(sqrt(d1)*sqrt(d2))')).
    sort{-it.cosine}.
    toList()[0..9].
    each {
         r -> g.V(r['v2']).as('v2').
         V(r['v1']).
         addE('PREDICTED_COSINE').
         to('v2').
         property('score', r['cosine']).
         toList()
    }

but can't figure out how to get the dot product in the third by step with select('v1rec','v2rec'). Please help.

UPDATE:

I couldn't fit this in a comment so posting here:

I tried another approach that gets me closer (I think) but still have an issue iterating over each list of maps to extract values from each:

g.V().
    match(
        __.as('v1').outE().as('e1'),
        __.as('v1').V().as('v2'),
        __.as('v2').outE().as('e2'),
        __.as('v1').out().dedup().fold().as('v1n'),
        __.as('v2').out().dedup().fold().as('v2n')).
    where('v1',neq('v2').
        and(without('v1n'))).
    where('v2',without('v1n')).
    project('v1','v2','a1','a2').
        by(select('v1')).
        by(select('v2')).
        by(select('e1').by('amount')).
        by(select('e2').by('amount')).
    project('v1','v2','n','d1','d2').
        by(select('v1')).
        by(select('v2')).
        by(math('a1 * a2')).
        by(math('a1 * a1')).
        by(math('a2 * a2')).
    group().
        by(select('v1','v2')).
        unfold()

One line of output:

==>{v1=v[4240], v2=v[8320]}=[{v1=v[4240], v2=v[8320], n=210.0, d1=196.0, d2=225.0}, {v1=v[4240], v2=v[8320], n=182.0, d1=196.0, d2=169.0}, {v1=v[4240], v2=v[8320], n=182.0, d1=196.0, d2=169.0}, {v1=v[4240], v2=v[8320], n=45.0, d1=9.0, d2=225.0}, {v1=v[4240], v2=v[8320], n=39.0, d1=9.0, d2=169.0}, {v1=v[4240], v2=v[8320], n=39.0, d1=9.0, d2=169.0}, {v1=v[4240], v2=v[8320], n=45.0, d1=9.0, d2=225.0}, {v1=v[4240], v2=v[8320], n=39.0, d1=9.0, d2=169.0}, {v1=v[4240], v2=v[8320], n=39.0, d1=9.0, d2=169.0}]

My goal is to sum up all the "n", "d1", and "d2" values from the maps so I can calculate the similarity as sum(n)/(sqrt(sum(d1))*sqrt(sum(d2))) for each key (such as {v1=v[4240], v2=v[8320]} outside the list from the example so n would be 210 + 182 + 182 + 45 + 39 + 39 + 45 + 39 + 39 = 820). I want to do this for a bunch of graph so I don't have one specific for this. Make sense now?


Solution

  • This is what I finally came up with:

    g.V().
        match(
            __.as('v1').outE().as('e1'),
            __.as('v1').V().as('v2'),
            __.as('v2').outE().as('e2'),
            __.as('v1').out().dedup().fold().as('v1n'),
            __.as('v2').out().dedup().fold().as('v2n')
        ).
        where('v1',neq('v2').
            and(without('v1n'))).
        where('v2',without('v1n')).
        project('v1','v2','a1','a2').
            by(select('v1')).
            by(select('v2')).
            by(select('e1').by('amount')).
            by(select('e2').by('amount')).
        project('v1','v2','n','d1','d2').
            by(select('v1')).
            by(select('v2')).
            by(math('a1 * a2')).
            by(math('a1 * a1')).
            by(math('a2 * a2')).
        group().
            by(select('v1','v2')).
            unfold().
        project('v1','v2','n','d1','d2').
            by(select(keys).select('v1')).
            by(select(keys).select('v2')).
            by(select(values).local(unfold().select('n').sum())).
            by(select(values).local(unfold().select('d1').sum())).
            by(select(values).local(unfold().select('d2').sum())).
        project('v1','v2','c').
            by(select('v1')).
            by(select('v2')).
            by(math('n / (sqrt(d1) * sqrt(d2))')).
        sort{ -it.c }.
        toList()[0..9]
    

    Thank for all the help.