Skip to content

Utils API Reference

This module contains utility functions, data loaders, metrics, and helper tools for quantum GANs.

Data Loading and Preprocessing

get_data_loader

Factory function for creating data loaders for common datasets.

def get_data_loader(
    dataset_name: str,
    batch_size: int = 64,
    train: bool = True,
    download: bool = True,
    transform: Optional[callable] = None,
    quantum_preprocessing: bool = False,
    **kwargs
) -> DataLoader:

Parameters:

  • dataset_name: Name of dataset ('mnist', 'fashion-mnist', 'cifar10', 'celeba')
  • batch_size: Batch size for training
  • train: Whether to load training or test split
  • download: Download dataset if not available
  • transform: Optional data transformations
  • quantum_preprocessing: Apply quantum-aware preprocessing
  • **kwargs: Additional dataset-specific parameters

Basic Data Loading

from qgans_pro.utils import get_data_loader

# Load Fashion-MNIST
train_loader = get_data_loader(
    dataset_name="fashion-mnist",
    batch_size=64,
    train=True,
    download=True
)

# Load CIFAR-10 with custom transforms
import torchvision.transforms as transforms
transform = transforms.Compose([
    transforms.Resize(32),
    transforms.ToTensor(),
    transforms.Normalize((0.5,), (0.5,))
])

cifar_loader = get_data_loader(
    dataset_name="cifar10",
    batch_size=32,
    transform=transform,
    quantum_preprocessing=True
)

Quantum-Aware Preprocessing

# Enable quantum preprocessing for better quantum circuit training
quantum_loader = get_data_loader(
    dataset_name="mnist",
    batch_size=64,
    quantum_preprocessing=True,
    amplitude_encoding=True,  # Prepare for amplitude encoding
    normalize_quantum=True    # Normalize for quantum states
)

prepare_quantum_data

Prepare classical data for quantum processing.

def prepare_quantum_data(
    data: torch.Tensor,
    encoding_type: str = "amplitude",
    n_qubits: int = None,
    normalization: str = "l2"
) -> torch.Tensor:

Data Preparation Examples

from qgans_pro.utils import prepare_quantum_data

# Prepare data for amplitude encoding
classical_data = torch.randn(32, 784)  # Batch of MNIST images
quantum_data = prepare_quantum_data(
    data=classical_data,
    encoding_type="amplitude",
    n_qubits=10,
    normalization="l2"
)

# Prepare for angle encoding
angle_data = prepare_quantum_data(
    data=classical_data,
    encoding_type="angle",
    n_qubits=10
)

Evaluation Metrics

FIDScore

Fréchet Inception Distance for generative model evaluation.

class FIDScore:
    """
    Compute Fréchet Inception Distance between real and generated data.

    Lower FID scores indicate better generation quality.
    """

FID Score Calculation

from qgans_pro.utils import FIDScore
import torch

# Initialize FID metric
fid_metric = FIDScore(
    device='cuda',
    batch_size=50,
    dims=2048  # InceptionV3 feature dimension
)

# Calculate FID score
real_samples = torch.randn(1000, 3, 64, 64)
generated_samples = qgan.generate_samples(1000)

fid_score = fid_metric(real_samples, generated_samples)
print(f"FID Score: {fid_score:.2f}")

# Calculate with custom feature extractor
custom_fid = FIDScore(
    feature_extractor='resnet50',
    device='cuda'
)
fid_score = custom_fid(real_samples, generated_samples)

InceptionScore

Inception Score for evaluating generation quality and diversity.

class InceptionScore:
    """
    Compute Inception Score for generated samples.

    Higher scores indicate better quality and diversity.
    """

Inception Score Calculation

from qgans_pro.utils import InceptionScore

# Initialize IS metric
is_metric = InceptionScore(
    device='cuda',
    batch_size=32,
    splits=10  # Number of splits for statistical robustness
)

# Calculate Inception Score
generated_samples = qgan.generate_samples(5000)
is_mean, is_std = is_metric(generated_samples)

print(f"Inception Score: {is_mean:.2f} ± {is_std:.2f}")

QuantumFidelity

Quantum-specific fidelity metrics for quantum GANs.

class QuantumFidelity:
    """
    Compute quantum fidelity and related quantum metrics.

    Measures how well quantum properties are preserved during generation.
    """

