LeetCode 493 - Reverse Pairs

作者 QIFAN 日期 2017-03-22
LeetCode 493 - Reverse Pairs

原题链接: 493. Reverse Pairs


题干

给定一个数组 nums ,reverse pair 的定义是数组中的对于两个位置(i, j)i < jnums[i] > nums[j] * 2 ,返回 nums 中 reverse pair 的数量

思路

受 tag 启发用二叉树做。树节点包括了值 val ,频数 freq ,右边子树的总频数 rightNum ,以及左右子树节点 left right 。遍历数字逐个建立二叉搜索树,对每一个元素 num ,先将 num 插入树中,往右走时更新节点 rightNum++ 。然后搜索比 target = num * 2 大的节点树 res , 搜索的过程中:
往左走时,res += update(左节点) + freq + rightNum
往右走时,res += update(右节点)
相等时,res += rightNum

这个思路做的出来,但是在最差情况数组递增或递减时时间复杂度为 $O(N^2)$ ,因为树是完全平的。我最后也没有找到改进的方法。

后来看了 discussion 的提示用类似 merge sort 的方法。这就很厉害了。两个将要 merge 的区块分别是有序的,而且可以保证左边的区块里数字的原位置一定在右边所有数字的前边,所以和 mergesort 唯一的一点改动就是在 merge 前遍历右边数组元素,并加总左边数字中除以二大于该数字的数量的和(此处可用二叉查找)。
两个分别有序的数组 a 和 b ,对于每一个 b[i],计算满足 a[j] > b[i]/2 的对数。

最后时间复杂度为 $O(NlogN)$

坑:Integer Overflow

代码:

int res = 0;
int MAX = Integer.MAX_VALUE / 2;
int MIN = Integer.MIN_VALUE / 2;
public int reversePairs(int[] nums) {
if (nums.length < 2) {
return 0;
}
int[] temp = new int[nums.length];
mergeSort(nums, temp, 0, nums.length - 1);
return res;
}
private void mergeSort(int[] nums, int[] temp, int lo, int hi) {
if (lo >= hi) {
return;
}
int mid = lo + (hi - lo) / 2;
mergeSort(nums, temp, lo, mid);
mergeSort(nums, temp, mid + 1, hi);
calcReversePair(nums, lo, mid, hi);
int li = lo, i = lo, ri = mid + 1;
while (li <= mid && ri <= hi) {
while (li <= mid && nums[li] <= nums[ri]) {
temp[i++] = nums[li++];
}
while (ri <= hi && nums[li] > nums[ri]) {
temp[i++] = nums[ri++];
}
}
while(li <= mid) {
temp[i++] = nums[li++];
}
while (ri <= hi) {
temp[i++] = nums[ri++];
}
while (i > lo) {
nums[--i] = temp[i];
}
}
private void calcReversePair(int[] nums, int lo, int mid, int hi) {
int lastPos = lo;
for (int i = mid + 1; i <= hi; i++) {
if (nums[i] > MAX) {
continue;
} else if (nums[i] < MIN) {
res += mid - lo + 1;
} else {
int target = nums[i] * 2;
int offset = binarySearchGreater(nums, lastPos, mid, target);
if (offset > mid) {
break;
}
res += mid - offset + 1;
lastPos = offset;
}
}
}
private int binarySearchGreater(int[] nums, int lo, int hi, int target) {
if (lo > hi) {
return hi + 1;
}
int mid = lo + (hi - lo) / 2;
if (nums[mid] <= target) {
if (mid < nums.length - 1 && nums[mid + 1] > target) {
return mid + 1;
}
return binarySearchGreater(nums, mid + 1, hi, target);
}
return binarySearchGreater(nums, lo, mid - 1, target);
}