# Learning with SNN

{% hint style="info" %}
Read Chapter 13
{% endhint %}

### Python / Nengo demonstration

Imports:

```
import nengo
import numpy as np
import matplotlib.pyplot as plt
from nengo.processes import WhiteSignal

from urllib.request import urlretrieve
import tensorflow as tf
import nengo_dl
```

### PES learning: communication channel

```
model = nengo.Network('Learn a Communication Channel')
with model:
    stim = nengo.Node(output=WhiteSignal(10, high=5, rms=0.5))
    
    pre = nengo.Ensemble(60, dimensions=1)
    post = nengo.Ensemble(60, dimensions=1)
    
    nengo.Connection(stim, pre)
    conn = nengo.Connection(pre, post, function=lambda x: np.random.random())
    
    inp_p = nengo.Probe(stim)
    pre_p = nengo.Probe(pre, synapse=0.01)
    post_p = nengo.Probe(post, synapse=0.01)
    
    error = nengo.Ensemble(60, dimensions=1)
    error_p = nengo.Probe(error, synapse=0.03)
    
    nengo.Connection(post, error)
    nengo.Connection(pre, error, transform=-1) # Learn simple communication line
    conn.learning_rule_type = nengo.PES()
    learn_conn = nengo.Connection(error, conn.learning_rule) 

sim = nengo.Simulator(model)
sim.run(10.0)

t=sim.trange()

import matplotlib.pyplot as plt

plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[inp_p].T[0], c='k', label='Input')
plt.plot(t, sim.data[pre_p].T[0], c='b', label='Pre')
plt.plot(t, sim.data[post_p].T[0], c='r', label='Post')
plt.ylabel("Value")
plt.legend(loc=1)

plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[error_p].T[0], c='k', label='Error')
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend(loc='best')
```

Result:

![](/files/-Mj5biUDd1oub41GMWPH)

### &#x20;Pavlovian conditioning

```
D = 3
N = D*100

def us_stim(t):
    # cycle through the three US
    t = t % 3
    if 0.9 < t< 1: return [1, 0, 0]
    if 1.9 < t< 2: return [0, 1, 0]
    if 2.9 < t< 3: return [0, 0, 1]
    return [0, 0, 0]

def cs_stim(t):
    # cycle through the three CS
    t = t % 3
    if 0.7 < t< 1: return [0.7, 0,   0.5]
    if 1.7 < t< 2: return [0.6, 0.7, 0.8]
    if 2.7 < t< 3: return [0, 1, 0]
    return [0, 0, 0]

def stop_learning(t):
    if 8 > t > 2:
        return 0
    return 1

model = nengo.Network(label="Classical Conditioning")
with model:
    
    us_stim = nengo.Node(us_stim)
    us_stim_p = nengo.Probe(us_stim)
    
    us = nengo.Ensemble(N, D)
    ur = nengo.Ensemble(N, D)
    us_p = nengo.Probe(us, synapse=0.1)
    ur_p = nengo.Probe(ur, synapse=0.1)
    nengo.Connection(us, ur)
    nengo.Connection(us_stim, us[:D])
    
    cs_stim = nengo.Node(cs_stim)
    cs_stim_p = nengo.Probe(cs_stim)

    cs = nengo.Ensemble(N*2, D*2)
    cr = nengo.Ensemble(N, D)
    cs_p = nengo.Probe(cs, synapse=0.1)
    cr_p = nengo.Probe(cr, synapse=0.1)
    nengo.Connection(cs_stim, cs[:D])
    nengo.Connection(cs[:D], cs[D:], synapse=0.2)

    learn_conn = nengo.Connection(cs, cr, function=lambda x: [0]*D)
    learn_conn.learning_rule_type = nengo.PES(learning_rate=3e-4)

    error   = nengo.Ensemble(N, D)
    error_p = nengo.Probe(error, synapse=0.01)
    nengo.Connection(error, learn_conn.learning_rule)
    nengo.Connection(ur, error, transform=-1)
    nengo.Connection(cr, error, transform=1, synapse=0.1)

    stop_learn = nengo.Node(stop_learning)
    stop_learn_p = nengo.Probe(stop_learn)
    nengo.Connection(stop_learn, error.neurons, transform=-10*np.ones((N, 1)))
    
sim = nengo.Simulator(model)
sim.run(15)

t=sim.trange()

plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[us_stim_p].T[0], c='blue',  label='US #1')
plt.plot(t, sim.data[us_stim_p].T[1], c='red',   label='US #2')
plt.plot(t, sim.data[us_stim_p].T[2], c='black', label='US #3')
plt.plot(t, sim.data[ur_p].T[0], c='blue',  label='UR #1', linestyle=":", linewidth=3)
plt.plot(t, sim.data[ur_p].T[1], c='red',   label='UR #2', linestyle=":", linewidth=3)
plt.plot(t, sim.data[ur_p].T[2], c='black', label='UR #3', linestyle=":", linewidth=3)
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend()
plt.show()

plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[cs_stim_p].T[0], c='blue',  label='CS #1')
plt.plot(t, sim.data[cs_stim_p].T[1], c='red',   label='CS #2')
plt.plot(t, sim.data[cs_stim_p].T[2], c='black', label='CS #3')
plt.plot(t, sim.data[ur_p].T[0], c='blue',  label='UR #1', linestyle=":", linewidth=3)
plt.plot(t, sim.data[ur_p].T[1], c='red',   label='UR #2', linestyle=":", linewidth=3)
plt.plot(t, sim.data[ur_p].T[2], c='black', label='UR #3', linestyle=":", linewidth=3)
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend()
plt.show()

plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[cs_stim_p].T[0], c='blue',  label='CS #1')
plt.plot(t, sim.data[cs_stim_p].T[1], c='red',   label='CS #2')
plt.plot(t, sim.data[cs_stim_p].T[2], c='black', label='CS #3')
plt.plot(t, sim.data[cr_p].T[0],      c='blue',  label='CR #1', linestyle=":", linewidth=3)
plt.plot(t, sim.data[cr_p].T[1],      c='red',   label='CR #2', linestyle=":", linewidth=3)
plt.plot(t, sim.data[cr_p].T[2],      c='black', label='CR #3', linestyle=":", linewidth=3)
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend()
plt.show()

plt.figure(figsize=(12, 4))
plt.plot(t, sim.data[error_p].T[0], c='black', label='error')
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend()
plt.show()

plt.figure(figsize=(12, 2))
plt.plot(t, sim.data[stop_learn_p].T[0], c='black',  label='stop')
plt.ylabel("Value")
plt.xlabel("Time (sec)")
plt.legend()
plt.show()
```

