cp-library-cpp

This documentation is automatically generated by online-judge-tools/verification-helper

View the Project on GitHub suisen-cp/cp-library-cpp

:heavy_check_mark: test/src/datastructure/bbst/implicit_treap_segtree/dummy.test.cpp

Depends on

Code

#define PROBLEM "https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=ITP1_1_A"

#include <algorithm>
#include <iostream>
#include <limits>
#include <vector>

template <typename T>
std::ostream& operator<<(std::ostream &out, const std::vector<T> &a) {
    out << '{';
    for (auto &e : a) out << e << ',';
    return out << '}';
}

#include "library/datastructure/bbst/implicit_treap_segtree.hpp"

template <typename T, T(*op)(T, T), T(*e)()>
struct NaiveSolutionForSegmentTree {
    NaiveSolutionForSegmentTree() = default;
    NaiveSolutionForSegmentTree(const std::vector<T> &dat) : _n(dat.size()), _dat(dat) {}

    T get(int i) const {
        assert(0 <= i and i < _n);
        return _dat[i];
    }
    void set(int i, const T& val) {
        assert(0 <= i and i < _n);
        _dat[i] = val;
    }

    void insert(int i, const T& val) {
        assert(0 <= i and i <= _n);
        ++_n;
        _dat.insert(_dat.begin() + i, val);
    }
    void erase(int i) {
        assert(0 <= i and i < _n);
        --_n;
        _dat.erase(_dat.begin() + i);
    }

    T prod_all() const {
        return prod(0, _n);
    }
    T prod(int l, int r) const {
        assert(0 <= l and l <= r and r <= _n);
        T res = e();
        for (int i = l; i < r; ++i) res = op(res, _dat[i]);
        return res;
    }

    void reverse_all() {
        reverse(0, _n);
    }
    void reverse(int l, int r) {
        assert(0 <= l and l <= r and r <= _n);
        for (--r; l < r; ++l, --r) {
            std::swap(_dat[l], _dat[r]);
        }
    }

    void rotate(int i) {
        assert(0 <= i and i <= _n);
        std::rotate(_dat.begin(), _dat.begin() + i, _dat.end());
    }

    template <typename Pred>
    int max_right(int l, Pred &&pred) const {
        assert(0 <= l and l <= _n);
        T sum = e();
        for (int r = l; r < _n; ++r) {
            T next_sum = op(sum, _dat[r]);
            if (not pred(next_sum)) return r;
            sum = std::move(next_sum);
        }
        return _n;
    }

    template <typename Pred>
    int min_left(int r, Pred &&pred) const {
        assert(0 <= r and r <= _n);
        T sum = e();
        for (int l = r; l > 0; --l) {
            T next_sum = op(_dat[l - 1], sum);
            if (not pred(next_sum)) return l;
            sum = std::move(next_sum);
        }
        return 0;
    }

    std::vector<T> dump() { return _dat; }
private:
    int _n;
    std::vector<T> _dat;
};

/**
 * Point Set Range Sum
 */

using S = long long;

S op(S x, S y) {
    return x + y;
}
S e() {
    return 0;
}

using Tree = suisen::DynamicSegmentTree<S, op, e>;
using Naive = NaiveSolutionForSegmentTree<S, op, e>;

#include <random>
#include <algorithm>

constexpr int Q_get = 0;
constexpr int Q_set = 1;
constexpr int Q_insert = 4;
constexpr int Q_erase = 5;
constexpr int Q_rotate = 8;
constexpr int Q_prod = 2;
constexpr int Q_prod_all = 3;
constexpr int Q_max_right = 6;
constexpr int Q_min_left = 7;
constexpr int QueryTypeNum = 9;

void test() {
    int N = 3000, Q = 3000, MAX_VAL = 1000000000;

    std::mt19937 rng{std::random_device{}()};

    std::vector<S> init(N);
    for (int i = 0; i < N; ++i) init[i] = rng() % MAX_VAL;
    
    Tree t1(init);
    Naive t2(init);

    for (int i = 0; i < Q; ++i) {
        const int query_type = rng() % QueryTypeNum;

        if (query_type == Q_get) {
            const int i = rng() % N;
            assert(t1.get(i) == t2.get(i));
        } else if (query_type == Q_set) {
            const int i = rng() % N;
            const S v = rng() % MAX_VAL;
            t1.set(i, v);
            t2.set(i, v);
        } else if (query_type == Q_insert) {
            const int i = rng() % (N + 1);
            const S v = rng() % MAX_VAL;
            t1.insert(i, v);
            t2.insert(i, v);
            ++N;
        } else if (query_type == Q_erase) {
            const int i = rng() % N;
            t1.erase(i);
            t2.erase(i);
            --N;
        } else if (query_type == Q_rotate) {
            const int i = rng() % (N + 1);
            t1.rotate(i);
            t2.rotate(i);
        } else if (query_type == Q_prod) {
            const int l = rng() % (N + 1);
            const int r = l + rng() % (N - l + 1);
            assert(t1.prod(l, r) == t2.prod(l, r));
        } else if (query_type == Q_prod_all) {
            assert(t1.prod_all() == t2.prod_all());
        } else if (query_type == Q_max_right) {
            const int l = rng() % (N + 1);
            const int r = l + rng() % (N - l + 1);
            long long sum = std::max(0LL, t2.prod(l, r) + int(rng() % MAX_VAL) - MAX_VAL / 2);
            auto pred = [&](const S &x) { return x <= sum; };

            int r1 = t1.max_right(l, pred).index();
            int r2 = t2.max_right(l, pred);

            assert(r1 == r2);
        } else if (query_type == Q_min_left) {
            const int l = rng() % (N + 1);
            const int r = l + rng() % (N - l + 1);
            long long sum = std::max(0LL, t2.prod(l, r) + int(rng() % MAX_VAL) - MAX_VAL / 2);
            auto pred = [&](const S &x) { return x <= sum; };

            int l1 = t1.min_left(r, pred).index();
            int l2 = t2.min_left(r, pred);

            assert(l1 == l2);
        } else {
            assert(false);
        }
    }
}

