Skip to content

Training API Reference

This module contains training utilities and trainers for quantum GANs.

QuantumGAN

The main training class for quantum generative adversarial networks.

Class Definition

class QuantumGAN:
    """
    Main trainer for Quantum GANs.

    Handles training of quantum generators and discriminators with
    support for various loss functions, optimizers, and training strategies.
    """

Constructor

def __init__(
    self,
    generator: nn.Module,
    discriminator: nn.Module,
    device: str = "cpu",
    lr_g: float = 0.0002,
    lr_d: float = 0.0002,
    beta1: float = 0.5,
    beta2: float = 0.999,
    loss_type: str = "wgan-gp",
    lambda_gp: float = 10.0,
    **kwargs
):

Parameters:

  • generator: Quantum or classical generator model
  • discriminator: Quantum or classical discriminator model
  • device: Device to train on ("cpu", "cuda", "mps")
  • lr_g: Generator learning rate
  • lr_d: Discriminator learning rate
  • beta1: Adam optimizer beta1 parameter
  • beta2: Adam optimizer beta2 parameter
  • loss_type: Loss function type ("wgan-gp", "lsgan", "vanilla", "quantum-wgan")
  • lambda_gp: Gradient penalty coefficient
  • **kwargs: Additional training configuration

Usage Examples

Basic Training Setup

from qgans_pro import QuantumGAN, QuantumGenerator, QuantumDiscriminator
import torch

# Create models
generator = QuantumGenerator(n_qubits=8, n_layers=3, output_dim=784)
discriminator = QuantumDiscriminator(input_dim=784, n_qubits=8, n_layers=3)

# Initialize trainer
qgan = QuantumGAN(
    generator=generator,
    discriminator=discriminator,
    device="cuda",
    lr_g=0.0001,
    lr_d=0.0002,
    loss_type="wgan-gp"
)

# Load data
from qgans_pro.utils import get_data_loader
dataloader = get_data_loader("fashion-mnist", batch_size=64)

# Train the model
qgan.train(
    dataloader=dataloader,
    epochs=100,
    save_interval=10,
    sample_interval=5
)

Advanced Training Configuration

# Advanced training with custom configuration
qgan = QuantumGAN(
    generator=generator,
    discriminator=discriminator,
    device="cuda",
    lr_g=0.0001,
    lr_d=0.0002,
    loss_type="quantum-wgan",
    lambda_gp=10.0,
    lambda_quantum=1.0,  # Quantum loss weight
    gradient_accumulation_steps=4,
    mixed_precision=True,
    warmup_epochs=5
)

# Train with callbacks
from qgans_pro.training.callbacks import EarlyStopping, ModelCheckpoint

callbacks = [
    EarlyStopping(patience=20, metric="fid_score"),
    ModelCheckpoint(save_best_only=True, metric="fid_score")
]

qgan.train(
    dataloader=dataloader,
    epochs=200,
    callbacks=callbacks,
    validation_data=val_dataloader
)

Methods

train(dataloader, epochs, **kwargs)

Main training method for the quantum GAN.

Parameters:

  • dataloader: PyTorch DataLoader with training data
  • epochs: Number of training epochs
  • save_interval: Epochs between model saves (default: 10)
  • sample_interval: Epochs between sample generation (default: 5)
  • callbacks: List of training callbacks
  • validation_data: Validation DataLoader (optional)

Returns:

  • dict: Training history with losses and metrics

train_step(real_data)

Single training step for both generator and discriminator.

Parameters:

  • real_data: Batch of real training data

Returns:

  • dict: Dictionary with generator and discriminator losses

generate_samples(n_samples, **kwargs)

Generate synthetic samples using the trained generator.

Parameters:

  • n_samples: Number of samples to generate
  • **kwargs: Additional generation parameters

Returns:

  • torch.Tensor: Generated samples

evaluate(test_loader, metrics=None)

