Search code examples
algorithmmathequationlinear-programmingdiophantine

Efficient algorithm to generate all solutions of a linear diophantine equation with ai=1


I am trying to generate all the solutions for the following equations for a given H.

With H=4 :

1) ALL solutions for x_1 + x_2 + x_3 + x_4 =4
2) ALL solutions for x_1 + x_2 + x_3 = 4
3) ALL solutions for x_1 + x_2 = 4
4) ALL solutions for x_1 =4

For my problem, there are always 4 equations to solve (independently from the others). There are a total of 2^(H-1) solutions. For the previous one, here are the solutions :

1) 1 1 1 1
2) 1 1 2 and 1 2 1 and 2 1 1
3) 1 3 and 3 1 and 2 2
4) 4

Here is an R algorithm which solve the problem.

library(gtools)
H<-4
solutions<-NULL

for(i in seq(H))
{
    res<-permutations(H-i+1,i,repeats.allowed=T)
    resum<-apply(res,1,sum)
    id<-which(resum==H)

    print(paste("solutions with ",i," variables",sep=""))
    print(res[id,])
}

However, this algorithm makes more calculations than needed. I am sure it is possible to go faster. By that, I mean not generating the permutations for which the sums is > H

Any idea of a better algorithm for a given H ?


Solution

  • Here's an implementation in C++

    blah.cpp:

    #include <stdlib.h>
    #include <iostream>
    #include <vector>
    
    using namespace std;
    
    vector<int> ilist;
    
    void diophantine(int n)
    {
        size_t i;
        if (n==0) 
        {
            for (i=0; i < ilist.size(); i++) cout << " " << ilist[i];
            cout << endl;
        }
        else
        {
            for (i=n; i > 0; i--)
            {
                ilist.push_back(i);
                diophantine(n-i);
                ilist.pop_back();
            }
        }          
    }
    
    
    int main(int argc, char** argv)
    {
        int n;    
    
        if (argc == 2 && (n=strtol(argv[1], NULL, 10)))
        {
            diophantine(n);
        }
        else cout << "usage: " << argv[0] << " <Z+>" << endl;
    
        return 0;
    }
    


    commandline stuff:

    $ g++ -oblah blah.cpp
    $ ./blah 4
     4
     3 1
     2 2
     2 1 1
     1 3
     1 2 1
     1 1 2
     1 1 1 1
    $
    


    Here's an implementation in bash:

    blah.sh:

    #!/bin/bash
    
    diophantine()
    {
        local i
        local n=$1
        [[ ${n} -eq 0 ]] && echo "${ilist[@]}" ||
        {
            for ((i = n; i > 0; i--))
            do
                ilist[${#ilist[@]}]=${i}
                diophantine $((n-i))
                unset ilist[${#ilist[@]}-1]
            done               
        }    
    }
    
    RE_POS_INTEGER="^[1-9]+$"
    [[ $# -ne 1 || ! $1 =~ $RE_POS_INTEGER ]] && echo "usage: $(basename $0) <Z+>" ||
    {
        declare -a ilist=
        diophantine $1
    }
    exit 0
    


    Here's an implementation in Python

    blah.py:

    #!/usr/bin/python
    
    import time
    import sys
    
    
    def output(l):
        if isinstance(l,tuple): map(output,l) 
        else: print l,
    
    
    #more boring faster way -----------------------
    def diophantine_f(ilist,n):
        if n == 0:
            output(ilist)
            print
        else: 
            for i in xrange(n,0,-1):
                diophantine_f((ilist,i), n-i)
    
    
    #crazy fully recursive way --------------------
    def diophantine(ilist,n,i):
        if n == 0:
            output(ilist)
            print
        elif i > 0:
            diophantine(ilist, n, diophantine((ilist,i), n-i, n-i))
        return 0 if len(ilist) == 0 else ilist[-1]-1 
    
    
    ##########################
    #main
    ##########################
    try:
    
        if    len(sys.argv) == 1:  x=int(raw_input())
        elif  len(sys.argv) == 2:  x=int(sys.argv[1])
        else: raise ValueError 
    
        if x < 1: raise ValueError
    
        print "\n"
        #diophantine((),x,x)
        diophantine_f((),x)    
        print "\nelapsed: ", time.clock()
    
    except ValueError:
        print "usage: ", sys.argv[0], " <Z+>"
        exit(1)