Browse Source

1、添加BP神经网络预测算法

dev
mh 5 months ago
parent
commit
f0719ee78d
  1. 38
      2024数据库脚本.sql
  2. 135
      algorithm/pom.xml
  3. 8
      algorithm/src/main/java/com/mh/algorithm/bpnn/ActivationFunction.java
  4. 111
      algorithm/src/main/java/com/mh/algorithm/bpnn/BPModel.java
  5. 257
      algorithm/src/main/java/com/mh/algorithm/bpnn/BPNeuralNetworkFactory.java
  6. 106
      algorithm/src/main/java/com/mh/algorithm/bpnn/BPParameter.java
  7. 15
      algorithm/src/main/java/com/mh/algorithm/bpnn/Sigmoid.java
  8. 24
      algorithm/src/main/java/com/mh/algorithm/constants/OrderEnum.java
  9. 88
      algorithm/src/main/java/com/mh/algorithm/knn/KNN.java
  10. 646
      algorithm/src/main/java/com/mh/algorithm/matrix/Matrix.java
  11. 53
      algorithm/src/main/java/com/mh/algorithm/utils/CsvInfo.java
  12. 66
      algorithm/src/main/java/com/mh/algorithm/utils/CsvUtil.java
  13. 20
      algorithm/src/main/java/com/mh/algorithm/utils/DoubleUtil.java
  14. 285
      algorithm/src/main/java/com/mh/algorithm/utils/MatrixUtil.java
  15. 32
      algorithm/src/main/java/com/mh/algorithm/utils/SerializationUtil.java
  16. 71
      algorithm/src/test/java/com/mh/algorithm/bpnn/bpnnTest.java
  17. 46
      algorithm/src/test/java/com/mh/algorithm/knn/knnTest.java
  18. 1
      pom.xml
  19. 6
      user-service/src/main/java/com/mh/user/controller/ControlSetController.java
  20. 4
      user-service/src/main/java/com/mh/user/controller/SerialPortController.java
  21. 1
      user-service/src/main/java/com/mh/user/entity/MaintainInfoEntity.java
  22. 19
      user-service/src/main/java/com/mh/user/mapper/ControlSetMapper.java
  23. 8
      user-service/src/main/java/com/mh/user/mapper/MaintainInfoMapper.java
  24. 2
      user-service/src/main/java/com/mh/user/service/ControlSetService.java
  25. 9
      user-service/src/main/java/com/mh/user/service/impl/ControlSetServiceImpl.java
  26. 6
      user-service/src/main/java/com/mh/user/service/impl/DeviceControlServiceImpl.java

38
2024数据库脚本.sql

@ -0,0 +1,38 @@
-- 2024-05-07 维修表缺少字段
ALTER TABLE maintain_info ADD cost numeric(2,0) NULL;
EXEC sys.sp_addextendedproperty 'MS_Description', N'材料费用', 'schema', N'dbo', 'table', N'maintain_info', 'column', N'cost';
ALTER TABLE maintain_info ADD contents varchar(100) NULL;
EXEC sys.sp_addextendedproperty 'MS_Description', N'维保内容', 'schema', N'dbo', 'table', N'maintain_info', 'column', N'contents';
ALTER TABLE maintain_info ADD evaluate varchar(10) NULL;
EXEC sys.sp_addextendedproperty 'MS_Description', N'评价内容', 'schema', N'dbo', 'table', N'maintain_info', 'column', N'evaluate';
-- 训练集合:
select
eds.cur_date,
eds.building_id,
isnull(eds.water_value,
0) as water_value,
isnull(eds.elect_value,
0) as elect_value,
isnull(convert(numeric(24,2),t1.water_level),
0) as water_level
from
energy_day_sum eds
left join (
select
convert(date,
cur_date) as cur_date,
building_id,
avg(isnull(convert(numeric(24, 2), water_level), 0)) as water_level
from
history_data
group by
convert(date,
cur_date),
building_id
) t1 on
eds.cur_date = t1.cur_date and eds.building_id = t1.building_id
where eds.building_id != '所有'
order by
eds.building_id,
eds.cur_date

135
algorithm/pom.xml

@ -0,0 +1,135 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<parent>
<groupId>com.mh</groupId>
<artifactId>chws</artifactId>
<version>1.0-SNAPSHOT</version>
</parent>
<modelVersion>4.0.0</modelVersion>
<groupId>com.mh</groupId>
<artifactId>algorithm</artifactId>
<version>1.0.0</version>
<packaging>jar</packaging>
<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
<encoding>UTF-8</encoding>
<java.version>1.8</java.version>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
</properties>
<dependencies>
<!-- https://mvnrepository.com/artifact/net.sourceforge.javacsv/javacsv -->
<dependency>
<groupId>net.sourceforge.javacsv</groupId>
<artifactId>javacsv</artifactId>
<version>2.0</version>
</dependency>
<dependency>
<groupId>gov.nist.math</groupId>
<artifactId>jama</artifactId>
<version>1.0.3</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>RELEASE</version>
<scope>test</scope>
</dependency>
</dependencies>
<build>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.1</version>
<configuration>
<source>1.8</source>
<target>1.8</target>
</configuration>
</plugin>
</plugins>
</build>
<profiles>
<profile>
<id>default</id>
<activation>
<activeByDefault>true</activeByDefault>
</activation>
<build>
<plugins>
<!-- java版本 -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-compiler-plugin</artifactId>
<version>3.8.0</version>
<configuration>
<source>1.8</source>
<target>1.8</target>
<encoding>UTF-8</encoding>
</configuration>
</plugin>
<!-- 这是javadoc打包插件 -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-javadoc-plugin</artifactId>
<version>2.9.1</version>
<executions>
<execution>
<id>attach-javadocs</id>
<goals>
<goal>jar</goal>
</goals>
<!-- 该处屏蔽jdk1.8后javadoc的严格校验 -->
<configuration>
<additionalparam>-Xdoclint:none</additionalparam>
</configuration>
</execution>
</executions>
</plugin>
<!-- 打包源码插件 -->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-source-plugin</artifactId>
<version>2.3</version>
<executions>
<execution>
<id>attach-sources</id>
<goals>
<goal>jar</goal>
</goals>
</execution>
</executions>
</plugin>
<!--签名插件-->
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-gpg-plugin</artifactId>
<version>1.4</version>
<executions>
<execution>
<id>sign-artifacts</id>
<phase>verify</phase>
<goals>
<goal>sign</goal>
</goals>
</execution>
</executions>
</plugin>
<plugin>
<artifactId>maven-jar-plugin</artifactId>
<version>2.3.1</version>
<configuration>
<classesDirectory>target/classes</classesDirectory>
</configuration>
</plugin>
</plugins>
</build>
</profile>
</profiles>
</project>

8
algorithm/src/main/java/com/mh/algorithm/bpnn/ActivationFunction.java

@ -0,0 +1,8 @@
package com.mh.algorithm.bpnn;
public interface ActivationFunction {
//计算值
double computeValue(double val);
//计算导数
double computeDerivative(double val);
}

111
algorithm/src/main/java/com/mh/algorithm/bpnn/BPModel.java

@ -0,0 +1,111 @@
package com.mh.algorithm.bpnn;
import com.mh.algorithm.matrix.Matrix;
import java.io.Serializable;
public class BPModel implements Serializable {
//BP神经网络权值与阈值
private Matrix weightIJ;
private Matrix b1;
private Matrix weightJP;
private Matrix b2;
/*用于反归一化*/
private Matrix inputMax;
private Matrix inputMin;
private Matrix outputMax;
private Matrix outputMin;
/*BP神经网络训练参数*/
private BPParameter bpParameter;
/*BP神经网络训练情况*/
private double error;
private int times;
public Matrix getWeightIJ() {
return weightIJ;
}
public void setWeightIJ(Matrix weightIJ) {
this.weightIJ = weightIJ;
}
public Matrix getB1() {
return b1;
}
public void setB1(Matrix b1) {
this.b1 = b1;
}
public Matrix getWeightJP() {
return weightJP;
}
public void setWeightJP(Matrix weightJP) {
this.weightJP = weightJP;
}
public Matrix getB2() {
return b2;
}
public void setB2(Matrix b2) {
this.b2 = b2;
}
public Matrix getInputMax() {
return inputMax;
}
public void setInputMax(Matrix inputMax) {
this.inputMax = inputMax;
}
public Matrix getInputMin() {
return inputMin;
}
public void setInputMin(Matrix inputMin) {
this.inputMin = inputMin;
}
public Matrix getOutputMax() {
return outputMax;
}
public void setOutputMax(Matrix outputMax) {
this.outputMax = outputMax;
}
public Matrix getOutputMin() {
return outputMin;
}
public void setOutputMin(Matrix outputMin) {
this.outputMin = outputMin;
}
public BPParameter getBpParameter() {
return bpParameter;
}
public void setBpParameter(BPParameter bpParameter) {
this.bpParameter = bpParameter;
}
public double getError() {
return error;
}
public void setError(double error) {
this.error = error;
}
public int getTimes() {
return times;
}
public void setTimes(int times) {
this.times = times;
}
}