Quantum Fidelity Examples

from qgans_pro.utils import QuantumFidelity

# Initialize quantum fidelity metric
qf_metric = QuantumFidelity(
    backend='qiskit',
    device='aer_simulator',
    shots=1024
)

# Calculate quantum metrics
quantum_metrics = qf_metric(
    real_data=real_samples,
    generated_data=generated_samples,
    quantum_states=generator.get_quantum_states()
)

print(f"Quantum Fidelity: {quantum_metrics['fidelity']:.4f}")
print(f"Entanglement Measure: {quantum_metrics['entanglement']:.4f}")
print(f"Quantum Volume: {quantum_metrics['quantum_volume']}")

# Advanced quantum analysis
detailed_metrics = qf_metric.detailed_analysis(
    quantum_circuit=generator.circuit,
    measurement_basis='computational'
)

Visualization Tools

plot_generated_samples

Visualize generated samples in a grid layout.

def plot_generated_samples(
    samples: torch.Tensor,
    title: str = "Generated Samples",
    nrow: int = 8,
    padding: int = 2,
    normalize: bool = True,
    save_path: Optional[str] = None
) -> matplotlib.figure.Figure:

Sample Visualization

from qgans_pro.utils import plot_generated_samples
import matplotlib.pyplot as plt

# Generate and visualize samples
samples = qgan.generate_samples(64)
fig = plot_generated_samples(
    samples=samples,
    title="Quantum GAN Generated Fashion-MNIST",
    nrow=8,
    normalize=True,
    save_path="generated_samples.png"
)

plt.show()

plot_training_curves

Visualize training progress and metrics.

def plot_training_curves(
    training_history: Dict[str, List[float]],
    metrics: List[str] = None,
    save_path: Optional[str] = None,
    smoothing: float = 0.9
) -> matplotlib.figure.Figure:

Training Visualization

from qgans_pro.utils import plot_training_curves

# Plot training history
training_history = qgan.get_training_history()
fig = plot_training_curves(
    training_history=training_history,
    metrics=['g_loss', 'd_loss', 'fid_score', 'quantum_fidelity'],
    smoothing=0.9,
    save_path="training_curves.png"
)

QuantumCircuitVisualizer

Specialized visualization for quantum circuits.

class QuantumCircuitVisualizer:
    """
    Visualize quantum circuits used in quantum GANs.

    Provides various visualization options for quantum circuits,
    including circuit diagrams, state evolution, and quantum metrics.
    """

Quantum Circuit Visualization

from qgans_pro.utils import QuantumCircuitVisualizer

# Initialize circuit visualizer
circuit_viz = QuantumCircuitVisualizer(backend='qiskit')

# Visualize generator circuit
generator_circuit = quantum_generator.get_circuit()
circuit_viz.draw_circuit(
    circuit=generator_circuit,
    output='mpl',
    title='Quantum Generator Circuit',
    save_path='generator_circuit.png'
)

# Visualize quantum state evolution
state_evolution = quantum_generator.get_state_evolution()
circuit_viz.animate_state_evolution(
    states=state_evolution,
    save_path='state_evolution.gif'
)

# Create quantum metrics dashboard
circuit_viz.create_metrics_dashboard(
    circuit=generator_circuit,
    metrics=['depth', 'gate_count', 'entanglement'],
    save_path='quantum_dashboard.html'
)

Model Utilities

ModelSummary

Generate detailed model summaries including quantum circuit information.

class ModelSummary:
    """
    Generate comprehensive model summaries for quantum GANs.

    Includes classical neural network parameters, quantum circuit
    properties, and computational complexity analysis.
    """

Model Summary Examples

from qgans_pro.utils import ModelSummary

# Create model summary
summary = ModelSummary()

# Summarize quantum generator
gen_summary = summary.summarize_model(
    model=quantum_generator,
    input_shape=(64, 8),  # Batch size, latent dim
    include_quantum_info=True
)

print(gen_summary)

# Compare multiple models
comparison = summary.compare_models({
    'Quantum Generator': quantum_generator,
    'Classical Generator': classical_generator,
    'Hybrid Generator': hybrid_generator
})

summary.export_comparison(comparison, 'model_comparison.html')

checkpoint_utils

Utilities for saving and loading model checkpoints.

