Learning with SNN
Chapter 13
Read Chapter 13
Python / Nengo demonstration
Imports:
import nengo
import numpy as np
import matplotlib.pyplot as plt
from nengo.processes import WhiteSignal
from urllib.request import urlretrieve
import tensorflow as tf
import nengo_dl
PES learning: communication channel
model = nengo.Network('Learn a Communication Channel')
with model:
stim = nengo.Node(output=WhiteSignal(10, high=5, rms=0.5))
pre = nengo.Ensemble(60, dimensions=1)
post = nengo.Ensemble(60, dimensions=1)
nengo.Connection(stim, pre)
conn = nengo.Connection(pre, post, function=lambda x: np.random.random())
inp_p = nengo.Probe(stim)
pre_p = nengo.Probe(pre, synapse=0.01)
post_p = nengo.Probe(post, synapse=0.01)
error = nengo.Ensemble(60, dimensions=1)
error_p = nengo.Probe(error, synapse=0.03)
nengo.Connection(post, error)
nengo.Connection(pre, error, transform=-1) # Learn simple communication line
conn.learning_rule_type = nengo.PES()
learn_conn = nengo.Connection(error, conn.learning_rule)
sim = nengo.Simulator(model)
sim.run(10.0)
t=sim.trange()
import matplotlib.pyplot as plt
plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[inp_p].T[0], c='k', label='Input')
plt.plot(t, sim.data[pre_p].T[0], c='b', label='Pre')
plt.plot(t, sim.data[post_p].T[0], c='r', label='Post')
plt.ylabel("Value")
plt.legend(loc=1)
plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[error_p].T[0], c='k', label='Error')
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend(loc='best')
Result:
Pavlovian conditioning
D = 3
N = D*100
def us_stim(t):
# cycle through the three US
t = t % 3
if 0.9 < t< 1: return [1, 0, 0]
if 1.9 < t< 2: return [0, 1, 0]
if 2.9 < t< 3: return [0, 0, 1]
return [0, 0, 0]
def cs_stim(t):
# cycle through the three CS
t = t % 3
if 0.7 < t< 1: return [0.7, 0, 0.5]
if 1.7 < t< 2: return [0.6, 0.7, 0.8]
if 2.7 < t< 3: return [0, 1, 0]
return [0, 0, 0]
def stop_learning(t):
if 8 > t > 2:
return 0
return 1
model = nengo.Network(label="Classical Conditioning")
with model:
us_stim = nengo.Node(us_stim)
us_stim_p = nengo.Probe(us_stim)
us = nengo.Ensemble(N, D)
ur = nengo.Ensemble(N, D)
us_p = nengo.Probe(us, synapse=0.1)
ur_p = nengo.Probe(ur, synapse=0.1)
nengo.Connection(us, ur)
nengo.Connection(us_stim, us[:D])
cs_stim = nengo.Node(cs_stim)
cs_stim_p = nengo.Probe(cs_stim)
cs = nengo.Ensemble(N*2, D*2)
cr = nengo.Ensemble(N, D)
cs_p = nengo.Probe(cs, synapse=0.1)
cr_p = nengo.Probe(cr, synapse=0.1)
nengo.Connection(cs_stim, cs[:D])
nengo.Connection(cs[:D], cs[D:], synapse=0.2)
learn_conn = nengo.Connection(cs, cr, function=lambda x: [0]*D)
learn_conn.learning_rule_type = nengo.PES(learning_rate=3e-4)
error = nengo.Ensemble(N, D)
error_p = nengo.Probe(error, synapse=0.01)
nengo.Connection(error, learn_conn.learning_rule)
nengo.Connection(ur, error, transform=-1)
nengo.Connection(cr, error, transform=1, synapse=0.1)
stop_learn = nengo.Node(stop_learning)
stop_learn_p = nengo.Probe(stop_learn)
nengo.Connection(stop_learn, error.neurons, transform=-10*np.ones((N, 1)))
sim = nengo.Simulator(model)
sim.run(15)
t=sim.trange()
plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[us_stim_p].T[0], c='blue', label='US #1')
plt.plot(t, sim.data[us_stim_p].T[1], c='red', label='US #2')
plt.plot(t, sim.data[us_stim_p].T[2], c='black', label='US #3')
plt.plot(t, sim.data[ur_p].T[0], c='blue', label='UR #1', linestyle=":", linewidth=3)
plt.plot(t, sim.data[ur_p].T[1], c='red', label='UR #2', linestyle=":", linewidth=3)
plt.plot(t, sim.data[ur_p].T[2], c='black', label='UR #3', linestyle=":", linewidth=3)
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend()
plt.show()
plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[cs_stim_p].T[0], c='blue', label='CS #1')
plt.plot(t, sim.data[cs_stim_p].T[1], c='red', label='CS #2')
plt.plot(t, sim.data[cs_stim_p].T[2], c='black', label='CS #3')
plt.plot(t, sim.data[ur_p].T[0], c='blue', label='UR #1', linestyle=":", linewidth=3)
plt.plot(t, sim.data[ur_p].T[1], c='red', label='UR #2', linestyle=":", linewidth=3)
plt.plot(t, sim.data[ur_p].T[2], c='black', label='UR #3', linestyle=":", linewidth=3)
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend()
plt.show()
plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[cs_stim_p].T[0], c='blue', label='CS #1')
plt.plot(t, sim.data[cs_stim_p].T[1], c='red', label='CS #2')
plt.plot(t, sim.data[cs_stim_p].T[2], c='black', label='CS #3')
plt.plot(t, sim.data[cr_p].T[0], c='blue', label='CR #1', linestyle=":", linewidth=3)
plt.plot(t, sim.data[cr_p].T[1], c='red', label='CR #2', linestyle=":", linewidth=3)
plt.plot(t, sim.data[cr_p].T[2], c='black', label='CR #3', linestyle=":", linewidth=3)
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend()
plt.show()
plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[error_p].T[0], c='black', label='error')
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend()
plt.show()
plt.figure(figsize=(12, 2))
plt.plot(t, sim.data[stop_learn_p].T[0], c='black', label='stop')
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend()
plt.show()
Results:
Diffrentiable LIF tuning curve
Rm = 1000 # Membrane Resistance [kOhm]
Cm = 5e-6 # Capacitance [uF]
t_ref = 10e-3
tau = Rm * Cm
v_th = 1
I0 = 1
R=1000
la = 0.05
rho = lambda x: np.max(x,0)
rho2 = lambda x: la*np.log(1+np.exp(x/la))
a3 = lambda i: 1/(t_ref+tau*np.log(1+(v_th/rho(i-v_th))))
a4 = lambda i: 1/(t_ref+tau*np.log(1+(v_th/rho2(i-v_th))))
I = np.linspace(0,3,10000)
A3 = [a3(i) for i in I]
A4 = [a4(i) for i in I]
plt.plot(I,A3, linewidth=3, color='black', label=r'$\rho=max(x,0)$')
plt.plot(I,A4, linewidth=3, color='red', label=r'$\rho=\lambda log(1+e^{x/\lambda})$')
plt.ylim(0,100)
plt.xlim(0,3)
plt.legend()
plt.ylabel("Firing rate (Hz)")
plt.xlabel("Input current")
plt.show()
Result:
MNIST classification
This demonstration was adopted from:
Data loading:
(train_images, train_labels), (
test_images,
test_labels,
) = tf.keras.datasets.mnist.load_data()
# flatten images
train_images = train_images.reshape((train_images.shape[0], -1))
test_images = test_images.reshape((test_images.shape[0], -1))
Model definition:
with nengo.Network(seed=0) as net:
net.config[nengo.Ensemble].max_rates = nengo.dists.Choice([100])
net.config[nengo.Ensemble].intercepts = nengo.dists.Choice([0])
net.config[nengo.Connection].synapse = None
neuron_type = nengo.LIF(amplitude=0.01)
# this is an optimization to improve the training speed,
# since we won't require stateful behaviour in this example
nengo_dl.configure_settings(stateful=False)
# the input node
inp = nengo.Node(np.zeros(28 * 28))
# add the first convolutional layer
x = nengo_dl.Layer(tf.keras.layers.Conv2D(filters=32, kernel_size=3))(
inp, shape_in=(28, 28, 1)
)
x = nengo_dl.Layer(neuron_type)(x)
# add the second convolutional layer
x = nengo_dl.Layer(tf.keras.layers.Conv2D(filters=64, strides=2, kernel_size=3))(
x, shape_in=(26, 26, 32)
)
x = nengo_dl.Layer(neuron_type)(x)
# add the third convolutional layer
x = nengo_dl.Layer(tf.keras.layers.Conv2D(filters=128, strides=2, kernel_size=3))(
x, shape_in=(12, 12, 64)
)
x = nengo_dl.Layer(neuron_type)(x)
# linear readout
out = nengo_dl.Layer(tf.keras.layers.Dense(units=10))(x)
out_p = nengo.Probe(out, label="out_p")
out_p_filt = nengo.Probe(out, synapse=0.1, label="out_p_filt")
Network building:
minibatch_size = 200
sim = nengo_dl.Simulator(net, minibatch_size=minibatch_size)
Preprocessing:
# add single timestep to training data
train_images = train_images[:, None, :]
train_labels = train_labels[:, None, None]
# when testing our network with spiking neurons we will need to run it
# over time, so we repeat the input/target data for a number of
# timesteps.
n_steps = 30
test_images = np.tile(test_images[:, None, :], (1, n_steps, 1))
test_labels = np.tile(test_labels[:, None, None], (1, n_steps, 1))
SNN compilation before training:
def classification_accuracy(y_true, y_pred):
return tf.metrics.sparse_categorical_accuracy(y_true[:, -1], y_pred[:, -1])
# note that we use `out_p_filt` when testing (to reduce the spike noise)
sim.compile(loss={out_p_filt: classification_accuracy})
print(
"Accuracy before training:",
sim.evaluate(test_images, {out_p_filt: test_labels}, verbose=0)["loss"],
)
# PRINTED: Accuracy before training: 0.09340000152587890:00:00
Training:
do_training = False
if do_training:
# run training
sim.compile(
optimizer=tf.optimizers.RMSprop(0.001),
loss={out_p: tf.losses.SparseCategoricalCrossentropy(from_logits=True)},
)
sim.fit(train_images, {out_p: train_labels}, epochs=10)
# save the parameters to file
sim.save_params("./mnist_params")
else:
# download pretrained weights
urlretrieve(
"https://drive.google.com/uc?export=download&"
"id=1l5aivQljFoXzPP5JVccdFXbOYRv3BCJR",
"mnist_params.npz",
)
# load parameters
sim.load_params("./mnist_params")
Evaluating after training:
sim.compile(loss={out_p_filt: classification_accuracy})
print(
"Accuracy after training:",
sim.evaluate(test_images, {out_p_filt: test_labels}, verbose=0)["loss"],
)
# PRINTED: Accuracy after training: 0.9869999885559082 0:00:00
Plotting:
data = sim.predict(test_images[:minibatch_size])
Results:
Last updated