Learning to play Atari games

In 2015 several researchers from DeepMind published a paper in Nature in which they described building a reinforcement learning system with a neural network that could spontaneously learn to play Atari video games expertly. In these notes I will explain how they did this, and show some example code to implement their system in keras.

Here is a link to a YouTube video showing the system learning to play Breakout.

The Q table

In reinforcement learning, the Q table is a table that records the expected future reward Q(s,a) for each combination of a state s and an action a that the agent can take from that state. In cases where the available states in a system are finite and discrete, these values truly can be stored in a table whose rows are the possible states s of the system and whose columns are the possible actions a that an agent can perform.

The entries in the Q table satisfy the Bellman equality:

Q(s,a) = r(s,a) + γ V(s')

where r(s,a) is the immediate reward for executing action a from state s, and V(s) is the maximum future reward from state s':

V(s') = maxa Q(s',a)

Another factor in these equations is the γ factor, which is used to discount the future. Each time we take a step back in time from a place where we can receive a reward of r, we discount that future reward by the γ factor: this biases the system toward wanting to gain rewards sooner rather than later.

These are the equations that the entries of the Q table have to satisfy, but they unfortunately do not tell us how to actually compute the values of the table. For this, we typically use an iterative update strategy to fill out the table. In this strategy we do the following:

This plan actually has theoretical support, since Bellman showed that if we carry out these updates to the Q table for enough rounds, eventually the values across the whole table will converge to the correct values.

The Q function and the Q network

The problem with the theory I set out in the previous section is that it rests on an assumption that sometimes is not true: we assumed that the system we were working with had a finite set of discrete states. Some of the environments we would like to operate in violate that assumption, either because the environment does not have discrete states or it has so many discrete states that Q table becomes unworkably large. In those cases we can instead fall back on trying to construct a Q function instead of a Q table. A Q function would still satisfy the Bellman equality.

If we decide to replace the Q table with a Q function it is only a small jump from there to imagine building a Q neural network that would have the task of computing that function. (This sort of network is usually referred to as a deep-Q network.) We can even train that network to learn the Q function by using an update scheme very similar to the one I set out above:

The next thing we need to think about is choosing actions: until we have a strategy to do this we can't operate in the environment. Fortunately, this where the Q network can help us. When we arrive at a state s and we have select a next action a we can use our Q network to compute Q(s,a) for every possible a, and then select the action that produces the largest possible Q value.

Another strategy that we use to supplement this plan is the ε-greedy strategy. In this strategy we set up a parameter ε that starts out close to 1. On each round we pick a real number at random from the interval (0,1): if that number is larger than ε we use the Q network to pick our next action, otherwise we select our next action at random. After each round we decrease ε slightly, so that over time the Q network determines our actions more often as it learns the correct Q values. This gives the system the opportunity to explore the state space early on and eventually lean more and more heavily on the network to control our actions.

Learning to play Atari games

The main example we are going to look at today is the example of using a deep-Q network to train an agent to play the Atari breakout game.

Here is some example code on the keras.io web site that illustrates how to construct and train a deep-Q network to play this game.

The state information that the system learns from is composed of images of the video game screen. Specifically, the state space consists of stacks for four frames at a time. We have to include multiple frames in the state information because a single image frame would only show where the ball and the paddle are located. We need multiple frames so the system can determine where the ball is moving.

The code runs a series of episodes, where each episode is a complete game of breakout. In each episode we use a combination of ε-greedy random moves and moves determined by using the network to give us an estimate for Q(s,a) for each possible action and then selecting the action with the largest Q value. As we play the game, we store a record of states, actions selected, and rewards obtained. (This record is known as a replay buffer.)

To train the Q network we make random samples from the replay buffer and compute targets q = r(s,a) + γ V(s') for each sample (we use the Q network to estimate V(s') for this step). We then train the Q network with inputs (s,a) and targets q.

Another feature of the example that is significant is the use of two networks. The second network, which we call the target network, is only used to compute the V(s) estimates when needed. The first network does everything else, and also gets trained on new examples from the replay buffer periodically. Finally, after every 10000 frames of game play we copy the weights from the Q network we have been training to the target network. The reason for this unusual two network strategy is to improve the stability of our Q learning: since Q(s,a) and V(s) related recursively, sometimes updates to Q(s,a) can cause an unwelcome feedback loop between these two factors. The solution to this problem is to break the feedback loop by only updating the Q(s,a) network and temporarily keeping the network we use to compute V(s) frozen. Eventually updates to Q(s,a) will transfer to V(s) when we copy the weights from the Q network to the target network.