воскресенье, 1 сентября 2024 г.

Gradient, Boost, Boosting

ModelTrainingGradientBoost.py
import pandas as pd
import numpy as np
from sklearn.datasets import fetch_california_housing
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
from sklearn.ensemble import GradientBoostingRegressor
from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
import matplotlib.pyplot as plt
from itertools import product

# Загрузка данных
cali_housing = fetch_california_housing()

X = pd.DataFrame(cali_housing.data, columns=cali_housing.feature_names)
y = pd.Series(cali_housing.target)

# Разделение данных на обучающую и тестовую выборки
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# Определение диапазонов для гиперпараметров
n_estimators_range = range(50, 301, 50)
max_depth_range = range(3, 11, 2)
learning_rate_range = [0.001, 0.01, 0.1, 0.5]
min_samples_split_range = [2, 5, 10, 15]
min_samples_leaf_range = [1, 2, 4]

# Создание всех возможных комбинаций гиперпараметров
param_grid = dict(
n_estimators=n_estimators_range,
max_depth=max_depth_range,
learning_rate=learning_rate_range,
min_samples_split=min_samples_split_range,
min_samples_leaf=min_samples_leaf_range
)


# Функция для оценки модели
def evaluate_model(params):
model = GradientBoostingRegressor(**params)
return -cross_val_score(model, X_train, y_train, cv=5, scoring='neg_mean_squared_error').mean()


# Настройка GridSearchCV
grid_search = GridSearchCV(
estimator=GradientBoostingRegressor(random_state=42),
param_grid=param_grid,
cv=5,
n_jobs=-1,
verbose=2,
error_score='raise',
scoring='neg_mean_squared_error'
)

try:
# Запуск GridSearchCV
grid_search.fit(X_train, y_train)
except Exception as e:
print(f"Произошла ошибка при запуске GridSearchCV: {e}")
else:
# Получаем лучшую модель
best_model = grid_search.best_estimator_

# Оценка лучшей модели на тестовой выборке
y_pred = best_model.predict(X_test)

mse = mean_squared_error(y_test, y_pred)
mae = mean_absolute_error(y_test, y_pred)
r2 = r2_score(y_test, y_pred)

print(f'Лучшие параметры: {grid_search.best_params_}')
print(f'Mean Squared Error: {mse}')
print(f'Mean Absolute Error: {mae}')
print(f'R^2 Score: {r2}')

# Визуализация важности признаков
feature_importance = best_model.feature_importances_
feature_names = X.columns
plt.figure(figsize=(12, 8))
plt.bar(feature_names, feature_importance)
plt.title('Feature Importance')
plt.xlabel('Features')
plt.ylabel('Importance')
plt.xticks(rotation=90)
plt.tight_layout()
plt.show()

# График предсказаний vs реальных значений
plt.figure(figsize=(10, 6))
plt.scatter(y_test, best_model.predict(X_test))
plt.plot([y_test.min(), y_test.max()], [y_test.min(), y_test.max()], 'r--', lw=2)
plt.xlabel('Actual Values')
plt.ylabel('Predicted Values')
plt.title('Actual vs Predicted Values')
plt.show()

# Дополнительная проверка на тестовой выборке
scores = cross_val_score(best_model, X_test, y_test, cv=5, scoring='neg_mean_squared_error')
print(f"Cross-validation scores on test set: {-scores}")

# Визуализация распределения ошибок
residuals = y_test - best_model.predict(X_test)
plt.figure(figsize=(10, 6))
plt.hist(residuals, bins=30, edgecolor='black')
plt.axvline(x=np.mean(residuals), color='red', linestyle='dashed', linewidth=2, label=f'Mean: {np.mean(residuals)}')
plt.axvline(x=np.median(residuals), color='green', linestyle='dotted', linewidth=2,
label=f'Median: {np.median(residuals)}')
plt.title('Distribution of Residuals')
plt.xlabel('Residuals')
plt.ylabel('Frequency')
plt.legend()
plt.show()

# Визуализация важности признаков по группам
group_feature_importance = {}
for i in range(len(feature_importance)):
if not round(feature_importance[i], 2) in group_feature_importance:
group_feature_importance[round(feature_importance[i], 2)] = []
group_feature_importance[round(feature_importance[i], 2)].append(i)

plt.figure(figsize=(12, 6))
for importance, indices in sorted(group_feature_importance.items()):
plt.bar(range(indices[-len(indices):]), [feature_importance[i] for i in indices[-len(indices):]],
label=f'{importance:.2f}')
plt.title('Grouped Feature Importances')
plt.xlabel('Features')
plt.ylabel('Importance')
plt.legend()
plt.show()

Комментариев нет:

Отправить комментарий