Tuesday, June 14, 2016

Multi-feature gradient descent in java

public class GDescent {

    public static void main(String[] args) {
  
         float x1[] = {1f,1f,1f,1f,2f,2f,2f,3f,3f,3f};  
         float x2[] = {6f,6f,6f,6f,9f,9f,9f,13f,13f,13f};
         float y[] = {2f,2f,3f,4f,5f,5f,6f,5.5f,6f,7f};  
      
         /*number of examples*/
         int m = 10;
      
          /*thetas and temp thetas*/
            float theta0, temp0, theta1, temp1, theta2, temp2;
            theta0 = 0.0f; temp0 = 0.0f;
            theta1 = 0.0f; temp1 = 0.0f;
            theta2 = 0.0f; temp2 = 0.0f;
        
        /*# of iterations and learning rate*/
            int iterations = 3430000;
            float alpha = 0.003f;

            int j = 0;
            float h0 = 0.0f;
            float h1 = 0.0f;
            float h2 = 0.0f;
            int i = 0;
            for(i = 0; i < iterations; i++)
            {
        
                h0 = 0.0f;
                h1 = 0.0f;
                h2 = 0.0f;
                for(j = 0; j<m; j++)
                {
                    h0 = h0 + ((theta0 +  x1[j]*theta1 + x2[j]*theta2) - y[j]);
                    h1 = h1 + ((theta0 +  x1[j]*theta1 + x2[j]*theta2) - y[j])*x1[j];
                    h2 = h2 + ((theta0 +  x1[j]*theta1 + x2[j]*theta2) - y[j])*x2[j];
                }   
        temp0 = theta0 - (alpha*h0)/(float)m;
        temp1 = theta1 - (alpha*h1)/(float)m;
        temp2 = theta2 - (alpha*h2)/(float)m;
        theta0 = temp0;
        theta1 = temp1;
        theta2 = temp2;
        }
          
       System.out.println("" + theta2 + "x2 + " + theta1 + "x1 + " + theta0);
     
       testGradientDescent(2f, theta0, theta1, theta2);

    }
  
    private static void testGradientDescent(float n, float theta0, float theta1, float theta2)
    {
        float result = theta0 + (theta1*1) + (theta2*6);
        //float x[] = {1f,2f,3f,4f,5f,6f,7f,8f,9f,10f};  
        //float y[] = {2f,5f,10f,17f,26f,37f,50f,65f,82f,101f};
        //int marks[][]={{1,2},{2,5},{3,10},{4,17},{5,26},{6,37},{7,50},{8,65},{9,82},{10,101}};
      
        System.out.println("Result: " + result);
        //String outputFunction = String.valueOf(String.valueOf(theta0)  + "+" + String.valueOf(theta2) + "*x*x");
        //System.out.println("Plotting " + outputFunction);
        //JavaPlot p = new JavaPlot();
        //p.addPlot(marks);
      
      // // p.addPlot(outputFunction);
       // p.setPersist(true);
       // p.plot();
        //p.setPersist(true);
      
    }

}


Notice that the partial derivatives are calculated here:

h0 = h0 + ((theta0 +  x1[j]*theta1 + x2[j]*theta2) - y[j]);
h1 = h1 + ((theta0 +  x1[j]*theta1 + x2[j]*theta2) - y[j])*x1[j];
h2 = h2 + ((theta0 +  x1[j]*theta1 + x2[j]*theta2) - y[j])*x2[j];

No comments:

Post a Comment