#include <ostream>
#include <cstdlib>
#include <divine/algorithm/common.h>

#include <divine/ltlce.h>
#include <divine/visitor.h>
#include <divine/parallel.h>
#include <divine/report.h>
#include <divine/porcp.h>

#ifndef DIVINE_EXPLAIN_H
#define DIVINE_EXPLAIN_H

namespace divine {
namespace algorithm {

//          dumper
//      ----------------
template< typename T >
struct dumper {
    dumper( bool ext = false ) {}

    void operator()( const T &x ) {
        std::cout << x << std::endl;
    }
};

template< typename E > struct dumper< std::vector< E > > {
    void operator()( const std::vector< E > &v ) {
        std::cout << " ----------------------------------" << std::endl;
        for_each( v.begin(), v.end(), dumper< E >() );
    }
};

template< typename E > struct dumper< std::set< E > > {
    dumper< E > m_minor;

    void operator()( const std::set< E > &v ) {
        std::cout << " ----------------------------------" << std::endl;
        for_each( v.begin(), v.end(), m_minor );
    }
};

struct NullContainer {
    template< typename T >
    void push_back( T x ) {}

    template< typename T >
    void push_front( T x ) {}

};

//          sameLocations    
//      ----------------------
typedef generator::NDve _DVE;
void notDVE() {
    std::cerr <<  "This binary only supports Explain on DVE models." << std::endl;
    assert_die();
}

template< typename GRAPH >
dve_explicit_system_t *get_dve( GRAPH &g ) {
    notDVE();
    return NULL;
}

dve_explicit_system_t *get_dve( _DVE &g ) {
    return static_cast< dve_explicit_system_t* >( g.legacy_system() );
}

dve_explicit_system_t *get_dve( divine::algorithm::NonPORGraph< _DVE > &g ) {
    return get_dve( g.g() );
}

template< typename GRAPH >
bool sameLocations( GRAPH &g, typename GRAPH::Node a, typename GRAPH::Node b ) {
    dve_explicit_system_t *dvesys = get_dve( g );
    assert( !dvesys->get_with_property() );
    size_t count = dvesys->get_process_count();

    state_t old_a = g.allocator().legacy_state( a );
    state_t old_b = g.allocator().legacy_state( b );

    for ( size_t proc_id = 0; proc_id < count; proc_id ++ ) {
        if ( dvesys->get_state_of_process( old_a, proc_id ) !=
             dvesys->get_state_of_process( old_b, proc_id ) ) return false;

    }
    return true;
}

//          transitions    
//      ----------------------
struct NoTransition {
    bool operator== ( const NoTransition &other ) const {
        assert_die();
        return false;
    }

    bool operator< ( const NoTransition &other ) const {
        return true;
    }
};

std::ostream &operator<<( std::ostream &o, const NoTransition &t ) {
    return o;
}


struct DveTransition {
    dve_enabled_trans_t trans;

    DveTransition() {};
    DveTransition( const enabled_trans_t &_trans ): trans( static_cast< const dve_enabled_trans_t& >( _trans ) ) {
        assert( !trans.get_erroneous() );
    }

    DveTransition &operator= ( const DveTransition &other ) {
        if ( this == &other ) return *this; // dve_transition_t screws this up and we need it in nested domain().parallel().run()
        trans = other.trans;
        return *this;
    }

    bool operator== ( const DveTransition &other ) const {
        if ( trans.get_count() != other.trans.get_count() ) return false;

        for ( size_int_t i = 0; i < trans.get_count(); i ++ ) {
            if ( trans[ i ]->get_gid() != other.trans[ i ]->get_gid() ) return false;
        }
        return true;
    }

    bool operator< ( const DveTransition &other ) const {
        if ( trans.get_count() < other.trans.get_count() ) return true;
        if ( trans.get_count() > other.trans.get_count() ) return false;

        for ( size_int_t i = 0; i < trans.get_count(); i ++ ) {
            if ( trans[ i ]->get_gid() < other.trans[ i ]->get_gid() ) return true;
            if ( trans[ i ]->get_gid() > other.trans[ i ]->get_gid() ) return false;
        }

        return false;
    }

