source: trunk/CrossPare/src/de/ugoe/cs/cpdp/dataprocessing/TransferComponentAnalysis.java @ 55

Last change on this file since 55 was 55, checked in by sherbold, 9 years ago
  • added TCA after Pan et al.
  • Property svn:mime-type set to text/plain
File size: 10.9 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.dataprocessing;
16
17import java.util.Arrays;
18import java.util.logging.Level;
19
20import org.ojalgo.matrix.PrimitiveMatrix;
21import org.ojalgo.matrix.jama.JamaEigenvalue;
22import org.ojalgo.matrix.jama.JamaEigenvalue.General;
23import org.ojalgo.scalar.ComplexNumber;
24import org.ojalgo.access.Access2D.Builder;
25import org.ojalgo.array.Array1D;
26
27import de.ugoe.cs.util.console.Console;
28import weka.core.Attribute;
29import weka.core.Instance;
30import weka.core.Instances;
31
32/**
33 * <p>
34 * TCA with a linear kernel after Pan et al. (Domain Adaptation via Transfer Component Analysis) and
35 * used for defect prediction by Nam et al. (Transfer Defect Learning)
36 * </p>
37 *
38 * TODO comment class
39 * @author Steffen Herbold
40 */
41public class TransferComponentAnalysis implements IProcessesingStrategy {
42
43    int reducedDimension = 5;
44
45    @Override
46    public void setParameter(String parameters) {
47       
48    }
49
50    @Override
51    public void apply(Instances testdata, Instances traindata) {
52        applyTCA(testdata, traindata);
53    }
54
55    private double linearKernel(Instance x1, Instance x2) {
56        double value = 0.0d;
57        for (int j = 0; j < x1.numAttributes(); j++) {
58            if (j != x1.classIndex()) {
59                value += x1.value(j) * x2.value(j);
60            }
61        }
62        return value;
63    }
64
65    private void applyTCA(Instances testdata, Instances traindata) {
66        final int sizeTest = testdata.numInstances();
67        final int sizeTrain = traindata.numInstances();
68        final PrimitiveMatrix kernelMatrix = buildKernel(testdata, traindata);
69        final PrimitiveMatrix kernelNormMatrix = buildKernelNormMatrix(sizeTest, sizeTrain); // L in
70                                                                                             // the
71        // paper
72        final PrimitiveMatrix centerMatrix = buildCenterMatrix(sizeTest, sizeTrain); // H in the
73                                                                                     // paper
74        final double mu = 1.0; // default from the MATLAB implementation
75        final PrimitiveMatrix muMatrix = buildMuMatrix(sizeTest, sizeTrain, mu);
76        PrimitiveMatrix.FACTORY.makeEye(sizeTest + sizeTrain, sizeTest + sizeTrain);
77
78        Console.traceln(Level.FINEST,
79                        "creating optimization matrix (dimension " + (sizeTest + sizeTrain) + ")");
80        final PrimitiveMatrix optimizationProblem = kernelMatrix.multiplyRight(kernelNormMatrix)
81            .multiplyRight(kernelMatrix).add(muMatrix).invert().multiplyRight(kernelMatrix)
82            .multiplyRight(centerMatrix).multiplyRight(kernelMatrix);
83        Console.traceln(Level.FINEST,
84                        "optimization matrix created, now solving eigenvalue problem");
85        General eigenvalueDecomposition = new JamaEigenvalue.General();
86        eigenvalueDecomposition.compute(optimizationProblem);
87        Console.traceln(Level.FINEST, "eigenvalue problem solved");
88
89        Array1D<ComplexNumber> eigenvaluesArray = eigenvalueDecomposition.getEigenvalues();
90        System.out.println(eigenvaluesArray.length);
91        final double[] eigenvalues = new double[(int) eigenvaluesArray.length];
92        final int[] index = new int[(int) eigenvaluesArray.length];
93        // create kernel transformation matrix from eigenvectors
94        for (int i = 0; i < eigenvaluesArray.length; i++) {
95            eigenvalues[i] = eigenvaluesArray.doubleValue(i);
96            index[i] = i;
97        }
98        quicksort(eigenvalues, index);
99
100        final PrimitiveMatrix transformedKernel = kernelMatrix.multiplyRight(eigenvalueDecomposition
101            .getV().selectColumns(Arrays.copyOfRange(index, 0, reducedDimension)));
102
103        // update testdata and traindata
104        for (int j = testdata.numAttributes() - 1; j >= 0; j--) {
105            if (j != testdata.classIndex()) {
106                testdata.deleteAttributeAt(j);
107                traindata.deleteAttributeAt(j);
108            }
109        }
110        for (int j = 0; j < reducedDimension; j++) {
111            testdata.insertAttributeAt(new Attribute("kerneldim" + j), 1);
112            traindata.insertAttributeAt(new Attribute("kerneldim" + j), 1);
113        }
114        for (int i = 0; i < sizeTrain; i++) {
115            for (int j = 0; j < reducedDimension; j++) {
116                traindata.instance(i).setValue(j + 1, transformedKernel.get(i, j));
117            }
118        }
119        for (int i = 0; i < sizeTest; i++) {
120            for (int j = 0; j < reducedDimension; j++) {
121                testdata.instance(i).setValue(j + 1, transformedKernel.get(i + sizeTrain, j));
122            }
123        }
124    }
125
126    private PrimitiveMatrix buildKernel(Instances testdata, Instances traindata) {
127        final int kernelDim = traindata.numInstances() + testdata.numInstances();
128
129        Builder<PrimitiveMatrix> kernelBuilder = PrimitiveMatrix.getBuilder(kernelDim, kernelDim);
130        // built upper left quadrant (source, source)
131        for (int i = 0; i < traindata.numInstances(); i++) {
132            for (int j = 0; j < traindata.numInstances(); j++) {
133                kernelBuilder.set(i, j, linearKernel(traindata.get(i), traindata.get(j)));
134            }
135        }
136
137        // built upper right quadrant (source, target)
138        for (int i = 0; i < traindata.numInstances(); i++) {
139            for (int j = 0; j < testdata.numInstances(); j++) {
140                kernelBuilder.set(i, j + traindata.numInstances(),
141                                  linearKernel(traindata.get(i), testdata.get(j)));
142            }
143        }
144
145        // built lower left quadrant (target, source)
146        for (int i = 0; i < testdata.numInstances(); i++) {
147            for (int j = 0; j < traindata.numInstances(); j++) {
148                kernelBuilder.set(i + traindata.numInstances(), j,
149                                  linearKernel(testdata.get(i), traindata.get(j)));
150            }
151        }
152
153        // built lower right quadrant (target, target)
154        for (int i = 0; i < testdata.numInstances(); i++) {
155            for (int j = 0; j < testdata.numInstances(); j++) {
156                kernelBuilder.set(i + traindata.numInstances(), j + traindata.numInstances(),
157                                  linearKernel(testdata.get(i), testdata.get(j)));
158            }
159        }
160        return kernelBuilder.build();
161    }
162
163    private PrimitiveMatrix buildKernelNormMatrix(final int dimTest, final int sizeTrain) {
164        final double trainSquared = 1.0 / (sizeTrain * (double) sizeTrain);
165        final double testSquared = 1.0 / (dimTest * (double) dimTest);
166        final double trainTest = -1.0 / (sizeTrain * (double) dimTest);
167        Builder<PrimitiveMatrix> kernelNormBuilder =
168            PrimitiveMatrix.getBuilder(sizeTrain + dimTest, sizeTrain + dimTest);
169
170        // built upper left quadrant (source, source)
171        for (int i = 0; i < sizeTrain; i++) {
172            for (int j = 0; j < sizeTrain; j++) {
173                kernelNormBuilder.set(i, j, trainSquared);
174            }
175        }
176
177        // built upper right quadrant (source, target)
178        for (int i = 0; i < sizeTrain; i++) {
179            for (int j = 0; j < dimTest; j++) {
180                kernelNormBuilder.set(i, j + sizeTrain, trainTest);
181            }
182        }
183
184        // built lower left quadrant (target, source)
185        for (int i = 0; i < dimTest; i++) {
186            for (int j = 0; j < sizeTrain; j++) {
187                kernelNormBuilder.set(i + sizeTrain, j, trainTest);
188            }
189        }
190
191        // built lower right quadrant (target, target)
192        for (int i = 0; i < dimTest; i++) {
193            for (int j = 0; j < dimTest; j++) {
194                kernelNormBuilder.set(i + sizeTrain, j + sizeTrain, testSquared);
195            }
196        }
197        return kernelNormBuilder.build();
198    }
199
200    private PrimitiveMatrix buildCenterMatrix(final int sizeTest, final int sizeTrain) {
201        Builder<PrimitiveMatrix> centerMatrix =
202            PrimitiveMatrix.getBuilder(sizeTest + sizeTrain, sizeTest + sizeTrain);
203        for (int i = 0; i < centerMatrix.countRows(); i++) {
204            centerMatrix.set(i, i, -1.0 / (sizeTest + sizeTrain));
205        }
206        return centerMatrix.build();
207    }
208
209    private PrimitiveMatrix buildMuMatrix(final int sizeTest,
210                                          final int sizeTrain,
211                                          final double mu)
212    {
213        Builder<PrimitiveMatrix> muMatrix =
214            PrimitiveMatrix.getBuilder(sizeTest + sizeTrain, sizeTest + sizeTrain);
215        for (int i = 0; i < muMatrix.countRows(); i++) {
216            muMatrix.set(i, i, mu);
217        }
218        return muMatrix.build();
219    }
220
221    // below is from http://stackoverflow.com/a/1040503
222    private static void quicksort(double[] main, int[] index) {
223        quicksort(main, index, 0, index.length - 1);
224    }
225
226    // quicksort a[left] to a[right]
227    private static void quicksort(double[] a, int[] index, int left, int right) {
228        if (right <= left)
229            return;
230        int i = partition(a, index, left, right);
231        quicksort(a, index, left, i - 1);
232        quicksort(a, index, i + 1, right);
233    }
234
235    // partition a[left] to a[right], assumes left < right
236    private static int partition(double[] a, int[] index, int left, int right) {
237        int i = left - 1;
238        int j = right;
239        while (true) {
240            while (less(a[++i], a[right])) // find item on left to swap
241            ; // a[right] acts as sentinel
242            while (less(a[right], a[--j])) // find item on right to swap
243                if (j == left)
244                    break; // don't go out-of-bounds
245            if (i >= j)
246                break; // check if pointers cross
247            exch(a, index, i, j); // swap two elements into place
248        }
249        exch(a, index, i, right); // swap with partition element
250        return i;
251    }
252
253    // is x < y ?
254    private static boolean less(double x, double y) {
255        return (x < y);
256    }
257
258    // exchange a[i] and a[j]
259    private static void exch(double[] a, int[] index, int i, int j) {
260        double swap = a[i];
261        a[i] = a[j];
262        a[j] = swap;
263        int b = index[i];
264        index[i] = index[j];
265        index[j] = b;
266    }
267}
Note: See TracBrowser for help on using the repository browser.