source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLocalFQTraining.java @ 112

Last change on this file since 112 was 99, checked in by sherbold, 9 years ago
  • improved error reporting
File size: 30.0 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.training;
16
17import java.util.ArrayList;
18import java.util.HashMap;
19import java.util.HashSet;
20import java.util.Iterator;
21import java.util.Random;
22import java.util.Set;
23import java.util.logging.Level;
24
25import de.ugoe.cs.cpdp.training.QuadTree;
26import de.ugoe.cs.util.console.Console;
27import weka.classifiers.AbstractClassifier;
28import weka.classifiers.Classifier;
29import weka.core.DenseInstance;
30import weka.core.EuclideanDistance;
31import weka.core.Instance;
32import weka.core.Instances;
33import weka.filters.Filter;
34import weka.filters.unsupervised.attribute.Remove;
35
36/**
37 * Trainer with reimplementation of WHERE clustering algorithm from: Tim Menzies, Andrew Butcher,
38 * David Cok, Andrian Marcus, Lucas Layman, Forrest Shull, Burak Turhan, Thomas Zimmermann,
39 * "Local versus Global Lessons for Defect Prediction and Effort Estimation," IEEE Transactions on
40 * Software Engineering, vol. 39, no. 6, pp. 822-834, June, 2013
41 *
42 * With WekaLocalFQTraining we do the following: 1) Run the Fastmap algorithm on all training data,
43 * let it calculate the 2 most significant dimensions and projections of each instance to these
44 * dimensions 2) With these 2 dimensions we span a QuadTree which gets recursively split on
45 * median(x) and median(y) values. 3) We cluster the QuadTree nodes together if they have similar
46 * density (50%) 4) We save the clusters and their training data 5) We only use clusters with >
47 * ALPHA instances (currently Math.sqrt(SIZE)), rest is discarded with the training data of this
48 * cluster 6) We train a Weka classifier for each cluster with the clusters training data 7) We
49 * recalculate Fastmap distances for a single instance with the old pivots and then try to find a
50 * cluster containing the coords of the instance. 7.1.) If we can not find a cluster (due to coords
51 * outside of all clusters) we find the nearest cluster. 8) We classify the Instance with the
52 * classifier and traindata from the Cluster we found in 7.
53 */
54public class WekaLocalFQTraining extends WekaBaseTraining implements ITrainingStrategy {
55
56    private final TraindatasetCluster classifier = new TraindatasetCluster();
57
58    @Override
59    public Classifier getClassifier() {
60        return classifier;
61    }
62
63    @Override
64    public void apply(Instances traindata) {
65        try {
66            classifier.buildClassifier(traindata);
67        }
68        catch (Exception e) {
69            throw new RuntimeException(e);
70        }
71    }
72
73    public class TraindatasetCluster extends AbstractClassifier {
74
75        private static final long serialVersionUID = 1L;
76
77        /* classifier per cluster */
78        private HashMap<Integer, Classifier> cclassifier;
79
80        /* instances per cluster */
81        private HashMap<Integer, Instances> ctraindata;
82
83        /*
84         * holds the instances and indices of the pivot objects of the Fastmap calculation in
85         * buildClassifier
86         */
87        private HashMap<Integer, Instance> cpivots;
88
89        /* holds the indices of the pivot objects for x,y and the dimension [x,y][dimension] */
90        private int[][] cpivotindices;
91
92        /* holds the sizes of the cluster multiple "boxes" per cluster */
93        private HashMap<Integer, ArrayList<Double[][]>> csize;
94
95        /* debug vars */
96        @SuppressWarnings("unused")
97        private boolean show_biggest = true;
98
99        @SuppressWarnings("unused")
100        private int CFOUND = 0;
101        @SuppressWarnings("unused")
102        private int CNOTFOUND = 0;
103
104        private Instance createInstance(Instances instances, Instance instance) {
105            // attributes for feeding instance to classifier
106            Set<String> attributeNames = new HashSet<>();
107            for (int j = 0; j < instances.numAttributes(); j++) {
108                attributeNames.add(instances.attribute(j).name());
109            }
110
111            double[] values = new double[instances.numAttributes()];
112            int index = 0;
113            for (int j = 0; j < instance.numAttributes(); j++) {
114                if (attributeNames.contains(instance.attribute(j).name())) {
115                    values[index] = instance.value(j);
116                    index++;
117                }
118            }
119
120            Instances tmp = new Instances(instances);
121            tmp.clear();
122            Instance instCopy = new DenseInstance(instance.weight(), values);
123            instCopy.setDataset(tmp);
124
125            return instCopy;
126        }
127
128        /**
129         * Because Fastmap saves only the image not the values of the attributes it used we can not
130         * use the old data directly to classify single instances to clusters.
131         *
132         * To classify a single instance we do a new fastmap computation with only the instance and
133         * the old pivot elements.
134         *
135         * After that we find the cluster with our fastmap result for x and y.
136         */
137        @Override
138        public double classifyInstance(Instance instance) {
139
140            double ret = 0;
141            try {
142                // classinstance gets passed to classifier
143                Instances traindata = ctraindata.get(0);
144                Instance classInstance = createInstance(traindata, instance);
145
146                // this one keeps the class attribute
147                Instances traindata2 = ctraindata.get(1);
148
149                // remove class attribute before clustering
150                Remove filter = new Remove();
151                filter.setAttributeIndices("" + (traindata.classIndex() + 1));
152                filter.setInputFormat(traindata);
153                traindata = Filter.useFilter(traindata, filter);
154                Instance clusterInstance = createInstance(traindata, instance);
155
156                Fastmap FMAP = new Fastmap(2);
157                EuclideanDistance dist = new EuclideanDistance(traindata);
158
159                // we set our pivot indices [x=0,y=1][dimension]
160                int[][] npivotindices = new int[2][2];
161                npivotindices[0][0] = 1;
162                npivotindices[1][0] = 2;
163                npivotindices[0][1] = 3;
164                npivotindices[1][1] = 4;
165
166                // build temp dist matrix (2 pivots per dimension + 1 instance we want to classify)
167                // the instance we want to classify comes first after that the pivot elements in the
168                // order defined above
169                double[][] distmat = new double[2 * FMAP.target_dims + 1][2 * FMAP.target_dims + 1];
170                distmat[0][0] = 0;
171                distmat[0][1] =
172                    dist.distance(clusterInstance,
173                                  this.cpivots.get((Integer) this.cpivotindices[0][0]));
174                distmat[0][2] =
175                    dist.distance(clusterInstance,
176                                  this.cpivots.get((Integer) this.cpivotindices[1][0]));
177                distmat[0][3] =
178                    dist.distance(clusterInstance,
179                                  this.cpivots.get((Integer) this.cpivotindices[0][1]));
180                distmat[0][4] =
181                    dist.distance(clusterInstance,
182                                  this.cpivots.get((Integer) this.cpivotindices[1][1]));
183
184                distmat[1][0] =
185                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]),
186                                  clusterInstance);
187                distmat[1][1] = 0;
188                distmat[1][2] =
189                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]),
190                                  this.cpivots.get((Integer) this.cpivotindices[1][0]));
191                distmat[1][3] =
192                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]),
193                                  this.cpivots.get((Integer) this.cpivotindices[0][1]));
194                distmat[1][4] =
195                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]),
196                                  this.cpivots.get((Integer) this.cpivotindices[1][1]));
197
198                distmat[2][0] =
199                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]),
200                                  clusterInstance);
201                distmat[2][1] =
202                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]),
203                                  this.cpivots.get((Integer) this.cpivotindices[0][0]));
204                distmat[2][2] = 0;
205                distmat[2][3] =
206                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]),
207                                  this.cpivots.get((Integer) this.cpivotindices[0][1]));
208                distmat[2][4] =
209                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]),
210                                  this.cpivots.get((Integer) this.cpivotindices[1][1]));
211
212                distmat[3][0] =
213                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]),
214                                  clusterInstance);
215                distmat[3][1] =
216                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]),
217                                  this.cpivots.get((Integer) this.cpivotindices[0][0]));
218                distmat[3][2] =
219                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]),
220                                  this.cpivots.get((Integer) this.cpivotindices[1][0]));
221                distmat[3][3] = 0;
222                distmat[3][4] =
223                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]),
224                                  this.cpivots.get((Integer) this.cpivotindices[1][1]));
225
226                distmat[4][0] =
227                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]),
228                                  clusterInstance);
229                distmat[4][1] =
230                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]),
231                                  this.cpivots.get((Integer) this.cpivotindices[0][0]));
232                distmat[4][2] =
233                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]),
234                                  this.cpivots.get((Integer) this.cpivotindices[1][0]));
235                distmat[4][3] =
236                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]),
237                                  this.cpivots.get((Integer) this.cpivotindices[0][1]));
238                distmat[4][4] = 0;
239
240                /*
241                 * debug output: show biggest distance found within the new distance matrix double
242                 * biggest = 0; for(int i=0; i < distmat.length; i++) { for(int j=0; j <
243                 * distmat[0].length; j++) { if(biggest < distmat[i][j]) { biggest = distmat[i][j];
244                 * } } } if(this.show_biggest) { Console.traceln(Level.INFO,
245                 * String.format(""+clusterInstance)); Console.traceln(Level.INFO,
246                 * String.format("biggest distances: "+ biggest)); this.show_biggest = false; }
247                 */
248
249                FMAP.setDistmat(distmat);
250                FMAP.setPivots(npivotindices);
251                FMAP.calculate();
252                double[][] x = FMAP.getX();
253                double[] proj = x[0];
254
255                // debug output: show the calculated distance matrix, our result vektor for the
256                // instance and the complete result matrix
257                /*
258                 * Console.traceln(Level.INFO, "distmat:"); for(int i=0; i<distmat.length; i++){
259                 * for(int j=0; j<distmat[0].length; j++){ Console.trace(Level.INFO,
260                 * String.format("%20s", distmat[i][j])); } Console.traceln(Level.INFO, ""); }
261                 *
262                 * Console.traceln(Level.INFO, "vector:"); for(int i=0; i < proj.length; i++) {
263                 * Console.trace(Level.INFO, String.format("%20s", proj[i])); }
264                 * Console.traceln(Level.INFO, "");
265                 *
266                 * Console.traceln(Level.INFO, "resultmat:"); for(int i=0; i<x.length; i++){ for(int
267                 * j=0; j<x[0].length; j++){ Console.trace(Level.INFO, String.format("%20s",
268                 * x[i][j])); } Console.traceln(Level.INFO, ""); }
269                 */
270
271                // now we iterate over all clusters (well, boxes of sizes per cluster really) and
272                // save the number of the
273                // cluster in which we are
274                int cnumber;
275                int found_cnumber = -1;
276                Iterator<Integer> clusternumber = this.csize.keySet().iterator();
277                while (clusternumber.hasNext() && found_cnumber == -1) {
278                    cnumber = clusternumber.next();
279
280                    // now iterate over the boxes of the cluster and hope we find one (cluster could
281                    // have been removed)
282                    // or we are too far away from any cluster because of the fastmap calculation
283                    // with the initial pivot objects
284                    for (int box = 0; box < this.csize.get(cnumber).size(); box++) {
285                        Double[][] current = this.csize.get(cnumber).get(box);
286
287                        if (proj[0] >= current[0][0] && proj[0] <= current[0][1] && // x
288                            proj[1] >= current[1][0] && proj[1] <= current[1][1])
289                        { // y
290                            found_cnumber = cnumber;
291                        }
292                    }
293                }
294
295                // we want to count how often we are really inside a cluster
296                // if ( found_cnumber == -1 ) {
297                // CNOTFOUND += 1;
298                // }else {
299                // CFOUND += 1;
300                // }
301
302                // now it can happen that we do not find a cluster because we deleted it previously
303                // (too few instances)
304                // or we get bigger distance measures from weka so that we are completely outside of
305                // our clusters.
306                // in these cases we just find the nearest cluster to our instance and use it for
307                // classification.
308                // to do that we use the EuclideanDistance again to compare our distance to all
309                // other Instances
310                // then we take the cluster of the closest weka instance
311                dist = new EuclideanDistance(traindata2);
312                if (!this.ctraindata.containsKey(found_cnumber)) {
313                    double min_distance = Double.MAX_VALUE;
314                    clusternumber = ctraindata.keySet().iterator();
315                    while (clusternumber.hasNext()) {
316                        cnumber = clusternumber.next();
317                        for (int i = 0; i < ctraindata.get(cnumber).size(); i++) {
318                            if (dist.distance(instance, ctraindata.get(cnumber).get(i)) <= min_distance)
319                            {
320                                found_cnumber = cnumber;
321                                min_distance =
322                                    dist.distance(instance, ctraindata.get(cnumber).get(i));
323                            }
324                        }
325                    }
326                }
327
328                // here we have the cluster where an instance has the minimum distance between
329                // itself and the
330                // instance we want to classify
331                // if we still have not found a cluster we exit because something is really wrong
332                if (found_cnumber == -1) {
333                    Console.traceln(Level.INFO, String
334                        .format("ERROR matching instance to cluster with full search!"));
335                    throw new RuntimeException("cluster not found with full search");
336                }
337
338                // classify the passed instance with the cluster we found and its training data
339                ret = cclassifier.get(found_cnumber).classifyInstance(classInstance);
340
341            }
342            catch (Exception e) {
343                Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!"));
344                throw new RuntimeException(e);
345            }
346            return ret;
347        }
348
349        @Override
350        public void buildClassifier(Instances traindata) throws Exception {
351
352            // Console.traceln(Level.INFO, String.format("found: "+ CFOUND + ", notfound: " +
353            // CNOTFOUND));
354            this.show_biggest = true;
355
356            cclassifier = new HashMap<Integer, Classifier>();
357            ctraindata = new HashMap<Integer, Instances>();
358            cpivots = new HashMap<Integer, Instance>();
359            cpivotindices = new int[2][2];
360
361            // 1. copy traindata
362            Instances train = new Instances(traindata);
363            Instances train2 = new Instances(traindata); // this one keeps the class attribute
364
365            // 2. remove class attribute for clustering
366            Remove filter = new Remove();
367            filter.setAttributeIndices("" + (train.classIndex() + 1));
368            filter.setInputFormat(train);
369            train = Filter.useFilter(train, filter);
370
371            // 3. calculate distance matrix (needed for Fastmap because it starts at dimension 1)
372            double biggest = 0;
373            EuclideanDistance dist = new EuclideanDistance(train);
374            double[][] distmat = new double[train.size()][train.size()];
375            for (int i = 0; i < train.size(); i++) {
376                for (int j = 0; j < train.size(); j++) {
377                    distmat[i][j] = dist.distance(train.get(i), train.get(j));
378                    if (distmat[i][j] > biggest) {
379                        biggest = distmat[i][j];
380                    }
381                }
382            }
383            // Console.traceln(Level.INFO, String.format("biggest distances: "+ biggest));
384
385            // 4. run fastmap for 2 dimensions on the distance matrix
386            Fastmap FMAP = new Fastmap(2);
387            FMAP.setDistmat(distmat);
388            FMAP.calculate();
389
390            cpivotindices = FMAP.getPivots();
391
392            double[][] X = FMAP.getX();
393
394            // quadtree payload generation
395            ArrayList<QuadTreePayload<Instance>> qtp = new ArrayList<QuadTreePayload<Instance>>();
396
397            // we need these for the sizes of the quadrants
398            double[] big =
399                { 0, 0 };
400            double[] small =
401                { Double.MAX_VALUE, Double.MAX_VALUE };
402
403            // set quadtree payload values and get max and min x and y values for size
404            for (int i = 0; i < X.length; i++) {
405                if (X[i][0] >= big[0]) {
406                    big[0] = X[i][0];
407                }
408                if (X[i][1] >= big[1]) {
409                    big[1] = X[i][1];
410                }
411                if (X[i][0] <= small[0]) {
412                    small[0] = X[i][0];
413                }
414                if (X[i][1] <= small[1]) {
415                    small[1] = X[i][1];
416                }
417                QuadTreePayload<Instance> tmp =
418                    new QuadTreePayload<Instance>(X[i][0], X[i][1], train2.get(i));
419                qtp.add(tmp);
420            }
421
422            // Console.traceln(Level.INFO,
423            // String.format("size for cluster ("+small[0]+","+small[1]+") - ("+big[0]+","+big[1]+")"));
424
425            // 5. generate quadtree
426            QuadTree TREE = new QuadTree(null, qtp);
427            QuadTree.size = train.size();
428            QuadTree.alpha = Math.sqrt(train.size());
429            QuadTree.ccluster = new ArrayList<ArrayList<QuadTreePayload<Instance>>>();
430            QuadTree.csize = new HashMap<Integer, ArrayList<Double[][]>>();
431
432            // Console.traceln(Level.INFO, String.format("Generate QuadTree with "+ QuadTree.size +
433            // " size, Alpha: "+ QuadTree.alpha+ ""));
434
435            // set the size and then split the tree recursively at the median value for x, y
436            TREE.setSize(new double[]
437                { small[0], big[0] }, new double[]
438                { small[1], big[1] });
439
440            // recursive split und grid clustering eher static
441            TREE.recursiveSplit(TREE);
442
443            // generate list of nodes sorted by density (childs only)
444            ArrayList<QuadTree> l = new ArrayList<QuadTree>(TREE.getList(TREE));
445
446            // recursive grid clustering (tree pruning), the values are stored in ccluster
447            TREE.gridClustering(l);
448
449            // wir iterieren durch die cluster und sammeln uns die instanzen daraus
450            // ctraindata.clear();
451            for (int i = 0; i < QuadTree.ccluster.size(); i++) {
452                ArrayList<QuadTreePayload<Instance>> current = QuadTree.ccluster.get(i);
453
454                // i is the clusternumber
455                // we only allow clusters with Instances > ALPHA, other clusters are not considered!
456                // if(current.size() > QuadTree.alpha) {
457                if (current.size() > 4) {
458                    for (int j = 0; j < current.size(); j++) {
459                        if (!ctraindata.containsKey(i)) {
460                            ctraindata.put(i, new Instances(train2));
461                            ctraindata.get(i).delete();
462                        }
463                        ctraindata.get(i).add(current.get(j).getInst());
464                    }
465                }
466                else {
467                    Console.traceln(Level.INFO,
468                                    String.format("drop cluster, only: " + current.size() +
469                                        " instances"));
470                }
471            }
472
473            // here we keep things we need later on
474            // QuadTree sizes for later use (matching new instances)
475            this.csize = new HashMap<Integer, ArrayList<Double[][]>>(QuadTree.csize);
476
477            // pivot elements
478            // this.cpivots.clear();
479            for (int i = 0; i < FMAP.PA[0].length; i++) {
480                this.cpivots.put(FMAP.PA[0][i], (Instance) train.get(FMAP.PA[0][i]).copy());
481            }
482            for (int j = 0; j < FMAP.PA[0].length; j++) {
483                this.cpivots.put(FMAP.PA[1][j], (Instance) train.get(FMAP.PA[1][j]).copy());
484            }
485
486            /*
487             * debug output int pnumber; Iterator<Integer> pivotnumber =
488             * cpivots.keySet().iterator(); while ( pivotnumber.hasNext() ) { pnumber =
489             * pivotnumber.next(); Console.traceln(Level.INFO, String.format("pivot: "+pnumber+
490             * " inst: "+cpivots.get(pnumber))); }
491             */
492
493            // train one classifier per cluster, we get the cluster number from the traindata
494            int cnumber;
495            Iterator<Integer> clusternumber = ctraindata.keySet().iterator();
496            // cclassifier.clear();
497
498            // int traindata_count = 0;
499            while (clusternumber.hasNext()) {
500                cnumber = clusternumber.next();
501                cclassifier.put(cnumber, setupClassifier()); // this is the classifier used for the
502                                                             // cluster
503                cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber));
504                // Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber));
505                // traindata_count += ctraindata.get(cnumber).size();
506                // Console.traceln(Level.INFO,
507                // String.format("building classifier in cluster "+cnumber +"  with "+
508                // ctraindata.get(cnumber).size() +" traindata instances"));
509            }
510
511            // add all traindata
512            // Console.traceln(Level.INFO, String.format("traindata in all clusters: " +
513            // traindata_count));
514        }
515    }
516
517    /**
518     * Payload for the QuadTree. x and y are the calculated Fastmap values. T is a weka instance.
519     */
520    public class QuadTreePayload<T> {
521
522        public double x;
523        public double y;
524        private T inst;
525
526        public QuadTreePayload(double x, double y, T value) {
527            this.x = x;
528            this.y = y;
529            this.inst = value;
530        }
531
532        public T getInst() {
533            return this.inst;
534        }
535    }
536
537    /**
538     * Fastmap implementation
539     *
540     * Faloutsos, C., & Lin, K. I. (1995). FastMap: A fast algorithm for indexing, data-mining and
541     * visualization of traditional and multimedia datasets (Vol. 24, No. 2, pp. 163-174). ACM.
542     */
543    public class Fastmap {
544
545        /* N x k Array, at the end, the i-th row will be the image of the i-th object */
546        private double[][] X;
547
548        /* 2 x k pivot Array one pair per recursive call */
549        private int[][] PA;
550
551        /* Objects we got (distance matrix) */
552        private double[][] O;
553
554        /* column of X currently updated (also the dimension) */
555        private int col = 0;
556
557        /* number of dimensions we want */
558        private int target_dims = 0;
559
560        // if we already have the pivot elements
561        private boolean pivot_set = false;
562
563        public Fastmap(int k) {
564            this.target_dims = k;
565        }
566
567        /**
568         * Sets the distance matrix and params that depend on this
569         *
570         * @param O
571         */
572        public void setDistmat(double[][] O) {
573            this.O = O;
574            int N = O.length;
575            this.X = new double[N][this.target_dims];
576            this.PA = new int[2][this.target_dims];
577        }
578
579        /**
580         * Set pivot elements, we need that to classify instances after the calculation is complete
581         * (because we then want to reuse only the pivot elements).
582         *
583         * @param pi
584         */
585        public void setPivots(int[][] pi) {
586            this.pivot_set = true;
587            this.PA = pi;
588        }
589
590        /**
591         * Return the pivot elements that were chosen during the calculation
592         *
593         * @return
594         */
595        public int[][] getPivots() {
596            return this.PA;
597        }
598
599        /**
600         * The distance function for euclidean distance
601         *
602         * Acts according to equation 4 of the fastmap paper
603         *
604         * @param x
605         *            x index of x image (if k==0 x object)
606         * @param y
607         *            y index of y image (if k==0 y object)
608         * @param kdimensionality
609         * @return distance
610         */
611        private double dist(int x, int y, int k) {
612
613            // basis is object distance, we get this from our distance matrix
614            double tmp = this.O[x][y] * this.O[x][y];
615
616            // decrease by projections
617            for (int i = 0; i < k; i++) {
618                double tmp2 = (this.X[x][i] - this.X[y][i]);
619                tmp -= tmp2 * tmp2;
620            }
621
622            return Math.abs(tmp);
623        }
624
625        /**
626         * Find the object farthest from the given index This method is a helper Method for
627         * findDistandObjects
628         *
629         * @param index
630         *            of the object
631         * @return index of the farthest object from the given index
632         */
633        private int findFarthest(int index) {
634            double furthest = Double.MIN_VALUE;
635            int ret = 0;
636
637            for (int i = 0; i < O.length; i++) {
638                double dist = this.dist(i, index, this.col);
639                if (i != index && dist > furthest) {
640                    furthest = dist;
641                    ret = i;
642                }
643            }
644            return ret;
645        }
646
647        /**
648         * Finds the pivot objects
649         *
650         * This method is basically algorithm 1 of the fastmap paper.
651         *
652         * @return 2 indexes of the choosen pivot objects
653         */
654        private int[] findDistantObjects() {
655            // 1. choose object randomly
656            Random r = new Random();
657            int obj = r.nextInt(this.O.length);
658
659            // 2. find farthest object from randomly chosen object
660            int idx1 = this.findFarthest(obj);
661
662            // 3. find farthest object from previously farthest object
663            int idx2 = this.findFarthest(idx1);
664
665            return new int[]
666                { idx1, idx2 };
667        }
668
669        /**
670         * Calculates the new k-vector values (projections)
671         *
672         * This is basically algorithm 2 of the fastmap paper. We just added the possibility to
673         * pre-set the pivot elements because we need to classify single instances after the
674         * computation is already done.
675         *
676         * @param dims
677         *            dimensionality
678         */
679        public void calculate() {
680
681            for (int k = 0; k < this.target_dims; k++) {
682                // 2) choose pivot objects
683                if (!this.pivot_set) {
684                    int[] pivots = this.findDistantObjects();
685
686                    // 3) record ids of pivot objects
687                    this.PA[0][this.col] = pivots[0];
688                    this.PA[1][this.col] = pivots[1];
689                }
690
691                // 4) inter object distances are zero (this.X is initialized with 0 so we just
692                // continue)
693                if (this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col) == 0) {
694                    continue;
695                }
696
697                // 5) project the objects on the line between the pivots
698                double dxy = this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col);
699                for (int i = 0; i < this.O.length; i++) {
700
701                    double dix = this.dist(i, this.PA[0][this.col], this.col);
702                    double diy = this.dist(i, this.PA[1][this.col], this.col);
703
704                    double tmp = (dix + dxy - diy) / (2 * Math.sqrt(dxy));
705
706                    // save the projection
707                    this.X[i][this.col] = tmp;
708                }
709
710                this.col += 1;
711            }
712        }
713
714        /**
715         * returns the result matrix of the projections
716         *
717         * @return calculated result
718         */
719        public double[][] getX() {
720            return this.X;
721        }
722    }
723}
Note: See TracBrowser for help on using the repository browser.