Skip to article frontmatterSkip to article content

xbatcher for Machine Learning Part 1

Introduction

Here, we will be covering how to use xbatcher with Keras/Tensorflow convolutional neural network (CNN) models.

Prerequisites:

ConceptsImportanceNotes
XarrayNecessary
Keras/TensorflowStrongly RecommendedNot strictly needed to understand this tutorial

This notebook replicates the work of Sinha and Abernathey, 2021, where the goal is to use a CNN to learn ocean surface currents (which are usually inferred diagnostically or modelled) from variables that can be observed directly, like sea surface temperature (SST) or wind stress.

Can we learn to predict ocean currents with just one snapshot of data?


Imports

To start, let’s import some libraries we’ll need. The important libraries here are numpy, xarray, xbatcher and tensorflow, while most of the others aren’t strictly necessary.

import numpy as np
import xarray as xr

from dataclasses import dataclass
from typing import Iterable

from matplotlib import pyplot as plt
from IPython.display import clear_output
import tensorflow as tf
import xbatcher as xb

Designing Scenarios

We want to experiment with different neural network models by providing different inputs, and perhaps by playing with whether or not we run them through a convolutional layer. There are a lot of possibilities here, and if we approach it haphazardly, we’ll end up with a mess of scattered experiments and results mixed in with other code.

Instead, we can be more systematic about it. We know we want to define an individual scenario once, and then have it stay constant through the workflow. This way, there will be no complexities later on about whether we’re referring to the right dataset, etc. With that in mind, we should use a dataclass. We want something minimal here, just enough to store the names of variables we’re interested in.

What is the structure of each experiment? We want some input variables to be run through a 2D convolutional layer, while some other inputs will be passed through directly to the dense part of the neural network. Both of these can be lists of strings, so we define conv_var and input_var as Iterable[str].

Likewise, we have more than one target, so we define the target item as Iterable[str] as well. Outside of the Scenario dataclass, we define target as a list: ['U', 'V']. Since we’re only interested in learning the currents, this won’t change.

Finally, we need to name each scenario something distinct, so when we create data subsets for training, testing, and prediction, we can recover them later.

@dataclass
class Scenario:
    conv_var: Iterable[str]
    input_var: Iterable[str]
    target: Iterable[str]
    name: str
target = ['U', 'V']
sc1 = Scenario(['SSH'],             ['TAUX', 'TAUY'], target, name = "derp")
sc5 = Scenario(['SSH', 'SST'], ['X', 'TAUX', 'TAUY'], target, name = "herp")

Data and Preprocessing

For our dataset, we will be using ocean data from a high-resolution CESM POP model.

We have some necessary I/O routines, but they aren’t central to our problem, aside from the addtion of the new variables X, Y, Z, dx and dy, which represent Euclidean positions and distances between grid points.

You can have a look in the notebook below if you’re curious about it.

%run ./surface_currents_prep.ipynb

From this notebook, we get a few new functions.

  • prepare_data takes a scenario, as well as the time slices for training, testing, and prediction we are interested in, and the time slice we’ll use for the NaN mask. It adds the new grid variables, and then stores each slice in a new zarr store that we can access later. This speeds up future I/O, which is helpful when modifying the model. Each scenario is stored separately.
  • load_training_data loads the training data created for the scenario passed to it.
  • load_test_data loads the testing data created for the scenario passed to it.
  • load_predict_data loads the prediction input data created for the scenario passed to it.

You can comment out prepare_data after you’ve run it once, it will save time if you rerun the whole notebook again.

prepare_data(sc5, 200, 1000, 1000, 200)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
Cell In[6], line 1
----> 1 prepare_data(sc5, 200, 1000, 1000, 200)

File /tmp/ipykernel_3828/3337232756.py:3, in prepare_data(sc, training_time, test_time, predict_time, mask_time)
      1 def prepare_data(sc, training_time, test_time, predict_time, mask_time=11):
      2     cat = open_catalog("https://raw.githubusercontent.com/pangeo-data/pangeo-datastore/master/intake-catalogs/ocean/CESM_POP.yaml")
