cp-library-cpp

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub suisen-cp/cp-library-cpp

:heavy_check_mark: Weighted Union Find
(library/datastructure/union_find/weighted_union_find.hpp)

Weighted Union Find

UnionFindTree に関する知見の諸々 で「ポテンシャル付きUnionFindTree」と呼ばれているもの。

Depends on

Verified with

Code

#ifndef SUISEN_WEIGHTED_UNION_FIND
#define SUISEN_WEIGHTED_UNION_FIND

#include <algorithm>
#include <cassert>
#include <numeric>
#include <vector>

#include "library/util/default_operator.hpp"

namespace suisen {

    // reference: https://noshi91.hatenablog.com/entry/2018/05/30/191943

    template <typename T, T(*op)(T, T) = default_operator_noref::add<T>, T(*e)() = default_operator_noref::zero<T>, T(*neg)(T) = default_operator_noref::neg<T>>
        struct WeightedUnionFind {
            WeightedUnionFind() = default;
            explicit WeightedUnionFind(int n) : n(n), par(n), siz(n, 1), value(n, e()) {
                std::iota(par.begin(), par.end(), 0);
            }
            // Get the root of `x`. equivalent to `operator[](x)`
            int root(int x) {
                while (par[x] != x) {
                    int& p = par[x];
                    value[x] = op(value[p], value[x]);
                    x = p = par[p];
                }
                return x;
            }
            // Get the root of `x`. euivalent to `root(x)`
            int operator[](int x) {
                return root(x);
            }
            // Merge two vertices `x` and `y` with the distance d = y - x.
            bool merge(int x, int y, const T& d) {
                /**
                 *   [root(x)] ----> [root(y)]
                 *       |     ??=rd     |
                 *  w(x) |               | w(y)
                 *       v               v
                 *      [x] ----------> [y]
                 *               d
                 */
                T rd = op(op(weight(x), d), neg(weight(y)));
                x = root(x), y = root(y);
                if (x == y) return false;
                if (siz[x] < siz[y]) {
                    std::swap(x, y);
                    rd = neg(std::move(rd));
                }
                siz[x] += siz[y], par[y] = x;
                value[y] = std::move(rd);
                return true;
            }
            // Return the distance d = y - x.
            T diff(int x, int y) {
                assert(same(x, y));
                return op(neg(weight(x)), weight(y));
            }
            // Check if `x` and `y` belongs to the same connected component.
            bool same(int x, int y) {
                return root(x) == root(y);
            }
            // Get the size of connected componet to which `x` belongs.
            int size(int x) {
                return siz[root(x)];
            }
            // Get all of connected components.
            std::vector<std::vector<int>> groups() {
                std::vector<std::vector<int>> res(n);
                for (int i = 0; i < n; ++i) res[root(i)].push_back(i);
                res.erase(std::remove_if(res.begin(), res.end(), [](const auto& g) { return g.empty(); }), res.end());
                return res;
            }
        private:
            int n;
            std::vector<int> par, siz;
            std::vector<T> value;

            T weight(int x) {
                T res = e();
                while (par[x] != x) {
                    int& p = par[x];
                    value[x] = op(value[p], value[x]);
                    res = op(value[x], res);
                    x = p = par[p];
                }
                return res;
            }
    };
} // namespace suisen

#endif // SUISEN_WEIGHTED_UNION_FIND
#line 1 "library/datastructure/union_find/weighted_union_find.hpp"



#include <algorithm>
#include <cassert>
#include <numeric>
#include <vector>

#line 1 "library/util/default_operator.hpp"



