cp-library-cpp

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub suisen-cp/cp-library-cpp

:heavy_check_mark: Tree Isomorphism Classification
(library/tree/tree_isomorphism_classification.hpp)

Tree Isomorphism Classification

Depends on

Verified with

Code

#ifndef SUISEN_ROOTED_TREE_ISOMORPHISM_CLASSIFICATION
#define SUISEN_ROOTED_TREE_ISOMORPHISM_CLASSIFICATION

#include <algorithm>
#include <cassert>
#include <deque>
#include <map>
#include <optional>
#include <random>
#include <utility>
#include <vector>

#include "library/string/trie_map.hpp"
#include "library/util/hashes.hpp"
#include "library/tree/find_centroid.hpp"

namespace suisen {
    namespace internal::tree_classification { struct IDHandlerBase {}; }

    struct IDHandlerNaive : internal::tree_classification::IDHandlerBase {
        using base_type = internal::tree_classification::IDHandlerBase;
        using key_type = std::vector<int>;
    private:
        static constexpr int None = -1;
        struct TrieNode : MapTrieNode<int> {
            int id = None;
        };
        std::vector<int> mp1{};
        MapTrie<TrieNode> mp{};
        int next_id = 0;

        void ensure_mp1(int id) {
            if (id >= int(mp1.size())) mp1.resize(id + 1, None);
        }
    public:
        IDHandlerNaive() = default;

        int get_id(key_type ch_ids) {
            if (const int siz = ch_ids.size(); siz == 1) {
                int ch = ch_ids[0];
                ensure_mp1(ch);
                return mp1[ch] != None ? mp1[ch] : (mp1[ch] = next_id++);
            } else {
                std::sort(ch_ids.begin(), ch_ids.end());
                TrieNode& node = mp.add(ch_ids);
                return node.id != None ? node.id : (node.id = next_id++);
            }
        }
        void add_child(key_type& key, int id) const {
            key.push_back(id);
        }
        void rem_child(key_type& key, int id) const {
            auto it = std::find(key.begin(), key.end(), id);
            assert(it != key.end());
            key.erase(it);
        }
    };

    template <std::size_t hash_num = 2>
    struct IDHandlerZobrist : internal::tree_classification::IDHandlerBase {
        using base_type = internal::tree_classification::IDHandlerBase;
        using key_type = std::array<uint64_t, hash_num>;
    private:
        std::mt19937_64 rng{ std::random_device{}() };
        std::vector<key_type> h{};
        std::map<key_type, int> mp{};
        int next_id = 0;
    public:
        IDHandlerZobrist() = default;

        int get_id(key_type key) {
            auto [it, inserted] = mp.try_emplace(key, next_id);
            if (inserted) {
                ++next_id;
                auto& x = h.emplace_back();
                for (std::size_t i = 0; i < hash_num; ++i) {
                    while ((x[i] = rng()) == 0);
                }
            }
            return it->second;
        }
        void add_child(key_type& key, int id) const {
            for (std::size_t i = 0; i < hash_num; ++i) key[i] += h[id][i];
        }
        void rem_child(key_type& key, int id) const {
            for (std::size_t i = 0; i < hash_num; ++i) key[i] -= h[id][i];
        }
    };

    template <
        typename IDHandler = IDHandlerNaive,
        std::enable_if_t<std::is_base_of_v<internal::tree_classification::IDHandlerBase, IDHandler>, std::nullptr_t> = nullptr
    >
    struct RootedTreeClassifier {
        using key_type = typename IDHandler::key_type;
    public:
        RootedTreeClassifier() = default;

        /**
         * @brief Classify subtrees by isomorphism in O(n log n) time.
         * @param g tree
         * @param root root of g
         * @return { number of distinct (rooted) subtrees, id of subtrees }
         */
        template <typename GraphType>
        std::vector<int> classify_subtrees(const GraphType& g, int root) {
            const int n = g.size();
            std::vector<int> ids(n), eid(n), par(n, -1);
            for (int cur = root; cur != -1;) {
                if (eid[cur] == int(g[cur].size())) {
                    key_type ch_ids{};
                    for (int v : g[cur]) if (v != par[cur]) {
                        _id_handler.add_child(ch_ids, ids[v]);
                    }
                    ids[cur] = classify(ch_ids);
                    cur = par[cur];
                } else if (int nxt = g[cur][eid[cur]++]; nxt != par[cur]) {
                    par[nxt] = cur;
                    cur = nxt;
                }
            }
            return ids;
        }

