flatMapGroupsWithState is the unsung hero of stateful stream processing in Spark, allowing you to maintain and update arbitrary state across incoming data groups, not just simple aggregations.
Imagine you’re tracking user sessions, and for each user, you want to know their last 10 visited pages, the total time spent on the site, and whether they’ve viewed a specific "checkout" page. A standard groupByKey followed by mapGroups would give you all the events for a user, but you’d have to re-process the entire history for each new event to update your session state. flatMapGroupsWithState lets you keep that session state (the last 10 pages, total time, checkout flag) alive and update it incrementally with each new event for that user.
Here’s a simplified example. Let’s say we have a stream of (userId, event) where event could be a page view or a purchase. We want to maintain a state for each userId that tracks the last page viewed and whether a purchase has occurred.
import org.apache.spark.sql.streaming.{GroupState, GroupStateTimeout}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
val spark = SparkSession.builder.appName("flatMapGroupsWithStateExample").getOrCreate()
// Sample data: (userId, eventType, timestamp)
val data = Seq(
(1, "page_view", 1678886400000L),
(2, "page_view", 1678886410000L),
(1, "page_view", 1678886420000L),
(1, "purchase", 1678886430000L),
(2, "page_view", 1678886440000L),
(1, "page_view", 1678886450000L)
)
val schema = StructType(Seq(
StructField("userId", IntegerType, false),
StructField("eventType", StringType, false),
StructField("timestamp", LongType, false)
))
import spark.implicits._
val df = spark.createDataFrame(data).toDF("userId", "eventType", "timestamp")
// Define the state structure
case class SessionState(lastPage: String, hasPurchased: Boolean)
// The processing function
def updateSessionState(
userId: Int,
events: Iterator[Row],
state: GroupState[SessionState]
): Iterator[Row] = {
// Get the current state, or initialize it if it's the first event for this user
var currentState = if (state.exists) state.get else SessionState("", false)
// Process each event for this user
events.foreach { row =>
val eventType = row.getString(1)
eventType match {
case "page_view" => currentState = currentState.copy(lastPage = "view") // Simplified: just record 'view'
case "purchase" => currentState = currentState.copy(hasPurchased = true)
case _ => // Ignore other event types
}
}
// Update the state
state.update(currentState)
// Output a row representing the updated state for this user
Iterator(Row(userId, currentState.lastPage, currentState.hasPurchased))
}
// Apply flatMapGroupsWithState
val result = df
.withWatermark("timestamp", "10 seconds") // Crucial for state management and timeouts
.groupBy($"userId")
.as[Row] // Treat the group as Rows for access to all columns
.flatMapGroupsWithState(
outputMode = "update", // Or "append", "complete"
timeoutConf = GroupStateTimeout.ofSeconds(30), // How long to keep state if no new data
initializeState = (userId: Int, events: Iterator[Row]) => {
// This function is called for the very first event of a new group
// We can initialize state here if needed, but our updateSessionState handles it too.
// For simplicity, we'll let updateSessionState handle initialization.
updateSessionState(userId, events, GroupState.empty[SessionState])
},
processGroups = updateSessionState // This is our main processing function
)
.toDF("userId", "lastEventType", "purchased")
// You would typically write this to a sink, e.g., a console or Kafka
// result.writeStream.format("console").start().awaitTermination()
The core of flatMapGroupsWithState is the processGroups function (or initializeState for the first event of a group). This function receives the key of the group (e.g., userId), an Iterator of all the data belonging to that group that has arrived since the last processing, and a GroupState object. The GroupState is your handle to the persistent state for that specific key. You can get the current state, update it with new values, remove it if necessary, or check if it exists.
The timeoutConf is critical. It tells Spark how long to keep the state for a particular key if no new data arrives for that key. If data for a userId stops arriving for longer than the timeout, Spark will discard that userId’s state. This prevents memory leaks for inactive keys. The withWatermark on the input DataFrame is also essential; it defines how late data is handled and is directly related to how Spark manages state and timeouts.
The outputMode determines what gets written to the sink. update (used above) only outputs rows where the state has changed. append outputs a new row for every processed group in each micro-batch. complete outputs all rows for all groups in each micro-batch.
The most surprising thing about flatMapGroupsWithState is that it doesn’t require you to pre-define the type of state you’ll be storing in Spark’s schema. You can use arbitrary Scala case classes, maps, or even complex nested structures. Spark serializes and deserializes this state using Kryo or Java serialization behind the scenes, making it incredibly flexible for complex stateful logic that goes beyond simple counts or sums.
The GroupState object itself is more than just a mutable container. It’s an interface that Spark provides to manage the state across micro-batches. When you call state.update(newState), Spark doesn’t just modify an in-memory object; it serializes newState and stores it durably (in Spark’s checkpointing mechanism, which defaults to the filesystem specified in your Spark configuration, e.g., spark.sql.streaming.checkpointLocation). When the next micro-batch arrives for the same userId, Spark reads the serialized newState from the checkpoint, deserializes it, and passes it back to your processGroups function as the GroupState object. This process ensures that your state survives failures and restarts.
The primary lever you control is the logic within your processGroups function. This is where you define how incoming events modify the existing state and what new output should be generated. You also tune the timeoutConf and ensure your input DataFrame has a watermark defined correctly to manage state memory and late-arriving data.
The next concept you’ll likely encounter is managing complex state evolution, such as transitioning from one state to another based on sequences of events, and how to handle state schema changes over time if your application evolves.