1// SPDX-License-Identifier: GPL-3.0-or-later
2#include "units/unit.hpp"
3
4#include "global.hpp"
5#include "logging.hpp"
6
7#include <vector>
8
9const auto INVALID_ARGUMENT = "<invalid>";
10
11static bool is_optional_key(const std::string_view key)
12{
13 return key == "depends_on" || key == "part_of";
14}
15
16const std::map<std::string, UnitCreatorType> &Unit::Creator(std::optional<std::pair<std::string, UnitCreatorType>> creator)
17{
18 static std::map<std::string, UnitCreatorType> creators;
19 if (creator)
20 creators.insert(x&: *creator);
21 return creators;
22}
23
24const std::map<std::string, UnitInstantiator> &Unit::Instantiator(std::optional<std::pair<std::string, UnitInstantiator>> instantiator)
25{
26 static std::map<std::string, UnitInstantiator> instantiators;
27 if (instantiator)
28 instantiators.insert(x&: *instantiator);
29 return instantiators;
30}
31
32void Unit::VerifyUnitArguments(const std::string &id, const toml::table &table)
33{
34 if (table.empty())
35 return;
36
37 for (const auto &[key, value] : table)
38 {
39 if (key == "options")
40 {
41 if (!value.is_table())
42 {
43 std::cerr << "unit " << id << " has bad options" << std::endl;
44 continue;
45 }
46
47 for (const auto &[key, value] : *value.as_table())
48 std::cerr << "unit " << id << " has unknown options: " << key << std::endl;
49 }
50 else
51 {
52 std::cerr << "unit " << id << " has unknown keys: " << key << std::endl;
53 }
54 }
55}
56
57std::shared_ptr<IUnit> Unit::Create(const std::string &id, const toml::table &data)
58{
59 // extract type from id, e.g. mos.service@abc -> service
60 const auto type_string = id.substr(pos: id.find(c: '.') + 1, n: id.find(c: '@') - id.find(c: '.') - 1);
61 if (type_string.empty())
62 {
63 std::cerr << "bad unit id" << std::endl;
64 return nullptr;
65 }
66
67 const auto &creators = Unit::Creator();
68 if (const auto it = creators.find(x: type_string); it != creators.end())
69 {
70 Debug << "creating unit " << id << " of type " << type_string << std::endl;
71 auto data_copy = data;
72 if (const auto unit = it->second(id, data_copy); unit)
73 {
74 VerifyUnitArguments(id, table: data_copy);
75 return unit;
76 }
77
78 std::cerr << "failed to create unit" << std::endl;
79 return nullptr;
80 }
81 else
82 {
83 std::cerr << RED("unknown type ") << type_string << std::endl;
84 return nullptr;
85 }
86}
87
88std::shared_ptr<IUnit> Unit::Instantiate(const std::string &id, std::shared_ptr<const Template> template_, const ArgumentMap &args)
89{
90 const auto &instantiators = Unit::Instantiator();
91
92 const auto type_string = id.substr(pos: id.find(c: '.') + 1, n: id.find(c: '@') - id.find(c: '.') - 1);
93 if (type_string.empty())
94 {
95 std::cerr << "bad unit id" << std::endl;
96 return nullptr;
97 }
98
99 if (const auto it = instantiators.find(x: type_string); it != instantiators.end())
100 {
101 Debug << "instantiating unit " << id << " of type " << type_string << std::endl;
102 if (const auto unit = it->second(id, template_, args); unit)
103 return unit;
104
105 std::cerr << "failed to instantiate unit" << std::endl;
106 return nullptr;
107 }
108
109 std::cerr << RED("unknown type ") << type_string << std::endl;
110 return nullptr;
111}
112
113std::ostream &operator<<(std::ostream &os, const Unit &unit)
114{
115 os << unit.description << " (" << unit.id << ")" << std::endl;
116 os << " depends_on: ";
117 for (const auto &dep : unit.dependsOn)
118 os << dep << " ";
119 if (unit.dependsOn.empty())
120 os << "(none)";
121 os << std::endl;
122 os << " part_of: ";
123 for (const auto &part : unit.partOf)
124 os << part << " ";
125 if (unit.partOf.empty())
126 os << "(none)";
127 os << std::endl;
128 unit.onPrint(os);
129 return os;
130}
131
132Unit::Unit(const std::string &id, toml::table &table, std::shared_ptr<const Template> template_, const ArgumentMap &args)
133 : IUnit(id), //
134 arguments(args), //
135 description(PopArg(table, key: "description", toplevel)), //
136 dependsOn(GetArrayArg(table, key: "depends_on", toplevel)), //
137 partOf(GetArrayArg(table, key: "part_of", toplevel)), //
138 template_(template_)
139{
140 Debug << "creating unit " << id << std::endl;
141}
142
143std::string Unit::PopArg(toml::table &table, std::string_view key)
144{
145 if (!table["options"].is_table())
146 return INVALID_ARGUMENT;
147 return PopArg(table&: *table["options"].as_table(), key, toplevel);
148}
149
150std::vector<std::string> Unit::GetArrayArg(toml::table &table, std::string_view key)
151{
152 if (!table["options"].is_table())
153 return {};
154 return GetArrayArg(table&: *table["options"].as_table(), key, toplevel);
155}
156
157std::string Unit::PopArg(toml::table &table, std::string_view key, toplevel_t)
158{
159 const auto tomlval = table[key];
160 if (!tomlval || !tomlval.is_string())
161 {
162 if (is_optional_key(key))
163 return "";
164 std::cerr << "unit " << id << " missing key " << key << std::endl;
165 return INVALID_ARGUMENT;
166 }
167
168 const auto value = ReplaceArgs(str: tomlval.as_string()->get());
169 table.erase(key);
170 return value;
171}
172
173std::vector<std::string> Unit::GetArrayArg(toml::table &table, std::string_view key, toplevel_t)
174{
175 const auto tomlval = table[key];
176 if (!tomlval || !tomlval.is_array())
177 {
178 if (is_optional_key(key))
179 return {};
180 std::cerr << "unit " << id << " missing key " << key << std::endl;
181 return {};
182 }
183
184 const auto value = ReplaceArgs(array: tomlval.as_array());
185 table.erase(key);
186 return value;
187}
188
189std::string Unit::ReplaceArgs(const std::string &str) const
190{
191 // replace any $key with value in args
192 std::string result = str;
193 for (const auto &[key, value] : arguments)
194 result = replace_all(str: result, matcher: "[" + key + "]", replacement: value);
195 return result;
196}
197
198std::vector<std::string> Unit::ReplaceArgs(const toml::array *array)
199{
200 std::vector<std::string> result;
201 for (const auto &e : *array)
202 {
203 if (!e.is_string())
204 {
205 std::cerr << "Invalid array element" << std::endl;
206 continue;
207 }
208 result.push_back(x: ReplaceArgs(str: e.as_string()->get()));
209 }
210 return result;
211}
212