#include <variant>
#include "../cutint/cutintegral.hpp"
#include "../xfem/symboliccutbfi.hpp"
#include "../xfem/symboliccutlfi.hpp"

CutIntegral :: CutIntegral (shared_ptr<CoefficientFunction> _cf, shared_ptr<CutDifferentialSymbol> _dx)
  : Integral(_cf, *_dx), lsetintdom(_dx->lsetintdom) { ; }

CutIntegral :: CutIntegral (shared_ptr<CoefficientFunction> _cf, DifferentialSymbol _dx, shared_ptr<LevelsetIntegrationDomain> _lsetintdom)
  : Integral(_cf, _dx), lsetintdom(_lsetintdom) { ; }

shared_ptr<BilinearFormIntegrator> CutIntegral :: MakeBilinearFormIntegrator() const
{
  // check for DG terms
  bool has_other = false;
  cf->TraverseTree ([&has_other] (CoefficientFunction & cf)
                    {
                      if (dynamic_cast<ProxyFunction*> (&cf))
                        if (dynamic_cast<ProxyFunction&> (cf).IsOther())
                          has_other = true;
                    });
  if (has_other && (dx.element_vb != BND) && !dx.skeleton)
    throw Exception("DG-facet terms need either skeleton=True or element_boundary=True");

  shared_ptr<BilinearFormIntegrator> bfi;
  if (!has_other && !dx.skeleton)
    bfi  = make_shared<SymbolicCutBilinearFormIntegrator> (*lsetintdom, cf, dx.vb, dx.element_vb);
  else
  {
    //if (lsetintdom->GetTimeIntegrationOrder() >= 0)
    //  throw Exception("Symbolic cuts on facets and boundary not yet (implemented/tested) for time_order >= 0..");
    if (dx.vb == BND) {
      static bool warned = false;
      if (!warned) {
        cout << "WARNING: Symbolic cuts on boundary facets are an experimental feature for now (not fully tested ..). \n" << endl;
      }
      warned = true;
      bfi = make_shared<SymbolicCutFacetBilinearFormIntegrator> (*lsetintdom, cf, dx.vb);
    }
    else {
      bfi = make_shared<SymbolicCutFacetBilinearFormIntegrator> (*lsetintdom, cf,dx.vb);
    }
  }

  if (dx.definedon)
    {
      if (auto definedon_bitarray = get_if<BitArray> (&*dx.definedon); definedon_bitarray)
        bfi->SetDefinedOn(*definedon_bitarray);
      /*
        // can't do that withouyt mesh
      if (auto definedon_string = get_if<string> (&*dx.definedon); definedon_string)
        {
          Region reg(self.GetFESpace()->GetMeshAccess(), dx.vb, *definedon_string);
          bfi->SetDefinedOn(reg.Mask());
        }
      */
    }
  bfi->SetDeformation(dx.deformation);               
  bfi->SetBonusIntegrationOrder(dx.bonus_intorder);
  if(dx.definedonelements)
    bfi->SetDefinedOnElements(dx.definedonelements);
  // for (auto both : dx.userdefined_intrules)
  //   bfi->SetIntegrationRule(both.first, *both.second);

  return bfi;
}



