This notebook addresses the process of reconstructing an xarray.DataArray
from the output of a machine learning model. While the previous notebook focused on generating batches from xarray
objects, this guide details the reverse process: assembling model outputs back into a coherent, labeled xarray
object. This is a common requirement in scientific machine learning workflows, where the model output needs to be analyzed in its original spatial or temporal context.
We will examine a function that reassembles model outputs, including a detailed look at how an internal API of xbatcher
can be used to map batch outputs back to their original coordinates.
Imports¶
import xarray as xr
import numpy as np
import torch
import xbatcher
from xbatcher.loaders.torch import MapDataset
from typing import Literal
from dummy_models import ExpandAlongAxis
Setup: Data, Batches, and a Dummy Model¶
We will begin by creating a sample xarray.DataArray
and a BatchGenerator
. We will also instantiate a dummy model that transforms the data, simulating a common machine learning scenario where the output dimensions differ from the input dimensions (e.g., super-resolution).
da = xr.DataArray(
data=np.random.rand(50, 40).astype(np.float32),
dims=("x", "y"),
coords={"x": np.arange(50), "y": np.arange(40)},
)
da
Next, we create the BatchGenerator
.
bgen = xbatcher.BatchGenerator(da, input_dims={"x": 10, "y": 10})
For the model, we will use ExpandAlongAxis
from dummy_models.py
. This model upsamples the input along a specified axis, changing the dimensions of the data.
# The model will expand the 'x' dimension by a factor of 2
model = ExpandAlongAxis(ax=1, n_repeats=2)
The predict_on_array
Function¶
The predict_on_array
function (from functions.py
) is designed to take batches from a BatchGenerator
, pass them through a model, and reassemble the outputs. The following sections will break down this function and its helpers.
def _get_resample_factor(
bgen: xbatcher.BatchGenerator,
output_tensor_dim: dict[str, int],
resample_dim: list[str]
):
resample_factor = {}
for dim in resample_dim:
r = output_tensor_dim[dim] / bgen.input_dims[dim]
is_int = (r == int(r))
is_inv_int = (1/r == int(1/r)) if r != 0 else False
assert is_int or is_inv_int, f"Resample ratio for dim '{dim}' must be an integer or its inverse."
resample_factor[dim] = r
return resample_factor
_get_resample_factor
¶
This helper function calculates the resampling factor for each dimension. For example, if input batches have x=10
and the model outputs tensors with x=20
, the resampling factor for x
is 2. This is used to determine the dimensions of the final reconstructed array.
def _get_output_array_size(
bgen: xbatcher.BatchGenerator,
output_tensor_dim: dict[str, int],
new_dim: list[str],
core_dim: list[str],
resample_dim: list[str]
):
resample_factor = _get_resample_factor(bgen, output_tensor_dim, resample_dim)
output_size = {}
for key, size in output_tensor_dim.items():
if key in new_dim:
output_size[key] = output_tensor_dim[key]
elif key in core_dim:
if output_tensor_dim[key] != bgen.ds.sizes[key]:
raise ValueError(
f"Axis {key} is a core dim, but the tensor size"
f"({output_tensor_dim[key]}) does not equal the "
f"source data array size ({bgen.ds.sizes[key]})."
)
output_size[key] = bgen.ds.sizes[key]
elif key in resample_dim:
temp_output_size = bgen.ds.sizes[key] * resample_factor[key]
assert temp_output_size.is_integer(), f"Resampling for dim '{key}' results in non-integer size."
output_size[key] = int(temp_output_size)
else:
raise ValueError(f"Axis {key} must be specified in one of new_dim, core_dim, or resample_dim")
return output_size
_get_output_array_size
¶
This function determines the final size of the reconstructed array. It uses the resampling factor and also considers new_dim
(dimensions that are new in the output) and core_dim
(dimensions that are not batched over and remain unchanged).
def _resample_coordinate(
coord: xr.DataArray,
factor: float,
mode: Literal["centers", "edges"]="edges"
) -> np.ndarray:
assert len(coord.shape) == 1 and coord.shape[0] > 1
assert (coord.shape[0] * factor).is_integer()
old_step = (coord.data[1] - coord.data[0])
offset = 0 if mode == "edges" else old_step / 2
new_step = old_step / factor
coord = coord - offset
new_coord_end = coord.max().item() + old_step
return np.arange(coord.min().item(), new_coord_end, step=new_step) + offset
_resample_coordinate
¶
If the size of a dimension is changed, its coordinates must also be updated. This function handles the resampling of coordinates.
def _get_output_array_coordinates(
src_da: xr.DataArray,
output_array_dim: list[str],
resample_factor: dict[str, int],
resample_mode: Literal["centers", "edges"]="edges"
) -> dict[str, np.ndarray]:
output_coords = {}
for dim in output_array_dim:
if dim in src_da.coords and dim in resample_factor:
output_coords[dim] = _resample_coordinate(src_da[dim], resample_factor[dim], resample_mode)
elif dim in src_da.coords:
output_coords[dim] = src_da[dim].copy(deep=True).data
else:
continue
return output_coords
_get_output_array_coordinates
¶
This function generates a dictionary of the new coordinates for the output array.
def predict_on_array(
dataset: MapDataset,
model: torch.nn.Module,
output_tensor_dim: dict[str, int],
new_dim: list[str],
core_dim: list[str],
resample_dim: list[str],
resample_mode: Literal["centers", "edges"]="edges",
batch_size: int=16
) -> xr.DataArray:
s_new = set(new_dim)
s_core = set(core_dim)
s_resample = set(resample_dim)
if s_new & s_core or s_new & s_resample or s_core & s_resample:
raise ValueError("new_dim, core_dim, and resample_dim must be disjoint sets.")
bgen = dataset.X_generator
resample_factor = _get_resample_factor(
bgen,
output_tensor_dim,
resample_dim
)
output_size = _get_output_array_size(
bgen,
output_tensor_dim,
new_dim,
core_dim,
resample_dim
)
output_da = xr.DataArray(
data=np.zeros(tuple(output_size.values())),
dims=tuple(output_size.keys()),
)
output_n = xr.full_like(output_da, 0)
loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)
for i, batch in enumerate(loader):
input_tensor = batch[0] if isinstance(batch, (list, tuple)) else batch
out_batch = model(input_tensor).detach().numpy()
for ib in range(out_batch.shape[0]):
global_index = (i * batch_size) + ib
old_indexer = bgen._batch_selectors.selectors[global_index][0]
new_indexer = {}
for key in old_indexer:
if key in resample_dim:
new_indexer[key] = slice(
int(old_indexer[key].start * resample_factor[key]),
int(old_indexer[key].stop * resample_factor[key])
)
output_da.loc[new_indexer] += out_batch[ib, ...]
output_n.loc[new_indexer] += 1
output_da = output_da / output_n
output_da = output_da.assign_coords(
_get_output_array_coordinates(
dataset.X_generator.ds,
list(output_tensor_dim.keys()),
resample_factor,
resample_mode
)
)
return output_da
predict_on_array
Internals¶
The key steps of this function are as follows:
- Initialization: An empty
DataArray
(output_da
) is created with the final dimensions, along with a correspondingDataArray
(output_n
) to track the number of predictions for each element (for averaging in case of overlaps). - Iteration: The function iterates through the
DataLoader
. - The Internal API: The core of the reconstruction is
bgen._batch_selectors.selectors[global_index]
. This internal attribute of theBatchGenerator
stores the slice objects for each batch, providing a map from the batch to the originalDataArray
’s coordinate space. - Disclaimer: Accessing internal attributes such as
_batch_selectors
is not part of the public API and may change in future versions ofxbatcher
. - Rescaling and Placing: The resampling factor is used to scale the slices, and
.loc
is used to place the model’s output into the correct location inoutput_da
. - Averaging and Coordinates: Finally, the predictions are averaged (if there were overlaps) and the new coordinates are assigned.
Reconstructing the Dataset¶
We will now use the predict_on_array
function to reconstruct the dataset.
map_dataset = MapDataset(bgen)
reconstructed_da = predict_on_array(
dataset=map_dataset,
model=model,
output_tensor_dim={"x": 20, "y": 10}, # The model doubles the x-dimension
new_dim=[],
core_dim=[],
resample_dim=["x", "y"],
batch_size=4
)
reconstructed_da
The reconstructed DataArray
has the upsampled x
dimension. We can compare its shape to the original.
print(f"Original shape: {da.shape}")
print(f"Reconstructed shape: {reconstructed_da.shape}")
Original shape: (50, 40)
Reconstructed shape: (100, 40)
The reconstructed array has twice the number of elements in the x
dimension, as expected. This concludes the demonstration of reconstructing an xarray.Dataset
from model outputs using xbatcher
.