Data Aggregation in Spark DataFrames
Loading the Shipments Dataset
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('agg').getOrCreate()
df = spark.read.csv('/path/to/shipments.csv', header=True, inferSchema=True)
Why we need this?
To start Spark, load the shipments dataset, and automatically detect schema types.
Preview & Schema
df.show()
Result
+----------+--------+----------+-----+-----+
|ShipmentID|Company | Product |Units|Sales|
+----------+--------+----------+-----+-----+
| S001 | FedEx | Soap | 100 |1500 |
| S002 | FedEx | Shampoo | 200 |3000 |
| S003 |BlueDart| Bread | 150 |1800 |
| S004 | DHL |Toothpaste| 120 |2400 |
| S005 | DHL | Rice | 300 |6000 |
| S006 |BlueDart|Chocolate | 180 |3600 |
| S007 | FedEx | Juice | 130 |2600 |
| S008 | DHL | Cereal | 220 |4400 |
| S009 |BlueDart| Soda | 110 |2200 |
| S010 | FedEx |Facewash | 140 |2800 |
+----------+--------+----------+-----+-----+
df.printSchema()
Result
root
|-- ShipmentID: string (nullable = true)
|-- Company: string (nullable = true)
|-- Product: string (nullable = true)
|-- Units: integer (nullable = true)
|-- Sales: integer (nullable = true)
Aggregation Examples
1. Group by Company (no aggregation yet)
df.groupBy("Company")
Returns a GroupedData object — needs aggregation like .count(), .mean() etc.
Why we need this?
To logically group data by company before applying aggregate functions.
2. Average of all numeric columns per Company
df.groupBy("Company").mean().show()
Why we need this?
To calculate per-company average of numeric columns (Units, Sales).
Result
+--------+-----------+-----------+
|Company |avg(Units) |avg(Sales) |
+--------+-----------+-----------+
|BlueDart| 146.6667 | 2533.3333|
| FedEx| 142.5 | 2475.0 |
| DHL| 213.3333 | 4266.6667|
+--------+-----------+-----------+
3. Count of records per Company
df.groupBy("Company").count().show()
Why we need this?
To check how many shipments each company handled.
Result
+--------+-----+
|Company |count|
+--------+-----+
|BlueDart| 3 |
| FedEx| 4 |
| DHL| 3 |
+--------+-----+
4. Max value per Company
df.groupBy("Company").max().show()
Why we need this? To find the highest values for Units and Sales per company. Result
+--------+-----------+-----------+
|Company |max(Units) |max(Sales) |
+--------+-----------+-----------+
|BlueDart| 180 | 3600 |
| FedEx| 200 | 3000 |
| DHL| 300 | 6000 |
+--------+-----------+-----------+
5. Total sales (across all rows)
df.agg({'Sales': 'sum'}).show()
Why we need this?
To compute the total sales value across the entire dataset.
Result
+----------+
|sum(Sales)|
+----------+
| 30300 |
+----------+
6. Maximum sale value
df.agg({'Sales': 'max'}).show()
Why we need this?
To identify the single largest sales value among all shipments.
Result
+----------+
|max(Sales)|
+----------+
| 6000 |
+----------+
7. Group first, then aggregate (max sales per company)
group_data = df.groupBy("Company")
group_data.agg({'Sales': 'max'}).show()
Why we need this?
To find max sales for each company after grouping.
Result
+--------+----------+
|Company |max(Sales)|
+--------+----------+
|BlueDart| 3600 |
| FedEx| 3000 |
| DHL| 6000 |
+--------+----------+
Built-in Functions
8. Average Sales using avg() function
from pyspark.sql.functions import avg
df.select(avg('Sales').alias('Average Sales')).show()
Why we need this?
To compute average sales more explicitly using built-in avg().
Result
+-------------+
|Average Sales|
+-------------+
| 3030.0 |
+-------------+
9. Count distinct sales values
from pyspark.sql.functions import countDistinct
df.select(countDistinct('Sales')).show()
Why we need this?
To find how many unique sales values exist in the dataset.
Result
+---------------------+
|count(DISTINCT Sales)|
+---------------------+
| 10 |
+---------------------+
Sorting / Ordering
10. Sort by Sales (ascending)
df.orderBy('Sales').show()
Why we need this?
To view shipments sorted from lowest to highest sales.
Result
+----------+--------+----------+-----+-----+
|ShipmentID|Company | Product |Units|Sales|
+----------+--------+----------+-----+-----+
| S001 | FedEx | Soap | 100 |1500 |
| S003 |BlueDart| Bread | 150 |1800 |
| S004 | DHL |Toothpaste| 120 |2400 |
... (and so on)
11. Sort by Sales (descending)
df.orderBy(df['Sales'].desc()).show()
Why we need this?
To view shipments sorted from highest to lowest sales.
Result
+----------+--------+----------+-----+-----+
|ShipmentID|Company | Product |Units|Sales|
+----------+--------+----------+-----+-----+
| S005 | DHL | Rice | 300 |6000 |
| S008 | DHL | Cereal | 220 |4400 |
| S006 |BlueDart|Chocolate | 180 |3600 |
... (and so on)
🔑 1-Minute Summary — Data Aggregation in PySpark (Shipments Dataset)
Code | What it Does |
---|---|
df.groupBy("Company") | Groups rows by Company column (no aggregation yet) |
df.groupBy("Company").mean().show() | Computes average of all numeric columns for each company |
df.groupBy("Company").count().show() | Counts number of rows (shipments) per company |
df.groupBy("Company").max().show() | Returns the max value per numeric column per company |
df.agg({'Sales': 'sum'}).show() | Calculates total sales across all rows |
df.agg({'Sales': 'max'}).show() | Finds the maximum single sales value |
df.groupBy("Company").agg({'Sales': 'max'}).show() | Max sales per company using grouped aggregation |
from pyspark.sql.functions import avg, countDistinct, stddev | Imports useful built-in aggregation functions |
df.select(avg('Sales')).show() | Computes average sales across all shipments |
df.select(countDistinct('Sales')).show() | Counts number of unique sales values |
df.orderBy('Sales').show() | Sorts all rows by Sales in ascending order |
df.orderBy(df['Sales'].desc()).show() | Sorts all rows by Sales in descending order |