Reduce (fold)

In previous sections we discussed many methods for iterating over data and transforming it. In this section we'll discuss another higher order function that is arguably one of the most powerful. It is a concept recognized across enough programming languages to get its own wikipedia article. Most popular languages call it reduce, although some languages will call it fold or inject. Here's the parameters it takes, and although the order of the parameters may be different in other languages the functionality and output will be the same.

reduce(list, fn, starting_value)

Like with map() and filter(), it takes a list you want to transform and a function (fn) to do the transformation. The transformation function behaves like a recursive loop like seen in the last section. Here's a function that takes a list of numbers and gives you the total sum of those numbers.

local list = {23, 63, 12, 48, 3}

local sum_fn = function(accumulator, current_number)
  return accumulator + current_number
end


local total_sum = reduce(list, sum_fn, 0)

We pass reduce a starting number of 0. What happens is sum_fn is invoked with the first parameter, the accumulator being the starting number 0 and current_number being the first number in the list. Whatever value the function returns becomes the new value for accumulator next loop around.

Lua doesn't have a reduce function built in so we'll implement our own here with a detailed description of all the parameters. Try not to get too hung up on the actual reduce function's implementation at the top, but rather focus below that on how it works. There will be several more examples. Once you understand how to use it, go back to the top and look at the actual reduce function's implementation. Copy all this code into the text editor window on the REPL and run it:

-- Applies fn on two arguments cumulative to the items of the array t,
-- from left to right, so as to reduce the array to a single value. If
-- a first value is specified the accumulator is initialized to this,
-- otherwise the first value in the array is used.
-- @param {table} t - a table to reduce
-- @param {function} fn - the reducer for comparing the two values
--   @param {*} acc - The accumulator accumulates the callback's return
--     values; It is the accumulated value previously returned in the
--     last invocation of the callback, or `first_value`, if supplied.
--   @param {*} current_value - The current element being processed in the list.
--   @param {number} current_index - The index of the current element
--     being processed in the list, starting at 1.
-- @param {*} first_value - The initial value of the accumulation. If the array is
--   empty, the first_value will also be the returned value. If the array is empty
--   and no first value is specified an error is raised.
-- @example
--   -- returns 'zxy'
--   reduce(
--     { 'x', 'y' },
--     function(a, b) return a + b end,
--     'z'
--    )
local function reduce(t, fn, first)
  local acc = first
  local starting_value = first ~= nil
  for i, v in ipairs(t) do
    -- No starting value, start on
    -- the first element in the list
    if starting_value then
      acc = fn(acc, v, i, t)
    else
      acc = v
      starting_value = true
    end
  end
  assert(
    starting_value,
    'Attempted to reduce an empty table with no first value.'
  )
  return acc
end

local list = {23, 63, 12, 48, 3}
local sum_fn = function(accumulator, current_number)
  print(accumulator)
  return accumulator + current_number
end

local total_sum = reduce(list, sum_fn, 0)
print('The total sum is:', total_sum)

Following the print statement inside of sum_fn, we can see that the accumulator starts out with the 0 we pass in. We add current_number to accumulator and it begins to accumulate all the values as it goes.

0
23
86
98
146
The total sum is:   149

If we don't pass in a starting number, the accumulator will begin right away with the first number in the list:

local sum_fn = function(accumulator, current_number)
  print(accumulator)
  return accumulator + current_number
end

local total_sum = reduce(list, sum_fn)
23
86
98
146
The total sum is:   149

If you've used javascript, you may be starting to see the uncanny resemblance it bears to javascript's reduce function. Both languages are very similar syntactically, and given the ubiquity of javascript this Lua implementation follows much of the same behavior.

Let's look at some more examples to better understand how to reduce and what situations doing so could prove useful. The reduce function is omitted in the following examples, but you can copy and paste the function in the REPL alongside the examples to run the code yourself.

-- Concatenate a list of words
local list = {'this', 'is', 'a', 'sentence'}

local sentence = reduce(list, function(acc, word, index, list)
  -- Add a period if this is the last word
  if index == #list then
    word = word .. '.'
  end
  -- Otherwise add a space between the words
  return acc .. ' ' .. word
end)