def save_checkpoint(
    model: nn.Module,
    optimizer: torch.optim.Optimizer,
    epoch: int,
    loss: float,
    path: str,
    metadata: Dict = None
) -> None:

def load_checkpoint(
    path: str,
    model: nn.Module = None,
    optimizer: torch.optim.Optimizer = None
) -> Dict:

Checkpoint Management

from qgans_pro.utils import save_checkpoint, load_checkpoint

# Save checkpoint with metadata
save_checkpoint(
    model=qgan,
    optimizer=optimizer,
    epoch=100,
    loss=0.0456,
    path='qgan_checkpoint_epoch_100.pth',
    metadata={
        'dataset': 'fashion-mnist',
        'n_qubits': 8,
        'fid_score': 15.2,
        'config': training_config
    }
)

# Load checkpoint
checkpoint_data = load_checkpoint(
    path='qgan_checkpoint_epoch_100.pth',
    model=qgan,
    optimizer=optimizer
)

print(f"Loaded model from epoch {checkpoint_data['epoch']}")
print(f"FID Score: {checkpoint_data['metadata']['fid_score']}")

Quantum Utilities

QuantumStateAnalyzer

Analyze quantum states and circuits.

class QuantumStateAnalyzer:
    """
    Analyze quantum states, circuits, and quantum properties.

    Provides tools for understanding quantum behavior in GANs.
    """

Quantum State Analysis

from qgans_pro.utils import QuantumStateAnalyzer

# Initialize analyzer
analyzer = QuantumStateAnalyzer(backend='qiskit')

# Analyze quantum circuit properties
circuit_analysis = analyzer.analyze_circuit(
    circuit=quantum_generator.circuit,
    metrics=['expressibility', 'entangling_capability', 'effective_dimension']
)

print(f"Expressibility: {circuit_analysis['expressibility']:.4f}")
print(f"Entangling Capability: {circuit_analysis['entangling_capability']:.4f}")

# Analyze quantum state properties
state_analysis = analyzer.analyze_quantum_state(
    quantum_state=final_state,
    analysis_type='full'  # 'full', 'entanglement', 'coherence'
)

# Generate quantum circuit recommendations
recommendations = analyzer.suggest_circuit_improvements(
    current_circuit=quantum_generator.circuit,
    target_metrics={'expressibility': 0.8, 'depth': 10}
)

quantum_encoding

Utilities for encoding classical data into quantum states.

def amplitude_encoding(
    data: torch.Tensor,
    n_qubits: int,
    normalization: str = 'l2'
) -> torch.Tensor:

def angle_encoding(
    data: torch.Tensor,
    n_qubits: int,
    encoding_scheme: str = 'basic'
) -> torch.Tensor:

def iqp_encoding(
    data: torch.Tensor,
    n_qubits: int,
    depth: int = 1
) -> torch.Tensor:

Quantum Encoding Examples

from qgans_pro.utils.quantum_encoding import (
    amplitude_encoding, 
    angle_encoding, 
    iqp_encoding
)

# Amplitude encoding
classical_data = torch.randn(32, 256)
amplitude_encoded = amplitude_encoding(
    data=classical_data,
    n_qubits=8,
    normalization='l2'
)

# Angle encoding
angle_encoded = angle_encoding(
    data=classical_data,
    n_qubits=8,
    encoding_scheme='linear'  # 'linear', 'polynomial'
)

# IQP encoding for entangled features
iqp_encoded = iqp_encoding(
    data=classical_data,
    n_qubits=8,
    depth=2
)

Performance and Profiling

PerformanceProfiler

Profile quantum GAN training performance.

class PerformanceProfiler:
    """
    Profile quantum GAN training for performance optimization.

    Monitors computational resources, quantum circuit execution
    times, and identifies bottlenecks.
    """

Performance Profiling

from qgans_pro.utils import PerformanceProfiler

# Initialize profiler
profiler = PerformanceProfiler(
    profile_quantum=True,
    profile_classical=True,
    memory_tracking=True
)

# Profile training step
with profiler.profile_context('training_step'):
    losses = qgan.train_step(batch_data)

# Get profiling results
results = profiler.get_results()
print(f"Quantum circuit time: {results['quantum_time']:.4f}s")
print(f"Classical computation time: {results['classical_time']:.4f}s")
print(f"Memory usage: {results['memory_peak']:.2f} MB")