        template <typename GraphType>
        int classify(const GraphType& g, int root) {
            return classify_subtrees(g, root)[root];
        }
        int classify(const key_type& ch_ids) {
            return _id_handler.get_id(ch_ids);
        }

        template <typename GraphType>
        std::vector<int> classify_rerooting(const GraphType& g) {
            const int n = g.size();
            std::vector<key_type> ch_ids(n);
            std::vector<int> sub_ids(n), eid(n), par(n, -1);
            std::vector<int> pre(n);
            std::vector<int>::iterator it = pre.begin();
            for (int cur = 0; cur != -1;) {
                if (eid[cur] == 0) *it++ = cur;
                if (eid[cur] == int(g[cur].size())) {
                    for (int v : g[cur]) if (v != par[cur]) {
                        _id_handler.add_child(ch_ids[cur], sub_ids[v]);
                    }
                    sub_ids[cur] = classify(ch_ids[cur]);
                    cur = par[cur];
                } else if (int nxt = g[cur][eid[cur]++]; nxt != par[cur]) {
                    par[nxt] = cur;
                    cur = nxt;
                }
            }
            std::vector<int> ids(n);
            ids[0] = sub_ids[0];
            for (int u : pre) {
                for (int v : g[u]) if (v != par[u]) {
                    key_type ku = ch_ids[u];
                    int iu = ids[u];
                    reroot(ku, iu, ch_ids[v], sub_ids[v]);
                    ids[v] = sub_ids[v];
                }
            }
            return ids;
        }

        void reroot(key_type& ch_ids_old_par, int& id_old_par, key_type& ch_ids_new_par, int& id_new_par) {
            _id_handler.rem_child(ch_ids_old_par, id_new_par);
            id_old_par = classify(ch_ids_old_par);
            _id_handler.add_child(ch_ids_new_par, id_old_par);
            id_new_par = classify(ch_ids_new_par);
        }

        template <typename GraphType>
        std::optional<std::pair<int, int>> is_isomorphic(const GraphType& g1, const GraphType& g2) {
            std::vector<int> cs1 = find_centroids(g1);
            std::vector<int> cs2 = find_centroids(g2);
            const int cnum1 = cs1.size(), cnum2 = cs2.size();

            std::vector<int> ids10 = classify_subtrees(g1, cs1[0]);
            std::vector<int> ids20 = classify_subtrees(g2, cs2[0]);

            if (ids10[cs1[0]] == ids20[cs2[0]]) return std::pair{ cs1[0], cs2[0] };

            int id11 = -1, id21 = -2;
            if (cnum1 == 2) {
                key_type ch_ids_old_par{};
                int id_old_par = ids10[cs1[0]];
                key_type ch_ids_new_par{};
                int id_new_par = ids10[cs1[1]];
                for (int v : g1[cs1[0]]) {
                    _id_handler.add_child(ch_ids_old_par, ids10[v]);
                }
                for (int v : g1[cs1[1]]) if (v != cs1[0]) {
                    _id_handler.add_child(ch_ids_new_par, ids10[v]);
                }
                reroot(ch_ids_old_par, id_old_par, ch_ids_new_par, id_new_par);
                id11 = id_new_par;
            }
            if (cnum2 == 2) {
                key_type ch_ids_old_par{};
                int id_old_par = ids20[cs2[0]];
                key_type ch_ids_new_par{};
                int id_new_par = ids20[cs2[1]];
                for (int v : g2[cs2[0]]) {
                    _id_handler.add_child(ch_ids_old_par, ids20[v]);
                }
                for (int v : g2[cs2[1]]) if (v != cs2[0]) {
                    _id_handler.add_child(ch_ids_new_par, ids20[v]);
                }
                reroot(ch_ids_old_par, id_old_par, ch_ids_new_par, id_new_par);
                id21 = id_new_par;
            }
            if (id11 == ids20[cs2[0]]) return std::pair{ cs1[1], cs2[0] };
            if (ids20[cs2[1]] == id21) return std::pair{ cs1[0], cs2[1] };
            if (id11 == id21) return std::pair{ cs1[1], cs2[1] };
            return std::nullopt;
        }
        template <typename GraphType>
        bool is_isomorphic_rooted(const GraphType& g1, int root1, const GraphType& g2, int root2) {
            return classify(g1, root1) == classify(g2, root2);
        }
    private:
        IDHandler _id_handler;
    };
} // namespace suisen

