An Aggregator is a class for creating your own custom, type-safe aggregate functions that operate on Spark Datasets.
To answer your main question directly: No, an Aggregator cannot be used with a RelationalGroupedDataset (the result of a DataFrame groupBy). It is designed exclusively for the type-safe Dataset API, and its primary use is with a KeyValueGroupedDataset.
A Spark custom aggregator allows users to define and apply their own aggregation logic on data within Spark SQL, particularly for operations like groupBy or window functions. This is achieved by extending Spark’s Aggregator class (or the deprecated org.apache.spark.sql.expressions.UserDefinedAggregateFunction - UDAF, in older Spark versions).
Think of an Aggregator as a blueprint for building your own aggregate function like SUM() or COUNT(), but one that works directly with your specific Java objects. It gives you complete control over the aggregation logic while ensuring type safety at compile time.
You’d use it when a standard built-in function isn’t sufficient. For example, you might want to calculate a geometrically weighted average, find the second-most-frequent item in a group, or combine complex objects into a single summary object.
An Aggregator is an abstract class you extend. Its power comes from the four key methods you must implement, which perfectly describe the logic of a parallel aggregation:
zero(): This method defines the “starting” or “empty” value for your aggregation before any elements have been processed. For a sum, this would be 0. For a product, 1. For a list, an empty list.
reduce(buffer, input): This defines how to incorporate one new input element into your current aggregated buffer. It’s the core logic that runs on each partition.
merge(b1, b2): This defines how to combine two intermediate aggregated buffers from different partitions into one. This is what makes your aggregation work in parallel.
finish(buffer): This method takes the final, merged buffer after all data has been processed and transforms it into the final output result.
AggregatorYou use an Aggregator by creating an instance of your custom class and passing it to the .agg() method of a KeyValueGroupedDataset.
Here’s a simple Java example of creating an Aggregator to calculate a custom average.
Java
// Your input data object
public class Sale implements Serializable {
private String category;
private double price;
// ... getters, setters, constructor
}
// The intermediate buffer to hold the sum and count
public class AverageBuffer implements Serializable {
private double sum;
private long count;
// ... getters, setters, constructor
}
AggregatorJava
import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;
public class MyAverageAggregator extends Aggregator<Sale, AverageBuffer, Double> {
// 1. What is the initial value?
@Override
public AverageBuffer zero() {
return new AverageBuffer(0.0, 0L);
}
// 2. How to add a new Sale to the buffer?
@Override
public AverageBuffer reduce(AverageBuffer buffer, Sale sale) {
buffer.setSum(buffer.getSum() + sale.getPrice());
buffer.setCount(buffer.getCount() + 1);
return buffer;
}
// 3. How to combine two intermediate buffers?
@Override
public AverageBuffer merge(AverageBuffer b1, AverageBuffer b2) {
b1.setSum(b1.getSum() + b2.getSum());
b1.setCount(b1.getCount() + b2.getCount());
return b1;
}
// 4. How to produce the final result from the final buffer?
@Override
public Double finish(AverageBuffer reduction) {
return reduction.getCount() == 0 ? 0.0 : reduction.getSum() / reduction.getCount();
}
// Define Encoders for Spark to serialize the intermediate buffer
@Override
public Encoder<AverageBuffer> bufferEncoder() {
return Encoders.bean(AverageBuffer.class);
}
// Define Encoder for the final output
@Override
public Encoder<Double> outputEncoder() {
return Encoders.DOUBLE();
}
}
KeyValueGroupedDatasetJava
// Assume 'salesDS' is a Dataset<Sale>
MyAverageAggregator myAvg = new MyAverageAggregator();
// Group by category and then apply the aggregator
Dataset<Tuple2<String, Double>> result = salesDS
.groupByKey(Sale::getCategory, Encoders.STRING())
.agg(myAvg.toColumn().name("average_price"));
result.show();
// +-----------+-----------------+
// | _1| average_price|
// +-----------+-----------------+
// |Electronics| 355.95|
// | Clothing| 45.55|
// +-----------+-----------------+
Aggregator vs. RelationalGroupedDatasetYou cannot use an Aggregator with a RelationalGroupedDataset because they belong to two different Spark APIs that have different design philosophies.
RelationalGroupedDataset is part of the untyped DataFrame API. It operates on generic Row objects and column names. To create a custom aggregation for it, you must use a UserDefinedAggregateFunction (UDAF), which is the older, untyped equivalent.
KeyValueGroupedDataset is part of the type-safe Dataset API. It operates on specific Java/Scala objects. The Aggregator was designed specifically for this API to provide compile-time type safety.
| Feature | Aggregator (For Datasets) |
UserDefinedAggregateFunction (UDAF) (For DataFrames) |
|---|---|---|
| API | Dataset API (groupByKey) |
DataFrame API (groupBy) |
| Type Safety | Strongly typed. Works with your Java objects. | Untyped. Works with generic Row objects and schemas. |
| Error Checking | Compile-time | Runtime |
| Ease of Use | Generally easier to write and read complex logic. | Can be more verbose and requires manual type casting. |