Skip to article frontmatterSkip to article content

Visualization of JEDI analysis with UXarray in the model space

JEDI

In this section, you’ll learn:

  • Utilizing UXarry to compute analysis increments, visualize increments in horizontal and vertical cross sections

Prerequisites

ConceptsImportanceNotes
Atmospheric Data AssimilationHelpful

Time to learn: 10 minutes


Import packages

%%time 

# autoload external python modules if they changed
%load_ext autoreload
%autoreload 2

# add ../funcs to the current path
import sys, os
sys.path.append(os.path.join(os.getcwd(), "..")) 

# import modules
import warnings
import math

import cartopy.crs as ccrs
import geoviews as gv
import geoviews.feature as gf
import holoviews as hv
import hvplot.xarray
from holoviews import opts
import matplotlib as mpl
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

import s3fs

import geopandas as gp
import numpy as np
import uxarray as ux
import xarray as xr
Loading...

Configure visualization tools

hv.extension("bokeh")
# hv.extension("matplotlib")

# common border lines
coast_lines = gf.coastline(projection=ccrs.PlateCarree(), line_width=1, scale="50m")
state_lines = gf.states(projection=ccrs.PlateCarree(), line_width=1, line_color='gray', scale="50m")
Loading...

Helper functions

The following functions are used for visualizing the data. The horizontal_contour function generates the contour map for a given slice of data.

# Generates a contour plot for a horizontal slice
def horizontal_contour(ux_hslice, title, cmin=None, cmax=None, width=800, height=500, clevs=20, cmap="coolwarm", symmetric_cmap=False):
    # Get min and max
    amin = ux_hslice.min().item()
    amax = ux_hslice.max().item()
    
    cmin = math.floor(amin) if(cmin is None) else cmin
    cmax = math.ceil(amax) if(cmax is None) else cmax
    
    if symmetric_cmap:  # get a symmetric color map when cmin < 0, cmax >0
        cmax = max(abs(cmin), cmax)
        cmin = -cmax

    if isinstance(cmap, str):
        cmap = plt.get_cmap(cmap)

    # Generate contour plot
    title = f" min={amin:.1f} max={amax:.1f}"
    
    contour_plot = hv.operation.contours(
        ux_hslice.plot(),
        levels=np.linspace(cmin, cmax, num=clevs),  # levels=np.arange(cmin, cmax, 0.5)
        filled=True
    ).opts(
        line_color=None,  # line_width=0.001
        width=width, height=height,
        cmap=cmap, clim=(cmin, cmax),
        colorbar=True, show_legend=False,
        tools=['hover'], title=title
    )

    return contour_plot

Retrieve/load MPAS/JEDI data

The example MPAS/JEDI data are stored at jetstream2. We need to retreive those data first.
There are two ways to retrieve MPAS data:

    1. Download all example data from JetStream2 to local and them load them locally. This approach allows downloading the data once per machine and reuse it in notebooks.
    1. Stream the JetStream2 S3 objects on demand. In this case, each notebook (including restarting a notebook) will retrieve the required data separately as needed.
# choose the data_load_method, check the above cell for details. Default to method 1, i.e. download once and reuse it in multiple notebooks
data_load_method = 2  # 1 or 2

Method 1: Download all example data once and reuse it in mulptile notebooks

%%time
local_dir="/tmp"

if data_load_method == 1 and not os.path.exists(local_dir + "/conus12km/bkg/mpasout.2024-05-06_01.00.00.nc"):
    jetstream_url = 'https://js2.jetstream-cloud.org:8001/'
    fs = s3fs.S3FileSystem(anon=True, asynchronous=False,client_kwargs=dict(endpoint_url=jetstream_url))
    conus12_path = 's3://pythia/mpas/conus12km'
    fs.get(conus12_path, local_dir, recursive=True)
    print("Data downloading completed")
else:
    print("Skip..., either data is available in local or data_load_method is NOT 1")
Skip..., either data is available in local or data_load_method is NOT 1
CPU times: user 442 μs, sys: 0 ns, total: 442 μs
Wall time: 447 μs
# Set file path
if data_load_method == 1:
    grid_file = local_dir + "/conus12km/conus12km.invariant.nc_L60_GFS"
    ana_file = local_dir + "/conus12km/bkg/mpasout.2024-05-06_01.00.00.nc"
    bkg_file = local_dir + "/conus12km/ana/mpasout.2024-05-06_01.00.00.nc"

Method 2: Stream the JetStream2 S3 objects on demand

