Neuromorphic Engineering Book
  • Welcome
  • Preliminaries
    • About the author
    • Preface
    • A tale about passion and fear
    • Before we begin
  • I. Introduction
    • 1. Introducing the perspective of the scientist
      • From the neuron doctrine to emergent behavior
      • Brain modeling
      • Take away lessons
    • 2. Introducing the perspective of the computer architect
      • Limits of integrated circuits
      • Emerging computing paradigms
      • Brain-inspired hardware
      • Take away lessons
      • Errata
    • 3. Introducing the perspective of the algorithm designer
      • From artificial to spiking neural networks
      • Neuromorphic software development
      • Take home lessons
  • II. Scientist perspective
    • 4. Biological description of neuronal dynamics
      • Potentials, spikes and power estimation
      • Take away lessons
      • Errata
    • 5. Models of point neuronal dynamic
      • Tutorial - models of point neuronal processes
        • The leaky integrate and fire model
        • The Izhikevich neuron model
        • The Hodgkin-Huxley neuron model
      • Synapse modeling and point neurons
      • Case study: a SNN for perceptual filling-in
      • Take away lessons
    • 6. Models of morphologically detailed neurons
      • Morphologically detailed modeling
      • The cable equation
      • The compartmental model
      • Case study: direction-selective SAC
      • Take away lessons
    • 7. Models of network dynamic and learning
      • Circuit taxonomy, reconstruction, and simulation
      • Case study: SACs' lateral inhibition in direction selectivity
      • Neuromorphic and biological learning
      • Take away lessons
      • Errate
  • III. Architect perspective
    • 8. Neuromorphic Hardware
      • Transistors and micro-power circuitry
      • The silicon neuron
      • Case study: hardware - software co-synthesis
      • Take away lessons
    • 9. Communication and hybrid circuit design
      • Neural architectures
      • Take away lessons
    • 10. In-memory computing with memristors
      • Memristive computing
      • Take away lessons
      • Errata
  • IV. Algorithm designer perspective
    • 11. Introduction to neuromorphic programming
      • Theory and neuromorphic programming
      • Take away lessons
    • 12. The neural engineering framework
      • NEF: Representation
      • NEF: Transformation
      • NEF: Dynamics
      • Case study: motion detection using oscillation interference
      • Take away lessons
      • Errate
    • 13. Learning spiking neural networks
      • Learning with SNN
      • Take away lessons
Powered by GitBook
On this page
  1. IV. Algorithm designer perspective
  2. 13. Learning spiking neural networks

Learning with SNN

Chapter 13

Previous13. Learning spiking neural networksNextTake away lessons

Last updated 3 years ago

Was this helpful?

CtrlK
  • Python / Nengo demonstration
  • PES learning: communication channel
  • Pavlovian conditioning
  • Diffrentiable LIF tuning curve
  • MNIST classification

Was this helpful?

Read Chapter 13

Python / Nengo demonstration

Imports:

import nengo
import numpy as np
import matplotlib.pyplot as plt
from nengo.processes import WhiteSignal

from urllib.request import urlretrieve
import tensorflow as tf
import nengo_dl

PES learning: communication channel

model = nengo.Network('Learn a Communication Channel')
with model:
    stim = nengo.Node(output=WhiteSignal(10, high=5, rms=0.5))
    
    pre = nengo.Ensemble(60, dimensions=1)
    post = nengo.Ensemble(60, dimensions=1)
    
    nengo.Connection(stim, pre)
    conn = nengo.Connection(pre, post, function=lambda x: np.random.random())
    
    inp_p = nengo.Probe(stim)
    pre_p = nengo.Probe(pre, synapse=0.01)
    post_p = nengo.Probe(post, synapse=0.01)
    
    error = nengo.Ensemble(60, dimensions=1)
    error_p = nengo.Probe(error, synapse=0.03)
    
    nengo.Connection(post, error)

Result:

Pavlovian conditioning

D = 3
N = D*100

def us_stim(t):
    # cycle through the three US
    t = t % 3
    if 0.9 < t< 1: return [1, 0, 0]
    if 1.9 < t< 2: return [0, 1, 0]
    if 2.9 < t< 3: return [0, 0, 1]
    return [0, 0, 0]

def cs_stim(t):
    # cycle through the three CS
    t = t % 3
    if 0.7 < t< 1: return [0.7, 0,   0.5]
    if 1.7 < t< 2: return [0.6, 0.7, 0.8]
    if 2.7 < t< 3: return [0, 1, 0]
    return [0, 0, 0]

def stop_learning(t):
    if 8 > t > 2:
        return 0
    return 1