Results:

![](/files/-Mj5cwj6D868sIGJgO5C)

![](/files/-Mj5d1GzhV320hvD85GB)

![](/files/-Mj5d6g7upiyajgbXM4y)

![](/files/-Mj5dJn65RvHOCwjxsWJ)

### Diffrentiable LIF tuning curve

```
Rm      = 1000     # Membrane Resistance      [kOhm]
Cm      = 5e-6     # Capacitance              [uF]

t_ref = 10e-3
tau = Rm * Cm
v_th = 1
I0 = 1
R=1000
la = 0.05

rho  = lambda x: np.max(x,0)
rho2 = lambda x: la*np.log(1+np.exp(x/la))

a3  = lambda i: 1/(t_ref+tau*np.log(1+(v_th/rho(i-v_th))))
a4  = lambda i: 1/(t_ref+tau*np.log(1+(v_th/rho2(i-v_th))))

I = np.linspace(0,3,10000)
A3 = [a3(i) for i in I]
A4 = [a4(i) for i in I]

plt.plot(I,A3, linewidth=3, color='black', label=r'$\rho=max(x,0)$')
plt.plot(I,A4, linewidth=3, color='red', label=r'$\rho=\lambda log(1+e^{x/\lambda})$')
plt.ylim(0,100)
plt.xlim(0,3)
plt.legend()
plt.ylabel("Firing rate (Hz)")
plt.xlabel("Input current")
plt.show()
```

Result:

![](/files/-Mj5e6BrTXlapwU9WUAr)

### MNIST classification

This demonstration was adopted from:

{% embed url="<https://www.nengo.ai/nengo-dl/examples/spiking-mnist.html>" %}

Data loading:

```
(train_images, train_labels), (
    test_images,
    test_labels,
) = tf.keras.datasets.mnist.load_data()

# flatten images
train_images = train_images.reshape((train_images.shape[0], -1))
test_images = test_images.reshape((test_images.shape[0], -1))
```

Model definition:

