import glob
print()
print("=" * 80)
print("Supplementary Figure 7: Effect of Training Dataset Size")
print("=" * 80)
# All configs to process (config_name, n_frames)
config_list = [
('signal_fig_2', 100000),
('signal_fig_supp_7_1', 50000),
('signal_fig_supp_7_2', 40000),
('signal_fig_supp_7_3', 30000),
('signal_fig_supp_7_4', 20000),
('signal_fig_supp_7_5', 10000),
]
device = []
best_model = ''
config_root = "./config"Supplementary Figure 7: Effect of training dataset size
Neural Activity
Simulation
GNN Training
This script reproduce the panels of paper’s Supplementary Figure 7. Performance scales with the length of the training series. This notebook displays connectivity matrix comparison and \(\phi^*\) plots for each dataset size.
Simulation parameters (constant across all experiments):
- N_neurons: 1000
- N_types: 4 parameterized by \(\tau_i\)={0.5,1}, \(s_i\)={1,2} and \(g_i\)=10
- Connectivity: 100% (dense), Lorentz distribution
Variable: Training dataset size (n_frames)
| Config | n_frames |
|---|---|
| signal_fig_2 | 100,000 |
| signal_fig_supp_7_1 | 50,000 |
| signal_fig_supp_7_2 | 40,000 |
| signal_fig_supp_7_3 | 30,000 |
| signal_fig_supp_7_4 | 20,000 |
| signal_fig_supp_7_5 | 10,000 |
Configuration
Steps 1-3: Generate, Train, and Plot for all configs
Loop over all dataset sizes: generate data, train GNN, and generate plots. Skips steps if data/models already exist.
for config_file_, n_frames in config_list:
print()
print("=" * 80)
print(f"Processing: {config_file_} (n_frames={n_frames:,})")
print("=" * 80)
config_file, pre_folder = add_pre_folder(config_file_)
# Load config
config = NeuralGraphConfig.from_yaml(f"{config_root}/{config_file}.yaml")
config.config_file = config_file
config.dataset = config_file
if device == []:
device = set_device(config.training.device)
log_dir = f'./log/{config_file}'
graphs_dir = f'./graphs_data/{config_file}'
# STEP 1: GENERATE
print()
print("-" * 80)
print("STEP 1: GENERATE - Simulating neural activity")
print("-" * 80)
data_file = f'{graphs_dir}/x_list_0.npy'
if os.path.exists(data_file):
print(f"data already exists at {graphs_dir}/")
print("skipping simulation, regenerating figures...")
data_generate(
config,
device=device,
visualize=False,
run_vizualized=0,
style="color",
alpha=1,
erase=False,
bSave=True,
step=2,
regenerate_plots_only=True,
)
else:
print(f"simulating {config.simulation.n_neurons} neurons, {config.simulation.n_frames} frames")
print(f"output: {graphs_dir}/")
data_generate(
config,
device=device,
visualize=False,
run_vizualized=0,
style="color",
alpha=1,
erase=False,
bSave=True,
step=2,
)
# STEP 2: TRAIN
print()
print("-" * 80)
print("STEP 2: TRAIN - Training GNN")
print("-" * 80)
model_files = glob.glob(f'{log_dir}/models/*.pt')
if model_files:
print(f"trained model already exists at {log_dir}/models/")
print("skipping training (delete models folder to retrain)")
else:
print(f"training for {config.training.n_epochs} epochs")
print(f"n_frames: {config.simulation.n_frames}")
data_train(
config=config,
erase=False,
best_model=best_model,
style='color',
device=device
)
# STEP 3: PLOT
print()
print("-" * 80)
print("STEP 3: PLOT - Generating figures")
print("-" * 80)
folder_name = f'{log_dir}/tmp_results/'
os.makedirs(folder_name, exist_ok=True)
data_plot(
config=config,
config_file=config_file,
epoch_list=['best'],
style='color',
extended='plots',
device=device,
apply_weight_correction=True,
plot_eigen_analysis=False
)Connectivity Matrix Comparison
Learned vs true connectivity matrix \(W_{ij}\) after training. The scatter plot shows \(R^2\) and slope metrics.






Update Function \(\phi^*(\mathbf{a}_i, x)\) (MLP0)
Learned update functions after training. Each curve represents one neuron. Colors indicate true neuron types. True functions overlaid in gray.





