# Copyright (c) Streamlit Inc. (2018-2022) Snowflake Inc. (2022-2025)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Discovery utilities for Component v2 manifests in installed packages.

The scanner searches installed distributions for a ``pyproject.toml`` with
``[tool.streamlit.component]`` configuration and extracts the component
manifests along with their package roots.

The implementation prioritizes efficiency and safety by filtering likely
candidates and avoiding excessive filesystem operations.
"""

from __future__ import annotations

import importlib.metadata
import importlib.util
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Final

import toml
from packaging import utils as packaging_utils

from streamlit.components.v2.component_path_utils import ComponentPathUtils
from streamlit.errors import StreamlitComponentRegistryError
from streamlit.logger import get_logger

_LOGGER: Final = get_logger(__name__)


def _normalize_package_name(dist_name: str) -> str:
    """Normalize a distribution name to an importable package name.

    This helper converts hyphens to underscores to derive a best-effort
    importable module/package name from a distribution name.

    Parameters
    ----------
    dist_name : str
        The distribution/project name (e.g., "my-awesome-component").

    Returns
    -------
    str
        The normalized package name suitable for import lookups
        (e.g., "my_awesome_component").
    """
    return dist_name.replace("-", "_")


@dataclass
class ComponentManifest:
    """Parsed component manifest data."""

    name: str
    version: str
    components: list[ComponentConfig]


@dataclass
class ComponentConfig:
    """Structured configuration for a single component entry.

    Parameters
    ----------
    name
        Component name as declared in ``pyproject.toml``.
    asset_dir
        Optional relative directory containing component assets.
    """

    name: str
    asset_dir: str | None = None

    @staticmethod
    def from_dict(config: dict[str, Any]) -> ComponentConfig:
        """Create a ComponentConfig from a raw dict.

        Parameters
        ----------
        config
            Raw component dictionary parsed from TOML.

        Returns
        -------
        ComponentConfig
            Parsed and validated component configuration.
        """
        name_value = config.get("name")
        if not isinstance(name_value, str) or not name_value:
            # Fail closed: invalid component entry
            raise ValueError("Component entry missing required 'name' field")

        asset_dir_value = config.get("asset_dir")
        if asset_dir_value is not None and not isinstance(asset_dir_value, str):
            # Fail closed: invalid asset_dir value
            raise ValueError("'asset_dir' must be a string")

        return ComponentConfig(
            name=name_value,
            asset_dir=asset_dir_value,
        )

    @staticmethod
    def parse_or_none(config: dict[str, Any]) -> ComponentConfig | None:
        """Best-effort parse without raising; returns None on malformed input."""
        try:
            return ComponentConfig.from_dict(config)
        except Exception as e:
            _LOGGER.debug("Skipping malformed component entry: %s", e)
            return None

    def resolve_asset_root(self, package_root: Path) -> Path | None:
        """Resolve and security-check the component's asset root directory.

        Parameters
        ----------
        package_root : Path
            The root directory of the installed component package.

        Returns
        -------
        Path | None
            Absolute, resolved path to the asset directory, or ``None`` if
            ``asset_dir`` is not declared.

        Raises
        ------
        StreamlitComponentRegistryError
            If the declared directory does not exist, is not a directory, or
            resolves outside of ``package_root``.
        """
        if self.asset_dir is None:
            return None

        # Validate the configured path string first
        ComponentPathUtils.validate_path_security(self.asset_dir)

        asset_root = (package_root / self.asset_dir).resolve()

        if not asset_root.exists() or not asset_root.is_dir():
            raise StreamlitComponentRegistryError(
                f"Declared asset_dir '{self.asset_dir}' for component '{self.name}' "
                f"does not exist or is not a directory under package root '{package_root}'."
            )

        # Ensure the resolved directory is within the package root after following symlinks
        ComponentPathUtils.ensure_within_root(
            abs_path=asset_root,
            root=package_root.resolve(),
            kind="asset_dir",
        )

        return asset_root


def _is_likely_streamlit_component_package(
    dist: importlib.metadata.Distribution,
) -> bool:
    """Check if a package is likely to contain streamlit components before
    expensive operations.

    This early filter reduces the number of packages that need file I/O
    operations from potentially hundreds down to just a few candidates.

    Parameters
    ----------
    dist : importlib.metadata.Distribution
        The package distribution to check.

    Returns
    -------
    bool
        True if the package might contain streamlit components, False otherwise.
    """
    # Get package metadata
    name = dist.name.lower()
    summary = dist.metadata["Summary"].lower() if "Summary" in dist.metadata else ""

    # Filter 1: Package name suggests streamlit component
    if "streamlit" in name:
        return True

    # Filter 2: Package description mentions streamlit
    if "streamlit" in summary:
        return True

    # Filter 3: Check if package depends on streamlit
    try:
        # Check requires_dist for streamlit dependency
        requires_dist = dist.metadata.get_all("Requires-Dist") or []
        for requirement in requires_dist:
            if requirement and "streamlit" in requirement.lower():
                return True
    except Exception as e:
        # Don't fail on metadata parsing issues, but log for debugging purposes
        _LOGGER.debug(
            "Failed to parse package metadata for streamlit component detection: %s", e
        )

    # Filter 4: Check if this is a known streamlit ecosystem package
    # Common patterns in streamlit component package names. Use anchored checks to
    # avoid matching unrelated packages like "test-utils".
    return name.startswith(("streamlit-", "streamlit_", "st-", "st_"))


def _find_package_pyproject_toml(dist: importlib.metadata.Distribution) -> Path | None:
    """Find ``pyproject.toml`` for a package.

    Handles both regular and editable installs. The function uses increasingly
    permissive strategies to locate the file while validating that the file
    belongs to the given distribution.

    Parameters
    ----------
    dist : importlib.metadata.Distribution
        The package distribution to find pyproject.toml for.

    Returns
    -------
    Path | None
        Path to the ``pyproject.toml`` file if found, otherwise ``None``.
    """
    package_name = _normalize_package_name(dist.name)

    # Try increasingly permissive strategies
    for finder in (
        _pyproject_via_read_text,
        _pyproject_via_dist_files,
        lambda d: _pyproject_via_import_spec(d, package_name),
    ):
        result = finder(dist)
        if result is not None:
            return result

    return None


def _pyproject_via_read_text(dist: importlib.metadata.Distribution) -> Path | None:
    """Locate pyproject.toml using the distribution's read_text + nearby files.

    This works for many types of installations including some editable ones.
    """
    package_name = _normalize_package_name(dist.name)
    try:
        if hasattr(dist, "read_text"):
            pyproject_content = dist.read_text("pyproject.toml")
            if pyproject_content and dist.files:
                # Found content, now find the actual file path
                # Look for a reasonable file to get the directory
                for file in dist.files:
                    if "__init__.py" in str(file) or ".py" in str(file):
                        try:
                            file_path = Path(str(dist.locate_file(file)))
                            # Check nearby directories for pyproject.toml
                            current_dir = file_path.parent
                            # Check current directory and parent
                            for search_dir in [current_dir, current_dir.parent]:
                                pyproject_path = search_dir / "pyproject.toml"
                                if (
                                    pyproject_path.exists()
                                    and _validate_pyproject_for_package(
                                        pyproject_path,
                                        dist.name,
                                        package_name,
                                    )
                                ):
                                    return pyproject_path
                            # Stop after first reasonable file
                            break
                        except Exception:  # noqa: S112
                            continue
    except Exception:
        return None
    return None


def _pyproject_via_dist_files(dist: importlib.metadata.Distribution) -> Path | None:
    """Locate pyproject.toml by scanning the distribution's file list."""
    package_name = _normalize_package_name(dist.name)
    files = getattr(dist, "files", None)
    if not files:
        return None
    for file in files:
        if getattr(file, "name", None) == "pyproject.toml" or str(file).endswith(
            "pyproject.toml"
        ):
            try:
                pyproject_path = Path(str(dist.locate_file(file)))
                if _validate_pyproject_for_package(
                    pyproject_path,
                    dist.name,
                    package_name,
                ):
                    return pyproject_path
            except Exception:  # noqa: S112
                continue
    return None