model = nengo.Network(label="Classical Conditioning")
with model:
    
    us_stim = nengo.Node(us_stim)
    us_stim_p = nengo.Probe(us_stim)
    
    us = nengo.Ensemble(N, D)
    ur = nengo.Ensemble(N, D)
    us_p = nengo.Probe(us, synapse=0.1)
    ur_p = nengo.Probe(ur, synapse=0.1)
    nengo.Connection(us, ur)
    nengo.Connection(us_stim, us[:D])
    
    cs_stim = nengo.Node(cs_stim)
    cs_stim_p = nengo.Probe(cs_stim)

    cs = nengo.Ensemble(N*2, D*2)
    cr = nengo.Ensemble(N, D)
    cs_p = nengo.Probe(cs, synapse=0.1)
    cr_p = nengo.Probe(cr, synapse=0.1)
    nengo.Connection(cs_stim, cs[:D])
    nengo.Connection(cs[:D], cs[D:], synapse=0.2)

    learn_conn = nengo.Connection(cs, cr, function=lambda x: [0]*D)
    learn_conn.learning_rule_type = nengo.PES(learning_rate=3e-4)

    error   = nengo.Ensemble(N, D)
    error_p = nengo.Probe(error, synapse=0.01)
    nengo.Connection(error, learn_conn.learning_rule)
    nengo.Connection(ur, error, transform=-1)
    nengo.Connection(cr, error, transform=1, synapse=0.1)

    stop_learn = nengo.Node(stop_learning)
    stop_learn_p = nengo.Probe(stop_learn)
    nengo.Connection(stop_learn, error.neurons, transform=-10*np.ones((N, 1)))
    
sim = nengo.Simulator(model)
sim.run(15)

t=sim.trange()

plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[us_stim_p].T[0], c='blue',  label='US #1')
plt.plot(t, sim.data[us_stim_p].T[1], c='red',   label='US #2')
plt.plot(t, sim.data[us_stim_p].T[2], c='black', label='US #3')
plt.plot(t, sim.data[ur_p].T[0], c='blue',  label='UR #1', linestyle=":", linewidth=3)
plt.plot(t, sim.data[ur_p].T[1], c='red',   label='UR #2', linestyle=":", linewidth=3)
plt.plot(t, sim.data[ur_p].T[2], c='black', label='UR #3', linestyle=":", linewidth=3)
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend()
plt.show()

plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[cs_stim_p].T[0], c='blue',  label='CS #1')
plt.plot(t, sim.data[cs_stim_p].T[1], c='red',   label='CS #2')
plt.plot(t, sim.data[cs_stim_p].T[2], c='black', label='CS #3')
plt.plot(t, sim.data[ur_p].T[0], c='blue',  label='UR #1', linestyle=":", linewidth=3)
plt.plot(t, sim.data[ur_p].T[1], c='red',   label='UR #2', linestyle=":", linewidth=3)
plt.plot(t, sim.data[ur_p].T[2], c='black', label='UR #3', linestyle=":", linewidth=3)
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend()
plt.show()

plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[cs_stim_p].T[0], c='blue',  label='CS #1')
plt.plot(t, sim.data[cs_stim_p].T[1], c='red',   label='CS #2')
plt.plot(t, sim.data[cs_stim_p].T[2], c='black', label='CS #3')
plt.plot(t, sim.data[cr_p].T[0],      c='blue',  label='CR #1', linestyle=":", linewidth=3)
plt.plot(t, sim.data[cr_p].T[1],      c='red',   label='CR #2', linestyle=":", linewidth=3)
plt.plot(t, sim.data[cr_p].T[2],      c='black', label='CR #3', linestyle=":", linewidth=3)
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend()
plt.show()

plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[error_p].T[0], c='black', label='error')
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend()
plt.show()

plt.figure(figsize=(12, 2))
plt.plot(t, sim.data[stop_learn_p].T[0], c='black',  label='stop')
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend()
plt.show()

Results:

Diffrentiable LIF tuning curve

Rm      = 1000     # Membrane Resistance      [kOhm]
Cm      = 5e-6     # Capacitance              [uF]

t_ref = 10e-3
tau = Rm * Cm
v_th = 1
I0 = 1
R=1000
la = 0.05

rho  = lambda x: np.max(x,0)
rho2 = lambda x: la*np.log(1+np.exp(x/la))

a3  = lambda i: 1/(t_ref+tau*np.log(1+(v_th/rho(i-v_th))))
a4  = lambda i: 1/(t_ref+tau*np.log(1+(v_th/rho2(i-v_th))))

I = np.linspace(0,3,10000)
A3 = [a3(i) for i in I]
A4 = [a4(i) for i in I]

plt.plot(I,A3, linewidth=3, color='black', label=r'$\rho=max(x,0)$')
plt.plot(I,A4, linewidth=3, color='red', label=r'$\rho=\lambda log(1+e^{x/\lambda})$')
plt.ylim(0,100)
plt.xlim(0,3)
plt.legend()
plt.ylabel("Firing rate (Hz)")
plt.xlabel("Input current")
plt.show()

