source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLocalFQTraining.java @ 52

Last change on this file since 52 was 41, checked in by sherbold, 9 years ago
  • formatted code and added copyrights
File size: 30.4 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.training;
16
17import java.io.PrintStream;
18import java.util.ArrayList;
19import java.util.HashMap;
20import java.util.HashSet;
21import java.util.Iterator;
22import java.util.Random;
23import java.util.Set;
24import java.util.logging.Level;
25
26import org.apache.commons.io.output.NullOutputStream;
27
28import de.ugoe.cs.cpdp.training.QuadTree;
29import de.ugoe.cs.util.console.Console;
30import weka.classifiers.AbstractClassifier;
31import weka.classifiers.Classifier;
32import weka.core.DenseInstance;
33import weka.core.EuclideanDistance;
34import weka.core.Instance;
35import weka.core.Instances;
36import weka.filters.Filter;
37import weka.filters.unsupervised.attribute.Remove;
38
39/**
40 * Trainer with reimplementation of WHERE clustering algorithm from: Tim Menzies, Andrew Butcher,
41 * David Cok, Andrian Marcus, Lucas Layman, Forrest Shull, Burak Turhan, Thomas Zimmermann,
42 * "Local versus Global Lessons for Defect Prediction and Effort Estimation," IEEE Transactions on
43 * Software Engineering, vol. 39, no. 6, pp. 822-834, June, 2013
44 *
45 * With WekaLocalFQTraining we do the following: 1) Run the Fastmap algorithm on all training data,
46 * let it calculate the 2 most significant dimensions and projections of each instance to these
47 * dimensions 2) With these 2 dimensions we span a QuadTree which gets recursively split on
48 * median(x) and median(y) values. 3) We cluster the QuadTree nodes together if they have similar
49 * density (50%) 4) We save the clusters and their training data 5) We only use clusters with >
50 * ALPHA instances (currently Math.sqrt(SIZE)), rest is discarded with the training data of this
51 * cluster 6) We train a Weka classifier for each cluster with the clusters training data 7) We
52 * recalculate Fastmap distances for a single instance with the old pivots and then try to find a
53 * cluster containing the coords of the instance. 7.1.) If we can not find a cluster (due to coords
54 * outside of all clusters) we find the nearest cluster. 8) We classify the Instance with the
55 * classifier and traindata from the Cluster we found in 7.
56 */
57public class WekaLocalFQTraining extends WekaBaseTraining implements ITrainingStrategy {
58
59    private final TraindatasetCluster classifier = new TraindatasetCluster();
60
61    @Override
62    public Classifier getClassifier() {
63        return classifier;
64    }
65
66    @Override
67    public void apply(Instances traindata) {
68        PrintStream errStr = System.err;
69        System.setErr(new PrintStream(new NullOutputStream()));
70        try {
71            classifier.buildClassifier(traindata);
72        }
73        catch (Exception e) {
74            throw new RuntimeException(e);
75        }
76        finally {
77            System.setErr(errStr);
78        }
79    }
80
81    public class TraindatasetCluster extends AbstractClassifier {
82
83        private static final long serialVersionUID = 1L;
84
85        /* classifier per cluster */
86        private HashMap<Integer, Classifier> cclassifier;
87
88        /* instances per cluster */
89        private HashMap<Integer, Instances> ctraindata;
90
91        /*
92         * holds the instances and indices of the pivot objects of the Fastmap calculation in
93         * buildClassifier
94         */
95        private HashMap<Integer, Instance> cpivots;
96
97        /* holds the indices of the pivot objects for x,y and the dimension [x,y][dimension] */
98        private int[][] cpivotindices;
99
100        /* holds the sizes of the cluster multiple "boxes" per cluster */
101        private HashMap<Integer, ArrayList<Double[][]>> csize;
102
103        /* debug vars */
104        @SuppressWarnings("unused")
105        private boolean show_biggest = true;
106
107        @SuppressWarnings("unused")
108        private int CFOUND = 0;
109        @SuppressWarnings("unused")
110        private int CNOTFOUND = 0;
111
112        private Instance createInstance(Instances instances, Instance instance) {
113            // attributes for feeding instance to classifier
114            Set<String> attributeNames = new HashSet<>();
115            for (int j = 0; j < instances.numAttributes(); j++) {
116                attributeNames.add(instances.attribute(j).name());
117            }
118
119            double[] values = new double[instances.numAttributes()];
120            int index = 0;
121            for (int j = 0; j < instance.numAttributes(); j++) {
122                if (attributeNames.contains(instance.attribute(j).name())) {
123                    values[index] = instance.value(j);
124                    index++;
125                }
126            }
127
128            Instances tmp = new Instances(instances);
129            tmp.clear();
130            Instance instCopy = new DenseInstance(instance.weight(), values);
131            instCopy.setDataset(tmp);
132
133            return instCopy;
134        }
135
136        /**
137         * Because Fastmap saves only the image not the values of the attributes it used we can not
138         * use the old data directly to classify single instances to clusters.
139         *
140         * To classify a single instance we do a new fastmap computation with only the instance and
141         * the old pivot elements.
142         *
143         * After that we find the cluster with our fastmap result for x and y.
144         */
145        @Override
146        public double classifyInstance(Instance instance) {
147
148            double ret = 0;
149            try {
150                // classinstance gets passed to classifier
151                Instances traindata = ctraindata.get(0);
152                Instance classInstance = createInstance(traindata, instance);
153
154                // this one keeps the class attribute
155                Instances traindata2 = ctraindata.get(1);
156
157                // remove class attribute before clustering
158                Remove filter = new Remove();
159                filter.setAttributeIndices("" + (traindata.classIndex() + 1));
160                filter.setInputFormat(traindata);
161                traindata = Filter.useFilter(traindata, filter);
162                Instance clusterInstance = createInstance(traindata, instance);
163
164                Fastmap FMAP = new Fastmap(2);
165                EuclideanDistance dist = new EuclideanDistance(traindata);
166
167                // we set our pivot indices [x=0,y=1][dimension]
168                int[][] npivotindices = new int[2][2];
169                npivotindices[0][0] = 1;
170                npivotindices[1][0] = 2;
171                npivotindices[0][1] = 3;
172                npivotindices[1][1] = 4;
173
174                // build temp dist matrix (2 pivots per dimension + 1 instance we want to classify)
175                // the instance we want to classify comes first after that the pivot elements in the
176                // order defined above
177                double[][] distmat = new double[2 * FMAP.target_dims + 1][2 * FMAP.target_dims + 1];
178                distmat[0][0] = 0;
179                distmat[0][1] =
180                    dist.distance(clusterInstance,
181                                  this.cpivots.get((Integer) this.cpivotindices[0][0]));
182                distmat[0][2] =
183                    dist.distance(clusterInstance,
184                                  this.cpivots.get((Integer) this.cpivotindices[1][0]));
185                distmat[0][3] =
186                    dist.distance(clusterInstance,
187                                  this.cpivots.get((Integer) this.cpivotindices[0][1]));
188                distmat[0][4] =
189                    dist.distance(clusterInstance,
190                                  this.cpivots.get((Integer) this.cpivotindices[1][1]));
191
192                distmat[1][0] =
193                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]),
194                                  clusterInstance);
195                distmat[1][1] = 0;
196                distmat[1][2] =
197                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]),
198                                  this.cpivots.get((Integer) this.cpivotindices[1][0]));
199                distmat[1][3] =
200                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]),
201                                  this.cpivots.get((Integer) this.cpivotindices[0][1]));
202                distmat[1][4] =
203                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]),
204                                  this.cpivots.get((Integer) this.cpivotindices[1][1]));
205
206                distmat[2][0] =
207                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]),
208                                  clusterInstance);
209                distmat[2][1] =
210                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]),
211                                  this.cpivots.get((Integer) this.cpivotindices[0][0]));
212                distmat[2][2] = 0;
213                distmat[2][3] =
214                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]),
215                                  this.cpivots.get((Integer) this.cpivotindices[0][1]));
216                distmat[2][4] =
217                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]),
218                                  this.cpivots.get((Integer) this.cpivotindices[1][1]));
219
220                distmat[3][0] =
221                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]),
222                                  clusterInstance);
223                distmat[3][1] =
224                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]),
225                                  this.cpivots.get((Integer) this.cpivotindices[0][0]));
226                distmat[3][2] =
227                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]),
228                                  this.cpivots.get((Integer) this.cpivotindices[1][0]));
229                distmat[3][3] = 0;
230                distmat[3][4] =
231                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]),
232                                  this.cpivots.get((Integer) this.cpivotindices[1][1]));
233
234                distmat[4][0] =
235                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]),
236                                  clusterInstance);
237                distmat[4][1] =
238                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]),
239                                  this.cpivots.get((Integer) this.cpivotindices[0][0]));
240                distmat[4][2] =
241                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]),
242                                  this.cpivots.get((Integer) this.cpivotindices[1][0]));
243                distmat[4][3] =
244                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]),
245                                  this.cpivots.get((Integer) this.cpivotindices[0][1]));
246                distmat[4][4] = 0;
247
248                /*
249                 * debug output: show biggest distance found within the new distance matrix double
250                 * biggest = 0; for(int i=0; i < distmat.length; i++) { for(int j=0; j <
251                 * distmat[0].length; j++) { if(biggest < distmat[i][j]) { biggest = distmat[i][j];
252                 * } } } if(this.show_biggest) { Console.traceln(Level.INFO,
253                 * String.format(""+clusterInstance)); Console.traceln(Level.INFO,
254                 * String.format("biggest distances: "+ biggest)); this.show_biggest = false; }
255                 */
256
257                FMAP.setDistmat(distmat);
258                FMAP.setPivots(npivotindices);
259                FMAP.calculate();
260                double[][] x = FMAP.getX();
261                double[] proj = x[0];
262
263                // debug output: show the calculated distance matrix, our result vektor for the
264                // instance and the complete result matrix
265                /*
266                 * Console.traceln(Level.INFO, "distmat:"); for(int i=0; i<distmat.length; i++){
267                 * for(int j=0; j<distmat[0].length; j++){ Console.trace(Level.INFO,
268                 * String.format("%20s", distmat[i][j])); } Console.traceln(Level.INFO, ""); }
269                 *
270                 * Console.traceln(Level.INFO, "vector:"); for(int i=0; i < proj.length; i++) {
271                 * Console.trace(Level.INFO, String.format("%20s", proj[i])); }
272                 * Console.traceln(Level.INFO, "");
273                 *
274                 * Console.traceln(Level.INFO, "resultmat:"); for(int i=0; i<x.length; i++){ for(int
275                 * j=0; j<x[0].length; j++){ Console.trace(Level.INFO, String.format("%20s",
276                 * x[i][j])); } Console.traceln(Level.INFO, ""); }
277                 */
278
279                // now we iterate over all clusters (well, boxes of sizes per cluster really) and
280                // save the number of the
281                // cluster in which we are
282                int cnumber;
283                int found_cnumber = -1;
284                Iterator<Integer> clusternumber = this.csize.keySet().iterator();
285                while (clusternumber.hasNext() && found_cnumber == -1) {
286                    cnumber = clusternumber.next();
287
288                    // now iterate over the boxes of the cluster and hope we find one (cluster could
289                    // have been removed)
290                    // or we are too far away from any cluster because of the fastmap calculation
291                    // with the initial pivot objects
292                    for (int box = 0; box < this.csize.get(cnumber).size(); box++) {
293                        Double[][] current = this.csize.get(cnumber).get(box);
294
295                        if (proj[0] >= current[0][0] && proj[0] <= current[0][1] && // x
296                            proj[1] >= current[1][0] && proj[1] <= current[1][1])
297                        { // y
298                            found_cnumber = cnumber;
299                        }
300                    }
301                }
302
303                // we want to count how often we are really inside a cluster
304                // if ( found_cnumber == -1 ) {
305                // CNOTFOUND += 1;
306                // }else {
307                // CFOUND += 1;
308                // }
309
310                // now it can happen that we do not find a cluster because we deleted it previously
311                // (too few instances)
312                // or we get bigger distance measures from weka so that we are completely outside of
313                // our clusters.
314                // in these cases we just find the nearest cluster to our instance and use it for
315                // classification.
316                // to do that we use the EuclideanDistance again to compare our distance to all
317                // other Instances
318                // then we take the cluster of the closest weka instance
319                dist = new EuclideanDistance(traindata2);
320                if (!this.ctraindata.containsKey(found_cnumber)) {
321                    double min_distance = Double.MAX_VALUE;
322                    clusternumber = ctraindata.keySet().iterator();
323                    while (clusternumber.hasNext()) {
324                        cnumber = clusternumber.next();
325                        for (int i = 0; i < ctraindata.get(cnumber).size(); i++) {
326                            if (dist.distance(instance, ctraindata.get(cnumber).get(i)) <= min_distance)
327                            {
328                                found_cnumber = cnumber;
329                                min_distance =
330                                    dist.distance(instance, ctraindata.get(cnumber).get(i));
331                            }
332                        }
333                    }
334                }
335
336                // here we have the cluster where an instance has the minimum distance between
337                // itself and the
338                // instance we want to classify
339                // if we still have not found a cluster we exit because something is really wrong
340                if (found_cnumber == -1) {
341                    Console.traceln(Level.INFO, String
342                        .format("ERROR matching instance to cluster with full search!"));
343                    throw new RuntimeException("cluster not found with full search");
344                }
345
346                // classify the passed instance with the cluster we found and its training data
347                ret = cclassifier.get(found_cnumber).classifyInstance(classInstance);
348
349            }
350            catch (Exception e) {
351                Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!"));
352                throw new RuntimeException(e);
353            }
354            return ret;
355        }
356
357        @Override
358        public void buildClassifier(Instances traindata) throws Exception {
359
360            // Console.traceln(Level.INFO, String.format("found: "+ CFOUND + ", notfound: " +
361            // CNOTFOUND));
362            this.show_biggest = true;
363
364            cclassifier = new HashMap<Integer, Classifier>();
365            ctraindata = new HashMap<Integer, Instances>();
366            cpivots = new HashMap<Integer, Instance>();
367            cpivotindices = new int[2][2];
368
369            // 1. copy traindata
370            Instances train = new Instances(traindata);
371            Instances train2 = new Instances(traindata); // this one keeps the class attribute
372
373            // 2. remove class attribute for clustering
374            Remove filter = new Remove();
375            filter.setAttributeIndices("" + (train.classIndex() + 1));
376            filter.setInputFormat(train);
377            train = Filter.useFilter(train, filter);
378
379            // 3. calculate distance matrix (needed for Fastmap because it starts at dimension 1)
380            double biggest = 0;
381            EuclideanDistance dist = new EuclideanDistance(train);
382            double[][] distmat = new double[train.size()][train.size()];
383            for (int i = 0; i < train.size(); i++) {
384                for (int j = 0; j < train.size(); j++) {
385                    distmat[i][j] = dist.distance(train.get(i), train.get(j));
386                    if (distmat[i][j] > biggest) {
387                        biggest = distmat[i][j];
388                    }
389                }
390            }
391            // Console.traceln(Level.INFO, String.format("biggest distances: "+ biggest));
392
393            // 4. run fastmap for 2 dimensions on the distance matrix
394            Fastmap FMAP = new Fastmap(2);
395            FMAP.setDistmat(distmat);
396            FMAP.calculate();
397
398            cpivotindices = FMAP.getPivots();
399
400            double[][] X = FMAP.getX();
401            distmat = new double[0][0];
402            System.gc();
403
404            // quadtree payload generation
405            ArrayList<QuadTreePayload<Instance>> qtp = new ArrayList<QuadTreePayload<Instance>>();
406
407            // we need these for the sizes of the quadrants
408            double[] big =
409                { 0, 0 };
410            double[] small =
411                { Double.MAX_VALUE, Double.MAX_VALUE };
412
413            // set quadtree payload values and get max and min x and y values for size
414            for (int i = 0; i < X.length; i++) {
415                if (X[i][0] >= big[0]) {
416                    big[0] = X[i][0];
417                }
418                if (X[i][1] >= big[1]) {
419                    big[1] = X[i][1];
420                }
421                if (X[i][0] <= small[0]) {
422                    small[0] = X[i][0];
423                }
424                if (X[i][1] <= small[1]) {
425                    small[1] = X[i][1];
426                }
427                QuadTreePayload<Instance> tmp =
428                    new QuadTreePayload<Instance>(X[i][0], X[i][1], train2.get(i));
429                qtp.add(tmp);
430            }
431
432            // Console.traceln(Level.INFO,
433            // String.format("size for cluster ("+small[0]+","+small[1]+") - ("+big[0]+","+big[1]+")"));
434
435            // 5. generate quadtree
436            QuadTree TREE = new QuadTree(null, qtp);
437            QuadTree.size = train.size();
438            QuadTree.alpha = Math.sqrt(train.size());
439            QuadTree.ccluster = new ArrayList<ArrayList<QuadTreePayload<Instance>>>();
440            QuadTree.csize = new HashMap<Integer, ArrayList<Double[][]>>();
441
442            // Console.traceln(Level.INFO, String.format("Generate QuadTree with "+ QuadTree.size +
443            // " size, Alpha: "+ QuadTree.alpha+ ""));
444
445            // set the size and then split the tree recursively at the median value for x, y
446            TREE.setSize(new double[]
447                { small[0], big[0] }, new double[]
448                { small[1], big[1] });
449
450            // recursive split und grid clustering eher static
451            TREE.recursiveSplit(TREE);
452
453            // generate list of nodes sorted by density (childs only)
454            ArrayList<QuadTree> l = new ArrayList<QuadTree>(TREE.getList(TREE));
455
456            // recursive grid clustering (tree pruning), the values are stored in ccluster
457            TREE.gridClustering(l);
458
459            // wir iterieren durch die cluster und sammeln uns die instanzen daraus
460            // ctraindata.clear();
461            for (int i = 0; i < QuadTree.ccluster.size(); i++) {
462                ArrayList<QuadTreePayload<Instance>> current = QuadTree.ccluster.get(i);
463
464                // i is the clusternumber
465                // we only allow clusters with Instances > ALPHA, other clusters are not considered!
466                // if(current.size() > QuadTree.alpha) {
467                if (current.size() > 4) {
468                    for (int j = 0; j < current.size(); j++) {
469                        if (!ctraindata.containsKey(i)) {
470                            ctraindata.put(i, new Instances(train2));
471                            ctraindata.get(i).delete();
472                        }
473                        ctraindata.get(i).add(current.get(j).getInst());
474                    }
475                }
476                else {
477                    Console.traceln(Level.INFO,
478                                    String.format("drop cluster, only: " + current.size() +
479                                        " instances"));
480                }
481            }
482
483            // here we keep things we need later on
484            // QuadTree sizes for later use (matching new instances)
485            this.csize = new HashMap<Integer, ArrayList<Double[][]>>(QuadTree.csize);
486
487            // pivot elements
488            // this.cpivots.clear();
489            for (int i = 0; i < FMAP.PA[0].length; i++) {
490                this.cpivots.put(FMAP.PA[0][i], (Instance) train.get(FMAP.PA[0][i]).copy());
491            }
492            for (int j = 0; j < FMAP.PA[0].length; j++) {
493                this.cpivots.put(FMAP.PA[1][j], (Instance) train.get(FMAP.PA[1][j]).copy());
494            }
495
496            /*
497             * debug output int pnumber; Iterator<Integer> pivotnumber =
498             * cpivots.keySet().iterator(); while ( pivotnumber.hasNext() ) { pnumber =
499             * pivotnumber.next(); Console.traceln(Level.INFO, String.format("pivot: "+pnumber+
500             * " inst: "+cpivots.get(pnumber))); }
501             */
502
503            // train one classifier per cluster, we get the cluster number from the traindata
504            int cnumber;
505            Iterator<Integer> clusternumber = ctraindata.keySet().iterator();
506            // cclassifier.clear();
507
508            // int traindata_count = 0;
509            while (clusternumber.hasNext()) {
510                cnumber = clusternumber.next();
511                cclassifier.put(cnumber, setupClassifier()); // this is the classifier used for the
512                                                             // cluster
513                cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber));
514                // Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber));
515                // traindata_count += ctraindata.get(cnumber).size();
516                // Console.traceln(Level.INFO,
517                // String.format("building classifier in cluster "+cnumber +"  with "+
518                // ctraindata.get(cnumber).size() +" traindata instances"));
519            }
520
521            // add all traindata
522            // Console.traceln(Level.INFO, String.format("traindata in all clusters: " +
523            // traindata_count));
524        }
525    }
526
527    /**
528     * Payload for the QuadTree. x and y are the calculated Fastmap values. T is a weka instance.
529     */
530    public class QuadTreePayload<T> {
531
532        public double x;
533        public double y;
534        private T inst;
535
536        public QuadTreePayload(double x, double y, T value) {
537            this.x = x;
538            this.y = y;
539            this.inst = value;
540        }
541
542        public T getInst() {
543            return this.inst;
544        }
545    }
546
547    /**
548     * Fastmap implementation
549     *
550     * Faloutsos, C., & Lin, K. I. (1995). FastMap: A fast algorithm for indexing, data-mining and
551     * visualization of traditional and multimedia datasets (Vol. 24, No. 2, pp. 163-174). ACM.
552     */
553    public class Fastmap {
554
555        /* N x k Array, at the end, the i-th row will be the image of the i-th object */
556        private double[][] X;
557
558        /* 2 x k pivot Array one pair per recursive call */
559        private int[][] PA;
560
561        /* Objects we got (distance matrix) */
562        private double[][] O;
563
564        /* column of X currently updated (also the dimension) */
565        private int col = 0;
566
567        /* number of dimensions we want */
568        private int target_dims = 0;
569
570        // if we already have the pivot elements
571        private boolean pivot_set = false;
572
573        public Fastmap(int k) {
574            this.target_dims = k;
575        }
576
577        /**
578         * Sets the distance matrix and params that depend on this
579         *
580         * @param O
581         */
582        public void setDistmat(double[][] O) {
583            this.O = O;
584            int N = O.length;
585            this.X = new double[N][this.target_dims];
586            this.PA = new int[2][this.target_dims];
587        }
588
589        /**
590         * Set pivot elements, we need that to classify instances after the calculation is complete
591         * (because we then want to reuse only the pivot elements).
592         *
593         * @param pi
594         */
595        public void setPivots(int[][] pi) {
596            this.pivot_set = true;
597            this.PA = pi;
598        }
599
600        /**
601         * Return the pivot elements that were chosen during the calculation
602         *
603         * @return
604         */
605        public int[][] getPivots() {
606            return this.PA;
607        }
608
609        /**
610         * The distance function for euclidean distance
611         *
612         * Acts according to equation 4 of the fastmap paper
613         *
614         * @param x
615         *            x index of x image (if k==0 x object)
616         * @param y
617         *            y index of y image (if k==0 y object)
618         * @param kdimensionality
619         * @return distance
620         */
621        private double dist(int x, int y, int k) {
622
623            // basis is object distance, we get this from our distance matrix
624            double tmp = this.O[x][y] * this.O[x][y];
625
626            // decrease by projections
627            for (int i = 0; i < k; i++) {
628                double tmp2 = (this.X[x][i] - this.X[y][i]);
629                tmp -= tmp2 * tmp2;
630            }
631
632            return Math.abs(tmp);
633        }
634
635        /**
636         * Find the object farthest from the given index This method is a helper Method for
637         * findDistandObjects
638         *
639         * @param index
640         *            of the object
641         * @return index of the farthest object from the given index
642         */
643        private int findFarthest(int index) {
644            double furthest = Double.MIN_VALUE;
645            int ret = 0;
646
647            for (int i = 0; i < O.length; i++) {
648                double dist = this.dist(i, index, this.col);
649                if (i != index && dist > furthest) {
650                    furthest = dist;
651                    ret = i;
652                }
653            }
654            return ret;
655        }
656
657        /**
658         * Finds the pivot objects
659         *
660         * This method is basically algorithm 1 of the fastmap paper.
661         *
662         * @return 2 indexes of the choosen pivot objects
663         */
664        private int[] findDistantObjects() {
665            // 1. choose object randomly
666            Random r = new Random();
667            int obj = r.nextInt(this.O.length);
668
669            // 2. find farthest object from randomly chosen object
670            int idx1 = this.findFarthest(obj);
671
672            // 3. find farthest object from previously farthest object
673            int idx2 = this.findFarthest(idx1);
674
675            return new int[]
676                { idx1, idx2 };
677        }
678
679        /**
680         * Calculates the new k-vector values (projections)
681         *
682         * This is basically algorithm 2 of the fastmap paper. We just added the possibility to
683         * pre-set the pivot elements because we need to classify single instances after the
684         * computation is already done.
685         *
686         * @param dims
687         *            dimensionality
688         */
689        public void calculate() {
690
691            for (int k = 0; k < this.target_dims; k++) {
692                // 2) choose pivot objects
693                if (!this.pivot_set) {
694                    int[] pivots = this.findDistantObjects();
695
696                    // 3) record ids of pivot objects
697                    this.PA[0][this.col] = pivots[0];
698                    this.PA[1][this.col] = pivots[1];
699                }
700
701                // 4) inter object distances are zero (this.X is initialized with 0 so we just
702                // continue)
703                if (this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col) == 0) {
704                    continue;
705                }
706
707                // 5) project the objects on the line between the pivots
708                double dxy = this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col);
709                for (int i = 0; i < this.O.length; i++) {
710
711                    double dix = this.dist(i, this.PA[0][this.col], this.col);
712                    double diy = this.dist(i, this.PA[1][this.col], this.col);
713
714                    double tmp = (dix + dxy - diy) / (2 * Math.sqrt(dxy));
715
716                    // save the projection
717                    this.X[i][this.col] = tmp;
718                }
719
720                this.col += 1;
721            }
722        }
723
724        /**
725         * returns the result matrix of the projections
726         *
727         * @return calculated result
728         */
729        public double[][] getX() {
730            return this.X;
731        }
732    }
733}
Note: See TracBrowser for help on using the repository browser.