Skip to article frontmatterSkip to article content

Infer model on array


Imports

import torch
import xbatcher
import xarray as xr
import numpy as np
import pytest

from functions import _get_output_array_size, predict_on_array

Testing the array size function

%%writefile test_get_array_size.py
import torch
import xbatcher
import xarray as xr
import numpy as np
import pytest

from functions import _get_output_array_size, _get_resample_factor
Overwriting test_get_array_size.py
%%writefile -a test_get_array_size.py

@pytest.fixture
def bgen_fixture() -> xbatcher.BatchGenerator:
    data = xr.DataArray(
        data=np.random.rand(100, 100, 10),
        dims=("x", "y", "t"),
        coords={
            "x": np.arange(100),
            "y": np.arange(100),
            "t": np.arange(10),
        }
    )
    
    bgen = xbatcher.BatchGenerator(
        data,
        input_dims=dict(x=10, y=10),
        input_overlap=dict(x=5, y=5),
    )
    return bgen

@pytest.mark.parametrize(
    "case_description, output_tensor_dim, new_dim, core_dim, resample_dim, expected_output",
    [
        (
            "Resampling only: Downsample x, Upsample y",
            {'x': 5, 'y': 20},  
            [],
            [],
            ['x', 'y'],
            {'x': 50, 'y': 200} 
        ),
        (
            "New dimensions only: Add a 'channel' dimension",
            {'channel': 3},
            ['channel'],
            [],
            [],
            {'channel': 3}
        ),
        (
            "Mixed: Resample x, add new channel dimension and keep t as core",
            {'x': 30, 'channel': 12, 't': 10}, 
            ['channel'],
            ['t'],
            ['x'],
            {'x': 300, 'channel': 12, 't': 10} 
        ),
        (
            "Identity resampling (ratio=1)",
            {'x': 10, 'y': 10},
            [],
            [],
            ['x', 'y'],
            {'x': 100, 'y': 100} 
        ),
        (
            "Core dims only: 't' is a core dim",
            {'t': 10},
            [], 
            ['t'], 
            [],
            {'t': 10}
        ),
    ]
)
def test_get_output_array_size_scenarios(
    bgen_fixture,  # The fixture is passed as an argument
    case_description,
    output_tensor_dim,
    new_dim,
    core_dim,
    resample_dim,
    expected_output
):
    """
    Tests various valid scenarios for calculating the output array size.
    The `case_description` parameter is not used in the code but helps make
    test results more readable.
    """
    # The `bgen_fixture` argument is the BatchGenerator instance created by our fixture
    result = _get_output_array_size(
        bgen=bgen_fixture,
        output_tensor_dim=output_tensor_dim,
        new_dim=new_dim,
        core_dim=core_dim,
        resample_dim=resample_dim
    )
    
    assert result == expected_output, f"Failed on case: {case_description}"
Appending to test_get_array_size.py
%%writefile -a test_get_array_size.py

def test_get_output_array_size_raises_error_on_mismatched_core_dim(bgen_fixture):
    """Tests ValueError when a core_dim size doesn't match the source."""
    with pytest.raises(ValueError, match="does not equal the source data array size"):
        _get_output_array_size(
            bgen_fixture, output_tensor_dim={'t': 99}, new_dim=[], core_dim=['t'], resample_dim=[]
        )

def test_get_output_array_size_raises_error_on_unspecified_dim(bgen_fixture):
    """Tests ValueError when a dimension is not specified in any category."""
    with pytest.raises(ValueError, match="must be specified in one of"):
        _get_output_array_size(
            bgen_fixture, output_tensor_dim={'x': 10}, new_dim=[], core_dim=[], resample_dim=[]
        )

def test_get_resample_factor_raises_error_on_invalid_ratio(bgen_fixture):
    """Tests AssertionError when the resample ratio is not an integer or its inverse."""
    with pytest.raises(AssertionError, match="must be an integer or its inverse"):
        # 15 / 10 = 1.5, which is not a valid ratio
        _get_resample_factor(bgen_fixture, output_tensor_dim={'x': 15}, resample_dim=['x'])
