Find Kth Smallest Element in an Array

Difficulty: Medium, Asked-in: Google, Microsoft, Amazon, Cisco, SAP Labs, VMWare

Key takeaways

  • This is an excellent problem to learn problem-solving using the heap data structure.
  • The quick-select algorithm is intuitive and worth exploring. It is based on the divide and conquer approach, similar to quicksort, and works in O(n) time average.

Let's understand the problem

Given an array X[] and a positive integer k, write a program to find the kth smallest element in the array.

  • It is given that all array elements are distinct.
  • We can assume that k is always valid, 1 ≤ k ≤ n.

Important note: Before moving on to the solutions, we recommend trying this problem on paper for at least 15 or 30 minutes. Enjoy problem-solving!

Examples:

Input: X[] = [4, 3, 13, 2, 12, 7, 23], k = 4

Output: 7, i.e., 7 is the 4th smallest element in the array.

Input: X[] = [-12, -8, 16, 23], k = 2

Output: -8, i.e., -8 is the 2nd smallest element in the array.

Discussed solution approaches

  • Brute force approach using sorting
  • Using the min-heap data structure
  • Using the max-heap data structure
  • Quick-select: Using divide and conquer idea similar to quick-sort

Brute force approach using sorting

Solution idea

As we know, all elements in the array are distinct. So one basic idea would be to sort the array in increasing order and directly return the kth number from the start, i.e., return X[k - 1].

Solution pseudocode

int KthSmallestArray(int X[], int n, int k)
{
    sort(X, n)
    return X[k - 1]
}

Solution analysis

Suppose we are using heap sort, which is an efficient O(nlogn) sorting algorithm. So time complexity is equal to the time complexity of the heap sort + the time complexity of accessing the kth smallest element, which is O(nlogn) + O(1) = O(nlogn).

The space complexity is O(1) because heap sort is an in-place sorting algorithm.

Using the min-heap data structure

The time complexity of the above solution is dominated by the sorting algorithm. Now the critical question is: Can we improve the time complexity further? Can we solve this problem without using sorting? Can we think of using some efficient mechanism to find min or max elements like the min-priority queue or heap data structure? Think!

Solution idea and steps

A min-heap is an array-based complete binary tree structure where the value in each node is smaller than or equal to the values in the children of that node. So the minimum element is always present at the root, i.e., X[0].

We also use the min-heap for the efficient implementation of a min-priority queue. Here are some critical min-heap operations:

  • getMin(): provides fast access to the minimum element in O(1) time.
  • deleteMin(): deletes the minimum element in O(logn) time.
  • Insert(): adds an element in O(logn) time.
  • We can build a min-heap in O(n) time using the bottom-up approach.

So how do we optimize time complexity and find the kth smallest element using the above min-heap operations? Here is an idea: we first build a min-heap of all n array elements and remove k - 1 elements by continuously performing the deleteMin() operation. After this, the kth smallest element will be present at the root of the min-heap. So we can easily get this in O(1) time by calling the getMin() operation.

Solution pseudocode

Finding kth smallest element using min-heap

int KthSmallestArray(int X[], int n, int k)
{
    MinHeap heap(X, n) 
    for (int i = 0; i < k - 1; i = i + 1)
        heap.deleteMin()    
    
      return heap.getMin()
}

Solution code C++ with min heap implementation

class MinHeap
{
private:

    int *heapArray;
    int heapCapacity;
    int heapSize;
    
    void minHeapify(int i)
    {
        int l = leftChild(i);
        int r = rightChild(i);
        int smallest = i;

        if (l < heapSize && heapArray[l] < heapArray[i])
            smallest = l;
        if (r < heapSize && heapArray[r] < heapArray[smallest])
            smallest = r;

        if(smallest != i)
        {
            swap(heapArray[i], heapArray[smallest]);
            minHeapify(smallest);
        }
    }

public:

    MinHeap(int X[], int size)
    {
        heapSize = size;
        heapArray = X;
        int i = (heapSize - 1)/2;

        while (i >= 0)
        {
            minHeapify(i);
            i = i - 1;
        }
    }

