Welcome to ShenZhenJia Knowledge Sharing Community for programmer and developer-Open, Learning and Share
menu search
person
Welcome To Ask or Share your Answers For Others

Categories

I am following the tutorial over here : https://www.rpubs.com/loveb/som . This tutorial shows how to use the Kohonen Network (also called SOM, a type of machine learning algorithm) on the iris data.

I ran this code from the tutorial:

library(kohonen) #fitting SOMs
library(ggplot2) #plots
library(GGally) #plots
library(RColorBrewer) #colors, using predefined palettes

iris_complete <-iris[complete.cases(iris),] 
iris_unique <- unique(iris_complete) # Remove duplicates

#scale data
iris.sc = scale(iris_unique[, 1:4]) #Levels/Factors cannot be scaled... But used in predictive SOM:s using xyf. Later.

#build grid
iris.grid = somgrid(xdim = 10, ydim=10, topo="hexagonal", toroidal = TRUE)

set.seed(33) #for reproducability
iris.som <- som(iris.sc, grid=iris.grid, rlen=700, alpha=c(0.05,0.01), keep.data = TRUE)

#plot 1
plot(iris.som, type="count")

#plot2
var <- 1 #define the variable to plot
plot(iris.som, type = "property", property = getCodes(iris.som)[,var], main=colnames(getCodes(iris.som))[var], palette.name=terrain.colors)

The above code fits a Kohonen Network on the iris data. Each observation from the data set is assigned to each one of the "colorful circles" (also called "neurons") in the below pictures.

My question: In these plots, how would you identify which observations were assigned to which circles? Suppose I wanted to know which observations belong in the circles outlined in with the black triangles below:

enter image description here enter image description here

Is it possible to do this? Right now, I am trying to use iris.som$classif to somehow trace which points are in which circle. Is there a better way to do this?

UPDATE: @Jonny Phelps showed me how to identify observations within a triangular form (see answer below). But i am still not sure if it possible to identify irregular shaped forms. E.g. enter image description here

In a previous post (Labelling Points on a Plot (R Language)), a user showed me how to assign arbitrary numbers to each circle on the grid:

enter image description here

Based on the above plot, how could you use the "som$classif" statement to find out which observations were in circles 92,91,82,81,72 and 71?

Thanks

See Question&Answers more detail:os

与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
thumb_up_alt 0 like thumb_down_alt 0 dislike
546 views
Welcome To Ask or Share your Answers For Others

1 Answer

EDIT: Now with Shiny App!

A plotly solution is also possible, where you can mouse over individual neurons to display the associated iris rownames (called id here). Based on your iris.som data and Jonny Phelps' grid approach, you can just assign the row numbers as concatenated strings to the individual neurons and have these shown upon mouseover:

library(ggplot2)
library(plotly)
ga <- data.frame(g=iris.som$unit.classif, 
                 sample=seq_len(dim(iris.som$data[[1]])[1]))
grid_pts <- as.data.frame(iris.som$grid$pts)
grid_pts$column <- rep(1:iris.som$grid$xdim, by=iris.som$grid$ydim)
grid_pts$row <- rep(1:iris.som$grid$ydim, each=iris.som$grid$xdim)
grid_pts$classif <- 1:nrow(grid_pts)
grid_pts$id <- sapply(seq_along(grid_pts$classif), 
                      function(x) paste(ga$sample[ga$g==x], collapse=", "))
grid_pts$count <- sapply(seq_along(grid_pts$classif), 
                         function(x) length(ga$sample[ga$g==x]))
grid_pts$count <- factor(grid_pts$count, levels=0:max(grid_pts$count))
p1 <- ggplot(grid_pts, aes(x=x, y=y, colour=count, row=row, column=column, id=id)) +
    geom_point(size=8) +
    scale_colour_manual(values=c("grey50", heat.colors(length(unique(grid_pts$count))))) +
    theme_void() +
    theme(plot.margin=unit(c(1,rep(.3, 3)),"cm"))
ggplotly(p1)

Here is a full Shiny app that allows lasso selection and shows a table with the data:

invisible(suppressPackageStartupMessages(
    lapply(c("shiny","dplyr","ggplot2", "plotly", "kohonen", "GGally", "DT"),
           require, character.only=TRUE)))

iris_complete <- iris[complete.cases(iris),] 
iris_unique <- unique(iris_complete) # Remove duplicates

#scale data
iris.sc = scale(iris_unique[, 1:4]) #Levels/Factors cannot be scaled... But used in predictive SOM:s using xyf. Later.

#build grid
iris.grid = somgrid(xdim = 10, ydim=10, topo="hexagonal", toroidal = TRUE)

set.seed(33) #for reproducability
iris.som <- som(iris.sc, grid=iris.grid, rlen=700, alpha=c(0.05,0.01), keep.data = TRUE)

ga <- data.frame(g=iris.som$unit.classif, 
                 sample=seq_len(dim(iris.som$data[[1]])[1]))
grid_pts <- as.data.frame(iris.som$grid$pts)
grid_pts$column <- rep(1:iris.som$grid$xdim, by=iris.som$grid$ydim)
grid_pts$row <- rep(1:iris.som$grid$ydim, each=iris.som$grid$xdim)
grid_pts$classif <- 1:nrow(grid_pts)
grid_pts$id <- sapply(seq_along(grid_pts$classif), 
                      function(x) paste(ga$sample[ga$g==x], collapse=", "))
grid_pts$count <- sapply(seq_along(grid_pts$classif), 
                         function(x) length(ga$sample[ga$g==x]))
grid_pts$count <- factor(grid_pts$count, levels=0:max(grid_pts$count))

# Shiny app, adapted from https://gist.github.com/dgrapov/128e3be71965bf00495768e47f0428b9

ui <- fluidPage(
    fluidRow(
        column(12, plotlyOutput("plot", height = "600px")),
        column(12, DT::dataTableOutput('data_table'))
    )
)


server <- function(input, output){
    
    output$plot <- renderPlotly({
        req(data()) 
        p <- ggplot(data = data()$data, 
            aes(x=x, y=y, classif=classif, colour=count, row=row, column=column, id=id)) +
            geom_point(size=8) +
            scale_colour_manual(
                values=c("grey50", heat.colors(length(unique(grid_pts$count))))
            ) +
            theme_void() +
            theme(plot.margin=unit(c(1, rep(.3, 3)), "cm"))
        
        obj <- data()$sel
        if(nrow(obj) != 0) {
            p <- p + geom_point(data=obj, mapping=aes(x=x, y=y, classif=classif, 
                    count=count, row=row, column=column, id=id), color="blue", 
                    size=5, inherit.aes=FALSE)
        }
        ggplotly(p, source="p1") %>% layout(dragmode = "lasso")
    })
   
    selected <- reactive({
        event_data("plotly_selected", source = "p1")
    })

    output$data_table <- DT::renderDataTable(
        data()$sel, filter='top', options=list(  
            pageLength=5, autoWidth=TRUE
        )
    )
    
    data <- reactive({
        tmp <- grid_pts 
        sel <- tryCatch(filter(grid_pts, paste(x, y, sep="_") %in% 
                paste(selected()$x, selected()$y, sep="_")),
            error=function(e){NULL})
        list(data=tmp, sel=sel)
    })
}  

shinyApp(ui,server)


与恶龙缠斗过久,自身亦成为恶龙;凝视深渊过久,深渊将回以凝视…
thumb_up_alt 0 like thumb_down_alt 0 dislike
Welcome to ShenZhenJia Knowledge Sharing Community for programmer and developer-Open, Learning and Share
...