15 changed files with 163 additions and 225 deletions
			
			
		@ -1,46 +1,46 @@
					 | 
				
			||||
package com.mh.algorithm.knn; | 
				
			||||
 | 
				
			||||
import com.mh.algorithm.matrix.Matrix; | 
				
			||||
import com.mh.algorithm.utils.CsvInfo; | 
				
			||||
import com.mh.algorithm.utils.CsvUtil; | 
				
			||||
import com.mh.algorithm.utils.DoubleUtil; | 
				
			||||
import org.junit.Test; | 
				
			||||
 | 
				
			||||
/** | 
				
			||||
 * @program: top-algorithm-set | 
				
			||||
 * @description: | 
				
			||||
 * @author: Mr.Zhao | 
				
			||||
 * @create: 2020-10-26 22:04 | 
				
			||||
 **/ | 
				
			||||
public class knnTest { | 
				
			||||
    @Test | 
				
			||||
    public void test() throws Exception { | 
				
			||||
        // 训练集
 | 
				
			||||
        CsvInfo csvInfo = CsvUtil.getCsvInfo(false, "E:\\jarTest\\trainData.csv"); | 
				
			||||
        Matrix trainSet = csvInfo.toMatrix(); | 
				
			||||
        Matrix trainSetLabels = trainSet.getColOfIdx(trainSet.getMatrixColCount() - 1); | 
				
			||||
        Matrix trainSetData = trainSet.subMatrix(0, trainSet.getMatrixRowCount(), 0, trainSet.getMatrixColCount() - 1); | 
				
			||||
 | 
				
			||||
        CsvInfo csvInfo1 = CsvUtil.getCsvInfo(false, "E:\\jarTest\\testData.csv"); | 
				
			||||
        Matrix testSet = csvInfo1.toMatrix(); | 
				
			||||
        Matrix testSetData = trainSet.subMatrix(0, testSet.getMatrixRowCount(), 0, testSet.getMatrixColCount() - 1); | 
				
			||||
        Matrix testSetLabels = trainSet.getColOfIdx(testSet.getMatrixColCount() - 1); | 
				
			||||
 | 
				
			||||
        // 分类
 | 
				
			||||
        long startTime = System.currentTimeMillis(); | 
				
			||||
        Matrix result = KNN.classify(testSetData, trainSetData, trainSetLabels, 5); | 
				
			||||
        long endTime = System.currentTimeMillis(); | 
				
			||||
        System.out.println("run time:" + (endTime - startTime)); | 
				
			||||
        // 正确率
 | 
				
			||||
        Matrix error = result.subtract(testSetLabels); | 
				
			||||
        int total = error.getMatrixRowCount(); | 
				
			||||
        int correct = 0; | 
				
			||||
        for (int i = 0; i < error.getMatrixRowCount(); i++) { | 
				
			||||
            if (DoubleUtil.equals(error.getValOfIdx(i, 0), 0.0)) { | 
				
			||||
                correct++; | 
				
			||||
            } | 
				
			||||
        } | 
				
			||||
        double correctRate = Double.valueOf(correct) / Double.valueOf(total); | 
				
			||||
        System.out.println("correctRate:"+ correctRate); | 
				
			||||
    } | 
				
			||||
} | 
				
			||||
//package com.mh.algorithm.knn;
 | 
				
			||||
//
 | 
				
			||||
//import com.mh.algorithm.matrix.Matrix;
 | 
				
			||||
//import com.mh.algorithm.utils.CsvInfo;
 | 
				
			||||
//import com.mh.algorithm.utils.CsvUtil;
 | 
				
			||||
//import com.mh.algorithm.utils.DoubleUtil;
 | 
				
			||||
//import org.junit.Test;
 | 
				
			||||
//
 | 
				
			||||
///**
 | 
				
			||||
// * @program: top-algorithm-set
 | 
				
			||||
// * @description:
 | 
				
			||||
// * @author: Mr.Zhao
 | 
				
			||||
// * @create: 2020-10-26 22:04
 | 
				
			||||
// **/
 | 
				
			||||
//public class knnTest {
 | 
				
			||||
//    @Test
 | 
				
			||||
//    public void test() throws Exception {
 | 
				
			||||