#endif // SUISEN_ROOTED_TREE_ISOMORPHISM_CLASSIFICATION
#line 1 "library/tree/tree_isomorphism_classification.hpp"



#include <algorithm>
#include <cassert>
#include <deque>
#include <map>
#include <optional>
#include <random>
#include <utility>
#include <vector>

#line 1 "library/string/trie_map.hpp"



#line 5 "library/string/trie_map.hpp"
#include <unordered_map>
#line 7 "library/string/trie_map.hpp"

namespace suisen {
    template <typename T, bool use_ordered_map = true>
    struct MapTrieNode : std::conditional_t<use_ordered_map, std::map<T, int>, std::unordered_map<T, int>> {
        static constexpr int none = -1;
        static constexpr bool ordered = use_ordered_map;

        using key_type = T;

        int operator[](const key_type& c) const {
            auto it = this->find(c);
            return it == this->end() ? none : it->second;
        }
        int& operator[](const key_type& c) {
            return this->try_emplace(c, none).first->second;
        }
    };
    template <
        typename NodeType,
        std::enable_if_t<std::is_base_of_v<MapTrieNode<typename NodeType::key_type, NodeType::ordered>, NodeType>, std::nullptr_t> = nullptr
    >
    struct MapTrie {
        using node_type = NodeType;
        using key_type = typename node_type::key_type;
        using base_node_type = MapTrieNode<key_type>;

        static constexpr int none = node_type::none;

        std::vector<node_type> nodes;

        MapTrie() { nodes.emplace_back(); }

        void reserve(int capacity) {
            nodes.reserve(capacity);
        }

        template <typename Container, std::enable_if_t<std::is_constructible_v<key_type, typename Container::value_type>, std::nullptr_t> = nullptr>
        node_type& add(const Container& s, int start = 0) {
            int cur = start;
            for (key_type c : s) {
                auto [it, inserted] = nodes[cur].try_emplace(c, nodes.size());
                if (inserted) nodes.emplace_back();
                cur = it->second;
            }
            return nodes[cur];
        }
        const node_type& operator[](int i) const {
            return nodes[i];
        }
        node_type& operator[](int i) {
            return nodes[i];
        }
    };
} // namespace suisen



#line 1 "library/util/hashes.hpp"



#include <array>
#include <cstdint>
#include <tuple>
#line 8 "library/util/hashes.hpp"

namespace std {
    namespace {
        template <class T>
        inline void hash_combine(std::size_t& seed, T const& v) {
            seed ^= hash<T>()(v) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
        }

        template <class Tuple, size_t Index = std::tuple_size<Tuple>::value - 1>
        struct HashValueImpl {
            static void apply(size_t& seed, Tuple const& t) {
                HashValueImpl<Tuple, Index - 1>::apply(seed, t);
                hash_combine(seed, get<Index>(t));
            }
        };

        template <class Tuple>
        struct HashValueImpl<Tuple, 0> {
            static void apply(size_t& seed, Tuple const& t) {
                hash_combine(seed, get<0>(t));
            }
        };
    }

    template <typename T, typename U>
    struct hash<std::pair<T, U>> {
        size_t operator()(std::pair<T, U> const& tt) const {
            size_t seed = 0;
            HashValueImpl<std::pair<T, U>>::apply(seed, tt);
            return seed;
        }
    };
    template <typename ...Args>
    struct hash<std::tuple<Args...>> {
        size_t operator()(std::tuple<Args...> const& tt) const {
            size_t seed = 0;
            HashValueImpl<std::tuple<Args...>>::apply(seed, tt);
            return seed;
        }
    };
    template <typename T, std::size_t N>
    struct hash<std::array<T, N>> {
        size_t operator()(std::array<T, N> const& tt) const {
            size_t seed = 0;
            HashValueImpl<std::array<T, N>>::apply(seed, tt);
            return seed;
        }
    };
}


#line 1 "library/tree/find_centroid.hpp"



#line 7 "library/tree/find_centroid.hpp"

