В какой библиотеке можно постоить нелинейную регрессию?
У меня есть данные (df.head и .tail) датафрейм:
date | mentor_cnt | mentee_cnt | mentee_dynamic | |
---|---|---|---|---|
0 | 2021-02-12 | 2 | 2 | nan |
1 | 2021-02-19 | 1 | 1 | 0,5 |
2 | 2021-03-06 | 1 | 1 | 1 |
3 | 2021-03-27 | 2 | 2 | 2 |
4 | 2021-04-10 | 4 | 4 | 2 |
5 | 2021-04-18 | 3 | 3 | 0,75 |
37 | 2022-07-08 | 279 | 368 | 1.319 |
38 | 2022-07-22 | 291 | 388 | 1.333 |
39 | 2022-08-05 | 341 | 478 | 1.402 |
40 | 2022-08-19 | 373 | 521 | 1.397 |
41 | 2022-09-02 | 460 | 725 | 1.576 |
Задача: сделать прогноз mentor_cnt на основе mentee_cnt.
Я уже построил нелинейную регрессию, однако она не годится для использования (к тому же подбирать вид уравнения вручную может быть сложно).
Я использовал scipy.optimize, sklearn и scipy.stats. Какой вариант еще стоит попробовать в моем случае?
P. S. LinearRegression и PolynomialFeatures из sklearn уже пробовал.
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
from sklearn.model_selection import train_test_split
from sklearn.metrics import mean_squared_error, r2_score
from scipy import stats
import scipy.optimize as opt
df = pd.read_csv(...)
# Создание y, x
x_non_linear = df['mentee_cnt'].to_numpy()
y_non_linear = df['mentor_cnt'].to_numpy()
# Разбивка на тренировочную и тестовую выборки
x_train_non_linear, x_test_non_linear, y_train_non_linear, y_test_non_linear = train_test_split(x_non_linear, y_non_linear, test_size=0.2)
def func(X, b, c):
return b*(X**c)
plt.plot(x_train_non_linear, y_train_non_linear, 'ko', label='Test values')
x_lin = np.linspace(0, x_train_non_linear.max(), 50)
p0 =[2, 0.8]
w, _ = opt.curve_fit(func, x_train_non_linear, y_train_non_linear, p0=p0)
y_model = func(x_lin, *w)
plt.plot(x_train_non_linear, y_train_non_linear, 'bo', label='Train values')
plt.plot(x_lin, y_model, 'r--', label='Model')
plt.title('Least squares regression')
plt.legend()
plt.show()
x_lin_test = np.linspace(0, x_test_non_linear.max(), 50)
y_model_test = func(x_lin_test, *w)
y_pred_for_test = func(x_test_non_linear, *w)
residuals_non_linear = y_test_non_linear - y_pred_for_test
residuals_non_linear_ptc = (y_test_non_linear - y_pred_for_test)/y_test_non_linear*100
rmse_non_linear = mean_squared_error(y_test_non_linear, y_pred_for_test, squared = False)
mse_non_linear = mean_squared_error(y_test_non_linear, y_pred_for_test)
r2_score_non_linear = r2_score(y_test_non_linear, y_pred_for_test)
print(f'rmse_non_linear: {rmse_non_linear:.3f}')
print(f'mse_non_linear: {mse_non_linear:.3f}')
print(f'r2_score_non_linear: {r2_score_non_linear}')
print(f'mentor_cnt = {w[0]:.3f} * (mentee_cnt ** {w[1]:.3f})')
# проверяем нормальность распределения остатков
stat_residuals_non_linear, p_residuals_non_linear = stats.shapiro(residuals_non_linear)
print(f'Shapiro: stat_residuals = {stat_residuals_non_linear:.3f}, p_residuals = {p_residuals_non_linear:.3f}')
plt.plot(x_test_non_linear, y_test_non_linear, 'bo', label='Test values')
plt.plot(x_lin_test, y_model_test, 'r--', label='Model')
plt.xlabel("mentee")
plt.ylabel("mentor")
plt.legend(loc="best")
plt.title(
f'''LinearRegression predict vs. test values\n
RMSE_test = {rmse_non_linear.mean():.3f}
mentor_cnt = {w[0]:.3f} * (mentee_cnt ** {w[1]:.3f})'''
)
plt.show()
# график для проверки нормальности распределения остатков
res = stats.probplot(residuals_non_linear, plot=plt)
plt.ylabel('Residuals')
plt.xlabel('Excepted normal Value')
plt.title(f'Normal probability plot of residuals\nShapirp: stat_residuals={stat_residuals_non_linear:.3f}, p_residuals={p_residuals_non_linear:.3f}')
plt.show()
# проверка однородности дисперсии остатков
sns.scatterplot(y=residuals_non_linear, x=y_pred_for_test)
plt.title('Predicted vs. Residual Values')
plt.xlabel('Predicted values')
plt.ylabel('Residuals')
plt.show()
Результат работы кода:
Модель vs тренировочная выборка
Модель vs тестовая выборка
rmse_non_linear: 5.507
mse_non_linear: 30.322
r2_score_non_linear: 0.9960486902733482
mentor_cnt = -9.466 + 3.113 * (mentee_cnt ** 0.766)
Shapiro: stat_residuals = 0.981, p_residuals = 0.970
Источник: Stack Overflow на русском