Evaluate the trained model on test data.

Parameters:

  • test_loader: Test data loader
  • metrics: List of metrics to compute

Returns:

  • dict: Evaluation results

save_checkpoint(path, **metadata)

Save model checkpoint with metadata.

Parameters:

  • path: Path to save checkpoint
  • **metadata: Additional metadata to save

load_checkpoint(path)

Load model checkpoint.

Parameters:

  • path: Path to checkpoint file

Returns:

  • dict: Loaded metadata

Properties

training_history

Get the complete training history.

Returns:

  • dict: Training metrics over time

current_epoch

Get the current training epoch.

Returns:

  • int: Current epoch number

is_trained

Check if the model has been trained.

Returns:

  • bool: True if model is trained

HybridGAN

Trainer for hybrid quantum-classical GANs.

Usage Examples

from qgans_pro import HybridGAN, QuantumGenerator, ClassicalDiscriminator

# Hybrid configuration: Quantum generator + Classical discriminator
generator = QuantumGenerator(n_qubits=8, n_layers=3, output_dim=784)
discriminator = ClassicalDiscriminator(input_dim=784)

hybrid_gan = HybridGAN(
    generator=generator,
    discriminator=discriminator,
    quantum_classical_balance=0.7,  # 70% quantum, 30% classical
    device="cuda"
)

# Train hybrid model
hybrid_gan.train(dataloader, epochs=100)

Training Utilities

Schedulers

Learning rate schedulers optimized for quantum GANs:

from qgans_pro.training.schedulers import QuantumAwareScheduler

# Quantum-aware learning rate scheduling
scheduler = QuantumAwareScheduler(
    optimizer=optimizer,
    mode="cosine_annealing",
    quantum_epochs=50,  # Epochs to focus on quantum training
    classical_epochs=50  # Epochs to focus on classical training
)

# Use in training loop
for epoch in range(epochs):
    # Training step
    losses = qgan.train_step(batch)

    # Update learning rates
    scheduler.step(losses['quantum_loss'])

Loss Functions

Specialized loss functions for quantum GANs:

from qgans_pro.training.losses import QuantumWassersteinLoss, QuantumHingeLoss

# Quantum Wasserstein loss with gradient penalty
loss_fn = QuantumWassersteinLoss(
    lambda_gp=10.0,
    lambda_quantum=1.0,
    quantum_regularization=True
)

# Quantum hinge loss
hinge_loss = QuantumHingeLoss(
    margin=1.0,
    quantum_margin_scaling=True
)

Optimizers

Quantum-aware optimizers:

from qgans_pro.training.optimizers import QuantumAdam, ParameterShiftOptim

# Quantum-aware Adam optimizer
optimizer_g = QuantumAdam(
    generator.parameters(),
    lr=0.0001,
    quantum_lr_scaling=True,
    parameter_shift_rule=True
)

# Parameter-shift rule optimizer for quantum circuits
optimizer_quantum = ParameterShiftOptim(
    generator.quantum_parameters(),
    lr=0.01,
    shift_value=np.pi/2
)

Callbacks

Training callbacks for monitoring and control:

from qgans_pro.training.callbacks import (
    EarlyStopping,
    ModelCheckpoint,
    QuantumMetricsLogger,
    TensorBoardLogger
)

# Early stopping based on FID score
early_stopping = EarlyStopping(
    monitor="fid_score",
    patience=20,
    min_delta=0.1,
    mode="min"
)

# Save best model checkpoints
checkpoint = ModelCheckpoint(
    filepath="best_model.pth",
    monitor="fid_score",
    save_best_only=True,
    mode="min"
)

# Log quantum-specific metrics
quantum_logger = QuantumMetricsLogger(
    log_interval=10,
    metrics=["quantum_fidelity", "entanglement_measure"]
)

# TensorBoard integration
tensorboard = TensorBoardLogger(
    log_dir="./logs",
    log_quantum_circuits=True,
    log_samples=True
)

