4

In Julia I can use argmax(X) to find max element. If I want to find all element satisfying condition C I can use findall(C,X). But how can I combine the two? What's the most efficient/idiomatic/concise way to find maximum element index satisfying some condition in Julia?

alagris
  • 1,838
  • 16
  • 31

3 Answers3

3

If you'd like to avoid allocations, filtering the array lazily would work:

idx_filtered = (i for (i, el) in pairs(X) if C(el))
argmax(i -> X[i], idx_filtered)

Unfortunately, this is about twice as slow as a hand-written version. (edit: in my benchmarks, it's 2x slower on Intel Xeon Platinum but nearly equal on Apple M1)

function byhand(C, X)
    start = findfirst(C, X)
    isnothing(start) && return nothing

    imax, max = start, X[start]
    for i = start:lastindex(X)
        if C(X[i]) && X[i] > max
            imax, max = i, X[i]
        end
     end
     imax, max
end
August
  • 12,410
  • 3
  • 35
  • 51
1

You can store the index returned by findall and subset it with the result of argmax of the vector fulfilling the condition.

X = [5, 4, -3, -5]
C = <(0)

i = findall(C, X);
i[argmax(X[i])]
#3

Or combine both:

argmax(i -> X[i], findall(C, X))
#3

Assuming that findall is not empty. Otherwise it need to be tested e.g. with isempty.


Benchmark

#Functions
function August(C, X)
    idx_filtered = (i for (i, el) in pairs(X) if C(el))
    argmax(i -> X[i], idx_filtered)
end

function byhand(C, X)
    start = findfirst(C, X)
    isnothing(start) && return nothing

    imax, max = start, X[start]
    for i = start:lastindex(X)
        if C(X[i]) && X[i] > max
            imax, max = i, X[i]
        end
     end
     imax, max
end

function GKi1(C, X)
    i = findall(C, X);
    i[argmax(X[i])]
end

GKi2(C, X) = argmax(i -> X[i], findall(C, X))
#Data
using Random
Random.seed!(42)
n = 100000
X = randn(n)
C = <(0)
#Benchmark
using BenchmarkTools

suite = BenchmarkGroup()
suite["August"] = @benchmarkable August(C, $X)
suite["byhand"] = @benchmarkable byhand(C, $X)
suite["GKi1"] = @benchmarkable GKi1(C, $X)
suite["GKi2"] = @benchmarkable GKi2(C, $X)

tune!(suite);
results = run(suite)
#Results
results
#4-element BenchmarkTools.BenchmarkGroup:
#  tags: []
#  "August" => Trial(641.061 μs)
#  "byhand" => Trial(261.135 μs)
#  "GKi2" => Trial(259.260 μs)
#  "GKi1" => Trial(339.570 μs)

results.data["August"]
#BenchmarkTools.Trial: 7622 samples with 1 evaluation.
# Range (min … max):  641.061 μs … 861.379 μs  ┊ GC (min … max): 0.00% … 0.00%
# Time  (median):     643.640 μs               ┊ GC (median):    0.00%
# Time  (mean ± σ):   653.027 μs ±  18.123 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%
#
#  ▄█▅▄▃   ▂▂▃▁ ▁▃▃▂▂     ▁▃    ▁▁                               ▁
#  ██████▇████████████▇▆▆▇████▇▆██▇▇▇▆▆▆▅▇▆▅▅▅▅▆██▅▆▆▆▇▆▇▇▆▇▆▆▆▅ █
#  641 μs        Histogram: log(frequency) by time        718 μs <
#
# Memory estimate: 16 bytes, allocs estimate: 1.

results.data["byhand"]
#BenchmarkTools.Trial: 10000 samples with 1 evaluation.
# Range (min … max):  261.135 μs … 621.141 μs  ┊ GC (min … max): 0.00% … 0.00%
# Time  (median):     261.356 μs               ┊ GC (median):    0.00%
# Time  (mean ± σ):   264.382 μs ±  11.638 μs  ┊ GC (mean ± σ):  0.00% ± 0.00%
#
#  █    ▁▁▁▁     ▂      ▁▁      ▂  ▁                        ▁    ▁
#  █▅▂▂▅████▅▄▃▄▆█▇▇▆▄▅███▇▄▄▅▆▆█▄▇█▅▄▅▅▆▇▇▅▄▅▄▄▄▃▄▃▃▃▄▅▆▅▄▇█▆▅▄ █
#  261 μs        Histogram: log(frequency) by time        292 μs <
#
# Memory estimate: 32 bytes, allocs estimate: 1.

results.data["GKi1"]
#BenchmarkTools.Trial: 10000 samples with 1 evaluation.
# Range (min … max):  339.570 μs …  1.447 ms  ┊ GC (min … max): 0.00% … 0.00%
# Time  (median):     342.579 μs              ┊ GC (median):    0.00%
# Time  (mean ± σ):   355.167 μs ± 52.935 μs  ┊ GC (mean ± σ):  1.90% ± 6.85%
#
#  █▆▄▅▃▂▁▁                                                   ▁ ▁
#  ████████▇▆▆▅▅▅▆▄▄▄▄▁▃▁▁▃▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁█ █
#  340 μs        Histogram: log(frequency) by time       722 μs <
#
# Memory estimate: 800.39 KiB, allocs estimate: 11.

results.data["GKi2"]
#BenchmarkTools.Trial: 10000 samples with 1 evaluation.
# Range (min … max):  259.260 μs … 752.773 μs  ┊ GC (min … max): 0.00% … 54.40%
# Time  (median):     260.692 μs               ┊ GC (median):    0.00%
# Time  (mean ± σ):   270.300 μs ±  40.094 μs  ┊ GC (mean ± σ):  1.31% ±  5.60%
#
#  █▁▁▅▄▂▂▄▃▂▁▁▁  ▁                                              ▁
#  █████████████████▇██▆▆▇▆▅▄▆▆▆▄▅▄▆▅▇▇▆▆▅▅▄▅▃▃▅▃▄▁▁▁▃▁▃▃▃▄▃▃▁▃▃ █
#  259 μs        Histogram: log(frequency) by time        390 μs <
#
# Memory estimate: 408.53 KiB, allocs estimate: 9.
versioninfo()
#Julia Version 1.8.0
#Commit 5544a0fab7 (2022-08-17 13:38 UTC)
#Platform Info:
#  OS: Linux (x86_64-linux-gnu)
#  CPU: 8 × Intel(R) Core(TM) i7-2600K CPU @ 3.40GHz
#  WORD_SIZE: 64
#  LIBM: libopenlibm
#  LLVM: libLLVM-13.0.1 (ORCJIT, sandybridge)
#  Threads: 1 on 8 virtual cores

In this example argmax(i -> X[i], findall(C, X)) is close to the performance of the hand written function of @August but uses more memory, but can show better performance in case the data is sorted:

sort!(X)
results = run(suite)
#4-element BenchmarkTools.BenchmarkGroup:
#  tags: []
#  "August" => Trial(297.519 μs)
#  "byhand" => Trial(270.486 μs)
#  "GKi2" => Trial(242.320 μs)
#  "GKi1" => Trial(319.732 μs)
GKi
  • 37,245
  • 2
  • 26
  • 48
  • 1
    Lovely answer! Covers the alternatives, highlights the performance of the simplest approach. Oh, and it gets the right answer! – Ted Dunning Aug 31 '22 at 15:49
0

From what I understand from your question you can use findmax() (requires Julia >= v1.7) to find the maximum index on the result of findall():

julia> v = [10, 20, 30, 40, 50]
5-element Vector{Int64}:
 10
 20
 30
 40
 50

julia> findmax(findall(x -> x > 30, v))[1]
5

Performance of the above function:

julia> v = collect(10:1:10_000_000);

julia> @btime findmax(findall(x -> x > 30, v))[1]
  33.471 ms (10 allocations: 77.49 MiB)
9999991

Update: solution suggested by @dan-getz of using last() and findlast() perform better than findmax() but findlast() is the winner:

julia> @btime last(findall(x -> x > 30, v))
  19.961 ms (9 allocations: 77.49 MiB)
9999991

julia> @btime findlast(x -> x > 30, v)
  81.422 ns (2 allocations: 32 bytes)

Update 2: Looks like the OP wanted to find the max element and not only the index. In that case, the solution would be:

julia> v[findmax(findall(x -> x > 30, v))[1]]
50
  • I think the intention is the maximum element, not index. – August Aug 25 '22 at 21:12
  • @August OP said: _"... way to find maximum index ..."_ – Chirag Anand Aug 25 '22 at 21:21
  • 1
    This is not correct, assume `v = [5, 4, -3, -5]` and the condition is `<(0)` ? – AboAmmar Aug 26 '22 at 05:27
  • @ChiragAnand sorry, for confusion. I meant index of maximum value. The question in the title is actually the one I wanted to solve. – alagris Aug 26 '22 at 09:05
  • 1
    But the problem here is that `findall` returns indexes. Finding the maximum index is not the right answer. The question asked for finding the index of the maximum element subject to a constraint. – Ted Dunning Aug 31 '22 at 15:44
  • Yes, I understood the question incorrectly. Nonetheless, I am keeping the answer there just in case it is useful to others. – Chirag Anand Sep 09 '22 at 09:44
  • @AboAmmar The answer works correctly for finding the `index` of the last element for your use case too: `v[findmax(findall(x -> x < 0, v))[1]]` `5` ``` – Chirag Anand Sep 09 '22 at 10:04
  • FWIW, I have added an update to include finding the max element and not only the index. – Chirag Anand Sep 09 '22 at 10:20