Given a non-empty binary search tree and a target value, find k values in the BST that are closest to the target.
Note: Given target value is a floating point. You may assume k is always valid, that is: k ≤ total nodes. You are guaranteed to have only one unique set of k values in the BST that are closest to the target. Follow up: Assume that the BST is balanced, could you solve it in less than O(n) runtime (where n = total nodes)?
Hint:
Consider implement these two helper functions:
getPredecessor(N), which returns the next smaller node to N.
getSuccessor(N), which returns the next larger node to N.
Try to assume that each node has a parent pointer, it makes the problem much easier.
Without parent pointer we just need to keep track of the path from the root to the current node using a stack.
You would need two stacks to track the path in finding predecessor and successor node separately.
Solution
/**
* Definition for a binary tree node.
* public class TreeNode {
* int val;
* TreeNode left;
* TreeNode right;
* TreeNode(int x) { val = x; }
* }
*/
public class Solution {
public List<Integer> closestKValues(TreeNode root, double target, int k) {
// two steps
LinkedList<TreeNode> ans = new LinkedList<TreeNode>();
Stack<TreeNode> leftStack = new Stack<TreeNode>();
Stack<TreeNode> rightStack = new Stack<TreeNode>();
ans.addLast(root);
leftStack.add(root);
rightStack.add(root);
// first we collect k candidates
while(ans.size() < k) {
// choose from left or right
TreeNode left = prev(leftStack, ans.peekFirst());
TreeNode right = next(rightStack, ans.peekLast());
if(left == null || right == null) {
if(left == null) ans.addLast(right);
else ans.addFirst(left);
} else if(dist(left.val, target) < dist(right.val, target)) {
ans.addFirst(left);
} else if(dist(left.val, target) > dist(right.val, target)) {
ans.addLast(right);
} else {
ans.addFirst(left);
}
}
TreeNode left = prev(leftStack, ans.peekFirst());
TreeNode right = next(rightStack, ans.peekLast());
if(left != null && dist(left.val, target) < dist(ans.peekLast().val, target)) {
while(left != null && dist(left.val, target) < dist(ans.peekLast().val, target)) {
ans.addFirst(left);
ans.pollLast();
left = prev(leftStack, ans.peekFirst());
}
} else if(right != null && dist(right.val, target) < dist(ans.peekFirst().val, target)) {
while(right != null && dist(right.val, target) < dist(ans.peekFirst().val, target)) {
ans.addLast(right);
ans.pollFirst();
right = next(rightStack, ans.peekLast());
}
}
List<Integer> res = new ArrayList<Integer>();
while(!ans.isEmpty()) res.add(ans.pollFirst().val);
return res;
}
public double dist(int val, double target) {
return Math.abs((double)val - target);
}
public TreeNode prev(Stack<TreeNode> stack, TreeNode cur) {
if(stack.isEmpty()) return null;
if(stack.peek() != cur) return stack.peek();
// assert stack.peek() == cur
if(cur.left != null) {
stack.push(cur.left);
while(stack.peek().right != null) stack.push(stack.peek().right);
} else {
stack.pop(); // pop out cur node
while(!stack.isEmpty() && stack.peek().left == cur) {
cur = stack.pop();
}
}
if(stack.isEmpty()) return null;
else return stack.peek();
}
public TreeNode next(Stack<TreeNode> stack, TreeNode cur) {
if(stack.isEmpty()) return null;
if(stack.peek() != cur) return stack.peek();
if(cur.right != null) {
stack.push(cur.right);
while(stack.peek().left != null) stack.push(stack.peek().left);
} else {
stack.pop();
while(!stack.isEmpty() && stack.peek().right == cur) {
cur = stack.pop();
}
}
if(stack.isEmpty()) return null;
else return stack.peek();
}
}