Published March 21, 2026 | Version v1.0
Software Open

Transformer-Based Framework for Transferable Hourly Ozone Forecasting

Authors/Creators

Description

TFT Pipeline — Training, Transfer Learning, Inference & Visualization

End-to-end Temporal Fusion Transformer (TFT) workflow for air-quality time series: dataset prep, training from scratch, transfer learning, inference export, and rich visualizations (including dataset builder, trainer, transfer-learning, inference, and visualization suite — maps and TensorBoard-to-PNG and PNG→GIF utilities optimized for air-quality or environmental time series).

---

Repository Layout

```
tft_forecasting_kit/
├── pyproject.toml
├── requirements.txt
├── requirements_jsc_booster.txt
├── requirements_jsc_cluster.txt
├── Config/                         # YAML config files
├── HPC_Scripts/                    # SLURM or HPC jobs
├── Plots/                          # Default output for plots/GIFs # Folder provided in 10.5281/zenodo.19151740 
├── src/
│   ├── Make_Datasets.py
│   ├── Train_TFT.py
│   ├── Finetune_transfer.py
│   ├── Optuna_Tune_Transfer.py
│   ├── Run_Inference.py
│   ├── dataset/
│   │   └── repository.py
│   ├── utils/
│   │   ├── config.py
│   │   └── logging.py
│   ├── custom_pytorch_forecasting/
│   │   └── ModifiedTimeseriesDataset.py
│   └── visualize/
│       ├── Plot_Inference.py
│       ├── Map_Plots.py
│       ├── Tensorboard_plots.py
│       ├── pngs_to_gif.py
│       └── README.md
└── README.md
```

> **Note:** All training/inference scripts alias the custom dataset for backward compatibility:
`sys.modules['ModifiedTimeseriesDataset'] = custom_pytorch_forecasting.ModifiedTimeseriesDataset`.

---

Installation

```bash
# create and activate a virtual environment (conda/venv as you prefer)
python -m venv .venv && source .venv/bin/activate

# pick one:
pip install -r requirements.txt          # local/dev
# or on HPC
pip install -r requirements_jsc_booster.txt    # Booster env
pip install -r requirements_jsc_cluster.txt    # Cluster env
```

If you want to use Cartopy (maps), you may need to set `CARTOPY_DATA_DIR` (offline usage) or install `cartopy-data` on your system.

CLI Entrypoints (from `pyproject.toml`)


```toml
[project.scripts]
make-datasets = "Make_Datasets:main"
train-tft     = "Train_TFT:main"
run-inference = "Run_Inference:main"
```

So you can call:
```bash
make-datasets --help
train-tft --help
run-inference --help
```

---

Quick Start

```bash
# 1️. Create datasets
make-datasets --config Config/make.yml

# 2️. Train model
train-tft --config Config/train.yml

# 3️. (Optional) Transfer learning / Optuna tuning
python src/Finetune_transfer.py --config Config/finetune.yml
python src/Optuna_Tune_Transfer.py --config Config/tune.yml

# 4️. Run inference
run-inference --config Config/infer.yml

# 5️. Visualize results
python src/visualize/Plot_Inference.py --config Config/plot.yml
```

---

Modules Overview

| Folder | Purpose |
|---------|----------|
| `src/dataset/repository.py` | `PickleTSRepository` — load/save pickled PyTorch-Forecasting datasets |
| `src/utils/config.py` | `TrainConfig`, `PlotConfig`, YAML loading |
| `src/utils/logging.py` | `setup_logging`, `rank_zero_*` wrappers |
| `src/custom_pytorch_forecasting/ModifiedTimeseriesDataset.py` | `PatchedTimeSeriesDataSet` — standalone-usable replacement for original class |
| `src/visualize/` | complete plotting, mapping, and export utilities |

---

Data & Artifacts

- **Datasets**: training/validation are stored as *pickled* `PatchedTimeSeriesDataSet` objects and loaded via `PickleTSRepository`.
- **Checkpoints**: PyTorch Lightning `.ckpt` files in `lightning_logs/version_*/checkpoints/`.
- **Inference outputs**: torch pickles — `predictions_y.pkl`, `predictions_out.pkl`, `raw_predictions_x.pkl`, `raw_predictions_out.pkl`.
- **External CAMS**: daily NetCDF files like `ENS_FORECAST_YYYY-MM-DD.nc` for optional overlays/skills.

---

Using the Custom Dataset Standalone

```bash
cd src/custom_pytorch_forecasting
pip install .          # or: pip install -e .
```

