source: trunk/CrossPare/src/de/ugoe/cs/cpdp/dataprocessing/TransferComponentAnalysis.java @ 139

Last change on this file since 139 was 135, checked in by sherbold, 8 years ago
  • code documentation and formatting
  • Property svn:mime-type set to text/plain
File size: 11.4 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.dataprocessing;
16
17import java.util.Arrays;
18import java.util.logging.Level;
19
20import org.ojalgo.matrix.PrimitiveMatrix;
21import org.ojalgo.matrix.jama.JamaEigenvalue;
22import org.ojalgo.matrix.jama.JamaEigenvalue.General;
23import org.ojalgo.scalar.ComplexNumber;
24import org.ojalgo.access.Access2D.Builder;
25import org.ojalgo.array.Array1D;
26
27import de.ugoe.cs.cpdp.util.SortUtils;
28import de.ugoe.cs.util.console.Console;
29import weka.core.Attribute;
30import weka.core.Instance;
31import weka.core.Instances;
32
33/**
34 * <p>
35 * TCA with a linear kernel after Pan et al. (Domain Adaptation via Transfer Component Analysis) and
36 * used for defect prediction by Nam et al. (Transfer Defect Learning)
37 * </p>
38 *
39 * @author Steffen Herbold
40 */
41public class TransferComponentAnalysis implements IProcessesingStrategy {
42
43    /**
44     * Dimension of the reduced data.
45     */
46    int reducedDimension = 5;
47
48    /*
49     * (non-Javadoc)
50     *
51     * @see de.ugoe.cs.cpdp.IParameterizable#setParameter(java.lang.String)
52     */
53    @Override
54    public void setParameter(String parameters) {
55        // dummy, paramters ignored
56    }
57
58    /*
59     * (non-Javadoc)
60     *
61     * @see de.ugoe.cs.cpdp.dataprocessing.IProcessesingStrategy#apply(weka.core.Instances,
62     * weka.core.Instances)
63     */
64    @Override
65    public void apply(Instances testdata, Instances traindata) {
66        applyTCA(testdata, traindata);
67    }
68
69    /**
70     * <p>
71     * calculates the linear kernel function between two instances
72     * </p>
73     *
74     * @param x1
75     *            first instance
76     * @param x2
77     *            second instance
78     * @return kernel value
79     */
80    private double linearKernel(Instance x1, Instance x2) {
81        double value = 0.0d;
82        for (int j = 0; j < x1.numAttributes(); j++) {
83            if (j != x1.classIndex()) {
84                value += x1.value(j) * x2.value(j);
85            }
86        }
87        return value;
88    }
89
90    /**
91     * <p>
92     * Applies TCA to the test and training data.
93     * </p>
94     *
95     * @param testdata
96     *            the test data
97     * @param traindata
98     *            the training data
99     */
100    private void applyTCA(Instances testdata, Instances traindata) {
101        final int sizeTest = testdata.numInstances();
102        final int sizeTrain = traindata.numInstances();
103        final PrimitiveMatrix kernelMatrix = buildKernel(testdata, traindata);
104        final PrimitiveMatrix kernelNormMatrix = buildKernelNormMatrix(sizeTest, sizeTrain); // L in
105                                                                                             // the
106        // paper
107        final PrimitiveMatrix centerMatrix = buildCenterMatrix(sizeTest, sizeTrain); // H in the
108                                                                                     // paper
109        final double mu = 1.0; // default from the MATLAB implementation
110        final PrimitiveMatrix muMatrix = buildMuMatrix(sizeTest, sizeTrain, mu);
111        PrimitiveMatrix.FACTORY.makeEye(sizeTest + sizeTrain, sizeTest + sizeTrain);
112
113        Console.traceln(Level.FINEST,
114                        "creating optimization matrix (dimension " + (sizeTest + sizeTrain) + ")");
115        final PrimitiveMatrix optimizationProblem = kernelMatrix.multiplyRight(kernelNormMatrix)
116            .multiplyRight(kernelMatrix).add(muMatrix).invert().multiplyRight(kernelMatrix)
117            .multiplyRight(centerMatrix).multiplyRight(kernelMatrix);
118        Console.traceln(Level.FINEST,
119                        "optimization matrix created, now solving eigenvalue problem");
120        General eigenvalueDecomposition = new JamaEigenvalue.General();
121        eigenvalueDecomposition.compute(optimizationProblem);
122        Console.traceln(Level.FINEST, "eigenvalue problem solved");
123
124        Array1D<ComplexNumber> eigenvaluesArray = eigenvalueDecomposition.getEigenvalues();
125        System.out.println(eigenvaluesArray.length);
126        final Double[] eigenvalues = new Double[(int) eigenvaluesArray.length];
127        final int[] index = new int[(int) eigenvaluesArray.length];
128        // create kernel transformation matrix from eigenvectors
129        for (int i = 0; i < eigenvaluesArray.length; i++) {
130            eigenvalues[i] = eigenvaluesArray.doubleValue(i);
131            index[i] = i;
132        }
133        SortUtils.quicksort(eigenvalues, index);
134
135        final PrimitiveMatrix transformedKernel = kernelMatrix.multiplyRight(eigenvalueDecomposition
136            .getV().selectColumns(Arrays.copyOfRange(index, 0, reducedDimension)));
137
138        // update testdata and traindata
139        for (int j = testdata.numAttributes() - 1; j >= 0; j--) {
140            if (j != testdata.classIndex()) {
141                testdata.deleteAttributeAt(j);
142                traindata.deleteAttributeAt(j);
143            }
144        }
145        for (int j = 0; j < reducedDimension; j++) {
146            testdata.insertAttributeAt(new Attribute("kerneldim" + j), 1);
147            traindata.insertAttributeAt(new Attribute("kerneldim" + j), 1);
148        }
149        for (int i = 0; i < sizeTrain; i++) {
150            for (int j = 0; j < reducedDimension; j++) {
151                traindata.instance(i).setValue(j + 1, transformedKernel.get(i, j));
152            }
153        }
154        for (int i = 0; i < sizeTest; i++) {
155            for (int j = 0; j < reducedDimension; j++) {
156                testdata.instance(i).setValue(j + 1, transformedKernel.get(i + sizeTrain, j));
157            }
158        }
159    }
160
161    /**
162     * <p>
163     * Creates the kernel matrix of the test and training data
164     * </p>
165     *
166     * @param testdata
167     *            the test data
168     * @param traindata
169     *            the training data
170     * @return kernel matrix
171     */
172    private PrimitiveMatrix buildKernel(Instances testdata, Instances traindata) {
173        final int kernelDim = traindata.numInstances() + testdata.numInstances();
174
175        Builder<PrimitiveMatrix> kernelBuilder = PrimitiveMatrix.getBuilder(kernelDim, kernelDim);
176        // built upper left quadrant (source, source)
177        for (int i = 0; i < traindata.numInstances(); i++) {
178            for (int j = 0; j < traindata.numInstances(); j++) {
179                kernelBuilder.set(i, j, linearKernel(traindata.get(i), traindata.get(j)));
180            }
181        }
182
183        // built upper right quadrant (source, target)
184        for (int i = 0; i < traindata.numInstances(); i++) {
185            for (int j = 0; j < testdata.numInstances(); j++) {
186                kernelBuilder.set(i, j + traindata.numInstances(),
187                                  linearKernel(traindata.get(i), testdata.get(j)));
188            }
189        }
190
191        // built lower left quadrant (target, source)
192        for (int i = 0; i < testdata.numInstances(); i++) {
193            for (int j = 0; j < traindata.numInstances(); j++) {
194                kernelBuilder.set(i + traindata.numInstances(), j,
195                                  linearKernel(testdata.get(i), traindata.get(j)));
196            }
197        }
198
199        // built lower right quadrant (target, target)
200        for (int i = 0; i < testdata.numInstances(); i++) {
201            for (int j = 0; j < testdata.numInstances(); j++) {
202                kernelBuilder.set(i + traindata.numInstances(), j + traindata.numInstances(),
203                                  linearKernel(testdata.get(i), testdata.get(j)));
204            }
205        }
206        return kernelBuilder.build();
207    }
208
209    /**
210     * <p>
211     * Calculates the kernel norm matrix, i.e., the matrix which is used for matrix multiplication
212     * to calculate the kernel norm.
213     * </p>
214     *
215     * @param dimTest
216     *            dimension of the test data
217     * @param sizeTrain
218     *            number of instances of the training data
219     * @return kernel norm matrix
220     */
221    private PrimitiveMatrix buildKernelNormMatrix(final int dimTest, final int sizeTrain) {
222        final double trainSquared = 1.0 / (sizeTrain * (double) sizeTrain);
223        final double testSquared = 1.0 / (dimTest * (double) dimTest);
224        final double trainTest = -1.0 / (sizeTrain * (double) dimTest);
225        Builder<PrimitiveMatrix> kernelNormBuilder =
226            PrimitiveMatrix.getBuilder(sizeTrain + dimTest, sizeTrain + dimTest);
227
228        // built upper left quadrant (source, source)
229        for (int i = 0; i < sizeTrain; i++) {
230            for (int j = 0; j < sizeTrain; j++) {
231                kernelNormBuilder.set(i, j, trainSquared);
232            }
233        }
234
235        // built upper right quadrant (source, target)
236        for (int i = 0; i < sizeTrain; i++) {
237            for (int j = 0; j < dimTest; j++) {
238                kernelNormBuilder.set(i, j + sizeTrain, trainTest);
239            }
240        }
241
242        // built lower left quadrant (target, source)
243        for (int i = 0; i < dimTest; i++) {
244            for (int j = 0; j < sizeTrain; j++) {
245                kernelNormBuilder.set(i + sizeTrain, j, trainTest);
246            }
247        }
248
249        // built lower right quadrant (target, target)
250        for (int i = 0; i < dimTest; i++) {
251            for (int j = 0; j < dimTest; j++) {
252                kernelNormBuilder.set(i + sizeTrain, j + sizeTrain, testSquared);
253            }
254        }
255        return kernelNormBuilder.build();
256    }
257
258    /**
259     * <p>
260     * Creates the center matrix
261     * </p>
262     *
263     * @param sizeTest
264     *            number of instances of the test data
265     * @param sizeTrain
266     *            number of instances of the training data
267     * @return center matrix
268     */
269    private PrimitiveMatrix buildCenterMatrix(final int sizeTest, final int sizeTrain) {
270        Builder<PrimitiveMatrix> centerMatrix =
271            PrimitiveMatrix.getBuilder(sizeTest + sizeTrain, sizeTest + sizeTrain);
272        for (int i = 0; i < centerMatrix.countRows(); i++) {
273            centerMatrix.set(i, i, -1.0 / (sizeTest + sizeTrain));
274        }
275        return centerMatrix.build();
276    }
277
278    /**
279     * <p>
280     * Builds the mu-Matrix for offsetting values.
281     * </p>
282     *
283     * @param sizeTest
284     *            number of instances of the test data
285     * @param sizeTrain
286     *            number of instances of the training data
287     * @param mu
288     *            mu parameter
289     * @return mu-Matrix
290     */
291    private PrimitiveMatrix buildMuMatrix(final int sizeTest,
292                                          final int sizeTrain,
293                                          final double mu)
294    {
295        Builder<PrimitiveMatrix> muMatrix =
296            PrimitiveMatrix.getBuilder(sizeTest + sizeTrain, sizeTest + sizeTrain);
297        for (int i = 0; i < muMatrix.countRows(); i++) {
298            muMatrix.set(i, i, mu);
299        }
300        return muMatrix.build();
301    }
302}
Note: See TracBrowser for help on using the repository browser.