Search code examples
javagorsa

Java RSA/ECB/OAEPWITHSHA-256ANDMGF1PADDING Migrate To Go


I want to migrate a code from Java to Go, these methods should be able to cipher/decipher each others output, but they generate different results and can not decrypt other ones cipher:

java Code

public static byte[] encrypt(byte[] data, PublicKey publicKeyObject)
        throws BadPaddingException, IllegalBlockSizeException,
        InvalidKeyException, NoSuchPaddingException,
        NoSuchAlgorithmException {
    Cipher cipher = Cipher
            .getInstance("RSA/ECB/OAEPWITHSHA-256ANDMGF1PADDING");

    OAEPParameterSpec oaepParameterSpec = new OAEPParameterSpec("SHA-256",
            "MGF1", MGF1ParameterSpec.SHA1, PSource.PSpecified.DEFAULT);
    try {
        cipher.init(Cipher.ENCRYPT_MODE, publicKeyObject,
                oaepParameterSpec);
    } catch (InvalidAlgorithmParameterException e) {
        e.printStackTrace();
        return null;
    }
    return cipher.doFinal(data);
}

private static byte[] decrypt(byte[] data, PrivateKey privateKeyObj)
        throws NoSuchPaddingException, NoSuchAlgorithmException,
        InvalidKeyException, BadPaddingException,
        IllegalBlockSizeException {
    Cipher cipher = Cipher
            .getInstance("RSA/ECB/OAEPWITHSHA-256ANDMGF1PADDING");

    OAEPParameterSpec oaepParameterSpec = new OAEPParameterSpec("SHA-256",
            "MGF1", MGF1ParameterSpec.SHA1, PSource.PSpecified.DEFAULT);
    try {
        cipher.init(Cipher.DECRYPT_MODE, privateKeyObj, oaepParameterSpec);

    } catch (InvalidAlgorithmParameterException e) {
        e.printStackTrace();
        return null;
    }

    return cipher.doFinal(data);
}

Go Code

rng := rand.Reader

ciphertext, err := rsa.EncryptOAEP(sha256.New(), rng, rsaPublicKey, secretMessage, label)
if err != nil {
    fmt.Printf("Error from encryption: %s\n", err)
    return
}
clearText, err := rsa.DecryptOAEP(sha256.New(), rng, rsaPrivateKey, ciphertext, label)
if err != nil {
    fmt.Printf("Error from decryption: %s\n", err)
    return
}

I even tried sha1 as first parameter of Go but the results are difference.

can anybody help me with this?


