1 /*
2  * Copyright (c) 2024 Huawei Device Co., Ltd.
3  * Licensed under the Apache License, Version 2.0 (the "License");
4  * you may not use this file except in compliance with the License.
5  * You may obtain a copy of the License at
6  *
7  *     http://www.apache.org/licenses/LICENSE-2.0
8  *
9  * Unless required by applicable law or agreed to in writing, software
10  * distributed under the License is distributed on an "AS IS" BASIS,
11  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12  * See the License for the specific language governing permissions and
13  * limitations under the License.
14  */
15 
16 #ifndef HGM_STATE_MACHINE_H
17 #define HGM_STATE_MACHINE_H
18 
19 #include <atomic>
20 #include <functional>
21 #include <memory>
22 #include <unordered_map>
23 
24 #include "hgm_log.h"
25 #include "rs_trace.h"
26 
27 namespace OHOS::Rosen {
28 template<typename StateType, typename EventType>
29 class HgmStateMachine {
30 public:
31     using State = StateType;
32     using Event = EventType;
33 
34     using EventCallback = std::function<void(Event)>;
35     // lastState, newState
36     using StateChangeCallback = std::function<void(State, State)>;
37     // state, {callbackId, stateChangeCallback}
38     using StateChangeCallbacksType = std::unordered_map<State, std::unordered_map<int32_t, StateChangeCallback>>;
39 
HgmStateMachine(State state)40     explicit HgmStateMachine(State state) : state_(state) {}
41     virtual ~HgmStateMachine() = default;
42 
GetState()43     State GetState() { return state_.load(); }
44     void ChangeState(State state);
45     void OnEvent(Event event);
46 
47     int32_t RegisterEnterStateCallback(State state, const StateChangeCallback& callback);
48     void UnRegisterEnterStateCallback(State state, int32_t callbackId);
49     int32_t RegisterExitStateCallback(State state, const StateChangeCallback& callback);
50     void UnRegisterExitStateCallback(State state, int32_t callbackId);
51 
52     void RegisterEventCallback(Event event, const EventCallback& callback);
53     void UnRegisterEventCallback(Event event);
54 
55 protected:
State2String(State state)56     virtual std::string State2String(State state) const { return std::to_string(state); }
CheckChangeStateValid(State lastState,State newState)57     virtual bool CheckChangeStateValid(State lastState, State newState) { return true; }
58     // callback should be run in same thread
ExecuteCallback(const std::function<void ()> & callback)59     virtual void ExecuteCallback(const std::function<void()>& callback)
60     {
61         if (callback != nullptr) {
62             callback();
63         }
64     }
65 
66 private:
67     // return callbackId
68     int32_t RegisterStateChangeCallback(
69         StateChangeCallbacksType& callbacks, State state, const StateChangeCallback& callback);
70     void UnRegisterStateChangeCallback(StateChangeCallbacksType& callbacks, State state, int32_t callbackId);
71 
72     std::atomic<State> state_;
73     StateChangeCallbacksType enterStateCallbacks_;
74     StateChangeCallbacksType exitStateCallbacks_;
75     std::atomic<int32_t> stateCallbackId_ = 0;
76 
77     std::unordered_map<Event, EventCallback> eventCallbacks_;
78 };
79 
80 template<typename State, typename Event>
ChangeState(State state)81 void HgmStateMachine<State, Event>::ChangeState(State state)
82 {
83     ExecuteCallback([this, state = state]() {
84         auto lastState = state_.load();
85         if (lastState == state) {
86             return;
87         }
88         if (!CheckChangeStateValid(lastState, state)) {
89             return;
90         }
91 
92         // exit state callback
93         for (auto &[id, callback] : exitStateCallbacks_[lastState]) {
94             if (callback != nullptr) {
95                 callback(lastState, state);
96             }
97         }
98 
99         // change state
100         RS_TRACE_NAME_FMT("StateMachine state change: %s -> %s",
101             State2String(lastState).c_str(), State2String(state).c_str());
102         HGM_LOGI("StateMachine state change: %{public}s -> %{public}s",
103             State2String(lastState).c_str(), State2String(state).c_str());
104         state_.store(state);
105 
106         // enter state callback
107         for (auto &[id, callback] : enterStateCallbacks_[state]) {
108             if (callback != nullptr) {
109                 callback(lastState, state);
110             }
111         }
112     });
113 }
114 
115 template<typename State, typename Event>
OnEvent(Event event)116 void HgmStateMachine<State, Event>::OnEvent(Event event)
117 {
118     if (auto iter = eventCallbacks_.find(event); iter != eventCallbacks_.end()) {
119         if (iter->second != nullptr) {
120             ExecuteCallback([callback = iter->second, event = event] () { callback(event); });
121         }
122     }
123 }
124 
125 template<typename State, typename Event>
RegisterEnterStateCallback(State state,const StateChangeCallback & callback)126 int32_t HgmStateMachine<State, Event>::RegisterEnterStateCallback(State state, const StateChangeCallback& callback)
127 {
128     return RegisterStateChangeCallback(enterStateCallbacks_, state, callback);
129 }
130 
131 template<typename State, typename Event>
UnRegisterEnterStateCallback(State state,int32_t callbackId)132 void HgmStateMachine<State, Event>::UnRegisterEnterStateCallback(State state, int32_t callbackId)
133 {
134     return UnRegisterStateChangeCallback(enterStateCallbacks_, state, callbackId);
135 }
136 
137 template<typename State, typename Event>
RegisterExitStateCallback(State state,const StateChangeCallback & callback)138 int32_t HgmStateMachine<State, Event>::RegisterExitStateCallback(State state, const StateChangeCallback& callback)
139 {
140     return RegisterStateChangeCallback(exitStateCallbacks_, state, callback);
141 }
142 
143 template<typename State, typename Event>
UnRegisterExitStateCallback(State state,int32_t callbackId)144 void HgmStateMachine<State, Event>::UnRegisterExitStateCallback(State state, int32_t callbackId)
145 {
146     return UnRegisterStateChangeCallback(exitStateCallbacks_, state, callbackId);
147 }
148 
149 template<typename State, typename Event>
RegisterEventCallback(Event event,const EventCallback & callback)150 void HgmStateMachine<State, Event>::RegisterEventCallback(Event event, const EventCallback& callback)
151 {
152     eventCallbacks_[event] = callback;
153 }
154 
155 template<typename State, typename Event>
UnRegisterEventCallback(Event event)156 void HgmStateMachine<State, Event>::UnRegisterEventCallback(Event event)
157 {
158     if (auto iter = eventCallbacks_.find(event); iter != eventCallbacks_.end()) {
159         eventCallbacks_.erase(iter);
160     }
161 }
162 
163 template<typename State, typename Event>
RegisterStateChangeCallback(StateChangeCallbacksType & callbacks,State state,const StateChangeCallback & callback)164 int32_t HgmStateMachine<State, Event>::RegisterStateChangeCallback(
165     StateChangeCallbacksType& callbacks, State state, const StateChangeCallback& callback)
166 {
167     stateCallbackId_++;
168     callbacks[state][stateCallbackId_.load()] = callback;
169     return stateCallbackId_;
170 }
171 
172 template<typename State, typename Event>
UnRegisterStateChangeCallback(StateChangeCallbacksType & callbacks,State state,int32_t callbackId)173 void HgmStateMachine<State, Event>::UnRegisterStateChangeCallback(
174     StateChangeCallbacksType& callbacks, State state, int32_t callbackId)
175 {
176     if (auto iter = callbacks.find(state); iter != callbacks.end()) {
177         if (auto toDelCallbackIter = iter->second; toDelCallbackIter != iter->second.end()) {
178             iter->second.erase(toDelCallbackIter);
179         }
180     }
181 }
182 }
183 #endif // HGM_STATE_MACHINE_H