# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from typing import TYPE_CHECKING, Any, TypedDict, cast

from typing_extensions import Self

from streamlit.errors import StreamlitAPIException
from streamlit.logger import get_logger
from streamlit.runtime.scriptrunner import get_script_run_ctx

if TYPE_CHECKING:
    from streamlit.runtime.state import SessionState
    from streamlit.runtime.state.common import WidgetValuePresenter


_LOGGER = get_logger(__name__)


class _TriggerPayload(TypedDict, total=False):
    event: str
    value: object


def make_bidi_component_presenter(
    aggregator_id: str,
    component_id: str | None = None,
    allowed_state_keys: set[str] | None = None,
) -> WidgetValuePresenter:
    """Return a presenter that merges trigger events into CCv2 state.

    This function returns a callable that takes a component's persistent state
    value and the current `SessionState` instance, and returns the user-visible
    value that should appear in `st.session_state`.

    The presenter is side-effect-free and does not mutate stored state or
    callback behavior. It is intended to be attached to the persistent state
    widget via the generic `presenter` hook.

    Parameters
    ----------
    aggregator_id
        The ID of the trigger aggregator widget that holds the event payloads.

    Returns
    -------
    WidgetValuePresenter
        A callable that merges the trigger event values into the component's
        base state for presentation in `st.session_state`.

    """

    def _present(base_value: object, session_state: SessionState) -> object:
        def _check_modification(k: str) -> None:
            ctx = get_script_run_ctx()
            if ctx is not None and component_id is not None:
                user_key = session_state._key_id_mapper.get_key_from_id(component_id)
                if (
                    component_id in ctx.widget_ids_this_run
                    or user_key in ctx.form_ids_this_run
                ):
                    raise StreamlitAPIException(
                        f"`st.session_state.{user_key}.{k}` cannot be modified after the component"
                        f" with key `{user_key}` is instantiated."
                    )

        # Base state must be a flat mapping; otherwise, present as-is.
        base_map: dict[str, object] | None = None
        if isinstance(base_value, dict):
            base_map = cast("dict[str, object]", base_value)

        if base_map is not None:
            # Read the trigger aggregator payloads if present
            try:
                agg_meta = session_state._new_widget_state.widget_metadata.get(
                    aggregator_id
                )
                if agg_meta is None or agg_meta.value_type != "json_trigger_value":
                    return base_value

                try:
                    agg_payloads_obj = session_state._new_widget_state[aggregator_id]
                except KeyError:
                    agg_payloads_obj = None

                payloads_list: list[_TriggerPayload] | None
                if agg_payloads_obj is None:
                    payloads_list = None
                elif isinstance(agg_payloads_obj, list):
                    # Filter and cast to the expected payload type shape
                    payloads_list = [
                        cast("_TriggerPayload", p)
                        for p in agg_payloads_obj
                        if isinstance(p, dict)
                    ]
                elif isinstance(agg_payloads_obj, dict):
                    payloads_list = [cast("_TriggerPayload", agg_payloads_obj)]
                else:
                    payloads_list = None

                event_to_val: dict[str, object] = {}
                if payloads_list is not None:
                    for payload in payloads_list:
                        ev = payload.get("event")
                        if isinstance(ev, str):
                            event_to_val[ev] = payload.get("value")

                # Merge triggers into a flat view: triggers first, then base
                flat: dict[str, object] = dict(event_to_val)
                flat.update(base_map)

                # Return a write-through dict that updates the underlying
                # component state when users assign nested keys via
                # st.session_state[component_user_key][name] = value. Using a
                # dict subclass ensures pretty-printing and JSON serialization
                # behave as expected for st.write and logs.
                class _WriteThrough(dict[str, object]):
                    def __init__(self, data: dict[str, object]) -> None:
                        super().__init__(data)

                    def __getattr__(self, name: str) -> object:
                        return self.get(name)

                    def __setattr__(self, name: str, value: object) -> None:
                        if name.startswith(("__", "_")):
                            return super().__setattr__(name, value)
                        self[name] = value
                        return None

                    def __deepcopy__(self, memo: dict[int, Any]) -> Self:
                        # This object is a proxy to the real state. Don't copy it.
                        memo[id(self)] = self
                        return self

                    def __setitem__(self, k: str, v: object) -> None:
                        _check_modification(k)

                        if (
                            allowed_state_keys is not None
                            and k not in allowed_state_keys
                        ):
                            # Silently ignore invalid keys to match permissive session_state semantics
                            return

                        # Update the underlying stored base state and this dict
                        super().__setitem__(k, v)
                        try:
                            # Store back to session state's widget store as a flat mapping
                            ss = session_state
                            # Directly set the value in the new widget state store
                            if component_id is not None:
                                ss._new_widget_state.set_from_value(
                                    component_id, dict(self)
                                )
                        except Exception as e:
                            _LOGGER.debug("Failed to persist CCv2 state update: %s", e)

                    def __delitem__(self, k: str) -> None:
                        _check_modification(k)

                        super().__delitem__(k)
                        try:
                            ss = session_state
                            if component_id is not None:
                                ss._new_widget_state.set_from_value(
                                    component_id, dict(self)
                                )
                        except Exception as e:
                            _LOGGER.debug(
                                "Failed to persist CCv2 state deletion: %s", e
                            )

                return _WriteThrough(flat)
            except Exception as e:
                # On any error, fall back to the base value
                _LOGGER.debug(
                    "Failed to merge trigger events into component state: %s",
                    e,
                    exc_info=e,
                )
                return base_value

        return base_value

    return _present
