Basic statistics concepts for machine learning in Scala spark
Before applying some distribution algorithm or probability density function or probability mass function, we need to understand some basic concepts of statistics these concepts might be though in our school ,we shall start by brushing up the concepts and implement those in Scala spark,Just for an overview i will be covering Mean, Median & Mode also Variation and Standard Deviation.
1) Mean
This is probably simplest concept of all, Average value in a set of Discrete Numerical Value the formulae is simple (sum of elements)/number of Elements. so lets Begin with an example how to calculate the mean in spark using Scala.
Now Lets Download Sample Data from here [Annual enterprise survey: 2018 financial year (provisional) — size bands CSV]
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.SparkStrategies
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._object Main extends App {
print("hello world"); val spark = SparkSession
.builder()
.appName("test")
.config("spark.master", "local")
.getOrCreate();
var data = spark.read.format("csv").
option("header", true).load("/<parth to downlaoded file>/<name>.csv").toDF(); data = data
.withColumn("rn",row_number()
.over(Window.orderBy("year"))).toDF(); data = data.filter(data("rn") > 2).toDF(); data.select(mean("value") as "mean").show()}
So i am creating stand alone spark session and assigning to variable and creating an data frame from CSV and creating row number by using Window
imported from org.apache.spark.sql.expressions.Window
and removing unwanted records from csv. and generating mean at the last line data.select(mean("value") as "mean").show()
2) Median
Is the Middle Value in an Sorted Set for example [1,2,3,4,5] Value 3 is Median, so it is easy when we have odd values in list. in case if have even number in set i.e [1,2,3,4] so middle value will between 2,3 so the median will be 2+3/2.
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.SparkStrategies
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._object Main extends App {
print("hello world"); val spark = SparkSession
.builder()
.appName("test")
.config("spark.master", "local")
.getOrCreate();
var data = spark.read.format("csv").
option("header", true).load("/<parth to downlaoded file>/<name>.csv").toDF(); data = data
.withColumn("rn",row_number()
.over(Window.orderBy("year"))).toDF(); data = data.filter(data("rn") > 2).toDF();
var median = data
.withColumn("units",data("value").cast(FloatType))
.select("units")
.stat.approxQuantile("units",Array(0.5),0.0001)
.head; println("=================================="); println(median); println("===================================")
}
The above code will fetch median value. as the the value is from CVS data , type will be String unlike like sql function stat function will not cast data, so we have to manually cast the data.
3) Mode
The number which appears most frequently in an set is called as mode for example [1,1,2,5,6,5] so the occurrence set will be [1:2,2:1,5:2,6:1] and the mode value will be 5,so lets see an sample code to find mode in Apache Spark.
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.execution.SparkStrategies
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types.FloatTypeobject Main extends App {
print("hello world"); val spark = SparkSession
.builder()
.appName("test")
.config("spark.master", "local")
.getOrCreate();
var data = spark.read.format("csv").
option("header", true).load("/home/ashrith/data/test.csv").toDF(); data = data
.withColumn("rn",row_number()
.over(Window.orderBy("year"))).toDF(); data = data.filter(data("rn") > 2).toDF(); data = data.filter(data("value") !== "C").groupBy("value").count(); data.orderBy(data("count").desc).first().get(0); /*var occurance = data.orderBy(data("count").desc).toDF();
occurance.first().get();*/
}
The above snippet is find the mode value form the data, so in spark it is not straight forward first we have group the occurrence and generate count and the order it by descending order and fetch the first row.
4) Variance
Is a measure that how far the data set is spread out, to calculate the variance
- Find the mean value of the data set.
- Subtract the Mean for each number in the data set and square the difference
- Then Work out the Average of the second step.
To Code find Variance in spark using Scala
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.expressions.Window
import org.apache.spark.sql.functions._object Main extends App {
print("hello world"); val spark = SparkSession
.builder()
.appName("test")
.config("spark.master", "local")
.getOrCreate();
var data = spark.read.format("csv").
option("header", true).load("/home/ashrith/data/test.csv").toDF(); data = data
.withColumn("rn",row_number()
.over(Window.orderBy("year"))).toDF(); data = data.filter(data("rn") > 2).toDF(); data.filter(data("value") !== "C").agg(variance(data("value"))).show();}
Standard Deviation
A quantity expressing by how much the members of a group differ from the mean value for the group. this is very useful in finding an outliers histogram, outliers are the abnormal distance from the group, the occurrence of these numbers are uncommon. so if you are take average household income in an regions, billionaire are called as outliers.
object Main extends App {
print("hello world"); val spark = SparkSession
.builder()
.appName("test")
.config("spark.master", "local")
.getOrCreate();
var data = spark.read.format("csv").
option("header", true).load("/home/ashrith/data/test.csv").toDF(); data = data
.withColumn("rn",row_number()
.over(Window.orderBy("year"))).toDF(); data = data.filter(data("rn") > 2).toDF(); data.filter(data("value") !== "C").agg(stddev(data("value"))).show();
}
Also Refer
Also published At :