Supplementary Figure 14: neuron-dependent transfer functions

Neural Activity
Simulation
GNN Training
Neuron-dependent
Author

Cédric Allier, Michael Innerberger, Stephan Saalfeld

This script reproduces the panels of paper’s Supplementary Figure 14. Training with neuron-neuron dependent transfer functions of the form \(\psi(a_i, a_j, x_j)\).

Simulation parameters:

The simulation follows an extended version of Equation 2:

\[\frac{dx_i}{dt} = -\frac{x_i}{\tau_i} + s_i \cdot \tanh(x_i) + g_i \cdot \sum_j W_{ij} \cdot \psi_{ij}(x_j)\]

where the transfer function depends on both sender \(j\) and receiver \(i\):

\[\psi_{ij}(x_j) = \tanh\left(\frac{x_j}{\gamma_i}\right) - \theta_j \cdot x_j\]

The GNN jointly optimizes the shared MLP \(\psi^*\) and latent vectors \(a_i\) to accurately identify the neuron-neuron dependent transfer functions:

\[\hat{\dot{x}}_i = \phi^*(a_i, x_i) + \sum_j W_{ij} \cdot \psi^*(a_i, a_j, x_j)\]

Configuration and Setup

print()
print("=" * 80)
print("Supplementary Figure 14: 1000 neurons, 4 types, neuron-dependent transfer functions")
print("=" * 80)

device = []
best_model = ''
config_file_ = 'signal_fig_supp_14'

print()
config_root = "./config"
config_file, pre_folder = add_pre_folder(config_file_)

# load config
config = NeuralGraphConfig.from_yaml(f"{config_root}/{config_file}.yaml")
config.config_file = config_file
config.dataset = config_file

if device == []:
    device = set_device(config.training.device)

log_dir = f'./log/{config_file}'
graphs_dir = f'./graphs_data/{config_file}'

Step 1: Generate Data

Generate synthetic neural activity data using the PDE_N5 model with neuron-dependent transfer functions. Each pair of neuron types has different transfer function characteristics depending on both source (\(a_j\)) and target (\(a_i\)) embeddings.

Outputs:

  • Sample time series
  • True connectivity matrix \(W_{ij}\)
# STEP 1: GENERATE
print()
print("-" * 80)
print("STEP 1: GENERATE - Simulating neural activity (neuron-dependent transfer functions)")
print("-" * 80)

# Check if data already exists
data_file = f'{graphs_dir}/x_list_0.npy'
if os.path.exists(data_file):
    print(f"data already exists at {graphs_dir}/")
    print("skipping simulation, regenerating figures...")
    data_generate(
        config,
        device=device,
        visualize=False,
        run_vizualized=0,
        style="color",
        alpha=1,
        erase=False,
        bSave=True,
        step=2,
        regenerate_plots_only=True,
    )
else:
    print(f"simulating {config.simulation.n_neurons} neurons, {config.simulation.n_neuron_types} types")
    print(f"generating {config.simulation.n_frames} time frames")
    print(f"transfer function gamma_i = [1, 2, 4, 8]")
    print(f"output: {graphs_dir}/")
    print()
    data_generate(
        config,
        device=device,
        visualize=False,
        run_vizualized=0,
        style="color",
        alpha=1,
        erase=False,
        bSave=True,
        step=2,
    )

Sample time series taken from the activity data (neuron-dependent transfer functions).

True connectivity \(W_{ij}\). The inset shows 20×20 weights.

Step 2: Train GNN

Train the GNN to learn connectivity \(W\), latent embeddings \(\mathbf{a}_i\), and functions \(\phi^*, \psi^*\). The GNN must learn neuron-neuron dependent transfer functions \(\psi^*(\mathbf{a}_i, \mathbf{a}_j, x_j)\).

# STEP 2: TRAIN
print()
print("-" * 80)
print("STEP 2: TRAIN - Training GNN to learn neuron-dependent transfer functions")
print("-" * 80)

# Check if trained model already exists (any .pt file in models folder)
import glob
model_files = glob.glob(f'{log_dir}/models/*.pt')
if model_files:
    print(f"trained model already exists at {log_dir}/models/")
    print("skipping training (delete models folder to retrain)")
else:
    print(f"training for {config.training.n_epochs} epochs, {config.training.n_runs} run(s)")
    print(f"learning: connectivity W, latent vectors a_i, neuron-dependent psi*(a_i, a_j, x_j)")
    print(f"models: {log_dir}/models/")
    print(f"training plots: {log_dir}/tmp_training")
    print(f"tensorboard: tensorboard --logdir {log_dir}/")
    print()
    data_train(
        config=config,
        erase=False,
        best_model=best_model,
        style='color',
        device=device
    )

Step 3: GNN Evaluation

Figures matching Supplementary Figure 14 from the paper.

Figure panels:

    1. Activity time series used for GNN training (10^5 time-points)
    1. Sample of 10 time series taken from (a)
    1. True connectivity \(W_{ij}\)
    1. Learned connectivity
    1. Comparison between learned and true connectivity
    1. Learned latent vectors \(a_i\)
    1. Learned update functions \(\phi^*(\mathbf{a}, x)\)
    1. Learned transfer functions \(\psi^*(a_i, a_j, x)\) (colors indicate true neuron types, true functions overlaid in light gray)
# STEP 3: GNN EVALUATION
print()
print("-" * 80)
print("STEP 3: GNN EVALUATION - Generating Supplementary Figure 14 panels")
print("-" * 80)
print(f"learned connectivity matrix")
print(f"W learned vs true (R^2, slope)")
print(f"latent vectors a_i (4 clusters)")
print(f"update functions phi*(a_i, x)")
print(f"transfer functions psi*(a_i, a_j, x) - neuron-neuron dependent")
print(f"output: {log_dir}/results/")
print()
folder_name = './log/' + pre_folder + '/tmp_results/'
os.makedirs(folder_name, exist_ok=True)
data_plot(config=config, config_file=config_file, epoch_list=['best'], style='color', extended='plots', device=device, apply_weight_correction=True, plot_eigen_analysis=False)

Supplementary Figure 14: GNN Evaluation Results

Learned connectivity.

Comparison of learned and true connectivity (given \(g_i\)=10). Expected: \(R^2\)=0.99, slope=0.99.

Learned latent vectors \(a_i\) of all neurons.

Learned update functions \(\phi^*(a_i, x)\). The plot shows 1000 overlaid curves. Colors indicate true neuron types. True functions are overlaid in light gray.

Learned transfer functions \(\psi^*(a_i, a_j, x_j)\). 2x2 montage: each panel corresponds to a receiving neuron type (border color), showing curves for all sending neuron types (line colors). True functions in gray.