Skip to content
Rain Hu's Workspace
Go back

[C++] Segment Tree

Rain Hu

線段樹 Segment Tree

簡介

建樹 build

void build(int s, int t, int p, const vector<int>& arr){
    if (s == t){
        tree[p] = SegmentItem(arr[s], 1);
        return;
    }
    int m = s + ((t - s) >> 1);
    build(s, m, p*2, arr);
    build(m+1, t, p*2+1, arr);
    // push_up
    tree[p] = tree[p*2] + tree[(p*2)+1];
}

查詢 query

SegmentItem find(int l, int r, int s, int t, int p){
    if (l <= s && t <= r){
        return tree[p];
    }
    int m = s + ((t - s) >> 1);
    SegmentItem sum;
    if (r <= m) return find(l, r, s, m, p*2);
    if (l > m) return find(l, r, m+1, t, p*2+1);
    return find(l, r, s, m, p*2) + find(l, r, m+1, t, p*2+1);
}

zkw 線段樹

結構

建樹 build

class Tree {
private:
    vector<int> arr;
    int n, m;   // n 為維護點數, m 為 zkw-tree 子葉節點數
public: 
    Tree (vector<int>& nums){
        n = nums.size();
        for (m = 1; m <= n; m <<= 1);   // 取大於等於 n 且為 2 的冪次的最小整數
        arr.assign(2*m, 0);     // 節點數設為 2m 個,其中 0 為空節點
    }
    void build(vector<int> nums){
        for (int i = 0; i < n; i++) {
            arr[i+m] = nums[i];  // 從子葉節點最左邊往右更新節點。
            mx[i+m] = nums[i];
            mn[i+m] = nums[i];
        }
        for (int i = m-1; i; i--){  // 向上更新父節點。
            arr[i] = in(x);
        }
    }
};
    // 取和
    arr[i] = arr[i<<1] + arr[i<<1|1];
    // 取最大值
    arr[i] = max(arr[i<<1], arr[i<<1|1]);
    // 取最小值
    arr[i] = min(arr[i<<1], arr[i<<1|1]);

更新 update

void update(int i, int val){
    int diff = val - arr[m+i]   // 原值 arr[m+i] 與新值 val 的差
    for (i += m; i; i >>= 1){
        arr[i] += diff;
    }
}

查詢 query

int query(int left, int right){
    int sum = 0;
    int i = left+m;     // 左閉區間
    int j = right+m;    // 右閉區間
    for (; i <= j; i >>= 1, j >>= 1){
        if (i & 1) sum += arr[i++];
        if (!(j & 1)) sum += arr[j--];
    }
    return sum;
}
int query(int left, int right){
    int sum = 0;
    int i = left+m-1;
    int j = right+m+1;
    for(; i^j^1; i >>= 1, j >>= 1){
        if (~i & 1) sum += arr[i^1];
        if (j & 1) sum += arr[j^1];
    }
    return sum;
}

區間修改

void update(int left, int right, int diff){
    int len = 1, cntl = 0, cntr = 0; // cntl, cntr 是左右邊分別實際修改的區間長度
    left += m-1;
    right += m+1;
    for (; left^right^1; left >> 1, right >> 1, len << 1){
        arr[left] += cntl * diff;
        arr[right] += cntr * diff;
        if (~left & 1) {
            arr[left^1] += diff * len;
            mark[left^1] += diff;
            cntl += len;
        }
        if (right & 1) {
            arr[right^1] += diff * len;
            mark[right^1] += diff;
            cntr += len;
        }
    }
    for (; left; left >>= 1, right >>= 1){
        arr[left] += cntl * diff;
        arr[right] += cntr * diff;
    }
}
int query(int left, int right){
    int sum = 0, len = 1, cntl = 0, cntr = 0;
    left += m - 1;
    right += m + 1;
    for (; left^right^1; left >>= 1, right >>= 1, len << 1){
        sum += cntl * mark[left] + cntr * mark[right];
        if (~left & 1) sum += arr[left^1], cntl += len;
        if (right & 1) sum += arr[right^1], cntr += len;
    }
    for (; left; left >> 1, right >> 1){
        sum += cntl * mark[left] + cntr * mark[right];
    }
    return sum;
}
void update(int l, int r, int d) {
    for (l += N - 1, r += N + 1; l ^ r ^ 1; l >>= 1, r >>= 1)
    {
        if (l < N) tree[l] = max(tree[l << 1], tree[l << 1 | 1]) + mark[l],
                    tree[r] = max(tree[r << 1], tree[r << 1 | 1]) + mark[r];
        if (~l & 1) tree[l ^ 1] += d, mark[l ^ 1] += d;
        if (r & 1) tree[r ^ 1] += d, mark[r ^ 1] += d;
    }
    for (; l; l >>= 1, r >>= 1)
        if (l < N) tree[l] = max(tree[l << 1], tree[l << 1 | 1]) + mark[l],
                    tree[r] = max(tree[r << 1], tree[r << 1 | 1]) + mark[r];
};
int query(int l, int r) {
    int maxl = -INF, maxr = -INF;
    for (l += N - 1, r += N + 1; l ^ r ^ 1; l >>= 1, r >>= 1)
    {
        maxl += mark[l], maxr += mark[r];
        if (~l & 1) cmax(maxl, tree[l ^ 1]);
        if (r & 1) cmax(maxr, tree[r ^ 1]);
    }
    for (; l; l >>= 1, r >>= 1)
        maxl += mark[l], maxr += mark[r];
    return max(maxl, maxr);
};

