Count of Smaller Numbers After Self

2016/2/14 posted in  LeetCode comments

算法描述

You are given an integer array nums and you have to return a new counts array. The counts array has the property where counts[i] is the number of smaller elements to the right of nums[i].

Example:

Given nums = [5, 2, 6, 1]
To the right of 5 there are 2 smaller elements (2 and 1).
To the right of 2 there is only 1 smaller element (1).
To the right of 6 there is 1 smaller element (1).
To the right of 1 there is 0 smaller element.

Return the array [2, 1, 1, 0].

题目大意

给定一个数组nums,要求返回一个数组counts,其中counts数组中的第i个元素是在nums数组中位于nums[i]的右边且比nums[i]小的元素的个数。

解题思路

题目要求是计算位于数组元素右端且小于该数组元素的元素数目,因此最直观的想法是从右到左遍历数组,同时维护一个辅助数组,辅助数组下标从小到大分别代表排序后的nums中的元素。我们用辅助数组来记录已经遇到过的元素的出现次数。这样求小于该数组元素的元素数目只需要对辅助数组求前缀和即可。由于普通数组求前缀和时间复杂度为O(n),因此我们可以考虑使用树状数组来作为辅助数组,这样可以降低时间复杂度到O(log n)。关于树状数组的介绍,可以参考Fenwick Tree

另一种思路,我们可以用二分查找树(Binary Search Tree),树的每一个结点带有一个count值,表示该结点元素出现次数。我们在从右至左遍历nums数组时,同时更新BST树,counts的值可以由搜索路径所经过的结点的counts值之和得到。

解法I:Fenwick Tree

class Solution(object):
    def countSmaller(self, nums):
        """
        :type nums: List[int]
        :rtype: List[int]
        """
        result = [0] * len(nums)
        order = {}
        for i, num in enumerate(sorted(set(nums))):
            order[num] = i + 1

        tree = FenwickTree(len(nums))
        for i in xrange(len(nums) - 1, -1, -1):
            result[i] = tree.sum(order[nums[i]] - 1)
            tree.add(order[nums[i]], 1)
        return result


class FenwickTree(object):
    def __init__(self, n):
        self.sum_array = [0] * (n + 1)
        self.n = n

    def lowbit(self, x):
        return x & -x

    def add(self, x, val):
        while x <= self.n:
            self.sum_array[x] += val
            x += self.lowbit(x)

    def sum(self, x):
        ret = 0
        while x > 0:
            ret += self.sum_array[x]
            x -= self.lowbit(x)
        return ret

解法II:Binary Search Tree

class Solution(object):
    def __init__(self):
        self.root = None

    def countSmaller(self, nums):
        """
        :type nums: List[int]
        :rtype: List[int]
        """
        counts = [0] * len(nums)
        for i in range(len(nums) - 1, -1, -1):
            counts[i] = self.traverse(nums[i])
        return counts

    def traverse(self, val):
        if not self.root:
            self.root = Node(val)
            return 0
        count = 0
        p = self.root
        while p:
            if val < p.val:
                p.small_cnt += 1
                if not p.left:
                    p.left = Node(val)
                    break
                p = p.left
            elif val > p.val:
                count += p.small_cnt + p.count
                if not p.right:
                    p.right = Node(val)
                    break
                p = p.right
            else:
                count += p.small_cnt
                p.count += 1
                break
        return count


class Node(object):
    def __init__(self, val):
        self.small_cnt = 0
        self.count = 1
        self.val = val
        self.left = None
        self.right = None