[Medium] 18. 4Sum

Given an array nums of n integers, return an array of all the unique quadruplets [nums[a], nums[b], nums[c], nums[d]] such that:

  • 0 <= a, b, c, d < n
  • a, b, c, and d are distinct.
  • nums[a] + nums[b] + nums[c] + nums[d] == target

You may return the answer in any order.

Examples

Example 1:

Input: nums = [1,0,-1,0,-2,2], target = 0
Output: [[-2,-1,1,2],[-2,0,0,2],[-1,0,0,1]]

Example 2:

Input: nums = [2,2,2,2,2], target = 8
Output: [[2,2,2,2]]

Constraints

  • 1 <= nums.length <= 200
  • -10^9 <= nums[i] <= 10^9
  • -10^9 <= target <= 10^9

Clarification Questions

Before diving into the solution, here are 5 important clarifications and assumptions to discuss during an interview:

  1. Quadruplet definition: What is a valid quadruplet? (Assumption: Four distinct indices (i, j, k, l) where nums[i] + nums[j] + nums[k] + nums[l] == target)

  2. Uniqueness: Should quadruplets be unique? (Assumption: Yes - no duplicate quadruplets in result, but can use same value at different indices)

  3. Return format: What should we return? (Assumption: List of all unique quadruplets - list of lists)

  4. Order requirement: Does order of quadruplets matter? (Assumption: No - can return in any order, but typically sorted)

  5. Duplicate values: How should we handle duplicate values? (Assumption: Can use same value multiple times if at different indices, but avoid duplicate quadruplets)

Interview Deduction Process (20 minutes)

Step 1: Brute-Force Approach (5 minutes)

Use four nested loops to try all combinations of four indices (i, j, k, l). Check if nums[i] + nums[j] + nums[k] + nums[l] == target. Collect all valid quadruplets, then remove duplicates. This approach has O(n⁴) time complexity, which is too slow for arrays up to 200 elements. Duplicate removal adds additional overhead.

Step 2: Semi-Optimized Approach (7 minutes)

Use three nested loops and a hash set for the fourth value. For each combination of three indices (i, j, k), check if target - (nums[i] + nums[j] + nums[k]) exists in a hash set of remaining values. This reduces to O(n³) time complexity. However, handling duplicates and ensuring we use distinct indices requires careful bookkeeping. Sorting the array first helps with duplicate handling.

Step 3: Optimized Solution (8 minutes)

Sort the array first, then use two nested loops for the first two numbers, and two pointers for the remaining two numbers. For each pair (i, j), use left and right pointers to find pairs (k, l) such that nums[i] + nums[j] + nums[k] + nums[l] == target. Skip duplicates by advancing pointers when values repeat. This achieves O(n³) time complexity with O(1) extra space (excluding output). The key insight is that sorting enables two-pointer technique and makes duplicate skipping straightforward, providing an optimal balance between time complexity and code clarity.

Solution: Two Pointers with Nested Loops

Time Complexity: O(n³)
Space Complexity: O(1) excluding the output array

This solution extends the 3Sum approach with an additional nested loop. We use two nested loops to fix the first two numbers, then use two pointers to find the remaining two numbers that sum to the target.

class Solution:
    def fourSum(self, nums, target):
        rtn = []
        n = len(nums)

        if n < 4:
            return rtn

        nums.sort()

        for i in range(n - 3):
            if i > 0 and nums[i] == nums[i - 1]:
                continue

            for j in range(i + 1, n - 2):
                if j > i + 1 and nums[j] == nums[j - 1]:
                    continue

                left = j + 1
                right = n - 1

                while left < right:
                    total = nums[i] + nums[j] + nums[left] + nums[right]

                    if total == target:
                        rtn.append([nums[i], nums[j], nums[left], nums[right]])

                        while left < right and nums[left] == nums[left + 1]:
                            left += 1
                        while left < right and nums[right] == nums[right - 1]:
                            right -= 1

                        left += 1
                        right -= 1

                    elif total < target:
                        left += 1
                    else:
                        right -= 1

        return rtn

How the Algorithm Works

Step-by-Step Example: nums = [1,0,-1,0,-2,2], target = 0

