"""
Pure-JAX AMICA on a NumPy array
===============================

This example shows how to configure :class:`amica.AmicaConfig` and fit AMICA
directly on a NumPy array.

Use this interface when your data is already represented as an array, for
example in a custom pipeline, simulation, or non-MNE workflow.

The expected input shape is ``(n_channels, n_samples)``.
"""

# %%
# Imports
# -------

from __future__ import annotations

import numpy as np

from amica import Amica, AmicaConfig

# %%
# Generate synthetic data
# -----------------------
#
# We create a simple ICA problem by mixing independent Laplacian sources with a
# random linear mixing matrix.


def make_synthetic_mixture(
    n_sources: int = 6,
    n_samples: int = 20_000,
    seed: int = 0,
) -> np.ndarray:
    """Create a random linear mixture of independent sources."""
    rng = np.random.default_rng(seed)

    sources = rng.laplace(size=(n_sources, n_samples))
    mixing = rng.standard_normal((n_sources, n_sources))

    return (mixing @ sources).astype(np.float64)


X = make_synthetic_mixture()

print(f"Data shape: {X.shape}  (channels, samples)")


# %%
# Configure AMICA
# ---------------
#
# ``num_mix_comps`` controls the number of mixture components per source.
# ``do_newton=True`` enables Newton optimization after the initial natural
# gradient phase.

config = AmicaConfig(
    max_iter=500,
    num_mix_comps=3,
    do_newton=True,
)

model = Amica(
    config,
    random_state=42,
)


# %%
# Fit the model
# -------------
#
# JAX will use an available GPU automatically when the JAX GPU package is
# installed; otherwise it runs on CPU.

result = model.fit(X)

final_ll = float(np.asarray(result.log_likelihood)[-1])

print(f"Converged in {int(result.n_iter)} iterations; final log-likelihood = {final_ll:.4f}")


# %%
# Recover sources
# ---------------
#
# The fitted model can transform the observed mixtures back into source
# activations.

estimated_sources = model.transform(X)

print(f"Recovered sources shape: {estimated_sources.shape}")


# %%
# Inspect outputs
# ---------------
#
# The ``result`` object also exposes AMICA parameters such as unmixing matrices,
# mixing matrices, source-density parameters, and, for multi-model AMICA,
# per-model weights and posterior probabilities.

print(type(result))
