diff options
| author | Paul Buetow <paul@buetow.org> | 2026-02-22 09:41:45 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-02-22 09:41:45 +0200 |
| commit | 830b314656b0353e6a52652b650b686edfefa97a (patch) | |
| tree | 6eb0330566e9b785c697a3348d7e5e1f3824c140 | |
| parent | f78a185364ee24bc8c46d4aa6cc96d705faab326 (diff) | |
Implement crypto package with AES-256-CBC and golden tests (task 354)
Add internal/crypto/crypto.go with byte-identical AES-256-CBC encryption
and decryption matching the Ruby OpenSSL reference in geheim.rb, including
key doubling (enforceKeyLength), IV derivation (buildIV), PKCS7 pad/unpad,
and NewCipher/Encrypt/Decrypt. Add internal/crypto/crypto_test.go with
table-driven golden tests (6 vectors from Ruby), roundtrip tests (8 cases),
enforceKeyLength/buildIV/PKCS7 unit tests, and error-path coverage (22 total).
Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
| -rw-r--r-- | internal/crypto/crypto.go | 162 | ||||
| -rw-r--r-- | internal/crypto/crypto_test.go | 450 |
2 files changed, 611 insertions, 1 deletions
diff --git a/internal/crypto/crypto.go b/internal/crypto/crypto.go index 0e985f4..83a4ae1 100644 --- a/internal/crypto/crypto.go +++ b/internal/crypto/crypto.go @@ -1,2 +1,162 @@ -// Package crypto provides encryption and decryption primitives for geheim. +// Package crypto provides AES-256-CBC encryption and decryption primitives +// for geheim. The implementation is byte-identical to Ruby's OpenSSL +// AES-256-CBC cipher so that files encrypted by the Ruby CLI can be +// decrypted by the Go implementation and vice-versa. package crypto + +import ( + "crypto/aes" + "crypto/cipher" + "errors" + "fmt" + "os" +) + +const defaultBlockSize = 16 + +// Cipher holds the derived key and IV used for every encrypt/decrypt call. +// Both values are fixed for the lifetime of the struct; create a new Cipher +// if the key file or PIN changes. +type Cipher struct { + key []byte // exactly keyLength bytes (default 32 for AES-256) + iv []byte // exactly 16 bytes (AES block size) +} + +// NewCipher reads the raw key material from keyFile, pads/truncates it to +// keyLength bytes using the same doubling strategy as the Ruby reference +// implementation, and derives the 16-byte IV from pin and addToIV. +// +// keyLength is typically 32 (AES-256). pin and addToIV must be ASCII strings +// because the IV is constructed at the byte level (not the rune level). +func NewCipher(keyFile string, keyLength int, pin string, addToIV string) (*Cipher, error) { + raw, err := os.ReadFile(keyFile) + if err != nil { + return nil, fmt.Errorf("reading key file %q: %w", keyFile, err) + } + if len(raw) == 0 { + return nil, errors.New("key file is empty") + } + + return &Cipher{ + key: enforceKeyLength(raw, keyLength), + iv: buildIV(pin, addToIV), + }, nil +} + +// enforceKeyLength replicates the Ruby `enforce_key_length` method: +// +// new_key += key while new_key.size < force_size +// new_key[0, force_size] +// +// If key is already exactly size bytes it is returned unchanged (after a copy). +// If it is longer it is simply truncated. If it is shorter the key is +// concatenated with itself until it reaches at least size bytes, then +// truncated to exactly size bytes. +func enforceKeyLength(key []byte, size int) []byte { + newKey := make([]byte, len(key)) + copy(newKey, key) + + // Keep appending the original key until we have enough bytes. + for len(newKey) < size { + newKey = append(newKey, key...) + } + + return newKey[:size] +} + +// buildIV constructs the 16-byte initialization vector the same way the Ruby +// reference does: +// +// iv_str = pin * 2 + add_to_iv + pin * 2 +// iv = iv_str.byteslice(0, 16) +// +// The slice is performed on bytes, not runes, so ASCII PINs are required for +// correct cross-language compatibility. +func buildIV(pin, addToIV string) []byte { + ivStr := pin + pin + addToIV + pin + pin + return []byte(ivStr)[:16] +} + +// Encrypt encrypts plaintext using AES-256-CBC with PKCS7 padding and returns +// the raw binary ciphertext (no base64 encoding). PKCS7 always adds a full +// extra block when the plaintext length is already a multiple of 16. +func (c *Cipher) Encrypt(plaintext []byte) ([]byte, error) { + padded := pkcs7Pad(plaintext, defaultBlockSize) + + block, err := aes.NewCipher(c.key) + if err != nil { + return nil, fmt.Errorf("creating AES cipher: %w", err) + } + + ciphertext := make([]byte, len(padded)) + mode := cipher.NewCBCEncrypter(block, c.iv) + mode.CryptBlocks(ciphertext, padded) + + return ciphertext, nil +} + +// Decrypt decrypts AES-256-CBC ciphertext (raw binary, no base64) and strips +// PKCS7 padding, returning the original plaintext bytes. +func (c *Cipher) Decrypt(ciphertext []byte) ([]byte, error) { + if len(ciphertext)%defaultBlockSize != 0 { + return nil, fmt.Errorf("ciphertext length %d is not a multiple of block size %d", + len(ciphertext), defaultBlockSize) + } + if len(ciphertext) == 0 { + return nil, errors.New("ciphertext is empty") + } + + block, err := aes.NewCipher(c.key) + if err != nil { + return nil, fmt.Errorf("creating AES cipher: %w", err) + } + + // Decrypt in-place: CBC decrypter writes back into the same slice. + plaintext := make([]byte, len(ciphertext)) + mode := cipher.NewCBCDecrypter(block, c.iv) + mode.CryptBlocks(plaintext, ciphertext) + + return pkcs7Unpad(plaintext) +} + +// pkcs7Pad appends PKCS7 padding so that len(result) is a multiple of +// blockSize. A full extra block is added when the input is already aligned, +// matching OpenSSL's default behaviour. +func pkcs7Pad(data []byte, blockSize int) []byte { + // padding value is the number of bytes that need to be added; + // at minimum 1, at maximum blockSize (full block when already aligned). + padding := blockSize - (len(data) % blockSize) + padded := make([]byte, len(data)+padding) + copy(padded, data) + for i := len(data); i < len(padded); i++ { + padded[i] = byte(padding) + } + return padded +} + +// pkcs7Unpad validates and removes PKCS7 padding from decrypted data. +// It returns an error if the padding byte value is out of range (0 or >16) or +// if any of the trailing padding bytes do not equal the padding value. +func pkcs7Unpad(data []byte) ([]byte, error) { + if len(data) == 0 { + return nil, errors.New("pkcs7Unpad: data is empty") + } + + padLen := int(data[len(data)-1]) + if padLen < 1 || padLen > defaultBlockSize { + return nil, fmt.Errorf("pkcs7Unpad: invalid padding byte %d", padLen) + } + if padLen > len(data) { + return nil, fmt.Errorf("pkcs7Unpad: padding length %d exceeds data length %d", padLen, len(data)) + } + + // Validate that every padding byte equals padLen. + for i := len(data) - padLen; i < len(data); i++ { + if data[i] != byte(padLen) { + return nil, fmt.Errorf("pkcs7Unpad: invalid padding at byte %d: got %d, want %d", + i, data[i], padLen) + } + } + + return data[:len(data)-padLen], nil +} diff --git a/internal/crypto/crypto_test.go b/internal/crypto/crypto_test.go new file mode 100644 index 0000000..556a12e --- /dev/null +++ b/internal/crypto/crypto_test.go @@ -0,0 +1,450 @@ +// Package crypto tests verify byte-for-byte compatibility with the Ruby +// OpenSSL AES-256-CBC reference implementation in geheim.rb. +// +// Golden hex values were generated with: +// +// ruby -e ' +// require "openssl" +// def enforce_key(key, size) +// k = key.dup; k += key while k.size < size; k[0, size] +// end +// def do_enc(plain, pin, key_content, add_to_iv="Hello world", key_length=32) +// key = enforce_key(key_content, key_length) +// iv_str = pin * 2 + add_to_iv + pin * 2 +// iv = iv_str.byteslice(0, 16) +// aes = OpenSSL::Cipher.new("AES-256-CBC") +// aes.encrypt; aes.key = key; aes.iv = iv +// ct = aes.update(plain) + aes.final +// puts ct.bytes.map { |b| "%02x" % b }.join +// end +// do_enc("Hello, world!", "1234", "shortkey") +// do_enc("Hello, world!", "ab", "x" * 32) +// do_enc("Hello, world!", "abcd1234", "y" * 64) +// do_enc("a" * 16, "1234", "shortkey") +// do_enc("b" * 48, "1234", "shortkey") +// do_enc("\x00\x01\x02\xff", "1234", "shortkey") +// ' +package crypto + +import ( + "encoding/hex" + "os" + "path/filepath" + "testing" +) + +// --- helpers ----------------------------------------------------------------- + +// writeKeyFile writes content to a temporary file and returns the path. +// The file is removed when the test completes. +func writeKeyFile(t *testing.T, content string) string { + t.Helper() + dir := t.TempDir() + path := filepath.Join(dir, "keyfile") + if err := os.WriteFile(path, []byte(content), 0o600); err != nil { + t.Fatalf("writeKeyFile: %v", err) + } + return path +} + +// mustHex decodes a hex string, failing the test on any error. +func mustHex(t *testing.T, s string) []byte { + t.Helper() + b, err := hex.DecodeString(s) + if err != nil { + t.Fatalf("mustHex(%q): %v", s, err) + } + return b +} + +// --- TestEnforceKeyLength ---------------------------------------------------- + +// TestEnforceKeyLength covers the four interesting edge cases for the key +// extension algorithm that mirrors Ruby's `enforce_key_length`. +func TestEnforceKeyLength(t *testing.T) { + cases := []struct { + name string + key []byte + size int + want []byte + }{ + { + name: "key shorter than size — doubled until long enough then truncated", + key: []byte("ab"), + size: 5, + want: []byte("ababa"), + }, + { + name: "key exact size — returned unchanged", + key: []byte("abcde"), + size: 5, + want: []byte("abcde"), + }, + { + name: "key longer than size — truncated", + key: []byte("abcdefgh"), + size: 5, + want: []byte("abcde"), + }, + { + name: "single-byte key expanded to 32 bytes", + key: []byte("x"), + size: 32, + want: []byte("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"), + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + got := enforceKeyLength(tc.key, tc.size) + if string(got) != string(tc.want) { + t.Errorf("enforceKeyLength(%q, %d) = %q; want %q", + tc.key, tc.size, got, tc.want) + } + }) + } +} + +// --- TestBuildIV ------------------------------------------------------------- + +// TestBuildIV verifies that the IV derivation matches the Ruby reference: +// +// iv_str = pin * 2 + add_to_iv + pin * 2, then byteslice(0, 16). +// +// Verified with Ruby: pin="1234" → "12341234Hello wo"; pin="ab" → "ababHello worlda" +func TestBuildIV(t *testing.T) { + cases := []struct { + pin string + addToIV string + // wantStr is the expected string content of the 16-byte IV. + // Verified via: ruby -e 'pin="X"; iv_str=pin*2+"ADD"+pin*2; p iv_str.byteslice(0,16)' + wantStr string + }{ + { + // "1234"*2 + "Hello world" + "1234"*2 = "12341234Hello world12341234" + // first 16 bytes: "12341234Hello wo" + pin: "1234", + addToIV: "Hello world", + wantStr: "12341234Hello wo", + }, + { + // "ab"*2 + "Hello world" + "ab"*2 = "ababHello worldabab" + // first 16 bytes: "ababHello worlda" + pin: "ab", + addToIV: "Hello world", + wantStr: "ababHello worlda", + }, + { + // pin="" → addToIV fills all 16 bytes + pin: "", + addToIV: "0123456789abcdef", + wantStr: "0123456789abcdef", + }, + } + + for _, tc := range cases { + t.Run(tc.pin+"|"+tc.addToIV, func(t *testing.T) { + got := buildIV(tc.pin, tc.addToIV) + if len(got) != 16 { + t.Errorf("buildIV returned %d bytes; want 16", len(got)) + } + gotHex := hex.EncodeToString(got) + wantHex := hex.EncodeToString([]byte(tc.wantStr)) + if gotHex != wantHex { + t.Errorf("buildIV(%q, %q)\n got hex=%q (%q)\n want hex=%q (%q)", + tc.pin, tc.addToIV, gotHex, got, wantHex, tc.wantStr) + } + }) + } +} + +// --- TestPKCS7PadUnpad ------------------------------------------------------- + +// TestPKCS7PadUnpad checks padding for inputs of various lengths including the +// critical case where the input is already block-aligned (must add a full extra +// block), and validates that pkcs7Unpad rejects corrupted padding. +func TestPKCS7PadUnpad(t *testing.T) { + t.Run("pad 15-byte input to 16", func(t *testing.T) { + data := make([]byte, 15) + got := pkcs7Pad(data, 16) + if len(got) != 16 { + t.Fatalf("expected 16 bytes; got %d", len(got)) + } + if got[15] != 0x01 { + t.Errorf("last byte = 0x%02x; want 0x01", got[15]) + } + }) + + t.Run("pad 16-byte input adds full extra block of 0x10", func(t *testing.T) { + data := make([]byte, 16) + got := pkcs7Pad(data, 16) + if len(got) != 32 { + t.Fatalf("expected 32 bytes; got %d", len(got)) + } + // All 16 padding bytes must equal 0x10. + for i := 16; i < 32; i++ { + if got[i] != 0x10 { + t.Errorf("padding byte %d = 0x%02x; want 0x10", i, got[i]) + } + } + }) + + t.Run("pad 0-byte input to 16 bytes of 0x10", func(t *testing.T) { + got := pkcs7Pad([]byte{}, 16) + if len(got) != 16 { + t.Fatalf("expected 16 bytes; got %d", len(got)) + } + for i, b := range got { + if b != 0x10 { + t.Errorf("byte %d = 0x%02x; want 0x10", i, b) + } + } + }) + + t.Run("unpad valid padding — 3 data bytes, 1 byte of padding 0x01", func(t *testing.T) { + // 15 data bytes (0x00..0x0e) followed by 0x01 padding → unpad yields 15 bytes. + data := []byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d, 0x0e, 0x01, + } + got, err := pkcs7Unpad(data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got) != 15 { + t.Errorf("expected 15 bytes after unpad; got %d", len(got)) + } + }) + + t.Run("unpad valid padding — 12 data bytes, 4 bytes of padding 0x04", func(t *testing.T) { + // 12 data bytes followed by four 0x04 bytes → unpad yields 12 bytes. + data := []byte{ + 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, + 0x08, 0x09, 0x0a, 0x0b, 0x04, 0x04, 0x04, 0x04, + } + got, err := pkcs7Unpad(data) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if len(got) != 12 { + t.Errorf("expected 12 bytes after unpad; got %d", len(got)) + } + }) + + t.Run("unpad invalid padding byte value 0", func(t *testing.T) { + // A padding byte of 0 is never valid in PKCS7. + data := make([]byte, 16) // all zeros → last byte is 0 + _, err := pkcs7Unpad(data) + if err == nil { + t.Error("expected error for padding byte 0; got nil") + } + }) + + t.Run("unpad invalid padding byte value 17 (> blockSize)", func(t *testing.T) { + data := make([]byte, 16) + data[15] = 0x11 // 17 > block size of 16 + _, err := pkcs7Unpad(data) + if err == nil { + t.Error("expected error for padding byte 17; got nil") + } + }) + + t.Run("unpad corrupted padding bytes", func(t *testing.T) { + // Claim 3 bytes of padding but the second-to-last bytes don't match. + data := []byte{0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x99, 0x03, 0x03} + _, err := pkcs7Unpad(data) + if err == nil { + t.Error("expected error for corrupted padding bytes; got nil") + } + }) +} + +// --- TestEncryptGolden ------------------------------------------------------- + +// TestEncryptGolden compares Go Encrypt output against hex values generated by +// the Ruby OpenSSL reference implementation, ensuring byte-for-byte output +// compatibility. +func TestEncryptGolden(t *testing.T) { + cases := []struct { + name string + plaintext []byte + pin string + keyData string + addToIV string + keyLength int + wantHex string + }{ + { + name: "Hello world / pin=1234 / shortkey", + plaintext: []byte("Hello, world!"), + pin: "1234", + keyData: "shortkey", + addToIV: "Hello world", + keyLength: 32, + wantHex: "78c08330c963e089ab15700bf9453700", + }, + { + name: "Hello world / pin=ab / 32x 'x'", + plaintext: []byte("Hello, world!"), + pin: "ab", + keyData: "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", // 32 bytes + addToIV: "Hello world", + keyLength: 32, + wantHex: "6190f985f42374d24dd8e17b3b2d6057", + }, + { + name: "Hello world / pin=abcd1234 / 64x 'y'", + plaintext: []byte("Hello, world!"), + pin: "abcd1234", + // 64 bytes of 'y': key is already 2x the required 32 bytes so it gets truncated. + keyData: "yyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyyy", + addToIV: "Hello world", + keyLength: 32, + wantHex: "b2a0c920a53676a3d32c7422e5f7fb4f", + }, + { + name: "16x 'a' (block-aligned) / pin=1234 / shortkey", + // A block-aligned plaintext still gets a full extra block of padding. + plaintext: []byte("aaaaaaaaaaaaaaaa"), + pin: "1234", + keyData: "shortkey", + addToIV: "Hello world", + keyLength: 32, + wantHex: "8968368e480298e8c3273c5d6169f57cf4827f5e4697c2772428c0e603487367", + }, + { + name: "48x 'b' / pin=1234 / shortkey", + plaintext: []byte("bbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbbb"), + pin: "1234", + keyData: "shortkey", + addToIV: "Hello world", + keyLength: 32, + // 128 hex chars = 64 bytes (48-byte plaintext + 16-byte padding block) + wantHex: "3c3cb309ea80422e958e454528f965fc40e46409ebb0ee459c769dd2be14976938a1551e3b907e7cb165da78196caa3cf50bdf6fcbe1128c14a39024f84eb168", + }, + { + name: "binary input 00 01 02 ff / pin=1234 / shortkey", + plaintext: []byte{0x00, 0x01, 0x02, 0xff}, + pin: "1234", + keyData: "shortkey", + addToIV: "Hello world", + keyLength: 32, + wantHex: "ccaf7cb5d2ce8703e20716beb9ecfc82", + }, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + keyFile := writeKeyFile(t, tc.keyData) + c, err := NewCipher(keyFile, tc.keyLength, tc.pin, tc.addToIV) + if err != nil { + t.Fatalf("NewCipher: %v", err) + } + + got, err := c.Encrypt(tc.plaintext) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + + gotHex := hex.EncodeToString(got) + if gotHex != tc.wantHex { + t.Errorf("Encrypt mismatch:\n got %s\n want %s", gotHex, tc.wantHex) + } + }) + } +} + +// --- TestEncryptDecryptRoundtrip --------------------------------------------- + +// TestEncryptDecryptRoundtrip verifies that Decrypt(Encrypt(plain)) == plain +// for a variety of inputs. It does not rely on golden values, so it catches +// padding or mode errors that the golden test might miss if both paths share +// the same bug. +func TestEncryptDecryptRoundtrip(t *testing.T) { + cases := []struct { + name string + plaintext []byte + pin string + keyData string + }{ + {"empty input", []byte{}, "pin", "somekey"}, + {"short ASCII", []byte("hello"), "1234", "shortkey"}, + {"exactly 16 bytes", []byte("0123456789abcdef"), "1234", "shortkey"}, + {"17 bytes", []byte("0123456789abcdefX"), "pin99", "mykey"}, + {"32 bytes", make([]byte, 32), "abcd", "k"}, + {"binary data", []byte{0x00, 0x01, 0x02, 0xfe, 0xff}, "zz", "binarykey"}, + {"127 bytes", make([]byte, 127), "longpin12345678", "keydata"}, + {"128 bytes", make([]byte, 128), "x", "y"}, + } + + for _, tc := range cases { + t.Run(tc.name, func(t *testing.T) { + keyFile := writeKeyFile(t, tc.keyData) + c, err := NewCipher(keyFile, 32, tc.pin, "Hello world") + if err != nil { + t.Fatalf("NewCipher: %v", err) + } + + ciphertext, err := c.Encrypt(tc.plaintext) + if err != nil { + t.Fatalf("Encrypt: %v", err) + } + + recovered, err := c.Decrypt(ciphertext) + if err != nil { + t.Fatalf("Decrypt: %v", err) + } + + if string(recovered) != string(tc.plaintext) { + t.Errorf("roundtrip mismatch:\n got %x\n want %x", recovered, tc.plaintext) + } + }) + } +} + +// --- TestNewCipherErrors ----------------------------------------------------- + +// TestNewCipherErrors exercises the error paths in NewCipher. +func TestNewCipherErrors(t *testing.T) { + t.Run("non-existent key file", func(t *testing.T) { + _, err := NewCipher("/nonexistent/path/keyfile", 32, "pin", "addiv") + if err == nil { + t.Error("expected error for missing key file; got nil") + } + }) + + t.Run("empty key file", func(t *testing.T) { + keyFile := writeKeyFile(t, "") + _, err := NewCipher(keyFile, 32, "pin", "addiv") + if err == nil { + t.Error("expected error for empty key file; got nil") + } + }) +} + +// --- TestDecryptErrors ------------------------------------------------------- + +// TestDecryptErrors verifies that Decrypt returns sensible errors for +// malformed input rather than panicking or silently returning garbage. +func TestDecryptErrors(t *testing.T) { + keyFile := writeKeyFile(t, "somekey") + c, err := NewCipher(keyFile, 32, "1234", "Hello world") + if err != nil { + t.Fatalf("NewCipher: %v", err) + } + + t.Run("empty ciphertext", func(t *testing.T) { + _, err := c.Decrypt([]byte{}) + if err == nil { + t.Error("expected error for empty ciphertext; got nil") + } + }) + + t.Run("ciphertext not multiple of block size", func(t *testing.T) { + _, err := c.Decrypt(mustHex(t, "deadbeef01")) + if err == nil { + t.Error("expected error for non-aligned ciphertext; got nil") + } + }) +} |
