P4C
The P4 Compiler
Loading...
Searching...
No Matches
enumerator.h
1/*
2Copyright 2013-present Barefoot Networks, Inc.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17/* -*-c++-*-
18 C#-like enumerator interface */
19
20#ifndef LIB_ENUMERATOR_H_
21#define LIB_ENUMERATOR_H_
22
23#include <cstdint>
24#include <functional>
25#include <initializer_list>
26#include <iterator>
27#include <stdexcept>
28#include <string>
29#include <type_traits>
30#include <vector>
31
32#include "iterator_range.h"
33
34namespace P4::Util {
35enum class EnumeratorState { NotStarted, Valid, PastEnd };
36
37template <typename T>
38class Enumerator;
39
43// FIXME: It is not a proper iterator (see reference type above) and should be removed
44// in favor of more standard approach. Note that Enumerator<T>::getCurrent() always
45// returns element by value, so more or less suitable only for copyable types that are cheap
46// to copy.
47template <typename T>
48class EnumeratorHandle {
49 private:
50 Enumerator<T> *enumerator = nullptr; // when nullptr it represents end()
51 explicit EnumeratorHandle(Enumerator<T> *enumerator) : enumerator(enumerator) {}
52 friend class Enumerator<T>;
53
54 public:
55 using iterator_category = std::input_iterator_tag;
56 using difference_type = std::ptrdiff_t;
57 using value_type = T;
58 using reference = T;
59 using pointer = void;
60
61 reference operator*() const;
62 const EnumeratorHandle<T> &operator++();
63 bool operator==(const EnumeratorHandle<T> &other) const;
64 bool operator!=(const EnumeratorHandle<T> &other) const;
65};
66
68template <class T>
69class Enumerator {
70 protected:
71 EnumeratorState state = EnumeratorState::NotStarted;
72
73 // This is a weird oddity of C++: this class is a friend of itself with different templates
74 template <class S>
75 friend class Enumerator;
76 static std::vector<T> emptyVector;
77 template <typename S>
78 friend class EnumeratorHandle;
79
80 public:
81 using value_type = T;
82
83 Enumerator() { this->reset(); }
84
85 virtual ~Enumerator() = default;
86
89 virtual bool moveNext() = 0;
91 virtual T getCurrent() const = 0;
93 virtual void reset() { this->state = EnumeratorState::NotStarted; }
94
95 EnumeratorHandle<T> begin() {
96 this->moveNext();
97 return EnumeratorHandle<T>(this);
98 }
99 EnumeratorHandle<T> end() { return EnumeratorHandle<T>(nullptr); }
100
101 const char *stateName() const {
102 switch (this->state) {
103 case EnumeratorState::NotStarted:
104 return "NotStarted";
105 case EnumeratorState::Valid:
106 return "Valid";
107 case EnumeratorState::PastEnd:
108 return "PastEnd";
109 }
110 throw std::logic_error("Unexpected state " + std::to_string(static_cast<int>(this->state)));
111 }
112
114 template <typename Container>
115 [[deprecated(
116 "Use Util::enumerate() instead")]] static Enumerator<typename Container::value_type> *
117 createEnumerator(const Container &data);
118 static Enumerator<T> *emptyEnumerator(); // empty data
119 template <typename Iter>
120 [[deprecated("Use Util::enumerate() instead")]] static Enumerator<typename Iter::value_type> *
121 createEnumerator(Iter begin, Iter end);
122 template <typename Iter>
123 [[deprecated("Use Util::enumerate() instead")]] static Enumerator<typename Iter::value_type> *
124 createEnumerator(iterator_range<Iter> range);
125
127 template <typename Filter>
128 Enumerator<T> *where(Filter filter);
130 template <typename Mapper>
131 Enumerator<std::invoke_result_t<Mapper, T>> *map(Mapper map);
133 template <typename S>
134 Enumerator<S> *as();
136 virtual Enumerator<T> *concat(Enumerator<T> *other);
138 static Enumerator<T> *concatAll(Enumerator<Enumerator<T> *> *inputs);
139
140 std::vector<T> toVector() {
141 std::vector<T> result;
142 while (moveNext()) result.push_back(getCurrent());
143 return result;
144 }
145
147 uint64_t count() {
148 uint64_t found = 0;
149 while (this->moveNext()) found++;
150 return found;
151 }
152
154 bool any() { return this->moveNext(); }
155
157 T single() {
158 bool next = moveNext();
159 if (!next) throw std::logic_error("There is no element for `single()'");
160 T result = getCurrent();
161 next = moveNext();
162 if (next) throw std::logic_error("There are multiple elements when calling `single()'");
163 return result;
164 }
165
169 bool next = moveNext();
170 if (!next) return T{};
171 T result = getCurrent();
172 next = moveNext();
173 if (next) throw std::logic_error("There are multiple elements when calling `single()'");
174 return result;
175 }
176
179 bool next = moveNext();
180 if (!next) return T{};
181 return getCurrent();
182 }
183
185 T next() {
186 bool next = moveNext();
187 if (!next) throw std::logic_error("There is no element for `next()'");
188 return getCurrent();
189 }
190};
191
193// the implementation must be in the header file due to the templates
194
197
199template <typename Iter>
200class IteratorEnumerator : public Enumerator<typename std::iterator_traits<Iter>::value_type> {
201 protected:
202 Iter begin;
203 Iter end;
204 Iter current;
205 const char *name;
206 friend class Enumerator<typename std::iterator_traits<Iter>::value_type>;
207
208 public:
209 IteratorEnumerator(Iter begin, Iter end, const char *name)
210 : Enumerator<typename std::iterator_traits<Iter>::value_type>(),
211 begin(begin),
212 end(end),
213 current(begin),
214 name(name) {}
215
216 [[nodiscard]] std::string toString() const {
217 return std::string(this->name) + ":" + this->stateName();
218 }
219
220 bool moveNext() override {
221 switch (this->state) {
222 case EnumeratorState::NotStarted:
223 this->current = this->begin;
224 if (this->current == this->end) {
225 this->state = EnumeratorState::PastEnd;
226 return false;
227 } else {
228 this->state = EnumeratorState::Valid;
229 }
230 return true;
231 case EnumeratorState::PastEnd:
232 return false;
233 case EnumeratorState::Valid:
234 ++this->current;
235 if (this->current == this->end) {
236 this->state = EnumeratorState::PastEnd;
237 return false;
238 }
239 return true;
240 }
241
242 throw std::runtime_error("Unexpected enumerator state");
243 }
244
245 typename std::iterator_traits<Iter>::value_type getCurrent() const override {
246 switch (this->state) {
247 case EnumeratorState::NotStarted:
248 throw std::logic_error("You cannot call 'getCurrent' before 'moveNext'");
249 case EnumeratorState::PastEnd:
250 throw std::logic_error("You cannot call 'getCurrent' past the collection end");
251 case EnumeratorState::Valid:
252 return *this->current;
253 }
254 throw std::runtime_error("Unexpected enumerator state");
255 }
256};
257
258template <typename Iter>
259IteratorEnumerator(Iter begin, Iter end, const char *name) -> IteratorEnumerator<Iter>;
260
262
263template <typename T>
264class SingleEnumerator : public Enumerator<T> {
265 T value;
266
267 public:
268 explicit SingleEnumerator(T v) : Enumerator<T>(), value(v) {}
269 bool moveNext() {
270 switch (this->state) {
271 case EnumeratorState::NotStarted:
272 this->state = EnumeratorState::Valid;
273 return true;
274 case EnumeratorState::PastEnd:
275 return false;
276 case EnumeratorState::Valid:
277 this->state = EnumeratorState::PastEnd;
278 return false;
279 }
280 throw std::runtime_error("Unexpected enumerator state");
281 }
282 T getCurrent() const {
283 switch (this->state) {
284 case EnumeratorState::NotStarted:
285 throw std::logic_error("You cannot call 'getCurrent' before 'moveNext'");
286 case EnumeratorState::PastEnd:
287 throw std::logic_error("You cannot call 'getCurrent' past the collection end");
288 case EnumeratorState::Valid:
289 return this->value;
290 }
291 throw std::runtime_error("Unexpected enumerator state");
292 }
293};
294
296
298template <typename T>
299class EmptyEnumerator : public Enumerator<T> {
300 public:
301 [[nodiscard]] std::string toString() const { return "EmptyEnumerator"; }
303 bool moveNext() override { return false; }
304 T getCurrent() const override {
305 throw std::logic_error("You cannot call 'getCurrent' on an EmptyEnumerator");
306 }
307};
308
310
316template <typename T, typename Filter>
317class FilterEnumerator final : public Enumerator<T> {
318 Enumerator<T> *input;
319 Filter filter;
320 T current; // must prevent repeated evaluation
321
322 public:
323 FilterEnumerator(Enumerator<T> *input, Filter filter)
324 : input(input), filter(std::move(filter)) {}
325
326 private:
327 bool advance() {
328 this->state = EnumeratorState::Valid;
329 while (this->input->moveNext()) {
330 this->current = this->input->getCurrent();
331 bool match = this->filter(this->current);
332 if (match) return true;
333 }
334 this->state = EnumeratorState::PastEnd;
335 return false;
336 }
337
338 public:
339 [[nodiscard]] std::string toString() const {
340 return "FilterEnumerator(" + this->input->toString() + "):" + this->stateName();
341 }
342
343 void reset() {
344 this->input->reset();
346 }
347
348 bool moveNext() {
349 switch (this->state) {
350 case EnumeratorState::NotStarted:
351 case EnumeratorState::Valid:
352 return this->advance();
353 case EnumeratorState::PastEnd:
354 return false;
355 }
356 throw std::runtime_error("Unexpected enumerator state");
357 }
358
359 T getCurrent() const {
360 switch (this->state) {
361 case EnumeratorState::NotStarted:
362 throw std::logic_error("You cannot call 'getCurrent' before 'moveNext'");
363 case EnumeratorState::PastEnd:
364 throw std::logic_error("You cannot call 'getCurrent' past the collection end");
365 case EnumeratorState::Valid:
366 return this->current;
367 }
368 throw std::runtime_error("Unexpected enumerator state");
369 }
370};
371
373
374namespace Detail {
375// See if we can use ICastable interface to cast from T to S. This is only possible if:
376// - Both T and S are pointer types (let's denote T = From* and S = To*)
377// - Expression (From*)()->to<To>() is well-formed
378// Essentially this means the following code is well-formed:
379// From *current = input->getCurrent(); current->to<To>();
380template <typename From, typename To, typename = void>
381static constexpr bool can_be_casted = false;
382
383template <typename From, typename To>
384static constexpr bool
385 can_be_casted<From *, To *, std::void_t<decltype(std::declval<From *>()->template to<To>())>> =
386 true;
387} // namespace Detail
388
390template <typename T, typename S>
391class AsEnumerator final : public Enumerator<S> {
392 template <typename U = S>
393 typename std::enable_if_t<!Detail::can_be_casted<T, S>, U> getCurrentImpl() const {
394 T current = input->getCurrent();
395 return dynamic_cast<S>(current);
396 }
397
398 template <typename U = S>
399 typename std::enable_if_t<Detail::can_be_casted<T, S>, U> getCurrentImpl() const {
400 T current = input->getCurrent();
401 return current->template to<std::remove_pointer_t<S>>();
402 }
403
404 protected:
405 Enumerator<T> *input;
406
407 public:
408 explicit AsEnumerator(Enumerator<T> *input) : input(input) {}
409
410 std::string toString() const {
411 return "AsEnumerator(" + this->input->toString() + "):" + this->stateName();
412 }
413
414 void reset() override {
416 this->input->reset();
417 }
418
419 bool moveNext() override {
420 bool result = this->input->moveNext();
421 if (result)
422 this->state = EnumeratorState::Valid;
423 else
424 this->state = EnumeratorState::PastEnd;
425 return result;
426 }
427
428 S getCurrent() const override { return getCurrentImpl(); }
429};
430
432
434template <typename T, typename S, typename Mapper>
435class MapEnumerator final : public Enumerator<S> {
436 protected:
437 Enumerator<T> *input;
438 Mapper map;
439 S current;
440
441 public:
442 MapEnumerator(Enumerator<T> *input, Mapper map) : input(input), map(std::move(map)) {}
443
444 void reset() {
445 this->input->reset();
447 }
448
449 [[nodiscard]] std::string toString() const {
450 return "MapEnumerator(" + this->input->toString() + "):" + this->stateName();
451 }
452
453 bool moveNext() {
454 switch (this->state) {
455 case EnumeratorState::NotStarted:
456 case EnumeratorState::Valid: {
457 bool success = input->moveNext();
458 if (success) {
459 T currentInput = this->input->getCurrent();
460 this->current = this->map(currentInput);
461 this->state = EnumeratorState::Valid;
462 return true;
463 } else {
464 this->state = EnumeratorState::PastEnd;
465 return false;
466 }
467 }
468 case EnumeratorState::PastEnd:
469 return false;
470 }
471 throw std::runtime_error("Unexpected enumerator state");
472 }
473
474 S getCurrent() const {
475 switch (this->state) {
476 case EnumeratorState::NotStarted:
477 throw std::logic_error("You cannot call 'getCurrent' before 'moveNext'");
478 case EnumeratorState::PastEnd:
479 throw std::logic_error("You cannot call 'getCurrent' past the collection end");
480 case EnumeratorState::Valid:
481 return this->current;
482 }
483 throw std::runtime_error("Unexpected enumerator state");
484 }
485};
486
487template <typename T, typename Mapper>
488MapEnumerator(Enumerator<T> *,
489 Mapper) -> MapEnumerator<T, typename std::invoke_result_t<Mapper, T>, Mapper>;
490
492
494template <typename T>
495class ConcatEnumerator final : public Enumerator<T> {
496 std::vector<Enumerator<T> *> inputs;
497 T currentResult;
498
499 public:
500 ConcatEnumerator() = default;
501 // We take ownership of the vector
502 explicit ConcatEnumerator(std::vector<Enumerator<T> *> &&inputs) : inputs(std::move(inputs)) {
503 for (auto *currentInput : inputs)
504 if (currentInput == nullptr) throw std::logic_error("Null iterator in concatenation");
505 }
506
507 ConcatEnumerator(std::initializer_list<Enumerator<T> *> inputs) : inputs(inputs) {
508 for (auto *currentInput : inputs)
509 if (currentInput == nullptr) throw std::logic_error("Null iterator in concatenation");
510 }
511 explicit ConcatEnumerator(Enumerator<Enumerator<T> *> *inputs)
512 : ConcatEnumerator(inputs->toVector()) {}
513
514 [[nodiscard]] std::string toString() const { return "ConcatEnumerator:" + this->stateName(); }
515
516 private:
517 bool advance() {
518 this->state = EnumeratorState::Valid;
519 for (auto *currentInput : inputs) {
520 if (currentInput->moveNext()) {
521 this->currentResult = currentInput->getCurrent();
522 return true;
523 }
524 }
525
526 this->state = EnumeratorState::PastEnd;
527 return false;
528 }
529
530 public:
531 Enumerator<T> *concat(Enumerator<T> *other) override {
532 // Too late to add
533 if (this->state == EnumeratorState::PastEnd)
534 throw std::runtime_error("Invalid enumerator state to concatenate");
535
536 inputs.push_back(other);
537
538 return this;
539 }
540
541 void reset() override {
542 for (auto *currentInput : inputs) currentInput->reset();
544 }
545
546 bool moveNext() override {
547 switch (this->state) {
548 case EnumeratorState::NotStarted:
549 case EnumeratorState::Valid:
550 return this->advance();
551 case EnumeratorState::PastEnd:
552 return false;
553 }
554 throw std::runtime_error("Unexpected enumerator state");
555 }
556
557 T getCurrent() const override {
558 switch (this->state) {
559 case EnumeratorState::NotStarted:
560 throw std::logic_error("You cannot call 'getCurrent' before 'moveNext'");
561 case EnumeratorState::PastEnd:
562 throw std::logic_error("You cannot call 'getCurrent' past the collection end");
563 case EnumeratorState::Valid:
564 return this->currentResult;
565 }
566 throw std::runtime_error("Unexpected enumerator state");
567 }
568};
569
571
572template <typename T>
573template <typename Mapper>
574Enumerator<std::invoke_result_t<Mapper, T>> *Enumerator<T>::map(Mapper map) {
575 return new MapEnumerator(this, std::move(map));
576}
577
578template <typename T>
579template <typename S>
580Enumerator<S> *Enumerator<T>::as() {
581 return new AsEnumerator<T, S>(this);
582}
583
584template <typename T>
585template <typename Filter>
586Enumerator<T> *Enumerator<T>::where(Filter filter) {
587 return new FilterEnumerator(this, std::move(filter));
588}
589
590template <typename T>
591template <typename Container>
592Enumerator<typename Container::value_type> *Enumerator<T>::createEnumerator(const Container &data) {
593 return new IteratorEnumerator(data.begin(), data.end(), typeid(Container).name());
594}
595
596template <typename T>
597Enumerator<T> *Enumerator<T>::emptyEnumerator() {
598 return new EmptyEnumerator<T>();
599}
600
601template <typename T>
602template <typename Iter>
603Enumerator<typename Iter::value_type> *Enumerator<T>::createEnumerator(Iter begin, Iter end) {
604 return new IteratorEnumerator(begin, end, "iterator");
605}
606
607template <typename T>
608template <typename Iter>
609Enumerator<typename Iter::value_type> *Enumerator<T>::createEnumerator(iterator_range<Iter> range) {
610 return new IteratorEnumerator(range.begin(), range.end(), "range");
611}
612
613template <typename T>
614Enumerator<T> *Enumerator<T>::concatAll(Enumerator<Enumerator<T> *> *inputs) {
615 return new ConcatEnumerator<T>(inputs);
616}
617
618template <typename T>
619Enumerator<T> *Enumerator<T>::concat(Enumerator<T> *other) {
620 return new ConcatEnumerator<T>({this, other});
621}
622
624
625template <typename T>
626T EnumeratorHandle<T>::operator*() const {
627 if (enumerator == nullptr) throw std::logic_error("Dereferencing end() iterator");
628 return enumerator->getCurrent();
629}
630
631template <typename T>
632const EnumeratorHandle<T> &EnumeratorHandle<T>::operator++() {
633 enumerator->moveNext();
634 return *this;
635}
636
637template <typename T>
638bool EnumeratorHandle<T>::operator==(const EnumeratorHandle<T> &other) const {
639 return !(*this != other);
640}
641
642template <typename T>
643bool EnumeratorHandle<T>::operator!=(const EnumeratorHandle<T> &other) const {
644 if (this->enumerator == other.enumerator) return true;
645 if (other.enumerator != nullptr) throw std::logic_error("Comparison with different iterator");
646 return this->enumerator->state == EnumeratorState::Valid;
647}
648
649template <typename Iter>
650Enumerator<typename std::iterator_traits<Iter>::value_type> *enumerate(Iter begin, Iter end) {
651 return new IteratorEnumerator(begin, end, "iterator");
652}
653
654template <typename Iter>
656 return new IteratorEnumerator(range.begin(), range.end(), "range");
657}
658
659template <typename Container>
660Enumerator<typename Container::value_type> *enumerate(const Container &data) {
661 using std::begin;
662 using std::end;
663 return new IteratorEnumerator(begin(data), end(data), typeid(data).name());
664}
665
666// TODO: Flatten ConcatEnumerator's during concatenation
667template <typename T>
668Enumerator<T> *concat(std::initializer_list<Enumerator<T> *> inputs) {
669 return new ConcatEnumerator<T>(inputs);
670}
671
672template <typename... Args>
673auto concat(Args &&...inputs) {
674 using FirstEnumeratorTy =
675 std::remove_pointer_t<std::decay_t<std::tuple_element_t<0, std::tuple<Args...>>>>;
676 std::initializer_list<Enumerator<typename FirstEnumeratorTy::value_type> *> init{
677 std::forward<Args>(inputs)...};
678 return concat(init);
679}
680
681} // namespace P4::Util
682
683#endif /* LIB_ENUMERATOR_H_ */
Casts each element.
Definition enumerator.h:391
S getCurrent() const override
Get current element in the collection.
Definition enumerator.h:428
void reset() override
Move back to the beginning of the collection.
Definition enumerator.h:414
bool moveNext() override
Definition enumerator.h:419
Concatenation.
Definition enumerator.h:495
void reset() override
Move back to the beginning of the collection.
Definition enumerator.h:541
bool moveNext() override
Definition enumerator.h:546
T getCurrent() const override
Get current element in the collection.
Definition enumerator.h:557
Enumerator< T > * concat(Enumerator< T > *other) override
Append all elements of other after all elements of this.
Definition enumerator.h:531
Always empty iterator (equivalent to end())
Definition enumerator.h:299
bool moveNext() override
Always returns false.
Definition enumerator.h:303
T getCurrent() const override
Get current element in the collection.
Definition enumerator.h:304
Definition enumerator.h:48
Type-erased Enumerator interface.
Definition enumerator.h:69
Enumerator< S > * as()
Cast to an enumerator of S objects.
Definition enumerator.h:580
Enumerator< std::invoke_result_t< Mapper, T > > * map(Mapper map)
Apply specified function to all elements of this enumerator.
Definition enumerator.h:574
virtual bool moveNext()=0
virtual Enumerator< T > * concat(Enumerator< T > *other)
Append all elements of other after all elements of this.
Definition enumerator.h:619
T nextOrDefault()
Next element, or the default value if none exists.
Definition enumerator.h:178
T single()
The only next element; throws if the enumerator does not have exactly 1 element.
Definition enumerator.h:157
virtual void reset()
Move back to the beginning of the collection.
Definition enumerator.h:93
bool any()
True if the enumerator has at least one element.
Definition enumerator.h:154
virtual T getCurrent() const =0
Get current element in the collection.
T next()
Next element; throws if there are no elements.
Definition enumerator.h:185
uint64_t count()
Enumerate all elements and return the count.
Definition enumerator.h:147
T singleOrDefault()
Definition enumerator.h:168
Enumerator< T > * where(Filter filter)
Return an enumerator returning all elements that pass the filter.
Definition enumerator.h:586
static Enumerator< T > * concatAll(Enumerator< Enumerator< T > * > *inputs)
Concatenate all these collections into a single one.
Definition enumerator.h:614
Definition enumerator.h:317
T getCurrent() const
Get current element in the collection.
Definition enumerator.h:359
void reset()
Move back to the beginning of the collection.
Definition enumerator.h:343
bool moveNext()
Definition enumerator.h:348
A generic iterator returning elements of type T.
Definition enumerator.h:200
bool moveNext() override
Definition enumerator.h:220
std::iterator_traits< Iter >::value_type getCurrent() const override
Get current element in the collection.
Definition enumerator.h:245
Transforms all elements from type T to type S.
Definition enumerator.h:435
void reset()
Move back to the beginning of the collection.
Definition enumerator.h:444
bool moveNext()
Definition enumerator.h:453
S getCurrent() const
Get current element in the collection.
Definition enumerator.h:474
bool moveNext()
Definition enumerator.h:269
T getCurrent() const
Get current element in the collection.
Definition enumerator.h:282
Definition iterator_range.h:44
STL namespace.