Result:

MNIST classification

This demonstration was adopted from:

Data loading:

(train_images, train_labels), (
    test_images,
    test_labels,
) = tf.keras.datasets.mnist.load_data()

# flatten images
train_images = train_images.reshape((train_images.shape[0], -1))
test_images = test_images.reshape((test_images.shape[0], -1))

Model definition:

with nengo.Network(seed=0) as net:

    net.config[nengo.Ensemble].max_rates = nengo.dists.Choice([100])
    net.config[nengo.Ensemble].intercepts = nengo.dists.Choice([0])
    net.config[nengo.Connection].synapse = None
    neuron_type = nengo.LIF(amplitude=0.01)

    # this is an optimization to improve the training speed,
    # since we won't require stateful behaviour in this example
    nengo_dl.configure_settings(stateful=False)

    # the input node 
    inp = nengo.Node(np.zeros(28 * 28))

    # add the first convolutional layer
    x = nengo_dl.Layer(tf.keras.layers.Conv2D(filters=32, kernel_size=3))(
        inp, shape_in=(28, 28, 1)
    )
    x = nengo_dl.Layer(neuron_type)(x)

    # add the second convolutional layer
    x = nengo_dl.Layer(tf.keras.layers.Conv2D(filters=64, strides=2, kernel_size=3))(
        x, shape_in=(26, 26, 32)
    )
    x = nengo_dl.Layer(neuron_type)(x)

    # add the third convolutional layer
    x = nengo_dl.Layer(tf.keras.layers.Conv2D(filters=128, strides=2, kernel_size=3))(
        x, shape_in=(12, 12, 64)
    )
    x = nengo_dl.Layer(neuron_type)(x)

    # linear readout
    out = nengo_dl.Layer(tf.keras.layers.Dense(units=10))(x)

    out_p = nengo.Probe(out, label="out_p")
    out_p_filt = nengo.Probe(out, synapse=0.1, label="out_p_filt")

Network building:

minibatch_size = 200
sim = nengo_dl.Simulator(net, minibatch_size=minibatch_size)

Preprocessing:

# add single timestep to training data
train_images = train_images[:, None, :]
train_labels = train_labels[:, None, None]

# when testing our network with spiking neurons we will need to run it
# over time, so we repeat the input/target data for a number of
# timesteps.
n_steps = 30
test_images = np.tile(test_images[:, None, :], (1, n_steps, 1))
test_labels = np.tile(test_labels[:, None, None], (1, n_steps, 1))

SNN compilation before training:

def classification_accuracy(y_true, y_pred):
    return tf.metrics.sparse_categorical_accuracy(y_true[:, -1], y_pred[:, -1])

# note that we use `out_p_filt` when testing (to reduce the spike noise)
sim.compile(loss={out_p_filt: classification_accuracy})
print(
    "Accuracy before training:",
    sim.evaluate(test_images, {out_p_filt: test_labels}, verbose=0)["loss"],
)

# PRINTED: Accuracy before training: 0.09340000152587890:00:00 

Training:

do_training = False
if do_training:
    # run training
    sim.compile(
        optimizer=tf.optimizers.RMSprop(0.001),
        loss={out_p: tf.losses.SparseCategoricalCrossentropy(from_logits=True)},
    )
    sim.fit(train_images, {out_p: train_labels}, epochs=10)

    # save the parameters to file
    sim.save_params("./mnist_params")
else:
    # download pretrained weights
    urlretrieve(
        "https://drive.google.com/uc?export=download&"
        "id=1l5aivQljFoXzPP5JVccdFXbOYRv3BCJR",
        "mnist_params.npz",
    )

    # load parameters
    sim.load_params("./mnist_params")

Evaluating after training:

sim.compile(loss={out_p_filt: classification_accuracy})
print(
    "Accuracy after training:",
    sim.evaluate(test_images, {out_p_filt: test_labels}, verbose=0)["loss"],
)

# PRINTED: Accuracy after training: 0.9869999885559082 0:00:00 

Plotting:

data = sim.predict(test_images[:minibatch_size])

Results:

nengo.Connection(pre, error, transform=-1) # Learn simple communication line
conn.learning_rule_type = nengo.PES()
learn_conn = nengo.Connection(error, conn.learning_rule)
sim = nengo.Simulator(model)
sim.run(10.0)
t=sim.trange()
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[inp_p].T[0], c='k', label='Input')
plt.plot(t, sim.data[pre_p].T[0], c='b', label='Pre')
plt.plot(t, sim.data[post_p].T[0], c='r', label='Post')
plt.ylabel("Value")
plt.legend(loc=1)
plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[error_p].T[0], c='k', label='Error')
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend(loc='best')
Optimizing a spiking neural network — NengoDL 3.4.5.dev0 docs
Logo