P4C
The P4 Compiler
Loading...
Searching...
No Matches
enumerator.h
1/*
2Copyright 2013-present Barefoot Networks, Inc.
3
4Licensed under the Apache License, Version 2.0 (the "License");
5you may not use this file except in compliance with the License.
6You may obtain a copy of the License at
7
8 http://www.apache.org/licenses/LICENSE-2.0
9
10Unless required by applicable law or agreed to in writing, software
11distributed under the License is distributed on an "AS IS" BASIS,
12WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13See the License for the specific language governing permissions and
14limitations under the License.
15*/
16
17/* -*-c++-*-
18 C#-like enumerator interface */
19
20#ifndef LIB_ENUMERATOR_H_
21#define LIB_ENUMERATOR_H_
22
23#include <cstdint>
24#include <functional>
25#include <initializer_list>
26#include <iterator>
27#include <stdexcept>
28#include <string>
29#include <type_traits>
30#include <vector>
31
32#include "iterator_range.h"
33
34namespace P4::Util {
35enum class EnumeratorState { NotStarted, Valid, PastEnd };
36
37template <typename T>
38class Enumerator;
39
43// FIXME: It is not a proper iterator (see reference type above) and should be removed
44// in favor of more standard approach. Note that Enumerator<T>::getCurrent() always
45// returns element by value, so more or less suitable only for copyable types that are cheap
46// to copy.
47template <typename T>
48class EnumeratorHandle {
49 private:
50 Enumerator<T> *enumerator = nullptr; // when nullptr it represents end()
51 explicit EnumeratorHandle(Enumerator<T> *enumerator) : enumerator(enumerator) {}
52 friend class Enumerator<T>;
53
54 public:
55 using iterator_category = std::input_iterator_tag;
56 using difference_type = std::ptrdiff_t;
57 using value_type = T;
58 using reference = T;
59 using pointer = void;
60
61 reference operator*() const;
62 const EnumeratorHandle<T> &operator++();
63 bool operator==(const EnumeratorHandle<T> &other) const;
64 bool operator!=(const EnumeratorHandle<T> &other) const;
65};
66
68template <class T>
69class Enumerator {
70 protected:
71 EnumeratorState state = EnumeratorState::NotStarted;
72
73 // This is a weird oddity of C++: this class is a friend of itself with different templates
74 template <class S>
75 friend class Enumerator;
76 static std::vector<T> emptyVector;
77 template <typename S>
78 friend class EnumeratorHandle;
79
80 public:
81 using value_type = T;
82
83 Enumerator() { this->reset(); }
84
85 virtual ~Enumerator() = default;
86
89 virtual bool moveNext() = 0;
91 virtual T getCurrent() const = 0;
93 virtual void reset() { this->state = EnumeratorState::NotStarted; }
94
95 EnumeratorHandle<T> begin() {
96 this->moveNext();
97 return EnumeratorHandle<T>(this);
98 }
99 EnumeratorHandle<T> end() { return EnumeratorHandle<T>(nullptr); }
100
101 const char *stateName() const {
102 switch (this->state) {
103 case EnumeratorState::NotStarted:
104 return "NotStarted";
105 case EnumeratorState::Valid:
106 return "Valid";
107 case EnumeratorState::PastEnd:
108 return "PastEnd";
109 }
110 throw std::logic_error("Unexpected state " + std::to_string(static_cast<int>(this->state)));
111 }
112
114 template <typename Container>
115 [[deprecated(
116 "Use Util::enumerate() instead")]] static Enumerator<typename Container::value_type> *
117 createEnumerator(const Container &data);
118 static Enumerator<T> *emptyEnumerator(); // empty data
119 template <typename Iter>
120 [[deprecated("Use Util::enumerate() instead")]] static Enumerator<typename Iter::value_type> *
121 createEnumerator(Iter begin, Iter end);
122 template <typename Iter>
123 [[deprecated("Use Util::enumerate() instead")]] static Enumerator<typename Iter::value_type> *
124 createEnumerator(iterator_range<Iter> range);
125
127 template <typename Filter>
128 Enumerator<T> *where(Filter filter);
130 template <typename Mapper>
131 Enumerator<std::invoke_result_t<Mapper, T>> *map(Mapper map);
133 template <typename S>
134 Enumerator<S> *as();
136 virtual Enumerator<T> *concat(Enumerator<T> *other);
138 static Enumerator<T> *concatAll(Enumerator<Enumerator<T> *> *inputs);
139
140 std::vector<T> toVector() {
141 std::vector<T> result;
142 while (moveNext()) result.push_back(getCurrent());
143 return result;
144 }
145
147 uint64_t count() {
148 uint64_t found = 0;
149 while (this->moveNext()) found++;
150 return found;
151 }
152
154 bool any() { return this->moveNext(); }
155
157 T single() {
158 bool next = moveNext();
159 if (!next) throw std::logic_error("There is no element for `single()'");
160 T result = getCurrent();
161 next = moveNext();
162 if (next) throw std::logic_error("There are multiple elements when calling `single()'");
163 return result;
164 }
165
169 bool next = moveNext();
170 if (!next) return T{};
171 T result = getCurrent();
172 next = moveNext();
173 if (next) throw std::logic_error("There are multiple elements when calling `single()'");
174 return result;
175 }
176
179 bool next = moveNext();
180 if (!next) return T{};
181 return getCurrent();
182 }
183
185 T next() {
186 bool next = moveNext();
187 if (!next) throw std::logic_error("There is no element for `next()'");
188 return getCurrent();
189 }
190};
191
193// the implementation must be in the header file due to the templates
194
197
199template <typename Iter>
200class IteratorEnumerator : public Enumerator<typename std::iterator_traits<Iter>::value_type> {
201 protected:
202 Iter begin;
203 Iter end;
204 Iter current;
205 const char *name;
206 friend class Enumerator<typename std::iterator_traits<Iter>::value_type>;
207
208 public:
209 IteratorEnumerator(Iter begin, Iter end, const char *name)
210 : Enumerator<typename std::iterator_traits<Iter>::value_type>(),
211 begin(begin),
212 end(end),
213 current(begin),
214 name(name) {}
215
216 [[nodiscard]] std::string toString() const {
217 return std::string(this->name) + ":" + this->stateName();
218 }
219
220 bool moveNext() {
221 switch (this->state) {
222 case EnumeratorState::NotStarted:
223 this->current = this->begin;
224 if (this->current == this->end) {
225 this->state = EnumeratorState::PastEnd;
226 return false;
227 } else {
228 this->state = EnumeratorState::Valid;
229 }
230 return true;
231 case EnumeratorState::PastEnd:
232 return false;
233 case EnumeratorState::Valid:
234 ++this->current;
235 if (this->current == this->end) {
236 this->state = EnumeratorState::PastEnd;
237 return false;
238 }
239 return true;
240 }
241
242 throw std::runtime_error("Unexpected enumerator state");
243 }
244
245 typename std::iterator_traits<Iter>::value_type getCurrent() const {
246 switch (this->state) {
247 case EnumeratorState::NotStarted:
248 throw std::logic_error("You cannot call 'getCurrent' before 'moveNext'");
249 case EnumeratorState::PastEnd:
250 throw std::logic_error("You cannot call 'getCurrent' past the collection end");
251 case EnumeratorState::Valid:
252 return *this->current;
253 }
254 throw std::runtime_error("Unexpected enumerator state");
255 }
256};
257
259
260template <typename T>
261class SingleEnumerator : public Enumerator<T> {
262 T value;
263
264 public:
265 explicit SingleEnumerator(T v) : Enumerator<T>(), value(v) {}
266 bool moveNext() {
267 switch (this->state) {
268 case EnumeratorState::NotStarted:
269 this->state = EnumeratorState::Valid;
270 return true;
271 case EnumeratorState::PastEnd:
272 return false;
273 case EnumeratorState::Valid:
274 this->state = EnumeratorState::PastEnd;
275 return false;
276 }
277 throw std::runtime_error("Unexpected enumerator state");
278 }
279 T getCurrent() const {
280 switch (this->state) {
281 case EnumeratorState::NotStarted:
282 throw std::logic_error("You cannot call 'getCurrent' before 'moveNext'");
283 case EnumeratorState::PastEnd:
284 throw std::logic_error("You cannot call 'getCurrent' past the collection end");
285 case EnumeratorState::Valid:
286 return this->value;
287 }
288 throw std::runtime_error("Unexpected enumerator state");
289 }
290};
291
293
295template <typename T>
296class EmptyEnumerator : public Enumerator<T> {
297 public:
298 [[nodiscard]] std::string toString() const { return "EmptyEnumerator"; }
300 bool moveNext() { return false; }
301 T getCurrent() const {
302 throw std::logic_error("You cannot call 'getCurrent' on an EmptyEnumerator");
303 }
304};
305
307
313template <typename T, typename Filter>
314class FilterEnumerator final : public Enumerator<T> {
315 Enumerator<T> *input;
316 Filter filter;
317 T current; // must prevent repeated evaluation
318
319 public:
320 FilterEnumerator(Enumerator<T> *input, Filter filter)
321 : input(input), filter(std::move(filter)) {}
322
323 private:
324 bool advance() {
325 this->state = EnumeratorState::Valid;
326 while (this->input->moveNext()) {
327 this->current = this->input->getCurrent();
328 bool match = this->filter(this->current);
329 if (match) return true;
330 }
331 this->state = EnumeratorState::PastEnd;
332 return false;
333 }
334
335 public:
336 [[nodiscard]] std::string toString() const {
337 return "FilterEnumerator(" + this->input->toString() + "):" + this->stateName();
338 }
339
340 void reset() {
341 this->input->reset();
343 }
344
345 bool moveNext() {
346 switch (this->state) {
347 case EnumeratorState::NotStarted:
348 case EnumeratorState::Valid:
349 return this->advance();
350 case EnumeratorState::PastEnd:
351 return false;
352 }
353 throw std::runtime_error("Unexpected enumerator state");
354 }
355
356 T getCurrent() const {
357 switch (this->state) {
358 case EnumeratorState::NotStarted:
359 throw std::logic_error("You cannot call 'getCurrent' before 'moveNext'");
360 case EnumeratorState::PastEnd:
361 throw std::logic_error("You cannot call 'getCurrent' past the collection end");
362 case EnumeratorState::Valid:
363 return this->current;
364 }
365 throw std::runtime_error("Unexpected enumerator state");
366 }
367};
368
370
371namespace Detail {
372// See if we can use ICastable interface to cast from T to S. This is only possible if:
373// - Both T and S are pointer types (let's denote T = From* and S = To*)
374// - Expression (From*)()->to<To>() is well-formed
375// Essentially this means the following code is well-formed:
376// From *current = input->getCurrent(); current->to<To>();
377template <typename From, typename To, typename = void>
378static constexpr bool can_be_casted = false;
379
380template <typename From, typename To>
381static constexpr bool
382 can_be_casted<From *, To *, std::void_t<decltype(std::declval<From *>()->template to<To>())>> =
383 true;
384} // namespace Detail
385
387template <typename T, typename S>
388class AsEnumerator final : public Enumerator<S> {
389 template <typename U = S>
390 typename std::enable_if_t<!Detail::can_be_casted<T, S>, U> getCurrentImpl() const {
391 T current = input->getCurrent();
392 return dynamic_cast<S>(current);
393 }
394
395 template <typename U = S>
396 typename std::enable_if_t<Detail::can_be_casted<T, S>, U> getCurrentImpl() const {
397 T current = input->getCurrent();
398 return current->template to<std::remove_pointer_t<S>>();
399 }
400
401 protected:
402 Enumerator<T> *input;
403
404 public:
405 explicit AsEnumerator(Enumerator<T> *input) : input(input) {}
406
407 std::string toString() const {
408 return "AsEnumerator(" + this->input->toString() + "):" + this->stateName();
409 }
410
411 void reset() override {
413 this->input->reset();
414 }
415
416 bool moveNext() override {
417 bool result = this->input->moveNext();
418 if (result)
419 this->state = EnumeratorState::Valid;
420 else
421 this->state = EnumeratorState::PastEnd;
422 return result;
423 }
424
425 S getCurrent() const override { return getCurrentImpl(); }
426};
427
429
431template <typename T, typename S, typename Mapper>
432class MapEnumerator final : public Enumerator<S> {
433 protected:
434 Enumerator<T> *input;
435 Mapper map;
436 S current;
437
438 public:
439 MapEnumerator(Enumerator<T> *input, Mapper map) : input(input), map(std::move(map)) {}
440
441 void reset() {
442 this->input->reset();
444 }
445
446 [[nodiscard]] std::string toString() const {
447 return "MapEnumerator(" + this->input->toString() + "):" + this->stateName();
448 }
449
450 bool moveNext() {
451 switch (this->state) {
452 case EnumeratorState::NotStarted:
453 case EnumeratorState::Valid: {
454 bool success = input->moveNext();
455 if (success) {
456 T currentInput = this->input->getCurrent();
457 this->current = this->map(currentInput);
458 this->state = EnumeratorState::Valid;
459 return true;
460 } else {
461 this->state = EnumeratorState::PastEnd;
462 return false;
463 }
464 }
465 case EnumeratorState::PastEnd:
466 return false;
467 }
468 throw std::runtime_error("Unexpected enumerator state");
469 }
470
471 S getCurrent() const {
472 switch (this->state) {
473 case EnumeratorState::NotStarted:
474 throw std::logic_error("You cannot call 'getCurrent' before 'moveNext'");
475 case EnumeratorState::PastEnd:
476 throw std::logic_error("You cannot call 'getCurrent' past the collection end");
477 case EnumeratorState::Valid:
478 return this->current;
479 }
480 throw std::runtime_error("Unexpected enumerator state");
481 }
482};
483
484template <typename T, typename Mapper>
485MapEnumerator(Enumerator<T> *,
486 Mapper) -> MapEnumerator<T, typename std::invoke_result_t<Mapper, T>, Mapper>;
487
489
491template <typename T>
492class ConcatEnumerator final : public Enumerator<T> {
493 std::vector<Enumerator<T> *> inputs;
494 T currentResult;
495
496 public:
497 ConcatEnumerator() = default;
498 // We take ownership of the vector
499 explicit ConcatEnumerator(std::vector<Enumerator<T> *> &&inputs) : inputs(std::move(inputs)) {
500 for (auto *currentInput : inputs)
501 if (currentInput == nullptr) throw std::logic_error("Null iterator in concatenation");
502 }
503
504 ConcatEnumerator(std::initializer_list<Enumerator<T> *> inputs) : inputs(inputs) {
505 for (auto *currentInput : inputs)
506 if (currentInput == nullptr) throw std::logic_error("Null iterator in concatenation");
507 }
508 explicit ConcatEnumerator(Enumerator<Enumerator<T> *> *inputs)
509 : ConcatEnumerator(inputs->toVector()) {}
510
511 [[nodiscard]] std::string toString() const { return "ConcatEnumerator:" + this->stateName(); }
512
513 private:
514 bool advance() {
515 this->state = EnumeratorState::Valid;
516 for (auto *currentInput : inputs) {
517 if (currentInput->moveNext()) {
518 this->currentResult = currentInput->getCurrent();
519 return true;
520 }
521 }
522
523 this->state = EnumeratorState::PastEnd;
524 return false;
525 }
526
527 public:
528 Enumerator<T> *concat(Enumerator<T> *other) override {
529 // Too late to add
530 if (this->state == EnumeratorState::PastEnd)
531 throw std::runtime_error("Invalid enumerator state to concatenate");
532
533 inputs.push_back(other);
534
535 return this;
536 }
537
538 void reset() override {
539 for (auto *currentInput : inputs) currentInput->reset();
541 }
542
543 bool moveNext() override {
544 switch (this->state) {
545 case EnumeratorState::NotStarted:
546 case EnumeratorState::Valid:
547 return this->advance();
548 case EnumeratorState::PastEnd:
549 return false;
550 }
551 throw std::runtime_error("Unexpected enumerator state");
552 }
553
554 T getCurrent() const override {
555 switch (this->state) {
556 case EnumeratorState::NotStarted:
557 throw std::logic_error("You cannot call 'getCurrent' before 'moveNext'");
558 case EnumeratorState::PastEnd:
559 throw std::logic_error("You cannot call 'getCurrent' past the collection end");
560 case EnumeratorState::Valid:
561 return this->currentResult;
562 }
563 throw std::runtime_error("Unexpected enumerator state");
564 }
565};
566
568
569template <typename T>
570template <typename Mapper>
571Enumerator<std::invoke_result_t<Mapper, T>> *Enumerator<T>::map(Mapper map) {
572 return new MapEnumerator(this, std::move(map));
573}
574
575template <typename T>
576template <typename S>
577Enumerator<S> *Enumerator<T>::as() {
578 return new AsEnumerator<T, S>(this);
579}
580
581template <typename T>
582template <typename Filter>
583Enumerator<T> *Enumerator<T>::where(Filter filter) {
584 return new FilterEnumerator(this, std::move(filter));
585}
586
587template <typename T>
588template <typename Container>
589Enumerator<typename Container::value_type> *Enumerator<T>::createEnumerator(const Container &data) {
590 return new IteratorEnumerator(data.begin(), data.end(), typeid(Container).name());
591}
592
593template <typename T>
594Enumerator<T> *Enumerator<T>::emptyEnumerator() {
595 return new EmptyEnumerator<T>();
596}
597
598template <typename T>
599template <typename Iter>
600Enumerator<typename Iter::value_type> *Enumerator<T>::createEnumerator(Iter begin, Iter end) {
601 return new IteratorEnumerator(begin, end, "iterator");
602}
603
604template <typename T>
605template <typename Iter>
606Enumerator<typename Iter::value_type> *Enumerator<T>::createEnumerator(iterator_range<Iter> range) {
607 return new IteratorEnumerator(range.begin(), range.end(), "range");
608}
609
610template <typename T>
611Enumerator<T> *Enumerator<T>::concatAll(Enumerator<Enumerator<T> *> *inputs) {
612 return new ConcatEnumerator<T>(inputs);
613}
614
615template <typename T>
616Enumerator<T> *Enumerator<T>::concat(Enumerator<T> *other) {
617 return new ConcatEnumerator<T>({this, other});
618}
619
621
622template <typename T>
623T EnumeratorHandle<T>::operator*() const {
624 if (enumerator == nullptr) throw std::logic_error("Dereferencing end() iterator");
625 return enumerator->getCurrent();
626}
627
628template <typename T>
629const EnumeratorHandle<T> &EnumeratorHandle<T>::operator++() {
630 enumerator->moveNext();
631 return *this;
632}
633
634template <typename T>
635bool EnumeratorHandle<T>::operator==(const EnumeratorHandle<T> &other) const {
636 return !(*this != other);
637}
638
639template <typename T>
640bool EnumeratorHandle<T>::operator!=(const EnumeratorHandle<T> &other) const {
641 if (this->enumerator == other.enumerator) return true;
642 if (other.enumerator != nullptr) throw std::logic_error("Comparison with different iterator");
643 return this->enumerator->state == EnumeratorState::Valid;
644}
645
646template <typename Iter>
647Enumerator<typename std::iterator_traits<Iter>::value_type> *enumerate(Iter begin, Iter end) {
648 return new IteratorEnumerator(begin, end, "iterator");
649}
650
651template <typename Iter>
653 return new IteratorEnumerator(range.begin(), range.end(), "range");
654}
655
656template <typename Container>
657Enumerator<typename Container::value_type> *enumerate(const Container &data) {
658 using std::begin;
659 using std::end;
660 return new IteratorEnumerator(begin(data), end(data), typeid(data).name());
661}
662
663// TODO: Flatten ConcatEnumerator's during concatenation
664template <typename T>
665Enumerator<T> *concat(std::initializer_list<Enumerator<T> *> inputs) {
666 return new ConcatEnumerator<T>(inputs);
667}
668
669template <typename... Args>
670auto concat(Args &&...inputs) {
671 using FirstEnumeratorTy =
672 std::remove_pointer_t<std::decay_t<std::tuple_element_t<0, std::tuple<Args...>>>>;
673 std::initializer_list<Enumerator<typename FirstEnumeratorTy::value_type> *> init{
674 std::forward<Args>(inputs)...};
675 return concat(init);
676}
677
678} // namespace P4::Util
679
680#endif /* LIB_ENUMERATOR_H_ */
Casts each element.
Definition enumerator.h:388
S getCurrent() const override
Get current element in the collection.
Definition enumerator.h:425
void reset() override
Move back to the beginning of the collection.
Definition enumerator.h:411
bool moveNext() override
Definition enumerator.h:416
Concatenation.
Definition enumerator.h:492
void reset() override
Move back to the beginning of the collection.
Definition enumerator.h:538
bool moveNext() override
Definition enumerator.h:543
T getCurrent() const override
Get current element in the collection.
Definition enumerator.h:554
Enumerator< T > * concat(Enumerator< T > *other) override
Append all elements of other after all elements of this.
Definition enumerator.h:528
Always empty iterator (equivalent to end())
Definition enumerator.h:296
T getCurrent() const
Get current element in the collection.
Definition enumerator.h:301
bool moveNext()
Always returns false.
Definition enumerator.h:300
Definition enumerator.h:48
Type-erased Enumerator interface.
Definition enumerator.h:69
Enumerator< S > * as()
Cast to an enumerator of S objects.
Definition enumerator.h:577
Enumerator< std::invoke_result_t< Mapper, T > > * map(Mapper map)
Apply specified function to all elements of this enumerator.
Definition enumerator.h:571
virtual bool moveNext()=0
virtual Enumerator< T > * concat(Enumerator< T > *other)
Append all elements of other after all elements of this.
Definition enumerator.h:616
T nextOrDefault()
Next element, or the default value if none exists.
Definition enumerator.h:178
T single()
The only next element; throws if the enumerator does not have exactly 1 element.
Definition enumerator.h:157
virtual void reset()
Move back to the beginning of the collection.
Definition enumerator.h:93
bool any()
True if the enumerator has at least one element.
Definition enumerator.h:154
virtual T getCurrent() const =0
Get current element in the collection.
T next()
Next element; throws if there are no elements.
Definition enumerator.h:185
uint64_t count()
Enumerate all elements and return the count.
Definition enumerator.h:147
T singleOrDefault()
Definition enumerator.h:168
Enumerator< T > * where(Filter filter)
Return an enumerator returning all elements that pass the filter.
Definition enumerator.h:583
static Enumerator< T > * concatAll(Enumerator< Enumerator< T > * > *inputs)
Concatenate all these collections into a single one.
Definition enumerator.h:611
Definition enumerator.h:314
T getCurrent() const
Get current element in the collection.
Definition enumerator.h:356
void reset()
Move back to the beginning of the collection.
Definition enumerator.h:340
bool moveNext()
Definition enumerator.h:345
A generic iterator returning elements of type T.
Definition enumerator.h:200
bool moveNext()
Definition enumerator.h:220
std::iterator_traits< Iter >::value_type getCurrent() const
Get current element in the collection.
Definition enumerator.h:245
Transforms all elements from type T to type S.
Definition enumerator.h:432
void reset()
Move back to the beginning of the collection.
Definition enumerator.h:441
bool moveNext()
Definition enumerator.h:450
S getCurrent() const
Get current element in the collection.
Definition enumerator.h:471
bool moveNext()
Definition enumerator.h:266
T getCurrent() const
Get current element in the collection.
Definition enumerator.h:279
Definition iterator_range.h:44
STL namespace.