namespace suisen {
    template <typename GraphType>
    std::vector<int> find_centroids(const GraphType& g) {
        const int n = g.size();
        std::vector<int> res;
        std::vector<int8_t> is_centroid(n, true);
        std::vector<int> eid(n), par(n, -1), sub(n, 1);
        for (int cur = 0; cur >= 0;) {
            if (eid[cur] == int(g[cur].size())) {
                if (par[cur] >= 0) {
                    sub[par[cur]] += sub[cur];
                    is_centroid[par[cur]] &= 2 * sub[cur] <= n;
                }
                if (is_centroid[cur] and 2 * sub[cur] >= n) {
                    res.push_back(cur);
                }
                cur = par[cur];
            } else {
                int nxt = g[cur][eid[cur]++];
                if (nxt == par[cur]) continue;
                par[nxt] = cur;
                cur = nxt;
            }
        }
        assert(res.size() == 1 or res.size() == 2);
        return res;
    }
} // namespace suisen



#line 16 "library/tree/tree_isomorphism_classification.hpp"

namespace suisen {
    namespace internal::tree_classification { struct IDHandlerBase {}; }

    struct IDHandlerNaive : internal::tree_classification::IDHandlerBase {
        using base_type = internal::tree_classification::IDHandlerBase;
        using key_type = std::vector<int>;
    private:
        static constexpr int None = -1;
        struct TrieNode : MapTrieNode<int> {
            int id = None;
        };
        std::vector<int> mp1{};
        MapTrie<TrieNode> mp{};
        int next_id = 0;

        void ensure_mp1(int id) {
            if (id >= int(mp1.size())) mp1.resize(id + 1, None);
        }
    public:
        IDHandlerNaive() = default;

        int get_id(key_type ch_ids) {
            if (const int siz = ch_ids.size(); siz == 1) {
                int ch = ch_ids[0];
                ensure_mp1(ch);
                return mp1[ch] != None ? mp1[ch] : (mp1[ch] = next_id++);
            } else {
                std::sort(ch_ids.begin(), ch_ids.end());
                TrieNode& node = mp.add(ch_ids);
                return node.id != None ? node.id : (node.id = next_id++);
            }
        }
        void add_child(key_type& key, int id) const {
            key.push_back(id);
        }
        void rem_child(key_type& key, int id) const {
            auto it = std::find(key.begin(), key.end(), id);
            assert(it != key.end());
            key.erase(it);
        }
    };

    template <std::size_t hash_num = 2>
    struct IDHandlerZobrist : internal::tree_classification::IDHandlerBase {
        using base_type = internal::tree_classification::IDHandlerBase;
        using key_type = std::array<uint64_t, hash_num>;
    private:
        std::mt19937_64 rng{ std::random_device{}() };
        std::vector<key_type> h{};
        std::map<key_type, int> mp{};
        int next_id = 0;
    public:
        IDHandlerZobrist() = default;

        int get_id(key_type key) {
            auto [it, inserted] = mp.try_emplace(key, next_id);
            if (inserted) {
                ++next_id;
                auto& x = h.emplace_back();
                for (std::size_t i = 0; i < hash_num; ++i) {
                    while ((x[i] = rng()) == 0);
                }
            }
            return it->second;
        }
        void add_child(key_type& key, int id) const {
            for (std::size_t i = 0; i < hash_num; ++i) key[i] += h[id][i];
        }
        void rem_child(key_type& key, int id) const {
            for (std::size_t i = 0; i < hash_num; ++i) key[i] -= h[id][i];
        }
    };

    template <
        typename IDHandler = IDHandlerNaive,
        std::enable_if_t<std::is_base_of_v<internal::tree_classification::IDHandlerBase, IDHandler>, std::nullptr_t> = nullptr
    >
    struct RootedTreeClassifier {
        using key_type = typename IDHandler::key_type;
    public:
        RootedTreeClassifier() = default;

        /**
         * @brief Classify subtrees by isomorphism in O(n log n) time.
         * @param g tree
         * @param root root of g
         * @return { number of distinct (rooted) subtrees, id of subtrees }
         */
        template <typename GraphType>
        std::vector<int> classify_subtrees(const GraphType& g, int root) {
            const int n = g.size();
            std::vector<int> ids(n), eid(n), par(n, -1);
            for (int cur = root; cur != -1;) {
                if (eid[cur] == int(g[cur].size())) {
                    key_type ch_ids{};
                    for (int v : g[cur]) if (v != par[cur]) {
                        _id_handler.add_child(ch_ids, ids[v]);
                    }
                    ids[cur] = classify(ch_ids);
                    cur = par[cur];
                } else if (int nxt = g[cur][eid[cur]++]; nxt != par[cur]) {
                    par[nxt] = cur;
                    cur = nxt;
                }
            }
            return ids;
        }

