Skip to content

Commit

Permalink
fix: add more locks (#16)
Browse files Browse the repository at this point in the history
  • Loading branch information
pancsta committed Feb 22, 2024
1 parent 9fcadfa commit 3586d3a
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 14 deletions.
73 changes: 59 additions & 14 deletions pkg/machine/machine.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ type Machine struct {
logger Logger
queueLock sync.RWMutex
queueProcessing atomic.Value
activeStatesLock sync.Mutex
activeStatesLock sync.RWMutex
panicCaught bool
disposed bool
indexWhen indexWhen
Expand Down Expand Up @@ -247,17 +247,19 @@ func (m *Machine) GetRelationsOf(fromState string) []Relation {
func (m *Machine) When(states []string, ctx context.Context) chan struct{} {
ch := make(chan struct{})

m.activeStatesLock.Lock()
defer m.activeStatesLock.Unlock()

// if all active, close early
if m.Is(states) {
if m.is(states) {
close(ch)
return ch
}

m.activeStatesLock.Lock()
setMap := stateIsActive{}
matched := 0
for _, s := range states {
setMap[s] = m.Is(S{s})
setMap[s] = m.is(S{s})
if setMap[s] {
matched++
}
Expand All @@ -278,13 +280,15 @@ func (m *Machine) When(states []string, ctx context.Context) chan struct{} {
if m.disposed {
return
}

m.activeStatesLock.Lock()
defer m.activeStatesLock.Unlock()

for _, s := range states {
if _, ok := m.indexWhen[s]; ok {
m.indexWhen[s] = lo.Without(m.indexWhen[s], binding)
}
}
m.activeStatesLock.Unlock()
}()
}
// insert the binding
Expand All @@ -295,7 +299,6 @@ func (m *Machine) When(states []string, ctx context.Context) chan struct{} {
m.indexWhen[s] = append(m.indexWhen[s], binding)
}
}
m.activeStatesLock.Unlock()

return ch
}
Expand All @@ -314,10 +317,12 @@ func (m *Machine) WhenNot(states []string, ctx context.Context) chan struct{} {
}

m.activeStatesLock.Lock()
defer m.activeStatesLock.Unlock()

setMap := stateIsActive{}
matched := 0
for _, s := range states {
setMap[s] = m.Is(S{s})
setMap[s] = m.is(S{s})
if !setMap[s] {
matched++
}
Expand All @@ -338,13 +343,15 @@ func (m *Machine) WhenNot(states []string, ctx context.Context) chan struct{} {
if m.disposed {
return
}

m.activeStatesLock.Lock()
defer m.activeStatesLock.Unlock()

for _, s := range states {
if _, ok := m.indexWhen[s]; ok {
m.indexWhen[s] = lo.Without(m.indexWhen[s], binding)
}
}
m.activeStatesLock.Unlock()
}()
}
// insert the binding
Expand All @@ -355,7 +362,6 @@ func (m *Machine) WhenNot(states []string, ctx context.Context) chan struct{} {
m.indexWhen[s] = append(m.indexWhen[s], binding)
}
}
m.activeStatesLock.Unlock()

return ch
}
Expand All @@ -365,6 +371,9 @@ func (m *Machine) WhenNot(states []string, ctx context.Context) chan struct{} {
// states: optionally passing a list of states param guarantees a deterministic
// order of the result.
func (m *Machine) Time(states S) T {
m.activeStatesLock.RLock()
defer m.activeStatesLock.RUnlock()

if states == nil {
states = m.StateNames
}
Expand Down Expand Up @@ -438,6 +447,14 @@ func (m *Machine) Set(states S, args A) Result {
// machine.Is(S{"Foo", "Bar"}) // false
// ```
func (m *Machine) Is(states S) bool {
m.activeStatesLock.RLock()
defer m.activeStatesLock.RUnlock()

return m.is(states)
}

// thread-unsafe version of Is(), make sure to acquire m.activeStatesLock
func (m *Machine) is(states S) bool {
return lo.Every(m.ActiveStates, m.MustParseStates(states))
}

Expand All @@ -452,6 +469,9 @@ func (m *Machine) Is(states S) bool {
// machine.Not(S{"C", "D"}) // true
// ```
func (m *Machine) Not(states S) bool {
m.activeStatesLock.RLock()
defer m.activeStatesLock.RUnlock()

return lo.None(m.MustParseStates(states), m.ActiveStates)
}

Expand Down Expand Up @@ -521,14 +541,16 @@ func (m *Machine) GetStateCtx(state string) context.Context {
cancel()
return stateCtx
}

m.activeStatesLock.Lock()
defer m.activeStatesLock.Unlock()

// add an index
if _, ok := m.indexStateCtx[state]; !ok {
m.indexStateCtx[state] = []context.CancelFunc{cancel}
} else {
m.indexStateCtx[state] = append(m.indexStateCtx[state], cancel)
}
m.activeStatesLock.Unlock()
return stateCtx
}

Expand Down Expand Up @@ -622,7 +644,9 @@ func (m *Machine) recoverToErr(emitter *emitter, r any) {
finals := S{}
finals = append(finals, t.Exits...)
finals = append(finals, t.Enters...)
m.activeStatesLock.RLock()
activeStates := m.ActiveStates
m.activeStatesLock.RUnlock()
found := false
// walk over enter/exits and remove states after the last step,
// as their final handlers haven't been executed
Expand Down Expand Up @@ -687,6 +711,8 @@ func (m *Machine) setActiveStates(calledStates S, targetStates S,
isAuto bool,
) S {
m.activeStatesLock.Lock()
defer m.activeStatesLock.Unlock()

previous := m.ActiveStates
newStates := DiffStates(targetStates, m.ActiveStates)
removedStates := DiffStates(m.ActiveStates, targetStates)
Expand Down Expand Up @@ -720,7 +746,6 @@ func (m *Machine) setActiveStates(calledStates S, targetStates S,
m.log(LogChanges, "[state%s]"+logMsg, autoLabel)
}

m.activeStatesLock.Unlock()
return previous
}

Expand Down Expand Up @@ -775,28 +800,29 @@ func (m *Machine) processQueue() Result {
}

func (m *Machine) processStateCtxBindings() {
m.activeStatesLock.RLock()

deactivated := DiffStates(m.Transition.StatesBefore, m.ActiveStates)
m.activeStatesLock.Lock()
var toCancel []context.CancelFunc
for _, s := range deactivated {
toCancel = append(toCancel, m.indexStateCtx[s]...)
delete(m.indexStateCtx, s)
}
m.activeStatesLock.Unlock()
m.activeStatesLock.RUnlock()
// cancel all the state contexts outside the critical zone
for _, cancel := range toCancel {
cancel()
}
}

func (m *Machine) processWhenBindings() {
m.activeStatesLock.Lock()
activated := DiffStates(m.ActiveStates, m.Transition.StatesBefore)
deactivated := DiffStates(m.Transition.StatesBefore, m.ActiveStates)
all := S{}
all = append(all, activated...)
all = append(all, deactivated...)
var toClose []chan struct{}
m.activeStatesLock.Lock()
for _, s := range all {
for k, binding := range m.indexWhen[s] {
if lo.Contains(activated, s) {
Expand Down Expand Up @@ -851,6 +877,7 @@ func (m *Machine) processWhenBindings() {
}
}
m.activeStatesLock.Unlock()
// notify outside the critical zone
for ch := range toClose {
closeSafe(toClose[ch])
}
Expand Down Expand Up @@ -1055,6 +1082,9 @@ func (m *Machine) DuringTransition() bool {

// Clock return the current tick for a state.
func (m *Machine) Clock(state string) uint64 {
m.activeStatesLock.RLock()
defer m.activeStatesLock.RUnlock()

return m.clock[state]
}

Expand All @@ -1079,6 +1109,9 @@ func (m *Machine) To() S {
func (m *Machine) IsQueued(mutationType MutationType, states S,
withoutArgsOnly bool, statesStrictEqual bool, startIndex int,
) int {
m.queueLock.RLock()
defer m.queueLock.RUnlock()

for index, item := range m.Queue {
if index >= startIndex &&
item.Type == mutationType &&
Expand All @@ -1105,6 +1138,9 @@ func (m *Machine) Has(states S) bool {
// HasStateChanged checks current active states have changed from the passed
// ones.
func (m *Machine) HasStateChanged(before S) bool {
m.activeStatesLock.RLock()
defer m.activeStatesLock.RUnlock()

lenEqual := len(before) == len(m.ActiveStates)
return !lenEqual || len(DiffStates(before, m.ActiveStates)) > 0
}
Expand All @@ -1113,6 +1149,9 @@ func (m *Machine) HasStateChanged(before S) bool {
// with their clock values. Inactive states are omitted.
// Eg: (Foo:2 Bar:1)
func (m *Machine) String() string {
m.activeStatesLock.RLock()
defer m.activeStatesLock.RUnlock()

ret := "("
for _, state := range m.ActiveStates {
if ret != "(" {
Expand All @@ -1127,6 +1166,9 @@ func (m *Machine) String() string {
// clock values. Inactive states are in square brackets.
// Eg: (Foo:2 Bar:1)[Baz:0]
func (m *Machine) StringAll() string {
m.activeStatesLock.RLock()
defer m.activeStatesLock.RUnlock()

ret := "("
ret2 := "["
for _, state := range m.StateNames {
Expand All @@ -1149,6 +1191,9 @@ func (m *Machine) StringAll() string {
// relations, clocks).
// states: param for ordered or partial results.
func (m *Machine) Inspect(states S) string {
m.activeStatesLock.RLock()
defer m.activeStatesLock.RUnlock()

if states == nil {
states = m.StateNames
}
Expand Down
3 changes: 3 additions & 0 deletions pkg/machine/transition.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ type Transition struct {

// newTransition creates a new transition for the given mutation.
func newTransition(m *Machine, item *Mutation) *Transition {
m.activeStatesLock.RLock()
defer m.activeStatesLock.RUnlock()

t := &Transition{
Mutation: item,
StatesBefore: m.ActiveStates,
Expand Down

0 comments on commit 3586d3a

Please sign in to comment.