Skip to content

Training Models: Complete Guide

Master the art of training quantum GANs with comprehensive strategies, advanced techniques, and practical examples for optimal performance.

🚀 Getting Started with Training

Basic Training Setup

import torch
from qgans_pro import (
    QuantumGenerator, QuantumDiscriminator, QuantumGAN,
    get_data_loader
)

# Configuration
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
n_qubits = 8
n_layers = 3
batch_size = 64
epochs = 100

# Load data
train_loader = get_data_loader(
    dataset_name="fashion-mnist",
    batch_size=batch_size,
    train=True,
    download=True
)

# Create models
generator = QuantumGenerator(
    n_qubits=n_qubits,
    n_layers=n_layers,
    output_dim=784,  # 28x28 images
    backend="qiskit",
    device="aer_simulator"
)

discriminator = QuantumDiscriminator(
    input_dim=784,
    n_qubits=n_qubits,
    n_layers=n_layers,
    backend="qiskit",
    device="aer_simulator"
)

# Initialize trainer
qgan = QuantumGAN(
    generator=generator,
    discriminator=discriminator,
    device=device
)

# Train
qgan.train(
    dataloader=train_loader,
    epochs=epochs,
    save_interval=10,
    sample_interval=5
)

⚙️ Training Configuration

Learning Rate Strategies

Different learning rate strategies for optimal convergence:

# Adaptive learning rates
training_config = {
    "lr_g": 0.0002,  # Generator learning rate
    "lr_d": 0.0002,  # Discriminator learning rate
    "beta1": 0.5,    # Adam beta1
    "beta2": 0.999,  # Adam beta2
    "scheduler": "cosine",  # Learning rate scheduler
    "warmup_epochs": 10
}

# Step-wise decay
step_config = {
    "lr_g": 0.001,
    "lr_d": 0.001,
    "scheduler": "step",
    "step_size": 30,
    "gamma": 0.5
}

# Exponential decay
exp_config = {
    "lr_g": 0.002,
    "lr_d": 0.002,
    "scheduler": "exponential",
    "gamma": 0.95
}

Training Regimens

Standard Training

def standard_training(qgan, dataloader, epochs=100):
    """Standard alternating training."""

    for epoch in range(epochs):
        for batch_idx, (real_data, _) in enumerate(dataloader):
            # Train discriminator
            d_loss = qgan.train_discriminator(real_data)

            # Train generator
            g_loss = qgan.train_generator(batch_size=len(real_data))

            if batch_idx % 100 == 0:
                print(f"Epoch {epoch}, Batch {batch_idx}: G_loss={g_loss:.4f}, D_loss={d_loss:.4f}")

Progressive Training

def progressive_training(qgan, dataloader, stages=[(4,2), (6,3), (8,4)]):
    """Progressive training with increasing model complexity."""

    for stage, (n_qubits, n_layers) in enumerate(stages):
        print(f"Stage {stage+1}: {n_qubits} qubits, {n_layers} layers")

        # Update model architecture
        qgan.update_architecture(n_qubits=n_qubits, n_layers=n_layers)

        # Train for this stage
        epochs_per_stage = 50
        for epoch in range(epochs_per_stage):
            for real_data, _ in dataloader:
                d_loss = qgan.train_discriminator(real_data)
                g_loss = qgan.train_generator(batch_size=len(real_data))

        # Evaluate stage
        qgan.evaluate_stage(stage)

Curriculum Learning

def curriculum_training(qgan, dataloader, difficulty_schedule):
    """Curriculum learning with increasing data complexity."""

    for phase, (difficulty, epochs) in enumerate(difficulty_schedule):
        print(f"Phase {phase+1}: Difficulty {difficulty}")

        # Filter data by difficulty
        filtered_loader = filter_by_difficulty(dataloader, difficulty)

        for epoch in range(epochs):
            for real_data, _ in filtered_loader:
                d_loss = qgan.train_discriminator(real_data)
                g_loss = qgan.train_generator(batch_size=len(real_data))

                # Adjust learning rate based on performance
                qgan.adjust_learning_rate(epoch, phase)

🎯 Advanced Training Techniques

Spectral Normalization

Stabilize discriminator training:

from qgans_pro.training import SpectralNormalization

# Apply spectral normalization to discriminator
discriminator = QuantumDiscriminator(
    input_dim=784,
    n_qubits=8,
    n_layers=3,
    spectral_norm=True,
    spectral_norm_target=1.0
)

