.. DO NOT EDIT. .. THIS FILE WAS AUTOMATICALLY GENERATED BY SPHINX-GALLERY. .. TO MAKE CHANGES, EDIT THE SOURCE PYTHON FILE: .. "auto_examples/02_pure_jax_fitting.py" .. LINE NUMBERS ARE GIVEN BELOW. .. only:: html .. note:: :class: sphx-glr-download-link-note :ref:`Go to the end ` to download the full example code. .. rst-class:: sphx-glr-example-title .. _sphx_glr_auto_examples_02_pure_jax_fitting.py: Pure-JAX AMICA on a NumPy array =============================== This example shows how to configure :class:`amica.AmicaConfig` and fit AMICA directly on a NumPy array. Use this interface when your data is already represented as an array, for example in a custom pipeline, simulation, or non-MNE workflow. The expected input shape is ``(n_channels, n_samples)``. .. GENERATED FROM PYTHON SOURCE LINES 15-17 Imports ------- .. GENERATED FROM PYTHON SOURCE LINES 17-24 .. code-block:: Python from __future__ import annotations import numpy as np from amica import Amica, AmicaConfig .. GENERATED FROM PYTHON SOURCE LINES 25-30 Generate synthetic data ----------------------- We create a simple ICA problem by mixing independent Laplacian sources with a random linear mixing matrix. .. GENERATED FROM PYTHON SOURCE LINES 30-51 .. code-block:: Python def make_synthetic_mixture( n_sources: int = 6, n_samples: int = 20_000, seed: int = 0, ) -> np.ndarray: """Create a random linear mixture of independent sources.""" rng = np.random.default_rng(seed) sources = rng.laplace(size=(n_sources, n_samples)) mixing = rng.standard_normal((n_sources, n_sources)) return (mixing @ sources).astype(np.float64) X = make_synthetic_mixture() print(f"Data shape: {X.shape} (channels, samples)") .. GENERATED FROM PYTHON SOURCE LINES 52-58 Configure AMICA --------------- ``num_mix_comps`` controls the number of mixture components per source. ``do_newton=True`` enables Newton optimization after the initial natural gradient phase. .. GENERATED FROM PYTHON SOURCE LINES 58-71 .. code-block:: Python config = AmicaConfig( max_iter=500, num_mix_comps=3, do_newton=True, ) model = Amica( config, random_state=42, ) .. GENERATED FROM PYTHON SOURCE LINES 72-77 Fit the model ------------- JAX will use an available GPU automatically when the JAX GPU package is installed; otherwise it runs on CPU. .. GENERATED FROM PYTHON SOURCE LINES 77-85 .. code-block:: Python result = model.fit(X) final_ll = float(np.asarray(result.log_likelihood)[-1]) print(f"Converged in {int(result.n_iter)} iterations; final log-likelihood = {final_ll:.4f}") .. GENERATED FROM PYTHON SOURCE LINES 86-91 Recover sources --------------- The fitted model can transform the observed mixtures back into source activations. .. GENERATED FROM PYTHON SOURCE LINES 91-97 .. code-block:: Python estimated_sources = model.transform(X) print(f"Recovered sources shape: {estimated_sources.shape}") .. GENERATED FROM PYTHON SOURCE LINES 98-104 Inspect outputs --------------- The ``result`` object also exposes AMICA parameters such as unmixing matrices, mixing matrices, source-density parameters, and, for multi-model AMICA, per-model weights and posterior probabilities. .. GENERATED FROM PYTHON SOURCE LINES 104-106 .. code-block:: Python print(type(result)) .. _sphx_glr_download_auto_examples_02_pure_jax_fitting.py: .. only:: html .. container:: sphx-glr-footer sphx-glr-footer-example .. container:: sphx-glr-download sphx-glr-download-jupyter :download:`Download Jupyter notebook: 02_pure_jax_fitting.ipynb <02_pure_jax_fitting.ipynb>` .. container:: sphx-glr-download sphx-glr-download-python :download:`Download Python source code: 02_pure_jax_fitting.py <02_pure_jax_fitting.py>` .. container:: sphx-glr-download sphx-glr-download-zip :download:`Download zipped: 02_pure_jax_fitting.zip <02_pure_jax_fitting.zip>` .. only:: html .. rst-class:: sphx-glr-signature `Gallery generated by Sphinx-Gallery `_