# pyright: reportMissingImports=false
import pytest
import polars as pl
import yaml


@pytest.fixture(scope="session", autouse=True)
def set_polars_display_settings():
    pl.Config.set_ascii_tables(True)  # Turns ┌── into +--
    pl.Config.set_tbl_rows(20)


with open("schema.yaml") as f:
    SCHEMA = yaml.safe_load(f)


@pytest.fixture(scope="session")
def df():
    try:
        return pl.read_parquet("sample_data.parquet")
    except Exception as e:
        pytest.fail(f"test fail due to {e}")


def get_pl_dtype(type_str):
    pl_dtype_map = {
        "int": [pl.Int8, pl.Int16, pl.Int32, pl.Int64, pl.Int128],
        "string": [pl.Categorical, pl.Utf8],
        "float": [pl.Float16, pl.Float32, pl.Float64],
        "bool": [pl.Boolean],
    }

    return pl_dtype_map.get(type_str, None)


dtype_cases = [(col, props) for col, props in SCHEMA["columns"].items()]


@pytest.mark.parametrize("col, props", dtype_cases)
def test_column_presence_type(col, props, df):
    assert col in df.columns, f"column missing: {col}"

    expected_dtype = get_pl_dtype(props["dtype"])
    actual_dtype = df[col].dtype

    if expected_dtype is None:
        raise ValueError(f"column {col} type in schema file is not found in mapping")

    assert (
        actual_dtype in expected_dtype
    ), f"column type mismatch: {col}, should be {expected_dtype}, got {actual_dtype}"


def get_rules(rule_name):
    return [
        (col, props[rule_name])
        for col, props in SCHEMA["columns"].items()
        if rule_name in props
    ]


@pytest.mark.parametrize("col, min_val", get_rules("min"))
def test_min_val(df, col, min_val):
    actual = df[col].min()
    assert (
        actual >= min_val
    ), f"Column '{col}' failed min check. Found {actual}, expected >= {min_val}"


@pytest.mark.parametrize("col, max_val", get_rules("max"))
def test_max_val(df, col, max_val):
    actual = df[col].max()
    assert (
        actual <= max_val
    ), f"Column '{col}' failed max check. Found {actual}, expected <= {max_val}"


nullable_cases = [
    (col, props["nullable"])
    for col, props in SCHEMA["columns"].items()
    if "nullable" in props and props["nullable"] is False
]


@pytest.mark.parametrize("col, is_nullable", nullable_cases)
def test_no_nulls_allowed(df, col, is_nullable):
    null_count = df[col].null_count()
    assert (
        null_count == 0
    ), f"Column '{col}' must not contain nulls (nullable:{is_nullable}). Found {null_count} nulls."


unique_cases = [
    col for col, props in SCHEMA["columns"].items() if props.get("unique") is True
]


@pytest.mark.parametrize("col", unique_cases)
def test_uniqueness(df, col):
    assert (
        df[col].is_unique().all()
    ), f"Column '{col}' contains duplicates, but spec requires unique values."


@pytest.mark.parametrize("col, range", get_rules("range"))
def test_range(df, col, range):
    expected_min = range[0]
    expected_max = range[1]
    actual_min = df[col].min()
    actual_max = df[col].max()
    assert (
        actual_min >= expected_min and actual_max <= expected_max
    ), f"column: {col} out of range. expected range: [{expected_min},{expected_max}], got: [{actual_min},{actual_max}]"


@pytest.mark.parametrize("col, allowed_values", get_rules("allowed_values"))
def test_allowed_values(df, col, allowed_values):
    expected_values = set(allowed_values)
    actual_values = set(df[col].unique())
    missings = actual_values - expected_values

    assert (
        len(missings) == 0
    ), f"column: {col}, has not allowed values. expected: {expected_values}, got: {actual_values}"