def _pyproject_via_import_spec(
    dist: importlib.metadata.Distribution, package_name: str
) -> Path | None:
    """Locate pyproject.toml by resolving the import spec and checking nearby.

    For editable installs, try the package directory and its parent only.
    """
    try:
        spec = importlib.util.find_spec(package_name)
        if spec and spec.origin:
            package_dir = Path(spec.origin).parent
            for search_dir in [package_dir, package_dir.parent]:
                pyproject_path = search_dir / "pyproject.toml"
                if pyproject_path.exists() and _validate_pyproject_for_package(
                    pyproject_path,
                    dist.name,
                    package_name,
                ):
                    return pyproject_path
    except Exception:
        return None
    return None


def _validate_pyproject_for_package(
    pyproject_path: Path, dist_name: str, package_name: str
) -> bool:
    """Validate that a ``pyproject.toml`` file belongs to the specified package.

    Parameters
    ----------
    pyproject_path : Path
        Path to the pyproject.toml file to validate.
    dist_name : str
        The distribution name (e.g., "streamlit-bokeh").
    package_name : str
        The package name (e.g., "streamlit_bokeh").

    Returns
    -------
    bool
        True if the file belongs to this package, False otherwise.
    """
    try:
        with open(pyproject_path, encoding="utf-8") as f:
            pyproject_data = toml.load(f)

        # Check if this pyproject.toml is for the package we're looking for
        project_name = None

        # Try to get the project name from [project] table
        if "project" in pyproject_data and "name" in pyproject_data["project"]:
            project_name = pyproject_data["project"]["name"]

        # Also try to get it from [tool.setuptools] or other build system configs
        if (
            not project_name
            and "tool" in pyproject_data
            and (
                "setuptools" in pyproject_data["tool"]
                and "package-name" in pyproject_data["tool"]["setuptools"]
            )
        ):
            project_name = pyproject_data["tool"]["setuptools"]["package-name"]

        # If we found a project name, check if it matches either the dist name or package name
        if project_name:
            # Normalize names for comparison using PEP 503 canonicalization
            # This handles hyphens, underscores, and dots consistently.
            canonical_project = packaging_utils.canonicalize_name(project_name)
            canonical_dist = packaging_utils.canonicalize_name(dist_name)
            canonical_package = packaging_utils.canonicalize_name(package_name)

            # Check if project name matches either the distribution name or the package name
            return canonical_project in (canonical_dist, canonical_package)

        # If we can't determine ownership, be conservative and reject it
        return False

    except Exception as e:
        _LOGGER.debug(
            "Error validating pyproject.toml at %s for %s: %s",
            pyproject_path,
            dist_name,
            e,
        )
        return False


