Search code examples
javafilterk-meansapache-flink

apache flink - filter as termination condition


I have defined a filter for the termination condition by k-means. if I run my app it always compute only one iteration.

I think the problem is here:

DataSet<GeoTimeDataCenter> finalCentroids = loop.closeWith(newCentroids, newCentroids.join(loop).where("*").equalTo("*").filter(new MyFilter()));

or maybe the filter function:

public static final class MyFilter implements FilterFunction<Tuple2<GeoTimeDataCenter, GeoTimeDataCenter>> {

    private static final long serialVersionUID = 5868635346889117617L;

    public boolean filter(Tuple2<GeoTimeDataCenter, GeoTimeDataCenter> tuple) throws Exception {
        if(tuple.f0.equals(tuple.f1)) {
            return true;
        }
        else {
            return false;
        }
    }
}

best regards, paul

my full code here:

public void run() {   
    //load properties
    Properties pro = new Properties();
    FileSystem fs = null;
    try {
        pro.load(FlinkMain.class.getResourceAsStream("/config.properties"));
        fs = FileSystem.get(new URI(pro.getProperty("hdfs.namenode")),new org.apache.hadoop.conf.Configuration());
    } catch (Exception e) {
        e.printStackTrace();
    }

    int maxIteration = Integer.parseInt(pro.getProperty("maxiterations"));
    String outputPath = fs.getHomeDirectory()+pro.getProperty("flink.output");
    // set up execution environment
    ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
    // get input points
    DataSet<GeoTimeDataTupel> points = getPointDataSet(env);
    DataSet<GeoTimeDataCenter> centroids = null;
    try {
        centroids = getCentroidDataSet(env);
    } catch (Exception e1) {
        e1.printStackTrace();
    }
    // set number of bulk iterations for KMeans algorithm
    IterativeDataSet<GeoTimeDataCenter> loop = centroids.iterate(maxIteration);
    DataSet<GeoTimeDataCenter> newCentroids = points
        // compute closest centroid for each point
        .map(new SelectNearestCenter(this.getBenchmarkCounter())).withBroadcastSet(loop, "centroids")
        // count and sum point coordinates for each centroid
        .groupBy(0).reduceGroup(new CentroidAccumulator())
        // compute new centroids from point counts and coordinate sums
        .map(new CentroidAverager(this.getBenchmarkCounter()));
    // feed new centroids back into next iteration with termination condition
    DataSet<GeoTimeDataCenter> finalCentroids = loop.closeWith(newCentroids, newCentroids.join(loop).where("*").equalTo("*").filter(new MyFilter()));
    DataSet<Tuple2<Integer, GeoTimeDataTupel>> clusteredPoints = points
        // assign points to final clusters
        .map(new SelectNearestCenter(-1)).withBroadcastSet(finalCentroids, "centroids");
    // emit result
    clusteredPoints.writeAsCsv(outputPath+"/points", "\n", " ");
    finalCentroids.writeAsText(outputPath+"/centers");//print();
    // execute program
    try {
        env.execute("KMeans Flink");
    } catch (Exception e) {
        e.printStackTrace();
    }
}

public static final class MyFilter implements FilterFunction<Tuple2<GeoTimeDataCenter, GeoTimeDataCenter>> {

    private static final long serialVersionUID = 5868635346889117617L;

    public boolean filter(Tuple2<GeoTimeDataCenter, GeoTimeDataCenter> tuple) throws Exception {
        if(tuple.f0.equals(tuple.f1)) {
            return true;
        }
        else {
            return false;
        }
    }
}

Solution

  • I think the problem is the filter function (modulo the code you haven't posted). Flink's termination criterion works the following way: The termination criterion is met if the provided termination DataSet is empty. Otherwise the next iteration is started if the maximum number of iterations has not been exceeded.

    Flink's filter function keeps only those elements for which the FilterFunction returns true. Thus, with your MyFilter implementation you only keep the centroids which are before and after the iteration identical. This implies that you'll obtain an empty DataSet if all centroids have changed and, thus, the iteration terminates. This is clearly the inverse of the actual termination criterion. The termination criterion should be: Continue with k-means as long as there is a centroid which has changed.

    You can do this with a coGroup function where you emit elements if there is no matching centroid from the preceding centroid DataSet. This is similar to a left outer join, just that you discard non null matches.

    public static void main(String[] args) throws Exception {
        // set up the execution environment
        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
    
        DataSet<Element> oldDS = env.fromElements(new Element(1, "test"), new Element(2, "test"), new Element(3, "foobar"));
        DataSet<Element> newDS = env.fromElements(new Element(1, "test"), new Element(3, "foobar"), new Element(4, "test"));
    
        DataSet<Element> filtered = newDS.coGroup(oldDS).where("*").equalTo("*").with(new FilterCoGroup());
    
        filtered.print();
    }
    
    public static class FilterCoGroup implements CoGroupFunction<Element, Element, Element> {
    
        @Override
        public void coGroup(
                Iterable<Element> newElements,
                Iterable<Element> oldElements,
                Collector<Element> collector) throws Exception {
    
            List<Element> persistedElements = new ArrayList<Element>();
    
            for(Element element: oldElements) {
                persistedElements.add(element);
            }
    
            for(Element newElement: newElements) {
                boolean contained = false;
    
                for(Element oldElement: persistedElements) {
                    if(newElement.equals(oldElement)){
                        contained = true;
                    }
                }
    
                if(!contained) {
                    collector.collect(newElement);
                }
            }
        }
    }
    
    public static class Element implements Key {
        private int id;
        private String name;
    
        public Element(int id, String name) {
            this.id = id;
            this.name = name;
        }
    
        public Element() {
            this(-1, "");
        }
    
        @Override
        public int hashCode() {
            return 31 + 7 * name.hashCode() + 11 * id;
        }
    
        @Override
        public boolean equals(Object obj) {
            if(obj instanceof Element) {
                Element element = (Element) obj;
    
                return id == element.id && name.equals(element.name);
            } else {
                return false;
            }
        }
    
        @Override
        public int compareTo(Object o) {
            if(o instanceof Element) {
                Element element = (Element) o;
    
    
                if(id == element.id) {
                    return name.compareTo(element.name);
                } else {
                    return id - element.id;
                }
            } else {
                throw new RuntimeException("Comparing incompatible types.");
            }
        }
    
        @Override
        public void write(DataOutputView dataOutputView) throws IOException {
            dataOutputView.writeInt(id);
            dataOutputView.writeUTF(name);
        }
    
        @Override
        public void read(DataInputView dataInputView) throws IOException {
            id = dataInputView.readInt();
            name = dataInputView.readUTF();
        }
    
        @Override
        public String toString() {
            return "(" + id + "; " + name + ")";
        }
    }