Supplementary Figure 8: Sparse connectivity (5% to 100%)

Neural Activity
Simulation
GNN Training
Author

Cédric Allier, Michael Innerberger, Stephan Saalfeld

This script reproduces the panels of paper’s Supplementary Figure 8. Performance of GNN for connectivity matrices with varying sparsity levels. This notebook displays connectivity matrix comparison and \(\phi^*\) plots for each sparsity level.

Simulation parameters (constant across all experiments):

Variable: Connectivity sparsity

Config Sparsity
signal_fig_supp_8 5%
signal_fig_supp_8_3 10%
signal_fig_supp_8_2 20%
signal_fig_supp_8_1 50%
signal_fig_2 100%

Configuration

import glob

print()
print("=" * 80)
print("Supplementary Figure 8: Effect of Connectivity Sparsity")
print("=" * 80)

# All configs to process (config_name, sparsity)
config_list = [
    ('signal_fig_supp_8', '5%'),
    ('signal_fig_supp_8_3', '10%'),
    ('signal_fig_supp_8_2', '20%'),
    ('signal_fig_supp_8_1', '50%'),
    ('signal_fig_2', '100%'),
]

device = []
best_model = ''
config_root = "./config"

Steps 1-3: Generate, Train, and Plot for all configs

Loop over all sparsity levels: generate data, train GNN, and generate plots. Skips steps if data/models already exist.

for config_file_, sparsity in config_list:
    print()
    print("=" * 80)
    print(f"Processing: {config_file_} ({sparsity} sparsity)")
    print("=" * 80)

    config_file, pre_folder = add_pre_folder(config_file_)

    # Load config
    config = NeuralGraphConfig.from_yaml(f"{config_root}/{config_file}.yaml")
    config.config_file = config_file
    config.dataset = config_file

    if device == []:
        device = set_device(config.training.device)

    log_dir = f'./log/{config_file}'
    graphs_dir = f'./graphs_data/{config_file}'

    # STEP 1: GENERATE
    print()
    print("-" * 80)
    print("STEP 1: GENERATE - Simulating neural activity")
    print("-" * 80)

    data_file = f'{graphs_dir}/x_list_0.npy'
    if os.path.exists(data_file):
        print(f"data already exists at {graphs_dir}/")
        print("skipping simulation, regenerating figures...")
        data_generate(
            config,
            device=device,
            visualize=False,
            run_vizualized=0,
            style="color",
            alpha=1,
            erase=False,
            bSave=True,
            step=2,
            regenerate_plots_only=True,
        )
    else:
        print(f"simulating {config.simulation.n_neurons} neurons, {config.simulation.n_frames} frames")
        print(f"output: {graphs_dir}/")
        data_generate(
            config,
            device=device,
            visualize=False,
            run_vizualized=0,
            style="color",
            alpha=1,
            erase=False,
            bSave=True,
            step=2,
        )

    # STEP 2: TRAIN
    print()
    print("-" * 80)
    print("STEP 2: TRAIN - Training GNN")
    print("-" * 80)

    model_files = glob.glob(f'{log_dir}/models/*.pt')
    if model_files:
        print(f"trained model already exists at {log_dir}/models/")
        print("skipping training (delete models folder to retrain)")
    else:
        print(f"training for {config.training.n_epochs} epochs")
        print(f"sparsity: {sparsity}")
        data_train(
            config=config,
            erase=False,
            best_model=best_model,
            style='color',
            device=device
        )

    # STEP 3: PLOT
    print()
    print("-" * 80)
    print("STEP 3: PLOT - Generating figures")
    print("-" * 80)

    folder_name = f'{log_dir}/tmp_results/'
    os.makedirs(folder_name, exist_ok=True)

    data_plot(
        config=config,
        config_file=config_file,
        epoch_list=['best'],
        style='color',
        extended='plots',
        device=device,
        apply_weight_correction=True,
        plot_eigen_analysis=False
    )

Activity Time Series

Sample of 100 time series for each sparsity level.

Sample of 100 time series (5% sparsity)

Sample of 100 time series (10% sparsity)

Sample of 100 time series (20% sparsity)

Sample of 100 time series (50% sparsity)

Sample of 100 time series (100% connectivity)

True Connectivity Matrix \(W_{ij}\)

True connectivity matrix for each sparsity level.

True connectivity \(W_{ij}\) (5% sparsity)

True connectivity \(W_{ij}\) (10% sparsity)

True connectivity \(W_{ij}\) (20% sparsity)

True connectivity \(W_{ij}\) (50% sparsity)

True connectivity \(W_{ij}\) (100% connectivity)

Connectivity Matrix Comparison

Learned vs true connectivity matrix \(W_{ij}\) after training. The scatter plot shows \(R^2\) and slope metrics.

Connectivity comparison (5% sparsity)

Connectivity comparison (10% sparsity)

Connectivity comparison (20% sparsity)

Connectivity comparison (50% sparsity)

Connectivity comparison (100% connectivity)

Update Function \(\phi^*(\mathbf{a}_i, x)\) (MLP0)

Learned update functions after training. Each curve represents one neuron. Colors indicate true neuron types. True functions overlaid in gray.

Update functions \(\phi^*(a_i, x)\) (5% sparsity). True functions are overlaid in light gray.

Update functions \(\phi^*(a_i, x)\) (10% sparsity). True functions are overlaid in light gray.

Update functions \(\phi^*(a_i, x)\) (20% sparsity). True functions are overlaid in light gray.

Update functions \(\phi^*(a_i, x)\) (50% sparsity). True functions are overlaid in light gray.

Update functions \(\phi^*(a_i, x)\) (100% connectivity). True functions are overlaid in light gray.