Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions DENOISING_DIFFUSION/configs/base_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
project:
name: exxa-denoising-diffusion
seed: 42

data:
train_noisy_path: data/dirty.npy
train_clean_path: data/clean.npy
image_size: 64
patch_size: 64
patches_per_image: 4
val_split: 0.1
test_split: 0.1
num_workers: 2

model:
in_channels: 1
out_channels: 1
base_channels: 64
channel_multipliers: [1, 2, 4, 8]
dropout: 0.1

diffusion:
timesteps: 1000
beta_schedule: linear
beta_start: 0.0001
beta_end: 0.02

training:
batch_size: 16
epochs: 100
learning_rate: 0.0002
weight_decay: 0.00001
grad_clip_norm: 1.0
use_amp: false
device: auto

logging:
log_every_n_steps: 100
val_every_n_epochs: 1
save_every_n_epochs: 5
output_dir: experiments
use_wandb: true
wandb_project: exxa-denoising
25 changes: 25 additions & 0 deletions DENOISING_DIFFUSION/src/training/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Training utilities for EXXA denoising diffusion models."""

from .config import (
ProjectConfig,
DataConfig,
ModelConfig,
DiffusionConfig,
TrainingConfig,
LoggingConfig,
ProjectMetadata,
load_config,
save_config,
)

__all__ = [
"ProjectConfig",
"DataConfig",
"ModelConfig",
"DiffusionConfig",
"TrainingConfig",
"LoggingConfig",
"ProjectMetadata",
"load_config",
"save_config",
]
164 changes: 164 additions & 0 deletions DENOISING_DIFFUSION/src/training/config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
"""Configuration schema and IO helpers for training and inference."""

from __future__ import annotations

from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import Any, Mapping
import copy
import json

import yaml


@dataclass
class ProjectMetadata:
name: str = "exxa-denoising-diffusion"
seed: int = 42


@dataclass
class DataConfig:
train_noisy_path: str = "data/dirty.npy"
train_clean_path: str = "data/clean.npy"
image_size: int = 64
patch_size: int = 64
patches_per_image: int = 4
val_split: float = 0.1
test_split: float = 0.1
num_workers: int = 2


@dataclass
class ModelConfig:
in_channels: int = 1
out_channels: int = 1
base_channels: int = 64
channel_multipliers: list[int] | tuple[int, ...] = (1, 2, 4, 8)
dropout: float = 0.1


@dataclass
class DiffusionConfig:
timesteps: int = 1000
beta_schedule: str = "linear"
beta_start: float = 1e-4
beta_end: float = 2e-2


@dataclass
class TrainingConfig:
batch_size: int = 16
epochs: int = 100
learning_rate: float = 2e-4
weight_decay: float = 1e-5
grad_clip_norm: float = 1.0
use_amp: bool = False
device: str = "auto"


@dataclass
class LoggingConfig:
log_every_n_steps: int = 100
val_every_n_epochs: int = 1
save_every_n_epochs: int = 5
output_dir: str = "experiments"
use_wandb: bool = True
wandb_project: str = "exxa-denoising"


@dataclass
class ProjectConfig:
project: ProjectMetadata = field(default_factory=ProjectMetadata)
data: DataConfig = field(default_factory=DataConfig)
model: ModelConfig = field(default_factory=ModelConfig)
diffusion: DiffusionConfig = field(default_factory=DiffusionConfig)
training: TrainingConfig = field(default_factory=TrainingConfig)
logging: LoggingConfig = field(default_factory=LoggingConfig)

def to_dict(self) -> dict[str, Any]:
return asdict(self)

def validate(self) -> None:
if self.data.image_size <= 0:
raise ValueError("data.image_size must be > 0")
if self.training.batch_size <= 0:
raise ValueError("training.batch_size must be > 0")
if self.training.learning_rate <= 0:
raise ValueError("training.learning_rate must be > 0")
if self.diffusion.timesteps <= 0:
raise ValueError("diffusion.timesteps must be > 0")
if self.diffusion.beta_schedule not in {"linear", "cosine"}:
raise ValueError("diffusion.beta_schedule must be one of: linear, cosine")
split_total = self.data.val_split + self.data.test_split
if not (0.0 <= self.data.val_split < 1.0 and 0.0 <= self.data.test_split < 1.0):
raise ValueError("data.val_split and data.test_split must be in [0, 1)")
if split_total >= 1.0:
raise ValueError("data.val_split + data.test_split must be < 1")

@classmethod
def from_dict(cls, payload: Mapping[str, Any]) -> "ProjectConfig":
data = copy.deepcopy(payload)
cfg = cls(
project=ProjectMetadata(**data.get("project", {})),
data=DataConfig(**data.get("data", {})),
model=ModelConfig(**data.get("model", {})),
diffusion=DiffusionConfig(**data.get("diffusion", {})),
training=TrainingConfig(**data.get("training", {})),
logging=LoggingConfig(**data.get("logging", {})),
)
cfg.validate()
return cfg


def _deep_merge(base: dict[str, Any], updates: Mapping[str, Any]) -> dict[str, Any]:
merged = copy.deepcopy(base)
for key, value in updates.items():
if key in merged and isinstance(merged[key], dict) and isinstance(value, Mapping):
merged[key] = _deep_merge(merged[key], value)
else:
merged[key] = value
return merged


def _read_config_file(path: Path) -> dict[str, Any]:
suffix = path.suffix.lower()
with path.open("r", encoding="utf-8") as f:
if suffix in {".yaml", ".yml"}:
data = yaml.safe_load(f) or {}
elif suffix == ".json":
data = json.load(f)
else:
raise ValueError(f"Unsupported config extension: {suffix}")
if not isinstance(data, dict):
raise ValueError("Config root must be a mapping")
return data


def load_config(
config_path: str | Path | None = None,
overrides: Mapping[str, Any] | None = None,
) -> ProjectConfig:
"""Load project config from file and optional nested overrides."""
if config_path is None:
config_path = Path(__file__).resolve().parents[2] / "configs" / "base_config.yaml"
path = Path(config_path)
payload = _read_config_file(path)
if overrides:
payload = _deep_merge(payload, overrides)
return ProjectConfig.from_dict(payload)


def save_config(config: ProjectConfig, config_path: str | Path) -> None:
"""Save config to YAML or JSON based on destination extension."""
path = Path(config_path)
path.parent.mkdir(parents=True, exist_ok=True)
data = config.to_dict()

with path.open("w", encoding="utf-8") as f:
if path.suffix.lower() in {".yaml", ".yml"}:
yaml.safe_dump(data, f, sort_keys=False)
elif path.suffix.lower() == ".json":
json.dump(data, f, indent=2)
else:
raise ValueError(f"Unsupported config extension: {path.suffix}")
68 changes: 68 additions & 0 deletions DENOISING_DIFFUSION/tests/test_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
"""Unit tests for training configuration schema and IO."""

from pathlib import Path
import json
import sys

import pytest

sys.path.insert(0, str(Path(__file__).parent.parent / "src"))

from training.config import ProjectConfig, load_config, save_config


def test_load_default_config() -> None:
cfg = load_config()
assert isinstance(cfg, ProjectConfig)
assert cfg.project.name == "exxa-denoising-diffusion"
assert cfg.diffusion.timesteps == 1000


def test_load_with_overrides() -> None:
cfg = load_config(
overrides={
"training": {"batch_size": 8},
"diffusion": {"beta_schedule": "cosine"},
}
)
assert cfg.training.batch_size == 8
assert cfg.diffusion.beta_schedule == "cosine"


def test_load_json_config(tmp_path: Path) -> None:
config_path = tmp_path / "custom_config.json"
payload = {
"project": {"name": "json-test", "seed": 7},
"data": {"image_size": 128},
"model": {"base_channels": 32},
"diffusion": {"timesteps": 200, "beta_schedule": "linear"},
"training": {"batch_size": 4, "learning_rate": 1e-4},
"logging": {"use_wandb": False},
}
config_path.write_text(json.dumps(payload), encoding="utf-8")

cfg = load_config(config_path)
assert cfg.project.name == "json-test"
assert cfg.data.image_size == 128
assert cfg.training.batch_size == 4


def test_save_and_reload_yaml(tmp_path: Path) -> None:
cfg = load_config(overrides={"project": {"name": "saved-test"}})
out_path = tmp_path / "saved_config.yaml"

save_config(cfg, out_path)
loaded = load_config(out_path)

assert loaded.project.name == "saved-test"
assert loaded.model.base_channels == cfg.model.base_channels


def test_validation_rejects_bad_split() -> None:
with pytest.raises(ValueError, match=r"val_split \+ data.test_split"):
load_config(overrides={"data": {"val_split": 0.7, "test_split": 0.4}})


def test_validation_rejects_bad_schedule() -> None:
with pytest.raises(ValueError, match="beta_schedule"):
load_config(overrides={"diffusion": {"beta_schedule": "invalid"}})