    void extPrint( std::ostream &out ) const {
        trans.write( out );
        out << " / ";
        for ( size_int_t i = 0; i < trans.get_count(); i++ ) {
            if ( i != 0 ) out << " & ";
            trans[ i ]->write( out );
        }

        out << std::endl;
    }
};

std::ostream &operator<<( std::ostream &o, const DveTransition &t ) {
    t.trans.write( o );
    return o;
}

template<> struct dumper< DveTransition > {
    bool m_ext;
    dumper( bool ext = false ): m_ext( ext ) {}

    void operator()( const DveTransition &t ) {
        if ( !m_ext ) {
            std::cout << t << std::endl;
        } else {
            t.extPrint( std::cout );
        }
    }
};


void writeVarName( std::ostream &out, dve_symbol_table_t *symbols, dve_symbol_t *symb ) {
    if ( symb->get_process_gid() != NO_ID ) {
        out << symbols->get_process( symb->get_process_gid() )->get_name() << ".";
    }
    out << symb->get_name();
}

//! Writes Daikon declarations for given system.
template< typename GRAPH >
void writeVarDecls( GRAPH &g, std::ostream &out, int ins_count ) {
    dve_symbol_table_t *symbols = get_dve( g )->get_symbol_table();

    std::string posneg[] = { "pos", "neg" };
    for ( int pn = 0; pn < 2; pn++ ) {
        for ( int ins = 0; ins < ins_count; ins ++ ) {
            out << "DECLARE" << std::endl;
            out << "instrument" << ins << "." << posneg[ pn ] << ":::POINT" << std::endl;

            for ( size_t vi = 0; vi < symbols->get_variable_count(); vi ++ ) {
                dve_symbol_t *symb = symbols->get_variable( vi );
                writeVarName( out, symbols, symb );
                if ( symb->is_vector() ) {
                    out << "[]" << std::endl;
                    out << "int[]" << std::endl << "int[]" << std::endl;
                    out << "1[1]" << std::endl;

                } else {
                    out << std::endl;
                    out << "int" << std::endl << "int" << std::endl;
                    out << "1" << std::endl;
                }
            }

            for ( size_int_t ch = 0; ch < symbols->get_channel_count(); ch ++ ) {
                dve_symbol_t *symb = symbols->get_channel( ch );
                if ( symb->get_channel_buffer_size() > 0 ) {
                    if ( symb->get_channel_item_count() > 1 ) {
                        std::cerr << "Warning: Channel " << symb->get_name() << " has more than one value, which is not supported by Daikon. Ignoring it. " << std::endl;
                        continue;
                    }
                    out << symb->get_name() << "[]" << std::endl;
                    out << "int[]" << std::endl << "int[]" << std::endl;
                    out << "1[1]" << std::endl;
                }
            }

            out << std::endl;
        }
    }
}

template< typename GRAPH >
void writeStateVars( GRAPH &g, typename GRAPH::Node n, std::ostream &out ) {
    dve_explicit_system_t *dvesys = get_dve( g );
    dve_symbol_table_t *symbols = dvesys->get_symbol_table();

    for ( size_t vi = 0; vi < symbols->get_variable_count(); vi ++ ) {
        dve_symbol_t *symb = symbols->get_variable( vi );
        writeVarName( out, symbols, symb );

        if ( symb->is_vector() ) {
            out << "[]" << std::endl << "[";
            for ( size_t ix = 0; ix < symb->get_vector_size(); ix ++ ) {
                out << dvesys->get_var_value( g.allocator().legacy_state( n ), symb->get_gid(), ix ) << " ";
            }
            out << "]";

        } else {
            out << std::endl << dvesys->get_var_value( g.allocator().legacy_state( n ), symb->get_gid() );
        }

        out << std::endl << "1" << std::endl;
    }

    for ( size_int_t ch = 0; ch < symbols->get_channel_count(); ch ++ ) {
        dve_symbol_t *symb = symbols->get_channel( ch );
        if ( symb->get_channel_buffer_size() > 0 ) {
            if ( symb->get_channel_item_count() > 1 ) {
                continue;
            }

            out << symb->get_name() << "[]" << std::endl;
            out << "[";
            std::vector< std::vector< int > > chan = dvesys->read_whole_channel( g.allocator().legacy_state( n ), symb->get_gid() );
            for ( std::vector< std::vector< int > >::iterator i = chan.begin(); i != chan.end(); i++ ) {
                out << i->front() << " ";
            }
            out << "]" << std::endl << "1" << std::endl;
        }
    }

    out << std::endl;
}

typedef std::pair< int, int > AssertionId;

template< typename OuterGraph >
struct GraphTypesDVEImpl {
    typedef DveTransition Trans;
    typedef dve_expression_t Expression;

    //! Returns the transition object between the two states.
    static Trans findTransition( OuterGraph &g, typename OuterGraph::Node from, typename OuterGraph::Node to ) {
        if ( ( !from.valid() ) || ( !to.valid() ) ) return Trans();

        dve_explicit_system_t *dvesys = get_dve( g );

        state_t old_from = g.allocator().legacy_state( from );
        state_t old_to = g.allocator().legacy_state( to );

        succ_container_t succs;
        enabled_trans_container_t transitions( *dvesys );
        dvesys->get_succs( old_from, succs, transitions );
        size_t result_ix = -1;
        for ( succ_container_t::iterator i = succs.begin(); i != succs.end(); i ++ ) {
            if ( *i == old_to ) result_ix = i - succs.begin();
        }

        assert( result_ix != -1 );

        // cleanup
        for ( succ_container_t::iterator i = succs.begin(); i != succs.end(); i ++ ) {
            g.release( g.allocator().unlegacy_state( *i ) );
        }

        return DveTransition( transitions[ result_ix ] );
    }

    static AssertionId violatesAssertion( OuterGraph &g, typename OuterGraph::Node n ) {
        dve_explicit_system_t *dvesys = get_dve( g );

        size_t count = dvesys->get_process_count();
        state_t old_node = g.allocator().legacy_state( n );
        bool eval_err = false;

        for ( size_t proc_id = 0; proc_id < count; proc_id ++ ) {
            size_t proc_state = dvesys->get_state_of_process( old_node, proc_id );
            dve_process_t *dve_proc = dynamic_cast< dve_process_t* >( dvesys->get_process( proc_id ) );

            for ( size_int_t assert_ix = 0; assert_ix < dve_proc->get_assertion_count( proc_state ); assert_ix ++ ) {
                dve_expression_t * assert = dve_proc->get_assertion( proc_state, assert_ix );
                if ( !dvesys->eval_expr( assert, old_node, eval_err ) ) {
                    return make_pair( proc_id, assert_ix );
                }
            }
        }
        return make_pair( -1, -1 );
    }

    //          expressions    
    //      ----------------------
    static dve_expression_t *newExpression( OuterGraph &g, std::string str ) {
        return new dve_expression_t( str, get_dve( g ) );
    }

    static int evaluateExpression( OuterGraph &g, Expression *expr, typename OuterGraph::Node n ) {
        bool err = false;
        int res = get_dve( g )->eval_expr( expr, g.allocator().legacy_state( n ), err );
        assert( !err );
        return res;
    }
};

template< typename Graph >
struct GraphTypes {
    typedef NoTransition Trans;
    typedef wibble::Unit Expression;

    static Trans findTransition( Graph &g, typename Graph::Node from, typename Graph::Node to ) {
        notDVE();
        return Trans();
    }

    static Expression *newExpression( Graph &g, std::string str ) {
        notDVE();
        return NULL;
    }

