Note that the largest possible sum is 3*1024**2, so the largest possible square root is 1773 (floor - or 1774 rounded).
So you could simply take 0 as a starting guess, and repeatedly add 1 until the square exceeds the sum. That can't take more than about 1770 iterations.
Of course that's probably too slow. A straightforward binary search can cut that to 11 iterations, and doesn't require division (I'm assuming the MCU can shift right by 1 bit, which is the same as floor-division by 2).
EDIT
Here's some code, for a binary search returning the floor of the true square root:
def isqrt(n):
if n <= 1:
return n
lo = 0
hi = n >> 1
while lo <= hi:
mid = (lo + hi) >> 1
sq = mid * mid
if sq == n:
return mid
elif sq < n:
lo = mid + 1
result = mid
else:
hi = mid - 1
return result
To check, run:
from math import sqrt
assert all(isqrt(i) == int(sqrt(i)) for i in range(3*1024**2 + 1))
That checks all possible inputs given what you said - and since binary search is notoriously tricky to get right in all cases, it's good to check every case! It doesn't take long on a "real" machine ;-)
PROBABLY IMPORTANT
To guard against possible overflow, and speed it significantly, change the initialization of lo
and hi
to this:
hi = 1
while hi * hi <= n:
hi <<= 1
lo = hi >> 1
Then the runtime becomes proportional to the number of bits in the result, greatly speeding smaller results. Indeed, for sloppy enough definitions of "close", you could stop right there.
FOR POSTERITY ;-)
Looks like the OP doesn't actually need square roots at all. But for someone who may, and can't afford division, here's a simplified version of the code, also removing multiplications from the initialization. Note: I'm not using .bit_length()
because lots of deployed Python versions don't support that.
def isqrt(n):
if n <= 1:
return n
hi, hisq = 2, 4
while hisq <= n:
hi <<= 1
hisq <<= 2
lo = hi >> 1
while hi - lo > 1:
mid = (lo + hi) >> 1
if mid * mid <= n:
lo = mid
else:
hi = mid
assert lo + 1 == hi
assert lo**2 <= n < hi**2
return lo
from math import sqrt
assert all(isqrt(i) == int(sqrt(i)) for i in range(3*1024**2 + 1))