ラズパイ4 scikit-learn で機械学習(SVM)を学ぶ

概要

 前回までの投稿では、検出した電話着信音の周波数成分を一定のルールに基づいて判定しました。今回は、2つの電話着信音の分類にscikit-learn の認識モデルであるSVMを利用します。

 ※ scikit-learn : 機械学習ライブラリの名称
 ※ SVM : Support Vector Machine

SVM (調査内容)

 機械学習における分類と回帰を取り扱い、教師あり学習を用いるパターン認識モデルの一つ。基本は完全に直線・平面・超平面で線形分離可能なデータをマージン最大化という考え方に基づき2つのクラスに分類する境界を求める。但し、ソフトマージンによる線形分離出来ないデータの境界設定、曲線・曲面境界設定、複数クラスへの分類等にも対応するとのこと。 
 尚、教師あり学習は、膨大な入力データ(特徴値)と正解を学習させることで、ロジック自身を改良する方法である。

SCIKIT-LEARN インストール

 scikit-learn をインストールします。

sudo apt-get update
sudo apt-get install python-sklearn

  

プログラム【基礎】

 簡単なプログラムで学習モデル作成・保存方法を確認します。

① 学習モデルの作成・保存

 行番2・3はライブラリをインポートします。‘pickle’ は作成した学習モデルを保存する為に使います。
 行番6が入力データ(特徴値)で、行番7は入力データに対応する正解となります。それぞれ2つの特徴値を持つデータから2つのクラス(‘0’ or ‘1’)に分類します。行番8で ‘clf’ という学習モデルの学習アルゴリズムを設定し、行番9で学習モデルにデータセットを渡し、学習させます。行番12の検証データを用いて、行番13で学習モデル ‘clf’ による予測結果を得ます。 行番16・17で予測結果を出力します。
 行番21・22は、再利用目的で学習モデル ‘clf’ を保存します。

# -*- coding: utf-8 -*-
from sklearn import svm
import pickle

# 学習
train_data = [[0, 0], [0, 1], [1, 0], [1, 1]]   # 入力データ
train_label = [0, 0, 1, 1]                      # ラベル(正解)
clf = svm.SVC(C=10, gamma=0.1)                  # アルゴリズムを指定
clf.fit(train_data,train_label)                 # 学習

# テストデータに対して予測
test_data = [[0.6, 0], [1, 0], [0, 1], [0, 1]]  # テストデータ
test_label = clf.predict(test_data)             # 結果(予測値)

# テスト結果の表示
print("入力データ:{0}".format(test_data))
print("予測結果  :{0}".format(test_label))
print("")

# 学習モデル保存
with open('model.pickle',mode='wb') as f :
    pickle.dump(clf , f , protocol = 2 )

② 学習モデルの再利用

 ①で作成した学習モデルを再利用します。行番2で学習モデルを読み込む際に必要な ‘pickle’ をインポートします。行番5・6 で学習の済んだ認識モデルを呼び込みます。行番8 以降は呼び込んだ認識モデルを使用し、入力データに対する予測を行っています。

# -*- coding: utf-8 -*-
import pickle

# 学習モデルロード
with open('model.pickle',mode='rb') as f :
    clf = pickle.load(f)

# テストデータに対して予測(その1)
test_data = [[0.6, 0], [1, 0.9], [0, 1], [0, 1]]    # テストデータ
test_label = clf.predict(test_data)                 # 結果(予測値)
# 予測結果表示
print("入力データ:{0}".format(test_data))
print("予測結果  :{0}".format(test_label))
print("")

# テストデータに対して予測(その2)
test_data = [[0.2, 0], [0.5, 0.9]]                  # テストデータ
test_label = clf.predict(test_data)                 # 結果(予測値)
# 予測結果表示
print("入力データ:{0}".format(test_data))
print("予測結果  :{0}".format(test_label))
print("")

# テストデータに対して予測(その3)
inp_1 = float(input("数字(0--1)入力: "))
inp_2 = float(input("数字(0--1)入力: "))

temp_data = []
temp_data.append(inp_1)
temp_data.append(inp_2)

test_data = []
test_data.append(temp_data)                         # テストデータ

test_label = clf.predict(test_data)                 # 結果(予測値)
print("入力データ:{0}".format(test_data))
print("予測結果  :{0}".format(test_label))
print("")

  

プログラム【発展】

① 着信音検出用 学習モデル作成

 前回投稿で、各着信音のピーク周波数を次の様にCSV形式でファイルしました。今回、この着信音[01]・[02]のCSVデータを入力データ(特徴値)、各着信音に対するラベル(正解)を[0][1]として認識モデルの学習をします。