void test2() {
    std::mt19937 rng{ std::random_device{}() };
    Tree seq;
    const int n = 300, k = 20;

    std::vector<S> q(n * k);
    for (int i = 0; i < n * k; ++i) {
        q[i] = i % n;
    }
    std::shuffle(q.begin(), q.end(), rng);

    for (int v : q) {
        if (rng() % 2) {
            auto it = seq.lower_bound(v);
            seq.insert(it, v);
            int k = it.index();
            assert(k == 0 or seq[k - 1] < v);
            assert(k == seq.size() - 1 or seq[k + 1] >= v);
        } else {
            auto it = seq.upper_bound(v);
            seq.insert(it, v);
            int k = it.index();
            assert(k == 0 or seq[k - 1] <= v);
            assert(k == seq.size() - 1 or seq[k + 1] > v);
        }
    }

    for (int v : q) {
        auto it = seq.lower_bound(v);
        int k = it.index();
        seq.erase(it);
        assert(k == 0 or seq[k - 1] < v);
        assert(k == seq.size() or seq[k] >= v);
    }

    std::vector<S> sorted = q;
    std::sort(sorted.begin(), sorted.end());

    seq = sorted;

    assert(std::equal(sorted.begin(), sorted.end(), seq.begin()));
    assert(std::equal(sorted.rbegin(), sorted.rend(), seq.rbegin()));

    for (int i = 0; i < n * k; ++i) {
        assert(seq.begin()[i] == i / k);
    }

    {
        auto it = seq.begin();
        for (int q = 0; q < 100000; ++q) {
            int a = rng() % (n * k + 1);
            auto it2 = seq.begin() + a;
            it += it2 - it;
            if (a < n * k) {
                assert(*it == a / k);
            }
        }
    }

    for (int q = 0; q < 10000; ++q) {
        int a = rng() % (n * k + 1);
        int b = rng() % (n * k + 1);
        int d = b - a;
        assert((seq.begin() + a) + d == seq.begin() + b);
        assert(seq.begin() + a == (seq.begin() + b) - d);

        auto it1 = seq.begin() + a;
        auto it2 = seq.begin() + b;

        if (d > 0) {
            assert(not (it1 == it2));
            assert(it1 != it2);
            assert(not (it1 > it2));
            assert(not (it1 >= it2));
            assert(it1 < it2);
            assert(it1 <= it2);
        } else if (d < 0) {
            assert(not (it1 == it2));
            assert(it1 != it2);
            assert(not (it1 < it2));
            assert(not (it1 <= it2));
            assert(it1 > it2);
            assert(it1 >= it2);
        } else {
            assert(not (it1 != it2));
            assert(it1 == it2);
            assert(not (it1 > it2));
            assert(not (it1 < it2));
            assert(it1 <= it2);
            assert(it1 >= it2);
        }

        if (a != n * k and b != n * k) {
            assert(*it1 == a / k);
            assert(*it2 == b / k);

            it1 += d;

            assert(not (it1 != it2));
            assert(it1 == it2);
            assert(not (it1 < it2));
            assert(not (it1 > it2));
            assert(it1 <= it2);
            assert(it1 >= it2);
            assert(*it1 == *it2);
        }
    }

    std::vector<S> naive = sorted;
    assert(std::equal(naive.begin(), naive.end(), seq.begin()));

    // for (S& e : seq) --e; // Compile Error 
    // for (S& e : naive) --e;
    // assert(std::equal(naive.begin(), naive.end(), seq.begin()));
    // assert(std::equal(naive.rbegin(), naive.rend(), seq.rbegin()));

    const Tree& const_seq = const_cast<const Tree&>(seq);
    assert(std::equal(naive.begin(), naive.end(), const_seq.begin()));
    assert(std::equal(naive.rbegin(), naive.rend(), const_seq.rbegin()));

    for (int i = 0; i < n * k; ++i) {
        assert(const_seq[i] == naive[i]);
    }
}

int main() {
    test();
    test2();
    std::cout << "Hello World" << std::endl;
    return 0;
}
#line 1 "test/src/datastructure/bbst/implicit_treap_segtree/dummy.test.cpp"
#define PROBLEM "https://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=ITP1_1_A"

#include <algorithm>
#include <iostream>
#include <limits>
#include <vector>

template <typename T>
std::ostream& operator<<(std::ostream &out, const std::vector<T> &a) {
    out << '{';
    for (auto &e : a) out << e << ',';
    return out << '}';
}

#line 1 "library/datastructure/bbst/implicit_treap_segtree.hpp"



#line 1 "library/datastructure/bbst/implicit_treap_base.hpp"



#line 5 "library/datastructure/bbst/implicit_treap_base.hpp"
#include <cassert>
#include <cstdint>
#include <optional>
#include <string>
#include <random>
#include <tuple>
#line 12 "library/datastructure/bbst/implicit_treap_base.hpp"
#include <utility>

namespace suisen::internal::implicit_treap {
    template <typename T, typename Derived>
    struct Node {
        using random_engine = std::mt19937;
        static inline random_engine rng{ std::random_device{}() };

        using priority_type = std::invoke_result_t<random_engine>;

        static priority_type random_priority() { return rng(); }

        using node_type = Derived;
        using node_pointer = uint32_t;

        using size_type = uint32_t;

        using difference_type = int32_t;
        using value_type = T;
        using pointer = value_type*;
        using const_pointer = const value_type*;
        using reference = value_type&;
        using const_reference = const value_type&;

        static inline std::vector<node_type> _nodes{};
        static inline std::vector<node_pointer> _erased{};

        static constexpr node_pointer null = ~node_pointer(0);

        node_pointer _ch[2]{ null, null };
        value_type _val;
        size_type _size;
        priority_type _priority;

        node_pointer _prev = null, _next = null;

        Node(const value_type val = {}): _val(val), _size(1), _priority(random_priority()) {}

        static void reserve(size_type capacity) { _nodes.reserve(capacity); }

        static bool is_null(node_pointer t) { return t == null; }
        static bool is_not_null(node_pointer t) { return not is_null(t); }

        static node_type& node(node_pointer t) { return _nodes[t]; }
        static const node_type& const_node(node_pointer t) { return _nodes[t]; }

        static value_type& value(node_pointer t) { return node(t)._val; }
        static value_type set_value(node_pointer t, const value_type& new_val) { return std::exchange(value(t), new_val); }

