source: trunk/CrossPare/src/de/ugoe/cs/cpdp/dataprocessing/TransferComponentAnalysis.java @ 87

Last change on this file since 87 was 86, checked in by sherbold, 9 years ago
  • switched workspace encoding to UTF-8 and fixed broken characters
  • Property svn:mime-type set to text/plain
File size: 9.3 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.dataprocessing;
16
17import java.util.Arrays;
18import java.util.logging.Level;
19
20import org.ojalgo.matrix.PrimitiveMatrix;
21import org.ojalgo.matrix.jama.JamaEigenvalue;
22import org.ojalgo.matrix.jama.JamaEigenvalue.General;
23import org.ojalgo.scalar.ComplexNumber;
24import org.ojalgo.access.Access2D.Builder;
25import org.ojalgo.array.Array1D;
26
27import de.ugoe.cs.cpdp.util.SortUtils;
28import de.ugoe.cs.util.console.Console;
29import weka.core.Attribute;
30import weka.core.Instance;
31import weka.core.Instances;
32
33/**
34 * <p>
35 * TCA with a linear kernel after Pan et al. (Domain Adaptation via Transfer Component Analysis) and
36 * used for defect prediction by Nam et al. (Transfer Defect Learning)
37 * </p>
38 *
39 * TODO comment class
40 * @author Steffen Herbold
41 */
42public class TransferComponentAnalysis implements IProcessesingStrategy {
43
44    int reducedDimension = 5;
45
46    @Override
47    public void setParameter(String parameters) {
48       
49    }
50
51    @Override
52    public void apply(Instances testdata, Instances traindata) {
53        applyTCA(testdata, traindata);
54    }
55
56    private double linearKernel(Instance x1, Instance x2) {
57        double value = 0.0d;
58        for (int j = 0; j < x1.numAttributes(); j++) {
59            if (j != x1.classIndex()) {
60                value += x1.value(j) * x2.value(j);
61            }
62        }
63        return value;
64    }
65
66    private void applyTCA(Instances testdata, Instances traindata) {
67        final int sizeTest = testdata.numInstances();
68        final int sizeTrain = traindata.numInstances();
69        final PrimitiveMatrix kernelMatrix = buildKernel(testdata, traindata);
70        final PrimitiveMatrix kernelNormMatrix = buildKernelNormMatrix(sizeTest, sizeTrain); // L in
71                                                                                             // the
72        // paper
73        final PrimitiveMatrix centerMatrix = buildCenterMatrix(sizeTest, sizeTrain); // H in the
74                                                                                     // paper
75        final double mu = 1.0; // default from the MATLAB implementation
76        final PrimitiveMatrix muMatrix = buildMuMatrix(sizeTest, sizeTrain, mu);
77        PrimitiveMatrix.FACTORY.makeEye(sizeTest + sizeTrain, sizeTest + sizeTrain);
78
79        Console.traceln(Level.FINEST,
80                        "creating optimization matrix (dimension " + (sizeTest + sizeTrain) + ")");
81        final PrimitiveMatrix optimizationProblem = kernelMatrix.multiplyRight(kernelNormMatrix)
82            .multiplyRight(kernelMatrix).add(muMatrix).invert().multiplyRight(kernelMatrix)
83            .multiplyRight(centerMatrix).multiplyRight(kernelMatrix);
84        Console.traceln(Level.FINEST,
85                        "optimization matrix created, now solving eigenvalue problem");
86        General eigenvalueDecomposition = new JamaEigenvalue.General();
87        eigenvalueDecomposition.compute(optimizationProblem);
88        Console.traceln(Level.FINEST, "eigenvalue problem solved");
89
90        Array1D<ComplexNumber> eigenvaluesArray = eigenvalueDecomposition.getEigenvalues();
91        System.out.println(eigenvaluesArray.length);
92        final Double[] eigenvalues = new Double[(int) eigenvaluesArray.length];
93        final int[] index = new int[(int) eigenvaluesArray.length];
94        // create kernel transformation matrix from eigenvectors
95        for (int i = 0; i < eigenvaluesArray.length; i++) {
96            eigenvalues[i] = eigenvaluesArray.doubleValue(i);
97            index[i] = i;
98        }
99        SortUtils.quicksort(eigenvalues, index);
100
101        final PrimitiveMatrix transformedKernel = kernelMatrix.multiplyRight(eigenvalueDecomposition
102            .getV().selectColumns(Arrays.copyOfRange(index, 0, reducedDimension)));
103
104        // update testdata and traindata
105        for (int j = testdata.numAttributes() - 1; j >= 0; j--) {
106            if (j != testdata.classIndex()) {
107                testdata.deleteAttributeAt(j);
108                traindata.deleteAttributeAt(j);
109            }
110        }
111        for (int j = 0; j < reducedDimension; j++) {
112            testdata.insertAttributeAt(new Attribute("kerneldim" + j), 1);
113            traindata.insertAttributeAt(new Attribute("kerneldim" + j), 1);
114        }
115        for (int i = 0; i < sizeTrain; i++) {
116            for (int j = 0; j < reducedDimension; j++) {
117                traindata.instance(i).setValue(j + 1, transformedKernel.get(i, j));
118            }
119        }
120        for (int i = 0; i < sizeTest; i++) {
121            for (int j = 0; j < reducedDimension; j++) {
122                testdata.instance(i).setValue(j + 1, transformedKernel.get(i + sizeTrain, j));
123            }
124        }
125    }
126
127    private PrimitiveMatrix buildKernel(Instances testdata, Instances traindata) {
128        final int kernelDim = traindata.numInstances() + testdata.numInstances();
129
130        Builder<PrimitiveMatrix> kernelBuilder = PrimitiveMatrix.getBuilder(kernelDim, kernelDim);
131        // built upper left quadrant (source, source)
132        for (int i = 0; i < traindata.numInstances(); i++) {
133            for (int j = 0; j < traindata.numInstances(); j++) {
134                kernelBuilder.set(i, j, linearKernel(traindata.get(i), traindata.get(j)));
135            }
136        }
137
138        // built upper right quadrant (source, target)
139        for (int i = 0; i < traindata.numInstances(); i++) {
140            for (int j = 0; j < testdata.numInstances(); j++) {
141                kernelBuilder.set(i, j + traindata.numInstances(),
142                                  linearKernel(traindata.get(i), testdata.get(j)));
143            }
144        }
145
146        // built lower left quadrant (target, source)
147        for (int i = 0; i < testdata.numInstances(); i++) {
148            for (int j = 0; j < traindata.numInstances(); j++) {
149                kernelBuilder.set(i + traindata.numInstances(), j,
150                                  linearKernel(testdata.get(i), traindata.get(j)));
151            }
152        }
153
154        // built lower right quadrant (target, target)
155        for (int i = 0; i < testdata.numInstances(); i++) {
156            for (int j = 0; j < testdata.numInstances(); j++) {
157                kernelBuilder.set(i + traindata.numInstances(), j + traindata.numInstances(),
158                                  linearKernel(testdata.get(i), testdata.get(j)));
159            }
160        }
161        return kernelBuilder.build();
162    }
163
164    private PrimitiveMatrix buildKernelNormMatrix(final int dimTest, final int sizeTrain) {
165        final double trainSquared = 1.0 / (sizeTrain * (double) sizeTrain);
166        final double testSquared = 1.0 / (dimTest * (double) dimTest);
167        final double trainTest = -1.0 / (sizeTrain * (double) dimTest);
168        Builder<PrimitiveMatrix> kernelNormBuilder =
169            PrimitiveMatrix.getBuilder(sizeTrain + dimTest, sizeTrain + dimTest);
170
171        // built upper left quadrant (source, source)
172        for (int i = 0; i < sizeTrain; i++) {
173            for (int j = 0; j < sizeTrain; j++) {
174                kernelNormBuilder.set(i, j, trainSquared);
175            }
176        }
177
178        // built upper right quadrant (source, target)
179        for (int i = 0; i < sizeTrain; i++) {
180            for (int j = 0; j < dimTest; j++) {
181                kernelNormBuilder.set(i, j + sizeTrain, trainTest);
182            }
183        }
184
185        // built lower left quadrant (target, source)
186        for (int i = 0; i < dimTest; i++) {
187            for (int j = 0; j < sizeTrain; j++) {
188                kernelNormBuilder.set(i + sizeTrain, j, trainTest);
189            }
190        }
191
192        // built lower right quadrant (target, target)
193        for (int i = 0; i < dimTest; i++) {
194            for (int j = 0; j < dimTest; j++) {
195                kernelNormBuilder.set(i + sizeTrain, j + sizeTrain, testSquared);
196            }
197        }
198        return kernelNormBuilder.build();
199    }
200
201    private PrimitiveMatrix buildCenterMatrix(final int sizeTest, final int sizeTrain) {
202        Builder<PrimitiveMatrix> centerMatrix =
203            PrimitiveMatrix.getBuilder(sizeTest + sizeTrain, sizeTest + sizeTrain);
204        for (int i = 0; i < centerMatrix.countRows(); i++) {
205            centerMatrix.set(i, i, -1.0 / (sizeTest + sizeTrain));
206        }
207        return centerMatrix.build();
208    }
209
210    private PrimitiveMatrix buildMuMatrix(final int sizeTest,
211                                          final int sizeTrain,
212                                          final double mu)
213    {
214        Builder<PrimitiveMatrix> muMatrix =
215            PrimitiveMatrix.getBuilder(sizeTest + sizeTrain, sizeTest + sizeTrain);
216        for (int i = 0; i < muMatrix.countRows(); i++) {
217            muMatrix.set(i, i, mu);
218        }
219        return muMatrix.build();
220    }
221}
Note: See TracBrowser for help on using the repository browser.