    static int evaluateExpression( Graph &g, Expression *expr, typename Graph::Node n ) {
        notDVE();
        return 0;
    }

    static AssertionId violatesAssertion( Graph &g, typename Graph::Node n ) {
        notDVE();
        return AssertionId();
    }
};

template<>
struct GraphTypes< _DVE >: public GraphTypesDVEImpl< _DVE > {
};

template<>
struct GraphTypes< divine::algorithm::NonPORGraph< _DVE > > : public GraphTypesDVEImpl< divine::algorithm::NonPORGraph< _DVE > > {
};

template< typename S >
struct GraphTypes< divine::algorithm::PORGraph< _DVE, S > > : public GraphTypesDVEImpl< divine::algorithm::PORGraph< _DVE, S > > {
};


/**
 * Explains counterexamples of safety properties violations based on
 * A. Groce and W. Visser. What Went Wrong: Explaining Counterexamples.
 * In In SPIN Workshop on Model Checking of Software, pages 121–135. Springer-Verlag, 2003.
 */
template< typename G, typename Statistics >
struct Explain: Algorithm, DomainWorker< Explain< G, Statistics > >
{
    //      support
    //      -------
    typedef typename G::Node Node;
    typedef Explain< G, Statistics > Us;
    typedef GraphTypes< G > GT;
    typedef typename GT::Trans Transition;
    typedef std::vector< Transition > ActionTrail;
    typedef typename GT::Expression Expression;

    enum {
        NUM_PARENTS = 0,
        PREFER_DEFAULT_PARENT = 6
    };

    struct TrailSeedEdge {
        Node first, second;

        //! should we follow ext_parent first?
        bool is_extension;

        TrailSeedEdge( Node f, Node s, bool e ): first( f ), second( s ), is_extension( e ) {};
    };

    struct Transformation {
        size_t kt;

        typedef typename std::deque< Node >::const_iterator Iter;
        Iter tp_begin;
        Iter tp_end;
        size_t pos_index;

        Iter tn_begin;
        Iter tn_end;
        size_t neg_index;

        Transformation( size_t _kt, Iter _tp_begin, Iter _tp_end, Iter _tn_begin, Iter _tn_end ):
            kt( _kt ), tp_begin( _tp_begin ), tp_end( _tp_end ), tn_begin( _tn_begin ), tn_end( _tn_end ),
            pos_index( 0 ), neg_index( 0 ) {}

        Transformation(): kt( 0 ), pos_index( 0 ), neg_index( 0 ) {}

        bool operator<( const Transformation &other ) const {
            return size() < other.size();
        }

        size_t size() const {
            return ( tp_end - tp_begin ) + ( tn_end - tn_begin );
        }

        std::ostream &dump( std::ostream &o, G &graph ) const {
            o << "Transformation size " << size() << std::endl;
            o << "positive part (from p" << pos_index << "):" << std::endl;
            for ( typename Transformation::Iter i = tp_begin; i != tp_end; i ++ ) {
                o << graph.showNode( *i ) << std::endl;
            }

            o << "negative part (from n" << neg_index << "):" << std::endl;
            for ( typename Transformation::Iter i = tn_begin; i != tn_end; i ++ ) {
                o << graph.showNode( *i ) << std::endl;
            }
            return o;
        }
    };

    struct Shared {
        G g;
        CeShared< Node > ce;
        Node initial;        // start search here
        int initialTable;

        //! the error state we're trying to explain
        Node goal;

        GoalType::Enum goalType;

        //! Which assertion failed (only when the goal is an assertion violation)
        AssertionId assertion;

        //! The last edge of each negative (leading to an error state).
        std::vector< TrailSeedEdge > negatives;

        //! The last edge of positives (leading to an non-error state).
        std::vector< TrailSeedEdge > positives;

        //! The currently processed node in extension algorithm.
        Node ext_current;

        //! The counterexample
        std::vector< Node > ce_trail;

        //! Maximum search depth (from config)
        int depth;

        //! How many positives/negatives were found using extension.
        int ext_stat;

        Shared &operator=( const Shared &other ) {
            g = other.g;
            ce = other.ce;
            initial = other.initial;
            initialTable = other.initialTable;
            goal = other.goal;
            goalType = other.goalType;
            negatives = other.negatives;
            positives = other.positives;
            ext_current = other.ext_current;
            ce_trail = other.ce_trail;
            depth = other.depth;
            ext_stat = other.ext_stat;
            assertion = other.assertion;
        }

        Shared(): initialTable( 1024 ), goalType( static_cast< GoalType::Enum >( GoalType::All ) ),
                  depth( 100 ), ext_stat( 0 ) {}

    } shared;

    struct Extension {
        Node parent;
        int distance;

        //! Expansion of this state in explain has begun.
        int visited:30;
        bool on_negative:1;

        //! Which state on counterexample this state matches?
        int ext_match;

        //! Used to find the path in extension procedure.
        Node ext_parent;

        Node parents[NUM_PARENTS];

        //! TODO:combine with some other field if possible to save memory
        int ext_has_succ;

        bool push_parent( Node n ) {
            if ( !n.valid() ) return false;

            if ( !parent.valid() ) {
                parent = n;
                return true;

            } else if ( !( parent == n ) ) {
                for ( size_t ix = 0; ix < NUM_PARENTS; ix ++ ) {
                    if ( ( parents[ ix ].valid() ) && ( parents[ ix ] == n ) ) return false;
                    if ( parents[ ix ].valid() ) continue;
                    parents[ ix ] = n;
                    return true;
                }
            }

            return false;
        }
    };

    LtlCE< G, Shared, Extension > ce;
    std::vector< TrailSeedEdge > m_pos_starts;
    std::vector< TrailSeedEdge > m_neg_starts;

    typedef std::deque< Node > Trail;
    std::vector< Trail > m_positives;
    std::vector< Trail > m_negatives;

