Given an integer array nums, return the number of range sums that lie in [lower, upper] inclusive. Range sum S(i, j) is defined as the sum of the elements in nums between indices i and j (i ≤ j), inclusive.
Note: A naive algorithm of O(n2) is trivial. You MUST do better than that.
Example: Given nums = [-2, 5, -1], lower = -2, upper = 2, Return 3. The three ranges are : [0, 0], [2, 2], [0, 2] and their respective sums are: -2, -1, 2.
Solution
public class Solution {
public int countRangeSum(int[] nums, int lower, int upper) {
long[] sum = new long[nums.length+1];
for(int i = 0; i < nums.length; i++) {
sum[i + 1] = sum[i] + nums[i];
}
return splitMerge(sum, lower, upper, 0, sum.length - 1);
}
public int splitMerge(long[] nums, int lower, int upper, int i, int j) {
if(i >= j) return 0;
int mid = (i + j)/2;
int res = 0;
res += splitMerge(nums, lower, upper, i, mid);
res += splitMerge(nums, lower, upper, mid+1, j);
// deal with the two range
int low = mid+1, high = mid+1; // both indice of the right array
for(int k = i; k <= mid; k++) {
// find the first element from the right, that nums[low] - nums[k] >= lower
while(low <= j && nums[low] - nums[k] < lower) {
low++;
}
// find the first element from the right, that nums[high] - nums[k] > upper
while(high <= j && nums[high] - nums[k] <= upper) {
high++;
}
res += high - low;
}
// merge sort
long[] sorted = new long[j - i + 1];
int idx1 = i, idx2 = mid+1, cnt = 0;
while(idx1 != mid+1 || idx2 != j + 1) {
if(idx1 == mid+1) {
sorted[cnt++] = nums[idx2++];
} else if(idx2 == j+1) {
sorted[cnt++] = nums[idx1++];
} else {
if(nums[idx1] < nums[idx2]) {
sorted[cnt++] = nums[idx1++];
} else {
sorted[cnt++] = nums[idx2++];
}
}
}
for(int k = i; k <= j; k++) {
nums[k] = sorted[k - i];
}
return res;
}
}