Overview¶
This notebook addresses the process of reconstructing an xarray.DataArray
from the output of a machine learning model. While the previous notebook focused on generating batches from xarray
objects, this guide details the reverse process: assembling model outputs back into a coherent, labeled xarray
object. This is a common requirement in scientific machine learning workflows, where the model output needs to be analyzed in its original spatial or temporal context.
We will examine a function that reassembles model outputs, including a detailed look at how an internal API of xbatcher
can be used to map batch outputs back to their original coordinates.
Prerequisites¶
Concepts | Importance | Notes |
---|---|---|
Intro to Xarray | Necessary | Array indexing |
Loading Batches from Xarray | Helpful | PyTorch DataLoader API |
PyTorch fundamentals | Helpful | Model training loop |
Imports¶
import xarray as xr
import numpy as np
import xbatcher
from xbatcher.loaders.torch import MapDataset
from dummy_models import ExpandAlongAxis
from functions import predict_on_array
Setup: Data, Batches, and a Dummy Model¶
We will begin by creating a sample xarray.DataArray
and a BatchGenerator
. We will also instantiate a dummy model that transforms the data, simulating a common machine learning scenario where the output dimensions differ from the input dimensions (e.g., super-resolution).
da = xr.DataArray(
data=np.random.rand(50, 40).astype(np.float32),
dims=("x", "y"),
coords={"x": np.arange(50), "y": np.arange(40)},
)
da
Next, we create the BatchGenerator
.
bgen = xbatcher.BatchGenerator(da, input_dims={"x": 10, "y": 10})
For the model, we will use ExpandAlongAxis
from dummy_models.py
. This model upsamples the input along a specified axis, changing the dimensions of the data.
# The model will expand the 'x' dimension by a factor of 2
model = ExpandAlongAxis(ax=1, n_repeats=2)
Reconstructing the Dataset¶
We will now use the predict_on_array
function to reconstruct the dataset. The most important part of using this function is correctly specifying the arguments new_dim
, core_dim
, and resample_dim
. These lists all contain dimensions given in output_tensor_dim
and help us get an idea of how model output compares to the input data. In general:
new_dim
: Tensor dimensions that do not appear at all in the original xarray object.core_dim
: Tensor dimensions that are present in the original xarray object, but are not used for batch generation. We assume that all elements of this dimension in the xarray object flow through the model to the output. Coordinates are simply copied from the original xarray object.resample_dim
: Tensor dimensions that are present in the original xarray object and used for batch generation. These dimensions are allowed to change size (but see below note) and coordinates are resampled accordingly in the reconstructed array.
Let’s apply these rules to our present example. The batch generator creates tensors of size (x=10, y=10)
and the dummy model makes tensors of size (x=20, y=10)
. In this case, all tensor dimensions are present in the original data array and are used for batch generation. Therefore, both x
and y
go in resample_dim
. Now that all tensor dimensions are accounted for, we can simply leave new_dim
and core_dim
as empty lists.
map_dataset = MapDataset(bgen)
reconstructed_da = predict_on_array(
dataset=map_dataset,
model=model,
output_tensor_dim={"x": 20, "y": 10}, # The model doubles the x-dimension
new_dim=[],
core_dim=[],
resample_dim=["x", "y"],
batch_size=4,
progress_bar=False
)
reconstructed_da
The reconstructed DataArray
has the upsampled x
dimension. We can compare its shape to the original.
print(f"Original shape: {da.shape}")
print(f"Reconstructed shape: {reconstructed_da.shape}")
Original shape: (50, 40)
Reconstructed shape: (100, 40)
The reconstructed array has twice the number of elements in the x
dimension, as expected.