I have a DataFrame with several columns and date as index. I use sns.heatmap
to plot it, with the date on the y-axis. I would like to force the ticks to display the 1st of October of every year only. I used the solution given by @Ayrton Bourn on Date axis in heatmap seaborn, which allows me to change the frequency of ticks but not at which day to display the date.
His method is the only one that allows me to choose the frequency of y-ticks so far. I tried using mdates.YearLocator()
or set_major_locator
without success.
With the code below, do you have any suggestion that would allow me to choose the frequency of date ticks (every year) and the day displayed (every '200x-10-01' for example) ?
import numpy as np
from datetime import date, datetime, timedelta
import pandas as pd
import seaborn as sns
import matplotlib.dates as mdates
import matplotlib.pyplot as plt
from collections.abc import Iterable
from sklearn import linear_model
class AxTransformer:
def __init__(self, datetime_vals=False):
self.datetime_vals = datetime_vals
self.lr = linear_model.LinearRegression()
return
def process_tick_vals(self, tick_vals):
if not isinstance(tick_vals, Iterable) or isinstance(tick_vals, str):
tick_vals = [tick_vals]
if self.datetime_vals == True:
tick_vals = pd.to_datetime(tick_vals).astype(int).values
tick_vals = np.array(tick_vals)
return tick_vals
def fit(self, ax, axis='x'):
axis = getattr(ax, f'get_{axis}axis')()
tick_locs = axis.get_ticklocs()
tick_vals = self.process_tick_vals([label._text for label in axis.get_ticklabels()])
self.lr.fit(tick_vals.reshape(-1, 1), tick_locs)
return
def transform(self, tick_vals):
tick_vals = self.process_tick_vals(tick_vals)
tick_locs = self.lr.predict(np.array(tick_vals).reshape(-1, 1))
return tick_locs
def set_date_ticks(ax, start_date, end_date, axis='y', date_format='%Y-%m-%d', **date_range_kwargs):
dt_rng = pd.date_range(start_date, end_date, **date_range_kwargs)
ax_transformer = AxTransformer(datetime_vals=True)
ax_transformer.fit(ax, axis=axis)
getattr(ax, f'set_{axis}ticks')(ax_transformer.transform(dt_rng))
getattr(ax, f'set_{axis}ticklabels')(dt_rng.strftime(date_format))
ax.tick_params(axis=axis, which='both', bottom=True, top=False, labelbottom=True)
return ax
base = datetime(2000, 1, 1)
arr = np.array([base + timedelta(days=i) for i in range(366*3)])
val = np.random.rand(len(arr),3)
df = pd.DataFrame(val, index = arr)
f, ax = plt.subplots(figsize=(20,20))
ax = sns.heatmap(df, ax = ax)
set_date_ticks(ax, '2000-01-01', '2003-12-01', freq='1Y')
ax.format_ydata = mdates.DateFormatter('% Y')
plt.show()