    std::vector< Expression* > m_instrumentation;
    std::vector< std::string > m_instr_string;
    bool m_exptest;
    int m_iteration;

    //! Initial states for extension (behind the depth limit)
    // We need this per-thread, but without the autocopying in Shared.
    // HACK: We rely on runInRing not creating a new thread (and thus clearing this).
    std::vector< Node > m_ext_init;

    Domain< Us > &domain() {
        return DomainWorker< Us >::domain();
    }

    Extension &extension( Node n ) {
        return n.template get< Extension >();
    }

    //      reachability & common
    //      ---------------------
    visitor::ExpansionAction reachExpansion( Node st ) {
        return visitor::ExpandState;
    }

    void setParent( Node f, Node t ) {
        if ( extension( t ).push_parent( f ) ) {
            visitor::setPermanent( f );
        }
    }

    visitor::TransitionAction reachTransition( Node f, Node t ) {
        setParent( f, t );

        std::pair< bool, GoalType::Enum > prop = shared.g.isGoal( t );
        if ( prop.first ) {
            shared.goal = t;
            shared.goalType = prop.second;
            if ( prop.second == GoalType::Assert ) {
                shared.assertion = GT::violatesAssertion( shared.g, t );
                assert_neq( -1, shared.assertion.first );
            }
            return visitor::TerminateOnTransition;
        }

        return visitor::FollowTransition;
    }

    void _reachability() { // parallel
        typedef visitor::Setup< G, Us, Table, Statistics,
            &Us::reachTransition,
            &Us::reachExpansion > Setup;
        typedef visitor::Parallel< Setup, Us, Hasher > Visitor;

        m_initialTable = &shared.initialTable;
        Visitor visitor( shared.g, *this, *this,
                         Hasher( sizeof( Extension ) ), &table() );
        visitor.exploreFrom( shared.g.initial() );
    }

    //      explain algorithm
    //      -----------------
    static const int MAX_TRAILS = 50;

    visitor::ExpansionAction expansion( Node st ) {
        if ( st == shared.initial ) {
            extension( st ).visited = true;
        }
        return visitor::ExpandState;
    }

    struct CheckType {
        enum Enum {
            Positive = 1,
            Negative = 2,
            Extension = 4,
            OnlyOne = 8
        };
    };

    //! \returns true if the goal of \c n is the same as the goal stored in shared.
    bool same_goal( std::pair< bool, GoalType::Enum > prop, Node n ) {
        if ( !prop.first ) return false;
        if ( prop.second != shared.goalType ) return false;
        if ( ( prop.second == GoalType::Assert ) && ( GT::violatesAssertion( shared.g, n ) != shared.assertion ) ) return false;
        return true;
    }

    void check_state( Node f, Node t, int flags ) {
        std::vector< TrailSeedEdge > *target = &shared.negatives;
        if ( flags & CheckType::Positive ) target = &shared.positives;

        std::pair< bool, GoalType::Enum > prop = shared.g.isGoal( t );

        if ( prop.first == static_cast<bool>( flags & CheckType::Negative ) ) {
            if ( ( flags & CheckType::Negative ) && !same_goal( prop, t ) ) return;
            if ( ( flags & CheckType::OnlyOne ) && !ext_no_succ_yet( f ) ) return;

            target->push_back( TrailSeedEdge( f, t, flags & CheckType::Extension ) );
            if ( flags & CheckType::Extension ) shared.ext_stat ++;
        }
    }

    visitor::TransitionAction transition( Node f, Node t ) {
        setParent( f, t );

        if ( f.valid() ) {
            if ( !extension( t ).visited ) { // seen already?
                extension( t ).distance = extension( f ).distance + 1;
            }

            bool t_match = sameLocations( shared.g, t, shared.goal );
            bool f_match = sameLocations( shared.g, f, shared.goal );

            if ( t_match ) {
                check_state( f, t, CheckType::Negative );
            }

            // a positive must have a successor and a good one
            if ( f_match  ) {
                std::pair< bool, GoalType::Enum > f_prop = shared.g.isGoal( f );
                if ( !f_prop.first ) {
                    // we add the 't' state as well, even though it's rather irrelevant
                    // but that's what the definition says
                    check_state( f, t, CheckType::Positive );
                }
            }
        }

        if ( ( extension( t ).distance > shared.depth ) || extension( t ).visited ) {
            m_ext_init.push_back( t );
            return visitor::ForgetTransition;
        } else {
            extension( t ).visited = true;
            return visitor::ExpandTransition;
        }
    }

    void _explain() { // parallel
        m_initialTable = &shared.initialTable;
        typedef visitor::Setup< G, Us, Table, Statistics,
                &Us::transition,
                &Us::expansion > Setup;
        visitor::Parallel< Setup, Us, Hasher >
            vis( shared.g, *this, *this, hasher, &table() );

        // exploreFrom deletes the initial state
        Node init = shared.g.allocator().duplicate_state( shared.initial );
        assert( init.valid() );
        vis.exploreFrom( init );
    }

    //      functional style
    //      ----------------

    // got tired of iterating over shareds
    template< typename F >
    void do_shared( F fn ) {
        for ( int worker_ix = 0; worker_ix < domain().peers(); worker_ix ++ ) {
            (this->*fn)( domain().shared( worker_ix ) );
        }
    }

    template< typename I >
    I reduce_shared( I (Us::*fn)(I, Shared&), I start ) {
        I x = start;
        for ( int worker_ix = 0; worker_ix < domain().peers(); worker_ix ++ ) {
            x = (this->*fn)( x, domain().shared( worker_ix ) );
        }

        return x;
    }

    template< typename I, typename F, typename Iter >
    I reduce( F fn, I start, Iter begin, Iter end ) {
        I x = start;
        for ( Iter i = begin; i != end; i++ ) {
            x = fn( x, *i );
        }
        return x;
    }