# Custom spectral normalization
class SpectralNormalizedQGAN(QuantumGAN):
    def train_discriminator(self, real_data):
        # Apply spectral normalization before training
        self.discriminator.apply_spectral_norm()

        # Standard discriminator training
        loss = super().train_discriminator(real_data)

        return loss

Gradient Penalty

Improve training stability with gradient penalty:

def gradient_penalty(discriminator, real_data, fake_data, device):
    """Compute gradient penalty for WGAN-GP."""

    alpha = torch.rand(real_data.size(0), 1, device=device)
    alpha = alpha.expand_as(real_data)

    interpolated = alpha * real_data + (1 - alpha) * fake_data
    interpolated.requires_grad_(True)

    # Get discriminator output
    d_interpolated = discriminator(interpolated)

    # Compute gradients
    gradients = torch.autograd.grad(
        outputs=d_interpolated,
        inputs=interpolated,
        grad_outputs=torch.ones_like(d_interpolated),
        create_graph=True,
        retain_graph=True
    )[0]

    # Calculate penalty
    gradients = gradients.view(gradients.size(0), -1)
    gradient_norm = gradients.norm(2, dim=1)
    penalty = ((gradient_norm - 1) ** 2).mean()

    return penalty

# Use in training
class WGANGPQuantum(QuantumGAN):
    def __init__(self, *args, lambda_gp=10, **kwargs):
        super().__init__(*args, **kwargs)
        self.lambda_gp = lambda_gp

    def train_discriminator(self, real_data):
        # Generate fake data
        fake_data = self.generate_samples(len(real_data))

        # Standard losses
        real_loss = -self.discriminator(real_data).mean()
        fake_loss = self.discriminator(fake_data).mean()

        # Gradient penalty
        gp = gradient_penalty(self.discriminator, real_data, fake_data, self.device)

        # Total loss
        d_loss = real_loss + fake_loss + self.lambda_gp * gp

        # Backward pass
        self.optimizer_d.zero_grad()
        d_loss.backward()
        self.optimizer_d.step()

        return d_loss.item()

Self-Attention Mechanisms

Add self-attention to quantum circuits:

class SelfAttentionQuantumGAN(QuantumGAN):
    def __init__(self, *args, attention_layers=[2, 4], **kwargs):
        super().__init__(*args, **kwargs)
        self.attention_layers = attention_layers
        self.setup_attention()

    def setup_attention(self):
        """Setup self-attention mechanisms."""
        for layer in self.attention_layers:
            self.generator.add_attention_layer(layer)
            self.discriminator.add_attention_layer(layer)

    def train_with_attention(self, dataloader, epochs):
        """Training with attention mechanisms."""
        for epoch in range(epochs):
            for real_data, _ in dataloader:
                # Train with attention-enhanced features
                d_loss = self.train_discriminator_with_attention(real_data)
                g_loss = self.train_generator_with_attention(len(real_data))

                if epoch % 10 == 0:
                    self.visualize_attention_maps(epoch)

📊 Training Monitoring and Debugging

Comprehensive Monitoring

class QuantumGANMonitor:
    def __init__(self, save_dir="./monitoring"):
        self.save_dir = save_dir
        self.metrics = {
            'generator_loss': [],
            'discriminator_loss': [],
            'gradient_norms': [],
            'entanglement_measures': [],
            'fid_scores': [],
            'inception_scores': []
        }
        self.setup_logging()

    def log_training_step(self, qgan, epoch, batch_idx, g_loss, d_loss):
        """Log metrics for each training step."""

        # Basic losses
        self.metrics['generator_loss'].append(g_loss)
        self.metrics['discriminator_loss'].append(d_loss)

        # Gradient analysis
        g_grad_norm = self.compute_gradient_norm(qgan.generator)
        d_grad_norm = self.compute_gradient_norm(qgan.discriminator)
        self.metrics['gradient_norms'].append({
            'generator': g_grad_norm,
            'discriminator': d_grad_norm
        })

        # Quantum-specific metrics
        if batch_idx % 50 == 0:  # Compute expensive metrics less frequently
            entanglement = self.measure_entanglement(qgan.generator)
            self.metrics['entanglement_measures'].append(entanglement)

        # Evaluation metrics
        if batch_idx % 200 == 0:
            fid = self.compute_fid(qgan)
            inception_score = self.compute_inception_score(qgan)
            self.metrics['fid_scores'].append(fid)
            self.metrics['inception_scores'].append(inception_score)

    def detect_issues(self):
        """Detect common training issues."""
        issues = []

        # Check for barren plateaus
        recent_grads = self.metrics['gradient_norms'][-10:]
        if len(recent_grads) >= 10:
            avg_g_grad = np.mean([g['generator'] for g in recent_grads])
            if avg_g_grad < 1e-6:
                issues.append("Potential barren plateau in generator")

        # Check for mode collapse
        if len(self.metrics['inception_scores']) >= 5:
            recent_is = self.metrics['inception_scores'][-5:]
            if np.std(recent_is) < 0.1:
                issues.append("Potential mode collapse detected")

        # Check for training imbalance
        recent_g_loss = self.metrics['generator_loss'][-100:]
        recent_d_loss = self.metrics['discriminator_loss'][-100:]
        if len(recent_g_loss) >= 100:
            g_trend = np.polyfit(range(100), recent_g_loss, 1)[0]
            d_trend = np.polyfit(range(100), recent_d_loss, 1)[0]

            if g_trend > 0.01 and d_trend < -0.01:
                issues.append("Discriminator overpowering generator")
            elif g_trend < -0.01 and d_trend > 0.01:
                issues.append("Generator overpowering discriminator")

        return issues

