Skip to main content

Data Aggregation in Spark DataFrames

Loading the Shipments Dataset

from pyspark.sql import SparkSession

spark = SparkSession.builder.appName('agg').getOrCreate()

df = spark.read.csv('/path/to/shipments.csv', header=True, inferSchema=True)

Why we need this?
To start Spark, load the shipments dataset, and automatically detect schema types.

Preview & Schema

df.show()

Result

+----------+--------+----------+-----+-----+
|ShipmentID|Company | Product |Units|Sales|
+----------+--------+----------+-----+-----+
| S001 | FedEx | Soap | 100 |1500 |
| S002 | FedEx | Shampoo | 200 |3000 |
| S003 |BlueDart| Bread | 150 |1800 |
| S004 | DHL |Toothpaste| 120 |2400 |
| S005 | DHL | Rice | 300 |6000 |
| S006 |BlueDart|Chocolate | 180 |3600 |
| S007 | FedEx | Juice | 130 |2600 |
| S008 | DHL | Cereal | 220 |4400 |
| S009 |BlueDart| Soda | 110 |2200 |
| S010 | FedEx |Facewash | 140 |2800 |
+----------+--------+----------+-----+-----+
df.printSchema()

Result

root
|-- ShipmentID: string (nullable = true)
|-- Company: string (nullable = true)
|-- Product: string (nullable = true)
|-- Units: integer (nullable = true)
|-- Sales: integer (nullable = true)

Aggregation Examples

1. Group by Company (no aggregation yet)

 df.groupBy("Company")
Returns a GroupedData object — needs aggregation like .count(), .mean() etc.

Why we need this?
To logically group data by company before applying aggregate functions.

2. Average of all numeric columns per Company

df.groupBy("Company").mean().show()

Why we need this?
To calculate per-company average of numeric columns (Units, Sales). Result

+--------+-----------+-----------+
|Company |avg(Units) |avg(Sales) |
+--------+-----------+-----------+
|BlueDart| 146.6667 | 2533.3333|
| FedEx| 142.5 | 2475.0 |
| DHL| 213.3333 | 4266.6667|
+--------+-----------+-----------+

3. Count of records per Company

df.groupBy("Company").count().show()

Why we need this?
To check how many shipments each company handled.

Result

+--------+-----+
|Company |count|
+--------+-----+
|BlueDart| 3 |
| FedEx| 4 |
| DHL| 3 |
+--------+-----+

4. Max value per Company

df.groupBy("Company").max().show()

Why we need this? To find the highest values for Units and Sales per company. Result

+--------+-----------+-----------+
|Company |max(Units) |max(Sales) |
+--------+-----------+-----------+
|BlueDart| 180 | 3600 |
| FedEx| 200 | 3000 |
| DHL| 300 | 6000 |
+--------+-----------+-----------+

5. Total sales (across all rows)

df.agg({'Sales': 'sum'}).show()

Why we need this?
To compute the total sales value across the entire dataset.

Result

+----------+
|sum(Sales)|
+----------+
| 30300 |
+----------+

6. Maximum sale value

df.agg({'Sales': 'max'}).show()

Why we need this?
To identify the single largest sales value among all shipments.

Result

+----------+
|max(Sales)|
+----------+
| 6000 |
+----------+

7. Group first, then aggregate (max sales per company)

group_data = df.groupBy("Company")
group_data.agg({'Sales': 'max'}).show()

Why we need this?
To find max sales for each company after grouping.

Result

+--------+----------+
|Company |max(Sales)|
+--------+----------+
|BlueDart| 3600 |
| FedEx| 3000 |
| DHL| 6000 |
+--------+----------+

Built-in Functions

8. Average Sales using avg() function

from pyspark.sql.functions import avg

df.select(avg('Sales').alias('Average Sales')).show()

Why we need this?
To compute average sales more explicitly using built-in avg(). Result

+-------------+
|Average Sales|
+-------------+
| 3030.0 |
+-------------+

9. Count distinct sales values

from pyspark.sql.functions import countDistinct

df.select(countDistinct('Sales')).show()

Why we need this?
To find how many unique sales values exist in the dataset. Result

+---------------------+
|count(DISTINCT Sales)|
+---------------------+
| 10 |
+---------------------+

Sorting / Ordering

10. Sort by Sales (ascending)

df.orderBy('Sales').show()

Why we need this?
To view shipments sorted from lowest to highest sales. Result

+----------+--------+----------+-----+-----+
|ShipmentID|Company | Product |Units|Sales|
+----------+--------+----------+-----+-----+
| S001 | FedEx | Soap | 100 |1500 |
| S003 |BlueDart| Bread | 150 |1800 |
| S004 | DHL |Toothpaste| 120 |2400 |
... (and so on)

11. Sort by Sales (descending)

df.orderBy(df['Sales'].desc()).show()

Why we need this?
To view shipments sorted from highest to lowest sales. Result

+----------+--------+----------+-----+-----+
|ShipmentID|Company | Product |Units|Sales|
+----------+--------+----------+-----+-----+
| S005 | DHL | Rice | 300 |6000 |
| S008 | DHL | Cereal | 220 |4400 |
| S006 |BlueDart|Chocolate | 180 |3600 |
... (and so on)

🔑 1-Minute Summary — Data Aggregation in PySpark (Shipments Dataset)

CodeWhat it Does
df.groupBy("Company")Groups rows by Company column (no aggregation yet)
df.groupBy("Company").mean().show()Computes average of all numeric columns for each company
df.groupBy("Company").count().show()Counts number of rows (shipments) per company
df.groupBy("Company").max().show()Returns the max value per numeric column per company
df.agg({'Sales': 'sum'}).show()Calculates total sales across all rows
df.agg({'Sales': 'max'}).show()Finds the maximum single sales value
df.groupBy("Company").agg({'Sales': 'max'}).show()Max sales per company using grouped aggregation
from pyspark.sql.functions import avg, countDistinct, stddevImports useful built-in aggregation functions
df.select(avg('Sales')).show()Computes average sales across all shipments
df.select(countDistinct('Sales')).show()Counts number of unique sales values
df.orderBy('Sales').show()Sorts all rows by Sales in ascending order
df.orderBy(df['Sales'].desc()).show()Sorts all rows by Sales in descending order