Training Models: Complete Guide
Master the art of training quantum GANs with comprehensive strategies, advanced techniques, and practical examples for optimal performance.
🚀 Getting Started with Training
Basic Training Setup
import torch
from qgans_pro import (
QuantumGenerator, QuantumDiscriminator, QuantumGAN,
get_data_loader
)
# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_qubits = 8
n_layers = 3
batch_size = 64
epochs = 100
# Load data
train_loader = get_data_loader(
dataset_name="fashion-mnist",
batch_size=batch_size,
train=True,
download=True
)
# Create models
generator = QuantumGenerator(
n_qubits=n_qubits,
n_layers=n_layers,
output_dim=784, # 28x28 images
backend="qiskit",
device="aer_simulator"
)
discriminator = QuantumDiscriminator(
input_dim=784,
n_qubits=n_qubits,
n_layers=n_layers,
backend="qiskit",
device="aer_simulator"
)
# Initialize trainer
qgan = QuantumGAN(
generator=generator,
discriminator=discriminator,
device=device
)
# Train
qgan.train(
dataloader=train_loader,
epochs=epochs,
save_interval=10,
sample_interval=5
)
⚙️ Training Configuration
Learning Rate Strategies
Different learning rate strategies for optimal convergence:
# Adaptive learning rates
training_config = {
"lr_g": 0.0002, # Generator learning rate
"lr_d": 0.0002, # Discriminator learning rate
"beta1": 0.5, # Adam beta1
"beta2": 0.999, # Adam beta2
"scheduler": "cosine", # Learning rate scheduler
"warmup_epochs": 10
}
# Step-wise decay
step_config = {
"lr_g": 0.001,
"lr_d": 0.001,
"scheduler": "step",
"step_size": 30,
"gamma": 0.5
}
# Exponential decay
exp_config = {
"lr_g": 0.002,
"lr_d": 0.002,
"scheduler": "exponential",
"gamma": 0.95
}
Training Regimens
Standard Training
def standard_training(qgan, dataloader, epochs=100):
"""Standard alternating training."""
for epoch in range(epochs):
for batch_idx, (real_data, _) in enumerate(dataloader):
# Train discriminator
d_loss = qgan.train_discriminator(real_data)
# Train generator
g_loss = qgan.train_generator(batch_size=len(real_data))
if batch_idx % 100 == 0:
print(f"Epoch {epoch}, Batch {batch_idx}: G_loss={g_loss:.4f}, D_loss={d_loss:.4f}")
Progressive Training
def progressive_training(qgan, dataloader, stages=[(4,2), (6,3), (8,4)]):
"""Progressive training with increasing model complexity."""
for stage, (n_qubits, n_layers) in enumerate(stages):
print(f"Stage {stage+1}: {n_qubits} qubits, {n_layers} layers")
# Update model architecture
qgan.update_architecture(n_qubits=n_qubits, n_layers=n_layers)
# Train for this stage
epochs_per_stage = 50
for epoch in range(epochs_per_stage):
for real_data, _ in dataloader:
d_loss = qgan.train_discriminator(real_data)
g_loss = qgan.train_generator(batch_size=len(real_data))
# Evaluate stage
qgan.evaluate_stage(stage)
Curriculum Learning
def curriculum_training(qgan, dataloader, difficulty_schedule):
"""Curriculum learning with increasing data complexity."""
for phase, (difficulty, epochs) in enumerate(difficulty_schedule):
print(f"Phase {phase+1}: Difficulty {difficulty}")
# Filter data by difficulty
filtered_loader = filter_by_difficulty(dataloader, difficulty)
for epoch in range(epochs):
for real_data, _ in filtered_loader:
d_loss = qgan.train_discriminator(real_data)
g_loss = qgan.train_generator(batch_size=len(real_data))
# Adjust learning rate based on performance
qgan.adjust_learning_rate(epoch, phase)
🎯 Advanced Training Techniques
Spectral Normalization
Stabilize discriminator training:
from qgans_pro.training import SpectralNormalization
# Apply spectral normalization to discriminator
discriminator = QuantumDiscriminator(
input_dim=784,
n_qubits=8,
n_layers=3,
spectral_norm=True,
spectral_norm_target=1.0
)
# Custom spectral normalization
class SpectralNormalizedQGAN(QuantumGAN):
def train_discriminator(self, real_data):
# Apply spectral normalization before training
self.discriminator.apply_spectral_norm()
# Standard discriminator training
loss = super().train_discriminator(real_data)
return loss
Gradient Penalty
Improve training stability with gradient penalty:
def gradient_penalty(discriminator, real_data, fake_data, device):
"""Compute gradient penalty for WGAN-GP."""
alpha = torch.rand(real_data.size(0), 1, device=device)
alpha = alpha.expand_as(real_data)
interpolated = alpha * real_data + (1 - alpha) * fake_data
interpolated.requires_grad_(True)
# Get discriminator output
d_interpolated = discriminator(interpolated)
# Compute gradients
gradients = torch.autograd.grad(
outputs=d_interpolated,
inputs=interpolated,
grad_outputs=torch.ones_like(d_interpolated),
create_graph=True,
retain_graph=True
)[0]
# Calculate penalty
gradients = gradients.view(gradients.size(0), -1)
gradient_norm = gradients.norm(2, dim=1)
penalty = ((gradient_norm - 1) ** 2).mean()
return penalty
# Use in training
class WGANGPQuantum(QuantumGAN):
def __init__(self, *args, lambda_gp=10, **kwargs):
super().__init__(*args, **kwargs)
self.lambda_gp = lambda_gp
def train_discriminator(self, real_data):
# Generate fake data
fake_data = self.generate_samples(len(real_data))
# Standard losses
real_loss = -self.discriminator(real_data).mean()
fake_loss = self.discriminator(fake_data).mean()
# Gradient penalty
gp = gradient_penalty(self.discriminator, real_data, fake_data, self.device)
# Total loss
d_loss = real_loss + fake_loss + self.lambda_gp * gp
# Backward pass
self.optimizer_d.zero_grad()
d_loss.backward()
self.optimizer_d.step()
return d_loss.item()
Self-Attention Mechanisms
Add self-attention to quantum circuits:
class SelfAttentionQuantumGAN(QuantumGAN):
def __init__(self, *args, attention_layers=[2, 4], **kwargs):
super().__init__(*args, **kwargs)
self.attention_layers = attention_layers
self.setup_attention()
def setup_attention(self):
"""Setup self-attention mechanisms."""
for layer in self.attention_layers:
self.generator.add_attention_layer(layer)
self.discriminator.add_attention_layer(layer)
def train_with_attention(self, dataloader, epochs):
"""Training with attention mechanisms."""
for epoch in range(epochs):
for real_data, _ in dataloader:
# Train with attention-enhanced features
d_loss = self.train_discriminator_with_attention(real_data)
g_loss = self.train_generator_with_attention(len(real_data))
if epoch % 10 == 0:
self.visualize_attention_maps(epoch)
📊 Training Monitoring and Debugging
Comprehensive Monitoring
class QuantumGANMonitor:
def __init__(self, save_dir="./monitoring"):
self.save_dir = save_dir
self.metrics = {
'generator_loss': [],
'discriminator_loss': [],
'gradient_norms': [],
'entanglement_measures': [],
'fid_scores': [],
'inception_scores': []
}
self.setup_logging()
def log_training_step(self, qgan, epoch, batch_idx, g_loss, d_loss):
"""Log metrics for each training step."""
# Basic losses
self.metrics['generator_loss'].append(g_loss)
self.metrics['discriminator_loss'].append(d_loss)
# Gradient analysis
g_grad_norm = self.compute_gradient_norm(qgan.generator)
d_grad_norm = self.compute_gradient_norm(qgan.discriminator)
self.metrics['gradient_norms'].append({
'generator': g_grad_norm,
'discriminator': d_grad_norm
})
# Quantum-specific metrics
if batch_idx % 50 == 0: # Compute expensive metrics less frequently
entanglement = self.measure_entanglement(qgan.generator)
self.metrics['entanglement_measures'].append(entanglement)
# Evaluation metrics
if batch_idx % 200 == 0:
fid = self.compute_fid(qgan)
inception_score = self.compute_inception_score(qgan)
self.metrics['fid_scores'].append(fid)
self.metrics['inception_scores'].append(inception_score)
def detect_issues(self):
"""Detect common training issues."""
issues = []
# Check for barren plateaus
recent_grads = self.metrics['gradient_norms'][-10:]
if len(recent_grads) >= 10:
avg_g_grad = np.mean([g['generator'] for g in recent_grads])
if avg_g_grad < 1e-6:
issues.append("Potential barren plateau in generator")
# Check for mode collapse
if len(self.metrics['inception_scores']) >= 5:
recent_is = self.metrics['inception_scores'][-5:]
if np.std(recent_is) < 0.1:
issues.append("Potential mode collapse detected")
# Check for training imbalance
recent_g_loss = self.metrics['generator_loss'][-100:]
recent_d_loss = self.metrics['discriminator_loss'][-100:]
if len(recent_g_loss) >= 100:
g_trend = np.polyfit(range(100), recent_g_loss, 1)[0]
d_trend = np.polyfit(range(100), recent_d_loss, 1)[0]
if g_trend > 0.01 and d_trend < -0.01:
issues.append("Discriminator overpowering generator")
elif g_trend < -0.01 and d_trend > 0.01:
issues.append("Generator overpowering discriminator")
return issues
Real-time Visualization
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
class RealTimeVisualizer:
def __init__(self, qgan, monitor):
self.qgan = qgan
self.monitor = monitor
self.fig, self.axes = plt.subplots(2, 3, figsize=(15, 10))
self.setup_plots()
def setup_plots(self):
"""Setup real-time plotting."""
self.axes[0, 0].set_title("Training Losses")
self.axes[0, 1].set_title("Gradient Norms")
self.axes[0, 2].set_title("Generated Samples")
self.axes[1, 0].set_title("FID Score")
self.axes[1, 1].set_title("Inception Score")
self.axes[1, 2].set_title("Entanglement")
def update_plots(self, frame):
"""Update plots in real-time."""
# Clear axes
for ax_row in self.axes:
for ax in ax_row:
ax.clear()
# Plot losses
if self.monitor.metrics['generator_loss']:
epochs = range(len(self.monitor.metrics['generator_loss']))
self.axes[0, 0].plot(epochs, self.monitor.metrics['generator_loss'], label='Generator')
self.axes[0, 0].plot(epochs, self.monitor.metrics['discriminator_loss'], label='Discriminator')
self.axes[0, 0].legend()
self.axes[0, 0].set_title("Training Losses")
# Plot gradient norms
if self.monitor.metrics['gradient_norms']:
g_norms = [g['generator'] for g in self.monitor.metrics['gradient_norms']]
d_norms = [g['discriminator'] for g in self.monitor.metrics['gradient_norms']]
self.axes[0, 1].plot(g_norms, label='Generator')
self.axes[0, 1].plot(d_norms, label='Discriminator')
self.axes[0, 1].legend()
self.axes[0, 1].set_title("Gradient Norms")
# Show latest generated samples
samples = self.qgan.generate_samples(16)
if samples.dim() == 4: # Image data
grid = torchvision.utils.make_grid(samples[:16], nrow=4)
self.axes[0, 2].imshow(grid.permute(1, 2, 0).cpu().numpy())
self.axes[0, 2].set_title("Generated Samples")
# Plot evaluation metrics
if self.monitor.metrics['fid_scores']:
self.axes[1, 0].plot(self.monitor.metrics['fid_scores'])
self.axes[1, 0].set_title("FID Score")
if self.monitor.metrics['inception_scores']:
self.axes[1, 1].plot(self.monitor.metrics['inception_scores'])
self.axes[1, 1].set_title("Inception Score")
if self.monitor.metrics['entanglement_measures']:
self.axes[1, 2].plot(self.monitor.metrics['entanglement_measures'])
self.axes[1, 2].set_title("Entanglement")
def start_animation(self, interval=1000):
"""Start real-time animation."""
ani = FuncAnimation(self.fig, self.update_plots, interval=interval)
plt.show()
return ani
🔧 Hyperparameter Optimization
Grid Search
def quantum_grid_search():
"""Systematic hyperparameter search for quantum GANs."""
param_grid = {
'n_qubits': [4, 6, 8, 10],
'n_layers': [2, 3, 4, 5],
'learning_rate': [0.0001, 0.0002, 0.0005, 0.001],
'batch_size': [32, 64, 128],
'optimizer': ['adam', 'quantum_natural'],
'backend': ['qiskit', 'pennylane']
}
best_params = None
best_score = float('inf')
results = []
total_combinations = np.prod([len(v) for v in param_grid.values()])
print(f"Testing {total_combinations} combinations...")
for params in itertools.product(*param_grid.values()):
param_dict = dict(zip(param_grid.keys(), params))
try:
# Train model with these parameters
score = train_and_evaluate_config(param_dict)
results.append((param_dict.copy(), score))
if score < best_score:
best_score = score
best_params = param_dict.copy()
print(f"New best score: {best_score:.4f} with params: {best_params}")
except Exception as e:
print(f"Failed with params {param_dict}: {e}")
continue
return best_params, best_score, results
Bayesian Optimization
from skopt import gp_minimize
from skopt.space import Real, Integer, Categorical
def bayesian_optimization():
"""Bayesian optimization for quantum GAN hyperparameters."""
def objective(params):
n_qubits, n_layers, lr, batch_size, backend = params
config = {
'n_qubits': int(n_qubits),
'n_layers': int(n_layers),
'learning_rate': lr,
'batch_size': int(batch_size),
'backend': backend
}
try:
score = train_and_evaluate_config(config, epochs=50)
return score
except Exception as e:
print(f"Error with config {config}: {e}")
return float('inf')
# Define search space
space = [
Integer(4, 12, name='n_qubits'),
Integer(2, 6, name='n_layers'),
Real(1e-4, 1e-2, name='learning_rate', prior='log-uniform'),
Integer(16, 128, name='batch_size'),
Categorical(['qiskit', 'pennylane'], name='backend')
]
# Optimize
result = gp_minimize(
func=objective,
dimensions=space,
n_calls=50,
n_initial_points=10,
random_state=42,
acq_func='EI' # Expected Improvement
)
best_params = dict(zip(['n_qubits', 'n_layers', 'learning_rate', 'batch_size', 'backend'], result.x))
return best_params, result.fun
🚨 Troubleshooting Common Issues
Barren Plateaus
Symptoms: Vanishing gradients, no learning progress
Solutions:
def handle_barren_plateau(qgan):
"""Strategies to handle barren plateaus."""
# 1. Reduce circuit depth
if qgan.generator.n_layers > 3:
qgan.generator.reduce_layers(target_layers=2)
print("Reduced circuit depth")
# 2. Better parameter initialization
qgan.generator.initialize_parameters(method='xavier')
qgan.discriminator.initialize_parameters(method='xavier')
print("Reinitialized parameters")
# 3. Use local cost functions
qgan.use_local_cost_function()
print("Switched to local cost function")
# 4. Increase learning rate temporarily
original_lr = qgan.optimizer_g.param_groups[0]['lr']
qgan.optimizer_g.param_groups[0]['lr'] *= 10
print(f"Increased learning rate from {original_lr} to {qgan.optimizer_g.param_groups[0]['lr']}")
return qgan
Mode Collapse
Symptoms: Low sample diversity, repeating patterns
Solutions:
def handle_mode_collapse(qgan):
"""Strategies to handle mode collapse."""
# 1. Add noise to discriminator inputs
qgan.add_discriminator_noise(std=0.1)
# 2. Use spectral normalization
qgan.discriminator.apply_spectral_normalization()
# 3. Increase generator complexity
qgan.generator.add_layer()
# 4. Use minibatch discrimination
qgan.enable_minibatch_discrimination()
# 5. Adjust training frequency
qgan.set_training_ratio(generator_steps=2, discriminator_steps=1)
return qgan
Training Instability
Symptoms: Oscillating losses, sudden loss spikes
Solutions:
def stabilize_training(qgan):
"""Strategies to stabilize training."""
# 1. Use gradient clipping
qgan.enable_gradient_clipping(max_norm=1.0)
# 2. Reduce learning rates
qgan.scale_learning_rates(factor=0.5)
# 3. Use EMA for generator
qgan.enable_exponential_moving_average(decay=0.999)
# 4. Add label smoothing
qgan.enable_label_smoothing(real_label=0.9, fake_label=0.1)
# 5. Use different optimizers
qgan.switch_to_rmsprop()
return qgan
💾 Checkpointing and Recovery
Advanced Checkpointing
class AdvancedCheckpointing:
def __init__(self, save_dir="./checkpoints"):
self.save_dir = save_dir
os.makedirs(save_dir, exist_ok=True)
def save_checkpoint(self, qgan, epoch, metrics, optimizer_states):
"""Save comprehensive checkpoint."""
checkpoint = {
'epoch': epoch,
'generator_state_dict': qgan.generator.state_dict(),
'discriminator_state_dict': qgan.discriminator.state_dict(),
'optimizer_g_state_dict': qgan.optimizer_g.state_dict(),
'optimizer_d_state_dict': qgan.optimizer_d.state_dict(),
'metrics': metrics,
'quantum_circuit_params': qgan.generator.get_quantum_parameters(),
'training_config': qgan.get_config(),
'random_state': torch.get_rng_state()
}
# Save main checkpoint
checkpoint_path = os.path.join(self.save_dir, f"checkpoint_epoch_{epoch}.pth")
torch.save(checkpoint, checkpoint_path)
# Save best model separately
if self.is_best_model(metrics):
best_path = os.path.join(self.save_dir, "best_model.pth")
torch.save(checkpoint, best_path)
# Keep only last N checkpoints
self.cleanup_old_checkpoints(keep_last=5)
def load_checkpoint(self, qgan, checkpoint_path):
"""Load checkpoint and resume training."""
checkpoint = torch.load(checkpoint_path, map_location=qgan.device)
# Restore model states
qgan.generator.load_state_dict(checkpoint['generator_state_dict'])
qgan.discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
qgan.optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict'])
qgan.optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict'])
# Restore quantum parameters
qgan.generator.set_quantum_parameters(checkpoint['quantum_circuit_params'])
# Restore random state
torch.set_rng_state(checkpoint['random_state'])
return checkpoint['epoch'], checkpoint['metrics']
🎯 Training Best Practices
Do's
✅ Start with small circuits (4-6 qubits, 2-3 layers) ✅ Monitor gradient norms to detect barren plateaus early ✅ Use appropriate learning rates (typically 0.0001-0.001) ✅ Apply regularization to prevent overfitting ✅ Save checkpoints frequently for recovery ✅ Visualize training progress in real-time ✅ Use quantum-aware optimizers when available
Don'ts
❌ Don't use too deep circuits on NISQ devices ❌ Don't ignore quantum noise in real hardware training ❌ Don't train both networks equally if one dominates ❌ Don't use high learning rates without careful tuning ❌ Don't forget to validate on held-out data ❌ Don't train without monitoring quantum-specific metrics
Training Schedule Template
def optimal_training_schedule():
"""Recommended training schedule for quantum GANs."""
schedule = {
# Phase 1: Warmup (Epochs 1-20)
'warmup': {
'epochs': 20,
'lr_g': 0.0001,
'lr_d': 0.0001,
'n_qubits': 4,
'n_layers': 2,
'focus': 'stability'
},
# Phase 2: Growth (Epochs 21-60)
'growth': {
'epochs': 40,
'lr_g': 0.0002,
'lr_d': 0.0002,
'n_qubits': 6,
'n_layers': 3,
'focus': 'capacity'
},
# Phase 3: Refinement (Epochs 61-100)
'refinement': {
'epochs': 40,
'lr_g': 0.0001,
'lr_d': 0.0001,
'n_qubits': 8,
'n_layers': 4,
'focus': 'quality'
},
# Phase 4: Fine-tuning (Epochs 101-120)
'fine_tuning': {
'epochs': 20,
'lr_g': 0.00005,
'lr_d': 0.00005,
'n_qubits': 8,
'n_layers': 4,
'focus': 'polish'
}
}
return schedule
Training Efficiency
Use mixed precision training and gradient accumulation for better memory efficiency with large quantum circuits.
Hardware Considerations
When training on real quantum hardware, factor in queue times and shot noise in your training schedule.
Scalability
Quantum circuit simulation becomes exponentially expensive. Plan your computational resources accordingly.