//
//  PrdObsStatistics.cpp
//  SimpleCrowdingModel
//
//  Created by David Kieras on 4/17/18.
//  Copyright © 2018 University of Michigan. All rights reserved.
//

#include "PrdObsStatistics.h"
#include "Program_constants.h"
#include <string>
#include <sstream>
#include <cassert>
#include <iostream>
#include <fstream>


using namespace std;

PrdObsStatistics::PrdObsStatistics()
{
}

void PrdObsStatistics::load_obs_data(ifstream& infile, bool echo_print)
{
    for(int cond = 0; cond < 3; cond++) {
        for(int pol = 0; pol < 2; pol++) {
            for(int ssize = 0; ssize < 4; ssize++) {
                double rt;
                infile >> rt;
                assert(infile.good());
                obsrts[cond][pol][ssize] = rt;
                }
            }
        }
    for(int cond = 0; cond < 3; cond++) {
        for(int pol = 0; pol < 2; pol++) {
            for(int ssize = 0; ssize < 4; ssize++) {
                double er;
                infile >> er;
                assert(infile.good());
                obsproperrs[cond][pol][ssize] = er;
                }
            }
        }
    // echo print for check
    if(!echo_print)
    	return;
    for(int cond = 0; cond < 3; cond++) {
        for(int pol = 0; pol < 2; pol++) {
            for(int ssize = 0; ssize < 4; ssize++) {
                cout << obsrts[cond][pol][ssize] << ' ';
                cout << obsproperrs[cond][pol][ssize] << endl;
                }
            }
        }
}

void PrdObsStatistics::update(Condition_e cond, Polarity_e polarity, int set_size, double pred_rt, double pred_err)
{
    // set_size is 3/6/12/18 has to be transposed to 0/1/2/3
    int iset_size = 0; // value for 3
    if(set_size > 3)
        iset_size = set_size/6; // value for 6,12,18
    assert(iset_size >= 0 && iset_size <= 3);

    double obs_rt = obsrts[int(cond)][int(polarity)][iset_size];
    double obs_err = obsproperrs[int(cond)][int(polarity)][iset_size];
    
//    cout << int(cond) << ' ' << int(polarity) << ' ' << set_size << ' ' << iset_size << ' ' << obs_rt << ' ' << pred_rt << ' ' << obs_err << ' ' << pred_err << endl;
    // pred then observed because aare computed accordingly
    po_rts.update(pred_rt, obs_rt);
    po_prop_errors.update(pred_err, obs_err);
}

void PrdObsStatistics::output(ostream& os) const
{
    assert(po_rts.get_n() == po_prop_errors.get_n());
//    cout << po_rts.get_rsq() << "\t" << po_rts.get_avg_abs_rel_error() << "\t";
    os << po_rts.get_rsq() << "\t" << po_rts.get_avg_abs_error() << "\t" << po_rts.get_avg_abs_rel_error() << "\t";
    double rt_fom = po_rts.get_avg_abs_rel_error() / po_rts.get_rsq();
    os << rt_fom << "\t";
    os << po_prop_errors.get_rsq() << "\t" << po_prop_errors.get_avg_abs_error() << "\t" << po_prop_errors.get_avg_abs_rel_error() << "\t";
    double er_foma = po_prop_errors.get_avg_abs_error() / po_prop_errors.get_rsq();
    os <<  er_foma << "\t";
    double er_fomr = po_prop_errors.get_avg_abs_rel_error() / po_prop_errors.get_rsq();
    os <<  er_fomr << "\t";
    
    double wa_fomr = (RT_FoM_weight_c * rt_fom + (1. - RT_FoM_weight_c) * er_fomr);
    os << wa_fomr << "\t" << po_prop_errors.get_n();
}