Then in any script:
```python
from custom_pytorch_forecasting.ModifiedTimeseriesDataset import PatchedTimeSeriesDataSet
```

It behaves exactly like the stock `TimeSeriesDataSet`, with additional feature to have only continous samples.

---

1) Train (from scratch or continue)

```bash
python src/train_tft.py   --config Config/train.yml   --data-dir /path/to/pickles   --log-dir ./lightning_logs   --enc-len 336 --pred-len 96 --batch-size 128 --max-epochs 60   --accelerator gpu --devices 1 --strategy auto   --hidden-size 64 --lstm-layers 2 --dropout 0.1 --attention-heads 4 --learning-rate 1e-3
```

**Highlights**

- Supports `--config` (YAML) where CLI overrides take precedence.
- Can resume via `--checkpoint <.ckpt>` or auto-pick latest with `--auto-checkpoint-picker --auto-checkpoint-path <parent_of_version_*>`.
- Uses `TrainConfig` defaults for shapes/devices if not overridden.

---

2) Transfer Learning (Germany → Korea)

```bash
python src/finetune_transfer.py   --config Config/finetune.yml   --data-dir /path/to/korea_pickles   --checkpoint /path/to/germany.ckpt   --hidden-size 16 --attention-heads 4 --lstm-layers 1 --dropout 0.1   --learning-rate 5e-4   --freeze-policy default_meta   --log-dir ./lightning_logs_transfer --devices 1 --accelerator gpu
```

**What it does**

- Instantiates a fresh TFT from **target** dataset (vocab sizes match), then **loosely loads** source weights:
  - copies matching shapes,
  - smartly **expands embeddings** if vocab grew (new station codes),
  - skips incompatible tensors.
- Freezes all params, then **unfreezes** according to policy:
  - `default_meta`: station embeddings (+ selected static real prescalers),
  - `attn_output`: above **plus** attention and output layers.

---

3) LR Tuning on top of Transfer (Optuna)

```bash
python src/optuna_tune_transfer.py   --config Config/tune.yml   --data-dir /path/to/korea_pickles   --checkpoint /path/to/germany.ckpt   --hidden-size 16 --attention-heads 4 --lstm-layers 1 --dropout 0.1   --n-trials 12 --lr-min 1e-5 --lr-max 3e-3   --strategy ddp_find_unused_parameters_true   --devices 1 --accelerator gpu
```

- Fixes model architecture to the checkpoint-compatible values and **only tunes LR**.
- Applies the same warm-start + freeze policy per trial via a callback.
- Persists `optuna_lr/best.json` and `study.pkl` under the checkpoint dir.

---

4) Inference Export

```bash
python src/run_inference.py   --config Config/infer.yml   --data-dir /path/to/pickles   --checkpoint /path/to/best.ckpt   --out-dir ./inference_out   --prefix pred_   --batch-size 128
```

Produces:
- `predictions_y.pkl`, `predictions_out.pkl`
- `raw_predictions_x.pkl`, `raw_predictions_out.pkl`

---

5) Visualization Toolkit (`src/visualize`)

A) Plot per-sample, aggregates & skills

```bash
python src/visualize/Plot_Inference.py   --config Config/plot.yml   --validation-pkl /path/validation.pkl   --pred-y ./inference_out/pred_predictions_y.pkl   --pred-out ./inference_out/pred_predictions_out.pkl   --raw-x ./inference_out/pred_raw_predictions_x.pkl   --raw-out ./inference_out/pred_raw_predictions_out.pkl   --station-csv /path/station_meta.csv   --nc-dir /path/to/CAMS_nc_dir   --include-sample-plots   --use-external
```

- Feature flags: `--use-external/--no-use-external`, `--include-sample-plots/--no-include-sample-plots`.
- `PlotConfig` fields can be overridden from CLI, e.g. `--output-dir`, `--base-date`, `--sample-limit`, etc.

B) Per-timestamp maps (Actual/Pred/CAMS)

```bash
python src/visualize/Map_Plots.py   --config Config/maps.yml   --val-pkl /path/validation.pkl   --pred-y ./inference_out/predictions_y.pkl   --pred-out ./inference_out/predictions_out.pkl   --raw-x ./inference_out/raw_predictions_x.pkl   --raw-out ./inference_out/raw_predictions_out.pkl   --cams-nearest-csv /path/station_meta.csv   --cams-forecast-dir /path/CAMS_nc_dir   --output-dir ./Plots   --records-pkl ./Plots/map_plot_records.pkl   --use-external --external-scale 0.50115   --enc-len 336 --pred-len 96 --base-date "2020-01-01 00:00:00"
```