def _load_pyproject(pyproject_path: Path) -> dict[str, Any] | None:
    """Load and parse a pyproject.toml, returning parsed data or None on failure."""
    try:
        with open(pyproject_path, encoding="utf-8") as f:
            return toml.load(f)
    except Exception as e:
        _LOGGER.debug("Failed to parse pyproject.toml at %s: %s", pyproject_path, e)
        return None


def _extract_components(pyproject_data: dict[str, Any]) -> list[dict[str, Any]] | None:
    """Extract raw component dicts from pyproject data; return None if absent."""
    streamlit_component = (
        pyproject_data.get("tool", {}).get("streamlit", {}).get("component")
    )
    if not streamlit_component:
        return None
    raw_components = streamlit_component.get("components")
    if not isinstance(raw_components, list):
        return None
    # Ensure a list of dicts for type safety
    result: list[dict[str, Any]] = [
        item for item in raw_components if isinstance(item, dict)
    ]
    if not result:
        return None
    return result


def _resolve_package_root(
    dist: importlib.metadata.Distribution, package_name: str, pyproject_path: Path
) -> Path:
    """Resolve the package root directory with fallbacks."""
    package_root: Path | None = None
    try:
        spec = importlib.util.find_spec(package_name)
        if spec and spec.origin:
            package_root = Path(spec.origin).parent
    except Exception as e:
        _LOGGER.debug(
            "Failed to resolve package root via import spec for %s: %s",
            package_name,
            e,
        )

    files = getattr(dist, "files", None)
    if not package_root and files:
        for file in files:
            if package_name in str(file) and "__init__.py" in str(file):
                try:
                    init_path = Path(str(dist.locate_file(file)))
                    package_root = init_path.parent
                    break
                except Exception as e:
                    _LOGGER.debug(
                        "Failed to resolve package root via dist files for %s: %s",
                        package_name,
                        e,
                    )

    if not package_root:
        package_root = pyproject_path.parent

    return package_root