//        // 训练集
 | 
				
			||||
//        CsvInfo csvInfo = CsvUtil.getCsvInfo(false, "E:\\jarTest\\trainData.csv");
 | 
				
			||||
//        Matrix trainSet = csvInfo.toMatrix();
 | 
				
			||||
//        Matrix trainSetLabels = trainSet.getColOfIdx(trainSet.getMatrixColCount() - 1);
 | 
				
			||||
//        Matrix trainSetData = trainSet.subMatrix(0, trainSet.getMatrixRowCount(), 0, trainSet.getMatrixColCount() - 1);
 | 
				
			||||
//
 | 
				
			||||
//        CsvInfo csvInfo1 = CsvUtil.getCsvInfo(false, "E:\\jarTest\\testData.csv");
 | 
				
			||||
//        Matrix testSet = csvInfo1.toMatrix();
 | 
				
			||||
//        Matrix testSetData = trainSet.subMatrix(0, testSet.getMatrixRowCount(), 0, testSet.getMatrixColCount() - 1);
 | 
				
			||||
//        Matrix testSetLabels = trainSet.getColOfIdx(testSet.getMatrixColCount() - 1);
 | 
				
			||||
//
 | 
				
			||||
//        // 分类
 | 
				
			||||
//        long startTime = System.currentTimeMillis();
 | 
				
			||||
//        Matrix result = KNN.classify(testSetData, trainSetData, trainSetLabels, 5);
 | 
				
			||||
//        long endTime = System.currentTimeMillis();
 | 
				
			||||
//        System.out.println("run time:" + (endTime - startTime));
 | 
				
			||||
//        // 正确率
 | 
				
			||||
//        Matrix error = result.subtract(testSetLabels);
 | 
				
			||||
//        int total = error.getMatrixRowCount();
 | 
				
			||||
//        int correct = 0;
 | 
				
			||||
//        for (int i = 0; i < error.getMatrixRowCount(); i++) {
 | 
				
			||||
//            if (DoubleUtil.equals(error.getValOfIdx(i, 0), 0.0)) {
 | 
				
			||||
//                correct++;
 | 
				
			||||
//            }
 | 
				
			||||
//        }
 | 
				
			||||
//        double correctRate = Double.valueOf(correct) / Double.valueOf(total);
 | 
				
			||||
//        System.out.println("correctRate:"+ correctRate);
 | 
				
			||||
//    }
 | 
				
			||||
//}
 | 
				
			||||
					 | 
				
			||||
@ -1,13 +1,13 @@
					 | 
				
			||||
package com.mh.common.annotation; | 
				
			||||
 | 
				
			||||
import java.lang.annotation.*; | 
				
			||||
 | 
				
			||||
/** | 
				
			||||
 * Created by fangzhipeng on 2017/7/12. | 
				
			||||
 */ | 
				
			||||
@Target(ElementType.METHOD) | 
				
			||||
@Retention(RetentionPolicy.RUNTIME) | 
				
			||||
@Documented | 
				
			||||
public @interface SysLogger { | 
				
			||||
    String value() default ""; | 
				
			||||
} | 
				
			||||
//package com.mh.common.annotation;
 | 
				
			||||
//
 | 
				
			||||
//import java.lang.annotation.*;
 | 
				
			||||
//
 | 
				
			||||
///**
 | 
				
			||||
// * Created by fangzhipeng on 2017/7/12.
 | 
				
			||||
// */
 | 
				
			||||
//@Target(ElementType.METHOD)
 | 
				
			||||
//@Retention(RetentionPolicy.RUNTIME)
 | 
				
			||||
//@Documented
 | 
				
			||||
//public @interface SysLogger {
 | 
				
			||||
//    String value() default "";
 | 
				
			||||
//}
 | 
				
			||||
					 | 
				
			||||
@ -1,42 +1,42 @@
					 | 
				
			||||
package com.mh.common.utils; | 
				
			||||
 | 
				
			||||
import javax.servlet.http.HttpServletResponse; | 
				
			||||
import java.io.BufferedInputStream; | 
				
			||||
import java.io.BufferedOutputStream; | 
				
			||||
import java.io.File; | 
				
			||||
import java.io.FileInputStream; | 
				
			||||
