summaryrefslogblamecommitdiffstats
path: root/lib/rsa.go
blob: 772b898a27b2bbe0f100f194c4a1352abd976f8e (plain) (tree)
1
2
3
4
5
6
7
8
9
10





                                                             
                     


                  















                              








                                        
                                              
































                                                     





























                                                                          
                                      




























                                                                     










                                             
                               
                  

 



                               








                                                    
// Copyright © 2021 siddharth ravikumar <s@ricketyspace.net>
// SPDX-License-Identifier: ISC

package lib

import (
	"crypto/rand"
	"math/big"
)

// Represents an RSA key pair.
type RSAPair struct {
	Public  *RSAPub
	Private *RSAPrivate
}

type RSAPub struct {
	e *big.Int
	n *big.Int
}

type RSAPrivate struct {
	d *big.Int
	n *big.Int
}

// Copy b to a.
func biCopy(a, b *big.Int) *big.Int {
	a.SetBytes(b.Bytes())
	if b.Sign() == -1 {
		a.Mul(a, big.NewInt(-1))
	}
	return a
}

func InvMod(a, n *big.Int) (*big.Int, error) {
	// Initialize.
	t0 := big.NewInt(0)
	t1 := big.NewInt(1)
	r0 := biCopy(big.NewInt(0), n)
	r1 := biCopy(big.NewInt(0), a)

	for r1.Cmp(big.NewInt(0)) != 0 {
		q := big.NewInt(0)
		q.Div(r0, r1)

		tt := big.NewInt(0)
		tt = tt.Mul(q, t1)
		tt = tt.Sub(t0, tt)

		biCopy(t0, t1)
		biCopy(t1, tt)

		tr := big.NewInt(0)
		tr = tr.Mul(q, r1)
		tr = tr.Sub(r0, tr)

		biCopy(r0, r1)
		biCopy(r1, tr)
	}

	if r0.Cmp(big.NewInt(1)) > 0 {
		return nil, CPError{"not invertible"}
	}
	if t0.Cmp(big.NewInt(0)) < 0 {
		t0.Add(t0, n)
	}
	return t0, nil
}

func RSAGenKey() (*RSAPair, error) {
	// Initialize.
	e := big.NewInt(3)
	d := big.NewInt(0)
	n := big.NewInt(0)

	// Compute n and d.
	for {
		// Generate prime p.
		p, err := rand.Prime(rand.Reader, 1024)
		if err != nil {
			return nil, CPError{"unable to generate p"}
		}

		// Generate prime q.
		q, err := rand.Prime(rand.Reader, 1024)
		if err != nil {
			return nil, CPError{"unable to generate q"}
		}

		// Calculate n.
		n = big.NewInt(0).Mul(p, q)

		// Calculate totient.
		p1 := big.NewInt(0).Sub(p, big.NewInt(1)) // p-1
		q1 := big.NewInt(0).Sub(q, big.NewInt(1)) // q-1
		et := big.NewInt(0).Mul(p1, q1)           // Totient `et`.

		// Calculate private key `d`.
		d, err = InvMod(e, et)
		if err != nil {
			continue // Inverse does not does. Try again.
		}
		break
	}
	if n.Cmp(big.NewInt(0)) <= 0 {
		return nil, CPError{"unable to compute n"}
	}
	if d.Cmp(big.NewInt(0)) <= 0 {
		return nil, CPError{"unable to compute d"}
	}

	// Make pub key.
	pub := new(RSAPub)
	pub.e = e
	pub.n = biCopy(big.NewInt(0), n)

	// Make private key.
	prv := new(RSAPrivate)
	prv.d = d
	prv.n = biCopy(big.NewInt(0), n)

	// Make key pair.
	pair := new(RSAPair)
	pair.Public = pub
	pair.Private = prv

	return pair, nil
}

func (r *RSAPub) Encrypt(msg []byte) []byte {
	// Convert message to big int.
	m := big.NewInt(0).SetBytes(msg)

	// Encrypt.
	c := big.NewInt(0).Exp(m, r.e, r.n)

	return c.Bytes()
}

func (r *RSAPub) E() *big.Int {
	return r.e
}

func (r *RSAPub) N() *big.Int {
	return r.n
}

func (r *RSAPrivate) Decrypt(cipher []byte) []byte {
	// Convert cipher to big int.
	c := big.NewInt(0).SetBytes(cipher)

	// Decrypt.
	m := big.NewInt(0).Exp(c, r.d, r.n)

	return m.Bytes()
}