        static bool empty(node_pointer t) { return is_null(t); }
        static size_type& size(node_pointer t) { return node(t)._size; }
        static size_type safe_size(node_pointer t) { return empty(t) ? 0 : size(t); }

        static priority_type& priority(node_pointer t) { return node(t)._priority; }
        static void set_priority(node_pointer t, priority_type new_priority) { priority(t) = new_priority; }

        static node_pointer& prev(node_pointer t) { return node(t)._prev; }
        static node_pointer& next(node_pointer t) { return node(t)._next; }
        static void link(node_pointer l, node_pointer r) { next(l) = r, prev(r) = l; }

        static node_pointer min(node_pointer t) {
            while (true) {
                node_pointer nt = child0(t);
                if (is_null(nt)) return t;
                t = nt;
            }
        }
        static node_pointer max(node_pointer t) {
            while (true) {
                node_pointer nt = child1(t);
                if (is_null(nt)) return t;
                t = nt;
            }
        }

        static node_pointer& child0(node_pointer t) { return node(t)._ch[0]; }
        static node_pointer& child1(node_pointer t) { return node(t)._ch[1]; }
        static node_pointer& child(node_pointer t, bool b) { return node(t)._ch[b]; }
        static node_pointer set_child0(node_pointer t, node_pointer cid) { return std::exchange(child0(t), cid); }
        static node_pointer set_child1(node_pointer t, node_pointer cid) { return std::exchange(child1(t), cid); }
        static node_pointer set_child(node_pointer t, bool b, node_pointer cid) { return std::exchange(child(t, b), cid); }

        static node_pointer update(node_pointer t) { // t : not null
            size(t) = safe_size(child0(t)) + safe_size(child1(t)) + 1;
            return t;
        }

        static node_pointer empty_node() { return null; }
        template <typename ...Args>
        static node_pointer create_node(Args &&...args) {
            if (_erased.size()) {
                node_pointer res = _erased.back();
                _erased.pop_back();
                node(res) = node_type(std::forward<Args>(args)...);
                return res;
            } else {
                node_pointer res = _nodes.size();
                _nodes.emplace_back(std::forward<Args>(args)...);
                return res;
            }
        }
        static void delete_node(node_pointer t) { _erased.push_back(t); }
        static void delete_tree(node_pointer t) {
            if (is_null(t)) return;
            delete_tree(child0(t));
            delete_tree(child1(t));
            delete_node(t);
        }

        template <typename ...Args>
        static node_pointer build(Args &&... args) {
            std::vector<value_type> dat(std::forward<Args>(args)...);

            const size_t n = dat.size();

            std::vector<priority_type> priorities(n);
            std::generate(priorities.begin(), priorities.end(), random_priority);
            std::make_heap(priorities.begin(), priorities.end());

            std::vector<node_pointer> nodes(n);

            auto rec = [&](auto rec, size_t heap_index, size_t dat_index_offset) -> std::pair<size_t, node_pointer> {
                if (heap_index >= n) return { 0, null };
                auto [lsiz, lch] = rec(rec, 2 * heap_index + 1, dat_index_offset);
                dat_index_offset += lsiz;
                node_pointer root = create_node(std::move(dat[dat_index_offset]));
                nodes[dat_index_offset] = root;
                set_priority(root, priorities[heap_index]);
                if (dat_index_offset) {
                    link(nodes[dat_index_offset - 1], root);
                }
                dat_index_offset += 1;
                auto [rsiz, rch] = rec(rec, 2 * heap_index + 2, dat_index_offset);
                set_child0(root, lch);
                set_child1(root, rch);
                return { lsiz + 1 + rsiz, node_type::update(root) };
            };
            return rec(rec, 0, 0).second;
        }

        static std::pair<node_pointer, node_pointer> split(node_pointer t, size_type k) {
            if (k == 0) return { null, t };
            if (k == size(t)) return { t, null };

            static std::vector<node_pointer> lp{}, rp{};

            while (true) {
                if (const size_type lsiz = safe_size(child0(t)); k <= lsiz) {
                    if (rp.size()) set_child0(rp.back(), t);
                    rp.push_back(t);
                    if (k == lsiz) {
                        if (lp.size()) set_child1(lp.back(), child0(t));

                        node_pointer lt = set_child0(t, null), rt = null;

                        while (lp.size()) node_type::update(lt = lp.back()), lp.pop_back();
                        while (rp.size()) node_type::update(rt = rp.back()), rp.pop_back();

                        return { lt, rt };
                    }
                    t = child0(t);
                } else {
                    if (lp.size()) set_child1(lp.back(), t);
                    lp.push_back(t);
                    t = child1(t);
                    k -= lsiz + 1;
                }
            }
        }
        static std::tuple<node_pointer, node_pointer, node_pointer> split(node_pointer t, size_type l, size_type r) {
            auto [tlm, tr] = split(t, r);
            auto [tl, tm] = split(tlm, l);
            return { tl, tm, tr };
        }

        static node_pointer merge_impl(node_pointer tl, node_pointer tr) {
            if (priority(tl) < priority(tr)) {
                if (node_pointer tm = child0(tr); is_null(tm)) {
                    link(max(tl), tr);
                    set_child0(tr, tl);
                } else {
                    set_child0(tr, merge(tl, tm));
                }
                return node_type::update(tr);
            } else {
                if (node_pointer tm = child1(tl); is_null(tm)) {
                    link(tl, min(tr));
                    set_child1(tl, tr);
                } else {
                    set_child1(tl, merge(tm, tr));
                }
                return node_type::update(tl);
            }
        }
        static node_pointer merge(node_pointer tl, node_pointer tr) {
            if (is_null(tl)) return tr;
            if (is_null(tr)) return tl;
            return merge_impl(tl, tr);
        }
        static node_pointer merge(node_pointer tl, node_pointer tm, node_pointer tr) {
            return merge(merge(tl, tm), tr);
        }

