Obviously, each call to HeadAndTail should enumerate the sequence again (unless there is some sort of caching used). For example, consider the following:
var a = HeadAndTail(sequence);
Console.WriteLine(HeadAndTail(a.Tail).Tail);
//Element #2; enumerator is at least at #2 now.
var b = HeadAndTail(sequence);
Console.WriteLine(b.Tail);
//Element #1; there is no way to get #1 unless we enumerate the sequence again.
For the same reason, HeadAndTail could not be implemented as separate Head and Tail methods (unless you want even the first call to Tail to enumerate the sequence again even if it was already enumerated by a call to Head).
Additionally, HeadAndTail should not return an instance of IEnumerable (as it could be enumerated multiple times).
This leaves us with the only option: HeadAndTail should return IEnumerator, and, to make things more obvious, it should accept IEnumerator as well (we're just moving an invocation of GetEnumerator from inside the HeadAndTail to the outside, to emphasize it is of one-time use only).
Now that we have worked out the requirements, the implementation is pretty straightforward:
class HeadAndTail<T> {
public readonly T Head;
public readonly IEnumerator<T> Tail;
public HeadAndTail(T head, IEnumerator<T> tail) {
Head = head;
Tail = tail;
}
}
static class IEnumeratorExtensions {
public static HeadAndTail<T> HeadAndTail<T>(this IEnumerator<T> enumerator) {
if (!enumerator.MoveNext()) return null;
return new HeadAndTail<T>(enumerator.Current, enumerator);
}
}
And now it can be used like this:
Console.WriteLine(sequence.GetEnumerator().HeadAndTail().Tail.HeadAndTail().Head);
//Element #2
Or in recursive functions like this:
TResult FoldR<TSource, TResult>(
IEnumerator<TSource> sequence,
TResult seed,
Func<TSource, TResult, TResult> f
) {
var headAndTail = sequence.HeadAndTail();
if (headAndTail == null) return seed;
return f(headAndTail.Head, FoldR(headAndTail.Tail, seed, f));
}
int Sum(IEnumerator<int> sequence) {
return FoldR(sequence, 0, (x, y) => x+y);
}
var array = Enumerable.Range(1, 5);
Console.WriteLine(Sum(array.GetEnumerator())); //1+(2+(3+(4+(5+0)))))