◆着信音[01]ピーク周波数 CSVファイル保存例◆
2654.639205460991,5303.6797234348
2654.639205460991,5299.373040752351
2653.777868924501,5300.234377288841
2648.1791814373187,5302.81838689831
2654.639205460991,5300.234377288841
2654.639205460991,5304.110391703044
 <以下、省略(全201行)>

 行番7〜30で 着信音[01][02]の CSVファイルを続けて読み込んで、入力データを配列 ‘dt_frq_arr’ に、正解を 配列 ‘dt_rtn_arr’ に格納します。
 配列格納データを利用し、行番33・34で学習し、行番46・47で学習結果をファイル保存します。

# -*- coding: utf-8 -*-
from sklearn import svm
import numpy as np
import pickle
import csv

csv_fil = ['SVM_01.csv','SVM_02.csv']                               # CSVファイル名
dt_frq_arr = np.empty((0,2) , int)                                  # 入力データ格納 numpy.ndarray
dt_rtn_arr = np.empty((0) , int)          # ラベル(正解)格納配列

for f_num in range(len(csv_fil)) :                                  # ファイル数分ループ処理
    dat_frg = []                                                    # 一時保存リスト
    with open('/home/pi/MySndTest/csv/' + csv_fil[f_num]) as f :    # CSVファイル開く
        reader = csv.reader(f)
        for row in reader :                                         # ファイル行ループ
            if len(row)>=2 :                                        # 1行に2データ以上を対象
                tmp=[]
                tmp.append(int(float(row[0])))                      # 文字列をintに変換
                tmp.append(int(float(row[1])))                      # 文字列をintに変換  
                dat_frg.append(tmp)                                 # 一時保存リストに追加

                dt_rtn_arr = np.append(dt_rtn_arr,f_num)
                

    dt_frq = np.array(dat_frg)                                      # numpy.ndarrayに変換 
    dt_frq_ave = np.average(dt_frq , axis=0)                        # 平均
    dt_frq_max = np.max(dt_frq , axis=0)                            # 最大
    dt_frq_min = np.min(dt_frq , axis=0)                            # 最小

    dt_frq_arr = np.append( dt_frq_arr , dt_frq , axis=0 )

# 学習
clf = svm.SVC(C=10 , gamma=0.1)                                 # アルゴリズムを指定
clf.fit(dt_frq_arr , dt_rtn_arr)                                # 学習

# テストデータに対して予測
test_data = [[2858,5684],[2841,5685],[2660,5296],[2642,5311]]   # テストデータ
test_label = clf.predict(test_data)                             # 結果(予測値)

# テスト結果の表示
print("入力データ:{0}".format(test_data))
print("予測結果  :{0}".format(test_label))
print("")

# 学習モデル保存
pickle.dump(clf , open('frq_feature.pickle',mode='wb'))
print("finished")

② 着信音検出用 学習モデル試行

 保存した学習モデルの簡易テストに使用しました。

# -*- coding: utf-8 -*-
import pickle

# 学習モデルロード
clf = pickle.load(open('frq_feature.pickle',mode='rb'))

# テストデータに対して予測(その1)
test_data = [[2858,5684],[2841,5685],[2660,5296],[2642,5311]]   # テストデータ
test_label = clf.predict(test_data)                             # 結果(予測値)
# 予測結果表示
print("入力データ:{0}".format(test_data))
print("予測結果  :{0}".format(test_label))
print("")

# テストデータに対して予測(その2)
inp_1 = float(input("周波数入力: "))
inp_2 = float(input("周波数入力: "))

temp_data = []
temp_data.append(inp_1)
temp_data.append(inp_2)

test_data = []
test_data.append(temp_data)                         # テストデータ

test_label = clf.predict(test_data)                 # 結果(予測値)
print("入力データ:{0}".format(test_data))
print("予測結果  :{0}".format(test_label))
print("")

③ 着信音検出プログラム

 11行目で保存した学習モデルを読み込みます。行番71でデータ(特徴値)を入力し、結果(予測値)を取得します。雑音の影響を受けにくい様に、行番57で各周波数成分の大きさ等でフィルタリングしています。

import pyaudio
import wave
import numpy as np
import matplotlib.pyplot as plt
import csv
import os
from scipy.signal import argrelmax
import pickle


clf = pickle.load(open('frq_feature.pickle',mode='rb')) # 学習モデルロード

frm_dt = pyaudio.paInt16                                # 16-bit resolution
chnl = 1                                                # 1 channel

s_rate = 44100                                          # 44.1kHz サンプリング周波数
N = 100                                                 # chunk / s_rate = samp_time (sec)
chunk = 1024 * N                                        # 2^12 一回の取得データ数 (1024*50/44100= 1.16 (sec), 1024*100/44100= 2.32 (sec))
dev_idx = 5                                             # デバイス番号(必須ではない?(優先度指定している為?))

