from ray import tune
nhits_config = {
    "start_padding_enabled": True,                                            # Enable start padding
   "max_steps": 100,                                                         # Number of SGD steps
   "input_size": 24,                                                         # Size of input window
   "learning_rate": tune.loguniform(1e-5, 1e-1),                             # Initial Learning rate
   "n_pool_kernel_size": tune.choice([[2, 2, 2], [16, 8, 1]]),               # MaxPool's Kernelsize
   "n_freq_downsample": tune.choice([[168, 24, 1], [24, 12, 1], [1, 1, 1]]), # Interpolation expressivity ratios
   "val_check_steps": 50,                                                    # Compute validation every 50 steps
   "random_seed": tune.randint(1, 10),                                       # Random seed
}
# models = [NBEATS(input_size=2 * horizon, h=horizon, max_steps=50, start_padding_enabled=True),
#     NHITS(input_size=2 * horizon, h=horizon, max_steps=50, start_padding_enabled=True)]
from ray.tune.search.hyperopt import HyperOptSearch
from neuralforecast.losses.pytorch import MAE
from neuralforecast.auto import AutoNHITS
import ray
ray.init(num_cpus=16, num_gpus=1)
models = [AutoNHITS(h=12,
              loss=MAE(),
              config=nhits_config,
              search_alg=HyperOptSearch(),
              backend='ray',
              num_samples=10)]

nf = NeuralForecast(models=models, freq='M')
nf.fit(df=Y_train_df)
Y_hat_df = nf.predict().reset_index()