'''
Dynex SDK Neuromorphic Computing Library
Copyright (c) 2021-2024, Dynex Developers

All rights reserved.

AUTHOR :: Samer Rahmeh | Dynex | Global Head of Quantum Solutions Architecture
AUTHOR :: "This code is CONFIDENTIAL and cannot be shared without Sumitomo's AUTHORIZATION"

1. Redistributions of source code must retain the above copyright notice, this list of
    conditions and the following disclaimer.
 
2. Redistributions in binary form must reproduce the above copyright notice, this list
   of conditions and the following disclaimer in the documentation and/or other
   materials provided with the distribution.
 
3. Neither the name of the copyright holder nor the names of its contributors may be
   used to endorse or promote products derived from this software without specific
   prior written permission.

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF
THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
'''
try:
    import tensorflow as tf
    HAS_TENSORFLOW = True
except ImportError:
    tf = None  # tensorflow is optional
    HAS_TENSORFLOW = False
import pennylane as qml
from pennylane import numpy as np
from pyqubo import Binary, Constraint, Placeholder, And, Xor
# import dynex  # Removed to avoid circular import - not needed in this module
from collections import Counter
import inspect
import warnings
from collections.abc import Iterable
from typing import Optional, Text, Dict, Union, Tuple, List, Any
import time
try:
    from sklearn.metrics.pairwise import euclidean_distances
except ImportError:
    euclidean_distances = None  # sklearn is optional

class ComplexQubit:
    def __init__(self, name):
        self.real = Binary(f'{name}_real')
        self.imag = Binary(f'{name}_imag')

    def conjugate(self):
        conj = ComplexQubit(f'{self.real.label[:-5]}_conj')
        conj.real = self.real
        conj.imag = -self.imag
        return conj

    def __mul__(self, other):
        if isinstance(other, (int, float)):
            scaled = ComplexQubit(f'{self.real.label[:-5]}_scaled')
            scaled.real = self.real * other
            scaled.imag = self.imag * other
            return scaled
        elif isinstance(other, ComplexQubit):
            product = ComplexQubit(f'{self.real.label[:-5]}_{other.real.label[:-5]}_mul')
            product.real = self.real * other.real - self.imag * other.imag
            product.imag = self.real * other.imag + self.imag * other.real
            return product

    def __add__(self, other):
        if isinstance(other, ComplexQubit):
            sumQ = ComplexQubit(f'{self.real.label[:-5]}_{other.real.label[:-5]}_add')
            sumQ.real = self.real + other.real
            sumQ.imag = self.imag + other.imag
            return sumQ
        elif isinstance(other, (int, float)):
            sumQ = ComplexQubit(f'{self.real.label[:-5]}_scaled_add')
            sumQ.real = self.real + other
            sumQ.imag = self.imag
            return sumQ
    def __sub__(self, other):
        if isinstance(other, ComplexQubit):
            diff = ComplexQubit(f'{self.real.label[:-5]}_{other.real.label[:-5]}_sub')
            diff.real = self.real - other.real
            diff.imag = self.imag - other.imag
            return diff
        elif isinstance(other, (int, float)):
            diff = ComplexQubit(f'{self.real.label[:-5]}_scaled_sub')
            diff.real = self.real - other
            diff.imag = self.imag
            return diff
        else:
            raise TypeError(f"Unsupported operand type for -: '{type(other)}'")

    def __truediv__(self, other):
        if isinstance(other, (int, float)):
            div = ComplexQubit(f'{self.real.label[:-5]}_div')
            div.real = self.real / other
            div.imag = self.imag / other
            return div
        else:
            raise TypeError(f"Unsupported operand type for /: '{type(other)}'")

    def __radd__(self, other):
        return self.__add__(other)

    def __rsub__(self, other):
        if isinstance(other, (int, float)):
            diff = ComplexQubit(f'{self.real.label[:-5]}_rscaled_sub')
            diff.real = other - self.real
            diff.imag = -self.imag
            return diff
        else:
            raise TypeError(f"Unsupported operand type for -: '{type(other)}'")
    def __rmul__(self, other):
        return self.__mul__(other)

    def __neg__(self):
        neg = ComplexQubit(f'{self.real.label[:-5]}_neg')
        neg.real = -self.real
        neg.imag = -self.imag
        return neg
    @property
    def conj(self):
        return self.conjugate()

class Adjoint:
    def __init__(self, bOP):
        self.bOP = bOP
        
    def decomposition(self):
        if hasattr(self.bOP, 'decomposition'):
            return [Adjoint(op) for op in reversed(self.bOP.decomposition())]
        return [self]
        
    def matrix(self):
        bMX = self.bOP.matrix()
        return np.conj(bMX.T)

