<?xml version="1.0" encoding="utf-8"?>
<feed xmlns="http://www.w3.org/2005/Atom">
  <generator uri="http://jekyllrb.com" version="3.9.5">Jekyll</generator>
  
  
  <link href="https://brandonlmorris.github.io/feed.xml" rel="self" type="application/atom+xml" />
  <link href="https://brandonlmorris.github.io/" rel="alternate" type="text/html" />
  <updated>2024-05-19T22:20:38+00:00</updated>
  <id>https://brandonlmorris.github.io//</id>

  
    <title type="html">Brandon Morris</title>
  

  
    <subtitle>The personal and research blog of Brandon L. Morris
</subtitle>
  

  
    <author>
        <name>Brandon L. Morris</name>
      
      
    </author>
  

  
  
    <entry>
      
      <title type="html">Reinforcement Learning: Playing Doom with PyTorch</title>
      
      <link href="https://brandonlmorris.github.io/2018/10/09/dql-vizdoom/" rel="alternate" type="text/html" title="Reinforcement Learning: Playing Doom with PyTorch" />
      <published>2018-10-09T00:00:00+00:00</published>
      <updated>2018-10-09T00:00:00+00:00</updated>
      <id>https://brandonlmorris.github.io/2018/10/09/dql-vizdoom</id>
      <content type="html" xml:base="https://brandonlmorris.github.io/2018/10/09/dql-vizdoom/">&lt;!-- Outline
* What is Reinforcement learning
  - Analogy to operant conditioning
  - Performing a task
  - Learn both the environment and the optimal policy
* Deep Q Learning
  - The markov decision process
  - The Q-function
  - Learning the q-function
* Implementation in vizdoom
--&gt;

&lt;blockquote&gt;
  &lt;p&gt;This tutorial is adapted from the one on &lt;a href=&quot;http://vizdoom.cs.put.edu.pl/tutorial&quot;&gt;ViZDoom’s website&lt;/a&gt;.
Additionally, the code used here is adapted from &lt;a href=&quot;https://github.com/mwydmuch/ViZDoom/blob/master/examples/python/learning_pytorch.py&quot;&gt;this
tutorial&lt;/a&gt;, with substantial modification.&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;Machine learning allows us to &lt;a href=&quot;https://medium.com/@karpathy/software-2-0-a64152b37c35&quot;&gt;program by example&lt;/a&gt;. We can present
the algorithm with some data, potentially provide it some feedback, and then
glean the results of our system. For &lt;a href=&quot;https://brandonlmorris.github.io/2018/06/30/wide-resnet-pytorch&quot;&gt;image classification&lt;/a&gt;, we give the
model some images and it learns to identify what object(s) are in that image.
Tasks like this where the model only needs to find the “right answer” (i.e.
supervised learning) have seen a lot of success, and have huge potential to
automate mundane manual tasks.  But is that all machine learning can do?&lt;/p&gt;

&lt;p&gt;In this post, I’ll introduce some of the ideas fundamental to reinforcement
learning, and how it differs from typical supervised learning. We will then
examine up close one algorithm for solving reinforcement learning problems,
known as Deep Q-learning. Then, we’ll implement Deep Q-learning to teach a
neural network how to play a simple game of Doom using the &lt;a href=&quot;http://vizdoom.cs.put.edu.pl/&quot;&gt;ViZDoom&lt;/a&gt; environment
and PyTorch.&lt;/p&gt;

&lt;figure class=&quot;image&quot;&gt;
&lt;div style=&quot;display:block; margin:0 auto; text-align:center&quot;&gt; &lt;img style=&quot;&quot; src=&quot;https://brandonlmorris.github.io/images/rl-intro/episode-9.gif&quot; alt=&quot;The trained RL agent
shooting a monster&quot; /&gt; &lt;/div&gt; &lt;figcaption style=&quot;display:block;margin:0
auto;text-align:center&quot;&gt;Slow-motion capture of the reinforcement learning agent
shooting a monster in Doom&lt;/figcaption&gt; &lt;/figure&gt;
&lt;p&gt;&lt;br /&gt;&lt;/p&gt;

&lt;h2 id=&quot;what-is-reinforcement-learning&quot;&gt;What is Reinforcement Learning?&lt;/h2&gt;

&lt;p&gt;Reinforcement learning is a branch of machine learning where we try to teach the
model to actually &lt;strong&gt;do&lt;/strong&gt; something. The most famous example of reinforcement
learning is the success of &lt;a href=&quot;https://deepmind.com/research/alphago/&quot;&gt;DeepMind’s AlphaGo&lt;/a&gt; and its variants.
Rather than just predicting an answer, AlhpaGo is a reinforcement learning
agent that learns to masterfully play the game of Go. It can’t just classify; it
needs to sequentially interact with its environment – making moves and
receiving its opponent’s moves – in such a way that it will be most likely to
achieve its long-term goal of winning the game.&lt;/p&gt;

&lt;p&gt;Even though Go is just a board game, programming a competent player is
exceedingly difficult. And interestingly, the same framework used to design
AlphaGo can be applied to nearly any other domain. This framework is known as
the &lt;strong&gt;Markov Decision Process&lt;/strong&gt; (MDP), and it allows us to rigorously and
mathematically characterize a system for reinforcement learning (as well as
other situations).&lt;/p&gt;

&lt;h3 id=&quot;the-markov-decision-process&quot;&gt;The Markov Decision Process&lt;/h3&gt;

&lt;p&gt;MDPs have several different formulations and variants. However, there are two
critical components that are tacitly understood: the &lt;strong&gt;agent&lt;/strong&gt; and the
&lt;strong&gt;environment&lt;/strong&gt;. The agent is the person or thing that is actually trying to
perform the task. They make the decisions and carry out the actions. The
environment is essentially everything else: the world around the agent, the
rules of that world, and even other players can be abstracted out to the
environment.&lt;/p&gt;

&lt;figure class=&quot;image&quot;&gt;
&lt;div style=&quot;display:block; margin:0 auto; text-align:center&quot;&gt;
&lt;img style=&quot;&quot; src=&quot;https://brandonlmorris.github.io/images/rl-intro/rl-loop.png&quot; alt=&quot;The RL
loop&quot; /&gt;
&lt;/div&gt;
&lt;figcaption style=&quot;display:block;margin:0 auto;text-align:center&quot;&gt;The loop for
reinforcement learning: The agent perceives the environment and decides an
action, which changes the environment.&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;&lt;br /&gt;&lt;/p&gt;

&lt;p&gt;In addition to the agent and the environment, MDPs are made up of several
pieces. First is the set of &lt;strong&gt;states&lt;/strong&gt;, which is just the potential
configurations of the agent/environment at a given point in time. For something
like a board game, the state is just the current board and whether it’s the
agent’s turn or not. Next are the &lt;strong&gt;actions&lt;/strong&gt;: all the things that the
agent can actually do. Note that the actions are dependent on the state, since
not all actions are valid in every state. The last main piece is the &lt;strong&gt;reward
function&lt;/strong&gt;, which tells us how “good” our agent is doing at a task. This is what
our reinforcement learning algorithm is going to focus on. Ultimately, we want
to train the model to know how to act such that it will maximize the overall
reward, also known as the return. The overall reward is calculated as just the
sum of the rewards at each step in the process, but maximizing it is difficult,
since we may have to make strategic decisions that are initially low reward to
boost the final return. &lt;!--We usually also include a __discount factor__, which
determines how to balance short-term and long-term rewards. A high discount
means that I care a lot about future rewards, and a very low discount factor
prioritizes maximizing the immediate reward.--&gt;&lt;/p&gt;

&lt;p&gt;Depending on the situation, the MDP may also include transition probabilities.
These tell us how likely we are to transition to a new state \(s_2\) if we’re
currently in a state \(s_1\) and we take some action \(a\). However, in complex
problems we often don’t know what the transition probabilities will be. In a
board game like Go, how can I effectively predict how my opponent will move?
Additionally, in domains like Doom, the state space is so large that enumerating
the transitions is impractical. So instead we let the reinforcement algorithm
&lt;em&gt;learn these transitions&lt;/em&gt; as well. Here is how reinforcement learning differs
from planning systems: we don’t assume to know the world dynamics, and instead
try to learn those dynamics along with good actions.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Aside:&lt;/strong&gt; You may be wondering why we call this framework a &lt;em&gt;Markov&lt;/em&gt; decision
process. The Markov property states that we can reason about all future state
given &lt;strong&gt;only&lt;/strong&gt; the current state. That is, we only need to know where we are
right now, not necessarily how we got here. This is usually the case, and is
crucial for reinforcement learning algorithms to be tractable. Even in cases
where the history is significant, there are ways we can encode that history into
the current state to maintain the Markov property.&lt;/p&gt;

&lt;!--To make this a bit more concrete, imagine trying to write a program to play
chess. How would you approach it? The best chess engines use enormous sets of
heuristics (i.e. rules-of-thumb) and simulate future game states to try and
derive the likely outcome from a particular move--&gt;

&lt;h2 id=&quot;deep-q-learning&quot;&gt;Deep Q-Learning&lt;/h2&gt;

&lt;p&gt;Up to this point, we’ve only described the reinforcement learning problem: given
an MDP, we want to figure out good actions that will maximizes the sum of our
rewards (i.e. the return). The process of deciding an action from a state is
known as a &lt;strong&gt;policy&lt;/strong&gt;, so in other words, we want to learn the best policy for a
given task. There are several different algorithms that do this, but one of the
most straightforward that we’ll look at here is known as &lt;strong&gt;Q-learning&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;Before we can discuss how Q-learning actually works, we need some more
terminology. Recall that the policy involves selecting an action from a state,
and that the return is the sum over all our rewards at each state. Then the
&lt;strong&gt;value&lt;/strong&gt; of a state \(V_\pi(s)\) is the expected return if we start in state
\(s\) and follow policy \(\pi\). Essentially, \(V_\pi\) tries to predict our
final score using just the current state and the action-selection process.&lt;/p&gt;

&lt;p&gt;We can take this a step further. Instead of taking just a state and trying to
predict the final score, we can take the current state &lt;em&gt;and an action&lt;/em&gt; and try
to predict the return. This is known as the Q-value: \(Q_\pi(s, a)\). If our
Q-values are accurate, then playing optimally just boils down to picking the
action with the highest Q-value in our state. However, since we don’t know the
transition probabilities, we have to estimate these Q-values, and try to improve
them. This is where Q-learning comes in. Additionally, we can use deep neural
networks to approximate the Q-functions, hence Deep Q-Learning.&lt;/p&gt;

&lt;p&gt;You may have noticed that Q-functions are inherently recursive. That is, we can
decompose the value of a Q-function by putting it in terms of the Q-function in
the next state:&lt;/p&gt;

\[Q(s_t, a_t) = r_{t+1} + \gamma \cdot \max_a Q(s_{t+1}, a)\]

&lt;p&gt;where \(r_{t+1}\) is the reward we got for taking action \(a_t\), and \(\gamma\)
is our discount factor that trades off immediate vs. long-term rewards. All this
equation says is that Q-functions build off each other over time, and we can
leverage that fact to efficiently estimate them.&lt;/p&gt;

&lt;p&gt;To learn the Q-functions, we’ll utilize a deep neural network. The network will
take a state as input, and output a vector of Q-values, one for each action. We
will train it by presenting it with sets of &lt;em&gt;transitions&lt;/em&gt; (a first state,
action, reward, and second state). The Q-value for the first state should have a
value in the index of the selected action that matches the right-hand side of
the above equation. That difference (squared) will be our loss backpropagated
through our network.&lt;/p&gt;

&lt;p&gt;That’s really all there is to Deep Q-learning. We try to approximate a function
that estimates our overall return after taking a particular action is a
particular state. We turn this global problem into a much more localized variant
by trying to optimize our Q-function estimations over individual transitions.
Provided we have enough of these transitions and they are adequately diverse,
the Q-function will converge to reasonably correct values that let us derive an
optimal policy by repeatedly selecting the maximum action from the Q-values.&lt;/p&gt;

&lt;h2 id=&quot;putting-it-to-practice-vizdoom&quot;&gt;Putting it to Practice: ViZDoom&lt;/h2&gt;

&lt;p&gt;The ViZDoom environment is a fantastic tool for playing with reinforcement
learning. It provides a nice programming interface for the classic video game
Doom, and was designed with reinforcement learning in mind. It comes with
several scenarios out of the box, such as the one we will use that involves
shooting a monster across the room. However, these scenarios can actually be
custom-built using existing free tools like &lt;a href=&quot;http://www.doombuilder.com/&quot;&gt;Doom Builder&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;For the sake of brevity, I’m only going to walk through the particularly
important parts of the Q-learning implementation. You can see the full script to
train and run the ViZDoom agent &lt;a href=&quot;https://gist.github.com/BrandonLMorris/dc75086b844d65c51ab92b956494ecbd&quot;&gt;at this GitHub gist&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;First, we’re going to define a class for our &lt;strong&gt;replay memory&lt;/strong&gt;. The replay
memory will serve as a bank of the recent transitions (e.g. first state, action
taken, second state, and reward). Additionally, we need to keep track as to
whether the action terminated the episode, since that will mean there is no
second state to process.&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;ReplayMemory&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;capacity&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;channels&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;state_shape&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;capacity&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;channels&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;resolution&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;state_shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;float32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;state_shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;float32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;a&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;capacity&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;long&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;r&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;capacity&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;float32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;isterminal&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zeros&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;capacity&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dtype&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;float32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;capacity&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;capacity&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pos&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;add_transition&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;s1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;action&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;s2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;isterminal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;reward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pos&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,:,:]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;s1&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;a&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;action&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;not&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;isterminal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,:,:]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;s2&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;isterminal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;isterminal&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;r&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;reward&lt;/span&gt;

        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pos&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;pos&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;%&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;capacity&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;min&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;capacity&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;get_sample&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sample&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;a&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;isterminal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
                &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;r&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;idx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;The replay memory mostly just stores huge batches of the transitions, though we
included some useful methods for adding transitions to the memory and gathering
a random sampling from the non-zero entries. The replay memory will be critical
to training our network, since it allows us to efficiently gather numerous and
diverse inputs from the agent’s experience. In fact, &lt;strong&gt;the model will only learn
from the replay memory directly.&lt;/strong&gt; As the agent learns during training, it will
leverage the Q-network to determine its actions, add its experience to the
replay memory, and then update its parameters from a sample of transitions that
come from the replay memory.&lt;/p&gt;

&lt;p&gt;Next, we’ll build out our actual Q-function model. Recall that we are using a
deep neural network to approximate our Q-function. Since ViZDoom will give us
raw pixels as our inputs, we’ll leverage a convolutional neural net that can
effectively learn the visual features.&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;QNet&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;available_actions_count&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;nb&quot;&gt;super&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;QNet&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;conv1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Conv2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;8&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;6&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;stride&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# 8x9x14
&lt;/span&gt;        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;conv2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Conv2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;8&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;8&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;stride&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# 8x4x6 = 192
&lt;/span&gt;        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fc1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;192&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;128&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fc2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;128&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;available_actions_count&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;criterion&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;MSELoss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;optimizer&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;optim&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;SGD&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;parameters&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;FLAGS&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;learning_rate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;memory&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ReplayMemory&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;capacity&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;FLAGS&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;replay_memory&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;forward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;relu&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;conv1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;relu&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;conv2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;view&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;192&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;relu&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fc1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fc2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;get_best_action&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;index&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;max&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;index&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;train_step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;s1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;target_q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;output&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;criterion&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;output&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;target_q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;optimizer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;zero_grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;backward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;optimizer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;learn_from_memory&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;memory&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;size&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;FLAGS&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;s1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;a&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;s2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;isterminal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;r&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;memory&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get_sample&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;FLAGS&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;q&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;detach&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;q2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;_&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;max&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dim&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;target_q&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;detach&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;idxs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;arange&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;target_q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;a&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;target_q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;idxs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;r&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;FLAGS&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;discount&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;isterminal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;q2&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;target_q&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;Here we define the basic architecture and some useful methods for training. Note
that the network isn’t particularly large: only 4 layers and not a great deal of
parameters at each of those layers. The particular task isn’t very complex, and
we’re restricting our inputs to small grayscale images of 30x45 pixels.&lt;/p&gt;

&lt;p&gt;Pay particular attention to the second to last line in the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;learn_from_memory()&lt;/code&gt;
method. We want the Q-values for &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;s1&lt;/code&gt; to match the recursive equation above (but
only at the action that was actually taken during that transition). But updating
these indexes to the “true” value, we can take the squared difference as our
network loss.&lt;/p&gt;

&lt;p&gt;Now that we have our replay memory and model, we can flesh out our training loop
method. As I mentioned before, the basic formula is to first experience a
transition, then record that transition and learn from the replay memory. Here’s
the code below:&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;perform_learning_step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;epoch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;game&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;actions&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;s1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;game_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;game&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;random&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;lt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;find_eps&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;epoch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;a&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;torch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tensor&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randint&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;actions&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)).&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;long&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;s1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;s1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;resolution&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;a&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get_best_action&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;to&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;device&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;reward&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;game&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;make_action&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;actions&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;a&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;frame_repeat&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;game&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;is_episode_finished&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;isterminal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;s2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;1.&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;isterminal&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;s2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;game_state&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;game&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;memory&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_transition&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;s1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;a&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;s2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;isterminal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;reward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;learn_from_memory&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;Note the action selection process. Initially, our agent has no idea what good
actions are. As such, we want it to explore very broadly, so that it can get a
diverse range of experience that it can build off of. The &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;find_eps&lt;/code&gt; method will
determine some exploration rate depending on how far into training we are. As
the agent is more and more trained, it will take random actions (i.e. explore)
less and more often take the best action available. This is known as an
“epsilon-greedy” policy. When we’re done training, or evaluating our model, we
will always select the best action and no longer explore.&lt;/p&gt;

&lt;p&gt;The rest of the code involves setting up the ViZDoom game, command line flags,
and training epoch loops. All of it is pretty standard, and has thus been
omitted. You can see and run the full script &lt;a href=&quot;https://gist.github.com/BrandonLMorris/dc75086b844d65c51ab92b956494ecbd&quot;&gt;here&lt;/a&gt;. Since the
network and inputs are pretty small, you should be able to run this on your
personal computer, even if you don’t have a GPU.&lt;/p&gt;

&lt;p&gt;I trained this model on my machine for 20 epochs at 2,000 iterations per epoch.
Pretty quickly the agent learned a reasonable policy, and the whole thing
converged in a little less than 20 minutes. You can see one of the test episodes
in gif form below.&lt;/p&gt;

&lt;figure class=&quot;image&quot;&gt;
&lt;div style=&quot;display:block; margin:0 auto; text-align:center&quot;&gt; &lt;img style=&quot;&quot; src=&quot;https://brandonlmorris.github.io/images/rl-intro/episode-0.gif&quot; alt=&quot;The trained RL agent
shooting a monster&quot; /&gt; &lt;/div&gt; &lt;figcaption style=&quot;display:block;margin:0
auto;text-align:center&quot;&gt;Slow-motion capture of the trained agent. The agent
overshoots initially, but moves back in front of the monster to get the full
reward of the kill.
&lt;/figcaption&gt; &lt;/figure&gt;
&lt;p&gt;&lt;br /&gt;&lt;/p&gt;

&lt;h2 id=&quot;additional-resources&quot;&gt;Additional Resources&lt;/h2&gt;

&lt;p&gt;Deep reinforcement learning is a burgeoning field with lots of exciting new
advancements. This tutorial barely scratches the surface of Deep RL, but should
provide you with everything to get started. If you’re interested in learning
more, here are several resources that I found particularly interesting and/or
useful.&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;&lt;strong&gt;“Reinforcement Learning: An Introduction” by Sutton and Barto&lt;/strong&gt; This is
&lt;strong&gt;the&lt;/strong&gt; textbook on reinforcement learning broadly. A classic in the field,
and a free draft is available &lt;a href=&quot;http://incompleteideas.net/book/bookdraft2017nov5.pdf&quot;&gt;here&lt;/a&gt;.&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;Reinforcement Learning Crash Course&lt;/strong&gt; Lectures from David Silver’s course of
reinforcement learning. &lt;a href=&quot;https://www.youtube.com/watch?v=2pWv7GOvuf0&amp;amp;list=PLqYmG7hTraZDM-OYHWgPebj2MfCFzFObQ&quot;&gt;Link&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;“Deep Reinforcement Learning Doesn’t Work Yet”&lt;/strong&gt; A critical yet honest appraisal
of the current state of reinforcement learning and how it often falls short of
our press releases. &lt;a href=&quot;https://www.alexirpan.com/2018/02/14/rl-hard.html&quot;&gt;Link&lt;/a&gt;&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;Deep Reinforcement Learning NIPS 2018 Workshop&lt;/strong&gt; Collection of talks and
papers literally on the cutting edge of Deep RL research (to be published with
the conference in December). &lt;a href=&quot;https://sites.google.com/view/deep-rl-workshop-nips-2018/home&quot;&gt;Link&lt;/a&gt;&lt;/li&gt;
&lt;/ul&gt;

&lt;!--blockquote class=&quot;twitter-tweet&quot; data-lang=&quot;en&quot;&gt;&lt;p lang=&quot;en&quot; dir=&quot;ltr&quot;&gt;Deep RL
is popular because it&amp;#39;s the only area in ML where it&amp;#39;s socially
acceptable to train on the test set.&lt;/p&gt;&amp;mdash; Jacob Andreas (@jacobandreas) &lt;a
href=&quot;https://twitter.com/jacobandreas/status/924356906344267776?ref_src=twsrc%5Etfw&quot;&gt;October
28, 2017&lt;/a&gt;&lt;/blockquote&gt; &lt;script async
src=&quot;https://platform.twitter.com/widgets.js&quot; charset=&quot;utf-8&quot;&gt;&lt;/script--&gt;</content>

      
      
      
      
      

      
        <author>
            <name>Brandon L. Morris</name>
          
          
        </author>
      

      
        <category term="ai" />
      

      

      
        <summary type="html"></summary>
      

      
      
    </entry>
  
  
  
    <entry>
      
      <title type="html">Building a World-Class CIFAR-10 Model From Scratch</title>
      
      <link href="https://brandonlmorris.github.io/2018/06/30/wide-resnet-pytorch/" rel="alternate" type="text/html" title="Building a World-Class CIFAR-10 Model From Scratch" />
      <published>2018-06-30T00:00:00+00:00</published>
      <updated>2018-06-30T00:00:00+00:00</updated>
      <id>https://brandonlmorris.github.io/2018/06/30/wide-resnet-pytorch</id>
      <content type="html" xml:base="https://brandonlmorris.github.io/2018/06/30/wide-resnet-pytorch/">&lt;!-- Outline