257
algorithm/src/main/java/com/mh/algorithm/bpnn/BPNeuralNetworkFactory.java

@ -0,0 +1,257 @@
package com.mh.algorithm.bpnn;
import com.mh.algorithm.matrix.Matrix;
import com.mh.algorithm.utils.MatrixUtil;
import java.util.*;
public class BPNeuralNetworkFactory {
/**
* 训练BP神经网络模型
* @param bpParameter
* @param inputAndOutput
* @return
*/
public BPModel trainBP(BPParameter bpParameter, Matrix inputAndOutput) throws Exception {
ActivationFunction activationFunction = bpParameter.getActivationFunction();
int inputCount = bpParameter.getInputLayerNeuronCount();
int hiddenCount = bpParameter.getHiddenLayerNeuronCount();
int outputCount = bpParameter.getOutputLayerNeuronCount();
double normalizationMin = bpParameter.getNormalizationMin();
double normalizationMax = bpParameter.getNormalizationMax();
double step = bpParameter.getStep();
double momentumFactor = bpParameter.getMomentumFactor();
double precision = bpParameter.getPrecision();
int maxTimes = bpParameter.getMaxTimes();
if(inputAndOutput.getMatrixColCount() != inputCount + outputCount){
throw new Exception("神经元个数不符,请修改");
}
// 初始化权值
Matrix weightIJ = initWeight(inputCount, hiddenCount);
Matrix weightJP = initWeight(hiddenCount, outputCount);
// 初始化阈值
Matrix b1 = initThreshold(hiddenCount);
Matrix b2 = initThreshold(outputCount);
// 动量项
Matrix deltaWeightIJ0 = new Matrix(inputCount, hiddenCount);
Matrix deltaWeightJP0 = new Matrix(hiddenCount, outputCount);
Matrix deltaB10 = new Matrix(1, hiddenCount);
Matrix deltaB20 = new Matrix(1, outputCount);
// 截取输入矩阵和输出矩阵
Matrix input = inputAndOutput.subMatrix(0,inputAndOutput.getMatrixRowCount(),0,inputCount);
Matrix output = inputAndOutput.subMatrix(0,inputAndOutput.getMatrixRowCount(),inputCount,outputCount);
// 归一化
Map<String,Object> inputAfterNormalize = MatrixUtil.normalize(input, normalizationMin, normalizationMax);
input = (Matrix) inputAfterNormalize.get("res");
Map<String,Object> outputAfterNormalize = MatrixUtil.normalize(output, normalizationMin, normalizationMax);
output = (Matrix) outputAfterNormalize.get("res");
int times = 1;
double E = 0;//误差
while (times < maxTimes) {
/*-----------------正向传播---------------------*/
// 隐含层输入
Matrix jIn = input.multiple(weightIJ);
// 扩充阈值
Matrix b1Copy = b1.extend(2,jIn.getMatrixRowCount());
// 加上阈值
jIn = jIn.plus(b1Copy);
// 隐含层输出
Matrix jOut = computeValue(jIn,activationFunction);
// 输出层输入
Matrix pIn = jOut.multiple(weightJP);
// 扩充阈值
Matrix b2Copy = b2.extend(2, pIn.getMatrixRowCount());
// 加上阈值
pIn = pIn.plus(b2Copy);
// 输出层输出
Matrix pOut = computeValue(pIn,activationFunction);
// 计算误差
Matrix e = output.subtract(pOut);
E = computeE(e);//误差
// 判断是否符合精度
if (Math.abs(E) <= precision) {
System.out.println("满足精度");
break;
}
/*-----------------反向传播---------------------*/
// J与P之间权值修正量
Matrix deltaWeightJP = e.multiple(step);
deltaWeightJP = deltaWeightJP.pointMultiple(computeDerivative(pIn,activationFunction));
deltaWeightJP = deltaWeightJP.transpose().multiple(jOut);
deltaWeightJP = deltaWeightJP.transpose();
// P层神经元阈值修正量
Matrix deltaThresholdP = e.multiple(step);
deltaThresholdP = deltaThresholdP.transpose().multiple(computeDerivative(pIn, activationFunction));
// I与J之间的权值修正量
Matrix deltaO = e.pointMultiple(computeDerivative(pIn,activationFunction));
Matrix tmp = weightJP.multiple(deltaO.transpose()).transpose();
Matrix deltaWeightIJ = tmp.pointMultiple(computeDerivative(jIn, activationFunction));
deltaWeightIJ = input.transpose().multiple(deltaWeightIJ);
deltaWeightIJ = deltaWeightIJ.multiple(step);
// J层神经元阈值修正量
Matrix deltaThresholdJ = tmp.transpose().multiple(computeDerivative(jIn, activationFunction));
deltaThresholdJ = deltaThresholdJ.multiple(-step);
if (times == 1) {
// 更新权值与阈值
weightIJ = weightIJ.plus(deltaWeightIJ);
weightJP = weightJP.plus(deltaWeightJP);
b1 = b1.plus(deltaThresholdJ);
b2 = b2.plus(deltaThresholdP);
}else{
// 加动量项
weightIJ = weightIJ.plus(deltaWeightIJ).plus(deltaWeightIJ0.multiple(momentumFactor));
weightJP = weightJP.plus(deltaWeightJP).plus(deltaWeightJP0.multiple(momentumFactor));
b1 = b1.plus(deltaThresholdJ).plus(deltaB10.multiple(momentumFactor));
b2 = b2.plus(deltaThresholdP).plus(deltaB20.multiple(momentumFactor));
}
deltaWeightIJ0 = deltaWeightIJ;
deltaWeightJP0 = deltaWeightJP;
deltaB10 = deltaThresholdJ;
deltaB20 = deltaThresholdP;
times++;
}
// BP神经网络的输出
BPModel result = new BPModel();
result.setInputMax((Matrix) inputAfterNormalize.get("max"));
result.setInputMin((Matrix) inputAfterNormalize.get("min"));
result.setOutputMax((Matrix) outputAfterNormalize.get("max"));
result.setOutputMin((Matrix) outputAfterNormalize.get("min"));
result.setWeightIJ(weightIJ);
result.setWeightJP(weightJP);
result.setB1(b1);
result.setB2(b2);
result.setError(E);
result.setTimes(times);
result.setBpParameter(bpParameter);
System.out.println("循环次数:" + times + ",误差:" + E);
return result;
}
/**
* 计算BP神经网络的值
* @param bpModel
* @param input
* @return
*/
public Matrix computeBP(BPModel bpModel,Matrix input) throws Exception {
if (input.getMatrixColCount() != bpModel.getBpParameter().getInputLayerNeuronCount()) {
throw new Exception("输入矩阵纬度有误");
}
ActivationFunction activationFunction = bpModel.getBpParameter().getActivationFunction();
Matrix weightIJ = bpModel.getWeightIJ();
Matrix weightJP = bpModel.getWeightJP();
Matrix b1 = bpModel.getB1();
Matrix b2 = bpModel.getB2();
double[][] normalizedInput = new double[input.getMatrixRowCount()][input.getMatrixColCount()];
for (int i = 0; i < input.getMatrixRowCount(); i++) {
for (int j = 0; j < input.getMatrixColCount(); j++) {
normalizedInput[i][j] = bpModel.getBpParameter().getNormalizationMin()
+ (input.getValOfIdx(i,j) - bpModel.getInputMin().getValOfIdx(0,j))
/ (bpModel.getInputMax().getValOfIdx(0,j) - bpModel.getInputMin().getValOfIdx(0,j))
* (bpModel.getBpParameter().getNormalizationMax() - bpModel.getBpParameter().getNormalizationMin());
}
}
Matrix normalizedInputMatrix = new Matrix(normalizedInput);
Matrix jIn = normalizedInputMatrix.multiple(weightIJ);
// 扩充阈值
Matrix b1Copy = b1.extend(2,jIn.getMatrixRowCount());
// 加上阈值
jIn = jIn.plus(b1Copy);
// 隐含层输出
Matrix jOut = computeValue(jIn,activationFunction);
// 输出层输入
Matrix pIn = jOut.multiple(weightJP);
// 扩充阈值
Matrix b2Copy = b2.extend(2,pIn.getMatrixRowCount());
// 加上阈值
pIn = pIn.plus(b2Copy);
// 输出层输出
Matrix pOut = computeValue(pIn,activationFunction);
// 反归一化
return MatrixUtil.inverseNormalize(pOut, bpModel.getBpParameter().getNormalizationMax(), bpModel.getBpParameter().getNormalizationMin(), bpModel.getOutputMax(), bpModel.getOutputMin());
}
// 初始化权值
private Matrix initWeight(int x,int y){
Random random=new Random();
double[][] weight = new double[x][y];
for (int i = 0; i < x; i++) {
for (int j = 0; j < y; j++) {
weight[i][j] = 2*random.nextDouble()-1;
}
}
return new Matrix(weight);
}
// 初始化阈值
private Matrix initThreshold(int x){
Random random = new Random();
double[][] result = new double[1][x];
for (int i = 0; i < x; i++) {
result[0][i] = 2*random.nextDouble()-1;
}
return new Matrix(result);
}
/**
* 计算激活函数的值
* @param a
* @return
*/
private Matrix computeValue(Matrix a, ActivationFunction activationFunction) throws Exception {
if (a.getMatrix() == null) {
throw new Exception("参数值为空");
}
double[][] result = new double[a.getMatrixRowCount()][a.getMatrixColCount()];
for (int i = 0; i < a.getMatrixRowCount(); i++) {
for (int j = 0; j < a.getMatrixColCount(); j++) {
result[i][j] = activationFunction.computeValue(a.getValOfIdx(i,j));
}
}
return new Matrix(result);
}
/**
* 激活函数导数的值
* @param a
* @return
*/
private Matrix computeDerivative(Matrix a , ActivationFunction activationFunction) throws Exception {
if (a.getMatrix() == null) {
throw new Exception("参数值为空");
}
double[][] result = new double[a.getMatrixRowCount()][a.getMatrixColCount()];
for (int i = 0; i < a.getMatrixRowCount(); i++) {
for (int j = 0; j < a.getMatrixColCount(); j++) {
result[i][j] = activationFunction.computeDerivative(a.getValOfIdx(i,j));
}
}
return new Matrix(result);
}
/**
* 计算误差
* @param e
* @return
*/
private double computeE(Matrix e){
e = e.square();
return 0.5*e.sumAll();
}
}

