import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
import tkinter as tk
from tkinter import ttk
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import random
from scipy.optimize import differential_evolution
import qiskit
from scipy.optimize import differential_evolution
#from qiskit.algorithms.optimizers import COBYLA
from qiskit.circuit.library import RealAmplitudes
from qiskit_algorithms import QAOA, VQE
from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister
from qiskit_aer import Aer, AerSimulator
from qiskit.circuit.library import ZGate
from qiskit_algorithms.optimizers import COBYLA
from qiskit_aer.primitives import Estimator
from qiskit.visualization import plot_bloch_multivector
from qiskit.quantum_info import DensityMatrix, Statevector
import time

'''
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import Axes3D
import tkinter as tk
from tkinter import ttk
from matplotlib.backends.backend_tkagg import FigureCanvasTkAgg
import threading
from scipy.optimize import differential_evolution
from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister, Aer, transpile
from qiskit.visualization import plot_bloch_multivector
from qiskit.quantum_info import DensityMatrix, Statevector
from qiskit_aer.primitives import Estimator
import time
'''

class OptimizationDemo:
    def __init__(self):
        self.root = tk.Tk()
        self.root.title("Advanced Optimization Comparison")
        
        # Initialize quantum backend and estimator
        self.simulator = Aer.get_backend('aer_simulator')
        self.estimator = Estimator()
        
        # Parameters
        self.amplitude = tk.DoubleVar(value=1.0)
        self.frequency = tk.DoubleVar(value=2.0)
        
        # Optimization metrics
        self.classical_metrics = {'iterations': 0, 'time': 0, 'path_length': 0}
        self.quantum_metrics = {'iterations': 0, 'time': 0, 'path_length': 0}
        
        # Setup GUI with additional tabs
        self.setup_gui()
        self.update_plot()

    def setup_gui(self):
        # Create notebook for tabs
        self.notebook = ttk.Notebook(self.root)
        self.notebook.grid(row=0, column=0, sticky='nsew')
        
        # Main optimization tab
        self.main_frame = ttk.Frame(self.notebook)
        self.notebook.add(self.main_frame, text='Optimization')
        
        # Quantum state tab
        self.quantum_frame = ttk.Frame(self.notebook)
        self.notebook.add(self.quantum_frame, text='Quantum State')
        
        # Metrics tab
        self.metrics_frame = ttk.Frame(self.notebook)
        self.notebook.add(self.metrics_frame, text='Comparison Metrics')
        
        self.setup_main_tab()
        self.setup_quantum_tab()
        self.setup_metrics_tab()

    def setup_main_tab(self):
        # Controls frame
        controls = ttk.Frame(self.main_frame)
        controls.grid(row=0, column=0, padx=10, pady=5)
        
        # Parameter controls
        param_frame = ttk.LabelFrame(controls, text="Parameters")
        param_frame.pack(fill='x', padx=5, pady=5)
        
        ttk.Label(param_frame, text="Amplitude:").pack(side=tk.LEFT)
        amplitude_slider = ttk.Scale(
            param_frame, from_=0.1, to=2.0, 
            variable=self.amplitude, 
            orient=tk.HORIZONTAL, 
            length=200,
            command=lambda _: self.update_plot()
        )
        amplitude_slider.pack(side=tk.LEFT, padx=5)
        
        ttk.Label(param_frame, text="Frequency:").pack(side=tk.LEFT)
        frequency_slider = ttk.Scale(
            param_frame, from_=0.5, to=4.0, 
            variable=self.frequency, 
            orient=tk.HORIZONTAL, 
            length=200,
            command=lambda _: self.update_plot()
        )
        frequency_slider.pack(side=tk.LEFT, padx=5)
        
        # Buttons frame
        button_frame = ttk.Frame(controls)
        button_frame.pack(fill='x', padx=5, pady=5)
        
        ttk.Button(button_frame, text="Run Both Optimizations", 
                  command=self.run_both_optimizations).pack(side=tk.LEFT, padx=5)
        ttk.Button(button_frame, text="Reset", 
                  command=self.reset_optimization).pack(side=tk.LEFT, padx=5)
        
        # Status frame
        self.status_label = ttk.Label(controls, text="")
        self.status_label.pack(fill='x', padx=5)
        
        # Plot
        self.fig = plt.Figure(figsize=(10, 8))
        self.ax = self.fig.add_subplot(111, projection='3d')
        self.canvas = FigureCanvasTkAgg(self.fig, master=self.main_frame)
        self.canvas.get_tk_widget().grid(row=1, column=0, padx=10, pady=10)

    def setup_quantum_tab(self):
        self.quantum_fig = plt.Figure(figsize=(8, 6))
        self.quantum_canvas = FigureCanvasTkAgg(self.quantum_fig, master=self.quantum_frame)
        self.quantum_canvas.get_tk_widget().pack(fill=tk.BOTH, expand=True)

    def setup_metrics_tab(self):
        self.metrics_text = tk.Text(self.metrics_frame, height=20, width=60)
        self.metrics_text.pack(padx=10, pady=10)

    def update_metrics(self):
        self.metrics_text.delete(1.0, tk.END)
        metrics_str = f"""Optimization Comparison Metrics:

Classical Optimization:
- Iterations: {self.classical_metrics['iterations']}
- Time: {self.classical_metrics['time']:.2f} seconds
- Path Length: {self.classical_metrics['path_length']:.2f}
- Final Value: {self.classical_metrics.get('final_value', 'N/A')}

Quantum Optimization:
- Iterations: {self.quantum_metrics['iterations']}
- Time: {self.quantum_metrics['time']:.2f} seconds
- Path Length: {self.quantum_metrics['path_length']:.2f}
- Final Value: {self.quantum_metrics.get('final_value', 'N/A')}

Comparison:
- Time Ratio (Quantum/Classical): {self.quantum_metrics['time']/self.classical_metrics['time'] if self.classical_metrics['time'] > 0 else 'N/A'}
- Path Length Ratio: {self.quantum_metrics['path_length']/self.classical_metrics['path_length'] if self.classical_metrics['path_length'] > 0 else 'N/A'}
"""
        self.metrics_text.insert(1.0, metrics_str)

    def update_quantum_state(self, qc):
        """Update quantum state visualization"""
        self.quantum_fig.clear()
        
        # Get statevector
        backend = Aer.get_backend('statevector_simulator')
        job = backend.run(transpile(qc, backend))
        statevector = job.result().get_statevector()
        
        # Plot Bloch sphere for first qubit
        axes = self.quantum_fig.add_subplot(111, projection='3d')
        state = DensityMatrix.from_instruction(qc)
        plot_bloch_multivector(state, ax=axes)
        
        self.quantum_canvas.draw()

    def calculate_path_length(self, route):
        """Calculate total path length of optimization route"""
        if len(route) < 2:
            return 0
        route = np.array(route)
        return np.sum(np.sqrt(np.sum((route[1:] - route[:-1])**2, axis=1)))

    def run_both_optimizations(self):
        """Run both classical and quantum optimizations in parallel"""
        self.reset_optimization()
        
        # Create and start threads
        classical_thread = threading.Thread(target=self.run_classical_optimization)
        quantum_thread = threading.Thread(target=self.run_quantum_optimization)
        
        classical_thread.start()
        quantum_thread.start()

    def reset_optimization(self):
        """Reset all optimization data and visualizations"""
        self.current_route = []
        self.best_point = None
        self.classical_metrics = {'iterations': 0, 'time': 0, 'path_length': 0}
        self.quantum_metrics = {'iterations': 0, 'time': 0, 'path_length': 0}
        self.update_plot()
        self.update_metrics()
        self.status_label.config(text="")

    # [Previous methods remain the same: objective_function, update_plot]

    def run_classical_optimization(self):
        """Enhanced classical optimization"""
        start_time = time.time()
        self.classical_metrics['iterations'] = 0
        
        def callback(xk, convergence):
            self.current_route.append(xk)
            self.best_point = xk
            self.classical_metrics['iterations'] += 1
            self.update_plot()
            self.root.update()
        
        bounds = [(-3, 3), (-3, 3)]
        result = differential_evolution(
            self.objective_function, 
            bounds,
            callback=callback,
            polish=True
        )
        
        # Update metrics
        self.classical_metrics['time'] = time.time() - start_time
        self.classical_metrics['path_length'] = self.calculate_path_length(self.current_route)
        self.classical_metrics['final_value'] = result.fun
        
        self.update_metrics()


    def update_plot(self):
        self.ax.clear()
        
        # Generate surface data
        x = np.linspace(-3, 3, 100)
        y = np.linspace(-3, 3, 100)
        X, Y = np.meshgrid(x, y)
        Z = self.objective_function(X, Y)
        
        # Plot surface
        surf = self.ax.plot_surface(X, Y, Z, cmap=cm.viridis, alpha=0.8)
        
        # Plot optimization route if exists
        if self.current_route:
            route = np.array(self.current_route)
            self.ax.plot(route[:, 0], route[:, 1], 
                        [self.objective_function(p) for p in route],
                        'r.-', linewidth=2, label='Optimization path')
        
        # Plot best point if exists
        if self.best_point is not None:
            self.ax.scatter([self.best_point[0]], [self.best_point[1]], 
                          [self.objective_function(self.best_point)],
                          color='red', s=100, marker='*', label='Global minimum')
        
        self.ax.set_xlabel('X')
        self.ax.set_ylabel('Y')
        self.ax.set_zlabel('Z')
        self.ax.set_title('Non-Convex Function Optimization')
        
        if self.current_route or self.best_point:
            self.ax.legend()
            
        self.canvas.draw()    




    def run_quantum_optimization(self):
        """Enhanced quantum optimization"""
        start_time = time.time()
        self.quantum_metrics['iterations'] = 0
        
        # Enhanced parameter update method using SPSA
        def spsa_update(params, gradient, iteration):
            a = 0.1 * (1 + iteration/10)**(-0.602)
            c = 0.1 * (1 + iteration)**(-0.101)
            
            # Generate perturbation vector
            delta = 2 * np.random.randint(2, size=len(params)) - 1
            
            # Evaluate at perturbed points
            plus_params = params + c * delta
            minus_params = params - c * delta
            
            plus_value = self.quantum_expectation(plus_params)
            minus_value = self.quantum_expectation(minus_params)
            
            # Approximate gradient
            gradient = (plus_value - minus_value) / (2 * c * delta)
            
            # Update parameters
            return params - a * gradient

        # Initialize parameters for quantum circuit
        num_params = 8
        parameters = np.random.rand(num_params) * 2 * np.pi
        
        for iteration in range(20):
            self.quantum_metrics['iterations'] += 1
            
            # Create and run quantum circuit
            qc = self.create_quantum_circuit(parameters)
            self.update_quantum_state(qc)
            
            compiled_circuit = transpile(qc, self.simulator)
            job = self.simulator.run(compiled_circuit, shots=1000)
            result = job.result()
            counts = result.get_counts()
            
            # Process results
            max_bitstring = max(counts, key=counts.get)
            x = (int(max_bitstring[:2], 2) / 3 - 1) * 3
            y = (int(max_bitstring[2:], 2) / 3 - 1) * 3
            
            current_point = np.array([x, y])
            self.current_route.append(current_point)
            
            if (self.best_point is None or 
                self.objective_function(current_point) < self.objective_function(self.best_point)):
                self.best_point = current_point
            
            # Update parameters using SPSA
            parameters = spsa_update(parameters, None, iteration)
            
            self.update_plot()
            self.root.update()
            time.sleep(0.1)
        
        # Update metrics
        self.quantum_metrics['time'] = time.time() - start_time
        self.quantum_metrics['path_length'] = self.calculate_path_length(self.current_route)
        self.quantum_metrics['final_value'] = self.objective_function(self.best_point)
        
        self.update_metrics()

if __name__ == "__main__":
    app = OptimizationDemo()
    app.run()