audio = pyaudio.PyAudio()                               # create pyaudio instantiation

# ストリーム作成
stream = audio.open(format = frm_dt , rate = s_rate , channels = chnl, input_device_index = dev_idx , \
                    input = True , frames_per_buffer = chunk )

# x軸( frquency )
wv_x1 = np.linspace(0, s_rate, chunk)
chnk_lmt = int(chunk/2)                                 # FFT有効範囲考慮
wv_x2 = wv_x1[0:chnk_lmt]



# 音声取得/FFT処理
print("Started")
while True:
    try:
        # 音声データ取得
        data = stream.read(chunk , exception_on_overflow=False )
        ndarray = np.frombuffer(data, dtype='int16')
        
        # y軸(FFT : amplitude)
        wv_y1 = np.fft.fft(ndarray)
        wv_y1 = np.abs(wv_y1)
        wv_y2 = wv_y1[0:chnk_lmt]

        peak_args = argrelmax(wv_y2,order=200)                  # ピーク検出(argrelmax:極大値インデックス ,order:次極大値判定範囲/間隔(最小:int 1))

        # peak_args(index) >降順sort> peak_args_sort(index)に格納 # ◆未使用(参考)◆
        f_peak = wv_y2[peak_args]                               # インデックスを強度値に変換
        f_peak_argsort = f_peak.argsort()[::-1]                 # 強度降順並替(argsort:値昇順ソート(インデックス返す),[::-1]降順ソート時)
        peak_args_sort = peak_args[0][f_peak_argsort]           # 強度降順並替後のインデックス配列を指定し、強度降順配列を取得

        # 検出ピークの対象を絞り込む
        wv_x3 = wv_x2[peak_args]                                 
        wv_y3 = wv_y2[peak_args]
        idx_tgt = np.where((wv_y3>0.2*10**8) & (wv_x3>=2500))   # 有効範囲設定

        wv_x4 = wv_x3[idx_tgt]
        wv_y4 = wv_y3[idx_tgt]

        # ピークデータ検出時、判定処理
        msg = ""
        if len(wv_x4) > 1:
            temp_data = []
            temp_data.append(int(wv_x4[0]))
            temp_data.append(int(wv_x4[1]))
            
            inp_data = []
            inp_data.append(temp_data)                         # 入力データ(検出周波数)
            rtn_label = clf.predict(inp_data)                  # 結果(予測値)
            msg = "["+("00"+str(rtn_label[0]+1))[-2:]+"]" 
            
        # グラフ表示データ設定
        plt.plot(wv_x2 , wv_y2)                                 # FFT結果 
        plt.plot(wv_x4 , wv_y4 ,'ro')                           # ピーク値

        # タイトル(判定結果表示)
        if msg == "" :
            plt.title("No phone is detected." , fontsize=16 , color='black')         
        else :
            plt.title(msg + " is detected.", fontsize=22 , color='red') 

        plt.xlabel("frquency [Hz]")                             # X軸ラベル
        plt.ylabel("amplitude")                                 # Y軸ラベル
        plt.xlim(0, 10000)                                      # X軸表示範囲
        plt.ylim(0, 3 * 10 ** 8)                                # Y軸表示範囲
        
        plt.minorticks_on()                                     # 補助目盛り表示
        plt.grid(which="major", color="black", alpha=0.5)       # 目盛り線の表示
        plt.grid(which="minor", color="gray", linestyle=":")    # 補助目盛り線表示
        
        plt.draw()                                              # リアルタイム更新
        plt.pause(0.0001)                                       # リアルタイム更新
        plt.cla()                                               # 現在軸クリア
        
    except KeyboardInterrupt:
        print("Ctrl+Cで停止しました")
        break

print("Finished")
plt.clf()
plt.close()

# ストリーム終了
stream.stop_stream()
stream.close()
audio.terminate()

   

検出状況(動画)

 プログラム【発展】ー ③ 着信音検出プログラムの実行状況(動画)です。雑音の影響を受けにくい場所では、安定した結果を得ることが出来ます。(動画に音声はつけていません。)

   

まとめ

 実績データを入力・学習させると、判定ロジックを自動で作成してくれるなんてすばらしい。
 でも、まったくかけ離れたデータを入力してもどちらかに判定するのだと思います。雑音の影響を少なくしたり、ある程度入力前にデータをチェックするなどの検討も必要だと思います。
 私としては機械学習については初めてのことなので、これからもっと良い方法に出会えると思います。


 

コメントを残す

メールアドレスが公開されることはありません。 * が付いている欄は必須項目です