
collect_set()
collect_set()
is an aggregation function in Apache Spark that is used to aggregate data within groups based on a specified grouping condition. Unlike collect_list()
, which collects values into a list allowing duplicates, collect_set()
collects unique values into a set. This means that collect_set()
ensures that each distinct value within a group is represented only once in the resulting set.
Create Spark Session and sample DataFrame
from pyspark.sql import SparkSessionfrom pyspark.sql.functions import collect_list, collect_set
# Initialize Spark Sessionspark = SparkSession.builder.appName("collectListSetExample").getOrCreate()
# Sample DataFramedata = [("James", "Apple"), ("Michael", "Banana"), ("James", "Apple"), ("Robert", "Cherry")]columns = ["Name", "Fruit"]df = spark.createDataFrame(data, columns)df.show()
Output:
+-------+------+
| Name| Fruit|
+-------+------+
| James| Apple|
|Michael|Banana|
| James| Apple|
| Robert|Cherry|
+-------+------+
Example: Use collect_set()
set_df = df.groupBy("Name").agg(collect_set("Fruit").alias("Unique Fruits"))set_df.show(truncate=False)
Output:
+-------+-------------+
|Name |Unique Fruits|
+-------+-------------+
|James |[Apple] |
|Michael|[Banana] |
|Robert |[Cherry] |
+-------+-------------+
df.groupBy("Name")
: it groups the dataset by the Name column.agg()
: this function is used to chain thecollect_list()
function andalias()
function together.collect_set("Fruit")
: thecollect_list()
function takes in the Fruit column and merge only distinct fruit values of fruit into one row based on the grouping column Name.alias("Fruits"))
: This renames the returned column to Fruits.truncate=False
in theshow()
method ensures that the entire content of the list is displayed without truncation.
# Stop the Spark Sessionspark.stop()