
sum()
The sum()
function is an aggregation function that computes the total sum of a numeric column in a DataFrame.
Usage
sum()
can be applied to a DataFrame to calculate the total of a specific column.- When combined with
groupBy()
, it can be used to compute the sum for each group in a dataset.
Create Spark Session and sample DataFrame
from pyspark.sql import SparkSessionfrom pyspark.sql.functions import sum
# Initialize Spark Sessionspark = SparkSession.builder.appName("sumExample").getOrCreate()
# Sample DataFramedata = [("group A", 45), ("group A", 30), ("group A", 55), ("group B", 10), ("group B", 20), ("group B", 60), ]columns = ["Group", "Variable"]df = spark.createDataFrame(data, columns)df.show()
Output:
+-------+--------+
| Group|Variable|
+-------+--------+
|group A| 45|
|group A| 30|
|group A| 55|
|group B| 10|
|group B| 20|
|group B| 60|
+-------+--------+
Example: Use sum()
to get the total sum of a numeric column
sum("Variable")
: it calculates the sum of value in the Variable column of the DataFrame df.alias("Sum of Variable")
: it renames the resulting column as Sum of Variable.
total_df = df.select(sum("Variable").alias("Sum of Variable"))total_df.show()
Output:
+---------------+
|Sum of Variable|
+---------------+
| 220|
+---------------+
Example: Use sum()
and groupBy()
to get sum of different groups
groupBy("Group")
: it groups the df DataFrame by the Group column.agg
:agg
function is used for aggregation functions.sum("Variable")
: it sums up Variable values of each unique group.
grouped_data = df.groupBy("Group").agg(sum("Variable").alias("Sum of Variable"))grouped_data.show()
Output:
+-------+---------------+
| Group|Sum of Variable|
+-------+---------------+
|group A| 130|
|group B| 90|
+-------+---------------+
# Stop the Spark Sessionspark.stop()