1

I have a Scala Spark dataframe with the schema:

    root
     |-- passengerId: string (nullable = true)
     |-- travelHist: array (nullable = true)
     |    |-- element: integer (containsNull = true)

I want to iterate through the array elements and find the max number of occurrences of 0 values between 1 and 2.

passengerID travelHist
1 1, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0
2 0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0
3 0,0,0,2,1,0,2,1,0

The output for the above records should look like below:

passengerID maxStreak
1 7
2 3
3 1

What would be the most efficient way to find such an interval assuming the number of elements in the array does not exceed 50 values?

1
  • The answer could be even in Python. I can convert it to Scala syntax later. Commented Sep 15, 2023 at 12:04

2 Answers 2

3

Let us do some pattern matching

df1 = (
    df
    .withColumn('matches', F.expr("array_join(travelHist, '')"))
    .withColumn('matches', F.expr("regexp_extract_all(matches, '1(0+)2', 1)"))
    .withColumn('matches', F.expr("transform(matches, x -> length(x))"))
    .withColumn('maxStreak', F.expr("array_max(matches)"))
)

df1.show()
+-----------+--------------------+-------+---------+
|passengerID|          travelHist|matches|maxStreak|
+-----------+--------------------+-------+---------+
|          1|[1, 0, 0, 0, 0, 2...| [4, 7]|        7|
|          2|[0, 0, 0, 0, 0, 0...|    [3]|        3|
|          3|[0, 0, 0, 2, 1, 0...|    [1]|        1|
+-----------+--------------------+-------+---------+
Sign up to request clarification or add additional context in comments.

1 Comment

Your answer is correct. Thanks for the help! However, I am on Spark version 2.4. It does not support regexp_extract_all() function. I am looking for an alternative way to get the same result in Scala.
2

Here's a solution using scala UDF in pyspark. You can find the code for the UDF and release jar used in the pyspark script in the following repository.

https://github.com/dineshdharme/pyspark-native-udfs

Code for scala UDF is as follows.

package com.help.udf

import org.apache.spark.sql.api.java.UDF1

import scala.collection.mutable
import util.control.Breaks._
import scala.reflect.runtime.currentMirror
import scala.tools.reflect.ToolBox

class CountZeros extends UDF1[Array[Int], Int] {

  override def call(given_array: Array[Int]): Int = {


    //println("Printing all element")
    //given_array.foreach(ele => print (ele + ",  "))
    //println("adding the debug printing ")
    var maxCount = -1

    var runningCount = -1
    var insideLoop = false

    for( ele <- given_array ){

        if (ele == 1) {
          // initialize count to 0
          runningCount = 0
          insideLoop = true


        }
        if (ele == 0 && insideLoop) {
          runningCount += 1

        }
        if (ele == 2 && insideLoop) {
          insideLoop = false
          if (maxCount == -1) {
            maxCount = runningCount
          }
          if (runningCount > maxCount) {
            maxCount = runningCount
          }

        }



      //println( "ele ", ele, " maxCount  ", maxCount, "  runningCount  ", runningCount, " insideLoop flag  ", insideLoop)
    }


    //println("maxCount" , maxCount)
    maxCount
  }
}

Following is the pyspark code which uses the above UDF.

import sys

import pyspark.sql.functions as F
from pyspark import SparkContext, SQLContext
from pyspark.sql import SparkSession
from pyspark.sql.types import *

spark = SparkSession.builder \
    .appName("MyApp") \
    .config("spark.jars", "file:/path/to/pyspark-native-udfs/releases/pyspark-native-udfs-assembly-0.1.2.jar") \
    .getOrCreate()

sc = spark.sparkContext
sqlContext = SQLContext(sc)


data1 = [
    [1, [1, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0]],
    [2, [0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0]],
    [3, [0,0,0,2,1,0,2,1,0]],
]

df1Columns = ["passengerID", "travelHist"]
df1 = sqlContext.createDataFrame(data=data1, schema=df1Columns)
df1 = df1.withColumn("travelHist", F.col("travelHist").cast("array<int>"))

df1.show(n=100, truncate=False)
df1.printSchema()



spark.udf.registerJavaFunction("count_zeros_udf", "com.help.udf.CountZeros", IntegerType())

df1.createOrReplaceTempView("given_table")

df1_array = sqlContext.sql("select *, count_zeros_udf(travelHist) as maxStreak from given_table")
print("Dataframe after applying SCALA NATIVE UDF")
df1_array.show(n=100, truncate=False)

Output :

+-----------+------------------------------------------------------+
|passengerID|travelHist                                            |
+-----------+------------------------------------------------------+
|1          |[1, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0]   |
|2          |[0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0]|
|3          |[0, 0, 0, 2, 1, 0, 2, 1, 0]                           |
+-----------+------------------------------------------------------+

root
 |-- passengerID: long (nullable = true)
 |-- travelHist: array (nullable = true)
 |    |-- element: integer (containsNull = true)

Dataframe after applying SCALA NATIVE UDF
+-----------+------------------------------------------------------+---------+
|passengerID|travelHist                                            |maxStreak|
+-----------+------------------------------------------------------+---------+
|1          |[1, 0, 0, 0, 0, 2, 1, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0]   |7        |
|2          |[0, 0, 0, 0, 0, 0, 0, 0, 2, 1, 0, 0, 0, 2, 0, 0, 0, 0]|3        |
|3          |[0, 0, 0, 2, 1, 0, 2, 1, 0]                           |1        |
+-----------+------------------------------------------------------+---------+

2 Comments

Thank you! I am using Spark version 2.4 and it does not support regexp_extract_all(). The UDF mentioned above will work, but I am seeing if I can use findAllIn() instead. Does my syntax below look correct? val df1 = travel_hist.withColumn("matches", expr("array_join(travelHist, '')")) val pat = "1(0+)2".r val df = travel_hist.select("matches").map { (line:Row) => (for { m <- pat.findAllIn(line("matches")).matchData g <- m.subgroups } yield (g) ).toList }
@Abishek : I would collect all the indices where 1 appears in one array. Similarly for 2. Now the difference between the adjacent indices where first index value is 1 and second index value is 2 should give me maxStreak. No need to do regex matching.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.