summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorPaul Buetow <paul@buetow.org>2026-03-25 21:58:36 +0200
committerPaul Buetow <paul@buetow.org>2026-03-25 21:58:36 +0200
commit9ee6293b175710e7cf1ff5bc9a6dc555ddaf559f (patch)
tree34da64728c21c195147969eda153774db6455d62
parent25b2728a6499b427e201fb80d8a70e65909645e0 (diff)
code-quality: Various improvements to code quality and thread safety
-rw-r--r--FIXES.md10
-rw-r--r--cmd/gt/main.go2
-rw-r--r--internal/perc/perc.go8
-rw-r--r--internal/repl/concurrent_test.go62
-rw-r--r--internal/repl/handlers.go10
-rw-r--r--internal/repl/repl.go117
-rw-r--r--internal/repl/repl_test.go12
-rw-r--r--internal/rpn/operations.go88
-rw-r--r--internal/rpn/rpn_ops.go8
-rw-r--r--internal/rpn/rpn_parse.go9
-rw-r--r--internal/rpn/rpn_state.go19
-rw-r--r--internal/rpn/rpn_test.go101
12 files changed, 371 insertions, 75 deletions
diff --git a/FIXES.md b/FIXES.md
index baea6b7..d6cc694 100644
--- a/FIXES.md
+++ b/FIXES.md
@@ -216,9 +216,9 @@ func getRPNState() *RPNState {
---
-## Fix #5: Improved Error Context in Calculator
+## Fix #5: Improved Error Context in Percentage Calculator
-**File:** `internal/calculator/calculator.go`
+**File:** `internal/perc/perc.go`
**Issue:** Errors not wrapped with context
**Location:** Various places
@@ -226,7 +226,7 @@ func getRPNState() *RPNState {
```go
func Parse(input string) (string, error) {
// ...
- result, err := calculator.Parse(input)
+ result, err := perc.Parse(input)
if err != nil {
return "", err // Missing context
}
@@ -238,7 +238,7 @@ func Parse(input string) (string, error) {
```go
func Parse(input string) (string, error) {
// ...
- result, err := calculator.Parse(input)
+ result, err := perc.Parse(input)
if err != nil {
return "", fmt.Errorf("rpn fallback failed for input %q: %w", input, err)
}
@@ -371,7 +371,7 @@ golangci-lint run
| #1 | `rpn.go` | Error wrapping | Better debugging |
| #3 | `repl.go` | Proper resource cleanup | No resource leaks |
| #4 | `repl.go` | Mutex safety | Thread safety |
-| #5 | `calculator.go` | Error context | Better error messages |
+| #5 | `perc.go` | Error context | Better error messages |
| #6 | `variables.go` | Slice handling | Memory management |
| #7 | `variables.go` | Performance optimization | Reduced allocations |
diff --git a/cmd/gt/main.go b/cmd/gt/main.go
index 5ace4c5..15372d8 100644
--- a/cmd/gt/main.go
+++ b/cmd/gt/main.go
@@ -39,7 +39,7 @@
//
// The package uses a layered architecture:
// - main.go: Entry point and command routing
-// - calculator/: Handles percentage calculation parsing
+// - perc/: Handles percentage calculation parsing
// - rpn/: Handles RPN expression parsing and evaluation
// - repl/: Provides interactive Read-Eval-Print Loop mode
//
diff --git a/internal/perc/perc.go b/internal/perc/perc.go
index 9921ce4..63a171e 100644
--- a/internal/perc/perc.go
+++ b/internal/perc/perc.go
@@ -120,13 +120,13 @@ func ParseCalculation(input string) (*Calculation, error) {
registry.register(parseXIsWhatPercentOfY)
registry.register(parseXIsYPercentOfWhat)
- calc, ok, err := registry.parse(input)
- if ok {
- return calc, nil
- }
+ calc, _, err := registry.parse(input)
if err != nil {
return nil, err
}
+ if calc != nil {
+ return calc, nil
+ }
return nil, fmt.Errorf("perc: unable to parse input %q. See usage for examples", input)
}
diff --git a/internal/repl/concurrent_test.go b/internal/repl/concurrent_test.go
new file mode 100644
index 0000000..14c82ef
--- /dev/null
+++ b/internal/repl/concurrent_test.go
@@ -0,0 +1,62 @@
+package repl
+
+import (
+ "sync"
+ "testing"
+)
+
+func TestConcurrentExecutor(t *testing.T) {
+ // Test concurrent calls to executor()
+ var wg sync.WaitGroup
+ for i := 0; i < 10; i++ {
+ wg.Add(1)
+ go func(id int) {
+ defer wg.Done()
+ executor("20% of 150")
+ }(i)
+ }
+ wg.Wait()
+}
+
+func TestConcurrentRPN(t *testing.T) {
+ // Test concurrent calls to runRPN()
+ var wg sync.WaitGroup
+ for i := 0; i < 10; i++ {
+ wg.Add(1)
+ go func(id int) {
+ defer wg.Done()
+ runRPN("3 4 +")
+ }(i)
+ }
+ wg.Wait()
+}
+
+func TestConcurrentRatModeToggle(t *testing.T) {
+ // Test concurrent calls to executor() that change mode
+ var wg sync.WaitGroup
+ for i := 0; i < 10; i++ {
+ wg.Add(1)
+ go func(id int) {
+ defer wg.Done()
+ executor("rat toggle")
+ }(i)
+ }
+ wg.Wait()
+}
+
+func TestConcurrentExecutorAndRPN(t *testing.T) {
+ // Test concurrent calls to executor() and runRPN()
+ var wg sync.WaitGroup
+ for i := 0; i < 5; i++ {
+ wg.Add(2)
+ go func(id int) {
+ defer wg.Done()
+ executor("20% of 150")
+ }(i)
+ go func(id int) {
+ defer wg.Done()
+ runRPN("3 4 +")
+ }(i)
+ }
+ wg.Wait()
+}
diff --git a/internal/repl/handlers.go b/internal/repl/handlers.go
index 2cf62c4..663fd84 100644
--- a/internal/repl/handlers.go
+++ b/internal/repl/handlers.go
@@ -104,7 +104,7 @@ func handleRatCommand(repl *REPL, input string) (string, bool, error) {
}
modeArg := strings.ToLower(args[1])
- rpnState := repl.getRPNState()
+ rpnState := repl.rpnState
switch modeArg {
case "on":
@@ -151,7 +151,7 @@ func (h *RPNHandler) Handle(repl *REPL, input string) (output string, handled bo
if strings.HasPrefix(lowerInput, "rpn ") || strings.HasPrefix(lowerInput, "calc ") {
// Extract the expression after rpn/calc
rest := strings.TrimSpace(strings.TrimPrefix(input, strings.SplitN(input, " ", 2)[0]))
- result, err := repl.runRPN(rest)
+ result, err := repl.rpnState.rpnCalc.ParseAndEvaluate(rest)
if err != nil {
return "", true, err
}
@@ -159,10 +159,10 @@ func (h *RPNHandler) Handle(repl *REPL, input string) (output string, handled bo
}
// Try RPN parsing first (for bare RPN expressions like "3 4 +")
- if state := repl.getRPNState(); state != nil {
+ if state := repl.rpnState; state != nil {
// Check if input looks like RPN (contains spaces or is a single known operator)
if strings.Contains(input, " ") {
- result, err := repl.runRPN(input)
+ result, err := state.rpnCalc.ParseAndEvaluate(input)
if err == nil {
return result, true, nil
}
@@ -206,7 +206,7 @@ func (h *RPNHandler) Handle(repl *REPL, input string) (output string, handled bo
}
// PercentageHandler handles percentage calculation expressions.
-// It uses the calculator.Parse function to evaluate expressions like:
+// It uses the perc.Parse function to evaluate expressions like:
// - "20% of 150"
// - "what is 20% of 150"
// - "30 is what % of 150"
diff --git a/internal/repl/repl.go b/internal/repl/repl.go
index 55fff97..243cf90 100644
--- a/internal/repl/repl.go
+++ b/internal/repl/repl.go
@@ -22,13 +22,25 @@ type RPNState struct {
rpnCalc *rpn.RPN
}
-// rpnState holds the singleton RPN state for REPL operations.
-// It is initialized lazily using sync.Once to ensure thread-safe initialization.
-var rpnState *RPNState
-
-// rpnStateOnce ensures rpnState is initialized exactly once.
-// It's used by getRPNState to guarantee lazy singleton initialization.
-var rpnStateOnce sync.Once
+// executorREPL holds the REPL instance created by the executor function.
+// This is used for backward compatibility with tests that need to access RPN state
+// after calling executor(). It's not part of the main REPL architecture.
+// Thread safety: Use executorREPLOnce for lazy initialization and executorREPLMu for access.
+var executorREPL *REPL
+var executorREPLOnce sync.Once
+var executorREPLMu sync.Mutex
+
+// ResetExecutorREPL resets the executorREPL for clean test isolation.
+// This should be called between tests that use executor() and getRPNState()
+// to ensure each test starts with a fresh RPN state.
+//
+// Note: This function is intended for test use only and should not be used
+// in production code. For production use, create new REPL instances with NewREPL().
+func ResetExecutorREPL() {
+ executorREPLMu.Lock()
+ defer executorREPLMu.Unlock()
+ executorREPL = nil
+}
// REPL manages the interactive command-line interface for the percentage calculator.
// It provides an interactive prompt with history, tab-completion, signal handling,
@@ -39,12 +51,14 @@ var rpnStateOnce sync.Once
// - HistoryManager: manages command history persistence
// - SignalHandler: handles SIGINT (Ctrl+C)
// - commandChain: processes commands via chain of responsibility
+// - rpnState: provides RPN state for calculations
type REPL struct {
ttyChecker *TTYChecker
historyMgr *HistoryManager
signalHandler *SignalHandler
prompt *prompt.Prompt
commandChain CommandHandler
+ rpnState *RPNState
}
// NewREPL creates a new REPL instance with default components.
@@ -54,11 +68,19 @@ type REPL struct {
// The executor function is called for each non-empty input line.
// The completer function provides tab-completion suggestions for the prompt.
func NewREPL(executor func(string), completer func(prompt.Document) []prompt.Suggest) *REPL {
+ // Initialize RPN state via dependency injection
+ vars := rpn.NewVariables()
+ rpnState := &RPNState{
+ vars: vars,
+ rpnCalc: rpn.NewRPN(vars),
+ }
+
repl := &REPL{
ttyChecker: &TTYChecker{},
historyMgr: NewHistoryManager(".gt_history"),
signalHandler: NewSignalHandler(),
commandChain: NewCommandChain(),
+ rpnState: rpnState,
}
// Set up executor - if nil, use default
@@ -226,59 +248,68 @@ func RunREPL() error {
// commands via the chain of responsibility pattern, including percentage
// calculations, RPN expressions, and built-in commands.
func executor(input string) {
- // Create a minimal REPL instance without building a prompt
- r := &REPL{
- ttyChecker: &TTYChecker{},
- historyMgr: NewHistoryManager(".gt_history"),
- signalHandler: NewSignalHandler(),
- commandChain: NewCommandChain(),
- }
- defaultExecutor(r, input)
-}
-
-// getRPNState returns or creates the RPN state using lazy initialization.
-// It's thread-safe using sync.Once to ensure the RPN state is initialized exactly once.
-// The RPN state is shared across all REPL instances.
-//
-// Returns the RPNState instance for performing RPN calculations
-func getRPNState() *RPNState {
- rpnStateOnce.Do(func() {
+ // Initialize executorREPL only once using sync.Once for thread-safe lazy initialization
+ executorREPLOnce.Do(func() {
vars := rpn.NewVariables()
- rpnState = &RPNState{
+ rpnState := &RPNState{
vars: vars,
rpnCalc: rpn.NewRPN(vars),
}
+
+ // Create a minimal REPL instance without building a prompt
+ executorREPL = &REPL{
+ ttyChecker: &TTYChecker{},
+ historyMgr: NewHistoryManager(".gt_history"),
+ signalHandler: NewSignalHandler(),
+ commandChain: NewCommandChain(),
+ rpnState: rpnState,
+ }
})
- return rpnState
-}
-// getRPNState returns the RPN state.
-// This is a REPL instance method for backward compatibility that delegates to the package-level getRPNState.
-//
-// Returns the RPNState instance for performing RPN calculations
-func (r *REPL) getRPNState() *RPNState {
- return getRPNState()
+ // Use mutex to protect access to executorREPL during execution
+ executorREPLMu.Lock()
+ repl := executorREPL
+ executorREPLMu.Unlock()
+
+ defaultExecutor(repl, input)
}
// runRPN parses and evaluates an RPN (Reverse Polish Notation) expression.
-// It uses the shared RPN state to maintain stack state across multiple calls.
+// This is a package-level wrapper for backward compatibility that delegates to
+// the executor's REPL runRPN method.
//
-// input: the RPN expression to evaluate (e.g., "3 4 +" or "x 5 = x x +")
+// input: the RPN expression to evaluate
// Returns the result string and an error if the expression is invalid
func runRPN(input string) (string, error) {
- state := getRPNState()
- return state.rpnCalc.ParseAndEvaluate(input)
+ executorREPLMu.Lock()
+ defer executorREPLMu.Unlock()
+
+ if executorREPL != nil {
+ return executorREPL.rpnState.rpnCalc.ParseAndEvaluate(input)
+ }
+ return "", fmt.Errorf("no executor REPL available - call executor() first")
}
-// runRPN parses and evaluates an RPN (Reverse Polish Notation) expression.
-// This is a REPL instance method for backward compatibility that delegates to the package-level runRPN.
+// getRPNState returns the RPN state from the executor's REPL.
+// This is a package-level helper for backward compatibility with tests that need
+// to access RPN state after calling executor(). It's not part of the main REPL
+// architecture.
//
-// input: the RPN expression to evaluate
-// Returns the result string and an error if the expression is invalid
-func (r *REPL) runRPN(input string) (string, error) {
- return runRPN(input)
+// Returns the RPNState instance from the last executor() call, or nil if executor() hasn't been called
+func getRPNState() *RPNState {
+ executorREPLMu.Lock()
+ defer executorREPLMu.Unlock()
+
+ if executorREPL != nil {
+ return executorREPL.rpnState
+ }
+ return nil
}
+
+
+
+
// getHistoryPath returns the absolute path to the history file.
// This is a package-level wrapper for backward compatibility.
// The history file is stored in the user's home directory.
diff --git a/internal/repl/repl_test.go b/internal/repl/repl_test.go
index 2200b2a..cd062cb 100644
--- a/internal/repl/repl_test.go
+++ b/internal/repl/repl_test.go
@@ -470,11 +470,17 @@ func TestRPNHandlerWithPercentageExpression(t *testing.T) {
func TestRPNHandlerWithRPNExpression(t *testing.T) {
// Test RPN expressions
chain := NewCommandChain()
+ vars := rpn.NewVariables()
+ rpnState := &RPNState{
+ vars: vars,
+ rpnCalc: rpn.NewRPN(vars),
+ }
r := &REPL{
ttyChecker: &TTYChecker{},
historyMgr: NewHistoryManager(".gt_history"),
signalHandler: NewSignalHandler(),
commandChain: chain,
+ rpnState: rpnState,
}
// Test RPN expression
@@ -490,11 +496,17 @@ func TestRPNHandlerWithRPNExpression(t *testing.T) {
func TestRPNHandlerWithSingleNumber(t *testing.T) {
// Test single number input (RPN - pushes number onto stack)
chain := NewCommandChain()
+ vars := rpn.NewVariables()
+ rpnState := &RPNState{
+ vars: vars,
+ rpnCalc: rpn.NewRPN(vars),
+ }
r := &REPL{
ttyChecker: &TTYChecker{},
historyMgr: NewHistoryManager(".gt_history"),
signalHandler: NewSignalHandler(),
commandChain: chain,
+ rpnState: rpnState,
}
// Test single number
diff --git a/internal/rpn/operations.go b/internal/rpn/operations.go
index 1300d99..81ea78e 100644
--- a/internal/rpn/operations.go
+++ b/internal/rpn/operations.go
@@ -6,6 +6,7 @@ package rpn
import (
"fmt"
"math"
+ "sync"
)
// ArithmeticOperator defines the interface for basic arithmetic operators.
@@ -56,6 +57,7 @@ type StackOperator interface {
type VariableOperator interface {
ListVariables() (string, error)
ClearVariables()
+ AssignVariableFromStack(stack *Stack) error
}
// Operator is the combined interface for all operator implementations.
@@ -74,8 +76,14 @@ type Operator interface {
type Operations struct {
vars VariableStore
mode CalculationMode
+ mu sync.RWMutex
}
+// Ensure Operations implements Operator at compile time.
+// This is an explicit interface satisfaction check that will fail to compile
+// if Operations doesn't implement all methods required by the Operator interface.
+var _ Operator = (*Operations)(nil)
+
// NewOperations creates a new Operations instance with the given variable store.
func NewOperations(vars VariableStore) *Operations {
return &Operations{
@@ -85,10 +93,21 @@ func NewOperations(vars VariableStore) *Operations {
}
// SetMode sets the calculation mode for the Operations instance.
+// This method is thread-safe for writes.
func (o *Operations) SetMode(mode CalculationMode) {
+ o.mu.Lock()
+ defer o.mu.Unlock()
o.mode = mode
}
+// GetMode returns the current calculation mode.
+// This method is thread-safe for reads.
+func (o *Operations) GetMode() CalculationMode {
+ o.mu.RLock()
+ defer o.mu.RUnlock()
+ return o.mode
+}
+
// OperatorHandler represents a function that handles an operator.
// Returns (result string, handled bool, error error).
// result is non-empty only for commands that return immediately (like show, vars).
@@ -129,6 +148,7 @@ func NewOperatorRegistry(op Operator) *OperatorRegistry {
registry.registerStandardOperator("==", func(stack *Stack) error { return op.EQ(stack) })
registry.registerStandardOperator("neq", func(stack *Stack) error { return op.NEQ(stack) })
registry.registerStandardOperator("!=", func(stack *Stack) error { return op.NEQ(stack) })
+ registry.registerStandardOperator("=", func(stack *Stack) error { return op.AssignVariableFromStack(stack) })
registry.registerStandardOperator("dup", func(stack *Stack) error { return op.Dup(stack) })
registry.registerStandardOperator("swap", func(stack *Stack) error { return op.Swap(stack) })
registry.registerStandardOperator("pop", func(stack *Stack) error { return op.Pop(stack) })
@@ -348,7 +368,8 @@ func (o *Operations) Log2(stack *Stack) error {
}
// Compute log2 using the number interface
- stack.Push(NewNumber(math.Log2(val), o.mode))
+ mode := o.GetMode()
+ stack.Push(NewNumber(math.Log2(val), mode))
return nil
}
@@ -367,7 +388,8 @@ func (o *Operations) Log10(stack *Stack) error {
}
// Compute log10 using the number interface
- stack.Push(NewNumber(math.Log10(val), o.mode))
+ mode := o.GetMode()
+ stack.Push(NewNumber(math.Log10(val), mode))
return nil
}
@@ -386,7 +408,8 @@ func (o *Operations) Ln(stack *Stack) error {
}
// Compute ln using the number interface
- stack.Push(NewNumber(math.Log(val), o.mode))
+ mode := o.GetMode()
+ stack.Push(NewNumber(math.Log(val), mode))
return nil
}
@@ -418,7 +441,8 @@ func (o *Operations) HyperAdd(stack *Stack) error {
for i := 0; i < len(values); i++ {
sum += values[i].Float64()
}
- stack.Push(NewNumber(sum, o.mode))
+ mode := o.GetMode()
+ stack.Push(NewNumber(sum, mode))
return nil
}
@@ -436,7 +460,8 @@ func (o *Operations) HyperMultiply(stack *Stack) error {
}
product *= val.Float64()
}
- stack.Push(NewNumber(product, o.mode))
+ mode := o.GetMode()
+ stack.Push(NewNumber(product, mode))
return nil
}
@@ -466,7 +491,8 @@ func (o *Operations) HyperSubtract(stack *Stack) error {
for i := 1; i < len(values); i++ {
result -= values[i].Float64()
}
- stack.Push(NewNumber(result, o.mode))
+ mode := o.GetMode()
+ stack.Push(NewNumber(result, mode))
return nil
}
@@ -500,7 +526,8 @@ func (o *Operations) HyperDivide(stack *Stack) error {
}
result /= val
}
- stack.Push(NewNumber(result, o.mode))
+ mode := o.GetMode()
+ stack.Push(NewNumber(result, mode))
return nil
}
@@ -530,7 +557,8 @@ func (o *Operations) HyperPower(stack *Stack) error {
for i := 1; i < len(values); i++ {
result = math.Pow(result, values[i].Float64())
}
- stack.Push(NewNumber(result, o.mode))
+ mode := o.GetMode()
+ stack.Push(NewNumber(result, mode))
return nil
}
@@ -564,7 +592,8 @@ func (o *Operations) HyperModulo(stack *Stack) error {
}
result = math.Mod(result, val)
}
- stack.Push(NewNumber(result, o.mode))
+ mode := o.GetMode()
+ stack.Push(NewNumber(result, mode))
return nil
}
@@ -602,7 +631,8 @@ func (o *Operations) HyperLog2(stack *Stack) error {
}
// Push the result as a Number
- stack.Push(NewNumber(result, o.mode))
+ mode := o.GetMode()
+ stack.Push(NewNumber(result, mode))
return nil
}
@@ -640,7 +670,8 @@ func (o *Operations) HyperLog10(stack *Stack) error {
}
// Push the result as a Number
- stack.Push(NewNumber(result, o.mode))
+ mode := o.GetMode()
+ stack.Push(NewNumber(result, mode))
return nil
}
@@ -676,7 +707,8 @@ func (o *Operations) HyperLn(stack *Stack) error {
}
result += math.Log(val)
}
- stack.Push(NewNumber(result, o.mode))
+ mode := o.GetMode()
+ stack.Push(NewNumber(result, mode))
return nil
}
@@ -878,7 +910,8 @@ func (o *Operations) UseVariable(stack *Stack, name string) error {
return fmt.Errorf("%w: %s", ErrVariableNotFound, name)
}
- stack.Push(NewNumber(val, o.mode))
+ mode := o.GetMode()
+ stack.Push(NewNumber(val, mode))
return nil
}
@@ -907,3 +940,32 @@ func (o *Operations) ListVariables() (string, error) {
func (o *Operations) ClearVariables() {
o.vars.ClearVariables()
}
+
+// AssignVariableFromStack assigns a value from the stack to a variable.
+// It pops the variable name from the stack first, then pops the value.
+// Usage: `name value =` or `x value =` (where x is on stack as a string)
+func (o *Operations) AssignVariableFromStack(stack *Stack) error {
+ if stack.Len() < 1 {
+ return fmt.Errorf("insufficient operands for assignment: need variable name")
+ }
+
+ nameVal, err := stack.Pop()
+ if err != nil {
+ return err
+ }
+
+ // Get the variable name from the popped value
+ name := nameVal.String()
+
+ if stack.Len() < 1 {
+ return fmt.Errorf("insufficient operands for assignment: need value")
+ }
+
+ val, err := stack.Pop()
+ if err != nil {
+ return err
+ }
+
+ // Convert to float64 for variable storage
+ return o.vars.SetVariable(name, val.Float64())
+}
diff --git a/internal/rpn/rpn_ops.go b/internal/rpn/rpn_ops.go
index 52a03fc..0c46576 100644
--- a/internal/rpn/rpn_ops.go
+++ b/internal/rpn/rpn_ops.go
@@ -19,6 +19,10 @@ func Tokenize(input string) []string {
// ResultStack returns the final stack state after evaluation.
// This is useful for commands that need to show the stack without consuming it.
func (r *RPN) ResultStack(tokens []string) (string, error) {
+ r.mu.RLock()
+ mode := r.mode
+ r.mu.RUnlock()
+
stack := NewStack()
for _, token := range tokens {
@@ -37,7 +41,7 @@ func (r *RPN) ResultStack(tokens []string) (string, error) {
if stack.Len() >= r.maxStack {
return "", fmt.Errorf("stack overflow")
}
- stack.Push(NewNumber(num, r.mode))
+ stack.Push(NewNumber(num, mode))
continue
}
@@ -58,7 +62,7 @@ func (r *RPN) ResultStack(tokens []string) (string, error) {
// Check if it's a variable reference (push its value)
val, exists := r.vars.GetVariable(token)
if exists {
- stack.Push(NewNumber(val, r.mode))
+ stack.Push(NewNumber(val, mode))
} else {
return "", fmt.Errorf("unknown token '%s'", token)
}
diff --git a/internal/rpn/rpn_parse.go b/internal/rpn/rpn_parse.go
index 4686122..c755150 100644
--- a/internal/rpn/rpn_parse.go
+++ b/internal/rpn/rpn_parse.go
@@ -11,15 +11,20 @@ import (
// ParseAndEvaluate parses and evaluates an RPN expression.
// Returns the result as a formatted string or an error.
+// This method is thread-safe for concurrent execution.
func (r *RPN) ParseAndEvaluate(input string) (string, error) {
// Validate input and initialize
input = strings.TrimSpace(input)
if input == "" {
return "", fmt.Errorf("rpn: empty expression")
}
+
+ // Lock for write operations on currentStack
+ r.mu.Lock()
if r.currentStack == nil {
r.currentStack = NewStack()
}
+ r.mu.Unlock()
// Handle assignment formats
if assignmentResult, isAssignment, err := r.handleAssignment(input); err != nil {
@@ -38,7 +43,11 @@ func (r *RPN) ParseAndEvaluate(input string) (string, error) {
}
// evaluate evaluates a list of tokens and returns the result.
+// This method is thread-safe for concurrent execution.
func (r *RPN) evaluate(tokens []string) (string, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
// Use the current stack for evaluation to preserve state
// This allows incremental operations in REPL mode
if r.currentStack == nil {
diff --git a/internal/rpn/rpn_state.go b/internal/rpn/rpn_state.go
index 684f758..93f0ff4 100644
--- a/internal/rpn/rpn_state.go
+++ b/internal/rpn/rpn_state.go
@@ -3,8 +3,15 @@
package rpn
+import (
+ "sync"
+)
+
// RPN represents the RPN parser and evaluator with state management.
+// It is thread-safe for concurrent read operations, but write operations
+// on the stack or mode should be synchronized externally or use the provided methods.
type RPN struct {
+ mu sync.RWMutex
vars VariableStore
ops Operator
opRegistry *OperatorRegistry
@@ -28,19 +35,28 @@ func NewRPN(vars VariableStore) *RPN {
}
// GetMode returns the current calculation mode.
+// This method is thread-safe for concurrent reads.
func (r *RPN) GetMode() CalculationMode {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
return r.mode
}
// SetMode sets the calculation mode.
+// This method is thread-safe for writes.
func (r *RPN) SetMode(mode CalculationMode) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
r.mode = mode
r.ops.SetMode(mode)
}
// GetCurrentStack returns a copy of the current stack for inspection.
// Returns []Number to preserve value types (numbers and booleans).
+// This method is thread-safe for concurrent reads.
func (r *RPN) GetCurrentStack() []Number {
+ r.mu.RLock()
+ defer r.mu.RUnlock()
if r.currentStack == nil {
return nil
}
@@ -49,7 +65,10 @@ func (r *RPN) GetCurrentStack() []Number {
// SetCurrentStack sets the current stack from a slice of numbers.
// This is useful for restoring stack state.
+// This method is thread-safe for writes.
func (r *RPN) SetCurrentStack(values []Number) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
r.currentStack = NewStack()
for _, v := range values {
r.currentStack.Push(v)
diff --git a/internal/rpn/rpn_test.go b/internal/rpn/rpn_test.go
index 7bd1d9a..76f6dbb 100644
--- a/internal/rpn/rpn_test.go
+++ b/internal/rpn/rpn_test.go
@@ -6,6 +6,7 @@ package rpn
import (
"fmt"
"strings"
+ "sync"
"testing"
)
@@ -591,9 +592,9 @@ func TestResultStackErrors(t *testing.T) {
expectedError: "insufficient operands",
},
{
- name: "invalid assignment syntax in ResultStack",
+ name: "insufficient operands for =",
input: []string{"="},
- expectedError: "unknown token '='",
+ expectedError: "insufficient operands for assignment",
},
}
@@ -1227,3 +1228,99 @@ func TestRPNStackPreservation(t *testing.T) {
t.Errorf("Stack should have 2 values, got %d", len(stack))
}
}
+
+// TestRPNModeThreadSafety verifies that mode changes are thread-safe
+func TestRPNModeThreadSafety(t *testing.T) {
+ r := NewRPN(NewVariables())
+
+ // Run multiple goroutines that change mode and perform operations
+ done := make(chan bool, 100)
+ for i := 0; i < 100; i++ {
+ go func() {
+ // Toggle mode
+ r.SetMode(FloatMode)
+ r.SetMode(RationalMode)
+
+ // Perform an evaluation while mode might be changing
+ _, _ = r.ParseAndEvaluate("1 2 +")
+ done <- true
+ }()
+ }
+
+ // Wait for all goroutines to complete
+ for i := 0; i < 100; i++ {
+ <-done
+ }
+}
+
+// TestRPNModeDirectAccess verifies direct mode access doesn't have race conditions
+func TestRPNModeDirectAccess(t *testing.T) {
+ r := NewRPN(NewVariables())
+
+ var wg sync.WaitGroup
+ iterations := 100
+
+ // Goroutine 1: Direct mode reads (simulating evaluate, ResultStack, EvalOperator)
+ // This simulates what happens in evaluate() - reading mode while holding lock
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for i := 0; i < iterations; i++ {
+ // Simulate evaluate() - acquire lock, read mode, then release lock
+ r.mu.RLock()
+ _ = r.mode
+ r.mu.RUnlock()
+ _, _ = r.ParseAndEvaluate("1 2 +")
+ }
+ }()
+
+ // Goroutine 2: SetMode (simulating handleRatCommand)
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for i := 0; i < iterations; i++ {
+ r.SetMode(FloatMode)
+ r.SetMode(RationalMode)
+ }
+ }()
+
+ // Goroutine 3: GetMode (mutex-protected)
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for i := 0; i < iterations; i++ {
+ _ = r.GetMode()
+ }
+ }()
+
+ wg.Wait()
+}
+
+// TestRPNConcurrentModeAndEval tests concurrent mode changes and evaluations
+func TestRPNConcurrentModeAndEval(t *testing.T) {
+ r := NewRPN(NewVariables())
+
+ var wg sync.WaitGroup
+ iterations := 50
+
+ // Goroutine 1: Change mode
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for i := 0; i < iterations; i++ {
+ r.SetMode(FloatMode)
+ r.SetMode(RationalMode)
+ }
+ }()
+
+ // Goroutine 2: Evaluate expressions
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ for i := 0; i < iterations; i++ {
+ _, _ = r.ParseAndEvaluate("1 2 +")
+ }
+ }()
+
+ wg.Wait()
+}