Short intro:
  - Introduce the problem: image classification in CIFAR10
  - Introduce the solution: Wide ResNet
The ResNet architecture:
  - Introduced in 2015 to win ImageNet
  - Allows for crazy deep neural networks
  - Build from residual &quot;blocks&quot; and skip connections
The Wide ResNet:
  - Paper shortly after original resnet (double check)
  - Looked at different types of residual blocks
Our Implementation:
  - The BasicBlock class
  - The WideResNet class
  - Data loading and processing
  - The training schedule
Conclusion:
  - Really good results (4.16 error)
  - SoTA is ~2.87, (Pyramidal resnet w/ shake-drop), but more complex and
    required 6x more training epochs
--&gt;

&lt;p&gt;In this post, I walk through how to build and train a world-class deep learning
image recognition model. Deep learning models tout amazing results in
competitions, but it can be difficult to go from a dense, technical research
paper to actually working code. Here I take one of those papers, break down the
import steps, and translate the words on the page into code you can run and get
near state-of-the-art results on a popular image recognition benchmark.&lt;/p&gt;

&lt;p&gt;The problem we will be solving is one of the most common in deep learning: image
recognition. Here, our model is presented with an image (typically raw pixel
values) and is tasked with outputting the object inside that image from a set of
possible classes.&lt;/p&gt;

&lt;p&gt;The dataset will be using is &lt;a href=&quot;https://www.cs.toronto.edu/~kriz/cifar.html&quot;&gt;CIFAR-10&lt;/a&gt;, which is one of the most popular
datasets in current deep learning research. CIFAR-10 is a collection of 60,000
images, each one containing one of 10 potential classes. These images are tiny:
just 32x32 pixels (for reference, an HDTV will have over a thousand pixels in
width and height). This means the resulting images are grainy and it’s
potentially difficult to determine exactly what’s in them, even for a human.  A
few examples are depicted below.&lt;/p&gt;

&lt;figure class=&quot;image&quot;&gt;
&lt;div style=&quot;display:block; margin:0 auto; text-align:center&quot;&gt;
&lt;img style=&quot;&quot; src=&quot;https://brandonlmorris.github.io/images/wideresnet/cifar-boat.png&quot; alt=&quot;A very
pixelated boat&quot; /&gt;
&lt;img style=&quot;&quot; src=&quot;https://brandonlmorris.github.io/images/wideresnet/cifar-frog.png&quot; alt=&quot;A very
pixelated frog&quot; /&gt;
&lt;/div&gt;
&lt;figcaption style=&quot;display:block;margin:0 auto;text-align:center&quot;&gt;Images of a
boat and frog from the CIFAR-10 dataset&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;&lt;br /&gt;&lt;/p&gt;

&lt;p&gt;The training set consists of 50,000 images, and the remaining 10,000 are used
for evaluating models. At the time of this writing, the best reported model is
&lt;a href=&quot;https://arxiv.org/abs/1802.02375&quot;&gt;97.69% accurate&lt;/a&gt; on the test set. The model we will create here
won’t be quite as accurate, but still very impressive.&lt;/p&gt;

&lt;p&gt;The architecture we will use is a variation of residual networks known as a
&lt;a href=&quot;https://arxiv.org/abs/1605.07146&quot;&gt;&lt;em&gt;wide&lt;/em&gt; residual network&lt;/a&gt;. We’ll use PyTorch as our deep learning
library, and automate some of the data loading and processing with the &lt;a href=&quot;https://github.com/fastai/fastai&quot;&gt;Fast.ai
library&lt;/a&gt;. But first, let’s dig into the architecture of ResNets and the
particular variant we’re interested in.&lt;/p&gt;

&lt;h2 id=&quot;the-residual-network&quot;&gt;The Residual Network&lt;/h2&gt;

&lt;p&gt;Deep neural networks function as a stack of layers. The input moves from one
layer, to the next, with some kind of transformation (e.g. convolution) followed
by a non-linear activation function (e.g. ReLU). With the exception of RNNs,
this process of pushing inputs directly through the network one layer
at a time was standard practice in top-performing deep neural networks.&lt;/p&gt;

&lt;p&gt;Then, in 2015, Kaiming He and his colleagues at Microsoft Research introduced
the &lt;a href=&quot;https://arxiv.org/abs/1512.03385v1&quot;&gt;Residual Network&lt;/a&gt; architecture. In a residual network (resnet,
for short), activations are able to “skip” past layers at certain points and be
summed up with the activations of the layers it skipped. These skip connections
form what are typically referred to as a &lt;strong&gt;residual block&lt;/strong&gt;. The image below
depicts one block in a resnet.&lt;/p&gt;

&lt;figure class=&quot;image&quot;&gt;
&lt;img style=&quot;display:block; margin:0 auto;&quot; src=&quot;https://brandonlmorris.github.io/images/wideresnet/resnet-block.png&quot; /&gt;
&lt;figcaption style=&quot;display:block;margin:0 auto;text-align:center&quot;&gt;The structure
of a resnet block. Inputs are allowed to skip past layers and be summed up
with the activations of the layers they skipped.&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;&lt;br /&gt;&lt;/p&gt;

&lt;p&gt;Architectures built by stacking together residual blocks (i.e. resnets) 
train much more efficiently and to less error. The original paper explores
various depths, and are able to train networks of over 1,200 layers. Before, it
was difficult to train networks with just 19 layers. One potential reason
resnets allow for deeper networks is because they allow the gradient
signal from backpropagation to travel further back up through the network, using
the skip connections like a highway to get closer to the input layer. In 2015, a
residual network won the &lt;a href=&quot;http://www.image-net.org&quot;&gt;ImageNet&lt;/a&gt; with 3.57% test error.&lt;/p&gt;

&lt;p&gt;The authors explain the intuition (and the name) of the residual block as a
recharacterization of the learning process. Consider just a few layers, like
those that make up a single residual block. Now, there should be some ideal
mapping from the block’s inputs to it’s output. Let’s call this mapping
\(H(x)\). Typical learning tries to derive this mapping directly: that is, find
an \(F(x, W)\) similar to our ideal \(H(x)\). But we can change this, and
instead allow \(F\) to approximate the &lt;em&gt;residual&lt;/em&gt;, or the difference, between
\(H(x)\) and \(x\). That is,&lt;/p&gt;

\[F(x, W) := H(x) - x\]

&lt;p&gt;which is equivalent to&lt;/p&gt;

\[H(x) = F(x, W) + x\]

&lt;p&gt;which is the definition of our residual block.&lt;/p&gt;

&lt;h2 id=&quot;the-wide-resnet&quot;&gt;The Wide ResNet&lt;/h2&gt;

&lt;p&gt;Since their introduction, resnets have become a standard choice for deep
learning architectures dealing with computer vision. Several variations of the
residual blocks and architectures presented in the original paper have been
explored, &lt;a href=&quot;https://arxiv.org/abs/1802.02375&quot;&gt;one of which&lt;/a&gt; currently holds the state of the art test
accuracy for CIFAR-10.&lt;/p&gt;

&lt;p&gt;The variation we are going to implement here is the &lt;a href=&quot;https://arxiv.org/abs/1605.07146&quot;&gt;&lt;strong&gt;wide residual
network&lt;/strong&gt;&lt;/a&gt;. Here, the authors point out that the &lt;em&gt;depth&lt;/em&gt; of resnets
was the focal point in their introduction, rather than the &lt;em&gt;width&lt;/em&gt; (that is, the
number of convolutional filters in the layers). They explore some different
kinds of resnet blocks, and show that shallow and wide can be faster and more
accurate than the original deep and thin.&lt;/p&gt;

&lt;figure class=&quot;image&quot;&gt;
&lt;img style=&quot;display:block; margin:0 auto;&quot; src=&quot;https://brandonlmorris.github.io/images/wideresnet/block-comparison.png&quot; /&gt;
&lt;figcaption style=&quot;display:block;margin:0 auto;text-align:center&quot;&gt;Comparison of
the different block structures in vanilla and wide resnets. The two on the left
are those found in a traditional resnet: a basic block of two thin 3x3
convolutions and a &quot;bottleneck&quot; block. On the right, the wide resnet uses blocks
similar to the original basic block, but much wider convolutions (i.e. more
filters). There may or may not be dropout between the convolutions to regularize
the model.&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;&lt;br /&gt;&lt;/p&gt;

&lt;h3 id=&quot;the-structure-of-a-wide-resnet&quot;&gt;The Structure of a Wide ResNet&lt;/h3&gt;

&lt;p&gt;The wide resnet consists of three main architectural components:&lt;/p&gt;

&lt;ol&gt;
  &lt;li&gt;&lt;strong&gt;An initial convolution.&lt;/strong&gt; This is done to pull out any high level features
and help upsample our initial image from only three channels to a
high-dimensional convolutional activation.&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;A number of “groups”.&lt;/strong&gt; Each group will consists of a set of \(N\) residual
blocks.  More on this in a moment.&lt;/li&gt;
  &lt;li&gt;&lt;strong&gt;A pooling and linear layer.&lt;/strong&gt; This will downsample our convolutions and
convert them into class predictions.&lt;/li&gt;
&lt;/ol&gt;

&lt;p&gt;The real meat of the wide resnet will lie in the groups: that’s where all
of our residual blocks will live. The original paper always used three groups in
their experiments, but we will write our code to be modular to the number of
groups.&lt;/p&gt;

&lt;figure class=&quot;image&quot;&gt;
&lt;img style=&quot;display:block; margin:0 auto;width:75%&quot; src=&quot;https://brandonlmorris.github.io/images/wideresnet/wide-resnet-arch.png&quot; /&gt;
&lt;figcaption style=&quot;display:block;margin:0 auto;text-align:center&quot;&gt;Outline of the
wide resnet architecture. `conv1` is the initial convolution and `conv2` through
`conv4` make up the three groups, each consisting of \(N\) blocks. In this case,
the blocks are the wide 3x3 basic blocks, where the width is initially 16\(\cdot
k\) and doubled after each group. Every group after the first also downsamples
to reduce the width and height of the convolutional activations.&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;&lt;br /&gt;&lt;/p&gt;

&lt;p&gt;There are a few considerations that will become key to implementing the blocks
in each of our groups:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;Each block after the first will downsample the size of the activations. This
means that the 32x32 activation block will shrink to 16x16. We’ll do this by
setting the stride of the first convolution in the blocks to 2.&lt;/li&gt;
  &lt;li&gt;Each group will double the number of filters from the previous group.&lt;/li&gt;
  &lt;li&gt;The first block of each group will need to have a convolution in its shortcut
to get it to the right dimensions for the addition operation.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;So how wide will these convolutions be? Our initial convolution will turn our
three channels into 16. The first group will multiply the number of channels by
the &lt;strong&gt;widening factor&lt;/strong&gt; \(k\), and every subsequent group will double the width
of the convolutions. Essentially, the \(i\)th group will have \((16 \cdot
k)\cdot 2^i\) filters in its convolutions (where \(i\) starts from 0).&lt;/p&gt;

&lt;h2 id=&quot;implementing-the-wide-resnet&quot;&gt;Implementing the Wide ResNet&lt;/h2&gt;

&lt;p&gt;Now that the architecture is all settled, it’s time to write some code. I’m
going to implement this in PyTorch, with a little help from the &lt;a href=&quot;https://github.com/fastai/fastai&quot;&gt;fastai
library&lt;/a&gt;. Fastai is a fantastic library for
quickly building high quality models. It’s also really helpful for automating
the more mundane aspects of writing deep learning code, like building data
loaders and training loops, which is what I’ll use it for here.&lt;/p&gt;

&lt;p&gt;The implementation will be done piece by piece: starting with the basic block,
then fleshing out the whole network, and finally building our data pipeline and
training loop. You can find the complete implementation
&lt;a href=&quot;https://github.com/BrandonLMorris/image-classification/blob/master/wide-resnet/wideresnet.py&quot;&gt;here&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Note:&lt;/strong&gt; Some of this code is not going to be as tidy as it could be. In this
article, I’m optimizing for understanding, not necessarily style or cleanliness.&lt;/p&gt;

&lt;h3 id=&quot;the-basicblock-class&quot;&gt;The &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;BasicBlock&lt;/code&gt; Class&lt;/h3&gt;

&lt;p&gt;Since the majority of the model will consist of basic residual blocks, it makes
sense to define a reusable component that we can fill our model with.
Fortunately, PyTorch makes this really easy by allowing us to subclass the
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;nn.Module&lt;/code&gt; class.&lt;/p&gt;

&lt;p&gt;The full implementation of the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;BasicBlock&lt;/code&gt; class can be seen below:&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;BasicBlock&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;inf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;outf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;stride&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;drop&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;nb&quot;&gt;super&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;().&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;bn1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BatchNorm2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;conv1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Conv2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;outf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;padding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                               &lt;span class=&quot;n&quot;&gt;stride&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;stride&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;bias&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;drop&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Dropout&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;drop&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;inplace&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;bn2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BatchNorm2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;outf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;conv2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Conv2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;outf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;outf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;padding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                               &lt;span class=&quot;n&quot;&gt;stride&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;bias&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;inf&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;outf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shortcut&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;lambda&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shortcut&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Sequential&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
                    &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BatchNorm2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ReLU&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inplace&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
                    &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Conv2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;outf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;padding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;stride&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;stride&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;bias&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;forward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;conv1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;relu&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;bn1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;drop&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;conv2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;F&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;relu&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;bn2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;r&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shortcut&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;r&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;A few things to note:&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;After the first group, the first block of each group will need to downsample
the height and width of the convolutional activation. This can be done by
passing in a 2 to the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;stride&lt;/code&gt; parameter when instantiating the first
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;BasicBlock&lt;/code&gt; of the group.&lt;/li&gt;
  &lt;li&gt;With the exception mentioned above, each convolution should preserve the width
and height of the convolutional activation. We achieve this by always using a
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;kernel_size&lt;/code&gt; of 3 and a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;padding&lt;/code&gt; of 1. Additionally, since we’re using
batchnorm, our convolutions don’t need a bias parameter, hence &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;bias=False&lt;/code&gt;.&lt;/li&gt;
  &lt;li&gt;We follow the order of batchnorm -&amp;gt; relu -&amp;gt; convolution. Although the
original batchnorm paper used a different order, this has since been shown to
be more effective during training.&lt;/li&gt;
  &lt;li&gt;If this is the first block in a group, we’re going to double the width via our
convolutions. In that case, the dimensions won’t match for the shortcut
connection, so &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;shortcut&lt;/code&gt; will need it’s own convolution (preceeded by
batchnorm and relu) to increase it to have width &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;outf&lt;/code&gt;. Also, since we only
double on the first block in a group, and we may be downsampling then too,
we’ll need to use the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;stride&lt;/code&gt; parameter in this convolution as well.&lt;/li&gt;
&lt;/ul&gt;

&lt;h3 id=&quot;the-wideresnet-class&quot;&gt;The &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;WideResNet&lt;/code&gt; Class&lt;/h3&gt;

&lt;p&gt;Now that we have our &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;BasicBlock&lt;/code&gt; implementation, we can flesh out the rest of
the wide resnet architecture.&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;k&quot;&gt;class&lt;/span&gt; &lt;span class=&quot;nc&quot;&gt;WideResNet&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Module&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n_grps&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;N&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;drop&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;first_width&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;16&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;nb&quot;&gt;super&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;().&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;__init__&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;layers&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Conv2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;first_width&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;kernel_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;padding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;bias&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Double feature depth at each group, after the first
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;widths&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;first_width&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grp&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_grps&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;widths&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;first_width&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grp&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_grps&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;layers&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+=&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_make_group&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;N&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;widths&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;widths&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grp&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;+&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
                                       &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grp&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;drop&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;layers&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;BatchNorm2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;widths&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;ReLU&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inplace&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
                   &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;AdaptiveAvgPool2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Flatten&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt;
                   &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Linear&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;widths&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]&lt;/span&gt;
        &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;features&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Sequential&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;layers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;_make_group&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;N&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;inf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;outf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;stride&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;drop&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;group&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;list&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;N&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;blk&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;BasicBlock&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inf&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;inf&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;outf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;outf&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;outf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                             &lt;span class=&quot;n&quot;&gt;stride&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;stride&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;drop&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;drop&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;group&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;blk&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;group&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;forward&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;self&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;features&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;You can see the outline of the architecture in the code. Right after we call the
super constructor, we initialize the first convolutional layer (&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;conv1&lt;/code&gt; in the
architecture table).&lt;/p&gt;

&lt;p&gt;After the initial convolution, we calculate the widths (i.e. number of filters) in
each block, creating a list that will become our &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;inf&lt;/code&gt; and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;outf&lt;/code&gt; parameters
during block construction. Then we construct each group in a for loop. If this
is the first group, we use &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;stride=1&lt;/code&gt; since this is the only time we don’t want
to decrease the width and height of the convolutional activations. Making a
group involves calling our &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;_make_group&lt;/code&gt; helper function, which will construct
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;N&lt;/code&gt; instances of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;BasicBlock&lt;/code&gt; with the appropriate &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;inf&lt;/code&gt;, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;outf&lt;/code&gt;, and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;stride&lt;/code&gt;
parameters.&lt;/p&gt;

&lt;p&gt;Finally, we average pool our activations, turning each \(64 \cdot k\)
convolutional activations into a single value, which is input to our last linear
layer used for classification.&lt;/p&gt;

&lt;h3 id=&quot;data-loading-and-training&quot;&gt;Data Loading and Training&lt;/h3&gt;

&lt;p&gt;Our model is locked and loaded, now we just need some data to feed it and a
training loop to optimize it. Since this is the least interesting part of
building a model, I’m going to rely heavily on the fastai library. Note that for
this code to run, the library will need to be importable, which is most simply
done by cloning the repository and then symlinking the library directory into
the same directory that the model is in.&lt;/p&gt;

&lt;p&gt;To start, we’ll set up our &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;data&lt;/code&gt; folder and download our dataset via
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;torchvision.datasets&lt;/code&gt;. We’ll also convert the dataset to numpy arrays of
floating point values, and move the inputs between 0 and 1.&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;n&quot;&gt;os&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;makedirs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;PATH&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;exist_ok&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;trn_ds&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;CIFAR10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;PATH&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;download&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;tst_ds&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;CIFAR10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;PATH&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;train&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;False&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;download&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;trn&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;trn_ds&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'float32'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;/&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;255&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;trn_ds&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train_labels&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;tst&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tst_ds&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;test_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'float32'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;/&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;255&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tst_ds&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;test_labels&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;Now &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;trn&lt;/code&gt; and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tst&lt;/code&gt; are tuples containing our training and test inputs/outputs,
respectively. Next we’ll set up our preprocessing transformations using fastai.&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;n&quot;&gt;sz&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;bs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;128&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;stats&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.4914&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;mf&quot;&gt;0.48216&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;mf&quot;&gt;0.44653&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]),&lt;/span&gt;
         &lt;span class=&quot;n&quot;&gt;np&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.24703&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;mf&quot;&gt;0.24349&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;  &lt;span class=&quot;mf&quot;&gt;0.26159&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;aug_tfms&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;RandomFlip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Cutout&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;16&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;tfms&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tfms_from_stats&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;stats&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sz&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;aug_tfms&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;aug_tfms&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;pad&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;Our inputs will be size 32x32, with batch size 128 (you may need to decrease
this depending on your hardware; this is the value used in the original paper).
We set up our &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tfms&lt;/code&gt; object to be a list of transformation for our inputs: we
normalize based on the known means and standard deviations, take a random crop
after padding each size 4 pixels, and randomly flip the image 50% of the time.
Additionally, we also use &lt;a href=&quot;https://arxiv.org/abs/1708.04552&quot;&gt;cutout&lt;/a&gt;, which will
randomly zero out a square in our input image. Here we set cutout to use 1
square of length 16.&lt;/p&gt;

&lt;p&gt;Finally, we’ll put everything together by creating a dataset object,
instantiating our model, and creating a learner object.&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;n&quot;&gt;data&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ImageClassifierData&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;from_arrays&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'data'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;trn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tst&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;bs&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;bs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tfms&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tfms&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;wrn&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;WideResNet&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n_grps&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;N&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;k&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;learn&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;ConvLearner&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;from_model_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;wrn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;Here we’re using a wide resnet with 3 groups, each group has four blocks, and a
widening factor of 10. We’ll also let dropout be 0.3, which is the default we
picked when we defined the class. This results in a 28-layer network and
produced the best results for our dataset.&lt;/p&gt;

&lt;p&gt;To train, we will follow the same training procedure outline in the original
paper.&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;0.01&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;wds&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;5e-4&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;epochs&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;enumerate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;60&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;60&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;40&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;40&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;learn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fit&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;epochs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;wds&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;5e-4&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;best_save_name&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;sa&quot;&gt;f&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'wrl-10-28-p&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt;&lt;span class=&quot;si&quot;&gt;}&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;We train for 200 epochs, decreasing the learning rate by a fifth at certain
intervals. Fastai will automatically save the best performing model of each
phase in our &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;data&lt;/code&gt; directory since we set the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;best_save_name&lt;/code&gt; parameter.&lt;/p&gt;

&lt;p&gt;In my own tests, this model achieved a final test time accuracy of 95.84%. The
current state of the art for CIFAR-10 is about 98% (though they also trained for
9 times as long). Not bad for less than 100 lines of code!&lt;/p&gt;

&lt;h2 id=&quot;conclusion&quot;&gt;Conclusion&lt;/h2&gt;

&lt;p&gt;In this post, I walked through implementing the wide residual network.
Leveraging PyTorch’s modular API, we were able to construct the model with just
a few dozen lines of code. We also were able to skip past the mundane image
processing and training loop using the fastai library.&lt;/p&gt;

