Transformer-Based Framework for Transferable Hourly Ozone Forecasting
Authors/Creators
Description
TFT Pipeline — Training, Transfer Learning, Inference & Visualization
End-to-end Temporal Fusion Transformer (TFT) workflow for air-quality time series: dataset prep, training from scratch, transfer learning, inference export, and rich visualizations (including dataset builder, trainer, transfer-learning, inference, and visualization suite — maps and TensorBoard-to-PNG and PNG→GIF utilities optimized for air-quality or environmental time series).
---
Repository Layout
```
tft_forecasting_kit/
├── pyproject.toml
├── requirements.txt
├── requirements_jsc_booster.txt
├── requirements_jsc_cluster.txt
├── Config/ # YAML config files
├── HPC_Scripts/ # SLURM or HPC jobs
├── Plots/ # Default output for plots/GIFs # Folder provided in 10.5281/zenodo.19151740
├── src/
│ ├── Make_Datasets.py
│ ├── Train_TFT.py
│ ├── Finetune_transfer.py
│ ├── Optuna_Tune_Transfer.py
│ ├── Run_Inference.py
│ ├── dataset/
│ │ └── repository.py
│ ├── utils/
│ │ ├── config.py
│ │ └── logging.py
│ ├── custom_pytorch_forecasting/
│ │ └── ModifiedTimeseriesDataset.py
│ └── visualize/
│ ├── Plot_Inference.py
│ ├── Map_Plots.py
│ ├── Tensorboard_plots.py
│ ├── pngs_to_gif.py
│ └── README.md
└── README.md
```
> **Note:** All training/inference scripts alias the custom dataset for backward compatibility:
`sys.modules['ModifiedTimeseriesDataset'] = custom_pytorch_forecasting.ModifiedTimeseriesDataset`.
---
Installation
```bash
# create and activate a virtual environment (conda/venv as you prefer)
python -m venv .venv && source .venv/bin/activate
# pick one:
pip install -r requirements.txt # local/dev
# or on HPC
pip install -r requirements_jsc_booster.txt # Booster env
pip install -r requirements_jsc_cluster.txt # Cluster env
```
If you want to use Cartopy (maps), you may need to set `CARTOPY_DATA_DIR` (offline usage) or install `cartopy-data` on your system.
CLI Entrypoints (from `pyproject.toml`)
```toml
[project.scripts]
make-datasets = "Make_Datasets:main"
train-tft = "Train_TFT:main"
run-inference = "Run_Inference:main"
```
So you can call:
```bash
make-datasets --help
train-tft --help
run-inference --help
```
---
Quick Start
```bash
# 1️. Create datasets
make-datasets --config Config/make.yml
# 2️. Train model
train-tft --config Config/train.yml
# 3️. (Optional) Transfer learning / Optuna tuning
python src/Finetune_transfer.py --config Config/finetune.yml
python src/Optuna_Tune_Transfer.py --config Config/tune.yml
# 4️. Run inference
run-inference --config Config/infer.yml
# 5️. Visualize results
python src/visualize/Plot_Inference.py --config Config/plot.yml
```
---
Modules Overview
| Folder | Purpose |
|---------|----------|
| `src/dataset/repository.py` | `PickleTSRepository` — load/save pickled PyTorch-Forecasting datasets |
| `src/utils/config.py` | `TrainConfig`, `PlotConfig`, YAML loading |
| `src/utils/logging.py` | `setup_logging`, `rank_zero_*` wrappers |
| `src/custom_pytorch_forecasting/ModifiedTimeseriesDataset.py` | `PatchedTimeSeriesDataSet` — standalone-usable replacement for original class |
| `src/visualize/` | complete plotting, mapping, and export utilities |
---
Data & Artifacts
- **Datasets**: training/validation are stored as *pickled* `PatchedTimeSeriesDataSet` objects and loaded via `PickleTSRepository`.
- **Checkpoints**: PyTorch Lightning `.ckpt` files in `lightning_logs/version_*/checkpoints/`.
- **Inference outputs**: torch pickles — `predictions_y.pkl`, `predictions_out.pkl`, `raw_predictions_x.pkl`, `raw_predictions_out.pkl`.
- **External CAMS**: daily NetCDF files like `ENS_FORECAST_YYYY-MM-DD.nc` for optional overlays/skills.
---
Using the Custom Dataset Standalone
```bash
cd src/custom_pytorch_forecasting
pip install . # or: pip install -e .
```
Then in any script:
```python
from custom_pytorch_forecasting.ModifiedTimeseriesDataset import PatchedTimeSeriesDataSet
```
It behaves exactly like the stock `TimeSeriesDataSet`, with additional feature to have only continous samples.
---
1) Train (from scratch or continue)
```bash
python src/train_tft.py --config Config/train.yml --data-dir /path/to/pickles --log-dir ./lightning_logs --enc-len 336 --pred-len 96 --batch-size 128 --max-epochs 60 --accelerator gpu --devices 1 --strategy auto --hidden-size 64 --lstm-layers 2 --dropout 0.1 --attention-heads 4 --learning-rate 1e-3
```
**Highlights**
- Supports `--config` (YAML) where CLI overrides take precedence.
- Can resume via `--checkpoint <.ckpt>` or auto-pick latest with `--auto-checkpoint-picker --auto-checkpoint-path <parent_of_version_*>`.
- Uses `TrainConfig` defaults for shapes/devices if not overridden.
---
2) Transfer Learning (Germany → Korea)
```bash
python src/finetune_transfer.py --config Config/finetune.yml --data-dir /path/to/korea_pickles --checkpoint /path/to/germany.ckpt --hidden-size 16 --attention-heads 4 --lstm-layers 1 --dropout 0.1 --learning-rate 5e-4 --freeze-policy default_meta --log-dir ./lightning_logs_transfer --devices 1 --accelerator gpu
```
**What it does**
- Instantiates a fresh TFT from **target** dataset (vocab sizes match), then **loosely loads** source weights:
- copies matching shapes,
- smartly **expands embeddings** if vocab grew (new station codes),
- skips incompatible tensors.
- Freezes all params, then **unfreezes** according to policy:
- `default_meta`: station embeddings (+ selected static real prescalers),
- `attn_output`: above **plus** attention and output layers.
---
3) LR Tuning on top of Transfer (Optuna)
```bash
python src/optuna_tune_transfer.py --config Config/tune.yml --data-dir /path/to/korea_pickles --checkpoint /path/to/germany.ckpt --hidden-size 16 --attention-heads 4 --lstm-layers 1 --dropout 0.1 --n-trials 12 --lr-min 1e-5 --lr-max 3e-3 --strategy ddp_find_unused_parameters_true --devices 1 --accelerator gpu
```
- Fixes model architecture to the checkpoint-compatible values and **only tunes LR**.
- Applies the same warm-start + freeze policy per trial via a callback.
- Persists `optuna_lr/best.json` and `study.pkl` under the checkpoint dir.
---
4) Inference Export
```bash
python src/run_inference.py --config Config/infer.yml --data-dir /path/to/pickles --checkpoint /path/to/best.ckpt --out-dir ./inference_out --prefix pred_ --batch-size 128
```
Produces:
- `predictions_y.pkl`, `predictions_out.pkl`
- `raw_predictions_x.pkl`, `raw_predictions_out.pkl`
---
5) Visualization Toolkit (`src/visualize`)
A) Plot per-sample, aggregates & skills
```bash
python src/visualize/Plot_Inference.py --config Config/plot.yml --validation-pkl /path/validation.pkl --pred-y ./inference_out/pred_predictions_y.pkl --pred-out ./inference_out/pred_predictions_out.pkl --raw-x ./inference_out/pred_raw_predictions_x.pkl --raw-out ./inference_out/pred_raw_predictions_out.pkl --station-csv /path/station_meta.csv --nc-dir /path/to/CAMS_nc_dir --include-sample-plots --use-external
```
- Feature flags: `--use-external/--no-use-external`, `--include-sample-plots/--no-include-sample-plots`.
- `PlotConfig` fields can be overridden from CLI, e.g. `--output-dir`, `--base-date`, `--sample-limit`, etc.
B) Per-timestamp maps (Actual/Pred/CAMS)
```bash
python src/visualize/Map_Plots.py --config Config/maps.yml --val-pkl /path/validation.pkl --pred-y ./inference_out/predictions_y.pkl --pred-out ./inference_out/predictions_out.pkl --raw-x ./inference_out/raw_predictions_x.pkl --raw-out ./inference_out/raw_predictions_out.pkl --cams-nearest-csv /path/station_meta.csv --cams-forecast-dir /path/CAMS_nc_dir --output-dir ./Plots --records-pkl ./Plots/map_plot_records.pkl --use-external --external-scale 0.50115 --enc-len 336 --pred-len 96 --base-date "2020-01-01 00:00:00"
```
Tips:
- If `--records-pkl` already exists, it is **not** overwritten.
- To limit rendered timestamps: `--sample-limit N` (plots from `enc_len` to `enc_len+N`); otherwise from `0` to `len- pred_len`.
C) TensorBoard scalars → PNGs
```bash
python src/visualize/Tensorboard_plots.py --version_dir ./lightning_logs --plot_dir ./Plots/TensorBoard
```
D) PNGs → GIF (streaming, RGBA-safe)
```bash
python src/visualize/pngs_to_gif.py ./Plots/Maps/*.png -o ./Plots/Maps_anim.gif --fps 12 --optimize
```
Options: `--duration`, `--loop`, `--size 1024x1024`, `--reverse`.
---
YAML Examples
Create small YAMLs under `Config/` and override via CLI as needed.
`train.yml`
```yaml
data_dir: /path/to/pickles
enc_len: 336
pred_len: 96
batch_size: 128
max_epochs: 60
accelerator: gpu
devices: 1
strategy: auto
hidden_size: 64
lstm_layers: 2
dropout: 0.1
attention_heads: 4
learning_rate: 0.001
log_dir: ./lightning_logs
# checkpoint: /path/to/prev.ckpt
# auto_checkpoint_picker: true
# auto_checkpoint_path: ./lightning_logs
```
`finetune.yml`
```yaml
data_dir: /path/to/korea_pickles
checkpoint: /path/to/germany.ckpt
hidden_size: 16
attention_heads: 4
lstm_layers: 1
dropout: 0.1
learning_rate: 0.0005
freeze_policy: default_meta # or attn_output
log_dir: ./lightning_logs_transfer
devices: 1
accelerator: gpu
```
`infer.yml`
```yaml
data_dir: /path/to/pickles
checkpoint: /path/to/best.ckpt
out_dir: ./inference_out
prefix: pred_
batch_size: 128
```
`plot.yml`
```yaml
output_dir: ./Plots
base_date: "2020-01-01 00:00:00"
max_prediction_length: 96
max_encoder_length: 336
x_tick_gap_hours: 8
external_scale: 0.50115
use_external: true
include_sample_plots: true
sample_limit: 24
```
`maps.yml`
```yaml
val_pkl: /path/validation.pkl
pred_y: ./inference_out/predictions_y.pkl
pred_out: ./inference_out/predictions_out.pkl
raw_x: ./inference_out/raw_predictions_x.pkl
raw_out: ./inference_out/raw_predictions_out.pkl
cams_nearest_csv: /path/station_meta.csv
cams_forecast_dir: /path/CAMS_nc_dir
cams_file_pattern: "ENS_FORECAST_%Y-%m-%d.nc"
output_dir: ./Plots
records_pkl: ./Plots/map_plot_records.pkl
enc_len: 336
pred_len: 96
base_date: "2020-01-01 00:00:00"
extent: [5, 16, 47, 56]
cmap: viridis
dpi: 200
use_external: true
external_scale: 0.50115
```
CLI Scripts and HELP info
`Make_Datasets.py`
**Usage**
```bash
make-datasets --config Config/make.yml --country DE --mode all --indir data/raw --outdir data/pickles
```
**--help**
```
--config PATH YAML config for source folders and target path
--country {DE,KR} Dataset region
--mode {all,summer} Filter by seasonal subset
--indir PATH Input CSV folder
--codes PATH Station codes CSV
--outcsv PATH Output merged CSV
--pivot_glob STR Pattern for pivoted CSVs
--rank INT MPI rank (default 0)
--world INT MPI world size (default 1)
```
---
`Train_TFT.py
`**Usage**
```bash
train-tft --config Config/train.yml --enc-len 336 --pred-len 96
```
**--help**
```
--config PATH YAML configuration
--data-dir PATH PickleTSRepository path
--log-dir PATH Output Lightning logs
--checkpoint PATH Resume checkpoint
--auto-checkpoint-picker Use latest ckpt automatically
--auto-checkpoint-path PATH Path containing version_* folders
--hidden-size INT Hidden dimension size
--lstm-layers INT Number of LSTM layers
--dropout FLOAT Dropout fraction
--attention-heads INT Number of attention heads
--learning-rate FLOAT Learning rate
--max-epochs INT Training epochs
--accelerator STR 'gpu' or 'cpu'
--devices INT Device count
--strategy STR DDP or auto
```
---
`Finetune_transfer.py`
**Usage**
```bash
python src/Finetune_transfer.py --config Config/finetune.yml
```
**--help**
```
--config PATH YAML config
--data-dir PATH Target dataset (PickleTSRepository)
--checkpoint PATH Source pretrained model
--freeze-policy {default_meta,attn_output}
--hidden-size INT
--attention-heads INT
--lstm-layers INT
--dropout FLOAT
--learning-rate FLOAT
--log-dir PATH
--devices INT
--accelerator STR
```
---
`Optuna_Tune_Transfer.py`
**Usage**
```bash
python src/Optuna_Tune_Transfer.py --config Config/tune.yml --n-trials 20
```
**--help**
```
--config PATH YAML config
--checkpoint PATH Finetune base checkpoint
--n-trials INT Optuna trials
--lr-min FLOAT Lower bound for LR search
--lr-max FLOAT Upper bound
--timeout FLOAT Stop after (sec)
--strategy STR ddp_find_unused_parameters_true or auto
--devices INT
--accelerator STR
```
---
`Run_Inference.py`
**Usage**
```bash
run-inference --config Config/infer.yml --checkpoint /path/best.ckpt
```
**--help**
```
--config PATH YAML config
--data-dir PATH Dataset pickle directory
--checkpoint PATH Model checkpoint
--out-dir PATH Output directory
--prefix STR Output prefix
--batch-size INT Batch size for prediction
```
---
Visualization Suite
[All tools live under `src/visualize/` and have their own README.](src/visualize/README.md)
`Plot_Inference.py`
```
--config PATH
--validation-pkl PATH
--pred-y PATH
--pred-out PATH
--raw-x PATH
--raw-out PATH
--station-csv PATH
--nc-dir PATH
--output-dir PATH
--base-date STR
--enc-len INT
--pred-len INT
--sample-limit INT
--x-tick-gap-hours INT
--use-external / --no-use-external
--include-sample-plots / --no-include-sample-plots
```
`Map_Plots.py`
```
--config PATH
--val-pkl PATH
--pred-y PATH
--pred-out PATH
--raw-x PATH
--raw-out PATH
--cams-nearest-csv PATH
--cams-forecast-dir PATH
--output-dir PATH
--records-pkl PATH
--use-external
--external-scale FLOAT
--enc-len INT
--pred-len INT
--base-date STR
--extent L R B T
--sample-limit INT
```
`Tensorboard_plots.py`
```
--version_dir PATH Directory containing lightning_logs
--plot_dir PATH Destination for PNGs
```
`pngs_to_gif.py`
```
patterns... PNG glob patterns
-o, --out PATH Output GIF file
--fps INT Frames per second
--duration FLOAT Duration per frame
--loop INT Loop count
--size WxH Resize output
--optimize Optimize palette
--reverse Append reversed frames
```
---
Example End-to-End Workflow
```bash
make-datasets --config Config/make.yml
train-tft --config Config/train.yml
python src/Finetune_transfer.py --config Config/finetune.yml
python src/Optuna_Tune_Transfer.py --config Config/tune.yml
run-inference --config Config/infer.yml
python src/visualize/Plot_Inference.py --config Config/plot.yml
python src/visualize/Map_Plots.py --config Config/maps.yml
python src/visualize/pngs_to_gif.py "./Plots/maps/*.png" -o ./Plots/maps.gif
```
---
Notes & Conventions
- `TrainConfig`/`PlotConfig` centralize **shapes**, **devices**, and **plot** defaults; CLI and YAML can override.
- Scripts set `torch.set_float32_matmul_precision("high")` on CUDA for speed.
- Random seed fixed in `TrainConfig.seed` for reproducibility.
- All scripts import the custom dataset shim so legacy names keep working.
---
Files
TFT_Forecasting_Framework_v1.0.zip
Files
(97.3 kB)
| Name | Size | Download all |
|---|---|---|
|
md5:83039d977b7986a555e647f8a922830f
|
97.3 kB | Preview Download |
Additional details
Related works
- Continues
- Software: 10.5281/zenodo.19151435 (DOI)
- Is source of
- Other: 10.5281/zenodo.19151740 (DOI)
Software
- Repository URL
- https://gitlab.jsc.fz-juelich.de/vasireddy1/tft/-/tree/main