Learning with SNN

Chapter 13

Read Chapter 13

Python / Nengo demonstration

Imports:

import nengo
import numpy as np
import matplotlib.pyplot as plt
from nengo.processes import WhiteSignal

from urllib.request import urlretrieve
import tensorflow as tf
import nengo_dl

PES learning: communication channel

model = nengo.Network('Learn a Communication Channel')
with model:
    stim = nengo.Node(output=WhiteSignal(10, high=5, rms=0.5))
    
    pre = nengo.Ensemble(60, dimensions=1)
    post = nengo.Ensemble(60, dimensions=1)
    
    nengo.Connection(stim, pre)
    conn = nengo.Connection(pre, post, function=lambda x: np.random.random())
    
    inp_p = nengo.Probe(stim)
    pre_p = nengo.Probe(pre, synapse=0.01)
    post_p = nengo.Probe(post, synapse=0.01)
    
    error = nengo.Ensemble(60, dimensions=1)
    error_p = nengo.Probe(error, synapse=0.03)
    
    nengo.Connection(post, error)
    nengo.Connection(pre, error, transform=-1) # Learn simple communication line
    conn.learning_rule_type = nengo.PES()
    learn_conn = nengo.Connection(error, conn.learning_rule) 

sim = nengo.Simulator(model)
sim.run(10.0)

t=sim.trange()

import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[inp_p].T[0], c='k', label='Input')
plt.plot(t, sim.data[pre_p].T[0], c='b', label='Pre')
plt.plot(t, sim.data[post_p].T[0], c='r', label='Post')
plt.ylabel("Value")
plt.legend(loc=1)

plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[error_p].T[0], c='k', label='Error')
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend(loc='best')

Result:

Pavlovian conditioning

D = 3
N = D*100

def us_stim(t):
    # cycle through the three US
    t = t % 3
    if 0.9 < t< 1: return [1, 0, 0]
    if 1.9 < t< 2: return [0, 1, 0]
    if 2.9 < t< 3: return [0, 0, 1]
    return [0, 0, 0]

def cs_stim(t):
    # cycle through the three CS
    t = t % 3
    if 0.7 < t< 1: return [0.7, 0,   0.5]
    if 1.7 < t< 2: return [0.6, 0.7, 0.8]
    if 2.7 < t< 3: return [0, 1, 0]
    return [0, 0, 0]

def stop_learning(t):
    if 8 > t > 2:
        return 0
    return 1

model = nengo.Network(label="Classical Conditioning")
with model:
    
    us_stim = nengo.Node(us_stim)
    us_stim_p = nengo.Probe(us_stim)
    
    us = nengo.Ensemble(N, D)
    ur = nengo.Ensemble(N, D)
    us_p = nengo.Probe(us, synapse=0.1)
    ur_p = nengo.Probe(ur, synapse=0.1)
    nengo.Connection(us, ur)
    nengo.Connection(us_stim, us[:D])
    
    cs_stim = nengo.Node(cs_stim)
    cs_stim_p = nengo.Probe(cs_stim)

    cs = nengo.Ensemble(N*2, D*2)
    cr = nengo.Ensemble(N, D)
    cs_p = nengo.Probe(cs, synapse=0.1)
    cr_p = nengo.Probe(cr, synapse=0.1)
    nengo.Connection(cs_stim, cs[:D])
    nengo.Connection(cs[:D], cs[D:], synapse=0.2)

    learn_conn = nengo.Connection(cs, cr, function=lambda x: [0]*D)
    learn_conn.learning_rule_type = nengo.PES(learning_rate=3e-4)

    error   = nengo.Ensemble(N, D)
    error_p = nengo.Probe(error, synapse=0.01)
    nengo.Connection(error, learn_conn.learning_rule)
    nengo.Connection(ur, error, transform=-1)
    nengo.Connection(cr, error, transform=1, synapse=0.1)

    stop_learn = nengo.Node(stop_learning)
    stop_learn_p = nengo.Probe(stop_learn)
    nengo.Connection(stop_learn, error.neurons, transform=-10*np.ones((N, 1)))
    
sim = nengo.Simulator(model)
sim.run(15)

t=sim.trange()

plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[us_stim_p].T[0], c='blue',  label='US #1')
plt.plot(t, sim.data[us_stim_p].T[1], c='red',   label='US #2')
plt.plot(t, sim.data[us_stim_p].T[2], c='black', label='US #3')
plt.plot(t, sim.data[ur_p].T[0], c='blue',  label='UR #1', linestyle=":", linewidth=3)
plt.plot(t, sim.data[ur_p].T[1], c='red',   label='UR #2', linestyle=":", linewidth=3)
plt.plot(t, sim.data[ur_p].T[2], c='black', label='UR #3', linestyle=":", linewidth=3)
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend()
plt.show()

plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[cs_stim_p].T[0], c='blue',  label='CS #1')
plt.plot(t, sim.data[cs_stim_p].T[1], c='red',   label='CS #2')
plt.plot(t, sim.data[cs_stim_p].T[2], c='black', label='CS #3')
plt.plot(t, sim.data[ur_p].T[0], c='blue',  label='UR #1', linestyle=":", linewidth=3)
plt.plot(t, sim.data[ur_p].T[1], c='red',   label='UR #2', linestyle=":", linewidth=3)
plt.plot(t, sim.data[ur_p].T[2], c='black', label='UR #3', linestyle=":", linewidth=3)
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend()
plt.show()

plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[cs_stim_p].T[0], c='blue',  label='CS #1')
plt.plot(t, sim.data[cs_stim_p].T[1], c='red',   label='CS #2')
plt.plot(t, sim.data[cs_stim_p].T[2], c='black', label='CS #3')
plt.plot(t, sim.data[cr_p].T[0],      c='blue',  label='CR #1', linestyle=":", linewidth=3)
plt.plot(t, sim.data[cr_p].T[1],      c='red',   label='CR #2', linestyle=":", linewidth=3)
plt.plot(t, sim.data[cr_p].T[2],      c='black', label='CR #3', linestyle=":", linewidth=3)
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend()
plt.show()

plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[error_p].T[0], c='black', label='error')
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend()
plt.show()

plt.figure(figsize=(12, 2))
plt.plot(t, sim.data[stop_learn_p].T[0], c='black',  label='stop')
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend()
plt.show()

Results:

Diffrentiable LIF tuning curve

Rm      = 1000     # Membrane Resistance      [kOhm]
Cm      = 5e-6     # Capacitance              [uF]

t_ref = 10e-3
tau = Rm * Cm
v_th = 1
I0 = 1
R=1000
la = 0.05

rho  = lambda x: np.max(x,0)
rho2 = lambda x: la*np.log(1+np.exp(x/la))

a3  = lambda i: 1/(t_ref+tau*np.log(1+(v_th/rho(i-v_th))))
a4  = lambda i: 1/(t_ref+tau*np.log(1+(v_th/rho2(i-v_th))))

I = np.linspace(0,3,10000)
A3 = [a3(i) for i in I]
A4 = [a4(i) for i in I]

plt.plot(I,A3, linewidth=3, color='black', label=r'$\rho=max(x,0)$')
plt.plot(I,A4, linewidth=3, color='red', label=r'$\rho=\lambda log(1+e^{x/\lambda})$')
plt.ylim(0,100)
plt.xlim(0,3)
plt.legend()
plt.ylabel("Firing rate (Hz)")
plt.xlabel("Input current")
plt.show()

Result:

MNIST classification

This demonstration was adopted from:

Data loading:

(train_images, train_labels), (
    test_images,
    test_labels,
) = tf.keras.datasets.mnist.load_data()

# flatten images
train_images = train_images.reshape((train_images.shape[0], -1))
test_images = test_images.reshape((test_images.shape[0], -1))

Model definition:

with nengo.Network(seed=0) as net:

    net.config[nengo.Ensemble].max_rates = nengo.dists.Choice([100])
    net.config[nengo.Ensemble].intercepts = nengo.dists.Choice([0])
    net.config[nengo.Connection].synapse = None
    neuron_type = nengo.LIF(amplitude=0.01)

    # this is an optimization to improve the training speed,
    # since we won't require stateful behaviour in this example
    nengo_dl.configure_settings(stateful=False)

    # the input node 
    inp = nengo.Node(np.zeros(28 * 28))

    # add the first convolutional layer
    x = nengo_dl.Layer(tf.keras.layers.Conv2D(filters=32, kernel_size=3))(
        inp, shape_in=(28, 28, 1)
    )
    x = nengo_dl.Layer(neuron_type)(x)

    # add the second convolutional layer
    x = nengo_dl.Layer(tf.keras.layers.Conv2D(filters=64, strides=2, kernel_size=3))(
        x, shape_in=(26, 26, 32)
    )
    x = nengo_dl.Layer(neuron_type)(x)

    # add the third convolutional layer
    x = nengo_dl.Layer(tf.keras.layers.Conv2D(filters=128, strides=2, kernel_size=3))(
        x, shape_in=(12, 12, 64)
    )
    x = nengo_dl.Layer(neuron_type)(x)

    # linear readout
    out = nengo_dl.Layer(tf.keras.layers.Dense(units=10))(x)

    out_p = nengo.Probe(out, label="out_p")
    out_p_filt = nengo.Probe(out, synapse=0.1, label="out_p_filt")

Network building:

minibatch_size = 200
sim = nengo_dl.Simulator(net, minibatch_size=minibatch_size)

Preprocessing:

# add single timestep to training data
train_images = train_images[:, None, :]
train_labels = train_labels[:, None, None]

# when testing our network with spiking neurons we will need to run it
# over time, so we repeat the input/target data for a number of
# timesteps.
n_steps = 30
test_images = np.tile(test_images[:, None, :], (1, n_steps, 1))
test_labels = np.tile(test_labels[:, None, None], (1, n_steps, 1))

SNN compilation before training:

def classification_accuracy(y_true, y_pred):
    return tf.metrics.sparse_categorical_accuracy(y_true[:, -1], y_pred[:, -1])

# note that we use `out_p_filt` when testing (to reduce the spike noise)
sim.compile(loss={out_p_filt: classification_accuracy})
print(
    "Accuracy before training:",
    sim.evaluate(test_images, {out_p_filt: test_labels}, verbose=0)["loss"],
)

# PRINTED: Accuracy before training: 0.09340000152587890:00:00 

Training:

do_training = False
if do_training:
    # run training
    sim.compile(
        optimizer=tf.optimizers.RMSprop(0.001),
        loss={out_p: tf.losses.SparseCategoricalCrossentropy(from_logits=True)},
    )
    sim.fit(train_images, {out_p: train_labels}, epochs=10)

    # save the parameters to file
    sim.save_params("./mnist_params")
else:
    # download pretrained weights
    urlretrieve(
        "https://drive.google.com/uc?export=download&"
        "id=1l5aivQljFoXzPP5JVccdFXbOYRv3BCJR",
        "mnist_params.npz",
    )

    # load parameters
    sim.load_params("./mnist_params")

Evaluating after training:

sim.compile(loss={out_p_filt: classification_accuracy})
print(
    "Accuracy after training:",
    sim.evaluate(test_images, {out_p_filt: test_labels}, verbose=0)["loss"],
)

# PRINTED: Accuracy after training: 0.9869999885559082 0:00:00 

Plotting:

data = sim.predict(test_images[:minibatch_size])

Results:

Last updated