106
algorithm/src/main/java/com/mh/algorithm/bpnn/BPParameter.java

@ -0,0 +1,106 @@
package com.mh.algorithm.bpnn;
import java.io.Serializable;
public class BPParameter implements Serializable {
//输入层神经元个数
private int inputLayerNeuronCount = 3;
//隐含层神经元个数
private int hiddenLayerNeuronCount = 3;
//输出层神经元个数
private int outputLayerNeuronCount = 1;
//归一化区间
private double normalizationMin = 0.2;
private double normalizationMax = 0.8;
//学习步长
private double step = 0.05;
//动量因子
private double momentumFactor = 0.2;
//激活函数
private ActivationFunction activationFunction = new Sigmoid();
//精度
private double precision = 0.000001;
//最大循环次数
private int maxTimes = 1000000;
public double getMomentumFactor() {
return momentumFactor;
}
public void setMomentumFactor(double momentumFactor) {
this.momentumFactor = momentumFactor;
}
public double getStep() {
return step;
}
public void setStep(double step) {
this.step = step;
}
public double getNormalizationMin() {
return normalizationMin;
}
public void setNormalizationMin(double normalizationMin) {
this.normalizationMin = normalizationMin;
}
public double getNormalizationMax() {
return normalizationMax;
}
public void setNormalizationMax(double normalizationMax) {
this.normalizationMax = normalizationMax;
}
public int getInputLayerNeuronCount() {
return inputLayerNeuronCount;
}
public void setInputLayerNeuronCount(int inputLayerNeuronCount) {
this.inputLayerNeuronCount = inputLayerNeuronCount;
}
public int getHiddenLayerNeuronCount() {
return hiddenLayerNeuronCount;
}
public void setHiddenLayerNeuronCount(int hiddenLayerNeuronCount) {
this.hiddenLayerNeuronCount = hiddenLayerNeuronCount;
}
public int getOutputLayerNeuronCount() {
return outputLayerNeuronCount;
}
public void setOutputLayerNeuronCount(int outputLayerNeuronCount) {
this.outputLayerNeuronCount = outputLayerNeuronCount;
}
public ActivationFunction getActivationFunction() {
return activationFunction;
}
public void setActivationFunction(ActivationFunction activationFunction) {
this.activationFunction = activationFunction;
}
public double getPrecision() {
return precision;
}
public void setPrecision(double precision) {
this.precision = precision;
}
public int getMaxTimes() {
return maxTimes;
}
public void setMaxTimes(int maxTimes) {
this.maxTimes = maxTimes;
}
}

15
algorithm/src/main/java/com/mh/algorithm/bpnn/Sigmoid.java

@ -0,0 +1,15 @@
package com.mh.algorithm.bpnn;
import java.io.Serializable;
public class Sigmoid implements ActivationFunction, Serializable {
@Override
public double computeValue(double val) {
return 1 / (1 + Math.exp(-val));
}
@Override
public double computeDerivative(double val) {
return computeValue(val) * (1 - computeValue(val));
}
}

24
algorithm/src/main/java/com/mh/algorithm/constants/OrderEnum.java

@ -0,0 +1,24 @@
package com.mh.algorithm.constants;
/**
* 排序枚举类
*/
public enum OrderEnum {
ASC(1,"升序"),
DESC(2,"降序");
OrderEnum(int flag, String name) {
this.flag = flag;
this.name = name;
}
private int flag;
private String name;
}

88
algorithm/src/main/java/com/mh/algorithm/knn/KNN.java

@ -0,0 +1,88 @@
package com.mh.algorithm.knn;
import com.mh.algorithm.constants.OrderEnum;
import com.mh.algorithm.matrix.Matrix;
import com.mh.algorithm.utils.MatrixUtil;
import java.util.*;
/**
* @program: top-algorithm-set
* @description: KNN k-临近算法进行分类
* @author: Mr.Zhao
* @create: 2020-10-13 22:03
**/
public class KNN {
public static Matrix classify(Matrix input, Matrix dataSet, Matrix labels, int k) throws Exception {
if (dataSet.getMatrixRowCount() != labels.getMatrixRowCount()) {
throw new IllegalArgumentException("矩阵训练集与标签维度不一致");
}
if (input.getMatrixColCount() != dataSet.getMatrixColCount()) {
throw new IllegalArgumentException("待分类矩阵列数与训练集列数不一致");
}
if (dataSet.getMatrixRowCount() < k) {
throw new IllegalArgumentException("训练集样本数小于k");
}
// 归一化
int trainCount = dataSet.getMatrixRowCount();
int testCount = input.getMatrixRowCount();
Matrix trainAndTest = dataSet.splice(2, input);
Map<String, Object> normalize = MatrixUtil.normalize(trainAndTest, 0, 1);
trainAndTest = (Matrix) normalize.get("res");
dataSet = trainAndTest.subMatrix(0, trainCount, 0, trainAndTest.getMatrixColCount());
input = trainAndTest.subMatrix(0, testCount, 0, trainAndTest.getMatrixColCount());
// 获取标签信息
List<Double> labelList = new ArrayList<>();
for (int i = 0; i < labels.getMatrixRowCount(); i++) {
if (!labelList.contains(labels.getValOfIdx(i, 0))) {
labelList.add(labels.getValOfIdx(i, 0));
}
}
Matrix result = new Matrix(new double[input.getMatrixRowCount()][1]);
for (int i = 0; i < input.getMatrixRowCount(); i++) {
// 计算向量间的欧式距离
// 将labels矩阵扩展
Matrix labelMatrixCopied = input.getRowOfIdx(i).extend(2, dataSet.getMatrixRowCount());
// 前面是计算欧氏距离,splice(1,labels)是将距离矩阵与labels矩阵合并
Matrix distanceMatrix = dataSet.subtract(labelMatrixCopied).square().sumRow().pow(0.5).splice(1, labels);
// 将计算出的距离矩阵按照距离升序排序
distanceMatrix.sort(0, OrderEnum.ASC);
// 遍历最近的k个变量
Map<Double, Integer> map = new HashMap<>();
for (int j = 0; j < k; j++) {
// 遍历标签种类数
for (Double label : labelList) {
if (distanceMatrix.getValOfIdx(j, 1) == label) {
map.put(label, map.getOrDefault(label, 0) + 1);
}
}
}
result.setValue(i, 0, getKeyOfMaxValue(map));
}
return result;
}
/**
* 取map中值最大的key
*
* @param map
* @return
*/
private static Double getKeyOfMaxValue(Map<Double, Integer> map) {
if (map == null)
return null;
Double keyOfMaxValue = 0.0;
Integer maxValue = 0;
for (Double key : map.keySet()) {
if (map.get(key) > maxValue) {
keyOfMaxValue = key;
maxValue = map.get(key);
}
}
return keyOfMaxValue;
}
}