Tips:
- If `--records-pkl` already exists, it is **not** overwritten.
- To limit rendered timestamps: `--sample-limit N` (plots from `enc_len` to `enc_len+N`); otherwise from `0` to `len- pred_len`.

C) TensorBoard scalars → PNGs

```bash
python src/visualize/Tensorboard_plots.py   --version_dir ./lightning_logs   --plot_dir ./Plots/TensorBoard
```

D) PNGs → GIF (streaming, RGBA-safe)

```bash
python src/visualize/pngs_to_gif.py   ./Plots/Maps/*.png   -o ./Plots/Maps_anim.gif   --fps 12 --optimize
```

Options: `--duration`, `--loop`, `--size 1024x1024`, `--reverse`.

---

YAML Examples

Create small YAMLs under `Config/` and override via CLI as needed.

`train.yml`


```yaml
data_dir: /path/to/pickles
enc_len: 336
pred_len: 96
batch_size: 128
max_epochs: 60
accelerator: gpu
devices: 1
strategy: auto
hidden_size: 64
lstm_layers: 2
dropout: 0.1
attention_heads: 4
learning_rate: 0.001
log_dir: ./lightning_logs
# checkpoint: /path/to/prev.ckpt
# auto_checkpoint_picker: true
# auto_checkpoint_path: ./lightning_logs
```

`finetune.yml`


```yaml
data_dir: /path/to/korea_pickles
checkpoint: /path/to/germany.ckpt
hidden_size: 16
attention_heads: 4
lstm_layers: 1
dropout: 0.1
learning_rate: 0.0005
freeze_policy: default_meta  # or attn_output
log_dir: ./lightning_logs_transfer
devices: 1
accelerator: gpu
```

`infer.yml`

```yaml
data_dir: /path/to/pickles
checkpoint: /path/to/best.ckpt
out_dir: ./inference_out
prefix: pred_
batch_size: 128
```

`plot.yml`

```yaml
output_dir: ./Plots
base_date: "2020-01-01 00:00:00"
max_prediction_length: 96
max_encoder_length: 336
x_tick_gap_hours: 8
external_scale: 0.50115
use_external: true
include_sample_plots: true
sample_limit: 24
```

`maps.yml`

```yaml
val_pkl: /path/validation.pkl
pred_y: ./inference_out/predictions_y.pkl
pred_out: ./inference_out/predictions_out.pkl
raw_x: ./inference_out/raw_predictions_x.pkl
raw_out: ./inference_out/raw_predictions_out.pkl
cams_nearest_csv: /path/station_meta.csv
cams_forecast_dir: /path/CAMS_nc_dir
cams_file_pattern: "ENS_FORECAST_%Y-%m-%d.nc"
output_dir: ./Plots
records_pkl: ./Plots/map_plot_records.pkl
enc_len: 336
pred_len: 96
base_date: "2020-01-01 00:00:00"
extent: [5, 16, 47, 56]
cmap: viridis
dpi: 200
use_external: true
external_scale: 0.50115
```

CLI Scripts and HELP info

`Make_Datasets.py`

**Usage**
```bash
make-datasets --config Config/make.yml --country DE --mode all --indir data/raw --outdir data/pickles
```

**--help**
```
--config PATH              YAML config for source folders and target path
--country {DE,KR}          Dataset region
--mode {all,summer}        Filter by seasonal subset
--indir PATH               Input CSV folder
--codes PATH               Station codes CSV
--outcsv PATH              Output merged CSV
--pivot_glob STR           Pattern for pivoted CSVs
--rank INT                 MPI rank (default 0)
--world INT                MPI world size (default 1)
```

---

 `Train_TFT.py

`**Usage**
```bash
train-tft --config Config/train.yml --enc-len 336 --pred-len 96
```

**--help**
```
--config PATH              YAML configuration
--data-dir PATH            PickleTSRepository path
--log-dir PATH             Output Lightning logs
--checkpoint PATH          Resume checkpoint
--auto-checkpoint-picker   Use latest ckpt automatically
--auto-checkpoint-path PATH  Path containing version_* folders
--hidden-size INT          Hidden dimension size
--lstm-layers INT          Number of LSTM layers
--dropout FLOAT            Dropout fraction
--attention-heads INT      Number of attention heads
--learning-rate FLOAT      Learning rate
--max-epochs INT           Training epochs
--accelerator STR          'gpu' or 'cpu'
--devices INT              Device count
--strategy STR             DDP or auto
```