shared_ptr<LinearFormIntegrator> CutIntegral :: MakeLinearFormIntegrator() const
{
  // check for DG terms
  bool has_other = false;
  cf->TraverseTree ([&has_other] (CoefficientFunction & cf)
                    {
                      if (dynamic_cast<ProxyFunction*> (&cf))
                        if (dynamic_cast<ProxyFunction&> (cf).IsOther())
                          has_other = true;
                    });
  if (has_other && (dx.element_vb != BND) && !dx.skeleton)
    throw Exception("DG-facet terms need either skeleton=True or element_boundary=True");

  shared_ptr<LinearFormIntegrator> lfi; 
  if (!has_other && !dx.skeleton)
    lfi  = make_shared<SymbolicCutLinearFormIntegrator> (*lsetintdom, cf, dx.vb);
  else
  {
    if (dx.vb == BND) {
      static bool warned = false;
      if (!warned) {
        cout << "WARNING: Symbolic cuts on boundary facets are an experimental feature for now (not fully tested ..). \n" << endl;
      }
      warned = true;
      lfi  = make_shared<SymbolicCutFacetLinearFormIntegrator> (*lsetintdom, cf, dx.vb);
    }
    else {
      throw Exception("SymbolicFacetCutLinearFormIntegrator not yet implemented for interior facets.");
    }
  }
  //lfi  = make_shared<SymbolicCutLinearFormIntegrator> (*lsetintdom, cf, dx.vb);

  if (dx.definedon)
    {
      if (auto definedon_bitarray = get_if<BitArray> (&*dx.definedon); definedon_bitarray)
        lfi->SetDefinedOn(*definedon_bitarray);
      /*
        // can't do that withouyt mesh
      if (auto definedon_string = get_if<string> (&*dx.definedon); definedon_string)
        {
          Region reg(self.GetFESpace()->GetMeshAccess(), dx.vb, *definedon_string);
          lfi->SetDefinedOn(reg.Mask());
        }
      */
    }
  lfi->SetDeformation(dx.deformation);               
  lfi->SetBonusIntegrationOrder(dx.bonus_intorder);
  if(dx.definedonelements)
    lfi->SetDefinedOnElements(dx.definedonelements);
  // for (auto both : dx.userdefined_intrules)
  //   lfi->SetIntegrationRule(both.first, *both.second);

  return lfi;
}


template <typename TSCAL>
TSCAL CutIntegral :: T_CutIntegrate (const ngcomp::MeshAccess & ma,
                                  FlatVector<TSCAL> element_wise)
{
  static Timer timer("CutIntegral::T_CutIntegrate");
  RegionTimer reg (timer);
  LocalHeap glh(1000000000, "lh-T_CutIntegrate");
  // bool space_time = lsetintdom->GetTimeIntegrationOrder() >= 0;
  if (dx.element_vb == BND)
    throw Exception("CutIntegrate can only deal with VOL a.t.m..");

  BitArray defon;

  if (dx.definedon)
    {
      if (auto definedon_bitarray = get_if<BitArray> (&*dx.definedon))
        defon = *definedon_bitarray;
      if (auto definedon_string = get_if<string> (&*dx.definedon))
        {
          shared_ptr<MeshAccess> spma(const_cast<MeshAccess*>(&ma), NOOP_Deleter);
          Region reg(spma, dx.vb, *definedon_string);
          defon = reg.Mask();
        }
    }
  
  bool simd_eval = globxvar.SIMD_EVAL;
  int cfdim = cf->Dimension();
  if(cfdim != 1)
    throw Exception("only implemented for 1 dimensional coefficientfunctions");

  TSCAL sum = 0.0;
  ma.IterateElements(VOL, glh, [&] (Ngs_Element el, LocalHeap & lh)
  {
    if (defon.Size() && !defon.Test(el.GetIndex()))
      return;
    if (dx.definedonelements && !dx.definedonelements->Test(el.Nr()))
      return;

    auto & trafo1 = ma.GetTrafo (el, lh);
    auto & trafo = trafo1.AddDeformation(this->dx.deformation.get(), lh);
    const IntegrationRule *ns_ir;
    Array<double> ns_wei_arr;
    tie (ns_ir, ns_wei_arr) = CreateCutIntegrationRule(*lsetintdom,trafo,lh);
    if (ns_ir == nullptr)
      return;

    if (simd_eval) {
      try 
      {
        SIMD_IntegrationRule simd_ir(*ns_ir, lh);
        FlatArray<SIMD<double>> simd_wei_arr = CreateSIMD_FlatArray(ns_wei_arr, lh);
        {
          SIMD_BaseMappedIntegrationRule & simd_mir = trafo(simd_ir, lh);
          FlatMatrix<SIMD<TSCAL>> val(simd_mir.Size(), 1, lh);
          cf -> Evaluate (simd_mir, val);
          SIMD<TSCAL> lsum(0.0);
          for (int i = 0; i < simd_mir.Size(); i++)
              lsum += simd_mir[i].GetMeasure()*simd_wei_arr[i]*val(i,0);
          if (element_wise.Size())
            element_wise(el.Nr()) += HSum(lsum);
          AtomicAdd(sum, HSum(lsum));
          return;
        }
      } catch (ExceptionNOSIMD e) {
        cout << IM(6) << e.What()
            << "switching to non-SIMD evaluation" << endl;
        simd_eval = false;
      }
    }
    //scalar (non-SIMD) evaluation
    {
      BaseMappedIntegrationRule & mir = trafo(*ns_ir, lh);
      FlatMatrix<TSCAL> val(mir.Size(), 1, lh);
      cf -> Evaluate (mir, val);
      TSCAL lsum(0.0);
      for (int i = 0; i < mir.Size(); i++)
          lsum += mir[i].GetMeasure()*ns_wei_arr[i]*val(i,0);

      if (element_wise.Size())
        element_wise(el.Nr()) += lsum;
      
      AtomicAdd(sum,lsum);

    }
  });
  return ma.GetCommunicator().AllReduce(sum, NG_MPI_SUM);
}


