Supplementary Figure 7: Effect of training dataset size

Neural Activity
Simulation
GNN Training
Author

Cédric Allier, Michael Innerberger, Stephan Saalfeld

This script reproduce the panels of paper’s Supplementary Figure 7. Performance scales with the length of the training series. This notebook displays connectivity matrix comparison and \(\phi^*\) plots for each dataset size.

Simulation parameters (constant across all experiments):

Variable: Training dataset size (n_frames)

Config n_frames
signal_fig_2 100,000
signal_fig_supp_7_1 50,000
signal_fig_supp_7_2 40,000
signal_fig_supp_7_3 30,000
signal_fig_supp_7_4 20,000
signal_fig_supp_7_5 10,000

Configuration

import glob

print()
print("=" * 80)
print("Supplementary Figure 7: Effect of Training Dataset Size")
print("=" * 80)

# All configs to process (config_name, n_frames)
config_list = [
    ('signal_fig_2', 100000),
    ('signal_fig_supp_7_1', 50000),
    ('signal_fig_supp_7_2', 40000),
    ('signal_fig_supp_7_3', 30000),
    ('signal_fig_supp_7_4', 20000),
    ('signal_fig_supp_7_5', 10000),
]

device = []
best_model = ''
config_root = "./config"

Steps 1-3: Generate, Train, and Plot for all configs

Loop over all dataset sizes: generate data, train GNN, and generate plots. Skips steps if data/models already exist.

for config_file_, n_frames in config_list:
    print()
    print("=" * 80)
    print(f"Processing: {config_file_} (n_frames={n_frames:,})")
    print("=" * 80)

    config_file, pre_folder = add_pre_folder(config_file_)

    # Load config
    config = NeuralGraphConfig.from_yaml(f"{config_root}/{config_file}.yaml")
    config.config_file = config_file
    config.dataset = config_file

    if device == []:
        device = set_device(config.training.device)

    log_dir = f'./log/{config_file}'
    graphs_dir = f'./graphs_data/{config_file}'

    # STEP 1: GENERATE
    print()
    print("-" * 80)
    print("STEP 1: GENERATE - Simulating neural activity")
    print("-" * 80)

    data_file = f'{graphs_dir}/x_list_0.npy'
    if os.path.exists(data_file):
        print(f"data already exists at {graphs_dir}/")
        print("skipping simulation, regenerating figures...")
        data_generate(
            config,
            device=device,
            visualize=False,
            run_vizualized=0,
            style="color",
            alpha=1,
            erase=False,
            bSave=True,
            step=2,
            regenerate_plots_only=True,
        )
    else:
        print(f"simulating {config.simulation.n_neurons} neurons, {config.simulation.n_frames} frames")
        print(f"output: {graphs_dir}/")
        data_generate(
            config,
            device=device,
            visualize=False,
            run_vizualized=0,
            style="color",
            alpha=1,
            erase=False,
            bSave=True,
            step=2,
        )

    # STEP 2: TRAIN
    print()
    print("-" * 80)
    print("STEP 2: TRAIN - Training GNN")
    print("-" * 80)

    model_files = glob.glob(f'{log_dir}/models/*.pt')
    if model_files:
        print(f"trained model already exists at {log_dir}/models/")
        print("skipping training (delete models folder to retrain)")
    else:
        print(f"training for {config.training.n_epochs} epochs")
        print(f"n_frames: {config.simulation.n_frames}")
        data_train(
            config=config,
            erase=False,
            best_model=best_model,
            style='color',
            device=device
        )

    # STEP 3: PLOT
    print()
    print("-" * 80)
    print("STEP 3: PLOT - Generating figures")
    print("-" * 80)

    folder_name = f'{log_dir}/tmp_results/'
    os.makedirs(folder_name, exist_ok=True)

    data_plot(
        config=config,
        config_file=config_file,
        epoch_list=['best'],
        style='color',
        extended='plots',
        device=device,
        apply_weight_correction=True,
        plot_eigen_analysis=False
    )

Connectivity Matrix Comparison

Learned vs true connectivity matrix \(W_{ij}\) after training. The scatter plot shows \(R^2\) and slope metrics.

Connectivity comparison (n_frames=100,000)

Connectivity comparison (n_frames=50,000)

Connectivity comparison (n_frames=40,000)

Connectivity comparison (n_frames=30,000)

Connectivity comparison (n_frames=20,000)

Connectivity comparison (n_frames=10,000)

Update Function \(\phi^*(\mathbf{a}_i, x)\) (MLP0)

Learned update functions after training. Each curve represents one neuron. Colors indicate true neuron types. True functions overlaid in gray.

Update functions \(\phi^*(a_i, x)\) (n_frames=100,000). True functions are overlaid in light gray.

Update functions \(\phi^*(a_i, x)\) (n_frames=50,000). True functions are overlaid in light gray.

Update functions \(\phi^*(a_i, x)\) (n_frames=40,000). True functions are overlaid in light gray.

Update functions \(\phi^*(a_i, x)\) (n_frames=30,000). True functions are overlaid in light gray.

Update functions \(\phi^*(a_i, x)\) (n_frames=20,000). True functions are overlaid in light gray.

Update functions \(\phi^*(a_i, x)\) (n_frames=10,000). True functions are overlaid in light gray.