Real-time Visualization

import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation

class RealTimeVisualizer:
    def __init__(self, qgan, monitor):
        self.qgan = qgan
        self.monitor = monitor
        self.fig, self.axes = plt.subplots(2, 3, figsize=(15, 10))
        self.setup_plots()

    def setup_plots(self):
        """Setup real-time plotting."""
        self.axes[0, 0].set_title("Training Losses")
        self.axes[0, 1].set_title("Gradient Norms")
        self.axes[0, 2].set_title("Generated Samples")
        self.axes[1, 0].set_title("FID Score")
        self.axes[1, 1].set_title("Inception Score")
        self.axes[1, 2].set_title("Entanglement")

    def update_plots(self, frame):
        """Update plots in real-time."""
        # Clear axes
        for ax_row in self.axes:
            for ax in ax_row:
                ax.clear()

        # Plot losses
        if self.monitor.metrics['generator_loss']:
            epochs = range(len(self.monitor.metrics['generator_loss']))
            self.axes[0, 0].plot(epochs, self.monitor.metrics['generator_loss'], label='Generator')
            self.axes[0, 0].plot(epochs, self.monitor.metrics['discriminator_loss'], label='Discriminator')
            self.axes[0, 0].legend()
            self.axes[0, 0].set_title("Training Losses")

        # Plot gradient norms
        if self.monitor.metrics['gradient_norms']:
            g_norms = [g['generator'] for g in self.monitor.metrics['gradient_norms']]
            d_norms = [g['discriminator'] for g in self.monitor.metrics['gradient_norms']]
            self.axes[0, 1].plot(g_norms, label='Generator')
            self.axes[0, 1].plot(d_norms, label='Discriminator')
            self.axes[0, 1].legend()
            self.axes[0, 1].set_title("Gradient Norms")

        # Show latest generated samples
        samples = self.qgan.generate_samples(16)
        if samples.dim() == 4:  # Image data
            grid = torchvision.utils.make_grid(samples[:16], nrow=4)
            self.axes[0, 2].imshow(grid.permute(1, 2, 0).cpu().numpy())
            self.axes[0, 2].set_title("Generated Samples")

        # Plot evaluation metrics
        if self.monitor.metrics['fid_scores']:
            self.axes[1, 0].plot(self.monitor.metrics['fid_scores'])
            self.axes[1, 0].set_title("FID Score")

        if self.monitor.metrics['inception_scores']:
            self.axes[1, 1].plot(self.monitor.metrics['inception_scores'])
            self.axes[1, 1].set_title("Inception Score")

        if self.monitor.metrics['entanglement_measures']:
            self.axes[1, 2].plot(self.monitor.metrics['entanglement_measures'])
            self.axes[1, 2].set_title("Entanglement")

    def start_animation(self, interval=1000):
        """Start real-time animation."""
        ani = FuncAnimation(self.fig, self.update_plots, interval=interval)
        plt.show()
        return ani

🔧 Hyperparameter Optimization

