TensorFlow如何實(shí)現(xiàn)線性支持向量機(jī)SVM

TensorFlow如何實(shí)現(xiàn)線性支持向量機(jī)SVM,相信很多沒有經(jīng)驗(yàn)的人對(duì)此束手無(wú)策,為此本文總結(jié)了問題出現(xiàn)的原因和解決方法,通過這篇文章希望你能解決這個(gè)問題。

創(chuàng)新互聯(lián)公司是一家業(yè)務(wù)范圍包括IDC托管業(yè)務(wù),虛擬空間、主機(jī)租用、主機(jī)托管,四川、重慶、廣東電信服務(wù)器租用,資陽(yáng)主機(jī)托管,成都網(wǎng)通服務(wù)器托管,成都服務(wù)器租用,業(yè)務(wù)范圍遍及中國(guó)大陸、港澳臺(tái)以及歐美等多個(gè)國(guó)家及地區(qū)的互聯(lián)網(wǎng)數(shù)據(jù)服務(wù)公司。

今天要說的是線性可分情況下的支持向量機(jī)的實(shí)現(xiàn),如果對(duì)于平面內(nèi)的點(diǎn),支持向量機(jī)的目的是找到一條直線,把訓(xùn)練樣本分開,使得直線到兩個(gè)樣本的距離相等,如果是高維空間,就是一個(gè)超平面。

然后我們簡(jiǎn)單看下對(duì)于線性可分的svm原理是啥,對(duì)于線性模型:

TensorFlow如何實(shí)現(xiàn)線性支持向量機(jī)SVM

訓(xùn)練樣本為

TensorFlow如何實(shí)現(xiàn)線性支持向量機(jī)SVM

標(biāo)簽為:

TensorFlow如何實(shí)現(xiàn)線性支持向量機(jī)SVM

如果

TensorFlow如何實(shí)現(xiàn)線性支持向量機(jī)SVM

那么樣本就歸為正類, 否則歸為負(fù)類。

這樣svm的目標(biāo)是找到W(向量)和b,然后假設(shè)我們找到了這樣的一條直線,可以把數(shù)據(jù)分開,那么這些數(shù)據(jù)到這條直線的距離為:

TensorFlow如何實(shí)現(xiàn)線性支持向量機(jī)SVM

然后我們把超平面兩邊到超平面的距離叫做間隔(margin),優(yōu)化目標(biāo)是使得這個(gè)margin最大,使得這樣得到的超平面具有良好的泛化能力(用別的數(shù)據(jù)也能正確分類),

TensorFlow如何實(shí)現(xiàn)線性支持向量機(jī)SVM

SVM的優(yōu)化目標(biāo)是:

TensorFlow如何實(shí)現(xiàn)線性支持向量機(jī)SVM

條件是:

TensorFlow如何實(shí)現(xiàn)線性支持向量機(jī)SVM

注意這里,因?yàn)閠n可以取+1和-1,當(dāng)取-1的時(shí)候,不等式兩邊都會(huì)乘以-1,所以不等號(hào)的方向會(huì)變。求解這個(gè)優(yōu)化問題(二次規(guī)劃),可以用拉格朗日乘子法,其中alpha是拉格朗日乘子。

TensorFlow如何實(shí)現(xiàn)線性支持向量機(jī)SVM

對(duì)w和b求導(dǎo),可以得到:

TensorFlow如何實(shí)現(xiàn)線性支持向量機(jī)SVM

然后把這個(gè)求解的結(jié)果代到上面的L里面,可以得到L的對(duì)偶形式,得L~:

TensorFlow如何實(shí)現(xiàn)線性支持向量機(jī)SVM

對(duì)偶形式的條件是:

TensorFlow如何實(shí)現(xiàn)線性支持向量機(jī)SVM

然后將開始的那個(gè)線性模型中的參數(shù)W用核函數(shù)代替得到:

TensorFlow如何實(shí)現(xiàn)線性支持向量機(jī)SVM

上面L的對(duì)偶形式,就是一個(gè)簡(jiǎn)單的二次規(guī)劃問題,可以利用KKT條件求解:

TensorFlow如何實(shí)現(xiàn)線性支持向量機(jī)SVM

然后把上面的y(xn)帶入到這個(gè)等式里面,就得到下面這個(gè)式子:

TensorFlow如何實(shí)現(xiàn)線性支持向量機(jī)SVM

求解上式,得b為:

TensorFlow如何實(shí)現(xiàn)線性支持向量機(jī)SVM

其中Ns表示的就是支持向量,K(Xn,Xm)表示核函數(shù)。

下面舉個(gè)核函數(shù)的栗子,對(duì)于二維平面內(nèi)的點(diǎn),TensorFlow如何實(shí)現(xiàn)線性支持向量機(jī)SVM

花了兩個(gè)多小時(shí),終于算是把代碼調(diào)通了,雖然不難,但是還是覺得自己水平有限,實(shí)現(xiàn)起來(lái)還是會(huì)有很多問題

import numpy as np
import tensorflow as tf
from sklearn import datasets


