Part 3. Data Validation with Pytest
Continue the data validation following prior one now using pytest. Before using pytest, I have written a standard python script to check my data. It works well actually but I thought I wanted to see how it works with pytest.
Here’s some of my learnings.
Key Syntax & Concepts
Fixtures (@pytest.fixture)
Fixtures are setup functions, and I can use the fixture to avoid in-place data updates, which happens a lot in DS work.
-
scope="session": This is critical for performance. It tells Pytest to load the dataframe once and keep it in memory for the entire test run. -
Without this: Pytest would reload the Parquet file 50 times if I have 50 tests.
-
autouse=True: Automatically applies the fixture to every test without needing to request it as an argument. I used this for the Polars display config so tables always look nice in logs.
@pytest.fixture(scope="session")def df(): # only reads disk once! return pl.read_parquet("sample_data.parquet")Parametrization (@pytest.mark.parametrize)
This is the engine of the validation suite. It allows us to write one test function but run it multiple times with different inputs.
How I used it:
- Generate the list of cases: I used list comprehensions to extract rules from the YAML schema.
# creates a list like: [('age', 0), ('salary', 20000)]min_val_rules = get_rules("min")- Decorate the test:
@pytest.mark.parametrize("col, min_val", min_val_rules)def test_min_val(df, col, min_val): # this function runs once for every item in min_val_rules ...Syntax Deep Dive: @pytest.mark.parametrize
The syntax can be confusing because it links strings to function arguments.
There are two main arguments you must provide:
argnames(String): A comma-separated string identifying the variable names.argvalues(List): A list of data.
import pytest
# (1) THE NAMES (2) THE DATA# │ │# ▼ ▼@pytest.mark.parametrize("param_name, another_param", [ (value1_a, value1_b), (value2_a, value2_b), # ... more test cases])
def test_function(param_name, another_param):# ▲ ▲# │ │# (3) MUST MATCH EXACTLY assert param_name + another_param > 0Best Practices for Data Validation
- Fail Fast vs. Collect Errors: Pytest collects errors by default (good).
- Schema as Code: Keeping rules in
schema.yamlmakes it readable for non-coders (like PMs or stakeholders). - Sanity Check the Fixture: My
dffixture includes atry/exceptblock. If the input file doesn’t exist,pytest.failstops the whole suite immediately, saving time. - Display Settings: Setting
pl.Config.set_tbl_rows(20)in the autouse fixture ensures that if I print a dataframe in a failed test for debugging, I can actually see the data in the CI/CD logs.
Next Step: “Source of Truth” Comparison (The Anti-Join Pattern)
When validating a dataset against a “Gold Source” (or previous production’s run), we want to avoid slow Python loops. In Polars, the efficient way to do “record-by-record” comparison is using Anti-Joins.
Concept: An anti join returns rows from the left dataframe that verify false against the right dataframe.
The Strategy
- Check 1 (completeness): Do I have keys that shouldn’t be there? (Anti-join on Primary Key).
- This sometimes will become: I want to test these 10k observations (provided with the same common key)
- Check 2 (correctness): Do I have rows where the data doesn’t match the source? (Anti-join on All Columns or sometimes a subset of the columns).
Code Implementation
@pytest.fixture(scope="session")def source_df(): # load the "gold source" or "truth" file return pl.read_parquet("source_of_truth.parquet")
def test_reconciliation_exact_match(df, source_df): """ Validates that rows in 'df' match 'source_df' exactly for common keys. """ # define your primary keys (what makes a row unique?) primary_keys = ["id", "date"]
# define value columns (what data are we comparing?) # getting intersection of columns ensures we only compare what exists in both value_cols = [c for c in df.columns if c in source_df.columns and c not in primary_keys]
# --------------------------------------------------------- # CHECK 1: Unexpected Records (Phantom Keys) # Are there IDs in my new data that don't exist in the source? # --------------------------------------------------------- unexpected_rows = df.join(source_df, on=primary_keys, how="anti")
assert unexpected_rows.height == 0, \ f"Found {unexpected_rows.height} records with IDs not present in Source of Truth.\n{unexpected_rows.head()}"
# --------------------------------------------------------- # CHECK 2: Value Mismatches # Join on EVERYTHING. If a row remains, it means the combination # of (Key + Values) in 'df' was not found in 'source_df'. # --------------------------------------------------------- # Note: This finds rows that are "different", but implies the key exists # (since we passed Check 1). comparison_cols = primary_keys + value_cols
mismatched_rows = df.join(source_df, on=comparison_cols, how="anti")
# If this fails, it prints the rows from 'df' that are wrong assert mismatched_rows.height == 0, \ f"Found {mismatched_rows.height} records where values differ from Source.\n{mismatched_rows.head()}"Full Code Annotation & Design Patterns
Summary of Approach
- Config-Driven Testing: The logic is decoupled from the data. Adding a new column check requires editing
schema.yaml, not the Python code. - Performance First: The
dffixture usesscope="session"to load the Parquet file exactly once, rather than reloading it for every single test case (which would be slow for large datasets).- This sometimes could be just the
pytestway, and I feel I see it sometimes practical to just use a global variable seems fine.
- This sometimes could be just the
- Dynamic Parametrization: We generate test cases programmatically. Pytest sees the list of rules before it even runs the tests, allowing it to report “50 tests passed” rather than “1 loop passed”.
Annotated Implementation
import pytestimport polars as plimport yaml
# NOTE: Good UX tweak.# 'autouse=True' means I don't need to pass this into every test function.# Setting ASCII tables ensures that if a test fails in a CI/CD pipeline (like GitHub Actions),# the dataframe printout remains readable in the text logs.@pytest.fixture(scope="session", autouse=True)def set_polars_display_settings(): pl.Config.set_ascii_tables(True) pl.Config.set_tbl_rows(20)
# NOTE: Global load.# We load the schema outside fixtures so 'parametrize' decorators can access it# during "collection phase" (before tests run).with open("schema.yaml") as f: SCHEMA = yaml.safe_load(f)
@pytest.fixture(scope="session")def df(): try: return pl.read_parquet("sample_data.parquet") except Exception as e: # TRICK: Fail fast. # If the file is missing, don't just error out; explicitly fail the test suite # with a clear message. This stops Pytest from trying to run 100 tests on a NoneType. pytest.fail(f"test fail due to {e}")
# ... (get_pl_dtype helper) ...
# NOTE: Generator Pattern.# We create a list of tuples [(col, props), ...] here.# Pytest uses this list to generate individual test cases.dtype_cases = [(col, props) for col, props in SCHEMA["columns"].items()]
@pytest.mark.parametrize("col, props", dtype_cases)def test_column_presence_type(col, props, df): # Check 1: Existence assert col in df.columns, f"column missing: {col}"
# Check 2: Type Safety # We use a list for expected_dtype (e.g., [Int32, Int64]) to allow flexibility # because Polars might infer Int32 while we are okay with Int64. expected_dtype = get_pl_dtype(props["dtype"]) actual_dtype = df[col].dtype
assert actual_dtype in expected_dtype, \ f"Type mismatch: {col}. Got {actual_dtype}, expected one of {expected_dtype}"
# ... (get_rules helper) ...
@pytest.mark.parametrize("col, min_val", get_rules("min"))def test_min_val(df, col, min_val): # PERFORMANCE TIP: # Use Polars native .min() (Rust engine) instead of iterating rows in Python. actual = df[col].min() assert actual >= min_val, f"Column '{col}' min check failed."
# ... (max check) ...
# NOTE: Explicit is better than implicit.# We explicitly look for 'nullable: False' rather than assuming default is True.nullable_cases = [ (col, props["nullable"]) for col, props in SCHEMA["columns"].items() if "nullable" in props and props["nullable"] is False]
@pytest.mark.parametrize("col, is_nullable", nullable_cases)def test_no_nulls_allowed(df, col, is_nullable): # Polars .null_count() is essentially instant vs looping in Python. null_count = df[col].null_count() assert null_count == 0, f"Column '{col}' has {null_count} nulls."
@pytest.mark.parametrize("col, allowed_values", get_rules("allowed_values"))def test_allowed_values(df, col, allowed_values): # ALGORITHM: Set Difference. # Instead of checking "if x in allowed", we do "Set(Actual) - Set(Allowed)". # Any remainder implies an illegal value. expected_values = set(allowed_values) actual_values = set(df[col].unique()) missings = actual_values - expected_values
assert len(missings) == 0, \ f"Column '{col}' has illegal values: {missings}"Attachment
Files used: