|
@@ -87,17 +87,17 @@ class KMeansLocal {
|
|
|
|
|
|
for (int i = 0; i < sampleSize; i++) {
|
|
|
float[] vector = vectors.vectorValue(i);
|
|
|
- int[] neighborOffsets = null;
|
|
|
- int centroidIdx = -1;
|
|
|
+ final int assignment = assignments[i];
|
|
|
+ final int bestCentroidOffset;
|
|
|
if (neighborhoods != null) {
|
|
|
- neighborOffsets = neighborhoods.get(assignments[i]);
|
|
|
- centroidIdx = assignments[i];
|
|
|
+ bestCentroidOffset = getBestCentroidFromNeighbours(centroids, vector, assignment, neighborhoods.get(assignment));
|
|
|
+ } else {
|
|
|
+ bestCentroidOffset = getBestCentroid(centroids, vector);
|
|
|
}
|
|
|
- int bestCentroidOffset = getBestCentroidOffset(centroids, vector, centroidIdx, neighborOffsets);
|
|
|
- if (assignments[i] != bestCentroidOffset) {
|
|
|
+ if (assignment != bestCentroidOffset) {
|
|
|
+ assignments[i] = bestCentroidOffset;
|
|
|
changed = true;
|
|
|
}
|
|
|
- assignments[i] = bestCentroidOffset;
|
|
|
centroidCounts[bestCentroidOffset]++;
|
|
|
for (int d = 0; d < dim; d++) {
|
|
|
nextCentroids[bestCentroidOffset][d] += vector[d];
|
|
@@ -116,23 +116,28 @@ class KMeansLocal {
|
|
|
return changed;
|
|
|
}
|
|
|
|
|
|
- int getBestCentroidOffset(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) {
|
|
|
+ int getBestCentroidFromNeighbours(float[][] centroids, float[] vector, int centroidIdx, int[] centroidOffsets) {
|
|
|
int bestCentroidOffset = centroidIdx;
|
|
|
- float minDsq;
|
|
|
- if (centroidIdx > 0 && centroidIdx < centroids.length) {
|
|
|
- minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]);
|
|
|
- } else {
|
|
|
- minDsq = Float.MAX_VALUE;
|
|
|
+ assert centroidIdx >= 0 && centroidIdx < centroids.length;
|
|
|
+ float minDsq = VectorUtil.squareDistance(vector, centroids[centroidIdx]);
|
|
|
+ for (int offset : centroidOffsets) {
|
|
|
+ float dsq = VectorUtil.squareDistance(vector, centroids[offset]);
|
|
|
+ if (dsq < minDsq) {
|
|
|
+ minDsq = dsq;
|
|
|
+ bestCentroidOffset = offset;
|
|
|
+ }
|
|
|
}
|
|
|
+ return bestCentroidOffset;
|
|
|
+ }
|
|
|
|
|
|
- int k = 0;
|
|
|
- for (int j = 0; j < centroids.length; j++) {
|
|
|
- if (centroidOffsets == null || j == centroidOffsets[k]) {
|
|
|
- float dsq = VectorUtil.squareDistance(vector, centroids[j]);
|
|
|
- if (dsq < minDsq) {
|
|
|
- minDsq = dsq;
|
|
|
- bestCentroidOffset = j;
|
|
|
- }
|
|
|
+ int getBestCentroid(float[][] centroids, float[] vector) {
|
|
|
+ int bestCentroidOffset = 0;
|
|
|
+ float minDsq = Float.MAX_VALUE;
|
|
|
+ for (int i = 0; i < centroids.length; i++) {
|
|
|
+ float dsq = VectorUtil.squareDistance(vector, centroids[i]);
|
|
|
+ if (dsq < minDsq) {
|
|
|
+ minDsq = dsq;
|
|
|
+ bestCentroidOffset = i;
|
|
|
}
|
|
|
}
|
|
|
return bestCentroidOffset;
|
|
@@ -271,7 +276,8 @@ class KMeansLocal {
|
|
|
return;
|
|
|
}
|
|
|
|
|
|
- int[] assignments = new int[n];
|
|
|
+ int[] assignments = kMeansIntermediate.assignments();
|
|
|
+ assert assignments.length == n;
|
|
|
float[][] nextCentroids = new float[centroids.length][vectors.dimension()];
|
|
|
for (int i = 0; i < maxIterations; i++) {
|
|
|
if (stepLloyd(vectors, centroids, nextCentroids, assignments, sampleSize, neighborhoods) == false) {
|
|
@@ -291,7 +297,7 @@ class KMeansLocal {
|
|
|
* @param maxIterations the max iterations to shift centroids
|
|
|
*/
|
|
|
public static void cluster(FloatVectorValues vectors, float[][] centroids, int sampleSize, int maxIterations) throws IOException {
|
|
|
- KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids);
|
|
|
+ KMeansIntermediate kMeansIntermediate = new KMeansIntermediate(centroids, new int[vectors.size()], vectors::ordToDoc);
|
|
|
KMeansLocal kMeans = new KMeansLocal(sampleSize, maxIterations);
|
|
|
kMeans.cluster(vectors, kMeansIntermediate);
|
|
|
}
|