%%time
if data_load_method == 2:
    jetstream_url = 'https://js2.jetstream-cloud.org:8001/'
    fs = s3fs.S3FileSystem(anon=True, asynchronous=False,client_kwargs=dict(endpoint_url=jetstream_url))
    conus12_path = 's3://pythia/mpas/conus12km'
    
    grid_url=f"{conus12_path}/conus12km.invariant.nc_L60_GFS"
    bkg_url=f"{conus12_path}/bkg/mpasout.2024-05-06_01.00.00.nc"
    ana_url=f"{conus12_path}/ana/mpasout.2024-05-06_01.00.00.nc"
    
    grid_file = fs.open(grid_url)
    ana_file = fs.open(ana_url)
    bkg_file = fs.open(bkg_url)
else:
    print("Skip..., data_load_method is NOT 2")
CPU times: user 56.9 ms, sys: 16 ms, total: 72.8 ms
Wall time: 122 ms

Loading the data into UXarray datasets

We use the UXarray data structures for working with the data. This package supports data defined over unstructured grid and provides utilities for modifying and visualizing it. The available fucntionality are discussed in UxDataset documentation.

uxds_a = ux.open_dataset(grid_file, ana_file)
uxds_b = ux.open_dataset(grid_file, bkg_file)

compute the analysis increments from the JEDI data assimilation

JEDI updates the background atmospheric state (uxds_b) with observation innovations and gets a new atmospheric state called analysis (uxds_a).
The difference of uxds_a - uxds_b is called “analysis increments”

var_name = "theta"
uxdiff0 = uxds_a[var_name] - uxds_b[var_name]
uxvar = uxdiff0

plot temperature analysis increments at different levels

## Utility to create ncl style color bars.
# def make_discrete_cmap(n_colors, base="coolwarm", *, center_white=False):
#     if center_white and n_colors % 2 == 0:
#         raise ValueError("center_white=True requires an odd n_colors")
#     base_cmap  = plt.get_cmap(base)
#     positions  = np.linspace(0, 1, n_colors)
#     colours    = base_cmap(positions)
#     if center_white:
#         colours[n_colors // 2] = (1.0, 1.0, 1.0, 1.0)

#     suffix = "_cw" if center_white else ""
#     return ListedColormap(colours, name=f"{base}_{n_colors}{suffix}")
# # levels = np.arange(-4.0, 4.0 + 0.5, 0.5)
# cmap   = make_discrete_cmap(base="coolwarm", n_colors=16, center_white=False)
from matplotlib.colors import ListedColormap, BoundaryNorm, to_rgba

def make_interval_cmap(
    edges,
    colors=None,
    *,
    base_cmap="viridis",
    bad="none",
    under=None,
    over=None
):
    edges = np.asarray(edges, dtype=float)
    if edges.ndim != 1 or len(edges) < 2:
        raise ValueError("`edges` must be 1D with at least two values.")
    if not np.all(np.diff(edges) > 0):
        raise ValueError("`edges` must be strictly increasing.")

    n_bins = len(edges) - 1

    if colors is None:
        # sample from a base cmap at bin midpoints
        cm = plt.get_cmap(base_cmap)
        mids = 0.5 * (edges[:-1] + edges[1:])
        # normalize mids to 0..1 based on full range of edges
        t = (mids - edges[0]) / (edges[-1] - edges[0])
        cols = cm(t)
    else:
        if len(colors) != n_bins:
            raise ValueError(f"`colors` must have length {n_bins}.")
        cols = [to_rgba(c) for c in colors]

    cmap = ListedColormap(cols, name="interval_cmap")

    if bad != "none":
        cmap.set_bad(bad)
    if under is not None:
        cmap.set_under(under)
    if over is not None:
        cmap.set_over(over)

    norm = BoundaryNorm(edges, cmap.N, extend=("neither" if (under is None and over is None) else "both"))
    return cmap, norm

edges = [-4, -3.5, -3, -2.5, -2, -1.5, -1, -0.5, 0, 0.5, 1.0, 1.5, 2, 2.5 ,3, 3.5, 4]
# edges = [0, 5, 10, 15, 20, 25, 30, 35, 40, 45, 50]
colors = [
    "green",
    "blue",
    "#313695",  # deep navy
    "#4575b4",  # medium blue
    "#74add1",  # light blue
    "#abd9e9",  # pale blue
    "#d0f0c0",  # light green
    "#ffffbf",  # pale yellow
    "white",    # pure white (center)
    "#fee090",  # pale yellow (mirror)
    "#f46d43",  # orange-red
    "#d73027",  # red
    "#a50026",  # deep red
    "#800026",  # darker maroon
    "#4d0013",  # deep maroon
    "black"     # extreme end
]


cmap, norm = make_interval_cmap(edges, colors, bad="lightgray", under=None, over=None)
%%time

nt=0  # time dimension
plot_levels = [0, 19, 29, 39, 42, 49, 58]  # [0, 29, 42]  # [0, 19, 29, 39, 49, 58]

zero_shift = 0.0