# Generate performance report
profiler.generate_report('performance_report.html')

MemoryOptimizer

Optimize memory usage for large quantum circuits.

class MemoryOptimizer:
    """
    Optimize memory usage in quantum GAN training.

    Provides strategies for reducing memory footprint while
    maintaining training effectiveness.
    """

Memory Optimization

from qgans_pro.utils import MemoryOptimizer

# Initialize memory optimizer
memory_opt = MemoryOptimizer()

# Apply memory optimizations
optimized_generator = memory_opt.optimize_model(
    model=quantum_generator,
    strategies=['gradient_checkpointing', 'mixed_precision', 'circuit_chunking']
)

# Monitor memory usage
memory_stats = memory_opt.monitor_memory(
    model=optimized_generator,
    dataloader=train_loader,
    num_batches=10
)

print(f"Peak memory: {memory_stats['peak_memory']} MB")
print(f"Average memory: {memory_stats['avg_memory']} MB")

Configuration Management

Config

Configuration management for quantum GANs.

class Config:
    """
    Configuration management system for quantum GANs.

    Handles model configurations, training parameters,
    and experiment settings with validation and serialization.
    """

Configuration Management

from qgans_pro.utils import Config

# Load configuration from file
config = Config.from_file('qgan_config.yaml')

# Create configuration programmatically
config = Config({
    'model': {
        'generator': {
            'n_qubits': 8,
            'n_layers': 3,
            'backend': 'qiskit'
        },
        'discriminator': {
            'hidden_dims': [512, 256, 128],
            'activation': 'leaky_relu'
        }
    },
    'training': {
        'batch_size': 64,
        'learning_rate': 0.0002,
        'epochs': 100
    }
})

# Validate configuration
validation_result = config.validate()
if not validation_result.is_valid:
    print(f"Configuration errors: {validation_result.errors}")

# Save configuration
config.save('experiment_config.yaml')

# Use configuration to create models
generator = config.create_generator()
discriminator = config.create_discriminator()

Best Practices and Helpers

ExperimentTracker

Track experiments and compare results.

class ExperimentTracker:
    """
    Track quantum GAN experiments and results.

    Provides experiment logging, comparison tools,
    and result visualization.
    """

Experiment Tracking

from qgans_pro.utils import ExperimentTracker

# Initialize experiment tracker
tracker = ExperimentTracker(
    project_name='quantum_gan_research',
    save_dir='./experiments'
)

# Start new experiment
experiment_id = tracker.start_experiment(
    name='fashion_mnist_8qubit',
    config=config,
    description='8-qubit quantum generator on Fashion-MNIST'
)

# Log metrics during training
for epoch in range(epochs):
    # Training step
    losses = qgan.train_step(batch)

    # Log metrics
    tracker.log_metrics({
        'epoch': epoch,
        'g_loss': losses['g_loss'],
        'd_loss': losses['d_loss'],
        'fid_score': fid_score
    })

# End experiment
tracker.end_experiment(
    final_metrics={'best_fid': 12.5},
    model_path='final_model.pth'
)

# Compare experiments
comparison = tracker.compare_experiments([
    'fashion_mnist_8qubit',
    'fashion_mnist_6qubit',
    'fashion_mnist_classical'
])

tracker.visualize_comparison(comparison)

reproducibility

Utilities for reproducible experiments.

def set_seed(seed: int) -> None:
    """Set random seeds for reproducible results."""

def get_system_info() -> Dict:
    """Get system information for experiment reproducibility."""

def create_reproducible_environment(seed: int = 42) -> Dict:
    """Create reproducible environment for quantum GAN experiments."""

Reproducibility Setup

from qgans_pro.utils.reproducibility import (
    set_seed, 
    get_system_info, 
    create_reproducible_environment
)

# Set up reproducible environment
env_info = create_reproducible_environment(seed=42)
print(f"Environment: {env_info}")

# Set random seeds
set_seed(42)

# Get system information for logging
system_info = get_system_info()
print(f"System: {system_info['platform']}")
print(f"Python: {system_info['python_version']}")
print(f"PyTorch: {system_info['torch_version']}")
print(f"Quantum backends: {system_info['quantum_backends']}")