PROBLEM DESCRIPTION
Given the root of a complete binary tree, return the number of the nodes in the tree.
According to Wikipedia, every level, except possibly the last, is completely filled in a complete binary tree, and all nodes in the last level are as far left as possible. It can have between 1 and 2h nodes inclusive at the last level h.
Design an algorithm that runs in less than O(n) time complexity.
SOLUTION
THE SIMPLE WAY
Use any of the ways to traverse the binary tree and calculate the number of nodes. However, this will have a time complexity of O(n).
1
2
3
4
5
6
7
8
9
class Solution {
public int countNodes(TreeNode root) {
if(root == null) return 0;
return 1 + countNodes(root.left) + countNodes(root.right);
}
}
OPTIMIZED (USING BINARY SEARCH)
1
2
Time Complexity: O(d^2) = O(logN)
where d is the depth of the tree.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class Solution {
public int countNodes(TreeNode root) {
// Check if the root is null, meaning an empty tree
if(root == null) return 0;
// Calculate the depth of the leftmost path (height of the left subtree)
int depth = calculateDepth(root);
// Calculate the number of nodes before the last level using the formula 2^d - 1
int nodesBeforeLastLevel = (int)Math.pow(2, depth) - 1;
// Calculate the maximum number of nodes that can be in the last level (between 1 and 2^d inclusive)
int nodesLastLevel = (int)Math.pow(2, depth);
// Initialize binary search boundaries for checking presence of leaf nodes
int left = 0;
int right = nodesLastLevel-1;
int presentTillIndex = -1; // This variable will store the highest leaf index that is present
// Perform binary search to find the highest leaf index that is present
while(left <= right){
int pivot = (left + right) / 2;
// Check if the leaf node with index 'pivot' is present
if(isLeafPresent(root, depth, nodesLastLevel, pivot)){
presentTillIndex = Math.max(presentTillIndex, pivot); // Update answer
left = pivot + 1; // Move the left boundary to search for higher indices
} else {
right = pivot - 1; // Move the right boundary to search for lower indices
}
}
// Return the total number of nodes in the tree
return nodesBeforeLastLevel + (presentTillIndex + 1);
}
// Check if a leaf node with a given index is present in the last level
public boolean isLeafPresent(TreeNode root, int depth, int nodesLastLevel, int idx){
//init: left, right the range of index
int left = 0;
int right = nodesLastLevel - 1;
TreeNode t = root;
// Traverse down the tree following the path indicated by the binary representation of idx
for(int i = 0; i < depth; i++){
int pivot = (left + right) / 2;
// Move to the left child if the index is less than or equal to the pivot
if(idx <= pivot){
t = t.left;
right = pivot - 1;
} else {
// Move to the right child if the index is greater than the pivot
t = t.right;
left = pivot + 1;
}
}
// If 't' is not null, the leaf node at the given index is present
return t != null;
}
// Calculate the depth of the leftmost path in the tree
// We know that the given tree is a complete binary tree, so we can simply follow the left nodes
public int calculateDepth(TreeNode root){
int depth = 0;
TreeNode t = root;
// Traverse down the leftmost path while counting the levels
while(t.left != null){
depth++;
t = t.left;
}
return depth;
}
}