P4C
The P4 Compiler
Loading...
Searching...
No Matches
enumerator.h
1/*
2 * SPDX-FileCopyrightText: 2013 Barefoot Networks, Inc.
3 * Copyright 2013-present Barefoot Networks, Inc.
4 *
5 * SPDX-License-Identifier: Apache-2.0
6 */
7
8/* -*-c++-*-
9 C#-like enumerator interface */
10
11#ifndef LIB_ENUMERATOR_H_
12#define LIB_ENUMERATOR_H_
13
14#include <cstdint>
15#include <functional>
16#include <initializer_list>
17#include <iterator>
18#include <stdexcept>
19#include <string>
20#include <type_traits>
21#include <vector>
22
23#include "iterator_range.h"
24
25namespace P4::Util {
26enum class EnumeratorState { NotStarted, Valid, PastEnd };
27
28template <typename T>
29class Enumerator;
30
34// FIXME: It is not a proper iterator (see reference type above) and should be removed
35// in favor of more standard approach. Note that Enumerator<T>::getCurrent() always
36// returns element by value, so more or less suitable only for copyable types that are cheap
37// to copy.
38template <typename T>
39class EnumeratorHandle {
40 private:
41 Enumerator<T> *enumerator = nullptr; // when nullptr it represents end()
42 explicit EnumeratorHandle(Enumerator<T> *enumerator) : enumerator(enumerator) {}
43 friend class Enumerator<T>;
44
45 public:
46 using iterator_category = std::input_iterator_tag;
47 using difference_type = std::ptrdiff_t;
48 using value_type = T;
49 using reference = T;
50 using pointer = void;
51
52 reference operator*() const;
53 const EnumeratorHandle<T> &operator++();
54 bool operator==(const EnumeratorHandle<T> &other) const;
55 bool operator!=(const EnumeratorHandle<T> &other) const;
56};
57
59template <class T>
60class Enumerator {
61 protected:
62 EnumeratorState state = EnumeratorState::NotStarted;
63
64 // This is a weird oddity of C++: this class is a friend of itself with different templates
65 template <class S>
66 friend class Enumerator;
67 static std::vector<T> emptyVector;
68 template <typename S>
69 friend class EnumeratorHandle;
70
71 public:
72 using value_type = T;
73
74 Enumerator() { this->reset(); }
75
76 virtual ~Enumerator() = default;
77
80 virtual bool moveNext() = 0;
82 virtual T getCurrent() const = 0;
84 virtual void reset() { this->state = EnumeratorState::NotStarted; }
85
86 EnumeratorHandle<T> begin() {
87 this->moveNext();
88 return EnumeratorHandle<T>(this);
89 }
90 EnumeratorHandle<T> end() { return EnumeratorHandle<T>(nullptr); }
91
92 const char *stateName() const {
93 switch (this->state) {
94 case EnumeratorState::NotStarted:
95 return "NotStarted";
96 case EnumeratorState::Valid:
97 return "Valid";
98 case EnumeratorState::PastEnd:
99 return "PastEnd";
100 }
101 throw std::logic_error("Unexpected state " + std::to_string(static_cast<int>(this->state)));
102 }
103
105 template <typename Container>
106 [[deprecated(
107 "Use Util::enumerate() instead")]] static Enumerator<typename Container::value_type> *
108 createEnumerator(const Container &data);
109 static Enumerator<T> *emptyEnumerator(); // empty data
110 template <typename Iter>
111 [[deprecated("Use Util::enumerate() instead")]] static Enumerator<typename Iter::value_type> *
112 createEnumerator(Iter begin, Iter end);
113 template <typename Iter>
114 [[deprecated("Use Util::enumerate() instead")]] static Enumerator<typename Iter::value_type> *
115 createEnumerator(iterator_range<Iter> range);
116
118 template <typename Filter>
119 Enumerator<T> *where(Filter filter);
121 template <typename Mapper>
122 Enumerator<std::invoke_result_t<Mapper, T>> *map(Mapper map);
124 template <typename S>
125 Enumerator<S> *as();
127 virtual Enumerator<T> *concat(Enumerator<T> *other);
129 static Enumerator<T> *concatAll(Enumerator<Enumerator<T> *> *inputs);
130
131 std::vector<T> toVector() {
132 std::vector<T> result;
133 while (moveNext()) result.push_back(getCurrent());
134 return result;
135 }
136
138 uint64_t count() {
139 uint64_t found = 0;
140 while (this->moveNext()) found++;
141 return found;
142 }
143
145 bool any() { return this->moveNext(); }
146
148 T single() {
149 bool next = moveNext();
150 if (!next) throw std::logic_error("There is no element for `single()'");
151 T result = getCurrent();
152 next = moveNext();
153 if (next) throw std::logic_error("There are multiple elements when calling `single()'");
154 return result;
155 }
156
160 bool next = moveNext();
161 if (!next) return T{};
162 T result = getCurrent();
163 next = moveNext();
164 if (next) throw std::logic_error("There are multiple elements when calling `single()'");
165 return result;
166 }
167
170 bool next = moveNext();
171 if (!next) return T{};
172 return getCurrent();
173 }
174
176 T next() {
177 bool next = moveNext();
178 if (!next) throw std::logic_error("There is no element for `next()'");
179 return getCurrent();
180 }
181};
182
184// the implementation must be in the header file due to the templates
185
188
190template <typename Iter>
191class IteratorEnumerator : public Enumerator<typename std::iterator_traits<Iter>::value_type> {
192 protected:
193 Iter begin;
194 Iter end;
195 Iter current;
196 const char *name;
197 friend class Enumerator<typename std::iterator_traits<Iter>::value_type>;
198
199 public:
200 IteratorEnumerator(Iter begin, Iter end, const char *name)
201 : Enumerator<typename std::iterator_traits<Iter>::value_type>(),
202 begin(begin),
203 end(end),
204 current(begin),
205 name(name) {}
206
207 [[nodiscard]] std::string toString() const {
208 return std::string(this->name) + ":" + this->stateName();
209 }
210
211 bool moveNext() override {
212 switch (this->state) {
213 case EnumeratorState::NotStarted:
214 this->current = this->begin;
215 if (this->current == this->end) {
216 this->state = EnumeratorState::PastEnd;
217 return false;
218 } else {
219 this->state = EnumeratorState::Valid;
220 }
221 return true;
222 case EnumeratorState::PastEnd:
223 return false;
224 case EnumeratorState::Valid:
225 ++this->current;
226 if (this->current == this->end) {
227 this->state = EnumeratorState::PastEnd;
228 return false;
229 }
230 return true;
231 }
232
233 throw std::runtime_error("Unexpected enumerator state");
234 }
235
236 typename std::iterator_traits<Iter>::value_type getCurrent() const override {
237 switch (this->state) {
238 case EnumeratorState::NotStarted:
239 throw std::logic_error("You cannot call 'getCurrent' before 'moveNext'");
240 case EnumeratorState::PastEnd:
241 throw std::logic_error("You cannot call 'getCurrent' past the collection end");
242 case EnumeratorState::Valid:
243 return *this->current;
244 }
245 throw std::runtime_error("Unexpected enumerator state");
246 }
247};
248
249template <typename Iter>
250IteratorEnumerator(Iter begin, Iter end, const char *name) -> IteratorEnumerator<Iter>;
251
253
254template <typename T>
255class SingleEnumerator : public Enumerator<T> {
256 T value;
257
258 public:
259 explicit SingleEnumerator(T v) : Enumerator<T>(), value(v) {}
260 bool moveNext() {
261 switch (this->state) {
262 case EnumeratorState::NotStarted:
263 this->state = EnumeratorState::Valid;
264 return true;
265 case EnumeratorState::PastEnd:
266 return false;
267 case EnumeratorState::Valid:
268 this->state = EnumeratorState::PastEnd;
269 return false;
270 }
271 throw std::runtime_error("Unexpected enumerator state");
272 }
273 T getCurrent() const {
274 switch (this->state) {
275 case EnumeratorState::NotStarted:
276 throw std::logic_error("You cannot call 'getCurrent' before 'moveNext'");
277 case EnumeratorState::PastEnd:
278 throw std::logic_error("You cannot call 'getCurrent' past the collection end");
279 case EnumeratorState::Valid:
280 return this->value;
281 }
282 throw std::runtime_error("Unexpected enumerator state");
283 }
284};
285
287
289template <typename T>
290class EmptyEnumerator : public Enumerator<T> {
291 public:
292 [[nodiscard]] std::string toString() const { return "EmptyEnumerator"; }
294 bool moveNext() override { return false; }
295 T getCurrent() const override {
296 throw std::logic_error("You cannot call 'getCurrent' on an EmptyEnumerator");
297 }
298};
299
301
307template <typename T, typename Filter>
308class FilterEnumerator final : public Enumerator<T> {
309 Enumerator<T> *input;
310 Filter filter;
311 T current; // must prevent repeated evaluation
312
313 public:
314 FilterEnumerator(Enumerator<T> *input, Filter filter)
315 : input(input), filter(std::move(filter)) {}
316
317 private:
318 bool advance() {
319 this->state = EnumeratorState::Valid;
320 while (this->input->moveNext()) {
321 this->current = this->input->getCurrent();
322 bool match = this->filter(this->current);
323 if (match) return true;
324 }
325 this->state = EnumeratorState::PastEnd;
326 return false;
327 }
328
329 public:
330 [[nodiscard]] std::string toString() const {
331 return "FilterEnumerator(" + this->input->toString() + "):" + this->stateName();
332 }
333
334 void reset() {
335 this->input->reset();
337 }
338
339 bool moveNext() {
340 switch (this->state) {
341 case EnumeratorState::NotStarted:
342 case EnumeratorState::Valid:
343 return this->advance();
344 case EnumeratorState::PastEnd:
345 return false;
346 }
347 throw std::runtime_error("Unexpected enumerator state");
348 }
349
350 T getCurrent() const {
351 switch (this->state) {
352 case EnumeratorState::NotStarted:
353 throw std::logic_error("You cannot call 'getCurrent' before 'moveNext'");
354 case EnumeratorState::PastEnd:
355 throw std::logic_error("You cannot call 'getCurrent' past the collection end");
356 case EnumeratorState::Valid:
357 return this->current;
358 }
359 throw std::runtime_error("Unexpected enumerator state");
360 }
361};
362
364
365namespace Detail {
366// See if we can use ICastable interface to cast from T to S. This is only possible if:
367// - Both T and S are pointer types (let's denote T = From* and S = To*)
368// - Expression (From*)()->to<To>() is well-formed
369// Essentially this means the following code is well-formed:
370// From *current = input->getCurrent(); current->to<To>();
371template <typename From, typename To, typename = void>
372static constexpr bool can_be_casted = false;
373
374template <typename From, typename To>
375static constexpr bool
376 can_be_casted<From *, To *, std::void_t<decltype(std::declval<From *>()->template to<To>())>> =
377 true;
378} // namespace Detail
379
381template <typename T, typename S>
382class AsEnumerator final : public Enumerator<S> {
383 template <typename U = S>
384 typename std::enable_if_t<!Detail::can_be_casted<T, S>, U> getCurrentImpl() const {
385 T current = input->getCurrent();
386 return dynamic_cast<S>(current);
387 }
388
389 template <typename U = S>
390 typename std::enable_if_t<Detail::can_be_casted<T, S>, U> getCurrentImpl() const {
391 T current = input->getCurrent();
392 return current->template to<std::remove_pointer_t<S>>();
393 }
394
395 protected:
396 Enumerator<T> *input;
397
398 public:
399 explicit AsEnumerator(Enumerator<T> *input) : input(input) {}
400
401 std::string toString() const {
402 return "AsEnumerator(" + this->input->toString() + "):" + this->stateName();
403 }
404
405 void reset() override {
407 this->input->reset();
408 }
409
410 bool moveNext() override {
411 bool result = this->input->moveNext();
412 if (result)
413 this->state = EnumeratorState::Valid;
414 else
415 this->state = EnumeratorState::PastEnd;
416 return result;
417 }
418
419 S getCurrent() const override { return getCurrentImpl(); }
420};
421
423
425template <typename T, typename S, typename Mapper>
426class MapEnumerator final : public Enumerator<S> {
427 protected:
428 Enumerator<T> *input;
429 Mapper map;
430 S current;
431
432 public:
433 MapEnumerator(Enumerator<T> *input, Mapper map) : input(input), map(std::move(map)) {}
434
435 void reset() {
436 this->input->reset();
438 }
439
440 [[nodiscard]] std::string toString() const {
441 return "MapEnumerator(" + this->input->toString() + "):" + this->stateName();
442 }
443
444 bool moveNext() {
445 switch (this->state) {
446 case EnumeratorState::NotStarted:
447 case EnumeratorState::Valid: {
448 bool success = input->moveNext();
449 if (success) {
450 T currentInput = this->input->getCurrent();
451 this->current = this->map(currentInput);
452 this->state = EnumeratorState::Valid;
453 return true;
454 } else {
455 this->state = EnumeratorState::PastEnd;
456 return false;
457 }
458 }
459 case EnumeratorState::PastEnd:
460 return false;
461 }
462 throw std::runtime_error("Unexpected enumerator state");
463 }
464
465 S getCurrent() const {
466 switch (this->state) {
467 case EnumeratorState::NotStarted:
468 throw std::logic_error("You cannot call 'getCurrent' before 'moveNext'");
469 case EnumeratorState::PastEnd:
470 throw std::logic_error("You cannot call 'getCurrent' past the collection end");
471 case EnumeratorState::Valid:
472 return this->current;
473 }
474 throw std::runtime_error("Unexpected enumerator state");
475 }
476};
477
478template <typename T, typename Mapper>
479MapEnumerator(Enumerator<T> *,
480 Mapper) -> MapEnumerator<T, typename std::invoke_result_t<Mapper, T>, Mapper>;
481
483
485template <typename T>
486class ConcatEnumerator final : public Enumerator<T> {
487 std::vector<Enumerator<T> *> inputs;
488 T currentResult;
489
490 public:
491 ConcatEnumerator() = default;
492 // We take ownership of the vector
493 explicit ConcatEnumerator(std::vector<Enumerator<T> *> &&inputs) : inputs(std::move(inputs)) {
494 for (auto *currentInput : inputs)
495 if (currentInput == nullptr) throw std::logic_error("Null iterator in concatenation");
496 }
497
498 ConcatEnumerator(std::initializer_list<Enumerator<T> *> inputs) : inputs(inputs) {
499 for (auto *currentInput : inputs)
500 if (currentInput == nullptr) throw std::logic_error("Null iterator in concatenation");
501 }
502 explicit ConcatEnumerator(Enumerator<Enumerator<T> *> *inputs)
503 : ConcatEnumerator(inputs->toVector()) {}
504
505 [[nodiscard]] std::string toString() const { return "ConcatEnumerator:" + this->stateName(); }
506
507 private:
508 bool advance() {
509 this->state = EnumeratorState::Valid;
510 for (auto *currentInput : inputs) {
511 if (currentInput->moveNext()) {
512 this->currentResult = currentInput->getCurrent();
513 return true;
514 }
515 }
516
517 this->state = EnumeratorState::PastEnd;
518 return false;
519 }
520
521 public:
522 Enumerator<T> *concat(Enumerator<T> *other) override {
523 // Too late to add
524 if (this->state == EnumeratorState::PastEnd)
525 throw std::runtime_error("Invalid enumerator state to concatenate");
526
527 inputs.push_back(other);
528
529 return this;
530 }
531
532 void reset() override {
533 for (auto *currentInput : inputs) currentInput->reset();
535 }
536
537 bool moveNext() override {
538 switch (this->state) {
539 case EnumeratorState::NotStarted:
540 case EnumeratorState::Valid:
541 return this->advance();
542 case EnumeratorState::PastEnd:
543 return false;
544 }
545 throw std::runtime_error("Unexpected enumerator state");
546 }
547
548 T getCurrent() const override {
549 switch (this->state) {
550 case EnumeratorState::NotStarted:
551 throw std::logic_error("You cannot call 'getCurrent' before 'moveNext'");
552 case EnumeratorState::PastEnd:
553 throw std::logic_error("You cannot call 'getCurrent' past the collection end");
554 case EnumeratorState::Valid:
555 return this->currentResult;
556 }
557 throw std::runtime_error("Unexpected enumerator state");
558 }
559};
560
562
563template <typename T>
564template <typename Mapper>
565Enumerator<std::invoke_result_t<Mapper, T>> *Enumerator<T>::map(Mapper map) {
566 return new MapEnumerator(this, std::move(map));
567}
568
569template <typename T>
570template <typename S>
571Enumerator<S> *Enumerator<T>::as() {
572 return new AsEnumerator<T, S>(this);
573}
574
575template <typename T>
576template <typename Filter>
577Enumerator<T> *Enumerator<T>::where(Filter filter) {
578 return new FilterEnumerator(this, std::move(filter));
579}
580
581template <typename T>
582template <typename Container>
583Enumerator<typename Container::value_type> *Enumerator<T>::createEnumerator(const Container &data) {
584 return new IteratorEnumerator(data.begin(), data.end(), typeid(Container).name());
585}
586
587template <typename T>
588Enumerator<T> *Enumerator<T>::emptyEnumerator() {
589 return new EmptyEnumerator<T>();
590}
591
592template <typename T>
593template <typename Iter>
594Enumerator<typename Iter::value_type> *Enumerator<T>::createEnumerator(Iter begin, Iter end) {
595 return new IteratorEnumerator(begin, end, "iterator");
596}
597
598template <typename T>
599template <typename Iter>
600Enumerator<typename Iter::value_type> *Enumerator<T>::createEnumerator(iterator_range<Iter> range) {
601 return new IteratorEnumerator(range.begin(), range.end(), "range");
602}
603
604template <typename T>
605Enumerator<T> *Enumerator<T>::concatAll(Enumerator<Enumerator<T> *> *inputs) {
606 return new ConcatEnumerator<T>(inputs);
607}
608
609template <typename T>
610Enumerator<T> *Enumerator<T>::concat(Enumerator<T> *other) {
611 return new ConcatEnumerator<T>({this, other});
612}
613
615
616template <typename T>
617T EnumeratorHandle<T>::operator*() const {
618 if (enumerator == nullptr) throw std::logic_error("Dereferencing end() iterator");
619 return enumerator->getCurrent();
620}
621
622template <typename T>
623const EnumeratorHandle<T> &EnumeratorHandle<T>::operator++() {
624 enumerator->moveNext();
625 return *this;
626}
627
628template <typename T>
629bool EnumeratorHandle<T>::operator==(const EnumeratorHandle<T> &other) const {
630 return !(*this != other);
631}
632
633template <typename T>
634bool EnumeratorHandle<T>::operator!=(const EnumeratorHandle<T> &other) const {
635 if (this->enumerator == other.enumerator) return true;
636 if (other.enumerator != nullptr) throw std::logic_error("Comparison with different iterator");
637 return this->enumerator->state == EnumeratorState::Valid;
638}
639
640template <typename Iter>
641Enumerator<typename std::iterator_traits<Iter>::value_type> *enumerate(Iter begin, Iter end) {
642 return new IteratorEnumerator(begin, end, "iterator");
643}
644
645template <typename Iter>
647 return new IteratorEnumerator(range.begin(), range.end(), "range");
648}
649
650template <typename Container>
651Enumerator<typename Container::value_type> *enumerate(const Container &data) {
652 using std::begin;
653 using std::end;
654 return new IteratorEnumerator(begin(data), end(data), typeid(data).name());
655}
656
657// TODO: Flatten ConcatEnumerator's during concatenation
658template <typename T>
659Enumerator<T> *concat(std::initializer_list<Enumerator<T> *> inputs) {
660 return new ConcatEnumerator<T>(inputs);
661}
662
663template <typename... Args>
664auto concat(Args &&...inputs) {
665 using FirstEnumeratorTy =
666 std::remove_pointer_t<std::decay_t<std::tuple_element_t<0, std::tuple<Args...>>>>;
667 std::initializer_list<Enumerator<typename FirstEnumeratorTy::value_type> *> init{
668 std::forward<Args>(inputs)...};
669 return concat(init);
670}
671
672} // namespace P4::Util
673
674#endif /* LIB_ENUMERATOR_H_ */
Casts each element.
Definition enumerator.h:382
S getCurrent() const override
Get current element in the collection.
Definition enumerator.h:419
void reset() override
Move back to the beginning of the collection.
Definition enumerator.h:405
bool moveNext() override
Definition enumerator.h:410
Concatenation.
Definition enumerator.h:486
void reset() override
Move back to the beginning of the collection.
Definition enumerator.h:532
bool moveNext() override
Definition enumerator.h:537
T getCurrent() const override
Get current element in the collection.
Definition enumerator.h:548
Enumerator< T > * concat(Enumerator< T > *other) override
Append all elements of other after all elements of this.
Definition enumerator.h:522
Always empty iterator (equivalent to end())
Definition enumerator.h:290
bool moveNext() override
Always returns false.
Definition enumerator.h:294
T getCurrent() const override
Get current element in the collection.
Definition enumerator.h:295
Definition enumerator.h:39
Type-erased Enumerator interface.
Definition enumerator.h:60
Enumerator< S > * as()
Cast to an enumerator of S objects.
Definition enumerator.h:571
Enumerator< std::invoke_result_t< Mapper, T > > * map(Mapper map)
Apply specified function to all elements of this enumerator.
Definition enumerator.h:565
virtual bool moveNext()=0
virtual Enumerator< T > * concat(Enumerator< T > *other)
Append all elements of other after all elements of this.
Definition enumerator.h:610
T nextOrDefault()
Next element, or the default value if none exists.
Definition enumerator.h:169
T single()
The only next element; throws if the enumerator does not have exactly 1 element.
Definition enumerator.h:148
virtual void reset()
Move back to the beginning of the collection.
Definition enumerator.h:84
bool any()
True if the enumerator has at least one element.
Definition enumerator.h:145
virtual T getCurrent() const =0
Get current element in the collection.
T next()
Next element; throws if there are no elements.
Definition enumerator.h:176
uint64_t count()
Enumerate all elements and return the count.
Definition enumerator.h:138
T singleOrDefault()
Definition enumerator.h:159
Enumerator< T > * where(Filter filter)
Return an enumerator returning all elements that pass the filter.
Definition enumerator.h:577
static Enumerator< T > * concatAll(Enumerator< Enumerator< T > * > *inputs)
Concatenate all these collections into a single one.
Definition enumerator.h:605
Definition enumerator.h:308
T getCurrent() const
Get current element in the collection.
Definition enumerator.h:350
void reset()
Move back to the beginning of the collection.
Definition enumerator.h:334
bool moveNext()
Definition enumerator.h:339
A generic iterator returning elements of type T.
Definition enumerator.h:191
bool moveNext() override
Definition enumerator.h:211
std::iterator_traits< Iter >::value_type getCurrent() const override
Get current element in the collection.
Definition enumerator.h:236
Transforms all elements from type T to type S.
Definition enumerator.h:426
void reset()
Move back to the beginning of the collection.
Definition enumerator.h:435
bool moveNext()
Definition enumerator.h:444
S getCurrent() const
Get current element in the collection.
Definition enumerator.h:465
bool moveNext()
Definition enumerator.h:260
T getCurrent() const
Get current element in the collection.
Definition enumerator.h:273
Definition iterator_range.h:44
STL namespace.