Spark’s reduce() and reduceByKey() functions

A couple of weeks ago, I had written about Spark’s map() and flatMap() transformations. Expanding on that, here is another series of code snippets that illustrate the reduce() and reduceByKey() methods.

As in the previous example, we shall start by understanding the reduce() function in Python before diving into Spark. The map() operation in Python applies the same function to multiple elements in a collection, and it is faster than using a for loop. However, merely using the map function might not always solve the problem at hand. We might need to get to a single value from the result of the map operation.

For example, consider the case where we need to compute the sum of the cubes of the first n natural numbers. The map() operation would give us the squares as a list. To this, we can apply the reduce() operation which produces a single result, in this case, the sum of the numbers of a list. The arguments passed to the reduce() method are similar to that of the map() method — it takes a function, and an input list.

If you can grok this concept, it will be easy to understand how this works in Spark. The only difference between the reduce() function in Python and Spark is that, similar to the map() function, Spark’s reduce() function is a member method of the RDD class. The code snippet below shows the similarity between the operations in Python and Spark.

reduceByKey()

While computing the sum of cubes is a useful start, as a use case, it is too simple. Let us consider instead a use case that is more germane to Spark — word counts. We have an input file, and we will need to count the number of occurrences of each word in the file.

Before writing the code, let us get our logic right, and this will help us form the code in no time.

  1. Read the file
  2. The entire contents would be held in memory as a list of strings, where each item of the list corresponds to a line from the file
  3. Split each list item and form a list of words
  4. At this point, assume that each word occurs once in the text
  5. Now process the word counts. Each time you encounter a previously-known word, increment the count
  6. Display each unique word and the number of times it appears in the text

Let us understand how our little algorithm above translates to the code snippet. From our prior encounter with flatMap(), we know that it is the best way to flatten the list of lines to a list of words. Line 4 in the code snippet takes care of steps 1, 2, and 3 in our algorithm.

Line 5 is critical to further processing. Here, we create a pair RDD, which is simply an RDD comprised of key-value pairs. In this case our key-value pair takes the form of a Python tuple, where the key is a word and the value is its count, which is initially set to 1 for all words in the file.

If we wanted to count the number of words in the file, we would call the reduce() function. However, we need to return the count of each word. If you are familiar with SQL, this is the equivalent of using a GROUP BY clause and the COUNT() function. This is exactly what we are doing in step 6. In this case, we pass the count function to the reduceByKey() method. The result is an RDD — a pair RDD, to be more accurate — that contains key-values pairs for all the words and their counts.