    template< typename T, typename Iterator >
    static void intersection_inplace( std::set< T > &cont, Iterator begin, Iterator end ) {
        Iterator right_it = begin;
        typename std::set< T >::iterator left_it = cont.begin();

        for (; ( left_it != cont.end() ) && ( right_it != end ); ) {
            if ( *left_it < *right_it ) {
                cont.erase( left_it++ );
            } else if ( *right_it < *left_it ) {
                right_it++;
            } else {
                left_it++;
                ++right_it;
            }
        }

        // cut off the remains
        cont.erase( left_it, cont.end() );
    }

    template< typename InIter, typename T >
    static void big_intersection( InIter begin, InIter end, std::set< T > &target ) {
        InIter first = begin;
        if ( first != end ) {
            target.insert( first->begin(), first->end() );
        }

        for ( InIter i = begin; i != end; i++ ) {
            intersection_inplace( target, i->begin(), i->end() );
        }
    }


    //      main
    //      ----

    void _parentTrace() {
        ce.setup( shared.g, shared ); // XXX this will be done many times needlessly
        ce._parentTrace( *this, hasher, equal, table() );
    }

    void trail_from( Node n ) {
        shared.ce.initial = n;
        ce.setup( shared.g, shared );
        ce.linear( domain(), *this );
    }

    void collect_pos( Shared &sh ) {
        m_pos_starts.insert( m_pos_starts.end(), sh.positives.begin(), sh.positives.end() );
    }

    void collect_neg( Shared &sh ) {
        m_neg_starts.insert( m_neg_starts.end(), sh.negatives.begin(), sh.negatives.end() );
    }

    struct GoalInfo {
        Node n;
        GoalType::Enum type;
        AssertionId assertion;

        GoalInfo() {}
        GoalInfo( Node _n, GoalType::Enum _type, AssertionId _assertion ): n( _n ), type( _type ), assertion( _assertion ) {}
    };

    GoalInfo collectGoal( GoalInfo val, Shared &sh ) {
        if ( val.n.valid() ) return val;
        return GoalInfo( sh.goal, sh.goalType, sh.assertion );
    }

    int sum_ext_stat( int val, Shared &sh ) {
        return val + sh.ext_stat;
    }

    Node get_parent( Extension &node_ext, bool is_extension, bool &was_default ) {
        // parent from extension has priority
        if ( is_extension && node_ext.ext_parent.valid() ) {
            was_default = true;
            return node_ext.ext_parent;
        }

        // count the number of choices
        size_t count = 0;
        for ( size_t ix = 0; ix < NUM_PARENTS; ix ++ ) {
            if ( !node_ext.parents[ ix ].valid() ) break;
            count ++;
        }

        // select randomly
        int choose = rand() % ( count + PREFER_DEFAULT_PARENT );
        Node result;
        if ( choose < PREFER_DEFAULT_PARENT ) {
            was_default = true;
            result = node_ext.parent;

        } else {
            choose -= PREFER_DEFAULT_PARENT;
            was_default = false;
            result = node_ext.parents[ choose ];
            if ( extension( result ).visited == m_iteration ) {
                was_default = true;
                result = node_ext.parent;
            }
        }
        extension( result ).visited = m_iteration;
        return result;
    }

    // Singlethreaded version of LtlCE::parentTrace which avoids
    // the overhead with copying shared etc.. MPI version won't be that easy.
    //! \param is_extension If this is a trail from the extension algorithm.
    //! \returns false if at some point we used a different parent than the default
    template< typename T >
    bool fast_follow_parents( Node start, T &target, bool is_extension, bool is_positive ) {
        bool clean = true;
        Node stop = shared.g.initial();
        Node current = start;

        while ( !equal( current, stop ) ) {
            target.push_front( current );
            Extension &ext = foreign_extension( current );

            if ( !is_positive ) {
                ext.on_negative = true;
                //std::cerr << "NN " << shared.g.showNode( current ) << std::endl;
            }
            bool was_default;
            current = get_parent( ext, is_extension, was_default );
            if ( !was_default ) clean = false;
        }
        target.push_front( current );

        return clean;
    }

    bool had_enough_traces( size_t count, bool doing_positives ) {
        if ( doing_positives ) return count >= MAX_TRAILS;
        return false;
    }

    bool keep_adding_traces( size_t count, bool doing_positives ) {
        if ( doing_positives ) return true;
        return count < MAX_TRAILS;
    }

    void extract_traces( std::vector< TrailSeedEdge > &source, std::vector< Trail > &target, bool doing_positives ) {
        for ( typename std::vector< TrailSeedEdge >::iterator i = source.begin(); i != source.end(); i++ ) {
            bool did_last = false;
            while ( !did_last && !had_enough_traces( target.size(), doing_positives ) ) {
                if ( keep_adding_traces( target.size(), doing_positives ) ) {
                    Trail trail;
                    did_last = fast_follow_parents( i->first, trail, i->is_extension, doing_positives );

                    // add last state
                    trail.push_back( i->second );
                    if ( !doing_positives ) {
                        foreign_extension( i->second ).on_negative = true;
                        //std::cerr << "NN " << shared.g.showNode( i->second ) << std::endl;
                    }

                    // if asked to, check if it's a prefix of something
                    if ( !( doing_positives && goes_bad( trail ) ) ) {
                        target.push_back( trail );
                    }
                } else {
                    // just mark the states along the path
                    NullContainer sink;
                    fast_follow_parents( i->first, sink, i->is_extension, doing_positives );
                    did_last = true;
                }


                m_iteration ++;
            }
        }
    }

    int owner( Node n ) {
        return hasher( n ) % this->peers();
    }

    Extension &foreign_extension( Node n ) {
        Node stored = domain().parallel().instance( owner( n ) ).table().get( n ).key;
        return extension( stored );
    }

    bool goes_bad( const Trail &pos ) {
        Node last = pos.back();
        return foreign_extension( last ).on_negative;
    }