----> 3     ds  = cat["CESM_POP_hires_control"].to_dask()
      4     ds = ds.rename({'U1_1':'U', 'V1_1':'V', 'TAUX_2':'TAUX', 'TAUY_2':'TAUY', 'SSH_2':'SSH', 'ULONG':'XU', 'ULAT':'YU'})
      5     ds = add_grid(ds)

File ~/micromamba/envs/xbatcher-ML-1-cookbook-dev/lib/python3.12/site-packages/intake_xarray/base.py:8, in IntakeXarraySourceAdapter.to_dask(self)
      6 def to_dask(self):
      7     if "chunks" not in self.reader.kwargs:
----> 8         return self.reader(chunks={}).read()
      9     else:
     10         return self.reader.read()

File ~/micromamba/envs/xbatcher-ML-1-cookbook-dev/lib/python3.12/site-packages/intake/readers/readers.py:121, in BaseReader.read(self, *args, **kwargs)
    119 kw.update(kwargs)
    120 args = kw.pop("args", ()) or args
--> 121 return self._read(*args, **kw)

File ~/micromamba/envs/xbatcher-ML-1-cookbook-dev/lib/python3.12/site-packages/intake/readers/readers.py:1327, in XArrayDatasetReader._read(self, data, open_local, **kw)
   1325         f = fsspec.open(data.url, **(data.storage_options or {})).open()
   1326         return open_dataset(f, **kw)
-> 1327 return open_dataset(data.url, **kw)

File ~/micromamba/envs/xbatcher-ML-1-cookbook-dev/lib/python3.12/site-packages/xarray/backends/api.py:715, in open_dataset(filename_or_obj, engine, chunks, cache, decode_cf, mask_and_scale, decode_times, decode_timedelta, use_cftime, concat_characters, decode_coords, drop_variables, create_default_indexes, inline_array, chunked_array_type, from_array_kwargs, backend_kwargs, **kwargs)
    703 decoders = _resolve_decoders_kwargs(
    704     decode_cf,
    705     open_backend_dataset_parameters=backend.open_dataset_parameters,
   (...)    711     decode_coords=decode_coords,
    712 )
    714 overwrite_encoded_chunks = kwargs.pop("overwrite_encoded_chunks", None)
--> 715 backend_ds = backend.open_dataset(
    716     filename_or_obj,
    717     drop_variables=drop_variables,
    718     **decoders,
    719     **kwargs,
    720 )
    721 ds = _dataset_from_backend_dataset(
    722     backend_ds,
    723     filename_or_obj,
   (...)    734     **kwargs,
    735 )
    736 return ds

File ~/micromamba/envs/xbatcher-ML-1-cookbook-dev/lib/python3.12/site-packages/xarray/backends/zarr.py:1587, in ZarrBackendEntrypoint.open_dataset(self, filename_or_obj, mask_and_scale, decode_times, concat_characters, decode_coords, drop_variables, use_cftime, decode_timedelta, group, mode, synchronizer, consolidated, chunk_store, storage_options, zarr_version, zarr_format, store, engine, use_zarr_fill_value_as_mask, cache_members)
   1585 filename_or_obj = _normalize_path(filename_or_obj)
   1586 if not store:
-> 1587     store = ZarrStore.open_group(
   1588         filename_or_obj,
   1589         group=group,
   1590         mode=mode,
   1591         synchronizer=synchronizer,
   1592         consolidated=consolidated,
   1593         consolidate_on_close=False,
   1594         chunk_store=chunk_store,
   1595         storage_options=storage_options,
   1596         zarr_version=zarr_version,
   1597         use_zarr_fill_value_as_mask=None,
   1598         zarr_format=zarr_format,
   1599         cache_members=cache_members,
   1600     )
   1602 store_entrypoint = StoreBackendEntrypoint()
   1603 with close_on_error(store):

