Skip to article frontmatterSkip to article content

Reconstructing Xarray Datasets from Model Outputs

This notebook addresses the process of reconstructing an xarray.DataArray from the output of a machine learning model. While the previous notebook focused on generating batches from xarray objects, this guide details the reverse process: assembling model outputs back into a coherent, labeled xarray object. This is a common requirement in scientific machine learning workflows, where the model output needs to be analyzed in its original spatial or temporal context.

We will examine a function that reassembles model outputs, including a detailed look at how an internal API of xbatcher can be used to map batch outputs back to their original coordinates.

Imports

import xarray as xr
import numpy as np
import torch
import xbatcher
from xbatcher.loaders.torch import MapDataset
from typing import Literal

from dummy_models import ExpandAlongAxis

Setup: Data, Batches, and a Dummy Model

We will begin by creating a sample xarray.DataArray and a BatchGenerator. We will also instantiate a dummy model that transforms the data, simulating a common machine learning scenario where the output dimensions differ from the input dimensions (e.g., super-resolution).

da = xr.DataArray(
    data=np.random.rand(50, 40).astype(np.float32),
    dims=("x", "y"),
    coords={"x": np.arange(50), "y": np.arange(40)},
)
da
Loading...

Next, we create the BatchGenerator.

bgen = xbatcher.BatchGenerator(da, input_dims={"x": 10, "y": 10})

For the model, we will use ExpandAlongAxis from dummy_models.py. This model upsamples the input along a specified axis, changing the dimensions of the data.

# The model will expand the 'x' dimension by a factor of 2
model = ExpandAlongAxis(ax=1, n_repeats=2)

The predict_on_array Function

The predict_on_array function (from functions.py) is designed to take batches from a BatchGenerator, pass them through a model, and reassemble the outputs. The following sections will break down this function and its helpers.

def _get_resample_factor(
    bgen: xbatcher.BatchGenerator,
    output_tensor_dim: dict[str, int],
    resample_dim: list[str]
):
    resample_factor = {}
    for dim in resample_dim:
        r = output_tensor_dim[dim] / bgen.input_dims[dim]
        is_int = (r == int(r))
        is_inv_int = (1/r == int(1/r)) if r != 0 else False
        assert is_int or is_inv_int, f"Resample ratio for dim '{dim}' must be an integer or its inverse."
        resample_factor[dim] = r

    return resample_factor

_get_resample_factor

This helper function calculates the resampling factor for each dimension. For example, if input batches have x=10 and the model outputs tensors with x=20, the resampling factor for x is 2. This is used to determine the dimensions of the final reconstructed array.

def _get_output_array_size(
    bgen: xbatcher.BatchGenerator,
    output_tensor_dim: dict[str, int],
    new_dim: list[str],
    core_dim: list[str],
    resample_dim: list[str]
):
    resample_factor = _get_resample_factor(bgen, output_tensor_dim, resample_dim)
    output_size = {}
    for key, size in output_tensor_dim.items():
        if key in new_dim:
            output_size[key] = output_tensor_dim[key]
        elif key in core_dim:
            if output_tensor_dim[key] != bgen.ds.sizes[key]:
                raise ValueError(
                    f"Axis {key} is a core dim, but the tensor size"
                    f"({output_tensor_dim[key]}) does not equal the "
                    f"source data array size ({bgen.ds.sizes[key]})."
                )
            output_size[key] = bgen.ds.sizes[key]
        elif key in resample_dim:
            temp_output_size = bgen.ds.sizes[key] * resample_factor[key]
            assert temp_output_size.is_integer(), f"Resampling for dim '{key}' results in non-integer size."
            output_size[key] = int(temp_output_size)
        else:
            raise ValueError(f"Axis {key} must be specified in one of new_dim, core_dim, or resample_dim") 
    return output_size

_get_output_array_size

This function determines the final size of the reconstructed array. It uses the resampling factor and also considers new_dim (dimensions that are new in the output) and core_dim (dimensions that are not batched over and remain unchanged).

def _resample_coordinate(
    coord: xr.DataArray,
    factor: float,
    mode: Literal["centers", "edges"]="edges"
) -> np.ndarray:
    assert len(coord.shape) == 1 and coord.shape[0] > 1
    assert (coord.shape[0] * factor).is_integer()
    old_step = (coord.data[1] - coord.data[0])
    offset = 0 if mode == "edges" else old_step / 2
    new_step = old_step / factor
    coord = coord - offset
    new_coord_end = coord.max().item() + old_step
    return np.arange(coord.min().item(), new_coord_end, step=new_step) + offset

_resample_coordinate