Appending to test_get_array_size.py
!pytest -v test_get_array_size.py
============================= test session starts ==============================
platform linux -- Python 3.13.5, pytest-8.4.1, pluggy-1.6.0 -- /home/runner/micromamba/envs/cookbook-dev/bin/python3.13
cachedir: .pytest_cache
rootdir: /home/runner/work/xbatcher-deep-learning/xbatcher-deep-learning/notebooks
plugins: anyio-4.10.0
collecting ... 
collected 8 items                                                              

test_get_array_size.py::test_get_output_array_size_scenarios[Resampling only: Downsample x, Upsample y-output_tensor_dim0-new_dim0-core_dim0-resample_dim0-expected_output0] PASSED [ 12%]
test_get_array_size.py::test_get_output_array_size_scenarios[New dimensions only: Add a 'channel' dimension-output_tensor_dim1-new_dim1-core_dim1-resample_dim1-expected_output1] PASSED [ 25%]
test_get_array_size.py::test_get_output_array_size_scenarios[Mixed: Resample x, add new channel dimension and keep t as core-output_tensor_dim2-new_dim2-core_dim2-resample_dim2-expected_output2] PASSED [ 37%]
test_get_array_size.py::test_get_output_array_size_scenarios[Identity resampling (ratio=1)-output_tensor_dim3-new_dim3-core_dim3-resample_dim3-expected_output3] PASSED [ 50%]
test_get_array_size.py::test_get_output_array_size_scenarios[Core dims only: 't' is a core dim-output_tensor_dim4-new_dim4-core_dim4-resample_dim4-expected_output4] PASSED [ 62%]
test_get_array_size.py::test_get_output_array_size_raises_error_on_mismatched_core_dim PASSED [ 75%]
test_get_array_size.py::test_get_output_array_size_raises_error_on_unspecified_dim PASSED [ 87%]
test_get_array_size.py::test_get_resample_factor_raises_error_on_invalid_ratio PASSED [100%]

============================== 8 passed in 1.86s ===============================

Testing the predict_on_array function

%%writefile test_predict_on_array.py
import xarray as xr
import numpy as np
import torch
import xbatcher
import pytest
from xbatcher.loaders.torch import MapDataset

from functions import _get_output_array_size, _resample_coordinate
from functions import predict_on_array, _get_resample_factor
from dummy_models import Identity, MeanAlongDim, SubsetAlongAxis, ExpandAlongAxis, AddAxis
Overwriting test_predict_on_array.py
import xarray as xr
import numpy as np
import torch
import xbatcher
import pytest
from xbatcher.loaders.torch import MapDataset

from functions import *
from dummy_models import *
input_tensor = torch.arange(125).reshape((5, 5, 5)).to(torch.float32)
input_tensor[0,0,:]
tensor([0., 1., 2., 3., 4.])
model = ExpandAlongAxis(1, 2)
model(input_tensor).shape
torch.Size([5, 10, 5])
%%writefile -a test_predict_on_array.py

@pytest.fixture
def map_dataset_fixture() -> MapDataset:
    data = xr.DataArray(
        data=np.arange(20 * 10).reshape(20, 10).astype(np.float32),
        dims=("x", "y"),
        coords={"x": np.arange(20, dtype=float), "y": np.arange(10, dtype=float)},
    )
    bgen = xbatcher.BatchGenerator(data, input_dims=dict(x=10, y=5), input_overlap=dict(x=2, y=2))
    return MapDataset(bgen)
Appending to test_predict_on_array.py
data = xr.DataArray(
    data=np.arange(20 * 10).reshape(20, 10),
    dims=("x", "y"),
    coords={"x": np.arange(20), "y": np.arange(10)}
).astype(float)