plots = []
for lev in plot_levels:
    tmp = horizontal_contour(
        uxvar.isel(Time=nt, nVertLevels=lev), 
        title=f'lev={lev}',
        symmetric_cmap=True,
        cmap =cmap,
        cmax=4,
        cmin=4
        #clevs=20,
    )  # for the whole domain
    
    
    plots.append(tmp * coast_lines * state_lines)
for p in plots:
   display(p)

# dat.where((dat > 0.1) | (dat < -0.1)),
Loading...

Zoomed into Colorado using the subset capability

%%time

lon_center = -105.03
lat_center = 39.0
lon_incr = 5 # degree
lat_incr = 3 # degree
lon_bounds = (lon_center - lon_incr, lon_center + lon_incr)
lat_bounds = (lat_center - lat_incr, lat_center + lat_incr)

### subset to a small domain
uxdiff1 = uxdiff0.subset.bounding_box(lon_bounds, lat_bounds,)
uxvar = uxdiff1


nt=0  # time dimension
plot_levels = [0, 29, 42]  # [0, 19, 29, 39, 49, 58]

plots = []
for lev in plot_levels:
    tmp = horizontal_contour(uxvar.isel(Time=nt, nVertLevels=lev), title=f'lev={lev}', width=700, height=500)  # for the subdomain  
    plots.append(tmp * coast_lines * state_lines .opts(xlim=(lon_bounds[0], lon_bounds[1]), ylim=(lat_bounds[0], lat_bounds[1])))

# plots share one toolbar, which facilitates doing sync'ed zoom-in/out
# hv.Layout(plots).cols(1)

# each plot has its own toolbar, which facilitates controlling each plot individually
for p in plots:
   display(p)
Loading...

Random Great Circle Arc (GCA)

# lat=43.3
# step_between_points = 100

start_point = (-110, 20)
end_point = (-70, 50)
var_name = "theta"
uxdiff0 = uxds_a[var_name].isel(Time=0) - uxds_b[var_name].isel(Time=0)
uxvar = uxdiff0
cross_section_gca = uxvar.cross_section(start=start_point, end=end_point, steps=100)
hlabelticks = [
    f"{abs(lat):.1f}°{'N' if lat >= 0 else 'S'}\n{abs(lon):.1f}°{'E' if lon >= 0 else 'W'}"
    for lat, lon in zip(cross_section_gca['lat'], cross_section_gca['lon'])
]
# cross_section_gca.isel(Time=0).transpose().plot.contourf()
%matplotlib inline


fig= plt.figure(figsize=(8,3))
gs= fig.add_gridspec(1,1)
ax = fig.add_subplot(gs[0,0])
cf=ax.contourf(cross_section_gca.transpose(),cmap='Reds',extend='both')
tick_stride = 10
ax.set_xticks(cross_section_gca['steps'][::tick_stride])
ax.set_xticklabels(hlabelticks[::tick_stride])
[Text(0, 0, '20.0°N\n110.0°W'), Text(10, 0, '23.5°N\n107.1°W'), Text(20, 0, '27.0°N\n104.1°W'), Text(30, 0, '30.3°N\n100.9°W'), Text(40, 0, '33.6°N\n97.4°W'), Text(50, 0, '36.8°N\n93.7°W'), Text(60, 0, '39.9°N\n89.7°W'), Text(70, 0, '42.8°N\n85.3°W'), Text(80, 0, '45.5°N\n80.5°W'), Text(90, 0, '48.0°N\n75.2°W')]
<Figure size 800x300 with 1 Axes>
lon=-83.3
cross_section_lon = uxvar.cross_section(lon=lon, steps=100)

hlabelticks = [
    f"{abs(lat):.1f}°{'N' if lat >= 0 else 'S'}" for lat in cross_section_lon['lat']
]

%matplotlib inline
fig= plt.figure(figsize=(8,3))
gs= fig.add_gridspec(1,1)
ax = fig.add_subplot(gs[0,0])
cf=ax.contourf(cross_section_lon.transpose(),cmap='Reds',extend='both')

ax.set_xticks(cross_section_lon['steps'][::tick_stride])
ax.set_xticklabels(hlabelticks[::tick_stride])
[Text(0, 0, '90.0°S'), Text(10, 0, '71.8°S'), Text(20, 0, '53.6°S'), Text(30, 0, '35.5°S'), Text(40, 0, '17.3°S'), Text(50, 0, '0.9°N'), Text(60, 0, '19.1°N'), Text(70, 0, '37.3°N'), Text(80, 0, '55.5°N'), Text(90, 0, '73.6°N')]
<Figure size 800x300 with 1 Axes>
cross_section_lon = uxvar.cross_section(lon=-100., steps=100)
cross_section_lon
Loading...

save plots to files

# hv.save(tmp, 'vslice.png')