From 831137abdecfcafeb21fb5f3de45156819f35ed4 Mon Sep 17 00:00:00 2001 From: Paul Buetow Date: Sun, 18 Jun 2023 19:59:20 +0300 Subject: refactor --- internal/client/tcpclient.go | 6 ++--- internal/iorw/iorw.go | 36 +++++++++++++++++++++++++ internal/iorw/iorw_test.go | 64 ++++++++++++++++++++++++++++++++++++++++++++ internal/server/tcpserver.go | 6 ++--- internal/tcp/tcp.go | 44 ------------------------------ internal/tcp/tcp_test.go | 63 ------------------------------------------- 6 files changed, 106 insertions(+), 113 deletions(-) create mode 100644 internal/iorw/iorw.go create mode 100644 internal/iorw/iorw_test.go delete mode 100644 internal/tcp/tcp.go delete mode 100644 internal/tcp/tcp_test.go (limited to 'internal') diff --git a/internal/client/tcpclient.go b/internal/client/tcpclient.go index 3181fd4..db4334f 100644 --- a/internal/client/tcpclient.go +++ b/internal/client/tcpclient.go @@ -6,7 +6,7 @@ import ( "log" "net" - "codeberg.org/snonux/gorum/internal/tcp" + "codeberg.org/snonux/gorum/internal/iorw" "codeberg.org/snonux/gorum/internal/vote" ) @@ -29,11 +29,11 @@ func tcpClientRun(ctx context.Context, node string, ch <-chan vote.Vote) error { } log.Println("tcpclient: sending", message, "to node", node) - if err := tcp.WriteStr(conn, message); err != nil { + if err := iorw.WriteStr(conn, message); err != nil { return err } - response, err := tcp.ReadStr(conn) + response, err := iorw.ReadStr(conn) if err != nil { return err } diff --git a/internal/iorw/iorw.go b/internal/iorw/iorw.go new file mode 100644 index 0000000..a91a5b2 --- /dev/null +++ b/internal/iorw/iorw.go @@ -0,0 +1,36 @@ +package iorw + +import ( + "encoding/binary" + "io" +) + +func WriteStr(w io.Writer, message string) error { + messageBytes := []byte(message) + sizeBytes := make([]byte, 8) + binary.BigEndian.PutUint64(sizeBytes, uint64(len(messageBytes))) + + if _, err := w.Write(sizeBytes); err != nil { + return err + } + if _, err := w.Write(messageBytes); err != nil { + return err + } + + return nil +} + +func ReadStr(r io.Reader) (string, error) { + sizeBytes := make([]byte, 8) + if _, err := io.ReadFull(r, sizeBytes); err != nil { + return "", err + } + messageSize := binary.BigEndian.Uint64(sizeBytes) + + messageBytes := make([]byte, messageSize) + if _, err := io.ReadFull(r, messageBytes); err != nil { + return "", err + } + + return string(messageBytes), nil +} diff --git a/internal/iorw/iorw_test.go b/internal/iorw/iorw_test.go new file mode 100644 index 0000000..c54c059 --- /dev/null +++ b/internal/iorw/iorw_test.go @@ -0,0 +1,64 @@ +package iorw + +import ( + "testing" +) + +type readWriteTest struct { + sizeWritten *bool + sizeRead *bool + sizeBytes []byte + messageBytes []byte +} + +func (rwt readWriteTest) Write(b []byte) (n int, err error) { + if !*rwt.sizeWritten { + copy(rwt.sizeBytes, b) + *rwt.sizeWritten = true + } else { + copy(rwt.messageBytes, b) + } + + return len(b), nil +} + +func (rwt readWriteTest) Read(b []byte) (n int, err error) { + if !*rwt.sizeRead { + copy(b, rwt.sizeBytes) + *rwt.sizeRead = true + } else { + copy(b, rwt.messageBytes) + } + return len(b), nil +} + +func TestReadWrite(t *testing.T) { + t.Parallel() + + var ( + message = "Hello world!" + sizeWritten = false + sizeRead = false + ) + + rwt := readWriteTest{ + sizeWritten: &sizeWritten, + sizeRead: &sizeRead, + sizeBytes: make([]byte, 8), + messageBytes: make([]byte, len([]byte(message))), + } + + if err := WriteStr(rwt, message); err != nil { + t.Errorf(err.Error()) + } + + response, err := ReadStr(rwt) + if err != nil { + t.Errorf(err.Error()) + } + + if response != message { + t.Errorf("Expected response '%s' to be equal to original message '%s'!", + response, message) + } +} diff --git a/internal/server/tcpserver.go b/internal/server/tcpserver.go index 21324c3..d952972 100644 --- a/internal/server/tcpserver.go +++ b/internal/server/tcpserver.go @@ -7,7 +7,7 @@ import ( "net" "codeberg.org/snonux/gorum/internal/config" - "codeberg.org/snonux/gorum/internal/tcp" + "codeberg.org/snonux/gorum/internal/iorw" ) type handlerCb func(message string) string @@ -49,7 +49,7 @@ func handleConnection(ctx context.Context, conn net.Conn, cb handlerCb) { log.Println("server: context done, disconnecting client:", remoteAddr) return default: - message, err := tcp.ReadStr(conn) + message, err := iorw.ReadStr(conn) if err != nil { log.Println("server: unable to read message", remoteAddr, err) return @@ -58,7 +58,7 @@ func handleConnection(ctx context.Context, conn net.Conn, cb handlerCb) { log.Println("server: received message", message, "from", remoteAddr) response := cb(message) - if err := tcp.WriteStr(conn, response); err != nil { + if err := iorw.WriteStr(conn, response); err != nil { log.Println("error:", err) } } diff --git a/internal/tcp/tcp.go b/internal/tcp/tcp.go deleted file mode 100644 index 3f9bafc..0000000 --- a/internal/tcp/tcp.go +++ /dev/null @@ -1,44 +0,0 @@ -package tcp - -import ( - "encoding/binary" - "io" -) - -type Writer interface { - Write(b []byte) (n int, err error) -} - -type Reader interface { - Read(b []byte) (n int, err error) -} - -func WriteStr(w Writer, message string) error { - messageBytes := []byte(message) - sizeBytes := make([]byte, 8) - binary.BigEndian.PutUint64(sizeBytes, uint64(len(messageBytes))) - - if _, err := w.Write(sizeBytes); err != nil { - return err - } - if _, err := w.Write(messageBytes); err != nil { - return err - } - - return nil -} - -func ReadStr(r Reader) (string, error) { - sizeBytes := make([]byte, 8) - if _, err := io.ReadFull(r, sizeBytes); err != nil { - return "", err - } - messageSize := binary.BigEndian.Uint64(sizeBytes) - - messageBytes := make([]byte, messageSize) - if _, err := io.ReadFull(r, messageBytes); err != nil { - return "", err - } - - return string(messageBytes), nil -} diff --git a/internal/tcp/tcp_test.go b/internal/tcp/tcp_test.go deleted file mode 100644 index 668758c..0000000 --- a/internal/tcp/tcp_test.go +++ /dev/null @@ -1,63 +0,0 @@ -package tcp - -import ( - "testing" -) - -type readWriteTest struct { - sizeWritten *bool - sizeRead *bool - sizeBytes []byte - messageBytes []byte -} - -func (rwt readWriteTest) Write(b []byte) (n int, err error) { - if !*rwt.sizeWritten { - copy(rwt.sizeBytes, b) - *rwt.sizeWritten = true - } else { - copy(rwt.messageBytes, b) - } - - return len(b), nil -} - -func (rwt readWriteTest) Read(b []byte) (n int, err error) { - if !*rwt.sizeRead { - copy(b, rwt.sizeBytes) - *rwt.sizeRead = true - } else { - copy(b, rwt.messageBytes) - } - return len(b), nil -} - -func TestReadWrite(t *testing.T) { - t.Parallel() - - message := "Hello world!" - - var sizeWritten bool - var sizeRead bool - - rwt := readWriteTest{ - sizeWritten: &sizeWritten, - sizeRead: &sizeRead, - sizeBytes: make([]byte, 8), - messageBytes: make([]byte, len([]byte(message))), - } - - if err := WriteStr(rwt, message); err != nil { - t.Errorf(err.Error()) - } - - response, err := ReadStr(rwt) - if err != nil { - t.Errorf(err.Error()) - } - - if response != message { - t.Errorf("Expected response '%s' to be equal to original message '%s'!", - response, message) - } -} -- cgit v1.2.3