source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLocalFQTraining.java @ 87

Last change on this file since 87 was 86, checked in by sherbold, 9 years ago
  • switched workspace encoding to UTF-8 and fixed broken characters
File size: 30.3 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.training;
16
17import java.io.PrintStream;
18import java.util.ArrayList;
19import java.util.HashMap;
20import java.util.HashSet;
21import java.util.Iterator;
22import java.util.Random;
23import java.util.Set;
24import java.util.logging.Level;
25
26import org.apache.commons.io.output.NullOutputStream;
27
28import de.ugoe.cs.cpdp.training.QuadTree;
29import de.ugoe.cs.util.console.Console;
30import weka.classifiers.AbstractClassifier;
31import weka.classifiers.Classifier;
32import weka.core.DenseInstance;
33import weka.core.EuclideanDistance;
34import weka.core.Instance;
35import weka.core.Instances;
36import weka.filters.Filter;
37import weka.filters.unsupervised.attribute.Remove;
38
39/**
40 * Trainer with reimplementation of WHERE clustering algorithm from: Tim Menzies, Andrew Butcher,
41 * David Cok, Andrian Marcus, Lucas Layman, Forrest Shull, Burak Turhan, Thomas Zimmermann,
42 * "Local versus Global Lessons for Defect Prediction and Effort Estimation," IEEE Transactions on
43 * Software Engineering, vol. 39, no. 6, pp. 822-834, June, 2013
44 *
45 * With WekaLocalFQTraining we do the following: 1) Run the Fastmap algorithm on all training data,
46 * let it calculate the 2 most significant dimensions and projections of each instance to these
47 * dimensions 2) With these 2 dimensions we span a QuadTree which gets recursively split on
48 * median(x) and median(y) values. 3) We cluster the QuadTree nodes together if they have similar
49 * density (50%) 4) We save the clusters and their training data 5) We only use clusters with >
50 * ALPHA instances (currently Math.sqrt(SIZE)), rest is discarded with the training data of this
51 * cluster 6) We train a Weka classifier for each cluster with the clusters training data 7) We
52 * recalculate Fastmap distances for a single instance with the old pivots and then try to find a
53 * cluster containing the coords of the instance. 7.1.) If we can not find a cluster (due to coords
54 * outside of all clusters) we find the nearest cluster. 8) We classify the Instance with the
55 * classifier and traindata from the Cluster we found in 7.
56 */
57public class WekaLocalFQTraining extends WekaBaseTraining implements ITrainingStrategy {
58
59    private final TraindatasetCluster classifier = new TraindatasetCluster();
60
61    @Override
62    public Classifier getClassifier() {
63        return classifier;
64    }
65
66    @Override
67    public void apply(Instances traindata) {
68        PrintStream errStr = System.err;
69        System.setErr(new PrintStream(new NullOutputStream()));
70        try {
71            classifier.buildClassifier(traindata);
72        }
73        catch (Exception e) {
74            throw new RuntimeException(e);
75        }
76        finally {
77            System.setErr(errStr);
78        }
79    }
80
81    public class TraindatasetCluster extends AbstractClassifier {
82
83        private static final long serialVersionUID = 1L;
84
85        /* classifier per cluster */
86        private HashMap<Integer, Classifier> cclassifier;
87
88        /* instances per cluster */
89        private HashMap<Integer, Instances> ctraindata;
90
91        /*
92         * holds the instances and indices of the pivot objects of the Fastmap calculation in
93         * buildClassifier
94         */
95        private HashMap<Integer, Instance> cpivots;
96
97        /* holds the indices of the pivot objects for x,y and the dimension [x,y][dimension] */
98        private int[][] cpivotindices;
99
100        /* holds the sizes of the cluster multiple "boxes" per cluster */
101        private HashMap<Integer, ArrayList<Double[][]>> csize;
102
103        /* debug vars */
104        @SuppressWarnings("unused")
105        private boolean show_biggest = true;
106
107        @SuppressWarnings("unused")
108        private int CFOUND = 0;
109        @SuppressWarnings("unused")
110        private int CNOTFOUND = 0;
111
112        private Instance createInstance(Instances instances, Instance instance) {
113            // attributes for feeding instance to classifier
114            Set<String> attributeNames = new HashSet<>();
115            for (int j = 0; j < instances.numAttributes(); j++) {
116                attributeNames.add(instances.attribute(j).name());
117            }
118
119            double[] values = new double[instances.numAttributes()];
120            int index = 0;
121            for (int j = 0; j < instance.numAttributes(); j++) {
122                if (attributeNames.contains(instance.attribute(j).name())) {
123                    values[index] = instance.value(j);
124                    index++;
125                }
126            }
127
128            Instances tmp = new Instances(instances);
129            tmp.clear();
130            Instance instCopy = new DenseInstance(instance.weight(), values);
131            instCopy.setDataset(tmp);
132
133            return instCopy;
134        }
135
136        /**
137         * Because Fastmap saves only the image not the values of the attributes it used we can not
138         * use the old data directly to classify single instances to clusters.
139         *
140         * To classify a single instance we do a new fastmap computation with only the instance and
141         * the old pivot elements.
142         *
143         * After that we find the cluster with our fastmap result for x and y.
144         */
145        @Override
146        public double classifyInstance(Instance instance) {
147
148            double ret = 0;
149            try {
150                // classinstance gets passed to classifier
151                Instances traindata = ctraindata.get(0);
152                Instance classInstance = createInstance(traindata, instance);
153
154                // this one keeps the class attribute
155                Instances traindata2 = ctraindata.get(1);
156
157                // remove class attribute before clustering
158                Remove filter = new Remove();
159                filter.setAttributeIndices("" + (traindata.classIndex() + 1));
160                filter.setInputFormat(traindata);
161                traindata = Filter.useFilter(traindata, filter);
162                Instance clusterInstance = createInstance(traindata, instance);
163
164                Fastmap FMAP = new Fastmap(2);
165                EuclideanDistance dist = new EuclideanDistance(traindata);
166
167                // we set our pivot indices [x=0,y=1][dimension]
168                int[][] npivotindices = new int[2][2];
169                npivotindices[0][0] = 1;
170                npivotindices[1][0] = 2;
171                npivotindices[0][1] = 3;
172                npivotindices[1][1] = 4;
173
174                // build temp dist matrix (2 pivots per dimension + 1 instance we want to classify)
175                // the instance we want to classify comes first after that the pivot elements in the
176                // order defined above
177                double[][] distmat = new double[2 * FMAP.target_dims + 1][2 * FMAP.target_dims + 1];
178                distmat[0][0] = 0;
179                distmat[0][1] =
180                    dist.distance(clusterInstance,
181                                  this.cpivots.get((Integer) this.cpivotindices[0][0]));
182                distmat[0][2] =
183                    dist.distance(clusterInstance,
184                                  this.cpivots.get((Integer) this.cpivotindices[1][0]));
185                distmat[0][3] =
186                    dist.distance(clusterInstance,
187                                  this.cpivots.get((Integer) this.cpivotindices[0][1]));
188                distmat[0][4] =
189                    dist.distance(clusterInstance,
190                                  this.cpivots.get((Integer) this.cpivotindices[1][1]));
191
192                distmat[1][0] =
193                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]),
194                                  clusterInstance);
195                distmat[1][1] = 0;
196                distmat[1][2] =
197                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]),
198                                  this.cpivots.get((Integer) this.cpivotindices[1][0]));
199                distmat[1][3] =
200                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]),
201                                  this.cpivots.get((Integer) this.cpivotindices[0][1]));
202                distmat[1][4] =
203                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]),
204                                  this.cpivots.get((Integer) this.cpivotindices[1][1]));
205
206                distmat[2][0] =
207                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]),
208                                  clusterInstance);
209                distmat[2][1] =
210                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]),
211                                  this.cpivots.get((Integer) this.cpivotindices[0][0]));
212                distmat[2][2] = 0;
213                distmat[2][3] =
214                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]),
215                                  this.cpivots.get((Integer) this.cpivotindices[0][1]));
216                distmat[2][4] =
217                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]),
218                                  this.cpivots.get((Integer) this.cpivotindices[1][1]));
219
220                distmat[3][0] =
221                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]),
222                                  clusterInstance);
223                distmat[3][1] =
224                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]),
225                                  this.cpivots.get((Integer) this.cpivotindices[0][0]));
226                distmat[3][2] =
227                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]),
228                                  this.cpivots.get((Integer) this.cpivotindices[1][0]));
229                distmat[3][3] = 0;
230                distmat[3][4] =
231                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]),
232                                  this.cpivots.get((Integer) this.cpivotindices[1][1]));
233
234                distmat[4][0] =
235                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]),
236                                  clusterInstance);
237                distmat[4][1] =
238                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]),
239                                  this.cpivots.get((Integer) this.cpivotindices[0][0]));
240                distmat[4][2] =
241                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]),
242                                  this.cpivots.get((Integer) this.cpivotindices[1][0]));
243                distmat[4][3] =
244                    dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]),
245                                  this.cpivots.get((Integer) this.cpivotindices[0][1]));
246                distmat[4][4] = 0;
247
248                /*
249                 * debug output: show biggest distance found within the new distance matrix double
250                 * biggest = 0; for(int i=0; i < distmat.length; i++) { for(int j=0; j <
251                 * distmat[0].length; j++) { if(biggest < distmat[i][j]) { biggest = distmat[i][j];
252                 * } } } if(this.show_biggest) { Console.traceln(Level.INFO,
253                 * String.format(""+clusterInstance)); Console.traceln(Level.INFO,
254                 * String.format("biggest distances: "+ biggest)); this.show_biggest = false; }
255                 */
256
257                FMAP.setDistmat(distmat);
258                FMAP.setPivots(npivotindices);
259                FMAP.calculate();
260                double[][] x = FMAP.getX();
261                double[] proj = x[0];
262
263                // debug output: show the calculated distance matrix, our result vektor for the
264                // instance and the complete result matrix
265                /*
266                 * Console.traceln(Level.INFO, "distmat:"); for(int i=0; i<distmat.length; i++){
267                 * for(int j=0; j<distmat[0].length; j++){ Console.trace(Level.INFO,
268                 * String.format("%20s", distmat[i][j])); } Console.traceln(Level.INFO, ""); }
269                 *
270                 * Console.traceln(Level.INFO, "vector:"); for(int i=0; i < proj.length; i++) {
271                 * Console.trace(Level.INFO, String.format("%20s", proj[i])); }
272                 * Console.traceln(Level.INFO, "");
273                 *
274                 * Console.traceln(Level.INFO, "resultmat:"); for(int i=0; i<x.length; i++){ for(int
275                 * j=0; j<x[0].length; j++){ Console.trace(Level.INFO, String.format("%20s",
276                 * x[i][j])); } Console.traceln(Level.INFO, ""); }
277                 */
278
279                // now we iterate over all clusters (well, boxes of sizes per cluster really) and
280                // save the number of the
281                // cluster in which we are
282                int cnumber;
283                int found_cnumber = -1;
284                Iterator<Integer> clusternumber = this.csize.keySet().iterator();
285                while (clusternumber.hasNext() && found_cnumber == -1) {
286                    cnumber = clusternumber.next();
287
288                    // now iterate over the boxes of the cluster and hope we find one (cluster could
289                    // have been removed)
290                    // or we are too far away from any cluster because of the fastmap calculation
291                    // with the initial pivot objects
292                    for (int box = 0; box < this.csize.get(cnumber).size(); box++) {
293                        Double[][] current = this.csize.get(cnumber).get(box);
294
295                        if (proj[0] >= current[0][0] && proj[0] <= current[0][1] && // x
296                            proj[1] >= current[1][0] && proj[1] <= current[1][1])
297                        { // y
298                            found_cnumber = cnumber;
299                        }
300                    }
301                }
302
303                // we want to count how often we are really inside a cluster
304                // if ( found_cnumber == -1 ) {
305                // CNOTFOUND += 1;
306                // }else {
307                // CFOUND += 1;
308                // }
309
310                // now it can happen that we do not find a cluster because we deleted it previously
311                // (too few instances)
312                // or we get bigger distance measures from weka so that we are completely outside of
313                // our clusters.
314                // in these cases we just find the nearest cluster to our instance and use it for
315                // classification.
316                // to do that we use the EuclideanDistance again to compare our distance to all
317                // other Instances
318                // then we take the cluster of the closest weka instance
319                dist = new EuclideanDistance(traindata2);
320                if (!this.ctraindata.containsKey(found_cnumber)) {
321                    double min_distance = Double.MAX_VALUE;
322                    clusternumber = ctraindata.keySet().iterator();
323                    while (clusternumber.hasNext()) {
324                        cnumber = clusternumber.next();
325                        for (int i = 0; i < ctraindata.get(cnumber).size(); i++) {
326                            if (dist.distance(instance, ctraindata.get(cnumber).get(i)) <= min_distance)
327                            {
328                                found_cnumber = cnumber;
329                                min_distance =
330                                    dist.distance(instance, ctraindata.get(cnumber).get(i));
331                            }
332                        }
333                    }
334                }
335
336                // here we have the cluster where an instance has the minimum distance between
337                // itself and the
338                // instance we want to classify
339                // if we still have not found a cluster we exit because something is really wrong
340                if (found_cnumber == -1) {
341                    Console.traceln(Level.INFO, String
342                        .format("ERROR matching instance to cluster with full search!"));
343                    throw new RuntimeException("cluster not found with full search");
344                }
345
346                // classify the passed instance with the cluster we found and its training data
347                ret = cclassifier.get(found_cnumber).classifyInstance(classInstance);
348
349            }
350            catch (Exception e) {
351                Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!"));
352                throw new RuntimeException(e);
353            }
354            return ret;
355        }
356
357        @Override
358        public void buildClassifier(Instances traindata) throws Exception {
359
360            // Console.traceln(Level.INFO, String.format("found: "+ CFOUND + ", notfound: " +
361            // CNOTFOUND));
362            this.show_biggest = true;
363
364            cclassifier = new HashMap<Integer, Classifier>();
365            ctraindata = new HashMap<Integer, Instances>();
366            cpivots = new HashMap<Integer, Instance>();
367            cpivotindices = new int[2][2];
368
369            // 1. copy traindata
370            Instances train = new Instances(traindata);
371            Instances train2 = new Instances(traindata); // this one keeps the class attribute
372
373            // 2. remove class attribute for clustering
374            Remove filter = new Remove();
375            filter.setAttributeIndices("" + (train.classIndex() + 1));
376            filter.setInputFormat(train);
377            train = Filter.useFilter(train, filter);
378
379            // 3. calculate distance matrix (needed for Fastmap because it starts at dimension 1)
380            double biggest = 0;
381            EuclideanDistance dist = new EuclideanDistance(train);
382            double[][] distmat = new double[train.size()][train.size()];
383            for (int i = 0; i < train.size(); i++) {
384                for (int j = 0; j < train.size(); j++) {
385                    distmat[i][j] = dist.distance(train.get(i), train.get(j));
386                    if (distmat[i][j] > biggest) {
387                        biggest = distmat[i][j];
388                    }
389                }
390            }
391            // Console.traceln(Level.INFO, String.format("biggest distances: "+ biggest));
392
393            // 4. run fastmap for 2 dimensions on the distance matrix
394            Fastmap FMAP = new Fastmap(2);
395            FMAP.setDistmat(distmat);
396            FMAP.calculate();
397
398            cpivotindices = FMAP.getPivots();
399
400            double[][] X = FMAP.getX();
401
402            // quadtree payload generation
403            ArrayList<QuadTreePayload<Instance>> qtp = new ArrayList<QuadTreePayload<Instance>>();
404
405            // we need these for the sizes of the quadrants
406            double[] big =
407                { 0, 0 };
408            double[] small =
409                { Double.MAX_VALUE, Double.MAX_VALUE };
410
411            // set quadtree payload values and get max and min x and y values for size
412            for (int i = 0; i < X.length; i++) {
413                if (X[i][0] >= big[0]) {
414                    big[0] = X[i][0];
415                }
416                if (X[i][1] >= big[1]) {
417                    big[1] = X[i][1];
418                }
419                if (X[i][0] <= small[0]) {
420                    small[0] = X[i][0];
421                }
422                if (X[i][1] <= small[1]) {
423                    small[1] = X[i][1];
424                }
425                QuadTreePayload<Instance> tmp =
426                    new QuadTreePayload<Instance>(X[i][0], X[i][1], train2.get(i));
427                qtp.add(tmp);
428            }
429
430            // Console.traceln(Level.INFO,
431            // String.format("size for cluster ("+small[0]+","+small[1]+") - ("+big[0]+","+big[1]+")"));
432
433            // 5. generate quadtree
434            QuadTree TREE = new QuadTree(null, qtp);
435            QuadTree.size = train.size();
436            QuadTree.alpha = Math.sqrt(train.size());
437            QuadTree.ccluster = new ArrayList<ArrayList<QuadTreePayload<Instance>>>();
438            QuadTree.csize = new HashMap<Integer, ArrayList<Double[][]>>();
439
440            // Console.traceln(Level.INFO, String.format("Generate QuadTree with "+ QuadTree.size +
441            // " size, Alpha: "+ QuadTree.alpha+ ""));
442
443            // set the size and then split the tree recursively at the median value for x, y
444            TREE.setSize(new double[]
445                { small[0], big[0] }, new double[]
446                { small[1], big[1] });
447
448            // recursive split und grid clustering eher static
449            TREE.recursiveSplit(TREE);
450
451            // generate list of nodes sorted by density (childs only)
452            ArrayList<QuadTree> l = new ArrayList<QuadTree>(TREE.getList(TREE));
453
454            // recursive grid clustering (tree pruning), the values are stored in ccluster
455            TREE.gridClustering(l);
456
457            // wir iterieren durch die cluster und sammeln uns die instanzen daraus
458            // ctraindata.clear();
459            for (int i = 0; i < QuadTree.ccluster.size(); i++) {
460                ArrayList<QuadTreePayload<Instance>> current = QuadTree.ccluster.get(i);
461
462                // i is the clusternumber
463                // we only allow clusters with Instances > ALPHA, other clusters are not considered!
464                // if(current.size() > QuadTree.alpha) {
465                if (current.size() > 4) {
466                    for (int j = 0; j < current.size(); j++) {
467                        if (!ctraindata.containsKey(i)) {
468                            ctraindata.put(i, new Instances(train2));
469                            ctraindata.get(i).delete();
470                        }
471                        ctraindata.get(i).add(current.get(j).getInst());
472                    }
473                }
474                else {
475                    Console.traceln(Level.INFO,
476                                    String.format("drop cluster, only: " + current.size() +
477                                        " instances"));
478                }
479            }
480
481            // here we keep things we need later on
482            // QuadTree sizes for later use (matching new instances)
483            this.csize = new HashMap<Integer, ArrayList<Double[][]>>(QuadTree.csize);
484
485            // pivot elements
486            // this.cpivots.clear();
487            for (int i = 0; i < FMAP.PA[0].length; i++) {
488                this.cpivots.put(FMAP.PA[0][i], (Instance) train.get(FMAP.PA[0][i]).copy());
489            }
490            for (int j = 0; j < FMAP.PA[0].length; j++) {
491                this.cpivots.put(FMAP.PA[1][j], (Instance) train.get(FMAP.PA[1][j]).copy());
492            }
493
494            /*
495             * debug output int pnumber; Iterator<Integer> pivotnumber =
496             * cpivots.keySet().iterator(); while ( pivotnumber.hasNext() ) { pnumber =
497             * pivotnumber.next(); Console.traceln(Level.INFO, String.format("pivot: "+pnumber+
498             * " inst: "+cpivots.get(pnumber))); }
499             */
500
501            // train one classifier per cluster, we get the cluster number from the traindata
502            int cnumber;
503            Iterator<Integer> clusternumber = ctraindata.keySet().iterator();
504            // cclassifier.clear();
505
506            // int traindata_count = 0;
507            while (clusternumber.hasNext()) {
508                cnumber = clusternumber.next();
509                cclassifier.put(cnumber, setupClassifier()); // this is the classifier used for the
510                                                             // cluster
511                cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber));
512                // Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber));
513                // traindata_count += ctraindata.get(cnumber).size();
514                // Console.traceln(Level.INFO,
515                // String.format("building classifier in cluster "+cnumber +"  with "+
516                // ctraindata.get(cnumber).size() +" traindata instances"));
517            }
518
519            // add all traindata
520            // Console.traceln(Level.INFO, String.format("traindata in all clusters: " +
521            // traindata_count));
522        }
523    }
524
525    /**
526     * Payload for the QuadTree. x and y are the calculated Fastmap values. T is a weka instance.
527     */
528    public class QuadTreePayload<T> {
529
530        public double x;
531        public double y;
532        private T inst;
533
534        public QuadTreePayload(double x, double y, T value) {
535            this.x = x;
536            this.y = y;
537            this.inst = value;
538        }
539
540        public T getInst() {
541            return this.inst;
542        }
543    }
544
545    /**
546     * Fastmap implementation
547     *
548     * Faloutsos, C., & Lin, K. I. (1995). FastMap: A fast algorithm for indexing, data-mining and
549     * visualization of traditional and multimedia datasets (Vol. 24, No. 2, pp. 163-174). ACM.
550     */
551    public class Fastmap {
552
553        /* N x k Array, at the end, the i-th row will be the image of the i-th object */
554        private double[][] X;
555
556        /* 2 x k pivot Array one pair per recursive call */
557        private int[][] PA;
558
559        /* Objects we got (distance matrix) */
560        private double[][] O;
561
562        /* column of X currently updated (also the dimension) */
563        private int col = 0;
564
565        /* number of dimensions we want */
566        private int target_dims = 0;
567
568        // if we already have the pivot elements
569        private boolean pivot_set = false;
570
571        public Fastmap(int k) {
572            this.target_dims = k;
573        }
574
575        /**
576         * Sets the distance matrix and params that depend on this
577         *
578         * @param O
579         */
580        public void setDistmat(double[][] O) {
581            this.O = O;
582            int N = O.length;
583            this.X = new double[N][this.target_dims];
584            this.PA = new int[2][this.target_dims];
585        }
586
587        /**
588         * Set pivot elements, we need that to classify instances after the calculation is complete
589         * (because we then want to reuse only the pivot elements).
590         *
591         * @param pi
592         */
593        public void setPivots(int[][] pi) {
594            this.pivot_set = true;
595            this.PA = pi;
596        }
597
598        /**
599         * Return the pivot elements that were chosen during the calculation
600         *
601         * @return
602         */
603        public int[][] getPivots() {
604            return this.PA;
605        }
606
607        /**
608         * The distance function for euclidean distance
609         *
610         * Acts according to equation 4 of the fastmap paper
611         *
612         * @param x
613         *            x index of x image (if k==0 x object)
614         * @param y
615         *            y index of y image (if k==0 y object)
616         * @param kdimensionality
617         * @return distance
618         */
619        private double dist(int x, int y, int k) {
620
621            // basis is object distance, we get this from our distance matrix
622            double tmp = this.O[x][y] * this.O[x][y];
623
624            // decrease by projections
625            for (int i = 0; i < k; i++) {
626                double tmp2 = (this.X[x][i] - this.X[y][i]);
627                tmp -= tmp2 * tmp2;
628            }
629
630            return Math.abs(tmp);
631        }
632
633        /**
634         * Find the object farthest from the given index This method is a helper Method for
635         * findDistandObjects
636         *
637         * @param index
638         *            of the object
639         * @return index of the farthest object from the given index
640         */
641        private int findFarthest(int index) {
642            double furthest = Double.MIN_VALUE;
643            int ret = 0;
644
645            for (int i = 0; i < O.length; i++) {
646                double dist = this.dist(i, index, this.col);
647                if (i != index && dist > furthest) {
648                    furthest = dist;
649                    ret = i;
650                }
651            }
652            return ret;
653        }
654
655        /**
656         * Finds the pivot objects
657         *
658         * This method is basically algorithm 1 of the fastmap paper.
659         *
660         * @return 2 indexes of the choosen pivot objects
661         */
662        private int[] findDistantObjects() {
663            // 1. choose object randomly
664            Random r = new Random();
665            int obj = r.nextInt(this.O.length);
666
667            // 2. find farthest object from randomly chosen object
668            int idx1 = this.findFarthest(obj);
669
670            // 3. find farthest object from previously farthest object
671            int idx2 = this.findFarthest(idx1);
672
673            return new int[]
674                { idx1, idx2 };
675        }
676
677        /**
678         * Calculates the new k-vector values (projections)
679         *
680         * This is basically algorithm 2 of the fastmap paper. We just added the possibility to
681         * pre-set the pivot elements because we need to classify single instances after the
682         * computation is already done.
683         *
684         * @param dims
685         *            dimensionality
686         */
687        public void calculate() {
688
689            for (int k = 0; k < this.target_dims; k++) {
690                // 2) choose pivot objects
691                if (!this.pivot_set) {
692                    int[] pivots = this.findDistantObjects();
693
694                    // 3) record ids of pivot objects
695                    this.PA[0][this.col] = pivots[0];
696                    this.PA[1][this.col] = pivots[1];
697                }
698
699                // 4) inter object distances are zero (this.X is initialized with 0 so we just
700                // continue)
701                if (this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col) == 0) {
702                    continue;
703                }
704
705                // 5) project the objects on the line between the pivots
706                double dxy = this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col);
707                for (int i = 0; i < this.O.length; i++) {
708
709                    double dix = this.dist(i, this.PA[0][this.col], this.col);
710                    double diy = this.dist(i, this.PA[1][this.col], this.col);
711
712                    double tmp = (dix + dxy - diy) / (2 * Math.sqrt(dxy));
713
714                    // save the projection
715                    this.X[i][this.col] = tmp;
716                }
717
718                this.col += 1;
719            }
720        }
721
722        /**
723         * returns the result matrix of the projections
724         *
725         * @return calculated result
726         */
727        public double[][] getX() {
728            return this.X;
729        }
730    }
731}
Note: See TracBrowser for help on using the repository browser.