
collect_list()
In Apache Spark, the collect_list()
function is used for aggregating values from multiple rows into a list within a single row. This function is particularly useful for creating arrays of values grouped by keys.
Usage
collect_list()
gathers values from a column in multiple rows into a single list.- It is often used with
groupBy()
for aggregating values based on a specific grouping.
Create Spark Session and sample DataFrame
from pyspark.sql import SparkSessionfrom pyspark.sql.functions import collect_list
# Initialize Spark Sessionspark = SparkSession.builder.appName("collectListExample").getOrCreate()
# Sample DataFramedata = [("James", "Apple"), ("Michael", "Banana"), ("Robert", "Cherry"), ("James", "Dragonfruit"), ("Michael", "Elderberry")]columns = ["Name", "Fruit"]df = spark.createDataFrame(data, columns)df.show()
Output:
+-------+-----------+
| Name| Fruit|
+-------+-----------+
| James| Apple|
|Michael| Banana|
| Robert| Cherry|
| James|Dragonfruit|
|Michael| Elderberry|
+-------+-----------+
Example: Use collect_list to merge multiple rows into a list within a single row
df.groupBy("Name")
: it groups the dataset by the Name column.agg()
: this function is used to chain thecollect_list()
function andalias()
function together.collect_list("Fruit")
: thecollect_list()
function takes in the Fruit column and merge multiple rows of fruit into one row based on the grouping column Name.alias("Fruits"))
: This renamed the returned column to Fruits.truncate=False
in the show method ensures that the entire content of the list is displayed without truncation.
grouped_fruits = df.groupBy("Name").agg(collect_list("Fruit").alias("Fruits"))grouped_fruits.show(truncate=False)
Output:
+-------+--------------------+
|Name |Fruits |
+-------+--------------------+
|James |[Apple, Dragonfruit]|
|Michael|[Banana, Elderberry]|
|Robert |[Cherry] |
+-------+--------------------+
# Stop the Spark Sessionspark.stop()