Spectral Clustering
Overview
The current notebook will demonstrate a simplified machine learning approach to observe the change in a lake water’s extent across time. In order to identify the water, we can use spectral clustering to classify each grid cell into a category based on the similarity of the combined set of pixels across wavelength-bands in our image stacks.
Our example approach uses a version of spectral clustering from dask_ml that is a scalable equivalent of what is available in scikit-learn. We will begin this approach with a single image stack and then conduct a direct comparison on the results from different time points.
This workflow uses data from Microsoft Planetary Computer but it can be adapted to work with any data ingestion approach from this cookbook.
Prerequisites
Concepts |
Importance |
Notes |
---|---|---|
Necessary |
||
Helpful |
Spectral clustering |
|
Helpful |
Spectral clustering at scale |
Time to learn: 20 minutes.
Imports
# Data
import numpy as np
import odc.stac
import pandas as pd
import planetary_computer
import pystac_client
import xarray as xr
from dask.distributed import Client
from pystac.extensions.eo import EOExtension as eo
# Analysis
from dask_ml.cluster import SpectralClustering
# Viz
import hvplot.xarray
/home/runner/miniconda3/envs/cookbook-dev/lib/python3.10/site-packages/dask/dataframe/_pyarrow_compat.py:17: FutureWarning: Minimal version of pyarrow will soon be increased to 14.0.1. You are using 12.0.1. Please consider upgrading.
warnings.warn(
Loading Data
Let’s start by loading some Landsat data. These steps are covered in the Data Ingestion - Planetary Computer prerequisite.
Search the catalog
catalog = pystac_client.Client.open(
"https://planetarycomputer.microsoft.com/api/stac/v1",
modifier=planetary_computer.sign_inplace,
)
bbox = [-118.89, 38.54, -118.57, 38.84] # Region over a lake in Nevada, USA
datetime = "2017-06-01/2017-09-30" # Summer months of 2017
collection = "landsat-c2-l2"
platform = "landsat-8"
cloudy_less_than = 1 # percent
search = catalog.search(
collections=["landsat-c2-l2"],
bbox=bbox,
datetime=datetime,
query={"eo:cloud_cover": {"lt": cloudy_less_than}, "platform": {"in": [platform]}},
)
items = search.get_all_items()
print(f"Returned {len(items)} Items:")
[[i, item.id] for i, item in enumerate(items)]
/home/runner/miniconda3/envs/cookbook-dev/lib/python3.10/site-packages/pystac_client/item_search.py:849: FutureWarning: get_all_items() is deprecated, use item_collection() instead.
warnings.warn(
Returned 3 Items:
[[0, 'LC08_L2SP_042033_20170718_02_T1'],
[1, 'LC08_L2SP_042033_20170702_02_T1'],
[2, 'LC08_L2SP_042033_20170616_02_T1']]
Load a dataset
item = items[1] # select one of the results
assets = []
for _, asset in item.assets.items():
try:
assets.append(asset.extra_fields["eo:bands"][0])
except:
pass
cols_ordered = [
"common_name",
"description",
"name",
"center_wavelength",
"full_width_half_max",
]
bands = pd.DataFrame.from_dict(assets)[cols_ordered]
bands
common_name | description | name | center_wavelength | full_width_half_max | |
---|---|---|---|---|---|
0 | red | Visible red | OLI_B4 | 0.65 | 0.04 |
1 | blue | Visible blue | OLI_B2 | 0.48 | 0.06 |
2 | green | Visible green | OLI_B3 | 0.56 | 0.06 |
3 | nir08 | Near infrared | OLI_B5 | 0.87 | 0.03 |
4 | lwir11 | Long-wave infrared | TIRS_B10 | 10.90 | 0.59 |
5 | swir16 | Short-wave infrared | OLI_B6 | 1.61 | 0.09 |
6 | swir22 | Short-wave infrared | OLI_B7 | 2.20 | 0.19 |
7 | coastal | Coastal/Aerosol | OLI_B1 | 0.44 | 0.02 |
ds_2017 = odc.stac.stac_load(
[item],
bands=bands.common_name.values,
bbox=bbox,
chunks={}, # <-- use Dask
).isel(time=0)
Retain CRS Attribute
epsg = item.properties["proj:epsg"]
ds_2017.attrs["crs"] = f"epsg:{epsg}"
da_2017 = ds_2017.to_array(dim="band")
da_2017
<xarray.DataArray (band: 8, y: 1128, x: 950)> dask.array<stack, shape=(8, 1128, 950), dtype=uint16, chunksize=(1, 1128, 950), chunktype=numpy.ndarray> Coordinates: * y (y) float64 4.301e+06 4.301e+06 ... 4.267e+06 4.267e+06 * x (x) float64 3.353e+05 3.353e+05 ... 3.637e+05 3.638e+05 spatial_ref int32 32611 time datetime64[ns] 2017-07-02T18:33:06.200763 * band (band) object 'red' 'blue' 'green' ... 'swir22' 'coastal' Attributes: crs: epsg:32611
Reshaping Data
The shape of our data is currently n_bands
, n_y
, n_x
. In order for dask-ml / scikit-learn to consume our data, we’ll need to reshape our image stacks into n_samples, n_features
, where n_features
is the number of wavelength-bands and n_samples
is the total number of pixels in each wavelength-band image. Essentially, we’ll be creating a vector of pixels out of each image, where each pixel has multiple features (bands), but the ordering of the pixels is no longer relevant to the computation.
By using xarray methods to flatten the data, we can keep track of the coordinate labels ‘x’ and ‘y’ along the way. This means that we have the ability to reshape back to our original array at any time with no information loss!
flattened_xda = da_2017.stack(z=("x", "y")) # flatten each band
flattened_t_xda = flattened_xda.transpose("z", "band")
flattened_t_xda
<xarray.DataArray (z: 1071600, band: 8)> dask.array<transpose, shape=(1071600, 8), dtype=uint16, chunksize=(1071600, 1), chunktype=numpy.ndarray> Coordinates: spatial_ref int32 32611 time datetime64[ns] 2017-07-02T18:33:06.200763 * band (band) object 'red' 'blue' 'green' ... 'swir22' 'coastal' * z (z) object MultiIndex * x (z) float64 3.353e+05 3.353e+05 ... 3.638e+05 3.638e+05 * y (z) float64 4.301e+06 4.301e+06 ... 4.267e+06 4.267e+06 Attributes: crs: epsg:32611
Standardize Data
Now that we have the data in the correct shape, let’s standardize (or rescale) the values of the data. We do this to get all the flattened image vectors onto a common scale while preserving the differences in the ranges of values. Again, we’ll demonstrate doing this first in NumPy and then xarray.
with xr.set_options(keep_attrs=True):
rescaled_xda = (flattened_t_xda - flattened_t_xda.mean()) / flattened_t_xda.std()
rescaled_xda
<xarray.DataArray (z: 1071600, band: 8)> dask.array<truediv, shape=(1071600, 8), dtype=float64, chunksize=(1071600, 1), chunktype=numpy.ndarray> Coordinates: spatial_ref int32 32611 time datetime64[ns] 2017-07-02T18:33:06.200763 * band (band) object 'red' 'blue' 'green' ... 'swir22' 'coastal' * z (z) object MultiIndex * x (z) float64 3.353e+05 3.353e+05 ... 3.638e+05 3.638e+05 * y (z) float64 4.301e+06 4.301e+06 ... 4.267e+06 4.267e+06 Attributes: crs: epsg:32611
Info
Above, we are using a context manager “with xr.set_options(keep_attrs=True):” to retain the array’s attributes through the operations. That is, we want any metadata like ‘crs’ to stay with our result so we can use ‘geo=True’ in our plotting.
As rescaled_xda
is still a Dask object, if we wanted to actually run the rescaling at this point (provided that all the data can fit into memory), we would use rescaled_xda.compute()
.
ML pipeline
Now that our data is in the proper shape and value range, we are ready to conduct spectral clustering. Here we will use a version of spectral clustering from dask_ml that is a scalable equivalent to operations from Scikit-learn that cluster pixels based on similarity (across all wavelength-bands, which makes it spectral clustering by spectra!)
client = Client(processes=False)
client
Client
Client-6060ada8-f06c-11ee-8a9f-6045bdc85b55
Connection method: Cluster object | Cluster type: distributed.LocalCluster |
Dashboard: http://10.1.0.22:8787/status |
Cluster Info
LocalCluster
4a1bb2d2
Dashboard: http://10.1.0.22:8787/status | Workers: 1 |
Total threads: 4 | Total memory: 15.61 GiB |
Status: running | Using processes: False |
Scheduler Info
Scheduler
Scheduler-bd9125d3-b1b1-49ea-9b2a-20a2a8566fe3
Comm: inproc://10.1.0.22/2719/1 | Workers: 1 |
Dashboard: http://10.1.0.22:8787/status | Total threads: 4 |
Started: Just now | Total memory: 15.61 GiB |
Workers
Worker: 0
Comm: inproc://10.1.0.22/2719/4 | Total threads: 4 |
Dashboard: http://10.1.0.22:37687/status | Memory: 15.61 GiB |
Nanny: None | |
Local directory: /tmp/dask-scratch-space/worker-s38ahvhm |
Now we will compute and persist the rescaled data to feed into the ML pipeline. Notice that our X
matrix below has the shape: n_samples, n_features
as discussed earlier.
X = client.persist(rescaled_xda)
X.shape
(1071600, 8)
First we will set up the model with the number of clusters, and other options.
clf = SpectralClustering(
n_clusters=4,
random_state=0,
gamma=None,
kmeans_params={"init_max_iter": 5},
persist_embedding=True,
)
This next step is the slow part. We’ll fit the model to our matrix X
. Depending on your setup, it could take seconds to minutes to run depending on the size of our data.
%time clf.fit(X)
/home/runner/miniconda3/envs/cookbook-dev/lib/python3.10/site-packages/distributed/client.py:3157: UserWarning: Sending large graph of size 81.80 MiB.
This may cause some slowdown.
Consider scattering data ahead of time and using futures.
warnings.warn(
/home/runner/miniconda3/envs/cookbook-dev/lib/python3.10/site-packages/dask/base.py:1462: UserWarning: Running on a single-machine scheduler when a distributed client is active might lead to unexpected results.
warnings.warn(
CPU times: user 22 s, sys: 13.2 s, total: 35.2 s
Wall time: 29.2 s
SpectralClustering(gamma=None, kmeans_params={'init_max_iter': 5}, n_clusters=4, persist_embedding=True, random_state=0)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
SpectralClustering(gamma=None, kmeans_params={'init_max_iter': 5}, n_clusters=4, persist_embedding=True, random_state=0)
Let’s check the shape of the result:
labels = clf.assign_labels_.labels_.compute()
labels.shape
(1071600,)
labels
array([0, 2, 2, ..., 2, 2, 2], dtype=int32)
The result is a single vector of cluster labels.
Un-flattening
Once the computation is done, we can use the coordinates of our input array to restack our output array back into an image. Again, one of the main benefits of using xarray
for this stacking and unstacking is that it keeps track of the coordinate information for us.
Since the original array is n_samples by n_features (90000, 6) and the cluster label output is (90000,), we just need the coordinates from one of the original features in the shape of n_samples. We can just copy the coordinates from the first input feature and populate is with our output data:
template = flattened_t_xda[:, 0]
output_array = template.copy(data=labels)
output_array
<xarray.DataArray (z: 1071600)> array([0, 2, 2, ..., 2, 2, 2], dtype=int32) Coordinates: spatial_ref int32 32611 time datetime64[ns] 2017-07-02T18:33:06.200763 band <U3 'red' * z (z) object MultiIndex * x (z) float64 3.353e+05 3.353e+05 ... 3.638e+05 3.638e+05 * y (z) float64 4.301e+06 4.301e+06 ... 4.267e+06 4.267e+06 Attributes: crs: epsg:32611
With this new output array with coordinates copied from the input array, we can unstack back to the original x
and y
image dimensions by just using .unstack()
.
unstacked_2017 = output_array.unstack()
unstacked_2017
/home/runner/miniconda3/envs/cookbook-dev/lib/python3.10/site-packages/numpy/core/numeric.py:407: RuntimeWarning: invalid value encountered in cast
multiarray.copyto(res, fill_value, casting='unsafe')
<xarray.DataArray (x: 950, y: 1128)> array([[0, 2, 2, ..., 0, 0, 0], [0, 0, 0, ..., 0, 0, 0], [2, 0, 2, ..., 0, 0, 0], ..., [2, 2, 2, ..., 2, 2, 2], [2, 2, 2, ..., 2, 2, 2], [2, 2, 0, ..., 2, 2, 2]], dtype=int32) Coordinates: * x (x) float64 3.353e+05 3.353e+05 ... 3.637e+05 3.638e+05 * y (y) float64 4.301e+06 4.301e+06 ... 4.267e+06 4.267e+06 spatial_ref int32 32611 time datetime64[ns] 2017-07-02T18:33:06.200763 band <U3 'red' Attributes: crs: epsg:32611
Finally, we can visualize the results! By hovering over the resulting imge, we can see that the lake water has been clustered with a certain label or ‘value’.
raw_plot_2017 = da_2017.sel(band="red").hvplot.image(
x="x", y="y", geo=True, xlabel="lon", ylabel="lat", datashade=True, cmap="greys", title="Raw Image 2017",
)
result_plot_2017 = unstacked_2017.hvplot(
x="x", y="y", cmap="Set3", geo=True, xlabel="lon", ylabel="lat", colorbar=False, title="Spectral Clustering 2017",
)
raw_plot_2017 + result_plot_2017
/home/runner/miniconda3/envs/cookbook-dev/lib/python3.10/site-packages/geoviews/operation/__init__.py:14: HoloviewsDeprecationWarning: 'ResamplingOperation' is deprecated and will be removed in version 1.18, use 'ResampleOperation2D' instead.
from holoviews.operation.datashader import (