1

I've created the following graph:

spark = SparkSession.builder.appName('aggregate').getOrCreate()
vertices = spark.createDataFrame([('1', 'foo', 99), 
                                  ('2', 'bar', 10),
                                 ('3', 'baz', 25),
                                 ('4', 'spam', 7)],
                                 ['id', 'name', 'value'])

edges = spark.createDataFrame([('1', '2'), 
                               ('1', '3'),
                               ('3', '4')],
                              ['src', 'dst'])

g = GraphFrame(vertices, edges)

I would like to aggregate the messages, such that for any given vertex we have a list of all values for its children vertices all the way to the edge. For example, from vertex 1 we have a child edge to vertex 3 which has a child edge to vertex 4. We also have a child edge to 2. That is:

(1) --> (3) --> (4)
  \
   \--> (2)

From 1 I'd like to collect all values from this path: [99, 10, 25, 7]. Where 99 is the value for vertex 1, 10 is the value of the child vertex 2, 25 is the value at vertex 3 and 7 is the value at vertex 4.

From 3 we'd have the values [25, 7], etc.

I can approximate this with aggregateMessages:

agg = g.aggregateMessages(collect_list(AM.msg).alias('allValues'),
                          sendToSrc=AM.dst['value'],
                          sendToDst=None)

agg.show()

Which produces:

+---+---------+
| id|allValues|
+---+---------+
|  3|      [7]|
|  1| [25, 10]|
+---+---------+

At 1 we have [25, 10] which are the immediate child values, but we are missing 7 and the "self" value of 99.

Similarly, I'm missing 25 for vertex 3.

How can I aggregate messages "recursively", such that allValues from child vertices are aggregated at the parent?

2
  • You can find a connected components first. Then apply a custom function to find out the trees for each root within that connected component. Commented Jan 1, 2021 at 15:10
  • @Julio Did you find any solution ? Commented Feb 28, 2022 at 9:07

1 Answer 1

0

Adapting this answer for your question, and wrangled the result of that answer to get your desired output. I admit it's a very ugly solution, but I hope it'll be helpful for you as a starting point to work towards a more efficient and elegant implementation.

from graphframes import GraphFrame
from graphframes.lib import Pregel
import pyspark.sql.functions as F
from pyspark.sql.window import Window
from pyspark.sql.types import *


vertices = spark.createDataFrame([('1', 'foo', 99), 
                                  ('2', 'bar', 10),
                                 ('3', 'baz', 25),
                                 ('4', 'spam', 7)],
                                 ['id', 'name', 'value'])

edges = spark.createDataFrame([('1', '2'), 
                               ('1', '3'),
                               ('3', '4')],
                              ['src', 'dst'])

g = GraphFrame(vertices, edges)

### Adapted from previous answer

vertColSchema = StructType()\
      .add("dist", DoubleType())\
      .add("node", StringType())\
      .add("path", ArrayType(StringType(), True))

def vertexProgram(vd, msg):
    if msg == None or vd.__getitem__(0) < msg.__getitem__(0):
        return (vd.__getitem__(0), vd.__getitem__(1), vd.__getitem__(2))
    else:
        return (msg.__getitem__(0), vd.__getitem__(1), msg.__getitem__(2))

vertexProgramUdf = F.udf(vertexProgram, vertColSchema)

def sendMsgToDst(src, dst):
    srcDist = src.__getitem__(0)
    dstDist = dst.__getitem__(0)
    if srcDist < (dstDist - 1):
        return (srcDist + 1, src.__getitem__(1), src.__getitem__(2) + [dst.__getitem__(1)])
    else:
        return None

sendMsgToDstUdf = F.udf(sendMsgToDst, vertColSchema)

def aggMsgs(agg):
    shortest_dist = sorted(agg, key=lambda tup: tup[1])[0]
    return (shortest_dist.__getitem__(0), shortest_dist.__getitem__(1), shortest_dist.__getitem__(2))

aggMsgsUdf = F.udf(aggMsgs, vertColSchema)

result = (
    g.pregel.withVertexColumn(
        colName = "vertCol",

        initialExpr = F.when(
            F.col("id") == 1,
            F.struct(F.lit(0.0), F.col("id"), F.array(F.col("id")))
        ).otherwise(
            F.struct(F.lit(float("inf")), F.col("id"), F.array(F.lit("")))
        ).cast(vertColSchema),

        updateAfterAggMsgsExpr = vertexProgramUdf(F.col("vertCol"), Pregel.msg())
    )
    .sendMsgToDst(sendMsgToDstUdf(F.col("src.vertCol"), Pregel.dst("vertCol")))
    .aggMsgs(aggMsgsUdf(F.collect_list(Pregel.msg())))
    .setMaxIter(3)    ## This should be greater than the max depth of the graph
    .setCheckpointInterval(1)
    .run()
)

df = result.select("vertCol.node", "vertCol.path").repartition(1)
df.show()
+----+---------+
|node|     path|
+----+---------+
|   1|      [1]|
|   2|   [1, 2]|
|   3|   [1, 3]|
|   4|[1, 3, 4]|
+----+---------+

### Wrangling the dataframe to get desired output

final = df.select(
    'node',
    F.posexplode_outer('path')
).withColumn(
    'children', 
    F.collect_list('col').over(Window.partitionBy('node').orderBy(F.desc('pos')))
).groupBy('col').agg(
    F.array_distinct(F.flatten(F.collect_list('children'))).alias('children')
).alias('t1').repartition(1).join(
    vertices,
    F.array_contains(F.col('t1.children'), vertices.id)
).groupBy('col').agg(
    F.collect_list('value').alias('values')
).withColumnRenamed('col', 'id').orderBy('id')

final.show()
+---+---------------+
| id|         values|
+---+---------------+
|  1|[99, 10, 25, 7]|
|  2|           [10]|
|  3|        [25, 7]|
|  4|            [7]|
+---+---------------+
Sign up to request clarification or add additional context in comments.

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.