Skip to content

Viz

Visualization utilities.

Visualization utilities for visuneu data.

plot_data(neuro, visual, sfreq, trial=None, trial_window=None, figsize=(6, 3), window=(0.0, 5.0), cmap_neuro='Greys', cmap_ontime='summer', color_offtime='black', marker_size=40)

Plot neural activity alongside stimulus labels.

Parameters:

Name Type Description Default
neuro ndarray

Neural data, shape (n_samples, n_channels).

required
visual ndarray

Label vector, shape (n_samples,).

required
sfreq float

Sampling frequency in Hz.

required
trial ndarray or None

Trial-ID vector, shape (n_samples,). When provided together with trial_window, every in-trial timepoint is scatter-plotted and coloured by its position inside the trial window.

None
trial_window list of float or int, or None

Two-element [start, end] relative to stimulus onset. float values are interpreted as seconds, int values as samples (same convention as BaseData.configure).

None
figsize tuple of float

Figure size (width, height) in inches.

(6, 3)
window tuple of float | int

Display window (start, end). float values are interpreted as seconds, int values as samples.

(0.0, 5.0)
cmap_neuro str

Colormap for the neural activity heatmap.

'Greys'
cmap_ontime str

Colormap for in-trial time of stimulus labels.

'summer'
color_offtime str

Color for non-stimulus time points.

'black'
marker_size float

Marker size for scatter points.

40

Returns:

Type Description
Figure

The generated figure.

