3

I'm working with some deeply nested data in a PySpark dataframe. As I'm trying to flatten the structure into rows and columns I noticed that when I call withColumn if the row contains null in the source column then that row is dropped from my result dataframe. Instead I would like to find a way to retain the row and have null in the resulting column.

A sample dataframe to work with:

from pyspark.sql.functions import explode, first, col, monotonically_increasing_id
from pyspark.sql import Row

df = spark.createDataFrame([
  Row(dataCells=[Row(posx=0, posy=1, posz=.5, value=1.5, shape=[Row(_type='square', _len=1)]), 
                 Row(posx=1, posy=3, posz=.5, value=4.5, shape=[]), 
                 Row(posx=2, posy=5, posz=.5, value=7.5, shape=[Row(_type='circle', _len=.5)])
    ])
])

I also have a function I use to flatten structs:

def flatten_struct_cols(df):
    flat_cols = [column[0] for column in df.dtypes if 'struct' not in column[1][:6]]
    struct_columns = [column[0] for column in df.dtypes if 'struct' in column[1][:6]]

    df = df.select(flat_cols +
                   [col(sc + '.' + c).alias(sc + '_' + c)
                   for sc in struct_columns
                   for c in df.select(sc + '.*').columns])

    return df

And the schema looks like this:

df.printSchema()

root
 |-- dataCells: array (nullable = true)
 |    |-- element: struct (containsNull = true)
 |    |    |-- posx: long (nullable = true)
 |    |    |-- posy: long (nullable = true)
 |    |    |-- posz: double (nullable = true)
 |    |    |-- shape: array (nullable = true)
 |    |    |    |-- element: struct (containsNull = true)
 |    |    |    |    |-- _len: long (nullable = true)
 |    |    |    |    |-- _type: string (nullable = true)
 |    |    |-- value: double (nullable = true)

The starting dataframe:

df.show(3)

