diff options
| author | Paul Buetow <paul@buetow.org> | 2026-03-25 21:58:36 +0200 |
|---|---|---|
| committer | Paul Buetow <paul@buetow.org> | 2026-03-25 21:58:36 +0200 |
| commit | 9ee6293b175710e7cf1ff5bc9a6dc555ddaf559f (patch) | |
| tree | 34da64728c21c195147969eda153774db6455d62 | |
| parent | 25b2728a6499b427e201fb80d8a70e65909645e0 (diff) | |
code-quality: Various improvements to code quality and thread safety
| -rw-r--r-- | FIXES.md | 10 | ||||
| -rw-r--r-- | cmd/gt/main.go | 2 | ||||
| -rw-r--r-- | internal/perc/perc.go | 8 | ||||
| -rw-r--r-- | internal/repl/concurrent_test.go | 62 | ||||
| -rw-r--r-- | internal/repl/handlers.go | 10 | ||||
| -rw-r--r-- | internal/repl/repl.go | 117 | ||||
| -rw-r--r-- | internal/repl/repl_test.go | 12 | ||||
| -rw-r--r-- | internal/rpn/operations.go | 88 | ||||
| -rw-r--r-- | internal/rpn/rpn_ops.go | 8 | ||||
| -rw-r--r-- | internal/rpn/rpn_parse.go | 9 | ||||
| -rw-r--r-- | internal/rpn/rpn_state.go | 19 | ||||
| -rw-r--r-- | internal/rpn/rpn_test.go | 101 |
12 files changed, 371 insertions, 75 deletions
@@ -216,9 +216,9 @@ func getRPNState() *RPNState { --- -## Fix #5: Improved Error Context in Calculator +## Fix #5: Improved Error Context in Percentage Calculator -**File:** `internal/calculator/calculator.go` +**File:** `internal/perc/perc.go` **Issue:** Errors not wrapped with context **Location:** Various places @@ -226,7 +226,7 @@ func getRPNState() *RPNState { ```go func Parse(input string) (string, error) { // ... - result, err := calculator.Parse(input) + result, err := perc.Parse(input) if err != nil { return "", err // Missing context } @@ -238,7 +238,7 @@ func Parse(input string) (string, error) { ```go func Parse(input string) (string, error) { // ... - result, err := calculator.Parse(input) + result, err := perc.Parse(input) if err != nil { return "", fmt.Errorf("rpn fallback failed for input %q: %w", input, err) } @@ -371,7 +371,7 @@ golangci-lint run | #1 | `rpn.go` | Error wrapping | Better debugging | | #3 | `repl.go` | Proper resource cleanup | No resource leaks | | #4 | `repl.go` | Mutex safety | Thread safety | -| #5 | `calculator.go` | Error context | Better error messages | +| #5 | `perc.go` | Error context | Better error messages | | #6 | `variables.go` | Slice handling | Memory management | | #7 | `variables.go` | Performance optimization | Reduced allocations | diff --git a/cmd/gt/main.go b/cmd/gt/main.go index 5ace4c5..15372d8 100644 --- a/cmd/gt/main.go +++ b/cmd/gt/main.go @@ -39,7 +39,7 @@ // // The package uses a layered architecture: // - main.go: Entry point and command routing -// - calculator/: Handles percentage calculation parsing +// - perc/: Handles percentage calculation parsing // - rpn/: Handles RPN expression parsing and evaluation // - repl/: Provides interactive Read-Eval-Print Loop mode // diff --git a/internal/perc/perc.go b/internal/perc/perc.go index 9921ce4..63a171e 100644 --- a/internal/perc/perc.go +++ b/internal/perc/perc.go @@ -120,13 +120,13 @@ func ParseCalculation(input string) (*Calculation, error) { registry.register(parseXIsWhatPercentOfY) registry.register(parseXIsYPercentOfWhat) - calc, ok, err := registry.parse(input) - if ok { - return calc, nil - } + calc, _, err := registry.parse(input) if err != nil { return nil, err } + if calc != nil { + return calc, nil + } return nil, fmt.Errorf("perc: unable to parse input %q. See usage for examples", input) } diff --git a/internal/repl/concurrent_test.go b/internal/repl/concurrent_test.go new file mode 100644 index 0000000..14c82ef --- /dev/null +++ b/internal/repl/concurrent_test.go @@ -0,0 +1,62 @@ +package repl + +import ( + "sync" + "testing" +) + +func TestConcurrentExecutor(t *testing.T) { + // Test concurrent calls to executor() + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + executor("20% of 150") + }(i) + } + wg.Wait() +} + +func TestConcurrentRPN(t *testing.T) { + // Test concurrent calls to runRPN() + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + runRPN("3 4 +") + }(i) + } + wg.Wait() +} + +func TestConcurrentRatModeToggle(t *testing.T) { + // Test concurrent calls to executor() that change mode + var wg sync.WaitGroup + for i := 0; i < 10; i++ { + wg.Add(1) + go func(id int) { + defer wg.Done() + executor("rat toggle") + }(i) + } + wg.Wait() +} + +func TestConcurrentExecutorAndRPN(t *testing.T) { + // Test concurrent calls to executor() and runRPN() + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(2) + go func(id int) { + defer wg.Done() + executor("20% of 150") + }(i) + go func(id int) { + defer wg.Done() + runRPN("3 4 +") + }(i) + } + wg.Wait() +} diff --git a/internal/repl/handlers.go b/internal/repl/handlers.go index 2cf62c4..663fd84 100644 --- a/internal/repl/handlers.go +++ b/internal/repl/handlers.go @@ -104,7 +104,7 @@ func handleRatCommand(repl *REPL, input string) (string, bool, error) { } modeArg := strings.ToLower(args[1]) - rpnState := repl.getRPNState() + rpnState := repl.rpnState switch modeArg { case "on": @@ -151,7 +151,7 @@ func (h *RPNHandler) Handle(repl *REPL, input string) (output string, handled bo if strings.HasPrefix(lowerInput, "rpn ") || strings.HasPrefix(lowerInput, "calc ") { // Extract the expression after rpn/calc rest := strings.TrimSpace(strings.TrimPrefix(input, strings.SplitN(input, " ", 2)[0])) - result, err := repl.runRPN(rest) + result, err := repl.rpnState.rpnCalc.ParseAndEvaluate(rest) if err != nil { return "", true, err } @@ -159,10 +159,10 @@ func (h *RPNHandler) Handle(repl *REPL, input string) (output string, handled bo } // Try RPN parsing first (for bare RPN expressions like "3 4 +") - if state := repl.getRPNState(); state != nil { + if state := repl.rpnState; state != nil { // Check if input looks like RPN (contains spaces or is a single known operator) if strings.Contains(input, " ") { - result, err := repl.runRPN(input) + result, err := state.rpnCalc.ParseAndEvaluate(input) if err == nil { return result, true, nil } @@ -206,7 +206,7 @@ func (h *RPNHandler) Handle(repl *REPL, input string) (output string, handled bo } // PercentageHandler handles percentage calculation expressions. -// It uses the calculator.Parse function to evaluate expressions like: +// It uses the perc.Parse function to evaluate expressions like: // - "20% of 150" // - "what is 20% of 150" // - "30 is what % of 150" diff --git a/internal/repl/repl.go b/internal/repl/repl.go index 55fff97..243cf90 100644 --- a/internal/repl/repl.go +++ b/internal/repl/repl.go @@ -22,13 +22,25 @@ type RPNState struct { rpnCalc *rpn.RPN } -// rpnState holds the singleton RPN state for REPL operations. -// It is initialized lazily using sync.Once to ensure thread-safe initialization. -var rpnState *RPNState - -// rpnStateOnce ensures rpnState is initialized exactly once. -// It's used by getRPNState to guarantee lazy singleton initialization. -var rpnStateOnce sync.Once +// executorREPL holds the REPL instance created by the executor function. +// This is used for backward compatibility with tests that need to access RPN state +// after calling executor(). It's not part of the main REPL architecture. +// Thread safety: Use executorREPLOnce for lazy initialization and executorREPLMu for access. +var executorREPL *REPL +var executorREPLOnce sync.Once +var executorREPLMu sync.Mutex + +// ResetExecutorREPL resets the executorREPL for clean test isolation. +// This should be called between tests that use executor() and getRPNState() +// to ensure each test starts with a fresh RPN state. +// +// Note: This function is intended for test use only and should not be used +// in production code. For production use, create new REPL instances with NewREPL(). +func ResetExecutorREPL() { + executorREPLMu.Lock() + defer executorREPLMu.Unlock() + executorREPL = nil +} // REPL manages the interactive command-line interface for the percentage calculator. // It provides an interactive prompt with history, tab-completion, signal handling, @@ -39,12 +51,14 @@ var rpnStateOnce sync.Once // - HistoryManager: manages command history persistence // - SignalHandler: handles SIGINT (Ctrl+C) // - commandChain: processes commands via chain of responsibility +// - rpnState: provides RPN state for calculations type REPL struct { ttyChecker *TTYChecker historyMgr *HistoryManager signalHandler *SignalHandler prompt *prompt.Prompt commandChain CommandHandler + rpnState *RPNState } // NewREPL creates a new REPL instance with default components. @@ -54,11 +68,19 @@ type REPL struct { // The executor function is called for each non-empty input line. // The completer function provides tab-completion suggestions for the prompt. func NewREPL(executor func(string), completer func(prompt.Document) []prompt.Suggest) *REPL { + // Initialize RPN state via dependency injection + vars := rpn.NewVariables() + rpnState := &RPNState{ + vars: vars, + rpnCalc: rpn.NewRPN(vars), + } + repl := &REPL{ ttyChecker: &TTYChecker{}, historyMgr: NewHistoryManager(".gt_history"), signalHandler: NewSignalHandler(), commandChain: NewCommandChain(), + rpnState: rpnState, } // Set up executor - if nil, use default @@ -226,59 +248,68 @@ func RunREPL() error { // commands via the chain of responsibility pattern, including percentage // calculations, RPN expressions, and built-in commands. func executor(input string) { - // Create a minimal REPL instance without building a prompt - r := &REPL{ - ttyChecker: &TTYChecker{}, - historyMgr: NewHistoryManager(".gt_history"), - signalHandler: NewSignalHandler(), - commandChain: NewCommandChain(), - } - defaultExecutor(r, input) -} - -// getRPNState returns or creates the RPN state using lazy initialization. -// It's thread-safe using sync.Once to ensure the RPN state is initialized exactly once. -// The RPN state is shared across all REPL instances. -// -// Returns the RPNState instance for performing RPN calculations -func getRPNState() *RPNState { - rpnStateOnce.Do(func() { + // Initialize executorREPL only once using sync.Once for thread-safe lazy initialization + executorREPLOnce.Do(func() { vars := rpn.NewVariables() - rpnState = &RPNState{ + rpnState := &RPNState{ vars: vars, rpnCalc: rpn.NewRPN(vars), } + + // Create a minimal REPL instance without building a prompt + executorREPL = &REPL{ + ttyChecker: &TTYChecker{}, + historyMgr: NewHistoryManager(".gt_history"), + signalHandler: NewSignalHandler(), + commandChain: NewCommandChain(), + rpnState: rpnState, + } }) - return rpnState -} -// getRPNState returns the RPN state. -// This is a REPL instance method for backward compatibility that delegates to the package-level getRPNState. -// -// Returns the RPNState instance for performing RPN calculations -func (r *REPL) getRPNState() *RPNState { - return getRPNState() + // Use mutex to protect access to executorREPL during execution + executorREPLMu.Lock() + repl := executorREPL + executorREPLMu.Unlock() + + defaultExecutor(repl, input) } // runRPN parses and evaluates an RPN (Reverse Polish Notation) expression. -// It uses the shared RPN state to maintain stack state across multiple calls. +// This is a package-level wrapper for backward compatibility that delegates to +// the executor's REPL runRPN method. // -// input: the RPN expression to evaluate (e.g., "3 4 +" or "x 5 = x x +") +// input: the RPN expression to evaluate // Returns the result string and an error if the expression is invalid func runRPN(input string) (string, error) { - state := getRPNState() - return state.rpnCalc.ParseAndEvaluate(input) + executorREPLMu.Lock() + defer executorREPLMu.Unlock() + + if executorREPL != nil { + return executorREPL.rpnState.rpnCalc.ParseAndEvaluate(input) + } + return "", fmt.Errorf("no executor REPL available - call executor() first") } -// runRPN parses and evaluates an RPN (Reverse Polish Notation) expression. -// This is a REPL instance method for backward compatibility that delegates to the package-level runRPN. +// getRPNState returns the RPN state from the executor's REPL. +// This is a package-level helper for backward compatibility with tests that need +// to access RPN state after calling executor(). It's not part of the main REPL +// architecture. // -// input: the RPN expression to evaluate -// Returns the result string and an error if the expression is invalid -func (r *REPL) runRPN(input string) (string, error) { - return runRPN(input) +// Returns the RPNState instance from the last executor() call, or nil if executor() hasn't been called +func getRPNState() *RPNState { + executorREPLMu.Lock() + defer executorREPLMu.Unlock() + + if executorREPL != nil { + return executorREPL.rpnState + } + return nil } + + + + // getHistoryPath returns the absolute path to the history file. // This is a package-level wrapper for backward compatibility. // The history file is stored in the user's home directory. diff --git a/internal/repl/repl_test.go b/internal/repl/repl_test.go index 2200b2a..cd062cb 100644 --- a/internal/repl/repl_test.go +++ b/internal/repl/repl_test.go @@ -470,11 +470,17 @@ func TestRPNHandlerWithPercentageExpression(t *testing.T) { func TestRPNHandlerWithRPNExpression(t *testing.T) { // Test RPN expressions chain := NewCommandChain() + vars := rpn.NewVariables() + rpnState := &RPNState{ + vars: vars, + rpnCalc: rpn.NewRPN(vars), + } r := &REPL{ ttyChecker: &TTYChecker{}, historyMgr: NewHistoryManager(".gt_history"), signalHandler: NewSignalHandler(), commandChain: chain, + rpnState: rpnState, } // Test RPN expression @@ -490,11 +496,17 @@ func TestRPNHandlerWithRPNExpression(t *testing.T) { func TestRPNHandlerWithSingleNumber(t *testing.T) { // Test single number input (RPN - pushes number onto stack) chain := NewCommandChain() + vars := rpn.NewVariables() + rpnState := &RPNState{ + vars: vars, + rpnCalc: rpn.NewRPN(vars), + } r := &REPL{ ttyChecker: &TTYChecker{}, historyMgr: NewHistoryManager(".gt_history"), signalHandler: NewSignalHandler(), commandChain: chain, + rpnState: rpnState, } // Test single number diff --git a/internal/rpn/operations.go b/internal/rpn/operations.go index 1300d99..81ea78e 100644 --- a/internal/rpn/operations.go +++ b/internal/rpn/operations.go @@ -6,6 +6,7 @@ package rpn import ( "fmt" "math" + "sync" ) // ArithmeticOperator defines the interface for basic arithmetic operators. @@ -56,6 +57,7 @@ type StackOperator interface { type VariableOperator interface { ListVariables() (string, error) ClearVariables() + AssignVariableFromStack(stack *Stack) error } // Operator is the combined interface for all operator implementations. @@ -74,8 +76,14 @@ type Operator interface { type Operations struct { vars VariableStore mode CalculationMode + mu sync.RWMutex } +// Ensure Operations implements Operator at compile time. +// This is an explicit interface satisfaction check that will fail to compile +// if Operations doesn't implement all methods required by the Operator interface. +var _ Operator = (*Operations)(nil) + // NewOperations creates a new Operations instance with the given variable store. func NewOperations(vars VariableStore) *Operations { return &Operations{ @@ -85,10 +93,21 @@ func NewOperations(vars VariableStore) *Operations { } // SetMode sets the calculation mode for the Operations instance. +// This method is thread-safe for writes. func (o *Operations) SetMode(mode CalculationMode) { + o.mu.Lock() + defer o.mu.Unlock() o.mode = mode } +// GetMode returns the current calculation mode. +// This method is thread-safe for reads. +func (o *Operations) GetMode() CalculationMode { + o.mu.RLock() + defer o.mu.RUnlock() + return o.mode +} + // OperatorHandler represents a function that handles an operator. // Returns (result string, handled bool, error error). // result is non-empty only for commands that return immediately (like show, vars). @@ -129,6 +148,7 @@ func NewOperatorRegistry(op Operator) *OperatorRegistry { registry.registerStandardOperator("==", func(stack *Stack) error { return op.EQ(stack) }) registry.registerStandardOperator("neq", func(stack *Stack) error { return op.NEQ(stack) }) registry.registerStandardOperator("!=", func(stack *Stack) error { return op.NEQ(stack) }) + registry.registerStandardOperator("=", func(stack *Stack) error { return op.AssignVariableFromStack(stack) }) registry.registerStandardOperator("dup", func(stack *Stack) error { return op.Dup(stack) }) registry.registerStandardOperator("swap", func(stack *Stack) error { return op.Swap(stack) }) registry.registerStandardOperator("pop", func(stack *Stack) error { return op.Pop(stack) }) @@ -348,7 +368,8 @@ func (o *Operations) Log2(stack *Stack) error { } // Compute log2 using the number interface - stack.Push(NewNumber(math.Log2(val), o.mode)) + mode := o.GetMode() + stack.Push(NewNumber(math.Log2(val), mode)) return nil } @@ -367,7 +388,8 @@ func (o *Operations) Log10(stack *Stack) error { } // Compute log10 using the number interface - stack.Push(NewNumber(math.Log10(val), o.mode)) + mode := o.GetMode() + stack.Push(NewNumber(math.Log10(val), mode)) return nil } @@ -386,7 +408,8 @@ func (o *Operations) Ln(stack *Stack) error { } // Compute ln using the number interface - stack.Push(NewNumber(math.Log(val), o.mode)) + mode := o.GetMode() + stack.Push(NewNumber(math.Log(val), mode)) return nil } @@ -418,7 +441,8 @@ func (o *Operations) HyperAdd(stack *Stack) error { for i := 0; i < len(values); i++ { sum += values[i].Float64() } - stack.Push(NewNumber(sum, o.mode)) + mode := o.GetMode() + stack.Push(NewNumber(sum, mode)) return nil } @@ -436,7 +460,8 @@ func (o *Operations) HyperMultiply(stack *Stack) error { } product *= val.Float64() } - stack.Push(NewNumber(product, o.mode)) + mode := o.GetMode() + stack.Push(NewNumber(product, mode)) return nil } @@ -466,7 +491,8 @@ func (o *Operations) HyperSubtract(stack *Stack) error { for i := 1; i < len(values); i++ { result -= values[i].Float64() } - stack.Push(NewNumber(result, o.mode)) + mode := o.GetMode() + stack.Push(NewNumber(result, mode)) return nil } @@ -500,7 +526,8 @@ func (o *Operations) HyperDivide(stack *Stack) error { } result /= val } - stack.Push(NewNumber(result, o.mode)) + mode := o.GetMode() + stack.Push(NewNumber(result, mode)) return nil } @@ -530,7 +557,8 @@ func (o *Operations) HyperPower(stack *Stack) error { for i := 1; i < len(values); i++ { result = math.Pow(result, values[i].Float64()) } - stack.Push(NewNumber(result, o.mode)) + mode := o.GetMode() + stack.Push(NewNumber(result, mode)) return nil } @@ -564,7 +592,8 @@ func (o *Operations) HyperModulo(stack *Stack) error { } result = math.Mod(result, val) } - stack.Push(NewNumber(result, o.mode)) + mode := o.GetMode() + stack.Push(NewNumber(result, mode)) return nil } @@ -602,7 +631,8 @@ func (o *Operations) HyperLog2(stack *Stack) error { } // Push the result as a Number - stack.Push(NewNumber(result, o.mode)) + mode := o.GetMode() + stack.Push(NewNumber(result, mode)) return nil } @@ -640,7 +670,8 @@ func (o *Operations) HyperLog10(stack *Stack) error { } // Push the result as a Number - stack.Push(NewNumber(result, o.mode)) + mode := o.GetMode() + stack.Push(NewNumber(result, mode)) return nil } @@ -676,7 +707,8 @@ func (o *Operations) HyperLn(stack *Stack) error { } result += math.Log(val) } - stack.Push(NewNumber(result, o.mode)) + mode := o.GetMode() + stack.Push(NewNumber(result, mode)) return nil } @@ -878,7 +910,8 @@ func (o *Operations) UseVariable(stack *Stack, name string) error { return fmt.Errorf("%w: %s", ErrVariableNotFound, name) } - stack.Push(NewNumber(val, o.mode)) + mode := o.GetMode() + stack.Push(NewNumber(val, mode)) return nil } @@ -907,3 +940,32 @@ func (o *Operations) ListVariables() (string, error) { func (o *Operations) ClearVariables() { o.vars.ClearVariables() } + +// AssignVariableFromStack assigns a value from the stack to a variable. +// It pops the variable name from the stack first, then pops the value. +// Usage: `name value =` or `x value =` (where x is on stack as a string) +func (o *Operations) AssignVariableFromStack(stack *Stack) error { + if stack.Len() < 1 { + return fmt.Errorf("insufficient operands for assignment: need variable name") + } + + nameVal, err := stack.Pop() + if err != nil { + return err + } + + // Get the variable name from the popped value + name := nameVal.String() + + if stack.Len() < 1 { + return fmt.Errorf("insufficient operands for assignment: need value") + } + + val, err := stack.Pop() + if err != nil { + return err + } + + // Convert to float64 for variable storage + return o.vars.SetVariable(name, val.Float64()) +} diff --git a/internal/rpn/rpn_ops.go b/internal/rpn/rpn_ops.go index 52a03fc..0c46576 100644 --- a/internal/rpn/rpn_ops.go +++ b/internal/rpn/rpn_ops.go @@ -19,6 +19,10 @@ func Tokenize(input string) []string { // ResultStack returns the final stack state after evaluation. // This is useful for commands that need to show the stack without consuming it. func (r *RPN) ResultStack(tokens []string) (string, error) { + r.mu.RLock() + mode := r.mode + r.mu.RUnlock() + stack := NewStack() for _, token := range tokens { @@ -37,7 +41,7 @@ func (r *RPN) ResultStack(tokens []string) (string, error) { if stack.Len() >= r.maxStack { return "", fmt.Errorf("stack overflow") } - stack.Push(NewNumber(num, r.mode)) + stack.Push(NewNumber(num, mode)) continue } @@ -58,7 +62,7 @@ func (r *RPN) ResultStack(tokens []string) (string, error) { // Check if it's a variable reference (push its value) val, exists := r.vars.GetVariable(token) if exists { - stack.Push(NewNumber(val, r.mode)) + stack.Push(NewNumber(val, mode)) } else { return "", fmt.Errorf("unknown token '%s'", token) } diff --git a/internal/rpn/rpn_parse.go b/internal/rpn/rpn_parse.go index 4686122..c755150 100644 --- a/internal/rpn/rpn_parse.go +++ b/internal/rpn/rpn_parse.go @@ -11,15 +11,20 @@ import ( // ParseAndEvaluate parses and evaluates an RPN expression. // Returns the result as a formatted string or an error. +// This method is thread-safe for concurrent execution. func (r *RPN) ParseAndEvaluate(input string) (string, error) { // Validate input and initialize input = strings.TrimSpace(input) if input == "" { return "", fmt.Errorf("rpn: empty expression") } + + // Lock for write operations on currentStack + r.mu.Lock() if r.currentStack == nil { r.currentStack = NewStack() } + r.mu.Unlock() // Handle assignment formats if assignmentResult, isAssignment, err := r.handleAssignment(input); err != nil { @@ -38,7 +43,11 @@ func (r *RPN) ParseAndEvaluate(input string) (string, error) { } // evaluate evaluates a list of tokens and returns the result. +// This method is thread-safe for concurrent execution. func (r *RPN) evaluate(tokens []string) (string, error) { + r.mu.Lock() + defer r.mu.Unlock() + // Use the current stack for evaluation to preserve state // This allows incremental operations in REPL mode if r.currentStack == nil { diff --git a/internal/rpn/rpn_state.go b/internal/rpn/rpn_state.go index 684f758..93f0ff4 100644 --- a/internal/rpn/rpn_state.go +++ b/internal/rpn/rpn_state.go @@ -3,8 +3,15 @@ package rpn +import ( + "sync" +) + // RPN represents the RPN parser and evaluator with state management. +// It is thread-safe for concurrent read operations, but write operations +// on the stack or mode should be synchronized externally or use the provided methods. type RPN struct { + mu sync.RWMutex vars VariableStore ops Operator opRegistry *OperatorRegistry @@ -28,19 +35,28 @@ func NewRPN(vars VariableStore) *RPN { } // GetMode returns the current calculation mode. +// This method is thread-safe for concurrent reads. func (r *RPN) GetMode() CalculationMode { + r.mu.RLock() + defer r.mu.RUnlock() return r.mode } // SetMode sets the calculation mode. +// This method is thread-safe for writes. func (r *RPN) SetMode(mode CalculationMode) { + r.mu.Lock() + defer r.mu.Unlock() r.mode = mode r.ops.SetMode(mode) } // GetCurrentStack returns a copy of the current stack for inspection. // Returns []Number to preserve value types (numbers and booleans). +// This method is thread-safe for concurrent reads. func (r *RPN) GetCurrentStack() []Number { + r.mu.RLock() + defer r.mu.RUnlock() if r.currentStack == nil { return nil } @@ -49,7 +65,10 @@ func (r *RPN) GetCurrentStack() []Number { // SetCurrentStack sets the current stack from a slice of numbers. // This is useful for restoring stack state. +// This method is thread-safe for writes. func (r *RPN) SetCurrentStack(values []Number) { + r.mu.Lock() + defer r.mu.Unlock() r.currentStack = NewStack() for _, v := range values { r.currentStack.Push(v) diff --git a/internal/rpn/rpn_test.go b/internal/rpn/rpn_test.go index 7bd1d9a..76f6dbb 100644 --- a/internal/rpn/rpn_test.go +++ b/internal/rpn/rpn_test.go @@ -6,6 +6,7 @@ package rpn import ( "fmt" "strings" + "sync" "testing" ) @@ -591,9 +592,9 @@ func TestResultStackErrors(t *testing.T) { expectedError: "insufficient operands", }, { - name: "invalid assignment syntax in ResultStack", + name: "insufficient operands for =", input: []string{"="}, - expectedError: "unknown token '='", + expectedError: "insufficient operands for assignment", }, } @@ -1227,3 +1228,99 @@ func TestRPNStackPreservation(t *testing.T) { t.Errorf("Stack should have 2 values, got %d", len(stack)) } } + +// TestRPNModeThreadSafety verifies that mode changes are thread-safe +func TestRPNModeThreadSafety(t *testing.T) { + r := NewRPN(NewVariables()) + + // Run multiple goroutines that change mode and perform operations + done := make(chan bool, 100) + for i := 0; i < 100; i++ { + go func() { + // Toggle mode + r.SetMode(FloatMode) + r.SetMode(RationalMode) + + // Perform an evaluation while mode might be changing + _, _ = r.ParseAndEvaluate("1 2 +") + done <- true + }() + } + + // Wait for all goroutines to complete + for i := 0; i < 100; i++ { + <-done + } +} + +// TestRPNModeDirectAccess verifies direct mode access doesn't have race conditions +func TestRPNModeDirectAccess(t *testing.T) { + r := NewRPN(NewVariables()) + + var wg sync.WaitGroup + iterations := 100 + + // Goroutine 1: Direct mode reads (simulating evaluate, ResultStack, EvalOperator) + // This simulates what happens in evaluate() - reading mode while holding lock + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + // Simulate evaluate() - acquire lock, read mode, then release lock + r.mu.RLock() + _ = r.mode + r.mu.RUnlock() + _, _ = r.ParseAndEvaluate("1 2 +") + } + }() + + // Goroutine 2: SetMode (simulating handleRatCommand) + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + r.SetMode(FloatMode) + r.SetMode(RationalMode) + } + }() + + // Goroutine 3: GetMode (mutex-protected) + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + _ = r.GetMode() + } + }() + + wg.Wait() +} + +// TestRPNConcurrentModeAndEval tests concurrent mode changes and evaluations +func TestRPNConcurrentModeAndEval(t *testing.T) { + r := NewRPN(NewVariables()) + + var wg sync.WaitGroup + iterations := 50 + + // Goroutine 1: Change mode + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + r.SetMode(FloatMode) + r.SetMode(RationalMode) + } + }() + + // Goroutine 2: Evaluate expressions + wg.Add(1) + go func() { + defer wg.Done() + for i := 0; i < iterations; i++ { + _, _ = r.ParseAndEvaluate("1 2 +") + } + }() + + wg.Wait() +} |