print(sentence)
this is a sentence.
-- Only keep odd numbers
local list = {23, 63, 12, 48, 3}
local odd_numbers = reduce(list, function(acc, current_number)
  if current_number % 2 == 0 then
    return acc
  end
  acc[#acc + 1] = current_number
  return acc
end, {})

for key, value in ipairs(odd_numbers) do
  print(value)
end
23
63
3

This looks similar to what we might do with the filter function previously covered in 3.3 - Map and filter. In fact, we can compose filter and map from reduce. Take a look at the same code refactored out:

local filter = function(list, predicate_fn)
  return reduce(list, function(acc, val, i, t)
    if predicate_fn(val, i, t) then
      acc[#acc + 1] = val
      return acc
    end
    return acc
  end, {})
end


-- Only keep odd numbers
local list = {23, 63, 12, 48, 3}
local odd_numbers = filter(list, function(current_number)
  return current_number % 2 ~= 0
end)

for key, value in ipairs(odd_numbers) do
  print(value)
end

An example of wrapping reduce with a new map function won't be explained here, but rather left up to the reader as an exercise at the end of this section.

Here's one more example that is a bit more complex, a function called compose that creates a pipeline for passing data through. It accomplishes this by passing any functions you give it through to reduce as a list:

-- Function that allows you to compose other functions
-- together to form a pipeline. The resulting pipeline
-- is a function that you can pass your intended data through.
local compose = function(...)
  -- "..." and "arg" are special keywords in Lua.
  -- See: https://www.lua.org/pil/5.2.html
  local fns = arg
  return function(x)
    return reduce(fns, function(acc, v)
      return v(acc)
    end, x)
  end
end

-- Some example composable functions
local add = function(x)
  return function(y)
    return y + x
  end
end
local multiply = function(x)
  return function(y)
    return y * x
  end
end
local subtract = function(x)
  return function(y)
    return y - x
  end
end


local number_pipeline = compose(add(12), multiply(2), subtract(9))
print(number_pipeline(3))
print(number_pipeline(2))

Alternative reduce implementations

Iterating tables

Let's go back to the implementation of reduce for a moment. Take a look at the implementation of it given above. Notice the iteration inside is using ipairs which expects an array/list-type table. If we wanted to reduce a non-list table we could modify reduce to first check if the table is an array and do appropriate iteration over the table whether or not it is. Let's test that:

local function reduce(t, fn, first)
  local get_iterator = function(t)
    if type(t) == 'table' then
      -- If property of 1 is empty then
      -- iterate as a regular keyed table
      if t[1] == nil then
        return pairs(t)
      end
      return ipairs(t)
    end
    error('Expected table, got ' .. tostring(t))
  end
  local acc = first
  local starting_value = first ~= nil
  -- Whether we do ipairs or pairs is conditional
  for i, v in get_iterator(t) do
    -- No starting value, start on
    -- the first element in the list
    if starting_value then
      acc = fn(acc, v, i, t)
    else
      acc = v
      starting_value = true
    end
  end
  assert(
    starting_value,
    'Attempted to reduce an empty table with no first value.'
  )
  return acc
end


local list = {
  monday = 23,
  tuesday = 63,
  wednesday = 12,
  thursday = 48,
  friday = 3
}
local total_sum = reduce(list, function(acc, current_number, key)
  print(key .. ': ' .. current_number)
  return acc + current_number
end)

print('total sum: ' .. total_sum)

This should print something like this:

wednesday: 12
friday: 3
thursday: 48
monday: 23
total sum: 149

Note that the order the keys are iterated in are not guaranteed. Also "tuesday" wasn't printed out because it was the starting number, but it was still included in the total. Passing an extra argument of 0 to reduce would have caused all the days to be passed through our reducer function and printed out.

Break early

Ok, here's another example that seems tricky at first glance; Let's say you implemented some search functionality on top of reduce like this:

local list = {23, 63, 12, 48, 3}

local find = function(list, predicate_fn)
  return reduce(list, function(acc, v, i, t)
    if predicate_fn(v, a, t) then
      return v
    end
    return acc
  end)
end

print(find(list, function(val)
  return val > 50
end))
print(find(list, function(val)
  return val % 8 == 0
end))

Which prints out the expected results:

63
48

But do you see what's problematic about this? If we find the results we want, the reduce function will keep running through the entire list unnecessarily. Typically when doing a search you only want the first item you find anyway, but the above implementation will return the last item found if more than one match is made. Do you remember how the reduce function passes in the table as the last argument to the reducer function? We can take control of iterator via the table and kill the iteration prematurely. This involved mutating the table:

local list = {23, 63, 12, 48, 3}

local find = function(list, predicate_fn)
  return reduce(list, function(acc, v, i, t)
    if predicate_fn(v, a, t) then
      -- If a result was found, destroy the next item in the list
      -- to prevent the iteration from going any further.
      t[i + 1] = nil
      return v
    end
    return acc
  end)
end

print(find(list, function(val)
  return val > 1
end))

This returns the correct result:

23

But if we loop over the table afterwards we can see we've messed with the original data which can lead to unexpected consequences in a real application. If your data is coming from an immutable source, meaning something is generating a new copy each time you use it then this wouldn't be a problem:

local generate_list = function()
  return {23, 63, 12, 48, 3}
end

reduce(generate_list(), function()
  ...
  ...

However we could fix all of this if we are willing to add another parameter to our reduce implementation.

local function reduce(t, fn, first)
  local get_iterator = function(t)
    if type(t) == 'table' then
      -- If property of 1 is empty then
      -- iterate as a regular keyed table
      if t[1] == nil then
        return pairs(t)
      end
      return ipairs(t)
    end
    error('Expected table, got ' .. tostring(t))
  end
  local acc = first
  local starting_value = first ~= nil
  for i, v in get_iterator(t) do
    -- Exit the loop when true
    local should_break = false
    -- No starting value, start on
    -- the first element in the list
    if starting_value then
      acc, should_break = fn(acc, v, i, t)
      if should_break then
        break
      end
    else
      acc = v
      starting_value = true
    end
  end
  assert(
    starting_value,
    'Attempted to reduce an empty table with no first value.'
  )
  return acc
end

Now if we pass true as a second return parameter then we will get the first number we are looking for instead of the last. Loop through and print out the list afterward to make sure we haven't mutated it unexpectedly.

local list = {23, 63, 12, 48, 3}

local find = function(list, predicate_fn)
  return reduce(list, function(acc, v, i, t)
    if predicate_fn(v, a, t) then
      return v, true
    end
    return acc
  end, false)
end

print(find(list, function(val)
  return val > 1
end))

for idx, val in ipairs(list) do
  print(idx, val)
end

reduce_right

Another possible change you would want to make is to replace the iterator with a custom-made one to transform data in a specific order or pattern. Taken from lua-users.org's Iteration Tutorial is this reverse-ipairs (ripairs) implementation that allows you to iterate over a table from right to left. This modified version of reduce is typically called reduce_right.

local function reduce_right(t, fn, first)
  local ripairs = function(t)
    local max = 1
    while t[max] ~= nil do
      max = max + 1
    end
    local function ripairs_it(t, i)
      i = i-1
      local v = t[i]
      if v ~= nil then
        return i,v
      else
        return nil
      end
    end
    return ripairs_it, t, max
  end
  local acc = first
  local starting_value = first ~= nil
  for i, v in ripairs(t) do
    -- Exit the loop when true
    local should_break = false
    -- No starting value, start on
    -- the first element in the list
    if starting_value then
      acc, should_break = fn(acc, v, i, t)
      if should_break then
        break
      end
    else
      acc = v
      starting_value = true
    end
  end
  assert(
    starting_value,
    'Attempted to reduce an empty table with no first value.'
  )
  return acc
end

Then swap out reduce for reduce_right in the places you want to use it:

local list = {23, 63, 12, 48, 3}

local find = function(list, predicate_fn)
  return reduce_right(list, function(acc, v, i, t)
    if predicate_fn(v, a, t) then
      return v, true
    end
    return acc
  end, false)
end

print(find(list, function(val)
  return val > 1
end))

Recursive

Since we talked about recursion in the last section, let's try a recursive implementation of reduce. Although with Lua there's no practical reason to choose a recursive implementation over a for-loop or while-loop implementation, doing recursion is fun.

local function reduce(t, fn, acc, key)
  -- Check for starting value
  if key == nil and acc == nil then
    key = next(t, key)
    acc = t[key]
  end
  -- Begin next iteration. Next is a Lua built-in function
  -- that fetches the next key in a table after the given key.
  -- See: https://www.lua.org/pil/7.3.html
  key = next(t, key)
  -- Return acc if we've iterated all keys
  if key == nil then
    return acc
  end
  local break_early = false
  -- Collect new accumulator from predicate function
  acc, break_early = fn(acc, t[key], key, t)
  -- Check to see if the predicate wants to end early
  if break_early then
    return acc
  end
  -- Recur
  return reduce(t, fn, acc, key, acc)
end


-- Test it by getting the total sum from a table like before
local list = {
  monday = 23,
  tuesday = 63,
  wednesday = 12,
  thursday = 48,
  friday = 3
}
local total_sum = reduce(list, function(acc, current_number, key)
  print(key .. ': ' .. current_number)
  return acc + current_number
end, 0)

print('total sum: ' .. total_sum)

This supports breaking early like the two previous implementations.

Exercises

  • Create a count function that counts up the number of items in a list that match the predicate and returns the total. It should work like this:

    local count = function(list, predicate_fn)
      ????
    end
    
    local list = {23, 63, 12, 48, 3}
    -- Print number of items evenly divisible by 3 (should return 4)
    print(count(list, function(v)
      return v % 3 == 0
    end))
    
  • Go back to the map section in 3.3 and see if you can reimplement the map function on top of reduce.

results matching ""

    No results matching ""