/*
ga2.cc
Selah Lynch July '04

Class GAProb version2 definitions
*/

#include <mpi.h>
#include <stdio.h>
#include "ga2.h"
#include "population.h"
#include "problem.h"

GAProb::GAProb(Problem& prob, Population& pop){
  gapop=&pop;
  gaprob=&prob;
  popsize=gapop->GetPopSize();

  numpoints=prob.GetNumPoints();

  generation=0;

  MPI_Comm_rank(MPI_COMM_WORLD, &myrank);
  MPI_Comm_size(MPI_COMM_WORLD, &sizeofcluster);

  bestslnovertime = new Solution(gapop->BestSln());

  int killpercent=50, mutpercent=50;
  SetVariables(killpercent,mutpercent);
  
  slnout.open("solutions.data");

  deletepop=false;
}

GAProb::GAProb(Problem& prob, int popsizearg){
  gapop=new Population(popsizearg,prob);
  GAProb(prob, *gapop);
  deletepop=true;
}


GAProb::GAProb(GAProb& gap){
  check(3==2,"copy constructor should not be used, please pass by reference");

}


GAProb::~GAProb(){

  WriteSolutionsFile(slnout);
  slnout.close();
  if (deletepop) delete gapop;
}


void GAProb::SetVariables(int percentkillarg,int percentmutatearg){
  percentkill=percentkillarg;
  percentmutate=percentmutatearg;
}


void GAProb::RunTillDone(int checkinterval){
  bool keepgoing = true;
  //Solution s(pop.BestSln());
  double olddist, newdist;

  newdist=BestOverCluster(BestSlnOverTime()).GetDist(*gaprob);
  olddist=newdist+1; //just so olddist starts out as >than newdist  

  //if(myrank==0)cout<<"Generation \tOldDist \tNewDist\n";
  //if(myrank==0)cout<<generation<<"\t\t"<<olddist<<"\t\t"<<newdist<<"\n";

  while(newdist<olddist){
    RunFor(checkinterval);
    olddist = newdist;
    newdist = BestOverCluster(BestSlnOverTime()).GetDist(*gaprob); 
  }
}



void GAProb::RunFor(int howmanygens){
  check(howmanygens<=43000, "RunFor(), Too long to run.");
  for(int i=0; i<howmanygens; i++) {
    NextGeneration();}
}


void GAProb::NextGeneration(){
  //takes the population and turns it into the next gen

  gapop->BreedAndReplace(percentkill);
  gapop->Mutate(percentmutate);
  gapop->MixPopulations();


  Solution bestslnnow(gapop->BestSln());

  double bestdist, thisdist;

  if(generation==0) bestslnovertime->SetAsEqual(bestslnnow);
  else{

    bestdist = bestslnovertime->GetDist(*gaprob);
    thisdist = bestslnnow.GetDist(*gaprob);
    bool isimproved = thisdist < bestdist;
 
    if(isimproved){ 
      bestslnovertime->SetAsEqual(bestslnnow);
    }
  }
  generation++;

  if(myrank==0 && generation%250==0 && generation!=0) 
    printf("%d generations down...\n",generation);

}


int GAProb::GetGeneration(){
  return generation;
}



Solution& GAProb::BestSlnOverTime(){
  return *bestslnovertime;
}


Solution GAProb::BestOverCluster(Solution& slnarg){

  const int ROOT=0;

  //find which computers solution has the smallest distance
  double dist=slnarg.GetDist(*gaprob);
  int dist100_000=(int)(100000*dist);

  int reducebuffer[2], reslutbuffer[2];

  reducebuffer[0] = dist100_000;
  reducebuffer[1] = myrank;

  MPI_Reduce(reducebuffer,reslutbuffer,2,MPI_2INT,MPI_MINLOC,ROOT,MPI_COMM_WORLD);

  MPI_Bcast(&(reslutbuffer[1]),1,MPI_INT,ROOT,MPI_COMM_WORLD);
  int minrank = reslutbuffer[1];

  //broadcast the order from the computer maxrank
  int order[numpoints];
  if(myrank==minrank) slnarg.SpitOutOrder(order);

  MPI_Bcast(order,numpoints,MPI_INT,minrank,MPI_COMM_WORLD);

  Solution s(slnarg);
  s.SetAsString(order);

  return s;

}


void GAProb::WriteSolutionsFile(ostream& out){
  if(myrank==ROOT){
    out<<popsize<<endl;
    gapop->Display(out);
    out<<"-999"<<endl;
  }
}



void GAProb::check(bool b, char* mess){
  if(!b) {
    printf("\n\nERROR[GAProb] - %s \n\n\n", mess);
    exit(0);
  }
}



