Training API Reference
This module contains training utilities and trainers for quantum GANs.
QuantumGAN
The main training class for quantum generative adversarial networks.
Class Definition
class QuantumGAN:
"""
Main trainer for Quantum GANs.
Handles training of quantum generators and discriminators with
support for various loss functions, optimizers, and training strategies.
"""
Constructor
def __init__(
self,
generator: nn.Module,
discriminator: nn.Module,
device: str = "cpu",
lr_g: float = 0.0002,
lr_d: float = 0.0002,
beta1: float = 0.5,
beta2: float = 0.999,
loss_type: str = "wgan-gp",
lambda_gp: float = 10.0,
**kwargs
):
Parameters:
generator: Quantum or classical generator modeldiscriminator: Quantum or classical discriminator modeldevice: Device to train on ("cpu", "cuda", "mps")lr_g: Generator learning ratelr_d: Discriminator learning ratebeta1: Adam optimizer beta1 parameterbeta2: Adam optimizer beta2 parameterloss_type: Loss function type ("wgan-gp", "lsgan", "vanilla", "quantum-wgan")lambda_gp: Gradient penalty coefficient**kwargs: Additional training configuration
Usage Examples
Basic Training Setup
from qgans_pro import QuantumGAN, QuantumGenerator, QuantumDiscriminator
import torch
# Create models
generator = QuantumGenerator(n_qubits=8, n_layers=3, output_dim=784)
discriminator = QuantumDiscriminator(input_dim=784, n_qubits=8, n_layers=3)
# Initialize trainer
qgan = QuantumGAN(
generator=generator,
discriminator=discriminator,
device="cuda",
lr_g=0.0001,
lr_d=0.0002,
loss_type="wgan-gp"
)
# Load data
from qgans_pro.utils import get_data_loader
dataloader = get_data_loader("fashion-mnist", batch_size=64)
# Train the model
qgan.train(
dataloader=dataloader,
epochs=100,
save_interval=10,
sample_interval=5
)
Advanced Training Configuration
# Advanced training with custom configuration
qgan = QuantumGAN(
generator=generator,
discriminator=discriminator,
device="cuda",
lr_g=0.0001,
lr_d=0.0002,
loss_type="quantum-wgan",
lambda_gp=10.0,
lambda_quantum=1.0, # Quantum loss weight
gradient_accumulation_steps=4,
mixed_precision=True,
warmup_epochs=5
)
# Train with callbacks
from qgans_pro.training.callbacks import EarlyStopping, ModelCheckpoint
callbacks = [
EarlyStopping(patience=20, metric="fid_score"),
ModelCheckpoint(save_best_only=True, metric="fid_score")
]
qgan.train(
dataloader=dataloader,
epochs=200,
callbacks=callbacks,
validation_data=val_dataloader
)
Methods
train(dataloader, epochs, **kwargs)
Main training method for the quantum GAN.
Parameters:
dataloader: PyTorch DataLoader with training dataepochs: Number of training epochssave_interval: Epochs between model saves (default: 10)sample_interval: Epochs between sample generation (default: 5)callbacks: List of training callbacksvalidation_data: Validation DataLoader (optional)
Returns:
dict: Training history with losses and metrics
train_step(real_data)
Single training step for both generator and discriminator.
Parameters:
real_data: Batch of real training data
Returns:
dict: Dictionary with generator and discriminator losses
generate_samples(n_samples, **kwargs)
Generate synthetic samples using the trained generator.
Parameters:
n_samples: Number of samples to generate**kwargs: Additional generation parameters
Returns:
torch.Tensor: Generated samples
evaluate(test_loader, metrics=None)
Evaluate the trained model on test data.
Parameters:
test_loader: Test data loadermetrics: List of metrics to compute
Returns:
dict: Evaluation results
save_checkpoint(path, **metadata)
Save model checkpoint with metadata.
Parameters:
path: Path to save checkpoint**metadata: Additional metadata to save
load_checkpoint(path)
Load model checkpoint.
Parameters:
path: Path to checkpoint file
Returns:
dict: Loaded metadata
Properties
training_history
Get the complete training history.
Returns:
dict: Training metrics over time
current_epoch
Get the current training epoch.
Returns:
int: Current epoch number
is_trained
Check if the model has been trained.
Returns:
bool: True if model is trained
HybridGAN
Trainer for hybrid quantum-classical GANs.
Usage Examples
from qgans_pro import HybridGAN, QuantumGenerator, ClassicalDiscriminator
# Hybrid configuration: Quantum generator + Classical discriminator
generator = QuantumGenerator(n_qubits=8, n_layers=3, output_dim=784)
discriminator = ClassicalDiscriminator(input_dim=784)
hybrid_gan = HybridGAN(
generator=generator,
discriminator=discriminator,
quantum_classical_balance=0.7, # 70% quantum, 30% classical
device="cuda"
)
# Train hybrid model
hybrid_gan.train(dataloader, epochs=100)
Training Utilities
Schedulers
Learning rate schedulers optimized for quantum GANs:
from qgans_pro.training.schedulers import QuantumAwareScheduler
# Quantum-aware learning rate scheduling
scheduler = QuantumAwareScheduler(
optimizer=optimizer,
mode="cosine_annealing",
quantum_epochs=50, # Epochs to focus on quantum training
classical_epochs=50 # Epochs to focus on classical training
)
# Use in training loop
for epoch in range(epochs):
# Training step
losses = qgan.train_step(batch)
# Update learning rates
scheduler.step(losses['quantum_loss'])
Loss Functions
Specialized loss functions for quantum GANs:
from qgans_pro.training.losses import QuantumWassersteinLoss, QuantumHingeLoss
# Quantum Wasserstein loss with gradient penalty
loss_fn = QuantumWassersteinLoss(
lambda_gp=10.0,
lambda_quantum=1.0,
quantum_regularization=True
)
# Quantum hinge loss
hinge_loss = QuantumHingeLoss(
margin=1.0,
quantum_margin_scaling=True
)
Optimizers
Quantum-aware optimizers:
from qgans_pro.training.optimizers import QuantumAdam, ParameterShiftOptim
# Quantum-aware Adam optimizer
optimizer_g = QuantumAdam(
generator.parameters(),
lr=0.0001,
quantum_lr_scaling=True,
parameter_shift_rule=True
)
# Parameter-shift rule optimizer for quantum circuits
optimizer_quantum = ParameterShiftOptim(
generator.quantum_parameters(),
lr=0.01,
shift_value=np.pi/2
)
Callbacks
Training callbacks for monitoring and control:
from qgans_pro.training.callbacks import (
EarlyStopping,
ModelCheckpoint,
QuantumMetricsLogger,
TensorBoardLogger
)
# Early stopping based on FID score
early_stopping = EarlyStopping(
monitor="fid_score",
patience=20,
min_delta=0.1,
mode="min"
)
# Save best model checkpoints
checkpoint = ModelCheckpoint(
filepath="best_model.pth",
monitor="fid_score",
save_best_only=True,
mode="min"
)
# Log quantum-specific metrics
quantum_logger = QuantumMetricsLogger(
log_interval=10,
metrics=["quantum_fidelity", "entanglement_measure"]
)
# TensorBoard integration
tensorboard = TensorBoardLogger(
log_dir="./logs",
log_quantum_circuits=True,
log_samples=True
)
# Use callbacks in training
qgan.train(
dataloader=dataloader,
epochs=100,
callbacks=[early_stopping, checkpoint, quantum_logger, tensorboard]
)
Advanced Training Features
Distributed Training
Scale training across multiple devices:
from qgans_pro.training.distributed import DistributedQuantumGAN
# Initialize distributed training
distributed_qgan = DistributedQuantumGAN(
generator=generator,
discriminator=discriminator,
backend="nccl", # or "gloo" for CPU
world_size=4, # Number of GPUs
rank=0 # Current process rank
)
# Train across multiple GPUs
distributed_qgan.train(dataloader, epochs=100)
Progressive Training
Gradually increase model complexity:
from qgans_pro.training.progressive import ProgressiveQuantumGAN
# Progressive training: start small, grow gradually
progressive_qgan = ProgressiveQuantumGAN(
initial_qubits=4,
max_qubits=12,
growth_schedule="linear", # or "exponential"
growth_epochs=20
)
progressive_qgan.train(dataloader, epochs=200)
Transfer Learning
Transfer knowledge between quantum models:
from qgans_pro.training.transfer import QuantumTransferLearning
# Transfer learning from pre-trained model
transfer_trainer = QuantumTransferLearning(
source_model_path="pretrained_qgan.pth",
target_generator=new_generator,
freeze_quantum_layers=True,
fine_tune_epochs=50
)
transfer_trainer.train(dataloader)
Federated Learning
Quantum GANs with federated learning:
from qgans_pro.training.federated import FederatedQuantumGAN
# Federated quantum GAN training
federated_qgan = FederatedQuantumGAN(
local_model=local_generator,
aggregation_method="fedavg",
privacy_budget=1.0, # Differential privacy
quantum_secure=True
)
# Train across federated clients
federated_qgan.federated_train(
client_data_loaders=client_loaders,
rounds=50,
local_epochs=5
)
Best Practices
Training Configuration
Recommended settings for different scenarios:
# For beginners: stable training
config_beginner = {
"lr_g": 0.0001,
"lr_d": 0.0002,
"loss_type": "wgan-gp",
"lambda_gp": 10.0,
"n_critic": 5 # Train discriminator 5x per generator step
}
# For research: experimental features
config_research = {
"lr_g": 0.0002,
"lr_d": 0.0002,
"loss_type": "quantum-wgan",
"lambda_quantum": 1.0,
"quantum_regularization": True,
"spectral_normalization": True
}
# For production: robust and efficient
config_production = {
"lr_g": 0.0001,
"lr_d": 0.0001,
"loss_type": "lsgan",
"mixed_precision": True,
"gradient_checkpointing": True,
"memory_efficient": True
}
Monitoring and Debugging
Tools for training monitoring:
from qgans_pro.training.monitor import TrainingMonitor
# Comprehensive training monitoring
monitor = TrainingMonitor(
metrics=["loss", "fid", "quantum_fidelity"],
visualization=True,
alerts=True,
save_frequency=10
)
# Add to training loop
qgan.add_monitor(monitor)
Hyperparameter Optimization
Automated hyperparameter tuning:
from qgans_pro.training.optimization import HyperparameterOptimizer
# Optimize hyperparameters
optimizer = HyperparameterOptimizer(
search_space={
"lr_g": (1e-5, 1e-2),
"lr_d": (1e-5, 1e-2),
"n_layers": (2, 6),
"lambda_gp": (1.0, 50.0)
},
optimization_method="bayesian", # or "grid", "random"
n_trials=100
)
best_config = optimizer.optimize(
model_fn=create_qgan,
dataloader=dataloader,
epochs=50,
metric="fid_score"
)