Pocket

Motivation

Learn fundamental Machine learning skill set

Chapter3. A Tour of Machine Learning Classifiers Using Scikit-learn

Goal of this Chapter

  • Take a tour though a selection of popular and powerful machine learning algorithms
  • Lean about the differences between several supervised learning algorythm for classification
  • Develop an intuitive appreciation of their strengths and weakness

Summary of this Chapter

  • Introduction to the concepts of popular classification algorithms
  • Using the scikit-learn machine learning library
  • Questions to ask when selecting a machine learning algorithm(how do we choose?)

3.1 Choosing a classification algorithm

  • no single classifiier works best across all possible scenarios
    • Practices are necessary to choose appropriate alogorythm for each problem
  • comparisons of each algorythm result and selections of best model based on the results are necessary for the particular problem

five main steps for training

  1. Selection of features
  2. Choosing a performance metric
  3. Choosing a classifier and optimization algorithm
  4. Evaluating the performance of the model
  5. Tuning the alogorithm

3.2 First steps with scikit-learn

  • scikit-learn is an libary that offers,
    • a large vriety of learning alogorithms
    • functions to preprocess data
    • functions to fine-tune and evaluate models

3.2.1 Training a perceptron via scikit-learn

  • traiging with perceptron model and Iris dataset
    • (`datasets.load_iris()](http://scikit-learn.org/stable/modules/generated/sklearn.datasets.load_iris.html#sklearn.datasets.load_iris)
    • ‘data’, the data to learn
    • ‘target’the classification labels
from sklearn import datasets
import numpy as np

iris = datasets.load_iris()
X = iris.data[:, [2, 3]]
y = iris.target

print('Class labels:', np.unique(y))

> Class labels: [0 1 2]
  • iris datasets has lots of data
iris = datasets.load_iris()
print(iris)

>    {'target': array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
           0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
           1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
           1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
           2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2,
           2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]), 'data': array([[ 5.1,  3.5,  1.4,  0.2],
           [ 4.9,  3. ,  1.4,  0.2],
           [ 4.7,  3.2,  1.3,  0.2],
           [ 4.6,  3.1,  1.5,  0.2],
           [ 5. ,  3.6,  1.4,  0.2],
           [ 5.4,  3.9,  1.7,  0.4],
           [ 4.6,  3.4,  1.4,  0.3],
           [ 5. ,  3.4,  1.5,  0.2],
           [ 4.4,  2.9,  1.4,  0.2],
           [ 4.9,  3.1,  1.5,  0.1],
           [ 5.4,  3.7,  1.5,  0.2],
           [ 4.8,  3.4,  1.6,  0.2],
           [ 4.8,  3. ,  1.4,  0.1],
           [ 4.3,  3. ,  1.1,  0.1],
           [ 5.8,  4. ,  1.2,  0.2],
           [ 5.7,  4.4,  1.5,  0.4],
           [ 5.4,  3.9,  1.3,  0.4],
           [ 5.1,  3.5,  1.4,  0.3],
           [ 5.7,  3.8,  1.7,  0.3],
           [ 5.1,  3.8,  1.5,  0.3],
           [ 5.4,  3.4,  1.7,  0.2],
           [ 5.1,  3.7,  1.5,  0.4],
           [ 4.6,  3.6,  1. ,  0.2],
           [ 5.1,  3.3,  1.7,  0.5],
           [ 4.8,  3.4,  1.9,  0.2],
           [ 5. ,  3. ,  1.6,  0.2],
           [ 5. ,  3.4,  1.6,  0.4],
           [ 5.2,  3.5,  1.5,  0.2],
           [ 5.2,  3.4,  1.4,  0.2],
           [ 4.7,  3.2,  1.6,  0.2],
           [ 4.8,  3.1,  1.6,  0.2],
           [ 5.4,  3.4,  1.5,  0.4],
           [ 5.2,  4.1,  1.5,  0.1],
           [ 5.5,  4.2,  1.4,  0.2],
           [ 4.9,  3.1,  1.5,  0.1],
           [ 5. ,  3.2,  1.2,  0.2],
           [ 5.5,  3.5,  1.3,  0.2],
           [ 4.9,  3.1,  1.5,  0.1],
           [ 4.4,  3. ,  1.3,  0.2],
           [ 5.1,  3.4,  1.5,  0.2],
           [ 5. ,  3.5,  1.3,  0.3],
           [ 4.5,  2.3,  1.3,  0.3],
           [ 4.4,  3.2,  1.3,  0.2],
           [ 5. ,  3.5,  1.6,  0.6],
           [ 5.1,  3.8,  1.9,  0.4],
           [ 4.8,  3. ,  1.4,  0.3],
           [ 5.1,  3.8,  1.6,  0.2],
           [ 4.6,  3.2,  1.4,  0.2],
           [ 5.3,  3.7,  1.5,  0.2],
           [ 5. ,  3.3,  1.4,  0.2],
           [ 7. ,  3.2,  4.7,  1.4],
           [ 6.4,  3.2,  4.5,  1.5],
           [ 6.9,  3.1,  4.9,  1.5],
           [ 5.5,  2.3,  4. ,  1.3],
           [ 6.5,  2.8,  4.6,  1.5],
           [ 5.7,  2.8,  4.5,  1.3],
           [ 6.3,  3.3,  4.7,  1.6],
           [ 4.9,  2.4,  3.3,  1. ],
           [ 6.6,  2.9,  4.6,  1.3],
           [ 5.2,  2.7,  3.9,  1.4],
           [ 5. ,  2. ,  3.5,  1. ],
           [ 5.9,  3. ,  4.2,  1.5],
           [ 6. ,  2.2,  4. ,  1. ],
           [ 6.1,  2.9,  4.7,  1.4],
           [ 5.6,  2.9,  3.6,  1.3],
           [ 6.7,  3.1,  4.4,  1.4],
           [ 5.6,  3. ,  4.5,  1.5],
           [ 5.8,  2.7,  4.1,  1. ],
           [ 6.2,  2.2,  4.5,  1.5],
           [ 5.6,  2.5,  3.9,  1.1],
           [ 5.9,  3.2,  4.8,  1.8],
           [ 6.1,  2.8,  4. ,  1.3],
           [ 6.3,  2.5,  4.9,  1.5],
           [ 6.1,  2.8,  4.7,  1.2],
           [ 6.4,  2.9,  4.3,  1.3],
           [ 6.6,  3. ,  4.4,  1.4],
           [ 6.8,  2.8,  4.8,  1.4],
           [ 6.7,  3. ,  5. ,  1.7],
           [ 6. ,  2.9,  4.5,  1.5],
           [ 5.7,  2.6,  3.5,  1. ],
           [ 5.5,  2.4,  3.8,  1.1],
           [ 5.5,  2.4,  3.7,  1. ],
           [ 5.8,  2.7,  3.9,  1.2],
           [ 6. ,  2.7,  5.1,  1.6],
           [ 5.4,  3. ,  4.5,  1.5],
           [ 6. ,  3.4,  4.5,  1.6],
           [ 6.7,  3.1,  4.7,  1.5],
           [ 6.3,  2.3,  4.4,  1.3],
           [ 5.6,  3. ,  4.1,  1.3],
           [ 5.5,  2.5,  4. ,  1.3],
           [ 5.5,  2.6,  4.4,  1.2],
           [ 6.1,  3. ,  4.6,  1.4],
           [ 5.8,  2.6,  4. ,  1.2],
           [ 5. ,  2.3,  3.3,  1. ],
           [ 5.6,  2.7,  4.2,  1.3],
           [ 5.7,  3. ,  4.2,  1.2],
           [ 5.7,  2.9,  4.2,  1.3],
           [ 6.2,  2.9,  4.3,  1.3],
           [ 5.1,  2.5,  3. ,  1.1],
           [ 5.7,  2.8,  4.1,  1.3],
           [ 6.3,  3.3,  6. ,  2.5],
           [ 5.8,  2.7,  5.1,  1.9],
           [ 7.1,  3. ,  5.9,  2.1],
           [ 6.3,  2.9,  5.6,  1.8],
           [ 6.5,  3. ,  5.8,  2.2],
           [ 7.6,  3. ,  6.6,  2.1],
           [ 4.9,  2.5,  4.5,  1.7],
           [ 7.3,  2.9,  6.3,  1.8],
           [ 6.7,  2.5,  5.8,  1.8],
           [ 7.2,  3.6,  6.1,  2.5],
           [ 6.5,  3.2,  5.1,  2. ],
           [ 6.4,  2.7,  5.3,  1.9],
           [ 6.8,  3. ,  5.5,  2.1],
           [ 5.7,  2.5,  5. ,  2. ],
           [ 5.8,  2.8,  5.1,  2.4],
           [ 6.4,  3.2,  5.3,  2.3],
           [ 6.5,  3. ,  5.5,  1.8],
           [ 7.7,  3.8,  6.7,  2.2],
           [ 7.7,  2.6,  6.9,  2.3],
           [ 6. ,  2.2,  5. ,  1.5],
           [ 6.9,  3.2,  5.7,  2.3],
           [ 5.6,  2.8,  4.9,  2. ],
           [ 7.7,  2.8,  6.7,  2. ],
           [ 6.3,  2.7,  4.9,  1.8],
           [ 6.7,  3.3,  5.7,  2.1],
           [ 7.2,  3.2,  6. ,  1.8],
           [ 6.2,  2.8,  4.8,  1.8],
           [ 6.1,  3. ,  4.9,  1.8],
           [ 6.4,  2.8,  5.6,  2.1],
           [ 7.2,  3. ,  5.8,  1.6],
           [ 7.4,  2.8,  6.1,  1.9],
           [ 7.9,  3.8,  6.4,  2. ],
           [ 6.4,  2.8,  5.6,  2.2],
           [ 6.3,  2.8,  5.1,  1.5],
           [ 6.1,  2.6,  5.6,  1.4],
           [ 7.7,  3. ,  6.1,  2.3],
           [ 6.3,  3.4,  5.6,  2.4],
           [ 6.4,  3.1,  5.5,  1.8],
           [ 6. ,  3. ,  4.8,  1.8],
           [ 6.9,  3.1,  5.4,  2.1],
           [ 6.7,  3.1,  5.6,  2.4],
           [ 6.9,  3.1,  5.1,  2.3],
           [ 5.8,  2.7,  5.1,  1.9],
           [ 6.8,  3.2,  5.9,  2.3],
           [ 6.7,  3.3,  5.7,  2.5],
           [ 6.7,  3. ,  5.2,  2.3],
           [ 6.3,  2.5,  5. ,  1.9],
           [ 6.5,  3. ,  5.2,  2. ],
           [ 6.2,  3.4,  5.4,  2.3],
           [ 5.9,  3. ,  5.1,  1.8]]), 'target_names': array(['setosa', 'versicolor', 'virginica'], 
          dtype='<U10'), 'DESCR': 'Iris Plants Database\n====================\n\nNotes\n-----\nData Set Characteristics:\n    :Number of Instances: 150 (50 in each of three classes)\n    :Number of Attributes: 4 numeric, predictive attributes and the class\n    :Attribute Information:\n        - sepal length in cm\n        - sepal width in cm\n        - petal length in cm\n        - petal width in cm\n        - class:\n                - Iris-Setosa\n                - Iris-Versicolour\n                - Iris-Virginica\n    :Summary Statistics:\n\n    ============== ==== ==== ======= ===== ====================\n                    Min  Max   Mean    SD   Class Correlation\n    ============== ==== ==== ======= ===== ====================\n    sepal length:   4.3  7.9   5.84   0.83    0.7826\n    sepal width:    2.0  4.4   3.05   0.43   -0.4194\n    petal length:   1.0  6.9   3.76   1.76    0.9490  (high!)\n    petal width:    0.1  2.5   1.20  0.76     0.9565  (high!)\n    ============== ==== ==== ======= ===== ====================\n\n    :Missing Attribute Values: None\n    :Class Distribution: 33.3% for each of 3 classes.\n    :Creator: R.A. Fisher\n    :Donor: Michael Marshall (MARSHALL%PLU@io.arc.nasa.gov)\n    :Date: July, 1988\n\nThis is a copy of UCI ML iris datasets.\nhttp://archive.ics.uci.edu/ml/datasets/Iris\n\nThe famous Iris database, first used by Sir R.A Fisher\n\nThis is perhaps the best known database to be found in the\npattern recognition literature.  Fisher\'s paper is a classic in the field and\nis referenced frequently to this day.  (See Duda & Hart, for example.)  The\ndata set contains 3 classes of 50 instances each, where each class refers to a\ntype of iris plant.  One class is linearly separable from the other 2; the\nlatter are NOT linearly separable from each other.\n\nReferences\n----------\n   - Fisher,R.A. "The use of multiple measurements in taxonomic problems"\n     Annual Eugenics, 7, Part II, 179-188 (1936); also in "Contributions to\n     Mathematical Statistics" (John Wiley, NY, 1950).\n   - Duda,R.O., & Hart,P.E. (1973) Pattern Classification and Scene Analysis.\n     (Q327.D83) John Wiley & Sons.  ISBN 0-471-22361-1.  See page 218.\n   - Dasarathy, B.V. (1980) "Nosing Around the Neighborhood: A New System\n     Structure and Classification Rule for Recognition in Partially Exposed\n     Environments".  IEEE Transactions on Pattern Analysis and Machine\n     Intelligence, Vol. PAMI-2, No. 1, 67-71.\n   - Gates, G.W. (1972) "The Reduced Nearest Neighbor Rule".  IEEE Transactions\n     on Information Theory, May 1972, 431-433.\n   - See also: 1988 MLC Proceedings, 54-64.  Cheeseman et al"s AUTOCLASS II\n     conceptual clustering system finds 3 classes in the data.\n   - Many, many more ...\n', 'feature_names': ['sepal length (cm)', 'sepal width (cm)', 'petal length (cm)', 'petal width (cm)']}
  • convert class labels to interger is recommended for the optimal performance of many machine learning libraries

  • split dataset into training dataset and test dataset

from sklearn.cross_validation import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=0)
  • Many machine learning and optimization algorithms also require feature scaling for optimal performance
    • standardize the features using the StandardScaler class from scikit-learn’s preprocessing module
    • μ(sample mean) and σ(standard deviation) for each feature dimesion are estimated
from sklearn.preprocessing import StandardScaler

sc = StandardScaler()
sc.fit(X_train)
X_train_std = sc.transform(X_train)
X_test_std = sc.transform(X_test)
  • Having standardized the training data, we can train a perceptron model
    • Now ready for traiging a perceptron model!
from sklearn.linear_model import Perceptron
ppn = Perceptron(n_iter=40, eta0=0.1, random_state=0)
ppn.fit(X_train_std, y_train)

> Perceptron(alpha=0.0001, class_weight=None, eta0=0.1, fit_intercept=True,
    n_iter=40, n_jobs=1, penalty=None, random_state=0, shuffle=True,
    verbose=0, warm_start=False)
  • Finding an appropriate learning rate requires some experimentation
    • If the learning rate is too large, the algorithm will overshoot the global cost minimum
    • If the learning rate is too small, the algorithm requires more epochs until convergence, which can make the learning slow
  • Use random_state parameter for reproducibility of the initial shuf ing of the training dataset after each epoch
    • http://scikit-learn.org/stable/modules/generated/sklearn.linear_model.Perceptron.html
y_pred = ppn.predict(X_test_std)
print('Misclassified samples: %d' % (y_test != y_pred).sum())

>  Misclassified samples: 4
  • perceptron misclassies 4 out of the 45 ower samples
    • The misclassication error on the test dataset is 0.089 or 8.9 percent (4 / 45 ≈ 0.089).
    • accuracy = 1 – misclassi cation error = 0.911 or 91.1 percent
  • metcrics module gives us the function to calculate this accuracy as follows:
from sklearn.metrics import accuracy_score
print('Accuracy: %.2f' % accuracy_score(y_test, y_pred))

> Accuracy: 0.91
  • Plot the decision regions of our newly trained perceptron model
    • visualize the result of classification
from matplotlib.colors import ListedColormap
import matplotlib.pyplot as plt
import warnings


def versiontuple(v):
    return tuple(map(int, (v.split("."))))


def plot_decision_regions(X, y, classifier, test_idx=None, resolution=0.02):

    # setup marker generator and color map
    markers = ('s', 'x', 'o', '^', 'v')
    colors = ('red', 'blue', 'lightgreen', 'gray', 'cyan')
    cmap = ListedColormap(colors[:len(np.unique(y))])

    # plot the decision surface
    x1_min, x1_max = X[:, 0].min() - 1, X[:, 0].max() + 1
    x2_min, x2_max = X[:, 1].min() - 1, X[:, 1].max() + 1
    xx1, xx2 = np.meshgrid(np.arange(x1_min, x1_max, resolution),
                           np.arange(x2_min, x2_max, resolution))
    Z = classifier.predict(np.array([xx1.ravel(), xx2.ravel()]).T)
    Z = Z.reshape(xx1.shape)
    plt.contourf(xx1, xx2, Z, alpha=0.4, cmap=cmap)
    plt.xlim(xx1.min(), xx1.max())
    plt.ylim(xx2.min(), xx2.max())

    for idx, cl in enumerate(np.unique(y)):
        plt.scatter(x=X[y == cl, 0], y=X[y == cl, 1],
                    alpha=0.8, c=cmap(idx),
                    marker=markers[idx], label=cl)

    # highlight test samples
    if test_idx:
        # plot all samples
        if not versiontuple(np.__version__) >= versiontuple('1.9.0'):
            X_test, y_test = X[list(test_idx), :], y[list(test_idx)]
            warnings.warn('Please update to NumPy 1.9.0 or newer')
        else:
            X_test, y_test = X[test_idx, :], y[test_idx]

        plt.scatter(X_test[:, 0],
                    X_test[:, 1],
                    c='',
                    alpha=1.0,
                    linewidths=1,
                    marker='o',
                    s=55, label='test set')
  • plot the result
X_combined_std = np.vstack((X_train_std, X_test_std))
y_combined = np.hstack((y_train, y_test))

plot_decision_regions(X=X_combined_std, y=y_combined,
                      classifier=ppn, test_idx=range(105, 150))
plt.xlabel('petal length [standardized]')
plt.ylabel('petal width [standardized]')
plt.legend(loc='upper left')

plt.tight_layout()
# plt.savefig('./figures/iris_perceptron_scikit.png', dpi=300)
plt.show()
  • The three ower classes cannot be perfectly separated by a linear decision boundaries
  • The perceptron algorithm never converges on datasets that aren’t perfectly linearly separable
    • The perceptron algorithm is typically not recommended in practice

Tips

  • cross_validation module was deprecated and might be better to use model_selection module instead.
/Users/takayuki-watanabe/.pyenv/versions/3.5.1/lib/python3.5/site-packages/sklearn/cross_validation.py:44: DeprecationWarning: This module was deprecated in version 0.18 in favor of the model_selection module into which all the refactored classes and functions are moved. Also note that the interface of the new CV iterators are different from that of this module. This module will be removed in 0.20.
"This module will be removed in 0.20.", DeprecationWarning)

Question?

  • sklearn.cross_validation
    • http://ogrisel.github.io/scikit-learn.org/sklearn-tutorial/modules/classes.html#module-sklearn.cross_validation
  • what is train_test_split?
    • a function for Split arrays or matrices into random train and test subsets
    • http://ogrisel.github.io/scikit-learn.org/sklearn-tutorial/modules/generated/sklearn.cross_validation.train_test_split.html#sklearn.cross_validation.train_test_split
  • what is Cross-validation ?
    • https://en.wikipedia.org/wiki/Cross-validation_(statistics)
    • 統計学において標本データを分割し、その一部をまず解析して、残る部分でその解析のテストを行い、解析自身の妥当性の検証・確認に当てる手法
  • 特徴量のスケーリングも必要とは?
    • 2章の勾配降下法で例があるとのこと
  • plot_decision_regions関数
    • matplogLib.colors.ListedColormapと matplotlib.pyplotを使って実装した関数
  • 線形分離可能なデータセットとは
    • http://hokuts.com/2015/11/24/ml1_func/
    • 2次元空間における直線、3次元空間における平面
  • パラメータ n_iter は エポック数(データセットのトレーニング回数)
  • 決定領域とは?

References

Pocket

Share Your Thought

CAPTCHA