
mean()
The mean()
function in Apache Spark is an aggregation function that computes the average value of a numeric column in a DataFrame.
Usage
mean()
can be used on its own to compute the average of a column.- It's often used with
groupBy()
to calculate the average for each group in a dataset.
Create Spark Session and sample DataFrame
from pyspark.sql import SparkSessionfrom pyspark.sql.functions import mean
# Initialize Spark Sessionspark = SparkSession.builder.appName("meanExample").getOrCreate()
# Sample DataFramedata = [("group A", 45), ("group A", 30), ("group A", 55), ("group B", 10), ("group B", 20), ("group B", 60), ]columns = ["Group", "Variable"]df = spark.createDataFrame(data, columns)df.show()
Output:
+-------+--------+
| Group|Variable|
+-------+--------+
|group A| 45|
|group A| 30|
|group A| 55|
|group B| 10|
|group B| 20|
|group B| 60|
+-------+--------+
Example: Use mean()
to compute the mean value of a column
mean("Variable")
: it computes the mean value of the Variable column.alias("Mean Value")
: it renames the resulting column as Mean Value.
df.select(mean("Variable").alias("Mean Value")).show()
Output:
+------------------+
| Mean Value|
+------------------+
|36.666666666666664|
+------------------+
Example: Use mean()
with groupBy()
to compute the mean value of each group
groupBy("Group")
: it groups the data by the Group column.agg(mean("Variable").alias("Mean Value")
: it computes the mean value of each group and renames it as Mean Value.
grouped_data = df.groupBy("Group").agg(mean("Variable").alias("Mean Value"))grouped_data.show()
Output:
+-------+------------------+
| Group| Mean Value|
+-------+------------------+
|group A|43.333333333333336|
|group B| 30.0|
+-------+------------------+
# Stop the Spark Sessionspark.stop()