You can
optimise the calculate method by assuming that is c % 2 == 0
is false than c % 2 != 0
must be true. You can also assume that c * 3 + 1
must be an even number so you can calculate (c * 3 + 1)/2
and add two to the numSteps. You can use a loop instead of recursion as Java doesn't have tail-call optimisation.
get a bigger improvement by using memorisation. For each each number you can memorise the result you get and if the number has been calculated before just return that value. You might want to place an upper bound on memorization e.g. no higher than the last number you want to calculate. If you don't do this some of the value will be many times the largest value.
For your interest
public class Collatz {
static final int[] CALC_CACHE = new int[2_000_000_000];
static int calculate(long n) {
int numSteps = 0;
long c = n;
while (c != 1) {
if (c < CALC_CACHE.length) {
int steps = CALC_CACHE[(int) c];
if (steps > 0) {
numSteps += steps;
break;
}
}
if (c % 2 == 0) {
numSteps++;
c /= 2;
} else {
numSteps += 2;
if (c > Long.MAX_VALUE / 3)
throw new IllegalStateException("c is too large " + c);
c = (c * 3 + 1) / 2;
}
}
if (n < CALC_CACHE.length) {
CALC_CACHE[(int) n] = numSteps;
}
return numSteps;
}
public static void main(String args[]) {
long n = 1, maxN = 0, maxSteps = 0;
long startTime = System.currentTimeMillis();
while (System.currentTimeMillis() < startTime + 60000) {
for (int i = 0; i < 10; i++) {
int steps = calculate(n);
if (steps > maxSteps) {
maxSteps = steps;
maxN = n;
}
n++;
}
if (n % 10000000 == 1)
System.out.printf("%,d%n", n);
}
System.out.printf("The highest number was: %,d, maxSteps: %,d for: %,d%n", n, maxSteps, maxN);
}
}
prints
The highest number was: 1,672,915,631, maxSteps: 1,000 for: 1,412,987,847
A more advanced answer would be to use multiple threads. In this case using recursion with memorisation was easier to implement.
import java.util.stream.LongStream;
public class Collatz {
static final short[] CALC_CACHE = new short[Integer.MAX_VALUE-8];
public static int calculate(long c) {
if (c == 1) {
return 0;
}
int steps;
if (c < CALC_CACHE.length) {
steps = CALC_CACHE[(int) c];
if (steps > 0)
return steps;
}
if (c % 2 == 0) {
steps = calculate(c / 2) + 1;
} else {
steps = calculate((c * 3 + 1) / 2) + 2;
}
if (c < CALC_CACHE.length) {
if (steps > Short.MAX_VALUE)
throw new AssertionError();
CALC_CACHE[(int) c] = (short) steps;
}
return steps;
}
static int calculate2(long n) {
int numSteps = 0;
long c = n;
while (c != 1) {
if (c < CALC_CACHE.length) {
int steps = CALC_CACHE[(int) c];
if (steps > 0) {
numSteps += steps;
break;
}
}
if (c % 2 == 0) {
numSteps++;
c /= 2;
} else {
numSteps += 2;
if (c > Long.MAX_VALUE / 3)
throw new IllegalStateException("c is too large " + c);
c = (c * 3 + 1) / 2;
}
}
if (n < CALC_CACHE.length) {
CALC_CACHE[(int) n] = (short) numSteps;
}
return numSteps;
}
public static void main(String args[]) {
long maxN = 0, maxSteps = 0;
long startTime = System.currentTimeMillis();
long[] res = LongStream.range(1, 6_000_000_000L).parallel().collect(
() -> new long[2],
(long[] arr, long n) -> {
int steps = calculate(n);
if (steps > arr[0]) {
arr[0] = steps;
arr[1] = n;
}
},
(a, b) -> {
if (a[0] < b[0]) {
a[0] = b[0];
a[1] = b[1];
}
});
maxN = res[1];
maxSteps = res[0];
long time = System.currentTimeMillis() - startTime;
System.out.printf("After %.3f seconds, maxSteps: %,d for: %,d%n", time / 1e3, maxSteps, maxN);
}
}
prints
After 52.461 seconds, maxSteps: 1,131 for: 4,890,328,815
Note: If I change the second calculate call to
steps = calculate((c * 3 + 1) ) + 1;
it prints
After 63.065 seconds, maxSteps: 1,131 for: 4,890,328,815