def _derive_project_metadata(
    pyproject_data: dict[str, Any], dist: importlib.metadata.Distribution
) -> tuple[str, str]:
    """Derive project name and version with safe fallbacks."""
    project_table = pyproject_data.get("project", {})
    derived_name = project_table.get("name") or dist.name
    derived_version = project_table.get("version") or dist.version or "0.0.0"
    return derived_name, derived_version


def _process_single_package(
    dist: importlib.metadata.Distribution,
) -> tuple[ComponentManifest, Path] | None:
    """Process a single package to extract component manifest.

    This function is designed to be called from a thread pool for parallel processing.

    Parameters
    ----------
    dist : importlib.metadata.Distribution
        The package distribution to process.

    Returns
    -------
    tuple[ComponentManifest, Path] | None
        The manifest and package root if found, otherwise ``None``.
    """
    try:
        pyproject_path = _find_package_pyproject_toml(dist)
        if not pyproject_path:
            return None

        pyproject_data = _load_pyproject(pyproject_path)
        if pyproject_data is None:
            return None

        raw_components = _extract_components(pyproject_data)
        if not raw_components:
            return None

        package_name = _normalize_package_name(dist.name)
        package_root = _resolve_package_root(dist, package_name, pyproject_path)

        derived_name, derived_version = _derive_project_metadata(pyproject_data, dist)

        parsed_components: list[ComponentConfig] = [
            parsed
            for comp in raw_components
            if (parsed := ComponentConfig.parse_or_none(comp)) is not None
        ]

        if not parsed_components:
            return None

        manifest = ComponentManifest(
            name=derived_name,
            version=derived_version,
            components=parsed_components,
        )

        return (manifest, package_root)

    except Exception as e:
        _LOGGER.debug(
            "Unexpected error processing distribution %s: %s",
            getattr(dist, "name", "<unknown>"),
            e,
        )
        return None


def scan_component_manifests(
    max_workers: int | None = None,
) -> list[tuple[ComponentManifest, Path]]:
    """Scan installed packages for Streamlit component metadata.

    Uses parallel processing to improve performance in environments with many
    installed packages. Applies early filtering to only check packages likely to
    contain streamlit components.

    Parameters
    ----------
    max_workers : int or None
        Maximum number of worker threads. If None, uses min(32, (os.cpu_count()
        or 1) + 4).

    Returns
    -------
    list[tuple[ComponentManifest, Path]]
        List of tuples of manifests and their package root paths.
    """
    manifests: list[tuple[ComponentManifest, Path]] = []

    # Get all distributions first (this is fast)
    all_distributions = list(importlib.metadata.distributions())

    if not all_distributions:
        return manifests

    # Apply early filtering to reduce expensive file operations
    candidate_distributions = [
        dist
        for dist in all_distributions
        if _is_likely_streamlit_component_package(dist)
    ]

    _LOGGER.debug(
        "Filtered %d packages down to %d candidates for component scanning",
        len(all_distributions),
        len(candidate_distributions),
    )

    if not candidate_distributions:
        return manifests

    # Default max_workers follows ThreadPoolExecutor's default logic
    if max_workers is None:
        max_workers = min(32, (os.cpu_count() or 1) + 4)

    # Clamp max_workers to reasonable bounds for this task
    max_workers = min(
        max_workers, len(candidate_distributions), 16
    )  # Don't use more threads than packages or 16

    _LOGGER.debug(
        "Scanning %d candidate packages for component manifests using %d worker threads",
        len(candidate_distributions),
        max_workers,
    )

    # Process packages in parallel
    with ThreadPoolExecutor(max_workers=max_workers) as executor:
        # Submit all tasks
        future_to_dist = {
            executor.submit(_process_single_package, dist): dist.name
            for dist in candidate_distributions
        }

        # Collect results as they complete
        for future in as_completed(future_to_dist):
            result = future.result()
            if result:
                manifests.append(result)

    _LOGGER.debug("Found %d component manifests total", len(manifests))
    return manifests