    int deleteMin()
    { 
        if (heapSize == 0)
            return INT_MAX;
        int min = heapArray[0];

        if (heapSize > 1)
        {
            heapArray[0] = heapArray[heapSize - 1];
            minHeapify(0);
        }

        heapSize = heapSize - 1;
        return min;
    }

    int parent(int i)
    {
        return (i - 1)/2;
    }

    int leftChild(int i)
    {
        return (2 * i + 1);
    }

    int rightChild(int i)
    {
        return (2 * i + 2);
    }

    int getMin()
    {
        return heapArray[0];
    }
};

int KthSmallestArray(int X[], int n, int k)
{
    MinHeap heap(X, n); 
    for (int i = 0; i < k - 1; i = i + 1)
        heap.deleteMin();    

      return heap.getMin();
}

Solution analysis

Time complexity = Time complexity of building the min-heap of size n + Time complexity of deleting k - 1 elements from the min-heap + Time complexity of accessing the kth element from the min-heap.

  • The time complexity of building the min-heap of size n = O(n).
  • After each deletion, the min-heap size will reduce by 1. So the time complexity of deleting the minimum element k - 1 time = log n + log (n - 1) + log (n - 2) + ... + log (n - k + 1) = log [(n-1)!/(n-k)!] < log(n^(k-1)) = (k - 1) log n = O(k log n). Here time complexity depends on the value of k (1 ≤ k ≤ n).
  • The time complexity of accessing the kth minimum element = O(1).

So overall time complexity = O(n) + O(k log n) + O(1) = O(n + k log n).

  • The best-case scenario occurs when k = 1, and time complexity is O(n + log n) = O(n).
  • The worst-case scenario occurs when k = n, and time complexity is O(n + n log n) = O(n log n).

The space complexity is O(1) because we can build the min-heap in place using the same array. Therefore, we are using constant extra memory. Now a critical question would be: can we optimize the solution further?

Using max-heap data structure

Solution idea and steps

Similar to a min-heap, the max-heap is an array-based complete binary tree structure where the value in each node is larger than or equal to the values in the children of that node. So the maximum element is always present at the root, i.e., X[0].

We also use a max-heap for the efficient implementation of a max-priority queue. Here are some critical max-heap operations:

  • getMax(): provides fast access to the maximum element in O(1) time.
  • deleteMax(): removes the maximum element in O(logn) time.
  • Insert(): adds an element in O(logn) time.
  • We can build a max-heap in O(n) time using the bottom-up approach.

How do we solve this problem using a max-heap? A solution insight would be: If we have a max-heap of the k smallest elements of the array, then the kth smallest element will be present at the root of the max-heap, and we can get the root value in O(1) time. But the critical question is: how do we generate the max-heap of k smallest elements in the array? Let's think!

  • We start by building a max-heap of the first k elements of the array. There is no need to allocate extra space, and we can use the same array from index i = 0 to k - 1.
  • After this process, the maximum element of the first k elements will be present at the root of the heap, i.e., X[0].
  • Now we need to track the k smallest elements using the k size max-heap. For this, we scan the array from index i = k to n - 1 and update the max-heap with all the values smaller than the root of the max-heap.
  • In other words, we run a loop from i = k to n -1 and check if (X[i] < root). If yes, then we replace the max heap root with X[i] and run the heapify() operation on the root to maintain the max-heap property. Otherwise, we ignore the current element and move to the next element in the array. Here we access the root value using the getMax() function and replace the heap root using the replaceMax(X[i]) function.
  • By the end of the above loop, the max-heap contains the top k smallest elements of the array, i.e., the kth smallest element will be present at the root. We return this value.

Solution pseudocode

Finding the kth smallest element

int KthSmallestArray(int X[], int n, int k)
{
    MaxHeap heap(X, k)
    for (int i = k; i < n; i = i + 1)
    {
        if(X[i] < heap.getMax())
            heap.replaceMax(X[i])
   }       
   return heap.getMax()
}

Solution code C++ with max heap implementation

class MaxHeap
{
    private:
        int* heapArray;
        int heapCapacity;
        int heapSize;
        
