用Python实战EEG癫痫检测:从CHB-MIT数据集到SVM分类的保姆级教程

张开发
2026/4/17 17:27:21 15 分钟阅读

分享文章

用Python实战EEG癫痫检测:从CHB-MIT数据集到SVM分类的保姆级教程
用Python实战EEG癫痫检测从CHB-MIT数据集到SVM分类的保姆级教程在生物医学信号处理领域脑电图EEG分析一直是最具挑战性的任务之一。癫痫发作检测作为EEG分析的重要应用场景不仅对临床诊断具有重大意义也为机器学习工程师提供了绝佳的技术试验场。本教程将带您从零开始构建一个完整的癫痫检测流程使用Python生态中的强大工具链包括MNE-Python、scikit-learn和NumPy等库逐步实现数据加载、预处理、特征提取和模型训练的全过程。1. 环境准备与数据加载工欲善其事必先利其器。在开始EEG分析之前我们需要配置合适的Python环境并获取CHB-MIT数据集。这个由波士顿儿童医院和麻省理工学院联合收集的公开数据集包含了23名癫痫患者的长期头皮EEG记录采样率为256Hz采用国际10-20系统放置电极。首先安装必要的Python包pip install mne scikit-learn numpy pandas matplotlib seaborn加载CHB-MIT数据集的核心代码如下import mne import os # 设置数据路径 data_dir path_to_CHB-MIT_dataset raw_fname os.path.join(data_dir, chb01, chb01_03.edf) # 读取EDF文件 raw mne.io.read_raw_edf(raw_fname, preloadTrue) raw.info[bads] [] # 标记坏通道如有 # 查看基本信息 print(raw.info)数据集中的每个EDF文件都包含多通道EEG信号和对应的标注信息。特别需要注意的是CHB-MIT数据集中的发作事件seizure events已经由专业医师标注我们可以利用这些标注来构建监督学习模型。2. EEG信号预处理实战原始EEG信号通常包含各种噪声和伪迹有效的预处理是后续分析成功的关键。我们将采用一系列专业级处理方法2.1 滤波处理# 带通滤波0.5-50Hz raw.filter(0.5, 50., fir_designfirwin) # 陷波滤波去除工频干扰 raw.notch_filter([50., 60.]) # 针对50Hz或60Hz电源噪声2.2 伪迹去除独立成分分析(ICA)是去除眼动和肌电伪迹的有效方法# 设置并拟合ICA ica mne.preprocessing.ICA(n_components20, random_state97) ica.fit(raw) # 自动检测并去除眼动伪迹 eog_indices, eog_scores ica.find_bads_eog(raw) ica.exclude eog_indices[:2] # 排除前两个眼动相关成分 # 应用ICA raw_clean ica.apply(raw)2.3 重参考与分段# 转换为平均参考 raw_clean.set_eeg_reference(ref_channelsaverage) # 提取发作期和发作间期数据 events, event_ids mne.events_from_annotations(raw_clean) epochs mne.Epochs(raw_clean, events, event_idevent_ids, tmin-60, tmax60, baseline(-60, -30))3. 特征工程从时域到时频域特征提取是将原始信号转化为机器学习可理解形式的核心步骤。对于EEG信号我们需要从多个维度提取有区分力的特征。3.1 时域特征import numpy as np def extract_time_features(epoch): features [] for ch in epoch: signal epoch[ch] features.extend([ np.mean(signal), # 均值 np.std(signal), # 标准差 np.median(signal), # 中位数 np.max(signal) - np.min(signal), # 峰峰值 np.sum(np.diff(np.sign(signal)) ! 0) / len(signal) # 零交叉率 ]) return features3.2 频域特征from scipy.signal import welch def extract_freq_features(epoch, sfreq256): features [] for ch in epoch: freqs, psd welch(epoch[ch], fssfreq, nperseg256) for band, (l_freq, h_freq) in {delta: (0.5,4), theta: (4,8), alpha: (8,13), beta: (13,30)}.items(): band_mask (freqs l_freq) (freqs h_freq) features.append(np.log10(psd[band_mask].mean())) return features3.3 非线性特征from antropy import sample_entropy, petrosian_fd def extract_nonlinear_features(epoch): features [] for ch in epoch: signal epoch[ch] features.extend([ sample_entropy(signal, order2, metricchebyshev), petrosian_fd(signal), np.mean(np.abs(np.diff(signal))) # 平均绝对差分 ]) return features4. 构建SVM分类器支持向量机(SVM)在小样本、高维特征空间表现优异非常适合EEG分类任务。我们将使用scikit-learn实现一个完整的分类流程。4.1 特征整合与标准化from sklearn.preprocessing import StandardScaler from sklearn.pipeline import make_pipeline from sklearn.svm import SVC from sklearn.model_selection import cross_val_score # 假设X是特征矩阵y是标签(0:非发作, 1:发作) scaler StandardScaler() X_scaled scaler.fit_transform(X) # 构建SVM管道 svm_pipe make_pipeline( StandardScaler(), SVC(kernelrbf, C1.0, gammascale, class_weightbalanced) )4.2 交叉验证评估# 10折交叉验证 cv_scores cross_val_score(svm_pipe, X, y, cv10, scoringroc_auc, n_jobs-1) print(f平均AUC: {np.mean(cv_scores):.3f} ± {np.std(cv_scores):.3f})4.3 参数优化from sklearn.model_selection import GridSearchCV param_grid { svc__C: [0.1, 1, 10, 100], svc__gamma: [scale, auto, 0.001, 0.01, 0.1] } grid_search GridSearchCV(svm_pipe, param_grid, cv5, scoringroc_auc, n_jobs-1) grid_search.fit(X_train, y_train) print(f最佳参数: {grid_search.best_params_}) print(f最佳AUC: {grid_search.best_score_:.3f})5. 实战技巧与常见问题解决在实际项目中EEG癫痫检测会遇到各种挑战。以下是几个关键问题的解决方案5.1 类别不平衡处理癫痫发作数据通常远少于非发作数据我们可以采用以下策略from imblearn.over_sampling import SMOTE smote SMOTE(random_state42) X_resampled, y_resampled smote.fit_resample(X_train, y_train)5.2 通道选择优化并非所有电极通道都包含有用信息我们可以使用递归特征消除(RFE)选择最佳通道组合from sklearn.feature_selection import RFECV selector RFECV(SVC(kernellinear), step1, cv5, scoringroc_auc, n_jobs-1) selector selector.fit(X, y) print(f最优特征数: {selector.n_features_})5.3 实时检测实现对于实时应用我们可以使用滑动窗口技术def real_time_detection(raw, model, window_size5, step1): sfreq raw.info[sfreq] n_samples int(window_size * sfreq) step_samples int(step * sfreq) for start in range(0, len(raw)-n_samples, step_samples): window raw[:, start:startn_samples][0] features extract_features(window) proba model.predict_proba([features])[0,1] if proba 0.8: # 设置阈值 print(f检测到癫痫发作时间: {start/sfreq:.1f}s)6. 模型解释与可视化理解模型决策过程对临床应用至关重要。我们可以使用SHAP值来解释SVM的预测import shap # 使用核SHAP解释SVM explainer shap.KernelExplainer(svm_pipe.predict_proba, X_train[:100]) shap_values explainer.shap_values(X_test[:10]) # 可视化单个预测的解释 shap.force_plot(explainer.expected_value[1], shap_values[1][0], X_test[0])对于EEG信号我们还可以绘制特征重要性图import matplotlib.pyplot as plt # 获取线性SVM的权重 svm_linear SVC(kernellinear).fit(X_scaled, y) coef svm_linear.coef_[0] # 绘制特征重要性 plt.figure(figsize(12,6)) plt.bar(range(len(coef)), np.abs(coef)) plt.xticks(range(len(coef)), feature_names, rotation90) plt.title(SVM特征重要性) plt.tight_layout()7. 性能提升策略当基础模型性能不足时可以考虑以下进阶技术7.1 集成学习方法from sklearn.ensemble import BaggingClassifier bagging BaggingClassifier( SVC(kernelrbf, probabilityTrue), n_estimators10, max_samples0.8, n_jobs-1, random_state42 ) bagging.fit(X_train, y_train)7.2 深度学习融合虽然本教程聚焦传统机器学习但可以结合深度学习特征from tensorflow.keras.models import Sequential from tensorflow.keras.layers import Conv1D, MaxPooling1D, Flatten # 构建简单的1D CNN提取特征 cnn Sequential([ Conv1D(32, 3, activationrelu, input_shape(n_timesteps, n_channels)), MaxPooling1D(2), Flatten() ]) # 提取CNN特征 cnn_features cnn.predict(X_reshaped) # 将CNN特征与传统特征结合 X_combined np.hstack([X_features, cnn_features])7.3 患者特异性调优from sklearn.model_selection import LeaveOneGroupOut # 按患者分组交叉验证 logo LeaveOneGroupOut() for train_idx, test_idx in logo.split(X, y, groupspatient_ids): X_train, X_test X[train_idx], X[test_idx] y_train, y_test y[train_idx], y[test_idx] # 训练患者特定模型 model.fit(X_train, y_train) # 评估...8. 部署与生产化考虑将研究原型转化为实际应用需要考虑以下方面8.1 模型轻量化from sklearn.linear_model import LogisticRegression # 使用L1正则化进行特征选择 lr_l1 LogisticRegression(penaltyl1, solverliblinear, C0.1) lr_l1.fit(X_scaled, y) # 查看非零系数数量 print(f非零特征数: {np.sum(lr_l1.coef_ ! 0)})8.2 实时性能优化from numba import jit jit(nopythonTrue) def fast_feature_extraction(signal_window): # 优化关键特征提取函数 # ... return features8.3 模型监控from sklearn.metrics import classification_report # 定期评估模型性能 y_pred model.predict(X_new) print(classification_report(y_true, y_pred)) # 监控特征分布变化 plt.figure() plt.hist(X_train[:,0], alpha0.5, label训练数据) plt.hist(X_new[:,0], alpha0.5, label新数据) plt.legend()

更多文章