    int dump( Node n ) {
        std::cout << shared.g.showNode( n ) << std::endl;
        return 0;
    }

    int dump( const Trail &t ) {
        std::cout << "Trail:----" << std::endl;
        for ( typename Trail::const_iterator i = t.begin(); i != t.end(); i++ ) {
            Node x = *i;
            std::cout << shared.g.showNode( x ) << std::endl;
        }
        return 0;
    }

    int dump( const std::vector< Trail > &trails, char type ) {
        for ( typename std::vector< Trail >::const_iterator i = trails.begin(); i != trails.end(); i++ ) {
            // there may be a lot of negatives
            if ( i - trails.begin() > MAX_TRAILS ) return 0;
            std::cout << type << i - trails.begin() << " ";
            dump( *i );
        }
    }

    void extract_transitions( const std::vector< Trail > &source, std::vector< ActionTrail > &target ) {
        target.resize( source.size() );

        for ( typename std::vector< Trail >::const_iterator tr = source.begin(); tr != source.end(); tr++ ) {
            ActionTrail &target_trail = target.at( tr - source.begin() );
            target_trail.reserve( tr->size() - 1 );

            for ( size_t a = 0; a < tr->size() - 1; a ++ ) {
                size_t b = a + 1;

                Transition action = GT::findTransition( shared.g, tr->at( a ), tr->at( b ) );
                target_trail.push_back( action );
            }
        }
    }

    template< typename T >
    static void sort_vec( std::vector< T > &v ) {
        sort( v.begin(), v.end() );
    }

    void transition_analysis() {
        std::vector< ActionTrail > pos_atrails;
        std::vector< ActionTrail > neg_atrails;

        extract_transitions( m_positives, pos_atrails );
        extract_transitions( m_negatives, neg_atrails );

        if ( m_exptest ) {
            std::cout << "Positive actions:" << std::endl;
            for_each( pos_atrails.begin(), pos_atrails.end(), dumper< ActionTrail >() );
            std::cout << "Negative actions:" << std::endl;
            for_each( neg_atrails.begin(), neg_atrails.end(), dumper< ActionTrail >() );
        }

        // prepare for set operations - sort
        for_each( pos_atrails.begin(), pos_atrails.end(), Us::template sort_vec< Transition > );
        for_each( neg_atrails.begin(), neg_atrails.end(), Us::template sort_vec< Transition > );

        // trans -- all transitions which appear somewhere
        std::set< Transition > trans_pos;
        std::set< Transition > trans_neg;
        for ( typename std::vector< ActionTrail >::const_iterator i = pos_atrails.begin(); i != pos_atrails.end(); i++ ) {
            trans_pos.insert( i->begin(), i->end() );
        }

        for ( typename std::vector< ActionTrail >::const_iterator i = neg_atrails.begin(); i != neg_atrails.end(); i++ ) {
            trans_neg.insert( i->begin(), i->end() );
        }

        dumper< std::set< Transition > > dumpset;
        dumper< std::set< Transition > > detail_dump;
        detail_dump.m_minor = dumper< Transition >( true );
        if ( m_exptest ) {
            std::cout << "TRANS: (pos)";
            dumpset( trans_pos );
            std::cout << "TRANS: (neg)";
            dumpset( trans_neg );
        }

        // all -- the transitions present in every trail
        std::set< Transition > all_pos;
        std::set< Transition > all_neg;
        big_intersection( pos_atrails.begin(), pos_atrails.end(), all_pos );
        big_intersection( neg_atrails.begin(), neg_atrails.end(), all_neg );
        std::cout << "ALL: (pos)";
        detail_dump( all_pos );
        std::cout << "ALL: (neg)";
        detail_dump( all_neg );

        // only -- the transitions occuring in only one type of traces
        std::set< Transition > only_pos;
        std::set< Transition > only_neg;
        std::set_difference( trans_pos.begin(), trans_pos.end(), trans_neg.begin(), trans_neg.end(), std::inserter( only_pos, only_pos.end() ) );
        std::set_difference( trans_neg.begin(), trans_neg.end(), trans_pos.begin(), trans_pos.end(), std::inserter( only_neg, only_neg.end() ) );
        std::cout << "ONLY: (pos)";
        detail_dump( only_pos );
        std::cout << "ONLY: (neg)";
        detail_dump( only_neg );

        // cause -- the transitions occuring in all traces of one type and only in traces of that type
        std::set< Transition > cause_pos;
        std::set< Transition > cause_neg;
        std::set_intersection( all_pos.begin(), all_pos.end(), only_pos.begin(), only_pos.end(), std::inserter( cause_pos, cause_pos.begin() ) );
        std::set_intersection( all_neg.begin(), all_neg.end(), only_neg.begin(), only_neg.end(), std::inserter( cause_neg, cause_neg.begin() ) );
        std::cout << "CAUSE: (pos)";
        detail_dump( cause_pos );
        std::cout << "CAUSE: (neg)";
        detail_dump( cause_neg );
    }

    void ianalysis_group( std::ostream &out, std::vector< Trail > &trails, std::string suffix ) {
        for ( typename std::vector< Trail >::iterator trail = trails.begin(); trail != trails.end(); trail++ ) {
            for ( typename Trail::iterator st = trail->begin(); st != trail->end(); st ++ ) {
                for ( typename std::vector< Expression* >::iterator i = m_instrumentation.begin(); i != m_instrumentation.end(); i++ ) {
                    if ( GT::evaluateExpression( shared.g, *i, *st ) ) {
                        out << "instrument" << i - m_instrumentation.begin() << "." << suffix << ":::POINT" << std::endl;

                        writeStateVars( shared.g, *st, out );
                    }
                }
            }
        }
    }

