/**
 * Copyright (c) 2016-present, Facebook, Inc.
 * All rights reserved.
 *
 * This source code is licensed under the MIT license found in the
 * LICENSE file in the root directory of this source tree.
 */

#include "quantmatrix.h"

#include <assert.h>
#include <iostream>
#include <stdexcept>

namespace fasttext {

QuantMatrix::QuantMatrix() : Matrix(), qnorm_(false), codesize_(0) {}

real QuantMatrix::dotRow(const Vector& vec, int64_t i) const {
  assert(i >= 0);
  assert(i < m_);
  assert(vec.size() == n_);
  real norm = 1;
  if (qnorm_) {
    norm = npq_->get_centroids(0, norm_codes_[i])[0];
  }
  return pq_->mulcode(vec, codes_.data(), i, norm);
}

void QuantMatrix::addRowToVector(Vector& x, int32_t i, real a) const {
  real norm = 1;
  if (qnorm_) {
    norm = npq_->get_centroids(0, norm_codes_[i])[0];
  }
  pq_->addcode(x, codes_.data(), i, a * norm);
}

void QuantMatrix::addRowToVector(Vector& x, int32_t i) const {
  real norm = 1;
  if (qnorm_) {
    norm = npq_->get_centroids(0, norm_codes_[i])[0];
  }
  pq_->addcode(x, codes_.data(), i, norm);
}

void QuantMatrix::load(std::istream& in) {
  in.read((char*)&qnorm_, sizeof(qnorm_));
  in.read((char*)&m_, sizeof(m_));
  in.read((char*)&n_, sizeof(n_));
  in.read((char*)&codesize_, sizeof(codesize_));
  codes_ = std::vector<uint8_t>(codesize_);
  in.read((char*)codes_.data(), codesize_ * sizeof(uint8_t));
  pq_ = std::unique_ptr<ProductQuantizer>(new ProductQuantizer());
  pq_->load(in);
  if (qnorm_) {
    norm_codes_ = std::vector<uint8_t>(m_);
    in.read((char*)norm_codes_.data(), m_ * sizeof(uint8_t));
    npq_ = std::unique_ptr<ProductQuantizer>(new ProductQuantizer());
    npq_->load(in);
  }
}

} // namespace fasttext