        static node_pointer insert_impl(node_pointer t, size_type k, node_pointer new_node) {
            if (is_null(t)) return new_node;
            static std::vector<node_pointer> st;
            bool b = false;

            while (true) {
                if (is_null(t) or priority(new_node) > priority(t)) {
                    if (is_null(t)) {
                        t = new_node;
                    } else {
                        auto [tl, tr] = split(t, k);
                        if (is_not_null(tl)) link(max(tl), new_node);
                        if (is_not_null(tr)) link(new_node, min(tr));
                        set_child0(new_node, tl);
                        set_child1(new_node, tr);
                        t = node_type::update(new_node);
                    }
                    if (st.size()) {
                        set_child(st.back(), b, t);
                        do t = node_type::update(st.back()), st.pop_back(); while (st.size());
                    }
                    return t;
                } else {
                    if (const size_type lsiz = safe_size(child0(t)); k <= lsiz) {
                        if (k == lsiz) link(new_node, t);
                        st.push_back(t), b = false;
                        t = child0(t);
                    } else {
                        if (k == lsiz + 1) link(t, new_node);
                        st.push_back(t), b = true;
                        t = child1(t);
                        k -= lsiz + 1;
                    }
                }
            }
        }
        template <typename ...Args>
        static node_pointer insert(node_pointer t, size_type k, Args &&...args) {
            return insert_impl(t, k, create_node(std::forward<Args>(args)...));
        }

        static std::pair<node_pointer, value_type> erase(node_pointer t, size_type k) {
            if (const size_type lsiz = safe_size(child0(t)); k == lsiz) {
                delete_node(t);
                return { merge(child0(t), child1(t)), std::move(value(t)) };
            } else if (k < lsiz) {
                auto [c0, v] = erase(child0(t), k);
                set_child0(t, c0);
                if (is_not_null(c0) and k == lsiz - 1) link(max(c0), t);
                return { node_type::update(t), std::move(v) };
            } else {
                auto [c1, v] = erase(child1(t), k - (lsiz + 1));
                set_child1(t, c1);
                if (is_not_null(c1) and k == lsiz + 1) link(t, min(c1));
                return { node_type::update(t), std::move(v) };
            }
        }

        static node_pointer rotate(node_pointer t, size_type k) {
            auto [tl, tr] = split(t, k);
            return merge(tr, tl);
        }
        static node_pointer rotate(node_pointer t, size_type l, size_type m, size_type r) {
            auto [tl, tm, tr] = split(t, l, r);
            return merge(tl, rotate(tm, m - l), tr);
        }

        template <typename Func>
        static node_pointer set_update(node_pointer t, size_type k, const Func& f) {
            if (const size_type lsiz = safe_size(child0(t)); k == lsiz) {
                value_type& val = value(t);
                val = f(const_cast<const value_type&>(val));
            } else if (k < lsiz) {
                set_child0(t, set_update(child0(t), k, f));
            } else {
                set_child1(t, set_update(child1(t), k - (lsiz + 1), f));
            }
            return node_type::update(t);
        }

        static std::vector<value_type> dump(node_pointer t) {
            std::vector<value_type> res;
            res.reserve(safe_size(t));
            auto rec = [&](auto rec, node_pointer t) -> void {
                if (is_null(t)) return;
                rec(rec, child0(t));
                res.push_back(value(t));
                rec(rec, child1(t));
            };
            rec(rec, t);
            return res;
        }

        template <bool reversed_, bool constant_>
        struct NodeIterator {
            static constexpr bool constant = constant_;
            static constexpr bool reversed = reversed_;

            friend Node;
            friend Derived;

            using difference_type = Node::difference_type;
            using value_type = Node::value_type;
            using pointer = std::conditional_t<constant, Node::const_pointer, Node::pointer>;
            using reference = std::conditional_t<constant, Node::const_reference, Node::reference>;
            using iterator_category = std::random_access_iterator_tag;

            NodeIterator(): NodeIterator(null) {}
            explicit NodeIterator(node_pointer root): NodeIterator(root, 0, null) {}
            NodeIterator(const NodeIterator<reversed, not constant>& it): NodeIterator(it._root, it._index, it._cur) {}

            reference operator*() const {
                if (is_null(_cur) and _index != safe_size(_root)) {
                    _cur = _root;
                    for (size_type k = _index;;) {
                        if (size_type siz = safe_size(child(_cur, reversed)); k == siz) {
                            break;
                        } else if (k < siz) {
                            _cur = child(_cur, reversed);
                        } else {
                            _cur = child(_cur, not reversed);
                            k -= siz + 1;
                        }
                    }
                }
                return value(_cur);
            }
            reference operator[](difference_type k) const { return *((*this) + k); }

            NodeIterator& operator++() { return *this += 1; }
            NodeIterator& operator--() { return *this -= 1; }
            NodeIterator& operator+=(difference_type k) { return suc(+k), * this; }
            NodeIterator& operator-=(difference_type k) { return suc(-k), * this; }
            NodeIterator operator++(int) { NodeIterator res = *this; ++(*this); return res; }
            NodeIterator operator--(int) { NodeIterator res = *this; --(*this); return res; }
            friend NodeIterator operator+(NodeIterator it, difference_type k) { return it += k; }
            friend NodeIterator operator+(difference_type k, NodeIterator it) { return it += k; }
            friend NodeIterator operator-(NodeIterator it, difference_type k) { return it -= k; }

            friend difference_type operator-(const NodeIterator& lhs, const NodeIterator& rhs) { return lhs._index - rhs._index; }

            friend bool operator==(const NodeIterator& lhs, const NodeIterator& rhs) { return lhs._index == rhs._index; }
            friend bool operator!=(const NodeIterator& lhs, const NodeIterator& rhs) { return lhs._index != rhs._index; }
            friend bool operator<(const NodeIterator& lhs, const NodeIterator& rhs) { return lhs._index < rhs._index; }
            friend bool operator>(const NodeIterator& lhs, const NodeIterator& rhs) { return lhs._index > rhs._index; }
            friend bool operator<=(const NodeIterator& lhs, const NodeIterator& rhs) { return lhs._index <= rhs._index; }
            friend bool operator>=(const NodeIterator& lhs, const NodeIterator& rhs) { return lhs._index >= rhs._index; }

            static NodeIterator begin(node_pointer root) { return NodeIterator(root, 0, null); }
            static NodeIterator end(node_pointer root) { return NodeIterator(root, safe_size(root), null); }