        void maxHeapify(int i)
        {
            int l = leftChild(i);
            int r = rightChild(i);
            int largest = i;

            if (l < heapSize && heapArray[l] > heapArray[i])
                largest = l;
            if (r < heapSize && heapArray[r] > heapArray[largest])
                largest = r;

            if (largest != i)
            {
                swap(heapArray[i], heapArray[largest]);
                maxHeapify(largest);
            }
        }

    public:
        MaxHeap(int X[], int size)
        {
            heapSize = size;
            heapArray = X;
            int i = (heapSize - 1) / 2;

            while (i >= 0)
            {
                maxHeapify(i);
                i = i - 1;
            }
        }

        int deleteMax()
        {
            if (heapSize == 0)
                return INT_MAX;
            int max = heapArray[0];

            if (heapSize > 1)
            {
                heapArray[0] = heapArray[heapSize - 1];
                maxHeapify(0);
            }

            heapSize = heapSize - 1;
            return max;
        }

        int parent(int i)
        {
            return (i - 1)/2;
        }

        int leftChild(int i)
        {
            return (2 * i + 1);
        } 

        int rightChild(int i)
        {
            return (2 * i + 2);
        }

        int getMax()
        {
            return heapArray[0];
        }

        void replaceMax(int value)
        {
            heapArray[0] = value;
            maxHeapify(0);
        }
};

int KthSmallestArray(int X[], int n, int k)
{
    MaxHeap heap(X, k);
    for (int i = k; i < n; i = i + 1)
    {
        if(X[i] < heap.getMax())
            heap.replaceMax(X[i]);
    }       
    return heap.getMax();
}

Solution analysis

  • The time complexity of building a k-size max-heap is O(k).
  • Now, in the worst case, at each iteration of the loop, if (X[i] < heap.getMax()) is true, we run the replaceMax(X[i]) operation every time. The time complexity of each replaceMax() operation on a k-size heap is O(log k).
  • So the overall time complexity of the (n - k) replaceMax() operations is (n - k) * O(log k) = O((n - k) log k).
  • The overall time complexity in the worst case is O(k) + O((n - k) * log k) + O(1) = O(k + (n - k) * log k) = O(k + n log k - k log k) = O(n log k), where k and k log k are lower-order terms.
  • Space Complexity = O(1); we are using constant extra space.

Quick-select approach (Divide and conquer idea similar to quick-sort)

Solution idea

Now, we will discuss an interesting quick-select approach that solves the problem efficiently by using a divide-and-conquer idea similar to the quick-sort algorithm.

The solution intuition comes from the quick-sort partition process: dividing the array into two parts around a pivot and returning the sorted array's pivot index. Elements in the array will look like this after the partition: X[l...pos-1] < pivot < X[pos+1...r]. Here the pivot element is present at index pos, and pos - l elements are smaller than the pivot. So the pivot element is the (pos - l + 1)th smallest element in the array.

  • if ((pos - l + 1) == k): the pivot is the kth smallest, and we return X[pos].
  • if ((pos - l + 1) > k): the kth smallest must be present in the left subarray.
  • if ((pos - l + 1) < k): the kth smallest must be present on the right subarray.

So from the above insight, we can develop an approach to use the partition process and find the kth smallest element recursively. But unlike the quick-sort, which processes both subarrays recursively, we process only one subarray. We recur for either the left or right side based on comparing k and the pivot position.

Solution steps

  1. We define a function KthSmallestArray(int X[], int l, int r, int k), where the left and right ends of the array are provided in the input. Initially, l = 0 and r = n - 1.
  2. The base case is the scenario of a single-element array, i.e., if l == r, return X[l].
  3. Now we partition the array X[l ... r] into two subarrays (X[l ... pos-1] and X[pos + 1 ... r]) using the partition process. Here pos is the index of the pivot returned by the partition.
  4. After the partition, we calculate the pivot order in the array, i.e., i = pos - l + 1. In other words, the pivot is the (pos - l + 1)th smallest element.
  5. If i == k, we return X[pos]. Otherwise, we determine in which of the two subarrays (X[l ... pos-1] and X[pos + 1 ... r]) the kth smallest is present.
  6. If i > k, the desired element must lie in the left subarray. We call the same function recursively for the left part, i.e., KthSmallestArray(X, l, pos-1, k).
  7. If i < k, the desired element must lie in the right subarray. We call the same function recursively for the right part, i.e., KthSmallestArray(X, pos + 1, r, k - i). Note: The desired element is the (k - i)th smallest element of the right side because we already know i number of values is smaller than the kth smallest element (left subarray and a pivot).