File ~/micromamba/envs/xbatcher-ML-1-cookbook-dev/lib/python3.12/site-packages/xarray/backends/zarr.py:664, in ZarrStore.open_group(cls, store, mode, synchronizer, group, consolidated, consolidate_on_close, chunk_store, storage_options, append_dim, write_region, safe_chunks, align_chunks, zarr_version, zarr_format, use_zarr_fill_value_as_mask, write_empty, cache_members)
    638 @classmethod
    639 def open_group(
    640     cls,
   (...)    657     cache_members: bool = True,
    658 ):
    659     (
    660         zarr_group,
    661         consolidate_on_close,
    662         close_store_on_close,
    663         use_zarr_fill_value_as_mask,
--> 664     ) = _get_open_params(
    665         store=store,
    666         mode=mode,
    667         synchronizer=synchronizer,
    668         group=group,
    669         consolidated=consolidated,
    670         consolidate_on_close=consolidate_on_close,
    671         chunk_store=chunk_store,
    672         storage_options=storage_options,
    673         zarr_version=zarr_version,
    674         use_zarr_fill_value_as_mask=use_zarr_fill_value_as_mask,
    675         zarr_format=zarr_format,
    676     )
    678     return cls(
    679         zarr_group,
    680         mode,
   (...)    689         cache_members=cache_members,
    690     )

File ~/micromamba/envs/xbatcher-ML-1-cookbook-dev/lib/python3.12/site-packages/xarray/backends/zarr.py:1791, in _get_open_params(store, mode, synchronizer, group, consolidated, consolidate_on_close, chunk_store, storage_options, zarr_version, use_zarr_fill_value_as_mask, zarr_format)
   1787 group = open_kwargs.pop("path")
   1789 if consolidated:
   1790     # TODO: an option to pass the metadata_key keyword
-> 1791     zarr_root_group = zarr.open_consolidated(store, **open_kwargs)
   1792 elif consolidated is None:
   1793     # same but with more error handling in case no consolidated metadata found
   1794     try:

File ~/micromamba/envs/xbatcher-ML-1-cookbook-dev/lib/python3.12/site-packages/zarr/api/synchronous.py:222, in open_consolidated(use_consolidated, *args, **kwargs)
    217 def open_consolidated(*args: Any, use_consolidated: Literal[True] = True, **kwargs: Any) -> Group:
    218     """
    219     Alias for :func:`open_group` with ``use_consolidated=True``.
    220     """
    221     return Group(
--> 222         sync(async_api.open_consolidated(*args, use_consolidated=use_consolidated, **kwargs))
    223     )

File ~/micromamba/envs/xbatcher-ML-1-cookbook-dev/lib/python3.12/site-packages/zarr/core/sync.py:163, in sync(coro, loop, timeout)
    160 return_result = next(iter(finished)).result()
    162 if isinstance(return_result, BaseException):
--> 163     raise return_result
    164 else:
    165     return return_result

File ~/micromamba/envs/xbatcher-ML-1-cookbook-dev/lib/python3.12/site-packages/zarr/core/sync.py:119, in _runner(coro)
    114 """
    115 Await a coroutine and return the result of running it. If awaiting the coroutine raises an
    116 exception, the exception will be returned.
    117 """
    118 try:
--> 119     return await coro
    120 except Exception as ex:
    121     return ex

File ~/micromamba/envs/xbatcher-ML-1-cookbook-dev/lib/python3.12/site-packages/zarr/api/asynchronous.py:382, in open_consolidated(use_consolidated, *args, **kwargs)
    377 if use_consolidated is not True:
    378     raise TypeError(
    379         "'use_consolidated' must be 'True' in 'open_consolidated'. Use 'open' with "
    380         "'use_consolidated=False' to bypass consolidated metadata."
    381     )
--> 382 return await open_group(*args, use_consolidated=use_consolidated, **kwargs)

File ~/micromamba/envs/xbatcher-ML-1-cookbook-dev/lib/python3.12/site-packages/zarr/api/asynchronous.py:845, in open_group(store, mode, cache_attrs, synchronizer, path, chunk_store, storage_options, zarr_version, zarr_format, meta_array, attributes, use_consolidated)
    843 try:
    844     if mode in _READ_MODES:
--> 845         return await AsyncGroup.open(
    846             store_path, zarr_format=zarr_format, use_consolidated=use_consolidated
    847         )
    848 except (KeyError, FileNotFoundError):
    849     pass