def quantum_grid_search():
    """Systematic hyperparameter search for quantum GANs."""

    param_grid = {
        'n_qubits': [4, 6, 8, 10],
        'n_layers': [2, 3, 4, 5],
        'learning_rate': [0.0001, 0.0002, 0.0005, 0.001],
        'batch_size': [32, 64, 128],
        'optimizer': ['adam', 'quantum_natural'],
        'backend': ['qiskit', 'pennylane']
    }

    best_params = None
    best_score = float('inf')
    results = []

    total_combinations = np.prod([len(v) for v in param_grid.values()])
    print(f"Testing {total_combinations} combinations...")

    for params in itertools.product(*param_grid.values()):
        param_dict = dict(zip(param_grid.keys(), params))

        try:
            # Train model with these parameters
            score = train_and_evaluate_config(param_dict)
            results.append((param_dict.copy(), score))

            if score < best_score:
                best_score = score
                best_params = param_dict.copy()
                print(f"New best score: {best_score:.4f} with params: {best_params}")

        except Exception as e:
            print(f"Failed with params {param_dict}: {e}")
            continue

    return best_params, best_score, results

Bayesian Optimization

from skopt import gp_minimize
from skopt.space import Real, Integer, Categorical

def bayesian_optimization():
    """Bayesian optimization for quantum GAN hyperparameters."""

    def objective(params):
        n_qubits, n_layers, lr, batch_size, backend = params

        config = {
            'n_qubits': int(n_qubits),
            'n_layers': int(n_layers),
            'learning_rate': lr,
            'batch_size': int(batch_size),
            'backend': backend
        }

        try:
            score = train_and_evaluate_config(config, epochs=50)
            return score
        except Exception as e:
            print(f"Error with config {config}: {e}")
            return float('inf')

    # Define search space
    space = [
        Integer(4, 12, name='n_qubits'),
        Integer(2, 6, name='n_layers'),
        Real(1e-4, 1e-2, name='learning_rate', prior='log-uniform'),
        Integer(16, 128, name='batch_size'),
        Categorical(['qiskit', 'pennylane'], name='backend')
    ]

    # Optimize
    result = gp_minimize(
        func=objective,
        dimensions=space,
        n_calls=50,
        n_initial_points=10,
        random_state=42,
        acq_func='EI'  # Expected Improvement
    )

    best_params = dict(zip(['n_qubits', 'n_layers', 'learning_rate', 'batch_size', 'backend'], result.x))
    return best_params, result.fun

🚨 Troubleshooting Common Issues

Barren Plateaus

Symptoms: Vanishing gradients, no learning progress

Solutions:

def handle_barren_plateau(qgan):
    """Strategies to handle barren plateaus."""

    # 1. Reduce circuit depth
    if qgan.generator.n_layers > 3:
        qgan.generator.reduce_layers(target_layers=2)
        print("Reduced circuit depth")

    # 2. Better parameter initialization
    qgan.generator.initialize_parameters(method='xavier')
    qgan.discriminator.initialize_parameters(method='xavier')
    print("Reinitialized parameters")

    # 3. Use local cost functions
    qgan.use_local_cost_function()
    print("Switched to local cost function")

    # 4. Increase learning rate temporarily
    original_lr = qgan.optimizer_g.param_groups[0]['lr']
    qgan.optimizer_g.param_groups[0]['lr'] *= 10
    print(f"Increased learning rate from {original_lr} to {qgan.optimizer_g.param_groups[0]['lr']}")

    return qgan

Mode Collapse

Symptoms: Low sample diversity, repeating patterns

Solutions:

def handle_mode_collapse(qgan):
    """Strategies to handle mode collapse."""

    # 1. Add noise to discriminator inputs
    qgan.add_discriminator_noise(std=0.1)

    # 2. Use spectral normalization
    qgan.discriminator.apply_spectral_normalization()

    # 3. Increase generator complexity
    qgan.generator.add_layer()

    # 4. Use minibatch discrimination
    qgan.enable_minibatch_discrimination()

    # 5. Adjust training frequency
    qgan.set_training_ratio(generator_steps=2, discriminator_steps=1)

    return qgan

Training Instability

Symptoms: Oscillating losses, sudden loss spikes

Solutions:

def stabilize_training(qgan):
    """Strategies to stabilize training."""

    # 1. Use gradient clipping
    qgan.enable_gradient_clipping(max_norm=1.0)

    # 2. Reduce learning rates
    qgan.scale_learning_rates(factor=0.5)

    # 3. Use EMA for generator
    qgan.enable_exponential_moving_average(decay=0.999)

    # 4. Add label smoothing
    qgan.enable_label_smoothing(real_label=0.9, fake_label=0.1)

    # 5. Use different optimizers
    qgan.switch_to_rmsprop()

    return qgan