            int size() const { return safe_size(_root); }
            int index() const { return _index; }
        private:
            node_pointer _root;
            size_type _index;
            mutable node_pointer _cur; // it==end() or uninitialized (updates only index)

            NodeIterator(node_pointer root, size_type index, node_pointer cur): _root(root), _index(index), _cur(cur) {}

            void suc(difference_type k) {
                _index += k;
                if (_index == safe_size(_root) or std::abs(k) >= 20) _cur = null;
                if (is_null(_cur)) return;

                const bool positive = k < 0 ? (k = -k, reversed) : not reversed;

                if (positive) {
                    while (k-- > 0) _cur = next(_cur);
                } else {
                    while (k-- > 0) _cur = prev(_cur);
                }
            }

            node_pointer root() const { return _root; }
            void set_root(node_pointer new_root, size_type new_index) { _root = new_root, _index = new_index; }

            node_pointer get_child0() const { return child0(_cur); }
            node_pointer get_child1() const { return child1(_cur); }

            template <typename Predicate>
            static NodeIterator binary_search(node_pointer t, const Predicate& f) {
                NodeIterator res(t, safe_size(t), null);
                if (is_null(t)) return res;

                NodeIterator it(t, safe_size(child0(t)), t);
                while (is_not_null(it._cur)) {
                    if (f(it)) {
                        res = it;
                        it._cur = it.get_child0();
                        it._index -= is_null(it._cur) ? 1 : safe_size(it.get_child1()) + 1;
                    } else {
                        it._cur = it.get_child1();
                        it._index += is_null(it._cur) ? 1 : safe_size(it.get_child0()) + 1;
                    }
                }
                return res;
            }

            size_type get_gap_index_left() const {
                if constexpr (reversed) return size() - index();
                else return index();
            }
            size_type get_element_index_left() const {
                if constexpr (reversed) return size() - index() - 1;
                else return index();
            }
        };
        using iterator = NodeIterator<false, false>;
        using reverse_iterator = NodeIterator<true, false>;
        using const_iterator = NodeIterator<false, true>;
        using const_reverse_iterator = NodeIterator<true, true>;

        template <typename>
        struct is_node_iterator: std::false_type {};
        template <bool reversed_, bool constant_>
        struct is_node_iterator<NodeIterator<reversed_, constant_>>: std::true_type {};
        template <typename X>
        static constexpr bool is_node_iterator_v = is_node_iterator<X>::value;

        static iterator begin(node_pointer t) { return iterator::begin(t); }
        static iterator end(node_pointer t) { return iterator::end(t); }
        static reverse_iterator rbegin(node_pointer t) { return reverse_iterator::begin(t); }
        static reverse_iterator rend(node_pointer t) { return reverse_iterator::end(t); }
        static const_iterator cbegin(node_pointer t) { return const_iterator::begin(t); }
        static const_iterator cend(node_pointer t) { return const_iterator::end(t); }
        static const_reverse_iterator crbegin(node_pointer t) { return const_reverse_iterator::begin(t); }
        static const_reverse_iterator crend(node_pointer t) { return const_reverse_iterator::end(t); }

        // Find the first element that satisfies the condition f : iterator -> { false, true }.
        // Returns const_iterator
        template <typename Iterator, typename Predicate, std::enable_if_t<is_node_iterator_v<Iterator>, std::nullptr_t> = nullptr>
        static Iterator binary_search(node_pointer t, const Predicate& f) {
            return Iterator::binary_search(t, f);
        }
        // comp(T t, U u) = (t < u)
        template <typename Iterator, typename U, typename Compare = std::less<>, std::enable_if_t<is_node_iterator_v<Iterator>, std::nullptr_t> = nullptr>
        static Iterator lower_bound(node_pointer t, const U& target, Compare comp) {
            return binary_search<Iterator>(t, [&](Iterator it) { return not comp(*it, target); });
        }
        // comp(T u, U t) = (u < t)
        template <typename Iterator, typename U, typename Compare = std::less<>, std::enable_if_t<is_node_iterator_v<Iterator>, std::nullptr_t> = nullptr>
        static Iterator upper_bound(node_pointer t, const U& target, Compare comp) {
            return binary_search<Iterator>(t, [&](Iterator it) { return comp(target, *it); });
        }

        template <typename Iterator, std::enable_if_t<is_node_iterator_v<Iterator>, std::nullptr_t> = nullptr>
        static node_pointer insert(Iterator it, const value_type& val) {
            return insert(it.root(), it.get_gap_index_left(), val);
        }
        template <typename Iterator, std::enable_if_t<is_node_iterator_v<Iterator>, std::nullptr_t> = nullptr>
        static std::pair<node_pointer, value_type> erase(Iterator it) {
            return erase(it.root(), it.get_element_index_left());
        }
        template <typename Iterator, std::enable_if_t<is_node_iterator_v<Iterator>, std::nullptr_t> = nullptr>
        static std::pair<node_pointer, node_pointer> split(Iterator it) {
            return split(it.root(), it.get_gap_index_left());
        }
    };
} // namespace suisen::internal::implicit_treap


#line 5 "library/datastructure/bbst/implicit_treap_segtree.hpp"

namespace suisen {
    namespace internal::implicit_treap {
        template <typename T, T(*op)(T, T), T(*e)()>
        struct RangeProductNode: Node<T, RangeProductNode<T, op, e>> {
            using base = Node<T, RangeProductNode<T, op, e>>;
            using node_pointer = typename base::node_pointer;
            using value_type = typename base::value_type;

            value_type _sum;
            RangeProductNode(const value_type& val): base(val), _sum(val) {}

            // ----- override ----- //
            static node_pointer update(node_pointer t) {
                base::update(t);
                prod_all(t) = op(op(safe_prod(base::child0(t)), base::value(t)), safe_prod(base::child1(t)));
                return t;
            }

            // ----- new features ----- //
            static value_type& prod_all(node_pointer t) {
                return base::node(t)._sum;
            }
            static value_type safe_prod(node_pointer t) {
                return base::is_null(t) ? e() : prod_all(t);
            }
            static std::pair<node_pointer, value_type> prod(node_pointer t, size_t l, size_t r) {
                auto [tl, tm, tr] = base::split(t, l, r);
                value_type res = safe_prod(tm);
                return { base::merge(tl, tm, tr), res };
            }
            template <typename Func>
            static node_pointer set(node_pointer t, size_t k, const Func& f) {
                return base::set_update(t, k, f);
            }