File ~/micromamba/envs/xbatcher-ML-1-cookbook-dev/lib/python3.12/site-packages/zarr/core/group.py:542, in AsyncGroup.open(cls, store, zarr_format, use_consolidated)
    535         raise FileNotFoundError(store_path)
    536 elif zarr_format is None:
    537     (
    538         zarr_json_bytes,
    539         zgroup_bytes,
    540         zattrs_bytes,
    541         maybe_consolidated_metadata_bytes,
--> 542     ) = await asyncio.gather(
    543         (store_path / ZARR_JSON).get(),
    544         (store_path / ZGROUP_JSON).get(),
    545         (store_path / ZATTRS_JSON).get(),
    546         (store_path / str(consolidated_key)).get(),
    547     )
    548     if zarr_json_bytes is not None and zgroup_bytes is not None:
    549         # warn and favor v3
    550         msg = f"Both zarr.json (Zarr format 3) and .zgroup (Zarr format 2) metadata objects exist at {store_path}. Zarr format 3 will be used."

File ~/micromamba/envs/xbatcher-ML-1-cookbook-dev/lib/python3.12/site-packages/zarr/storage/_common.py:164, in StorePath.get(self, prototype, byte_range)
    162 if prototype is None:
    163     prototype = default_buffer_prototype()
--> 164 return await self.store.get(self.path, prototype=prototype, byte_range=byte_range)

File ~/micromamba/envs/xbatcher-ML-1-cookbook-dev/lib/python3.12/site-packages/zarr/storage/_fsspec.py:300, in FsspecStore.get(self, key, prototype, byte_range)
    298 try:
    299     if byte_range is None:
--> 300         value = prototype.buffer.from_bytes(await self.fs._cat_file(path))
    301     elif isinstance(byte_range, RangeByteRequest):
    302         value = prototype.buffer.from_bytes(
    303             await self.fs._cat_file(
    304                 path,
   (...)    307             )
    308         )

File ~/micromamba/envs/xbatcher-ML-1-cookbook-dev/lib/python3.12/site-packages/gcsfs/core.py:1117, in GCSFileSystem._cat_file(self, path, start, end, **kwargs)
   1115 else:
   1116     head = {}
-> 1117 headers, out = await self._call("GET", u2, headers=head)
   1118 return out

File ~/micromamba/envs/xbatcher-ML-1-cookbook-dev/lib/python3.12/site-packages/gcsfs/core.py:483, in GCSFileSystem._call(self, method, path, json_out, info_out, *args, **kwargs)
    479 async def _call(
    480     self, method, path, *args, json_out=False, info_out=False, **kwargs
    481 ):
    482     logger.debug(f"{method.upper()}: {path}, {args}, {kwargs.get('headers')}")
--> 483     status, headers, info, contents = await self._request(
    484         method, path, *args, **kwargs
    485     )
    486     if json_out:
    487         return json.loads(contents)

File ~/micromamba/envs/xbatcher-ML-1-cookbook-dev/lib/python3.12/site-packages/decorator.py:224, in decorate.<locals>.fun(*args, **kw)
    222 if not kwsyntax:
    223     args, kw = fix(args, kw, sig)
--> 224 return await caller(func, *(extras + args), **kw)

File ~/micromamba/envs/xbatcher-ML-1-cookbook-dev/lib/python3.12/site-packages/gcsfs/retry.py:135, in retry_request(func, retries, *args, **kwargs)
    133     if retry > 0:
    134         await asyncio.sleep(min(random.random() + 2 ** (retry - 1), 32))
--> 135     return await func(*args, **kwargs)
    136 except (
    137     HttpError,
    138     requests.exceptions.RequestException,
   (...)    141     aiohttp.client_exceptions.ClientError,
    142 ) as e:
    143     if (
    144         isinstance(e, HttpError)
    145         and e.code == 400
    146         and "requester pays" in e.message
    147     ):

File ~/micromamba/envs/xbatcher-ML-1-cookbook-dev/lib/python3.12/site-packages/gcsfs/core.py:476, in GCSFileSystem._request(self, method, path, headers, json, data, *args, **kwargs)
    473 info = r.request_info  # for debug only
    474 contents = await r.read()
--> 476 validate_response(status, contents, path, args)
    477 return status, headers, info, contents

File ~/micromamba/envs/xbatcher-ML-1-cookbook-dev/lib/python3.12/site-packages/gcsfs/retry.py:120, in validate_response(status, content, path, args)
    118     raise requests.exceptions.ProxyError()
    119 elif "invalid" in str(msg):