646
algorithm/src/main/java/com/mh/algorithm/matrix/Matrix.java

@ -0,0 +1,646 @@
package com.mh.algorithm.matrix;
import com.mh.algorithm.constants.OrderEnum;
import java.io.Serializable;
public class Matrix implements Serializable {
private double[][] matrix;
//矩阵列数
private int matrixColCount;
//矩阵行数
private int matrixRowCount;
/**
* 构造一个空矩阵
*/
public Matrix() {
this.matrix = null;
this.matrixColCount = 0;
this.matrixRowCount = 0;
}
/**
* 构造一个matrix矩阵
* @param matrix
*/
public Matrix(double[][] matrix) {
this.matrix = matrix;
this.matrixRowCount = matrix.length;
this.matrixColCount = matrix[0].length;
}
/**
* 构造一个rowCount行colCount列值为0的矩阵
* @param rowCount
* @param colCount
*/
public Matrix(int rowCount,int colCount) {
double[][] matrix = new double[rowCount][colCount];
for (int i = 0; i < rowCount; i++) {
for (int j = 0; j < colCount; j++) {
matrix[i][j] = 0;
}
}
this.matrix = matrix;
this.matrixRowCount = rowCount;
this.matrixColCount = colCount;
}
/**
* 构造一个rowCount行colCount列值为val的矩阵
* @param val
* @param rowCount
* @param colCount
*/
public Matrix(double val,int rowCount,int colCount) {
double[][] matrix = new double[rowCount][colCount];
for (int i = 0; i < rowCount; i++) {
for (int j = 0; j < colCount; j++) {
matrix[i][j] = val;
}
}
this.matrix = matrix;
this.matrixRowCount = rowCount;
this.matrixColCount = colCount;
}
public double[][] getMatrix() {
return matrix;
}
public void setMatrix(double[][] matrix) {
this.matrix = matrix;
this.matrixRowCount = matrix.length;
this.matrixColCount = matrix[0].length;
}
public int getMatrixColCount() {
return matrixColCount;
}
public int getMatrixRowCount() {
return matrixRowCount;
}
/**
* 获取矩阵指定位置的值
*
* @param x
* @param y
* @return
*/
public double getValOfIdx(int x, int y) throws IllegalArgumentException {
if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
throw new IllegalArgumentException("矩阵为空");
}
if (x > matrixRowCount - 1) {
throw new IllegalArgumentException("索引x越界");
}
if (y > matrixColCount - 1) {
throw new IllegalArgumentException("索引y越界");
}
return matrix[x][y];
}
/**
* 获取矩阵指定行
*
* @param x
* @return
*/
public Matrix getRowOfIdx(int x) throws IllegalArgumentException {
if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
throw new IllegalArgumentException("矩阵为空");
}
if (x > matrixRowCount - 1) {
throw new IllegalArgumentException("索引x越界");
}
double[][] result = new double[1][matrixColCount];
result[0] = matrix[x];
return new Matrix(result);
}
/**
* 获取矩阵指定列
*
* @param y
* @return
*/
public Matrix getColOfIdx(int y) throws IllegalArgumentException {
if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
throw new IllegalArgumentException("矩阵为空");
}
if (y > matrixColCount - 1) {
throw new IllegalArgumentException("索引y越界");
}
double[][] result = new double[matrixRowCount][1];
for (int i = 0; i < matrixRowCount; i++) {
result[i][0] = matrix[i][y];
}
return new Matrix(result);
}
/**
* 设置矩阵中x,y位置元素的值
* @param x
* @param y
* @param val
*/
public void setValue(int x, int y, double val) {
if (x > this.matrixRowCount - 1) {
throw new IllegalArgumentException("行索引越界");
}
if (y > this.matrixColCount - 1) {
throw new IllegalArgumentException("列索引越界");
}
this.matrix[x][y] = val;
}
/**
* 矩阵乘矩阵
*
* @param a
* @return
* @throws IllegalArgumentException
*/
public Matrix multiple(Matrix a) throws IllegalArgumentException {
if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
throw new IllegalArgumentException("矩阵为空");
}
if (a.getMatrix() == null || a.getMatrixRowCount() == 0 || a.getMatrixColCount() == 0) {
throw new IllegalArgumentException("参数矩阵为空");
}
if (matrixColCount != a.getMatrixRowCount()) {
throw new IllegalArgumentException("矩阵纬度不同,不可计算");
}
double[][] result = new double[matrixRowCount][a.getMatrixColCount()];
for (int i = 0; i < matrixRowCount; i++) {
for (int j = 0; j < a.getMatrixColCount(); j++) {
for (int k = 0; k < matrixColCount; k++) {
result[i][j] = result[i][j] + matrix[i][k] * a.getMatrix()[k][j];
}
}
}
return new Matrix(result);
}
/**
* 矩阵乘一个数字
*
* @param a
* @return
*/
public Matrix multiple(double a) throws IllegalArgumentException {
if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
throw new IllegalArgumentException("矩阵为空");
}
double[][] result = new double[matrixRowCount][matrixColCount];
for (int i = 0; i < matrixRowCount; i++) {
for (int j = 0; j < matrixColCount; j++) {
result[i][j] = matrix[i][j] * a;
}
}
return new Matrix(result);
}
/**
* 矩阵点乘
*
* @param a
* @return
*/
public Matrix pointMultiple(Matrix a) throws IllegalArgumentException {
if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
throw new IllegalArgumentException("矩阵为空");
}
if (a.getMatrix() == null || a.getMatrixRowCount() == 0 || a.getMatrixColCount() == 0) {
throw new IllegalArgumentException("参数矩阵为空");
}
if (matrixRowCount != a.getMatrixRowCount() && matrixColCount != a.getMatrixColCount()) {
throw new IllegalArgumentException("矩阵纬度不同,不可计算");
}
double[][] result = new double[matrixRowCount][matrixColCount];
for (int i = 0; i < matrixRowCount; i++) {
for (int j = 0; j < matrixColCount; j++) {
result[i][j] = matrix[i][j] * a.getMatrix()[i][j];
}
}
return new Matrix(result);
}
/**
* 矩阵除一个数字
* @param a
* @return
* @throws IllegalArgumentException
*/
public Matrix divide(double a) throws IllegalArgumentException {
if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
throw new IllegalArgumentException("矩阵为空");
}
double[][] result = new double[matrixRowCount][matrixColCount];
for (int i = 0; i < matrixRowCount; i++) {
for (int j = 0; j < matrixColCount; j++) {
result[i][j] = matrix[i][j] / a;
}
}
return new Matrix(result);
}
/**
* 矩阵加法
*
* @param a
* @return
*/
public Matrix plus(Matrix a) throws IllegalArgumentException {
if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
throw new IllegalArgumentException("矩阵为空");
}
if (a.getMatrix() == null || a.getMatrixRowCount() == 0 || a.getMatrixColCount() == 0) {
throw new IllegalArgumentException("参数矩阵为空");
}
if (matrixRowCount != a.getMatrixRowCount() && matrixColCount != a.getMatrixColCount()) {
throw new IllegalArgumentException("矩阵纬度不同,不可计算");
}
double[][] result = new double[matrixRowCount][matrixColCount];
for (int i = 0; i < matrixRowCount; i++) {
for (int j = 0; j < matrixColCount; j++) {
result[i][j] = matrix[i][j] + a.getMatrix()[i][j];
}
}
return new Matrix(result);
}
/**
* 矩阵加一个数字
* @param a
* @return
* @throws IllegalArgumentException
*/
public Matrix plus(double a) throws IllegalArgumentException {
if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
throw new IllegalArgumentException("矩阵为空");
}
double[][] result = new double[matrixRowCount][matrixColCount];
for (int i = 0; i < matrixRowCount; i++) {
for (int j = 0; j < matrixColCount; j++) {
result[i][j] = matrix[i][j] + a;
}
}
return new Matrix(result);
}
/**
* 矩阵减法
*
* @param a
* @return
*/
public Matrix subtract(Matrix a) throws IllegalArgumentException {
if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
throw new IllegalArgumentException("矩阵为空");
}
if (a.getMatrix() == null || a.getMatrixRowCount() == 0 || a.getMatrixColCount() == 0) {
throw new IllegalArgumentException("参数矩阵为空");
}
if (matrixRowCount != a.getMatrixRowCount() && matrixColCount != a.getMatrixColCount()) {
throw new IllegalArgumentException("矩阵纬度不同,不可计算");
}
double[][] result = new double[matrixRowCount][matrixColCount];
for (int i = 0; i < matrixRowCount; i++) {
for (int j = 0; j < matrixColCount; j++) {
result[i][j] = matrix[i][j] - a.getMatrix()[i][j];
}
}
return new Matrix(result);
}
/**
* 矩阵减一个数字
* @param a
* @return
* @throws IllegalArgumentException
*/
public Matrix subtract(double a) throws IllegalArgumentException {
if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
throw new IllegalArgumentException("矩阵为空");
}
double[][] result = new double[matrixRowCount][matrixColCount];
for (int i = 0; i < matrixRowCount; i++) {
for (int j = 0; j < matrixColCount; j++) {
result[i][j] = matrix[i][j] - a;
}
}
return new Matrix(result);
}
/**
* 矩阵行求和
*
* @return
*/
public Matrix sumRow() throws IllegalArgumentException {
if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
throw new IllegalArgumentException("矩阵为空");
}
double[][] result = new double[matrixRowCount][1];
for (int i = 0; i < matrixRowCount; i++) {
for (int j = 0; j < matrixColCount; j++) {
result[i][0] += matrix[i][j];
}
}
return new Matrix(result);
}
/**
* 矩阵列求和
*
* @return
*/
public Matrix sumCol() throws IllegalArgumentException {
if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
throw new IllegalArgumentException("矩阵为空");
}
double[][] result = new double[1][matrixColCount];
for (int i = 0; i < matrixRowCount; i++) {
for (int j = 0; j < matrixColCount; j++) {
result[0][j] += matrix[i][j];
}
}
return new Matrix(result);
}
/**
* 矩阵所有元素求和
*
* @return
*/
public double sumAll() throws IllegalArgumentException {
if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
throw new IllegalArgumentException("矩阵为空");
}
double result = 0;
for (double[] doubles : matrix) {
for (int j = 0; j < matrixColCount; j++) {
result += doubles[j];
}
}
return result;
}
/**
* 矩阵所有元素求平方
*
* @return
*/
public Matrix square() throws IllegalArgumentException {
if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
throw new IllegalArgumentException("矩阵为空");
}
double[][] result = new double[matrixRowCount][matrixColCount];
for (int i = 0; i < matrixRowCount; i++) {
for (int j = 0; j < matrixColCount; j++) {
result[i][j] = matrix[i][j] * matrix[i][j];
}
}
return new Matrix(result);
}
/**
* 矩阵所有元素求N次方
*
* @return
*/
public Matrix pow(double n) throws IllegalArgumentException {
if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
throw new IllegalArgumentException("矩阵为空");
}
double[][] result = new double[matrixRowCount][matrixColCount];
for (int i = 0; i < matrixRowCount; i++) {
for (int j = 0; j < matrixColCount; j++) {
result[i][j] = Math.pow(matrix[i][j],n);
}
}
return new Matrix(result);
}
/**
* 矩阵转置
*
* @return
*/
public Matrix transpose() throws IllegalArgumentException {
if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
throw new IllegalArgumentException("矩阵为空");
}
double[][] result = new double[matrixColCount][matrixRowCount];
for (int i = 0; i < matrixRowCount; i++) {
for (int j = 0; j < matrixColCount; j++) {
result[j][i] = matrix[i][j];
}
}
return new Matrix(result);
}
/**
* 截取矩阵
* @param startRowIndex 开始行索引
* @param rowCount 截取行数
* @param startColIndex 开始列索引
* @param colCount 截取列数
* @return
* @throws IllegalArgumentException
*/
public Matrix subMatrix(int startRowIndex,int rowCount,int startColIndex,int colCount) throws IllegalArgumentException {
if (startRowIndex + rowCount > matrixRowCount) {
throw new IllegalArgumentException("行索引越界");
}
if (startColIndex + colCount> matrixColCount) {
throw new IllegalArgumentException("列索引越界");
}
double[][] result = new double[rowCount][colCount];
for (int i = startRowIndex; i < startRowIndex + rowCount; i++) {
if (startColIndex + colCount - startColIndex >= 0)
System.arraycopy(matrix[i], startColIndex, result[i - startRowIndex], 0, colCount);
}
return new Matrix(result);
}
/**
* 矩阵合并
* @param direction 合并方向1为横向2为竖向
* @param a
* @return
* @throws IllegalArgumentException
*/
public Matrix splice(int direction, Matrix a) throws IllegalArgumentException {
if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
throw new IllegalArgumentException("矩阵为空");
}
if (a.getMatrix() == null || a.getMatrixRowCount() == 0 || a.getMatrixColCount() == 0) {
throw new IllegalArgumentException("参数矩阵为空");
}
if(direction == 1){
//横向拼接
if (matrixRowCount != a.getMatrixRowCount()) {
throw new IllegalArgumentException("矩阵行数不一致,无法拼接");
}
double[][] result = new double[matrixRowCount][matrixColCount + a.getMatrixColCount()];
for (int i = 0; i < matrixRowCount; i++) {
System.arraycopy(matrix[i],0,result[i],0,matrixColCount);
System.arraycopy(a.getMatrix()[i],0,result[i],matrixColCount,a.getMatrixColCount());
}
return new Matrix(result);
}else if(direction == 2){
//纵向拼接
if (matrixColCount != a.getMatrixColCount()) {
throw new IllegalArgumentException("矩阵列数不一致,无法拼接");
}
double[][] result = new double[matrixRowCount + a.getMatrixRowCount()][matrixColCount];
for (int i = 0; i < matrixRowCount; i++) {
result[i] = matrix[i];
}
for (int i = 0; i < a.getMatrixRowCount(); i++) {
result[matrixRowCount + i] = a.getMatrix()[i];
}
return new Matrix(result);
}else{
throw new IllegalArgumentException("方向参数有误");
}
}
/**
* 扩展矩阵
* @param direction 扩展方向1为横向2为竖向
* @param a
* @return
* @throws IllegalArgumentException
*/
public Matrix extend(int direction , int a) throws IllegalArgumentException {
if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
throw new IllegalArgumentException("矩阵为空");
}
if(direction == 1){
//横向复制
double[][] result = new double[matrixRowCount][matrixColCount*a];
for (int i = 0; i < matrixRowCount; i++) {
for (int j = 0; j < a; j++) {
System.arraycopy(matrix[i],0,result[i],j*matrixColCount,matrixColCount);
}
}
return new Matrix(result);
}else if(direction == 2){
//纵向复制
double[][] result = new double[matrixRowCount*a][matrixColCount];
for (int i = 0; i < matrixRowCount*a; i++) {
result[i] = matrix[i%matrixRowCount];
}
return new Matrix(result);
}else{
throw new IllegalArgumentException("方向参数有误");
}
}
/**
* 获取每列的平均值
* @return
* @throws IllegalArgumentException
*/
public Matrix getColAvg() throws IllegalArgumentException {
Matrix tmp = this.sumCol();
return tmp.divide(matrixRowCount);
}
/**
* 矩阵行排序
* @param index 根据第几列的数进行行排序
* @param order 排序顺序升序或降序
* @return
* @throws IllegalArgumentException
*/
public void sort(int index, OrderEnum order) throws IllegalArgumentException{
if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
throw new IllegalArgumentException("矩阵为空");
}
if(index >= matrixColCount){
throw new IllegalArgumentException("排序索引index越界");
}
sort(index,order,0,this.matrixRowCount - 1);
}
/**
* 判断是否是方阵
* 行列数相等并且不等于0
* @return
*/
public boolean isSquareMatrix(){
return matrixColCount == matrixRowCount && matrixColCount != 0;
}
@Override
public String toString() {
StringBuilder stringBuilder = new StringBuilder();
stringBuilder.append("\r\n");
for (int i = 0; i < matrixRowCount; i++) {
stringBuilder.append("# ");
for (int j = 0; j < matrixColCount; j++) {
stringBuilder.append(matrix[i][j]).append("\t ");
}
stringBuilder.append("#\r\n");
}
stringBuilder.append("\r\n");
return stringBuilder.toString();
}
private void sort(int index,OrderEnum order,int start,int end){
if(start >= end){
return;
}
int tmp = partition(index,order,start,end);
sort(index,order, start, tmp - 1);
sort(index,order, tmp + 1, end);
}
private int partition(int index,OrderEnum order,int start,int end){
int l = start + 1,r = end;
double v = matrix[start][index];
switch (order){
case ASC:
while(true){
while(matrix[r][index] >= v && r > start){
r--;
}
while(matrix[l][index] <= v && l < end){
l++;
}
if(l >= r){
break;
}
double[] tmp = matrix[r];
matrix[r] = matrix[l];
matrix[l] = tmp;
}
break;
case DESC:
while(true){
while(matrix[r][index] <= v && r > start){
r--;
}
while(matrix[l][index] >= v && l < end){
l++;
}
if(l >= r){
break;
}
double[] tmp = matrix[r];
matrix[r] = matrix[l];
matrix[l] = tmp;
}
break;
}
double[] tmp = matrix[r];
matrix[r] = matrix[start];
matrix[start] = tmp;
return r;
}
}