namespace suisen {
    namespace default_operator {
        template <typename T>
        auto zero() -> decltype(T { 0 }) { return T { 0 }; }
        template <typename T>
        auto one()  -> decltype(T { 1 }) { return T { 1 }; }
        template <typename T>
        auto add(const T &x, const T &y) -> decltype(x + y) { return x + y; }
        template <typename T>
        auto sub(const T &x, const T &y) -> decltype(x - y) { return x - y; }
        template <typename T>
        auto mul(const T &x, const T &y) -> decltype(x * y) { return x * y; }
        template <typename T>
        auto div(const T &x, const T &y) -> decltype(x / y) { return x / y; }
        template <typename T>
        auto mod(const T &x, const T &y) -> decltype(x % y) { return x % y; }
        template <typename T>
        auto neg(const T &x) -> decltype(-x) { return -x; }
        template <typename T>
        auto inv(const T &x) -> decltype(one<T>() / x)  { return one<T>() / x; }
    } // default_operator
    namespace default_operator_noref {
        template <typename T>
        auto zero() -> decltype(T { 0 }) { return T { 0 }; }
        template <typename T>
        auto one()  -> decltype(T { 1 }) { return T { 1 }; }
        template <typename T>
        auto add(T x, T y) -> decltype(x + y) { return x + y; }
        template <typename T>
        auto sub(T x, T y) -> decltype(x - y) { return x - y; }
        template <typename T>
        auto mul(T x, T y) -> decltype(x * y) { return x * y; }
        template <typename T>
        auto div(T x, T y) -> decltype(x / y) { return x / y; }
        template <typename T>
        auto mod(T x, T y) -> decltype(x % y) { return x % y; }
        template <typename T>
        auto neg(T x) -> decltype(-x) { return -x; }
        template <typename T>
        auto inv(T x) -> decltype(one<T>() / x)  { return one<T>() / x; }
    } // default_operator
} // namespace suisen


#line 10 "library/datastructure/union_find/weighted_union_find.hpp"

namespace suisen {

    // reference: https://noshi91.hatenablog.com/entry/2018/05/30/191943

    template <typename T, T(*op)(T, T) = default_operator_noref::add<T>, T(*e)() = default_operator_noref::zero<T>, T(*neg)(T) = default_operator_noref::neg<T>>
        struct WeightedUnionFind {
            WeightedUnionFind() = default;
            explicit WeightedUnionFind(int n) : n(n), par(n), siz(n, 1), value(n, e()) {
                std::iota(par.begin(), par.end(), 0);
            }
            // Get the root of `x`. equivalent to `operator[](x)`
            int root(int x) {
                while (par[x] != x) {
                    int& p = par[x];
                    value[x] = op(value[p], value[x]);
                    x = p = par[p];
                }
                return x;
            }
            // Get the root of `x`. euivalent to `root(x)`
            int operator[](int x) {
                return root(x);
            }
            // Merge two vertices `x` and `y` with the distance d = y - x.
            bool merge(int x, int y, const T& d) {
                /**
                 *   [root(x)] ----> [root(y)]
                 *       |     ??=rd     |
                 *  w(x) |               | w(y)
                 *       v               v
                 *      [x] ----------> [y]
                 *               d
                 */
                T rd = op(op(weight(x), d), neg(weight(y)));
                x = root(x), y = root(y);
                if (x == y) return false;
                if (siz[x] < siz[y]) {
                    std::swap(x, y);
                    rd = neg(std::move(rd));
                }
                siz[x] += siz[y], par[y] = x;
                value[y] = std::move(rd);
                return true;
            }
            // Return the distance d = y - x.
            T diff(int x, int y) {
                assert(same(x, y));
                return op(neg(weight(x)), weight(y));
            }
            // Check if `x` and `y` belongs to the same connected component.
            bool same(int x, int y) {
                return root(x) == root(y);
            }
            // Get the size of connected componet to which `x` belongs.
            int size(int x) {
                return siz[root(x)];
            }
            // Get all of connected components.
            std::vector<std::vector<int>> groups() {
                std::vector<std::vector<int>> res(n);
                for (int i = 0; i < n; ++i) res[root(i)].push_back(i);
                res.erase(std::remove_if(res.begin(), res.end(), [](const auto& g) { return g.empty(); }), res.end());
                return res;
            }
        private:
            int n;
            std::vector<int> par, siz;
            std::vector<T> value;

            T weight(int x) {
                T res = e();
                while (par[x] != x) {
                    int& p = par[x];
                    value[x] = op(value[p], value[x]);
                    res = op(value[x], res);
                    x = p = par[p];
                }
                return res;
            }
    };
} // namespace suisen
Back to top page