Search code examples

Feige-Fiat-Shamir identification protocol Java

I'm trying to implement Feige-Fiat-Shamir identification scheme as it is described in book "Handbook of Applied Cryptography"(page 410, section 10.4.2). I have a code, but the problem is than sometimes it succeeds but sometimes it fails. Can anybody help me to find a mistake in this code? Thank you.

    public static void main(String[] args) throws Exception {

    BigInteger p = BigInteger.probablePrime(16, new Random());
    BigInteger q = BigInteger.probablePrime(16, new Random());
    int k = 10;  // Receive k

    BigInteger trustedN = p.multiply(q);

    List<BigInteger> randomInts = new ArrayList<>();    //s1,
    BitSet randomBits = new BitSet(k);  // b1,b2..bk
    List<BigInteger> listV = new ArrayList<>();

    Random rand = new Random();

    Choose k positive numbers less than trustedN.
    Choose k bits 0 or 1

    for (int i = 0; i < k; i++) {
        // Generate random big ints less than trustedN
        randomInts.add(new BigInteger(trustedN.bitLength() + 1, rand).mod(trustedN));

        randomBits.set(i, rand.nextBoolean());
        // (-1)^bi
        BigInteger minus1pow = (((new BigInteger("-1")).pow(randomBits.get(i) ? 1 : 0))).mod(trustedN);

        // (s^2)^(-1)
        BigInteger randomIntPow = (randomInts.get(i).pow(2)).modInverse(trustedN);

        // vi = (-1)^bi * (s^2)^(-1)

    // Random r
    BigInteger randomR = new BigInteger(trustedN.bitLength() + 1, rand).mod(trustedN);
    // Random bit index
    int bitIndex = rand.nextInt(randomBits.length() + 1);

    // Calculate x
    BigInteger x = ((new BigInteger("-1")).pow(randomBits.get(bitIndex) ? 1 : 0).mod(trustedN)).multiply((randomR.pow(2)).mod(trustedN)).mod(trustedN);

    // Let pretend it was randomly selected vector (e1,e2,e3...)
    String eBits = "1100011010";

    BigInteger totalMultS = new BigInteger("1");
    for (int i = 0; i < k; i++) {
        totalMultS = totalMultS
                .multiply(randomInts.get(i).pow(eBits.charAt(i) == '1' ? 1 : 0));

    totalMultS = totalMultS.mod(trustedN).multiply(randomR.mod(trustedN)).mod(trustedN);
    BigInteger y = totalMultS;

    BigInteger totalMultV = new BigInteger("1");

    for (int i = 0; i < k; i++) {
        totalMultV = totalMultV
                .multiply(listV.get(i).pow(eBits.charAt(i) == '1' ? 1 : 0));

    totalMultV = totalMultV.mod(trustedN);
    BigInteger z = (y.pow(2).mod(trustedN)).multiply(totalMultV).mod(trustedN);

    if (z.toString().equals(x.toString())){
    else {

        System.out.println("x: " + x.toString());
        System.out.println("z: " + z.toString());




  • I've found the solution. The problem was in the last condition. See the code:

        public static void main(String[] args) throws Exception {
        BigInteger p = BigInteger.probablePrime(4, new Random());
        BigInteger q = BigInteger.probablePrime(4, new Random());
        System.out.println("p: " + p.toString());
        System.out.println("q: " + q.toString());
        int k = 3;  // Receive k
        BigInteger trustedN = p.multiply(q);
        System.out.println("n: " + trustedN.toString());
        List<BigInteger> randomInts = new ArrayList<>();    //s1,
        BitSet randomBits = new BitSet(k);  // b1,b2..bk
        List<BigInteger> listV = new ArrayList<>();
        Random rand = new Random();
        Choose k positive numbers less than trustedN.
        Choose k bits 0 or 1
        System.out.print("random s: ");
        for (int i = 0; i < k; i++) {
            BigInteger si = new BigInteger(trustedN.bitLength() + 1, rand).mod(trustedN);
            while (si.gcd(trustedN).intValue() != 1){
                si = new BigInteger(trustedN.bitLength() + 1, rand).mod(trustedN);
            // Generate random big ints less than trustedN
            randomBits.set(i, rand.nextBoolean());
            // (-1)^bi
            System.out.print(randomInts.get(i) + " " + randomBits.get(i) + " ");
            BigInteger minus1pow = (((new BigInteger("-1")).pow(randomBits.get(i) ? 1 : 0)));
            // (s^2)^(-1)
            BigInteger randomIntPow = minus1pow.multiply(randomInts.get(i).pow(2)).modInverse(trustedN);
            // vi = (-1)^bi * (s^2)^(-1)//            listV.add((minus1pow.multiply(randomIntPow)).mod(trustedN));
        System.out.print("\nlist v: ");
        for (BigInteger bi:
             listV) {
            System.out.print(bi.toString() + " ");
        // Random r
        BigInteger randomR = new BigInteger(trustedN.bitLength() + 1, rand).mod(trustedN);
        System.out.println("r: " + randomR.toString());
        // Random bit index
        int bitIndex = (int) (Math.random() * ( randomBits.length()  ));
        System.out.println("bitIndex: " + bitIndex + " bit value: " + randomBits.get(bitIndex));
        // Calculate x//        BigInteger x = ((new BigInteger("-1")).pow(randomBits.get(bitIndex) ? 1 : 0).mod(trustedN)).multiply((randomR.pow(2)).mod(trustedN)).mod(trustedN);
        BigInteger x = (((new BigInteger("-1")).pow(randomBits.get(bitIndex) ? 1 : 0)).multiply((randomR.pow(2)))).mod(trustedN);
        // Let pretend it was randomly selected vector (e1,e2,e3)
        String eBits = "100";
        BigInteger totalMultS = new BigInteger("1");
        for (int i = 0; i < k; i++) {
            totalMultS = totalMultS
                    .multiply(randomInts.get(i).pow(eBits.charAt(i) == '1' ? 1 : 0));
        BigInteger y = totalMultS.multiply(randomR.mod(trustedN)).mod(trustedN);
        System.out.println("y: " + y.toString());
        BigInteger totalMultV = new BigInteger("1");
        for (int i = 0; i < k; i++) {
            totalMultV = totalMultV
                    .multiply(listV.get(i).pow(eBits.charAt(i) == '1' ? 1 : 0));
        System.out.println("total mult v: " + totalMultV);
        if ((z.toString().equals(x.toString()) || z.toString().equals(x.negate().mod(trustedN).toString()))
                && !z.toString().equals("0")){
        else {
            System.out.println("x: " + x.toString());
            System.out.println("z: " + z.toString());