After sorting: [-2, -1, 0, 0, 1, 2]

  1. i = 0, nums[i] = -2: Fix first number
    • j = 1, nums[j] = -1: Fix second number
      • left = 2, right = 5: -2 + (-1) + 0 + 2 = -1 < 0 → left++
      • left = 3, right = 5: -2 + (-1) + 0 + 2 = -1 < 0 → left++
      • left = 4, right = 5: -2 + (-1) + 1 + 2 = 0 → Found! [-2, -1, 1, 2]
    • j = 2, nums[j] = 0: Fix second number
      • left = 3, right = 5: -2 + 0 + 0 + 2 = 0 → Found! [-2, 0, 0, 2]
  2. i = 1, nums[i] = -1: Fix first number
    • j = 2, nums[j] = 0: Fix second number
      • left = 3, right = 5: -1 + 0 + 0 + 1 = 0 → Found! [-1, 0, 0, 1]

Final Answer: [[-2,-1,1,2],[-2,0,0,2],[-1,0,0,1]]

Key Insights

  1. Sorting First: Sorting allows us to use two pointers and skip duplicates efficiently
  2. Nested Loops: Two outer loops fix the first two numbers (i, j)
  3. Two Pointers: Inner loop uses two pointers to find the remaining two numbers
  4. Skip Duplicates: After finding a valid quadruplet, skip all duplicate values
  5. Long Long for Sum: Use long long to prevent integer overflow

Algorithm Breakdown

1. Base Case Check

def four_sum_after_setup(nums, target):
    rtn = []
    n = len(nums)
    if n < 4:
        return rtn
    # ... sort and nested loops follow

2. Sort the Array

nums.sort()

3. Outer Loop (First Number)

for i in range(n - 3):
    if i > 0 and nums[i - 1] == nums[i]:
        continue  # Skip duplicates

4. Inner Loop (Second Number)

for j in range(i + 1, n - 2):
    if j > i + 1 and nums[j - 1] == nums[j]:
        continue  # Skip duplicates

5. Two Pointers (Third and Fourth Numbers)

left, right = j + 1, n - 1
while left < right:
    total = nums[i] + nums[j] + nums[left] + nums[right]
    if total == target:
        rtn.append([nums[i], nums[j], nums[left], nums[right]])
        while left < right and nums[left] == nums[left + 1]:
            left += 1
        while left < right and nums[right] == nums[right - 1]:
            right -= 1
        left += 1
        right -= 1
    elif total < target:
        left += 1
    else:
        right -= 1

Complexity Analysis

Aspect Complexity
Time O(n³) - Two nested loops O(n²) × two pointers O(n)
Space O(1) - Excluding the output array, only using a few variables

Edge Cases

  1. Array length < 4: Return empty array
  2. All same numbers: [2,2,2,2,2], target = 8[[2,2,2,2]]
  3. No solution: Return empty array
  4. Large numbers: Use long long to prevent overflow

Why This Solution is Optimal

  1. Efficient: O(n³) is optimal for this problem (cannot do better than checking combinations)
  2. No Duplicates: Properly skips duplicates using sorted array property
  3. Early Termination: Could add early termination if nums[i] + nums[i+1] + nums[i+2] + nums[i+3] > target
  4. Space Efficient: Only uses O(1) extra space

Common Mistakes

  1. Integer Overflow: Not using long long for sum calculation
  2. Duplicate Handling: Not properly skipping duplicates after finding a match
  3. Index Bounds: Not checking n < 4 before processing
  4. Wrong Skip Condition: Using j > 0 instead of j > i + 1 for the second number

Optimization Tips

Early Termination

# Inside outer loop: skip if remaining numbers can't reach target
for i in range(n - 3):
    if nums[i] + nums[i + 1] + nums[i + 2] + nums[i + 3] > target:
        break
    if nums[i] + nums[n - 3] + nums[n - 2] + nums[n - 1] < target:
        continue

Similar Optimization for Inner Loop

# Inside inner loop (after fixing i, j)
for j in range(i + 1, n - 2):
    if nums[i] + nums[j] + nums[j + 1] + nums[j + 2] > target:
        break
    if nums[i] + nums[j] + nums[n - 2] + nums[n - 1] < target:
        continue