P4C
The P4 Compiler
Loading...
Searching...
No Matches
enumerator.h
1/*
2Copyright 2013-present Barefoot Networks, Inc.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17/* -*-c++-*-
18 C#-like enumerator interface */
19
20#ifndef LIB_ENUMERATOR_H_
21#define LIB_ENUMERATOR_H_
22
23#include <cstdint>
24#include <functional>
25#include <initializer_list>
26#include <iterator>
27#include <stdexcept>
28#include <string>
29#include <type_traits>
30#include <vector>
31
32#include "iterator_range.h"
33
34namespace P4::Util {
35enum class EnumeratorState { NotStarted, Valid, PastEnd };
36
37template <typename T>
38class Enumerator;
39
43// FIXME: It is not a proper iterator (see reference type above) and should be removed
44// in favor of more standard approach. Note that Enumerator<T>::getCurrent() always
45// returns element by value, so more or less suitable only for copyable types that are cheap
46// to copy.
47template <typename T>
49 private:
50 Enumerator<T> *enumerator = nullptr; // when nullptr it represents end()
51 explicit EnumeratorHandle(Enumerator<T> *enumerator) : enumerator(enumerator) {}
52 friend class Enumerator<T>;
53
54 public:
55 using iterator_category = std::input_iterator_tag;
56 using difference_type = std::ptrdiff_t;
57 using value_type = T;
58 using reference = T;
59 using pointer = void;
60
61 reference operator*() const;
62 const EnumeratorHandle<T> &operator++();
63 bool operator!=(const EnumeratorHandle<T> &other) const;
64};
65
67template <class T>
69 protected:
70 EnumeratorState state = EnumeratorState::NotStarted;
71
72 // This is a weird oddity of C++: this class is a friend of itself with different templates
73 template <class S>
74 friend class Enumerator;
75 static std::vector<T> emptyVector;
76 template <typename S>
77 friend class EnumeratorHandle;
78
79 public:
80 using value_type = T;
81
82 Enumerator() { this->reset(); }
83
84 virtual ~Enumerator() = default;
85
88 virtual bool moveNext() = 0;
90 virtual T getCurrent() const = 0;
92 virtual void reset() { this->state = EnumeratorState::NotStarted; }
93
94 EnumeratorHandle<T> begin() {
95 this->moveNext();
96 return EnumeratorHandle<T>(this);
97 }
98 EnumeratorHandle<T> end() { return EnumeratorHandle<T>(nullptr); }
99
100 const char *stateName() const {
101 switch (this->state) {
102 case EnumeratorState::NotStarted:
103 return "NotStarted";
104 case EnumeratorState::Valid:
105 return "Valid";
106 case EnumeratorState::PastEnd:
107 return "PastEnd";
108 }
109 throw std::logic_error("Unexpected state " + std::to_string(static_cast<int>(this->state)));
110 }
111
113 template <typename Container>
114 [[deprecated(
115 "Use Util::enumerate() instead")]] static Enumerator<typename Container::value_type> *
116 createEnumerator(const Container &data);
117 static Enumerator<T> *emptyEnumerator(); // empty data
118 template <typename Iter>
119 [[deprecated("Use Util::enumerate() instead")]] static Enumerator<typename Iter::value_type> *
120 createEnumerator(Iter begin, Iter end);
121 template <typename Iter>
122 [[deprecated("Use Util::enumerate() instead")]] static Enumerator<typename Iter::value_type> *
123 createEnumerator(iterator_range<Iter> range);
124
126 template <typename Filter>
127 Enumerator<T> *where(Filter filter);
129 template <typename Mapper>
132 template <typename S>
138
139 std::vector<T> toVector() {
140 std::vector<T> result;
141 while (moveNext()) result.push_back(getCurrent());
142 return result;
143 }
144
146 uint64_t count() {
147 uint64_t found = 0;
148 while (this->moveNext()) found++;
149 return found;
150 }
151
153 bool any() { return this->moveNext(); }
154
156 T single() {
157 bool next = moveNext();
158 if (!next) throw std::logic_error("There is no element for `single()'");
159 T result = getCurrent();
160 next = moveNext();
161 if (next) throw std::logic_error("There are multiple elements when calling `single()'");
162 return result;
163 }
164
168 bool next = moveNext();
169 if (!next) return T{};
170 T result = getCurrent();
171 next = moveNext();
172 if (next) throw std::logic_error("There are multiple elements when calling `single()'");
173 return result;
174 }
175
178 bool next = moveNext();
179 if (!next) return T{};
180 return getCurrent();
181 }
182
184 T next() {
185 bool next = moveNext();
186 if (!next) throw std::logic_error("There is no element for `next()'");
187 return getCurrent();
188 }
189};
190
192// the implementation must be in the header file due to the templates
193
196
198template <typename Iter>
199class IteratorEnumerator : public Enumerator<typename std::iterator_traits<Iter>::value_type> {
200 protected:
201 Iter begin;
202 Iter end;
203 Iter current;
204 const char *name;
205 friend class Enumerator<typename std::iterator_traits<Iter>::value_type>;
206
207 public:
208 IteratorEnumerator(Iter begin, Iter end, const char *name)
209 : Enumerator<typename std::iterator_traits<Iter>::value_type>(),
210 begin(begin),
211 end(end),
212 current(begin),
213 name(name) {}
214
215 [[nodiscard]] std::string toString() const {
216 return std::string(this->name) + ":" + this->stateName();
217 }
218
219 bool moveNext() {
220 switch (this->state) {
221 case EnumeratorState::NotStarted:
222 this->current = this->begin;
223 if (this->current == this->end) {
224 this->state = EnumeratorState::PastEnd;
225 return false;
226 } else {
227 this->state = EnumeratorState::Valid;
228 }
229 return true;
230 case EnumeratorState::PastEnd:
231 return false;
232 case EnumeratorState::Valid:
233 ++this->current;
234 if (this->current == this->end) {
235 this->state = EnumeratorState::PastEnd;
236 return false;
237 }
238 return true;
239 }
240
241 throw std::runtime_error("Unexpected enumerator state");
242 }
243
244 typename std::iterator_traits<Iter>::value_type getCurrent() const {
245 switch (this->state) {
246 case EnumeratorState::NotStarted:
247 throw std::logic_error("You cannot call 'getCurrent' before 'moveNext'");
248 case EnumeratorState::PastEnd:
249 throw std::logic_error("You cannot call 'getCurrent' past the collection end");
250 case EnumeratorState::Valid:
251 return *this->current;
252 }
253 throw std::runtime_error("Unexpected enumerator state");
254 }
255};
256
258
259template <typename T>
260class SingleEnumerator : public Enumerator<T> {
261 T value;
262
263 public:
264 explicit SingleEnumerator(T v) : Enumerator<T>(), value(v) {}
265 bool moveNext() {
266 switch (this->state) {
267 case EnumeratorState::NotStarted:
268 this->state = EnumeratorState::Valid;
269 return true;
270 case EnumeratorState::PastEnd:
271 return false;
272 case EnumeratorState::Valid:
273 this->state = EnumeratorState::PastEnd;
274 return false;
275 }
276 throw std::runtime_error("Unexpected enumerator state");
277 }
278 T getCurrent() const {
279 switch (this->state) {
280 case EnumeratorState::NotStarted:
281 throw std::logic_error("You cannot call 'getCurrent' before 'moveNext'");
282 case EnumeratorState::PastEnd:
283 throw std::logic_error("You cannot call 'getCurrent' past the collection end");
284 case EnumeratorState::Valid:
285 return this->value;
286 }
287 throw std::runtime_error("Unexpected enumerator state");
288 }
289};
290
292
294template <typename T>
295class EmptyEnumerator : public Enumerator<T> {
296 public:
297 [[nodiscard]] std::string toString() const { return "EmptyEnumerator"; }
299 bool moveNext() { return false; }
300 T getCurrent() const {
301 throw std::logic_error("You cannot call 'getCurrent' on an EmptyEnumerator");
302 }
303};
304
306
312template <typename T, typename Filter>
313class FilterEnumerator final : public Enumerator<T> {
314 Enumerator<T> *input;
315 Filter filter;
316 T current; // must prevent repeated evaluation
317
318 public:
319 FilterEnumerator(Enumerator<T> *input, Filter filter)
320 : input(input), filter(std::move(filter)) {}
321
322 private:
323 bool advance() {
324 this->state = EnumeratorState::Valid;
325 while (this->input->moveNext()) {
326 this->current = this->input->getCurrent();
327 bool match = this->filter(this->current);
328 if (match) return true;
329 }
330 this->state = EnumeratorState::PastEnd;
331 return false;
332 }
333
334 public:
335 [[nodiscard]] std::string toString() const {
336 return "FilterEnumerator(" + this->input->toString() + "):" + this->stateName();
337 }
338
339 void reset() {
340 this->input->reset();
342 }
343
344 bool moveNext() {
345 switch (this->state) {
346 case EnumeratorState::NotStarted:
347 case EnumeratorState::Valid:
348 return this->advance();
349 case EnumeratorState::PastEnd:
350 return false;
351 }
352 throw std::runtime_error("Unexpected enumerator state");
353 }
354
355 T getCurrent() const {
356 switch (this->state) {
357 case EnumeratorState::NotStarted:
358 throw std::logic_error("You cannot call 'getCurrent' before 'moveNext'");
359 case EnumeratorState::PastEnd:
360 throw std::logic_error("You cannot call 'getCurrent' past the collection end");
361 case EnumeratorState::Valid:
362 return this->current;
363 }
364 throw std::runtime_error("Unexpected enumerator state");
365 }
366};
367
369
370namespace Detail {
371// See if we can use ICastable interface to cast from T to S. This is only possible if:
372// - Both T and S are pointer types (let's denote T = From* and S = To*)
373// - Expression (From*)()->to<To>() is well-formed
374// Essentially this means the following code is well-formed:
375// From *current = input->getCurrent(); current->to<To>();
376template <typename From, typename To, typename = void>
377static constexpr bool can_be_casted = false;
378
379template <typename From, typename To>
380static constexpr bool
381 can_be_casted<From *, To *, std::void_t<decltype(std::declval<From *>()->template to<To>())>> =
382 true;
383} // namespace Detail
384
386template <typename T, typename S>
387class AsEnumerator final : public Enumerator<S> {
388 template <typename U = S>
389 typename std::enable_if_t<!Detail::can_be_casted<T, S>, U> getCurrentImpl() const {
390 T current = input->getCurrent();
391 return dynamic_cast<S>(current);
392 }
393
394 template <typename U = S>
395 typename std::enable_if_t<Detail::can_be_casted<T, S>, U> getCurrentImpl() const {
396 T current = input->getCurrent();
397 return current->template to<std::remove_pointer_t<S>>();
398 }
399
400 protected:
401 Enumerator<T> *input;
402
403 public:
404 explicit AsEnumerator(Enumerator<T> *input) : input(input) {}
405
406 std::string toString() const {
407 return "AsEnumerator(" + this->input->toString() + "):" + this->stateName();
408 }
409
410 void reset() override {
412 this->input->reset();
413 }
414
415 bool moveNext() override {
416 bool result = this->input->moveNext();
417 if (result)
418 this->state = EnumeratorState::Valid;
419 else
420 this->state = EnumeratorState::PastEnd;
421 return result;
422 }
423
424 S getCurrent() const override { return getCurrentImpl(); }
425};
426
428
430template <typename T, typename S, typename Mapper>
431class MapEnumerator final : public Enumerator<S> {
432 protected:
433 Enumerator<T> *input;
434 Mapper map;
435 S current;
436
437 public:
438 MapEnumerator(Enumerator<T> *input, Mapper map) : input(input), map(std::move(map)) {}
439
440 void reset() {
441 this->input->reset();
443 }
444
445 [[nodiscard]] std::string toString() const {
446 return "MapEnumerator(" + this->input->toString() + "):" + this->stateName();
447 }
448
449 bool moveNext() {
450 switch (this->state) {
451 case EnumeratorState::NotStarted:
452 case EnumeratorState::Valid: {
453 bool success = input->moveNext();
454 if (success) {
455 T currentInput = this->input->getCurrent();
456 this->current = this->map(currentInput);
457 this->state = EnumeratorState::Valid;
458 return true;
459 } else {
460 this->state = EnumeratorState::PastEnd;
461 return false;
462 }
463 }
464 case EnumeratorState::PastEnd:
465 return false;
466 }
467 throw std::runtime_error("Unexpected enumerator state");
468 }
469
470 S getCurrent() const {
471 switch (this->state) {
472 case EnumeratorState::NotStarted:
473 throw std::logic_error("You cannot call 'getCurrent' before 'moveNext'");
474 case EnumeratorState::PastEnd:
475 throw std::logic_error("You cannot call 'getCurrent' past the collection end");
476 case EnumeratorState::Valid:
477 return this->current;
478 }
479 throw std::runtime_error("Unexpected enumerator state");
480 }
481};
482
483template <typename T, typename Mapper>
484MapEnumerator(Enumerator<T> *,
485 Mapper) -> MapEnumerator<T, typename std::invoke_result_t<Mapper, T>, Mapper>;
486
488
490template <typename T>
491class ConcatEnumerator final : public Enumerator<T> {
492 std::vector<Enumerator<T> *> inputs;
493 T currentResult;
494
495 public:
496 ConcatEnumerator() = default;
497 // We take ownership of the vector
498 explicit ConcatEnumerator(std::vector<Enumerator<T> *> &&inputs) : inputs(std::move(inputs)) {
499 for (auto *currentInput : inputs)
500 if (currentInput == nullptr) throw std::logic_error("Null iterator in concatenation");
501 }
502
503 ConcatEnumerator(std::initializer_list<Enumerator<T> *> inputs) : inputs(inputs) {
504 for (auto *currentInput : inputs)
505 if (currentInput == nullptr) throw std::logic_error("Null iterator in concatenation");
506 }
507 explicit ConcatEnumerator(Enumerator<Enumerator<T> *> *inputs)
508 : ConcatEnumerator(inputs->toVector()) {}
509
510 [[nodiscard]] std::string toString() const { return "ConcatEnumerator:" + this->stateName(); }
511
512 private:
513 bool advance() {
514 this->state = EnumeratorState::Valid;
515 for (auto *currentInput : inputs) {
516 if (currentInput->moveNext()) {
517 this->currentResult = currentInput->getCurrent();
518 return true;
519 }
520 }
521
522 this->state = EnumeratorState::PastEnd;
523 return false;
524 }
525
526 public:
528 // Too late to add
529 if (this->state == EnumeratorState::PastEnd)
530 throw std::runtime_error("Invalid enumerator state to concatenate");
531
532 inputs.push_back(other);
533
534 return this;
535 }
536
537 void reset() override {
538 for (auto *currentInput : inputs) currentInput->reset();
540 }
541
542 bool moveNext() override {
543 switch (this->state) {
544 case EnumeratorState::NotStarted:
545 case EnumeratorState::Valid:
546 return this->advance();
547 case EnumeratorState::PastEnd:
548 return false;
549 }
550 throw std::runtime_error("Unexpected enumerator state");
551 }
552
553 T getCurrent() const override {
554 switch (this->state) {
555 case EnumeratorState::NotStarted:
556 throw std::logic_error("You cannot call 'getCurrent' before 'moveNext'");
557 case EnumeratorState::PastEnd:
558 throw std::logic_error("You cannot call 'getCurrent' past the collection end");
559 case EnumeratorState::Valid:
560 return this->currentResult;
561 }
562 throw std::runtime_error("Unexpected enumerator state");
563 }
564};
565
567
568template <typename T>
569template <typename Mapper>
571 return new MapEnumerator(this, std::move(map));
572}
573
574template <typename T>
575template <typename S>
577 return new AsEnumerator<T, S>(this);
578}
579
580template <typename T>
581template <typename Filter>
583 return new FilterEnumerator(this, std::move(filter));
584}
585
586template <typename T>
587template <typename Container>
589 return new IteratorEnumerator(data.begin(), data.end(), typeid(Container).name());
590}
591
592template <typename T>
593Enumerator<T> *Enumerator<T>::emptyEnumerator() {
594 return new EmptyEnumerator<T>();
595}
596
597template <typename T>
598template <typename Iter>
599Enumerator<typename Iter::value_type> *Enumerator<T>::createEnumerator(Iter begin, Iter end) {
600 return new IteratorEnumerator(begin, end, "iterator");
601}
602
603template <typename T>
604template <typename Iter>
605Enumerator<typename Iter::value_type> *Enumerator<T>::createEnumerator(iterator_range<Iter> range) {
606 return new IteratorEnumerator(range.begin(), range.end(), "range");
607}
608
609template <typename T>
613
614template <typename T>
616 return new ConcatEnumerator<T>({this, other});
617}
618
620
621template <typename T>
623 if (enumerator == nullptr) throw std::logic_error("Dereferencing end() iterator");
624 return enumerator->getCurrent();
625}
626
627template <typename T>
628const EnumeratorHandle<T> &EnumeratorHandle<T>::operator++() {
629 enumerator->moveNext();
630 return *this;
631}
632
633template <typename T>
634bool EnumeratorHandle<T>::operator!=(const EnumeratorHandle<T> &other) const {
635 if (this->enumerator == other.enumerator) return true;
636 if (other.enumerator != nullptr) throw std::logic_error("Comparison with different iterator");
637 return this->enumerator->state == EnumeratorState::Valid;
638}
639
640template <typename Iter>
641Enumerator<typename std::iterator_traits<Iter>::value_type> *enumerate(Iter begin, Iter end) {
642 return new IteratorEnumerator(begin, end, "iterator");
643}
644
645template <typename Iter>
646Enumerator<typename std::iterator_traits<Iter>::value_type> *enumerate(iterator_range<Iter> range) {
647 return new IteratorEnumerator(range.begin(), range.end(), "range");
648}
649
650template <typename Container>
651Enumerator<typename Container::value_type> *enumerate(const Container &data) {
652 using std::begin;
653 using std::end;
654 return new IteratorEnumerator(begin(data), end(data), typeid(data).name());
655}
656
657// TODO: Flatten ConcatEnumerator's during concatenation
658template <typename T>
659Enumerator<T> *concat(std::initializer_list<Enumerator<T> *> inputs) {
660 return new ConcatEnumerator<T>(inputs);
661}
662
663template <typename... Args>
664auto concat(Args &&...inputs) {
665 using FirstEnumeratorTy =
666 std::remove_pointer_t<std::decay_t<std::tuple_element_t<0, std::tuple<Args...>>>>;
667 std::initializer_list<Enumerator<typename FirstEnumeratorTy::value_type> *> init{
668 std::forward<Args>(inputs)...};
669 return concat(init);
670}
671
672} // namespace P4::Util
673
674#endif /* LIB_ENUMERATOR_H_ */
Casts each element.
Definition enumerator.h:387
S getCurrent() const override
Get current element in the collection.
Definition enumerator.h:424
void reset() override
Move back to the beginning of the collection.
Definition enumerator.h:410
bool moveNext() override
Definition enumerator.h:415
Concatenation.
Definition enumerator.h:491
void reset() override
Move back to the beginning of the collection.
Definition enumerator.h:537
bool moveNext() override
Definition enumerator.h:542
T getCurrent() const override
Get current element in the collection.
Definition enumerator.h:553
Enumerator< T > * concat(Enumerator< T > *other) override
Append all elements of other after all elements of this.
Definition enumerator.h:527
Always empty iterator (equivalent to end())
Definition enumerator.h:295
T getCurrent() const
Get current element in the collection.
Definition enumerator.h:300
bool moveNext()
Always returns false.
Definition enumerator.h:299
Definition enumerator.h:48
Type-erased Enumerator interface.
Definition enumerator.h:68
Enumerator< S > * as()
Cast to an enumerator of S objects.
Definition enumerator.h:576
Enumerator< std::invoke_result_t< Mapper, T > > * map(Mapper map)
Apply specified function to all elements of this enumerator.
Definition enumerator.h:570
virtual bool moveNext()=0
virtual Enumerator< T > * concat(Enumerator< T > *other)
Append all elements of other after all elements of this.
Definition enumerator.h:615
T nextOrDefault()
Next element, or the default value if none exists.
Definition enumerator.h:177
T single()
The only next element; throws if the enumerator does not have exactly 1 element.
Definition enumerator.h:156
virtual void reset()
Move back to the beginning of the collection.
Definition enumerator.h:92
bool any()
True if the enumerator has at least one element.
Definition enumerator.h:153
virtual T getCurrent() const =0
Get current element in the collection.
T next()
Next element; throws if there are no elements.
Definition enumerator.h:184
uint64_t count()
Enumerate all elements and return the count.
Definition enumerator.h:146
T singleOrDefault()
Definition enumerator.h:167
Enumerator< T > * where(Filter filter)
Return an enumerator returning all elements that pass the filter.
Definition enumerator.h:582
static Enumerator< T > * concatAll(Enumerator< Enumerator< T > * > *inputs)
Concatenate all these collections into a single one.
Definition enumerator.h:610
Definition enumerator.h:313
T getCurrent() const
Get current element in the collection.
Definition enumerator.h:355
void reset()
Move back to the beginning of the collection.
Definition enumerator.h:339
bool moveNext()
Definition enumerator.h:344
A generic iterator returning elements of type T.
Definition enumerator.h:199
bool moveNext()
Definition enumerator.h:219
std::iterator_traits< Iter >::value_type getCurrent() const
Get current element in the collection.
Definition enumerator.h:244
Transforms all elements from type T to type S.
Definition enumerator.h:431
void reset()
Move back to the beginning of the collection.
Definition enumerator.h:440
bool moveNext()
Definition enumerator.h:449
S getCurrent() const
Get current element in the collection.
Definition enumerator.h:470
Definition enumerator.h:260
bool moveNext()
Definition enumerator.h:265
T getCurrent() const
Get current element in the collection.
Definition enumerator.h:278
STL namespace.