            using const_iterator = typename base::const_iterator;

            template <typename Predicate>
            static std::pair<node_pointer, const_iterator> max_right(node_pointer t, size_t l, const Predicate& f) {
                auto [tl, tr] = base::split(t, l);
                value_type sum = e();
                assert(f(sum));
                const_iterator it = base::template binary_search<const_iterator>(
                    tr, [&](const_iterator it) {
                        value_type nxt_sum = op(op(sum, safe_prod(it.get_child0())), *it);
                        return f(nxt_sum) ? (sum = std::move(nxt_sum), false) : true;
                    }
                );
                it.set_root(t = base::merge(tl, tr), l + it.index());
                return { t, it };
            }
            template <typename Predicate>
            static std::pair<node_pointer, const_iterator> min_left(node_pointer t, size_t r, const Predicate& f) {
                auto [tl, tr] = base::split(t, r);
                value_type sum = e();
                assert(f(sum));
                const_iterator it = base::template binary_search<const_iterator>(
                    tl, [&](const_iterator it) {
                        value_type nxt_sum = op(*it, op(safe_prod(it.get_child1()), sum));
                        return f(nxt_sum) ? (sum = std::move(nxt_sum), true) : false;
                    }
                );
                it.set_root(t = base::merge(tl, tr), it.index());
                return { t, it };
            }
        };
    }

    template <typename T, T(*op)(T, T), T(*e)()>
    class DynamicSegmentTree {
        using node_type = internal::implicit_treap::RangeProductNode<T, op, e>;
        using node_pointer = typename node_type::node_pointer;

        node_pointer _root;

        struct node_pointer_construct {};
        DynamicSegmentTree(node_pointer root, node_pointer_construct): _root(root) {}

    public:
        using value_type = typename node_type::value_type;

        DynamicSegmentTree(): _root(node_type::empty_node()) {}
        explicit DynamicSegmentTree(size_t n, const value_type& fill_value = {}): _root(node_type::build(n, fill_value)) {}
        template <typename U>
        DynamicSegmentTree(const std::vector<U>& dat) : _root(node_type::build(dat.begin(), dat.end())) {}

        void free() {
            node_type::delete_tree(_root);
            _root = node_type::empty_node();
        }
        void clear() { free(); }

        static void reserve(size_t capacity) { node_type::reserve(capacity); }

        bool empty() const { return node_type::empty(_root); }
        int size() const { return node_type::safe_size(_root); }

        const value_type& operator[](size_t k) const { return get(k); }
        const value_type& get(size_t k) const {
            assert(k < size_t(size()));
            return cbegin()[k];
        }
        const value_type& front() const { return *cbegin(); }
        const value_type& back() const { return *crbegin(); }

        void set(size_t k, const value_type& val) {
            assert(k < size_t(size()));
            _root = node_type::set(_root, k, [&](const value_type&) { return val; });
        }
        template <typename Func>
        void apply(size_t k, const Func& f) {
            assert(k < size_t(size()));
            _root = node_type::set(_root, k, [&](const value_type& val) { return f(val); });
        }

        value_type prod_all() const { return node_type::safe_prod(_root); }
        value_type prod(size_t l, size_t r) {
            value_type res;
            std::tie(_root, res) = node_type::prod(_root, l, r);
            return res;
        }

        void insert(size_t k, const value_type& val) {
            assert(k <= size_t(size()));
            _root = node_type::insert(_root, k, val);
        }
        void push_front(const value_type& val) { insert(0, val); }
        void push_back(const value_type& val) { insert(size(), val); }

        value_type erase(size_t k) {
            assert(k <= size_t(size()));
            value_type v;
            std::tie(_root, v) = node_type::erase(_root, k);
            return v;
        }
        value_type pop_front() { return erase(0); }
        value_type pop_back() { return erase(size() - 1); }

        // Split immediately before the k-th element.
        DynamicSegmentTree split(size_t k) {
            assert(k <= size_t(size()));
            node_pointer root_r;
            std::tie(_root, root_r) = node_type::split(_root, k);
            return DynamicSegmentTree(root_r, node_pointer_construct{});
        }

        void merge(DynamicSegmentTree r) { _root = node_type::merge(_root, r._root); }

        void rotate(size_t k) {
            assert(k <= size_t(size()));
            _root = node_type::rotate(_root, k);
        }
        void rotate(size_t l, size_t m, size_t r) {
            assert(l <= m and m <= r and r <= size_t(size()));
            _root = node_type::rotate(_root, l, m, r);
        }

        std::vector<value_type> dump() const { return node_type::dump(_root); }

        using iterator = typename node_type::const_iterator;
        using reverse_iterator = typename node_type::const_reverse_iterator;
        using const_iterator = typename node_type::const_iterator;
        using const_reverse_iterator = typename node_type::const_reverse_iterator;

        iterator begin() const { return cbegin(); }
        iterator end() const { return cend(); }
        reverse_iterator rbegin() const { return crbegin(); }
        reverse_iterator rend() const { return crend(); }
        const_iterator cbegin() const { return node_type::cbegin(_root); }
        const_iterator cend() const { return node_type::cend(_root); }
        const_reverse_iterator crbegin() const { return node_type::crbegin(_root); }
        const_reverse_iterator crend() const { return node_type::crend(_root); }

        // Returns the iterator with index max{ r | f(op(A[l], ..., A[r-1])) = true } (0 <= r <= size())
        template <typename Predicate>
        iterator max_right(size_t l, const Predicate& f) {
            assert(l <= size_t(size()));
            iterator it;
            std::tie(_root, it) = node_type::max_right(_root, l, f);
            return it;
        }
        // Returns the iterator with index min{ l | f(op(A[l], ..., A[r-1])) = true } (0 <= l <= size())
        template <typename Predicate>
        iterator min_left(size_t r, const Predicate& f) {
            assert(r <= size_t(size()));
            iterator it;
            std::tie(_root, it) = node_type::min_left(_root, r, f);
            return it;
        }