53
algorithm/src/main/java/com/mh/algorithm/utils/CsvInfo.java

@ -0,0 +1,53 @@
package com.mh.algorithm.utils;
import com.mh.algorithm.matrix.Matrix;
import java.util.ArrayList;
public class CsvInfo {
private String[] header;
private int csvRowCount;
private int csvColCount;
private ArrayList<String[]> csvFileList;
public String[] getHeader() {
return header;
}
public void setHeader(String[] header) {
this.header = header;
}
public int getCsvRowCount() {
return csvRowCount;
}
public int getCsvColCount() {
return csvColCount;
}
public ArrayList<String[]> getCsvFileList() {
return csvFileList;
}
public void setCsvFileList(ArrayList<String[]> csvFileList) {
this.csvFileList = csvFileList;
this.csvColCount = csvFileList.get(0) != null?csvFileList.get(0).length:0;
this.csvRowCount = csvFileList.size();
}
public Matrix toMatrix() throws Exception {
double[][] arr = new double[csvFileList.size()][csvFileList.get(0).length];
for (int i = 0; i < csvFileList.size(); i++) {
for (int j = 0; j < csvFileList.get(0).length; j++) {
try {
arr[i][j] = Double.parseDouble(csvFileList.get(i)[j]);
}catch (NumberFormatException e){
throw new Exception("Csv中含有非数字字符,无法转换成Matrix对象");
}
}
}
return new Matrix(arr);
}
}

