Working with large, multi-dimensional datasets, common in fields like climate science and oceanography, presents a significant challenge when preparing data for machine learning models. The xbatcher
library is designed to simplify this crucial preprocessing step.
xbatcher
is a Python package that facilitates the generation of data batches from xarray
objects for machine learning. It serves as a bridge between the labeled, multi-dimensional data structures of xarray
and the tensor-based inputs required by deep learning frameworks such as PyTorch and TensorFlow.
This guide provides an introduction to the fundamentals of xbatcher
. We will cover how to create a BatchGenerator
, customize it for specific needs, and prepare the resulting data for integration with a PyTorch model.
Imports¶
import xarray as xr
import numpy as np
import torch
import xbatcher
from xbatcher.loaders.torch import MapDataset, IterableDataset
Creating a Sample Dataset¶
To begin, we will create a sample xarray.Dataset
. This allows us to focus on the mechanics of xbatcher
without the overhead of a specific real-world dataset. This sample can be replaced by any xarray.Dataset
loaded from a file (e.g., NetCDF, Zarr).
ds = xr.Dataset(
{
"temperature": (("x", "y", "time"), np.random.rand(100, 100, 50)),
"precipitation": (("x", "y", "time"), np.random.rand(100, 100, 50)),
},
coords={
"x": np.arange(100),
"y": np.arange(100),
"time": np.arange(50),
},
)
ds
The dataset contains two variables, temperature
and precipitation
, and three dimensions: x
, y
, and time
. We will now use xbatcher
to generate batches from this dataset.
The BatchGenerator
¶
The BatchGenerator
is the core component of xbatcher
. It is a Python generator that yields batches of data from an xarray
object.
bgen = xbatcher.BatchGenerator(ds, input_dims={"x": 10, "y": 10})
The BatchGenerator
is initialized with the dataset and the input_dims
parameter. input_dims
specifies the size of the batches along each dimension. In this case, we are creating batches of size 10x10 along the x
and y
dimensions. The time
dimension is not specified, so xbatcher
will yield batches that include all time steps.
Let’s inspect the first batch generated.
first_batch = next(iter(bgen))
first_batch
The first batch has dimensions x=10
, y=10
, and time=50
, as expected. The BatchGenerator
will yield 100 batches in total (10 batches in the x-direction * 10 batches in the y-direction).
print(f"The BatchGenerator contains {len(bgen)} batches.")
The BatchGenerator contains 100 batches.
Overlapping Batches with input_overlap
¶
In many applications, it is useful to have overlapping batches to provide context from neighboring data points. The input_overlap
parameter allows for this.
bgen_overlap = xbatcher.BatchGenerator(
ds,
input_dims={"x": 10, "y": 10},
input_overlap={"x": 2, "y": 2}
)
first_batch_overlap = next(iter(bgen_overlap))
first_batch_overlap
The input_overlap
parameter specifies the number of elements to overlap between consecutive batches. The size of the batches themselves does not change. Let’s verify this by inspecting the coordinates of the first two batches.
print(f"Batch 1 y-coords: {bgen_overlap[0].y.values}, Batch 1 x-coords: {bgen_overlap[0].x.values}")
print(f"Batch 2 y-coords: {bgen_overlap[1].y.values}, Batch 2 x-coords: {bgen_overlap[1].x.values}")
print(f"Batch 3 y-coords: {bgen_overlap[2].y.values}, Batch 3 x-coords: {bgen_overlap[2].x.values}")
print(f"Batch 13 y-coords: {bgen_overlap[12].y.values}, Batch 13 x-coords: {bgen_overlap[12].x.values}")
print(f"Batch 14 y-coords: {bgen_overlap[13].y.values}, Batch 11 x-coords: {bgen_overlap[13].x.values}")
Batch 1 y-coords: [0 1 2 3 4 5 6 7 8 9], Batch 1 x-coords: [0 1 2 3 4 5 6 7 8 9]
Batch 2 y-coords: [ 8 9 10 11 12 13 14 15 16 17], Batch 2 x-coords: [0 1 2 3 4 5 6 7 8 9]
Batch 3 y-coords: [16 17 18 19 20 21 22 23 24 25], Batch 3 x-coords: [0 1 2 3 4 5 6 7 8 9]
Batch 13 y-coords: [0 1 2 3 4 5 6 7 8 9], Batch 13 x-coords: [ 8 9 10 11 12 13 14 15 16 17]
Batch 14 y-coords: [ 8 9 10 11 12 13 14 15 16 17], Batch 11 x-coords: [ 8 9 10 11 12 13 14 15 16 17]
As you can see, the second batch starts at y=8
, which is an overlap of 2 elements with the first batch, which ends at y=9
.
Integration with PyTorch¶
xbatcher
provides MapDataset
and IterableDataset
to wrap the BatchGenerator
for use with PyTorch.
MapDataset
vs. IterableDataset
¶
MapDataset
: Implements__getitem__
and__len__
, allowing for random access to data samples. This is the most common type of dataset in PyTorch.IterableDataset
: Implements__iter__
, and is suitable for very large datasets that may not fit into memory, as it streams data.
We will use MapDataset
for this example.
bgen[0].temperature.shape
bgen[0].precipitation.shape
(10, 10, 50)
def patch_to_tensor(patch):
temp_patch = torch.tensor(patch.temperature.data)
prcp_patch = torch.tensor(patch.precipitation.data)
stacked_patch = torch.stack((temp_patch, prcp_patch), dim=0)
patch = stacked_patch
patch = torch.nan_to_num(patch)
# patch = torch.unsqueeze(patch, 0)
patch = patch.float()
return patch
map_ds = MapDataset(bgen, transform=patch_to_tensor)
The MapDataset
can then be used with a PyTorch DataLoader
, which provides utilities for shuffling, batching, and multiprocessing.
dataloader = torch.utils.data.DataLoader(map_ds, batch_size=4)
Inspecting a batch from the DataLoader
reveals a batch of PyTorch tensors.
torch_batch = next(iter(dataloader))
torch_batch.shape
torch.Size([4, 2, 10, 10, 50])
The DataLoader
has stacked 4 of the xbatcher
batches, creating a new batch
dimension of size 4. The data is now ready for use in a PyTorch model.
In the next notebook, we will explore how to reconstruct an xarray.Dataset
from a model’s output.