
pivot()
The pivot()
function is used to transform and reshape data in a DataFrame. It turns unique values from one column into multiple columns in the output DataFrame, facilitating data summarization and analysis.
Usage
pivot()
is commonly used after agroupBy()
operation.- It takes a column name as an argument and pivots that column so that its unique values become column headers in the resulting DataFrame.
Create Spark Session and sample DataFrame
from pyspark.sql import SparkSessionfrom pyspark.sql.functions import sum
# Initialize Spark Sessionspark = SparkSession.builder.appName("pivotExample").getOrCreate()
# Sample DataFramedata = [("James", "Sales", 3000), ("Ana", "Sales", 4100), ("Robert", "IT", 5000), ("Maria", "IT", 3900)]columns = ["Employee Name", "Department", "Salary"]df = spark.createDataFrame(data, columns)df.show()
Output:
+-------------+----------+------+
|Employee Name|Department|Salary|
+-------------+----------+------+
| James| Sales| 3000|
| Ana| Sales| 4100|
| Robert| IT| 5000|
| Maria| IT| 3900|
+-------------+----------+------+
Example: Use pivot
to expand a column
groupBy(Employee Name)
: it groups the data by the Employee Name column.pivot("Department")
: it pivots the Department column, turning its unique values into separate columns.sum("Salary")
: it calculates the total salary of each employee in each pivoted department column.- The resulting DataFrame displays the salary of each employee within their respective department. It displays the salary as Null if the employee doesn't belong to the department.
pivot_df = df.groupBy("Employee Name").pivot("Department").sum("Salary")pivot_df.show()
Output:
+-------------+----+-----+
|Employee Name| IT|Sales|
+-------------+----+-----+
| James|NULL| 3000|
| Ana|NULL| 4100|
| Maria|3900| NULL|
| Robert|5000| NULL|
+-------------+----+-----+
# Stop the Spark Sessionspark.stop()