cp-library-cpp

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub suisen-cp/cp-library-cpp

:heavy_check_mark: test/src/datastructure/union_find/merge_history_forest/abc235_h.test.cpp

Depends on

Code

#define PROBLEM "https://atcoder.jp/contests/abc235/tasks/abc235_Ex"

#include <iostream>
#include <map>

#include <atcoder/modint>
using mint = atcoder::modint998244353;

#include <atcoder/dsu>
#include <algorithm>
#include <deque>
#include <numeric>
#include <optional>
#include <queue>

#include "library/datastructure/union_find/merge_history_forest.hpp"
using suisen::MergeHistoryForest;

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
 
    int n, m, k;
    std::cin >> n >> m >> k;
 
    std::map<int, std::vector<std::pair<int, int>>> edges;
    while (m --> 0) {
        int u, v, w;
        std::cin >> u >> v >> w;
        edges[w].emplace_back(u - 1, v - 1);
    }
 
    MergeHistoryForest uf(n);
    for (const auto &e : edges) {
        const auto &es = e.second;
        if (es.size() == 1) {
            uf.merge(es.front());
        } else {
            uf.merge_simultaneously(e.second);
        }
    }
 
    const auto &g = uf.get_forest();
 
    auto merge = [&](auto &f, auto &g) {
        int szf = f.size(), szg = g.size();
        std::vector<mint> nf(std::min(szf + szg - 1, k + 1), 0);
        for (int i = 0; i < szf; ++i) for (int j = 0; j < szg; ++j) {
            if (i + j > k) break;
            nf[i + j] += f[i] * g[j];
        }
        return nf;
    };
 
    std::vector dp(g.size(), std::vector<mint>{});
    auto dfs = [&](auto dfs, int u) -> void {
        if (g[u].empty()) {
            dp[u] = { 1, 1 };
            return;
        }
        dp[u] = { 1 };
        for (int v : g[u]) {
            dfs(dfs, v);
            dp[u] = merge(dp[u], dp[v]);
        }
        dp[u][1] += 1;
        if (int c = g[u].size(); c <= k) dp[u][c] -= 1;
    };
 
    std::vector<mint> f { 1 };
    for (int root : uf.forest_roots()) {
        dfs(dfs, root);
        f = merge(f, dp[root]);
    }
 
    std::cout << std::accumulate(f.begin(), f.end(), mint(0)).val() << std::endl;
 
    return 0;
}
#line 1 "test/src/datastructure/union_find/merge_history_forest/abc235_h.test.cpp"
#define PROBLEM "https://atcoder.jp/contests/abc235/tasks/abc235_Ex"

#include <iostream>
#include <map>

#include <atcoder/modint>
using mint = atcoder::modint998244353;

#include <atcoder/dsu>
#include <algorithm>
#include <deque>
#include <numeric>
#include <optional>
#include <queue>

#line 1 "library/datastructure/union_find/merge_history_forest.hpp"



#line 8 "library/datastructure/union_find/merge_history_forest.hpp"
#include <limits>

namespace suisen {
    struct MergeHistoryForest : public atcoder::dsu {
        using base_type = atcoder::dsu;

        MergeHistoryForest() : MergeHistoryForest(0) {}
        explicit MergeHistoryForest(int n) : base_type(n), _g(n), _parent(n, -1), _root(n), _time(0), _created_time(n, _time) {
            std::iota(_root.begin(), _root.end(), 0);
        }

        int node_num() const { return _g.size(); }
        int leaf_num() const { return _root.size(); }

        const auto& get_forest() const { return _g; }
    
        int forest_root(int i) { return _root[leader(i)]; }
        int forest_parent(int vid) const { return _parent[vid]; }
        const auto& forest_children(int vid) { return _g[vid]; }
        bool is_forest_root(int vid) const { return _parent[vid] < 0; }
        bool is_forest_leaf(int vid) const { return vid < leaf_num(); }

        std::vector<int> forest_roots() {
            const int n = leaf_num();
            std::vector<int> roots;
            for (int i = 0; i < n; ++i) if (leader(i) == i) roots.push_back(_root[i]);
            return roots;
        }
 
        void merge(int u, int v) {
            ++_time;
            const int ru = leader(u), rv = leader(v);
            if (ru == rv) return;
            const int new_root = create_node();
            create_edge(new_root, _root[ru]), create_edge(new_root, _root[rv]);
            _root[base_type::merge(ru, rv)] = new_root;
        }
        void merge(const std::pair<int, int> &edge) { merge(edge.first, edge.second); }

