From 0325d6f982eeb13a66dedff9d3065f3748f6311e Mon Sep 17 00:00:00 2001
From: siddharth <s@ricketyspace.net>
Date: Sun, 10 Oct 2021 17:14:34 -0400
Subject: lib: implement md4

---
 lib/md4.go      | 253 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++
 lib/md4_test.go |  58 +++++++++++++
 2 files changed, 311 insertions(+)
 create mode 100644 lib/md4.go
 create mode 100644 lib/md4_test.go

(limited to 'lib')

diff --git a/lib/md4.go b/lib/md4.go
new file mode 100644
index 0000000..a1bebc1
--- /dev/null
+++ b/lib/md4.go
@@ -0,0 +1,253 @@
+// Copyright © 2021 rsiddharth <s@ricketyspace.net>
+// SPDX-License-Identifier: ISC
+
+package lib
+
+// MD4 implementation.
+// Reference https://datatracker.ietf.org/doc/html/rfc1320
+
+type Md4 struct {
+	hvs    []uint32
+	Msg    []byte
+	MsgLen int
+}
+
+// Initial hash value.
+var md4IHashValues []uint32 = []uint32{
+	0x67452301,
+	0xefcdab89,
+	0x98badcfe,
+	0x10325476,
+}
+
+func md4F(x, y, z uint32) uint32 {
+	return (x & y) | (^x & z)
+}
+
+func md4G(x, y, z uint32) uint32 {
+	return (x & y) | (x & z) | (y & z)
+}
+
+func md4H(x, y, z uint32) uint32 {
+	return x ^ y ^ z
+}
+
+func md4RoundOneFunc(a, b, c, d, x uint32, s uint) uint32 {
+	a = shaAdd(a, md4F(b, c, d), x)
+	a = shaRotl(a, s)
+
+	return a
+}
+
+func md4RoundTwoFunc(a, b, c, d, x uint32, s uint) uint32 {
+	a = shaAdd(a, md4G(b, c, d), x, 0x5A827999)
+	a = shaRotl(a, s)
+
+	return a
+}
+
+func md4RoundThreeFunc(a, b, c, d, x uint32, s uint) uint32 {
+	a = shaAdd(a, md4H(b, c, d), x, 0x6ED9EBA1)
+	a = shaRotl(a, s)
+
+	return a
+}
+
+func md4Padding(l int) []byte {
+	l = l * 8 // msg size in bits
+
+	// Reckon value of `k`
+	k := 0
+	for ((l + 1 + k) % 512) != 448 {
+		k += 1
+	}
+
+	// Initialize padding bytes
+	pbs := make([]byte, 0)
+
+	// Add bit `1` as byte block.
+	pbs = append(pbs, 0x80)
+	f := 7 // unclaimed bits in last byte of `pbs`
+
+	// Add `k` bit `0`s
+	for i := 0; i < k; i++ {
+		if f == 0 {
+			pbs = append(pbs, 0x0)
+			f = 8
+		}
+		f = f - 1
+	}
+
+	// Add `l` in a 64 bit block in `pbs`
+	l64 := uint64(l)
+	b64 := make([]byte, 8) // last 64-bits
+	for i := 0; i <= 7; i++ {
+		// Get 8 last bits.
+		b64[i] = byte(l64 & 0xFF)
+
+		// Get rid of the last 8 bits.
+		l64 = l64 >> 8
+	}
+	pbs = append(pbs, b64...)
+
+	return pbs
+}
+
+func md4MessageBlocks(pm []byte) [][]uint32 {
+	// Break into 512-bit blocks
+	bs := BreakIntoBlocks(pm, 64)
+
+	mbs := make([][]uint32, 0) // Message blocks.
+	for i := 0; i < len(bs); i++ {
+		ws := make([]uint32, 0) // 32-bit words.
+
+		// Break 512-bit (64 bytes) into 32-bit words.
+		for j := 0; j < 64; j = j + 4 {
+			// Pack 4 bytes into a 32-bit word.
+			w := (uint32(bs[i][j]) |
+				uint32(bs[i][j+1])<<8 |
+				uint32(bs[i][j+2])<<16 |
+				uint32(bs[i][j+3])<<24)
+			ws = append(ws, w)
+		}
+		mbs = append(mbs, ws)
+	}
+	return mbs
+}
+
+func (md *Md4) Init(hvs []uint32) {
+	// Set Initial Hash Values.
+	h := make([]uint32, 4)
+	if len(hvs) == 4 {
+		copy(h, hvs)
+		md.hvs = h
+	} else {
+		copy(h, md4IHashValues)
+		md.hvs = h
+	}
+}
+
+func (md *Md4) Message(m []byte) {
+	md.Msg = m
+	md.MsgLen = len(m)
+}
+
+// MD4 - Pad message such that its length is a multiple of 512.
+func (md *Md4) Pad() []byte {
+	// Initialize padded message
+	pm := make([]byte, len(md.Msg))
+	copy(pm, md.Msg)
+
+	// Add padding.
+	pm = append(pm, md4Padding(md.MsgLen)...)
+
+	return pm
+}
+
+func (md *Md4) Hash() []byte {
+	// Pad message.
+	pm := md.Pad()
+
+	// Break into message blocks.
+	mbs := md4MessageBlocks(pm)
+
+	// Initialize hash values.
+	h := make([]uint32, 4)
+	copy(h, md.hvs) // Initial hash values.
+
+	// Process each message block.
+	for _, mb := range mbs {
+
+		// Initialize working variables.
+		a := h[0]
+		b := h[1]
+		c := h[2]
+		d := h[3]
+
+		//  Round 1.
+		a = md4RoundOneFunc(a, b, c, d, mb[0], 3)
+		d = md4RoundOneFunc(d, a, b, c, mb[1], 7)
+		c = md4RoundOneFunc(c, d, a, b, mb[2], 11)
+		b = md4RoundOneFunc(b, c, d, a, mb[3], 19)
+
+		a = md4RoundOneFunc(a, b, c, d, mb[4], 3)
+		d = md4RoundOneFunc(d, a, b, c, mb[5], 7)
+		c = md4RoundOneFunc(c, d, a, b, mb[6], 11)
+		b = md4RoundOneFunc(b, c, d, a, mb[7], 19)
+
+		a = md4RoundOneFunc(a, b, c, d, mb[8], 3)
+		d = md4RoundOneFunc(d, a, b, c, mb[9], 7)
+		c = md4RoundOneFunc(c, d, a, b, mb[10], 11)
+		b = md4RoundOneFunc(b, c, d, a, mb[11], 19)
+
+		a = md4RoundOneFunc(a, b, c, d, mb[12], 3)
+		d = md4RoundOneFunc(d, a, b, c, mb[13], 7)
+		c = md4RoundOneFunc(c, d, a, b, mb[14], 11)
+		b = md4RoundOneFunc(b, c, d, a, mb[15], 19)
+
+		// Round 2
+		a = md4RoundTwoFunc(a, b, c, d, mb[0], 3)
+		d = md4RoundTwoFunc(d, a, b, c, mb[4], 5)
+		c = md4RoundTwoFunc(c, d, a, b, mb[8], 9)
+		b = md4RoundTwoFunc(b, c, d, a, mb[12], 13)
+
+		a = md4RoundTwoFunc(a, b, c, d, mb[1], 3)
+		d = md4RoundTwoFunc(d, a, b, c, mb[5], 5)
+		c = md4RoundTwoFunc(c, d, a, b, mb[9], 9)
+		b = md4RoundTwoFunc(b, c, d, a, mb[13], 13)
+
+		a = md4RoundTwoFunc(a, b, c, d, mb[2], 3)
+		d = md4RoundTwoFunc(d, a, b, c, mb[6], 5)
+		c = md4RoundTwoFunc(c, d, a, b, mb[10], 9)
+		b = md4RoundTwoFunc(b, c, d, a, mb[14], 13)
+
+		a = md4RoundTwoFunc(a, b, c, d, mb[3], 3)
+		d = md4RoundTwoFunc(d, a, b, c, mb[7], 5)
+		c = md4RoundTwoFunc(c, d, a, b, mb[11], 9)
+		b = md4RoundTwoFunc(b, c, d, a, mb[15], 13)
+
+		// Round 3
+		a = md4RoundThreeFunc(a, b, c, d, mb[0], 3)
+		d = md4RoundThreeFunc(d, a, b, c, mb[8], 9)
+		c = md4RoundThreeFunc(c, d, a, b, mb[4], 11)
+		b = md4RoundThreeFunc(b, c, d, a, mb[12], 15)
+
+		a = md4RoundThreeFunc(a, b, c, d, mb[2], 3)
+		d = md4RoundThreeFunc(d, a, b, c, mb[10], 9)
+		c = md4RoundThreeFunc(c, d, a, b, mb[6], 11)
+		b = md4RoundThreeFunc(b, c, d, a, mb[14], 15)
+
+		a = md4RoundThreeFunc(a, b, c, d, mb[1], 3)
+		d = md4RoundThreeFunc(d, a, b, c, mb[9], 9)
+		c = md4RoundThreeFunc(c, d, a, b, mb[5], 11)
+		b = md4RoundThreeFunc(b, c, d, a, mb[13], 15)
+
+		a = md4RoundThreeFunc(a, b, c, d, mb[3], 3)
+		d = md4RoundThreeFunc(d, a, b, c, mb[11], 9)
+		c = md4RoundThreeFunc(c, d, a, b, mb[7], 11)
+		b = md4RoundThreeFunc(b, c, d, a, mb[15], 15)
+
+		// Compute intermediate hash values.
+		h[0] = shaAdd(a, h[0])
+		h[1] = shaAdd(b, h[1])
+		h[2] = shaAdd(c, h[2])
+		h[3] = shaAdd(d, h[3])
+	}
+
+	// Slurp md4 digest from hash values.
+	d := make([]byte, 0) // md4 digest
+	for i := 0; i < 4; i++ {
+		// Break 32-bit hash value into 4 bytes.
+		hb := make([]byte, 4)
+		for j := 0; j <= 3; j++ {
+			// Get last 8 bits.
+			hb[j] = byte(h[i] & 0xFF)
+
+			// Get rid of last 8 bits.
+			h[i] = h[i] >> 8
+		}
+		d = append(d, hb...)
+	}
+
+	return d
+}
diff --git a/lib/md4_test.go b/lib/md4_test.go
new file mode 100644
index 0000000..31b7b31
--- /dev/null
+++ b/lib/md4_test.go
@@ -0,0 +1,58 @@
+// Copyright © 2021 rsiddharth <s@ricketyspace.net>
+// SPDX-License-Identifier: ISC
+
+package lib
+
+import "testing"
+
+func TestMd4Hash(t *testing.T) {
+	md4 := Md4{}
+	md4.Init([]uint32{})
+
+	// Test 1
+	m := "abc"
+	md4.Message(StrToBytes(m))
+	h := md4.Hash()
+	e := "a448017aaf21d8525fc10ae87aa6729d" // Expected hash.
+	if BytesToHexStr(h) != e {
+		t.Errorf("md4 test 1 failed: %x != %s\n", h, e)
+	}
+
+	// Test 2
+	m = "message digest"
+	md4.Message(StrToBytes(m))
+	h = md4.Hash()
+	e = "d9130a8164549fe818874806e1c7014b" // Expected hash.
+	if BytesToHexStr(h) != e {
+		t.Errorf("md4 test 2 failed: %x != %s\n", h, e)
+	}
+
+	// Test 3
+	m = "abcdefghijklmnopqrstuvwxyz"
+	md4.Message(StrToBytes(m))
+	h = md4.Hash()
+	e = "d79e1c308aa5bbcdeea8ed63df412da9" // Expected hash.
+	if BytesToHexStr(h) != e {
+		t.Errorf("md4 test 3 failed: %x != %s\n", h, e)
+	}
+
+	// Test 4
+	m = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"
+	md4.Message(StrToBytes(m))
+	h = md4.Hash()
+	e = "043f8582f241db351ce627e153e7f0e4" // Expected hash.
+	if BytesToHexStr(h) != e {
+		t.Errorf("md4 test 4 failed: %x != %s\n", h, e)
+	}
+
+	// Test 5
+	m = "123456789012345678901234567890123456789012345678901234567890123456"
+	m += "78901234567890"
+	md4.Message(StrToBytes(m))
+	h = md4.Hash()
+	e = "e33b4ddc9c38f2199c3e7b164fcc0536" // Expected hash.
+	if BytesToHexStr(h) != e {
+		t.Errorf("md4 test 5 failed: %x != %s\n", h, e)
+	}
+
+}
-- 
cgit v1.2.3