66
algorithm/src/main/java/com/mh/algorithm/utils/CsvUtil.java

@ -0,0 +1,66 @@
package com.mh.algorithm.utils;
import com.csvreader.CsvReader;
import com.csvreader.CsvWriter;
import com.mh.algorithm.matrix.Matrix;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
import java.util.ArrayList;
public class CsvUtil {
/**
* 获取CSV中的信息
* @param hasHeader 是否含有表头
* @param path CSV文件的路径
* @return
* @throws IOException
*/
public static CsvInfo getCsvInfo(boolean hasHeader , String path) throws IOException {
//创建csv对象,存储csv中的信息
CsvInfo csvInfo = new CsvInfo();
//获取CsvReader流
CsvReader csvReader = new CsvReader(path, ',', StandardCharsets.UTF_8);
if(hasHeader){
csvReader.readHeaders();
}
//获取Csv中的所有记录
ArrayList<String[]> csvFileList = new ArrayList<String[]>();
while (csvReader.readRecord()) {
csvFileList.add(csvReader.getValues());
}
//赋值
csvInfo.setHeader(csvReader.getHeaders());
csvInfo.setCsvFileList(csvFileList);
//关闭流
csvReader.close();
return csvInfo;
}
/**
* 将矩阵写入到csv文件中
* @param header 表头
* @param data 以矩阵形式存放的数据
* @param path 写入的文件地址
* @throws Exception
*/
public static void createCsvFile(String[] header,Matrix data,String path) throws Exception {
if (header!=null && header.length != data.getMatrixColCount()) {
throw new Exception("表头列数与数据列数不符");
}
CsvWriter csvWriter = new CsvWriter(path, ',', StandardCharsets.UTF_8);
if (header != null) {
csvWriter.writeRecord(header);
}
for (int i = 0; i < data.getMatrixRowCount(); i++) {
String[] record = new String[data.getMatrixColCount()];
for (int j = 0; j < data.getMatrixColCount(); j++) {
record[j] = data.getValOfIdx(i, j)+"";
}
csvWriter.writeRecord(record);
}
csvWriter.close();
}
}

20
algorithm/src/main/java/com/mh/algorithm/utils/DoubleUtil.java

@ -0,0 +1,20 @@
package com.mh.algorithm.utils;
/**
* @program: top-algorithm-set
* @description: DoubleTool
* @author: Mr.Zhao
* @create: 2020-11-12 21:54
**/
public class DoubleUtil {
private static final Double MAX_ERROR = 0.0001;
public static boolean equals(Double a, Double b) {
return Math.abs(a - b)< MAX_ERROR;
}
public static boolean equals(Double a, Double b,Double maxError) {
return Math.abs(a - b)< maxError;
}
}

285
algorithm/src/main/java/com/mh/algorithm/utils/MatrixUtil.java

