Search code examples
pythonpysparkspark-graphxgraphframes

group the related values in one group


trying to group the column values based on related records

partColumns = (["partnumber","colVal1","colVal2", "colVal3","colVal4","colVal5"])

partrelations = ([("part0","part1","", "","",""),
                  ("part1","","part2", "","part4",""),
                  ("part2","part3", "", "part5","part6","part7"),
                  ("part10","part11","", "","",""),
                  ("part11","part13","part21", "","",""),
                  ("part13","part21","part18", "","part20",""),
                 ])
df_part_groups = spark.createDataFrame(data=partrelations, schema = partColumns) 

Dataframe Output

trying to get output as below -

sample output

edges = (df_part_groups
         .withColumnRenamed("partnumber", "src")
         .withColumnRenamed("colVal1", "dst")
        )

vertices = (edges.select("src").distinct()
            .union(edges.select("dst").distinct())
            .withColumnRenamed("src", "id"))
         
#create a graph and find all connected components
g = G.GraphFrame(vertices, edges)
cc = g.connectedComponents()

display(df_part_groups
        .join(cc.distinct(), df_part_groups.device == cc.id)
        .orderBy("component", "partnumber", "colVal1"))

Above is what I am trying to put together

thanks for help!!


Solution

  • We can do a simple check using set intersection to solve the problem. (Not aware of GraphFrames :()

    step 1: combine all parts in to a single array for each row

    from pyspark.sql import functions as F
        
    df_part_groups1= df_part_groups.withColumn('parts', F.array('partnumber', 'colVal1', 'colVal2', 'colVal3', 'colVal4', 'colVal5')  )
    

    step 2: get all_parts which is a list of lists of combined parts, since the group needs to be determined amongst various rows.

    def clean_lists(plists):
      return [ list(filter(None, pl)) for pl in plists]
    
    all_parts = clean_lists((df_part_groups1.groupBy(F.lit(1)).agg(F.collect_list('parts').alias('parts')).collect())[0].parts)
    

    step 3: get groups data using the collected all_parts

    def part_of_existing_group(gps, pl):
      for key in gps.keys():
        if set(gps[key]) & set(pl):
          gps[key] = list(set(gps[key] + pl))      
          return True
          return False   
          
    def findGroups(plists): 
      groups = {}    
      index = 1
      for pl in plists:
        if len(groups.keys()) == 0 or (not part_of_existing_group(groups, pl)):
          groups[f'G{index}'] = pl
          index +=1
      return groups  
    

    Step 4: Assign groups based on the groups map that you created.

     groups = findGroups(all_parts)
        
        @udf
    def get_group_val(part):
      for key in groups.keys():
        if part in groups[key]:
          return key
      return -1
    
    df_part_groups2 = df_part_groups1.withColumn('part', F.explode('parts')).dropDuplicates(['part']).where(~F.col('part').like('')).select('part', 'parts').withColumn('Group', get_group_val('part'))
    
        df_part_groups2.show()
    +------+--------------------+-----+
    |  part|               parts|Group|
    +------+--------------------+-----+
    | part0|[part0, part1, , ...|   G1|
    | part1|[part0, part1, , ...|   G1|
    |part10|[part10, part11, ...|   G2|
    |part11|[part10, part11, ...|   G2|
    |part13|[part11, part13, ...|   G2|
    |part18|[part13, part21, ...|   G2|
    | part2|[part1, , part2, ...|   G1|
    |part20|[part13, part21, ...|   G2|
    |part21|[part11, part13, ...|   G2|
    | part3|[part2, part3, , ...|   G1|
    | part4|[part1, , part2, ...|   G1|
    | part5|[part2, part3, , ...|   G1|
    | part6|[part2, part3, , ...|   G1|
    | part7|[part2, part3, , ...|   G1|
    +------+--------------------+-----+