&lt;p&gt;Our final results got us almost 96% accuracy on a rather challenging dataset. We
are within 2% of the best that anybody has ever done. While deep learning moves
at breakneck speeds, often times papers will present ideas that are fairly
straightforward to reimplement yourself. This isn’t always the case, like in
some experiments that require absolutely &lt;a href=&quot;https://en.wikipedia.org/wiki/AlphaZero&quot;&gt;enormous computational
power&lt;/a&gt;. But in cases like the wide
resnet, it can be really fun and extremely rewarding to recreate a paper’s
experiments from scratch.&lt;/p&gt;</content>

      
      
      
      
      

      
        <author>
            <name>Brandon L. Morris</name>
          
          
        </author>
      

      
        <category term="ai" />
      

      

      
        <summary type="html"></summary>
      

      
      
    </entry>
  
  
  
    <entry>
      
      <title type="html">Mastering the Learning Rate to Speed Up Deep Learning</title>
      
      <link href="https://brandonlmorris.github.io/2018/06/24/mastering-the-learning-rate/" rel="alternate" type="text/html" title="Mastering the Learning Rate to Speed Up Deep Learning" />
      <published>2018-06-24T00:00:00+00:00</published>
      <updated>2018-06-24T00:00:00+00:00</updated>
      <id>https://brandonlmorris.github.io/2018/06/24/mastering-the-learning-rate</id>
      <content type="html" xml:base="https://brandonlmorris.github.io/2018/06/24/mastering-the-learning-rate/">&lt;!--
Outline:
  - Training neural nets is hard
    - A lot of time, a lot of computational power
  - Picking the right hyperparameters is hard
    - Can dramatically impact training time and performance
    - Can't be optimized like network parameters
    - Interdependent relationships
    - Come with experience, not much guidelines
  - Finding the Right Learning Rate
    - LR is the single most important hyperparameter
      - Too small: network takes forever to train
      - Too big: network won't be able to converge
    - Typical method: piecewise annealing
    - New method: Cyclical learning rate (with restarts)
    - LR range test
  - Pedal to the metal: Superconvergence and the 1cycle policy
  - Conclusion
    - Practical, insanely useful advice
    - Hyperparameter search shouldn't be stumbling in the dark or hidden
      knowledge

 Pictures? Graphs
 --&gt;

&lt;p&gt;Efficiently training deep neural networks can often be an art as much as a
science. Industry-grade libraries like &lt;a href=&quot;https://pytorch.org&quot;&gt;PyTorch&lt;/a&gt; and
&lt;a href=&quot;https://www.tensorflow.org&quot;&gt;TensorFlow&lt;/a&gt; have rapidly increased the speed with which efficient
deep learning code can be written, but there are still a lot of work required to
create a performant model.&lt;/p&gt;

&lt;p&gt;Let’s say, for example, you want to build an image classifier model. A
convolutional neural network would be the proper approach to utilize deep
learning. But how many layers go in your network? How much momentum and weight
decay should you use? What’s the best dropout probability?&lt;/p&gt;

&lt;p&gt;The reality is that these questions don’t have definitive answers. What works
great on one dataset might not work nearly as well on another. There are
sensible defaults and good rules of thumb, but finding the best combination is
nontrivial. These kinds of decisions are known as &lt;strong&gt;hyperparameters&lt;/strong&gt;: values
that are determined prior to actually executing the training algorithm. Figuring
out the optimal set of hyperparameters can be one of the most time consuming
portions of creating a machine learning model, and that’s particularly true in
deep learning.&lt;/p&gt;

&lt;h2 id=&quot;difficulties-in-finding-the-right-hyperparameters&quot;&gt;Difficulties in Finding the Right Hyperparameters&lt;/h2&gt;

&lt;p&gt;Unlike the parameters inside the model, the hyperparameters are difficult to
optimize. While it’s possible to optimize hyperparameters with &lt;a href=&quot;https://en.wikipedia.org/wiki/Hyperparameter_optimization#Bayesian_optimization&quot;&gt;Bayesian
methods&lt;/a&gt;, this is almost never done in practice. Instead,
the best set of hyperparameters is typically sought through a brute force
search.&lt;/p&gt;

&lt;p&gt;Part of the difficulty of finding the right hyperparameter values is their
complex interplay between each other. One value of weight decay may work well
for a particular learning rate and poorly for another. Changing one value
 impacts many others in ways that are difficult to control.&lt;/p&gt;

&lt;p&gt;A tempting, naive method is to set up reasonable steps for each hyperparameter,
and loop over a range, trying different values for each one. This is known as &lt;strong&gt;grid
search&lt;/strong&gt;, and it’s generally a bad idea for two reasons. First, the model has to
be completely retrained for each set of hyperparameters, and the number of sets
grows exponentially with the number of hyperparameters. Most of these values
will be suboptimal, meaning that we’re wasting a great deal of time and energy
unnecessarily retraining out model. The second reason is a little more subtle.
Our steps will need to have a reasonable size to reduce the number of times we
need to retrain the model, meaning we’re jumping over a decent bit of the search
space with each iteration. There’s no reason our particular intervals are likely
to contain good values, so it very possible we will entirely skip over good
values. In fact, just doing a random search will usually yield better results
that stepping across a fixed interval. The picture below depicts this visually.&lt;/p&gt;

&lt;figure class=&quot;image&quot;&gt;
&lt;img style=&quot;display:block; margin: 0 auto; width: 75%&quot; src=&quot;https://brandonlmorris.github.io/images/learning-rate/gridsearchbad.jpeg&quot; alt=&quot;Random search is not a good choice&quot; /&gt;
&lt;figcaption style=&quot;display:block;margin:0 auto;text-align:center&quot;&gt;Somewhat
counterintuitively, randomly searching for hyperparameter values can give better
results than a systematic grid search. Taken from &lt;a href=&quot;http://www.jmlr.org/papers/volume13/bergstra12a/bergstra12a.pdf&quot;&gt;
Random Search for Hyper-Parameter Optimization&lt;/a&gt;&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;&lt;br /&gt;&lt;/p&gt;

&lt;p&gt;Unfortunately, the state of the art in hyperparameter selection is little more
than a random search. Most values have sensible defaults, but picking the best
possible set can have a significant impact on the model’s final performance.
Many machine learning researchers and practitioners develop intuitions about
good values and how hyperparameters interact, but it takes a lot of time and
practice.  However, some recent and exciting research has outlined techniques
for finding arguable the most important hyperparameter: the &lt;strong&gt;learning rate&lt;/strong&gt;.&lt;/p&gt;

&lt;h2 id=&quot;what-is-the-learning-rate&quot;&gt;What is the Learning Rate?&lt;/h2&gt;

&lt;p&gt;Neural network training is typically performed as stochastic optimization. We
start out with a random set of network parameters, find out which direction they
should move to be improved, then take a step in that direction. This process is
known as &lt;strong&gt;gradient descent&lt;/strong&gt; (the stochastic portion comes from the fact that
we find our improvement direction on a random subset of the training data). The
learning rate determines how big of a step we take in updating the parameters.&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;c1&quot;&gt;# w is our weight, and dw is the derivative
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;w&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+=&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;learning_rate&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dw&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;The above parameter update occurs every iteration of the training process
(though modern networks almost always use a more sophisticated update that adds
extra terms). Without a doubt, &lt;strong&gt;the learning rate is the single most important
hyperparameter for a deep neural network&lt;/strong&gt;. If the learning rate is too small,
the parameters will only change in tiny ways, and the model will take too long
to converge. On the other hand, if the learning rate is too large, the
parameters could jump over low spaces of the loss function, and the network may
never converge.&lt;/p&gt;

&lt;blockquote class=&quot;twitter-tweet&quot; data-lang=&quot;en&quot;&gt;&lt;p lang=&quot;en&quot; dir=&quot;ltr&quot;&gt;3e-4 is
the best learning rate for Adam, hands down.&lt;/p&gt;&amp;mdash; Andrej Karpathy
(@karpathy) &lt;a href=&quot;https://twitter.com/karpathy/status/801621764144971776?ref_src=twsrc%5Etfw&quot;&gt;November
24, 2016&lt;/a&gt;&lt;/blockquote&gt;
&lt;script async=&quot;&quot; src=&quot;https://platform.twitter.com/widgets.js&quot; charset=&quot;utf-8&quot;&gt;&lt;/script&gt;

&lt;p&gt;Picking the learning rate is pretty arbitrary. There are a range of reasonable
values, but that range and the optimal value will vary with the architecture and
dataset. As Andrej Karpathy joked in a tweet seen above, saying that one
learning rate is “the best” is pretty preposterous.&lt;/p&gt;

&lt;p&gt;Commonly, the ideal learning rate will change during training. Most world-class
deep architectures are trained with a piecewise annealing strategy: train the
network for a while with one learning rate, and when the model stops improving,
decrease the learning rate by some factor and keep going. Intuitively, this
makes some sense: if the model gets to a low spot in the loss space, the
steps we take may be too big to keep from jumping across deeper valleys.
Decreasing the learning rate allows for a more fine-grained training.&lt;/p&gt;

&lt;p&gt;While piecewise annealing works in practice, we’ll soon see that it is
suboptimal. There are better ways that we can (1) systematically find
appropriate learning rate(s) for our particular problem, and (2) schedule the
learning rate to automatically vary for faster training and improved
performance.&lt;/p&gt;

&lt;!--
  - Finding the Right Learning Rate
    - LR is the single most important hyperparameter
      - Too small: network takes forever to train
      - Too big: network won't be able to converge
    - Typical method: piecewise annealing
    - New method: Cyclical learning rate (with restarts)
    - LR range test
--&gt;

&lt;h2 id=&quot;cyclical-learning-rates&quot;&gt;Cyclical Learning Rates&lt;/h2&gt;

&lt;p&gt;Picking the perfect learning rate is hard. In fact, it’s probably too hard to
find the singular best value. Instead, we can pick a &lt;em&gt;range of learning rates&lt;/em&gt;
and move through them during training. Kind of surprisingly, this method of
&lt;strong&gt;cyclical learning rates&lt;/strong&gt; &lt;a href=&quot;https://arxiv.org/abs/1506.01186&quot;&gt;works pretty well&lt;/a&gt;.&lt;/p&gt;

&lt;figure class=&quot;image&quot;&gt;
&lt;img style=&quot;display:block; margin: 0 auto; width: 66%&quot; src=&quot;https://brandonlmorris.github.io/images/learning-rate/cycle-lr.png&quot; alt=&quot;Cycling the learning rate during training&quot; /&gt;
&lt;figcaption style=&quot;display:block;margin:0 auto;text-align:center&quot;&gt;Plot of the
learning rate during training. Each cycle starts with a high learning rate that
decreases during training. In this case, we're using cosine annealing. During
training, we'll spend most of our time at least close to the optimal learning
rate, and automatically decrease it for better fine-grained training.&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;&lt;br /&gt;&lt;/p&gt;

&lt;p&gt;Cycling through values for the learning rate during training alleviates two of
the problems with picking the learning rate. First, we don’t need to find an
exactly perfect value, just a range of potentially good values. If we pick our
range well (more on that shortly), we will be close to the optimal value for
most of the training cycle, which is much better than randomly searching for the
perfect learning rate. Additionally, we no longer need to manually schedule the
learning rate to decrease during training, since the cycle does it for us. Just
be sure to start the cycle with a high rate, and decrease it to a low rate.&lt;/p&gt;

&lt;p&gt;You might be asking yourself why immediately reset the learning rate to a high
value instead of allowing it to gradually climb back up. Resetting to a high
learning rate gives us the benefit of a &lt;a href=&quot;http://arxiv.org/abs/1608.03983&quot;&gt;warm restart&lt;/a&gt; in our
optimization and can improve our generalization. Remember that in machine
learning, our primary goal is not to create a model that works well on the
training data, but has high performance on the &lt;em&gt;test&lt;/em&gt; data. This means that we
not only want to find a low spot in the loss space during training, but we also
want a very &lt;em&gt;wide&lt;/em&gt; space. That way, even when our model is presented with new
data that moves it around in loss space, it’s still likely to be at very low
spot, and hence very accurate.  Warm restarts helps us find those low and wide
spaces that we’re looking for.  Even if we find a cozy low valley with our low
learning rate, restarting it to a high value at the start of a new cycle will
pop us right out of that space if it’s not wide enough.&lt;/p&gt;

&lt;p&gt;There are several ways to tinker with cyclical learning rates that might improve
the final results. The length of a cycle is usually about an epoch, but longer
cycles are possible. We can even increase the length of the cycle after each
cycle.  For instance, start with a cycle length of one epoch, then two epochs,
then four, and so on. Schedules like this often give good results in part
because they spend more time at lower rates during the end of training, allowing
the model to hone in on an optimal space of the loss.&lt;/p&gt;

&lt;p&gt;Cyclical learning rates allow us to circumvent the difficulty of picking a good
learning rate. All we need are approximate bounds, and we can spend the majority
of our training time being close to the optimal value, even as that optimal
value changes during training. Additionally, we get the added benefit of
restarts that will help us find wide areas in the loss space, improving our
generalization ability. Now all we need is a method to find the approximate
bounds to cycle through.&lt;/p&gt;

&lt;h2 id=&quot;the-lr-range-test&quot;&gt;The LR Range Test&lt;/h2&gt;

&lt;p&gt;Cyclical learning rates preclude us from needing to find an optimal learning
rate, but we still need an upper and lower bound for our cycles. Luckily, we
don’t need to resort to the guessing game and random search that plagued or
initial hyperparameter search. Instead, the &lt;a href=&quot;https://arxiv.org/abs/1506.01186&quot;&gt;paper&lt;/a&gt; that described
the cyclical learning rate method also introduced a systematic method for
finding good boundaries: the &lt;strong&gt;LR Range Test&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;The LR Range Test is simple to understand and cheap to execute. Start with your
initialized network, and pick a very small learning rate (much smaller than you
would ever likely use). As you train, exponentially increase the learning rate.
Keep track of the loss function for each value of the learning rate. If you’re
in the right range, the loss should drop, then increase as the learning rate gets
too high. Below is a graph of the loss value as a function of the learning rate.&lt;/p&gt;

&lt;figure class=&quot;image&quot;&gt;
&lt;img style=&quot;display:block; margin: 0 auto; width: 66%&quot; src=&quot;https://brandonlmorris.github.io/images/learning-rate/lr-range.png&quot; alt=&quot;Cycling the learning rate during training&quot; /&gt;
&lt;figcaption style=&quot;display:block;margin:0 auto;text-align:center&quot;&gt;Loss as a
function of the learning rate for the LR Range Test. Starting from scratch,
exponentially increase the learning rate until the loss begins increasing. The
point just before the loss starts to increase is the upper bound to use for
cyclical learning rate. A good lower bound is one tenth of the upper bound.&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;&lt;br /&gt;&lt;/p&gt;

&lt;p&gt;Looking at a plot of the loss vs. the learning rate, we can find our boundaries
to use for our cycles. The place to look for is the learning rate where the loss
stops decreasing: the minimum value. In the graph above it’s roughly
\(10^{-1}\). This value is probably too high to use as our boundary, which is
why the loss stopped decreasing here. We need to go back just a bit to a smaller
value for our maximum boundary to use in our cycles. A good one to use here
would be \(10^{-2}\).  At that point the loss is still decreasing with some
gusto. We wouldn’t want to pick the value with the steepest slope, since this
will be the maximum, and the cycle will only spend a little while at that point.
For the minimum, we can use any value that is smaller; typically we can divide
the maximum by a factor such as 3 or 10.&lt;/p&gt;

&lt;h2 id=&quot;pedal-to-the-metal-super-convergence&quot;&gt;Pedal to the Metal: Super-Convergence&lt;/h2&gt;

&lt;p&gt;Cyclical learning rates work well in practice, but there’s actually a way to
take it a step further. The technique was introduced by Leslie Smith again and
dubbed &lt;a href=&quot;http://arxiv.org/abs/1708.07120&quot;&gt;super-convergence&lt;/a&gt;. This strategy is a
modification of the cyclical learning rate, and allows for training to converge
substantially faster, hence the name.&lt;/p&gt;

&lt;p&gt;To exploit super-convergence, instead of iterating over cycles of the learning
rate, we use a single “1cycle” policy. We derive the maximum and minimum
learning rate from the LR Range Test as before. Now, we take one long cycle,
moving up from the minimum to the maximum, and back down again. Then we continue
training and decreasing the learning rate. We also inversely cycle the momentum,
going from a high to low, and allowing it to continue increasing. A plot of the
learning rate and momentum schedules are shown below.&lt;/p&gt;

&lt;figure class=&quot;image&quot;&gt;
&lt;img style=&quot;display:block; margin: 0 auto; width: 75%&quot; src=&quot;https://brandonlmorris.github.io/images/learning-rate/lr_plot.png&quot; alt=&quot;The 1cycle learning rate and momentum schedule&quot; /&gt;
&lt;figcaption style=&quot;display:block;margin:0 auto;text-align:center&quot;&gt;The 1cycle
learning rate and momentum schedule. Following this policy during training leads
to very fast train times and the phenomenon known as
&quot;super-convergence&quot;.&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;&lt;br /&gt;&lt;/p&gt;

&lt;p&gt;Amazingly, adopting the 1cycle policy permits incredibly fast training. The
original authors reported training deep networks on large datasets in a fraction
of the epochs required by other training regimes. Recently, fast.ai &lt;a href=&quot;https://dawn.cs.stanford.edu/benchmark/&quot;&gt;leveraged
super-convergence&lt;/a&gt; to train an ImageNet model in less than three hours,
and a CIFAR10 model &lt;strong&gt;in lest than three minutes&lt;/strong&gt;.&lt;/p&gt;

&lt;h2 id=&quot;conclusion&quot;&gt;Conclusion&lt;/h2&gt;

&lt;p&gt;There are many difficulties in training deep neural networks. The best
practitioners have spent a long time cutting their teeth and developing
intuitions about the best values for hyperparameters. Fortunately, research has
shown us better ways to pick the learning rate than wasting time and computing
power fumbling around in the dark. The LR Range Test provides a quick way to
find suitable boundaries for the learning rate, which we can cycle through
during training to completely avoid having to find an optimal value. This means
more time can be spent training more networks, and less time searching for
hyperparameters. Additionally, the 1cycle policy lets us train neural nets at
breakneck speeds, creating performant models in a fraction of the training time.&lt;/p&gt;

&lt;blockquote&gt;
  &lt;p&gt;Special thanks to the &lt;a href=&quot;https://course.fast.ai&quot;&gt;fast.ai&lt;/a&gt; course for providing the inspiration and
instruction for this blog post.&lt;/p&gt;
&lt;/blockquote&gt;</content>

      
      
      
      
      

      
        <author>
            <name>Brandon L. Morris</name>
          
          
        </author>
      

      
        <category term="ai" />
      

      

      
        <summary type="html">&amp;lt;!– Outline: Training neural nets is hard A lot of time, a lot of computational power Picking the right hyperparameters is hard Can dramatically impact training time and performance Can’t be optimized like network parameters Interdependent relationships Come with experience, not much guidelines Finding the Right Learning Rate LR is the single most important hyperparameter Too small: network takes forever to train Too big: network won’t be able to converge Typical method: piecewise annealing New method: Cyclical learning rate (with restarts) LR range test Pedal to the metal: Superconvergence and the 1cycle policy Conclusion Practical, insanely useful advice Hyperparameter search shouldn’t be stumbling in the dark or hidden knowledge</summary>
      

      
      
    </entry>
  
  
  
    <entry>
      
      <title type="html">Breaking Neural Nets with Adversarial Examples</title>
      
      <link href="https://brandonlmorris.github.io/2018/02/11/breaking-neural-nets/" rel="alternate" type="text/html" title="Breaking Neural Nets with Adversarial Examples" />
      <published>2018-02-11T00:00:00+00:00</published>
      <updated>2018-02-11T00:00:00+00:00</updated>
      <id>https://brandonlmorris.github.io/2018/02/11/breaking-neural-nets</id>
      <content type="html" xml:base="https://brandonlmorris.github.io/2018/02/11/breaking-neural-nets/">&lt;p&gt;&lt;a href=&quot;https://brandonlmorris.github.io}/2017/09/09/what-is-deep-learning&quot;&gt;Deep learning&lt;/a&gt; has asserted itself as the king of machine learning. No
other method produced thus far has had such excellent success at machine
learning tasks that are increasingly complex. In some cases, deep neural
networks trained by backpropagation and stochastic gradient descent (i.e. deep
learning) have been able to dramatically outperform humans merely by being
presented examples of a particular task. The model cultivates “knowledge” by
discerning which signals (called “features”) are significant and how their
existence or absence, in conjunction with other signals, contributes to an
overall results.&lt;/p&gt;

&lt;p&gt;It is beyond doubt that deep neural networks are incredibly sophisticated and
versatile machine learning models. They can derive meaning in clever and
sometimes unexpected ways, all under the loose guide of human-defined
architecture and algorithm. The fact of the matter is that very little human
knowledge is explicitly implanted in these performant models: neural networks
achieve a bottom-up understanding of their task from the training data. Indeed,
understanding how neural networks actually operate is an active area of research
that deserves more attention, as &lt;a href=&quot;https://www.youtube.com/watch?v=Qi1Yry33TQE&quot;&gt;Ali Rahimi points out in his NIPS 2017
talk&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;But deep neural networks, for all their successes in the recent past, have a
significant weakness. It turns out that these very sophisticated models can be
fooled into making dramatically incorrect results. Consider the figure below: to
any normal human being it seems like the image is simply random noise. And
indeed, the image below was created in part by selecting random values for each
of the pixels.&lt;/p&gt;

&lt;p&gt;&lt;br /&gt;&lt;/p&gt;
&lt;figure class=&quot;image&quot;&gt;
&lt;img style=&quot;display:block;margin:0 auto;&quot; src=&quot;https://brandonlmorris.github.io/images/axs/adversarial.jpeg&quot; alt=&quot;Seemingly-random noise&quot; /&gt;
&lt;figcaption style=&quot;display:block;margin:0 auto;text-align:center&quot;&gt;Just some
harmless random noise; that is, unless you're a neural network.&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;&lt;br /&gt;&lt;/p&gt;