double CutIntegral::Integrate (const ngcomp::MeshAccess & ma,
                            FlatVector<double> element_wise)
{ 
  return T_CutIntegrate(ma, element_wise);
}

Complex CutIntegral::Integrate (const ngcomp::MeshAccess & ma,
                              FlatVector<Complex> element_wise)
{ return T_CutIntegrate(ma, element_wise);}

FacetPatchIntegral::FacetPatchIntegral (shared_ptr<CoefficientFunction> _cf,
                                        shared_ptr<FacetPatchDifferentialSymbol> _dx)
      : Integral(_cf, *_dx), time_order(_dx->time_order), tref(_dx->tref), downscale(_dx->downscale) { ; }

FacetPatchIntegral::FacetPatchIntegral (shared_ptr<CoefficientFunction> _cf, DifferentialSymbol _dx,
                                        int _time_order, optional<double> _tref, optional<double> _downscale)
      : Integral(_cf, _dx), time_order(_time_order), tref(_tref), downscale(_downscale) { ; }

shared_ptr<BilinearFormIntegrator> FacetPatchIntegral :: MakeBilinearFormIntegrator() const
{
  // check for DG terms
  bool has_other = false;
  cf->TraverseTree ([&has_other] (CoefficientFunction & cf)
                    {
                      if (dynamic_cast<ProxyFunction*> (&cf))
                        if (dynamic_cast<ProxyFunction&> (cf).IsOther())
                          has_other = true;
                    });
  if (!has_other)
    cout << IM(3) << " no Other() used?!" << endl;

  auto bfi = make_shared<SymbolicFacetPatchBilinearFormIntegrator> (cf);
  if ((tref) && (time_order > -1))
    throw Exception("not reference time fixing for space-time integration domain");
  bfi->SetTimeIntegrationOrder(time_order);
  if (tref)
    bfi->SetReferenceTime(*tref);

  if (dx.definedon)
  {
    if (auto definedon_bitarray = get_if<BitArray> (&*dx.definedon); definedon_bitarray)
      bfi->SetDefinedOn(*definedon_bitarray);
  }
  bfi->SetDeformation(dx.deformation);              
  bfi->SetBonusIntegrationOrder(dx.bonus_intorder);
  if(dx.definedonelements)
    bfi->SetDefinedOnElements(dx.definedonelements);
  if (downscale)
    bfi->SetIRScaling(*downscale);
  return bfi;
}



template double CutIntegral :: T_CutIntegrate<double> (const ngcomp::MeshAccess & ma,
                                                  FlatVector<double> element_wise);
template Complex CutIntegral :: T_CutIntegrate<Complex> (const ngcomp::MeshAccess & ma,
                                                    FlatVector<Complex> element_wise);
