diff options
author | siddharth <s@ricketyspace.net> | 2022-04-11 20:03:08 -0400 |
---|---|---|
committer | siddharth <s@ricketyspace.net> | 2022-04-11 20:03:08 -0400 |
commit | 8e1700059a73f7090528194fab0b36751d6d1693 (patch) | |
tree | 09960e62ffbc4b2f2dcaf145d615af7a1ba2dacb /lib | |
parent | f613d9d03211efed74dcc9e5674f7eb8d9a94325 (diff) |
lib: add srp session key functions
Diffstat (limited to 'lib')
-rw-r--r-- | lib/srp.go | 77 | ||||
-rw-r--r-- | lib/srp_test.go | 87 |
2 files changed, 164 insertions, 0 deletions
@@ -199,6 +199,30 @@ func (u *SRPUser) SetScramblingParam(a *big.Int) error { return nil } +func (u *SRPUser) ComputeSessionKey(a *big.Int) error { + if a.Cmp(big.NewInt(0)) != 1 { + return CPError{"a is invalid"} + } + + // v^u + vu := new(big.Int) + vu.Exp(u.v, u.u, u.n) + + // S = (A * v^u) ^ b + s := new(big.Int) + s.Mul(a, vu) + s.Exp(s, u.b, u.n) + sb := s.Bytes() + + // K = H(S) + m := make([]byte, 0) + m = append(m, sb...) + u.h.Message(m) + u.sk = u.h.Hash() + + return nil +} + func NewSRPClientSession(n, g, k, ident string) (*SRPClientSession, error) { var ok bool @@ -270,3 +294,56 @@ func (s *SRPClientSession) SetScramblingParam(b *big.Int) error { } return nil } + +func (s *SRPClientSession) ComputeSessionKey(salt []byte, + pass string, b *big.Int) error { + if len(salt) < 1 { + return CPError{"salt invalid"} + } + if len(pass) < 1 { + return CPError{"pass invalid"} + } + + // salt+pass + sp := make([]byte, 0) + copy(sp, salt) + sp = append(sp, StrToBytes(pass)...) + + // x = H(salt+pass) + x := new(big.Int) + s.h.Message(sp) + x.SetBytes(s.h.Hash()) + + // g^x + gx := new(big.Int) + gx.Exp(s.g, x, s.n) + + // k * g^x + kgx := new(big.Int) + kgx.Mul(s.k, gx) + + // B - (k * g^x) + bkgx := new(big.Int) + bkgx.Sub(b, kgx) + + // u * x + ux := new(big.Int) + ux.Mul(s.u, x) + + // a + u*x + aux := new(big.Int) + aux.Add(s.a, ux) + + // S = (B - (k * g^x)) ^ (a + u*x) + sec := new(big.Int) + sec.Exp(bkgx, aux, s.n) + sb := sec.Bytes() + + // K = H(S) + m := make([]byte, 0) + m = append(m, sb...) + s.h.Message(m) + s.sk = s.h.Hash() + + return nil +} diff --git a/lib/srp_test.go b/lib/srp_test.go index a819b15..1445e01 100644 --- a/lib/srp_test.go +++ b/lib/srp_test.go @@ -248,3 +248,90 @@ func TestSRPScramblingParamter(t *testing.T) { return } } + +func TestSRPSessionKey(t *testing.T) { + n := StripSpaceChars( + `ffffffffffffffffc90fdaa22168c234c4c6628b80dc1cd129024 + e088a67cc74020bbea63b139b22514a08798e3404ddef9519b3cd + 3a431b302b0a6df25f14374fe1356d6d51c245e485b576625e7ec + 6f44c42e9a637ed6b0bff5cb6f406b7edee386bfb5a899fa5ae9f + 24117c4b1fe649286651ece45b3dc2007cb8a163bf0598da48361 + c55d39a69163fa8fd24cf5f83655d23dca3ad961c62f356208552 + bb9ed529077096966d670c354e4abc9804f1746c08ca237327fff + fffffffffffff`) + g := "2" + k := "3" + ident := "s@ricketyspace.net" + pass := "d59d6c93af0f37f272d924979" + + // Init srp server user. + user, err := NewSRPUser(n, g, k, ident, pass) + if err != nil { + t.Errorf("unable to create user on server: %v\n", err) + return + } + + // Get server's pub for user. + user.EphemeralKeyGen() + pubB, err := user.EphemeralKeyPub() + if err != nil { + t.Errorf("server ephemeral pub error: %v\n", err) + return + } + + // Init srp client session. + session, err := NewSRPClientSession(n, g, k, ident) + if err != nil { + t.Errorf("unable to create client session: %v\n", err) + return + } + + // Get client's pub for user. + pubA, err := session.EphemeralKeyPub() + if err != nil { + t.Errorf("client ephemeral pub error: %v\n", err) + return + } + + // Compute server's scrambling parameter for user. + err = user.SetScramblingParam(pubA) + if err != nil { + t.Errorf("unable generate server scrambling parameter: %v\n", err) + return + } + + // Compute client's scrambling paramter for user. + err = session.SetScramblingParam(pubB) + if err != nil { + t.Errorf("unable generate client scrambling parameter: %v\n", err) + return + } + + // The server's and client's scrambling parameter must be the + // same. + if user.u.Cmp(session.u) != 0 { + t.Error("Error: scrambling parameter of server != client\n") + return + } + + // Compute server's session key for user. + err = user.ComputeSessionKey(pubA) + if err != nil { + t.Errorf("unable to compute server's session key: %v", err) + return + } + + // Compute client's session key for for user. + err = session.ComputeSessionKey(user.salt, pass, pubB) + if err != nil { + t.Errorf("unable to compute client's session key: %v", err) + return + } + + // Verify that the session key is the same. + if !BytesEqual(user.sk, session.sk) { + t.Errorf("server's and client's session key not equal:"+ + " server_sk(%v): client_sk(%v)", user.sk, session.sk) + return + } +} |