x_data = tf.placeholder(shape=[None, 2], dtype=tf.float32)
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)

# 獲得batch大小的數(shù)據(jù)
def gen_data(batch_size):
   iris = datasets.load_iris()
   iris_X = np.array([[x[0], x[3]] for x in iris.data])
   iris_y = np.array([1 if y == 0 else -1 for y in iris.target])
   train_indices = np.random.choice(len(iris_X),
               int(round(len(iris_X) * 0.8)), replace=False)
   train_x = iris_X[train_indices]
   train_y = iris_y[train_indices]
   rand_index = np.random.choice(len(train_x), size=batch_size)
   batch_train_x = train_x[rand_index]
   batch_train_y = np.transpose([train_y[rand_index]])
   test_indices = np.array(
       list(set(range(len(iris_X))) - set(train_indices)))
   test_x = iris_X[test_indices]
   test_y = iris_y[test_indices]
   return batch_train_x, batch_train_y, test_x, test_y

# 定義模型
def svm():
   A = tf.Variable(tf.random_normal(shape=[2, 1]))
   b = tf.Variable(tf.random_normal(shape=[1, 1]))
   model_output = tf.subtract(tf.matmul(x_data, A), b)
   l2_norm = tf.reduce_sum(tf.square(A))
   alpha = tf.constant([0.01])
   classification_term = tf.reduce_mean(tf.maximum(0.,
        tf.subtract(1., tf.multiply(model_output, y_target))))
   loss = tf.add(classification_term, tf.multiply(alpha, l2_norm))
   my_opt = tf.train.GradientDescentOptimizer(0.01)
   train_step = my_opt.minimize(loss)
   return model_output, loss, train_step


def train(sess, batch_size):
   print("# Training loop")

   for i in range(100):
       x_vals_train, y_vals_train,\
       x_vals_test, y_vals_test = gen_data(batch_size)
       model_output, loss, train_step = svm()

       init = tf.global_variables_initializer()
       sess.run(init)

       prediction = tf.sign(model_output)
       accuracy = tf.reduce_mean(tf.cast(
           tf.equal(prediction, y_target), tf.float32))
       sess.run(train_step, feed_dict=
                                  {
                                   x_data: x_vals_train,
                       y_target: y_vals_train
                                   })

       train_loss = sess.run(loss, feed_dict=
                                          {
                                          x_data: x_vals_train,
                            y_target: y_vals_train
                                          })
       train_acc = sess.run(accuracy, feed_dict=
                                           {
                                           x_data: x_vals_train,
                             y_target: y_vals_train
                                           })

       test_acc = sess.run(accuracy, feed_dict=
                                       {
                                 x_data: x_vals_test,
                      y_target: np.transpose([y_vals_test])
                                       })

       if i % 10 == 1:
           print("train loss: {:.6f}, train accuracy : {:.6f}".
                 format(train_loss[0], train_acc))
           print
           print("test accuracy : {:.6f}".format(test_acc))
           print("- * - "*15)


def main(_):
   with tf.Session() as sess:
       train(sess, batch_size=16)


if __name__ == "__main__":
   tf.app.run()
   
   

總結(jié)一下,SVM里面的坑,首先要知道SVM的目的找到一條線或者超平面,然后會(huì)計(jì)算點(diǎn)到超平面的距離,然后把這個(gè)距離轉(zhuǎn)化為一個(gè)二次規(guī)劃問題,然后就是使用拉格朗日方法求解這個(gè)優(yōu)化問題,最后會(huì)涉及核函數(shù)方法。

看完上述內(nèi)容,你們掌握TensorFlow如何實(shí)現(xiàn)線性支持向量機(jī)SVM的方法了嗎?如果還想學(xué)到更多技能或想了解更多相關(guān)內(nèi)容,歡迎關(guān)注創(chuàng)新互聯(lián)行業(yè)資訊頻道,感謝各位的閱讀!

網(wǎng)頁(yè)標(biāo)題:TensorFlow如何實(shí)現(xiàn)線性支持向量機(jī)SVM
分享路徑:http://bm7419.com/article20/ijhsjo.html

成都網(wǎng)站建設(shè)公司_創(chuàng)新互聯(lián),為您提供網(wǎng)站營(yíng)銷、Google服務(wù)器托管、品牌網(wǎng)站設(shè)計(jì)、網(wǎng)站排名域名注冊(cè)

廣告

聲明:本網(wǎng)站發(fā)布的內(nèi)容(圖片、視頻和文字)以用戶投稿、用戶轉(zhuǎn)載內(nèi)容為主,如果涉及侵權(quán)請(qǐng)盡快告知,我們將會(huì)在第一時(shí)間刪除。文章觀點(diǎn)不代表本網(wǎng)站立場(chǎng),如需處理請(qǐng)聯(lián)系客服。電話:028-86922220;郵箱:631063699@qq.com。內(nèi)容未經(jīng)允許不得轉(zhuǎn)載,或轉(zhuǎn)載時(shí)需注明來(lái)源: 創(chuàng)新互聯(lián)

搜索引擎優(yōu)化