Closest Binary Search Tree Value II

TreeStackQuick Select
https://leetcode.com/problems/closest-binary-search-tree-value-ii

# Definition for a Binary Tree Node

class TreeNode:
    def __init__(self, val=0, left=None, right=None):
        self.val = val
        self.left = left
        self.right = right
1
2
3
4
5

# Solution

Let hh be the height and nn be the number of nodes of tree.

# Sort

Complexity

time: O(nlogn)O(n \log n) (O(n)O(n) to build inorder traversal and O(nlogn)O(n \log n) to sort it)
space: O(n)O(n) (store nums)

def closestKValues(self, root: TreeNode, target: float, k: int) -> List[int]:
    def inorder(root: TreeNode):
        return inorder(root.left) + [root.val] + inorder(root.right) if root else []

    nums = inorder(root)
    nums.sort(key=lambda x: abs(x-target))
    return nums[:k]
1
2
3
4
5
6
7

# Heap

Complexity

time: O(nlogk)O(n\log k) (push nn elts into the heap of size kk)
space: O(k+h)O(k + h) (heap of kk elements and recursion stack of tree height hh)

from heapq import heappush, heappop
def closestKValues(self, root: TreeNode, target: float, k: int) -> List[int]:
    def inorder(root: TreeNode):
        if not root: return
        inorder(root.left)
        heappush(heap, (-abs(root.val-target), root.val))
        # keep the size of heap always to k
        if len(heap) > k:
            heappop(heap)
        inorder(root.right)

    heap = []
    inorder(root)
    return [v for _, v in heap]
1
2
3
4
5
6
7
8
9
10
11
12
13
14

# Inorder Predecessor & Successor

Time complexity of getPredecessor and getSuccessor is amortized O(1) because it's just part of the inorder traversal. If it's O(logn), the your inorder traversal is O(nlogn).

Complexity

time: O(n)O(n) (worst case: O(nlogn)O(n \log n), getPredecessor and getSuccessor take amortized O(1)O(1) but worst time is O(logn)O(\log n))
space: O(n)O(n)

def closestKValues(self, root: TreeNode, target: float, k: int) -> List[int]:
    res = []
    preStack = []
    sucStack = []

    while root:
        if root.val < target:
            preStack.append(root)
            root = root.right
        else:
            sucStack.append(root)
            root = root.left

    def getPredecessor(stack):
        if stack:
            pre = stack.pop()
            p = pre.left
            while p:
                stack.append(p)
                p = p.right
            return pre

    def getSuccessor(stack):
        if stack:
            suc = stack.pop()
            p = suc.right
            while p:
                stack.append(p)
                p = p.left
            return suc

    pre = getPredecessor(preStack)
    suc = getSuccessor(sucStack)

    while k:
        k -= 1
        if pre and not suc:
            res.append(pre.val)
            pre = getPredecessor(preStack)
        elif not pre and suc:
            res.append(suc.val)
            suc = getSuccessor(sucStack)
        elif pre and suc and abs(pre.val - target) <= abs(suc.val - target):
            res.append(pre.val)
            pre = getPredecessor(preStack)
        elif pre and suc and abs(pre.val - target) >= abs(suc.val - target):
            res.append(suc.val)
            suc = getSuccessor(sucStack)
    return res
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49

# Quick Select

The following partition uses Lomuto's scheme.

Complexity

time: O(n)O(n) (worst case: O(n2)O(n^2))
space: O(n)O(n)

from random import randint
def closestKValues(self, root: TreeNode, target: float, k: int) -> List[int]:
    def inorder(r: TreeNode):
        return inorder(r.left) + [r.val] + inorder(r.right) if r else []

    def partition(pivot_idx, left, right):
        pivot_dist = dist(pivot_idx)

        # 1. move pivot to end
        nums[right], nums[pivot_idx] = nums[pivot_idx], nums[right]
        store_idx = left

        # 2. move more close elements to the left
        for i in range(left, right):
            if dist(i) < pivot_dist:
                nums[i], nums[store_idx] = nums[store_idx], nums[i]
                store_idx += 1

        # 3. move pivot to its final place
        nums[right], nums[store_idx] = nums[store_idx], nums[right]

        return store_idx

    def quickselect(left, right):
        """
        Sort a list within left..right till kth less close element takes its place.
        """
        # base case: the list contains only one element
        if left == right:
            return

        # select a random pivot_index
        pivot_idx = randint(left, right)

        # find the pivot position in a sorted list
        true_idx = partition(pivot_idx, left, right)

        # if the pivot is in its final sorted position
        if true_idx == k:
            return

        if true_idx < k:
            # go left
            quickselect(true_idx, right)
        else:
            # go right
            quickselect(left, true_idx)

    nums = inorder(root)
    dist = lambda idx : abs(nums[idx] - target)
    quickselect(0, len(nums) - 1)
    return nums[:k]
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52