Skip to article frontmatterSkip to article content

Reconstructing Xarray Datasets from Model Outputs


Overview

This notebook addresses the process of reconstructing an xarray.DataArray from the output of a machine learning model. While the previous notebook focused on generating batches from xarray objects, this guide details the reverse process: assembling model outputs back into a coherent, labeled xarray object. This is a common requirement in scientific machine learning workflows, where the model output needs to be analyzed in its original spatial or temporal context.

We will examine a function that reassembles model outputs, including a detailed look at how an internal API of xbatcher can be used to map batch outputs back to their original coordinates.

Prerequisites

ConceptsImportanceNotes
Intro to XarrayNecessaryArray indexing
Loading Batches from XarrayHelpfulPyTorch DataLoader API
PyTorch fundamentalsHelpfulModel training loop

Imports

import xarray as xr
import numpy as np
import xbatcher
from xbatcher.loaders.torch import MapDataset

from dummy_models import ExpandAlongAxis

from functions import predict_on_array

Setup: Data, Batches, and a Dummy Model

We will begin by creating a sample xarray.DataArray and a BatchGenerator. We will also instantiate a dummy model that transforms the data, simulating a common machine learning scenario where the output dimensions differ from the input dimensions (e.g., super-resolution).

da = xr.DataArray(
    data=np.random.rand(50, 40).astype(np.float32),
    dims=("x", "y"),
    coords={"x": np.arange(50), "y": np.arange(40)},
)
da
Loading...

Next, we create the BatchGenerator.

bgen = xbatcher.BatchGenerator(da, input_dims={"x": 10, "y": 10})

For the model, we will use ExpandAlongAxis from dummy_models.py. This model upsamples the input along a specified axis, changing the dimensions of the data.

# The model will expand the 'x' dimension by a factor of 2
model = ExpandAlongAxis(ax=1, n_repeats=2)

Reconstructing the Dataset

We will now use the predict_on_array function to reconstruct the dataset. The most important part of using this function is correctly specifying the arguments new_dim, core_dim, and resample_dim. These lists all contain dimensions given in output_tensor_dim and help us get an idea of how model output compares to the input data. In general:

  • new_dim: Tensor dimensions that do not appear at all in the original xarray object.
  • core_dim: Tensor dimensions that are present in the original xarray object, but are not used for batch generation. We assume that all elements of this dimension in the xarray object flow through the model to the output. Coordinates are simply copied from the original xarray object.
  • resample_dim: Tensor dimensions that are present in the original xarray object and used for batch generation. These dimensions are allowed to change size (but see below note) and coordinates are resampled accordingly in the reconstructed array.

Let’s apply these rules to our present example. The batch generator creates tensors of size (x=10, y=10) and the dummy model makes tensors of size (x=20, y=10). In this case, all tensor dimensions are present in the original data array and are used for batch generation. Therefore, both x and y go in resample_dim. Now that all tensor dimensions are accounted for, we can simply leave new_dim and core_dim as empty lists.

map_dataset = MapDataset(bgen)
reconstructed_da = predict_on_array(
    dataset=map_dataset,
    model=model,
    output_tensor_dim={"x": 20, "y": 10}, # The model doubles the x-dimension
    new_dim=[],
    core_dim=[],
    resample_dim=["x", "y"],
    batch_size=4,
    progress_bar=False
)
reconstructed_da
Loading...

The reconstructed DataArray has the upsampled x dimension. We can compare its shape to the original.

print(f"Original shape: {da.shape}")
print(f"Reconstructed shape: {reconstructed_da.shape}")
Original shape: (50, 40)
Reconstructed shape: (100, 40)

The reconstructed array has twice the number of elements in the x dimension, as expected.