Skip to article frontmatterSkip to article content

Dataloading from Xarray Datasets

Working with large, multi-dimensional datasets, common in fields like climate science and oceanography, presents a significant challenge when preparing data for machine learning models. The xbatcher library is designed to simplify this crucial preprocessing step.

xbatcher is a Python package that facilitates the generation of data batches from xarray objects for machine learning. It serves as a bridge between the labeled, multi-dimensional data structures of xarray and the tensor-based inputs required by deep learning frameworks such as PyTorch and TensorFlow.

This guide provides an introduction to the fundamentals of xbatcher. We will cover how to create a BatchGenerator, customize it for specific needs, and prepare the resulting data for integration with a PyTorch model.

Imports

import xarray as xr
import numpy as np
import torch
import xbatcher
from xbatcher.loaders.torch import MapDataset, IterableDataset

Creating a Sample Dataset

To begin, we will create a sample xarray.Dataset. This allows us to focus on the mechanics of xbatcher without the overhead of a specific real-world dataset. This sample can be replaced by any xarray.Dataset loaded from a file (e.g., NetCDF, Zarr).

ds = xr.Dataset(
    {
        "temperature": (("x", "y", "time"), np.random.rand(100, 100, 50)),
        "precipitation": (("x", "y", "time"), np.random.rand(100, 100, 50)),
    },
    coords={
        "x": np.arange(100),
        "y": np.arange(100),
        "time": np.arange(50),
    },
)
ds
Loading...

The dataset contains two variables, temperature and precipitation, and three dimensions: x, y, and time. We will now use xbatcher to generate batches from this dataset.

The BatchGenerator

The BatchGenerator is the core component of xbatcher. It is a Python generator that yields batches of data from an xarray object.

bgen = xbatcher.BatchGenerator(ds, input_dims={"x": 10, "y": 10})

The BatchGenerator is initialized with the dataset and the input_dims parameter. input_dims specifies the size of the batches along each dimension. In this case, we are creating batches of size 10x10 along the x and y dimensions. The time dimension is not specified, so xbatcher will yield batches that include all time steps.

Let’s inspect the first batch generated.

first_batch = next(iter(bgen))
first_batch
Loading...

The first batch has dimensions x=10, y=10, and time=50, as expected. The BatchGenerator will yield 100 batches in total (10 batches in the x-direction * 10 batches in the y-direction).

print(f"The BatchGenerator contains {len(bgen)} batches.")
The BatchGenerator contains 100 batches.

Overlapping Batches with input_overlap

In many applications, it is useful to have overlapping batches to provide context from neighboring data points. The input_overlap parameter allows for this.

bgen_overlap = xbatcher.BatchGenerator(
    ds, 
    input_dims={"x": 10, "y": 10}, 
    input_overlap={"x": 2, "y": 2}
)
first_batch_overlap = next(iter(bgen_overlap))
first_batch_overlap
Loading...

The input_overlap parameter specifies the number of elements to overlap between consecutive batches. The size of the batches themselves does not change. Let’s verify this by inspecting the coordinates of the first two batches.

print(f"Batch 1 y-coords: {bgen_overlap[0].y.values}, Batch 1 x-coords: {bgen_overlap[0].x.values}")
print(f"Batch 2 y-coords: {bgen_overlap[1].y.values}, Batch 2 x-coords: {bgen_overlap[1].x.values}")
print(f"Batch 3 y-coords: {bgen_overlap[2].y.values}, Batch 3 x-coords: {bgen_overlap[2].x.values}")
print(f"Batch 13 y-coords: {bgen_overlap[12].y.values}, Batch 13 x-coords: {bgen_overlap[12].x.values}")
print(f"Batch 14 y-coords: {bgen_overlap[13].y.values}, Batch 11 x-coords: {bgen_overlap[13].x.values}")
Batch 1 y-coords: [0 1 2 3 4 5 6 7 8 9], Batch 1 x-coords: [0 1 2 3 4 5 6 7 8 9]
Batch 2 y-coords: [ 8  9 10 11 12 13 14 15 16 17], Batch 2 x-coords: [0 1 2 3 4 5 6 7 8 9]
Batch 3 y-coords: [16 17 18 19 20 21 22 23 24 25], Batch 3 x-coords: [0 1 2 3 4 5 6 7 8 9]
Batch 13 y-coords: [0 1 2 3 4 5 6 7 8 9], Batch 13 x-coords: [ 8  9 10 11 12 13 14 15 16 17]
Batch 14 y-coords: [ 8  9 10 11 12 13 14 15 16 17], Batch 11 x-coords: [ 8  9 10 11 12 13 14 15 16 17]

As you can see, the second batch starts at y=8, which is an overlap of 2 elements with the first batch, which ends at y=9.

Integration with PyTorch

xbatcher provides MapDataset and IterableDataset to wrap the BatchGenerator for use with PyTorch.

MapDataset vs. IterableDataset

  • MapDataset: Implements __getitem__ and __len__, allowing for random access to data samples. This is the most common type of dataset in PyTorch.
  • IterableDataset: Implements __iter__, and is suitable for very large datasets that may not fit into memory, as it streams data.

We will use MapDataset for this example.

bgen[0].temperature.shape
bgen[0].precipitation.shape
(10, 10, 50)
def patch_to_tensor(patch):
    temp_patch = torch.tensor(patch.temperature.data)
    prcp_patch = torch.tensor(patch.precipitation.data)
    stacked_patch = torch.stack((temp_patch, prcp_patch), dim=0)
    patch = stacked_patch
    patch = torch.nan_to_num(patch)
    # patch = torch.unsqueeze(patch, 0)
    patch = patch.float()
    return patch

map_ds = MapDataset(bgen, transform=patch_to_tensor)

The MapDataset can then be used with a PyTorch DataLoader, which provides utilities for shuffling, batching, and multiprocessing.

dataloader = torch.utils.data.DataLoader(map_ds, batch_size=4)

Inspecting a batch from the DataLoader reveals a batch of PyTorch tensors.

torch_batch = next(iter(dataloader))
torch_batch.shape
torch.Size([4, 2, 10, 10, 50])

The DataLoader has stacked 4 of the xbatcher batches, creating a new batch dimension of size 4. The data is now ready for use in a PyTorch model.

In the next notebook, we will explore how to reconstruct an xarray.Dataset from a model’s output.