Beyond the State-of-the-Art

最先端を超えたいと思ってる(大嘘)エンジニアのブログ

Go初心者がGoでk-meansを実装してみた

Qiitaからの移植です。

Goの勉強として、クラスタリングの手法の1つであるk-meansを実装してみました。今回は簡単のため、2次元平面上に点の集まりをクラスタリングしました。

Go初心者なので変な書き方があるかと思いますが、そのときはコメントでご指摘お願いします。

データ作り

こんな感じのデータを作りました。3つのクラスターに分かれていることが目視で確認できます。

f:id:renor:20200327002300p:plain

グラフはGnuplotで作りました。Go言語のプログラムでデータをファイルで出力して、それをGnuplotで読み込ませました。データは乱数を使って作りました。データづくりのコードは次の通りです。

package main

import (
    "fmt"
    "os"
    "math/rand"
)

type XY struct{ X, Y float64}

func main() {
    // 乱数初期化
    rand.Seed(int64(0))

    // 出力ファイルオープン
    file, err := os.Create("points.dat")
    if err != nil {
        panic(err)
    }
    defer file.Close()

    // データ作って、ファイルに出力
    clusterSize := 100
    writeData(file, createData(XY{0, 0}, clusterSize))
    writeData(file, createData(XY{-4, 4}, clusterSize))
    writeData(file, createData(XY{4, 4}, clusterSize))
}

// ファイルに出力(Gnuplotのデータ形式)
func writeData(file *os.File, data []XY) {
    for _, xy := range data {
        file.WriteString(fmt.Sprintf("%f %f\n", xy.X, xy.Y))
    }
}

func createData(center XY, num int) []XY {
    data := make([]XY, num)
    for i := range data {
        data[i].X = rand.NormFloat64() + center.X
        data[i].Y = rand.NormFloat64() + center.Y
    }
    return data
}

k-means

上で作ったデータを使ってk-meansでクラスタリングします。実装したコードは次の通りです。計算効率はあまり考えずに実装しました。

package main

import (
    "fmt"
    "os"
    "bufio"
    "strings"
    "strconv"
    "math"
    "math/rand"
)


type XY struct { X, Y float64 }

type Cluster []XY

func main() {
    rand.Seed(int64(114514))
    numCluster := 3

    clusters := readData("points.dat", numCluster)

    for n := 0; n < 10; n++ {
        // 重心計算
        centers := make([]XY, 3)
        for i := 0; i < numCluster; i++ {
            centers[i] = centerOfCluster(clusters[i])
        }
        // クラスターの更新
        clusters = updateClusters(clusters, centers, numCluster)
    }

    writeData(clusters, "result.dat", numCluster)
}

// ファイルからデータ読み込み、ついでに初期クラスタに分ける
func readData(fileName string, numCluster int) []Cluster {
    file, err := os.Open(fileName)
    if err != nil {
        panic(err)
    }
    defer file.Close()

    clusters := make([]Cluster, numCluster)

    scanner := bufio.NewScanner(file)
    for scanner.Scan() {
        line := scanner.Text()
        xy := strings.SplitN(line, " ", 2)

        x, _ := strconv.ParseFloat(xy[0], 64)
        y, _ := strconv.ParseFloat(xy[1], 64)
        p := XY{x, y}

        group := rand.Intn(numCluster)
        clusters[group] = append(clusters[group], p)
    }

    return clusters
}

// Gnuplotのデータ形式で書き込み
// 1クラスターのデータを1ブロックに
func writeData(clusters []Cluster, fileName string, numCluster int) {
    file, err := os.Create(fileName)
    if err != nil {
        panic(err)
    }
    defer file.Close()

    for i := 0; i < numCluster; i++ {
        for _, p := range clusters[i] {
            file.WriteString(fmt.Sprintf("%f %f\n", p.X, p.Y))
        }
        file.WriteString("\n")
    }
}

// クラスターの重心を計算
func centerOfCluster(cluster Cluster) XY {
    clusterSize := len(cluster)
    var sumX float64 = 0
    var sumY float64 = 0

    for _, p := range cluster {
        sumX += p.X
        sumY += p.Y
    }

    cX := sumX / float64(clusterSize)
    cY := sumY / float64(clusterSize)

    return XY{cX, cY}
}

// 2点間の距離
func distance(p XY, q XY) float64 {
    d2 := math.Pow(p.X - q.X, 2.0) + math.Pow(p.Y - q.Y, 2.0)
    return math.Sqrt(d2)
}

// クラスターの更新
func updateClusters(clusters []Cluster, centers []XY, numCluster int) []Cluster {
    newClusters := make([]Cluster, numCluster)

    for i := 0; i < numCluster; i++ {
        for _, p := range clusters[i] {
            group := 0
            minDistance := distance(p, centers[0])

            // 距離が最も近いクラスターを探す
            for j := 1; j < numCluster; j++ {
                d := distance(p, centers[j])
                if d < minDistance {
                    group = j
                    minDistance = d
                }
            }

            newClusters[group] = append(newClusters[group], p)
        }
    }

    return newClusters
}

本当は収束判定をしてk-meansの計算を打ち切るかどうかを決めるべきですが、今回は簡略化のため、回数を決め打ちしてk-meansの計算をしています。

実装してて思ったのは、typeで気軽に型に別名を付けられるのは便利ということですね。

さて、上のコードを使ってクラスタリングした結果はこのようになりました。

f:id:renor:20200327002311p:plain

うまくクラスタリングできていますね。めでたし、めでたし。