import java.io.InputStream; | 
				
			||||
 | 
				
			||||
/** | 
				
			||||
 * 文件相关操作 | 
				
			||||
 * @author Louis | 
				
			||||
 * @date Jan 14, 2019 | 
				
			||||
 */ | 
				
			||||
public class FileUtils { | 
				
			||||
 | 
				
			||||
	/** | 
				
			||||
	 * 下载文件 | 
				
			||||
	 * @param response | 
				
			||||
	 * @param file | 
				
			||||
	 * @param newFileName | 
				
			||||
	 */ | 
				
			||||
	public static void downloadFile(HttpServletResponse response, File file, String newFileName) { | 
				
			||||
		try { | 
				
			||||
			response.setHeader("Content-Disposition", "attachment; filename=" + new String(newFileName.getBytes("ISO-8859-1"), "UTF-8")); | 
				
			||||
			BufferedOutputStream bos = new BufferedOutputStream(response.getOutputStream()); | 
				
			||||
			InputStream is = new FileInputStream(file.getAbsolutePath()); | 
				
			||||
			BufferedInputStream bis = new BufferedInputStream(is); | 
				
			||||
			int length = 0; | 
				
			||||
			byte[] temp = new byte[1 * 1024 * 10]; | 
				
			||||
			while ((length = bis.read(temp)) != -1) { | 
				
			||||
				bos.write(temp, 0, length); | 
				
			||||
			} | 
				
			||||
			bos.flush(); | 
				
			||||
			bis.close(); | 
				
			||||
			bos.close(); | 
				
			||||
			is.close(); | 
				
			||||
		} catch (Exception e) { | 
				
			||||
			e.printStackTrace(); | 
				
			||||
		} | 
				
			||||
	} | 
				
			||||
} | 
				
			||||
//package com.mh.common.utils;
 | 
				
			||||
//
 | 
				
			||||
//import javax.servlet.http.HttpServletResponse;
 | 
				
			||||
//import java.io.BufferedInputStream;
 | 
				
			||||
//import java.io.BufferedOutputStream;
 | 
				
			||||
//import java.io.File;
 | 
				
			||||
//import java.io.FileInputStream;
 | 
				
			||||
//import java.io.InputStream;
 | 
				
			||||
//
 | 
				
			||||
///**
 | 
				
			||||
// * 文件相关操作
 | 
				
			||||
// * @author Louis
 | 
				
			||||
// * @date Jan 14, 2019
 | 
				
			||||
// */
 | 
				
			||||
//public class FileUtils {
 | 
				
			||||
//
 | 
				
			||||
//	/**
 | 
				
			||||
//	 * 下载文件
 | 
				
			||||
//	 * @param response
 | 
				
			||||
//	 * @param file
 | 
				
			||||
//	 * @param newFileName
 | 
				
			||||
//	 */
 | 
				
			||||
//	public static void downloadFile(HttpServletResponse response, File file, String newFileName) {
 | 
				
			||||
//		try {
 | 
				
			||||
//			response.setHeader("Content-Disposition", "attachment; filename=" + new String(newFileName.getBytes("ISO-8859-1"), "UTF-8"));
 | 
				
			||||
//			BufferedOutputStream bos = new BufferedOutputStream(response.getOutputStream());
 | 
				
			||||
//			InputStream is = new FileInputStream(file.getAbsolutePath());
 | 
				
			||||
//			BufferedInputStream bis = new BufferedInputStream(is);
 | 
				
			||||
//			int length = 0;
 | 
				
			||||
//			byte[] temp = new byte[1 * 1024 * 10];
 | 
				
			||||
//			while ((length = bis.read(temp)) != -1) {
 | 
				
			||||
//				bos.write(temp, 0, length);
 | 
				
			||||
//			}
 | 
				
			||||
//			bos.flush();
 | 
				
			||||
//			bis.close();
 | 
				
			||||
//			bos.close();
 | 
				
			||||
//			is.close();
 | 
				
			||||
//		} catch (Exception e) {
 | 
				
			||||
//			e.printStackTrace();
 | 
				
			||||
//		}
 | 
				
			||||
//	}
 | 
				
			||||
//}
 | 
				
			||||
					 | 
				
			||||
					Loading…
					
					
				
		Reference in new issue