Source code in src/vneurotk/viz/data.py
def plot_data(
    neuro: np.ndarray,
    visual: np.ndarray,
    sfreq: float,
    trial: np.ndarray | None = None,
    trial_window: list | None = None,
    figsize: tuple[float, float] = (6, 3),
    window: tuple[float, float] = (0.0, 5.0),
    cmap_neuro: str = "Greys",
    cmap_ontime: str = "summer",
    color_offtime: str = "black",
    marker_size: float = 40,
) -> plt.Figure:
    """Plot neural activity alongside stimulus labels.

    Parameters
    ----------
    neuro : np.ndarray
        Neural data, shape ``(n_samples, n_channels)``.
    visual : np.ndarray
        Label vector, shape ``(n_samples,)``.
    sfreq : float
        Sampling frequency in Hz.
    trial : np.ndarray or None
        Trial-ID vector, shape ``(n_samples,)``.  When provided together
        with *trial_window*, every in-trial timepoint is scatter-plotted
        and coloured by its position inside the trial window.
    trial_window : list of float or int, or None
        Two-element ``[start, end]`` relative to stimulus onset.
        *float* values are interpreted as **seconds**, *int* values as
        **samples** (same convention as ``BaseData.configure``).
    figsize : tuple of float
        Figure size ``(width, height)`` in inches.
    window : tuple of float | int
        Display window ``(start, end)``.  *float* values are interpreted
        as **seconds**, *int* values as **samples**.
    cmap_neuro : str
        Colormap for the neural activity heatmap.
    cmap_ontime : str
        Colormap for in-trial time of stimulus labels.
    color_offtime : str
        Color for non-stimulus time points.
    marker_size : float
        Marker size for scatter points.

    Returns
    -------
    matplotlib.figure.Figure
        The generated figure.
    """
    # -- Convert window to samples (float=seconds, int=samples) --
    s_start = max(int(round(window[0] * sfreq)), 0) if isinstance(window[0], float) else max(int(window[0]), 0)
    s_end = (
        min(int(round(window[1] * sfreq)), neuro.shape[0])
        if isinstance(window[1], float)
        else min(int(window[1]), neuro.shape[0])
    )
    if isinstance(window[0], float):
        logger.info("plot window: {}-{} s (samples {}-{}).", window[0], window[1], s_start, s_end)
    else:
        logger.info("plot window: {}-{} samples.", s_start, s_end)
    X_win = neuro[s_start:s_end]
    y_win = visual[s_start:s_end]
    times = np.arange(s_start, s_end) / sfreq
    t_min, t_max = times[0], times[-1]

    # -- Parse labels --
    if trial is not None and trial_window is not None:
        # Convert trial_window to samples (float=seconds, int=samples)
        tw_samples = [int(round(v * sfreq)) if isinstance(v, float) else int(v) for v in trial_window]
        # Build stim_map from FULL arrays so edge-of-window trials are covered
        full_stim_map: dict[int, str] = {}
        for i in range(len(visual)):
            if not _is_null(visual[i]) and not _is_null(trial[i]):
                full_stim_map[int(trial[i])] = str(visual[i])

        trial_win = trial[s_start:s_end]
        is_stim, y_cat, intrial_time, tick_labels = _parse_labels_with_trial(
            y_win,
            trial_win,
            tw_samples,
            sfreq,
            full_stim_map,
        )
    else:
        is_stim, y_cat, intrial_time, tick_labels = _parse_labels(y_win)
        intrial_time = intrial_time / sfreq  # samples -> seconds

    # -- Layout --
    fig = plt.figure(figsize=figsize, constrained_layout=True)
    gs = fig.add_gridspec(
        2,
        2,
        width_ratios=[1, 0.015],
        height_ratios=[0.8, 1],
        hspace=0.12,
        wspace=0.02,
    )
    ax_y = fig.add_subplot(gs[0, 0])
    cax_y = fig.add_subplot(gs[0, 1])
    ax_x = fig.add_subplot(gs[1, 0], sharex=ax_y)
    cax_x = fig.add_subplot(gs[1, 1])

    # -- Upper panel: stimulus labels --
    ax_y.set_title("Trial setting", fontsize=10, loc="left")

    # Non-trial points
    if np.any(~is_stim):
        stride = max(1, len(y_win) // 5000)
        idx = np.where(~is_stim)[0][::stride]
        ax_y.scatter(
            times[idx],
            y_cat[idx],
            c=color_offtime,
            s=marker_size,
            marker=".",
            rasterized=True,
            linewidths=0,
            alpha=0.5,
        )

    # In-trial points with combined baseline+active colormap
    if np.any(is_stim):
        vmin = np.nanmin(intrial_time)
        vmax = np.nanmax(intrial_time)
        combined_cmap = _build_trial_cmap(
            vmin,
            vmax,
            cmap_ontime,
            color_offtime,
        )
        sc = ax_y.scatter(
            times[is_stim],
            y_cat[is_stim],
            c=intrial_time[is_stim],
            cmap=combined_cmap,
            vmin=vmin,
            vmax=vmax,
            s=marker_size,
            marker=".",
            rasterized=True,
            linewidths=0,
        )
        cbar_y = fig.colorbar(sc, cax=cax_y)
        cbar_y.set_label("In-trial Time (s)", fontsize=10)
        _apply_ticks(cbar_y, vmin, vmax, is_cbar=True)
    else:
        cax_y.axis("off")

    ax_y.set_yticks(range(len(tick_labels)))
    ax_y.set_yticklabels(tick_labels, fontsize=8)
    ax_y.set_ylabel("Label", fontsize=10)
    ax_y.grid(True, alpha=0.3, axis="y")
    ax_y.tick_params(axis="x", which="both", bottom=False, labelbottom=False)

    # -- Lower panel: neural activity --
    ax_x.set_title("Neural Activity", fontsize=10, loc="left")
    im = ax_x.imshow(
        X_win.T,
        aspect="auto",
        origin="lower",
        extent=(t_min, t_max, 0, X_win.shape[1]),
        cmap=cmap_neuro,
        interpolation="nearest",
    )
    cbar_x = fig.colorbar(im, cax=cax_x)
    cbar_x.set_label("Amplitude", fontsize=10)

    _apply_ticks(cbar_x, np.min(X_win), np.max(X_win), is_cbar=True)
    _apply_ticks(ax_x, t_min, t_max, axis="x")
    _apply_ticks(ax_x, 0, X_win.shape[1], axis="y", force_int=True)

    ax_x.set_xlabel("In-sample Time (s)", fontsize=10)
    ax_x.set_ylabel("Channel", fontsize=10)

    # -- Cleanup --
    for ax in (ax_y, ax_x):
        ax.spines["top"].set_visible(False)
        ax.spines["right"].set_visible(False)
    ax_y.spines["bottom"].set_visible(False)
    fig.align_ylabels([ax_y, ax_x])

    return fig