Supplementary Figure 12: many types - 32 neuron types

Neural Activity
Simulation
GNN Training
Many Types
Author

Cédric Allier, Michael Innerberger, Stephan Saalfeld

This script reproduces the panels of paper’s Supplementary Figure 12. Test with 32 different neuron types (update functions).

Simulation parameters:

The simulation follows Equation 2 from the paper:

\[\frac{dx_i}{dt} = -\frac{x_i}{\tau_i} + s_i \cdot \tanh(x_i) + g_i \cdot \sum_j W_{ij} \cdot \psi(x_j)\]

Classification accuracy expected: 0.99

Configuration and Setup

print()
print("=" * 80)
print("Supplementary Figure 12: 1000 neurons, 32 types, dense connectivity")
print("=" * 80)

device = []
best_model = ''
config_file_ = 'signal_fig_supp_12'

print()
config_root = "./config"
config_file, pre_folder = add_pre_folder(config_file_)

# load config
config = NeuralGraphConfig.from_yaml(f"{config_root}/{config_file}.yaml")
config.config_file = config_file
config.dataset = config_file

if device == []:
    device = set_device(config.training.device)

log_dir = f'./log/{config_file}'
graphs_dir = f'./graphs_data/{config_file}'

Step 1: Generate Data

Generate synthetic neural activity data using the PDE_N2 model with 32 neuron types. This tests the GNN’s ability to learn many distinct update functions.

Outputs:

  • Sample time series
  • True connectivity matrix \(W_{ij}\)
# STEP 1: GENERATE
print()
print("-" * 80)
print("STEP 1: GENERATE - Simulating neural activity (32 neuron types)")
print("-" * 80)

# Check if data already exists
data_file = f'{graphs_dir}/x_list_0.npy'
if os.path.exists(data_file):
    print(f"data already exists at {graphs_dir}/")
    print("skipping simulation, regenerating figures...")
    data_generate(
        config,
        device=device,
        visualize=False,
        run_vizualized=0,
        style="color",
        alpha=1,
        erase=False,
        bSave=True,
        step=2,
        regenerate_plots_only=True,
    )
else:
    print(f"simulating {config.simulation.n_neurons} neurons, {config.simulation.n_neuron_types} types")
    print(f"generating {config.simulation.n_frames} time frames")
    print(f"output: {graphs_dir}/")
    print()
    data_generate(
        config,
        device=device,
        visualize=False,
        run_vizualized=0,
        style="color",
        alpha=1,
        erase=False,
        bSave=True,
        step=2,
    )

Sample time series taken from the activity data (32 neuron types).

True connectivity \(W_{ij}\). The inset shows 20×20 weights.

Step 2: Train GNN

Train the GNN to learn connectivity \(W\), latent embeddings \(\mathbf{a}_i\), and functions \(\phi^*, \psi^*\). The GNN must learn to distinguish 32 different update functions.

The GNN optimizes the update rule (Equation 3 from the paper):

\[\hat{\dot{x}}_i = \phi^*(\mathbf{a}_i, x_i) + \sum_j W_{ij} \psi^*(x_j)\]

# STEP 2: TRAIN
print()
print("-" * 80)
print("STEP 2: TRAIN - Training GNN to learn 32 neuron types")
print("-" * 80)

# Check if trained model already exists (any .pt file in models folder)
import glob
model_files = glob.glob(f'{log_dir}/models/*.pt')
if model_files:
    print(f"trained model already exists at {log_dir}/models/")
    print("skipping training (delete models folder to retrain)")
else:
    print(f"training for {config.training.n_epochs} epochs, {config.training.n_runs} run(s)")
    print(f"learning: connectivity W, latent vectors a_i for 32 types, functions phi* and psi*")
    print(f"models: {log_dir}/models/")
    print(f"training plots: {log_dir}/tmp_training")
    print(f"tensorboard: tensorboard --logdir {log_dir}/")
    print()
    data_train(
        config=config,
        erase=False,
        best_model=best_model,
        style='color',
        device=device
    )

Step 3: GNN Evaluation

Figures matching Supplementary Figure 12 from the paper.

Figure panels:

  • Learned connectivity matrix
  • Comparison of learned vs true connectivity
  • Learned latent vectors \(\mathbf{a}_i\) (32 clusters expected)
  • Learned update functions \(\phi^*(\mathbf{a}_i, x)\) (32 distinct functions)
  • Learned transfer function \(\psi^*(x)\)
# STEP 3: GNN EVALUATION
print()
print("-" * 80)
print("STEP 3: GNN EVALUATION - Generating Supplementary Figure 12 panels")
print("-" * 80)
print(f"learned connectivity matrix")
print(f"W learned vs true (R^2, slope)")
print(f"latent vectors a_i (32 clusters)")
print(f"update functions phi*(a_i, x) - 32 distinct functions")
print(f"transfer function psi*(x)")
print(f"output: {log_dir}/results/")
print()
folder_name = './log/' + pre_folder + '/tmp_results/'
os.makedirs(folder_name, exist_ok=True)
data_plot(config=config, config_file=config_file, epoch_list=['best'], style='color', extended='plots', device=device, apply_weight_correction=True, plot_eigen_analysis=False)

Supplementary Figure 12: GNN Evaluation Results

Learned connectivity.

Comparison of learned and true connectivity (given \(g_i\)=10).

Learned latent vectors \(a_i\) of all neurons. 32 clusters expected (one per neuron type).

Learned update functions \(\phi^*(a_i, x)\). The plot shows 1000 overlaid curves representing 32 distinct update functions. Colors indicate true neuron types. True functions are overlaid in light gray.

Learned transfer function \(\psi^*(x)\), normalized to a maximum value of 1. True function is overlaid in light gray.