@ -0,0 +1,285 @@
package com.mh.algorithm.utils;
import Jama.EigenvalueDecomposition;
import com.mh.algorithm.matrix.Matrix;
import java.util.*;
public class MatrixUtil {
/**
* 创建一个单位矩阵
* @param matrixRowCount 单位矩阵的纬度
* @return
*/
public static Matrix eye(int matrixRowCount){
double[][] result = new double[matrixRowCount][matrixRowCount];
for (int i = 0; i < matrixRowCount; i++) {
for (int j = 0; j < matrixRowCount; j++) {
if(i == j){
result[i][j] = 1;
}else{
result[i][j] = 0;
}
}
}
return new Matrix(result);
}
/**
* 求矩阵的逆
* 原理:AE=EA^-1
* @param a
* @return
* @throws Exception
*/
public static Matrix inv(Matrix a) throws Exception {
if (!invable(a)) {
throw new Exception("矩阵不可逆");
}
// [a|E]
Matrix b = a.splice(1, eye(a.getMatrixRowCount()));
double[][] data = b.getMatrix();
int rowCount = b.getMatrixRowCount();
int colCount = b.getMatrixColCount();
//此处应用a的列数,为简化,直接用b的行数
for (int j = 0; j < rowCount; j++) {
//若遇到0则交换两行
int notZeroRow = -2;
if(data[j][j] == 0){
notZeroRow = -1;
for (int l = j; l < rowCount; l++) {
if (data[l][j] != 0) {
notZeroRow = l;
break;
}
}
}
if (notZeroRow == -1) {
throw new Exception("矩阵不可逆");
}else if(notZeroRow != -2){
//交换j与notZeroRow两行
double[] tmp = data[j];
data[j] = data[notZeroRow];
data[notZeroRow] = tmp;
}
//将第data[j][j]化为1
if (data[j][j] != 1) {
double multiple = data[j][j];
for (int colIdx = j; colIdx < colCount; colIdx++) {
data[j][colIdx] /= multiple;
}
}
//行与行相减
for (int i = 0; i < rowCount; i++) {
if (i != j) {
double multiple = data[i][j] / data[j][j];
//遍历行中的列
for (int k = j; k < colCount; k++) {
data[i][k] = data[i][k] - multiple * data[j][k];
}
}
}
}
Matrix result = new Matrix(data);
return result.subMatrix(0, rowCount, rowCount, rowCount);
}
/**
* 求矩阵的伴随矩阵
* 原理:A*=|A|A^-1
* @param a
* @return
* @throws Exception
*/
public static Matrix adj(Matrix a) throws Exception {
return inv(a).multiple(det(a));
}
/**
* 矩阵转成上三角矩阵
* @param a
* @return
* @throws Exception
*/
public static Matrix getTopTriangle(Matrix a) throws Exception {
if (!a.isSquareMatrix()) {
throw new Exception("不是方阵无法进行计算");
}
int matrixHeight = a.getMatrixRowCount();
double[][] result = a.getMatrix();
//遍历列
for (int j = 0; j < matrixHeight; j++) {
//遍历行
for (int i = j+1; i < matrixHeight; i++) {
//若遇到0则交换两行
int notZeroRow = -2;
if(result[j][j] == 0){
notZeroRow = -1;
for (int l = i; l < matrixHeight; l++) {
if (result[l][j] != 0) {
notZeroRow = l;
break;
}
}
}
if (notZeroRow == -1) {
throw new Exception("矩阵不可逆");
}else if(notZeroRow != -2){
//交换j与notZeroRow两行
double[] tmp = result[j];
result[j] = result[notZeroRow];
result[notZeroRow] = tmp;
}
double multiple = result[i][j]/result[j][j];
//遍历行中的列
for (int k = j; k < matrixHeight; k++) {
result[i][k] = result[i][k] - multiple * result[j][k];
}
}
}
return new Matrix(result);
}
/**
* 计算矩阵的行列式
* @param a
* @return
* @throws Exception
*/
public static double det(Matrix a) throws Exception {
//将矩阵转成上三角矩阵
Matrix b = MatrixUtil.getTopTriangle(a);
double result = 1;
//计算矩阵行列式
for (int i = 0; i < b.getMatrixRowCount(); i++) {
result *= b.getValOfIdx(i, i);
}
return result;
}
/**
* 获取协方差矩阵
* @param a
* @return
* @throws Exception
*/
public static Matrix cov(Matrix a) throws Exception {
if (a.getMatrix() == null) {
throw new Exception("矩阵为空");
}
Matrix avg = a.getColAvg().extend(2, a.getMatrixRowCount());
Matrix tmp = a.subtract(avg);
return tmp.transpose().multiple(tmp).multiple(1/((double) a.getMatrixRowCount() -1));
}
/**
* 判断矩阵是否可逆
* 如果可转为上三角矩阵则可逆
* @param a
* @return
*/
public static boolean invable(Matrix a) {
try {
getTopTriangle(a);
return true;
} catch (Exception e) {
return false;
}
}
/**
* 获取矩阵的特征值矩阵调用Jama中的getV方法
* @param a
* @return
*/
public static Matrix getV(Matrix a) {
EigenvalueDecomposition eig = new EigenvalueDecomposition(new Jama.Matrix(a.getMatrix()));
return new Matrix(eig.getV().getArray());
}
/**
* 取特征值实部
* @param a
* @return
*/
public double[] getRealEigenvalues(Matrix a){
EigenvalueDecomposition eig = new EigenvalueDecomposition(new Jama.Matrix(a.getMatrix()));
return eig.getRealEigenvalues();
}
/**
* 取特征值虚部
* @param a
* @return
*/
public double[] getImagEigenvalues(Matrix a){
EigenvalueDecomposition eig = new EigenvalueDecomposition(new Jama.Matrix(a.getMatrix()));
return eig.getImagEigenvalues();
}
/**
* 取块对角特征值矩阵
* @param a
* @return
*/
public static Matrix getD(Matrix a) {
EigenvalueDecomposition eig = new EigenvalueDecomposition(new Jama.Matrix(a.getMatrix()));
return new Matrix(eig.getD().getArray());
}
/**
* 数据归一化
* @param a 要归一化的数据
* @param normalizationMin 要归一化的区间下限
* @param normalizationMax 要归一化的区间上限
* @return
*/
public static Map<String, Object> normalize(Matrix a, double normalizationMin, double normalizationMax) throws Exception {
HashMap<String, Object> result = new HashMap<>();
double[][] maxArr = new double[1][a.getMatrixColCount()];
double[][] minArr = new double[1][a.getMatrixColCount()];
double[][] res = new double[a.getMatrixRowCount()][a.getMatrixColCount()];
for (int i = 0; i < a.getMatrixColCount(); i++) {
List tmp = new ArrayList();
for (int j = 0; j < a.getMatrixRowCount(); j++) {
tmp.add(a.getValOfIdx(j,i));
}
double max = (double) Collections.max(tmp);
double min = (double) Collections.min(tmp);
//数据归一化(注:若max与min均为0则不需要归一化)
if (max != 0 || min != 0) {
for (int j = 0; j < a.getMatrixRowCount(); j++) {
res[j][i] = normalizationMin + (a.getValOfIdx(j,i) - min) / (max - min) * (normalizationMax - normalizationMin);
}
}
maxArr[0][i] = max;
minArr[0][i] = min;
}
result.put("max", new Matrix(maxArr));
result.put("min", new Matrix(minArr));
result.put("res", new Matrix(res));
return result;
}
/**
* 反归一化
* @param a 要反归一化的数据
* @param normalizationMin 要反归一化的区间下限
* @param normalizationMax 要反归一化的区间上限
* @param dataMax 数据最大值
* @param dataMin 数据最小值
* @return
*/
public static Matrix inverseNormalize(Matrix a, double normalizationMax, double normalizationMin , Matrix dataMax,Matrix dataMin){
double[][] res = new double[a.getMatrixRowCount()][a.getMatrixColCount()];
for (int i = 0; i < a.getMatrixColCount(); i++) {
//数据反归一化
if (dataMin.getValOfIdx(0,i) != 0 || dataMax.getValOfIdx(0,i) != 0) {
for (int j = 0; j < a.getMatrixRowCount(); j++) {
res[j][i] = dataMin.getValOfIdx(0,i) + (dataMax.getValOfIdx(0,i) - dataMin.getValOfIdx(0,i)) * (a.getValOfIdx(j,i) - normalizationMin) / (normalizationMax - normalizationMin);
}
}
}
return new Matrix(res);
}
}

32
algorithm/src/main/java/com/mh/algorithm/utils/SerializationUtil.java

@ -0,0 +1,32 @@
package com.mh.algorithm.utils;
import java.io.*;
public class SerializationUtil {
/**
* 对象序列化到本地
* @param object
* @throws IOException
*/
public static void serialize(Object object, String path) throws IOException {
File file = new File(path);
System.out.println(file.getAbsolutePath());
ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(file));
out.writeObject(object);
out.close();
}
/**
* 对象反序列化
* @return
* @throws IOException
* @throws ClassNotFoundException
*/
public static Object deSerialization(String path) throws IOException, ClassNotFoundException {
File file = new File(path);
ObjectInputStream oin = new ObjectInputStream(new FileInputStream(file));
Object object = oin.readObject();
oin.close();
return object;
}
}

71
algorithm/src/test/java/com/mh/algorithm/bpnn/bpnnTest.java

@ -0,0 +1,71 @@
package com.mh.algorithm.bpnn;
import com.mh.algorithm.matrix.Matrix;
import com.mh.algorithm.utils.CsvInfo;
import com.mh.algorithm.utils.CsvUtil;
import com.mh.algorithm.utils.SerializationUtil;
import org.junit.Test;
import java.util.Date;
public class bpnnTest {
@Test
public void test() throws Exception {
// 创建训练集矩阵
CsvInfo csvInfo = CsvUtil.getCsvInfo(true, "D:\\ljf\\my_pro\\top-algorithm-set-dev\\src\\trainDataElec.csv");
Matrix trainSet = csvInfo.toMatrix();
// 创建BPNN工厂对象
BPNeuralNetworkFactory factory = new BPNeuralNetworkFactory();
// 创建BP参数对象
BPParameter bpParameter = new BPParameter();
bpParameter.setInputLayerNeuronCount(2);
bpParameter.setHiddenLayerNeuronCount(2);
bpParameter.setOutputLayerNeuronCount(2);
bpParameter.setPrecision(0.01);
bpParameter.setMaxTimes(100000);
// 训练BP神经网络
System.out.println(new Date());
BPModel bpModel = factory.trainBP(bpParameter, trainSet);
System.out.println(new Date());
// 将BPModel序列化到本地
SerializationUtil.serialize(bpModel, "elec");
CsvInfo csvInfo2 = CsvUtil.getCsvInfo(true, "D:\\ljf\\my_pro\\top-algorithm-set-dev\\src\\testDataElec.csv");
Matrix testSet = csvInfo2.toMatrix();
Matrix testData1 = testSet.subMatrix(0, testSet.getMatrixRowCount(), 0, testSet.getMatrixColCount() - 2);
Matrix testLabel = testSet.subMatrix(0, testSet.getMatrixRowCount(), testSet.getMatrixColCount() - 2, 1);
// 将BPModel反序列化
BPModel bpModel1 = (BPModel) SerializationUtil.deSerialization("elec");
Matrix result = factory.computeBP(bpModel1, testData1);
int total = result.getMatrixRowCount();
int correct = 0;
for (int i = 0; i < result.getMatrixRowCount(); i++) {
if(Math.round(result.getValOfIdx(i,0)) == testLabel.getValOfIdx(i,0)){
correct++;
}
}
double correctRate = Double.valueOf(correct) / Double.valueOf(total);
System.out.println(correctRate);
}
/**
* 使用示例
* @throws Exception
*/
@Test
public void bpnnUsing() throws Exception{
CsvInfo csvInfo = CsvUtil.getCsvInfo(false, "D:\\ljf\\my_pro\\top-algorithm-set-dev\\src\\dataElec.csv");
Matrix data = csvInfo.toMatrix();
// 将BPModel反序列化
BPModel bpModel1 = (BPModel) SerializationUtil.deSerialization("elec");
// 创建工厂
BPNeuralNetworkFactory factory = new BPNeuralNetworkFactory();
Matrix result = factory.computeBP(bpModel1, data);
CsvUtil.createCsvFile(null,result,"D:\\ljf\\my_pro\\top-algorithm-set-dev\\src\\computeResult.csv");
}
}

46
algorithm/src/test/java/com/mh/algorithm/knn/knnTest.java

