The Blind 75 Leetcode Series: 3Sum

Jonathan Chao
6 min readApr 2, 2022
Photo by Chris Ried on Unsplash

Today, we will be working on 3Sum

Given an integer array nums, return all the triplets [nums[i], nums[j], nums[k]] such that i != j, i != k, and j != k, and nums[i] + nums[j] + nums[k] == 0.

Notice that the solution set must not contain duplicate triplets.

The question itself is clear enough, some enhancements would be

  1. do we worry about empty list/array?
  2. do we worry about a list that has fewer than 3 integers? (so the conditions will always be invalid)

If both answers are yes, then we want to take care of these conditions as well. Some short-circuiting return will work.

Initial approach is still the brute-force approach.

Notice the last line of the question asks for a dedupe of triplets. With the brute-force approach, we can actually use one trick: sorting. If the given list is [-1,0,1,2,-1,-4] , [-1, 0, 1] and [0, 1, -1] are actually going to be duplicates. If we sort it first, though, the list becomes [-4, -1, -1, 0, 1, 2] , notice that we would always have a subset that is sorted in ascending order, so we can easily dedupe it by adding a set on top.

Now, the brute-force approach would look like this:

def threeSum(nums):
if len(nums) < 3:
return []
n = len(nums)
results = set()
nums.sort()
for i in range(n-2):
for j in range(i+1, n-1):
for k in range(j+1, n):
if nums[i] + nums[j] + nums[k] == 0:
results.add((nums[i], nums[j], nums[k]))
return results

Be aware of the beginning and end of indices for each for-loop. j always starts after i, and k after j. i should leave 2 spaces for j and k so the index doesn’t go out of range. This approach checks every combination of the sorted nums, so we loop through the list 3 times. This produces an O(n³) time complexity. To be this in code’s perspective:

def threeSum(nums):
if len(nums) < 3:
return []
n = len(nums)
results = set()
nums.sort() # O(n logn)
for i in range(n-2): # O(n)
for j in range(i+1, n-1): # O(n)
for k in range(j+1, n): # O(n)
if nums[i] + nums[j] + nums[k] == 0: # O(1)
results.add((nums[i], nums[j], nums[k]))
return results

Definitely room for improvement.

If you have been following this series, you might remember the other similar question we worked on: The Two Sum.

We are simply adding one more number, so is it possible to dissect the problem and fit a sub-problem into the 2Sum scheme?

Of course! Since we have 3 for-loops, we can keep 1 for loop and use the 2Sum strategy to reduce the time complexity.

Here we are using the 2 pointers approach. Keep the first for-loop with i, set j=i+1 and k=len(nums)-1 to be the left and right pointers. Then we sum them up. If the result is too small, we know we want to move the left pointer. If the result is too large, we move the right one 1 step to the left.

Now, in the 2Sum question, we return immediately when we find an answer. Not this time. We need to keep going, so we append the triplet to the result set, and move both j and k toward the middle.

The code then will look like

def threeSum(nums):
if len(nums) < 3:
return []
n = len(nums)
results = set()
nums.sort()
for i in range(n-2):
# here, we enter the 2Sum
j, k = i + 1, len(nums) - 1
while j < k:
if nums[i] + nums[j] + nums[k] == 0:
results.add((nums[i], nums[j], nums[k]))
j += 1
k -= 1
elif nums[i] + nums[j] + nums[k] > 0:
k -= 1
else:
j += 1
return results

Now, besides the short-circuit return, we can also optimize it a bit more. Let’s say we have a sorted list [-1, -1, -1, -1, -1, -1, 0, 0, 1] , we would only get [-1, 0 ,1] as the only answer, so essentially we can skip all these duplicated numbers. When we found an answer to add to the result set, we can check the subsequent numbers and see if they are the same. If they are, there’s no point checking for the sum again. Simply skip them!

def threeSum(nums):
if len(nums) < 3:
return []
n = len(nums)
results = set()
nums.sort()
for i in range(n-2):
j, k = i + 1, len(nums) - 1
while j < k:
if nums[i] + nums[j] + nums[k] == 0:
results.add((nums[i], nums[j], nums[k]))
# here's the fast forward part
while j < k and nums[j] == nums[j+1]:
j += 1
while j < k and nums[k] == nums[k-1]:
k -= 1
j += 1
k -= 1
elif nums[i] + nums[j] + nums[k] > 0:
k -= 1
else:
j += 1
return results

Last but not least, let’s check our time complexity one more time.

def threeSum(nums):
if len(nums) < 3:
return []
n = len(nums)
results = set()
nums.sort() # still O(n logn)
for i in range(n-2): # O(n)
j, k = i + 1, len(nums) - 1
while j < k: # O(n) as we discussed in 2Sum
if nums[i] + nums[j] + nums[k] == 0:
results.add((nums[i], nums[j], nums[k]))
# here's the fast forward part
while j < k and nums[j] == nums[j+1]:
j += 1
while j < k and nums[k] == nums[k-1]:
k -= 1
j += 1
k -= 1
elif nums[i] + nums[j] + nums[k] > 0:
k -= 1
else:
j += 1
return results

The overall time complexity became O(n²). We essentially avoided one additional loop!

But wait! Is there another solution? I found a solution in the discussion channel that is quite interesting. It’s very directed to this problem instead of a generic approach like the 2Sum. I cannot claim credit for this. Here’s the link to the solution. I’ll do my best to explain.

Because the problem specifically states that the sum of the triplet should be zero, this creates an interesting case where we can separate the list of integers into a set of positive numbers, a set of negatives, and the zeros.

We then can check several different scenarios: the case of (-x, 0, x); the case of (0, 0, 0), the case of (-x1, -x2, x3), and the case of (-x1, x2, x3) where the variables with attached is a negative number.

Let’s start with the preparation:

results = set()
positives, negatives, zeros = [], [], []
for num in nums:
if num < 0:
negatives.append(num)
elif num > 0:
positives.append(num)
else:
zeros.append(num)
P = set(positives) # to get constant lookup time
N = set(negatives)

Now we have 3 sets of numbers, and we can check each scenario listed above:

# all zeros (0, 0, 0), check if there are 3 zeros in the list
if len(zeros) >= 3:
results.add((0, 0, 0))
# (-x, 0, x)
if zeros:
for n in positives:
if -n in N:
results.add((-n, 0, n))
# (-x1, -x2, x3)
for i, n in enumerate(negatives):
for j in range(i+1, len(negatives)):
target = -(n + negatives[j])
if target in P:
results.add(tuple(sorted([n, negatives[j],target]))) # sort this to avoid cases like (-2, -1, 3) and (-1, -2, 3) since we did not sort the entire list in this approach
# similarly, for (x1, x2, -x3)
for i, n in enumerate(positives):
for j in range(i+1, len(positives)):
target = -(n + positives[j])
if target in N:
results.add(tuple(sorted([n, positives[j],target])))
return results

This approach, although more lengthy, is actually producing a faster result on Leetcode. It is using more memories and the time complexity is still O(n²), but with repeated number and balanced positive list and negative list, the worst-case of O(n²) is reduced by a lot.

That’s it! Another problem down!

Buy my a coffee: https://www.buymeacoffee.com/jonathanckz

--

--

Jonathan Chao

I am a software developer who has been in this industry for close to a decade. I share my experience to people who are or want to get into the industry