&lt;p&gt;A deep neural network, trained on the &lt;a href=&quot;image-net.org&quot;&gt;ImageNet&lt;/a&gt; dataset of over a
million images across one thousand categories, might categorize this image as
something innocuous like a prayer rug. Furthermore, it would likely give it a
low confidence rating, suggesting that the neural network isn’t exactly sure
what the image contains, but it &lt;em&gt;might&lt;/em&gt; be a prayer rug.&lt;/p&gt;

&lt;p&gt;However, the &lt;a href=&quot;https://arxiv.org/abs/1409.4842&quot;&gt;Inception network&lt;/a&gt; pretrained on ImageNet &lt;strong&gt;classified
this image as a school bus with 93% confidence.&lt;/strong&gt; Inception is normally about
70% accurate on normal images, so what went wrong? Is this just a fluke? The
reality is that no: results like this are consistently reproducible for even the
most advanced and accurate deep neural networks, and they form the basis of a
persistent weakness in deep learning, known as &lt;strong&gt;adversarial examples&lt;/strong&gt;.&lt;/p&gt;

&lt;h2 id=&quot;what-are-adversarial-examples&quot;&gt;What are adversarial examples?&lt;/h2&gt;

&lt;p&gt;Adversarial examples lack a formal definition, but they can broadly be
considered as inputs that cause otherwise performant machine learning models to
produce very inaccurate results. Note that normal inputs can be incorrectly
interpreted by a model without necessarily being adversarial.&lt;/p&gt;

&lt;p&gt;Some of the most interesting adversarial examples are those that come from
otherwise normal (and correctly handled) inputs. Take for instance the two
pictures below. On top, the image is perfectly normal, and the Inception model
classifies it accurately as school bus with 95% confidence. However, we can
manipulate the image ever so slightly by introducing minute perturbations to the
image and create the picture on the bottom. The two are almost
indistinguishable, but not to our Inception model. The new image is &lt;strong&gt;classified
as an ostrich with over 98% confidence&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;&lt;br /&gt;&lt;/p&gt;
&lt;figure class=&quot;image&quot;&gt;
&lt;img style=&quot;display:block; margin: 2% auto&quot; src=&quot;https://brandonlmorris.github.io/images/axs/normal-bus.jpeg&quot; alt=&quot;Seemingly-random noise&quot; /&gt;
&lt;img style=&quot;display:block; margin: 2% auto&quot; src=&quot;https://brandonlmorris.github.io/images/axs/adversarial-bus.jpeg&quot; alt=&quot;Seemingly-random noise&quot; /&gt;
&lt;figcaption style=&quot;display:block;margin:0 auto;text-align:center&quot;&gt;A school bus
or an ostrich? On top, the normal image is correctly classified, with high
confidence. But, by slightly changing the pixels in the image here and there
(bottom), we can trick the model into thinking this is an ostrich with almost
complete confidence.&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;&lt;br /&gt;&lt;/p&gt;

&lt;p&gt;And these kinds of results can be replicated with nearly any image. So while
deep neural networks are ostensibly very well equipped to manage normal data,
when that input can be manipulate, even minutely, the classifier can be
dramatically fooled.&lt;/p&gt;

&lt;h2 id=&quot;how-adversarial-examples-are-created&quot;&gt;How adversarial examples are created&lt;/h2&gt;

&lt;p&gt;When they were &lt;a href=&quot;https://arxiv.org/abs/1312.6199&quot;&gt;first presented&lt;/a&gt;, “crafting” (the process
of taking a normal input and transforming it to become adversarial) was
construed as an optimization problem. Given an input and a classifier, find a
perturbation (subject to constraints) that maximizes the error when combined to
the original input and fed through the neural network. From this loose
framework, we can apply known non-convex optimizers to solve the problem, such
as &lt;a href=&quot;https://en.wikipedia.org/wiki/Limited-memory_BFGS&quot;&gt;L-BFGS&lt;/a&gt; which was done in the paper.&lt;/p&gt;

&lt;p&gt;While this technique is effective, it is also somewhat slow. L-BFGS can have a
hard time converging with deep neural networks, since its a second-order
optimization method. In recent years, a number of alternative attack methods
have been proposed in the literature, but in this post we will only examine one
in detail.&lt;/p&gt;

&lt;p&gt;One of the simplest and fastest methods for crafting adversarial examples relies
on utilizing the gradient that is crucial for training the network. With a
single backwards pass through the network, we can collect all the information
necessary to strike a remarkably effective attack. This method is known as the
&lt;a href=&quot;https://arxiv.org/abs/1412.6572&quot;&gt;Fast Gradient Sign Method&lt;/a&gt;, or FGSM.&lt;/p&gt;

&lt;p&gt;&lt;br /&gt;&lt;/p&gt;
&lt;figure class=&quot;image&quot;&gt;
&lt;img style=&quot;display:block; margin: 2% auto&quot; src=&quot;https://brandonlmorris.github.io/images/axs/fgsm-panda.png&quot; alt=&quot;Seemingly-random noise&quot; /&gt;
&lt;figcaption style=&quot;display:block;margin:0 auto;text-align:center&quot;&gt;An example of
crafting and adversarial input using the Fast Gradient Sign Method (FGSM). By
backpropagating to the original image, we can determine which direction each
pixel should move to increase the error in the prediction. We then take a small
step in the sign of that direction, to produce an image that completely fools
the target model.&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;&lt;br /&gt;&lt;/p&gt;

&lt;p&gt;The FGSM attack can be characterized by the following equation:&lt;/p&gt;

\[x_{adversary} = x + \epsilon \cdot \text{sign}(\nabla_x J(x, y; \theta))\]

&lt;p&gt;where \(\epsilon\) is the attack strength and \(J\) is our cost function for
training the model. Since \(J\) is a function that determines how “wrong” our
model is after making a prediction, taking the gradient with respect to our
input (\(\nabla_x J\)) tells us how modifying \(x\) will change how correct our
model’s prediction is. By taking the sign we only concern ourselves with
direction and not magnitude. Then we multiply these gradients by our attack
strength \(\epsilon\), which is constrained to be small so that the resulting
adversarial image is similar to the original input. Finally, we combine the
calculated perturbation with the original image by simple addition.&lt;/p&gt;

&lt;p&gt;While the math may be somewhat intimidating, the reality is that the FGSM attack
is both very simple to program, and extremely efficient to execute. Here’s a
simple Python function that calculates adversarial examples for a TensorFlow
model:&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;fgsm&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;gradients&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;grad_sign&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sign&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;grad&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;adv_x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;FLAGS&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;eps&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;grad_sign&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;adv_x&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;The function takes a classifier, which can be any object that has an input
tensor &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;model.x&lt;/code&gt; and a loss function &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;model.loss&lt;/code&gt;, and returns the adversarial
example crafted from the initial input. The process follows the description
above almost exactly, and is easy to see how cheap the computation is (the most
expensive portion is the backwards pass in &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;tf.gradients()&lt;/code&gt;).&lt;/p&gt;

&lt;p&gt;FGSM is simple and fast, but is it effective? I trained a simple, three layer
neural network on the &lt;a href=&quot;http://yann.lecun.com/exdb/mnist/&quot;&gt;MNIST dataset&lt;/a&gt; of handwritten digits. Even without
convolutional layers, I was able to achieve a test accuracy of 97%. Using the
above code, &lt;strong&gt;the adversarial example accuracy of the network was less than one
percent&lt;/strong&gt;. This was possible with an attack strength of just 0.1, which meant
that no pixel was modified by more than 10%. Below is a sample of an adversarial
example the above method produced, which my network classified as a 3.&lt;/p&gt;

&lt;figure class=&quot;image&quot;&gt;
&lt;img style=&quot;display:block; margin: 0 auto&quot; src=&quot;https://brandonlmorris.github.io/images/axs/mnist-ax.png&quot; alt=&quot;Seemingly-random noise&quot; /&gt;
&lt;figcaption style=&quot;display:block;margin:0 auto;text-align:center&quot;&gt;Result of an
FGSM attack on an MNIST handwritten digit example. While clearly still a 5, my
network (which has normal test accuracy of 97%) thought this image was a 3. Can
you see where FGSM manipulated the image?&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;&lt;br /&gt;&lt;/p&gt;

&lt;h2 id=&quot;adversarial-examples-outside-of-image-classification&quot;&gt;Adversarial examples outside of image classification&lt;/h2&gt;

&lt;p&gt;Adversarial examples are typically studied in the domain of computer vision, and
typically with the task of object classification (given an image, output what is
inside the image). However researchers have demonstrated that this weakness of
deep neural networks infects other kinds of applications as well.
&lt;strong&gt;Reinforcement learning&lt;/strong&gt; involves teaching “agents” or programs how to perform
tasks or play games intelligently. RL algorithms often employ deep neural
networks to manage the huge number of possible environment states and action
strategies. As such, they can be fooled by manipulating the data that is fed to
the agent, causing the agent to take incorrect and disadvantageous actions.&lt;/p&gt;

&lt;p&gt;Systems that try to comprehend passages of text can also be fooled. Since text
comprehension is less well understood than image recognition, these systems
don’t necessarily need that intelligent of manipulations to be fooled. Models
that can answer basic questions about a paragraph can be completely thrown off
by the addition of a single, irrelevant sentence. The image below depicts one
such situation.&lt;/p&gt;