    void invariant_analysis() {
        if ( m_instrumentation.size() == 0 ) return;

        std::ostringstream out;
        writeVarDecls( shared.g, out, m_instrumentation.size() );

        ianalysis_group( out, m_positives, "pos" );
        ianalysis_group( out, m_negatives, "neg" );

        if ( m_exptest ) {
            std::cout << out.str();
        }
        std::cout.flush();
        std::cerr.flush();

        // run Daikon
        std::ofstream file( "explain.dtrace" );
        file << out.str();
        file.close();

        if (system( "java -cp daikon.jar daikon.Daikon explain.dtrace" )) {
            std::cerr << "Error running Daikon." << std::endl;
        }
    }

    Transformation transform( const Trail &pos, const Trail &neg ) {
        // max prefix
        size_t p;
        for ( p = 0; p < std::min( pos.size(), neg.size() ); p++ ) {
            if ( !( pos[p] == neg[p] ) ) break;
        }

        // max suffix
        size_t u; // length of the suffix, excluding the last state of the positive (starting at the end of negative)
        for ( u = 0; u < std::min( pos.size() - 1, neg.size() ); u++ ) {
            if ( u + p > std::min( pos.size() - 1, neg.size() ) ) break;
            if ( !sameLocations( shared.g, pos[ ( pos.size() - 1 ) - 1 - u ], neg[ ( neg.size() - 1 ) - u ] ) ) break;
        }

        typename Trail::const_iterator tp_begin, tp_end, tn_begin, tn_end;
        if ( p <= pos.size() - ( u + 1 ) ) {
            // include the last state of prefix (if there is a prefix)
            // it may be more informative if it's there. In the paper, the example has it, the definition
            // does not. So we can pick to our liking.
            tp_begin = pos.begin() + std::max( 0, (int)p - 1 );
            tp_end = pos.end() - u;
        } else {
            // empty
            tp_begin = pos.end();
            tp_end = pos.end();
        }

        if ( p <= neg.size() - u ) {
            tn_begin = neg.begin() + std::max( 0, (int)p - 1 );
            tn_end = neg.end() - ( u - 1 );
        } else {
            tn_begin = neg.end();
            tn_end = neg.end();
        }

        return Transformation( p, tp_begin, tp_end, tn_begin, tn_end );
    }

    void transformation_analysis() {
        std::vector< Transformation > transformations;
        for ( typename std::vector< Trail >::iterator pos = m_positives.begin(); pos != m_positives.end(); pos ++ ) {
            for ( typename std::vector< Trail >::iterator neg = m_negatives.begin(); neg != m_negatives.end(); neg ++ ) {
                Transformation t = transform( *pos, *neg );
                if ( t.size() == 0 ) continue;
                t.pos_index = pos - m_positives.begin();
                t.neg_index = neg - m_negatives.begin();

                transformations.push_back( t );
            }
        }

        // transformations use iterators, don't touch the negatives from now on!

        // sort
        std::sort( transformations.begin(), transformations.end() );

        // print
        for ( typename std::vector< Transformation >::const_iterator i = transformations.begin(); i != transformations.end(); i ++ ) {
            if ( ( i - transformations.begin() >= 3 ) ) break;
            i->dump( std::cout, shared.g );
        }
    }

    visitor::ExpansionAction ext_expansion( Node st ) {
        extension( st ).ext_has_succ = 0;
        return visitor::ExpandState;
    }

    visitor::TransitionAction ext_continue( Node t ) {
        if ( !extension( t ).visited ) {
            extension( t ).visited = true;
            return visitor::ExpandTransition;
        } else {
            return visitor::ForgetTransition;
        }
    }

    bool ext_no_succ_yet( Node f ) {
        return __sync_bool_compare_and_swap( &extension( f ).ext_has_succ, 0, 1 );
    }

    visitor::TransitionAction ext_inner_trans( Node f, Node t ) {
        int new_match = extension( f ).ext_match + 1;
        Node t_match = shared.ce_trail[ new_match ];
        //std::cout << shared.g.showNode( t ) << " vs " << shared.g.showNode( t_match ) << std::endl;
        if ( sameLocations( shared.g, t, t_match ) ) {
            // we got a candidate
            if ( !ext_no_succ_yet( f ) )
                return visitor::ForgetTransition;

            extension( t ).ext_match = new_match;
            extension( t ).ext_parent = f;

            if ( extension( t ).ext_match == shared.ce_trail.size() - 1 ) {
                check_state( f, t, CheckType::Negative | CheckType::Extension );
            }
            return ext_continue( t );
        } else
            return visitor::ForgetTransition;
    }

    visitor::TransitionAction ext_transition( Node f, Node t ) {
        if ( f.valid() ) {
            int new_match = extension( f ).ext_match + 1;
            if ( new_match < shared.ce_trail.size() ) {
                return ext_inner_trans( f, t );

            } else if ( new_match == shared.ce_trail.size() ) {
                // positives need one more action
                if ( !shared.g.isGoal( f ).first ) {
                    check_state( f, t, CheckType::Positive | CheckType::Extension | CheckType::OnlyOne );
                }
                return visitor::ForgetTransition;

            } else {
                return visitor::ForgetTransition;
            }

        }
        return ext_continue( t );
    }

    // performs a "simple DFS" from current state
    void _extension_from_state() {
        m_initialTable = &shared.initialTable;
        typedef visitor::Setup< G, Us, Table, Statistics,
            &Us::ext_transition,
            &Us::ext_expansion > Setup;
        typedef visitor::Parallel< Setup, Us, Hasher > Visitor;

        Visitor visitor( shared.g, *this, *this,
                         Hasher( sizeof( Extension ) ), &table() );
        assert( shared.ext_current.valid() );
        if ( visitor.owner( shared.ext_current ) == this->globalId() )
            visitor.queue( Blob(), shared.ext_current );
        visitor.processQueue();
    }

