Utils API Reference
This module contains utility functions, data loaders, metrics, and helper tools for quantum GANs.
Data Loading and Preprocessing
get_data_loader
Factory function for creating data loaders for common datasets.
def get_data_loader(
dataset_name: str,
batch_size: int = 64,
train: bool = True,
download: bool = True,
transform: Optional[callable] = None,
quantum_preprocessing: bool = False,
**kwargs
) -> DataLoader:
Parameters:
dataset_name: Name of dataset ('mnist', 'fashion-mnist', 'cifar10', 'celeba')batch_size: Batch size for trainingtrain: Whether to load training or test splitdownload: Download dataset if not availabletransform: Optional data transformationsquantum_preprocessing: Apply quantum-aware preprocessing**kwargs: Additional dataset-specific parameters
Basic Data Loading
from qgans_pro.utils import get_data_loader
# Load Fashion-MNIST
train_loader = get_data_loader(
dataset_name="fashion-mnist",
batch_size=64,
train=True,
download=True
)
# Load CIFAR-10 with custom transforms
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.Resize(32),
transforms.ToTensor(),
transforms.Normalize((0.5,), (0.5,))
])
cifar_loader = get_data_loader(
dataset_name="cifar10",
batch_size=32,
transform=transform,
quantum_preprocessing=True
)
Quantum-Aware Preprocessing
# Enable quantum preprocessing for better quantum circuit training
quantum_loader = get_data_loader(
dataset_name="mnist",
batch_size=64,
quantum_preprocessing=True,
amplitude_encoding=True, # Prepare for amplitude encoding
normalize_quantum=True # Normalize for quantum states
)
prepare_quantum_data
Prepare classical data for quantum processing.
def prepare_quantum_data(
data: torch.Tensor,
encoding_type: str = "amplitude",
n_qubits: int = None,
normalization: str = "l2"
) -> torch.Tensor:
Data Preparation Examples
from qgans_pro.utils import prepare_quantum_data
# Prepare data for amplitude encoding
classical_data = torch.randn(32, 784) # Batch of MNIST images
quantum_data = prepare_quantum_data(
data=classical_data,
encoding_type="amplitude",
n_qubits=10,
normalization="l2"
)
# Prepare for angle encoding
angle_data = prepare_quantum_data(
data=classical_data,
encoding_type="angle",
n_qubits=10
)
Evaluation Metrics
FIDScore
Fréchet Inception Distance for generative model evaluation.
class FIDScore:
"""
Compute Fréchet Inception Distance between real and generated data.
Lower FID scores indicate better generation quality.
"""
FID Score Calculation
from qgans_pro.utils import FIDScore
import torch
# Initialize FID metric
fid_metric = FIDScore(
device='cuda',
batch_size=50,
dims=2048 # InceptionV3 feature dimension
)
# Calculate FID score
real_samples = torch.randn(1000, 3, 64, 64)
generated_samples = qgan.generate_samples(1000)
fid_score = fid_metric(real_samples, generated_samples)
print(f"FID Score: {fid_score:.2f}")
# Calculate with custom feature extractor
custom_fid = FIDScore(
feature_extractor='resnet50',
device='cuda'
)
fid_score = custom_fid(real_samples, generated_samples)
InceptionScore
Inception Score for evaluating generation quality and diversity.
class InceptionScore:
"""
Compute Inception Score for generated samples.
Higher scores indicate better quality and diversity.
"""
Inception Score Calculation
from qgans_pro.utils import InceptionScore
# Initialize IS metric
is_metric = InceptionScore(
device='cuda',
batch_size=32,
splits=10 # Number of splits for statistical robustness
)
# Calculate Inception Score
generated_samples = qgan.generate_samples(5000)
is_mean, is_std = is_metric(generated_samples)
print(f"Inception Score: {is_mean:.2f} ± {is_std:.2f}")
QuantumFidelity
Quantum-specific fidelity metrics for quantum GANs.
class QuantumFidelity:
"""
Compute quantum fidelity and related quantum metrics.
Measures how well quantum properties are preserved during generation.
"""
Quantum Fidelity Examples
from qgans_pro.utils import QuantumFidelity
# Initialize quantum fidelity metric
qf_metric = QuantumFidelity(
backend='qiskit',
device='aer_simulator',
shots=1024
)
# Calculate quantum metrics
quantum_metrics = qf_metric(
real_data=real_samples,
generated_data=generated_samples,
quantum_states=generator.get_quantum_states()
)
print(f"Quantum Fidelity: {quantum_metrics['fidelity']:.4f}")
print(f"Entanglement Measure: {quantum_metrics['entanglement']:.4f}")
print(f"Quantum Volume: {quantum_metrics['quantum_volume']}")
# Advanced quantum analysis
detailed_metrics = qf_metric.detailed_analysis(
quantum_circuit=generator.circuit,
measurement_basis='computational'
)
Visualization Tools
plot_generated_samples
Visualize generated samples in a grid layout.
def plot_generated_samples(
samples: torch.Tensor,
title: str = "Generated Samples",
nrow: int = 8,
padding: int = 2,
normalize: bool = True,
save_path: Optional[str] = None
) -> matplotlib.figure.Figure:
Sample Visualization
from qgans_pro.utils import plot_generated_samples
import matplotlib.pyplot as plt
# Generate and visualize samples
samples = qgan.generate_samples(64)
fig = plot_generated_samples(
samples=samples,
title="Quantum GAN Generated Fashion-MNIST",
nrow=8,
normalize=True,
save_path="generated_samples.png"
)
plt.show()
plot_training_curves
Visualize training progress and metrics.
def plot_training_curves(
training_history: Dict[str, List[float]],
metrics: List[str] = None,
save_path: Optional[str] = None,
smoothing: float = 0.9
) -> matplotlib.figure.Figure:
Training Visualization
from qgans_pro.utils import plot_training_curves
# Plot training history
training_history = qgan.get_training_history()
fig = plot_training_curves(
training_history=training_history,
metrics=['g_loss', 'd_loss', 'fid_score', 'quantum_fidelity'],
smoothing=0.9,
save_path="training_curves.png"
)
QuantumCircuitVisualizer
Specialized visualization for quantum circuits.
class QuantumCircuitVisualizer:
"""
Visualize quantum circuits used in quantum GANs.
Provides various visualization options for quantum circuits,
including circuit diagrams, state evolution, and quantum metrics.
"""
Quantum Circuit Visualization
from qgans_pro.utils import QuantumCircuitVisualizer
# Initialize circuit visualizer
circuit_viz = QuantumCircuitVisualizer(backend='qiskit')
# Visualize generator circuit
generator_circuit = quantum_generator.get_circuit()
circuit_viz.draw_circuit(
circuit=generator_circuit,
output='mpl',
title='Quantum Generator Circuit',
save_path='generator_circuit.png'
)
# Visualize quantum state evolution
state_evolution = quantum_generator.get_state_evolution()
circuit_viz.animate_state_evolution(
states=state_evolution,
save_path='state_evolution.gif'
)
# Create quantum metrics dashboard
circuit_viz.create_metrics_dashboard(
circuit=generator_circuit,
metrics=['depth', 'gate_count', 'entanglement'],
save_path='quantum_dashboard.html'
)
Model Utilities
ModelSummary
Generate detailed model summaries including quantum circuit information.
class ModelSummary:
"""
Generate comprehensive model summaries for quantum GANs.
Includes classical neural network parameters, quantum circuit
properties, and computational complexity analysis.
"""
Model Summary Examples
from qgans_pro.utils import ModelSummary
# Create model summary
summary = ModelSummary()
# Summarize quantum generator
gen_summary = summary.summarize_model(
model=quantum_generator,
input_shape=(64, 8), # Batch size, latent dim
include_quantum_info=True
)
print(gen_summary)
# Compare multiple models
comparison = summary.compare_models({
'Quantum Generator': quantum_generator,
'Classical Generator': classical_generator,
'Hybrid Generator': hybrid_generator
})
summary.export_comparison(comparison, 'model_comparison.html')
checkpoint_utils
Utilities for saving and loading model checkpoints.
def save_checkpoint(
model: nn.Module,
optimizer: torch.optim.Optimizer,
epoch: int,
loss: float,
path: str,
metadata: Dict = None
) -> None:
def load_checkpoint(
path: str,
model: nn.Module = None,
optimizer: torch.optim.Optimizer = None
) -> Dict:
Checkpoint Management
from qgans_pro.utils import save_checkpoint, load_checkpoint
# Save checkpoint with metadata
save_checkpoint(
model=qgan,
optimizer=optimizer,
epoch=100,
loss=0.0456,
path='qgan_checkpoint_epoch_100.pth',
metadata={
'dataset': 'fashion-mnist',
'n_qubits': 8,
'fid_score': 15.2,
'config': training_config
}
)
# Load checkpoint
checkpoint_data = load_checkpoint(
path='qgan_checkpoint_epoch_100.pth',
model=qgan,
optimizer=optimizer
)
print(f"Loaded model from epoch {checkpoint_data['epoch']}")
print(f"FID Score: {checkpoint_data['metadata']['fid_score']}")
Quantum Utilities
QuantumStateAnalyzer
Analyze quantum states and circuits.
class QuantumStateAnalyzer:
"""
Analyze quantum states, circuits, and quantum properties.
Provides tools for understanding quantum behavior in GANs.
"""
Quantum State Analysis
from qgans_pro.utils import QuantumStateAnalyzer
# Initialize analyzer
analyzer = QuantumStateAnalyzer(backend='qiskit')
# Analyze quantum circuit properties
circuit_analysis = analyzer.analyze_circuit(
circuit=quantum_generator.circuit,
metrics=['expressibility', 'entangling_capability', 'effective_dimension']
)
print(f"Expressibility: {circuit_analysis['expressibility']:.4f}")
print(f"Entangling Capability: {circuit_analysis['entangling_capability']:.4f}")
# Analyze quantum state properties
state_analysis = analyzer.analyze_quantum_state(
quantum_state=final_state,
analysis_type='full' # 'full', 'entanglement', 'coherence'
)
# Generate quantum circuit recommendations
recommendations = analyzer.suggest_circuit_improvements(
current_circuit=quantum_generator.circuit,
target_metrics={'expressibility': 0.8, 'depth': 10}
)
quantum_encoding
Utilities for encoding classical data into quantum states.
def amplitude_encoding(
data: torch.Tensor,
n_qubits: int,
normalization: str = 'l2'
) -> torch.Tensor:
def angle_encoding(
data: torch.Tensor,
n_qubits: int,
encoding_scheme: str = 'basic'
) -> torch.Tensor:
def iqp_encoding(
data: torch.Tensor,
n_qubits: int,
depth: int = 1
) -> torch.Tensor:
Quantum Encoding Examples
from qgans_pro.utils.quantum_encoding import (
amplitude_encoding,
angle_encoding,
iqp_encoding
)
# Amplitude encoding
classical_data = torch.randn(32, 256)
amplitude_encoded = amplitude_encoding(
data=classical_data,
n_qubits=8,
normalization='l2'
)
# Angle encoding
angle_encoded = angle_encoding(
data=classical_data,
n_qubits=8,
encoding_scheme='linear' # 'linear', 'polynomial'
)
# IQP encoding for entangled features
iqp_encoded = iqp_encoding(
data=classical_data,
n_qubits=8,
depth=2
)
Performance and Profiling
PerformanceProfiler
Profile quantum GAN training performance.
class PerformanceProfiler:
"""
Profile quantum GAN training for performance optimization.
Monitors computational resources, quantum circuit execution
times, and identifies bottlenecks.
"""
Performance Profiling
from qgans_pro.utils import PerformanceProfiler
# Initialize profiler
profiler = PerformanceProfiler(
profile_quantum=True,
profile_classical=True,
memory_tracking=True
)
# Profile training step
with profiler.profile_context('training_step'):
losses = qgan.train_step(batch_data)
# Get profiling results
results = profiler.get_results()
print(f"Quantum circuit time: {results['quantum_time']:.4f}s")
print(f"Classical computation time: {results['classical_time']:.4f}s")
print(f"Memory usage: {results['memory_peak']:.2f} MB")
# Generate performance report
profiler.generate_report('performance_report.html')
MemoryOptimizer
Optimize memory usage for large quantum circuits.
class MemoryOptimizer:
"""
Optimize memory usage in quantum GAN training.
Provides strategies for reducing memory footprint while
maintaining training effectiveness.
"""
Memory Optimization
from qgans_pro.utils import MemoryOptimizer
# Initialize memory optimizer
memory_opt = MemoryOptimizer()
# Apply memory optimizations
optimized_generator = memory_opt.optimize_model(
model=quantum_generator,
strategies=['gradient_checkpointing', 'mixed_precision', 'circuit_chunking']
)
# Monitor memory usage
memory_stats = memory_opt.monitor_memory(
model=optimized_generator,
dataloader=train_loader,
num_batches=10
)
print(f"Peak memory: {memory_stats['peak_memory']} MB")
print(f"Average memory: {memory_stats['avg_memory']} MB")
Configuration Management
Config
Configuration management for quantum GANs.
class Config:
"""
Configuration management system for quantum GANs.
Handles model configurations, training parameters,
and experiment settings with validation and serialization.
"""
Configuration Management
from qgans_pro.utils import Config
# Load configuration from file
config = Config.from_file('qgan_config.yaml')
# Create configuration programmatically
config = Config({
'model': {
'generator': {
'n_qubits': 8,
'n_layers': 3,
'backend': 'qiskit'
},
'discriminator': {
'hidden_dims': [512, 256, 128],
'activation': 'leaky_relu'
}
},
'training': {
'batch_size': 64,
'learning_rate': 0.0002,
'epochs': 100
}
})
# Validate configuration
validation_result = config.validate()
if not validation_result.is_valid:
print(f"Configuration errors: {validation_result.errors}")
# Save configuration
config.save('experiment_config.yaml')
# Use configuration to create models
generator = config.create_generator()
discriminator = config.create_discriminator()
Best Practices and Helpers
ExperimentTracker
Track experiments and compare results.
class ExperimentTracker:
"""
Track quantum GAN experiments and results.
Provides experiment logging, comparison tools,
and result visualization.
"""
Experiment Tracking
from qgans_pro.utils import ExperimentTracker
# Initialize experiment tracker
tracker = ExperimentTracker(
project_name='quantum_gan_research',
save_dir='./experiments'
)
# Start new experiment
experiment_id = tracker.start_experiment(
name='fashion_mnist_8qubit',
config=config,
description='8-qubit quantum generator on Fashion-MNIST'
)
# Log metrics during training
for epoch in range(epochs):
# Training step
losses = qgan.train_step(batch)
# Log metrics
tracker.log_metrics({
'epoch': epoch,
'g_loss': losses['g_loss'],
'd_loss': losses['d_loss'],
'fid_score': fid_score
})
# End experiment
tracker.end_experiment(
final_metrics={'best_fid': 12.5},
model_path='final_model.pth'
)
# Compare experiments
comparison = tracker.compare_experiments([
'fashion_mnist_8qubit',
'fashion_mnist_6qubit',
'fashion_mnist_classical'
])
tracker.visualize_comparison(comparison)
reproducibility
Utilities for reproducible experiments.
def set_seed(seed: int) -> None:
"""Set random seeds for reproducible results."""
def get_system_info() -> Dict:
"""Get system information for experiment reproducibility."""
def create_reproducible_environment(seed: int = 42) -> Dict:
"""Create reproducible environment for quantum GAN experiments."""
Reproducibility Setup
from qgans_pro.utils.reproducibility import (
set_seed,
get_system_info,
create_reproducible_environment
)
# Set up reproducible environment
env_info = create_reproducible_environment(seed=42)
print(f"Environment: {env_info}")
# Set random seeds
set_seed(42)
# Get system information for logging
system_info = get_system_info()
print(f"System: {system_info['platform']}")
print(f"Python: {system_info['python_version']}")
print(f"PyTorch: {system_info['torch_version']}")
print(f"Quantum backends: {system_info['quantum_backends']}")