```
with nengo.Network(seed=0) as net:

    net.config[nengo.Ensemble].max_rates = nengo.dists.Choice([100])
    net.config[nengo.Ensemble].intercepts = nengo.dists.Choice([0])
    net.config[nengo.Connection].synapse = None
    neuron_type = nengo.LIF(amplitude=0.01)

    # this is an optimization to improve the training speed,
    # since we won't require stateful behaviour in this example
    nengo_dl.configure_settings(stateful=False)

    # the input node 
    inp = nengo.Node(np.zeros(28 * 28))

    # add the first convolutional layer
    x = nengo_dl.Layer(tf.keras.layers.Conv2D(filters=32, kernel_size=3))(
        inp, shape_in=(28, 28, 1)
    )
    x = nengo_dl.Layer(neuron_type)(x)

    # add the second convolutional layer
    x = nengo_dl.Layer(tf.keras.layers.Conv2D(filters=64, strides=2, kernel_size=3))(
        x, shape_in=(26, 26, 32)
    )
    x = nengo_dl.Layer(neuron_type)(x)

    # add the third convolutional layer
    x = nengo_dl.Layer(tf.keras.layers.Conv2D(filters=128, strides=2, kernel_size=3))(
        x, shape_in=(12, 12, 64)
    )
    x = nengo_dl.Layer(neuron_type)(x)

    # linear readout
    out = nengo_dl.Layer(tf.keras.layers.Dense(units=10))(x)

    out_p = nengo.Probe(out, label="out_p")
    out_p_filt = nengo.Probe(out, synapse=0.1, label="out_p_filt")
```

Network building:

```
minibatch_size = 200
sim = nengo_dl.Simulator(net, minibatch_size=minibatch_size)
```

Preprocessing:

```
# add single timestep to training data
train_images = train_images[:, None, :]
train_labels = train_labels[:, None, None]

# when testing our network with spiking neurons we will need to run it
# over time, so we repeat the input/target data for a number of
# timesteps.
n_steps = 30
test_images = np.tile(test_images[:, None, :], (1, n_steps, 1))
test_labels = np.tile(test_labels[:, None, None], (1, n_steps, 1))
```

SNN compilation before training:

```
def classification_accuracy(y_true, y_pred):
    return tf.metrics.sparse_categorical_accuracy(y_true[:, -1], y_pred[:, -1])

# note that we use `out_p_filt` when testing (to reduce the spike noise)
sim.compile(loss={out_p_filt: classification_accuracy})
print(
    "Accuracy before training:",
    sim.evaluate(test_images, {out_p_filt: test_labels}, verbose=0)["loss"],
)

# PRINTED: Accuracy before training: 0.09340000152587890:00:00 
```

Training:

```
do_training = False
if do_training:
    # run training
    sim.compile(
        optimizer=tf.optimizers.RMSprop(0.001),
        loss={out_p: tf.losses.SparseCategoricalCrossentropy(from_logits=True)},
    )
    sim.fit(train_images, {out_p: train_labels}, epochs=10)

    # save the parameters to file
    sim.save_params("./mnist_params")
else:
    # download pretrained weights
    urlretrieve(
        "https://drive.google.com/uc?export=download&"
        "id=1l5aivQljFoXzPP5JVccdFXbOYRv3BCJR",
        "mnist_params.npz",
    )

    # load parameters
    sim.load_params("./mnist_params")
```

Evaluating after training:

```
sim.compile(loss={out_p_filt: classification_accuracy})
print(
    "Accuracy after training:",
    sim.evaluate(test_images, {out_p_filt: test_labels}, verbose=0)["loss"],
)

# PRINTED: Accuracy after training: 0.9869999885559082 0:00:00 
```

Plotting:

```
data = sim.predict(test_images[:minibatch_size])
```

Results:

![](/files/-Mj5gGaE9A4LOc-LnD33)

###


---

# Agent Instructions: Querying This Documentation

If you need additional information that is not directly available in this page, you can query the documentation dynamically by asking a question.

Perform an HTTP GET request on the current page URL with the `ask` query parameter:

```
GET https://elishai.gitbook.io/neuromorphic-engineering/algorithm-designer-perspective/13.-learning-spiking-neural-networks/learning-with-snn.md?ask=<question>
```

The question should be specific, self-contained, and written in natural language.
The response will contain a direct answer to the question and relevant excerpts and sources from the documentation.

Use this mechanism when the answer is not explicitly present in the current page, you need clarification or additional context, or you want to retrieve related documentation sections.
