"""Tests the SARIMAX model."""

__author__ = ["TNTran92", "yarnabrina"]

import pytest
from numpy.testing import assert_allclose
from pandas.testing import assert_frame_equal

from sktime.forecasting.sarimax import SARIMAX
from sktime.tests.test_switch import run_test_for_class
from sktime.utils._testing.forecasting import make_forecasting_problem


@pytest.mark.skipif(
    not run_test_for_class(SARIMAX),
    reason="run test only if softdeps are present and incrementally (if requested)",
)
def test_SARIMAX_against_statsmodels():
    """Compares Sktime's and Statsmodel's SARIMAX."""
    from statsmodels.tsa.api import SARIMAX as _SARIMAX

    df = make_forecasting_problem()

    sktime_model = SARIMAX(order=(1, 0, 0), trend="t", seasonal_order=(1, 0, 0, 6))
    sktime_model.fit(df)
    y_pred = sktime_model.predict(df.index)

    stats = _SARIMAX(endog=df, order=(1, 0, 0), trend="t", seasonal_order=(1, 0, 0, 6))
    stats_fit = stats.fit()
    stats_pred = stats_fit.predict(df.index[0])
    assert_allclose(y_pred.tolist(), stats_pred.tolist())


@pytest.mark.skipif(
    not run_test_for_class(SARIMAX),
    reason="run test only if softdeps are present and incrementally (if requested)",
)
def test_SARIMAX_single_interval_against_statsmodels():
    """Compares Sktime's and Statsmodel's SARIMAX.

    Notes
    -----
    * Predict confidence intervals using underlying estimator and the wrapper.
    * Predicts for a single coverage.
    * Uses a non-default value of 97.5% to test inputs are actually being respected.
    """
    from statsmodels.tsa.api import SARIMAX as _SARIMAX

    df = make_forecasting_problem()

    sktime_model = SARIMAX(order=(1, 0, 0), trend="t", seasonal_order=(1, 0, 0, 6))
    sktime_model.fit(df)
    sktime_pred_int = sktime_model.predict_interval(df.index, coverage=0.975)
    sktime_pred_int = sktime_pred_int.xs((0, 0.975), axis="columns")

    stats = _SARIMAX(endog=df, order=(1, 0, 0), trend="t", seasonal_order=(1, 0, 0, 6))
    stats_fit = stats.fit()
    stats_pred_int = stats_fit.get_prediction(df.index[0]).conf_int(alpha=0.025)
    stats_pred_int.columns = ["lower", "upper"]

    assert_frame_equal(sktime_pred_int, stats_pred_int)


@pytest.mark.skipif(
    not run_test_for_class(SARIMAX),
    reason="run test only if softdeps are present and incrementally (if requested)",
)
def test_SARIMAX_multiple_intervals_against_statsmodels():
    """Compares Sktime's and Statsmodel's SARIMAX.

    Notes
    -----
    * Predict confidence intervals using underlying estimator and the wrapper.
    * Predicts for multiple coverage values, viz. 70% and 80%.
    """
    from statsmodels.tsa.api import SARIMAX as _SARIMAX

    df = make_forecasting_problem()

    sktime_model = SARIMAX(order=(1, 0, 0), trend="t", seasonal_order=(1, 0, 0, 6))
    sktime_model.fit(df)
    sktime_pred_int = sktime_model.predict_interval(df.index, coverage=[0.70, 0.80])
    sktime_pred_int_70 = sktime_pred_int.xs((0, 0.70), axis="columns")
    sktime_pred_int_80 = sktime_pred_int.xs((0, 0.80), axis="columns")

    stats = _SARIMAX(endog=df, order=(1, 0, 0), trend="t", seasonal_order=(1, 0, 0, 6))
    stats_fit = stats.fit()
    stats_pred_int_70 = stats_fit.get_prediction(df.index[0]).conf_int(alpha=0.30)
    stats_pred_int_70.columns = ["lower", "upper"]
    stats_pred_int_80 = stats_fit.get_prediction(df.index[0]).conf_int(alpha=0.20)
    stats_pred_int_80.columns = ["lower", "upper"]

    assert_frame_equal(sktime_pred_int_70, stats_pred_int_70)
    assert_frame_equal(sktime_pred_int_80, stats_pred_int_80)


@pytest.mark.skipif(
    not run_test_for_class(SARIMAX),
    reason="run test only if softdeps are present and incrementally (if requested)",
)
def test_SARIMAX_for_exogeneous_features():
    """Checking when X is passed to predict but not fit"""
    from sktime.datasets import load_longley
    from sktime.split import temporal_train_test_split

    y, X = load_longley()
    y_train, _, _, X_test = temporal_train_test_split(y, X)
    forecaster = SARIMAX()
    forecaster.fit(y_train)
    forecaster.predict(fh=[1, 2, 3, 4], X=X_test)


@pytest.mark.skipif(
    not run_test_for_class(SARIMAX),
    reason="run test only if softdeps are present and incrementally (if requested)",
)
def test_SARIMAX_update_with_exogenous_variables():
    """Test update method with exogenous variables (targets PR #8626 bug)."""
    from sktime.datasets import load_longley
    from sktime.split import temporal_train_test_split

    y, X = load_longley()
    y_train, y_test, X_train, X_test = temporal_train_test_split(y, X)

    # Test with update_params=True
    forecaster = SARIMAX(order=(1, 0, 0))
    forecaster.fit(y_train, X=X_train, fh=1)
    forecaster.update(y_test, X=X_test, update_params=True)

    # Verify that the forecaster state is correctly updated
    assert forecaster.cutoff == y_test.index[-1]
    assert forecaster._is_fitted

    # Test with update_params=False
    forecaster2 = SARIMAX(order=(1, 0, 0))
    forecaster2.fit(y_train, X=X_train, fh=1)
    forecaster2.update(y_test, X=X_test, update_params=False)

    # Verify that the forecaster state is correctly updated
    assert forecaster2.cutoff == y_test.index[-1]
    assert forecaster2._is_fitted