---

 `Finetune_transfer.py`

**Usage**
```bash
python src/Finetune_transfer.py --config Config/finetune.yml
```

**--help**
```
--config PATH              YAML config
--data-dir PATH            Target dataset (PickleTSRepository)
--checkpoint PATH          Source pretrained model
--freeze-policy {default_meta,attn_output}
--hidden-size INT
--attention-heads INT
--lstm-layers INT
--dropout FLOAT
--learning-rate FLOAT
--log-dir PATH
--devices INT
--accelerator STR
```

---

`Optuna_Tune_Transfer.py`

**Usage**
```bash
python src/Optuna_Tune_Transfer.py --config Config/tune.yml --n-trials 20
```

**--help**
```
--config PATH              YAML config
--checkpoint PATH          Finetune base checkpoint
--n-trials INT             Optuna trials
--lr-min FLOAT             Lower bound for LR search
--lr-max FLOAT             Upper bound
--timeout FLOAT            Stop after (sec)
--strategy STR             ddp_find_unused_parameters_true or auto
--devices INT
--accelerator STR
```

---

`Run_Inference.py`

**Usage**
```bash
run-inference --config Config/infer.yml --checkpoint /path/best.ckpt
```

**--help**
```
--config PATH              YAML config
--data-dir PATH            Dataset pickle directory
--checkpoint PATH          Model checkpoint
--out-dir PATH             Output directory
--prefix STR               Output prefix
--batch-size INT           Batch size for prediction
```

---

Visualization Suite

[All tools live under `src/visualize/` and have their own README.](src/visualize/README.md)

`Plot_Inference.py`

```
--config PATH
--validation-pkl PATH
--pred-y PATH
--pred-out PATH
--raw-x PATH
--raw-out PATH
--station-csv PATH
--nc-dir PATH
--output-dir PATH
--base-date STR
--enc-len INT
--pred-len INT
--sample-limit INT
--x-tick-gap-hours INT
--use-external / --no-use-external
--include-sample-plots / --no-include-sample-plots
```

 `Map_Plots.py`

```
--config PATH
--val-pkl PATH
--pred-y PATH
--pred-out PATH
--raw-x PATH
--raw-out PATH
--cams-nearest-csv PATH
--cams-forecast-dir PATH
--output-dir PATH
--records-pkl PATH
--use-external
--external-scale FLOAT
--enc-len INT
--pred-len INT
--base-date STR
--extent L R B T
--sample-limit INT
```

`Tensorboard_plots.py`

```
--version_dir PATH    Directory containing lightning_logs
--plot_dir PATH       Destination for PNGs
```

`pngs_to_gif.py`

```
patterns...           PNG glob patterns
-o, --out PATH        Output GIF file
--fps INT             Frames per second
--duration FLOAT      Duration per frame
--loop INT            Loop count
--size WxH            Resize output
--optimize            Optimize palette
--reverse             Append reversed frames
```

---

Example End-to-End Workflow

```bash
make-datasets --config Config/make.yml
train-tft --config Config/train.yml
python src/Finetune_transfer.py --config Config/finetune.yml
python src/Optuna_Tune_Transfer.py --config Config/tune.yml
run-inference --config Config/infer.yml
python src/visualize/Plot_Inference.py --config Config/plot.yml
python src/visualize/Map_Plots.py --config Config/maps.yml
python src/visualize/pngs_to_gif.py "./Plots/maps/*.png" -o ./Plots/maps.gif
```

---

Notes & Conventions

- `TrainConfig`/`PlotConfig` centralize **shapes**, **devices**, and **plot** defaults; CLI and YAML can override.
- Scripts set `torch.set_float32_matmul_precision("high")` on CUDA for speed.
- Random seed fixed in `TrainConfig.seed` for reproducibility.
- All scripts import the custom dataset shim so legacy names keep working.

---

Files

TFT_Forecasting_Framework_v1.0.zip

Files (97.3 kB)

Name Size Download all
md5:83039d977b7986a555e647f8a922830f
97.3 kB Preview Download

Additional details

Related works

Continues
Software: 10.5281/zenodo.19151435 (DOI)
Is source of
Other: 10.5281/zenodo.19151740 (DOI)

Funding

European Commission
AQplus4 - Deep Learning Air Quality Forecasts for Four Days 101113400