Solving the 0/1 Knapsack Problem with Dynamic Programming

In the last section we saw that our initial naive approach was no good for further optimizations. Let's start from the beginning and think more carefully about the problem to come up with a better approach that will be more amenable to optimizing.

Thinking About the Problem

It's always a good idea to ask some questions that will allow us to understand the problem better.

  • Do we need to iterate over every item? Put another way, would a greedy algorithm work here? The answer is no, we need to iterate over all items. All we need to do is provide a counter-example. A simple one would be values = [60, 100, 120] and weights = [10, 20, 30]. If the knapsack has a capacity of 50, then a greedy algorithm choosing the items with highest value per unit weight would choose the first two items for a total value of 160, but the optimal solution has a total value of 220.

  • Do we care about the items themselves? Recalling back to our naive solution, we generated all the combinations of items and carried those around until the very end. But really, we don't care about that. We only care about the total value and the remaining capacity of the knapsack as we go through each item.

  • What is the decision process as we iterate through each item? Now that we've established we will be visiting each item at least once, what do we actually do when we visit an item? It's simple really, we have two choices: we either take the item (and put it in the knapsack) or we don't.

  • When do we stop? We stop when we've either run out of items or the capacity of our knapsack is at or below 0.

So we have a terminating condition and a decision process for getting to the next step. Does this approach sound familiar? I hope so, it's recursion.

Defining the Recursion Relation

Let's define m[i, c] to be the max value for item i with capacity c. We have m[-1, c] = 0. Simply, if we don't have any items to consider, our max value is 0. Let's also denote the value and weight of item i to be v[i] and w[i] respectively. If w[i] > c then m[i, c] = m[i-1, c]. That is, if the weight of item i exceeds the capacity and we do not include it in the knapsack. Our max value is therefore the same as when we visited item i-1. On the other hand, if w[i] ≤ c then m[i, c] = max(m[i-1, c], m[i-1, c-w[i]] + v[i]). Put another way, before we take the item we need to know which is greater, the total value we have accumulated up until now or the total value we can get with reduced capacity c-w[i] and then adding v[i] to the total value. The greater of those two values will determine whether we take item i or not.

TODO: enable mathjax and write it out in a more compact form

A First Solution

Now that we have defined the recursion relation, let's translate that to code. Here is a candidate solution:

def knapsack_01_brute_recursive(the_values, the_weights, the_capacity):
    def aux(n, values, weights, capacity):
        if capacity <= 0 or n < 0:
            return 0
        q = aux(n - 1, values, weights, capacity)
        if weights[n] <= capacity:
            q = max(q, aux(n - 1, values, weights, capacity - weights[n]) + values[n])
        return q

    return aux(len(the_values) - 1, the_values, the_weights, the_capacity)

Here I defined the function that gets recursively called within another function. It's simply a way of keeping things organized in my IDE. If you decided to go this route, you might get asked about it in an interview. You should be aware that there is a performance penalty and be prepared to talk about it. Next, it's generally not a good idea to shadow outer variables in a nested function, hence why I prepended the_ to the outer variables to keep things organized. Finally, I iterate in reverse instead of from 0 and forward. Doing so helps me visualize the recursion from a top-down perspective a bit better, but it's trivial to switch around, just remember to change the n < 0 check to n > len(values) - 1.

So let's see how this code implements our recursion relation. First we check if we run out of capacity or items, if we do then we cannot gain anymore value so we return 0. Next, we need the total value without including the current item we are visiting. This is aux(n-1, ...) which we store in q. We do this first since if the current item we are visiting, item n, does not fit we simply return q and keep going. If it does fit, we need to check which gives us more value: q (we don't take the item), or aux(n-1, values, weights, capacity - weights[n] + values[n]) (we take the item). When we take the item, our capacity goes from capacity to capacity - weights[n]. In a certain sense, every time a decision is made for a given item we have to backtrack to see how that decision would affect previous decisions made.

The time complexity of this solution is also O(2n) because each function calls itself twice. Space complexity is O(n) as it goes depth first, so there can be only n calls on the stack at once.

You should ingrain this solution into your head. Just about every dynamic programming problem will be solvable in this manner.

A First Pass at Optimization

I am going to assume that you know what memoization, top-down, and bottom-up programming mean. If you don't, you should read up on those concepts either from the internet, or one of the linked textbooks at the beginning of this book.