@ -0,0 +1,46 @@
package com.mh.algorithm.knn;
import com.mh.algorithm.matrix.Matrix;
import com.mh.algorithm.utils.CsvInfo;
import com.mh.algorithm.utils.CsvUtil;
import com.mh.algorithm.utils.DoubleUtil;
import org.junit.Test;
/**
* @program: top-algorithm-set
* @description:
* @author: Mr.Zhao
* @create: 2020-10-26 22:04
**/
public class knnTest {
@Test
public void test() throws Exception {
// 训练集
CsvInfo csvInfo = CsvUtil.getCsvInfo(false, "E:\\jarTest\\trainData.csv");
Matrix trainSet = csvInfo.toMatrix();
Matrix trainSetLabels = trainSet.getColOfIdx(trainSet.getMatrixColCount() - 1);
Matrix trainSetData = trainSet.subMatrix(0, trainSet.getMatrixRowCount(), 0, trainSet.getMatrixColCount() - 1);
CsvInfo csvInfo1 = CsvUtil.getCsvInfo(false, "E:\\jarTest\\testData.csv");
Matrix testSet = csvInfo1.toMatrix();
Matrix testSetData = trainSet.subMatrix(0, testSet.getMatrixRowCount(), 0, testSet.getMatrixColCount() - 1);
Matrix testSetLabels = trainSet.getColOfIdx(testSet.getMatrixColCount() - 1);
// 分类
long startTime = System.currentTimeMillis();
Matrix result = KNN.classify(testSetData, trainSetData, trainSetLabels, 5);
long endTime = System.currentTimeMillis();
System.out.println("run time:" + (endTime - startTime));
// 正确率
Matrix error = result.subtract(testSetLabels);
int total = error.getMatrixRowCount();
int correct = 0;
for (int i = 0; i < error.getMatrixRowCount(); i++) {
if (DoubleUtil.equals(error.getValOfIdx(i, 0), 0.0)) {
correct++;
}
}
double correctRate = Double.valueOf(correct) / Double.valueOf(total);
System.out.println("correctRate:"+ correctRate);
}
}

1
pom.xml

@ -11,6 +11,7 @@
<modules>
<module>common</module>
<module>user-service</module>
<module>algorithm</module>
</modules>
<parent>

6
user-service/src/main/java/com/mh/user/controller/ControlSetController.java

@ -46,11 +46,11 @@ public class ControlSetController {
}
//查询设置表
@SysLogger(title="控制设置",optDesc = "查询设置值")
@SysLogger(title="控制设置",optDesc = "查询时控设置值")
@PostMapping(value="/query")
public HttpResult queryControlSet(@RequestParam("buildingId") String buildingId) {
public HttpResult queryControlSet(@RequestParam("buildingId") String buildingId, @RequestParam(value = "timeName",required = false) String timeName) {
try{
ControlSetEntity control=controlSetService.queryControlSet(buildingId);
ControlSetEntity control=controlSetService.queryControlSet(buildingId, timeName);
return HttpResult.ok(control);
}catch (Exception e){
// e.printStackTrace();

4
user-service/src/main/java/com/mh/user/controller/SerialPortController.java

@ -424,9 +424,9 @@ public class SerialPortController {
}
@PostMapping(value = "/control")
public HttpResult queryControlSet(@RequestParam(value = "buildingId") String buildingId) {
public HttpResult queryControlSet(@RequestParam(value = "buildingId") String buildingId, @RequestParam(value = "timeName", required = false) String timeName) {
try {
ControlSetEntity list = controlSetService.queryControlSet(buildingId);
ControlSetEntity list = controlSetService.queryControlSet(buildingId, timeName);
return HttpResult.ok(list);
} catch (Exception e) {
// e.printStackTrace();

1
user-service/src/main/java/com/mh/user/entity/MaintainInfoEntity.java

@ -18,6 +18,7 @@ public class MaintainInfoEntity {
private String maintainPeople; //维护人员
private Double cost; //费用
private String contents; //维保内容
private String evaluate; // 评价分数
}

19
user-service/src/main/java/com/mh/user/mapper/ControlSetMapper.java

@ -60,6 +60,21 @@ public interface ControlSetMapper {
@Result(column = "back_water_temp", property = "backWaterTemp"),
@Result(column = "up_water_temp", property = "upWaterTemp"),
})
@Select("select * from control_Set where building_id=#{buildingId}")
ControlSetEntity queryControlSet(@Param("buildingId") String buildingId);
@Select("select " +
" top 1 " +
" * " +
"from " +
" control_Set " +
"where " +
" building_id = #{buildingId} " +
" and exists ( " +
" select " +
" 1 " +
" from " +
" device_install di " +
" where " +
" di.building_id = building_id " +
" and di.device_name like concat(#{timeName}, '时控') " +
") ")
ControlSetEntity queryControlSet(@Param("buildingId") String buildingId, @Param("timeName") String timeName);
}

8
user-service/src/main/java/com/mh/user/mapper/MaintainInfoMapper.java

@ -18,8 +18,8 @@ public interface MaintainInfoMapper {
* 维修保养信息
* @param maintainInfoEntity
*/
@Insert("insert into maintain_info(cur_date,building_id,device_type,device_addr,maintain_type,maintain_people,cost,contents) values (" +
" getDate(),#{buildingId},#{deviceType},#{deviceAddr},#{maintainType},#{maintainPeople},#{cost},#{contents})")
@Insert("insert into maintain_info(cur_date,building_id,device_type,device_addr,maintain_type,maintain_people,cost,contents, evaluate) values (" +
" getDate(),#{buildingId},#{deviceType},#{deviceAddr},#{maintainType},#{maintainPeople},#{cost},#{contents}, #{evaluate})")
int saveMaintainInfo(MaintainInfoEntity maintainInfoEntity);
/**
@ -36,6 +36,7 @@ public interface MaintainInfoMapper {
" <if test='maintainPeople!=null'> , maintain_people = #{maintainPeople} </if>" +
" <if test='cost!=null'> , cost = #{cost} </if>" +
" <if test='contents!=null'> , contents = #{contents} </if>" +
" <if test='evaluate!=null'> , evaluate = #{evaluate} </if>" +
" where id = #{id} " +
"</script>")
int updateMaintainInfo(MaintainInfoEntity maintainInfoEntity);
@ -61,7 +62,8 @@ public interface MaintainInfoMapper {
@Result(property="maintainPeople",column="maintain_people"),
@Result(property="id",column="id"),
@Result(property="cost",column="cost"),
@Result(property="contents",column="contents")
@Result(property="contents",column="contents"),
@Result(property="evaluate",column="evaluate")
})
List<MaintainInfoEntity> queryMaintainInfo(@Param("curDate") String curDate,
@Param("buildingId") String buildingId,

2
user-service/src/main/java/com/mh/user/service/ControlSetService.java

@ -6,5 +6,5 @@ public interface ControlSetService {
void saveControlSet(ControlSetEntity controlSetEntity);
ControlSetEntity queryControlSet(String buildingId);
ControlSetEntity queryControlSet(String buildingId, String timeName);
}

9
user-service/src/main/java/com/mh/user/service/impl/ControlSetServiceImpl.java

@ -1,5 +1,6 @@
package com.mh.user.service.impl;
import com.mh.common.utils.StringUtils;
import com.mh.user.entity.ControlSetEntity;
import com.mh.user.mapper.ControlSetMapper;
import com.mh.user.service.ControlSetService;
@ -26,8 +27,10 @@ public class ControlSetServiceImpl implements ControlSetService {
}
@Override
public ControlSetEntity queryControlSet(String buildingId) {
return controlSetMapper.queryControlSet(buildingId);
public ControlSetEntity queryControlSet(String buildingId, String timeName) {
if (StringUtils.isBlank(timeName)) {
timeName = "%";
}
return controlSetMapper.queryControlSet(buildingId,timeName);
}
}

6
user-service/src/main/java/com/mh/user/service/impl/DeviceControlServiceImpl.java

@ -11,6 +11,7 @@ import com.mh.user.model.DeviceModel;
import com.mh.user.model.SerialPortModel;
import com.mh.user.serialport.SerialPortSingle2;
import com.mh.user.service.*;
import com.mh.user.utils.ExchangeStringUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@ -385,6 +386,11 @@ public class DeviceControlServiceImpl implements DeviceControlService {
} else {
deviceCodeParam.setRegisterAddr("00240000"); //寄存器地址
}
} else if ("startOrStop".equals(deviceCodeParam.getParam())) {
deviceCodeParam.setFunCode("06"); //功能码读数据
if ("瑞星".equals(deviceCodeParam.getBrand())) {
deviceCodeParam.setRegisterAddr("0001"+ ExchangeStringUtil.addZeroForNum(deviceCodeParam.getDataValue(), 4));
}
}
return rtData;
}

Loading…
Cancel
Save