{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "44dee133" }, "source": [ "# OpenHydroNet Finetuning Tutorial Notebook: Exploring Base Models and Targeted Fine-Tuning\n", "\n", "Welcome to this interactive tutorial designed to help you understand the concepts of base hydrological models and the powerful technique of fine-tuning, using the OpenHydroNet framework. This notebook is structured as a student exercise, aiming to provide a hands-on experience in evaluating and improving predictive models.\n", "\n", "## What You Will Learn:\n", "\n", "By working through this notebook, you will gain insights into:\n", "\n", "1. **Understanding Base Models:** How a general-purpose hydrological model, trained on a diverse set of basins, serves as a foundation for further specialization.\n", "2. **The Concept of Fine-Tuning:** The process of adapting a pre-trained model to improve its performance on a specific target or region that might be 'out-of-distribution' compared to the base model's training data.\n", "3. **Running Models:** Practical experience in generating and understanding the command-line arguments for training (`train`), fine-tuning (`finetune`), and performing inference (`infer`) with the framework. You will learn how to prepare config files and execute these operations, typically in a terminal environment.\n", "4. **Model Performance Analysis:** How to quantitatively evaluate model performance using various metrics (e.g., KGE, NSE) and qualitatively assess predictions through hydrograph comparisons.\n", "5. **Impact of Static Attributes:** The critical role of static basin attributes (like basin area) in model performance and how targeted fine-tuning of specific model layers (e.g., `static_attributes_fc`) can address discrepancies when a target basin's characteristics differ significantly from the base model's training distribution.\n", "\n", "For more detailed insights into fine-tuning for river modeling, refer to the paper: Ryd, Emil, and Grey Nearing. \"Fine Flood Forecasts: Incorporating local data into global models through fine-tuning.\" arXiv preprint arXiv:2504.12559 (2025). [https://arxiv.org/abs/2504.12559](https://arxiv.org/abs/2504.12559)\n", "\n", "## Notebook Workflow:\n", "\n", "This notebook follows a structured approach to guide you through the comparison of a **base model** and a corresponding **fine-tuned model**. The key steps include:\n", "\n", "1. **Configuration:** Setting up local paths for model runs, shapefiles, and attribute data.\n", "2. **Base Model Exploration:** Selecting a base model, visualizing its training and testing basins, generating its run commands, calculating performance metrics, and visualizing score distributions.\n", "3. **Fine-Tuning Process:** Choosing a specific target basin for fine-tuning, understanding how to configure the fine-tuning run, and generating the necessary fine-tuning and inference commands.\n", "4. **Fine-Tuned Model Analysis:** Selecting the fine-tuned model and loading its performance metrics.\n", "5. **Comparative Analysis:** Directly comparing the base and fine-tuned models' performance on the target basin using both metric-versus-lead-time plots and detailed hydrograph comparisons.\n", "6. **Basin Context:** Examining the basin area distribution of the training data relative to the fine-tuning target basin to understand the rationale behind targeted fine-tuning strategies, particularly concerning the `static_attributes_fc` layer.\n", "\n", "By the end of this tutorial, you should have a solid understanding of how to leverage pre-trained hydrological models and adapt them for specific regional challenges, ultimately contributing to more accurate flood forecasting.\n", "\n", "This notebook is designed as an educational exercise rather than a performance benchmark. To ensure the code runs quickly in a standard environment (like Google Colab), the experiment is restricted to a \"toy\" dataset of only 5 training basins. Because State-of-the-Art (SOTA) global models typically require data from hundreds or thousands of basins to learn universal hydrologic behaviors and relationships, this 5-basin model will **not** yield state-of-the-art results. Specifically, a model trained on such a small sample size lacks the \"experience\" to understand basins in different climates or terrains. You will observe that performance metrics (NSE/KGE) on the 3 \"ungauged\" basins (basins not seen during training) are significantly lower than the training set. This is expected behavior from a model trained on a small (5-basin) dataset." ] }, { "cell_type": "markdown", "metadata": { "id": "rZFZoXFEzKic" }, "source": [ "## Imports" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "MdSkjG7XyFEB" }, "outputs": [], "source": [ "# # --- Install Dependencies ---\n", "# # These lines ensure that all required third-party libraries are installed.\n", "# # These are not needed if you use the supplied conda\n", "# # environment in `~/flood-forecasting/environments/conda.yml`\n", "\n", "# # Data visualization and interactive widgets\n", "# !pip install -q matplotlib seaborn ipywidgets\n", "\n", "# # Scientific computing and data structures\n", "# !pip install -q numpy pandas xarray netCDF4\n", "\n", "# # Google Cloud Storage (for MultiMet data access)\n", "# !pip install -q google-cloud-storage gcsfs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "oQAXWhpgzKid" }, "outputs": [], "source": [ "# Standard Library Imports\n", "import os\n", "import glob\n", "import yaml\n", "from typing import Dict, List, Optional, Set\n", "\n", "# Scientific Computing & Data Analysis\n", "import numpy as np\n", "import pandas as pd\n", "import xarray as xr\n", "\n", "# Data Visualization\n", "import matplotlib.pyplot as plt\n", "import seaborn as sns\n", "\n", "# Interactive Widgets for Notebook Environments\n", "import ipywidgets as widgets\n", "from ipywidgets import HBox, VBox, interactive\n", "\n", "# Local Tutorial Module\n", "import backend" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sXNcqhj5yWaj" }, "outputs": [], "source": [ "%matplotlib inline" ] }, { "cell_type": "markdown", "metadata": { "id": "Npt3Yb4cZj5U" }, "source": [ "## User-Defined Local Paths\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "mFzHRcwHStZJ" }, "outputs": [], "source": [ "# Define the path to the shapefile containing basin geometries.\n", "# This shapefile is used for plotting maps of basin locations.\n", "# Users may need to update this path if their data is stored elsewhere.\n", "SHAPEFILE_PATH = '~/flood-forecasting/tutorial/Caravan-nc/shapefiles/camels/camels_basin_shapes.shx'\n", "\n", "# Path to the attributes file that contains additional basin information, such as basin area.\n", "# This file is used for visualizing the distribution of basin areas.\n", "ATTRIBUTES_FILE_PATH = '~/flood-forecasting/tutorial/Caravan-nc/attributes/camels/attributes_other_camels.csv'\n", "\n", "# Path to a base directory containing one or more model run directories.\n", "# This directory should house all your trained models.\n", "# The interactive selection widgets in subsequent cells will scan this directory to allow you to choose which model run to evaluate.\n", "MODEL_RUN_DIR = '/home/gsnearing/flood-forecasting/tutorial/model-runs'\n", "\n", "# The directory where .yml model configuration files are located.\n", "# These files define parameters for model runs, including data paths and training settings.\n", "CONFIG_DIR = '/home/gsnearing/flood-forecasting/tutorial/configs'\n", "\n", "# --- OPTIONAL: Full Training Pipeline Trigger ---\n", "# By default, this tutorial uses pre-trained model outputs to save time.\n", "# Set this flag to 'True' ONLY if you want to perform a fresh training run\n", "# and generate new predictions from scratch.\n", "RUN_NEW_MODELS = False\n", "\n", "# A boolean flag to control whether performance statistics (metrics) are recalculated or loaded from pre-existing files.\n", "# Set this to 'True' for the initial run or if model outputs have changed and you need fresh calculations.\n", "# Set to 'False' if metrics have already been computed and saved, to save time during subsequent runs.\n", "# This is particularly useful for live demos or quick comparisons where recalculation is not necessary.\n", "CALCULATE_STATISTICS = False" ] }, { "cell_type": "markdown", "metadata": { "id": "3ZcLWaP3bZ1l" }, "source": [ "## Interactive Functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Z2HqxIvobYS7" }, "outputs": [], "source": [ "def select_model(model_selection):\n", " \"\"\"\n", " Sets the global variables for the selected base model run directory and name.\n", "\n", " This function is designed to be used with an `ipywidgets.Dropdown` to allow\n", " interactive selection of a model. It updates two global variables:\n", " - `model_run_dir`: The absolute path to the directory of the selected model run.\n", " - `model_name`: A user-friendly display name for the selected model.\n", "\n", " Args:\n", " model_selection (str): The key (display name) of the selected model run\n", " from the available options in the dropdown.\n", " \"\"\"\n", " global model_run_dir, model_name\n", "\n", " # Get the absolute path for the selected run directory from the pre-populated dictionary\n", " model_run_dir = available_run_dirs.get(model_selection)\n", " # Create a display name for the selected model, including its selection identifier\n", " model_name = f'Model ({model_selection})'\n", "\n", " print(\"\\nConfiguration set:\")\n", " print(f\" Model: {model_name} ({model_run_dir})\")" ] }, { "cell_type": "markdown", "metadata": { "id": "e88885de" }, "source": [ "## Base Model" ] }, { "cell_type": "markdown", "metadata": { "id": "HMu8Mya-wjud" }, "source": [ "### Training & Inference Pipelines" ] }, { "cell_type": "markdown", "metadata": { "id": "IqmtDMY01xcK" }, "source": [ "#### Run Commands" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "52nueCXgwjJG" }, "outputs": [], "source": [ "# --- Base Model: Generate Run Commands ---\n", "\n", "# This cell generates the commands needed to train and run inference for the base model.\n", "# These commands are typically executed outside of this notebook in a terminal or a separate process\n", "# that has access to the `flood-forecasting/googlehydrology` project and its\n", "# required Python environment (e.g., a Conda environment).\n", "\n", "# Get the experiment name. We are assuming '5-basin-example' as the base experiment\n", "# because the base_config may not be loaded yet at this stage.\n", "training_config_file = 'train-config.yml'\n", "base_config_file = f\"{CONFIG_DIR}/{training_config_file}\"\n", "\n", "# Construct the base model training command.\n", "# This command initiates the training process for the base model, using its specific configuration file.\n", "base_train_command = f\"run train --config-file={base_config_file}\"\n", "\n", "# Construct the base model inference command.\n", "# After a base model has been trained, this command is used to run predictions\n", "# (inference) using the trained model. It specifies the `--run-dir`\n", "# which points to the output directory of the base model training run.\n", "# NOTE: At this stage, the actual run directory might not exist yet if you are training a new model.\n", "# You will need to replace '' with the actual directory name created by STEP 1.\n", "base_infer_command = f\"run infer --run-dir={MODEL_RUN_DIR}\"\n", "\n", "# Print the run instructions to the screen\n", "print(\"--------------------------------------------------------------------------------------------------\")\n", "print(\" BASE MODEL TRAINING AND INFERENCE INSTRUCTIONS \")\n", "print(\"--------------------------------------------------------------------------------------------------\")\n", "print(\"These commands need to be executed in your terminal or a separate process.\")\n", "print(\"Ensure you are in an environment with the necessary Python libraries (e.g., the Conda environment).\")\n", "print(\"\")\n", "print(\"IMPORTANT: Before running these commands, please examine the configuration file:\")\n", "print(f\" Config File Location: {base_config_file}\")\n", "print(\" This file defines crucial parameters for your model run, including data paths, model architecture, and training settings.\")\n", "print(\" Understanding its contents will help you customize and debug your experiments.\")\n", "print(\"\")\n", "print(\"STEP 1: Run the base model training command FIRST (if you need to train a new base model):\")\n", "print(f\" {base_train_command}\")\n", "print(\"\")\n", "print(\"STEP 2: After the base model training is complete, run the inference command:\")\n", "print(\" (IMPORTANT: Replace '' with the actual run directory created by STEP 1, e.g., 'model-runs/5-basin-example_YYYYMMDD_HHMMSS')\")\n", "print(f\" {base_infer_command}\")\n", "print(\"\")\n", "print(\"--------------------------------------------------------------------------------------------------\")" ] }, { "cell_type": "markdown", "metadata": { "id": "3xI3P_kh1xcM" }, "source": [ "#### Run Training & Inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "rCPCJ1621xcM" }, "outputs": [], "source": [ "if RUN_NEW_MODELS:\n", " # 1. Execute Base Model Training:\n", " # This calls the 'run train' command with your specified YAML config.\n", " # WARNING: Full training can take a significant amount of time (minutes to hours)\n", " # depending on your hardware (CPU/GPU) and the number of basins.\n", " print(f\"Executing training: {base_train_command}\")\n", " os.system(base_train_command)\n", "\n", " # Load the base config file to get the actual experiment name\n", " with open(base_config_file, 'r') as f:\n", " base_config_data = yaml.safe_load(f)\n", " actual_experiment_name = base_config_data.get('experiment_name', 'default_experiment')\n", " print(f\"Experiment name from config file: {actual_experiment_name}\")\n", "\n", " # Get all potential run directories for the base experiment\n", " search_pattern = os.path.join(MODEL_RUN_DIR, f'{actual_experiment_name}_*_*/')\n", " all_run_dirs = glob.glob(search_pattern)\n", "\n", " if not all_run_dirs:\n", " print(f\"Error: No run directories found matching pattern {search_pattern}\")\n", " else:\n", " # Sort them by modification time to get the most recent one\n", " # Or, parse the timestamp from the directory name if it's consistently formatted\n", " # For now, let's assume the default `googlehydrology` naming `experiment_YYYYMMDD_HHMMSS` makes lexicographical sort work.\n", " all_run_dirs.sort()\n", " base_model_actual_run_dir = all_run_dirs[-1].rstrip('/') # Get the last (most recent) and remove trailing slash\n", " print(f\"Found most recent base model run directory: {base_model_actual_run_dir}\")\n", "\n", " # 2. Execute Inference:\n", " # After training completes, this command generates the actual streamflow\n", " # predictions (.nc or .zarr files) used for the hydrographs later in this notebook.\n", " infer_command = f\"run infer --run-dir={base_model_actual_run_dir}\"\n", " print(f\"Executing inference: {infer_command}\")\n", " os.system(infer_command)\n", "else:\n", " print(\"Skipping re-training. You will select the base model in the next step.\")" ] }, { "cell_type": "markdown", "metadata": { "id": "OGNH-O-NaA7o" }, "source": [ "### Choose Base Model\n", "\n", "This section allows you to select a **base model** from your local `MODEL_RUN_DIR`. The notebook automatically scans this directory (including subdirectories) to find all available model run folders. Each folder represents a unique experiment or training run.\n", "\n", "An interactive dropdown widget is then populated with the names of these detected runs. Selecting an option from this dropdown will set the `model_run_dir` and `model_name` global variables, which are used throughout the rest of the notebook for loading data and calculating metrics for the chosen base model.\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ykzxoSmSzKif" }, "outputs": [], "source": [ "# Find available run directories using the path set in the cell above\n", "available_run_dirs = backend.find_model_run_dirs(MODEL_RUN_DIR)\n", "\n", "# Extract the name of the most recent model run for default selection\n", "if 'base_model_actual_run_dir' in globals():\n", " latest_run_name = os.path.basename(base_model_actual_run_dir)\n", "else:\n", " latest_run_name = None\n", "\n", "# Create a dropdown widget for selecting the base model\n", "run_dir_options = sorted(available_run_dirs.keys())\n", "base_dropdown = widgets.Dropdown(\n", " options=run_dir_options,\n", " description='Select Base Model:',\n", " disabled=False,\n", " style = {'description_width': 'initial'},\n", " value=latest_run_name if latest_run_name in run_dir_options else run_dir_options[0]\n", ")\n", "\n", "# Create an interactive widget linking the dropdown to the selection function\n", "interactive_selection = interactive(\n", " select_model,\n", " model_selection=base_dropdown,\n", ")\n", "\n", "# Display the interactive widget\n", "display(interactive_selection)" ] }, { "cell_type": "markdown", "metadata": { "id": "OA6sGtYclfT3" }, "source": [ "### Visualize Train and Test Basins" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wDImbSOQFRer" }, "outputs": [], "source": [ "# Load train and test basin lists\n", "# `backend.load_model_config_and_basins` reads the configuration file of the selected base model\n", "# (from the `model_run_dir`) to extract which basins were used for training and testing.\n", "# It returns the configuration dictionary, and sets of training and testing basin IDs.\n", "base_config, train_basin_ids, test_basin_ids = backend.load_model_config_and_basins(model_run_dir)\n", "\n", "# Plot train and test basins\n", "# `backend.plot_train_test_shapefile` visualizes the geographical locations of the\n", "# training and testing basins on a map. It uses the `SHAPEFILE_PATH` to draw the basin\n", "# boundaries, highlights `train_basin_ids` and `test_basin_ids` with different colors,\n", "# and uses `model_name` for the plot title.\n", "backend.plot_train_test_shapefile(\n", " shapefile_path=SHAPEFILE_PATH,\n", " train_basin_ids=train_basin_ids,\n", " test_basin_ids=test_basin_ids,\n", " model_name=model_name\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "bf7d54f2" }, "source": [ "### Calculate Metrics" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "XnCJ5QHIJSul" }, "outputs": [], "source": [ "# Load simulation data and calculate/load metrics for the base model.\n", "# `backend.load_data_and_metrics` handles reading the raw model outputs\n", "# and, if specified, computes performance metrics across different lead times.\n", "# It requires the `model_run_dir` to locate the data and `test_basin_ids` to filter for relevant basins.\n", "base_model_data, base_model_metrics = backend.load_data_and_metrics(\n", " model_run_dir=model_run_dir,\n", " test_basin_ids=test_basin_ids,\n", " # The 'calculate_statistics' parameter controls whether metrics are re-calculated\n", " # or loaded from a pre-saved file. Setting it to `True` (as done here temporarily)\n", " # forces recalculation, which is useful after a new model run or if previous\n", " # calculations are outdated. If set to `False` (referencing the global\n", " # CALCULATE_STATISTICS), it will load existing metrics if available, saving time.\n", " calculate_statistics=CALCULATE_STATISTICS,\n", " model_name=model_name\n", ")\n", "\n", "# `base_model_data` will contain the raw simulation outputs (e.g., streamflow_sim, streamflow_obs)\n", "# typically as an xarray Dataset, allowing for detailed hydrograph analysis.\n", "# `base_model_metrics` will be a pandas DataFrame containing calculated performance scores\n", "# (e.g., KGE, NSE) for each basin and lead time, used for quantitative comparisons.\n", "display(base_model_metrics.head())" ] }, { "cell_type": "markdown", "metadata": { "id": "_uAz54aQi-ri" }, "source": [ "### Visualize Score Distributions" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "FAgVHeN_kCSy" }, "outputs": [], "source": [ "# Get the list of available metrics from the columns of the base model metrics DataFrame.\n", "# This ensures the dropdown only presents metrics that have actually been calculated and are present.\n", "available_metrics = list(base_model_metrics.columns)\n", "\n", "# Create a dropdown widget for selecting a metric to plot at lead time 0.\n", "# The default value is 'KGE' if available, otherwise it defaults to the first metric in the list.\n", "metric_widget = widgets.Dropdown(\n", " options=available_metrics,\n", " value='KGE' if 'KGE' in available_metrics else available_metrics[0],\n", " description='Select Metric for Lead Time 0 Plot:',\n", " disabled=False,\n", " style = {'description_width': 'initial'}\n", ")\n", "\n", "# Create an interactive widget that links the 'metric_widget' dropdown to the 'plot_lead_time_zero_scores' function.\n", "# `widgets.fixed` is used to pass static arguments (like dataframes and basin IDs) that do not change\n", "# when the dropdown value is altered, making the plot update dynamically only based on the selected metric.\n", "interactive_plot = interactive(\n", " backend.plot_lead_time_zero_scores,\n", " metrics_df=widgets.fixed(base_model_metrics),\n", " train_basin_ids=widgets.fixed(train_basin_ids),\n", " test_basin_ids=widgets.fixed(test_basin_ids),\n", " metric_name=metric_widget,\n", " model_name=widgets.fixed(model_name if 'model_name' in globals() else 'Base Model')\n", ")\n", "\n", "# Display the interactive widget, which includes the dropdown and the output plot.\n", "display(interactive_plot)" ] }, { "cell_type": "markdown", "metadata": { "id": "b0ATuuhR9a6P" }, "source": [ "## Fine Tuning" ] }, { "cell_type": "markdown", "metadata": { "id": "oyUKE4UMAW76" }, "source": [ "### Choose Basin" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "yhaN9WfT9djN" }, "outputs": [], "source": [ "# --- Fine-Tuning: Choose Basin for Fine-Tuning ---\n", "\n", "# This cell is dedicated to interactively selecting and visualizing the specific\n", "# basin that will be targeted for fine-tuning. The goal is to improve the\n", "# model's performance specifically for this chosen basin.\n", "\n", "# Initial hardcoded value for FINE_TUNING_BASIN. This will be the default\n", "# selection in the dropdown and will be used if no interactive selection occurs.\n", "# Example for CAMELS dataset: Basin with a low KGE score.\n", "# You can change this value to any valid basin ID from your dataset.\n", "DEFAULT_FINE_TUNING_BASIN_ID = 'camels_13235000' # Example for small model run in live demo.\n", "\n", "# Get the list of test basin IDs from the base model configuration.\n", "# These are the basins on which the base model was evaluated and from which\n", "# we can select a target for fine-tuning.\n", "# `test_basin_ids` is expected to be available globally from previous cells\n", "# (e.g., after `backend.load_model_config_and_basins` in the 'Base Model' section).\n", "fine_tuning_basin_options = sorted(list(test_basin_ids))\n", "\n", "# Create a dropdown widget for selecting the fine-tuning basin.\n", "# The options are derived from the `test_basin_ids` of the base model.\n", "# The default value is set to `DEFAULT_FINE_TUNING_BASIN_ID` if it exists\n", "# in the options, otherwise it defaults to the first available basin.\n", "basin_selector_dropdown = widgets.Dropdown(\n", " options=fine_tuning_basin_options,\n", " value=DEFAULT_FINE_TUNING_BASIN_ID if DEFAULT_FINE_TUNING_BASIN_ID in fine_tuning_basin_options else fine_tuning_basin_options[0],\n", " description='Select Fine-Tuning Basin:',\n", " disabled=False,\n", " style = {'description_width': 'initial'}\n", ")\n", "\n", "# Define a function to update the global FINE_TUNING_BASIN and plot\n", "# when a new basin is selected from the dropdown.\n", "def select_fine_tuning_basin(basin_id_selection):\n", " global FINE_TUNING_BASIN\n", " FINE_TUNING_BASIN = basin_id_selection\n", " print(f\"\\nSelected Fine-Tuning Basin: {FINE_TUNING_BASIN}\")\n", "\n", " # Plot the selected fine-tuning basin using the shapefile function.\n", " # This visualization helps confirm that the correct basin has been selected\n", " # and allows for a geographical context of the fine-tuning target.\n", " # `backend.plot_train_test_shapefile` is a utility function that draws basin boundaries\n", " # on a map, highlighting specific basins.\n", " backend.plot_train_test_shapefile(\n", " shapefile_path=SHAPEFILE_PATH, # Path to the shapefile containing basin geometries.\n", " train_basin_ids=set(), # An empty set because we are not highlighting training basins here.\n", " test_basin_ids={FINE_TUNING_BASIN}, # A set containing only the chosen fine-tuning basin ID to highlight it.\n", " model_name=f\"Fine-Tuning Basin: {FINE_TUNING_BASIN}\" # Updates the plot title to indicate the fine-tuning basin.\n", " )\n", "\n", "# Create an interactive widget linking the dropdown to the selection function.\n", "# This displays the dropdown and automatically calls `select_fine_tuning_basin`\n", "# whenever the dropdown value changes.\n", "interactive_basin_selection = interactive(\n", " select_fine_tuning_basin,\n", " basin_id_selection=basin_selector_dropdown,\n", ")\n", "\n", "# Display the interactive widget to the user.\n", "display(interactive_basin_selection)\n", "\n", "# Initialize FINE_TUNING_BASIN with the default or first value initially.\n", "# This ensures FINE_TUNING_BASIN is set even before the user interacts with the dropdown.\n", "if 'FINE_TUNING_BASIN' not in globals():\n", " FINE_TUNING_BASIN = basin_selector_dropdown.value\n", " print(f\"Initial Fine-Tuning Basin: {FINE_TUNING_BASIN}\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "urTMYkecGq_Q" }, "outputs": [], "source": [ "finetune_basin_kge_scores = base_model_metrics.loc[f'{FINE_TUNING_BASIN}', 'KGE']\n", "\n", "# Create a simple bar chart\n", "plt.figure(figsize=(10, 6))\n", "plt.bar(finetune_basin_kge_scores.index, finetune_basin_kge_scores.values)\n", "\n", "plt.title(f\"Base Model KGE vs. Lead Time for Basin {FINE_TUNING_BASIN}\")\n", "plt.xlabel(\"Lead Time (days)\")\n", "plt.ylabel(\"KGE Score\")\n", "plt.grid(axis='y', linestyle='--', alpha=0.7)\n", "plt.xticks(finetune_basin_kge_scores.index)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": { "id": "-QLfqnjF1xcS" }, "source": [ "### Finetuning & Inference Pipelines" ] }, { "cell_type": "markdown", "metadata": { "id": "6Kg-hqJe5uLW" }, "source": [ "#### Create Finetuning Config File" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "kcaKIGzK5zJK" }, "outputs": [], "source": [ "# --- Configuration File Paths ---\n", "finetune_config_template_path = f'{CONFIG_DIR}/finetune-config.yml'\n", "finetune_basin_config_file = f'{CONFIG_DIR}/finetune-config-{FINE_TUNING_BASIN}.yml'\n", "\n", "# --- Execute Config Generation ---\n", "\n", "# 1. Create the basin list file\n", "backend.create_basin_list_file(FINE_TUNING_BASIN, output_dir='basin-lists')\n", "\n", "# 2. Generate the config file\n", "# Construct the full path to the base model run directory\n", "base_model_run_dir_path = os.path.join(MODEL_RUN_DIR, os.path.basename(model_run_dir))\n", "\n", "backend.generate_basin_finetune_config(\n", " template_path=finetune_config_template_path,\n", " basin_id=FINE_TUNING_BASIN,\n", " base_model_dir=base_model_run_dir_path,\n", " output_path=finetune_basin_config_file\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "4FXfsXY4AUKn" }, "source": [ "#### Run Command" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "g4YzINjH9dl3" }, "outputs": [], "source": [ "# --- Fine-Tuning: Generate Run Command ---\n", "\n", "# This cell generates the commands needed to run the fine-tuning experiment\n", "# and subsequent inference for the selected basin.\n", "\n", "# Assumes FINE_TUNING_BASIN and finetune_basin_config_file are available globally.\n", "\n", "# Construct the fine-tuning training command.\n", "# This points to the basin-specific config file we just generated.\n", "finetune_train_command = f\"run finetune --config-file={finetune_basin_config_file}\"\n", "\n", "# Construct the fine-tuning inference command.\n", "# The run directory will be inside the base model directory, starting with the experiment name.\n", "# Note: The actual directory will have a timestamp suffix.\n", "finetune_experiment_name = f\"finetune-{FINE_TUNING_BASIN}\"\n", "finetune_infer_command_display = f\"run infer --run-dir={model_run_dir}/{finetune_experiment_name}_YYYYMMDD_HHMMSS\"\n", "\n", "\n", "# Print the run instructions to the screen\n", "print(\"--------------------------------------------------------------------------------------------------\")\n", "print(\" FINE-TUNING AND INFERENCE RUN INSTRUCTIONS \")\n", "print(\"--------------------------------------------------------------------------------------------------\")\n", "print(\"These commands need to be executed in your terminal or a separate process.\")\n", "print(\"Ensure you are in an environment with the necessary Python libraries (e.g., the Conda environment).\")\n", "print(\"\")\n", "print(\"IMPORTANT: The configuration file has been automatically generated for your selected basin:\")\n", "print(f\" Config File Location: {finetune_basin_config_file}\")\n", "print(\" This file contains the specific settings for fine-tuning on basin \" + str(FINE_TUNING_BASIN))\n", "print(\" It has been pre-configured with the correct base model paths and basin list.\")\n", "print(\"\")\n", "print(\"STEP 1: Run the fine-tuning training command:\")\n", "print(f\" {finetune_train_command}\")\n", "print(\"\")\n", "print(\"STEP 2: After training is complete, find the created run directory and run inference.\")\n", "print(f\" The output directory will be created inside: {model_run_dir}\")\n", "print(f\" It will be named like: {finetune_experiment_name}_YYYYMMDD_HHMMSS\")\n", "print(\"\")\n", "print(\" Run the inference command using that directory:\")\n", "print(f\" {finetune_infer_command_display}\")\n", "print(\"\")\n", "print(\"--------------------------------------------------------------------------------------------------\")" ] }, { "cell_type": "markdown", "metadata": { "id": "r_l64g1h1xcT" }, "source": [ "#### Run Finetuning & Inference" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "sQ2DY5NU1xcU" }, "outputs": [], "source": [ "if RUN_NEW_MODELS:\n", " # 1. Execute Targeted Fine-Tuning:\n", " # This calls 'run finetune' using the basin-specific YAML config.\n", " # It freezes most of the model (LSTMs) and only updates the 'static_attributes_fc'\n", " # and 'head' layers to better represent the unique characteristics of this basin.\n", " print(f\"Executing fine-tuning: {finetune_train_command}\")\n", " os.system(finetune_train_command)\n", "\n", " # Load the fine-tune config file to get the actual experiment name\n", " # We use the dynamically generated config file path from the previous steps\n", " with open(finetune_basin_config_file, 'r') as f:\n", " finetune_config_data = yaml.safe_load(f)\n", " actual_finetune_experiment_name = finetune_config_data.get('experiment_name', 'default_finetune_experiment')\n", " print(f\"Finetune experiment name from config file: {actual_finetune_experiment_name}\")\n", "\n", " # Get all potential run directories for the fine-tune experiment\n", " # These are expected to be within the base model's run directory structure.\n", " finetune_search_pattern = os.path.join(model_run_dir, f'{actual_finetune_experiment_name}_*_*/')\n", " all_finetune_run_dirs = glob.glob(finetune_search_pattern)\n", "\n", " if not all_finetune_run_dirs:\n", " print(f\"Error: No fine-tune run directories found matching pattern {finetune_search_pattern}\")\n", " else:\n", " all_finetune_run_dirs.sort()\n", " finetune_model_actual_run_dir = all_finetune_run_dirs[-1].rstrip('/')\n", " print(f\"Found most recent fine-tune model run directory: {finetune_model_actual_run_dir}\")\n", "\n", " # 2. Execute Inference on Fine-Tuned Model:\n", " # Once fine-tuning is complete, this command generates new predictions\n", " # using the updated weights.\n", " finetune_infer_command_actual = f\"run infer --run-dir={finetune_model_actual_run_dir}\"\n", " print(f\"Executing fine-tuned inference: {finetune_infer_command_actual}\")\n", " os.system(finetune_infer_command_actual)\n", "else:\n", " print(\"Skipping fresh fine-tuning. Loading existing fine-tuned artifacts from the run directory.\")" ] }, { "cell_type": "markdown", "metadata": { "id": "k3J1RGtqFdDF" }, "source": [ "### Select Fine Tuning Model Run" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "hIhmul5TA-2I" }, "outputs": [], "source": [ "# --- Fine-Tuning: Select Fine-Tuning Model Run ---\n", "\n", "# Define the expected experiment name based on the selected basin\n", "finetune_experiment_name = f\"finetune-{FINE_TUNING_BASIN}\"\n", "\n", "# Try to find the run directory dynamically\n", "# This handles both cases:\n", "# 1. A new run was just created (with a timestamp).\n", "# 2. We are loading an existing run from disk.\n", "\n", "# Search for directories matching the experiment name pattern inside the base model directory\n", "search_path = os.path.join(model_run_dir, f\"{finetune_experiment_name}*\")\n", "found_dirs = glob.glob(search_path)\n", "\n", "if found_dirs:\n", " # Sort to find the most recent one (lexicographical sort works for standard timestamps)\n", " found_dirs.sort()\n", " finetune_run_dir = found_dirs[-1]\n", "else:\n", " # Fallback if no directory is found (e.g., if the model hasn't been run yet)\n", " finetune_run_dir = os.path.join(model_run_dir, finetune_experiment_name)\n", " print(f\"Warning: No run directory found matching '{finetune_experiment_name}*'. Defaulting to: {finetune_run_dir}\")\n", "\n", "# Update the global variables to point to the fine-tuned model.\n", "# This ensures that the next cell (metrics calculation) loads data from this new directory.\n", "model_name = f'Fine-Tuned Model ({FINE_TUNING_BASIN})'\n", "\n", "print(f\"Selected Fine-Tuned Model: {model_name}\")\n", "print(f\"Run Directory: {finetune_run_dir}\")" ] }, { "cell_type": "markdown", "metadata": { "id": "X852gLEPCDqQ" }, "source": [ "### Calculate Metrics" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "diDTDWFc9dpH" }, "outputs": [], "source": [ "# --- Fine-Tuning: Calculate Metrics ---\n", "\n", "# This cell is responsible for loading the configuration and test basin IDs for the\n", "# selected fine-tuned model, and then either calculating or loading its performance metrics.\n", "# It mirrors the process used for the base model but focuses on the fine-tuned experiment.\n", "\n", "# Load configuration and basin IDs for the fine-tuned model.\n", "# `backend.load_model_config_and_basins` reads the configuration file associated with\n", "# the `finetune_run_dir` (which was set in the cell above).\n", "# It extracts the experiment's configuration details, as well as the training and testing\n", "# basin IDs used for this specific fine-tuned run.\n", "finetune_config, finetune_train_basin_ids, finetune_test_basin_ids = backend.load_model_config_and_basins(finetune_run_dir)\n", "\n", "# Load simulation data and calculate/load metrics for the fine-tuned model.\n", "# `backend.load_data_and_metrics` performs the heavy lifting of reading raw model outputs\n", "# and, based on the `CALCULATE_STATISTICS` flag, computes or loads performance metrics.\n", "# - `model_run_dir`: The directory of the fine-tuned model run selected in the previous step.\n", "# - `finetune_test_basin_ids`: The set of basin IDs that were used for testing this fine-tuned model.\n", "# - `CALCULATE_STATISTICS`: A global boolean flag. If `True`, metrics are recalculated;\n", "# if `False`, pre-saved metrics are loaded to save computation time.\n", "# - `model_name`: A descriptive name for the fine-tuned model, used for display purposes.\n", "finetune_data, finetune_metrics = backend.load_data_and_metrics(\n", " model_run_dir=finetune_run_dir,\n", " test_basin_ids=finetune_test_basin_ids,\n", " calculate_statistics=CALCULATE_STATISTICS,\n", " model_name=model_name\n", ")\n", "\n", "# Display the head of the calculated/loaded fine-tuned model metrics DataFrame.\n", "# This provides a quick overview of the metrics, including basin IDs, lead times, and scores.\n", "display(finetune_metrics.head())" ] }, { "cell_type": "markdown", "metadata": { "id": "yURISNzJCrDN" }, "source": [ "### Visualize Metircs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Ii3MpfSvCZ9U" }, "outputs": [], "source": [ "# --- Fine-Tuning: Visualize Metrics ---\n", "\n", "# This cell provides an interactive visualization to compare the performance\n", "# of the base model against the fine-tuned model for a specific basin and metric.\n", "# This allows you to assess the impact of fine-tuning on the target basin's predictions\n", "# across different lead times.\n", "\n", "# Get the list of all unique basin IDs present in the base model metrics.\n", "# This ensures that the dropdown for basin selection contains only valid basin IDs\n", "# for which data is available.\n", "all_basin_ids = sorted(list(base_model_metrics.index.get_level_values('basin_id').unique()))\n", "\n", "# Get the list of available metrics from the base model metrics columns.\n", "# This ensures the dropdown only presents metrics that have actually been calculated and are present.\n", "available_metrics = list(base_model_metrics.columns)\n", "\n", "# Create a dropdown widget for selecting the basin.\n", "# The default value is set to the `FINE_TUNING_BASIN`.\n", "# We ensure it matches the format in all_basin_ids.\n", "basin_widget = widgets.Dropdown(\n", " options=all_basin_ids,\n", " value=FINE_TUNING_BASIN if FINE_TUNING_BASIN in all_basin_ids else all_basin_ids[0],\n", " description='Select Basin:',\n", " disabled=False,\n", " style = {'description_width': 'initial'}\n", ")\n", "\n", "# Create a dropdown widget for selecting the metric.\n", "# The default metric is 'KGE' (Kling-Gupta Efficiency), a common hydrological performance metric.\n", "metric_widget = widgets.Dropdown(\n", " options=available_metrics,\n", " value='KGE',\n", " description='Select Metric:',\n", " disabled=False,\n", " style = {'description_width': 'initial'}\n", ")\n", "\n", "# Create an interactive widget linking the dropdowns to the plotting function.\n", "# `plot_comparison_metrics_vs_lead_time` will dynamically update the plot\n", "# whenever a new basin or metric is selected from the dropdowns.\n", "# `widgets.fixed` is used to pass the metrics DataFrames as static arguments\n", "# because they do not change with dropdown selections.\n", "interactive_plot = interactive(\n", " backend.plot_comparison_metrics_vs_lead_time,\n", " base_metrics_df=widgets.fixed(base_model_metrics), # Pass the base metrics DataFrame\n", " finetune_metrics_df=widgets.fixed(finetune_metrics),\n", " basin_id=basin_widget,\n", " metric_name=metric_widget\n", ")\n", "\n", "# Display the interactive widget and the plot output.\n", "# The `interactive_plot` object contains the controls (dropdowns) and the output (plot).\n", "# We separate them to display controls above the plot.\n", "ui = interactive_plot.children[:-1] # Controls are all children except the last one (the output)\n", "out = interactive_plot.children[-1] # The output is the last child\n", "\n", "display(VBox(ui), out) # Display controls and output in a VBox" ] }, { "cell_type": "markdown", "metadata": { "id": "MfR8mYcgDzGH" }, "source": [ "### Visualize Hydrographs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ewgCq6AmG4vr" }, "outputs": [], "source": [ "# --- Fine-Tuning: Visualize Hydrographs ---\n", "\n", "# This cell generates a hydrograph plot to visually compare the observed streamflow\n", "# against the simulated streamflow from both the base model and the fine-tuned model.\n", "# This visualization is critical for understanding the models' ability to reproduce\n", "# the hydrological behavior of the selected basin, especially at lead time zero (the current day's prediction).\n", "\n", "# Define the lead time for the hydrograph comparison. LEAD_TIME_ZERO typically refers\n", "# to the hindcast or current-day prediction, which is often the most accurate.\n", "LEAD_TIME_ZERO = 0\n", "\n", "# Normalize the fine-tuning basin ID for consistent data selection.\n", "# The `basin_widget.value` holds the currently selected basin from the interactive dropdown\n", "# in the previous cell (Ii3MpfSvCZ9U), ensuring the hydrograph is plotted for the user's chosen basin.\n", "finetune_basin = basin_widget.value\n", "\n", "# Select observed streamflow data for the chosen fine-tuning basin at the specified lead time.\n", "# `finetune_data` contains the loaded observed data (streamflow_obs) for the fine-tuned run.\n", "obs_hydrograph = finetune_data['streamflow_obs'].sel(\n", " basin=finetune_basin,\n", " time_step=LEAD_TIME_ZERO\n", ")\n", "\n", "# Select base model simulated streamflow data for the chosen fine-tuning basin at lead time 0.\n", "# `base_model_data` contains the loaded simulation results (streamflow_sim) for the base model.\n", "base_model_hydrograph = base_model_data['streamflow_sim'].sel(\n", " basin=finetune_basin,\n", " time_step=LEAD_TIME_ZERO\n", ")\n", "\n", "# Select fine-tuned model simulated streamflow data for the chosen fine-tuning basin at lead time 0.\n", "# `finetune_data` also contains the simulation results for the fine-tuned model.\n", "finetune_model_hydrograph = finetune_data['streamflow_sim'].sel(\n", " basin=finetune_basin,\n", " time_step=LEAD_TIME_ZERO\n", ")\n", "\n", "# Create the plot figure.\n", "plt.figure(figsize=(12, 6))\n", "\n", "# Plot the hydrographs using different colors and line styles for easy differentiation.\n", "# Observed streamflow is shown in black, base model in blue (dashed), and fine-tuned model in red (dash-dot).\n", "obs_hydrograph.plot(label='Observed', color='black')\n", "base_model_hydrograph.plot(label='Base Model', color='blue', linestyle='--')\n", "finetune_model_hydrograph.plot(label='Fine-Tuned Model', color='red', linestyle='-.')\n", "\n", "# Add title, axis labels, and a grid for clarity.\n", "plt.title(f\"Hydrograph Comparison for Basin {FINE_TUNING_BASIN} (Lead Time {LEAD_TIME_ZERO})\")\n", "plt.xlabel(\"Date\")\n", "plt.ylabel(\"Streamflow\")\n", "plt.grid(True, linestyle='--', alpha=0.6)\n", "plt.legend() # Display the legend to identify each line\n", "plt.tight_layout() # Adjust layout to prevent labels from overlapping\n", "plt.show() # Display the generated plot" ] }, { "cell_type": "markdown", "metadata": { "id": "Z8YXUwl4S_lw" }, "source": [ "## Look at Basin Area Distribution" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1RPFa1N-hbT4" }, "outputs": [], "source": [ "# Load attributes data from the specified CSV file.\n", "# This DataFrame contains various hydrological attributes for different basins, including their area.\n", "attributes_df = pd.read_csv(ATTRIBUTES_FILE_PATH)\n", "\n", "# Get the fine-tuning basin ID from the `basin_widget` (from the previous interactive cell).\n", "# This ensures that the highlighted basin in the plot is the one currently selected by the user.\n", "finetune_basin_id = basin_widget.value\n", "\n", "# Filter the `attributes_df` to include only the basins that were used for training.\n", "# We select both the original 'gauge_id' and the 'area' for these basins.\n", "training_basin_areas = attributes_df[\n", " attributes_df['gauge_id'].isin(train_basin_ids)\n", "][['gauge_id', 'area']]\n", "\n", "# Get the area of the fine-tuning basin from `attributes_df`.\n", "# We check if `attributes_df` and its required columns exist to prevent errors.\n", "# If the fine-tuning basin's area is found, it's stored; otherwise, a warning is printed.\n", "if 'attributes_df' in globals() and 'gauge_id' in attributes_df.columns and 'area' in attributes_df.columns:\n", " finetune_basin_area_row = attributes_df[attributes_df['gauge_id'] == finetune_basin_id]\n", " if not finetune_basin_area_row.empty:\n", " finetune_basin_area = finetune_basin_area_row['area'].iloc[0]\n", " else:\n", " finetune_basin_area = None\n", " print(f\"Warning: Area not found for fine-tuning basin ID: {finetune_basin_id}\")\n", "else:\n", " finetune_basin_area = None\n", " print(\"Warning: attributes_df not available or missing required columns.\")\n", "\n", "\n", "# Plot a histogram of the training basin areas.\n", "# This visualizes the distribution of catchment areas for the basins the base model was trained on.\n", "plt.figure(figsize=(10, 6))\n", "plt.hist(training_basin_areas['area'], bins=50, alpha=0.7, label='Training Basin Areas')\n", "\n", "# Add a vertical line or marker for the fine-tuning basin area if it was found.\n", "# This helps to visually compare the size of the fine-tuning basin relative to the training basins.\n", "if finetune_basin_area is not None:\n", " plt.axvline(finetune_basin_area, color='red', linestyle='dashed', linewidth=2, label=f'Fine-Tuning Basin Area ({finetune_basin_area:.2f})')\n", "\n", "# Set plot labels and title for clarity.\n", "plt.xlabel(\"Catchment Area\")\n", "plt.ylabel(\"Frequency\")\n", "plt.title(\"Distribution of Training Basin Areas with Fine-Tuning Basin Area Highlighted\")\n", "plt.legend() # Display the legend to differentiate between training areas and the fine-tuning basin\n", "plt.grid(axis='y', linestyle='--', alpha=0.7) # Add a horizontal grid for readability\n", "plt.show() # Display the generated plot" ] }, { "cell_type": "markdown", "metadata": { "id": "27eef590" }, "source": [ "### Why is Basin Area Distribution Important for Fine-Tuning?\n", "\n", "The basin area distribution plot (from the cell above) is crucial for understanding our fine-tuning strategy, especially concerning the `static_attributes_fc` layer. Here's why:\n", "\n", "When we fine-tune a pre-trained model, our goal is often to improve its performance on a specific, potentially 'out-of-distribution' target. The base model was trained on a set of basins, and its performance might be suboptimal for basins with characteristics significantly different from those it was trained on.\n", "\n", "1. **Basin Area as a Key Characteristic:** Basin area is a fundamental static attribute of a catchment. If the fine-tuning basin's area falls significantly outside the range or distribution of the training basin areas (as visualized in the histogram), it suggests that the base model might not have learned a robust representation for basins of that particular size.\n", "\n", "2. **Role of `static_attributes_fc` Layer:** The `static_attributes_fc` (static attributes fully connected) layer in the model is specifically designed to process these static basin characteristics (like area, elevation, geology, etc.). It learns to embed these attributes into a representation that the rest of the hydrological model can use.\n", "\n", "3. **Why Fine-Tune This Layer:** If our target fine-tuning basin has static attributes (like area) that are dissimilar to the bulk of the training data, the `static_attributes_fc` layer might be producing a suboptimal embedding for it. By fine-tuning *only* this layer, we allow the model to learn an improved, more accurate representation for the unique static characteristics of our target basin, without retraining the entire, more complex hydrological prediction part of the model. This targeted approach is more efficient and prevents 'catastrophic forgetting'—where retraining the whole model might degrade performance on the attributes it already learned well for the original training distribution.\n", "\n", "In essence, if the target basin is an outlier in terms of area, fine-tuning `static_attributes_fc` directly addresses this discrepancy by teaching the model how to better interpret and utilize that specific characteristic for the target basin's predictions." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "wXjA_wkWQpBg" }, "outputs": [], "source": [] } ], "metadata": { "colab": { "provenance": [], "toc_visible": true }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.12.13" } }, "nbformat": 4, "nbformat_minor": 4 }