bgen = xbatcher.BatchGenerator(
    data,
    input_dims=dict(x=10, y=5),
    input_overlap=dict(x=2, y=2)
)
ds = MapDataset(bgen)
data
Loading...
ds[1]
tensor([[ 3., 4., 5., 6., 7.], [13., 14., 15., 16., 17.], [23., 24., 25., 26., 27.], [33., 34., 35., 36., 37.], [43., 44., 45., 46., 47.], [53., 54., 55., 56., 57.], [63., 64., 65., 66., 67.], [73., 74., 75., 76., 77.], [83., 84., 85., 86., 87.], [93., 94., 95., 96., 97.]], dtype=torch.float64)
output_tensor_dim = {'x': 20, 'y': 5}
resample_dim = ['x', 'y']
core_dim = []
new_dim = []
ds[0].shape
torch.Size([10, 5])
model(ds[0]).shape
torch.Size([10, 10])
import functions
from importlib import reload
reload(functions)
result = functions.predict_on_array(
    ds,
    model,
    output_tensor_dim=output_tensor_dim,
    new_dim=new_dim,
    core_dim=core_dim,
    resample_dim=resample_dim,
    batch_size=4
)
  0%|          | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:00<00:00, 179.82it/s]

%%writefile -a test_predict_on_array.py

@pytest.mark.parametrize("factor, mode, expected", [
    (2.0, "edges", np.arange(0, 10, 0.5)),
    (0.5, "edges", np.arange(0, 10, 2.0)),
])
def test_resample_coordinate(factor, mode, expected):
    coord = xr.DataArray(np.arange(10, dtype=float), dims="x")
    resampled = _resample_coordinate(coord, factor, mode)
    np.testing.assert_allclose(resampled, expected)
Appending to test_predict_on_array.py
%%writefile -a test_predict_on_array.py

@pytest.mark.parametrize(
    "model, output_tensor_dim, new_dim, core_dim, resample_dim, manual_transform",
    [
        # Case 1: Identity - No change
        (
            Identity(),
            {'x': 10, 'y': 5},
            [], [], ['x', 'y'],
            lambda da: da.data
        ),
        # Case 2: ExpandAlongAxis - Upsampling
        (
            ExpandAlongAxis(ax=1, n_repeats=2), # ax=1 is 'x'
            {'x': 20, 'y': 5},
            [], [], ['x', 'y'],
            lambda da: da.data.repeat(2, axis=0) # axis=0 in the 2D numpy array
        ),
        # Case 3: SubsetAlongAxis - Coarsening
        (
            SubsetAlongAxis(ax=1, n=5), # ax=1 is 'x'
            {'x': 5, 'y': 5},
            [], [], ['x', 'y'],
            lambda da: da.isel(x=slice(0, 5)).data
        ),
        # Case 4: MeanAlongDim - Dimension reduction
        (
            MeanAlongDim(ax=2), # ax=2 is 'y'
            {'x': 10},
            [], [], ['x'],
            lambda da: da.mean(dim='y').data
        ),
        # Case 5: AddAxis - Add a new dimension
        (
            AddAxis(ax=1), # Add new dim at axis 1
            {'channel': 1, 'x': 10, 'y': 5},
            ['channel'], [], ['x', 'y'],
            lambda da: np.expand_dims(da.data, axis=0)
        ),
    ]
)
def test_predict_on_array_all_models(
    map_dataset_fixture, model, output_tensor_dim, new_dim, core_dim, resample_dim, manual_transform
):
    """
    Tests reassembly, averaging, and coordinate assignment using a variety of models.
    """
    dataset = map_dataset_fixture
    bgen = dataset.X_generator
    resample_factor = _get_resample_factor(bgen, output_tensor_dim, resample_dim)

    # --- Run the function under test ---
    result_da = predict_on_array(
        dataset=dataset, model=model, output_tensor_dim=output_tensor_dim,
        new_dim=new_dim, core_dim=core_dim, resample_dim=resample_dim, batch_size=4
    )

    # --- Manually calculate the expected result ---
    expected_size = _get_output_array_size(bgen, output_tensor_dim, new_dim, core_dim, resample_dim)
    expected_sum = xr.DataArray(np.zeros(list(expected_size.values())), dims=list(expected_size.keys()))
    expected_count = xr.full_like(expected_sum, 0, dtype=int)

    for i in range(len(dataset)):
        batch_da = bgen[i]
        old_indexer = bgen._batch_selectors.selectors[i][0]
        new_indexer = {}
        for key in old_indexer:
            if key in resample_dim:
                new_indexer[key] = slice(int(old_indexer[key].start * resample_factor.get(key, 1)), int(old_indexer[key].stop * resample_factor.get(key, 1)))
            elif key in core_dim:
                new_indexer[key] = old_indexer[key]

        model_output_on_batch = manual_transform(batch_da)
        print(f"Batch {i}: {new_indexer} -> {model_output_on_batch.shape}")
        print(f"Expected sum shape: {expected_sum.loc[new_indexer].shape}")
        expected_sum.loc[new_indexer] += model_output_on_batch
        expected_count.loc[new_indexer] += 1
        
    expected_avg_data = expected_sum.data / expected_count.data
    
    # --- Assert correctness ---
    np.testing.assert_allclose(result_da.values, expected_avg_data, equal_nan=True)
