CoordinatedBolt and LinearDRPCTopologyBuilder


package backtype.storm.drpc;

import backtype.storm.generated.GlobalStreamId;
import backtype.storm.Config;
import java.util.Collection;
import backtype.storm.Constants;
import backtype.storm.generated.Grouping;
import backtype.storm.task.IOutputCollector;
import backtype.storm.task.OutputCollector;
import backtype.storm.task.TopologyContext;
import backtype.storm.topology.IRichBolt;
import backtype.storm.topology.OutputFieldsDeclarer;
import backtype.storm.tuple.Fields;
import backtype.storm.tuple.Tuple;
import backtype.storm.utils.TimeCacheMap;
import backtype.storm.utils.Utils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.apache.log4j.Logger;
import static backtype.storm.utils.Utils.get;
import static backtype.storm.utils.Utils.tuple;

public class CoordinatedBolt implements IRichBolt {
    public static Logger LOG = Logger.getLogger(CoordinatedBolt.class);

    public static interface FinishedCallback {
        void finishedId(Object id);

    public static class SourceArgs implements Serializable {
        public boolean singleCount;

        protected SourceArgs(boolean singleCount) {
            this.singleCount = singleCount;

        public static SourceArgs single() {
            return new SourceArgs(true);

        public static SourceArgs all() {
            return new SourceArgs(false);
        public String toString() {
            return "<Single: " + singleCount + ">";

    public class CoordinatedOutputCollector extends OutputCollector {
        IOutputCollector _delegate;

        public CoordinatedOutputCollector(IOutputCollector delegate) {
            _delegate = delegate;

        public List<Integer> emit(String stream, Collection<Tuple> anchors, List<Object> tuple) {
            List<Integer> tasks = _delegate.emit(stream, anchors, tuple);
            updateTaskCounts(tuple.get(0), tasks);
            return tasks;

        public void emitDirect(int task, String stream, Collection<Tuple> anchors, List<Object> tuple) {
            updateTaskCounts(tuple.get(0), Arrays.asList(task));
            _delegate.emitDirect(task, stream, anchors, tuple);

        public void ack(Tuple tuple) {
            Object id = tuple.getValue(0);
            synchronized(_tracked) {

        public void fail(Tuple tuple) {
            Object id = tuple.getValue(0);
            synchronized(_tracked) {
        public void reportError(Throwable error) {

        private void updateTaskCounts(Object id, List<Integer> tasks) {
            Map<Integer, Integer> taskEmittedTuples = _tracked.get(id).taskEmittedTuples;
            for(Integer task: tasks) {
                int newCount = get(taskEmittedTuples, task, 0) + 1;
                taskEmittedTuples.put(task, newCount);

    private SourceArgs _sourceArgs;
    private String _idComponent;
    private IRichBolt _delegate;
    private Integer _numSourceReports;
    private List<Integer> _countOutTasks = new ArrayList<Integer>();;
    private OutputCollector _collector;
    private TimeCacheMap<Object, TrackingInfo> _tracked;

    public static class TrackingInfo {
        int reportCount = 0;
        int expectedTupleCount = 0;
        int receivedTuples = 0;
        Map<Integer, Integer> taskEmittedTuples = new HashMap<Integer, Integer>();
        boolean receivedId = false;
        public String toString() {
            return "reportCount: " + reportCount + "\n" +
                   "expectedTupleCount: " + expectedTupleCount + "\n" +
                   "receivedTuples: " + receivedTuples + "\n" +

    public CoordinatedBolt(IRichBolt delegate) {
        this(delegate, null, null);

    public CoordinatedBolt(IRichBolt delegate, SourceArgs sourceArgs, String idComponent) {
        _sourceArgs = sourceArgs;
        _delegate = delegate;
        _idComponent = idComponent;

    public void prepare(Map config, TopologyContext context, OutputCollector collector) {
        _tracked = new TimeCacheMap<Object, TrackingInfo>(Utils.toInteger(config.get(Config.TOPOLOGY_MESSAGE_TIMEOUT_SECS)));
        _collector = collector;
        _delegate.prepare(config, context, new CoordinatedOutputCollector(collector));
        for(String component: Utils.get(context.getThisTargets(),
                                        new HashMap<String, Grouping>())
                                        .keySet()) {
            for(Integer task: context.getComponentTasks(component)) {
        if(_sourceArgs!=null) {
            if(_sourceArgs.singleCount) {
                _numSourceReports = 1;
            } else {
                Iterator<GlobalStreamId> it = context.getThisSources().keySet().iterator();
                while(it.hasNext()) {
                    String sourceComponent =;
                    if(_idComponent==null || !sourceComponent.equals(_idComponent)) {
                        _numSourceReports = context.getComponentTasks(sourceComponent).size();

    private void checkFinishId(Object id) {
        synchronized(_tracked) {
            TrackingInfo track = _tracked.get(id);
                    && track.receivedId 
                    && (_sourceArgs==null
                       track.reportCount==_numSourceReports &&
                       track.expectedTupleCount == track.receivedTuples)) {
                if(_delegate instanceof FinishedCallback) {
                Iterator<Integer> outTasks = _countOutTasks.iterator();
                while(outTasks.hasNext()) {
                    int task =;
                    int numTuples = get(track.taskEmittedTuples, task, 0);
                    _collector.emitDirect(task, Constants.COORDINATED_STREAM_ID, tuple(id, numTuples));

    public void execute(Tuple tuple) {
        Object id = tuple.getValue(0);
        TrackingInfo track;
        synchronized(_tracked) {
            track = _tracked.get(id);
            if(track==null) {
                track = new TrackingInfo();
                if(_idComponent==null) track.receivedId = true;
                _tracked.put(id, track);
        boolean checkFinish = false;
                && tuple.getSourceComponent().equals(_idComponent)
                && tuple.getSourceStreamId().equals(PrepareRequest.ID_STREAM)) {
            synchronized(_tracked) {
                track.receivedId = true;
            checkFinish = true;
        } else if(_sourceArgs!=null
                && tuple.getSourceStreamId().equals(Constants.COORDINATED_STREAM_ID)) {
            int count = (Integer) tuple.getValue(1);
            synchronized(_tracked) {
            checkFinish = true;
        } else {            
        if(checkFinish) {

    public void cleanup() {

    public void declareOutputFields(OutputFieldsDeclarer declarer) {
        declarer.declareStream(Constants.COORDINATED_STREAM_ID, true, new Fields("id", "count"));



package backtype.storm.drpc;

import backtype.storm.Constants;
import backtype.storm.ILocalDRPC;
import backtype.storm.drpc.CoordinatedBolt.FinishedCallback;
import backtype.storm.drpc.CoordinatedBolt.SourceArgs;
import backtype.storm.generated.StormTopology;
import backtype.storm.generated.StreamInfo;
import backtype.storm.topology.BasicBoltExecutor;
import backtype.storm.topology.IBasicBolt;
import backtype.storm.topology.IRichBolt;
import backtype.storm.topology.InputDeclarer;
import backtype.storm.topology.OutputFieldsGetter;
import backtype.storm.topology.TopologyBuilder;
import backtype.storm.tuple.Fields;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;

// need a "final bolt" method, that does fields groupings based on the first field of previous streams.
// preparerequest needs to emit to a special stream to indicate which task in the last bolt is responsible for that id?
// -- what if it's shuffle grouping all the way through? need to enforce that last bolt do fields grouping on id...
public class LinearDRPCTopologyBuilder {    
    String _function;
    List<Component> _components = new ArrayList<Component>();
    public LinearDRPCTopologyBuilder(String function) {
        _function = function;
    public LinearDRPCInputDeclarer addBolt(IRichBolt bolt, int parallelism) {
        Component component = new Component(bolt, parallelism);
        return new InputDeclarerImpl(component);
    public LinearDRPCInputDeclarer addBolt(IRichBolt bolt) {
        return addBolt(bolt, 1);
    public LinearDRPCInputDeclarer addBolt(IBasicBolt bolt, int parallelism) {
        return addBolt(new BasicBoltExecutor(bolt), parallelism);

    public LinearDRPCInputDeclarer addBolt(IBasicBolt bolt) {
        return addBolt(bolt, 1);
    public StormTopology createLocalTopology(ILocalDRPC drpc) {
        return createTopology(new DRPCSpout(_function, drpc));
    public StormTopology createRemoteTopology() {
        return createTopology(new DRPCSpout(_function));
    private StormTopology createTopology(DRPCSpout spout) {
        final String SPOUT_ID = "spout";
        final String PREPARE_ID = "prepare-request";
        TopologyBuilder builder = new TopologyBuilder();
        builder.setSpout(SPOUT_ID, spout);
        builder.setBolt(PREPARE_ID, new PrepareRequest())
        int i=0;
        for(; i<_components.size();i++) {
            Component component = _components.get(i);
            SourceArgs source;
            if(i==0) {
                source = null;
            } else if (i==1) {
                source = SourceArgs.single();
            } else {
                source = SourceArgs.all();
            String idComponent = null;
            if(i==_components.size()-1 && component.bolt instanceof FinishedCallback) {
                idComponent = PREPARE_ID;
            InputDeclarer declarer = builder.setBolt(
                    new CoordinatedBolt(component.bolt, source, idComponent),
            if(idComponent!=null) {
                declarer.fieldsGrouping(idComponent, PrepareRequest.ID_STREAM, new Fields("request"));
            if(i==0 && component.declarations.size()==0) {
                declarer.noneGrouping(PREPARE_ID, PrepareRequest.ARGS_STREAM);
            } else {
                String prevId;
                if(i==0) {
                    prevId = PREPARE_ID;
                } else {
                    prevId = boltId(i-1);
                for(InputDeclaration declaration: component.declarations) {
                    declaration.declare(prevId, declarer);
            if(i>0) {
                declarer.directGrouping(boltId(i-1), Constants.COORDINATED_STREAM_ID); 
        IRichBolt lastBolt = _components.get(_components.size()-1).bolt;
        OutputFieldsGetter getter = new OutputFieldsGetter();
        Map<String, StreamInfo> streams = getter.getFieldsDeclaration();
        if(streams.size()!=1) {
            throw new RuntimeException("Must declare exactly one stream from last bolt in LinearDRPCTopology");
        String outputStream = streams.keySet().iterator().next();
        List<String> fields = streams.get(outputStream).get_output_fields();
        if(fields.size()!=2) {
            throw new RuntimeException("Output stream of last component in LinearDRPCTopology must contain exactly two fields. The first should be the request id, and the second should be the result.");

        builder.setBolt(boltId(i), new JoinResult(PREPARE_ID))
                .fieldsGrouping(boltId(i-1), outputStream, new Fields(fields.get(0)))
                .fieldsGrouping(PREPARE_ID, PrepareRequest.RETURN_STREAM, new Fields("request"));
        builder.setBolt(boltId(i), new ReturnResults())
        return builder.createTopology();
    private static String boltId(int index) {
        return "bolt" + index;
    private static class Component {
        public IRichBolt bolt;
        public int parallelism;
        public List<InputDeclaration> declarations = new ArrayList<InputDeclaration>();
        public Component(IRichBolt bolt, int parallelism) {
            this.bolt = bolt;
            this.parallelism = parallelism;
    private static interface InputDeclaration {
        public void declare(String prevComponent, InputDeclarer declarer);
    private class InputDeclarerImpl implements LinearDRPCInputDeclarer {
        Component _component;
        public InputDeclarerImpl(Component component) {
            _component = component;
        public LinearDRPCInputDeclarer fieldsGrouping(final Fields fields) {
            addDeclaration(new InputDeclaration() {
                public void declare(String prevComponent, InputDeclarer declarer) {
                    declarer.fieldsGrouping(prevComponent, fields);
            return this;

        public LinearDRPCInputDeclarer fieldsGrouping(final String streamId, final Fields fields) {
            addDeclaration(new InputDeclaration() {
                public void declare(String prevComponent, InputDeclarer declarer) {
                    declarer.fieldsGrouping(prevComponent, streamId, fields);
            return this;

        public LinearDRPCInputDeclarer globalGrouping() {
            addDeclaration(new InputDeclaration() {
                public void declare(String prevComponent, InputDeclarer declarer) {
            return this;

        public LinearDRPCInputDeclarer globalGrouping(final String streamId) {
            addDeclaration(new InputDeclaration() {
                public void declare(String prevComponent, InputDeclarer declarer) {
                    declarer.globalGrouping(prevComponent, streamId);
            return this;

        public LinearDRPCInputDeclarer shuffleGrouping() {
            addDeclaration(new InputDeclaration() {
                public void declare(String prevComponent, InputDeclarer declarer) {
            return this;

        public LinearDRPCInputDeclarer shuffleGrouping(final String streamId) {
            addDeclaration(new InputDeclaration() {
                public void declare(String prevComponent, InputDeclarer declarer) {
                    declarer.shuffleGrouping(prevComponent, streamId);
            return this;

        public LinearDRPCInputDeclarer noneGrouping() {
            addDeclaration(new InputDeclaration() {
                public void declare(String prevComponent, InputDeclarer declarer) {
            return this;

        public LinearDRPCInputDeclarer noneGrouping(final String streamId) {
            addDeclaration(new InputDeclaration() {
                public void declare(String prevComponent, InputDeclarer declarer) {
                    declarer.noneGrouping(prevComponent, streamId);
            return this;

        public LinearDRPCInputDeclarer allGrouping() {
            addDeclaration(new InputDeclaration() {
                public void declare(String prevComponent, InputDeclarer declarer) {
            return this;

        public LinearDRPCInputDeclarer allGrouping(final String streamId) {
            addDeclaration(new InputDeclaration() {
                public void declare(String prevComponent, InputDeclarer declarer) {
                    declarer.allGrouping(prevComponent, streamId);
            return this;

        public LinearDRPCInputDeclarer directGrouping() {
            addDeclaration(new InputDeclaration() {
                public void declare(String prevComponent, InputDeclarer declarer) {
            return this;

        public LinearDRPCInputDeclarer directGrouping(final String streamId) {
            addDeclaration(new InputDeclaration() {
                public void declare(String prevComponent, InputDeclarer declarer) {
                    declarer.directGrouping(prevComponent, streamId);
            return this;
        private void addDeclaration(InputDeclaration declaration) {