&lt;figure class=&quot;image&quot;&gt;
&lt;img style=&quot;display:block; margin: 0 auto&quot; src=&quot;https://brandonlmorris.github.io/images/axs/comprehension-adversary.jpeg&quot; alt=&quot;Seemingly-random noise&quot; /&gt;
&lt;figcaption style=&quot;display:block;margin:0 auto;text-align:center&quot;&gt;By adding a
single irrelevant sentence, reading comprehension systems can give very
incorrect answers to the contents of a paragraph. Taken from &quot;Adversarial
Examples for Evaluating Reading Comprehension Systems&quot; (https://arxiv.org/abs/1707.07328)&lt;/figcaption&gt;
&lt;/figure&gt;
&lt;p&gt;&lt;br /&gt;&lt;/p&gt;

&lt;p&gt;Another susceptible application within computer vision is image segmentation. In
some cases, it is insufficient to simply describe &lt;em&gt;what&lt;/em&gt; is in an image: we need
to know &lt;em&gt;where&lt;/em&gt; it is in the image (for instance, self-driving cars). Here
again, adversarial examples can cause devastating effects, such as &lt;a href=&quot;https://arxiv.org/abs/1707.05373&quot;&gt;drawing
pictures of minions in the segmentation map&lt;/a&gt;.&lt;/p&gt;

&lt;h2 id=&quot;conclusions-and-implications&quot;&gt;Conclusions and implications&lt;/h2&gt;

&lt;p&gt;Adversarial examples are a humbling weakness in a time where deep learning seems
capable of solving all the world’s problems and paving the road for general
intelligence. Their existence forces researchers to reevaluate the precise
workings of deep neural networks, and impose a new fundamental understanding
about the robustness of deep learning.&lt;/p&gt;

&lt;p&gt;From a practical standpoint, adversarial examples all but cripple deep learning
from being legitimately used in sensitive applications. Indeed, adversarial
examples pose extremely serious threats to systems that rely on deep learning.
How can anyone responsibly employ facial recognition for authentication, when
&lt;a href=&quot;https://www.cs.cmu.edu/~sbhagava/papers/face-rec-ccs16.pdf&quot;&gt;3D printed glasses can make you appear like someone else&lt;/a&gt;? The same
goes for financial trading, self-driving cars, and defense applications.&lt;/p&gt;

&lt;p&gt;Although a &lt;a href=&quot;https://scholar.google.com/scholar?cites=2835128024326609853&amp;amp;as_sdt=205&amp;amp;sciodt=0,1&amp;amp;hl=en&quot;&gt;flood of research&lt;/a&gt; has been produced into understanding,
exploiting, and defending against adversarial examples, &lt;strong&gt;no known technique for
properly defending against adversarial examples exists&lt;/strong&gt;. Even some of the most
recently proposed defense methods accepted to ICLR 2018 were &lt;a href=&quot;https://github.com/anishathalye/obfuscated-gradients&quot;&gt;able to be
bypassed by smarter attacks&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;The result is a looming cloud over all progress in deep learning, which
resembles a large portion of recent progress in artificial intelligence
generally. So far, matter how advanced our models become, they can be always be
fooled by relatively cheap method that would never confuse a human. Does a
simple defensive scheme exist that will solve the issue? Or are adversarial
examples inherent to the current deep learning paradigm, and will persist until
more sophisticated learning techniques that go beyond deep learning are
discovered?&lt;/p&gt;</content>

      
      
      
      
      

      
        <author>
            <name>Brandon L. Morris</name>
          
          
        </author>
      

      
        <category term="ai" />
      

      

      
        <summary type="html">Deep learning has asserted itself as the king of machine learning. No other method produced thus far has had such excellent success at machine learning tasks that are increasingly complex. In some cases, deep neural networks trained by backpropagation and stochastic gradient descent (i.e. deep learning) have been able to dramatically outperform humans merely by being presented examples of a particular task. The model cultivates “knowledge” by discerning which signals (called “features”) are significant and how their existence or absence, in conjunction with other signals, contributes to an overall results.</summary>
      

      
      
    </entry>
  
  
  
    <entry>
      
      <title type="html">Dynamic Routing Between Capsules</title>
      
      <link href="https://brandonlmorris.github.io/2017/11/16/dynamic-routing-between-capsules/" rel="alternate" type="text/html" title="Dynamic Routing Between Capsules" />
      <published>2017-11-16T00:00:00+00:00</published>
      <updated>2017-11-16T00:00:00+00:00</updated>
      <id>https://brandonlmorris.github.io/2017/11/16/dynamic-routing-between-capsules</id>
      <content type="html" xml:base="https://brandonlmorris.github.io/2017/11/16/dynamic-routing-between-capsules/">&lt;p&gt;Convolutional neural networks have dominated the computer vision landscape ever
since &lt;a href=&quot;http://papers.nips.cc/paper/4824-imagenet-classification-with-deep-convolutional-neural-networks&quot;&gt;AlexNet won the ImageNet challenge in 2012&lt;/a&gt;, and for good
reason. Convolutions create a spatial dependency inside the network that
functions as an effective prior for image classification and segmentation.
Weight sharing reduces the number of parameters, and efficient accelerated
implementations are readily available. These days, convolutional neural networks
(or “convnets” for short) are the de facto architecture for almost any computer
vision task.&lt;/p&gt;

&lt;p&gt;However, despite the enormous success of convnets in recent years, the question
begs to be asked, &lt;em&gt;“Can we do better?”&lt;/em&gt; Are there underlying assumptions built
into the fundamentals of convnets that makes them in some ways deficient?&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;Capsule networks&lt;/strong&gt; are novel one architecture that attempts to supersede
traditional convnets. Geoffrey Hinton, who helped develop the &lt;a href=&quot;https://www.wikipedia.org/en/Backpropagation&quot;&gt;backpropagation
algorithm&lt;/a&gt; and has pioneered neural networks and deep learning, has
&lt;a href=&quot;https://www.youtube.com/watch?v=rTawFwUvnLE&amp;amp;t=602s&quot;&gt;talked&lt;/a&gt; about capsule networks for some time, but until very
recently no work on the idea had been publicly published. Just a few weeks ago,
&lt;a href=&quot;https://arxiv.org/abs/1710.09829&quot;&gt;&lt;em&gt;Dynamic Routing Between Capsules&lt;/em&gt;&lt;/a&gt; by Sara Sabour, Nicholas
Frosst and Geoffrey Hinton was made available, explaining what capsule networks
are and the details of their functionality. Here, I’ll walk through the paper
and give a high level review as to how capsules function, the routing algorithm
described in the paper, and the results of using a capsule network for image
classification and segmentation.&lt;/p&gt;

&lt;h2 id=&quot;why-typical-convnets-are-doomed&quot;&gt;Why Typical Convnets are Doomed&lt;/h2&gt;

&lt;p&gt;Before we dive into how capsules solve the problems of convnets, we first need
to establish what capsules are trying to solve. After all, modern convnets can
be trained for near-perfect accuracy over a million images with a thousand
classes, so how bad can they be? While traditional convnets are great at
classifying images like ImageNet, they fall short of perfect in some key ways.&lt;/p&gt;

&lt;h3 id=&quot;sub-sampling-loses-precise-spatial-relationships&quot;&gt;Sub-sampling loses precise spatial relationships&lt;/h3&gt;

&lt;p&gt;Convolutions are great because they create a spatial dependency in our models
(see &lt;a href=&quot;https://brandonlmorris.github.io/2017/09/09/what-is-deep-learning/&quot;&gt;my previous post&lt;/a&gt; for a high level overview of convolutions, or &lt;a href=&quot;https://cs231n.github.io/convolutional-networks/&quot;&gt;these
lecture notes&lt;/a&gt; for an in-depth explanation of convnets), but they
have a key failure. Commonly, a convolutional layer is followed by a (max)
&lt;strong&gt;pooling&lt;/strong&gt; layer. The pooling layer sub-samples the extracted features by
sliding over patches and pulling out the maximum or average value. This has the
benefit of reducing the dimensionality (making it easier for other parts of our
network to work with), but also &lt;strong&gt;loses precise spatial relationships&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;By precise spatial relationships, I mean the exact ways that the extracted
features relate to one another. For instance, consider the following image of a
kitten:&lt;/p&gt;

&lt;figure class=&quot;image&quot;&gt;
&lt;img style=&quot;display:block;margin:0 auto;&quot; src=&quot;https://brandonlmorris.github.io/images/capsules/kitten.jpeg&quot; alt=&quot;A picture of a normal looking kitten&quot; /&gt;
&lt;figcaption style=&quot;display:block;margin:0 auto;text-align:center&quot;&gt;Just an
ordinary, though cute, kitten&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;A pretrained &lt;a href=&quot;https://arxiv.org/abs/1512.03385&quot;&gt;ResNet50&lt;/a&gt; convnet classifies as a tabby cat, which is
obviously correct. Convnets are excellent at detecting specific features within
an image, such as the cats ears, nose, eyes, paws, etc. and combining them to
form a classification. However, sub-sampling via pooling loses the exact
relationship that those features share with each other: e.g. the eyes should be
level and the mouth should be underneath them. Consider what happens when I edit
some of the spatial relationships, and create a kitten image more in the style
of Pablo Picasso:&lt;/p&gt;

&lt;figure class=&quot;image&quot;&gt;
&lt;img style=&quot;display:block;margin:0 auto;&quot; src=&quot;https://brandonlmorris.github.io/images/capsules/kitten-picasso.jpg&quot; alt=&quot;A picture of a normal looking kitten&quot; /&gt;
&lt;figcaption style=&quot;display:block;margin:0 auto;text-align:center&quot;&gt;A slightly
less ordinary kitten&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;When this image is fed to our convnet, we &lt;strong&gt;still&lt;/strong&gt; get a tabby classification
with similar confidence. That’s completely incorrect! Any person can
immediately tell by looking at the image something isn’t right, but the convnet
plugs along as if the two images are almost identical.&lt;/p&gt;

&lt;h3 id=&quot;convnets-are-invariant-not-equivariant&quot;&gt;Convnets are &lt;em&gt;invariant&lt;/em&gt;, not &lt;em&gt;equivariant&lt;/em&gt;&lt;/h3&gt;

&lt;p&gt;Another shortcoming of typical convnets is that the explicitly strive to be
invariant to change. By invariant, I mean that the entire classification
procedure (the hidden layer activations and the final prediction) are nearly
identical to small changes in the input (such as shift, tilt, zoom). This is
effective for the classification task, but it ultimately limits our convents.
Consider what happens when I flip the previous image of the kitten upside-down:&lt;/p&gt;

&lt;figure class=&quot;image&quot;&gt;
&lt;img style=&quot;display:block;margin:0 auto;&quot; src=&quot;https://brandonlmorris.github.io/images/capsules/kitten-rotated-180.jpg&quot; alt=&quot;A picture of a normal looking kitten&quot; /&gt;
&lt;figcaption style=&quot;display:block;margin:0 auto;text-align:center&quot;&gt;The view of a
kitten if you were hanging from the ceiling&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;This time, our ResNet &lt;strong&gt;thinks our kitten is a guinea pig&lt;/strong&gt;! The problem is that
while convnets are invariant to small changes, they don’t react well to large
changes. Even though all the features are still in the image, the lack of
spatial knowledge within the convnet means it can’t make head or tail of such a
transformation.&lt;/p&gt;

&lt;p&gt;Rather than invariance that’s built into traditional convnets by design in
pooling layers, what we should really strive for is &lt;strong&gt;equivariance&lt;/strong&gt;: the model
will still produce a similar classification, but the internal activations
transform along with the image transformations. Instead of ignoring
transformations, we should adjust alongside them.&lt;/p&gt;

&lt;h3 id=&quot;a-note-on-sub-sampling&quot;&gt;A note on sub-sampling&lt;/h3&gt;

&lt;p&gt;Before we proceed, I feel it necessary to point out that Hinton identifies these
problems with sub-sampling in convnets, &lt;strong&gt;not the convolution operation
itself&lt;/strong&gt;. The sub-sampling in typical convnets (usually max-pooling) is largely
to blame for these deficiencies in convnets. The convolution operation itself is
quite useful, and is even utilized in the capsule networks presented.&lt;/p&gt;

&lt;h2 id=&quot;capsules-to-the-rescue&quot;&gt;Capsules to the Rescue&lt;/h2&gt;

&lt;p&gt;These kinds of problems (lack of precise spatial knowledge and invariance to
transformations) are exactly what capsules try to solve. Most simply, &lt;strong&gt;a
capsule is just a group of neurons&lt;/strong&gt;. A typical neural network layer has some
number of neurons, each of which is a floating point number. A &lt;em&gt;capsule layer&lt;/em&gt;,
on the other hand, is a layer that has some number of capsules, each of which is
a grouping of floating point neurons. In this work, a capsule is a single
vector, though &lt;a href=&quot;https://openreview.net/pdf?id=HJWLfGWRb&quot;&gt;later work&lt;/a&gt; utilizes matrices for their
capsules.&lt;/p&gt;

&lt;p&gt;The key idea is that by grouping neurons into capsules, we can encode more
information about the entity (i.e. feature or object) that we’re detecting. This
extra information could be size, shape, position, or a host of other things. The
framework of capsules leaves this open, and its up to the implementation to
define and enforce these encoding principles.&lt;/p&gt;

&lt;p&gt;There’s a few things we need to be careful about before we can get started using
capsules. First, since capsules contain extra information, we need to be a
little more nuanced about how we connect capsule layers and utilize them in our
network. Typical convnets only care about the existence of a feature/object, so
their layers can be fully connected without problem, but we don’t get that
luxury when we start encoding extra properties with capsules. We also have to be
smart about how we connect capsules so that we can appropriately manage the
dimensionality without having to resort to pooling.&lt;/p&gt;

&lt;h2 id=&quot;connecting-capsule-layers-dynamic-routing-by-agreement&quot;&gt;Connecting Capsule Layers: Dynamic Routing by Agreement&lt;/h2&gt;

&lt;p&gt;The central algorithm presented in the &lt;a href=&quot;https://arxiv.org/abs/1710.09829&quot;&gt;capsules paper&lt;/a&gt; that came
out recently is one that describes how capsule layers can be connected to one
another. The authors chose an algorithm that encourages “routing by agreement”:
capsules in an earlier layer that cause a greater output in the subsequent layer
should be encouraged to send a greater portion of their output to that capsule
in the subsequent layer.&lt;/p&gt;

&lt;p&gt;The routing procedure happens for every forward pass through the network, both
during testing and training. The image below visually describes the effect of
the routing procedure.&lt;/p&gt;

&lt;figure class=&quot;image&quot;&gt;
&lt;img style=&quot;display:block;margin:0 auto;&quot; src=&quot;https://brandonlmorris.github.io/images/capsules/routing-visualized.jpeg&quot; alt=&quot;A visual explaination of the effect of the routing procedure&quot; /&gt;
&lt;figcaption style=&quot;display:block;margin:0 auto;text-align:center&quot;&gt;A visual
description of the before and after of the routing procedure. Arrow widths
correspond to the strength of the output. The lines coming out of the capsules
in the second layer correspond to the portions of the output that come from the
capsule in the previous layer.&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;Before the routing procedure, every capsule in the earlier layer spreads its
output &lt;em&gt;evenly&lt;/em&gt; to every capsule in the subsequent layer (initial couplings can
be learned like weights, but this isn’t done in the paper). During each
iteration of the dynamic routing algorithm, strong outputs from capsules in the
subsequent layer are used to encourage capsules in the previous layer to send a
greater portion of their output. Note how &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;caps_21&lt;/code&gt; has a large portion of its
output influenced by &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;caps_11&lt;/code&gt; (denoted by the thick arrow coming out on top).
After the routing procedure, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;caps_11&lt;/code&gt; sends much more of its output toward
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;caps_21&lt;/code&gt; than any of the other capsules in the second layer.&lt;/p&gt;

&lt;p&gt;The mathematical details of this procedure are excellently explained in the
&lt;a href=&quot;https://arxiv.org/abs/1710.09829&quot;&gt;original paper&lt;/a&gt;, but for brevity I will omit a complete
explanation.&lt;/p&gt;

&lt;h2 id=&quot;capsnet-a-shallow-network-with-deep-results&quot;&gt;CapsNet: A Shallow Network with Deep Results&lt;/h2&gt;

&lt;p&gt;Now let’s look at the actual capsule network utilized int the
&lt;a href=&quot;https://arxiv.org/abs/1710.09829&quot;&gt;paper&lt;/a&gt;, known as &lt;strong&gt;CapsNet&lt;/strong&gt;.&lt;/p&gt;

&lt;figure class=&quot;image&quot;&gt;
&lt;img style=&quot;display:block;margin:0 auto;&quot; src=&quot;https://brandonlmorris.github.io/images/capsules/capsnet-arch.jpeg&quot; alt=&quot;The CapsNet Architecture&quot; /&gt;
&lt;figcaption style=&quot;display:block;margin:0 auto;text-align:center&quot;&gt;The CapsNet
architecture&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;Our input images are the MNIST data set: 28x28 grayscale pictures of handwritten
digits (later we’ll see how capsules perform on other, more complex data sets).
These get fed into a convolutional layer (256 9x9 filters, stride of 1 and no
padding). That layer passes through another set of convolutions to become the
&lt;strong&gt;PrimaryCaps&lt;/strong&gt; layer, the first capsule layer. Each capsule inside PrimaryCaps
is a size 8 vector. There are 32x6x6 = 1152 of the capsule in the PrimaryCaps
layer. In the paper, they construct their capsule architecture such that the
&lt;em&gt;length&lt;/em&gt; of a capsule (i.e. the value after putting the vector through the
Euclidean norm) represents the likelihood an entity exists, and the
&lt;em&gt;orientation&lt;/em&gt; (i.e. how the values are distributed within a capsule) represents
all other spatial properties of the entity.&lt;/p&gt;

&lt;p&gt;The PrimaryCaps layer is connected to the DigitCaps layer. &lt;em&gt;This is the only
place in the CapsNet architecture where routing takes place&lt;/em&gt;. DigitCaps is a
layer of 10 capsules (one corresponding to each potential digit), each a size
16 vector. Since the length of the vector corresponds to its existence, all we
need to do is norm the capsules in DigitCaps to get our logits, which can be fed
into a softmax layer to get prediction probabilities.&lt;/p&gt;

&lt;p&gt;Since each digit could appear independently, and in some tests multiple digits,
the authors used a custom loss function that penalized each digit independently
during training. Additionally, they wanted to ensure that each DigitCaps was
learning a good representation within the capsule. To do this they appended a
fully connected network to function as a decoder and attempt to reconstruct the
original input from the DigitCaps layer.&lt;/p&gt;

&lt;figure class=&quot;image&quot;&gt;
&lt;img style=&quot;display:block;margin:0 auto;&quot; src=&quot;https://brandonlmorris.github.io/images/capsules/capsnet-reconstruction.jpeg&quot; alt=&quot;The CapsNet Architecture&quot; /&gt;
&lt;figcaption style=&quot;display:block;margin:0 auto;text-align:center&quot;&gt;Reconstructing
the original input from the DigitCaps layer&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;Only the capsule value from the correct label was used to reconstruct (the
others were masked out). The reconstruction error was used to regularize CapsNet
during training, forcing the DigitCaps layer to learn more about how the digit
appeared in the image.&lt;/p&gt;

&lt;h2 id=&quot;results&quot;&gt;Results&lt;/h2&gt;

&lt;p&gt;With only three layers, the CapsNet architecture performed remarkably well. The
authors report a &lt;strong&gt;0.25%&lt;/strong&gt; test error rate on MNIST, which is close to state of
the art and not possible with a similarly shallow convnet.&lt;/p&gt;

&lt;p&gt;They also performed so experiments on a MultiMNIST data set: two images from
MNIST overlapping each other by up to 80%. Since CapsNet understands more about
size, shape, and position, it should be able to use that knowledge to untangle
the overlapping words.&lt;/p&gt;

&lt;figure class=&quot;image&quot;&gt;
&lt;img style=&quot;display:block;margin:0 auto;&quot; src=&quot;https://brandonlmorris.github.io/images/capsules/multi-mnist.jpeg&quot; alt=&quot;The CapsNet Architecture&quot; /&gt;
&lt;figcaption style=&quot;display:block;margin:0 auto;text-align:center&quot;&gt;Reconstructing
from severely overlapping digits&lt;/figcaption&gt;
&lt;/figure&gt;

&lt;p&gt;The image above represents CapsNet reconstructing its predictions twice, one for
each digit in the image. The red corresponds to one digit, and green for another
(yellow is where they overlap). CapsNet is extremely good at this kind of
segmentation task, and the authors suggest is in part because the routing
mechanism serves as a form of attention.&lt;/p&gt;

&lt;p&gt;CapsNet is also performant on several other data sets. On CIFAR-10, it has a
10.6% error rate (with an ensemble and some minor architecture modifications),
which is roughly the same as when convnets were first used on the data set.
CapsNet attain 2.7% error on the smallNORB data set, and 4.3% error on a subset
of Street View Housing Numbers (SVHN).&lt;/p&gt;

&lt;h2 id=&quot;conclusion&quot;&gt;Conclusion&lt;/h2&gt;

&lt;p&gt;Convnets are extremely performant architectures for computer vision tasks. Their
resurgence has marked the recent AI renaissance currently unfolding with the
advent of deep learning. However, they suffer from some serious flaws that make
them unlikely to take us all the way to general intelligence.&lt;/p&gt;

&lt;p&gt;Capsules are a novel enhancement that go beyond typical convnets by encoding
extra information about detected objects and retain precise spatial
relationships by avoiding sub-sampling. The simple capsule architecture
presented in this paper, CapsNet, is able to get incredible results considering
its small size. Additionally, CapsNet understands more about the images its
classifying, like their position and size.&lt;/p&gt;

&lt;p&gt;Although CapsNet doesn’t necessarily outperform convnets, they are able to match
their accuracy out-of-the-box. This is really promising for the future role of
capsules in computer vision. There’s still a huge amount to research to be done
into improving capsules and scaling them to larger data sets. As a computer
vision researcher, this is an extremely exciting time to be working in the
field!&lt;/p&gt;</content>

      
      
      
      
      

      
        <author>
            <name>Brandon L. Morris</name>
          
          
        </author>
      

      
        <category term="ai" />
      

      

      
        <summary type="html">Convolutional neural networks have dominated the computer vision landscape ever since AlexNet won the ImageNet challenge in 2012, and for good reason. Convolutions create a spatial dependency inside the network that functions as an effective prior for image classification and segmentation. Weight sharing reduces the number of parameters, and efficient accelerated implementations are readily available. These days, convolutional neural networks (or “convnets” for short) are the de facto architecture for almost any computer vision task.</summary>
      

      
      
    </entry>
  
  
  
    <entry>
      
      <title type="html">Simplifying Deep Learning Programming with Keras</title>
      
      <link href="https://brandonlmorris.github.io/2017/11/01/keras-tutorial/" rel="alternate" type="text/html" title="Simplifying Deep Learning Programming with Keras" />
      <published>2017-11-01T00:00:00+00:00</published>
      <updated>2017-11-01T00:00:00+00:00</updated>
      <id>https://brandonlmorris.github.io/2017/11/01/keras-tutorial</id>
      <content type="html" xml:base="https://brandonlmorris.github.io/2017/11/01/keras-tutorial/">&lt;p&gt;&lt;a href=&quot;https://brandonlmorris.github.io/2017/09/12/intro-to-tensorflow&quot;&gt;Last post,&lt;/a&gt; I gave an introduction
into programming a deep neural network with TensorFlow. The model worked quite
well (98% accuracy on the test set) with only 150 lines of code, but it was
arguably a bit complex.&lt;/p&gt;

&lt;p&gt;The problem was we had to really dig into the nitty-gritty details of how we
wanted our model to work. But a lot of times, we do not need to deal with that
level of detail and the complexity that comes with it. This kind of problem
occurs often in software engineering, and it is generally solved with a
convenient library.&lt;/p&gt;

&lt;p&gt;&lt;a href=&quot;https://keras.io&quot;&gt;Keras&lt;/a&gt; is such a library: it does a great job of taking the
complexity out of building a neural network, so you can focus on the interesting
parts of training and utilizing the model. In this post I’ll walk through some
of the basics of Keras and we will rebuild our MNIST handwritten-digit
classifier in a much simpler program.&lt;/p&gt;

&lt;h2 id=&quot;keras-the-model-abstraction&quot;&gt;Keras: The Model Abstraction&lt;/h2&gt;

&lt;p&gt;In Keras, the fundamental abstraction is the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Model&lt;/code&gt; object. We can design,
train, and evaluate the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Model&lt;/code&gt; without necessarily knowing the exact details.
In this example, TensorFlow will be the backend that Keras will utilize behind
the scenes, but Keras can actually function agnostic of its specific backend and
run with &lt;a href=&quot;tensorflow.org&quot;&gt;TensorFlow&lt;/a&gt;, &lt;a href=&quot;http://www.deeplearning.net/software/theano/&quot;&gt;Theano&lt;/a&gt;, or &lt;a href=&quot;https://www.microsoft.com/en-us/cognitive-toolkit/&quot;&gt;CNTK&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;There are a few different model types, but the one we will utilize is the
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Sequential&lt;/code&gt; model. The &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Sequential&lt;/code&gt; model view the network architecture as a
sequence of layers strung together, one after another. This is exactly the
architecture we used in our previous convolutional neural network. With Keras,
we can stack our network layers like individual building blocks to create our
overall model.&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;keras.models&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Sequential&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Sequential&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;h2 id=&quot;adding-layers&quot;&gt;Adding Layers&lt;/h2&gt;

&lt;p&gt;To add new layers to a Keras model, we simply call the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;add()&lt;/code&gt; function and pass
in the layer we want to use. To recreate our previous convnet, we’ll need main
kinds of layers: &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Dense&lt;/code&gt; for our fully connected layers, and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Conv2D&lt;/code&gt; for our
two dimensional convolutional layers. We’ll also need &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;MaxPool2D&lt;/code&gt; and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Dropout&lt;/code&gt;
layers to utilize max pooling and dropout. Finally, a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Flatten&lt;/code&gt; layer will be
used to convert between our convolutional and fully connected layers.&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;keras.layers&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Dense&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Conv2D&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Dropout&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Flatten&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Conv2D&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;padding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'same'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;input_shape&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;28&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;28&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;MaxPooling2D&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;padding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'same'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Conv2D&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;64&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;padding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'same'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;MaxPooling2D&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;padding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'same'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Flatten&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Dense&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;activation&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'relu'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Dropout&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Dense&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;256&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;activation&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'relu'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Dropout&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;add&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Dense&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;activation&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'softmax'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;There’s a few things we should not here. The first layer we &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;add()&lt;/code&gt; needs to
take an additional argument: &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;input_shape&lt;/code&gt;. This tells Keras the size of the
inputs that we will feed into our model (in our case, a 28x28 pixel image for
MNIST). For the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Conv2D&lt;/code&gt; layers, the first argument represents the number of
filters, followed by the dimensions of our convolution. The &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Dense&lt;/code&gt; layers take
an argument that represents the number of neurons in that layer. We can also
specify the activation function we want to use by a keyword argument, as we did
here. Alternatively, we could have added an &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Activation&lt;/code&gt; layer.&lt;/p&gt;

&lt;h2 id=&quot;compiling-and-training-the-model&quot;&gt;Compiling and Training the Model&lt;/h2&gt;

&lt;p&gt;Now that we have defined what our convnet will look like by stacking all of our
layers into our &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;model&lt;/code&gt;, we can get ready to start training our model on the
data set. However, first we need to &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;compile&lt;/code&gt; the model. Since Keras serves as a
high-level wrapper of other machine learning libraries, it needs to convert our
Keras-defined model into a model of our backend. Additionally, we will need to
specify some other attributes of our training procedure.&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;keras.optimizers&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Adadelta&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;compile&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;optimizer&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Adadelta&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(),&lt;/span&gt;
              &lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'categorical_crossentropy'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
              &lt;span class=&quot;n&quot;&gt;metrics&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'accuracy'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;Here we specify that training will use the &lt;a href=&quot;https://arxiv.org/abs/1212.5701&quot;&gt;Adadelta optimizer&lt;/a&gt;, our
loss function is defined by the cross entropy of the output (since this is a
classification task), and we want to optimize over the accuracy of the model.&lt;/p&gt;

&lt;p&gt;Next, we can get our training data ready. Luckily, Keras even has some common
data sets built in.&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;keras.dataset&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mnist&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;keras.utils&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;to_categorical&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_test&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_test&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mnist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;load_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;c1&quot;&gt;# Convert to values between 0. and 1.
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_train&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;28&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;28&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'float32'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;255.&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;x_test&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x_test&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;28&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;28&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;astype&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'float32'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;mf&quot;&gt;255.&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;y_train&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;to_categorical&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_classes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;y_test&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;to_categorical&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y_test&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;num_classes&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;For the inputs, we need to convert the arrays into the right shape, and scale
the values between \([0, 1]\). The outputs get converted to binary one-hot
vectors by Kera’s &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;to_categorical&lt;/code&gt; utility function.&lt;/p&gt;

&lt;p&gt;Now we can finally train our model. This is done by the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Model&lt;/code&gt;’s &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;fit&lt;/code&gt; method:&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fit&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;epochs&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;batch_size&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;50&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;The &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;epochs&lt;/code&gt; argument will determine how many passes through the data training
will make. The &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;batch_size&lt;/code&gt; determines how many samples to train with for each
weight update. Keras will output its progress as it works, updating you on which
epoch is running, approximately how long it will take, and the current loss in
the model.&lt;/p&gt;

&lt;h2 id=&quot;evaluating-the-results&quot;&gt;Evaluating the Results&lt;/h2&gt;

&lt;p&gt;Once our model is trained, we can see how accurate it is at predicting on novel
data. To see how our model stacks up against the test set, use the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;evaluate&lt;/code&gt;
method:&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;n&quot;&gt;loss&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;accuracy&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;model&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;evaluate&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x_test&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y_test&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'Accuracy of {:.2f}%'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;format&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;accuracy&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;So in just a few lines of Python, we were able to create a high performing MNIST
classifier! Using Keras is really straightforward, and allows us to avoid the
nitty-gritty details of programming complex deep neural networks. Instead, we
can work on other interesting aspects of our models and keep the implementation
from hindering our ideas. And when Keras is too high level, we can even use it
as a &lt;a href=&quot;https://blog.keras.io/keras-as-a-simplified-interface-to-tensorflow-tutorial.html&quot;&gt;simplified interface to TensorFlow&lt;/a&gt;. As a deep
learning researcher, Keras takes a lot of the hassle out of programming deep
neural networks.&lt;/p&gt;</content>

      
      
      
      
      

      
        <author>
            <name>Brandon L. Morris</name>
          
          
        </author>
      

      
        <category term="ai" />
      
        <category term="tutorial" />
      

      

      
        <summary type="html">Last post, I gave an introduction into programming a deep neural network with TensorFlow. The model worked quite well (98% accuracy on the test set) with only 150 lines of code, but it was arguably a bit complex.</summary>
      

      
      
    </entry>
  
  
  
    <entry>
      
      <title type="html">Getting Started with Tensorflow</title>
      
      <link href="https://brandonlmorris.github.io/2017/09/12/intro-to-tensorflow/" rel="alternate" type="text/html" title="Getting Started with Tensorflow" />
      <published>2017-09-12T00:00:00+00:00</published>
      <updated>2017-09-12T00:00:00+00:00</updated>
      <id>https://brandonlmorris.github.io/2017/09/12/intro-to-tensorflow</id>
      <content type="html" xml:base="https://brandonlmorris.github.io/2017/09/12/intro-to-tensorflow/">&lt;blockquote&gt;
  &lt;p&gt;Note: All the code from this post can be found
&lt;a href=&quot;https://gist.github.com/BrandonLMorris/29752cf710603fc34f22953ff491f8b5&quot;&gt;here&lt;/a&gt;.
This tutorial is adopted from Google’s own TensorFlow tutorial, &lt;a href=&quot;https://www.tensorflow.org/get_started/mnist/pros&quot;&gt;“Deep MNIST
for ML Experts”&lt;/a&gt;&lt;/p&gt;
&lt;/blockquote&gt;

&lt;p&gt;My &lt;a href=&quot;https://brandonlmorris.github.io/2017/09/09/what-is-deep-learning&quot;&gt;last post&lt;/a&gt; talked about deep learning very generally, describing the
fundamentals of how deep neural networks work and are used. In this post, we’ll
look more concretely at actually building a convolutional neural network to
classify handwritten digits from the MNIST data set. Using Google’s popular
machine learning library TensorFlow, we’ll have a model that gets over 98%
accuracy with about 150 lines of code and 10 minutes of training time on my
laptop.&lt;/p&gt;

&lt;h1 id=&quot;prerequisites&quot;&gt;Prerequisites&lt;/h1&gt;

&lt;p&gt;This tutorial assumes that the reader has a basic familiarity of programming in
Python. If you’ve never seen the syntax before, it is pretty easy to pick up. I
also assume that you have &lt;a href=&quot;https://www.tensorflow.org/install/&quot;&gt;TensorFlow
installed&lt;/a&gt; on you machine. For the sake of
brevity, much of the fundamentals of deep learning are omitted, but you can
learn about those in my &lt;a href=&quot;https://brandonlmorris.github.io/2017/09/09/what-is-deep-learning&quot;&gt;previous blog
post&lt;/a&gt;.&lt;/p&gt;

&lt;h1 id=&quot;what-is-tensorflow&quot;&gt;What is TensorFlow?&lt;/h1&gt;

&lt;p&gt;TensorFlow is a high performance numerical computing library developed by
Google. It generally supports any kind of scientific computation, but was
developed specifically with machine learning in mind. TensorFlow has lots of
packages that make building and running deep and even distributed neural
networks much simpler than before. It was open sourced by Google in 2015.&lt;/p&gt;

&lt;p&gt;Before TensorFlow’s arrival, many of the machine learning libraries in use were
developed by research labs to support their needs. While these libraries were
great, they often lacked strong software engineering expertise and failed to
meet enterprise scale needs. Thankfully, TensorFlow was developed from the
ground up by experts in both of these domains. TensorFlow is the product of
choice that Google, a machine learning leader, uses in many of their products.&lt;/p&gt;

&lt;p&gt;Another great aspect of TensorFlow is that the models are portable. A TensorFlow
program trained to run on a rack of servers can be deployed to execute on a
smartphone. As we will discuss in a moment, this is because TensorFlow builds
computational graphs that can be stored independently of the program that
developed them. The parameters can be trained, and then the model shipped off to
run in production. TensorFlow can also utilize specialized hardware like
graphics processors without any explicit programming by the end user.&lt;/p&gt;

&lt;h1 id=&quot;programming-in-tensorflow&quot;&gt;Programming in TensorFlow&lt;/h1&gt;

&lt;p&gt;TensorFlow was originally developed in C++, but language bindings for Python are
the most common way people program their machine learning applications.
TensorFlow operates by building a &lt;em&gt;computational graph&lt;/em&gt; that can be executed.
This differs slightly from typical imperative programming, where each statement
is explicitly executed line by line. Instead, TensorFlow has us describe &lt;em&gt;how&lt;/em&gt;
to make certain calculations, and then when we want, we can evaluate them
against a session, feeding in any potential inputs. Let’s look at an example&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;tensorflow&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;sess&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Session&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;placeholder&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;float32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;z&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;sess&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;run&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;feed_dict&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;42&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;})&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# &amp;gt;&amp;gt;&amp;gt; 1864.0&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;First, we defined &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x&lt;/code&gt; as a &lt;em&gt;placeholder&lt;/em&gt;: this is essentially an input to our
computational graph. Then we told TensorFlow how to calculate &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;y&lt;/code&gt; and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;z&lt;/code&gt;.
Finally, we told TensorFlow to actually calculate the value of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;z&lt;/code&gt;, populating
the value for our placeholder &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;x&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;This was a pretty simple example, but it illustrates the basic mechanics of
programming in TensorFlow. However, let’s add some complexity. TensorFlow
derives its name from the tensor: a mathematical generalization of a matrix.
Tensors are kind of like arbitrarily high-dimensional arrays, and TensorFlow
excels at working with these constructs. For instance, if we have 100 images,
each 28 by 28 pixels in size, with three values for the red, green, and blue, we
can represent all that data as a single 100x28x28x3 tensor. Let’s try an example
of programming with tensors through matrix-vector multiplication.&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;c1&quot;&gt;# Matrix-vector multiplication
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;A&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;placeholder&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;float32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                   &lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;b&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;placeholder&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;float32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                    &lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;z&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;matmul&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;n&quot;&gt;sess&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;run&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;z&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;feed_dict&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;A&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:[[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;6&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;7&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;8&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;9&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]],&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;b&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:[[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]]&lt;/span&gt;
&lt;span class=&quot;p&quot;&gt;})&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# &amp;gt;&amp;gt;&amp;gt; [[14.],[32.],[50.]]&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;TensorFlow has lots optimized implementations of common operations like
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;matmul&lt;/code&gt;. These are really helpful when building deep neural networks that are
blazingly fast.&lt;/p&gt;