Appending to test_predict_on_array.py
!pytest -v test_predict_on_array.py
============================= test session starts ==============================
platform linux -- Python 3.13.5, pytest-8.4.1, pluggy-1.6.0 -- /home/runner/micromamba/envs/cookbook-dev/bin/python3.13
cachedir: .pytest_cache
rootdir: /home/runner/work/xbatcher-deep-learning/xbatcher-deep-learning/notebooks
plugins: anyio-4.10.0
collecting ... 
collected 7 items                                                              

test_predict_on_array.py::test_resample_coordinate[2.0-edges-expected0] PASSED [ 14%]
test_predict_on_array.py::test_resample_coordinate[0.5-edges-expected1] PASSED [ 28%]
test_predict_on_array.py::test_predict_on_array_all_models[model0-output_tensor_dim0-new_dim0-core_dim0-resample_dim0-<lambda>] PASSED [ 42%]
test_predict_on_array.py::test_predict_on_array_all_models[model1-output_tensor_dim1-new_dim1-core_dim1-resample_dim1-<lambda>] 
PASSED [ 57%]
test_predict_on_array.py::test_predict_on_array_all_models[model2-output_tensor_dim2-new_dim2-core_dim2-resample_dim2-<lambda>] PASSED [ 71%]
test_predict_on_array.py::test_predict_on_array_all_models[model3-output_tensor_dim3-new_dim3-core_dim3-resample_dim3-<lambda>] PASSED [ 85%]
test_predict_on_array.py::test_predict_on_array_all_models[model4-output_tensor_dim4-new_dim4-core_dim4-resample_dim4-<lambda>] PASSED [100%]

=============================== warnings summary ===============================
test_predict_on_array.py::test_predict_on_array_all_models[model0-output_tensor_dim0-new_dim0-core_dim0-resample_dim0-<lambda>]
test_predict_on_array.py::test_predict_on_array_all_models[model1-output_tensor_dim1-new_dim1-core_dim1-resample_dim1-<lambda>]
test_predict_on_array.py::test_predict_on_array_all_models[model2-output_tensor_dim2-new_dim2-core_dim2-resample_dim2-<lambda>]
test_predict_on_array.py::test_predict_on_array_all_models[model3-output_tensor_dim3-new_dim3-core_dim3-resample_dim3-<lambda>]
test_predict_on_array.py::test_predict_on_array_all_models[model4-output_tensor_dim4-new_dim4-core_dim4-resample_dim4-<lambda>]
  /home/runner/work/xbatcher-deep-learning/xbatcher-deep-learning/notebooks/test_predict_on_array.py:108: RuntimeWarning: invalid value encountered in divide
    expected_avg_data = expected_sum.data / expected_count.data

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================== 7 passed, 5 warnings in 2.28s =========================