Solution

  • The reason why both codes are not compatible has already been explained in the comments and the other answer: Go's crypto/rsa package, unlike the Java code, does not allow the separate specification of OAEP digest and MGF1 digest, resulting in different MGF1 digests being used in both codes.
    In contrast, this answer should focus on the adaptation of the crypto/rsa package to fix the issue.

    The meaning of both digests is described in RFC 8017, more precisely in section 7.1 RSAES-OAEP, where OAEP is defined.
    As options and input parameters a digest (OAEP digest), a mask generation function, a label, the message and the public key are specified, s. 7.1.1 Encryption Operation. The OAEP digest is used to hash the label, s. 7.1.1, Step 2a. As mask generation function RFC 8017 defines MGF1 exclusively (s. B.2.1 MGF1), which is therefore generally used in OAEP. MGF1 is based on a digest (MGF1 digest).
    RFC 8017 specifies the following default values, s. A.2.1 RSAES-OAEP: MGF1, SHA1 for the OAEP and MGF1 digests, and an empty label.

    Although SHA-1 is now deemed insecure, there are no known insecurities in the context of OAEP, s. here. Nevertheless, SHA256 is meanwhile often used, as a preventive measure or in the course of eliminating SHA-1 from the ecosystem.
    Furthermore, RFC 8017 does not exclude the use of different digests for the OAEP and the MGF1 digest, as in your example.
    Therefore, an implementation should allow the independent specification of both digests, which the crypto/rsa package fails to do.

    To allow both digests to be specified separately, a second parameter must be used in the EncryptOAEP() and DecryptOAEP() functions to pass the MGF1 digest, which is then applied to MGF1:

    func EncryptOAEP(hash hash.Hash, hashMGF1 hash.Hash, random io.Reader, pub *rsa.PublicKey, msg []byte, label []byte) ([]byte, error) {
        ...
        hashMGF1.Reset()
        mgf1XOR(seed, hashMGF1, db) 
        mgf1XOR(db, hashMGF1, seed) 
        ...
    }
    

    and analogously for DecryptOAEP().

    This could best be implemented with a corresponding adjustment of the crypto/rsa package itself. Alternatively, as a workaround, the required functions can be copied from the crypto/rsa package and adapted, as in the following code inclusive test:

    package main
    
    import (
        "crypto/rand"
        "crypto/subtle"
        "crypto/rsa"
        "crypto/sha256"
        "crypto/sha1"
        "crypto/x509"
        "encoding/pem"
        "encoding/base64"
        "hash"
        "errors"
        "io"
        "math/big"
        "sync"
        "fmt"
        )
    
    func main() {
    
        var publicKeyData = `-----BEGIN PUBLIC KEY-----
    MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAoZ67dtUTLxoXnNEzRBFB
    mwukEJGC+y69cGgpNbtElQj3m4Aft/7cu9qYbTNguTSnCDt7uovZNb21u1vpZwKH
    yVgFEGO4SA8RNnjhJt2D7z8RDMWX3saody7jo9TKlrPABLZGo2o8vadW8Dly/v+I
    d0YDheCkVCoCEeUjQ8koXZhTwhYkGPu+vkdiqX5cUaiVTu1uzt591aO5Vw/hV4DI
    hFKnOTnYXnpXiwRwtPyYoGTa64yWfi2t0bv99qz0BgDjQjD0civCe8LRXGGhyB1U
    1aHjDDGEnulTYJyEqCzNGwBpzEHUjqIOXElFjt55AFGpCHAuyuoXoP3gQvoSj6RC
    sQIDAQAB
    -----END PUBLIC KEY-----`
    
        var privateKeyData = `-----BEGIN RSA PRIVATE KEY-----
    MIIEowIBAAKCAQEAoZ67dtUTLxoXnNEzRBFBmwukEJGC+y69cGgpNbtElQj3m4Af
    t/7cu9qYbTNguTSnCDt7uovZNb21u1vpZwKHyVgFEGO4SA8RNnjhJt2D7z8RDMWX
    3saody7jo9TKlrPABLZGo2o8vadW8Dly/v+Id0YDheCkVCoCEeUjQ8koXZhTwhYk
    GPu+vkdiqX5cUaiVTu1uzt591aO5Vw/hV4DIhFKnOTnYXnpXiwRwtPyYoGTa64yW
    fi2t0bv99qz0BgDjQjD0civCe8LRXGGhyB1U1aHjDDGEnulTYJyEqCzNGwBpzEHU
    jqIOXElFjt55AFGpCHAuyuoXoP3gQvoSj6RCsQIDAQABAoIBAGoYl5ukuJk9Ga8a
    LftLELRFaghuXXui7T0zQ4pASv9DCbiM3UWeCy1OjK1zAtXR2Kywz8JgN9DtnrVF
    2uyCXr0wCPL/Y2P6cCRAKh2nYQrXbcvikpXt9311zH4qHGvdx/nP5oM0JHejuJCu
    Re1btiwGTB3AoF+XzBAPSZ0gGl2FqDQ7qLqqwG9Xr+78STLdN8UOUCsKV3qdTM6N
    XLeXliI0XIFQgT6XMiRGEvhJVaUTJ/3q23xza87k8jpqGsh5ArtnG6LUON26rEed
    BL2ome7HNV+IOR143PXVrBMyn6qnwAas+Zt+WfCbBCP0k68oL7mzLmP6IzY4KBE9
    BFEo04ECgYEA9GMgi2Xm9OqjUmihMt0oPnPcMx0DR+4mZezPVED2f3garOKcWvOV
    y1N/Mn5A9L785jPjWE+ui7i5DT6AMJiWxkeEdYjXmZhpG9I3pha1yaLzBXjl+Dri
    /dCXZxQq+Z7axnBxwIhDNHAeeCAau6hLfzsGgv5YAvSeg6KU7Af16dkCgYEAqUzG
    jvZxfV/2qPMdNh9oUcvVbIcnIphnTP1Ma7BAD6anTnSru2EDLR66yiRtdrC9E54d
    4xWeTNHsSUcaQBkAsyp7Cpewgy4vmo8GE3qUu91Jk3/1ZN6jxLyMoakyzhYTmq4s
    QsTPC1daUXqpRjGYzP/8dMMzlKQ2Vncp+2BXgJkCgYEAinzJ6nSahluYpZBpGLu+
    nHVnaQed3lsUI1oouyP9C4ryAtp/pAK49fmg8OoewRKhmYn54Qd2b/MD2n96gQ9X
    EZFhfIFJO97kYUGlC1d/OH5AnO8/0oT8MLzNrzn8iGv+qcj6jRIqk0Kd4ZC/1Wuv
    LLA0JnMfSL16PjoZjg+MyTECgYBRq47RooMnBycXY4hA9q+9XcZMP3qajsiudDbs
    cC7HHg7xowjBMNB2cK+NGjuQGTxs/UbPqDsgNdh1lQ5Nw4H57FFEz94/ugUO21YE
    CYs8gUigFgdMLLb2DjsNNXEjx7SXVtRVNVnnz7DrQ2/rQ7vBkO+5Z/03BGyOE5g2
    AsjTaQKBgDLpbXN2p3eubQGJqv/K6f/9LBux/RWGXnZ+C1oCtGrUj+Ja8N6+cd6G
    Mz9Go00GCdCUZXByx6rAZQaw7kWcI646miaplX4YtbX1d2mwbnmmz9EH4aRhzdby
    9VDoPXBgf4dufgNoS3xP4NS4H5oPg0gPS0vwpWspWqplLM+N/kGj
    -----END RSA PRIVATE KEY-----`
        
        secretMessage := []byte("The quick brown fox jumps over the lazy dog")
        label := []byte("")
        rng := rand.Reader
    
        // Encryption -------------------------------------------------------------
        
        // Load public key
        pubKeyBlock, _ := pem.Decode([]byte(publicKeyData))
        var rsaPublicKey *rsa.PublicKey
        pubInterface, parseErr := x509.ParsePKIXPublicKey(pubKeyBlock.Bytes)
        if parseErr != nil {
            fmt.Println("Load public key error")
            panic(parseErr)
        }
        rsaPublicKey = pubInterface.(*rsa.PublicKey)
    
        ciphertext, err := EncryptOAEP(sha256.New(), sha1.New(), rng, rsaPublicKey, secretMessage, label)
        if err != nil {
                fmt.Printf("Error from encryption: %s\n", err)
                return
        }
    
        // Decryption -------------------------------------------------------------
        
        // Load private key
        privateKeyBlock, _ := pem.Decode([]byte(privateKeyData))
        var rsaPrivateKey *rsa.PrivateKey
        rsaPrivateKey, _ = x509.ParsePKCS1PrivateKey(privateKeyBlock.Bytes)
        
        decrypted, err := DecryptOAEP(sha256.New(), sha1.New(), rng, rsaPrivateKey, ciphertext, label)
        if err != nil {
                fmt.Printf("Error from decryption: %s\n", err)
                return
        }
        fmt.Println("Go Encryption/Decryption : " + string(decrypted))  
        
        // Cross-platform test: ciphertext from Java
        /*
            Cipher cipher = Cipher.getInstance("RSA/ECB/OAEPWITHSHA-256ANDMGF1PADDING");
            OAEPParameterSpec oaepParameterSpec = new OAEPParameterSpec("SHA-256", "MGF1", MGF1ParameterSpec.SHA1, PSource.PSpecified.DEFAULT);
            cipher.init(Cipher.ENCRYPT_MODE, publicKeyObject, oaepParameterSpec);
            String ciphertextB64 = Base64.getEncoder().encodeToString(cipher.doFinal(data));
        */
        ciphertext,_ = base64.StdEncoding.DecodeString("cCrJasWOwVFrAQ8S+p7Cdn7OnCJn/FiCjZLzDkDISOSv15u1HcLbVAqNa7ory2AW/tsV5tNz5Y53azs6SN7dwYlu58YH7kwqkwfmvUwK8pLdPPRXGaUy8/gEbM4wkwHUuxbYm/bpoEjpmICBtWzb5VOsE1RWHnZu1G2BqGKe1+sE1XadVKQpBqNSahYdthY2Dk21i/PStO5S4eRrgW2nDdmxCs9UtV4MBU8BVYHYF0TYweA/udBoGTizSDjgmWn0RXYJruGvFMHWCRRlPnj+pcelatIfY4YKOHREYifKVkphkB7PT/JaVFyMZWzOtqzE13ZBWBwBmA/yCNLE/7krcg==")  
        decrypted, err = DecryptOAEP(sha256.New(), sha1.New(), rng, rsaPrivateKey, ciphertext, label)
        if err != nil {
                fmt.Printf("Error from decryption: %s\n", err)
                return
        }
        fmt.Println("Cross platform decryption: " + string(decrypted))  
        
    
    }
    
    // From rsa package - Encryption -------------------------------------------------------------
    
    func EncryptOAEP(hash hash.Hash, hashMGF1 hash.Hash, random io.Reader, pub *rsa.PublicKey, msg []byte, label []byte) ([]byte, error) {
        if err := checkPub(pub); err != nil {
            return nil, err
        }
        hash.Reset()
        k := pub.Size()
        if len(msg) > k-2*hash.Size()-2 {
            return nil, rsa.ErrMessageTooLong
        }
    
        hash.Write(label)
        lHash := hash.Sum(nil)
        hash.Reset()
    
        em := make([]byte, k)
        seed := em[1 : 1+hash.Size()]
        db := em[1+hash.Size():]
    
        copy(db[0:hash.Size()], lHash)
        db[len(db)-len(msg)-1] = 1
        copy(db[len(db)-len(msg):], msg)
    
        _, err := io.ReadFull(random, seed)
        if err != nil {
            return nil, err
        }
    
        hashMGF1.Reset()
        mgf1XOR(db, hashMGF1, seed)
        mgf1XOR(seed, hashMGF1, db)
    
        m := new(big.Int)
        m.SetBytes(em)
        c := encrypt(new(big.Int), pub, m)
    
        out := make([]byte, k)
        return c.FillBytes(out), nil
    }
    
    func encrypt(c *big.Int, pub *rsa.PublicKey, m *big.Int) *big.Int {
        e := big.NewInt(int64(pub.E))
        c.Exp(m, e, pub.N)
        return c
    }
    
    // From rsa package - Decryption -------------------------------------------------------------
    
    func DecryptOAEP(hash hash.Hash, hashMGF1 hash.Hash, random io.Reader, priv *rsa.PrivateKey, ciphertext []byte, label []byte) ([]byte, error) { // hashMGF1 hash.Hash added
        if err := checkPub(&priv.PublicKey); err != nil {
            return nil, err
        }
        k := priv.Size()
        if len(ciphertext) > k ||
            k < hash.Size()*2+2 {
            return nil, rsa.ErrDecryption
        }
    
        c := new(big.Int).SetBytes(ciphertext)
    
        m, err := decrypt(random, priv, c)
        if err != nil {
            return nil, err
        }
    
        hash.Write(label)
        lHash := hash.Sum(nil)
        hash.Reset()
    
        // We probably leak the number of leading zeros.
        // It's not clear that we can do anything about this.
        em := m.FillBytes(make([]byte, k))
    
        firstByteIsZero := subtle.ConstantTimeByteEq(em[0], 0)
    
        seed := em[1 : hash.Size()+1]
        db := em[hash.Size()+1:]
    
        hashMGF1.Reset()
        mgf1XOR(seed, hashMGF1, db) // apply hashMGF1
        mgf1XOR(db, hashMGF1, seed) // apply hashMGF1
    
        lHash2 := db[0:hash.Size()]
    
        // We have to validate the plaintext in constant time in order to avoid
        // attacks like: J. Manger. A Chosen Ciphertext Attack on RSA Optimal
        // Asymmetric Encryption Padding (OAEP) as Standardized in PKCS #1
        // v2.0. In J. Kilian, editor, Advances in Cryptology.
        lHash2Good := subtle.ConstantTimeCompare(lHash, lHash2)
    
        // The remainder of the plaintext must be zero or more 0x00, followed
        // by 0x01, followed by the message.
        //   lookingForIndex: 1 iff we are still looking for the 0x01
        //   index: the offset of the first 0x01 byte
        //   invalid: 1 iff we saw a non-zero byte before the 0x01.
        var lookingForIndex, index, invalid int
        lookingForIndex = 1
        rest := db[hash.Size():]
    
        for i := 0; i < len(rest); i++ {
            equals0 := subtle.ConstantTimeByteEq(rest[i], 0)
            equals1 := subtle.ConstantTimeByteEq(rest[i], 1)
            index = subtle.ConstantTimeSelect(lookingForIndex&equals1, i, index)
            lookingForIndex = subtle.ConstantTimeSelect(equals1, 0, lookingForIndex)
            invalid = subtle.ConstantTimeSelect(lookingForIndex&^equals0, 1, invalid)
        }
    
        if firstByteIsZero&lHash2Good&^invalid&^lookingForIndex != 1 {
            return nil, rsa.ErrDecryption
        }
    
        return rest[index+1:], nil
    }
    
    var bigZero = big.NewInt(0)
    var bigOne = big.NewInt(1)
    
    func decrypt(random io.Reader, priv *rsa.PrivateKey, c *big.Int) (m *big.Int, err error) {
        // TODO(agl): can we get away with reusing blinds?
        if c.Cmp(priv.N) > 0 {
            err = rsa.ErrDecryption
            return
        }
        if priv.N.Sign() == 0 {
            return nil, rsa.ErrDecryption
        }
    
        var ir *big.Int
        if random != nil {
            MaybeReadByte(random)
    
            // Blinding enabled. Blinding involves multiplying c by r^e.
            // Then the decryption operation performs (m^e * r^e)^d mod n
            // which equals mr mod n. The factor of r can then be removed
            // by multiplying by the multiplicative inverse of r.
    
            var r *big.Int
            ir = new(big.Int)
            for {
                r, err = rand.Int(random, priv.N)
                if err != nil {
                    return
                }
                if r.Cmp(bigZero) == 0 {
                    r = bigOne
                }
                ok := ir.ModInverse(r, priv.N)
                if ok != nil {
                    break
                }
            }
            bigE := big.NewInt(int64(priv.E))
            rpowe := new(big.Int).Exp(r, bigE, priv.N) // N != 0
            cCopy := new(big.Int).Set(c)
            cCopy.Mul(cCopy, rpowe)
            cCopy.Mod(cCopy, priv.N)
            c = cCopy
        }
    
        if priv.Precomputed.Dp == nil {
            m = new(big.Int).Exp(c, priv.D, priv.N)
        } else {
            // We have the precalculated values needed for the CRT.
            m = new(big.Int).Exp(c, priv.Precomputed.Dp, priv.Primes[0])
            m2 := new(big.Int).Exp(c, priv.Precomputed.Dq, priv.Primes[1])
            m.Sub(m, m2)
            if m.Sign() < 0 {
                m.Add(m, priv.Primes[0])
            }
            m.Mul(m, priv.Precomputed.Qinv)
            m.Mod(m, priv.Primes[0])
            m.Mul(m, priv.Primes[1])
            m.Add(m, m2)
    
            for i, values := range priv.Precomputed.CRTValues {
                prime := priv.Primes[2+i]
                m2.Exp(c, values.Exp, prime)
                m2.Sub(m2, m)
                m2.Mul(m2, values.Coeff)
                m2.Mod(m2, prime)
                if m2.Sign() < 0 {
                    m2.Add(m2, prime)
                }
                m2.Mul(m2, values.R)
                m.Add(m, m2)
            }
        }
    
        if ir != nil {
            // Unblind.
            m.Mul(m, ir)
            m.Mod(m, priv.N)
        }
    
        return
    }
    
    var (
        closedChanOnce sync.Once
        closedChan     chan struct{}
    )
    
    func MaybeReadByte(r io.Reader) { // from "crypto/internal/randutil"
        closedChanOnce.Do(func() {
            closedChan = make(chan struct{})
            close(closedChan)
        })
    
        select {
        case <-closedChan:
            return
        case <-closedChan:
            var buf [1]byte
            r.Read(buf[:])
        }
    }
    
    // From rsa package - both -------------------------------------------------------------
    
    func mgf1XOR(out []byte, hash hash.Hash, seed []byte) {
        var counter [4]byte
        var digest []byte
    
        done := 0
        for done < len(out) {
            hash.Write(seed)
            hash.Write(counter[0:4])
            digest = hash.Sum(digest[:0])
            hash.Reset()
    
            for i := 0; i < len(digest) && done < len(out); i++ {
                out[done] ^= digest[i]
                done++
            }
            incCounter(&counter)
        }
    }
    
    func checkPub(pub *rsa.PublicKey) error {
        if pub.N == nil {
            return errPublicModulus
        }
        if pub.E < 2 {
            return errPublicExponentSmall
        }
        if pub.E > 1<<31-1 {
            return errPublicExponentLarge
        }
        return nil
    }
    
    var (
        errPublicModulus       = errors.New("crypto/rsa: missing public modulus")
        errPublicExponentSmall = errors.New("crypto/rsa: public exponent too small")
        errPublicExponentLarge = errors.New("crypto/rsa: public exponent too large")
    )
    
    func incCounter(c *[4]byte) {
        if c[3]++; c[3] != 0 {
            return
        }
        if c[2]++; c[2] != 0 {
            return
        }
        if c[1]++; c[1] != 0 {
            return
        }
        c[0]++
    }