# pyright: reportMissingImports=false
import yaml
import polars as pl


# class for normalization (renaming and casting)
class DataHarmonizer:
    def __init__(self, schema_path: str):
        self.schema = self._load_schema(schema_path).get("schema", {})
        # map string types to actual polars types for casting
        # TODO: be more precise if desired type to be: `int8` in the schema file
        self.type_map = {
            "string": pl.String,
            "float": pl.Float64,
            "integer": pl.Int64,
            "date": pl.Date,
            "boolean": pl.Boolean,
        }

    def _load_schema(self, schema_path):
        try:
            with open(schema_path, "r") as f:
                return yaml.safe_load(f)
        except Exception as e:
            raise ValueError(f"Failed to load schema: {e}")

    def harmonize(self, df: pl.DataFrame):
        report_logs = []

        # 1. column renaming based on aliases
        rename_map = {}
        df_cols = set(df.columns)

        for col, props in self.schema.items():
            if col not in df_cols:
                # check if any alias exists in the dataframe
                aliases = props.get("aliases", [])
                for alias in aliases:
                    if alias in df_cols:
                        rename_map[alias] = col
                        report_logs.append(
                            {
                                "action": "Rename",
                                "column": col,
                                "details": f"Renamed from '{alias}'",
                            }
                        )
                        break  # stop after finding first match

        if rename_map:
            df = df.rename(rename_map)

        # 2. type casting, performs on already renamed df
        for col, props in self.schema.items():
            if col in df.columns:
                target_type_str = props.get("type")
                target_pl_type = self.type_map.get(target_type_str)

                if target_pl_type:
                    current_type = df[col].dtype

                    if current_type != target_pl_type:
                        # keep track of the null count; failure pattern on casting goes wrong
                        nulls_before = df[col].null_count()

                        # use strict=False forces bad values to null instead of crashing
                        df = df.with_columns(
                            pl.col(col).cast(target_pl_type, strict=False)
                        )

                        nulls_after = df[col].null_count()
                        failed_rows = nulls_after - nulls_before

                        if failed_rows > 0:
                            report_logs.append(
                                {
                                    "action": "Cast Fail",
                                    "column": col,
                                    "details": f"Failed to cast {failed_rows} rows to {target_type_str}",
                                }
                            )
                        else:
                            report_logs.append(
                                {
                                    "action": "Cast Success",
                                    "column": col,
                                    "details": f"Cast from {current_type} to {target_type_str}",
                                }
                            )

        # return the transformed df and the report
        return df, pl.DataFrame(
            report_logs,
            schema={"action": pl.String, "column": pl.String, "details": pl.String},
        )