--> 120     raise ValueError(f"Bad Request: {path}\n{msg}")
    121 elif error and not isinstance(error, str):
    122     raise HttpError(error)

ValueError: Bad Request: https://storage.googleapis.com/download/storage/v1/b/pangeo-cesm-pop/o/control%2F.zattrs?alt=media
User project specified in the request is invalid.

Next, we’ll load our training data and pick out the part we want to train with.

NOTE: Coordinates and attributes are dropped for speed, doing this shouldn’t be necessary in future (optimized) versions of xarray/xbatcher.

ds_training = load_training_data(sc5)
ds_training = just_the_data(ds_training)
ds_training

Looking inside ds_training, we see only the variables we would expect from sc5.

ds_training = select_from(ds_training)
xr.plot.contourf(ds_training['SST'])

Model Setup

We have a model architecuture we’re happy with already defined, so for this tutorial, we’ll focus on how to use xbatcher to generate training sets for the model. From the notebook below, we recieve:

  • get_model() Creates a mixed neural network based on some parameters. The architecture is intentionally a little arbitrary in terms of the depth of the dense part of the network, the depth of the convolutional part of the network, and the convolution kernel size. Returns a compiled Keras model.
  • LossHistory() Only needed here because it has to be passed to model.fit().
  • train() We will walk through this routine below.

Have a look inside for more details!

%run ./surface_currents_model.ipynb

Now the fun part: we define the train function to deal with high-level aspects of training the model, which means this is a good place to use xbatcher. Let’s walk through it...

The arguments to train are

  • ds: xr.DataSet The dataset you want to work with.
  • sc: Scenario The scenario you want to work with.
  • conv_dims: List[int] This is the shape of the stencil that will be passed to the first convolutional layer. We are only interested in 2D convolutions here, so it will need to be a list of two integers. Note that this is distinct from the convolutional kernel.
  • nfilters: int How many filters do we want to map the first convolution layer to?
  • conv_kernels: List[int] Each entry denotes the convolution kernel of a new convolution layer. train works best for odd-numbered convolution kernels.
  • dense_layers: int The number of dense layers in the model.

For this example, we only use one convolution layer, which makes some things simpler. Feel free to experiment with these parameters to use different data sets and create new CNN models.

sc = sc5
conv_dims = [5,5]
nfilters = 80
conv_kernels = [5]
dense_layers = 3

We’ll need some info about how to rectify the output of the convolution layers with raw input from other variables (see the surface_currents_model.ipynb notebook for more info). Based on the convolution kernel, we know how the output of a convolution layer will be shaped compared to the input: a halo of a certain size will be removed from the edges. For odd convolution kernels, the halo thickness is always n12\frac{n - 1}{2} where nn is the kernel.

halo_size = int((np.sum(conv_kernels) - len(conv_kernels))/2)
halo_size

Training a Model with xbatcher

Since we are trying to learn from a single 2D snapshot, it makes sense to iterate in both latitude and longitude. What we want are individual samples of the size given by conv_dims, but batched in a way that we can pass the correct number of samples to the model as a single tensor. So, input_dims will contain entries for both nlon and nlat. To take full advantage of the available data, we can add an overlap to make sure halo points are fully included in the neighboring samples.

NOTE: xbatcher currently runs slowly with concat_input_dims=True, and running without it will result in batches of size one. Therefore, we use an implemenatation of xarray rolling to mimic what xbatcher does. This is not good strategy when using large datasets, but for this example, the differences are minimal. We anticipate that fixed-size batches and some optimizations will be implemented in xbatcher in the future.

nlons, nlats = conv_dims
# bgen = xb.BatchGenerator(
#     ds_training,
#     {'nlon':nlons,       'nlat':nlats},
#     {'nlon':2*halo_size, 'nlat':2*halo_size}
# )
latlen = len(ds_training['nlat'])
lonlen = len(ds_training['nlon'])
nlon_range = range(nlons,lonlen,nlons - 2*halo_size)
nlat_range = range(nlats,latlen,nlats - 2*halo_size)

