Note
Go to the end to download the full example code
Weight adaptation according to the Urbanczik-Senn plasticity¶
For details and troubleshooting see How to run Jupyter notebooks.
This script demonstrates the learning in a compartmental neuron where the dendritic synapses adapt their weight according to the plasticity rule by Urbanczik and Senn [1]. In this simple setup, a spike pattern of 200 poisson spike trains is repeatedly presented to a neuron that is composed of one somatic and one dendritic compartment. At the same time, the somatic conductances are activated to produce a time-varying matching potential. After the learning, this signal is then reproduced by the membrane potential of the neuron. This script produces Fig. 1B in [1] but uses standard units instead of the unitless quantities used in the paper.
References¶
import numpy as np
from matplotlib import pyplot as plt
import nest
def g_inh(amplitude, t_start, t_end):
"""
returns weights for the spike generator that drives the inhibitory
somatic conductance.
"""
return lambda t: np.piecewise(t, [(t >= t_start) & (t < t_end)],
[amplitude, 0.0])
def g_exc(amplitude, freq, offset, t_start, t_end):
"""
returns weights for the spike generator that drives the excitatory
somatic conductance.
"""
return lambda t: np.piecewise(t, [(t >= t_start) & (t < t_end)],
[lambda t: amplitude * np.sin(freq * t) + offset, 0.0])
def matching_potential(g_E, g_I, nrn_params):
"""
returns the matching potential as a function of the somatic conductances.
"""
E_E = nrn_params['soma']['E_ex']
E_I = nrn_params['soma']['E_in']
return (g_E * E_E + g_I * E_I) / (g_E + g_I)
def V_w_star(V_w, nrn_params):
"""
returns the dendritic prediction of the somatic membrane potential.
"""
g_D = nrn_params['g_sp']
g_L = nrn_params['soma']['g_L']
E_L = nrn_params['soma']['E_L']
return (g_L * E_L + g_D * V_w) / (g_L + g_D)
def phi(U, nrn_params):
"""
rate function of the soma
"""
phi_max = nrn_params['phi_max']
k = nrn_params['rate_slope']
beta = nrn_params['beta']
theta = nrn_params['theta']
return phi_max / (1.0 + k * np.exp(beta * (theta - U)))
def h(U, nrn_params):
"""
derivative of the rate function phi
"""
k = nrn_params['rate_slope']
beta = nrn_params['beta']
theta = nrn_params['theta']
return 15.0 * beta / (1.0 + np.exp(-beta * (theta - U)) / k)
# simulation params
n_pattern_rep = 100 # number of repetitions of the spike pattern
pattern_duration = 200.0
t_start = 2.0 * pattern_duration
t_end = n_pattern_rep * pattern_duration + t_start
simulation_time = t_end + 2.0 * pattern_duration
n_rep_total = int(np.around(simulation_time / pattern_duration))
resolution = 0.1
nest.resolution = resolution
# neuron parameters
nrn_model = 'pp_cond_exp_mc_urbanczik'
nrn_params = {
't_ref': 3.0, # refractory period
'g_sp': 600.0, # soma-to-dendritic coupling conductance
'soma': {
'V_m': -70.0, # initial value of V_m
'C_m': 300.0, # capacitance of membrane
'E_L': -70.0, # resting potential
'g_L': 30.0, # somatic leak conductance
'E_ex': 0.0, # resting potential for exc input
'E_in': -75.0, # resting potential for inh input
'tau_syn_ex': 3.0, # time constant of exc conductance
'tau_syn_in': 3.0, # time constant of inh conductance
},
'dendritic': {
'V_m': -70.0, # initial value of V_m
'C_m': 300.0, # capacitance of membrane
'E_L': -70.0, # resting potential
'g_L': 30.0, # dendritic leak conductance
'tau_syn_ex': 3.0, # time constant of exc input current
'tau_syn_in': 3.0, # time constant of inh input current
},
# parameters of rate function
'phi_max': 0.15, # max rate
'rate_slope': 0.5, # called 'k' in the paper
'beta': 1.0 / 3.0,
'theta': -55.0,
}
# synapse params
syns = nest.GetDefaults(nrn_model)['receptor_types']
init_w = 0.3 * nrn_params['dendritic']['C_m']
syn_params = {
'synapse_model': 'urbanczik_synapse_wr',
'receptor_type': syns['dendritic_exc'],
'tau_Delta': 100.0, # time constant of low pass filtering of the weight change
'eta': 0.17, # learning rate
'weight': init_w,
'Wmax': 4.5 * nrn_params['dendritic']['C_m'],
'delay': resolution,
}
"""
# in case you want to use the unitless quantities as in [1]:
# neuron params:
nrn_model = 'pp_cond_exp_mc_urbanczik'
nrn_params = {
't_ref': 3.0,
'g_sp': 2.0,
'soma': {
'V_m': 0.0,
'C_m': 1.0,
'E_L': 0.0,
'g_L': 0.1,
'E_ex': 14.0 / 3.0,
'E_in': -1.0 / 3.0,
'tau_syn_ex': 3.0,
'tau_syn_in': 3.0,
},
'dendritic': {
'V_m': 0.0,
'C_m': 1.0,
'E_L': 0.0,
'g_L': 0.1,
'tau_syn_ex': 3.0,
'tau_syn_in': 3.0,
},
# parameters of rate function
'phi_max': 0.15,
'rate_slope': 0.5,
'beta': 5.0,
'theta': 1.0,
}
# synapse params:
syns = nest.GetDefaults(nrn_model)['receptor_types']
init_w = 0.2*nrn_params['dendritic']['g_L']
syn_params = {
'synapse_model': 'urbanczik_synapse_wr',
'receptor_type': syns['dendritic_exc'],
'tau_Delta': 100.0,
'eta': 0.0003 / (15.0*15.0*nrn_params['dendritic']['C_m']),
'weight': init_w,
'Wmax': 3.0*nrn_params['dendritic']['g_L'],
'delay': resolution,
}
"""
# somatic input
ampl_exc = 0.016 * nrn_params['dendritic']['C_m']
offset = 0.018 * nrn_params['dendritic']['C_m']
ampl_inh = 0.06 * nrn_params['dendritic']['C_m']
freq = 2.0 / pattern_duration
soma_exc_inp = g_exc(ampl_exc, 2.0 * np.pi * freq, offset, t_start, t_end)
soma_inh_inp = g_inh(ampl_inh, t_start, t_end)
# dendritic input
# create spike pattern by recording the spikes of a simulation of n_pg
# poisson generators. The recorded spike times are then given to spike
# generators.
n_pg = 200 # number of poisson generators
p_rate = 10.0 # rate in Hz
pgs = nest.Create('poisson_generator', n=n_pg, params={'rate': p_rate})
prrt_nrns_pg = nest.Create('parrot_neuron', n_pg)
nest.Connect(pgs, prrt_nrns_pg, {'rule': 'one_to_one'})
sr = nest.Create('spike_recorder', n_pg)
nest.Connect(prrt_nrns_pg, sr, {'rule': 'one_to_one'})
nest.Simulate(pattern_duration)
t_srs = [ssr.get('events', 'times') for ssr in sr]
nest.ResetKernel()
nest.resolution = resolution
"""
neuron and devices
"""
nrn = nest.Create(nrn_model, params=nrn_params)
# poisson generators are connected to parrot neurons which are
# connected to the mc neuron
prrt_nrns = nest.Create('parrot_neuron', n_pg)
# excitatory input to the soma
spike_times_soma_inp = np.arange(resolution, simulation_time, resolution)
sg_soma_exc = nest.Create('spike_generator',
params={'spike_times': spike_times_soma_inp,
'spike_weights': soma_exc_inp(spike_times_soma_inp)})
# inhibitory input to the soma
sg_soma_inh = nest.Create('spike_generator',
params={'spike_times': spike_times_soma_inp,
'spike_weights': soma_inh_inp(spike_times_soma_inp)})
# excitatory input to the dendrite
sg_prox = nest.Create('spike_generator', n=n_pg)
# for recording all parameters of the Urbanczik neuron
rqs = nest.GetDefaults(nrn_model)['recordables']
mm = nest.Create('multimeter', params={'record_from': rqs, 'interval': 0.1})
# for recoding the synaptic weights of the Urbanczik synapses
wr = nest.Create('weight_recorder')
# for recording the spiking of the soma
sr_soma = nest.Create('spike_recorder')
# create connections
nest.Connect(sg_prox, prrt_nrns, {'rule': 'one_to_one'})
nest.CopyModel('urbanczik_synapse', 'urbanczik_synapse_wr',
{'weight_recorder': wr[0]})
nest.Connect(prrt_nrns, nrn, syn_spec=syn_params)
nest.Connect(mm, nrn, syn_spec={'delay': 0.1})
nest.Connect(sg_soma_exc, nrn,
syn_spec={'receptor_type': syns['soma_exc'], 'weight': 10.0 * resolution, 'delay': resolution})
nest.Connect(sg_soma_inh, nrn,
syn_spec={'receptor_type': syns['soma_inh'], 'weight': 10.0 * resolution, 'delay': resolution})
nest.Connect(nrn, sr_soma)
# simulation divided into intervals of the pattern duration
for i in np.arange(n_rep_total):
# Set the spike times of the pattern for each spike generator
for (sg, t_sp) in zip(sg_prox, t_srs):
nest.SetStatus(
sg, {'spike_times': np.array(t_sp) + i * pattern_duration})
nest.Simulate(pattern_duration)
# read out devices
# multimeter
mm_events = mm.events
t = mm_events['times']
V_s = mm_events['V_m.s']
V_d = mm_events['V_m.p']
V_d_star = V_w_star(V_d, nrn_params)
g_in = mm_events['g_in.s']
g_ex = mm_events['g_ex.s']
I_ex = mm_events['I_ex.p']
I_in = mm_events['I_in.p']
U_M = matching_potential(g_ex, g_in, nrn_params)
# weight recorder
wr_events = wr.events
senders = wr_events['senders']
targets = wr_events['targets']
weights = wr_events['weights']
times = wr_events['times']
# spike recorder
spike_times_soma = sr_soma.get('events', 'times')
# plot results
fs = 22
lw = 2.5
fig1, (axA, axB, axC, axD) = plt.subplots(4, 1, sharex=True)
# membrane potentials and matching potential
axA.plot(t, V_s, lw=lw, label=r'$U$ (soma)', color='darkblue')
axA.plot(t, V_d, lw=lw, label=r'$V_W$ (dendrit)', color='deepskyblue')
axA.plot(t, V_d_star, lw=lw, label=r'$V_W^\ast$ (dendrit)', color='b', ls='--')
axA.plot(t, U_M, lw=lw, label=r'$U_M$ (soma)', color='r', ls='-')
axA.set_ylabel('membrane pot [mV]', fontsize=fs)
axA.legend(fontsize=fs)
# somatic conductances
axB.plot(t, g_in, lw=lw, label=r'$g_I$', color='r')
axB.plot(t, g_ex, lw=lw, label=r'$g_E$', color='coral')
axB.set_ylabel('somatic cond', fontsize=fs)
axB.legend(fontsize=fs)
# dendritic currents
axC.plot(t, I_ex, lw=lw, label=r'$I_ex$', color='r')
axC.plot(t, I_in, lw=lw, label=r'$I_in$', color='coral')
axC.set_ylabel('dend current', fontsize=fs)
axC.legend(fontsize=fs)
# rates
axD.plot(t, phi(V_s, nrn_params), lw=lw, label=r'$\phi(U)$', color='darkblue')
axD.plot(t, phi(V_d, nrn_params), lw=lw,
label=r'$\phi(V_W)$', color='deepskyblue')
axD.plot(t, phi(V_d_star, nrn_params), lw=lw,
label=r'$\phi(V_W^\ast)$', color='b', ls='--')
axD.plot(t, h(V_d_star, nrn_params), lw=lw,
label=r'$h(V_W^\ast)$', color='g', ls='--')
axD.plot(t, phi(V_s, nrn_params) - phi(V_d_star, nrn_params), lw=lw,
label=r'$\phi(U) - \phi(V_W^\ast)$', color='r', ls='-')
axD.plot(spike_times_soma, 0.0 * np.ones(len(spike_times_soma)),
's', color='k', markersize=2)
axD.legend(fontsize=fs)
# synaptic weights
fig2, axA = plt.subplots(1, 1)
for i in np.arange(2, 200, 10):
index = np.intersect1d(np.where(senders == i), np.where(targets == 1))
if not len(index) == 0:
axA.step(times[index], weights[index], label='pg_{}'.format(i - 2),
lw=lw)
axA.set_title('Synaptic weights of Urbanczik synapses')
axA.set_xlabel('time [ms]', fontsize=fs)
axA.set_ylabel('weight', fontsize=fs)
axA.legend(fontsize=fs - 4)
plt.show()
Total running time of the script: ( 0 minutes 0.000 seconds)