Search code examples
pythonpythagoreantriplet

Pythagorean Triplet with given sum


The following code prints the pythagorean triplet if it is equal to the input, but the problem is that it takes a long time for large numbers like 90,000 to answer. What can I do to optimize the following code? 1 ≤ n ≤ 90 000

def pythagoreanTriplet(n):

    # Considering triplets in
    # sorted order. The value
    # of first element in sorted
    # triplet can be at-most n/3.
    for i in range(1, int(n / 3) + 1):

        # The value of second element
        # must be less than equal to n/2
        for j in range(i + 1,
                       int(n / 2) + 1):

            k = n - i - j
            if (i * i + j * j == k * k):
                print(i, ", ", j, ", ",
                      k, sep="")
                return

    print("Impossible")
# Driver Code
vorodi = int(input())
pythagoreanTriplet(vorodi)

Solution

  • Your source code does a brute force search for a solution so it's slow.

    Faster Code

    def solve_pythagorean_triplets(n):
      " Solves for triplets whose sum equals n "
      solutions = []
      for a in range(1, n):
        denom = 2*(n-a)
        num = 2*a**2 + n**2 - 2*n*a
        if denom > 0 and num % denom == 0:
          c = num // denom
          b = n - a - c
          if b > a:
            solutions.append((a, b, c))
    
      return solutions
    

    OP code

    Modified OP code so it returns all solutions rather than printing the first found to compare performance

    def pythagoreanTriplet(n): 
      
        # Considering triplets in  
        # sorted order. The value  
        # of first element in sorted  
        # triplet can be at-most n/3. 
        results = []
        for i in range(1, int(n / 3) + 1):  
              
            # The value of second element  
            # must be less than equal to n/2 
            for j in range(i + 1,  
                           int(n / 2) + 1):  
      
                k = n - i - j 
                if (i * i + j * j == k * k):
                    results.append((i, j, k))
          
        return results
    

    Timing

     n     pythagoreanTriplet (OP Code)     solve_pythagorean_triplets (new)
      900   0.084 seconds                       0.039 seconds
      5000  3.130 seconds                       0.012 seconds
      90000 Timed out after several minutes     0.430 seconds
    

    Explanation

    Function solve_pythagorean_triplets is O(n) algorithm that works as follows.

    1. Searching for:

      a^2 + b^2 = c^2 (triplet)
      a + b + c = n   (sum equals input)
      
    2. Solve by searching over a (i.e. a fixed for an iteration). With a fixed, we have two equations and two unknowns (b, c):

      b + c = n - a
      c^2 - b^2 = a^2
      
    3. Solution is:

      denom = 2*(n-a)
      num = 2*a**2 + n**2 - 2*n*a
      if denom > 0 and num % denom == 0:
          c = num // denom
          b = n - a - c
          if b > a:
              (a, b, c) # is a solution
      
    4. Iterate a range(1, n) to get different solutions

    Edit June 2022 by @AbhijitSarkar:

    For those who like to see the missing steps:

    c^2 - b^2 = a^2
    b + c = n - a
    => b = n - a - c
    
    c^2 - (n - a - c)^2 = a^2
    => c^2 - (n - a - c) * (n - a - c) = a^2
    => c^2 - n(n - a - c) + a(n - a - c) + c(n - a - c) = a^2
    => c^2 - n^2 + an + nc + an - a^2 - ac + cn - ac - c^2 = a^2
    => -n^2 + 2an + 2nc - a^2 - 2ac = a^2
    => -n^2 + 2an + 2nc - 2a^2 - 2ac = 0
    => 2c(n - a) = n^2 - 2an + 2a^2
    => c = (n^2 - 2an + 2a^2) / 2(n - a)