&lt;p&gt;Before we start building a deep neural network, we need to introduce the idea of
TensorFlow &lt;strong&gt;variables&lt;/strong&gt;. Variables are dynamic values that are global to a
TensorFlow session. Generally these are used as the parameters in models that
are tuned during training. Although we won’t directly manipulate them, the
optimization procedure that we utilize will. Before we can start using our
variables in a session, though, we will need to execute
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;sess.run(tf.global_variables_initializer())&lt;/code&gt; in our program.&lt;/p&gt;

&lt;h1 id=&quot;building-an-mnist-classifier&quot;&gt;Building an MNIST classifier&lt;/h1&gt;

&lt;p&gt;Now let’s get started building an MNIST image classifier. The MNIST is a common
data set of 28x28 pixel images of handwritten digits. They were scraped from
tax forms, then centered on the image.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://brandonlmorris.github.io/images/tf/mnist.png&quot; style=&quot;display:block;margin:0 auto;width:40%;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;The goal of our model is to input these images as arrays of pixels and learn how
to derive which digit is displayed in the image. This is called classification,
since each image has to fall within 10 categories (the number 0 to 9).&lt;/p&gt;

&lt;p&gt;To accomplish this feat, we’re going to utilize a deep convolutional neural
network. We’ll have a total of 4 layers: the first two convolutional, the last
two fully connected. To prevent overfitting, we’ll utilize dropout. And finally,
our outputs will be converted into a probability distribution via the softmax
functions. It’s not necessary that you fully understand all the details; these
are just common practices within deep machine learning.&lt;/p&gt;

&lt;p&gt;To get started, we’re going to build some helper methods that will make
constructing our model a little less tedious. The code presented in this
tutorial will be somewhat out of order, but it should run fine when combine
(remember, all of the code in this tutorial can be found &lt;a href=&quot;https://gist.github.com/BrandonLMorris/29752cf710603fc34f22953ff491f8b5&quot;&gt;here&lt;/a&gt;, with some
additional features).&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;c1&quot;&gt;# Create some random variable weights
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;weights&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;truncated_normal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;stddev&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Create a “constant” bias variable
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;bias&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Variable&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;constant&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
      &lt;span class=&quot;mf&quot;&gt;0.1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;shape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Hardcode our convolution parameters
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;conv2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;conv2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;strides&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;padding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'SAME'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Same with our maxpooling
&lt;/span&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;max_pool&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;max_pool&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;ksize&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;strides&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;padding&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'SAME'&lt;/span&gt;
    &lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;The first two methods define the initialization for weights and constants that
we’ll use in our model. The last two methods hardcode some of our parameters for
our convolutional operations.&lt;/p&gt;

&lt;p&gt;Next, we’ll actually define our model. We will encapsulate it as a method so as
to keep our main method a bit cleaner. The model will take a tensor input for an
argument and return it’s output predictions (as well as the dropout probability
used, though that’s not significant).&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;cnn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;keep_prob&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;placeholder&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;float32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;imgs&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;28&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;28&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;

  &lt;span class=&quot;c1&quot;&gt;# Layer 1
&lt;/span&gt;  &lt;span class=&quot;n&quot;&gt;W_conv1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;weights&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;b_conv1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;bias&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;

  &lt;span class=&quot;n&quot;&gt;z_conv1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;conv2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;imgs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W_conv1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;b_con1&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;h_conv1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;relu&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;z_conv1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;h_pool1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;max_pool&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;h_conv1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

  &lt;span class=&quot;c1&quot;&gt;# Layer 2 (convolutional)
&lt;/span&gt;  &lt;span class=&quot;n&quot;&gt;W_conv2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;weights&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;5&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;64&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;b_conv2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;bias&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;64&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;

  &lt;span class=&quot;n&quot;&gt;z_conv2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;conv2d&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;h_pool1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W_conv2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;b_conv2&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;h_conv2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;relu&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;z_conv2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;h_pool2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;max_pool&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;h_conv2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

  &lt;span class=&quot;c1&quot;&gt;# Layer 3 (fully connected)
&lt;/span&gt;  &lt;span class=&quot;n&quot;&gt;W_fc1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;weights&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;7&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;7&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;64&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;b_fc1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;bias&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;

  &lt;span class=&quot;n&quot;&gt;h_pool2_flat&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;h_pool2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;7&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;7&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;*&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;64&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;

  &lt;span class=&quot;n&quot;&gt;z_fc1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;matmul&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;h_pool2_flat&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W_fc1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;b_fc1&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;h_fc1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;relu&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;z_fc1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;h_fc1_drop&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dropout&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;h_fc1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;keep_prob&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

  &lt;span class=&quot;c1&quot;&gt;# Layer 4 (fully connected, last hidden layer)
&lt;/span&gt;  &lt;span class=&quot;n&quot;&gt;W_fc2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;weights&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1024&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;b_fc2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;bias&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;

  &lt;span class=&quot;n&quot;&gt;z_fc2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;matmul&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;h_fc1_drop&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;W_fc2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;b_fc2&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;h_fc2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;relu&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;z_fc1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

  &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;h_fc2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;keep_prob&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;For each layer, we define our weights and biases (TensorFlow won’t reinitialize
these during execution), and perform our operation before moving on to the next
layer. For the convolutional layers, this involves applying the actual
convolution followed by a maxpooling to decrease the dimensionality. For the
fully connected layers, the operation is a matrix vector multiplication with the
weights, followed by an ReLU activation (and dropout in the second to last
layer). When switching from convolutional to fully connected layers, we needed
to reshape our data.&lt;/p&gt;

&lt;p&gt;Now that our model has been established, we can program our training procedure.
When training a deep neural network, we typically feed in some data, measure the
error, and adjust the parameters so as to decrease the error. Our input data and
labels will be placeholders, and we can measure the error using TensorFlow’s
built-in cross entropy operation on the softmax of our model outputs. Then we
can program that our optimization step (i.e. weight adjustment) should use the
Adam optimizer, which is an enhancement of stochastic gradient descent. While
we’re at it, we’ll also tell TensorFlow how to measure our accuracy.&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;main&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;_&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;mnist&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;input_data&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;read_data_sets&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;FLAGS&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;data_dir&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;one_hot&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  
  &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;placeholder&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;float32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                     &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;784&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;y_true&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;placeholder&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;float32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                               &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;bp&quot;&gt;None&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;10&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;y_hat&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;keep_prob&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cnn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

	&lt;span class=&quot;c1&quot;&gt;# Cross entropy measures the error in our predictions
&lt;/span&gt;  &lt;span class=&quot;n&quot;&gt;cross_entropy&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reduce_mean&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;nn&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;softmax_cross_entropy_with_logits&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;logits&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y_hat&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;labels&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y_true&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
  &lt;span class=&quot;c1&quot;&gt;# Once we have error, we can optimize
&lt;/span&gt;  &lt;span class=&quot;n&quot;&gt;training_step&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;AdamOptimizer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;1e-4&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;minimize&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cross_entropy&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

  &lt;span class=&quot;c1&quot;&gt;# Define a “correct prediction” to calc accuracy
&lt;/span&gt;  &lt;span class=&quot;n&quot;&gt;correct_prediction&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;equal&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;argmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y_hat&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;argmax&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y_true&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;accuracy&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reduce_mean&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cast&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;correct_prediction&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;float32&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;Now we can write the actual loop that will run the training procedure.&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;  &lt;span class=&quot;c1&quot;&gt;# inside main()
&lt;/span&gt;  &lt;span class=&quot;k&quot;&gt;with&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Session&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;sess&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;sess&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;run&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;global_variables_initializer&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;())&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;EPOCHS&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;batch&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mnist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;next_batch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;MINIBATCH_SIZE&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
      &lt;span class=&quot;n&quot;&gt;train_step&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;run&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;feed_dict&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;y_true&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;batch&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;],&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;keep_prob&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;0.5&lt;/span&gt;
      &lt;span class=&quot;p&quot;&gt;})&lt;/span&gt;

  &lt;span class=&quot;n&quot;&gt;test_acc&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;accuracy&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;eval&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;feed_dict&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;{&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mnist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;test&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;images&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;Y_true&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;mnist&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;test&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;labels&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;keep_prob&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;&lt;span class=&quot;mf&quot;&gt;1.0&lt;/span&gt;
  &lt;span class=&quot;p&quot;&gt;})&lt;/span&gt;

  &lt;span class=&quot;k&quot;&gt;print&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'Test accuracy is {:2f}%'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;format&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;test_acc&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;Finally, we can put the finishing touches so that our program will run. Outside
of any function, at the bottom of the file write:&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;tensorflow.examples.tutorials.mnist&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;input_data&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;tensorflow&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;as&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;__name__&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;s&quot;&gt;'__main__'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;EPOCHS&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;5000&lt;/span&gt;
  &lt;span class=&quot;n&quot;&gt;MINIBATCH_SIZE&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;50&lt;/span&gt;

  &lt;span class=&quot;n&quot;&gt;tf&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;app&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;run&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;main&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;main&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;On my 2013 MacBook Pro, I was able to run this program in about 5-10 minutes,
and achieved 98% accuracy. That’s pretty amazing! All of this code, including
some enhancements like model saving, can be found at &lt;a href=&quot;https://goo.gl/Pf4sDA&quot;&gt;this
link&lt;/a&gt;.&lt;/p&gt;</content>

      
      
      
      
      

      
        <author>
            <name>Brandon L. Morris</name>
          
          
        </author>
      

      
        <category term="ai" />
      
        <category term="tutorial" />
      

      

      
        <summary type="html">Note: All the code from this post can be found here. This tutorial is adopted from Google’s own TensorFlow tutorial, “Deep MNIST for ML Experts”</summary>
      

      
      
    </entry>
  
  
  
    <entry>
      
      <title type="html">What is deep learning?</title>
      
      <link href="https://brandonlmorris.github.io/2017/09/09/what-is-deep-learning/" rel="alternate" type="text/html" title="What is deep learning?" />
      <published>2017-09-09T00:00:00+00:00</published>
      <updated>2017-09-09T00:00:00+00:00</updated>
      <id>https://brandonlmorris.github.io/2017/09/09/what-is-deep-learning</id>
      <content type="html" xml:base="https://brandonlmorris.github.io/2017/09/09/what-is-deep-learning/">&lt;p&gt;Deep learning is a subfield of machine learning that has had remarkable research
success in the past decade. Huge numbers of research groups and top software
companies are pushing the boundaries of what was previously thought possible
through computation with these advancements.&lt;/p&gt;

&lt;p&gt;One of my favorite things about deep learning is that, despite the hype and
number of PhD’s who work in the field, it is actually relatively simple to get
started. Most common laptops are powerful enough to train and run simple deep
models. This is in part due to the growing amount of open source machine
learning libraries that make building deep models significantly simpler than
rolling them by hand. In this post we’ll walk through some of the fundamentals
of deep learning and the historical background. This foundations should provide
enough context to start digging into deep learning and building your very own
models.&lt;/p&gt;

&lt;h1 id=&quot;getting-started&quot;&gt;Getting started&lt;/h1&gt;

&lt;p&gt;Deep learning is one specific category within the general field of artificial
intelligence. At its core, it leverages learning models that are artificial
neural networks with lots of layers. These models are general purpose, and
similar architectures can be trained to run a variety of tasks. They “learn” by
adjusting parameters, of which deep neural nets can have billions. We can think
of these parameters as lots of dials on a big, complicated machine. By
discovering the right combination of dial settings, we can configure the machine
to properly accomplish a particular task.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://brandonlmorris.github.io/images/tf/ann.jpeg&quot; style=&quot;display:block;margin:0 auto;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;Neural networks are layers of groups of neurons that feed into each other. We’ll
discuss more about the neurons in a moment, but suffice it to say that they
perform a relatively simple mathematical transformation. Every neuron in a layer
receives its input from every neuron in the layer before, and sends its output
to every neuron in the next layer. There are three different kinds of layers in
a typical network: input, hidden, and output. The input layer is the first layer
in our network, and it is just the input for the task we’re trying to solve
(e.g. each neuron corresponds to a pixel within an image). The output layer is
what our model will ultimately result in for a given input. It may just be a
single number or, if our model is dealing with multiple categories, it may be
multiple outputs that we can combine in some meaningful way. Hidden layers are
where all of the interesting things happen. They contain all the parameters in
our model that are used during the mathematical transformation. By tuning the
parameters in the hidden layer neurons, we train our network. Deep neural
networks get their namesake by generally having lots of hidden layers.&lt;/p&gt;

&lt;h1 id=&quot;a-quick-history-of-deep-learning&quot;&gt;A quick history of deep learning&lt;/h1&gt;

&lt;p&gt;Artificial neural networks were originally theorized in the late 1950’s and
early 1960’s, and derive their name from the loose inspiration of how actual
neurons within our brain work. Although the concept is almost as old as
artificial intelligence itself, neural networks did not receive much attention
until the 1980’s. By then, researchers discovered how to increase the size of
these networks without dramatically increasing the training complexity.
However, they again fell out of vogue due to their inability (at the time) to
outperform other machine learning algorithms.&lt;/p&gt;

&lt;p&gt;All of this changed in 2006, when deep learning was born. Huge advancements in
computer hardware and some algorithmic improvements permitted researchers to
build neural networks with huge numbers of layers (i.e. “deep”) and train them
to beat other kinds of machine learning models. This marked the start of the era
of deep learning, where these models have been adapted and modified to perform
amazingly well in a wide variety of complex tasks. Many times, they can even
outperform humans.&lt;/p&gt;

&lt;h1 id=&quot;going-deeper-whats-in-a-neuron&quot;&gt;Going deeper: What’s in a neuron?&lt;/h1&gt;

&lt;p&gt;Neurons within a deep learning network perform a mathematical transformation of
their input that is determined by some parameters, or &lt;em&gt;weights&lt;/em&gt;. There’s a lot
of variety within the general model, but most models do a linear combination of
their weights and the input (easily calculated by a matrix-vector
multiplication), followed by some nonlinear “activation function”. The
activation function is important to make our models learn nonlinear data, and
several kinds exist. In equation form, we can describe what happens within a
neuron by&lt;/p&gt;

\[h = \sigma(x_1\theta_1 + x_2\theta_2 + \cdots + x_n\theta_n)\]

&lt;p&gt;where \(x_n\) is the input, \(\theta_n\) are the weights, and \(\sigma()\) is
the activation functions. The most popular activation
functions, called rectified linear unit (ReLU), defined as \(\sigma(z) =
\text{max}(0, z)\).&lt;/p&gt;

&lt;h1 id=&quot;training-a-neural-network-optimization-and-backpropagation&quot;&gt;Training a neural network: Optimization and backpropagation&lt;/h1&gt;

&lt;p&gt;By tuning the model parameters, we can teach our neural network to perform
a task. But how do we know how to adjust those parameters? This is particularly
difficult in deep models that are complex and have parameters on the order of
billions.&lt;/p&gt;

&lt;p&gt;The way most deep models are trained follows an optimization procedure. When a
model produces a result, we can define an &lt;em&gt;error function&lt;/em&gt; (also known as a cost
function) that measures how incorrect our model is. Then we can define the
training procedure as an optimization problem: we want to find the model
parameters such that the error on our training set is minimized.&lt;/p&gt;

&lt;p&gt;To actually perform this minimization, most deep learning models make use of an
optimization algorithm known as &lt;em&gt;stochastic gradient descent&lt;/em&gt; (SGD). This
algorithm repeatedly approximates the gradient of the error function, and then
slightly moves the parameters in the direction that will decrease the error. The
gradient can be calculated through a procedure known as &lt;em&gt;backpropagation&lt;/em&gt;.
Essentially, the gradient gives us an idea of how much blame to assign to ever
parameter in our network for a prediction. By calculating the blame for a number
of inputs, we can approximate how the parameters are affecting the model’s
accuracy generally. Then, we can adjust the parameters such that they will
hopefully make our model more accurate. By repeating this enough times, we can
train our neural network by tuning all the parameters.&lt;/p&gt;

&lt;h1 id=&quot;deep-learning-variations&quot;&gt;Deep learning variations&lt;/h1&gt;

&lt;p&gt;Vanilla neural networks can be useful, but we can get even better results if we
modify the model to utilize some inherit characteristics of our objective task.
For instance, within images, the pixel data often has a high &lt;em&gt;spatial
dependence&lt;/em&gt;, meaning that pixels close together often work together to make out
specific attributes within an image. We can exploit this fact by utilizing a
&lt;em&gt;convolutional neural network&lt;/em&gt;. Without getting into the details, convolutional
neural networks (also known as convnets) replace some of the early layers with
neurons that perform a convolutional operation. This combines areas of a picture
using math and lets us extract spatial information from our inputs.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://brandonlmorris.github.io/images/tf/convnet.png&quot; style=&quot;display:block;margin:0 auto;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;Another variation of the typical (or &lt;em&gt;feedforward&lt;/em&gt;) neural network involves
making use of temporal dependencies, or when the correct output depends on
multiple inputs spaced out over time. Temporal dependencies are really common,
for instance speech recognition and text comprehension. To exploit these
dependencies, models known as &lt;em&gt;recurrent neural networks&lt;/em&gt; (RNNs) are typically
used. These are neural networks where some of the neurons feed back into
previous layers, creating a cycle. When inputs are sequentially fed into the
network, part of the calculations will depend on not only what the network
currently sees, but what it has previously seen. The mechanisms for when and how
to feed back into the network can get pretty advanced, and can even mimic our
understanding of how our own brains work, such as the long short-term memory
(LSTM) neuron.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://brandonlmorris.github.io/images/tf/rnn.gif&quot; style=&quot;display:block;margin:0 auto;&quot; /&gt;&lt;/p&gt;

&lt;p&gt;Deep learning has also significantly improved other areas of machine learning. A
great example of this is within &lt;em&gt;reinforcement learning&lt;/em&gt;, where models are
“agents” that try to perform a specific task like win a game against an opponent
or traverse a robot through a maze. Although the algorithms required to perform
reinforcement learning don’t necessitate deep neural networks, they have
benefited from the accuracy and generalization abilities of these models. For
instance, most of the interesting problems in reinforcement learning have a huge
search space, meaning that the number of possible states (e.g. chess positions,
robotic sensor data) the agent could be in is enormous and impossible to
completely exhaust during training. However, the agent will still need to take
actions in these states, even though it has not seen them before. To solve this,
we can represent the state space with a deep neural network, which can learn
what “kind of” state its in, and approximate it feasibly for the agent. Other
calculations the agent performs can be similarly approximated. Because deep
networks generalize so well (meaning they give similar outputs for similar,
though distinct, inputs), the agent can encounter entirely novel situations and
still act appropriately because it has learned what to do in similar states.&lt;/p&gt;

&lt;h1 id=&quot;conclusion&quot;&gt;Conclusion&lt;/h1&gt;

&lt;p&gt;Deep learning has moved the field of machine learning and artificial
intelligence dramatically forward and into the forefront of our technological
future. However, despite all the hype, it turns out that deep learning is really
some manageable mathematics and algorithms, refined and improved over a half
century or so. Ultimately, the models will keep improving and expanding their
capabilities beyond even what we’ve seen so far. Now is an amazing time to jump
in on deep learning, to both discover more about the nature of intelligence and
leverage our existing knowledge to improve the world we live in.&lt;/p&gt;</content>

      
      
      
      
      

      
        <author>
            <name>Brandon L. Morris</name>
          
          
        </author>
      

      
        <category term="ai" />
      

      

      
        <summary type="html">Deep learning is a subfield of machine learning that has had remarkable research success in the past decade. Huge numbers of research groups and top software companies are pushing the boundaries of what was previously thought possible through computation with these advancements.</summary>
      

      
      
    </entry>
  
  
  
    <entry>
      
      <title type="html">Clustering with K-Means</title>
      
      <link href="https://brandonlmorris.github.io/2016/12/06/clustering-with-k-means/" rel="alternate" type="text/html" title="Clustering with K-Means" />
      <published>2016-12-06T00:00:00+00:00</published>
      <updated>2016-12-06T00:00:00+00:00</updated>
      <id>https://brandonlmorris.github.io/2016/12/06/clustering-with-k-means</id>
      <content type="html" xml:base="https://brandonlmorris.github.io/2016/12/06/clustering-with-k-means/">&lt;p&gt;K-Means is a classic machine learning algorithm for discovering clusters