        void merge_simultaneously(const std::vector<std::pair<int, int>> &edges) {
            ++_time;
            std::vector<int> vs;
            for (const auto &[u, v] : edges) {
                const int ru = leader(u), rv = leader(v);
                if (ru == rv) continue;
                const int r = base_type::merge(ru, rv), c = ru ^ rv ^ r;
                _g[r].push_back(c);
                vs.push_back(r);
            }
            for (int s : vs) if (s == leader(s)) {
                const int new_root = create_node();
                merge_dfs(s, new_root);
                _root[s] = new_root;
            }
        }

        int current_time() const { return _time; }
        int created_time(int vid) const { return _created_time[vid]; }

        std::vector<int> group(int i, int time = std::numeric_limits<int>::max()) {
            int root = i;
            while (_parent[root] >= 0 and _created_time[_parent[root]] <= time) root = _parent[root];
            std::vector<int> res;
            auto dfs = [&, this](auto dfs, int u) -> void {
                if (is_forest_leaf(u)) {
                    res.push_back(u);
                } else {
                    for (int v : _g[u]) dfs(dfs, v);
                }
            };
            dfs(dfs, root);
            return res;
        }
        std::vector<std::vector<int>> groups(int time = std::numeric_limits<int>::max()) {
            std::vector<std::vector<int>> res;
            const int n = leaf_num();
            std::vector<bool> seen(n, false);
            for (int i = 0; i < n; ++i) if (not seen[i]) for (int v : res.emplace_back(group(i, time))) seen[v] = true;
            return res;
        }

        template <typename GetLCA>
        bool same(int u, int v, int time, GetLCA&& get_lca) {
            if (not base_type::same(u, v)) return false;
            int a = get_lca(u, v);
            return _created_time[a] <= time;
        }

        using base_type::same;

    private:
        std::vector<std::vector<int>> _g;
        std::vector<int> _parent;
        std::vector<int> _root;

        // sum of the number of calls of function `merge` and those of `merge_simultaneously`
        int _time;
        std::vector<int> _created_time;

        void merge_dfs(int u, int new_root) {
            for (int v : _g[u]) merge_dfs(v, new_root), _g[v].shrink_to_fit();
            create_edge(new_root, _root[u]);
            _g[u].clear();
        }

        int create_node() {
            _g.emplace_back();
            _created_time.push_back(_time);
            _parent.push_back(-1);
            return _g.size() - 1;
        }
        void create_edge(int new_root, int old_root) {
            _g[new_root].push_back(old_root);
            _parent[old_root] = new_root;
        }
        static int floor_log2(int n) {
            int res = 0;
            while (1 << (res + 1) <= n) ++res;
            return res;
        }
    };
} // namespace suisen



#line 17 "test/src/datastructure/union_find/merge_history_forest/abc235_h.test.cpp"
using suisen::MergeHistoryForest;

int main() {
    std::ios::sync_with_stdio(false);
    std::cin.tie(nullptr);
 
    int n, m, k;
    std::cin >> n >> m >> k;
 
    std::map<int, std::vector<std::pair<int, int>>> edges;
    while (m --> 0) {
        int u, v, w;
        std::cin >> u >> v >> w;
        edges[w].emplace_back(u - 1, v - 1);
    }
 
    MergeHistoryForest uf(n);
    for (const auto &e : edges) {
        const auto &es = e.second;
        if (es.size() == 1) {
            uf.merge(es.front());
        } else {
            uf.merge_simultaneously(e.second);
        }
    }
 
    const auto &g = uf.get_forest();
 
    auto merge = [&](auto &f, auto &g) {
        int szf = f.size(), szg = g.size();
        std::vector<mint> nf(std::min(szf + szg - 1, k + 1), 0);
        for (int i = 0; i < szf; ++i) for (int j = 0; j < szg; ++j) {
            if (i + j > k) break;
            nf[i + j] += f[i] * g[j];
        }
        return nf;
    };
 
    std::vector dp(g.size(), std::vector<mint>{});
    auto dfs = [&](auto dfs, int u) -> void {
        if (g[u].empty()) {
            dp[u] = { 1, 1 };
            return;
        }
        dp[u] = { 1 };
        for (int v : g[u]) {
            dfs(dfs, v);
            dp[u] = merge(dp[u], dp[v]);
        }
        dp[u][1] += 1;
        if (int c = g[u].size(); c <= k) dp[u][c] -= 1;
    };
 
    std::vector<mint> f { 1 };
    for (int root : uf.forest_roots()) {
        dfs(dfs, root);
        f = merge(f, dp[root]);
    }
 
    std::cout << std::accumulate(f.begin(), f.end(), mint(0)).val() << std::endl;
 
    return 0;
}
Back to top page