If the size of a dimension is changed, its coordinates must also be updated. This function handles the resampling of coordinates.

def _get_output_array_coordinates(
    src_da: xr.DataArray,
    output_array_dim: list[str],
    resample_factor: dict[str, int],
    resample_mode: Literal["centers", "edges"]="edges"
) -> dict[str, np.ndarray]:
    output_coords = {}
    for dim in output_array_dim:
        if dim in src_da.coords and dim in resample_factor:
            output_coords[dim] = _resample_coordinate(src_da[dim], resample_factor[dim], resample_mode)
        elif dim in src_da.coords:
            output_coords[dim] = src_da[dim].copy(deep=True).data
        else:
            continue
    return output_coords

_get_output_array_coordinates

This function generates a dictionary of the new coordinates for the output array.

def predict_on_array(
    dataset: MapDataset,
    model: torch.nn.Module,
    output_tensor_dim: dict[str, int],
    new_dim: list[str],
    core_dim: list[str],
    resample_dim: list[str],
    resample_mode: Literal["centers", "edges"]="edges",
    batch_size: int=16
) -> xr.DataArray:
    s_new = set(new_dim)
    s_core = set(core_dim)
    s_resample = set(resample_dim)

    if s_new & s_core or s_new & s_resample or s_core & s_resample:
        raise ValueError("new_dim, core_dim, and resample_dim must be disjoint sets.")

    bgen = dataset.X_generator

    resample_factor = _get_resample_factor(
        bgen,
        output_tensor_dim, 
        resample_dim
    )
    
    output_size = _get_output_array_size(
        bgen,
        output_tensor_dim,
        new_dim,
        core_dim,
        resample_dim
    )
            
    output_da = xr.DataArray(
        data=np.zeros(tuple(output_size.values())),
        dims=tuple(output_size.keys()),
    )
    output_n = xr.full_like(output_da, 0)
    
    loader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)

    for i, batch in enumerate(loader):
        input_tensor = batch[0] if isinstance(batch, (list, tuple)) else batch
        out_batch = model(input_tensor).detach().numpy()

        for ib in range(out_batch.shape[0]):
            global_index = (i * batch_size) + ib
            old_indexer = bgen._batch_selectors.selectors[global_index][0]
            new_indexer = {}
            for key in old_indexer:
                if key in resample_dim:
                    new_indexer[key] = slice(
                        int(old_indexer[key].start * resample_factor[key]),
                        int(old_indexer[key].stop * resample_factor[key])
                    )

            output_da.loc[new_indexer] += out_batch[ib, ...]
            output_n.loc[new_indexer] += 1

    output_da = output_da / output_n

    output_da = output_da.assign_coords(
        _get_output_array_coordinates(
            dataset.X_generator.ds, 
            list(output_tensor_dim.keys()), 
            resample_factor, 
            resample_mode
        )
    )

    return output_da

predict_on_array Internals

The key steps of this function are as follows:

  1. Initialization: An empty DataArray (output_da) is created with the final dimensions, along with a corresponding DataArray (output_n) to track the number of predictions for each element (for averaging in case of overlaps).
  2. Iteration: The function iterates through the DataLoader.
  3. The Internal API: The core of the reconstruction is bgen._batch_selectors.selectors[global_index]. This internal attribute of the BatchGenerator stores the slice objects for each batch, providing a map from the batch to the original DataArray’s coordinate space.
  4. Disclaimer: Accessing internal attributes such as _batch_selectors is not part of the public API and may change in future versions of xbatcher.
  5. Rescaling and Placing: The resampling factor is used to scale the slices, and .loc is used to place the model’s output into the correct location in output_da.
  6. Averaging and Coordinates: Finally, the predictions are averaged (if there were overlaps) and the new coordinates are assigned.

Reconstructing the Dataset

We will now use the predict_on_array function to reconstruct the dataset.

map_dataset = MapDataset(bgen)
reconstructed_da = predict_on_array(
    dataset=map_dataset,
    model=model,
    output_tensor_dim={"x": 20, "y": 10}, # The model doubles the x-dimension
    new_dim=[],
    core_dim=[],
    resample_dim=["x", "y"],
    batch_size=4
)
reconstructed_da
Loading...

The reconstructed DataArray has the upsampled x dimension. We can compare its shape to the original.

print(f"Original shape: {da.shape}")
print(f"Reconstructed shape: {reconstructed_da.shape}")
Original shape: (50, 40)
Reconstructed shape: (100, 40)

The reconstructed array has twice the number of elements in the x dimension, as expected. This concludes the demonstration of reconstructing an xarray.Dataset from model outputs using xbatcher.