Ahmad Hoghooghi

An Aggregator is a class for creating your own custom, type-safe aggregate functions that operate on Spark Datasets.

To answer your main question directly: No, an Aggregator cannot be used with a RelationalGroupedDataset (the result of a DataFrame groupBy). It is designed exclusively for the type-safe Dataset API, and its primary use is with a KeyValueGroupedDataset.

A Spark custom aggregator allows users to define and apply their own aggregation logic on data within Spark SQL, particularly for operations like groupBy or window functions. This is achieved by extending Spark’s Aggregator class (or the deprecated org.apache.spark.sql.expressions.UserDefinedAggregateFunction - UDAF, in older Spark versions).


## What is an Aggregator?

Think of an Aggregator as a blueprint for building your own aggregate function like SUM() or COUNT(), but one that works directly with your specific Java objects. It gives you complete control over the aggregation logic while ensuring type safety at compile time.

You’d use it when a standard built-in function isn’t sufficient. For example, you might want to calculate a geometrically weighted average, find the second-most-frequent item in a group, or combine complex objects into a single summary object.


## The Four Key Methods

An Aggregator is an abstract class you extend. Its power comes from the four key methods you must implement, which perfectly describe the logic of a parallel aggregation:

  1. zero(): This method defines the “starting” or “empty” value for your aggregation before any elements have been processed. For a sum, this would be 0. For a product, 1. For a list, an empty list.

  2. reduce(buffer, input): This defines how to incorporate one new input element into your current aggregated buffer. It’s the core logic that runs on each partition.

  3. merge(b1, b2): This defines how to combine two intermediate aggregated buffers from different partitions into one. This is what makes your aggregation work in parallel.

  4. finish(buffer): This method takes the final, merged buffer after all data has been processed and transforms it into the final output result.


## How to Use an Aggregator

You use an Aggregator by creating an instance of your custom class and passing it to the .agg() method of a KeyValueGroupedDataset.

Here’s a simple Java example of creating an Aggregator to calculate a custom average.

1. Define the Data and Buffer Classes

Java

// Your input data object
public class Sale implements Serializable {
    private String category;
    private double price;
    // ... getters, setters, constructor
}

// The intermediate buffer to hold the sum and count
public class AverageBuffer implements Serializable {
    private double sum;
    private long count;
    // ... getters, setters, constructor
}

2. Implement the Aggregator

Java

import org.apache.spark.sql.expressions.Aggregator;
import org.apache.spark.sql.Encoder;
import org.apache.spark.sql.Encoders;

public class MyAverageAggregator extends Aggregator<Sale, AverageBuffer, Double> {

    // 1. What is the initial value?
    @Override
    public AverageBuffer zero() {
        return new AverageBuffer(0.0, 0L);
    }

    // 2. How to add a new Sale to the buffer?
    @Override
    public AverageBuffer reduce(AverageBuffer buffer, Sale sale) {
        buffer.setSum(buffer.getSum() + sale.getPrice());
        buffer.setCount(buffer.getCount() + 1);
        return buffer;
    }

    // 3. How to combine two intermediate buffers?
    @Override
    public AverageBuffer merge(AverageBuffer b1, AverageBuffer b2) {
        b1.setSum(b1.getSum() + b2.getSum());
        b1.setCount(b1.getCount() + b2.getCount());
        return b1;
    }

    // 4. How to produce the final result from the final buffer?
    @Override
    public Double finish(AverageBuffer reduction) {
        return reduction.getCount() == 0 ? 0.0 : reduction.getSum() / reduction.getCount();
    }

    // Define Encoders for Spark to serialize the intermediate buffer
    @Override
    public Encoder<AverageBuffer> bufferEncoder() {
        return Encoders.bean(AverageBuffer.class);
    }

    // Define Encoder for the final output
    @Override
    public Encoder<Double> outputEncoder() {
        return Encoders.DOUBLE();
    }
}

3. Apply it to a KeyValueGroupedDataset

Java

// Assume 'salesDS' is a Dataset<Sale>
MyAverageAggregator myAvg = new MyAverageAggregator();

// Group by category and then apply the aggregator
Dataset<Tuple2<String, Double>> result = salesDS
    .groupByKey(Sale::getCategory, Encoders.STRING())
    .agg(myAvg.toColumn().name("average_price"));

result.show();
// +-----------+-----------------+
// |         _1|    average_price|
// +-----------+-----------------+
// |Electronics|           355.95|
// |   Clothing|            45.55|
// +-----------+-----------------+

## Aggregator vs. RelationalGroupedDataset

You cannot use an Aggregator with a RelationalGroupedDataset because they belong to two different Spark APIs that have different design philosophies.

Feature Aggregator (For Datasets) UserDefinedAggregateFunction (UDAF) (For DataFrames)
API Dataset API (groupByKey) DataFrame API (groupBy)
Type Safety Strongly typed. Works with your Java objects. Untyped. Works with generic Row objects and schemas.
Error Checking Compile-time Runtime
Ease of Use Generally easier to write and read complex logic. Can be more verbose and requires manual type casting.