"""Preprocessing pipeline for single MEG/EEG datasets.
This module provides the main PreprocessingPipeline class for processing
individual MEG/EEG datasets with filtering, bad channel detection,
line noise removal, and ICA artifact removal.
"""
from __future__ import annotations
from pathlib import Path
from typing import Literal
import mne
from loguru import logger
from mne.io import BaseRaw
from mne.preprocessing import ICA
from mne_bids import BIDSPath, read_raw_bids
DType = Literal["eeg", "meg"]
[docs]
class PreprocessingPipeline:
"""Main preprocessing pipeline for MEG/EEG data.
This class provides a comprehensive preprocessing pipeline including:
- Flexible input path parsing (string, Path, BIDSPath)
- Data loading with automatic datatype inference
- Filtering and resampling
- Bad channel detection and interpolation
- Line noise removal (zapline)
- ICA artifact removal with automatic/manual labeling
Parameters
----------
input_path : str | Path | BIDSPath | BaseRaw
Path to the input data, or a Raw object directly.
Can be a plain string, pathlib.Path, mne_bids.BIDSPath object,
or an MNE Raw object (useful for testing).
output_dir : str | Path | None, optional
Directory for saving outputs. If None, uses BIDS derivatives.
n_jobs : int, optional
Number of parallel jobs. Default is 1.
use_cuda : bool, optional
Whether to use CUDA acceleration. Default is False.
random_state : int, optional
Random seed for reproducibility. Default is 42.
Attributes
----------
input_path : Path | BIDSPath | None
Parsed input path (None if Raw was provided directly).
output_dir : Path
Output directory path.
n_jobs : int | str
Number of parallel jobs or "cuda".
random_state : int
Random seed.
raw : BaseRaw | None
Loaded raw data.
datatype : DType | None
Inferred datatype ("eeg" or "meg").
"""
[docs]
def __init__(
self,
input_path: str | Path | BIDSPath | BaseRaw,
output_dir: str | Path | None = None,
n_jobs: int = 1,
use_cuda: bool = False,
random_state: int = 42,
) -> None:
"""Initialize preprocessing pipeline."""
# Initialize state first (before parsing input)
self.raw: BaseRaw | None = None
self.datatype: DType | None = None
self._ica: ICA | None = None
self._ic_labels: dict | None = None
# Parse input path (may set self.raw and self.datatype)
self.input_path = self._parse_input_path(input_path)
# Set up output directory
if output_dir is None:
if self.input_path is not None and isinstance(self.input_path, BIDSPath):
self.output_dir = Path(self.input_path.root) / "derivatives" / "preproc"
elif self.input_path is not None:
self.output_dir = Path(self.input_path).parent / "derivatives" / "preproc"
else:
# Raw object was provided, use current directory
self.output_dir = Path.cwd() / "derivatives" / "preproc"
else:
self.output_dir = Path(output_dir)
# Configuration
self.random_state = random_state
# CUDA setup
if use_cuda:
logger.trace("Using CUDA for computations where possible.")
self.n_jobs: int | str = "cuda"
try:
mne.cuda.init_cuda()
except Exception as e:
logger.warning(f"Failed to initialize CUDA: {e}. Falling back to CPU.")
self.n_jobs = n_jobs
else:
self.n_jobs = n_jobs
logger.trace(f"Using {n_jobs} CPU cores for computations where possible.")
logger.info(f"Initialized PreprocessingPipeline for {self.input_path}")
def _parse_input_path(
self, input_path: str | Path | BIDSPath | BaseRaw
) -> Path | BIDSPath | None:
"""Parse and validate input path.
Parameters
----------
input_path : str | Path | BIDSPath | BaseRaw
Input path in various formats, or a Raw object for testing.
Returns
-------
Path | BIDSPath | None
Validated path object, or None if Raw was provided.
Raises
------
TypeError
If input_path is not a valid type.
FileNotFoundError
If the path does not exist.
ValueError
If the path is invalid.
"""
# Special case: allow passing Raw object directly (for testing)
if isinstance(input_path, BaseRaw):
self.raw = input_path
self.datatype = self._infer_datatype(input_path)
return None
# Type check
if not isinstance(input_path, str | Path | BIDSPath):
raise TypeError(
f"input_path must be str, Path, BIDSPath, or BaseRaw, got {type(input_path)}"
)
# Convert to appropriate type
if isinstance(input_path, BIDSPath):
path = input_path
# Validate BIDSPath exists
if path.fpath is None or not Path(path.fpath).exists():
raise FileNotFoundError(f"BIDSPath does not exist: {path}")
# Infer datatype from BIDSPath
if path.datatype in ["eeg", "meg"]:
self.datatype = path.datatype
elif isinstance(input_path, str | Path):
path = Path(input_path)
# Validate path exists
if not path.exists():
raise FileNotFoundError(f"Path does not exist: {path}")
else:
raise ValueError(f"Invalid input_path type: {type(input_path)}")
return path
def _infer_datatype(self, raw: BaseRaw) -> DType:
"""Infer datatype (EEG or MEG) from raw data.
Parameters
----------
raw : BaseRaw
Raw data object.
Returns
-------
DType
Inferred datatype: "eeg" or "meg".
Raises
------
ValueError
If datatype cannot be inferred.
"""
ch_types = set(raw.get_channel_types())
if "eeg" in ch_types:
return "eeg"
elif any(t in ch_types for t in ("mag", "grad", "planar1", "planar2")):
return "meg"
else:
raise ValueError(
f"Cannot infer datatype from channel types: {ch_types}. "
"Must contain 'eeg' or MEG-related types."
)
[docs]
def load_data(self) -> BaseRaw:
"""Load data from input path.
Returns
-------
BaseRaw
Loaded raw data.
Raises
------
ValueError
If data cannot be loaded.
"""
# Return cached data if already loaded
if self.raw is not None:
logger.trace("Returning cached raw data.")
return self.raw
logger.info("Loading data...")
# Load based on path type
if isinstance(self.input_path, BIDSPath):
self.raw = read_raw_bids(self.input_path, verbose=False)
self.raw.load_data()
else:
# Try to read as standard MNE format
try:
self.raw = mne.io.read_raw(self.input_path, preload=True, verbose=False)
except Exception as e:
raise ValueError(f"Failed to load data from {self.input_path}: {e}")
# Infer and store datatype
self.datatype = self._infer_datatype(self.raw)
logger.info(f"Loaded {self.datatype.upper()} data with {len(self.raw.ch_names)} channels.")
return self.raw
[docs]
def filter_and_resample(
self,
highpass: float = 0.1,
lowpass: float = 100.0,
sfreq: float = 250.0,
) -> BaseRaw:
"""Apply bandpass filter and resample data.
Parameters
----------
highpass : float, optional
High-pass filter frequency in Hz. Default is 0.1.
lowpass : float, optional
Low-pass filter frequency in Hz. Default is 100.0.
sfreq : float, optional
Target sampling frequency in Hz. Default is 250.0.
Returns
-------
BaseRaw
Filtered and resampled raw data.
Raises
------
AssertionError
If filter parameters are invalid.
ValueError
If data is not loaded.
"""
if self.raw is None:
raise ValueError("Data not loaded. Call load_data() first.")
# Validate parameters
nyquist = sfreq / 2
assert (
lowpass < nyquist
), f"Lowpass frequency ({lowpass} Hz) must be less than Nyquist frequency ({nyquist} Hz)."
assert highpass < lowpass, (
f"Highpass frequency ({highpass} Hz) must be less than "
f"lowpass frequency ({lowpass} Hz)."
)
logger.info(
f"Filtering: highpass={highpass} Hz, lowpass={lowpass} Hz, resampling to {sfreq} Hz"
)
# Apply filter
self.raw.filter(l_freq=highpass, h_freq=lowpass, n_jobs=self.n_jobs, verbose=False)
# Resample
self.raw.resample(sfreq, n_jobs=self.n_jobs, verbose=False)
logger.success("Filtering and resampling completed.")
return self.raw
[docs]
def detect_and_fix_bad_channels(
self,
fix: bool = True,
reset_bads: bool = True,
origin: tuple[float, float, float] = (0.0, 0.0, 0.04),
) -> BaseRaw:
"""Detect and optionally interpolate bad channels.
Parameters
----------
fix : bool, optional
Whether to interpolate bad channels. Default is True.
reset_bads : bool, optional
Whether to reset bads list after interpolation. Default is True.
origin : tuple, optional
Origin for MEG interpolation. Default is (0.0, 0.0, 0.04).
Returns
-------
BaseRaw
Raw data with bad channels marked/fixed.
Raises
------
ValueError
If data not loaded or datatype not inferred.
"""
if self.raw is None:
raise ValueError("Data not loaded. Call load_data() first.")
if self.datatype is None:
self.datatype = self._infer_datatype(self.raw)
logger.info("Detecting bad channels...")
# Import detection modules
from .bad_channels import detect_bad_channels_eeg, detect_bad_channels_meg
# Detect based on datatype
if self.datatype == "eeg":
bad_channels = detect_bad_channels_eeg(self.raw, random_state=self.random_state)
elif self.datatype == "meg":
bad_channels = detect_bad_channels_meg(self.raw, origin=origin)
else:
raise ValueError(f"Unsupported datatype: {self.datatype}")
# Mark bad channels
self.raw.info["bads"].extend(bad_channels)
logger.info(f"Detected {len(bad_channels)} bad channels: {bad_channels}")
# Interpolate if requested
if fix and len(bad_channels) > 0:
logger.info("Interpolating bad channels...")
self.raw.interpolate_bads(
reset_bads=reset_bads,
method=dict(meg="MNE", eeg="spline"),
origin=origin if self.datatype == "meg" else "auto",
verbose=False,
)
logger.success("Bad channels interpolated.")
# Save derivative
self._save_bad_channels_tsv(bad_channels)
return self.raw
[docs]
def remove_line_noise(
self,
fline: float = 50.0,
) -> BaseRaw:
"""Remove power line noise using zapline method.
Parameters
----------
fline : float, optional
Power line frequency in Hz. Default is 50.0.
Returns
-------
BaseRaw
Raw data with line noise removed.
Raises
------
ValueError
If data not loaded or datatype not inferred.
"""
if self.raw is None:
raise ValueError("Data not loaded. Call load_data() first.")
if self.datatype is None:
self.datatype = self._infer_datatype(self.raw)
logger.info(f"Removing {fline} Hz line noise...")
from .line_noise import remove_line_noise_eeg, remove_line_noise_meg
# Apply based on datatype
if self.datatype == "eeg":
self.raw = remove_line_noise_eeg(self.raw, fline=fline)
elif self.datatype == "meg":
self.raw = remove_line_noise_meg(self.raw, fline=fline)
else:
raise ValueError(f"Unsupported datatype: {self.datatype}")
logger.success("Line noise removal completed.")
return self.raw
[docs]
def apply_ica(
self,
n_components: int | None = None,
method: str = "infomax",
regress: bool = True,
manual_labels: list[str] | None = None,
) -> BaseRaw:
"""Apply ICA for artifact removal.
Parameters
----------
n_components : int | None, optional
Number of ICA components. If None, uses default (20 for EEG, 40 for MEG).
method : str, optional
ICA method. Default is "infomax".
regress : bool, optional
Whether to regress out artifact components. Default is False.
manual_labels : list[str] | None, optional
Manual labels for ICA components. If None, uses automatic labeling.
Returns
-------
BaseRaw
Raw data with ICA applied (if regress=True).
Raises
------
ValueError
If data not loaded or datatype not inferred.
"""
if self.raw is None:
raise ValueError("Data not loaded. Call load_data() first.")
if self.datatype is None:
self.datatype = self._infer_datatype(self.raw)
logger.info("Applying ICA...")
from .ica import apply_ica_pipeline
# Determine default n_components
if n_components is None:
n_components = 20 if self.datatype == "eeg" else 40
# Apply ICA
result = apply_ica_pipeline(
raw=self.raw,
datatype=self.datatype,
n_components=n_components,
method=method,
regress=regress,
manual_labels=manual_labels,
random_state=self.random_state,
)
if regress:
self.raw = result
logger.success("ICA artifact regression completed.")
else:
logger.success("ICA decomposition completed (no regression).")
return self.raw
[docs]
def run(
self,
filter_params: dict | None = None,
detect_bad_channels: bool = True,
remove_line_noise: bool = True,
apply_ica: bool = True,
ica_params: dict | None = None,
save_intermediate: bool = False,
) -> BaseRaw:
"""Run complete preprocessing pipeline.
Parameters
----------
filter_params : dict | None, optional
Filtering parameters. Keys: highpass, lowpass, sfreq.
detect_bad_channels : bool, optional
Whether to detect and fix bad channels. Default is True.
remove_line_noise : bool, optional
Whether to remove line noise. Default is True.
apply_ica : bool, optional
Whether to apply ICA. Default is True.
ica_params : dict | None, optional
ICA parameters. Keys: n_components, method, regress.
save_intermediate : bool, optional
Whether to save intermediate files. Default is False.
Returns
-------
BaseRaw
Preprocessed raw data.
"""
logger.info("Starting preprocessing pipeline...")
# Load data
self.load_data()
# Filtering
if filter_params is None:
filter_params = {"highpass": 0.1, "lowpass": 100.0, "sfreq": 250.0}
self.filter_and_resample(**filter_params)
# Bad channel detection
if detect_bad_channels:
self.detect_and_fix_bad_channels()
# Line noise removal
if remove_line_noise:
self.remove_line_noise()
# Apply reference
self._apply_reference()
# ICA
if apply_ica:
if ica_params is None:
ica_params = {}
self.apply_ica(**ica_params)
# Re-reference EEG after ICA if regressed
if (
apply_ica
and ica_params is not None
and ica_params.get("regress", False)
and self.datatype == "eeg"
and self.raw is not None
):
logger.info("Re-referencing EEG after ICA regression.")
self.raw.set_eeg_reference("average", verbose=False)
logger.success("Preprocessing pipeline completed!")
return self.raw
def _apply_reference(self) -> None:
"""Apply appropriate reference for the datatype."""
if self.raw is None:
raise ValueError("Data not loaded. Call load_data() first.")
if self.datatype == "eeg":
logger.info("Applying average reference for EEG.")
self.raw.set_eeg_reference("average", verbose=False)
elif self.datatype == "meg":
logger.info("Applying gradient compensation 3 for MEG.")
self.raw.apply_gradient_compensation(3, verbose=False)
[docs]
def save(self, filename: str | Path | None = None) -> None:
"""Save preprocessed data.
Parameters
----------
filename : str | Path | None, optional
Output filename. If None, generates BIDS-compliant name.
"""
if self.raw is None:
raise ValueError("No data to save. Run pipeline first.")
if self.datatype is None:
raise ValueError("Datatype not set. Load data first.")
# Create output directory
self.output_dir.mkdir(parents=True, exist_ok=True)
# Generate filename if not provided
if filename is None:
if isinstance(self.input_path, BIDSPath):
basename = self.input_path.basename.split(".")[0]
subject = self.input_path.subject
session = self.input_path.session
subdir = self.output_dir / f"sub-{subject}" / f"ses-{session}" / self.datatype
subdir.mkdir(parents=True, exist_ok=True)
filename = subdir / f"{basename}_preproc_{self.datatype}.fif"
elif self.input_path is not None:
filename = self.output_dir / f"{self.input_path.stem}_preproc_{self.datatype}.fif"
else:
filename = self.output_dir / f"preprocessed_{self.datatype}.fif"
else:
filename = Path(filename)
filename.parent.mkdir(parents=True, exist_ok=True)
logger.info(f"Saving preprocessed data to {filename}")
self.raw.save(filename, overwrite=True, verbose=False)
logger.success(f"Saved to {filename}")
def _save_bad_channels_tsv(self, bad_channels: list[str]) -> None:
"""Save bad channels information to TSV file."""
if not hasattr(self, "output_dir"):
return
if self.raw is None or self.datatype is None:
return
import json
import pandas as pd # type: ignore[import-untyped]
# Create filename
if self.input_path is None:
# Use default naming when Raw object was provided directly
subdir = self.output_dir
basename = "preprocessed"
elif isinstance(self.input_path, BIDSPath):
basename = self.input_path.basename
subject = self.input_path.subject
session = self.input_path.session
subdir = self.output_dir / f"sub-{subject}" / f"ses-{session}" / self.datatype
else:
subdir = self.output_dir
basename = self.input_path.stem
subdir.mkdir(parents=True, exist_ok=True)
fname = subdir / f"{basename}_desc-badchs_{self.datatype}.tsv"
# Create dataframe
chs = self.raw.ch_names
status = ["good"] * len(chs)
status_desc = ["fixed" if ch in bad_channels else "n/a" for ch in chs]
df = pd.DataFrame(
{
"name": chs,
"type": [self.datatype] * len(chs),
"status": status,
"status_description": status_desc,
}
)
# Save TSV
df.to_csv(fname, sep="\t", index=False, encoding="utf-8", na_rep="n/a")
# Save JSON sidecar
fname_json = fname.with_suffix(".json")
meta = {
"name": "Channels' name",
"type": "Channel type, e.g., EEG, MEG",
"status": "Channel status, good or bad",
"status_description": "Description of the channel status, e.g., fixed if interpolated",
}
with open(fname_json, "w") as f:
json.dump(meta, f, indent=4)
logger.trace(f"Saved bad channels info to {fname}")