initial commit
1 file changed, 483 insertions(+), 0 deletions(-)
changed files
A state/state_test.go
@@ -0,0 +1,483 @@ +package state + +import ( + "bytes" + "fmt" + "log" + "sync" + "testing" + "time" +) + +// MockCommand implements the command.Command interface for testing +type MockCommand struct { + mu sync.Mutex + startCalled bool + stopCalled bool + waitCalled bool + startErr error + waitErr error + stopErr error + waitChan chan struct{} + started bool +} + +func NewMockCommand() *MockCommand { + return &MockCommand{ + waitChan: make(chan struct{}), + } +} + +func (m *MockCommand) Start() error { + m.mu.Lock() + defer m.mu.Unlock() + m.startCalled = true + m.started = true + + return m.startErr +} + +func (m *MockCommand) Wait() error { + m.mu.Lock() + m.waitCalled = true + m.mu.Unlock() + + // Wait for signal to simulate command completion + <-m.waitChan + + return m.waitErr +} + +func (m *MockCommand) Stop() error { + m.mu.Lock() + defer m.mu.Unlock() + m.stopCalled = true + m.started = false + // Signal Wait to return + select { + case m.waitChan <- struct{}{}: + default: + } + + return m.stopErr +} + +func (m *MockCommand) SimulateExit() { + m.mu.Lock() + m.started = false + m.mu.Unlock() + // Signal Wait to return + select { + case m.waitChan <- struct{}{}: + default: + } +} + +func (m *MockCommand) WasStartCalled() bool { + m.mu.Lock() + defer m.mu.Unlock() + + return m.startCalled +} + +func (m *MockCommand) WasStopCalled() bool { + m.mu.Lock() + defer m.mu.Unlock() + + return m.stopCalled +} + +func (m *MockCommand) WasWaitCalled() bool { + m.mu.Lock() + defer m.mu.Unlock() + + return m.waitCalled +} + +func (m *MockCommand) SetStartError(err error) { + m.mu.Lock() + defer m.mu.Unlock() + m.startErr = err +} + +func (m *MockCommand) SetWaitError(err error) { + m.mu.Lock() + defer m.mu.Unlock() + m.waitErr = err +} + +func (m *MockCommand) SetStopError(err error) { + m.mu.Lock() + defer m.mu.Unlock() + m.stopErr = err +} + +func (m *MockCommand) Reset() { + m.mu.Lock() + defer m.mu.Unlock() + m.startCalled = false + m.stopCalled = false + m.waitCalled = false + m.startErr = nil + m.waitErr = nil + m.stopErr = nil + m.started = false + m.waitChan = make(chan struct{}) +} + +func createTestStateMachine() (*StateMachine, *MockCommand, *bytes.Buffer) { + mockCmd := NewMockCommand() + var buf bytes.Buffer + logger := log.New(&buf, "", 0) + + sm := New(mockCmd, logger) + + return sm, mockCmd, &buf +} + +func TestStateString(t *testing.T) { + t.Parallel() + + tests := []struct { + state State + expected string + }{ + {NotStarted, "NotStarted"}, + {Running, "Running"}, + {Exited, "Exited"}, + {State(99), "Unknown"}, + } + + for _, tt := range tests { + if got := tt.state.String(); got != tt.expected { + t.Errorf("State.String() = %v, want %v", got, tt.expected) + } + } +} + +func TestEventString(t *testing.T) { + t.Parallel() + + tests := []struct { + event Event + expected string + }{ + {Start, "Start"}, + {Signal, "Shutdown"}, + {Exit, "Stopped"}, + {Restart, "Restart"}, + {Event(99), "Unknown"}, + } + + for _, tt := range tests { + if got := tt.event.String(); got != tt.expected { + t.Errorf("Event.String() = %v, want %v", got, tt.expected) + } + } +} + +func TestNewStateMachine(t *testing.T) { + t.Parallel() + + sm, _, _ := createTestStateMachine() + + if sm.currentState != NotStarted { + t.Errorf("Initial state should be NotStarted, got %v", sm.currentState) + } +} + +func TestStartTransition(t *testing.T) { + t.Parallel() + + sm, mockCmd, _ := createTestStateMachine() + + // Send Start event + sm.SendEvent(Start) + + // Give some time for goroutine to start + time.Sleep(10 * time.Millisecond) + + if sm.currentState != Running { + t.Errorf("State should be Running after Start event, got %v", sm.currentState) + } + + if !mockCmd.WasStartCalled() { + t.Error("Start should have been called on command") + } + + if !mockCmd.WasWaitCalled() { + t.Error("Wait should have been called on command") + } +} + +func TestSignalTransition(t *testing.T) { + t.Parallel() + + sm, mockCmd, _ := createTestStateMachine() + + // Start the state machine + sm.SendEvent(Start) + time.Sleep(10 * time.Millisecond) + + // Reset mock to clear start calls + mockCmd.Reset() + + // Send Signal event + sm.SendEvent(Signal) + time.Sleep(10 * time.Millisecond) + + if sm.currentState != Exited { + t.Errorf("State should be Exited after Signal event from Running, got %v", sm.currentState) + } + + if !mockCmd.WasStopCalled() { + t.Error("Stop should have been called on command") + } +} + +func TestExitTransition(t *testing.T) { + t.Parallel() + + sm, mockCmd, _ := createTestStateMachine() + + // Start the state machine + sm.SendEvent(Start) + time.Sleep(10 * time.Millisecond) + + // Simulate command exit + mockCmd.SimulateExit() + time.Sleep(10 * time.Millisecond) + + if sm.currentState != Exited { + t.Errorf("State should be Exited after Exit event, got %v", sm.currentState) + } +} + +func TestRestartFromRunning(t *testing.T) { + t.Parallel() + + sm, mockCmd, _ := createTestStateMachine() + + // Start the state machine + sm.SendEvent(Start) + time.Sleep(10 * time.Millisecond) + + // Reset mock to clear initial calls + mockCmd.Reset() + + // Send Restart event + sm.SendEvent(Restart) + time.Sleep(10 * time.Millisecond) + + if sm.currentState != Running { + t.Errorf("State should still be Running after Restart event, got %v", sm.currentState) + } + + if !mockCmd.WasStopCalled() { + t.Error("Stop should have been called during restart") + } + + if !mockCmd.WasStartCalled() { + t.Error("Start should have been called during restart") + } +} + +func TestRestartFromExited(t *testing.T) { + t.Parallel() + + sm, mockCmd, _ := createTestStateMachine() + + // Start and then exit + sm.SendEvent(Start) + time.Sleep(10 * time.Millisecond) + mockCmd.SimulateExit() + time.Sleep(10 * time.Millisecond) + + // Reset mock + mockCmd.Reset() + + // Send Restart event from Exited state + sm.SendEvent(Restart) + time.Sleep(10 * time.Millisecond) + + if sm.currentState != Running { + t.Errorf("State should be Running after Restart from Exited, got %v", sm.currentState) + } + + if !mockCmd.WasStartCalled() { + t.Error("Start should have been called during restart from Exited") + } +} + +func TestStartFromExited(t *testing.T) { + t.Parallel() + + sm, mockCmd, _ := createTestStateMachine() + + // Start and then exit + sm.SendEvent(Start) + time.Sleep(10 * time.Millisecond) + mockCmd.SimulateExit() + time.Sleep(10 * time.Millisecond) + + // Reset mock + mockCmd.Reset() + + // Send Start event from Exited state + sm.SendEvent(Start) + time.Sleep(10 * time.Millisecond) + + if sm.currentState != Running { + t.Errorf("State should be Running after Start from Exited, got %v", sm.currentState) + } + + if !mockCmd.WasStartCalled() { + t.Error("Start should have been called") + } +} + +func TestInvalidTransitions(t *testing.T) { + t.Parallel() + + sm, _, buf := createTestStateMachine() + + tests := []struct { + initialState State + event Event + description string + }{ + {NotStarted, Signal, "Signal from NotStarted"}, + {NotStarted, Exit, "Exit from NotStarted"}, + {NotStarted, Restart, "Restart from NotStarted"}, + {Exited, Signal, "Signal from Exited"}, + {Exited, Exit, "Exit from Exited"}, + } + + for _, tt := range tests { + // Reset state machine to initial state + sm.currentState = tt.initialState + buf.Reset() + + sm.SendEvent(tt.event) + + if sm.currentState != tt.initialState { + t.Errorf( + "%s: state should remain %v, got %v", + tt.description, + tt.initialState, + sm.currentState, + ) + } + + logOutput := buf.String() + if logOutput != "" { + t.Errorf("%s: should not log invalid transition", tt.description) + } + } +} + +func TestCommandStartError(t *testing.T) { + t.Parallel() + + sm, mockCmd, buf := createTestStateMachine() + + // Set start to return an error + mockCmd.SetStartError(fmt.Errorf("test error")) + + sm.SendEvent(Start) + time.Sleep(10 * time.Millisecond) + + // Should still transition to Running state even if start fails + if sm.currentState != Running { + t.Errorf("State should be Running even if Start fails, got %v", sm.currentState) + } + + logOutput := buf.String() + if logOutput == "" { + t.Error("Should have logged start message") + } +} + +func TestCommandStopError(t *testing.T) { + t.Parallel() + + sm, mockCmd, buf := createTestStateMachine() + + // Start first + sm.SendEvent(Start) + time.Sleep(10 * time.Millisecond) + + // Set stop to return an error + mockCmd.SetStopError(fmt.Errorf("stop error")) + buf.Reset() + + sm.SendEvent(Signal) + time.Sleep(10 * time.Millisecond) + + if sm.currentState != Exited { + t.Errorf("State should be Exited even if Stop fails, got %v", sm.currentState) + } + + logOutput := buf.String() + if logOutput == "" { + t.Error("Should have logged stop message") + } +} + +func TestConcurrentEvents(t *testing.T) { + t.Parallel() + + sm, mockCmd, _ := createTestStateMachine() + + // Send multiple events concurrently + go sm.SendEvent(Start) + go sm.SendEvent(Start) + go sm.SendEvent(Start) + + time.Sleep(50 * time.Millisecond) + + if sm.currentState != Running { + t.Errorf("State should be Running after concurrent Start events, got %v", sm.currentState) + } + + if !mockCmd.WasStartCalled() { + t.Error("Start should have been called") + } +} + +func TestCompleteLifecycle(t *testing.T) { + t.Parallel() + + sm, mockCmd, _ := createTestStateMachine() + + // Complete lifecycle: NotStarted -> Running -> Exited -> Running -> Exited + + // Start + sm.SendEvent(Start) + time.Sleep(10 * time.Millisecond) + if sm.currentState != Running { + t.Errorf("Expected Running, got %v", sm.currentState) + } + + // Signal to stop + sm.SendEvent(Signal) + time.Sleep(10 * time.Millisecond) + if sm.currentState != Exited { + t.Errorf("Expected Exited, got %v", sm.currentState) + } + + // Start again + mockCmd.Reset() + sm.SendEvent(Start) + time.Sleep(10 * time.Millisecond) + if sm.currentState != Running { + t.Errorf("Expected Running after restart, got %v", sm.currentState) + } + + // Exit naturally + mockCmd.SimulateExit() + time.Sleep(10 * time.Millisecond) + if sm.currentState != Exited { + t.Errorf("Expected Exited after natural exit, got %v", sm.currentState) + } +}