import mlx.core as mx
from mlx import nn

from mflux.models.flux.model.flux_vae.decoder.decoder import Decoder
from mflux.models.flux.model.flux_vae.encoder.encoder import Encoder


class VAE(nn.Module):
    scaling_factor: int = 0.3611
    shift_factor: int = 0.1159

    def __init__(self):
        super().__init__()
        self.decoder = Decoder()
        self.encoder = Encoder()

    def decode(self, latents: mx.array) -> mx.array:
        scaled_latents = (latents / self.scaling_factor) + self.shift_factor
        return self.decoder(scaled_latents)

    def encode(self, latents: mx.array) -> mx.array:
        latents = self.encoder(latents)
        mean, _ = mx.split(latents, 2, axis=1)
        return (mean - self.shift_factor) * self.scaling_factor