        // Find the first element that satisfies the condition f.
        // Returns { position, optional(value) }
        // Requirements: f(A[i]) must be monotonic
        template <typename Predicate>
        iterator binary_search(const Predicate& f) {
            return node_type::template binary_search<iterator>(_root, f);
        }
        // comp(T t, U u) = (t < u)
        // Requirements: sequence is sorted
        template <typename U, typename Compare = std::less<>>
        iterator lower_bound(const U& target, Compare comp = {}) {
            return node_type::template lower_bound<iterator>(_root, target, comp);
        }
        // comp(T u, U t) = (u < t)
        // Requirements: sequence is sorted
        template <typename U, typename Compare = std::less<>>
        iterator upper_bound(const U& target, Compare comp = {}) {
            return node_type::template upper_bound<iterator>(_root, target, comp);
        }
        // Find the first element that satisfies the condition f.
        // Returns { position, optional(value) }
        // Requirements: f(A[i]) must be monotonic
        template <typename Predicate>
        const_iterator binary_search(const Predicate& f) const {
            return node_type::template binary_search<const_iterator>(_root, f);
        }
        // comp(T t, U u) = (t < u)
        // Requirements: sequence is sorted
        template <typename U, typename Compare = std::less<>>
        const_iterator lower_bound(const U& target, Compare comp = {}) const {
            return node_type::template lower_bound<const_iterator>(_root, target, comp);
        }
        // comp(T u, U t) = (u < t)
        // Requirements: sequence is sorted
        template <typename U, typename Compare = std::less<>>
        const_iterator upper_bound(const U& target, Compare comp = {}) const {
            return node_type::template upper_bound<const_iterator>(_root, target, comp);
        }
 
        template <typename Iterator, std::enable_if_t<node_type::template is_node_iterator_v<Iterator>, std::nullptr_t> = nullptr>
        void insert(Iterator it, const value_type &val) {
            _root = node_type::insert(it, val);
        }
        template <typename Iterator, std::enable_if_t<node_type::template is_node_iterator_v<Iterator>, std::nullptr_t> = nullptr>
        value_type erase(Iterator it) {
            value_type erased;
            std::tie(_root, erased) = node_type::erase(it);
            return erased;
        }
        template <typename Iterator, std::enable_if_t<node_type::template is_node_iterator_v<Iterator>, std::nullptr_t> = nullptr>
        DynamicSegmentTree split(Iterator it) {
            node_pointer root_r;
            std::tie(_root, root_r) = node_type::split(it);
            return DynamicSegmentTree(root_r, node_pointer_construct{});
        }

        // handling internal nodes
        using internal_node = node_type;
        using internal_node_pointer = node_pointer;

        internal_node_pointer& root_node() { return _root; }
        const internal_node_pointer& root_node() const { return _root; }
        void set_root_node(internal_node_pointer new_root) { root_node() = new_root; }
    };
} // namespace suisen



#line 16 "test/src/datastructure/bbst/implicit_treap_segtree/dummy.test.cpp"

template <typename T, T(*op)(T, T), T(*e)()>
struct NaiveSolutionForSegmentTree {
    NaiveSolutionForSegmentTree() = default;
    NaiveSolutionForSegmentTree(const std::vector<T> &dat) : _n(dat.size()), _dat(dat) {}

    T get(int i) const {
        assert(0 <= i and i < _n);
        return _dat[i];
    }
    void set(int i, const T& val) {
        assert(0 <= i and i < _n);
        _dat[i] = val;
    }

    void insert(int i, const T& val) {
        assert(0 <= i and i <= _n);
        ++_n;
        _dat.insert(_dat.begin() + i, val);
    }
    void erase(int i) {
        assert(0 <= i and i < _n);
        --_n;
        _dat.erase(_dat.begin() + i);
    }

    T prod_all() const {
        return prod(0, _n);
    }
    T prod(int l, int r) const {
        assert(0 <= l and l <= r and r <= _n);
        T res = e();
        for (int i = l; i < r; ++i) res = op(res, _dat[i]);
        return res;
    }

    void reverse_all() {
        reverse(0, _n);
    }
    void reverse(int l, int r) {
        assert(0 <= l and l <= r and r <= _n);
        for (--r; l < r; ++l, --r) {
            std::swap(_dat[l], _dat[r]);
        }
    }

    void rotate(int i) {
        assert(0 <= i and i <= _n);
        std::rotate(_dat.begin(), _dat.begin() + i, _dat.end());
    }

    template <typename Pred>
    int max_right(int l, Pred &&pred) const {
        assert(0 <= l and l <= _n);
        T sum = e();
        for (int r = l; r < _n; ++r) {
            T next_sum = op(sum, _dat[r]);
            if (not pred(next_sum)) return r;
            sum = std::move(next_sum);
        }
        return _n;
    }

    template <typename Pred>
    int min_left(int r, Pred &&pred) const {
        assert(0 <= r and r <= _n);
        T sum = e();
        for (int l = r; l > 0; --l) {
            T next_sum = op(_dat[l - 1], sum);
            if (not pred(next_sum)) return l;
            sum = std::move(next_sum);
        }
        return 0;
    }

    std::vector<T> dump() { return _dat; }
private:
    int _n;
    std::vector<T> _dat;
};

/**
 * Point Set Range Sum
 */

using S = long long;

S op(S x, S y) {
    return x + y;
}
S e() {
    return 0;
}

using Tree = suisen::DynamicSegmentTree<S, op, e>;
using Naive = NaiveSolutionForSegmentTree<S, op, e>;

#line 115 "test/src/datastructure/bbst/implicit_treap_segtree/dummy.test.cpp"

constexpr int Q_get = 0;
constexpr int Q_set = 1;
constexpr int Q_insert = 4;
constexpr int Q_erase = 5;
constexpr int Q_rotate = 8;
constexpr int Q_prod = 2;
constexpr int Q_prod_all = 3;
constexpr int Q_max_right = 6;
constexpr int Q_min_left = 7;
constexpr int QueryTypeNum = 9;