An easy way to optimize our simple recursive solution is to use memoization. Note that this is different from memorization. Both derive from Latin and are cognates, but memoization was derived from "memorandum" which is shortened to "memo" in English, hence memoization.

The whole point of memoization is to cache the results of subproblems along the way and then look them up when they are about to be recomputed. Let's consider a quick example where weights = [10, 20, 30] and capacity = 50. We don't care about values right now as I just want to show what the tree of recursive calls looks like.

                                          c:50, n:2
                                          +       +
                                          |       |
                                          |       |
                                          |       |
   +------------+ c:20, n:1  <------------+       +----------->  c:50, n:1 +------------------+
   |                    +                                         +                           |
   |                    |                                         |                           |
   |                    |                                      +--+                           |
   v                    v                                      v                              v
c:0, n:0         ++ c:20, n:0 ++                      +-+ c:30, n:0 +----+               c:50, n:0 +--------+
                 |             |                      |                  |                 +                |
                 |             |                      |                  |                 |                |
                 v             v                      v                  v                 v                v
             c:10, n:-1     c:20, n:-1               c:20, n:-1     c:30, n:-1        c:40, n:-1         c:50, n:-1

Note that this is a small example as enumerating all the combinations would get quite tedious. We can see that c:20, n:-1 gets called twice. We have n:-1 because we're starting at n=2 and proceeding in reverse. Granted, this isn't the greatest of examples to show that we are saving time because n:-1 is a terminating condition so the actual amount of work in those steps is minimal.

Now, memoization is essentially caching the output of a function based on the input. We need to know which inputs are varying as we go through each item. The changing inputs are: n and capacity. If we didn't know what types the inputs are, we would store them in a dictionary, or a hash map in some other language, and be done with it. Here, though, we know that both n and capacity are integers, so we can use lists, or arrays.

The strategy is simple, after we check the termination condition, we look up memo[n][capacity] to see if we already computed that combination. If we did we return the value, otherwise we proceed as before. Then, prior to returning the value, we store it in the memo for future computations.

def knapsack_01_brute_recursive_memoized(the_values, the_weights, the_capacity):
    the_memo = [[None] * (the_capacity + 1) for x in range(len(the_values))]

    def aux(n, values, weights, capacity, memo):
        if capacity <= 0 or n < 0:
            return 0
        val_get = memo[n][capacity]
        if val_get is not None:
            return val_get
        q = aux(n - 1, values, weights, capacity, memo)
        if weights[n] <= capacity:
            q = max(q, aux(n - 1, values, weights, capacity - weights[n], memo) + values[n])
        memo[n][capacity] = q
        return q

    return aux(len(the_values) - 1, the_values, the_weights, the_capacity, the_memo)

In Python the is operator is for testing identity, as opposed to equality. Since None is a singleton object it is idiomatic to test for identity as opposed to equality. Also, objects can implement __eq__ which could lead to a situation where testing for equality with None could evaluate to True.

As we can see, this is quite straightforward and doesn't really change the solution much. In fact, we could use functools.lru_cache in the standard library to automatically memoize via Python decorators.

Converting to Bottom-Up

Why is it even called top-down or bottom-up? In the bottom-up approach, which usually involves a tabulation, we start at the bottom state n=0 and solve subproblems until we reach the desired n. Tabulation is exactly what it sounds like, literally filling out a table. In our case, we will represent that table by a list of lists. In the top-down approach, which usually involves memoization, we start at the desired state n and solve the subproblems that will let us reach the desired state.

So how do we go from the top-down to the bottom-up approach? Instead of a memo cache, we will have a table, let's call it dp[n][c]. It's customary to use that as the variable name. I like to call the memo cache memo[n][c] to differentiate the two. It seems that dp has been settled on as the accepted name for the bottom-up table on the internet. Any given entry in dp represents the maximum value for a capacity c from the first n items. This is also the case in the top-down approach in our memo. We're starting at n=0 and we're visiting every item, but we also have to start at capacity=0. Since we don't already know which path through the states gets us to the final answer, we have to iterate through all values of capacity while filling in values in the table. In the recursive solution, if we didn't take the item, our value was aux(n-1, values, weights, capacity), if we did take the item it was aux(n-1, values, weights, capacity - weights[n]) + values[n]. Here in the bottom up approach, all our values are already in dp, so the analogs are dp[n-1][capacity] and dp[n-1][capacity - weights[n]] + values[n].