Leetcode. 307 範例

  1. TreeNode 變形
class NumArray {
    class SegTree {
    public:
        int val;
        int begin, end;
        SegTree* left, *right;
        SegTree(int v):val(v) {}
        SegTree(int v, int b, int e):val(v), begin(b), end(e) {}
        SegTree(int v, int b, int e, SegTree* l, SegTree* r)
            :val(v), begin(b), end(e), left(l), right(r) {}
    };
    
    SegTree* root;
    
    SegTree* build(vector<int>& nums, int b, int e){
        if (e < b) return NULL;
        if (b == e) return new SegTree(nums[b], b, b);
        int mid = b + (e-b)/2;
        SegTree* left = build(nums, b, mid);
        SegTree* right = build(nums, mid+1, e);
        return new SegTree(left->val + right->val, b, e, left, right);
    }
    
    void update(SegTree* node, int index, int val){
        if (node->begin == index && node->end == index){
            node->val = val;
        } else {
            int mid = node->begin + (node->end - node->begin)/2;
            if (index <= mid){
                update(node->left, index, val);
            } else {
                update(node->right, index, val);
            }
            node->val = node->left->val + node->right->val;
        }
    }
    int query(SegTree* node, int left, int right){
        if (node->begin == left && node->end == right){
            return node->val;
        }
        int mid = node->begin + (node->end - node->begin)/2;
        if (right <= mid){
            return query(node->left, left, right);
        } else if (left > mid){
            return query(node->right, left, right);
        }
        return query(node->left, left, mid) + query(node->right, mid+1, right);
    }
    
public:
    NumArray(vector<int>& nums) {
        root = build(nums, 0, nums.size()-1);
    }
    
    void update(int index, int val) {
        update(root, index, val);
    }
    
    int sumRange(int left, int right) {
        return query(root, left, right);
    }
};
  1. zkw 線段樹
class NumArray {
    class SegTree {
        vector<int> arr;
        int m, n;
    public:
        SegTree(vector<int>& nums) {
            n = nums.size();
            for (m = 1; m < n; m <<= 1);
            build(nums);
        }
        void build(vector<int>& nums) {
            arr.assign(2*m, 0);
            for (int i = 0; i < n; ++i) arr[m+i] = nums[i];
            for (int i = m-1; i; --i) arr[i] = arr[i<<1] + arr[i<<1|1];
        }
        void update(int index, int val) {
            int diff = val - arr[m+index];
            for (index += m; index; index >>= 1) arr[index] += diff;
        }
        int query(int left, int right) {
            int sum = 0;
            for (int i = left+m, j = right+m; i <= j; i >>= 1, j >>= 1){
                if (i & 1) sum += arr[i++];
                if (!(j & 1)) sum += arr[j--];
            }
            return sum;
        }
    };
public:
    SegTree* root;
    NumArray(vector<int>& nums) {
        root = new SegTree(nums);
    }
    
    void update(int index, int val) {
        root->update(index, val);
    }
    
    int sumRange(int left, int right) {
        return root->query(left, right);
    }
};

BIT(binary indexed tree)

class NumArray {
public:
    class Bit {
    public:
        vector<int> bit;
        int n;
        Bit(vector<int>& nums) {
            n = nums.size();
            bit.assign(n+1, 0);
            for (int i = 0; i < n; i++){
                build(i+1, nums[i]);
            }
        }
        void build(int index, int val) {
            while (index <= n){
                bit[index] += val;
                index = next(index);
            }
        }
        int next(int index) {
            return index + (index & -index);
        } 
        int parent(int index) {
            return index - (index & -index);
        }
        int getSum(int index) {
            int sum = 0;
            while (index){
                sum += bit[index];
                index = parent(index);
            }
            return sum;
        }
    };
    Bit* bit;
    NumArray(vector<int>& nums) {
        bit = new Bit(nums);
    }
    
    void update(int index, int val) {
        int diff = val - sumRange(index, index);
        bit->build(index+1, diff);
    }
    
    int sumRange(int left, int right) {
        return bit->getSum(right+1) - bit->getSum(left);
    }
};

Share this post on:

Previous
[Problem] Version Query
Next
[C++] stringstream 類範例 - split 與 concat