- Timestamp:
- 05/26/16 18:24:02 (8 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
trunk/CrossPare/src/de/ugoe/cs/cpdp/training/GPTraining.java
r104 r106 44 44 /** 45 45 * Genetic Programming Trainer 46 * 46 47 * 48 * - GPRun is a Run of a complete Genetic Programm Evolution, we want several complete runs. 49 * - GPVClassifier is the Validation Classifier 50 * - GPVVClassifier is the Validation-Voting Classifier 51 * 52 * config: <setwisetrainer name="GPTraining" param="GPVVClassifier" /> 47 53 */ 48 54 public class GPTraining implements ISetWiseTrainingStrategy, IWekaCompatibleTrainer { 49 55 50 private GPV VClassifier classifier = new GPVVClassifier();56 private GPVClassifier classifier = null; 51 57 52 58 private int populationSize = 1000; … … 54 60 private int initMaxDepth = 6; 55 61 private int tournamentSize = 7; 56 62 private int maxGenerations = 50; 63 private double errorType2Weight = 1; 64 private int numberRuns = 200; // 200 in the paper 65 private int maxDepth = 20; // max depth within one program 66 private int maxNodes = 100; // max nodes within one program 67 57 68 @Override 58 69 public void setParameter(String parameters) { 59 // todo, which type of classifier? GPV, GPVV? 60 // more config population size, etc. 61 // todo: voting for gpvv only 3 votes necessary? 70 if(parameters.equals("GPVVClassifier")) { 71 this.classifier = new GPVVClassifier(); 72 ((GPVVClassifier)this.classifier).configure(populationSize, initMinDepth, initMaxDepth, tournamentSize, maxGenerations, errorType2Weight, numberRuns, maxDepth, maxNodes); 73 }else if(parameters.equals("GPVClassifier")) { 74 this.classifier = new GPVClassifier(); 75 ((GPVClassifier)this.classifier).configure(populationSize, initMinDepth, initMaxDepth, tournamentSize, maxGenerations, errorType2Weight, numberRuns, maxDepth, maxNodes); 76 }else { 77 // default 78 this.classifier = new GPVVClassifier(); 79 ((GPVVClassifier)this.classifier).configure(populationSize, initMinDepth, initMaxDepth, tournamentSize, maxGenerations, errorType2Weight, numberRuns, maxDepth, maxNodes); 80 } 62 81 } 63 82 … … 105 124 } 106 125 107 // one gprun, we want several for voting 126 /** 127 * One Run of a GP Classifier 128 * we want several runs to mitigate problems with local maxima/minima 129 */ 108 130 public class GPRun extends AbstractClassifier { 109 131 private static final long serialVersionUID = -4250422550107888789L; 110 132 111 private int populationSize = 1000; 112 private int initMinDepth = 2; 113 private int initMaxDepth = 6; 114 private int tournamentSize = 7; 115 private int maxGenerations = 50; 133 private int populationSize; 134 private int initMinDepth; 135 private int initMaxDepth; 136 private int tournamentSize; 137 private int maxGenerations; 138 private double errorType2Weight; 139 private int maxDepth; 140 private int maxNodes; 116 141 117 142 private GPGenotype gp; 118 143 private GPProblem problem; 119 144 120 public void configure(int populationSize, int initMinDepth, int initMaxDepth, int tournamentSize, int maxGenerations ) {145 public void configure(int populationSize, int initMinDepth, int initMaxDepth, int tournamentSize, int maxGenerations, double errorType2Weight, int maxDepth, int maxNodes) { 121 146 this.populationSize = populationSize; 122 147 this.initMinDepth = initMinDepth; … … 124 149 this.tournamentSize = tournamentSize; 125 150 this.maxGenerations = maxGenerations; 151 this.errorType2Weight = errorType2Weight; 152 this.maxDepth = maxDepth; 153 this.maxNodes = maxNodes; 126 154 } 127 155 … … 133 161 return ((CrossPareGP)this.problem).getVariables(); 134 162 } 135 136 public void setEvaldata(Instances testdata) { 137 163 164 @Override 165 public void buildClassifier(Instances traindata) throws Exception { 166 InstanceData train = new InstanceData(traindata); 167 this.problem = new CrossPareGP(train.getX(), train.getY(), this.populationSize, this.initMinDepth, this.initMaxDepth, this.tournamentSize, this.errorType2Weight, this.maxDepth, this.maxNodes); 168 this.gp = problem.create(); 169 this.gp.evolve(this.maxGenerations); 138 170 } 139 171 … … 142 174 */ 143 175 class CrossPareGP extends GPProblem { 144 145 //private static final long serialVersionUID = 7526472295622776147L;146 147 176 private double[][] instances; 148 177 private boolean[] output; 149 178 179 private int maxDepth; 180 private int maxNodes; 181 150 182 private Variable[] x; 151 183 152 public CrossPareGP(double[][] instances, boolean[] output, int populationSize, int minInitDept, int maxInitDepth, int tournamentSize ) throws InvalidConfigurationException {184 public CrossPareGP(double[][] instances, boolean[] output, int populationSize, int minInitDept, int maxInitDepth, int tournamentSize, double errorType2Weight, int maxDepth, int maxNodes) throws InvalidConfigurationException { 153 185 super(new GPConfiguration()); 154 186 155 187 this.instances = instances; 156 188 this.output = output; 189 this.maxDepth = maxDepth; 190 this.maxNodes = maxNodes; 157 191 158 192 Configuration.reset(); 159 193 GPConfiguration config = this.getGPConfiguration(); 160 //config.reset();161 194 162 195 this.x = new Variable[this.instances[0].length]; 163 164 196 165 197 for(int j=0; j < this.x.length; j++) { … … 170 202 //config.setGPFitnessEvaluator(new DefaultGPFitnessEvaluator()); // bigger fitness is better 171 203 172 // from paper: 2-6173 204 config.setMinInitDepth(minInitDept); 174 205 config.setMaxInitDepth(maxInitDepth); 175 176 // missing from paper 177 // config.setMaxDepth(20); 178 206 179 207 config.setCrossoverProb((float)0.60); 180 208 config.setReproductionProb((float)0.10); … … 183 211 config.setSelectionMethod(new TournamentSelector(tournamentSize)); 184 212 185 // from paper 1000186 213 config.setPopulationSize(populationSize); 187 214 188 // BranchTypingCross189 215 config.setMaxCrossoverDepth(4); 190 config.setFitnessFunction(new CrossPareFitness(this.x, this.instances, this.output ));216 config.setFitnessFunction(new CrossPareFitness(this.x, this.instances, this.output, errorType2Weight)); 191 217 config.setStrictProgramCreation(true); 192 218 } … … 230 256 comb, 231 257 }; 232 233 GPGenotype result = GPGenotype.randomInitialGenotype(config, types, argTypes, nodeSets, 20, true); // 20 = maxNodes, true = verbose output 258 259 // we only have one chromosome so this suffices 260 int minDepths[] = {config.getMinInitDepth()}; 261 int maxDepths[] = {this.maxDepth}; 262 GPGenotype result = GPGenotype.randomInitialGenotype(config, types, argTypes, nodeSets, minDepths, maxDepths, this.maxNodes, false); // 40 = maxNodes, true = verbose output 234 263 235 264 return result; … … 250 279 private boolean[] output; 251 280 252 private double error _type2_weight = 1.0;281 private double errorType2Weight = 1.0; 253 282 254 283 // needed in evaluate 255 private Object[] NO_ARGS = new Object[0];284 //private Object[] NO_ARGS = new Object[0]; 256 285 257 286 private double sfitness = 0.0f; 258 private int error _type1 = 0;259 private int error _type2 = 0;260 261 public CrossPareFitness(Variable[] x, double[][] instances, boolean[] output ) {287 private int errorType1 = 0; 288 private int errorType2 = 0; 289 290 public CrossPareFitness(Variable[] x, double[][] instances, boolean[] output, double errorType2Weight) { 262 291 this.x = x; 263 292 this.instances = instances; 264 293 this.output = output; 294 this.errorType2Weight = errorType2Weight; 265 295 } 266 296 267 297 public int getErrorType1() { 268 return this.error _type1;298 return this.errorType1; 269 299 } 270 300 271 301 public int getErrorType2() { 272 return this.error _type2;302 return this.errorType2; 273 303 } 274 304 … … 288 318 289 319 // count classification errors 290 this.error _type1 = 0;291 this.error _type2 = 0;320 this.errorType1 = 0; 321 this.errorType2 = 0; 292 322 293 323 for(int i=0; i < this.instances.length; i++) { … … 299 329 300 330 // value gives us a double, if < 0.5 we set this instance as faulty 301 value = program.execute_double(0, NO_ARGS); // todo: test with this.x331 value = program.execute_double(0, this.x); // todo: test with this.x 302 332 303 333 if(value < 0.5) { 304 334 if(this.output[i] != true) { 305 this.error _type1 += 1;335 this.errorType1 += 1; 306 336 } 307 337 }else { 308 338 if(this.output[i] == true) { 309 this.error _type2 += 1;339 this.errorType2 += 1; 310 340 } 311 341 } … … 313 343 314 344 // now calc pfitness 315 pfitness = (this.error _type1 + this.error_type2_weight * this.error_type2) / this.instances.length;345 pfitness = (this.errorType1 + this.errorType2Weight * this.errorType2) / this.instances.length; 316 346 317 347 //System.out.println("pfitness: " + pfitness); … … 321 351 if(program.getChromosome(0).getSize(0) < 10) { 322 352 program.setApplicationData(10.0f); 323 this.sfitness = 10.0f; 324 //System.out.println("wenige nodes: "+program.getChromosome(0).getSize(0)); 325 //System.out.println(program.toStringNorm(0)); 326 } 327 328 // sfitness counts the number of nodes in the tree, if it is lower than 10 fitness is increased by 10 353 } 329 354 330 355 return pfitness; 331 356 } 332 357 } 333 334 @Override 335 public void buildClassifier(Instances traindata) throws Exception { 336 InstanceData train = new InstanceData(traindata); 337 this.problem = new CrossPareGP(train.getX(), train.getY(), this.populationSize, this.initMinDepth, this.initMaxDepth, this.tournamentSize); 338 this.gp = problem.create(); 339 this.gp.evolve(this.maxGenerations); 340 } 358 359 /** 360 * Custom GT implementation used in the GP Algorithm. 361 */ 362 public class GT extends MathCommand implements ICloneable { 363 364 private static final long serialVersionUID = 113454184817L; 365 366 public GT(final GPConfiguration a_conf, java.lang.Class a_returnType) throws InvalidConfigurationException { 367 super(a_conf, 2, a_returnType); 368 } 369 370 public String toString() { 371 return "GT(&1, &2)"; 372 } 373 374 public String getName() { 375 return "GT"; 376 } 377 378 public float execute_float(ProgramChromosome c, int n, Object[] args) { 379 float f1 = c.execute_float(n, 0, args); 380 float f2 = c.execute_float(n, 1, args); 381 382 float ret = 1.0f; 383 if(f1 > f2) { 384 ret = 0.0f; 385 } 386 387 return ret; 388 } 389 390 public double execute_double(ProgramChromosome c, int n, Object[] args) { 391 double f1 = c.execute_double(n, 0, args); 392 double f2 = c.execute_double(n, 1, args); 393 394 double ret = 1; 395 if(f1 > f2) { 396 ret = 0; 397 } 398 return ret; 399 } 400 401 public Object clone() { 402 try { 403 GT result = new GT(getGPConfiguration(), getReturnType()); 404 return result; 405 }catch(Exception ex) { 406 throw new CloneException(ex); 407 } 408 } 409 } 341 410 } 342 411 … … 349 418 */ 350 419 public class GPVVClassifier extends GPVClassifier { 351 420 421 private static final long serialVersionUID = -654710583852839901L; 352 422 private List<Classifier> classifiers = null; 353 423 … … 362 432 // each classifier is trained with one project from the set 363 433 // then is evaluated on the rest 434 classifiers = new LinkedList<>(); 364 435 for(int i=0; i < traindataSet.size(); i++) { 365 436 … … 367 438 LinkedList<Classifier> candidates = new LinkedList<>(); 368 439 369 // 200 runs 370 371 for(int k=0; k < 200; k++) { 440 // number of runs 441 for(int k=0; k < this.numberRuns; k++) { 372 442 Classifier classifier = new GPRun(); 443 ((GPRun)classifier).configure(this.populationSize, this.initMinDepth, this.initMaxDepth, this.tournamentSize, this.maxGenerations, this.errorType2Weight, this.maxDepth, this.maxNodes); 373 444 374 445 // one project is training data … … 381 452 // if type1 and type2 errors are < 0.5 we allow the model in the final voting 382 453 errors = this.evaluate((GPRun)classifier, traindataSet.get(j)); 383 if((errors[0] / traindataSet.get(j).numInstances()) < 0.5 && (errors[0] / traindataSet.get(j).numInstances()) < 0.5) {384 candidates.add(classifier); 454 if((errors[0] < 0.5) && (errors[0] < 0.5)) { 455 candidates.add(classifier); 385 456 } 386 457 } … … 406 477 // now we have the best classifier for this training data 407 478 classifiers.add(best); 408 409 479 } 410 480 } … … 417 487 418 488 int vote_positive = 0; 419 int vote_negative = 0;420 489 421 490 for (int i = 0; i < classifiers.size(); i++) { … … 432 501 if(fitest.execute_double(0, vars) < 0.5) { 433 502 vote_positive += 1; 434 }else {435 vote_negative += 1;436 503 } 437 504 } … … 450 517 * 451 518 * for one test data set: 452 * for one in 6possible training data sets:453 * For 200GP Runs:519 * for one in X possible training data sets: 520 * For Y GP Runs: 454 521 * train one Classifier with this training data 455 522 * then evaluate the classifier with the remaining project … … 465 532 private static final long serialVersionUID = 3708714057579101522L; 466 533 467 534 protected int populationSize; 535 protected int initMinDepth; 536 protected int initMaxDepth; 537 protected int tournamentSize; 538 protected int maxGenerations; 539 protected double errorType2Weight; 540 protected int numberRuns; 541 protected int maxDepth; 542 protected int maxNodes; 543 544 /** 545 * Configure the GP Params and number of Runs 546 * 547 * @param populationSize 548 * @param initMinDepth 549 * @param initMaxDepth 550 * @param tournamentSize 551 * @param maxGenerations 552 * @param errorType2Weight 553 */ 554 public void configure(int populationSize, int initMinDepth, int initMaxDepth, int tournamentSize, int maxGenerations, double errorType2Weight, int numberRuns, int maxDepth, int maxNodes) { 555 this.populationSize = populationSize; 556 this.initMinDepth = initMinDepth; 557 this.initMaxDepth = initMaxDepth; 558 this.tournamentSize = tournamentSize; 559 this.maxGenerations = maxGenerations; 560 this.errorType2Weight = errorType2Weight; 561 this.numberRuns = numberRuns; 562 this.maxDepth = maxDepth; 563 this.maxNodes = maxNodes; 564 } 565 468 566 /** Build the GP Multiple Data Sets Validation Classifier 469 567 * … … 485 583 486 584 // 200 runs 487 for(int k=0; k < 200; k++) {585 for(int k=0; k < this.numberRuns; k++) { 488 586 Classifier classifier = new GPRun(); 587 ((GPRun)classifier).configure(this.populationSize, this.initMinDepth, this.initMaxDepth, this.tournamentSize, this.maxGenerations, this.errorType2Weight, this.maxDepth, this.maxNodes); 489 588 490 589 // one project is training data … … 498 597 // if type1 and type2 errors are < 0.5 we allow the model in the final voting 499 598 errors = this.evaluate((GPRun)classifier, traindataSet.get(j)); 500 if((errors[0] / traindataSet.get(j).numInstances()) < 0.5 && (errors[0] / traindataSet.get(j).numInstances()) < 0.5) {599 if((errors[0] < 0.5) && (errors[0] < 0.5)) { 501 600 candidates.add(classifier); 502 601 } … … 549 648 public void buildClassifier(Instances traindata) throws Exception { 550 649 final Classifier classifier = new GPRun(); 650 ((GPRun)classifier).configure(populationSize, initMinDepth, initMaxDepth, tournamentSize, maxGenerations, errorType2Weight, this.maxDepth, this.maxNodes); 551 651 classifier.buildClassifier(traindata); 552 652 classifiers.add(classifier); … … 562 662 int error_type1 = 0; 563 663 int error_type2 = 0; 564 int number_instances = evalData.numInstances(); 664 int positive = 0; 665 int negative = 0; 565 666 566 667 for(Instance instance: evalData) { … … 571 672 572 673 classification = fitest.execute_double(0, vars); 674 675 // we need to count the absolutes of positives for percentage 676 if(instance.classValue() == 1.0) { 677 positive +=1; 678 }else { 679 negative +=1; 680 } 573 681 574 682 // classification < 0.5 we say defective … … 584 692 } 585 693 586 double et1_per = error_type1 / number_instances; 587 double et2_per = error_type2 / number_instances; 588 589 // return some kind of fehlerquote? 590 //return (error_type1 + error_type2) / number_instances; 591 return new double[]{error_type1, error_type2}; 694 // return error types percentages for the types 695 double et1_per = error_type1 / negative; 696 double et2_per = error_type2 / positive; 697 return new double[]{et1_per, et2_per}; 592 698 } 593 699 … … 614 720 } 615 721 } 616 617 618 /**619 * Custom GT implementation from the paper620 */621 public class GT extends MathCommand implements ICloneable {622 623 private static final long serialVersionUID = 113454184817L;624 625 public GT(final GPConfiguration a_conf, java.lang.Class a_returnType) throws InvalidConfigurationException {626 super(a_conf, 2, a_returnType);627 }628 629 public String toString() {630 return "GT(&1, &2)";631 }632 633 public String getName() {634 return "GT";635 }636 637 public float execute_float(ProgramChromosome c, int n, Object[] args) {638 float f1 = c.execute_float(n, 0, args);639 float f2 = c.execute_float(n, 1, args);640 641 float ret = 1.0f;642 if(f1 > f2) {643 ret = 0.0f;644 }645 646 return ret;647 }648 649 public double execute_double(ProgramChromosome c, int n, Object[] args) {650 double f1 = c.execute_double(n, 0, args);651 double f2 = c.execute_double(n, 1, args);652 653 double ret = 1;654 if(f1 > f2) {655 ret = 0;656 }657 return ret;658 }659 660 public Object clone() {661 try {662 GT result = new GT(getGPConfiguration(), getReturnType());663 return result;664 }catch(Exception ex) {665 throw new CloneException(ex);666 }667 }668 }669 722 }
Note: See TracChangeset
for help on using the changeset viewer.