/* ======================================================================
     "numlibBls" - Basic linear least squares library
             Copyright (c) 2002-2009  by A. Miyoshi, Univ. Tokyo
                         created: Dec.   7, 2009 from libnumLls (BEx1D)
                     last edited: Dec.  13, 2009
 ====================================================================== */

#include "numlibBls.h"
#include "lapack_dposv/wcpp_dposv.h"

/* ======================================================================
     basic Linear Least Squares Class
 ====================================================================== */

                                                 // --- Constructors ---
basicLinLsSq::basicLinLsSq() { clear(); }

basicLinLsSq::basicLinLsSq(int npar) { init(npar); }

                                                 // --- Initializers ---
void basicLinLsSq::clear()
 { nPars = 0; nMeas = 0; sumXX.clear(); sumYX.clear(); }

int basicLinLsSq::init(int npar) {
  int ip;
  clear();
  if (npar > MXNPAR) { errmsg(INIT, LRG_nPars); return ERRE; }
  else if (npar < 1) { errmsg(INIT, SML_nPars); return ERRE; }
  nPars = npar;
  for (ip = 0; ip < nPars; ip++)
   { sumYX.push_back(0.); sumXX.push_back(sumYX); }
  sumYY = 0.;
  return NORM;
}

                                                // --- Data addition ---
int basicLinLsSq::addDatumW(double y, vector<double> &x, double w) {
  int ip, iq;
  if (nPars == 0) { errmsg(INIT, NOT_init); return ERRE; }
  if (x.size() < nPars) { errmsg(DADD, SMLN_x); return ERRE; }
  for (ip = 0; ip < nPars; ip++) {
    sumYX[ip] += w * x[ip] * y;
    for (iq = 0; iq <= ip; iq++) { sumXX[ip][iq] += w * x[ip] * x[iq]; }
  }
  sumYY += w * y * y;
  nMeas++;
  return NORM;
}

int basicLinLsSq::addDatum(double y, vector<double> &x) {
  double w = 1.;
  return addDatumW(y, x, w);
}

int basicLinLsSq::putDataW(int n, vector<double> &y,
 vector<vector<double> > &x, vector<double> &w) {
  int id;
  if ((y.size() < n) || (x.size() < n) || (w.size() < n))
   { errmsg(DADD, SMLN_meas); return ERRE; }
  for (id = 0; id < n; id++) {
    if (addDatumW(y[id], x[id], w[id]) == ERRE) { return ERRE; }
  }
  return NORM;
}

int basicLinLsSq::putData(int n, vector<double> &y,
 vector<vector<double> > &x) {
  int id;
  if ((y.size() < n) || (x.size() < n))
   { errmsg(DADD, SMLN_meas); return ERRE; }
  for (id = 0; id < n; id++) {
    if (addDatum(y[id], x[id]) == ERRE) { return ERRE; }
  }
  return NORM;
}

                                               // --- Solve basicLLS ---
int basicLinLsSq::solve_blls(vector<double> &mlp, vector<double> &sdp,
 vector<vector<double> > &pcc, double &sdy) {
  int ir, ic, info;
  double Vy;
  vector<double> wcc;

  if (nPars == 0) { errmsg(INIT, NOT_init); return ERRE; }
  if (nMeas < nPars) { errmsg(SOLV, meas_SMLTp); return ERRE; }
  for (ir = 0; ir < nPars; ir++) {
    b[0][ir] = sumYX[ir];
    for (ic = 0; ic <= ir; ic++) { a[ic][ir] = sumXX[ir][ic]; }
  }
  info = lapack_dposv('L', nPars, 1, (double **)a, MXNPAR,
   (double **)b, MXNPAR);
  if (info < 0) { errmsg(SOLV, POSVpar); return ERRE; }
  else if (info > 0) { errmsg(SOLV, POSVfail); return ERRE; }
  info = lapack_dpotri('L', nPars, (double **)a, MXNPAR);
  if (info < 0) { errmsg(SOLV, POTRIpar); return ERRE; }
  else if (info > 0) { errmsg(SOLV, POTRIfail); return ERRE; }

  Vy = sumYY;
  mlp.clear(); sdp.clear(); pcc.clear();
  for (ir = 0; ir < nPars; ir++) {
    mlp.push_back(b[0][ir]);
    Vy += mlp[ir] * mlp[ir] * sumXX[ir][ir] - 2. * mlp[ir] * sumYX[ir];
    sdp.push_back(sqrt(a[ir][ir]));
    wcc.clear();
    for (ic = 0; ic < ir; ic++) {
      wcc.push_back(a[ic][ir] / sqrt(a[ir][ir] * a[ic][ic]));
      Vy += 2. * mlp[ir] * mlp[ic] * sumXX[ir][ic];
    }
    pcc.push_back(wcc);
  }
  if (nMeas > nPars) { sdy = sqrt(Vy / (nMeas - nPars)); }
  else { sdy = 0.; }
  for (ir = 0; ir < nPars; ir++) { sdp[ir] *= sdy; }
  return NORM;
}

                                                // --- Error message ---
void basicLinLsSq::errmsg(erract era, errcode erc) {
  string msg = "basicLinLsSq-error: ";
  switch(era) {
    case INIT: msg += "(initialization) "; break;
    case DADD: msg += "(data addition) "; break;
    case SOLV: msg += "(solver) "; break;
  }
  switch(erc) {
    case NOT_init: msg += "object not initialized"; break;
    case LRG_nPars: msg += "nPars larger than MXNPAR"; break;
    case SML_nPars: msg += "nPars smaller than 1"; break;
    case SMLN_x: msg += "x-vec smaller than # of param"; break;
    case SMLN_meas: msg += "meas-vec smaller than # of meas"; break;
    case meas_SMLTp: msg += "# of meas smaller than # of param"; break;
    case POSVpar: msg += "invalid input to DPOSV"; break;
    case POSVfail: msg += "DPOSV numerical error"; break;
    case POTRIpar: msg += "invalid input to DPOTRI"; break;
    case POTRIfail: msg += "DPOTRI numerical error"; break;
  }
  cout << msg << endl;
}