within data sets. It is a form of &lt;strong&gt;unsupervised learning&lt;/strong&gt;, where the
algorithm is not privy to the structure of the data, but is in fact trying to
learn said structure. My last blog post, concerning &lt;a href=&quot;https://brandonlmorris.github.io/ai/2016/11/10/an-introduction-to-machine-learning.html&quot;&gt;linear regression&lt;/a&gt;,
was an example of supervised learning, where the algorithm attempts to predict
values and with a clear right or wrong answer. Here, we will examine the
K-Means algorithm that identifies data points that have grouped together,
or &lt;strong&gt;clustered&lt;/strong&gt;, within the data set.&lt;/p&gt;

&lt;p&gt;Clustering occurs frequently in natural data, and discovering these groupings
can be very advantageous to a data scientist. Although searching for
clusters is typically an unsupervised task, their existence permits us to seek
new information about the data set and attempt to make predictions about
novel data. This kind of application is commonly employed by companies seeking
to recommend new products or services for their customers. If Netflix knows
that people who love movie A also love movie B, then they will likely suggest
movie B to you after you view movie A.&lt;/p&gt;

&lt;h2 id=&quot;the-setup&quot;&gt;The Setup&lt;/h2&gt;

&lt;p&gt;We will begin with the following scenario: imagine that we collected a
survey from a number of participants, asking them to rate their preference
of Cola C and Cola P. If we plot these ratings on a two-dimensional coordinate
plane, the results might look something like this:&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://brandonlmorris.github.io/images/ml/kmeans/k-means-initial.png&quot; alt=&quot;init-data&quot; /&gt;&lt;/p&gt;

&lt;p&gt;It is clear that the results have clustered around certain areas of the graph.
Some people mostly prefer Cola C, others just Cola P, and some do not
particularly care for either.&lt;/p&gt;

&lt;p&gt;Our goal will be to identify these groups mathematically. We will do this by
assuming that each cluster has some “center” point that the data points in
that cluster gather around. The K-Means algorithm will give us a way to
identify where these centers are for each of the clusters in our data set.&lt;/p&gt;

&lt;p&gt;Note that estimating these centers is inherently an unsupervised task. We do
not know where the centers are in the data, we want the algorithm to discover
them. A supervised task would involve trying to match data to a known
structure. Here, we are trying to learn the structure of the data itself.&lt;/p&gt;

&lt;h2 id=&quot;the-algorithm&quot;&gt;The Algorithm&lt;/h2&gt;

&lt;p&gt;The K-Means algorithm is relatively straightforward and simple mathematically.
Unlike gradient descent, no partial derivatives are necessary; just basic
addition and division.&lt;/p&gt;

&lt;p&gt;We begin by assuming there are a specific number of clusters in the data set:
\(k\). This hyperparameter may need to be tuned if the data is opaque, but
in our case a value of 3 is obviously correct.&lt;/p&gt;

&lt;p&gt;Each of these assumed clusters must have a center, as discussed earlier. We
can initialize these cluster centers randomly. Although incorrect at
the outset, their correct positions will be learned as the algorithm
progresses.&lt;/p&gt;

&lt;p&gt;The algorithm will then proceed as an iteration over two sequential phases,
referred to here as &lt;strong&gt;assign&lt;/strong&gt; and &lt;strong&gt;adjust&lt;/strong&gt;.&lt;/p&gt;

&lt;p&gt;The &lt;strong&gt;assign&lt;/strong&gt; phase will iterate over ever data point, calculate its distance
from each of the current cluster centers, and “assign” itself to the center it
is closest to. Here, distance can be measured a number of way (the example
below will use the common Euclidean distance). The result will be that every
data point will become associated with the cluster center that it best belongs
with.&lt;/p&gt;

&lt;p&gt;The next phase, &lt;strong&gt;adjust&lt;/strong&gt;, will recalculate the cluster center positions
based on which data points were assigned to it in the previous phase.
Specifically, the new center position will be the average position of
its assigned data points. This average, or “mean”, is where the algorithm
receives its name, K-Means, since each of the K clusters will be centered
around their average positions.&lt;/p&gt;

&lt;p&gt;As these two phases are repeated, the cluster centers will eventually settle
in their right position. That is because cluster centers will consistently
move towards the data points closest to them, thus making them even closer.
Eventually, the centers will no longer move, when they are in their final
position, at which point the algorithm can stop.&lt;/p&gt;

&lt;p&gt;So long as the number of cluster centers (\(k\)) matches the number of clusters
in the actual data, the algorithm will converge fairly quickly. It should be
noted, however, that it is possible for the cluster centers to settle in
incorrect positions, even when the value of \(k\) is correct. This can occur
if clusters are located close to one another, or when the initial values of
the cluster centers are initialized poorly. To prevent poor initialization, it
is common to run the algorithm multiple times, and compare the results from
each execution for consistency.&lt;/p&gt;

&lt;h2 id=&quot;the-code&quot;&gt;The Code&lt;/h2&gt;

&lt;p&gt;For this example, we will use the same data set depicted from our Cola C vs.
Cola P survey. We will assume that these data points come as lists, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;X&lt;/code&gt; and
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Y&lt;/code&gt;, corresponding to the \(x\) and \(y\) values, respectively.&lt;/p&gt;

&lt;p&gt;The adjust phase of the algorithm is implemented below. It takes, the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;X&lt;/code&gt; and
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Y&lt;/code&gt; values as parameters, as well as a list of the locations of the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;centers&lt;/code&gt;
(each center is a pair, i.e. &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;centers[0]&lt;/code&gt; might be &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;(22, 19)&lt;/code&gt;). The function
will return a list, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;nearests&lt;/code&gt;, which will contain a list of the assigned
data points for each center.&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;assign&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;centers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;nearests&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;list&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;centers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;zip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;dists&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;math&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;sqrt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cy&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;**&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
                 &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;cx&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cy&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;centers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;index&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;dists&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;index&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;min&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;dists&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;nearests&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;index&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;].&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;append&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;((&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;nearests&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;The next function, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;adjust&lt;/code&gt;, take a cluster center position and a list of
positions assigned to that center, and return the new position for the cluster
center. It calculates the new position as the average \(x\) and \(y\) values
of the assigned positions.&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;adjust&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;center&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;neighbors&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;neighbors&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;==&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# Avoid dividing by zero
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;center&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;avg_x&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;neighbors&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;neighbors&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;avg_y&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;neighbors&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;neighbors&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;avg_x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;avg_y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;The &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;main&lt;/code&gt; function, which will serve as the program’s entry point, will
initialize the cluster center positions, and iterate over the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;assign&lt;/code&gt; and
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;adjust&lt;/code&gt; functions until the cluster centers no longer move by any significant
amount (the value of &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;EPSILON&lt;/code&gt;) in any direction.&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;main&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# ...omitting data collection...
&lt;/span&gt;    &lt;span class=&quot;c1&quot;&gt;# Initial center points randomly
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;centers&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;randint&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;),&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;randint&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;100&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;range&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;3&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;iterations&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;while&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;c1&quot;&gt;# Move each center to the mean position of its assigned points
&lt;/span&gt;        &lt;span class=&quot;n&quot;&gt;neighbors&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;assign&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;centers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;new_centers&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;adjust&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;c&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;n&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;zip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;centers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;neighbors&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# Stop if we've converged
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;all&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([&lt;/span&gt;&lt;span class=&quot;nb&quot;&gt;abs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;EPSILON&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;and&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;abs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;lt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;EPSILON&lt;/span&gt;
                &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;n&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;c&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;zip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;centers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;new_centers&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]):&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;break&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;centers&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;new_centers&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;iterations&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;The following graphs display the algorithm in action. Initially, the cluster
centers are placed at random on the graph.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://brandonlmorris.github.io/images/ml/kmeans/k-means-st0.png&quot; alt=&quot;rand-init&quot; /&gt;&lt;/p&gt;

&lt;p&gt;After the first iteration&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://brandonlmorris.github.io/images/ml/kmeans/k-means-st1.png&quot; alt=&quot;first-iter&quot; /&gt;&lt;/p&gt;

&lt;p&gt;And after the second iteration, the cluster centers will settle, or converge
on their correct positions.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://brandonlmorris.github.io/images/ml/kmeans/k-means-st2.png&quot; alt=&quot;last-iter&quot; /&gt;&lt;/p&gt;

&lt;h2 id=&quot;conclusion&quot;&gt;Conclusion&lt;/h2&gt;

&lt;p&gt;K-Means is an extremely useful algorithm for discovering correlations within
a data set. Even better, the algorithm is efficient and simple to implement
(though much better implementations exist in open source machine learning
libraries).&lt;/p&gt;</content>

      
      
      
      
      

      
        <author>
            <name>Brandon L. Morris</name>
          
          
        </author>
      

      
        <category term="ai" />
      
        <category term="tutorial" />
      

      

      
        <summary type="html">K-Means is a classic machine learning algorithm for discovering clusters within data sets</summary>
      

      
      
    </entry>
  
  
  
    <entry>
      
      <title type="html">An Introduction to Machine Learning</title>
      
      <link href="https://brandonlmorris.github.io/2016/11/10/an-introduction-to-machine-learning/" rel="alternate" type="text/html" title="An Introduction to Machine Learning" />
      <published>2016-11-10T00:00:00+00:00</published>
      <updated>2016-11-10T00:00:00+00:00</updated>
      <id>https://brandonlmorris.github.io/2016/11/10/an-introduction-to-machine-learning</id>
      <content type="html" xml:base="https://brandonlmorris.github.io/2016/11/10/an-introduction-to-machine-learning/">&lt;p&gt;If you’ve seen any technology news in the past couple of years, it was probably
a headline having something to do with computers able to achieve feats
previously thought impossible. Driving cars, diagnosing patients, or
beating world champions at complex strategy board games, have all be done. All
of these tasks are being done with an exploding branch of computer science
called &lt;strong&gt;artificial intelligence&lt;/strong&gt;. In this post I’ll look at a subset of AI
called &lt;strong&gt;machine learning&lt;/strong&gt; and breakdown one of the simpler algorithms called
linear regression.&lt;/p&gt;

&lt;p&gt;Machine learning is a category of algorithms that can “learn” from data. In
other words, if they’re given a particular task, \(T\), machine learning
algorithms can improve at that task (to some limit) given experiences, \(E\).&lt;/p&gt;

&lt;h3 id=&quot;linear-regression&quot;&gt;Linear Regression&lt;/h3&gt;

&lt;p&gt;The algorithm that we’ll look at is called &lt;strong&gt;linear regression&lt;/strong&gt;. At its heart,
linear regression is a “line-fitting” algorithm: given some data that seems
to form some kind of line on a graph, we want to derive an equation that
describes that line. With such an equation, we can make predictions about new
values that are outside out initial dataset.&lt;/p&gt;

&lt;p&gt;To put it in more concrete terms, let’s say we wanted to build a model that
could predict house prices in a certain area. If we plotted the size of the
the house (in square feet) compared to the price that the house sold for, we
would likely see some kind of linear distribution of the data points. Using
linear regression, we could find the equation of that line, and then predict
how much money we think a new house will sell for given its square footage.&lt;/p&gt;

&lt;p&gt;In reality, most problems like this are much more complex than this simple
example. For instance, houses can have dozens of factors (or “features”) that
contribute to their selling price, like the number of bedrooms, school
district, and so on. Despite this complexity in the real world, we can look
past the specific and learn the principles that form the foundation of these
complicated models.&lt;/p&gt;

&lt;h3 id=&quot;case-study&quot;&gt;Case Study&lt;/h3&gt;

&lt;p&gt;For our example, let’s say we have the following data:&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://brandonlmorris.github.io/images/ml/initial-data.png&quot; alt=&quot;initial-dataset&quot; /&gt;&lt;/p&gt;

&lt;p&gt;I generated this data myself with a linear equation. I also added some random
noise to make it a little bit more realistic. Our goal will be to find a
way to predict where a new data point will fall vertically given where it sits
horizontally (in other words, the \(y\) value given the \(x\) value). We can
state this goal as such:&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;&lt;em&gt;Goal: Create a model that can reasonably predict new values based on the
existing data&lt;/em&gt;&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;This is a pretty vague goal, and we’ll refine it as we progress.&lt;/p&gt;

&lt;p&gt;You may remember from school something called the &lt;strong&gt;slope-intercept equation&lt;/strong&gt; of
a line. If you don’t here’s what it looks like:&lt;/p&gt;

\[y = b + mx\]

&lt;p&gt;Where \(x\) and \(y\) are the point coordinates, \(b\) is the \(y\)-intercept
(where the line hits the \(y\)-axis), and \(m\) is the slope of the line.&lt;/p&gt;

&lt;p&gt;We can rewrite this equation with some different symbols that are commonly
used in machine learning like so:&lt;/p&gt;

\[h_\theta(x) = \theta_0 + \theta_1x\]

&lt;p&gt;The function \(h_\theta\) is called the &lt;strong&gt;hypothesis&lt;/strong&gt; function, and it serves
as our predictor. Given an \(x\) value, \(h_\theta(x)\) will produce what
our model thinks \(y\) should be.&lt;/p&gt;

&lt;p&gt;Using what we already have said about the slope-intercept equation, we can consider
how we might be able to achieve our goal. We said that the \(\theta_0\) value
determines the \(y\)-intercept of the line, or where it sits vertically. By
increasing or decreasing \(\theta_0\), we can shift the line up or down,
respectively. Additionally, \(\theta_1\) determines the slope of the line, or
its angle. By modifying \(\theta_1\), we can tilt the line more or less. By
combining these two mechanisms, shifting and tilting, we can create the line
that resembles the data set.&lt;/p&gt;

&lt;p&gt;\(\theta_0\) and \(\theta_1\) serve as the &lt;strong&gt;parameters&lt;/strong&gt; of our model, since
their values will directly determine how well our model can predict existing
and new values.&lt;/p&gt;

&lt;p&gt;(It should be noted that this is a relatively contrived and simple example
with only two parameters, but the same principles apply even if we have more
dimensions or even higher-order dimensions like \(x^2\) or \(x^3\).)&lt;/p&gt;

&lt;p&gt;Now we can modify our goal slightly:&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;&lt;em&gt;Goal: Find the values of the parameters, \(\theta_0\) and \(\theta_1\), that
form a line that best “fits” our data.&lt;/em&gt;&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;But what do I mean when I say a line that “fits” the data? If we just look at
a plot of our line on top of the data, we can get an intuitive feel of if the
line is “good” or not. But what if our data has many dimensions and can’t
easily be visualized? And how can we empirically determine which of two lines
are better if they both seem pretty close to the data set?&lt;/p&gt;

&lt;p&gt;To solve these problems, we’ll need a tool called the &lt;strong&gt;cost function&lt;/strong&gt;. You
may also see it called the error or loss function, but it all refers to the
same concept. The cost function will let us “score” our model by determining
mathematically how close or far it is from the original data. Many different
cost functions exist, but most of them follow some scheme of measuring the
distance between a model prediction (\(h_\theta(x)\)) and the actual value
(\(y\)) for each data point.&lt;/p&gt;

&lt;p&gt;For our example, we’ll use the &lt;strong&gt;sum of squares&lt;/strong&gt; cost function, which is a
common and effective cost function. To calculate the cost, we will loop over
every data point, find the difference between the predicted and actual value,
square it, then add up the result for all the data points. Mathematically, it
looks like&lt;/p&gt;

\[J(\theta_0, \theta_1) = \frac{1}{2N}\sum_{n=1}^{N}(h_\theta(x_n) - y_n)^2\]

&lt;p&gt;We’ll also divide the sum by the number of data points, \(N\), which is called
&lt;em&gt;normalization&lt;/em&gt;. It allows our model’s cost to be independent of the number of
the number of data points it was calculated against (i.e. adding more data
doesn’t necessarily increase the cost). The division by 2 is somewhat
arbitrary and not completely necessary. We do it to make some of the math in
the future slightly cleaner.&lt;/p&gt;

&lt;p&gt;With the cost function, we now have a way to actually evaluate how good (or
rather, bad) our model is at predicting values. Ideally, we would want our
parameters (\(\theta_0\) and \(\theta_1\)) to be such that the cost function
is equal to 0, but that’s often impossible or even undesirable. So we’ll
settle with minimizing the cost function to the lowest value we reasonably can,
which means finding the parameter values that produce a lower cost than any
other set of values. So let’s update our goal to reflect this new idea&lt;/p&gt;

&lt;p&gt;&lt;strong&gt;&lt;em&gt;Goal: Find the values of the parameters that forms the equation of a line
that minimizes our cost function&lt;/em&gt;&lt;/strong&gt;&lt;/p&gt;

&lt;p&gt;Mathematically, this is called an &lt;strong&gt;optimization&lt;/strong&gt; problem and can be written
as such:&lt;/p&gt;

\[\min\limits_{\theta_0, \theta_1}J(\theta_0, \theta_1)\]

&lt;h3 id=&quot;optimization-via-gradient-descent&quot;&gt;Optimization via Gradient Descent&lt;/h3&gt;

&lt;p&gt;Gradient descent is an algorithm that will allow us to perform the
aforementioned optimization. The key insight to gradient descent comes from
the &lt;em&gt;shape&lt;/em&gt; of our cost function. If we plot out the cost function with some
different values, we’ll likely get a “U” shaped curve like below.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://brandonlmorris.github.io/images/ml/quadratic-cost.png&quot; alt=&quot;cost-curve&quot; /&gt;&lt;/p&gt;

&lt;p&gt;This comes from the square term in our sum of squares function. Note that this
graph is in two dimensions, where our actual graph of \(J(\theta_0, \theta_1)\)
would be three-dimensional, but the same idea still applies.&lt;/p&gt;

&lt;p&gt;So how can we determine a minimum from this curve? With gradient descent, we’ll
pick an arbitrary set of starting values for our initial parameters (usually 0).
The cost of those parameters will give us some point on the U curve of the cost
function we were just looking at. Since we are aiming for the minimum, we’ll
want to adjust our parameters such that the cost function will be less as a
result. To figure out how to tweak the parameters to achieve this, we’ll take
the derivative of the cost function, and subtract the current value of the
parameter by that partial derivative (multiplied by a constant). We’ll then
repeat this process over and over again, moving closer and closer to the
minimum. Mathematically this can be described as&lt;/p&gt;

&lt;p&gt;Repeat until convergence:&lt;br /&gt;
 for \(j = 0\dots m\):&lt;br /&gt;
  \(\theta_j := \theta_j - \alpha \frac{\partial}{\partial\theta_j}J(\theta_0, \theta_1)\)&lt;br /&gt;&lt;/p&gt;

&lt;p&gt;A few things to note about this pseudocode&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;“Until convergence” is user defined, but usually when the cost function
stops decreasing by a significant amount after an iteration of the outer loop&lt;/li&gt;
  &lt;li&gt;\(m\) is the number of parameters in the model&lt;/li&gt;
  &lt;li&gt;\(\alpha\) is the &lt;strong&gt;learning rate&lt;/strong&gt;, which will determine how big or little
of changes to the parameters we make per iteration.&lt;/li&gt;
  &lt;li&gt;Simultaneous update: we’ll have to calculate the new values of the
parameters one at a time, but we want to update them all at once. We don’t
want to mix new parameter values and old ones during the inner loop.&lt;/li&gt;
&lt;/ul&gt;

&lt;p&gt;Intuitively, we can think of gradient descent as if we were trying to walk
down a hill into a valley. The partial derivatives will tell us the “gradient”,
or which direction we need to move in order to progress to the bottom. Each
iteration of the outer loop is like taking a small step. Given enough steps,
we’ll eventually make it to the valley floor, or the minimum of our cost
function.&lt;/p&gt;

&lt;p&gt;A nice debugging feature of gradient descent is that, when everything is
working properly, &lt;strong&gt;the cost function should &lt;em&gt;always&lt;/em&gt; decrease with every
iteration&lt;/strong&gt;. If the cost function stays the same, or only increases my an
infinitesimal amount, then we’ve likely converged at the minimum. If it is
ever increase, something has probably gone wrong and we should check our
code (or adjust our learning rate to be smaller).&lt;/p&gt;

&lt;p&gt;There’s one last thing we need to define before we can start implementing this
algorithm in code, and that’s the partial derivative of the cost function. If
you know multivariate calculus, feel free to try and derive this yourself, but
for the sake of brevity I’ll simply give the answer below:&lt;/p&gt;

\[\frac{\partial}{\partial\theta_j}J(\theta_0,\theta_1) = \frac{1}{N}\sum_{n=1}^{N}[(h_\theta(x_n) - y_n) * x_{n,j}]\]

&lt;p&gt;Where \(x_{n, j}\) is the \(j\)th feature value on the \(n\)th training set.
Although we didn’t explicitly state it, we can think of each \(x\) example
as a pair, \((1, x)\), that gets multiplied and summed to the respective
parameter, \((\theta_0, \theta_1)\).&lt;/p&gt;

&lt;h2 id=&quot;naive-implementation-in-python&quot;&gt;Naive Implementation in Python&lt;/h2&gt;

&lt;p&gt;Math is fun, but what does this all look like in code? For this example, I’m
going to skip over the data collection and cleaning, and focus on the
interesting parts. I’m going to assume that I have two (Python) lists, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;X&lt;/code&gt; and
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Y&lt;/code&gt; that hold the \(x\) and \(y\) values of our data set, respectively.&lt;/p&gt;

&lt;p&gt;We’ll first break our &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;X&lt;/code&gt; and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Y&lt;/code&gt; lists into two separate sets of lists:
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;X_train&lt;/code&gt;, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Y_train&lt;/code&gt;, &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;X_test&lt;/code&gt;, and &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;Y_test&lt;/code&gt;. The rule of thumb for this split
is about 70% of the examples will go to the training lists, and the other 30%
will be reserved for testing.&lt;/p&gt;

&lt;p&gt;Before we move any further, I want to stress how critical it is that we split
up our data set into training and test sets. The training set will be used
for the “learning” portion (determining the values of \(\theta_0\) and
\(\theta_1\)), while the tests set will be used to evaluate how good of a model
we have once the parameter values are set. The point here is that our ultimate
goal is to have a model that’s &lt;strong&gt;general&lt;/strong&gt;: it is effective at predicting
&lt;strong&gt;new&lt;/strong&gt; values that it hasn’t seen before. So if we try to test with the same
data that we trained with, we have no idea if we’re achieving our goal, since
all the test points will have been seen during training. With the segregated
sets, we can determine if our model can be effective when it encounters new
data.&lt;/p&gt;

&lt;p&gt;The first thing we’ll need to do is define our hypothesis function. This is
trivial since our simple model only has two parameters.&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;hypothesis&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;theta0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;theta1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;theta0&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;+&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;theta1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;Next, we’ll need to define our cost function&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;cost&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;theta0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;theta1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;errors&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hypothesis&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;theta0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;theta1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;zip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;squared&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;i&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;errors&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;squared&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;Now we’ll define a &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;descend&lt;/code&gt; function that will serve as one iteration of the
gradient descent loop&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;descend&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;t0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;xy&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;zip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;N&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Calculate partial wrt theta0
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;p0&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hypothesis&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;t0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;xy&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;p0&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;N&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;p0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;c1&quot;&gt;# Calculate partial wrt theta1
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;p1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hypothesis&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;t0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;xy&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;p1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;N&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;p1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;n&quot;&gt;new_t0&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t0&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;LEARNING_RATE&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;p0&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;new_t1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;LEARNING_RATE&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;*&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;p1&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;new_t0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;new_t1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;Finally, we’ll define our &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;main&lt;/code&gt; function that will serve as the driver of
the program&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;main&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;():&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# ...data generation omitted...
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;theta0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;theta1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;0&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;old_cost&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cost&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;theta0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;theta1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

    &lt;span class=&quot;k&quot;&gt;while&lt;/span&gt; &lt;span class=&quot;bp&quot;&gt;True&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;theta0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;theta1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;descend&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;theta0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;theta1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt;
                                 &lt;span class=&quot;n&quot;&gt;X_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
        &lt;span class=&quot;n&quot;&gt;new_cost&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;cost&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;theta0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;theta1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

        &lt;span class=&quot;c1&quot;&gt;# Check for convergence
&lt;/span&gt;        &lt;span class=&quot;k&quot;&gt;if&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;abs&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;new_cost&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;old_cost&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;&amp;gt;&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;EPSILON&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;k&quot;&gt;break&lt;/span&gt;
        &lt;span class=&quot;k&quot;&gt;else&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;:&lt;/span&gt;
            &lt;span class=&quot;n&quot;&gt;old_cost&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;new_cost&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;And that’s all we really need to implement linear regression with gradient
descent. However, it may be useful to visualize the results and evaluate
how good of a job we did. We’ll create some extra functions that let us do
that.&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;graph&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;t0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;c1&quot;&gt;# using matplotlib.pyplot
&lt;/span&gt;    &lt;span class=&quot;n&quot;&gt;pyplot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;plot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;color&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;=&lt;/span&gt;&lt;span class=&quot;s&quot;&gt;'r.'&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;pyplot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;plot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hypothesis&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;t0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;pyplot&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;show&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;

&lt;span class=&quot;k&quot;&gt;def&lt;/span&gt; &lt;span class=&quot;nf&quot;&gt;r_squared&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;t0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;X&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;):&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;Yp&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;[&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;hypothesis&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;t0&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;t1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;x&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;]&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;u&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;yp&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;yt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;yp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;yt&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;zip&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Yp&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)])&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;mean&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;len&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Y_true&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
    &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;nb&quot;&gt;sum&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;([(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;mean&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;**&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;2&lt;/span&gt; &lt;span class=&quot;k&quot;&gt;for&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;y&lt;/span&gt; &lt;span class=&quot;ow&quot;&gt;in&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;])&lt;/span&gt;
    &lt;span class=&quot;k&quot;&gt;return&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;-&lt;/span&gt; &lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;u&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;/&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;v&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;))&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;The &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;r_squared&lt;/code&gt; function implements a common statistical equation called