if HAS_TENSORFLOW:
    class QuantumCallback(tf.keras.callbacks.Callback):    
        def __init__(self, qScheduleF=1, MhistoryS=5):
            super().__init__()
            self.qSF = qScheduleF
            self.mHS = MhistoryS
            self.eCounter = 0
            self.lossH = []
            self.accH = []
            self.qUsageH = []
            
        def on_epoch_begin(self, e, logs=None):
            qlayers = self._getQLayers()
            
            for l in qlayers:
                l._ONEPBegin(e, logs)
        
        def on_epoch_end(self, e, logs=None):
            qlayers = self._getQLayers()       
            if logs:
                if 'loss' in logs:
                    self.lossH.append(logs['loss'])
                    if len(self.lossH) > self.mHS:
                        self.lossH.pop(0)
                if 'accuracy' in logs:
                    self.accH.append(logs['accuracy'])
                    if len(self.accH) > self.mHS:
                        self.accH.pop(0)        
            self.eCounter += 1
            for l in qlayers:
                l._ONEPEnd(e, logs)
                self.qUsageH.append(l.eQuantumC)
                if len(self.qUsageH) > self.mHS:
                    self.qUsageH.pop(0)        
            if len(self.qUsageH) > 0:
                qUsageP = sum(self.qUsageH) / len(self.qUsageH) * 100
        
        def _getQLayers(self):
            qlayers = []
            def _RECSEARCH(layer):
                if isinstance(layer, QKerasLayer):
                    qlayers.append(layer)
                if hasattr(layer, 'layers'):
                    for sublayer in layer.layers:
                        _RECSEARCH(sublayer)
            _RECSEARCH(self.model)
            return qlayers
        
        def _ONTREnd(self, batch, logs=None):
            # this is for later versions 
            pass

    class QKerasLayer(tf.keras.layers.Layer):
        def __init__(self, bridge, wShapes: Dict[str, Union[int, Tuple, List]], 
                     outDim: Union[int, Tuple, List], wSpecs: Optional[Dict[str, Dict]] = None, 
                     num_reads: int = 512, annealing_time: int = 256, mainnet: bool = False, 
                     shots: int = 1, decMethod: str = 'measure', 
                     QNNun: int = 128, QNNlayers: int = 3,
                     QNNact: str = 'swish', QNNlr: float = 0.001,
                     QNNe: int = 50, QNNbs: int = 16,
                     qDivergenceTHR: float = 0.05,
                     imSampling: bool = True,
                     qSchedulingF: float = 0.1,
                     eBlending: bool = True,
                     qRegWEIGHT: float = 0.01,
                     QNNarch: bool = True,
                     corrST: float = 0.3, 
                     QNNGradient: str = 'default', 
                     gEps: float = 0.01, 
                     ptQNN: bool = True,
                     **kwargs):
            self.weightShapes = {
            weight: (tuple(size) if isinstance(size, Iterable) else (size,) if size > 1 else ())
            for weight, size in wShapes.items()
        }
        self._SIGNATUREVal(bridge, wShapes)
        self.bridge = bridge
        self.num_reads = num_reads
        self.annealing_time = annealing_time
        self.mainnet = mainnet
        self.shots = shots
        self.decMethod = decMethod        
        if isinstance(outDim, Iterable) and len(outDim) > 1:
            self.outDim = tuple(outDim)
        else:
            self.outDim = outDim[0] if isinstance(outDim, Iterable) else outDim
        self.wSpecs = wSpecs if wSpecs is not None else {}
        self.QnW = {}
        self.QNNun = QNNun
        self.QNNlayers = QNNlayers
        self.QNNact = QNNact
        self.QNNlr = QNNlr
        self.QNNe = QNNe
        self.QNNbs = QNNbs
        self.qDivergenceTHR = qDivergenceTHR        
        self.imSampling = imSampling
        self.qSchedulingF = qSchedulingF
        self.eBlending = eBlending
        self.qRegWEIGHT = qRegWEIGHT
        self.QNNarch = QNNarch
        self.corrST = corrST
        self.QNNGradient = QNNGradient
        self.gEps = gEps
        self.ptQNN = ptQNN
        self.cEPOCH = 0
        self.eQuantumC = False
        self.cBATCHidx = 0
        self.QNN = None
        self.QNNOpt = None
        self.qACCh = []
        self.QNNACCh = []
        self.qLossh = []
        self.QNNLossh = []
        self.qBLENDw = 0.7  
        self.pComp = "medium" 
        self.qINs = []
        self.qOUTs = []
        self.qWs = []
        self.sImp = []  
        self.succP = []  
        self.lastEPvalL = float('inf')
        self.lastEPvalA = 0.0
        self.solQH = []
        self.epochQcache = {}  
        self.gradCache = {}  
        self.qQNNdiv = tf.keras.metrics.Mean(name='QuantumQNNDiv')
        self.QNNacc = tf.keras.metrics.Mean(name='QNNAccuracy')
        self.solQM = tf.keras.metrics.Mean(name='SolutionQualityMetrics')        
        self.aQuantumS = []
        self.corrA = False
        self.archADP = False
        self.UNCERTAINTY = False
        self.QNNinit = False
        
        super().__init__(dynamic=True, **kwargs)
        self.build(None)
        self._initialized = True

    def _SIGNATUREVal(self, bridge, weight_shapes):
        circFunc = bridge.circuit_func
        sig = inspect.signature(circFunc).parameters
        paraNames = list(sig.keys())
        if len(paraNames) == 1 and paraNames[0] == 'params':
            return
            
        if self.input_arg not in sig:
            raise TypeError(
                f"Circuit function must include an argument with name '{self.input_arg}' for inputting data, "
                f"or use a single 'params' parameter that combines inputs and weights."
            )
        
        if self.input_arg in set(weight_shapes.keys()):
            raise ValueError(
                f"{self.input_arg} argument should not have its dimension specified in "
                f"weight_shapes"
            )
        
        paramK = [p.kind for p in sig.values()]
        if inspect.Parameter.VAR_POSITIONAL in paramK:
            raise TypeError("Cannot have a variable number of positional arguments")
        if inspect.Parameter.VAR_KEYWORD not in paramK:
            if set(weight_shapes.keys()) | {self.input_arg} != set(sig.keys()):
                missPARA = set(sig.keys()) - {self.input_arg} - set(weight_shapes.keys())
                exPARA = set(weight_shapes.keys()) - (set(sig.keys()) - {self.input_arg})
                errMSG = "Must specify a shape for every non-input parameter in the circuit function."
                if missPARA:
                    errMSG += f" Missing weight shapes for: {missPARA}"
                if exPARA:
                    errMSG += f" Extra weight shapes provided for: {exPARA}"
                
                raise ValueError(errMSG)

    def build(self, inShape):
        for w, s in self.weightShapes.items():
            spec = self.wSpecs.get(w, {})
            self.QnW[w] = self.add_weight(name=w, shape=s, **spec)        
        super().build(inShape)        
        self.lossTr = tf.keras.metrics.Mean(name="loss")
        self.maeTr = tf.keras.metrics.Mean(name="mae")

    def _GetProblemComp(self, inDim, modelS=None):
        if inDim <= 10:
            comp = "simple"
        elif inDim <= 50:
            comp = "medium"
        else:
            comp = "complex"            
        if hasattr(self.bridge, 'circuit_func'):
            try:
                with qml.tape.QuantumTape() as tape:
                    self.bridge.circuit_func(np.random.random(len(self.bridge.params)))
                circOPs = len(tape.operations)
                if circOPs > 30:
                    comp = "complex"
                elif circOPs > 15 and comp == "simple":
                    comp = "medium"
            except Exception as e:
                print(f"Error analyzing circuit complexity: {e}")        
        if modelS and hasattr(modelS, 'layers') and len(modelS.layers) > 5:
            if comp != "complex":
                comp = "medium"        
        return comp

    def _BuildQNNmodel(self, input_dim):
        wDim = sum(np.prod(w.shape) for w in self.QnW.values())
        combDim = input_dim + wDim        
        if self.QNNarch and not self.archADP:
            self.pComp = self._GetProblemComp(input_dim)
            self.archADP = True            
            if self.pComp == "simple":
                self.QNNun = max(32, input_dim * 2)
                self.QNNlayers = 2
                self.QNNact = "relu"
            elif self.pComp == "medium":
                self.QNNun = max(64, input_dim * 4)
                self.QNNlayers = 3
                self.QNNact = "swish"
            else:  # complex
                self.QNNun = max(128, input_dim * 8)
                self.QNNlayers = 4
                self.QNNact = "swish"            
        
        inputs = tf.keras.Input(shape=(combDim,))        
        x = tf.keras.layers.Dense(
            self.QNNun,
            activation=self.QNNact,
            kernel_initializer=tf.keras.initializers.GlorotUniform(),
            kernel_regularizer=tf.keras.regularizers.l2(0.0001),
            name='QNN_dense_0')(inputs)
        x = tf.keras.layers.BatchNormalization()(x)
        for i in range(1, self.QNNlayers):
            res = x
            lUnits = self.QNNun
            if self.pComp == "complex" and i > 1:
                lUnits = max(32, self.QNNun // (i))
            x = tf.keras.layers.Dense(
                lUnits, 
                activation=self.QNNact,
                kernel_initializer=tf.keras.initializers.GlorotUniform(),
                kernel_regularizer=tf.keras.regularizers.l2(0.0001),
                name=f'QNN_dense_{i}')(x)
            x = tf.keras.layers.BatchNormalization()(x)
            if self.pComp == "complex":
                x = tf.keras.layers.Dropout(0.1)(x)
            if res.shape[-1] == x.shape[-1]:
                x = tf.keras.layers.add([x, res])
        oSize = self.outDim if isinstance(self.outDim, int) else np.prod(self.outDim)        
        if self.pComp == "complex":
            outputs = tf.keras.layers.Dense(
                oSize, 
                activation=None,
                kernel_initializer=tf.keras.initializers.GlorotUniform(),
                name='QNN_output')(x)            
            uncert = tf.keras.layers.Dense(
                32, activation=self.QNNact,
                name='uncertainty_dense')(x)
            uncert = tf.keras.layers.Dense(
                oSize, activation='sigmoid',
                name='uncertainty_output')(uncert)
            model = tf.keras.Model(
                inputs=inputs, 
                outputs=[outputs, uncert], 
                name='QNN_with_uncertainty')
            self.UNCERTAINTY = True
        else:
            outputs = tf.keras.layers.Dense(
                oSize, 
                activation=None,
                kernel_initializer=tf.keras.initializers.GlorotUniform(),
                name='QNN_output')(x)            
            model = tf.keras.Model(inputs=inputs, outputs=outputs, name='quantum_QNN')
            self.UNCERTAINTY = False
        return model

    @tf.function
    def _train_step(self, x, y, sWEIGHTS=None):
        with tf.GradientTape() as tape:
            if self.UNCERTAINTY:
                preds, uncertainty = self.QNN(x, training=True)                
                if sWEIGHTS is not None:
                    if len(sWEIGHTS.shape) != len(preds.shape):
                        sWEIGHTS = tf.reshape(sWEIGHTS, [-1, 1])                    
                    error = tf.square(preds - y)
                    conf = 1.0 - uncertainty                      
                    wLoss = error * sWEIGHTS * conf
                    loss = tf.reduce_mean(wLoss)                     
                    uncertLoss = 0.1 * tf.reduce_mean(uncertainty)
                    loss += uncertLoss
                else:
                    error = tf.square(preds - y)
                    conf = 1.0 - uncertainty
                    loss = tf.reduce_mean(error * conf)                    
                    uncertLoss = 0.1 * tf.reduce_mean(uncertainty)
                    loss += uncertLoss
            else:
                preds = self.QNN(x, training=True)
                
                if sWEIGHTS is not None:
                    if len(sWEIGHTS.shape) != len(preds.shape):
                        sWEIGHTS = tf.reshape(sWEIGHTS, [-1, 1])                    
                    sqrtErr = tf.square(preds - y)
                    wLoss = sqrtErr * sWEIGHTS
                    loss = tf.reduce_mean(wLoss)
                else:
                    loss = tf.reduce_mean(tf.square(preds - y))            
            if self.QNN.losses:
                regLoss = tf.reduce_sum(self.QNN.losses)
                loss += regLoss        
        trVars = self.QNN.trainable_variables
        grads = tape.gradient(loss, trVars)        
        grads, _ = tf.clip_by_global_norm(grads, 1.0)
        self.QNNOpt.apply_gradients(zip(grads, trVars))
        return loss, preds

    def __initQNN__(self, xSamples):
        if not self.QNNinit and self.ptQNN:            
            if self.QNN is None:
                self.QNN = self._BuildQNNmodel(xSamples.shape[-1])
                self.QNNOpt = tf.keras.optimizers.Adam(learning_rate=self.QNNlr)            
            nSamples = min(100, max(30, xSamples.shape[0] * 2))            
            randINs = tf.random.normal([nSamples, xSamples.shape[-1]])
            if xSamples.shape[0] > 5:
                aIdx = np.random.choice(xSamples.shape[0], size=min(5, xSamples.shape[0]), replace=False)
                aINs = tf.gather(xSamples, aIdx)
                mINs = tf.concat([randINs, aINs], axis=0)
            else:
                mINs = randINs
            combINs = self.__initQNNins(mINs)            
            outS = self.outDim if isinstance(self.outDim, int) else np.prod(self.outDim)
            sOuts = tf.random.normal([combINs.shape[0], outS], mean=0.5, stddev=0.1)            
            dataset = tf.data.Dataset.from_tensor_slices((combINs, sOuts)) # this is for DynexAlgorithm v1 since it has the highest uncert I'VE SEEN IN MAAA LIIFE
            dataset = dataset.batch(min(16, combINs.shape[0]))
            for e in range(20):
                eLoss = 0.0
                nBatch = 0
                for xB, yB in dataset:
                    with tf.GradientTape() as tape:
                        if self.UNCERTAINTY:
                            preds, _ = self.QNN(xB, training=True)
                        else:
                            preds = self.QNN(xB, training=True)   
                        loss = tf.reduce_mean(tf.square(preds - yB))
                    
                    grads = tape.gradient(loss, self.QNN.trainable_variables)
                    self.QNNOpt.apply_gradients(zip(grads, self.QNN.trainable_variables))
                    eLoss += loss.numpy()
                    nBatch += 1
            self.QNNinit = True

    def _QforwardTraining(self, x):
        vars = list(self.QnW.values())
        @tf.custom_gradient
        def _forward(x, *weights):
            qRes = self._EvaluateQCircuit(x)
            def grad(upsGrads, variables=None):
                if self.QNNGradient == 'default':
                    combINs = self.__initQNNins(x)
                    with tf.GradientTape() as tape:
                        tape.watch(combINs)
                        QNNouts = self.QNN(combINs, training=False)
                        if self.UNCERTAINTY:
                            QNNouts = QNNouts[0]
                    QNNgrads = tape.gradient(QNNouts, combINs)
                    INgrads = QNNgrads[:, :x.shape[-1]]                    
                    wGrads = [tf.zeros_like(w) for w in weights]                    
                    return [INgrads * upsGrads * self.qBLENDw] + wGrads
                    
                elif self.QNNGradient == 'parameter_shift': # WIP
                    INgrads = self._QuantumGradEST(x) * upsGrads
                    wGrads = [tf.zeros_like(w) for w in weights]
                    return [INgrads] + wGrads
                    
                else: # WIP
                    QNNcombINS = self.__initQNNins(x)
                    with tf.GradientTape() as tape:
                        tape.watch(QNNcombINS)
                        QNNouts = self.QNN(QNNcombINS, training=False)
                        if self.UNCERTAINTY:
                            QNNouts = QNNouts[0]
                    QNNgrads = tape.gradient(QNNouts, QNNcombINS)
                    QNNinGrad = QNNgrads[:, :x.shape[-1]]                    
                    PARAshGrad = tf.zeros_like(x)
                    if x.shape[-1] > 5:
                        id = np.random.choice(x.shape[-1], size=min(5, x.shape[-1]), replace=False)
                        for idx in id:
                            xP = tf.identity(x)
                            xP[:, idx] += self.gEps
                            xm = tf.identity(x)
                            xm[:, idx] -= self.gEps
                            yP = self._EvaluateQCircuit(xP)
                            ym = self._EvaluateQCircuit(xm)
                            PARAshGrad[:, idx] = (yP - ym) / (2 * self.gEps)
                    else:
                        PARAshGrad = self._QuantumGradEST(x)                    
                    combGrad = 0.7 * QNNinGrad + 0.3 * PARAshGrad
                    INgrads = combGrad * upsGrads
                    wGrads = [tf.zeros_like(w) for w in weights]
                    return [INgrads] + wGrads
            return qRes, grad
        
        return _forward(x, *vars)

    def _QuantumGradEST(self, x, epsilon=None):
        if epsilon is None:
            epsilon = self.gEps
        grad = tf.zeros_like(x)
        xs = x[:1] if x.shape[0] > 1 else x
        for i in range(xs.shape[-1]):
            xP = tf.identity(xs)
            xP[..., i] += epsilon
            xm = tf.identity(xs)
            xm[..., i] -= epsilon            
            yP = self._EvaluateQCircuit(xP)
            ym = self._EvaluateQCircuit(xm)            
            ParaGrad = (yP - ym) / (2 * epsilon)            
            for b in range(x.shape[0]):
                grad = tf.tensor_scatter_nd_update(
                    grad, 
                    indices=[[b, i]],
                    updates=[ParaGrad[0]])
        return grad

    def _EvalSolQ(self, qOUTs):
        outs = qOUTs.numpy() if hasattr(qOUTs, 'numpy') else qOUTs        
        zR = np.mean(np.abs(outs) < 1e-6)        
        if outs.size > 1:
            stdDev = np.std(outs)
            uniformity = np.exp(-10 * stdDev)  
        else:
            uniformity = 0.0        
        biR = np.mean(np.logical_or(np.abs(outs) < 0.1, np.abs(outs - 1.0) < 0.1))        
        qualityS = 1.0 - (0.5 * zR + 0.3 * uniformity + 0.2 * (1.0 - biR))        
        qualityS = max(0.0, min(1.0, qualityS))
        return qualityS

    def _CorrectQuantumSol(self, qOUTs, trMetrics=None):
        solQua = self._EvalSolQ(qOUTs)
        self.solQM(solQua)
        self.solQH.append(solQua)
        corr = solQua < 0.6
        if not corr:
            return qOUTs, False
        oOUTs = tf.convert_to_tensor(qOUTs, dtype=tf.float32)        
        if solQua < 0.3:
            corrF = self.corrST * 1.5
        else:  
            corrF = self.corrST            
        if len(self.succP) > 0:
            bPattern = None
            bSim = -float('inf')
            for p in self.succP:
                pt = tf.convert_to_tensor(p['pattern'], dtype=tf.float32)                
                sim = -tf.reduce_mean(tf.abs(oOUTs - pt))
                if sim > bSim:
                    bSim = sim
                    bPattern = pt
            if bPattern is not None:
                corrOUTs = (1.0 - corrF) * oOUTs + corrF * bPattern
            else:
                mVals = tf.ones_like(oOUTs) * 0.5
                corrOUTs = (1.0 - corrF) * oOUTs + corrF * mVals
        else:
            mVals = tf.ones_like(oOUTs) * 0.5
            corrOUTs = (1.0 - corrF) * oOUTs + corrF * mVals
        corrOUTs = tf.clip_by_value(corrOUTs, 0.0, 1.0)
        return corrOUTs, True

    def call(self, ins, tr=None):
        iTRAIN = tr is None or tr        
        batchDim = len(ins.shape) > 1
        if batchDim:
            batchDims = tf.shape(ins)[:-1]
            ins = tf.reshape(ins, (-1, ins.shape[-1]))        
        if self.QNN is None:
            self.QNN = self._BuildQNNmodel(ins.shape[-1])
            self.QNNOpt = tf.keras.optimizers.Adam(
                learning_rate=self.QNNlr)            
            self.__initQNN__(ins)        
        if iTRAIN:
            self.cBATCHidx += 1            
            if not self.eQuantumC:                
                res = self._QforwardTraining(ins)                
                corrRes, corrApp = self._CorrectQuantumSol(
                    res, 
                    trMetrics={
                        'epoch': self.cEPOCH,
                        'batch': self.cBATCHidx
                    })
                if corrApp:
                    self.corrA = True                    
                    with tf.GradientTape() as tape:
                        results_var = tf.Variable(corrRes, trainable=True)
                        tape.watch(results_var)
                        res = tf.identity(results_var)
                else:
                    res = corrRes
                self.qINs.append(ins.numpy())
                self.qOUTs.append(res.numpy())
                inHash = [self._HASHins(x) for x in ins.numpy()]
                for i, hk in enumerate(inHash):
                    self.epochQcache[hk] = res[i].numpy()
                self.eQuantumC = True
                combINs = self.__initQNNins(ins)
                if self.UNCERTAINTY:
                    QNNres, uncertainty = self.QNN(combINs, training=True)
                    conf = 1.0 - tf.reduce_mean(uncertainty)
                    bW = tf.minimum(0.8, tf.maximum(0.2, conf))
                    ensRes = bW * res + (1.0 - bW) * QNNres
                else:
                    QNNres = self.QNN(combINs, training=True)
                    ensRes = self.qBLENDw * res + (1.0 - self.qBLENDw) * QNNres
                res = ensRes
            else:
                combINs = self.__initQNNins(ins)
                if self.UNCERTAINTY:
                    QNNres, uncertainty = self.QNN(combINs, training=True)
                    res = QNNres
                else:
                    res = self.QNN(combINs, training=True)
                
        else:
            combINs = self.__initQNNins(ins)
            if self.UNCERTAINTY:
                res, _ = self.QNN(combINs, training=False)
            else:
                res = self.QNN(combINs, training=False)
        
        res = tf.nn.sigmoid(res * 2.0)  # scale and squash to 0,1
        if batchDim:
            new_shape = tf.concat([batchDims, tf.shape(res)[1:]], axis=0)
            res = tf.reshape(res, new_shape)
        return res

    def _BOUTS(self, qRes, QNNRes, uncertainty=None):
        if uncertainty is not None:
            confidence = 1.0 - tf.reduce_mean(uncertainty)
            adpW = tf.minimum(0.9, tf.maximum(0.1, confidence))
            qW = 1.0 - adpW
        else:
            qW = self.qBLENDw
            
        bl = qW * qRes + (1.0 - qW) * QNNRes
        return bl

    def _EvaluateQCircuit(self, x):
        batchSize = x.shape[0]
        xN = x.numpy()
        weights = {k: w.numpy() for k, w in self.QnW.items()}
        sampleIdx = 0
        sX = xN[sampleIdx].flatten()
        circSig = inspect.signature(self.bridge.circuit_func).parameters
        paramN = list(circSig.keys())
        if len(paramN) == 1 and paramN[0] == 'params':
            initParams = np.copy(self.bridge.params)
            inSize = len(sX)
            nParams = np.copy(initParams)
            nParams[:inSize] = sX
            wIdx = inSize
            for wN, wVal in weights.items():
                wSize = np.prod(wVal.shape)
                if wSize == 1:
                    nParams[wIdx] = wVal.item()
                    wIdx += 1
                else:
                    nParams[wIdx:wIdx + wSize] = wVal.flatten()
                    wIdx += wSize
            
            self.bridge.params = nParams
        adpNumREADS, adpANNEALINGtime = self.__ALLOC_QR()        
        self.bridge.ExtractHamiltonian()
        bqm = self.bridge.to_bqm()
        sampleset = self.bridge.DynexCompute(
            bqm, 
            num_reads=adpNumREADS,
            annealing_time=adpANNEALINGtime,
            mainnet=self.mainnet,
            description='Dynex SDK Job',
            printSolution=False,
            debugging=False,
            is_cluster=True,
            shots=batchSize,)        
        aRes = self.bridge.DecodeSolution(sampleset, method='all')
        pRes = self.bridge.DecodeSolution(sampleset, method='probs')        
        self.aQuantumS.append({
            'all_results': aRes,
            'prob_results': pRes,
            'inputs': xN,
            'weights': weights,
            'num_reads': adpNumREADS,
            'annealing_time': adpANNEALINGtime,
            'epoch': self.cEPOCH
        })        
        if len(aRes) < batchSize:
            aRes = aRes + [aRes[-1]] * (batchSize - len(aRes))
        elif len(aRes) > batchSize:
            aRes = aRes[:batchSize]        
        resTensor = tf.convert_to_tensor(aRes, dtype=tf.float32)        
        if self.decMethod != 'all':
            if isinstance(self.outDim, tuple):
                resTensor = tf.reshape(resTensor, (batchSize,) + self.outDim)
            else:
                resTensor = tf.reshape(resTensor, (batchSize, self.outDim))
        
        return resTensor

    def __ALLOC_QR(self):
        if self.cEPOCH < 3:
            return self.num_reads, self.annealing_time        
        if len(self.qLossh) > 1:
            lossCH = self.qLossh[-1] - self.qLossh[-2]            
            if lossCH < -0.01:
                progF = min(1.0, abs(lossCH) / (self.qLossh[-2] + 1e-10))
                adpNumReads = max(int(self.num_reads * (1.0 - progF * 0.3)), 256)
                adpAnnealingTime = max(int(self.annealing_time * (1.0 - progF * 0.3)), 128)
                #print(f"Good progress BRO :D - reducing quantum resources (num_reads={adaptive_num_reads}, time={adaptive_annealing_time})")
                
            elif lossCH > 0.005:  
                stagF = min(1.0, abs(lossCH) / (self.qLossh[-2] + 1e-10))
                adpNumReads = min(int(self.num_reads * (1.0 + stagF * 0.5)), 1024)
                adpAnnealingTime = min(int(self.annealing_time * (1.0 + stagF * 0.5)), 512)
                #print(f"Bad progress :O - increasing quantum resources (num_reads={adaptive_num_reads}, time={adaptive_annealing_time})")
            
            else:
                adpNumReads = self.num_reads
                adpAnnealingTime = self.annealing_time
        else:
            adpNumReads = self.num_reads
            adpAnnealingTime = self.annealing_time
        return adpNumReads, adpAnnealingTime

    def __initQNNins(self, x):
        wVals = []
        for w in self.QnW.values():
            wVals.append(tf.reshape(w, [-1]))        
        aW = tf.concat(wVals, axis=0)        
        bSize = tf.shape(x)[0]
        repW = tf.repeat(tf.expand_dims(aW, axis=0), bSize, axis=0)
        combINs = tf.concat([x, repW], axis=1)
        return combINs

    def _HASHins(self, x):
        x_tuple = tuple(x.flatten().astype(np.float32))
        return hash(x_tuple)

    def _CalcSampleIMP(self, INs, OUTs, modelLoss=None):
        tSamples = sum(len(x) for x in INs) if isinstance(INs, list) else len(INs)        
        impW = np.ones(tSamples)        
        try:
            if isinstance(INs, list):
                aINs = np.vstack(INs)
            else:
                aINs = INs
        except Exception as e:
            return impW        
        if len(aINs) > 1:
            if aINs.shape[0] <= 1000:
                dist = euclidean_distances(aINs)
                meanDist = np.sum(dist, axis=1) / (len(aINs) - 1)
                if np.max(meanDist) > 0:
                    divW = 0.5 + meanDist / np.max(meanDist)
                    impW *= divW
            else:
                fSTDs = np.std(aINs, axis=0)
                avgSTD = np.mean(fSTDs)
                divF = np.clip(avgSTD, 0.5, 1.5)
                impW *= divF
        
        if modelLoss is not None:
            try:
                if hasattr(modelLoss, '__len__') and len(modelLoss) == len(impW):
                    normalized_loss = modelLoss / np.max(modelLoss) if np.max(modelLoss) > 0 else modelLoss
                    impW *= (0.5 + 0.5 * normalized_loss)
            except Exception as e:
                print(f"Error processing loss values: {e}")        
        if np.mean(impW) > 0:
            impW = impW / np.mean(impW)
        return impW

    def _AdaptQNNTraining(self, eml=None, emc=None):
        epochs = self.QNNe
        bSize = self.QNNbs
        lr = self.QNNlr        
        if len(self.solQH) > 0:
            avgQ = np.mean(self.solQH[-min(3, len(self.solQH)):])            
            if avgQ < 0.4:  
                epochs = int(self.QNNe * 1.3)  
                lr = self.QNNlr * 1.5 
                bSize = max(8, self.QNNbs // 2)  
            
            else:
                epochs = int(self.QNNe * 0.8) 
                lr = self.QNNlr * 0.8 
                bSize = min(self.QNNbs * 2, 32)  
                
        if emc is not None and len(self.qACCh) > 1:
            accIMP = emc - self.qACCh[-1]            
            if accIMP > 0.05:
                epochs = max(int(epochs * 0.8), 20)
            
            elif accIMP < 0.01:
                epochs = min(int(epochs * 1.2), 100)                
        return epochs, bSize, lr

    def _UpdatePatterns(self, qOUTs, modelImprov):
        if modelImprov > 0.01: 
            self.succP.append({
                'pattern': qOUTs,
                'improvement': modelImprov,
                'epoch': self.cEPOCH
            })
            
            if len(self.succP) > 10:
                self.succP = sorted(
                    self.succP, 
                    key=lambda x: x['improvement'], 
                    reverse=True)[:10]
                
    def _TrainQNN(self, eml=None, emc=None):
        if not self.qINs or not self.qOUTs:
            print("No quantum data available for QNN training")
            return False

        adptE, adptbS, adptLR = self._AdaptQNNTraining(
            eml, emc)
        
        if adptLR != self.QNNlr:
            self.QNNlr = adptLR
            self.QNNOpt.learning_rate.assign(adptLR)
        try:
            aINs = np.vstack(self.qINs)
            aOUTs = np.vstack(self.qOUTs)
        except Exception as e:
            print(f"Error combining quantum data: {e}")
            return False
        
        xT = tf.convert_to_tensor(aINs, dtype=tf.float32)
        combIN = self.__initQNNins(xT)
        yT = tf.convert_to_tensor(aOUTs, dtype=tf.float32)        
        sampleW = None
        if self.imSampling:
            try:
                sampleW = self._CalcSampleIMP(
                    self.qINs, 
                    self.qOUTs,
                    eml)
                
                if sampleW is not None and len(sampleW) == len(aINs):
                    pass
                else:
                    sampleW = np.ones(len(aINs))
            except Exception as e:
                sampleW = np.ones(len(aINs))
        
        try:
            if sampleW is not None:
                wTensor = tf.convert_to_tensor(sampleW, dtype=tf.float32)
                if len(wTensor.shape) == 1 and wTensor.shape[0] == combIN.shape[0]:
                    dataset = tf.data.Dataset.from_tensor_slices((combIN, yT, wTensor))
                    dataset = dataset.shuffle(buffer_size=min(1000, len(aINs))).batch(adptbS)
                else:
                    dataset = tf.data.Dataset.from_tensor_slices((combIN, yT))
                    dataset = dataset.shuffle(buffer_size=min(1000, len(aINs))).batch(adptbS)
            else:
                dataset = tf.data.Dataset.from_tensor_slices((combIN, yT))
                dataset = dataset.shuffle(buffer_size=min(1000, len(aINs))).batch(adptbS)
        except Exception as e:
            return False
        
        tLoss = 0
        try:
            for e in range(adptE):
                eLoss = 0
                nBatch = 0
                for bd in dataset:
                    if len(bd) == 3:
                        xB, yB, wB = bd
                        loss, _ = self._train_step(xB, yB, wB)
                    else:
                        xB, yB = bd
                        loss, _ = self._train_step(xB, yB)
                    eLoss += loss
                    nBatch += 1  
                avgELoss = eLoss / nBatch if nBatch > 0 else 0
                tLoss = avgELoss
        except Exception as e:
            return False
        
        try:
            latINs = self.qINs[-1]
            latOUTs = self.qOUTs[-1]
            
            xEVAL = tf.convert_to_tensor(latINs, dtype=tf.float32)
            combEVAL = self.__initQNNins(xEVAL)
            yEVAL = tf.convert_to_tensor(latOUTs, dtype=tf.float32)
            
            # predictions (with or without uncertainty)
            if self.UNCERTAINTY:
                preds, uncertainty = self.QNN(combEVAL, training=False)
                mae = tf.reduce_mean(tf.abs(preds - yEVAL)).numpy()
            else:
                preds = self.QNN(combEVAL, training=False)
                mae = tf.reduce_mean(tf.abs(preds - yEVAL)).numpy()
            
            if self.eBlending:
                QNNconf = max(0.1, min(0.9, 1.0 - mae))
                self.qBLENDw = 1.0 - QNNconf
            self.QNNacc(1.0 - mae)
            self.qQNNdiv(mae)
            
            return mae <= self.qDivergenceTHR
        except Exception as e:
            return False

    def _UpdateLM(self, eLoss, eAcc):
        preAcc = self.qACCh[-1] if len(self.qACCh) > 0 else 0        
        self.qLossh.append(eLoss)
        self.qACCh.append(eAcc)        
        accIMPROV = eAcc - preAcc        
        if len(self.qOUTs) > 0:
            self._UpdatePatterns(self.qOUTs[-1], accIMPROV)

    def _ONEPBegin(self, e, eLogs=None):
        self.cEPOCH = e
        self.cBATCHidx = 0        
        self.eQuantumC = False  
        self.epochQcache = {}
        self.corrA = False                
        nQuantum = True        
        if eLogs and 'val_loss' in eLogs:
            valLossCH = self.lastEPvalL - eLogs['val_loss']
            valAccCH = eLogs['val_accuracy'] - self.lastEPvalA            
            if (valLossCH > 0 and valAccCH > 0 and 
                self.qQNNdiv.result() < self.qDivergenceTHR and 
                np.random.random() > self.qSchedulingF):
                nQuantum = False
        
        if not nQuantum:
            self.eQuantumC = True

    def _ONEPEnd(self, e, eLogs=None):
        eLoss = None
        eAcc = None
        eValLoss = None
        eValAcc = None
        if eLogs:
            if 'loss' in eLogs:
                eLoss = eLogs['loss']
                
            if 'accuracy' in eLogs:
                eAcc = eLogs['accuracy']
                
            if 'val_loss' in eLogs:
                eValLoss = eLogs['val_loss']
                self.lastEPvalL = eValLoss
                
            if 'val_accuracy' in eLogs:
                eValAcc = eLogs['val_accuracy']
                self.lastEPvalA = eValAcc
        
        if eLoss is not None and eAcc is not None:
            self._UpdateLM(eLoss, eAcc)
        
        QNNAcc = self._TrainQNN(eLoss, eAcc)        
        if self.eBlending and eLogs and 'val_accuracy' in eLogs:
            val_acc = eLogs['val_accuracy']
            self.qBLENDw = max(0.1, min(0.9, 1.0 - val_acc))

        self.QNNacc.reset_states()
        self.qQNNdiv.reset_states()
        self.solQM.reset_states()
            
        if len(self.qLossh) > 2:
            recLOSS = self.qLossh[-3:]
            if max(recLOSS) - min(recLOSS) < 0.01:
                self.QNNlr *= 0.7
                self.QNNOpt.learning_rate.assign(self.QNNlr)

    def compute_output_shape(self, input_shape):
        return tf.TensorShape([input_shape[0]]).concatenate(self.outDim)
    
    def get_config(self):
        config = super().get_config()
        config.update({
            "output_dim": self.outDim,
            "weight_specs": self.wSpecs,
            "weight_shapes": self.weightShapes,
            "num_reads": self.num_reads,
            "annealing_time": self.annealing_time,
            "mainnet": self.mainnet,
            "shots": self.shots,
            "decode_method": self.decMethod,
            "QNN_units": self.QNNun,
            "QNN_layers": self.QNNlayers,
            "QNN_activation": self.QNNact,
            "QNN_lr": self.QNNlr,
            "importance_sampling": self.imSampling,
            "ensemble_blending": self.eBlending,
            "quantum_regularization_weight": self.qRegWEIGHT,
            "auto_architecture": self.QNNarch,
            "correction_strength": self.corrST,
            "gradient_estimation": self.QNNGradient,
            "gradient_epsilon": self.gEps,
            "pre_train_QNN": self.ptQNN,
        })
        return config
        
    def __str__(self):
        detail = "<Dynex Quantum Layer: func={}>"
        return detail.format(self.bridge.circuit_func.__name__)

    __repr__ = __str__
    _input_arg = "inputs"
    _initialized = False

    @property
    def input_arg(self):
        return self._input_arg

    @staticmethod
    def _SetInputARG(input_name: Text = "inputs") -> None:
        if HAS_TENSORFLOW:
            QKerasLayer._input_arg = input_name
else:
    # Stub classes when tensorflow is not available
    QuantumCallback = None
    QKerasLayer = None


class PennylaneBridge:
    def __init__(self, circuit, params, wires, draw=True, type='PennyLane', QKerasLayer=False):
        self.circuit_func = circuit
        self.circuit = circuit
        self.params = params
        self.wires = wires if isinstance(wires, int) else len(wires) # taking full measure approach :)
        self.device = qml.device('default.qubit', wires=wires)
        #self.qnode = qml.QNode(circuit, self.device) # we don't need it now
        with qml.tape.QuantumTape() as tape:
            self.circuit(self.params)
        max_wire = max(wire for op in tape.operations for wire in op.wires)
        self.qubits = {wire: ComplexQubit(f'q_{wire}') for wire in range(max_wire + 1)}
        self.hamiltonian = 0
        self.isGrover = False 
        self.isQPE = False
        self.isCQU = False
        self.isQU = False # Since it has sort of similar behaviour like CQU
        self.isAE = False 
        self.globalConst = []
        self.placeholders = {}
        self.gHandlers = {
            'PauliX': self._PAULIX,
            'PauliY': self._PAULIY,
            'PauliZ': self._PAULIZ,
            'RX': self._RX,
            'RY': self._RY,
            'RZ': self._RZ,
            'CNOT': self._CNOT,
            'Hadamard': self._H,
            'CZ': self._CZ,
            'SWAP': self._SWAP,
            'CRX': self._CRX,
            'CRY': self._CRY,
            'CRZ': self._CRZ,
            'Toffoli': self._TOFFOLI,
            'T': self._T,
            'QFT': self._QFT,
            'BasisEmbedding': self._BASISEMB,
            'BasisState': self._BASISSTATE,
            'FlipSign': self._FLIPSIGN,
            'GroverOperator': self._GROVER,
            'QuantumPhaseEstimation': self._QPE,
            'QubitUnitary': self._QUBITUNITARY,
            'ControlledQubitUnitary': self._CONTROLLEDQUBITUNITARY,
            'ControlledPhaseShift': self._CONTROLLEDPHASESHIFT,
            'S': self._S,
            'SX': self._SX,
            'AngleEmbedding': self._ANGLEEMBEDDING,
            'AmplitudeEmbedding': self._AMPLITUDEEMBEDDING,
            'Rot': self._ROT,
            'StronglyEntanglingLayers': self._STRONGLYENTANGLING,
        }
        self.penaltyScale = {
            'PauliX': 1.0, 'PauliY': 1.0, 'PauliZ': 1.0,
            'RX': 1.0, 'RY': 1.0, 'RZ': 1.0,  # Increase RZ for controlled rotations
            'CNOT': 1.0, 'Hadamard': 1.0,
            'CZ': 1.0, 'SWAP': 1.0, 'CRX': 1.0, 'CRY': 1.0, 'CRZ': 1.0,
            'Toffoli': 1.0, 'T': 1.0, 'QFT': 1.0,  # Increase QFT importance
            'BasisEmbedding': 1.0, 'BasisState': 1.0, 'GlobalConstraint': 1.0, 
            'FlipSign': 1.0,  # Critical for marking the solution
            'Ctrl': 1.0,  # Important for controlled operations
            'GroverOperator': 1.0,  # Key component of the algorithm
            'QPE': 1.0, # needed for the Shors Algo
            'QubitUnitary': 1.0,
            'ControlledQubitUnitary': 1.0,
            'ControlledPhaseShift': 1.0,
            'S': 1.0, # phase shift gatee
            'SX': 1.0, # square root of X
            'AngleEmbedding': 1.0, 
            'AmplitudeEmbedding': 1.0,
            'Rot': 1.0,
            'StronglyEntanglingLayers': 1.0,
        }

        if draw:
            print("-----------/ Visualize Circuit /-----------")
            drawer = qml.draw(circuit)
            print(drawer(params))
            print("-----------/ ***************** /-----------")
        
        self.hamiltonian = self.ExtractHamiltonian()
        self.bqm = self.to_bqm()

    def SetPenaltyScale(self, gate_type, scale):
        if gate_type in self.penaltyScale:
            self.penaltyScale[gate_type] = scale
        else:
            raise ValueError(f"Unknown gate type: {gate_type}")

    def AddGlobalConstraint(self, constraint_func, label, scale=1.0):
        self.globalConst.append((constraint_func, label, scale))

    def ExtractHamiltonian(self):
        with qml.tape.QuantumTape() as tape:
            self.circuit(self.params)
        for op in tape.operations:
            self._HandleGate(op)
        for constraint_func, label, scale in self.globalConst:
            constraint = constraint_func(self.qubits)
            self.hamiltonian += scale * self.penaltyScale['GlobalConstraint'] * Constraint(constraint, label=label)
        return self.hamiltonian
             
    def _HandleGate(self, op):
        gType = op.name
        #print(gType)
        scale = self.penaltyScale.get(gType, 1.0)
        try:
            if gType.startswith('Adjoint'):
                adjointNum = 0 # I added this counter for the nested Adjoint(dagger) gates (technically this is just FlipFlop method :))
                cOP = op
                while hasattr(cOP, 'base'):
                    if hasattr(cOP, 'adjoint') and cOP.adjoint:
                        adjointNum += 1
                    cOP = cOP.base
                bGate = cOP.name
                if bGate in self.gHandlers:
                    isADJOINT = adjointNum % 2 == 1
                    self.gHandlers[bGate](cOP, scale, adjoint=isADJOINT)
                else:
                    raise ValueError(f"Unknown adjoint operation: {bGate}")
            elif gType.startswith('C('):
                cOP = op.base
                cgType = cOP.name
                scale = self.penaltyScale.get(cgType, 1.0)
                self._CTRL(op, cOP, scale)
            elif gType in self.gHandlers:
                self.gHandlers[gType](op, scale)
            else:
                raise ValueError(f"Unknown gate type: {gType}")
        except Exception as e:
            print(f"Error handling gate {gType}: {str(e)}")
            raise

    def _PAULIX(self, op, scale):
        wire = op.wires[0]
        q = self.qubits[wire]
        qN = ComplexQubit(f'q_{wire}_paulix')
        constREAL = (qN.real - q.imag)**2
        constIMAG = (qN.imag - q.real)**2
        hamiltonian = scale * Constraint(constREAL + constIMAG, label=f'PauliX_{wire}')
        self.hamiltonian += hamiltonian
        self.qubits[wire] = qN
    
    def _PAULIY(self, op, scale):
        wire = op.wires[0]
        q = self.qubits[wire]
        qN = ComplexQubit(f'q_{wire}_pauliy')
        constREAL = (qN.real - q.imag)**2
        constIMAG = (qN.imag + q.real)**2
        hamiltonian = scale * Constraint(constREAL + constIMAG, label=f'PauliY_{wire}')
        self.hamiltonian += hamiltonian
        self.qubits[wire] = qN
    
    def _PAULIZ(self, op, scale):
        wire = op.wires[0]
        q = self.qubits[wire]
        qN = ComplexQubit(f'q_{wire}_pauliz')
        constREAL = (qN.real - q.real)**2
        constIMAG = (qN.imag + q.imag)**2
        hamiltonian = scale * Constraint(constREAL + constIMAG, label=f'PauliZ_{wire}')
        self.hamiltonian += hamiltonian
        self.qubits[wire] = qN
    
    def _RX(self, op, scale, adjoint=False):
        wire = op.wires[0]
        ang = op.parameters[0] * (-1 if adjoint else 1)
        q = self.qubits[wire]
        qN = ComplexQubit(f'q_{wire}_rx')
        c = Placeholder(f'cos_rx_{wire}')
        s = Placeholder(f'sin_rx_{wire}')
        constREAL = (qN.real - (c * q.real - s * q.imag))**2
        constIMAG = (qN.imag - (c * q.imag + s * q.real))**2
        hamiltonian = scale * Constraint(constREAL + constIMAG, label=f'RX_{wire}')
        self.hamiltonian += hamiltonian
        self.qubits[wire] = qN
        self.placeholders[f'cos_rx_{wire}'] = float(np.cos(ang/2))
        self.placeholders[f'sin_rx_{wire}'] = float(np.sin(ang/2))
        #print("[SamDEBUG] :: Rotated-X gate is DONE")

    def _RY(self, op, scale, adjoint=False):
        wire = op.wires[0]
        ang = op.parameters[0] * (-1 if adjoint else 1)
        q = self.qubits[wire]
        qN = ComplexQubit(f'q_{wire}_ry')
        c = Placeholder(f'cos_ry_{wire}')
        s = Placeholder(f'sin_ry_{wire}')
        constREAL = (qN.real - (c * q.real - s * (1 - q.real)))**2
        constIMAG = (qN.imag - (c * q.imag + s * q.imag))**2
        hamiltonian = scale * Constraint(constREAL + constIMAG, label=f'RY_{wire}')
        self.hamiltonian += hamiltonian
        self.qubits[wire] = qN
        self.placeholders[f'cos_ry_{wire}'] = float(np.cos(ang/2))
        self.placeholders[f'sin_ry_{wire}'] = float(np.sin(ang/2))
        #print("[SamDEBUG] :: Rotated-Y gate is DONE")

    def _RZ(self, op, scale, adjoint=False):
        wire = op.wires[0]
        ang = op.parameters[0] * (-1 if adjoint else 1)
        q = self.qubits[wire]
        qN = ComplexQubit(f'q_{wire}_rz')
        c = Placeholder(f'cos_rz_{wire}')
        s = Placeholder(f'sin_rz_{wire}')
        constREAL = (qN.real - (c * q.real + s * q.imag))**2
        constIMAG = (qN.imag - (c * q.imag - s * q.real))**2
        hamiltonian = scale * Constraint(constREAL + constIMAG, label=f'RZ_{wire}')
        self.hamiltonian += hamiltonian
        self.qubits[wire] = qN
        self.placeholders[f'cos_rz_{wire}'] = float(np.cos(ang/2))
        self.placeholders[f'sin_rz_{wire}'] = float(np.sin(ang/2))
        #print("[SamDEBUG] :: Rotated-Z gate is DONE")

    def _CNOT(self, op, scale):
        control, target = op.wires
        c, t = self.qubits[control], self.qubits[target]
        tN = ComplexQubit(f'q_{target}_cnot')
        constREAL = (Xor(c.real, t.real) - tN.real)**2
        constIMAG = (t.imag - tN.imag)**2
        hamiltonian = scale * Constraint(constREAL + constIMAG, label=f'CNOT_{control}_{target}')
        self.hamiltonian += hamiltonian
        self.qubits[target] = tN
        #print("[SamDEBUG] :: Controlled-NOT gate is DONE")

    def _H(self, op, scale):
        wire = op.wires[0]
        q = self.qubits[wire]
        qN = ComplexQubit(f'q_{wire}_h')
        constREAL = (qN.real - (q.real + q.imag) / np.sqrt(2))**2
        constIMAG = (qN.imag - (q.imag - q.real) / np.sqrt(2))**2
        hamiltonian = scale * Constraint(constREAL + constIMAG, label=f'Hadamard_{wire}')
        self.hamiltonian += hamiltonian
        self.qubits[wire] = qN
        #print("[SamDEBUG] :: Hadamard gate is DONE")

    def _CZ(self, op, scale):
        control, target = op.wires
        c, t = self.qubits[control], self.qubits[target]
        tN = ComplexQubit(f'q_{target}_cz')
        constREAL = (tN.real - t.real)**2
        constIMAG = (tN.imag - (1 - 2 * c.real * t.imag))**2
        hamiltonian = scale * Constraint(constREAL + constIMAG, label=f'CZ_{control}_{target}')
        self.hamiltonian += hamiltonian
        self.qubits[target] = tN
        #print("[SamDEBUG] :: Controlled-Z gate is DONE")


    def _ROT(self, op, scale, adjoint=False):
        """
        Matrix representation:
        [
            [e^(-i(phi+omega)/2)cos(theta/2), -e^(i(phi-omega)/2)sin(theta/2)],
            [e^(-i(phi-omega)/2)sin(theta/2), e^(i(phi+omega)/2)cos(theta/2)]
        ]
        """
        wire = op.wires[0]
        if adjoint:
            phi, theta, omega = [-p for p in reversed(op.parameters)]
        else:
            phi, theta, omega = op.parameters
        
        q = self.qubits[wire]
        qN = ComplexQubit(f'q_{wire}_rot')
        cos_half_theta = Placeholder(f'cos_rot_half_theta_{wire}')
        sin_half_theta = Placeholder(f'sin_rot_half_theta_{wire}')
        cos_phi_plus_omega = Placeholder(f'cos_phase_phi_plus_omega_{wire}')
        sin_phi_plus_omega = Placeholder(f'sin_phase_phi_plus_omega_{wire}')
        cos_phi_minus_omega = Placeholder(f'cos_phase_phi_minus_omega_{wire}')
        sin_phi_minus_omega = Placeholder(f'sin_phase_phi_minus_omega_{wire}')
        real_part = (cos_half_theta * cos_phi_plus_omega * q.real
                    + sin_half_theta * sin_phi_plus_omega * q.imag
                    - sin_half_theta * cos_phi_minus_omega * q.imag
                    - cos_half_theta * sin_phi_plus_omega * q.real)
        
        imag_part = (cos_half_theta * sin_phi_plus_omega * q.imag
                    + sin_half_theta * cos_phi_minus_omega * q.real
                    + sin_half_theta * sin_phi_minus_omega * q.imag
                    + cos_half_theta * cos_phi_plus_omega * q.real)
        constREAL = (qN.real - real_part)**2
        constIMAG = (qN.imag - imag_part)**2
        hamiltonian = scale * Constraint(constREAL + constIMAG, label=f'Rot_{wire}')
        self.hamiltonian += hamiltonian
        self.qubits[wire] = qN
        self.placeholders[f'cos_rot_half_theta_{wire}'] = float(np.cos(theta/2))
        self.placeholders[f'sin_rot_half_theta_{wire}'] = float(np.sin(theta/2))
        half_phi_plus_omega = (phi + omega)/2
        half_phi_minus_omega = (phi - omega)/2
        self.placeholders[f'cos_phase_phi_plus_omega_{wire}'] = float(np.cos(half_phi_plus_omega))
        self.placeholders[f'sin_phase_phi_plus_omega_{wire}'] = float(np.sin(half_phi_plus_omega))
        self.placeholders[f'cos_phase_phi_minus_omega_{wire}'] = float(np.cos(half_phi_minus_omega))
        self.placeholders[f'sin_phase_phi_minus_omega_{wire}'] = float(np.sin(half_phi_minus_omega))

    
    def _SWAP(self, op, scale):
        wire0, wire1 = op.wires
        q1, q2 = self.qubits[wire0], self.qubits[wire1]
        q0N, q1N = ComplexQubit(f'q_{wire0}_swap'), ComplexQubit(f'q_{wire1}_swap')
        constREAL0 = (q0N.real - q2.real)**2
        constIMAG0 = (q0N.imag - q2.imag)**2
        constREAL1 = (q1N.real - q1.real)**2
        constIMAG1 = (q1N.imag - q1.imag)**2
        hamiltonian = scale * Constraint(
            constREAL0 + constIMAG0 + constREAL1 + constIMAG1,
            label=f'SWAP_{wire0}_{wire1}')
        self.hamiltonian += hamiltonian
        self.qubits[wire0], self.qubits[wire1] = q0N, q1N
        #print("[SamDEBUG] :: SWAP gate is DONE")

    def _CRX(self, op, scale, adjoint=False):
        control, target = op.wires
        ang = op.parameters[0] * (-1 if adjoint else 1)
        c, t = self.qubits[control], self.qubits[target]
        tN = ComplexQubit(f'q_{target}_crx')
        cos = Placeholder(f'cos_crx_{control}_{target}')
        sin = Placeholder(f'sin_crx_{control}_{target}')
        constREAL = (tN.real - (c.real * (cos * t.real - sin * t.imag) + (1 - c.real) * t.real))**2
        constIMAG = (tN.imag - (c.real * (sin * t.real + cos * t.imag) + (1 - c.real) * t.imag))**2
        hamiltonian = scale * Constraint(constREAL + constIMAG, label=f'CRX_{control}_{target}')
        self.hamiltonian += hamiltonian
        self.qubits[target] = tN
        self.placeholders[f'cos_crx_{control}_{target}'] = float(np.cos(ang/2))
        self.placeholders[f'sin_crx_{control}_{target}'] = float(np.sin(ang/2))
        #print("[SamDEBUG] :: Controlled-Rotate-X gate is DONE")

    def _CRY(self, op, scale, adjoint=False):
        control, target = op.wires
        ang = op.parameters[0] * (-1 if adjoint else 1)
        c, t = self.qubits[control], self.qubits[target]
        tN = ComplexQubit(f'q_{target}_cry')
        cos = Placeholder(f'cos_cry_{control}_{target}')
        sin = Placeholder(f'sin_cry_{control}_{target}')
        constREAL = (tN.real - (c.real * (cos * t.real - sin * (1 - t.real)) + (1 - c.real) * t.real))**2
        constIMAG = (tN.imag - (c.real * (cos * t.imag + sin * t.imag) + (1 - c.real) * t.imag))**2
        hamiltonian = scale * Constraint(constREAL + constIMAG, label=f'CRY_{control}_{target}')
        self.hamiltonian += hamiltonian
        self.qubits[target] = tN
        self.placeholders[f'cos_cry_{control}_{target}'] = float(np.cos(ang/2))
        self.placeholders[f'sin_cry_{control}_{target}'] = float(np.sin(ang/2))
        #print("[SamDEBUG] :: Controlled-Rotate-Y gate is DONE")

    def _CRZ(self, op, scale, adjoint=False):
        control, target = op.wires
        ang = op.parameters[0] * (-1 if adjoint else 1)
        c, t = self.qubits[control], self.qubits[target]
        tN = ComplexQubit(f'q_{target}_crz')
        cos = Placeholder(f'cos_crz_{control}_{target}')
        sin = Placeholder(f'sin_crz_{control}_{target}')
        constREAL = (tN.real - (c.real * (cos * t.real + sin * t.imag) + (1 - c.real) * t.real))**2
        constIMAG = (tN.imag - (c.real * (cos * t.imag - sin * t.real) + (1 - c.real) * t.imag))**2
        hamiltonian = scale * Constraint(constREAL + constIMAG, label=f'CRZ_{control}_{target}')
        self.hamiltonian += hamiltonian
        self.qubits[target] = tN
        self.placeholders[f'cos_crz_{control}_{target}'] = float(np.cos(ang/2))
        self.placeholders[f'sin_crz_{control}_{target}'] = float(np.sin(ang/2))
        #print("[SamDEBUG] :: Controlled-Rotate-Z gate is DONE")

    def _ANGLEEMBEDDING(self, op, scale):
        features = op.parameters[0]        
        rotation_type = op.hyperparameters.get('rotation')        
        if hasattr(rotation_type, '__name__'):
            rotation_name = rotation_type.__name__
        elif isinstance(rotation_type, str):
            rotation_name = rotation_type
        else:
            rotation_name = str(rotation_type)
        if 'RX' in rotation_name or rotation_name == 'X':
            rotation_handler = self._RX
            rotation_constructor = qml.RX
        elif 'RY' in rotation_name or rotation_name == 'Y':
            rotation_handler = self._RY
            rotation_constructor = qml.RY
        elif 'RZ' in rotation_name or rotation_name == 'Z':
            rotation_handler = self._RZ
            rotation_constructor = qml.RZ
        else:
            raise ValueError(f"Unknown rotation type in AngleEmbedding: {rotation_name}")
        for i, wire in enumerate(op.wires):
            if i >= len(features):
                break
            sim_op = rotation_constructor(features[i], wires=wire)
            rotation_handler(sim_op, scale)
        self.isAE = True
        #print(f"[SamDEBUG] :: AngleEmbedding with {rotation_name} rotations on wires {op.wires}")


    def _AMPLITUDEEMBEDDING(self, op, scale):
        features = op.parameters[0]
        pad_with = op.hyperparameters.get('pad_with', None)
        normalize = op.hyperparameters.get('normalize', False)
        validate_norm = op.hyperparameters.get('validate_norm', True)
        n_qubits = len(op.wires)
        required_dim = 2**n_qubits
        if pad_with is not None and len(features) < required_dim:
            padding_size = required_dim - len(features)
            if hasattr(features, 'shape'):  # For array-like objects
                padding_shape = list(features.shape)
                padding_shape[0] = padding_size
                padding = np.full(padding_shape, pad_with, dtype=complex if np.iscomplexobj(features) else float)
                features = np.concatenate([features, padding])
            else:  
                features = list(features) + [pad_with] * padding_size
        if len(features) != required_dim:
            raise ValueError(f"Feature vector must be of length {required_dim} or smaller with padding; got {len(features)}.")
        if normalize:
            norm = np.sqrt(np.sum(np.abs(features)**2))
            if norm > 0:
                features = features / norm
        
        if validate_norm:
            norm = np.sqrt(np.sum(np.abs(features)**2))
            if not np.isclose(norm, 1.0, atol=1e-10):
                raise ValueError(f"Features must be normalized (norm = 1); got norm = {norm}.")
        for wire in op.wires:
            q = self.qubits[wire]
            qN = ComplexQubit(f'q_{wire}_amplitude_init')
            constREAL = (qN.real - 1)**2
            constIMAG = qN.imag**2
            hamiltonian = scale * Constraint(constREAL + constIMAG, label=f'AmplitudeInit_{wire}')
            self.hamiltonian += hamiltonian
            self.qubits[wire] = qN
        constraints = []
        
        for i in range(required_dim):
            binary_rep = format(i, f'0{n_qubits}b')
            target_amplitude = features[i]
            target_real = np.real(target_amplitude)
            target_imag = np.imag(target_amplitude)
            state_term_real = 1.0
            state_term_imag = 0.0
            
            for j, bit in enumerate(binary_rep):
                wire = op.wires[j]
                q = self.qubits[wire]
                if bit == '0':
                    state_term_real *= (1 - q.real)
                    state_term_imag *= q.imag
                else:
                    state_term_real *= q.real
                    state_term_imag *= (1 - q.imag)
            amplitude_constraint_real = (state_term_real - target_real)**2
            amplitude_constraint_imag = (state_term_imag - target_imag)**2
            constraints.append(amplitude_constraint_real + amplitude_constraint_imag)
        hamiltonian = scale * Constraint(
            sum(constraints), label=f'AmplitudeEmbedding_{op.wires}')
        self.hamiltonian += hamiltonian
        #print(f"[SamDEBUG] :: AmplitudeEmbedding with {len(features)} features on {n_qubits} qubits")


    def _TOFFOLI(self, op, scale):
        control1, control2, target = op.wires
        c1, c2, t = self.qubits[control1], self.qubits[control2], self.qubits[target]
        tN = ComplexQubit(f'q_{target}_toffoli')
        constREAL = (And(And(c1.real, c2.real), t.real) - tN.real)**2
        constIMAG = (t.imag - tN.imag)**2
        hamiltonian = scale * Constraint(constREAL + constIMAG, label=f'Toffoli_{control1}_{control2}_{target}')
        self.hamiltonian += hamiltonian
        self.qubits[target] = tN
        #print("[SamDEBUG] :: Toffoli gate is DONE")

    def _T(self, op, scale, adjoint=False):
        wire = op.wires[0]
        q = self.qubits[wire]
        qN = ComplexQubit(f'q_{wire}_t')
        c = Placeholder(f'cos_t_{wire}')
        s = Placeholder(f'sin_t_{wire}')
        constREAL = (qN.real - (c * q.real - s * q.imag))**2
        constIMAG = (qN.imag - (s * q.real + c * q.imag))**2
        hamiltonian = scale * Constraint(constREAL + constIMAG, label=f'T_{wire}')
        self.hamiltonian += hamiltonian
        self.qubits[wire] = qN
        phaseSH = np.pi / 4 * (-1 if adjoint else 1)
        self.placeholders[f'cos_t_{wire}'] = float(np.cos(phaseSH))
        self.placeholders[f'sin_t_{wire}'] = float(np.sin(phaseSH))
        #print("[SamDEBUG] :: T gate is DONE")

    def __update__(self, wire, nSTATE):
        self.qubits[wire] = nSTATE
        
    def _BASISSTATE(self, op, scale):
        bSTATE = op.parameters[0]
        wires = op.wires
        constraints = []
        biSTATE = ''.join(map(str, bSTATE))
        for wire, bit in zip(wires, biSTATE):
            qN = ComplexQubit(f'q_{wire}_basis')
            if bit == '1':
                constraints.append(qN.real**2 + (qN.imag - 1)**2)
            else:  # bit == '0'
                constraints.append((qN.real - 1)**2 + qN.imag**2)
        hamiltonian = scale * Constraint(sum(constraints), label=f'BasisState_{wires}')
        self.hamiltonian += hamiltonian
        self.qubits[wire] = qN
        print(f"[SamDEBUG] :: Basis State preparation on wires {wires} with state {biSTATE}")
        
    def _QFT(self, op, scale, adjoint=False):
        if isinstance(op, Adjoint):
            return self._QFT(op.bOP, scale, adjoint=True)
        wires = op.wires
        n = len(wires)
        constraints = []
        operations = []
        for i in range(n):
            operations.append(('H', i))
            for j in range(i + 1, n):
                angle = 2 * np.pi / 2**(j - i + 1)
                operations.append(('CP', i, j, angle))
        for i in range(n // 2):
            operations.append(('SWAP', i, n - 1 - i))
        if adjoint:
            operations = [('Adjoint', op) for op in reversed(operations)]
        for op in operations:
            if op[0] == 'H':
                constraints.extend(self.__hQFT(op[1]))
            elif op[0] == 'CP':
                constraints.extend(self.__cpQFT(op[1], op[2], op[3]))
            elif op[0] == 'SWAP':
                constraints.extend(self.__swapQFT(op[1], op[2]))
            elif op[0] == 'Adjoint':
                constraints.extend(self.__daggerQFT(op[1]))
    
        hamiltonian = scale * Constraint(sum(constraints), label=f'QFT_{wires}')
        self.hamiltonian += hamiltonian
        print(f"[SamDEBUG] :: QFT{'+' if adjoint else ''} on wires {wires}")
        
    
    def __hQFT(self, wire):
        q = self.qubits[wire]
        qN = ComplexQubit(f'q_{wire}_h')
        constraints = [
            (qN.real - (q.real + q.imag) / np.sqrt(2))**2,
            (qN.imag - (q.imag - q.real) / np.sqrt(2))**2
        ]
        self.__update__(wire, qN)
        return constraints
    
    def __cpQFT(self, control, target, angle):
        _, t = self.qubits[control], self.qubits[target]
        tN = ComplexQubit(f'q_{target}_cp')
        phaseREAL, phaseIMAG = np.cos(angle), np.sin(angle)
        constraints = [
            (tN.real - (t.real * phaseREAL - t.imag * phaseIMAG))**2,
            (tN.imag - (t.real * phaseIMAG + t.imag * phaseREAL))**2
        ]
        self.__update__(target, tN)
        return constraints
    
    def __swapQFT(self, wire1, wire2):
        q0, q1 = self.qubits[wire1], self.qubits[wire2]
        q0N, q1N = ComplexQubit(f'q_{wire1}_swap'), ComplexQubit(f'q_{wire2}_swap')
        constraints = [
            (q0N.real - q1.real)**2, (q0N.imag - q1.imag)**2,
            (q1N.real - q0.real)**2, (q1N.imag - q0.imag)**2
        ]
        self.__update__(wire1, q0N)
        self.__update__(wire2, q1N)
        return constraints
    
    def __daggerQFT(self, op):
        if op[0] == 'H':
            return self.__hQFT(op[1])
        elif op[0] == 'CP':
            return self.__cpQFT(op[1], op[2], -op[3])
        elif op[0] == 'SWAP':
            return self.__swapQFT(op[1], op[2])

    def _BASISEMB(self, op, scale):
        hyperP = op.parameters[0]
        wires = op.wires
        n = len(wires)
        params = self.params
        constraints = []
        enParams = [format(param, f'0{n}b') for param in params]
        hyperPstr = ''.join(map(str, hyperP))
        paramIdx = -1
        for idx, biPara in enumerate(enParams):
            if biPara == hyperPstr:
                paramIdx = idx
                break
            elif idx == (len(enParams) - 1) and paramIdx == -1: # Please use this carefully (This is an "overclock" for the hyperparameters as Pennylane limits the BE higher than 100 qubits,
                hyperPstr = biPara                              # this is a manual translation for the basis state up to 5000 qubits)
                paramIdx = idx                                  # PLEASE be careful when you release this to public (tag the whole `elif`)
        rest = params[1 - paramIdx]
        oRestPara = list(map(int, format(rest, f'0{n}b')))
        restBi = [hyperP[i] + oRestPara[i] for i in range(n)]
        phase = 0
        fRest = []
        for bit in reversed(restBi):
            total = bit + phase
            fRest.append(total % 2)
            phase = total // 2
        basisState = list(fRest)
        for wire, bit in zip(wires, basisState):
            qN = ComplexQubit(f'q_{wire}_basis')
            if bit == 1:
                constREAL = (qN.real - (1 - self.qubits[wire].real))**2
                constIMAG = (qN.imag - (-self.qubits[wire].imag))**2
            else:
                constREAL = (qN.real - self.qubits[wire].real)**2
                constIMAG = (qN.imag - self.qubits[wire].imag)**2
            constraints.append(constREAL + constIMAG)
            self.__update__(wire, qN)
        hamiltonian = scale * Constraint(sum(constraints), label=f'BasisEmbedding_{wires}')
        self.hamiltonian += hamiltonian
        print(f"[SamDEBUG] :: Basis Embedding on wires {wires}")

    def _FLIPSIGN(self, op, scale):
        wire = op.wires[0]
        q = self.qubits[wire]
        qN = ComplexQubit(f'q_{wire}_flipsign')
        constREAL = (qN.real - q.real)**2
        constIMAG = (qN.imag - (1 - 2 * q.real) * q.imag)**2
        hamiltonian = scale * self.penaltyScale['FlipSign'] * Constraint(
            constREAL + constIMAG,
            label=f'FlipSign_{wire}')
        self.hamiltonian += hamiltonian
        self.qubits[wire] = qN
        #print(f"[SamDEBUG] :: FlipSign gate applied on wire {wire}")

    def _CTRL(self, op, controlled_op, scale): # since I can retrieve the matrix from controlled gate (Pennylane Docs) I don't need to apply the operations manually
        cW = op.control_wires
        tW = op.wires[len(cW):]
        contPRODUCT = And(*[self.qubits[wire].real for wire in cW])
        if hasattr(controlled_op, 'matrix'):
            U = controlled_op.matrix()
            #print(U)
        else:
            raise ValueError(f"Controlled operation {controlled_op.name} does not have a matrix representation")
        constraints = []
        for i, target in enumerate(tW):
            q = self.qubits[target]
            qN = ComplexQubit(f'q_{target}_ctrl')
            for j in range(U.shape[1]):
                constraints.append(((1 - contPRODUCT) * (qN.real - q.real))**2)
                uR, uI = U[i,j].real, U[i,j].imag
                constraints.append((contPRODUCT * (qN.real - (uR * q.real - uI * q.imag)))**2)
                constraints.append((contPRODUCT * (qN.imag - (uR * q.imag + uI * q.real)))**2)
            self.qubits[target] = qN
        cLABEL = '_'.join(map(str, cW))
        tLABEL = '_'.join(map(str, tW))
        hamiltonian = scale * self.penaltyScale['Ctrl'] * Constraint(sum(constraints), label=f'CTRL_{controlled_op.name}_{cLABEL}_{tLABEL}')
        self.hamiltonian += hamiltonian
        #print(f"[SamDEBUG] :: Controlled {controlled_op.name} gate applied with control wires {control_wires} and target wires {target_wires}")

    def _GROVER(self, op, scale):
        wires = op.wires
        workWires = op.hyperparameters['work_wires']
        #n = len(wires)
        #print(f"[SamDEBUG] :: Grover Operator applied on wires {wires}, with work wires {work_wires}")
        for wire in wires[:-1]:
            self._H(qml.Hadamard(wires=wire), scale)
        self._PAULIZ(qml.PauliZ(wires=wires[-1]), scale)
        self._MULTICONTROLLEDX(wires, workWires, scale)
        self._PAULIZ(qml.PauliZ(wires=wires[-1]), scale)
        for wire in wires[:-1]:
            self._H(qml.Hadamard(wires=wire), scale)
        self._GLOBALPHASE(np.pi, wires, scale)
        self.isGrover = True # Activate New Decoding (Force Flipping)
    
    def _MULTICONTROLLEDX(self, wires, work_wires, scale):
        cW = wires[:-1]
        tW = wires[-1]
        contTERMS = [(1 - self.qubits[wire].real) for wire in cW]
        if len(contTERMS) > 1:
            contPRODUCT = contTERMS[0]
            for term in contTERMS[1:]:
                contPRODUCT = And(contPRODUCT, term)
        elif len(contTERMS) == 1:
            contPRODUCT = contTERMS[0]
        else:
            raise ValueError("MultiControlledX gate must have at least one control qubit")
        qT = self.qubits[tW]
        qN = ComplexQubit(f'q_{tW}_mcx')
        flipCond = Xor(contPRODUCT, qT.real)
        constREAL = (qN.real - flipCond)**2
        constIMAG = (qN.imag - qT.imag)**2
        hamiltonian = scale * self.penaltyScale['GroverOperator'] * Constraint(
            constREAL + constIMAG,
            label=f'MultiControlledX_{",".join(map(str, wires))}')
        self.hamiltonian += hamiltonian
        self.qubits[tW] = qN
    
    def _GLOBALPHASE(self, phase, wires, scale):
        for wire in wires:
            q = self.qubits[wire]
            qN = ComplexQubit(f'q_{wire}_globalphase')
            constREAL = (qN.real - (np.cos(phase) * q.real - np.sin(phase) * q.imag))**2
            constIMAG = (qN.imag - (np.sin(phase) * q.real + np.cos(phase) * q.imag))**2
            hamiltonian = scale * self.penaltyScale['GroverOperator'] * Constraint(
                constREAL + constIMAG,
                label=f'GlobalPhase_{wire}')
            self.hamiltonian += hamiltonian
            self.qubits[wire] = qN


    def _STRONGLYENTANGLING(self, op, scale):
        w = op.parameters[0]  
        wires = op.wires  
        ranges = op.hyperparameters.get('ranges', None)
        imp = op.hyperparameters.get('imprimitive', qml.CNOT)
        if ranges is None:
            nL = len(w)
            if len(wires) > 1:
                ranges = tuple((l % (len(wires) - 1)) + 1 for l in range(nL))
            else:
                ranges = (0,) * nL
        nL = len(w)
        consts = []
        for l in range(nL):
            for i, wire in enumerate(wires):
                rotOP = qml.Rot(
                    w[l, i, 0], #phi
                    w[l, i, 1], #theta
                    w[l, i, 2], #omega
                    wires=wire
                )
                self._ROT(rotOP, scale)            
            if len(wires) > 1:
                for i in range(len(wires)):
                    j = (i + ranges[l]) % len(wires)                    
                    if imp.__name__ == 'CNOT':
                        cnotOP = qml.CNOT(wires=[wires[i], wires[j]])
                        self._CNOT(cnotOP, scale)
                    elif imp.__name__ == 'CZ':
                        czOP = qml.CZ(wires=[wires[i], wires[j]])
                        self._CZ(czOP, scale)
                    elif imp.__name__ == 'CRX':
                        crxOP = qml.CRX(np.pi/4, wires=[wires[i], wires[j]])
                        self._CRX(crxOP, scale)
                    elif imp.__name__ == 'CRY':
                        cryOP = qml.CRY(np.pi/4, wires=[wires[i], wires[j]])
                        self._CRY(cryOP, scale)
                    elif imp.__name__ == 'CRZ':
                        crzOP = qml.CRZ(np.pi/4, wires=[wires[i], wires[j]])
                        self._CRZ(crzOP, scale)
                    elif imp.__name__ == 'SWAP':
                        swapOP = qml.SWAP(wires=[wires[i], wires[j]])
                        self._SWAP(swapOP, scale)
                    else:
                        gOP = imp(wires=[wires[i], wires[j]])
                        gType = gOP.name
                        if gType in self.gHandlers:
                            self.gHandlers[gType](gOP, scale)
                        else:
                            raise ValueError(f"Unsupported imprimitive gate: {gType}")
        #print(f"[SamDEBUG] :: StronglyEntanglingLayers with {nL} layers applied on wires {wires}")

    # ---------------------------- DEPRECATED ------------------------------
    #def _QUBITUNITARY(self, op, scale):
    #    wires = op.wires
    #    U = op.parameters[0]
    #    constraints = []
    #    for i, wire in enumerate(wires):
    #        q = self.qubits[wire]
    #        qN = ComplexQubit(f'q_{wire}_unitary')
    #        for j in range(U.shape[1]):
    #            uR, uI = U[i,j].real, U[i,j].imag
    #            constraints.append((qN.real - (uR * q.real - uI * q.imag))**2)
    #            constraints.append((qN.imag - (uR * q.imag + uI * q.real))**2)
    #        self.qubits[wire] = qN
    #    hamiltonian = scale * Constraint(sum(constraints), label=f'QubitUnitary_{wires}')
    #    self.hamiltonian += hamiltonian
    # ---------------------------- DEPRECATED ------------------------------

    def _QUBITUNITARY(self, op, scale):
        wires = op.wires
        U = op.parameters[0]
        if len(qml.math.shape(U)) == 3:  # Since I noticed that U.M takes different shapes, we need to batch them
            BATCH, dim, _ = qml.math.shape(U)
        else:  # non-batch mode
            BATCH, dim, _ = 1, *qml.math.shape(U)
        constraints = []
        for i, wire in enumerate(wires):
            q = self.qubits[wire]
            qN = ComplexQubit(f'q_{wire}_unitary')
            for j in range(dim): # applied based on U shape and batch size
                if BATCH > 1:
                    uR = Placeholder(f'U_{i}_{j}_real')
                    uI = Placeholder(f'U_{i}_{j}_imag')
                    self.placeholders[f'U_{i}_{j}_real'] = U[:, i, j].real
                    self.placeholders[f'U_{i}_{j}_imag'] = U[:, i, j].imag
                else:
                    uR, uI = U[i, j].real, U[i, j].imag
                constraints.append(((qN.real - uR * q.real) + (uI * q.imag))**2) # unitary transformation
                constraints.append(((qN.imag - uR * q.imag) - (uI * q.real))**2) #unitary transformation
            constraints.append((qN.real**2 + qN.imag**2 - 0.5)**2) # non-trivial states
            constraints.append((qN.real - q.real)**2 + (qN.imag - q.imag)**2)
            self.qubits[wire] = qN
        if len(wires) > 1: # enforce entanglement between adjacent qubits
            for i in range(len(wires) - 1):
                q1, q2 = self.qubits[wires[i]], self.qubits[wires[i+1]]
                constraints.append((q1.real * q2.imag - q1.imag * q2.real)**2)
        constraints.append((qml.math.prod([1 - self.qubits[wire].real for wire in wires]) + np.exp(len(qml.math.shape(U))))**2) # we need to scale up the constraint based on the dimension of U (due to Dynex nature)
        hamiltonian = scale * Constraint(sum(constraints), label=f'QubitUnitary_{wires}')
        self.hamiltonian += hamiltonian
        self.isQU = True  # enforce the qubit flipping 

    def _CONTROLLEDQUBITUNITARY(self, op, scale):
        cW = op.control_wires
        tW = op.wires
        U = op.parameters[0]
        if len(cW) == 1:
            contPRODUCT = self.qubits[cW[0]].real
        else:
            contPRODUCT = self.qubits[cW[0]].real
            for wire in cW[1:]:
                contPRODUCT = contPRODUCT * self.qubits[wire].real
        constraints = []
        for i, wire in enumerate(tW):
            q = self.qubits[wire]
            qN = ComplexQubit(f'q_{wire}_controlled_unitary')
            for j in range(U.shape[1]):
                uR = Placeholder(f'U_{i}_{j}_real')
                uI = Placeholder(f'U_{i}_{j}_imag')
                self.placeholders[f'U_{i}_{j}_real'] = float(U[i,j].real)
                self.placeholders[f'U_{i}_{j}_imag'] = float(U[i,j].imag)
                constraints.append(((1 - contPRODUCT) * (qN.real - q.real))**2)
                constraints.append(((1 - contPRODUCT) * (qN.imag - q.imag))**2)
                constraints.append((contPRODUCT * (qN.real - (uR * q.real - uI * q.imag)))**2)
                constraints.append((contPRODUCT * (qN.imag - (uR * q.imag + uI * q.real)))**2)
            self.qubits[wire] = qN
        hamiltonian = scale * Constraint(sum(constraints), label=f'ControlledQubitUnitary_{cW}_{tW}')
        self.hamiltonian += hamiltonian
        #print(f"[SamDEBUG] :: ControlledQubitUnitary applied with control wires {control_wires} and target wires {target_wires}")

    def _QPE(self, op, scale):
        tW = op.target_wires
        estW = op.estimation_wires
        nEstW = len(estW)
        unitary = op.hyperparameters["unitary"]
        def _CTRL_QPE(control_wire, target_wires, unitary, scale):
            control_qubit = self.qubits[control_wire].real
            U = qml.matrix(unitary)
            constraints = []
            for i, target in enumerate(target_wires):
                q = self.qubits[target]
                qN = ComplexQubit(f'q_{target}_ctrl_qpe')
                for j in range(U.shape[1]):
                    u_real, u_imag = U[i,j].real, U[i,j].imag
                    constraints.append(((1 - control_qubit) * (qN.real - q.real))**2)
                    constraints.append((control_qubit * (qN.real - (u_real * q.real - u_imag * q.imag)))**2)
                    constraints.append((control_qubit * (qN.imag - (u_real * q.imag + u_imag * q.real)))**2)
                self.qubits[target] = qN
            hamiltonian = scale * Constraint(sum(constraints), label=f'CTRL_QPE_{control_wire}_{target_wires}')
            self.hamiltonian += hamiltonian
        for wire in estW:
            self._H(qml.Hadamard(wires=wire), scale)
        for i, estimation_wire in enumerate(estW):
            power = 2 ** (nEstW - i - 1)
            for _ in range(power):
                _CTRL_QPE(estimation_wire, tW, unitary, scale)
        self._QFT(qml.QFT(wires=estW), scale, adjoint=True)
        constraints = []
        for i, wire in enumerate(estW):
            q = self.qubits[wire]
            phaseBi = ComplexQubit(f'phase_bit_{i}')
            constraints.append((phaseBi.real - q.real)**2)
            constraints.append((phaseBi.imag - q.imag)**2)
        hamiltonian = scale * Constraint(
            sum(constraints), label=f'QPE_{tW}_{estW}')
        self.hamiltonian += hamiltonian
        self.isQPE = True
        #print(f"[SamDEBUG] :: Quantum Phase Estimation applied on target wire {target_wires} with estimation wires {estimation_wires}")

    def _CONTROLLEDPHASESHIFT(self, op, scale):
        control, target = op.wires
        phi = op.parameters[0]
        c, t = self.qubits[control], self.qubits[target]
        tN = ComplexQubit(f'q_{target}_cps')
        cosPHI = Placeholder(f'cos_cps_{control}_{target}')
        sinPHI = Placeholder(f'sin_cps_{control}_{target}')  
        constREAL = (tN.real - (t.real * (1 - c.real) + (cosPHI * t.real + sinPHI * t.imag) * c.real))**2
        constIMAG = (tN.imag - (t.imag * (1 - c.real) + (cosPHI * t.imag - sinPHI * t.real) * c.real))**2
        hamiltonian = scale * self.penaltyScale['ControlledPhaseShift'] * Constraint(
            constREAL + constIMAG,
            label=f'ControlledPhaseShift_{control}_{target}')
        self.hamiltonian += hamiltonian
        self.qubits[target] = tN
        self.placeholders[f'cos_cps_{control}_{target}'] = float(np.cos(phi))
        self.placeholders[f'sin_cps_{control}_{target}'] = float(np.sin(phi))

    def _S(self, op, scale, adjoint=False):
        # technically S gate introduces a pi/2 phase shift when adjoint=False and a -pi/2 phase shift when adjoint=True
        wire = op.wires[0]
        q = self.qubits[wire]
        qN = ComplexQubit(f'q_{wire}_s')
        phase = -np.pi/2 if adjoint else np.pi/2
        c = Placeholder(f'cos_s_{wire}')
        s = Placeholder(f'sin_s_{wire}')
        constREAL = (qN.real - (c * q.real - s * q.imag))**2 
        constIMAG = (qN.imag - (s * q.real + c * q.imag))**2  
        normConstraint = (qN.real**2 + qN.imag**2 - (q.real**2 + q.imag**2))**2
        hamiltonian = scale * Constraint(
            constREAL + constIMAG + normConstraint,
            label=f'S_{wire}'
        )
        self.hamiltonian += hamiltonian
        self.qubits[wire] = qN
        self.placeholders[f'cos_s_{wire}'] = float(np.cos(phase))
        self.placeholders[f'sin_s_{wire}'] = float(np.sin(phase))

    def _SX(self, op, scale, adjoint=False):
        wire = op.wires[0]
        q = self.qubits[wire]
        qN = ComplexQubit(f'q_{wire}_sx')
        a_real = Placeholder(f'sx_a_real_{wire}')  # (1+i)/2 
        a_imag = Placeholder(f'sx_a_imag_{wire}')  # (1+i)/2  //for later
        b_real = Placeholder(f'sx_b_real_{wire}')  # (1-i)/2 
        b_imag = Placeholder(f'sx_b_imag_{wire}')  # (1-i)/2  //for later
        if not adjoint:
            constREAL = (qN.real - (a_real * q.real + b_real * q.imag))**2 # [qN.real] = [a_real  b_real] [q.real]
            constIMAG = (qN.imag - (b_real * q.real + a_real * q.imag))**2 # [qN.imag]   [b_real  a_real] [q.imag]
        else: # SX dagger transformation
            constREAL = (qN.real - (a_real * q.real - b_real * q.imag))**2 # [qN.real] = [a_real   -b_real] [q.real]
            constIMAG = (qN.imag - (-b_real * q.real + a_real * q.imag))**2 # [qN.imag]   [-b_real   a_real] [q.imag]
        unitaryConstraint = (qN.real**2 + qN.imag**2 - (q.real**2 + q.imag**2))**2
        phaseConstraint = (qN.real * q.imag - qN.imag * q.real)**2
        hamiltonian = scale * Constraint(
            constREAL + constIMAG + unitaryConstraint + phaseConstraint,
            label=f'SX_{wire}')
        self.hamiltonian += hamiltonian
        self.qubits[wire] = qN
        if not adjoint:
            self.placeholders[f'sx_a_real_{wire}'] = 0.5  # (1+i)/2
            self.placeholders[f'sx_a_imag_{wire}'] = 0.5  # (1+i)/2
            self.placeholders[f'sx_b_real_{wire}'] = 0.5  # (1-i)/2
            self.placeholders[f'sx_b_imag_{wire}'] = -0.5  # (1-i)/2
        else:
            self.placeholders[f'sx_a_real_{wire}'] = 0.5  # (1-i)/2
            self.placeholders[f'sx_a_imag_{wire}'] = -0.5  # (1-i)/2
            self.placeholders[f'sx_b_real_{wire}'] = 0.5  # (1+i)/2
            self.placeholders[f'sx_b_imag_{wire}'] = 0.5  # (1+i)/2

    def to_qasm(self):
        qasm = "OPENQASM 2.0;\ninclude \"qelib1.inc\";\n\n"
        n = self.wires
        qasm += f"qreg q[{n}];\n"
        qasm += f"creg c[{n}];\n\n"
        g2q = {
            'PauliX': lambda op: f"x q[{op.wires[0]}];",
            'PauliY': lambda op: f"y q[{op.wires[0]}];",
            'PauliZ': lambda op: f"z q[{op.wires[0]}];",
            'Hadamard': lambda op: f"h q[{op.wires[0]}];",
            'RX': lambda op: f"rx({op.parameters[0]}) q[{op.wires[0]}];",
            'RY': lambda op: f"ry({op.parameters[0]}) q[{op.wires[0]}];",
            'RZ': lambda op: f"rz({op.parameters[0]}) q[{op.wires[0]}];",
            'CNOT': lambda op: f"cx q[{op.wires[0]}],q[{op.wires[1]}];",
            'CZ': lambda op: f"cz q[{op.wires[0]}],q[{op.wires[1]}];",
            'SWAP': lambda op: f"swap q[{op.wires[0]}],q[{op.wires[1]}];",
            'T': lambda op: f"t q[{op.wires[0]}];",
            'S': lambda op: f"s q[{op.wires[0]}];",
            'CRX': lambda op: f"crx({op.parameters[0]}) q[{op.wires[0]}],q[{op.wires[1]}];",
            'CRY': lambda op: f"cry({op.parameters[0]}) q[{op.wires[0]}],q[{op.wires[1]}];",
            'CRZ': lambda op: f"crz({op.parameters[0]}) q[{op.wires[0]}],q[{op.wires[1]}];",
            'Toffoli': lambda op: f"ccx q[{op.wires[0]}],q[{op.wires[1]}],q[{op.wires[2]}];",
            'FlipSign': lambda op: f"z q[{op.wires[0]}];",
        }      
        def _CONTROLLED_(op):
            bOP = op.base
            cW = op.control_wires
            #tW = op.wires[len(cW):]
            if bOP.name in g2q:
                bQ = g2q[bOP.name](bOP).rstrip(';')
                cont = ','.join([f"q[{w}]" for w in cW])
                return f"c{bQ},{cont};"
            else:
                return f"// Unsupported controlled operation at the moment (TBD): {op.name}"
        def _QFT_(op, adj=False):
            w = list(op.wires) 
            qftQASM = []
            n = len(w)
            if adj:
                w = list(reversed(w))
                rFUNC = lambda: range(n-1, -1, -1)
            else:
                rFUNC = lambda: range(n)
            for i in rFUNC():
                qftQASM.append(f"h q[{w[i]}];")
                for j in range(i+1, n):
                    angle = np.pi / 2**(j-i)
                    qftQASM.append(f"cu1({angle}) q[{w[j]}],q[{w[i]}];")
            if not adj:
                for i in range(n//2):
                    qftQASM.append(f"swap q[{w[i]}],q[{w[n-1-i]}];")
            return '\n'.join(qftQASM)
        with qml.tape.QuantumTape() as tape:
            self.circuit(self.params)
        for op in tape.operations:
            if op.name in g2q:
                qasm += g2q[op.name](op) + '\n'
            elif isinstance(op, qml.ops.QubitUnitary):
                m = op.matrix
                w = op.wires
                qasm += f"// Custom unitary on wires {w}\n"
                qasm += f"unitary({m.tolist()}) q[{','.join(map(str, w))}];\n"
            elif op.name == 'QFT':
                qasm += _QFT_(op) + '\n'
            elif op.name == 'Adjoint(QFT)':
                qasm += _QFT_(op.base, adj=True) + '\n'
            elif op.name.startswith('C('):
                qasm += _CONTROLLED_(op) + '\n'
            elif op.name == 'BasisState':
                state = ''.join(map(str, op.parameters[0]))
                qasm += f"// Prepare basis state |{state}>\n"
                for i, bit in enumerate(state):
                    if bit == '1':
                        qasm += f"x q[{op.wires[i]}];\n"
            elif op.name == 'BasisEmbedding':
                qasm += f"// BasisEmbedding not directly supported at the moment (TBD)\n"
                qasm += f"// Implemented as a series of X gates\n"
                state = op.hyperparameters['basis_state']
                for i, bit in enumerate(state):
                    if bit == 1:
                        qasm += f"x q[{op.wires[i]}];\n"
            elif op.name == 'GroverOperator':
                qasm += f"// GroverOperator implementation\n"
                wW = op.hyperparameters['work_wires']
                qasm += f"// Work wires: {wW}\n"
                for wire in op.wires[:-1]:
                    qasm += f"h q[{wire}];\n"
                qasm += f"z q[{op.wires[-1]}];\n"
                controls = ','.join([f"q[{w}]" for w in op.wires[:-1]])
                target = op.wires[-1]
                qasm += f"mcx {controls},q[{target}];\n"
                qasm += f"z q[{op.wires[-1]}];\n"
                for wire in op.wires[:-1]:
                    qasm += f"h q[{wire}];\n"
                qasm += f"// Global phase of pi (NONE) at the moment (TBD)\n"
            elif op.name == 'FlipSign':
                qasm += f"// FlipSign operation\n"
                qasm += f"z q[{op.wires[0]}];\n"
            elif op.name == 'QuantumPhaseEstimation':
                target_wire = op.wires[0]
                estimation_wires = op.wires[1:]
                qasm += f"// Quantum Phase Estimation\n"
                for wire in estimation_wires:
                    qasm += f"h q[{wire}];\n"
                #u = op.base
                for i, wire in enumerate(estimation_wires):
                    power = 2 ** (len(estimation_wires) - i - 1)
                    qasm += f"// Controlled-U^{power} operation\n"
                    qasm += f"cu {power} q[{wire}],q[{target_wire}];\n"
                qasm += f"// Inverse QFT on estimation wires\n"
                qasm += f"qft_dagger q[{','.join(map(str, estimation_wires))}];\n"
            else:
                qasm += f"// Unsupported operation at the moment (TBD): {op.name}\n"
        qasm += "\nmeasure q -> c;"
        return qasm
    
    def to_bqm(self):
        model = self.hamiltonian.compile()
        feed_dict = {str(k): v for k, v in self.placeholders.items()}
        bqm = model.to_bqm(feed_dict=feed_dict)
        self.bqm = bqm
        return bqm
        
    def DynexCompute(self, bqm, num_reads=512, annealing_time=256, mainnet=False, description='Dynex SDK Job', printSolution=False, 
                    debugging=False, is_cluster=True, bnb=False, shots=1):
        bqm = self.bqm
        model = dynex.BQM(bqm)
        sampler = dynex.DynexSampler(model, mainnet=mainnet, description=description, bnb=bnb)
        sampleset = sampler.sample(num_reads=num_reads, annealing_time=annealing_time, debugging=debugging, is_cluster=is_cluster, shots=shots)
        if printSolution:
            solution = sampleset.first.sample
            print("-----------/ Dynex Solution /-----------")
            print("DYNEX Simulation Output:", solution)
            print("-----------/ ************** /-----------")
        return sampleset

    def DecodeSolution(self, sampleset, method='measure'):
        feed_dict = {str(k): v for k, v in self.placeholders.items()}
        model = self.hamiltonian.compile()
        decoded = model.decode_sampleset(sampleset, feed_dict=feed_dict)
        with qml.tape.QuantumTape() as tape:
            self.circuit(self.params)
        self.isGrover = any(op.name == 'GroverOperator' for op in tape.operations)
        if method not in ['measure', 'probs', 'all']:
            raise ValueError("Method must be either 'measure', 'probs', 'all'")
        #print(f"-----------/ Decoded Solution ({method}) /-----------")
        if method == 'measure':
            samples = self.__getSamples__(decoded, sampleset)
            if self.isQPE:
                result = np.array(samples[0])
            else:
                result = np.array(samples[0])[::-1]
            #print(result)
        elif method == 'all':
            result = [np.array(sample[::-1]) for sample in self.__getSamples__(decoded, sampleset)]
        else:  
            probs = self.__getProbs__(sampleset, decoded)
            result = probs
            #print(result)
        #print("-----------/ ************************* /-----------")
        return result


    def __getSamples__(self, decoded, sampleset):
        samples = []
        for solution, occurrence in zip(decoded, sampleset.record.num_occurrences):
            sample = self.__SOL2STATE(solution.sample)
            samples.extend([sample] * occurrence)
        return samples
    
    def __getProbs__(self, sampleset, decoded):
        state_counts = Counter()
        total_samples = sum(sampleset.record.num_occurrences)
        for solution, occurrence in zip(decoded, sampleset.record.num_occurrences):
            state = self.__SOL2STATE(solution.sample)
            state_counts[tuple(state)] += occurrence
        qubit_probs = np.zeros(self.wires)
        for state, count in state_counts.items():
            for i, bit in enumerate(state):
                if bit == 1:
                    qubit_probs[i] += count / total_samples
        return qubit_probs[::-1]
    
    def __SOL2STATE(self, sample):
        state = [0] * self.wires
        for wire in range(self.wires):
            rKEY = f'q_{wire}_real'
            iKEY = f'q_{wire}_imag'
            qpeKEY = f'q_{wire}_ctrl_qpe_imag'
            if self.isQPE and qpeKEY in sample:
                state[wire] = 1 if sample[qpeKEY] > sample[rKEY] else 0
            elif rKEY in sample and iKEY in sample:
                if self.isGrover or self.isCQU or self.isQU:
                    state[wire] = 1 if sample[iKEY] > 0.5 else 0
                else:
                    state[wire] = 1 if sample[rKEY] > 0.5 else 0
            else:
                print(f"Warning: No final state found for wire {wire}")
        return state
