import numpy as np
from qiskit import QuantumCircuit, QuantumRegister, ClassicalRegister
from qiskit_aer import Aer, AerSimulator
from qiskit.circuit.library import ZGate
import pygame
import random
import time
import argparse
import numpy as np
from typing import List, Tuple
from collections import deque
from maze3 import ImprovedMazeGame
from qiskit import transpile
from qiskit.circuit.library import ZGate

class QuantumMazeSolver(ImprovedMazeGame):
    def __init__(self, size: int = 30, complexity: float = 0.7):
        super().__init__(size, complexity)
        self.quantum_path = []
        self.quantum_explored = set()
        self.QUANTUM_PATH_COLOR = (128, 0, 128)  # Purple for quantum path
        # Initialize simulator
        self.simulator = AerSimulator()

    def create_phase_oracle(self, qc, qr_pos, qr_aux):
        """Create phase oracle using available gates"""
        # Apply Z gate controlled by other qubits
        for i in range(len(qr_pos)-1):
            qc.cz(qr_pos[i], qr_pos[-1])
        # Add phase shift
        qc.p(np.pi, qr_pos[-1])

    def encode_maze_state(self, pos: Tuple[int, int]) -> str:
        """Convert position to binary string"""
        x, y = pos
        x_bin = format(x, f'0{self.bits_needed()}b')
        y_bin = format(y, f'0{self.bits_needed()}b')
        return x_bin + y_bin

    def bits_needed(self) -> int:
        """Calculate number of bits needed to represent positions in maze"""
        # Calculate bits needed for both x and y coordinates
        return int(np.ceil(np.log2(self.size)))

    def encode_position(self, pos: tuple) -> str:
        """Encode position into binary string"""
        x, y = pos
        n_bits = self.bits_needed()
        x_bin = format(x, f'0{n_bits}b')
        y_bin = format(y, f'0{n_bits}b')
        return x_bin + y_bin

    def decode_position(self, bitstring: str) -> tuple:
        """Decode binary string into position"""
        n_bits = self.bits_needed()
        x = int(bitstring[:n_bits], 2)
        y = int(bitstring[n_bits:], 2)
        return (x, y)

    def quantum_solve(self):
        """Solve maze using quantum algorithm"""
        print("\n=== Starting Quantum Maze Solving ===")
        print(f"Maze size: {self.size}x{self.size}")
        
        # Initialize timing
        self.start_time = time.time()
        self.end_time = None  # Initialize to None
        
        try:
            # Calculate required qubits
            n_bits = self.bits_needed()
            n_qubits = min(2 * n_bits, 6)
            
            print("\nQuantum Circuit Setup:")
            print(f"├── Required bits per coordinate: {n_bits}")
            print(f"├── Total qubits used: {n_qubits}")
            print(f"└── Maximum positions representable: {2**n_qubits}")
            
            # Create quantum circuit
            qr = QuantumRegister(n_qubits, 'q')
            cr = ClassicalRegister(n_qubits, 'c')
            qc = QuantumCircuit(qr, cr)

            # Initialize superposition
            qc.h(qr)

            # Number of Grover iterations
            iterations = int(np.sqrt(2**n_qubits))
            print(f"\nGrover's Algorithm:")
            print(f"├── Starting iterations: {iterations}")
            
            # Apply Grover's algorithm
            for i in range(iterations):
                print(f"├── Iteration {i+1}/{iterations} in progress...")
                self.create_phase_oracle(qc, qr, None)
                qc.h(qr)
                qc.x(qr)
                for j in range(len(qr)-1):
                    qc.cz(qr[j], qr[-1])
                qc.x(qr)
                qc.h(qr)
                self.quantum_explored.add((i, i))
                self.draw_maze()
                pygame.time.delay(50)

            # Measure
            print("\nQuantum Measurement:")
            print("├── Preparing measurement...")
            qc.measure(qr, cr)

            print("├── Transpiling quantum circuit...")
            compiled_circuit = transpile(qc, self.simulator)
            print("├── Running quantum simulation (1000 shots)...")
            job = self.simulator.run(compiled_circuit, shots=1000)
            result = job.result()
            counts = result.get_counts()
            
            print("├── Quantum execution successful!")
            print(f"├── Circuit depth: {compiled_circuit.depth()}")
            print(f"└── Final qubit count: {compiled_circuit.num_qubits}")

        except Exception as e:
            print(f"\n❌ Quantum execution error: {e}")
            counts = {}
        finally:
            # Ensure end_time is set even if there's an error
            self.end_time = time.time()

        # Process results
        print("\nProcessing quantum results...")
        self.quantum_path = self._process_quantum_results(counts)
        
        # Print final timing
        duration = self.end_time - self.start_time
        print(f"\nTotal execution time: {duration:.3f} seconds")

    def _process_quantum_results(self, counts):
        if not counts:
            return []

        print("\n=== Processing Quantum Results ===")
        print("Top 5 measured states:")
        
        sorted_results = sorted(counts.items(), key=lambda x: x[1], reverse=True)
        total_shots = sum(count for _, count in counts.items())
        
        # Store probabilities
        self.state_probabilities = {bitstring: count/total_shots 
                                  for bitstring, count in counts.items()}
        
        valid_positions = []
        for i, (bitstring, count) in enumerate(sorted_results[:5]):
            prob = count/total_shots
            try:
                pos = self.decode_position(bitstring)
                valid = self._is_valid_position(*pos)
                valid_positions.append((pos, prob))
                print(f"├── State {i+1}: {bitstring}")
                print(f"│   ├── Position: {pos}")
                print(f"│   ├── Probability: {prob:.3%}")
                print(f"│   └── Valid position: {'✓' if valid else '✗'}")
            except ValueError as e:
                print(f"│   └── Invalid state: {e}")

        # Path construction
        print("\nPath Construction:")
        path = [self.start]
        current = self.start
        print(f"├── Starting from: {self.start}")
        
        step = 1
        while current != self.end and valid_positions:
            next_pos = None
            max_prob = 0
            
            # Find next valid position
            for pos, prob in valid_positions:
                if self._is_valid_move(current, pos) and prob > max_prob:
                    next_pos = pos
                    max_prob = prob
            
            if next_pos is None:
                print(f"├── Step {step}: No direct valid move found")
                next_pos = self._find_path_to_nearest_valid(current, valid_positions)
                if next_pos:
                    print(f"│   └── Found alternate path to: {next_pos}")
                else:
                    print(f"│   └── Failed to find valid path")
            else:
                print(f"├── Step {step}: Moving to {next_pos} (prob: {max_prob:.3%})")
            
            if next_pos:
                path.append(next_pos)
                current = next_pos
                valid_positions = [(p, pr) for p, pr in valid_positions if p != next_pos]
            else:
                print("└── Path construction terminated - no valid moves")
                break
            step += 1

        if current != self.end:
            print("\nFinding path to end:")
            end_path = self._find_path_to_end(path[-1])
            if end_path:
                print(f"├── Found path to end: {len(end_path)} steps")
                path.extend(end_path[1:])
            else:
                print("└── Failed to find path to end")

        print(f"\nFinal path length: {len(path)}")
        return path



    def _is_valid_move(self, current: tuple, next_pos: tuple) -> bool:
        """Check if move from current to next position is valid (no wall crossing)"""
        x1, y1 = current
        x2, y2 = next_pos
        
        # Check if positions are adjacent
        if abs(x1 - x2) + abs(y1 - y2) != 1:
            return False
        
        # Check if both positions are valid
        return (self._is_valid_position(x1, y1) and 
                self._is_valid_position(x2, y2))

    def _find_path_to_nearest_valid(self, current, valid_positions):
        """Find path to nearest valid position using BFS"""
        if not valid_positions:
            return None
            
        queue = deque([(current, [current])])
        visited = {current}
        
        while queue:
            pos, path = queue.popleft()
            
            # Check if current position is in valid_positions
            for valid_pos, _ in valid_positions:
                if pos == valid_pos:
                    return valid_pos
            
            # Try all adjacent positions
            for dx, dy in [(0,1), (1,0), (0,-1), (-1,0)]:
                next_pos = (pos[0] + dx, pos[1] + dy)
                if (self._is_valid_position(*next_pos) and 
                    next_pos not in visited):
                    visited.add(next_pos)
                    queue.append((next_pos, path + [next_pos]))
        
        return None

    def _find_path_to_end(self, start_pos):
        """Find path from given position to end using BFS"""
        queue = deque([(start_pos, [start_pos])])
        visited = {start_pos}
        
        while queue:
            pos, path = queue.popleft()
            
            if pos == self.end:
                return path
            
            for dx, dy in [(0,1), (1,0), (0,-1), (-1,0)]:
                next_pos = (pos[0] + dx, pos[1] + dy)
                if (self._is_valid_position(*next_pos) and 
                    next_pos not in visited):
                    visited.add(next_pos)
                    queue.append((next_pos, path + [next_pos]))
        
        return None

    def draw_maze(self):
        """Draw maze with enhanced visualization"""
        # Draw base maze
        super().draw_maze()
        
        # Draw quantum explored states with probability-based coloring
        if hasattr(self, 'state_probabilities'):
            max_prob = max(self.state_probabilities.values())
            for bitstring, prob in self.state_probabilities.items():
                try:
                    x, y = self.decode_position(bitstring)
                    if self._is_valid_position(x, y):
                        # Color intensity based on probability
                        intensity = int(255 * (prob / max_prob))
                        color = (intensity, 0, intensity, 100)  # Purple with alpha
                        
                        rect = pygame.Surface((self.cell_size, self.cell_size), pygame.SRCALPHA)
                        pygame.draw.rect(rect, color, rect.get_rect())
                        
                        self.screen.blit(rect, 
                                       (y * self.cell_size + self.padding,
                                        x * self.cell_size + self.padding))
                except ValueError:
                    continue
        
        # Draw quantum path with wall checking
        if len(self.quantum_path) > 1:
            for i in range(len(self.quantum_path)-1):
                if self._is_valid_move(self.quantum_path[i], self.quantum_path[i+1]):
                    start_x = self.quantum_path[i][1] * self.cell_size + self.padding + self.cell_size//2
                    start_y = self.quantum_path[i][0] * self.cell_size + self.padding + self.cell_size//2
                    end_x = self.quantum_path[i+1][1] * self.cell_size + self.padding + self.cell_size//2
                    end_y = self.quantum_path[i+1][0] * self.cell_size + self.padding + self.cell_size//2
                    
                    # Draw dashed line
                    dash_length = 5
                    dx = end_x - start_x
                    dy = end_y - start_y
                    distance = np.sqrt(dx**2 + dy**2)
                    steps = int(distance / dash_length)
                    
                    for step in range(steps):
                        if step % 2 == 0:
                            x1 = start_x + dx * step / steps
                            y1 = start_y + dy * step / steps
                            x2 = start_x + dx * (step + 1) / steps
                            y2 = start_y + dy * (step + 1) / steps
                            pygame.draw.line(self.screen, 
                                           self.QUANTUM_PATH_COLOR,
                                           (x1, y1),
                                           (x2, y2),
                                           self.wall_thickness + 1)
        
        pygame.display.flip()

    def _is_valid_position(self, x: int, y: int) -> bool:
        """Check if position is valid in maze"""
        return (0 <= x < self.size and 
                0 <= y < self.size and 
                self.maze[x][y] == 0)


    def analyze_quantum_solving(self):
        """Enhanced analysis with path comparison"""
        try:
            print("\n=== Quantum Solution Analysis ===")
            
            # Time analysis
            if hasattr(self, 'start_time') and hasattr(self, 'end_time') and self.start_time and self.end_time:
                duration = self.end_time - self.start_time
                print(f"Time taken: {duration:.3f} seconds")
            else:
                print("Time taken: Not available")
                
            # Path analysis    
            print(f"Path length: {len(self.quantum_path)}")
            print(f"Quantum states explored: {len(self.quantum_explored)}")
            
            print("\nQuantum Resources:")
            print(f"├── Required qubits: {self.bits_needed() * 2}")
            print(f"└── Circuit iterations: {int(np.sqrt(2**self.bits_needed()))}")
            
            print("\nCircuit Statistics:")
            try:
                n_qubits = min(6, self.bits_needed() * 2)
                qc = QuantumCircuit(n_qubits)
                compiled = transpile(qc, self.simulator)
                print(f"├── Optimized circuit depth: {compiled.depth()}")
                print(f"└── Number of operations: {sum(1 for _ in compiled.data)}")
            except Exception as e:
                print(f"└── Error analyzing circuit: {e}")
        
            # Path comparison
            if hasattr(self, 'path') and self.path and self.quantum_path:
                print("\nPath Comparison:")
                classical_length = len(self.path)
                quantum_length = len(self.quantum_path)
                print(f"├── Classical path length: {classical_length}")
                print(f"├── Quantum path length: {quantum_length}")
                print(f"└── Length difference: {abs(classical_length - quantum_length)}")
                
                # Calculate path overlap
                classical_set = set(self.path)
                quantum_set = set(self.quantum_path)
                overlap = len(classical_set.intersection(quantum_set))
                union = len(classical_set.union(quantum_set))
                similarity = overlap / union if union > 0 else 0
                print(f"\nPath Similarity: {similarity:.2%}")
                
                # Calculate efficiency
                classical_optimal = classical_length == len(set(self.path))
                quantum_optimal = quantum_length == len(set(self.quantum_path))
                print("\nPath Optimality:")
                print(f"├── Classical path optimal: {'✓' if classical_optimal else '✗'}")
                print(f"└── Quantum path optimal: {'✓' if quantum_optimal else '✗'}")
                
                # State exploration comparison
                print("\nExploration Efficiency:")
                print(f"├── Classical states explored: {len(self.explored)}")
                print(f"├── Quantum states explored: {len(self.quantum_explored)}")
                exploration_ratio = len(self.quantum_explored) / len(self.explored)
                print(f"└── Exploration ratio (Q/C): {exploration_ratio:.2f}")
            
            # Probability distribution
            if hasattr(self, 'state_probabilities'):
                print("\nTop 5 Quantum States:")
                sorted_probs = sorted(self.state_probabilities.items(), 
                                    key=lambda x: x[1], 
                                    reverse=True)[:5]
                for i, (state, prob) in enumerate(sorted_probs, 1):
                    print(f"├── State {i}: {state}")
                    print(f"│   └── Probability: {prob:.3f}")
                print("└── End of analysis")
                
        except Exception as e:
            print(f"\n❌ Analysis error: {e}")
   

   