\(R^2\). The details aren’t important; just know that a score closer to 1.0
is better.&lt;/p&gt;

&lt;h2 id=&quot;scikit-learn-implementation&quot;&gt;scikit-learn Implementation&lt;/h2&gt;

&lt;p&gt;As painful as that might have been, luckily other people have done it before
and given away their code for free. As a plus, they probably did it better.&lt;/p&gt;

&lt;p&gt;scikit-learn (sklearn) is a popular Python library that’s full of handy
algorithms for machine learning. Let’s take a look at how we can leverage this
awesome resource for our example problem.&lt;/p&gt;

&lt;figure class=&quot;highlight&quot;&gt;&lt;pre&gt;&lt;code class=&quot;language-python&quot; data-lang=&quot;python&quot;&gt;&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;sklearn.linear_model&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;LinearRegression&lt;/span&gt;
&lt;span class=&quot;kn&quot;&gt;from&lt;/span&gt; &lt;span class=&quot;nn&quot;&gt;numpy&lt;/span&gt; &lt;span class=&quot;kn&quot;&gt;import&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;array&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Transform X and Y to numpy vectors
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_train&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;Y_train&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;array&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;Y_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;).&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;reshape&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;o&quot;&gt;-&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;mi&quot;&gt;1&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Perform the actual training
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;LinearRegression&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;fit&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y_train&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;

&lt;span class=&quot;c1&quot;&gt;# Predict and evaluate
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;y_15&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;predict&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;mi&quot;&gt;15&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt;
&lt;span class=&quot;n&quot;&gt;r_squared&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;score&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;(&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;X_test&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;,&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;Y_test&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;)&lt;/span&gt; &lt;span class=&quot;c1&quot;&gt;# must be vectors
&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;params&lt;/span&gt; &lt;span class=&quot;o&quot;&gt;=&lt;/span&gt; &lt;span class=&quot;n&quot;&gt;lr&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;.&lt;/span&gt;&lt;span class=&quot;n&quot;&gt;get_params&lt;/span&gt;&lt;span class=&quot;p&quot;&gt;()&lt;/span&gt;&lt;/code&gt;&lt;/pre&gt;&lt;/figure&gt;

&lt;p&gt;And that’s it: the whole algorithm boils down to essentially two lines. You’ll
note that we had to turn our lists into &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;numpy&lt;/code&gt; vectors, and that is so that
scikit-learn can do some internal optimization with our data to run faster.&lt;/p&gt;

&lt;h2 id=&quot;results&quot;&gt;Results&lt;/h2&gt;

&lt;p&gt;When I ran our implementation, I found that a learning rate of about 0.0003
worked well. Gradient descent took about 25,000 iterations to converge, but
only about 2-3 seconds. We also got an \(R^2\) score of 0.9980. The results
are plotted below.&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://brandonlmorris.github.io/images/ml/results.png&quot; alt=&quot;lrgd-results&quot; /&gt;&lt;/p&gt;

&lt;p&gt;So even though our example was pretty simple, we did a pretty nice job fitting
our data. Not bad for less than 50 lines of code!&lt;/p&gt;

&lt;h2 id=&quot;conclusion&quot;&gt;Conclusion&lt;/h2&gt;

&lt;p&gt;Although this was a relatively simple example, these principles are
foundational and used throughout machine learning. Indeed, machine learning
can often seem magical (especially considering its powerful applications), but
at the end of the day it boils down to some algorithms leveraging statistics
quite nicely.&lt;/p&gt;

&lt;p&gt;Here, we looked at one specific example of machine learning: linear regression
with gradient descent. But tons of algorithms exists, like support vector
machines, logistic regression, and artificial neural networks to name a few.
Each has their own appeals and drawbacks that makes them more or less suited
to particular problems.&lt;/p&gt;

&lt;p&gt;Machine learning, and more broadly, artificial intelligence, can’t solve every
problem. But they are an extremely powerful way of tackling problems that
often seem impossible to accomplish through computation. We are still
discovering new applications and methods for machine learning, and I have no
doubt that it will even more of a dramatic impact on our daily lives in the
near future.&lt;/p&gt;</content>

      
      
      
      
      

      
        <author>
            <name>Brandon L. Morris</name>
          
          
        </author>
      

      
        <category term="ai" />
      
        <category term="tutorial" />
      

      

      
        <summary type="html">Computers are smart, but how can we teach them new things?</summary>
      

      
      
    </entry>
  
  
  
    <entry>
      
      <title type="html">Docker for Beginners</title>
      
      <link href="https://brandonlmorris.github.io/2016/10/13/docker-for-beginners/" rel="alternate" type="text/html" title="Docker for Beginners" />
      <published>2016-10-13T00:00:00+00:00</published>
      <updated>2016-10-13T00:00:00+00:00</updated>
      <id>https://brandonlmorris.github.io/2016/10/13/docker-for-beginners</id>
      <content type="html" xml:base="https://brandonlmorris.github.io/2016/10/13/docker-for-beginners/">&lt;p&gt;Docker and containerization are all the rage nowadays, but what are they, and
what makes them so technologically appealing?&lt;/p&gt;

&lt;p&gt;We should first consider the problems that they are trying to solve. Let’s say
that you have an application you want to deploy: some code that will have to
run on another machine, either your own server or in the mystical cloud. How
can you be sure that your app will work after you deploy it? Of course, you
have tests, and of course, they are passing before you even consider releasing,
so you know that it runs fine on your computer, but that’s not where we’re
deploying. On a different computer or in a different environment, there’s no
real way for us to ensure that &lt;em&gt;something&lt;/em&gt; won’t break &lt;em&gt;somewhere&lt;/em&gt; because
the configurations don’t perfectly match up.&lt;/p&gt;

&lt;p&gt;Now let’s consider a slightly different problem. Perhaps your application has
become incredibly popular. Congratulations. But the immense attention its
now receiving is causing a serious degredation in the performance. How can you
scale to meet your users needs? You could buy a nicer server or upgrade your
cloud instance, but that’s pricey and can only get us so far. What we need is
an effective way to scale outward, or &lt;em&gt;horizontally&lt;/em&gt;.&lt;/p&gt;

&lt;h2 id=&quot;the-old-solution-virtual-machines&quot;&gt;The Old Solution: Virtual Machines&lt;/h2&gt;

&lt;p&gt;Virtual machines can be a viable solution to these problems. They serve as
full-fledged computers implemented solely in software. Its like having a
computer inside your computer. They’re entirely self-contained, which solves
the consistency issue: instead of shipping code, you can ship a VM that has
your code on it. So long as you deploy with the same machine you test with,
you can rest assured that the two environments will remain consistent.&lt;/p&gt;

&lt;p&gt;Virtual machines can also aid in terms of scalability, though in a limited
fashion. If you design your application appropriately, you can scale by adding
new instances of your VM to a pool in production. When a user hits your
service, they will utilize only part of your total deployment, leading to a
distributed workload. Of course, developing an application for this kind of
architecture is no small feat, since distribution introduces problems of
consistency and fault-tolerance, to name a few.&lt;/p&gt;

&lt;p&gt;So why even bother with containization when we already have virtual machines.
A large drawback to VMs is that they are bulky. A typical virtual machine will
be gigabytes in size, since they have to exhibit all the characteristics of a
legitimate computer. All of that virtualization creates additional overhead,
even with “bare-metal” VMs that interact more closely with the underlying
hardware.&lt;/p&gt;

&lt;p&gt;In addition, VMs aren’t very flexible. Each one has to be configured prior to
starting with the exact amount of physical hardware (CPUs, memory, etc.) that
it, and it alone, has access to. Virtual machines can’t share resources, and
the ammount alloted to them can’t be modified at runtime. So although a
physical machine can run multiple VMs, its limited in terms of its total
resources.&lt;/p&gt;

&lt;h2 id=&quot;containers-the-lean-mean-virtual-machine&quot;&gt;Containers: The Lean, Mean, Virtual Machine&lt;/h2&gt;

&lt;p&gt;Containers leverage a lot of the benefits of virtual machines while avoiding
their detriments, and they do it in a very clever way. Recall that virtual
machines created a significant amount of overhead by recreating every aspect
of the physical computer in software. Containers skip this and leverage the
underlying kernel of the host, effectively &lt;em&gt;sharing&lt;/em&gt; the low level code and
resources with the host and other containers. In addition, they also share
common binaries and packages. In reality, containers aren’t so much a
virtualized computer as they are a shim: the real work is keeping track of the
ways that a particular container differs from the actual machine its running
on. By only keeping up with the differences, containers are dramatically
smaller and faster.&lt;/p&gt;

&lt;p&gt;Here’s what I mean concretely: the ISO for Ubuntu 14.04 on I have downloaded
is 649 megabytes. If I download the same version as a Docker container image,
it’s only 188 megabytes.&lt;/p&gt;

&lt;p&gt;Similarly, starting an Ubuntu VM takes around 30 seconds on my machine
(Macbook Pro with SSD), even if it just headless. A container, however, will
take less than a second.&lt;/p&gt;

&lt;p&gt;But with all this sharing going on, doesn’t that negate one of our primary
reasons for choosing VMs: that we would have completely isolated environments?
Actually, no. Containers maintain all their “differences”, what makes them
unique to the host and other containers, constantly. The effect is that they
are in fact completely isolated on a logical level. They may share a file or
a library under the hood, but its completely transparent to you, the user.&lt;/p&gt;

&lt;p&gt;The fact that containers are still isolated while also being lightweight means
that we can use them similarly to the way we can use VMs with much less
overhead. We can be confident that our apps will operate the way we expect
them too because we can literally &lt;strong&gt;ship the environment with the app&lt;/strong&gt; all
bundled up in neat containers.&lt;/p&gt;

&lt;h2 id=&quot;containers-are-not-vms&quot;&gt;Containers are not VMs&lt;/h2&gt;

&lt;p&gt;It can be tempting, especially at the outset, to view containers simply as
virtual machines, particularly since the analogy is so elegant. However the
two differ in some substantial ways; not only technologically, but
functionally.&lt;/p&gt;

&lt;p&gt;First of all, &lt;strong&gt;containers should be ephemeral&lt;/strong&gt;. They should be logically
small and capable of quickly being started and stopped. They should &lt;strong&gt;not&lt;/strong&gt;
serve as a location for data that needs to be maintained over time (for that,
you need volumes).&lt;/p&gt;

&lt;p&gt;Additionally, &lt;strong&gt;containers should only run one main process&lt;/strong&gt;. This is
important. Since containers are lightweight, we should use them as such. Since
containers are ephemeral, we shouldn’t rely on them sticking around. By
breaking up our app into multiple containers, we can acheive greater modularity
and scalability, though not without some redesign.&lt;/p&gt;

&lt;h2 id=&quot;working-with-docker&quot;&gt;Working with Docker&lt;/h2&gt;

&lt;p&gt;Enough theory, let’s actually get our hands dirty.&lt;/p&gt;

&lt;p&gt;To utilize containers yourself, you’ll need a containerization engine. The
most popular one out there is Docker. Instruction for installing on your
platform can be found &lt;a href=&quot;https://docs.docker.com/engine/installation/&quot;&gt;here&lt;/a&gt;.&lt;/p&gt;

&lt;p&gt;If everything went smoothly, you should be able to run &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;docker info&lt;/code&gt; and get
some reasonable output.&lt;/p&gt;

&lt;p&gt;Once you have Docker running, the real fun begins. Let’s print a message
from the Docker whale using cowsay:&lt;/p&gt;

&lt;div class=&quot;language-bash highlighter-rouge&quot;&gt;&lt;div class=&quot;highlight&quot;&gt;&lt;pre class=&quot;highlight&quot;&gt;&lt;code&gt;&lt;span class=&quot;nv&quot;&gt;$ &lt;/span&gt;docker run docker/whalesay cowsay &lt;span class=&quot;s2&quot;&gt;&quot;Hello, docker&quot;&lt;/span&gt;
&lt;/code&gt;&lt;/pre&gt;&lt;/div&gt;&lt;/div&gt;

&lt;p&gt;The command should take a few seconds to run fully. If this is your first time
running this command, Docker has to pull (download) the image first. An
image in Docker is analogous to an snapshot of a VM. It serves as a template
from which we build all of our containers off of. However, each container from
this image won’t cause the image to be reconstructed, since the container will
simply keep track of how it differs from that base image. Concretely, this
means that no matter how many containers of an image we spin up, we only need
to have &lt;strong&gt;one&lt;/strong&gt; copy of the image on disk.&lt;/p&gt;

&lt;p&gt;Images are almost always layered, which is why you probably saw multiple lines
downloading after issuing the command. Even these layers can be reused by
Docker. So if in the future you use a different, but similar, image, the
download time (and disk space) will be decreased.&lt;/p&gt;

&lt;p&gt;Docker hosts tons of images for lots of different applications, which can be
found at &lt;a href=&quot;https://hub.docker.com&quot;&gt;Docker hub&lt;/a&gt;. Docker will pull from Dockerhub
automatically if it can’t find the image locally. You can download an image
with &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;docker pull &amp;lt;&amp;lt;image name&amp;gt;&amp;gt;&lt;/code&gt; and search images with &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;docker search
&amp;lt;&amp;lt;image name&amp;gt;&amp;gt;&lt;/code&gt;&lt;/p&gt;

&lt;p&gt;Once the image downloads, the container will start. In our command, we
specified that the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;docker/whalesay&lt;/code&gt; image should execute the command
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;cowsay&lt;/code&gt;. You should see the following output:&lt;/p&gt;

&lt;p&gt;&lt;img src=&quot;https://brandonlmorris.github.io/images/whalesay.jpeg&quot; alt=&quot;whalesay&quot; /&gt;&lt;/p&gt;

&lt;p&gt;Once the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;cowsay&lt;/code&gt; command ends, the container will promptly stop. However, it
will stick around should you want to run it again. You can view your running
containers with &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;docker ps&lt;/code&gt; and view &lt;em&gt;all&lt;/em&gt; containers with &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;docker ps -a&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;Docker will automatically generate a silly name for your container, because
whimsy is an important aspect of software engineering. To delete the container,
you can run &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;docker rm &amp;lt;&amp;lt;container name/id&amp;gt;&amp;gt;&lt;/code&gt;&lt;/p&gt;

&lt;h2 id=&quot;playing-in-a-sandbox&quot;&gt;Playing in a Sandbox&lt;/h2&gt;

&lt;p&gt;Let’s toy around with this some more. If you ran the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;docker/whalesay&lt;/code&gt;
container from the previous section, you should have the image saved on your
computer. You can verify this with &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;docker images&lt;/code&gt;.&lt;/p&gt;

&lt;p&gt;Earlier, we ran the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;cowsay&lt;/code&gt; command on this image, but there’s nothing
stopping us from running other commands. Try&lt;/p&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;$ docker run docker/whalesay date&lt;/code&gt;&lt;/p&gt;

&lt;p&gt;It should print out the current time. That’s nice, but what if we want to
keep the container up and run multiple commands? This is called attaching to
a container. To do this, we will need two things. We need to tell Docker
to give our container a tty interface, so we can issue commands, as well as to
read input from &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;stdin&lt;/code&gt; (our keyboard, in this case). Additionally, we will
need to execute a &lt;strong&gt;long-running process&lt;/strong&gt; that won’t immediately end and kill
our container. We can do all this with the following command:&lt;/p&gt;

&lt;p&gt;&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;$ docker run -it docker/whalesay bash&lt;/code&gt;&lt;/p&gt;

&lt;p&gt;Here, the &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;-it&lt;/code&gt; flags give us an “interactive” and “tty” run on our container.
We also execute &lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;bash&lt;/code&gt;, which will serve as our long-running process. After
executing this command, you should see a different terminal prompt, coming
from inside your container.&lt;/p&gt;

&lt;p&gt;Feel free to mess around and issue any commands you normally could on an
Ubuntu machine. But be warned: any changes you make die along with the
container. &lt;strong&gt;Do not store data inside a container&lt;/strong&gt;. Instead you should use
&lt;a href=&quot;https://docs.docker.com/engine/tutorials/dockervolumes/&quot;&gt;volumes&lt;/a&gt; to export the data to a more persistent location.
When you’re done, you can exit the container by exiting the bash shell:
&lt;code class=&quot;language-plaintext highlighter-rouge&quot;&gt;&amp;lt;Ctrl-d&amp;gt;&lt;/code&gt;.&lt;/p&gt;

&lt;h2 id=&quot;go-forth-and-dockerize&quot;&gt;Go forth, and Dockerize&lt;/h2&gt;

&lt;p&gt;There’s a load more to say about Docker and containers in general. Some
interesting points that I didn’t get to in this post include (but are not
limited to):&lt;/p&gt;

&lt;ul&gt;
  &lt;li&gt;Volumes and persistent storage&lt;/li&gt;
  &lt;li&gt;Building your own images with Dockerfiles&lt;/li&gt;
  &lt;li&gt;Creating a full-stack app with containers&lt;/li&gt;
  &lt;li&gt;Networking to and between containers&lt;/li&gt;
&lt;/ul&gt;

&lt;h2 id=&quot;security-warning&quot;&gt;Security Warning&lt;/h2&gt;

&lt;p&gt;I would be remiss if I did not specifically clarify a security concern
regarding containers: &lt;strong&gt;containers are not a securely isolated solution&lt;/strong&gt;.
Although we spoke of the isolation that containers offer, they fundamentally
share the kernel with the host operating system, and are therefore not
secure in their own right. The common solution is to run Docker from within
a (single) virtual machine. Since the VM &lt;em&gt;is&lt;/em&gt; securely isolated, the host
machine is not at risk.&lt;/p&gt;</content>

      
      
      
      
      

      
        <author>
            <name>Brandon L. Morris</name>
          
          
        </author>
      

      
        <category term="docker" />
      
        <category term="tutorial" />
      

      

      
        <summary type="html">Docker and containerization are all the rage nowadays, but what are they, and what makes them so technologically appealing?</summary>
      

      
      
    </entry>
  
  
</feed>