💾 Checkpointing and Recovery

Advanced Checkpointing

class AdvancedCheckpointing:
    def __init__(self, save_dir="./checkpoints"):
        self.save_dir = save_dir
        os.makedirs(save_dir, exist_ok=True)

    def save_checkpoint(self, qgan, epoch, metrics, optimizer_states):
        """Save comprehensive checkpoint."""

        checkpoint = {
            'epoch': epoch,
            'generator_state_dict': qgan.generator.state_dict(),
            'discriminator_state_dict': qgan.discriminator.state_dict(),
            'optimizer_g_state_dict': qgan.optimizer_g.state_dict(),
            'optimizer_d_state_dict': qgan.optimizer_d.state_dict(),
            'metrics': metrics,
            'quantum_circuit_params': qgan.generator.get_quantum_parameters(),
            'training_config': qgan.get_config(),
            'random_state': torch.get_rng_state()
        }

        # Save main checkpoint
        checkpoint_path = os.path.join(self.save_dir, f"checkpoint_epoch_{epoch}.pth")
        torch.save(checkpoint, checkpoint_path)

        # Save best model separately
        if self.is_best_model(metrics):
            best_path = os.path.join(self.save_dir, "best_model.pth")
            torch.save(checkpoint, best_path)

        # Keep only last N checkpoints
        self.cleanup_old_checkpoints(keep_last=5)

    def load_checkpoint(self, qgan, checkpoint_path):
        """Load checkpoint and resume training."""

        checkpoint = torch.load(checkpoint_path, map_location=qgan.device)

        # Restore model states
        qgan.generator.load_state_dict(checkpoint['generator_state_dict'])
        qgan.discriminator.load_state_dict(checkpoint['discriminator_state_dict'])
        qgan.optimizer_g.load_state_dict(checkpoint['optimizer_g_state_dict'])
        qgan.optimizer_d.load_state_dict(checkpoint['optimizer_d_state_dict'])

        # Restore quantum parameters
        qgan.generator.set_quantum_parameters(checkpoint['quantum_circuit_params'])

        # Restore random state
        torch.set_rng_state(checkpoint['random_state'])

        return checkpoint['epoch'], checkpoint['metrics']

🎯 Training Best Practices

Do's

Start with small circuits (4-6 qubits, 2-3 layers) ✅ Monitor gradient norms to detect barren plateaus early ✅ Use appropriate learning rates (typically 0.0001-0.001) ✅ Apply regularization to prevent overfitting ✅ Save checkpoints frequently for recovery ✅ Visualize training progress in real-time ✅ Use quantum-aware optimizers when available

Don'ts

Don't use too deep circuits on NISQ devices ❌ Don't ignore quantum noise in real hardware training ❌ Don't train both networks equally if one dominates ❌ Don't use high learning rates without careful tuning ❌ Don't forget to validate on held-out data ❌ Don't train without monitoring quantum-specific metrics

Training Schedule Template

def optimal_training_schedule():
    """Recommended training schedule for quantum GANs."""

    schedule = {
        # Phase 1: Warmup (Epochs 1-20)
        'warmup': {
            'epochs': 20,
            'lr_g': 0.0001,
            'lr_d': 0.0001,
            'n_qubits': 4,
            'n_layers': 2,
            'focus': 'stability'
        },

        # Phase 2: Growth (Epochs 21-60)
        'growth': {
            'epochs': 40,
            'lr_g': 0.0002,
            'lr_d': 0.0002,
            'n_qubits': 6,
            'n_layers': 3,
            'focus': 'capacity'
        },

        # Phase 3: Refinement (Epochs 61-100)
        'refinement': {
            'epochs': 40,
            'lr_g': 0.0001,
            'lr_d': 0.0001,
            'n_qubits': 8,
            'n_layers': 4,
            'focus': 'quality'
        },

        # Phase 4: Fine-tuning (Epochs 101-120)
        'fine_tuning': {
            'epochs': 20,
            'lr_g': 0.00005,
            'lr_d': 0.00005,
            'n_qubits': 8,
            'n_layers': 4,
            'focus': 'polish'
        }
    }

    return schedule

Training Efficiency

Use mixed precision training and gradient accumulation for better memory efficiency with large quantum circuits.

Hardware Considerations

When training on real quantum hardware, factor in queue times and shot noise in your training schedule.

Scalability

Quantum circuit simulation becomes exponentially expensive. Plan your computational resources accordingly.