关于 Java8 的 stream 中的 reduce 操作

在之前再次学习折叠操作的时候,我曾做了一些笔记,并且使用 js 编写了一些示例。当时本打算同时也介绍一下 Java8 的 stream 中提供的 reduce 方法(以下简称 java8-reduce),但发现其行为和 js 中的相去甚远——为了高性能和并发支持,它的 reduce 方法是经过大量优化的,也引入了自己独有的所谓 Combiner,可谓是“Java 特色 reduce 操作”(哈哈!)。

最近学习 Hadoop 的 MapReduce 的时候,它的 Combiner 让我又回想起 java8-reduce 中的 Combiner(虽然其行为完全不是一回事就是了,MapReduce 中的 Combiner 的目的是预先进行一次本地的聚集,减少网络传输成本;java8-reduce 的 Combiner 的目的是归并各部分 reduce 的结果,得到最终值),因此这里又回过头来学习一波。

咳咳,学过 MR 的 Combiner 后回过头来一看,发现它的行为(至少是接口)确实和 java8-reduce 的非常像,它们都需要保证输入和输出的签名是相同的以进行反复的 Combine。

咳咳,学到 Spark 又回过头来看,发现 RDD 的 Aggregate 和 java8-stream 的并发版的 reduce 的签名是一致的,令人感叹。

咳咳,学习函数式编程时,编写一个二叉树的 fold 方法,能发现它也需要一个 combiner方法,将左右子树得到的结果进行拼接。combiner 这个角色显然使用的地方是更加普遍的,它绝非仅是某种 Java 特色的产物,而是函数式编程在工程实践中发现出的一种模式。

如果要说 java8-reduce 的最明显的特点,我认为有两个——

首先,它非常合适同 Java 的可变集合一起工作(collect 方法),这在其他语言里是见不到的,这样既符合 java 传统的编程模式,也能保证足够的性能,甚至在并发情况下也能够应付。

然后,它允许并发的 reduce 操作;java8-reduce 能把集合切分成多个部分,对每个部分并行地执行 reduce 操作,并通过所谓的 combiner 函数两两组合,得到最终结果;但显然使用并行的 reduce 操作时必须符合特定约束,以保证其能正确执行,无关切分情况或处理器核心数量。

挨个点名!

The easiest way

最简单的 reduce 操作就是结果值和集合内的值同类型的 reduce 了,比如对集合求积,求和,Java 中干这事一样简单——

1
2
3
// 不给定初始值也可,此时返回类型会变成 Optional<Integer>,这点还挺酷的,很“纯”。
int sum = IntStream.of(1, 2, 3).reduce(0, (acc, x) -> acc + x); // or Integer::sum
System.out.println(sum);

但 Java 未提供 reduceRight 方法,想必是认为其应用范围不广,事实也确实如此。

Parallel: Where thing goes wrong

上面是串行流,它的行为同其它语言一致,但若是并行流呢?试试下面的代码——

1
2
int sum = IntStream.of(1, 2, 3).parallel().reduce(1, Integer::sum);
System.out.println(sum);

结果是什么?如果是串行流的话,结果是肯定的——7,但这是并行流,我的电脑的结果是 9。也可以去尝试一下其它结果,最后会发现,只要 zero 参数(java 叫做 identity,这个名字兴许能给我们启发)不为 0,最后并行得到的结果和串行必然不一样。

那么,并行流的 reduce 究竟是怎么跑的?

我们知道,串行流的 reduce 可以描述成把集合中各元素通过一个二元操作符相连接,比如上面的求集合的和可以写成——

1
0 >=> 1 >=> 2 >=> 3

但是并行流显然不是这样干的,总的来说,并行流的 reduce 操作,会先把流切分成多个部分(具体切分数目由处理器核心数量决定),然后对每个部分各自并行执行 reduce 操作,然后再两两进行 Combine 操作,得到最终结果,这里的每个切分的部分,可以把它称作ReduceTask,比如对[1,2,3,4,5].reduceParr(1, plus),它的计算过程可能是这样的——

  1. [1, 2, 3, 4, 5]切分成四段,[1][2][3][4, 5]
  2. 对每一段,并行调用 reduce 方法,即[1].reduce(1, plus)[2].reduce(1, plus)……得到23410
  3. 对结果两两执行 combine 操作,这里的 combine 操作即 acc 操作,即 plus,combine(combine(2, 3), combine(4, 10))
  4. 得到结果——19。(实践好像是分了 5 段,因此结果是 20)