class DataValidator:
    def __init__(self, schema_path: str):
        self.schema = self._load_schema(schema_path).get("schema", {})
        self.polars_dtypes = {
            "string": [pl.String, pl.Categorical, pl.Enum],
            "float": [pl.Float32, pl.Float64, pl.Decimal],
            "integer": [
                pl.Int8,
                pl.Int16,
                pl.Int32,
                pl.Int64,
                pl.UInt8,
                pl.UInt16,
                pl.UInt32,
                pl.UInt64,
            ],
            "date": [pl.Date, pl.Datetime, pl.Duration, pl.Time],
            "boolean": [pl.Boolean],
        }

    def _load_schema(self, schema_path):
        try:
            with open(schema_path, "r") as f:
                return yaml.safe_load(f)
        except Exception as e:
            raise ValueError(f"Failed to load schema: {e}")

    def _get_example_msg(self, invalid_df: pl.DataFrame, col_name: str) -> str:
        """helper to get up to 3 unique examples."""
        try:
            invalid_vals = invalid_df[col_name].unique().to_list()
            invalid_vals = [v for v in invalid_vals if v is not None]

            count = len(invalid_vals)
            examples = invalid_vals[:3]
            example_str = ", ".join(map(str, examples))

            if count > 3:
                return f"{example_str}, ... (+{count - 3} more)"
            return example_str
        except Exception as e:
            print(f"no example found: {e}")
            return "N/A"

    def validate(self, df: pl.DataFrame) -> pl.DataFrame:
        errors = []

        for col, props in self.schema.items():
            if col not in df.columns:
                if not props.get("optional", False):
                    errors.append(
                        {
                            "variable": col,
                            "check": "Existence",
                            "description": "Column not found",
                            "examples": "N/A",
                        }
                    )
                continue

            col_expr = pl.col(col)

            # type check
            expected_key = props.get("type")
            actual_dtype = df[col].dtype

            if expected_key in self.polars_dtypes:
                valid_types = self.polars_dtypes[expected_key]
                # check if actual type is in the allowed list for that key
                if actual_dtype not in valid_types:
                    errors.append(
                        {
                            "variable": col,
                            "check": "Type",
                            "description": f"Expected {expected_key}, got {actual_dtype}",
                            "examples": "N/A",
                        }
                    )

            # range check
            if "range" in props:
                low, high = props["range"]
                invalid = df.filter((col_expr < low) | (col_expr > high))
                if not invalid.is_empty():
                    errors.append(
                        {
                            "variable": col,
                            "check": "Range",
                            "description": f"{len(invalid)} rows outside [{low}, {high}]",
                            "examples": self._get_example_msg(invalid, col),
                        }
                    )

            # allowed values (whitelist)
            if "allowed_value" in props:
                allowed = props["allowed_value"]
                if (
                    any(isinstance(x, str) for x in allowed)
                    and actual_dtype != pl.String
                ):
                    check_expr = col_expr.cast(pl.String)
                else:
                    check_expr = col_expr

                invalid = df.filter(~check_expr.is_in(allowed))
                if not invalid.is_empty():
                    errors.append(
                        {
                            "variable": col,
                            "check": "Allowed Values",
                            "description": f"{len(invalid)} rows have invalid values",
                            "examples": self._get_example_msg(invalid, col),
                        }
                    )

            # not allowed values (blacklist)
            if "not_allowed_value" in props:
                forbidden = props["not_allowed_value"]
                if (
                    any(isinstance(x, str) for x in forbidden)
                    and actual_dtype != pl.String
                ):
                    check_expr = col_expr.cast(pl.String)
                else:
                    check_expr = col_expr

                invalid = df.filter(check_expr.is_in(forbidden))
                if not invalid.is_empty():
                    errors.append(
                        {
                            "variable": col,
                            "check": "Forbidden Values",
                            "description": f"{len(invalid)} rows found in blacklist",
                            "examples": self._get_example_msg(invalid, col),
                        }
                    )

            #regex check
            if "regex" in props:
                pattern = props["regex"]
                # [ensure we only regex on strings
                if actual_dtype == pl.String:
                    invalid = df.filter(~col_expr.str.contains(pattern))
                    if not invalid.is_empty():
                        errors.append(
                            {
                                "variable": col,
                                "check": "Regex",
                                "description": f"{len(invalid)} rows mismatch pattern '{pattern}'",
                                "examples": self._get_example_msg(invalid, col),
                            }
                        )

        # return a dataframe
        schema = {
            "variable": pl.String,
            "check": pl.String,
            "description": pl.String,
            "examples": pl.String,
        }
        return pl.DataFrame(errors, schema=schema)


def print_report(title, df_report):
    width = 60
    print("\n" + "=" * width)
    print(f"{title:^60}")
    print("=" * width + "\n")

    if df_report.is_empty():
        print("No issues found.")
    else:
        with pl.Config(
            tbl_formatting="ASCII_MARKDOWN",
            tbl_hide_column_data_types=True,
            tbl_rows=-1,
            fmt_str_lengths=100,
        ):
            print(df_report)
    print("\n")


def main():
    # 1. path to data
    data_path = "../data/train.parquet"
    schema_path = "data_schema.yaml"

    # 2. initialize
    harmonizer = DataHarmonizer(schema_path)
    validator = DataValidator(schema_path)

    # 3. load Data
    print(f"Loading {data_path}...")
    df = pl.read_parquet(data_path)

    df.columns = [col.lower() for col in df.columns]

    # 4. step 1: rename & cast
    # run harmonization
    df_clean, harm_report = harmonizer.harmonize(df)
    print_report("HARMONIZATION REPORT (Renaming & Casting)", harm_report)

    # 5. step 2: validate (ranges & logic)
    error_df = validator.validate(df_clean)
    print_report("VALIDATION REPORT (Logic & Quality)", error_df)

    # 6. step 3: output sample
    # default 10% sampling
    sample_fraction = 0.1

    # handle small datasets gracefully (take at least 5 rows or all if small)
    n_rows = len(df_clean)
    if n_rows < 50:
        df_sample = df_clean
    else:
        df_sample = df_clean.sample(fraction=sample_fraction, seed=42)

    print(f"Original Row Count: {n_rows}")
    print(f"Sampled Row Count:  {len(df_sample)} (approx {sample_fraction * 100}%)")

    # df_sample.write_parquet("clean_sample.parquet")
    print("Sample dataset ready.")


if __name__ == "__main__":
    main()