Imports¶
import torch
import xbatcher
import xarray as xr
import numpy as np
import pytest
from functions import _get_output_array_size, predict_on_array
Testing the array size function¶
%%writefile test_get_array_size.py
import torch
import xbatcher
import xarray as xr
import numpy as np
import pytest
from functions import _get_output_array_size, _get_resample_factor
Overwriting test_get_array_size.py
%%writefile -a test_get_array_size.py
@pytest.fixture
def bgen_fixture() -> xbatcher.BatchGenerator:
data = xr.DataArray(
data=np.random.rand(100, 100, 10),
dims=("x", "y", "t"),
coords={
"x": np.arange(100),
"y": np.arange(100),
"t": np.arange(10),
}
)
bgen = xbatcher.BatchGenerator(
data,
input_dims=dict(x=10, y=10),
input_overlap=dict(x=5, y=5),
)
return bgen
@pytest.mark.parametrize(
"case_description, output_tensor_dim, new_dim, core_dim, resample_dim, expected_output",
[
(
"Resampling only: Downsample x, Upsample y",
{'x': 5, 'y': 20},
[],
[],
['x', 'y'],
{'x': 50, 'y': 200}
),
(
"New dimensions only: Add a 'channel' dimension",
{'channel': 3},
['channel'],
[],
[],
{'channel': 3}
),
(
"Mixed: Resample x, add new channel dimension and keep t as core",
{'x': 30, 'channel': 12, 't': 10},
['channel'],
['t'],
['x'],
{'x': 300, 'channel': 12, 't': 10}
),
(
"Identity resampling (ratio=1)",
{'x': 10, 'y': 10},
[],
[],
['x', 'y'],
{'x': 100, 'y': 100}
),
(
"Core dims only: 't' is a core dim",
{'t': 10},
[],
['t'],
[],
{'t': 10}
),
]
)
def test_get_output_array_size_scenarios(
bgen_fixture, # The fixture is passed as an argument
case_description,
output_tensor_dim,
new_dim,
core_dim,
resample_dim,
expected_output
):
"""
Tests various valid scenarios for calculating the output array size.
The `case_description` parameter is not used in the code but helps make
test results more readable.
"""
# The `bgen_fixture` argument is the BatchGenerator instance created by our fixture
result = _get_output_array_size(
bgen=bgen_fixture,
output_tensor_dim=output_tensor_dim,
new_dim=new_dim,
core_dim=core_dim,
resample_dim=resample_dim
)
assert result == expected_output, f"Failed on case: {case_description}"
Appending to test_get_array_size.py
%%writefile -a test_get_array_size.py
def test_get_output_array_size_raises_error_on_mismatched_core_dim(bgen_fixture):
"""Tests ValueError when a core_dim size doesn't match the source."""
with pytest.raises(ValueError, match="does not equal the source data array size"):
_get_output_array_size(
bgen_fixture, output_tensor_dim={'t': 99}, new_dim=[], core_dim=['t'], resample_dim=[]
)
def test_get_output_array_size_raises_error_on_unspecified_dim(bgen_fixture):
"""Tests ValueError when a dimension is not specified in any category."""
with pytest.raises(ValueError, match="must be specified in one of"):
_get_output_array_size(
bgen_fixture, output_tensor_dim={'x': 10}, new_dim=[], core_dim=[], resample_dim=[]
)
def test_get_resample_factor_raises_error_on_invalid_ratio(bgen_fixture):
"""Tests AssertionError when the resample ratio is not an integer or its inverse."""
with pytest.raises(AssertionError, match="must be an integer or its inverse"):
# 15 / 10 = 1.5, which is not a valid ratio
_get_resample_factor(bgen_fixture, output_tensor_dim={'x': 15}, resample_dim=['x'])
Appending to test_get_array_size.py
!pytest -v test_get_array_size.py
============================= test session starts ==============================
platform linux -- Python 3.13.5, pytest-8.4.1, pluggy-1.6.0 -- /home/runner/micromamba/envs/cookbook-dev/bin/python3.13
cachedir: .pytest_cache
rootdir: /home/runner/work/xbatcher-deep-learning/xbatcher-deep-learning/notebooks
plugins: anyio-4.10.0
collecting ...
collected 8 items
test_get_array_size.py::test_get_output_array_size_scenarios[Resampling only: Downsample x, Upsample y-output_tensor_dim0-new_dim0-core_dim0-resample_dim0-expected_output0] PASSED [ 12%]
test_get_array_size.py::test_get_output_array_size_scenarios[New dimensions only: Add a 'channel' dimension-output_tensor_dim1-new_dim1-core_dim1-resample_dim1-expected_output1] PASSED [ 25%]
test_get_array_size.py::test_get_output_array_size_scenarios[Mixed: Resample x, add new channel dimension and keep t as core-output_tensor_dim2-new_dim2-core_dim2-resample_dim2-expected_output2] PASSED [ 37%]
test_get_array_size.py::test_get_output_array_size_scenarios[Identity resampling (ratio=1)-output_tensor_dim3-new_dim3-core_dim3-resample_dim3-expected_output3] PASSED [ 50%]
test_get_array_size.py::test_get_output_array_size_scenarios[Core dims only: 't' is a core dim-output_tensor_dim4-new_dim4-core_dim4-resample_dim4-expected_output4] PASSED [ 62%]
test_get_array_size.py::test_get_output_array_size_raises_error_on_mismatched_core_dim PASSED [ 75%]
test_get_array_size.py::test_get_output_array_size_raises_error_on_unspecified_dim PASSED [ 87%]
test_get_array_size.py::test_get_resample_factor_raises_error_on_invalid_ratio PASSED [100%]
============================== 8 passed in 1.86s ===============================
Testing the predict_on_array function¶
%%writefile test_predict_on_array.py
import xarray as xr
import numpy as np
import torch
import xbatcher
import pytest
from xbatcher.loaders.torch import MapDataset
from functions import _get_output_array_size, _resample_coordinate
from functions import predict_on_array, _get_resample_factor
from dummy_models import Identity, MeanAlongDim, SubsetAlongAxis, ExpandAlongAxis, AddAxis
Overwriting test_predict_on_array.py
import xarray as xr
import numpy as np
import torch
import xbatcher
import pytest
from xbatcher.loaders.torch import MapDataset
from functions import *
from dummy_models import *
input_tensor = torch.arange(125).reshape((5, 5, 5)).to(torch.float32)
input_tensor[0,0,:]
tensor([0., 1., 2., 3., 4.])
model = ExpandAlongAxis(1, 2)
model(input_tensor).shape
torch.Size([5, 10, 5])
%%writefile -a test_predict_on_array.py
@pytest.fixture
def map_dataset_fixture() -> MapDataset:
data = xr.DataArray(
data=np.arange(20 * 10).reshape(20, 10).astype(np.float32),
dims=("x", "y"),
coords={"x": np.arange(20, dtype=float), "y": np.arange(10, dtype=float)},
)
bgen = xbatcher.BatchGenerator(data, input_dims=dict(x=10, y=5), input_overlap=dict(x=2, y=2))
return MapDataset(bgen)
Appending to test_predict_on_array.py
data = xr.DataArray(
data=np.arange(20 * 10).reshape(20, 10),
dims=("x", "y"),
coords={"x": np.arange(20), "y": np.arange(10)}
).astype(float)
bgen = xbatcher.BatchGenerator(
data,
input_dims=dict(x=10, y=5),
input_overlap=dict(x=2, y=2)
)
ds = MapDataset(bgen)
data
Loading...
ds[1]
tensor([[ 3., 4., 5., 6., 7.],
[13., 14., 15., 16., 17.],
[23., 24., 25., 26., 27.],
[33., 34., 35., 36., 37.],
[43., 44., 45., 46., 47.],
[53., 54., 55., 56., 57.],
[63., 64., 65., 66., 67.],
[73., 74., 75., 76., 77.],
[83., 84., 85., 86., 87.],
[93., 94., 95., 96., 97.]], dtype=torch.float64)
output_tensor_dim = {'x': 20, 'y': 5}
resample_dim = ['x', 'y']
core_dim = []
new_dim = []
ds[0].shape
torch.Size([10, 5])
model(ds[0]).shape
torch.Size([10, 10])
import functions
from importlib import reload
reload(functions)
result = functions.predict_on_array(
ds,
model,
output_tensor_dim=output_tensor_dim,
new_dim=new_dim,
core_dim=core_dim,
resample_dim=resample_dim,
batch_size=4
)
0%| | 0/1 [00:00<?, ?it/s]
100%|██████████| 1/1 [00:00<00:00, 179.82it/s]
%%writefile -a test_predict_on_array.py
@pytest.mark.parametrize("factor, mode, expected", [
(2.0, "edges", np.arange(0, 10, 0.5)),
(0.5, "edges", np.arange(0, 10, 2.0)),
])
def test_resample_coordinate(factor, mode, expected):
coord = xr.DataArray(np.arange(10, dtype=float), dims="x")
resampled = _resample_coordinate(coord, factor, mode)
np.testing.assert_allclose(resampled, expected)
Appending to test_predict_on_array.py
%%writefile -a test_predict_on_array.py
@pytest.mark.parametrize(
"model, output_tensor_dim, new_dim, core_dim, resample_dim, manual_transform",
[
# Case 1: Identity - No change
(
Identity(),
{'x': 10, 'y': 5},
[], [], ['x', 'y'],
lambda da: da.data
),
# Case 2: ExpandAlongAxis - Upsampling
(
ExpandAlongAxis(ax=1, n_repeats=2), # ax=1 is 'x'
{'x': 20, 'y': 5},
[], [], ['x', 'y'],
lambda da: da.data.repeat(2, axis=0) # axis=0 in the 2D numpy array
),
# Case 3: SubsetAlongAxis - Coarsening
(
SubsetAlongAxis(ax=1, n=5), # ax=1 is 'x'
{'x': 5, 'y': 5},
[], [], ['x', 'y'],
lambda da: da.isel(x=slice(0, 5)).data
),
# Case 4: MeanAlongDim - Dimension reduction
(
MeanAlongDim(ax=2), # ax=2 is 'y'
{'x': 10},
[], [], ['x'],
lambda da: da.mean(dim='y').data
),
# Case 5: AddAxis - Add a new dimension
(
AddAxis(ax=1), # Add new dim at axis 1
{'channel': 1, 'x': 10, 'y': 5},
['channel'], [], ['x', 'y'],
lambda da: np.expand_dims(da.data, axis=0)
),
]
)
def test_predict_on_array_all_models(
map_dataset_fixture, model, output_tensor_dim, new_dim, core_dim, resample_dim, manual_transform
):
"""
Tests reassembly, averaging, and coordinate assignment using a variety of models.
"""
dataset = map_dataset_fixture
bgen = dataset.X_generator
resample_factor = _get_resample_factor(bgen, output_tensor_dim, resample_dim)
# --- Run the function under test ---
result_da = predict_on_array(
dataset=dataset, model=model, output_tensor_dim=output_tensor_dim,
new_dim=new_dim, core_dim=core_dim, resample_dim=resample_dim, batch_size=4
)
# --- Manually calculate the expected result ---
expected_size = _get_output_array_size(bgen, output_tensor_dim, new_dim, core_dim, resample_dim)
expected_sum = xr.DataArray(np.zeros(list(expected_size.values())), dims=list(expected_size.keys()))
expected_count = xr.full_like(expected_sum, 0, dtype=int)
for i in range(len(dataset)):
batch_da = bgen[i]
old_indexer = bgen._batch_selectors.selectors[i][0]
new_indexer = {}
for key in old_indexer:
if key in resample_dim:
new_indexer[key] = slice(int(old_indexer[key].start * resample_factor.get(key, 1)), int(old_indexer[key].stop * resample_factor.get(key, 1)))
elif key in core_dim:
new_indexer[key] = old_indexer[key]
model_output_on_batch = manual_transform(batch_da)
print(f"Batch {i}: {new_indexer} -> {model_output_on_batch.shape}")
print(f"Expected sum shape: {expected_sum.loc[new_indexer].shape}")
expected_sum.loc[new_indexer] += model_output_on_batch
expected_count.loc[new_indexer] += 1
expected_avg_data = expected_sum.data / expected_count.data
# --- Assert correctness ---
np.testing.assert_allclose(result_da.values, expected_avg_data, equal_nan=True)
Appending to test_predict_on_array.py
!pytest -v test_predict_on_array.py
============================= test session starts ==============================
platform linux -- Python 3.13.5, pytest-8.4.1, pluggy-1.6.0 -- /home/runner/micromamba/envs/cookbook-dev/bin/python3.13
cachedir: .pytest_cache
rootdir: /home/runner/work/xbatcher-deep-learning/xbatcher-deep-learning/notebooks
plugins: anyio-4.10.0
collecting ...
collected 7 items
test_predict_on_array.py::test_resample_coordinate[2.0-edges-expected0] PASSED [ 14%]
test_predict_on_array.py::test_resample_coordinate[0.5-edges-expected1] PASSED [ 28%]
test_predict_on_array.py::test_predict_on_array_all_models[model0-output_tensor_dim0-new_dim0-core_dim0-resample_dim0-<lambda>] PASSED [ 42%]
test_predict_on_array.py::test_predict_on_array_all_models[model1-output_tensor_dim1-new_dim1-core_dim1-resample_dim1-<lambda>]
PASSED [ 57%]
test_predict_on_array.py::test_predict_on_array_all_models[model2-output_tensor_dim2-new_dim2-core_dim2-resample_dim2-<lambda>] PASSED [ 71%]
test_predict_on_array.py::test_predict_on_array_all_models[model3-output_tensor_dim3-new_dim3-core_dim3-resample_dim3-<lambda>] PASSED [ 85%]
test_predict_on_array.py::test_predict_on_array_all_models[model4-output_tensor_dim4-new_dim4-core_dim4-resample_dim4-<lambda>] PASSED [100%]
=============================== warnings summary ===============================
test_predict_on_array.py::test_predict_on_array_all_models[model0-output_tensor_dim0-new_dim0-core_dim0-resample_dim0-<lambda>]
test_predict_on_array.py::test_predict_on_array_all_models[model1-output_tensor_dim1-new_dim1-core_dim1-resample_dim1-<lambda>]
test_predict_on_array.py::test_predict_on_array_all_models[model2-output_tensor_dim2-new_dim2-core_dim2-resample_dim2-<lambda>]
test_predict_on_array.py::test_predict_on_array_all_models[model3-output_tensor_dim3-new_dim3-core_dim3-resample_dim3-<lambda>]
test_predict_on_array.py::test_predict_on_array_all_models[model4-output_tensor_dim4-new_dim4-core_dim4-resample_dim4-<lambda>]
/home/runner/work/xbatcher-deep-learning/xbatcher-deep-learning/notebooks/test_predict_on_array.py:108: RuntimeWarning: invalid value encountered in divide
expected_avg_data = expected_sum.data / expected_count.data
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
======================== 7 passed, 5 warnings in 2.28s =========================