import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
import tkinter as tk
from tkinter import ttk
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import threading
from scipy.optimize import differential_evolution
from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister, transpile
from qiskit.visualization import plot_bloch_multivector
from qiskit.quantum_info import DensityMatrix, Statevector
from qiskit_aer.primitives import Estimator
from qiskit_aer import Aer, AerSimulator
import time


class OptimizationDemo:
    def __init__(self):
        self.root = tk.Tk()
        self.root.title("Advanced Optimization Comparison")
        
        # Initialize quantum backend and estimator
        self.simulator = Aer.get_backend('aer_simulator')
        self.estimator = Estimator()
        
        # Parameters
        self.amplitude = tk.DoubleVar(value=1.0)
        self.frequency = tk.DoubleVar(value=2.0)
        
        # Optimization metrics
        self.classical_metrics = {'iterations': 0, 'time': 0, 'path_length': 0}
        self.quantum_metrics = {'iterations': 0, 'time': 0, 'path_length': 0}
        
        # Initialize optimization variables
        self.current_route = []
        self.best_point = None
        
        # Setup GUI with tabs and plots
        self.setup_gui()
        self.create_quantum_state_widgets()
        self.update_plot()

    def setup_gui(self):
        # Create notebook for tabs
        self.notebook = ttk.Notebook(self.root)
        self.notebook.grid(row=0, column=0, sticky='nsew')
        
        # Main optimization tab
        self.main_frame = ttk.Frame(self.notebook)
        self.notebook.add(self.main_frame, text='Optimization')
        
        # Quantum state tab
        self.quantum_frame = ttk.Frame(self.notebook)
        self.notebook.add(self.quantum_frame, text='Quantum State')
        
        # Metrics tab
        self.metrics_frame = ttk.Frame(self.notebook)
        self.notebook.add(self.metrics_frame, text='Comparison Metrics')
        
        self.setup_main_tab()
        self.setup_metrics_tab()

    def create_quantum_state_widgets(self):
        """Create quantum state visualization widgets"""
        # Create figure for quantum visualizations
        self.quantum_fig = plt.Figure(figsize=(8, 8))
        self.quantum_canvas = FigureCanvasTkAgg(self.quantum_fig, master=self.quantum_frame)
        self.quantum_canvas.get_tk_widget().pack(expand=True, fill='both')
        
        # Create info label
        self.quantum_info_label = ttk.Label(self.quantum_frame, text="")
        self.quantum_info_label.pack(pady=10)

    def update_quantum_state(self, qc):
        """Update quantum state visualization with multiple views"""
        self.quantum_fig.clear()
        
        # Create grid for multiple visualizations
        gs = self.quantum_fig.add_gridspec(2, 2)
        
        try:
            # Create a copy of the circuit without measurements for state visualization
            qc_no_measure = QuantumCircuit(qc.num_qubits)
            for inst, qargs, cargs in qc.data:
                # Skip measurement operations
                if inst.name != 'measure':
                    qc_no_measure.append(inst, qargs, cargs)
            
            # Get statevector
            backend = Aer.get_backend('statevector_simulator')
            job = backend.run(transpile(qc_no_measure, backend))
            statevector = job.result().get_statevector()
            
            # 1. Alternative state visualization (Phase plot)
            ax_state = self.quantum_fig.add_subplot(gs[0, 0])
            phases = np.angle(statevector)
            magnitudes = np.abs(statevector)
            ax_state.scatter(phases, magnitudes, c=magnitudes, cmap='viridis')
            ax_state.set_title('Quantum State (Phase vs Magnitude)', fontsize=8)
            ax_state.set_xlabel('Phase')
            ax_state.set_ylabel('Magnitude')
            ax_state.grid(True)
            
            # 2. Circuit diagram (show full circuit including measurements)
            ax_circuit = self.quantum_fig.add_subplot(gs[0, 1])
            circuit_drawing = qc.draw('mpl', ax=ax_circuit)
            ax_circuit.set_title('Quantum Circuit', fontsize=8)
            
            # 3. Statevector probabilities
            ax_statevector = self.quantum_fig.add_subplot(gs[1, 0])
            probabilities = np.abs(statevector) ** 2
            ax_statevector.bar(range(len(probabilities)), probabilities)
            ax_statevector.set_title('State Probabilities', fontsize=8)
            ax_statevector.set_xlabel('Basis State')
            ax_statevector.set_ylabel('Probability')
            
            # 4. Current parameter space visualization
            ax_params = self.quantum_fig.add_subplot(gs[1, 1])
            
            # Extract current parameters
            current_params = []
            for inst, qargs, cargs in qc.data:
                if hasattr(inst, 'params'):
                    current_params.extend(inst.params)
            
            if current_params:
                param_positions = np.arange(len(current_params))
                ax_params.bar(param_positions, current_params)
                ax_params.set_title('Circuit Parameters', fontsize=8)
                ax_params.set_xlabel('Parameter Index')
                ax_params.set_ylabel('Parameter Value')
            else:
                ax_params.text(0.5, 0.5, 'No parameters available',
                             ha='center', va='center')
            
            # Update layout
            self.quantum_fig.tight_layout()
            
            # Update info text
            info_text = f"""Quantum Circuit Information:
- Number of qubits: {qc.num_qubits}
- Circuit depth: {qc.depth()}
- Number of operations: {len(qc.data)}
- Number of parameters: {len(current_params)}
- Max probability state: {np.argmax(probabilities)}
- Max probability: {np.max(probabilities):.3f}"""
            self.quantum_info_label.config(text=info_text)
            
        except Exception as e:
            print(f"Error updating quantum visualization: {e}")
            # Create error message in the figure
            self.quantum_fig.text(0.5, 0.5, f"Error updating visualization:\n{str(e)}",
                                ha='center', va='center', color='red')
        
        self.quantum_canvas.draw()
            

    def setup_main_tab(self):
        # Controls frame
        controls = ttk.Frame(self.main_frame)
        controls.grid(row=0, column=0, padx=10, pady=5)
        
        # Parameter controls
        param_frame = ttk.LabelFrame(controls, text="Parameters")
        param_frame.pack(fill='x', padx=5, pady=5)
        
        ttk.Label(param_frame, text="Amplitude:").pack(side=tk.LEFT)
        amplitude_slider = ttk.Scale(
            param_frame, from_=0.1, to=2.0, 
            variable=self.amplitude, 
            orient=tk.HORIZONTAL, 
            length=200,
            command=lambda _: self.update_plot()
        )
        amplitude_slider.pack(side=tk.LEFT, padx=5)
        
        ttk.Label(param_frame, text="Frequency:").pack(side=tk.LEFT)
        frequency_slider = ttk.Scale(
            param_frame, from_=0.5, to=4.0, 
            variable=self.frequency, 
            orient=tk.HORIZONTAL, 
            length=200,
            command=lambda _: self.update_plot()
        )
        frequency_slider.pack(side=tk.LEFT, padx=5)
        
        # Buttons frame
        button_frame = ttk.Frame(controls)
        button_frame.pack(fill='x', padx=5, pady=5)
        
        ttk.Button(button_frame, text="Run Both Optimizations", 
                  command=self.run_both_optimizations).pack(side=tk.LEFT, padx=5)
        ttk.Button(button_frame, text="Reset", 
                  command=self.reset_optimization).pack(side=tk.LEFT, padx=5)
        
        # Status frame
        self.status_label = ttk.Label(controls, text="")
        self.status_label.pack(fill='x', padx=5)
        
        # Plot
        self.fig = plt.Figure(figsize=(10, 8))
        self.ax = self.fig.add_subplot(111, projection='3d')
        self.canvas = FigureCanvasTkAgg(self.fig, master=self.main_frame)
        self.canvas.get_tk_widget().grid(row=1, column=0, padx=10, pady=10)

    def setup_metrics_tab(self):
        self.metrics_text = tk.Text(self.metrics_frame, height=20, width=60)
        self.metrics_text.pack(padx=10, pady=10)

    def run_both_optimizations(self):
        """Run both classical and quantum optimizations sequentially"""
        self.reset_optimization()
        self.status_label.config(text="Starting optimizations...")
        
        # Run classical first
        self.run_classical_optimization()
        
        # Store classical results
        classical_route = self.current_route.copy()
        classical_best = self.best_point.copy()
        
        # Run quantum
        self.run_quantum_optimization()
        
        # Combine results in plot
        self.current_route = classical_route
        self.best_point = classical_best if self.objective_function(classical_best) < self.objective_function(self.best_point) else self.best_point
        self.update_plot()
        self.status_label.config(text="Both optimizations completed")


    def objective_function(self, x, y=None):
        """Combined non-convex function"""
        if y is None:  # For optimization algorithms that pass a single array
            y = x[1]
            x = x[0]
            
        amp = self.amplitude.get()
        freq = self.frequency.get()
        
        # Sine waves
        sine_component = amp * np.sin(freq * x) * np.cos(freq * y)
        
        # Gaussian peak
        gaussian_component = 2 * np.exp(-(x**2 + y**2)/(4 * amp))
        
        # Polynomial
        poly_component = (x**3/10 - y**2/5) * amp/2
        
        return sine_component + gaussian_component + poly_component


    def update_plot(self):
        """Update the 3D surface plot"""
        self.ax.clear()
        
        # Generate surface data
        x = np.linspace(-3, 3, 100)
        y = np.linspace(-3, 3, 100)
        X, Y = np.meshgrid(x, y)
        Z = self.objective_function(X, Y)
        
        # Plot surface
        surf = self.ax.plot_surface(X, Y, Z, cmap=cm.viridis, alpha=0.8)
        
        # Plot optimization route if exists
        if self.current_route:
            route = np.array(self.current_route)
            self.ax.plot(route[:, 0], route[:, 1], 
                        [self.objective_function(p) for p in route],
                        'r.-', linewidth=2, label='Optimization path')
        
        # Plot best point if exists
        if self.best_point is not None:
            self.ax.scatter([self.best_point[0]], [self.best_point[1]], 
                          [self.objective_function(self.best_point)],
                          color='red', s=100, marker='*', label='Global minimum')
        
        self.ax.set_xlabel('X')
        self.ax.set_ylabel('Y')
        self.ax.set_zlabel('Z')
        self.ax.set_title('Non-Convex Function Optimization')
        
        if self.current_route or self.best_point:
            self.ax.legend()
            
        self.canvas.draw()


    def create_quantum_circuit(self, params):
        """Create parameterized quantum circuit"""
        qr = QuantumRegister(4, 'q')
        cr = ClassicalRegister(4, 'c')
        qc = QuantumCircuit(qr, cr)
        
        # Initialize with Hadamard
        for i in range(4):
            qc.h(qr[i])
        
        # Parameterized rotations
        for i, param in enumerate(params[:4]):
            qc.ry(param, qr[i])
            qc.rz(param, qr[i])
        
        # Entanglement layers
        for i in range(3):
            qc.cx(qr[i], qr[i+1])
        
        # Second rotation layer
        for i, param in enumerate(params[4:], start=0):
            qc.ry(param, qr[i])
        
        # Measurement
        qc.measure(qr, cr)
        
        return qc

    def quantum_expectation(self, params):
        """Calculate expectation value using quantum circuit"""
        qc = self.create_quantum_circuit(params)
        compiled_circuit = transpile(qc, self.simulator)
        job = self.simulator.run(compiled_circuit, shots=1000)
        result = job.result()
        counts = result.get_counts()
        
        # Convert counts to expectation value
        expectation = 0
        total_shots = sum(counts.values())
        for bitstring, count in counts.items():
            x = (int(bitstring[:2], 2) / 3 - 1) * 3
            y = (int(bitstring[2:], 2) / 3 - 1) * 3
            expectation += self.objective_function(x, y) * count / total_shots
        
        return expectation

    def reset_optimization(self):
        """Reset all optimization data and visualizations"""
        self.current_route = []
        self.best_point = None
        self.classical_metrics = {'iterations': 0, 'time': 0, 'path_length': 0}
        self.quantum_metrics = {'iterations': 0, 'time': 0, 'path_length': 0}
        self.update_plot()
        self.update_metrics()
        self.status_label.config(text="")
        
        # Clear quantum visualization
        self.quantum_fig.clear()
        self.quantum_canvas.draw()

    def update_metrics(self):
        """Update the metrics display"""
        self.metrics_text.delete(1.0, tk.END)
        metrics_str = f"""Optimization Comparison Metrics:
        Classical Optimization:
- Iterations: {self.classical_metrics['iterations']}
- Time: {self.classical_metrics['time']:.2f} seconds
- Path Length: {self.classical_metrics['path_length']:.2f}
- Final Value: {self.classical_metrics.get('final_value', 'N/A')}

Quantum Optimization:
- Iterations: {self.quantum_metrics['iterations']}
- Time: {self.quantum_metrics['time']:.2f} seconds
- Path Length: {self.quantum_metrics['path_length']:.2f}
- Final Value: {self.quantum_metrics.get('final_value', 'N/A')}

Comparison:
- Time Ratio (Quantum/Classical): {self.quantum_metrics['time']/self.classical_metrics['time'] if self.classical_metrics['time'] > 0 else 'N/A'}
- Path Length Ratio: {self.quantum_metrics['path_length']/self.classical_metrics['path_length'] if self.classical_metrics['path_length'] > 0 else 'N/A'}
"""
        self.metrics_text.insert(1.0, metrics_str)

   
    def calculate_path_length(self, route):
        """Calculate total path length of optimization route"""
        if len(route) < 2:
            return 0
        route = np.array(route)
        return np.sum(np.sqrt(np.sum((route[1:] - route[:-1])**2, axis=1)))

    def run_classical_optimization(self):
        """Classical optimization using Differential Evolution"""
        start_time = time.time()
        self.classical_metrics['iterations'] = 0
        classical_route = []
        
        def callback(xk, convergence):
            classical_route.append(xk)
            self.classical_metrics['iterations'] += 1
            if len(classical_route) > 1:
                self.status_label.config(text=f"Classical optimization iteration: {len(classical_route)}")
        
        bounds = [(-3, 3), (-3, 3)]
        result = differential_evolution(
            self.objective_function, 
            bounds,
            callback=callback,
            polish=True
        )
        
        # Update metrics
        self.classical_metrics['time'] = time.time() - start_time
        self.classical_metrics['path_length'] = self.calculate_path_length(classical_route)
        self.classical_metrics['final_value'] = result.fun
        
        # Update visualization
        self.current_route = classical_route
        self.best_point = result.x
        self.update_plot()
        self.update_metrics()
        self.status_label.config(text="Classical optimization completed")

    def run_quantum_optimization(self):
        """Quantum optimization with SPSA"""
        start_time = time.time()
        self.quantum_metrics['iterations'] = 0
        quantum_route = []
        
        # Initialize parameters
        num_params = 8
        parameters = np.random.rand(num_params) * 2 * np.pi
        best_value = float('inf')
        best_point = None
        
        for iteration in range(20):
            self.quantum_metrics['iterations'] += 1
            
            # Create and run quantum circuit
            qc = self.create_quantum_circuit(parameters)
            self.update_quantum_state(qc)
            
            compiled_circuit = transpile(qc, self.simulator)
            job = self.simulator.run(compiled_circuit, shots=1000)
            result = job.result()
            counts = result.get_counts()
            
            # Process results
            max_bitstring = max(counts, key=counts.get)
            x = (int(max_bitstring[:2], 2) / 3 - 1) * 3
            y = (int(max_bitstring[2:], 2) / 3 - 1) * 3
            current_point = np.array([x, y])
            quantum_route.append(current_point)
            
            # Update best point if better
            current_value = self.objective_function(current_point)
            if current_value < best_value:
                best_value = current_value
                best_point = current_point
            
            # SPSA parameter update
            gradient = np.random.randn(num_params)
            learning_rate = 0.1 / (1 + iteration/10)
            parameters = parameters - learning_rate * gradient
            parameters = np.clip(parameters, 0, 2*np.pi)
            
            self.status_label.config(text=f"Quantum optimization iteration: {iteration + 1}/20")
            time.sleep(0.1)
        
        # Update metrics
        self.quantum_metrics['time'] = time.time() - start_time
        self.quantum_metrics['path_length'] = self.calculate_path_length(quantum_route)
        self.quantum_metrics['final_value'] = best_value
        
        # Update visualization
        self.current_route = quantum_route
        self.best_point = best_point
        self.update_plot()
        self.update_metrics()
        self.status_label.config(text="Quantum optimization completed")    


    def run(self):
        self.root.mainloop()

if __name__ == "__main__":
    app = OptimizationDemo()
    app.run()