batch = (
    ds_training
    .rolling({"nlat": nlats, "nlon": nlons})
    .construct({"nlat": "nlat_input", "nlon": "nlon_input"})[{'nlat':nlat_range, 'nlon':nlon_range}]
    .stack({"input_batch": ("nlat", "nlon")}, create_index=False)
    .rename_dims({'nlat_input':'nlat', 'nlon_input':'nlon'})
    .transpose('input_batch',...)
    # .chunk({'input_batch':32, 'nlat':nlats, 'nlon':nlons})
    .dropna('input_batch')
)
rnds = list(range(len(batch['input_batch'])))
np.random.shuffle(rnds)
batch = batch[{'input_batch':(rnds)}]
batch
# use with rolling
def batch_generator(batch_set, batch_size):
    n = 0
    while n < len(batch_set['input_batch']) - batch_size:
        yield batch_set.isel({'input_batch':range(n,(n+batch_size))})
        n += batch_size
        
# # use with xbatcher
# def batch_generator(bgen, batch_size):
#     b = (batch for batch in bgen)
#     n = 0
#     while n < 400:
#         batch_stack = [ next(b) for i in range(batch_size) ]
#         yield xr.concat(batch_stack, 'sample')
#         n += 1
bgen = batch_generator(batch, 4096)
# bgen = batch_generator(bgen, 32)

We need a subsetting stencil (sub) to compensate for the fact that a halo is removed by each convolution layer. This means that the input_var variables will be the wrong size at the concat layer unless we strip the halo from them.

sub = {'nlon':range(halo_size,nlons-halo_size),
       'nlat':range(halo_size,nlats-halo_size)}

Here, we generate our model and our history callback.

model = get_model(halo_size, ds_training, sc, conv_dims, nfilters, conv_kernels, dense_layers)
history = LossHistory()

And now, we can construct our training loop. Most use cases of the xb.BatchGenerator will take the form of a for-loop with the construct for batch in bgen.

Once we have a batch, we still have some things to do before we can pass the data to the model.

So when we look at the contents of each batch, we see

# a = []
# for batch in bgen:
#     a = batch
#     break
# a

...but our model expects tensors where the different variables are stacked in a new dimension we will call var.

Looking at model.fit(), we have two separate inputs because of the distinction between convolved inputs and raw inputs. Therefore, the model expects these inputs to be given as a list of the two. The training target is relatively straightforward. On the next line, we have a couple of parameters we can experiment with. The important thing to note is the batch_size parameter; you may need to check that the sample dimension is compatible with the dimensions that xb.BatchGenerator returned. And finally, we pass our history class as a callback so we can see how the model training is progressing.

for batch in bgen:
    
    batch_conv   = [batch[x] for x in sc.conv_var]
    batch_input  = [batch[x][sub] for x in sc.input_var]
    batch_target = [batch[x][sub] for x in sc.target]
    batch_conv   = xr.merge(batch_conv).to_array('var').transpose(...,'var')
    batch_input  = xr.merge(batch_input).to_array('var').transpose(...,'var')
    batch_target = xr.merge(batch_target).to_array('var').transpose(...,'var')

    #clear_output(wait=True)
    model.fit([batch_conv, batch_input],
              batch_target,
              batch_size=32, verbose=0,# epochs=4,
              callbacks=[history])

And now that we have our model trained, we can save it for future use. Note that once this model is saved, we don’t need to rerun much from above to continue with testing or prediction.

model.save('models/'+ sc.name)
np.savez('models/history_'+sc.name, losses=history.mae, mse=history.mse, accuracy=history.accuracy)

Training Function

#train(ds_training, sc5, conv_dims, conv_kernels)

Testing the Model

ds_test = load_test_data(sc5)
ds_test = just_the_data(ds_test)
ds_test = select_from(ds_test)
ds_test
latlen = len(ds_test['nlat'])
lonlen = len(ds_test['nlon'])
nlon_range = range(nlons,lonlen,nlons - 2*halo_size)
nlat_range = range(nlats,latlen,nlats - 2*halo_size)