void test() {
    int N = 3000, Q = 3000, MAX_VAL = 1000000000;

    std::mt19937 rng{std::random_device{}()};

    std::vector<S> init(N);
    for (int i = 0; i < N; ++i) init[i] = rng() % MAX_VAL;
    
    Tree t1(init);
    Naive t2(init);

    for (int i = 0; i < Q; ++i) {
        const int query_type = rng() % QueryTypeNum;

        if (query_type == Q_get) {
            const int i = rng() % N;
            assert(t1.get(i) == t2.get(i));
        } else if (query_type == Q_set) {
            const int i = rng() % N;
            const S v = rng() % MAX_VAL;
            t1.set(i, v);
            t2.set(i, v);
        } else if (query_type == Q_insert) {
            const int i = rng() % (N + 1);
            const S v = rng() % MAX_VAL;
            t1.insert(i, v);
            t2.insert(i, v);
            ++N;
        } else if (query_type == Q_erase) {
            const int i = rng() % N;
            t1.erase(i);
            t2.erase(i);
            --N;
        } else if (query_type == Q_rotate) {
            const int i = rng() % (N + 1);
            t1.rotate(i);
            t2.rotate(i);
        } else if (query_type == Q_prod) {
            const int l = rng() % (N + 1);
            const int r = l + rng() % (N - l + 1);
            assert(t1.prod(l, r) == t2.prod(l, r));
        } else if (query_type == Q_prod_all) {
            assert(t1.prod_all() == t2.prod_all());
        } else if (query_type == Q_max_right) {
            const int l = rng() % (N + 1);
            const int r = l + rng() % (N - l + 1);
            long long sum = std::max(0LL, t2.prod(l, r) + int(rng() % MAX_VAL) - MAX_VAL / 2);
            auto pred = [&](const S &x) { return x <= sum; };

            int r1 = t1.max_right(l, pred).index();
            int r2 = t2.max_right(l, pred);

            assert(r1 == r2);
        } else if (query_type == Q_min_left) {
            const int l = rng() % (N + 1);
            const int r = l + rng() % (N - l + 1);
            long long sum = std::max(0LL, t2.prod(l, r) + int(rng() % MAX_VAL) - MAX_VAL / 2);
            auto pred = [&](const S &x) { return x <= sum; };

            int l1 = t1.min_left(r, pred).index();
            int l2 = t2.min_left(r, pred);

            assert(l1 == l2);
        } else {
            assert(false);
        }
    }
}

void test2() {
    std::mt19937 rng{ std::random_device{}() };
    Tree seq;
    const int n = 300, k = 20;

    std::vector<S> q(n * k);
    for (int i = 0; i < n * k; ++i) {
        q[i] = i % n;
    }
    std::shuffle(q.begin(), q.end(), rng);

    for (int v : q) {
        if (rng() % 2) {
            auto it = seq.lower_bound(v);
            seq.insert(it, v);
            int k = it.index();
            assert(k == 0 or seq[k - 1] < v);
            assert(k == seq.size() - 1 or seq[k + 1] >= v);
        } else {
            auto it = seq.upper_bound(v);
            seq.insert(it, v);
            int k = it.index();
            assert(k == 0 or seq[k - 1] <= v);
            assert(k == seq.size() - 1 or seq[k + 1] > v);
        }
    }

    for (int v : q) {
        auto it = seq.lower_bound(v);
        int k = it.index();
        seq.erase(it);
        assert(k == 0 or seq[k - 1] < v);
        assert(k == seq.size() or seq[k] >= v);
    }

    std::vector<S> sorted = q;
    std::sort(sorted.begin(), sorted.end());

    seq = sorted;

    assert(std::equal(sorted.begin(), sorted.end(), seq.begin()));
    assert(std::equal(sorted.rbegin(), sorted.rend(), seq.rbegin()));

    for (int i = 0; i < n * k; ++i) {
        assert(seq.begin()[i] == i / k);
    }

    {
        auto it = seq.begin();
        for (int q = 0; q < 100000; ++q) {
            int a = rng() % (n * k + 1);
            auto it2 = seq.begin() + a;
            it += it2 - it;
            if (a < n * k) {
                assert(*it == a / k);
            }
        }
    }

    for (int q = 0; q < 10000; ++q) {
        int a = rng() % (n * k + 1);
        int b = rng() % (n * k + 1);
        int d = b - a;
        assert((seq.begin() + a) + d == seq.begin() + b);
        assert(seq.begin() + a == (seq.begin() + b) - d);

        auto it1 = seq.begin() + a;
        auto it2 = seq.begin() + b;

        if (d > 0) {
            assert(not (it1 == it2));
            assert(it1 != it2);
            assert(not (it1 > it2));
            assert(not (it1 >= it2));
            assert(it1 < it2);
            assert(it1 <= it2);
        } else if (d < 0) {
            assert(not (it1 == it2));
            assert(it1 != it2);
            assert(not (it1 < it2));
            assert(not (it1 <= it2));
            assert(it1 > it2);
            assert(it1 >= it2);
        } else {
            assert(not (it1 != it2));
            assert(it1 == it2);
            assert(not (it1 > it2));
            assert(not (it1 < it2));
            assert(it1 <= it2);
            assert(it1 >= it2);
        }

        if (a != n * k and b != n * k) {
            assert(*it1 == a / k);
            assert(*it2 == b / k);

            it1 += d;

            assert(not (it1 != it2));
            assert(it1 == it2);
            assert(not (it1 < it2));
            assert(not (it1 > it2));
            assert(it1 <= it2);
            assert(it1 >= it2);
            assert(*it1 == *it2);
        }
    }

    std::vector<S> naive = sorted;
    assert(std::equal(naive.begin(), naive.end(), seq.begin()));

    // for (S& e : seq) --e; // Compile Error 
    // for (S& e : naive) --e;
    // assert(std::equal(naive.begin(), naive.end(), seq.begin()));
    // assert(std::equal(naive.rbegin(), naive.rend(), seq.rbegin()));

    const Tree& const_seq = const_cast<const Tree&>(seq);
    assert(std::equal(naive.begin(), naive.end(), const_seq.begin()));
    assert(std::equal(naive.rbegin(), naive.rend(), const_seq.rbegin()));

    for (int i = 0; i < n * k; ++i) {
        assert(const_seq[i] == naive[i]);
    }
}

int main() {
    test();
    test2();
    std::cout << "Hello World" << std::endl;
    return 0;
}
Back to top page