1 /*
2 * Copyright (c) 2024 Huawei Device Co., Ltd.
3 * Licensed under the Apache License, Version 2.0 (the "License");
4 * you may not use this file except in compliance with the License.
5 * You may obtain a copy of the License at
6 *
7 * http://www.apache.org/licenses/LICENSE-2.0
8 *
9 * Unless required by applicable law or agreed to in writing, software
10 * distributed under the License is distributed on an "AS IS" BASIS,
11 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 * See the License for the specific language governing permissions and
13 * limitations under the License.
14 */
15
16 #ifndef HGM_STATE_MACHINE_H
17 #define HGM_STATE_MACHINE_H
18
19 #include <atomic>
20 #include <functional>
21 #include <memory>
22 #include <unordered_map>
23
24 #include "hgm_log.h"
25 #include "rs_trace.h"
26
27 namespace OHOS::Rosen {
28 template<typename StateType, typename EventType>
29 class HgmStateMachine {
30 public:
31 using State = StateType;
32 using Event = EventType;
33
34 using EventCallback = std::function<void(Event)>;
35 // lastState, newState
36 using StateChangeCallback = std::function<void(State, State)>;
37 // state, {callbackId, stateChangeCallback}
38 using StateChangeCallbacksType = std::unordered_map<State, std::unordered_map<int32_t, StateChangeCallback>>;
39
HgmStateMachine(State state)40 explicit HgmStateMachine(State state) : state_(state) {}
41 virtual ~HgmStateMachine() = default;
42
GetState()43 State GetState() { return state_.load(); }
44 void ChangeState(State state);
45 void OnEvent(Event event);
46
47 int32_t RegisterEnterStateCallback(State state, const StateChangeCallback& callback);
48 void UnRegisterEnterStateCallback(State state, int32_t callbackId);
49 int32_t RegisterExitStateCallback(State state, const StateChangeCallback& callback);
50 void UnRegisterExitStateCallback(State state, int32_t callbackId);
51
52 void RegisterEventCallback(Event event, const EventCallback& callback);
53 void UnRegisterEventCallback(Event event);
54
55 protected:
State2String(State state)56 virtual std::string State2String(State state) const { return std::to_string(state); }
CheckChangeStateValid(State lastState,State newState)57 virtual bool CheckChangeStateValid(State lastState, State newState) { return true; }
58 // callback should be run in same thread
ExecuteCallback(const std::function<void ()> & callback)59 virtual void ExecuteCallback(const std::function<void()>& callback)
60 {
61 if (callback != nullptr) {
62 callback();
63 }
64 }
65
66 private:
67 // return callbackId
68 int32_t RegisterStateChangeCallback(
69 StateChangeCallbacksType& callbacks, State state, const StateChangeCallback& callback);
70 void UnRegisterStateChangeCallback(StateChangeCallbacksType& callbacks, State state, int32_t callbackId);
71
72 std::atomic<State> state_;
73 StateChangeCallbacksType enterStateCallbacks_;
74 StateChangeCallbacksType exitStateCallbacks_;
75 std::atomic<int32_t> stateCallbackId_ = 0;
76
77 std::unordered_map<Event, EventCallback> eventCallbacks_;
78 };
79
80 template<typename State, typename Event>
ChangeState(State state)81 void HgmStateMachine<State, Event>::ChangeState(State state)
82 {
83 ExecuteCallback([this, state = state]() {
84 auto lastState = state_.load();
85 if (lastState == state) {
86 return;
87 }
88 if (!CheckChangeStateValid(lastState, state)) {
89 return;
90 }
91
92 // exit state callback
93 for (auto &[id, callback] : exitStateCallbacks_[lastState]) {
94 if (callback != nullptr) {
95 callback(lastState, state);
96 }
97 }
98
99 // change state
100 RS_TRACE_NAME_FMT("StateMachine state change: %s -> %s",
101 State2String(lastState).c_str(), State2String(state).c_str());
102 HGM_LOGI("StateMachine state change: %{public}s -> %{public}s",
103 State2String(lastState).c_str(), State2String(state).c_str());
104 state_.store(state);
105
106 // enter state callback
107 for (auto &[id, callback] : enterStateCallbacks_[state]) {
108 if (callback != nullptr) {
109 callback(lastState, state);
110 }
111 }
112 });
113 }
114
115 template<typename State, typename Event>
OnEvent(Event event)116 void HgmStateMachine<State, Event>::OnEvent(Event event)
117 {
118 if (auto iter = eventCallbacks_.find(event); iter != eventCallbacks_.end()) {
119 if (iter->second != nullptr) {
120 ExecuteCallback([callback = iter->second, event = event] () { callback(event); });
121 }
122 }
123 }
124
125 template<typename State, typename Event>
RegisterEnterStateCallback(State state,const StateChangeCallback & callback)126 int32_t HgmStateMachine<State, Event>::RegisterEnterStateCallback(State state, const StateChangeCallback& callback)
127 {
128 return RegisterStateChangeCallback(enterStateCallbacks_, state, callback);
129 }
130
131 template<typename State, typename Event>
UnRegisterEnterStateCallback(State state,int32_t callbackId)132 void HgmStateMachine<State, Event>::UnRegisterEnterStateCallback(State state, int32_t callbackId)
133 {
134 return UnRegisterStateChangeCallback(enterStateCallbacks_, state, callbackId);
135 }
136
137 template<typename State, typename Event>
RegisterExitStateCallback(State state,const StateChangeCallback & callback)138 int32_t HgmStateMachine<State, Event>::RegisterExitStateCallback(State state, const StateChangeCallback& callback)
139 {
140 return RegisterStateChangeCallback(exitStateCallbacks_, state, callback);
141 }
142
143 template<typename State, typename Event>
UnRegisterExitStateCallback(State state,int32_t callbackId)144 void HgmStateMachine<State, Event>::UnRegisterExitStateCallback(State state, int32_t callbackId)
145 {
146 return UnRegisterStateChangeCallback(exitStateCallbacks_, state, callbackId);
147 }
148
149 template<typename State, typename Event>
RegisterEventCallback(Event event,const EventCallback & callback)150 void HgmStateMachine<State, Event>::RegisterEventCallback(Event event, const EventCallback& callback)
151 {
152 eventCallbacks_[event] = callback;
153 }
154
155 template<typename State, typename Event>
UnRegisterEventCallback(Event event)156 void HgmStateMachine<State, Event>::UnRegisterEventCallback(Event event)
157 {
158 if (auto iter = eventCallbacks_.find(event); iter != eventCallbacks_.end()) {
159 eventCallbacks_.erase(iter);
160 }
161 }
162
163 template<typename State, typename Event>
RegisterStateChangeCallback(StateChangeCallbacksType & callbacks,State state,const StateChangeCallback & callback)164 int32_t HgmStateMachine<State, Event>::RegisterStateChangeCallback(
165 StateChangeCallbacksType& callbacks, State state, const StateChangeCallback& callback)
166 {
167 stateCallbackId_++;
168 callbacks[state][stateCallbackId_.load()] = callback;
169 return stateCallbackId_;
170 }
171
172 template<typename State, typename Event>
UnRegisterStateChangeCallback(StateChangeCallbacksType & callbacks,State state,int32_t callbackId)173 void HgmStateMachine<State, Event>::UnRegisterStateChangeCallback(
174 StateChangeCallbacksType& callbacks, State state, int32_t callbackId)
175 {
176 if (auto iter = callbacks.find(state); iter != callbacks.end()) {
177 if (auto toDelCallbackIter = iter->second; toDelCallbackIter != iter->second.end()) {
178 iter->second.erase(toDelCallbackIter);
179 }
180 }
181 }
182 }
183 #endif // HGM_STATE_MACHINE_H