batch_test = (
    ds_test
    .rolling({"nlat": nlats, "nlon": nlons})
    .construct({"nlat": "nlat_input", "nlon": "nlon_input"})[{'nlat':nlat_range, 'nlon':nlon_range}]
    .stack({"input_batch": ("nlat", "nlon")}, create_index=False)
    .rename_dims({'nlat_input':'nlat', 'nlon_input':'nlon'})
    .transpose('input_batch',...)
    # .chunk({'input_batch':32, 'nlat':nlats, 'nlon':nlons})
    .dropna('input_batch')
)

Let’s load the trained model from before:

model = tf.keras.models.load_model('models/'+ sc.name, custom_objects={'Grid_MAE':Grid_MAE})
test_conv   = [batch_test[x]      for x in sc.conv_var]
test_input  = [batch_test[x][sub] for x in sc.input_var]
test_target = [batch_test[x][sub] for x in sc.target]
test_conv   = xr.merge(test_conv  ).to_array('var').transpose(...,'var')
test_input  = xr.merge(test_input ).to_array('var').transpose(...,'var')
test_target = xr.merge(test_target).to_array('var').transpose(...,'var')
model.evaluate([test_conv, test_input], test_target)

Testing Function

#test(ds_test, sc5, conv_dims, conv_kernels)

Making Predictions

ds_predict = load_predict_data(sc5)
ds_predict = just_the_data(ds_predict)
ds_predict = select_from(ds_predict)
ds_predict
latlen = len(ds_predict['nlat'])
lonlen = len(ds_predict['nlon'])
nlon_range = range(nlons,lonlen,nlons - 2*halo_size)
nlat_range = range(nlats,latlen,nlats - 2*halo_size)

batch_predict = (
    ds_predict
    .rolling({"nlat": nlats, "nlon": nlons})
    .construct({"nlat": "nlat_input", "nlon": "nlon_input"})[{'nlat':nlat_range, 'nlon':nlon_range}]
    .stack({"input_batch": ("nlat", "nlon")}, create_index=False)
    .rename_dims({'nlat_input':'nlat', 'nlon_input':'nlon'})
    .transpose('input_batch',...)
    # .chunk({'input_batch':32, 'nlat':nlats, 'nlon':nlons})
    .dropna('input_batch')
)
model = tf.keras.models.load_model('models/'+ sc.name, custom_objects={'Grid_MAE':Grid_MAE})
predict_conv  = [batch_test[x]      for x in sc.conv_var]
predict_input = [batch_test[x][sub] for x in sc.input_var]
predict_conv  = xr.merge(predict_conv ).to_array('var').transpose(...,'var')
predict_input = xr.merge(predict_input).to_array('var').transpose(...,'var')
predict_target = model.predict([predict_conv, predict_input])

Prediction Function

#predict_target = predict(ds_predict, sc5, conv_dims, conv_kernels)

Prediction Results

Now, let’s take a look at the predicted surface currents and see how the model did. Notice that the predicted data can be retrieved fairly easily from our default setup. We only have to reshape them, with respect to the original dimensions and a halo that will be stripped off. This is because we chose to make the convolution kernel equal to the dimensions of the samples, which means the model will give results at individual points.

However, the convolution kernal can be different, it’s just that we will then have to use a more complex process to restructure our grid.

Note also that if there were nans removed, we would have to keep track of how to map the unstructured model inputs back to the original grid and insert nans in the correct positions.

U = ds_predict['U']
V = ds_predict['V']
U_pred = predict_target[:,0,0,0].reshape(545, 345)
V_pred = predict_target[:,0,0,1].reshape(545, 345)
plt.figure(figsize=(10, 6))
plt.pcolormesh(U, cmap='RdBu_r')
plt.clim([-100, 100])
plt.colorbar()
plt.figure(figsize=(10, 6))
plt.pcolormesh(U_pred, cmap='RdBu_r')
plt.clim([-100, 100])
plt.colorbar()

We can see that they look very similar, but to get a better idea of what our errors look like, we can subtract them.

plt.figure(figsize=(10, 6))
plt.pcolormesh(U_pred - U[3:-2,3:-2], cmap='RdBu_r') # double-check U indexing
plt.clim([-100, 100])
plt.colorbar()
References
  1. Sinha, A., & Abernathey, R. (2021). Estimating Ocean Surface Currents With Machine Learning. Frontiers in Marine Science, 8. 10.3389/fmars.2021.672477