def knapsack_01_bottomup_naive(values, weights, capacity):
    val_len = len(values)
    dp = [[values[0] if (item_idx == 0 and weights[0] <= cap) else 0 for cap in range(capacity + 1)] for item_idx in
          range(val_len)]
    for n in range(1, val_len):
        for cap in range(1, capacity + 1):
            val_take, val_no_take = 0, 0
            if weights[n] <= cap:
                val_take = dp[n - 1][cap - weights[n]] + values[n]
            val_no_take = dp[n - 1][cap]
            dp[n][cap] = max(val_take, val_no_take)
    return dp[val_len - 1][capacity]

This is fairly straightforward, but let's quickly go over this line

dp = [[values[0] if (item_idx == 0 and weights[0] <= cap) else 0 for cap in range(capacity + 1)] for item_idx in
      range(val_len)]

Normally, you would initialize all values in the table to be 0. But we can do a quick optimization ahead of time by realizing that when we're iterating through the first row, ie the first item whose index is 0, if we have enough capacity to hold its weight, then we will automatically take it. Similarly, the first column of our table, when capacity = 0, the values will also be 0 because if we don't have any capacity, we can't take an item. Of course, we initialize all other values of the table to be 0 so we don't need to do anything about this specifically. It's possible this might be the case with other problems, so do not forget. This is a common source of error in these kinds of problems.

Let W=|capacity|, then our space complexity is O(NW) and same with the time complexity. This is referred to as pseudo-polynomial time. For a great explanation see this SO post. The simple explanation is that time complexity is formally defined in terms of the number of bits in the input. If you are sorting an array of 32 bit integers, then the size of the input is 32n where n is the number of entries in the array. Now, we're using a number, here W, as an input. It takes log(W) bits to store the number. If you add a single significant bit, you end up doubling the number in the worst case. For example, assuming least significant bit first, 101 is5 and adding a significant bit gives us 1011 which is 13. If we let s=log(W) then our time complexity is actually O(2s) and 2s=2log(W)=W.

Optimizing the Bottom-Up solution

The first immediate optimization is the following: for any given n we only care about the nth row in the table and the n-1th row. This allows us to reduce our space complexity to O(W). But as it turns out we can make another space optimization and reduce it down to a single row with the following observation: the columns we are interested in in the previous row are either directly above, dp[n-1][cap], or to the left dp[n-1][cap-weights[n]].

How do we take advantage of this? We cannot iterate from left to right, otherwise we would be overwriting values we need to use later on. If we iterate right to left, then we don't don't overwrite values. This leads us to this very succinct solution:

def knapsack_01_bottomup_optimized_succinct(values, weights, capacity):
    vals = [0] * (capacity + 1)
    for n in range(len(values)):
        for cap in range(capacity, weights[n] - 1, -1):
            vals[cap] = max(vals[cap], vals[cap - weights[n]] + values[n])
    return vals[capacity]

The line range(capacity, weights[n] - 1, -1) means starting from capacity inclusive, until weights[n] - 1 exclusive, with a step value of -1. Let's look at a concrete example in the Python console

>>> list(range(5,0,-1))
[5, 4, 3, 2, 1]

which is exactly what we want. We don't need to consider cap < weights[n] because we no longer have any capacity to take the item, so it will always just be the value it already is.

Top-down vs Bottom-up

We've seen both approaches, but which one is "better"? Like a lot of things, there is no clear answer, but we can list some general advice.

  • Strive for the bottom-up approach in an interview. If you can't, then apply memoization to the recursive solution. If you can't get the recursive solution, then a brute force approach, like at the beginning of the chapter, is the bare minimum
  • Top-down + memoization can lead to fewer subproblems being solved because it solves only those that are needed to reach the nth state.
  • If all subproblems must be solved at least once then bottom-up will usually be faster.
  • In bottom-up approaches, the table is filled in one by one, whereas with top-down approaches, the memo is filled in as needed.
  • Bottom-up has the perception of being faster compared to top-down due to the overhead associated with function calls in the latter.

Discussion

In this section we solved the 0/1 knapsack problem in the canonical way. We started with a recursive solution, which is usually referred to as the top-down approach. We applied memoization to reduce the number of subproblems being solved along the way. We converted to a bottom-up approach and optimized it as much as we could reasonably do so.

In the next section we will take what we learned and see what kinds of problems are suited for dynamic programming and how to detect whether or not to use dynamic programming from a problem description.