    // walk all states and start the extension procedure from those which
    // were just behind the depth limit
    void _extension() {
        for ( typename std::vector< Node >::iterator i = m_ext_init.begin(); i != m_ext_init.end(); i++ ) {
            shared.ext_current = *i;

            for ( int ext_index = shared.depth + 1; ext_index < shared.ce_trail.size(); ext_index ++ ) {
                if ( !sameLocations( shared.g, shared.ext_current, shared.ce_trail[ ext_index ] ) ) continue;

                extension( shared.ext_current ).ext_match = ext_index;
                domain().parallel().run( shared, &Us::_extension_from_state );
            }
        }

        m_ext_init.clear();
    }

    void explain_step( Node init ) {
        shared.initial = init;
        shared.positives.clear();
        shared.negatives.clear();

        domain().parallel().run( shared, &Us::_explain );

        do_shared( &Us::collect_neg );
        do_shared( &Us::collect_pos );
    }

    void extension_step() {
        domain().parallel().runInRing( shared, &Us::_extension );

        do_shared( &Us::collect_neg );
        do_shared( &Us::collect_pos );
    }

    Result run() {
        srand( 1 );

        StopWatch reach_watch, explain_watch, extension_watch, atran_watch, ainv_watch, atransf_watch;
        double explain_total = 0, extension_total = 0;

        // find first counterexample
        reach_watch.start();
        std::cerr << "  searching for counterexample... \t" << std::flush;
        shared.ext_stat = 0;

        domain().parallel().run( shared, &Us::_reachability );

        GoalInfo goal_pair = reduce_shared( &Us::collectGoal, GoalInfo() );
        std::cerr << "done" << std::endl;
        safetyBanner( !goal_pair.n.valid() );

        if ( goal_pair.n.valid() ) {
            shared.goal = goal_pair.n;
            shared.goalType = goal_pair.type;
            shared.assertion = goal_pair.assertion;
            std::vector< Node > trail;
            ce.m_linearTrace = &trail;
            trail_from( goal_pair.n );
            ce.m_linearTrace = NULL;
            shared.ce_trail = trail;

            reach_watch.stop();

            // start explanation
            std::cerr << "  searching for positives and negatives... " << std::endl;
            assert( trail.size() >= 2 );
            assert( goal_pair.n == trail.back() );

            int k = trail.size();
            int i = k - 1 - 1; // index of one before last

            while ( i >= 0 ) {
                // explain
                explain_watch.start();
                explain_step( trail[i] );
                explain_watch.stop();

                // ext
                extension_watch.start();
                extension_step();
                extension_watch.stop();

                -- i;

                explain_total += explain_watch.duration();
                extension_total += extension_watch.duration();
            }

            // collect negatives first so we can check prefixes when collecting positives
            m_iteration = 2;   // start with a higher value to avoid conflict with true/false used in previous search
            extract_traces( m_neg_starts, m_negatives, false );
            extract_traces( m_pos_starts, m_positives, true );

            if ( m_exptest ) {
                std::cout << "positives: " << std::endl;
                dump( m_positives, 'p' );
                std::cout << "negatives: " << std::endl;
                dump( m_negatives, 'n' );
            }

            // the analyses
            atran_watch.start();
            transition_analysis();
            atran_watch.stop();

            ainv_watch.start();
            invariant_analysis();
            ainv_watch.stop();

            atransf_watch.start();
            transformation_analysis();
            atransf_watch.stop();

            // report
            // "Explain-Depth:" is used to detect the start of report
            std::cout << "Explain-Depth: " << shared.depth << std::endl;
            std::cout << "Explain-Max-Trails: " << MAX_TRAILS << std::endl;
            std::cout << "Positives-Count: " << m_positives.size() << std::endl;
            std::cout << "Negatives-Count: " << m_negatives.size() << std::endl;
            std::cout << "Found-Goal-Type: " << GoalType::print( goal_pair.type ) << std::endl;
            std::cout << "Extension-Stat: " << reduce_shared( &Us::sum_ext_stat, 0 ) << std::endl;
            std::cout << "Reachability-Time: " << reach_watch.duration() << std::endl;
            std::cout << "Explain-Time: " << explain_total << std::endl;
            std::cout << "Extension-Time: " << extension_total << std::endl;
            std::cout << "Transition-Analysis-Time: " << atran_watch.duration() << std::endl;
            std::cout << "Invariant-Analysis-Time: " << ainv_watch.duration() << std::endl;
            std::cout << "Transformation-Analysis-Time: " << atransf_watch.duration() << std::endl;

            for ( typename std::vector< std::string >::iterator si = m_instr_string.begin(); si != m_instr_string.end(); si ++ ) {
                std::cout << "Instrumentation." << si - m_instr_string.begin() << ": " << *si << std::endl;
            }

            // release the trail
            for ( typename std::vector< Node >::iterator it = trail.begin(); it != trail.end(); it++ ) {
                shared.g.release( *it );
            }
            trail.clear();
        }

        return result();
    }


    Explain( Config *c = NULL ): Algorithm( c, sizeof( Extension ) ) {
        initGraph( shared.g );

        // catch a pointer to the function so that it is always available in the debugger
        std::string ( G::*fn_p )( Node ) = &G::showNode;

        if ( c ) {
            becomeMaster( &shared, workerCount( c ) );
            shared.initialTable = c->initialTableSize();
            m_exptest = c->m_exptest;
            shared.depth = c->m_expdepth;

            m_instr_string = c->m_instrumentation;
            for ( std::vector< std::string >::iterator i = c->m_instrumentation.begin(); i != c->m_instrumentation.end(); i ++ ) {
                m_instrumentation.push_back( GT::newExpression( shared.g, *i ) );
            }
        }
    }

    ~Explain() {
        for ( typename std::vector< Expression* >::iterator i = m_instrumentation.begin(); i != m_instrumentation.end(); i ++ ) {
            safe_delete( *i );
        }
        m_instrumentation.clear();
    }
};

}
}

#endif