Solution code C++

int partition (int X[], int l, int r)
{
    int pivot = X[r];
    int i = l - 1;
    for (int j = l; j < r; j = j + 1)
    {
        if (X[j] <= pivot)
        {
            i = i + 1;
            swap (X[i], X[j]);
        }
    }
    swap (X[i + 1], X[r]);
    return i + 1;
}

int KthSmallestArray(int X[], int l, int r, int k)
{
    if(l == r)
        return X[l];
    int pos = partition(X, l, r);
    int i = pos - l + 1;
    
    if(i == k)
        return X[pos];
    else if(i > k)
        return KthSmallestArray(X, l, pos - 1, k);
    else 
        return KthSmallestArray(X, pos + 1, r, k - i);
}

int getKthSmallest(int X[], int n, int k)
{
    return KthSmallestArray(X, 0, n - 1, k);
}

Solution code Python

def partition(X, l, r):
    pivot = X[r]
    i = l - 1
    for j in range(l, r):
        if X[j] <= pivot:
            i = i + 1
            X[i], X[j] = X[j], X[i]
    X[i + 1], X[r] = X[r], X[i + 1]
    return i + 1

def KthSmallestArray(X, l, r, k):
    if l == r:
        return X[l]
    pos = partition(X, l, r)
    i = pos - l + 1
    
    if i == k:
        return X[pos]
    elif i > k:
        return KthSmallestArray(X, l, pos - 1, k)
    else: 
        return KthSmallestArray(X, pos + 1, r, k - i)

def getKthSmallest(X, n, k):
    return KthSmallestArray(X, 0, n - 1, k)

Time and space complexity analysis of the quick-select

This is a divide and conquer algorithm, where we are solving only one subproblem.

Worst-case analysis: The worst-case situation occurs when the partition is bad and highly unbalanced. There would be two scenarios: either 0 elements in the left subarray and (n - 1) elements in the right subarray or (n - 1) elements in the left subarray and 0 elements in the right subarray. So in the worst case, the algorithm always explores the larger sub-problem of size n - 1.

  • For the partition algorithm, we run a single loop with a constant operation at each iteration. So the time complexity of the partition algorithm = O(n).
  • So, the worst-case time complexity T(n) = Time complexity of the partition algorithm + Worst-case time complexity of one subproblem + Time complexity of extra operations = O(n) + T(n-1) + O(1) = T(n - 1) + O(n) = T(n - 1) + cn.
  • The solution to the above recurrence is O(n^2). We can easily solve the above recurrence using the recursion tree method. Refer to the worst-case analysis of quicksort.

The quick-select algorithm looks highly inefficient in the worst case. But the real magic would be average case analysis because the algorithm works in O(n) time complexity on average. How? Let's think.

Average case analysis: We assume that the partition process chooses the pivot randomly. In other words, all possibilities of partition are equally likely, and the probability of occurring is the worst case in 1/n.

  • The input size of the left subproblem = i - 1, Input size of the right subproblem = n - i. Here the value of i can be in the range of 1 to n. The time complexity of the left subproblem = T(i - 1), and the time complexity of the right subproblem = T(n - i).
  • We solve only one subproblem recursively, so let's take the upper bound: the time needed for the recursive call on the largest possible input. In other words, to obtain an upper bound, we assume that the kth smallest element is always on the side of the partition with the greater number of elements. So the time complexity of the smaller subproblem = T(max(i - 1, n - i)), where 1 <= i <= n.
  • For calculating the average case, we explore all n possible values of T(max(i - 1, n - i)) and divide it by n. The recurrence relation for the average case is:

T(n) = 1/n * [ i = 1 to n ∑ T(max(i - 1, n - i)) ] + O(n)

  • If i > n/2, max(i - 1, n - i) = i - 1.
  • If i <= n/2, max(i - 1, n - i) = n - i.

So, from i = 1 to n, each term T(i) appears twice in the above formula (think!). We can also place cn in place of O(n). So, we can simplify the above formula.

T(n) <= 2/n [ i = n/2 to n-1 ∑ T(i) ] + cn

We can solve this using the “guess and check” or substitution method based on our intuition. Let's assume that T(n) is a linear function, i.e., T(n) ≤ pn, where p is a constant.

T(n) <= 2/n [ i = n/2 to n-1 ∑ pi ] + cn

Let's further simplify the expression on the right-hand side.

2/n (i = n/2 to n-1 ∑pi) + cn
= 2p/n (i = n/2 to n-1 ∑ i) + cn
= 2p/n [( i = 0 to n-1 ∑ i) - (i = 0 to n/2-1 ∑ i)] + cn
= 2p/n [n(n-1)/2 - (n/2 - 1)(n/2 - 2)/2] + cn
= p/n [n(n-1) - (n/2 - 1)(n/2 - 2)] + cn
= p/n [n^2 - n - n^2/4 + n + n/2 - 2] + cn
= p/n [3n^2/4 + n/2 - 2] + cn
= p [3n/4 + 1/2 - 2/n] + cn
= 3pn/4 + p/2 - 2p/n + cn
= pn - (pn/4 - cn - p/2)

To complete the proof, we need to show that for sufficiently large n, pn - (pn/4 - cn - p/2) is at most pn, or equivalently (pn/4 - cn - p/2) > 0 => n(c- p/4) > p/2 => n > (p/2)/(c- p/4). (p/2)/(c- p/4) is a small constant, so our guess would be correct for n > (p/2)/(c- p/4). Therefore, the average case time complexity of the quick select algorithm T(n) = pn = O(n).

Regarding space complexity, the worst-case situation occurs when the partition is bad, and the height of the recursion tree will be O(n). In such a scenario, the recursion decreases by 1 and allocates O(n) stack space. But in the average case, space complexity would be O(logn). (Think!)

Critical ideas to think!

  • Can we improve the time complexity of the fourth approach to O(n) in the worst-case scenario? Explore the median finding algorithm for finding the kth smallest element.
  • How do we modify the above algorithms to handle duplicates?
  • In the first approach, do we need to sort the entire array?
  • What would be other ways to perform the partition of an array?
  • Why is the average case space complexity O(logn) in the fourth approach?
  • Can this problem be solved using a Binary Search Tree? What would be the complexity in that case?
  • Why is Heap preferred over BST for the Priority Queue implementation?
  • In the third approach, why do we ignore elements larger than the root in the second part of the algorithm? What would be the worst and best-case input?

Time and space complexities comparison

  • Using sorting: Time = O(nlogn), Space = O(1)
  • Using min-heap: Time = O(n + klogn), Space = O(1)
  • Using max-heap: Time = O(nlogk), Space = O(1)
  • Using Quick-Select: Time = O(n) average, Space = O(logn) average

Similar coding questions to practice

  • Find the Kth largest element.
  • Find the k smallest elements in an array.
  • Median of two sorted arrays of the same size.
  • The kth missing element in an unsorted array.
  • Find the kth largest element in a stream.
  • Kth Smallest Element in a Sorted Matrix.
  • K smallest elements in an array.
  • Kth element of two sorted arrays.
  • Find the top k frequent elements.

If you have any queries/doubts/feedback, please write us at contact@enjoyalgorithms.com. Enjoy learning, Enjoy algorithms!

Share Your Insights

☆ 16-week live DSA course
☆ 16-week live ML course
☆ 10-week live DSA course

More from EnjoyAlgorithms

Self-paced Courses and Blogs

Coding Interview

Machine Learning

System Design

Our Newsletter

Subscribe to get well designed content on data structure and algorithms, machine learning, system design, object orientd programming and math.