# Use callbacks in training
qgan.train(
    dataloader=dataloader,
    epochs=100,
    callbacks=[early_stopping, checkpoint, quantum_logger, tensorboard]
)

Advanced Training Features

Distributed Training

Scale training across multiple devices:

from qgans_pro.training.distributed import DistributedQuantumGAN

# Initialize distributed training
distributed_qgan = DistributedQuantumGAN(
    generator=generator,
    discriminator=discriminator,
    backend="nccl",  # or "gloo" for CPU
    world_size=4,  # Number of GPUs
    rank=0  # Current process rank
)

# Train across multiple GPUs
distributed_qgan.train(dataloader, epochs=100)

Progressive Training

Gradually increase model complexity:

from qgans_pro.training.progressive import ProgressiveQuantumGAN

# Progressive training: start small, grow gradually
progressive_qgan = ProgressiveQuantumGAN(
    initial_qubits=4,
    max_qubits=12,
    growth_schedule="linear",  # or "exponential"
    growth_epochs=20
)

progressive_qgan.train(dataloader, epochs=200)

Transfer Learning

Transfer knowledge between quantum models:

from qgans_pro.training.transfer import QuantumTransferLearning

# Transfer learning from pre-trained model
transfer_trainer = QuantumTransferLearning(
    source_model_path="pretrained_qgan.pth",
    target_generator=new_generator,
    freeze_quantum_layers=True,
    fine_tune_epochs=50
)

transfer_trainer.train(dataloader)

Federated Learning

Quantum GANs with federated learning:

from qgans_pro.training.federated import FederatedQuantumGAN

# Federated quantum GAN training
federated_qgan = FederatedQuantumGAN(
    local_model=local_generator,
    aggregation_method="fedavg",
    privacy_budget=1.0,  # Differential privacy
    quantum_secure=True
)

# Train across federated clients
federated_qgan.federated_train(
    client_data_loaders=client_loaders,
    rounds=50,
    local_epochs=5
)

Best Practices

Training Configuration

Recommended settings for different scenarios:

# For beginners: stable training
config_beginner = {
    "lr_g": 0.0001,
    "lr_d": 0.0002,
    "loss_type": "wgan-gp",
    "lambda_gp": 10.0,
    "n_critic": 5  # Train discriminator 5x per generator step
}

# For research: experimental features
config_research = {
    "lr_g": 0.0002,
    "lr_d": 0.0002,
    "loss_type": "quantum-wgan",
    "lambda_quantum": 1.0,
    "quantum_regularization": True,
    "spectral_normalization": True
}

# For production: robust and efficient
config_production = {
    "lr_g": 0.0001,
    "lr_d": 0.0001,
    "loss_type": "lsgan",
    "mixed_precision": True,
    "gradient_checkpointing": True,
    "memory_efficient": True
}

Monitoring and Debugging

Tools for training monitoring:

from qgans_pro.training.monitor import TrainingMonitor

# Comprehensive training monitoring
monitor = TrainingMonitor(
    metrics=["loss", "fid", "quantum_fidelity"],
    visualization=True,
    alerts=True,
    save_frequency=10
)

# Add to training loop
qgan.add_monitor(monitor)

Hyperparameter Optimization

Automated hyperparameter tuning:

from qgans_pro.training.optimization import HyperparameterOptimizer

# Optimize hyperparameters
optimizer = HyperparameterOptimizer(
    search_space={
        "lr_g": (1e-5, 1e-2),
        "lr_d": (1e-5, 1e-2),
        "n_layers": (2, 6),
        "lambda_gp": (1.0, 50.0)
    },
    optimization_method="bayesian",  # or "grid", "random"
    n_trials=100
)

best_config = optimizer.optimize(
    model_fn=create_qgan,
    dataloader=dataloader,
    epochs=50,
    metric="fid_score"
)