use pyo3::prelude::*;
use pyo3::types::PyModule;

use anyhow::Error as E;
use anyhow::Result;
use candle_core::{DType, Device, Tensor};
use candle_nn as nn;
use nn::{Module, VarBuilder, RNN};
use tokenizers::Tokenizer;

#[pyclass]
#[derive(Clone)]
struct TextClassifier {
    embedding: nn::Embedding,
    gru: nn::GRU,
    ln1: nn::Linear,
    device: Device,
}

impl TextClassifier {
    pub fn new(vs: VarBuilder) -> Result<Self> {
        let embedding = nn::embedding(10000, 256, vs.pp("embedding"))?;
        let gru = nn::gru(256, 256, Default::default(), vs.pp("gru"))?;
        let ln1 = nn::linear(256, 2, vs.pp("ln1"))?;
        let device = Device::cuda_if_available(0)?;
        return Ok(Self {
            embedding,
            ln1,
            gru,
            device,
        });
    }
    pub fn forward(&self, xs: &Tensor) -> Result<Tensor> {
        let xs = self.embedding.forward(xs)?;
        let mut gru_states = vec![self.gru.zero_state(1)?];
        for x in &xs.squeeze(0)?.to_vec2::<f32>()? {
            let state = self.gru.step(
                &Tensor::from_vec(x.clone(), (1, x.len()), &self.device)?,
                &gru_states.last().unwrap(),
            )?;
            gru_states.push(state);
        }
        let xs = gru_states.last().unwrap().h();
        let xs = self.ln1.forward(&xs)?;
        Ok(xs)
    }

    pub fn predict(
        &mut self,
        tokenizer: Tokenizer,
        device: &Device,
        text: String,
    ) -> Result<(bool, f32)> {
        let encoded = tokenizer.encode(text, false).map_err(E::msg)?;
        let data = Tensor::new(vec![encoded.get_ids()], device)?;
        let result = self.clone().forward(&data)?;
        let result = nn::ops::softmax(&result, 1)?;
        let probs = result.argmax(1)?.to_vec1::<u32>()?;
        return Ok((
            probs[0] != 0,
            result.to_vec2::<f32>()?[0][probs[0] as usize] * 100.0,
        ));
    }
}

#[pymethods]
impl TextClassifier {
    #[new]
    fn new_py() -> PyResult<Self> {
        let device = Device::cuda_if_available(0)
            .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
        let vs = nn::VarBuilder::from_buffered_safetensors(
            include_bytes!("../model.safetensors").to_vec(),
            DType::F32,
            &device,
        )
        .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
        Self::new(vs)
            .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))
    }

    /// Detects if text contains inappropriate content.
    fn detect(&mut self, text: String) -> PyResult<(bool, f32)> {
        let tokenizer = Tokenizer::from_bytes(include_bytes!("../tokenizer.json"))
            .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
        let device = Device::cuda_if_available(0)
            .map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;
        let (is_bad, prob) = self.predict(tokenizer, &device, text).map_err(|e| PyErr::new::<pyo3::exceptions::PyRuntimeError, _>(e.to_string()))?;

        Ok((is_bad, prob))
    }
}

/// A Python module implemented in Rust.
#[pymodule]
fn badetector(m: &Bound<'_, PyModule>) -> PyResult<()> {
    m.add_class::<TextClassifier>()?;
    Ok(())
}
