Anytime I call SendAsync, it should automatically authenticate the API if the access token has expired.
The issue is that _accessTokenExpirationTime
may lead to a race condition. How do I protect against it, SemaphoreSlim?
I also don't think the way I did it, is correct. I mean it works, but DateTime.UtcNow.AddSeconds
is probably wrong because what if the response delays with 1 minute and the access token's lifetime remains the same?
Any suggestions are welcome.
public sealed partial class Client : IDisposable
{
private readonly string _clientId;
private readonly string _clientSecret;
private readonly ILogger<Client> _logger;
private readonly WebSocketClient _client;
private readonly ConcurrentDictionary<Guid, TaskCompletionSource<ReadOnlyMemory<byte>>> _outstandingRequests = new();
private DateTime _accessTokenExpirationTime = DateTime.UtcNow;
public Client(string clientId, string clientSecret, ILoggerFactory? loggerFactory = default)
{
_logger = (loggerFactory ?? NullLoggerFactory.Instance).CreateLogger<Client>();
_clientId = clientId;
_clientSecret = clientSecret;
_client = new WebSocketClient("wss://www.deribit.com/ws/api/v2", loggerFactory);
_client.Connected += OnConnected;
_client.Disconnected += OnDisconnected;
_client.MessageReceived += OnMessageReceived;
}
public void Dispose()
{
_client.Connected -= OnConnected;
_client.Disconnected -= OnDisconnected;
_client.MessageReceived -= OnMessageReceived;
_client.Dispose();
}
public Task StartAsync()
{
return _client.StartAsync();
}
public Task StopAsync()
{
return _client.StopAsync();
}
private async ValueTask EnsureAuthenticated()
{
// TODO: Prevent race condition?
if ((_accessTokenExpirationTime - DateTime.UtcNow).TotalSeconds < 30)
{
var response = await AuthAsync(_clientId, _clientSecret).ConfigureAwait(false);
if (response is { Error: { } })
{
_logger.LogError("Error requesting authentication token: {Message}", response.Error.Message);
}
else if (response is { Result: { } })
{
_accessTokenExpirationTime = DateTime.UtcNow.AddSeconds(response.Result.ExpiresIn);
_logger.LogInformation("Authentication token received. Expires in {ExpiresIn} seconds.", response.Result.ExpiresIn);
}
}
}
private async ValueTask<TResponse> SendAsync<TResponse>(string method, object @params, bool isPublic)
{
if (isPublic)
{
await EnsureAuthenticated().ConfigureAwait(false);
}
var request = new JsonRpcRequest("2.0", Guid.NewGuid(), method, @params);
var tcs = new TaskCompletionSource<ReadOnlyMemory<byte>>(TaskCreationOptions.RunContinuationsAsynchronously);
_outstandingRequests.TryAdd(request.Id, tcs);
return await SendRequestAndWaitForResponseAsync().ConfigureAwait(false);
async ValueTask<TResponse> SendRequestAndWaitForResponseAsync()
{
var json = JsonSerializer.Serialize(request);
var message = new Message(Encoding.UTF8.GetBytes(json));
await _client.SendAsync(message).ConfigureAwait(false);
var completedTask = await Task.WhenAny(tcs.Task, Task.Delay(5000)).ConfigureAwait(false);
if (completedTask != tcs.Task)
{
tcs.SetException(new TimeoutException("The operation timed out"));
_outstandingRequests.TryRemove(request.Id, out _);
}
var response = await tcs.Task.ConfigureAwait(false);
return JsonSerializer.Deserialize<TResponse>(response.Span) ?? throw new JsonException("Could not deserialize the object");
}
}
public ValueTask<JsonRpcResponse<AuthInfo>> AuthAsync(string clientId, string clientSecret)
{
var @params = new Dictionary<string, string>
{
{ "grant_type", "client_credentials" },
{ "client_id", clientId },
{ "client_secret", clientSecret }
};
return SendAsync<JsonRpcResponse<AuthInfo>>("/public/auth", @params, false);
}
private void OnConnected(object? sender, EventArgs e)
{
_accessTokenExpirationTime = DateTime.UtcNow;
}
private void OnDisconnected(object? sender, EventArgs e)
{
}
private void OnNotification(Notification notification)
{
var callbacks = _subscriptionManager.GetCallbacks(notification.Params.Channel);
foreach (var callback in callbacks)
{
try
{
Task.Run(() => callback(notification));
}
catch (Exception ex)
{
_logger.LogError(ex, "OnNotification: Error during event callback call");
}
}
}
private void OnMessageReceived(object? sender, MessageReceivedEventArgs e)
{
try
{
using var document = JsonDocument.Parse(e.Message.Buffer);
if (document.RootElement.TryGetProperty("method", out var methodElement))
{
// Handle request
var method = methodElement.GetString();
Debug.Assert(method != null);
if (method == "subscription")
{
var notification = document.RootElement.Deserialize<Notification>();
Debug.Assert(notification != null);
OnNotification(notification);
}
else if (method == "heartbeat")
{
Console.WriteLine("Heartbeat");
}
else
{
_logger.LogWarning("Unknown server request");
}
}
else
{
// Handle response
var response = document.RootElement.Deserialize<JsonRpcResponse<object>>();
Debug.Assert(response != null);
if (!_outstandingRequests.TryRemove(response.Id, out var tcs))
{
_logger.LogWarning("Could not find request id {Id}", response.Id);
return;
}
if (response.Error != null)
{
tcs.SetException(new JsonRpcException(response));
}
else
{
tcs.TrySetResult(e.Message.Buffer);
}
}
}
catch (Exception ex)
{
_logger.LogError(ex, "Error handling message response: {ExceptionMessage}", ex.Message);
}
}
}
public record AuthInfo(
[property: JsonPropertyName("access_token")] string AccessToken,
[property: JsonPropertyName("expires_in")] int ExpiresIn,
[property: JsonPropertyName("refresh_token")] string RefreshToken,
[property: JsonPropertyName("scope")] string Scope,
[property: JsonPropertyName("state")] string? State,
[property: JsonPropertyName("token_type")] string TokenType);
public record JsonRpcResponse<T>(
[property: JsonPropertyName("jsonrpc")] string JsonRpc,
[property: JsonPropertyName("id")] Guid Id,
[property: JsonPropertyName("result")] T? Result,
[property: JsonPropertyName("error")] JsonRpcError? Error,
[property: JsonPropertyName("testnet")] bool Testnet,
[property: JsonPropertyName("usIn")] [property: JsonConverter(typeof(MicrosecondEpochDateTimeConverter))] DateTime UsIn,
[property: JsonPropertyName("usOut")] [property: JsonConverter(typeof(MicrosecondEpochDateTimeConverter))] DateTime UsOut,
[property: JsonPropertyName("usDiff")] long UsDiff)
where T : class;