具体流程是我们不应该关心的,我们只需要知道,各个部分会分别 reduce,然后会两两 combine,最终得到最终结果。这里的“两两 combine”不是说从左往右依次 combine(这不就又是一个串行的 reduce 嘛),想象一下归并排序的迭代版,它自底向上,每次 combine 的都是更“大”的值。

对于同类型的 reduce 操作,combine 函数同 acc 函数相同,所以[1,2,3].reduce(0, Integer::sum)也可以描述成[1,2,3].reduce(0, Integer::sum, Integer::sum),表示它的合并和积累函数都是 sum。

非同类型的 reduce

只有在非同类型的 reduce 操作中,combine 才会明确显示出来——它需要用户主动去定义,但我们先来看看 Java 特色的 reduce 操作。

非同类型的 reduce 操作有两个方法可用——

  • reduce(identity : U, acc : (U, T) => U, combiner : (U, U) => U)

  • collect(supplier : () => U, acc : (U, T) => (), combiner : (U, U) => ())

前者(reduce)用于累积的类型是不可变值的情况;后者(collect)用于累积的类型是可变值的情况。

reduce 方法

我们先来看看 reduce,下面的操作将集合反转——

1
2
3
4
5
6
7
List<Integer> res = IntStream.of(1, 2, 3, 4, 5).boxed()
.reduce(new LinkedList<>(), (acc, x) -> {
acc.addFirst(x);
return acc;
}, (a, b) -> {
return null;
});

在这里,第三个参数就是所谓的 combiner,这里因为是串行流,所以 combiner 不会被调用,可以直接返回 null。但 combiner 本身不能为 null,否则会抛空指针异常。

但若是并行流呢?我们凭第一印象,大概会这么写——

1
2
3
4
5
6
7
8
List<Integer> res = IntStream.of(1, 2, 3, 4, 5).boxed().parallel()
.reduce(new LinkedList<>(), (acc, x) -> {
acc.addFirst(x);
return acc;
}, (a, b) -> {
a.addAll(b);
return a;
});

但是这样是无法生效的,如果试图对该集合进行输出,它在输出时会抛出空指针异常!这说明它的结构被破坏了,我们遇到了线程安全问题!

但为什么会这样呢?原来,reduce 方法中所使用的 identity,会被每一个 ReduceTask 都共用!并且我们在 acc 函数中原样返回了累积值,因此它会被持续使用下去,如果在 combiner 中试图判断a == b,它也将是 true,因为 a 和 b 是同一个对象!

因此,reduce 方法只适合不可变对象,或者我们可以每次返回值都不改变原值,而是返回一个新的值,新的引用,但这对性能是极大的损耗,只有 string 这样的不可变类在 reduce 方法上才能得到运用。

collect 方法

这时,使用collect方法就是一个更好的选择,我们再看一看 collect 的签名——

1
U collect(supplier : () => U, acc : (U, T) => (), combiner : (U, U) => ())

这签名实际上把所有细节都描述出来了——我们通过一个 supplier,来让每一个 ReduceTask 都能拿到不一样的引用,从而避免共享数据问题;acc 和 combiner 都是没有返回值的,因此显然我们需要通过修改累积值来完成积累和合并操作。下面使用 collect 来进行一个 count 操作——

1
2
3
4
5
6
7
8
9
10
11
12
13
14
Map<String, Integer> count = Stream.of("hello", "world", "hello", "me", "hello", "yukina", "hello", "happy", "world")
.parallel().collect(() -> new HashMap<>(),
(acc, x) -> {
if (!acc.containsKey(x))
acc.put(x, 0);
acc.compute(x, (k, v) -> v + 1);
}, (a, b) -> {
// combine 时必须把第二个值合并到第一个值
b.forEach((kb, vb) -> {
if (!a.containsKey(kb))
a.put(kb, 0);
a.compute(kb, (ka, va) -> va + vb);
});
});

That’s it!了解这么多就足够了。使用原则是,当使用值类型,或不可变类型,如 String,Scala,Kotlin 的不可变集合,基本类型等的时候,使用 reduce 方法;使用引用类型,可变类型的时候,使用 collect 方法。虽然 combiner,acc,初始值的设置要遵循的一定的规律,但给出统一和容易理解的表述并不容易,具体问题具体分析吧。


本博客所有文章除特别声明外,均采用 CC BY-NC-SA 4.0 协议 ,转载请注明出处!