def main():
    parser = argparse.ArgumentParser(description='Quantum Maze Solver')
    parser.add_argument('--size', type=int, default=8,  # Reduced default size
                       help='Size of the maze (default: 8)')
    parser.add_argument('--complexity', type=float, default=0.7,
                       help='Complexity of the maze (0.0-1.0, default: 0.7)')
    parser.add_argument('--show-analysis', action='store_true',
                       help='Show detailed analysis after solving')
    args = parser.parse_args()

    #game = EnhancedQuantumMazeSolver(size=8, complexity=0.7)
    game = QuantumMazeSolver(args.size, args.complexity)
    running = True
    classical_solved = False
    quantum_solved = False

    while running:
        for event in pygame.event.get():
            if event.type == pygame.QUIT:
                running = False
            elif event.type == pygame.KEYDOWN:
                if event.key == pygame.K_c and not classical_solved:
                    print("\nSolving classically...")
                    game.solve_maze()
                    if args.show_analysis:
                        game.analyze_solving_process()
                    classical_solved = True
                elif event.key == pygame.K_q and not quantum_solved:
                    print("\nSolving quantum...")
                    game.quantum_solve()
                    if args.show_analysis:
                        game.analyze_quantum_solving()
                    quantum_solved = True
                elif event.key == pygame.K_r:
                    game = QuantumMazeSolver(args.size, args.complexity)
                    classical_solved = False
                    quantum_solved = False
                    print("\nMaze reset")

        game.draw_maze()
        pygame.time.delay(30)

    pygame.quit()

if __name__ == "__main__":
    main()