        template <typename GraphType>
        int classify(const GraphType& g, int root) {
            return classify_subtrees(g, root)[root];
        }
        int classify(const key_type& ch_ids) {
            return _id_handler.get_id(ch_ids);
        }

        template <typename GraphType>
        std::vector<int> classify_rerooting(const GraphType& g) {
            const int n = g.size();
            std::vector<key_type> ch_ids(n);
            std::vector<int> sub_ids(n), eid(n), par(n, -1);
            std::vector<int> pre(n);
            std::vector<int>::iterator it = pre.begin();
            for (int cur = 0; cur != -1;) {
                if (eid[cur] == 0) *it++ = cur;
                if (eid[cur] == int(g[cur].size())) {
                    for (int v : g[cur]) if (v != par[cur]) {
                        _id_handler.add_child(ch_ids[cur], sub_ids[v]);
                    }
                    sub_ids[cur] = classify(ch_ids[cur]);
                    cur = par[cur];
                } else if (int nxt = g[cur][eid[cur]++]; nxt != par[cur]) {
                    par[nxt] = cur;
                    cur = nxt;
                }
            }
            std::vector<int> ids(n);
            ids[0] = sub_ids[0];
            for (int u : pre) {
                for (int v : g[u]) if (v != par[u]) {
                    key_type ku = ch_ids[u];
                    int iu = ids[u];
                    reroot(ku, iu, ch_ids[v], sub_ids[v]);
                    ids[v] = sub_ids[v];
                }
            }
            return ids;
        }

        void reroot(key_type& ch_ids_old_par, int& id_old_par, key_type& ch_ids_new_par, int& id_new_par) {
            _id_handler.rem_child(ch_ids_old_par, id_new_par);
            id_old_par = classify(ch_ids_old_par);
            _id_handler.add_child(ch_ids_new_par, id_old_par);
            id_new_par = classify(ch_ids_new_par);
        }

        template <typename GraphType>
        std::optional<std::pair<int, int>> is_isomorphic(const GraphType& g1, const GraphType& g2) {
            std::vector<int> cs1 = find_centroids(g1);
            std::vector<int> cs2 = find_centroids(g2);
            const int cnum1 = cs1.size(), cnum2 = cs2.size();

            std::vector<int> ids10 = classify_subtrees(g1, cs1[0]);
            std::vector<int> ids20 = classify_subtrees(g2, cs2[0]);

            if (ids10[cs1[0]] == ids20[cs2[0]]) return std::pair{ cs1[0], cs2[0] };

            int id11 = -1, id21 = -2;
            if (cnum1 == 2) {
                key_type ch_ids_old_par{};
                int id_old_par = ids10[cs1[0]];
                key_type ch_ids_new_par{};
                int id_new_par = ids10[cs1[1]];
                for (int v : g1[cs1[0]]) {
                    _id_handler.add_child(ch_ids_old_par, ids10[v]);
                }
                for (int v : g1[cs1[1]]) if (v != cs1[0]) {
                    _id_handler.add_child(ch_ids_new_par, ids10[v]);
                }
                reroot(ch_ids_old_par, id_old_par, ch_ids_new_par, id_new_par);
                id11 = id_new_par;
            }
            if (cnum2 == 2) {
                key_type ch_ids_old_par{};
                int id_old_par = ids20[cs2[0]];
                key_type ch_ids_new_par{};
                int id_new_par = ids20[cs2[1]];
                for (int v : g2[cs2[0]]) {
                    _id_handler.add_child(ch_ids_old_par, ids20[v]);
                }
                for (int v : g2[cs2[1]]) if (v != cs2[0]) {
                    _id_handler.add_child(ch_ids_new_par, ids20[v]);
                }
                reroot(ch_ids_old_par, id_old_par, ch_ids_new_par, id_new_par);
                id21 = id_new_par;
            }
            if (id11 == ids20[cs2[0]]) return std::pair{ cs1[1], cs2[0] };
            if (ids20[cs2[1]] == id21) return std::pair{ cs1[0], cs2[1] };
            if (id11 == id21) return std::pair{ cs1[1], cs2[1] };
            return std::nullopt;
        }
        template <typename GraphType>
        bool is_isomorphic_rooted(const GraphType& g1, int root1, const GraphType& g2, int root2) {
            return classify(g1, root1) == classify(g2, root2);
        }
    private:
        IDHandler _id_handler;
    };
} // namespace suisen
Back to top page