+--------------------+
|           dataCells|
+--------------------+
|[[0,1,0.5,Wrapped...|
+--------------------+

I start by exploding the array since I want to turn this array of struct with an array of struct into rows and columns. I then flatten the struct fields into new columns.

df = df.withColumn('dataCells', explode(col('dataCells')))
df = flatten_struct_cols(df)
df.show(3)

And my data looks like:

+--------------+--------------+--------------+---------------+---------------+
|dataCells_posx|dataCells_posy|dataCells_posz|dataCells_shape|dataCells_value|
+--------------+--------------+--------------+---------------+---------------+
|             0|             1|           0.5|   [[1,square]]|            1.5|
|             1|             3|           0.5|             []|            4.5|
|             2|             5|           0.5|[[null,circle]]|            7.5|
+--------------+--------------+--------------+---------------+---------------+

All is well and as expected until I try to explode the dataCells_shape column which has an empty/null value.

df = df.withColumn('dataCells_shape', explode(col('dataCells_shape')))
df.show(3)

Which drops the second row out of the dataframe:

+--------------+--------------+--------------+---------------+---------------+
|dataCells_posx|dataCells_posy|dataCells_posz|dataCells_shape|dataCells_value|
+--------------+--------------+--------------+---------------+---------------+
|             0|             1|           0.5|     [1,square]|            1.5|
|             2|             5|           0.5|  [null,circle]|            7.5|
+--------------+--------------+--------------+---------------+---------------+

Instead I would like to keep the row and retain the empty value for that column as well as all of the values in the other columns. I've tried creating a new column instead of overwriting the old when doing the .withColumn explode and get the same result either way.

I also tried creating a UDF that performs the explode function if the row is not empty/null, but I have ran into JVM errors handling null.

from pyspark.sql.functions import udf
from pyspark.sql.types import NullType, StructType

def explode_if_not_null(trow):
    if trow:
        return explode(trow)
    else:
        return NullType

func_udf = udf(explode_if_not_null, StructType())
df = df.withColumn('dataCells_shape_test', func_udf(df['dataCells_shape']))
df.show(3)

AttributeError: 'NoneType' object has no attribute '_jvm'

Can anybody suggest a way for me to explode or flatten ArrayType columns without losing rows when the column is null?

I am using PySpark 2.2.0

Edit:

Following the link provided as a possible dupe I tried to implement the suggested .isNotNull().otherwise() solution providing the struct schema to .otherwise but the row is still dropping out of the result set.

df.withColumn("dataCells_shape_test", explode(when(col("dataCells_shape").isNotNull(), col("dataCells_shape"))
                                              .otherwise(array(lit(None).cast(df.select(col("dataCells_shape").getItem(0))
                                                                                                              .dtypes[0][1])
                                                              )
                                                        )
                                             )
             ).show()

+--------------+--------------+--------------+---------------+---------------+--------------------+
|dataCells_posx|dataCells_posy|dataCells_posz|dataCells_shape|dataCells_value|dataCells_shape_test|
+--------------+--------------+--------------+---------------+---------------+--------------------+
|             0|             1|           0.5|   [[1,square]]|            1.5|          [1,square]|
|             2|             5|           0.5|[[null,circle]]|            7.5|       [null,circle]|
+--------------+--------------+--------------+---------------+---------------+--------------------+
9
  • instead of using a udf can you try using spark's inbuilt when? it'll go something like, df = df.withColumn('dataCells', when(col('dataCells').isNotNull),explode(col('dataCells'))) Commented Oct 10, 2018 at 19:28
  • 1
    Possible duplicate of Spark sql how to explode without losing null values. Though that post is not for pyspark, the technique is not language specific. Commented Oct 10, 2018 at 19:50
  • 1
    @Alexander you are missing the parentheses at the end of isNotNull() Commented Oct 10, 2018 at 20:11
  • 1
    @Alexander I can't test this, but explode_outer is a part of spark version 2.2 (but not available in pyspark until 2.3)- can you try the following: 1) explode_outer = sc._jvm.org.apache.spark.sql.functions.explode_outer and then df.withColumn("dataCells", explode_outer("dataCells")).show() or 2) df.createOrReplaceTempView("myTable") and then spark.sql("select *, explode_outer(dataCells) from myTable").show() Commented Oct 10, 2018 at 20:14
  • 1
    @Alexander related post on how to pull in java/scala functions: Spark: How to map Python with Scala or Java User Defined Functions? Commented Oct 10, 2018 at 21:17

2 Answers 2

5

Thanks to pault for pointing me to this question and this question about mapping Python to Java. I was able to get a working solution with:

from pyspark.sql.column import Column, _to_java_column

def explode_outer(col):
    _explode_outer = sc._jvm.org.apache.spark.sql.functions.explode_outer 
    return Column(_explode_outer(_to_java_column(col)))

new_df = df.withColumn("dataCells_shape", explode_outer(col("dataCells_shape")))

+--------------+--------------+--------------+---------------+---------------+
|dataCells_posx|dataCells_posy|dataCells_posz|dataCells_shape|dataCells_value|
+--------------+--------------+--------------+---------------+---------------+
|             0|             1|           0.5|     [1,square]|            1.5|
|             1|             3|           0.5|           null|            4.5|
|             2|             5|           0.5|  [null,circle]|            7.5|
+--------------+--------------+--------------+---------------+---------------+

root
 |-- dataCells_posx: long (nullable = true)
 |-- dataCells_posy: long (nullable = true)
 |-- dataCells_posz: double (nullable = true)
 |-- dataCells_shape: struct (nullable = true)
 |    |-- _len: long (nullable = true)
 |    |-- _type: string (nullable = true)
 |-- dataCells_value: double (nullable = true)

It's important to note that this works for pyspark version 2.2 because explode_outer is defined in spark 2.2 (but for some reason the API wrapper was not implemented in pyspark until version 2.3). This solution creates a wrapper for the already implemented java function.

Sign up to request clarification or add additional context in comments.

Comments

0

for that complex structure would be easier to write a map function and use it in flatMap method of RDD interface. As a result you will get a new flatted RDD, then you have to create a data frame again by applying a new schema.

def flat_arr(row):
    rows = []
    # apply some logic to fill rows list with more "rows"
    return rows

rdd = df.rdd.flatMap(flat_arr)
schema = StructType(
    StructField('field1', StringType()),
    # define more fields
)
df = df.sql_ctx.createDataFrame(rdd, schema)
df.show()

This solution looks a bit longer than applying withColumn, but it could be a first iteration of your solution so then you can see how to convert it to withColumn statements. But in my opinion map function would be appropriate here just to keep things clear

1 Comment

Wouldn't using the RDD prevent the operations from being optimized by the catalyst optimizer?

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.