Higher Order Functions in Java

In order to understand future examples, we first have to discuss how to use higher-order functions in Java, and how to write anonymous inner classes. This post will have nothing to do with multi-stage programming.

Let’s write a program that can print out data tables for different mathematical functions. For example, for a function that multiplies by two, f(x) = 2x, we want to print something like this:

x                      f(x)
       -5.0000000000       -10.0000000000
       -4.0000000000        -8.0000000000
       -3.0000000000        -6.0000000000
       -2.0000000000        -4.0000000000
       -1.0000000000        -2.0000000000
        0.0000000000         0.0000000000
        1.0000000000         2.0000000000
        2.0000000000         4.0000000000
        3.0000000000         6.0000000000
        4.0000000000         8.0000000000
        5.0000000000        10.0000000000

We can write a function like this:

public static void printTableTimesTwo(double x1,
double x2,
int n) {
assert n>1;

double x = x1;
double delta = (x2-x1)/(double)(n-1);
System.out.println("x f(x)");
System.out.printf("%20.10f %20.10f\n", x, x*2);
for(int i=0; i<(n-1); ++i) { x += delta; System.out.printf("%20.10f %20.10f\n", x, x*2); } }

The parameter x1 determines the lower end of the interval, x2 the upper end, and n determines how many values should be printed. n needs to be at least 2 to print out the values at x1 and x2. We can generate the table above with this call:

printTableTimesTwo(-5, 5, 11);

What if we want to print out the values of a different function, for example f(x) = x + 4? We can write a new function:

public static void printTablePlusFour(double x1,
double x2,
int n) {
assert n>1;

double x = x1;
double delta = (x2-x1)/(double)(n-1);
System.out.println("x f(x)");
System.out.printf("%20.10f %20.10f\n", x, x+4);
for(int i=0; i<(n-1); ++i) { x += delta; System.out.printf("%20.10f %20.10f\n", x, x+4); } }

This involves a lot of code duplication, though. The only parts that actually differ are the two occurrences of x*2 and x+4. How can we factor that difference out?

Let's write an interface that we can use for any kind of function that takes in one parameter and returns one parameter f(x) = y is an example of such a function.

public interface ILambda; {
public R apply(P param);
}

This interface is called ILambda and it has one method, apply. We used Java generics and didn't specify the return type and the type of the parameter; instead, we just called them R and P, respectively. A function that takes in a Double and that returns a Double, like f(x) = y, can be expressed using a ILambda. A function taking a String and returning an Integer would use ILambda.

Now we can write our f(x) = 2x and f(x) = x + 4 functions using ILambda:

public class TimesTwo implements ILambda {
public Double apply(Double param) { return param*2; }
}
public class PlusFour implements ILambda {
public Double apply(Double param) { return param+4; }
}

Now we can write one printTable method that takes in an ILambda called f representing the function, in addition to the parameters x1, x2 and n, as before:

public static void printTable(ILambda f,
double x1,
double x2,
int n) {
assert n>1;

double x = x1;
double delta = (x2-x1)/(double)(n-1);

// f.apply(x) just means what f(x) means in math!
double y = f.apply(x);
System.out.println("x f(x)");
System.out.printf("%20.10f %20.10f\n", x, y);
for(int i=0; i<(n-1); ++i) { x += delta; y = f.apply(x); System.out.printf("%20.10f %20.10f\n", x, y); } }

Note that when we want to print out the y-value, we just write f.apply(x), which looks very similar to f(x) in mathematics. It means exactly the same.

We can print out the tables for our two functions using:

printTable(new TimesTwo(), -5, 5, 11);
printTable(new PlusFour(), -5, 5, 11);

We have to create new objects for the functions: The first time we call printTable we pass a new TimesTwo object; the second time, we pass a new PlusFour object.

We can now define as many functions as we like without having to rewrite the printTable function. For example, we can easily write a square root function and use it very easily:

public class SquareRoot implements ILambda {
public Double apply(Double param) {
return Math.sqrt(param);
}
}

// ...

printTable(new SquareRoot(), -5, 5, 11);

The really neat thing is that we can even define a new function on-the-fly, without having to give it a name. We do that using anonymous inner classes in Java. Here, we call printTable and pass it a new object that implements ILambda.

printTable(new ILambda() {
public Double apply(Double param) {
return param*param;
}
}, -5, 5, 11);

We define a new ILambda from Double to Double without giving it a name. When we use anonymous inner classes, we need to fill in all the methods that are still abstract. Here, it is just the apply method.

The method printTable is now a "higher order function", because conceptually it is a function that takes another function as input.

Questions:

  1. What does the anonymous ILambda in the example above compute? What's the mathematical function it represents?
  2. How would you print a table for the function f(x) = x2 + 2x?

You can download the complete source code for the examples here:

(Re-posted from The Java Mint Blog)

Share

About Mathias

Software development engineer. Principal developer of DrJava. Recent Ph.D. graduate from the Department of Computer Science at Rice University.
This entry was posted in Mint. Bookmark the permalink.

Leave a Reply