From f0719ee78d92ba09df8643c25b6f8f63cffe29c3 Mon Sep 17 00:00:00 2001 From: mh Date: Wed, 8 May 2024 18:22:21 +0800 Subject: [PATCH] =?UTF-8?q?1=E3=80=81=E6=B7=BB=E5=8A=A0BP=E7=A5=9E?= =?UTF-8?q?=E7=BB=8F=E7=BD=91=E7=BB=9C=E9=A2=84=E6=B5=8B=E7=AE=97=E6=B3=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- 2024数据库脚本.sql | 38 ++ algorithm/pom.xml | 135 ++++ .../mh/algorithm/bpnn/ActivationFunction.java | 8 + .../java/com/mh/algorithm/bpnn/BPModel.java | 111 +++ .../bpnn/BPNeuralNetworkFactory.java | 257 +++++++ .../com/mh/algorithm/bpnn/BPParameter.java | 106 +++ .../java/com/mh/algorithm/bpnn/Sigmoid.java | 15 + .../com/mh/algorithm/constants/OrderEnum.java | 24 + .../main/java/com/mh/algorithm/knn/KNN.java | 88 +++ .../java/com/mh/algorithm/matrix/Matrix.java | 646 ++++++++++++++++++ .../java/com/mh/algorithm/utils/CsvInfo.java | 53 ++ .../java/com/mh/algorithm/utils/CsvUtil.java | 66 ++ .../com/mh/algorithm/utils/DoubleUtil.java | 20 + .../com/mh/algorithm/utils/MatrixUtil.java | 285 ++++++++ .../mh/algorithm/utils/SerializationUtil.java | 32 + .../java/com/mh/algorithm/bpnn/bpnnTest.java | 71 ++ .../java/com/mh/algorithm/knn/knnTest.java | 46 ++ pom.xml | 1 + .../user/controller/ControlSetController.java | 6 +- .../user/controller/SerialPortController.java | 4 +- .../mh/user/entity/MaintainInfoEntity.java | 1 + .../com/mh/user/mapper/ControlSetMapper.java | 19 +- .../mh/user/mapper/MaintainInfoMapper.java | 8 +- .../mh/user/service/ControlSetService.java | 2 +- .../service/impl/ControlSetServiceImpl.java | 9 +- .../impl/DeviceControlServiceImpl.java | 6 + 26 files changed, 2043 insertions(+), 14 deletions(-) create mode 100644 2024数据库脚本.sql create mode 100644 algorithm/pom.xml create mode 100644 algorithm/src/main/java/com/mh/algorithm/bpnn/ActivationFunction.java create mode 100644 algorithm/src/main/java/com/mh/algorithm/bpnn/BPModel.java create mode 100644 algorithm/src/main/java/com/mh/algorithm/bpnn/BPNeuralNetworkFactory.java create mode 100644 algorithm/src/main/java/com/mh/algorithm/bpnn/BPParameter.java create mode 100644 algorithm/src/main/java/com/mh/algorithm/bpnn/Sigmoid.java create mode 100644 algorithm/src/main/java/com/mh/algorithm/constants/OrderEnum.java create mode 100644 algorithm/src/main/java/com/mh/algorithm/knn/KNN.java create mode 100644 algorithm/src/main/java/com/mh/algorithm/matrix/Matrix.java create mode 100644 algorithm/src/main/java/com/mh/algorithm/utils/CsvInfo.java create mode 100644 algorithm/src/main/java/com/mh/algorithm/utils/CsvUtil.java create mode 100644 algorithm/src/main/java/com/mh/algorithm/utils/DoubleUtil.java create mode 100644 algorithm/src/main/java/com/mh/algorithm/utils/MatrixUtil.java create mode 100644 algorithm/src/main/java/com/mh/algorithm/utils/SerializationUtil.java create mode 100644 algorithm/src/test/java/com/mh/algorithm/bpnn/bpnnTest.java create mode 100644 algorithm/src/test/java/com/mh/algorithm/knn/knnTest.java diff --git a/2024数据库脚本.sql b/2024数据库脚本.sql new file mode 100644 index 0000000..4f0c9a3 --- /dev/null +++ b/2024数据库脚本.sql @@ -0,0 +1,38 @@ +-- 2024-05-07 维修表缺少字段 +ALTER TABLE maintain_info ADD cost numeric(2,0) NULL; +EXEC sys.sp_addextendedproperty 'MS_Description', N'材料费用', 'schema', N'dbo', 'table', N'maintain_info', 'column', N'cost'; +ALTER TABLE maintain_info ADD contents varchar(100) NULL; +EXEC sys.sp_addextendedproperty 'MS_Description', N'维保内容', 'schema', N'dbo', 'table', N'maintain_info', 'column', N'contents'; +ALTER TABLE maintain_info ADD evaluate varchar(10) NULL; +EXEC sys.sp_addextendedproperty 'MS_Description', N'评价内容', 'schema', N'dbo', 'table', N'maintain_info', 'column', N'evaluate'; + +-- 训练集合: +select + eds.cur_date, + eds.building_id, + isnull(eds.water_value, + 0) as water_value, + isnull(eds.elect_value, + 0) as elect_value, + isnull(convert(numeric(24,2),t1.water_level), + 0) as water_level +from + energy_day_sum eds + left join ( + select + convert(date, + cur_date) as cur_date, + building_id, + avg(isnull(convert(numeric(24, 2), water_level), 0)) as water_level + from + history_data + group by + convert(date, + cur_date), + building_id + ) t1 on + eds.cur_date = t1.cur_date and eds.building_id = t1.building_id +where eds.building_id != '所有' +order by + eds.building_id, + eds.cur_date diff --git a/algorithm/pom.xml b/algorithm/pom.xml new file mode 100644 index 0000000..d7aab46 --- /dev/null +++ b/algorithm/pom.xml @@ -0,0 +1,135 @@ + + + + com.mh + chws + 1.0-SNAPSHOT + + 4.0.0 + + com.mh + algorithm + 1.0.0 + jar + + + UTF-8 + UTF-8 + 1.8 + 1.8 + 1.8 + + + + + net.sourceforge.javacsv + javacsv + 2.0 + + + gov.nist.math + jama + 1.0.3 + + + junit + junit + RELEASE + test + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.1 + + 1.8 + 1.8 + + + + + + + + default + + true + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.8.0 + + 1.8 + 1.8 + UTF-8 + + + + + org.apache.maven.plugins + maven-javadoc-plugin + 2.9.1 + + + attach-javadocs + + jar + + + + -Xdoclint:none + + + + + + + org.apache.maven.plugins + maven-source-plugin + 2.3 + + + attach-sources + + jar + + + + + + + org.apache.maven.plugins + maven-gpg-plugin + 1.4 + + + sign-artifacts + verify + + sign + + + + + + maven-jar-plugin + 2.3.1 + + target/classes + + + + + + + diff --git a/algorithm/src/main/java/com/mh/algorithm/bpnn/ActivationFunction.java b/algorithm/src/main/java/com/mh/algorithm/bpnn/ActivationFunction.java new file mode 100644 index 0000000..fff15fe --- /dev/null +++ b/algorithm/src/main/java/com/mh/algorithm/bpnn/ActivationFunction.java @@ -0,0 +1,8 @@ +package com.mh.algorithm.bpnn; + +public interface ActivationFunction { + //计算值 + double computeValue(double val); + //计算导数 + double computeDerivative(double val); +} diff --git a/algorithm/src/main/java/com/mh/algorithm/bpnn/BPModel.java b/algorithm/src/main/java/com/mh/algorithm/bpnn/BPModel.java new file mode 100644 index 0000000..719f8fb --- /dev/null +++ b/algorithm/src/main/java/com/mh/algorithm/bpnn/BPModel.java @@ -0,0 +1,111 @@ +package com.mh.algorithm.bpnn; + +import com.mh.algorithm.matrix.Matrix; + +import java.io.Serializable; + +public class BPModel implements Serializable { + //BP神经网络权值与阈值 + private Matrix weightIJ; + private Matrix b1; + private Matrix weightJP; + private Matrix b2; + /*用于反归一化*/ + private Matrix inputMax; + private Matrix inputMin; + private Matrix outputMax; + private Matrix outputMin; + /*BP神经网络训练参数*/ + private BPParameter bpParameter; + /*BP神经网络训练情况*/ + private double error; + private int times; + + public Matrix getWeightIJ() { + return weightIJ; + } + + public void setWeightIJ(Matrix weightIJ) { + this.weightIJ = weightIJ; + } + + public Matrix getB1() { + return b1; + } + + public void setB1(Matrix b1) { + this.b1 = b1; + } + + public Matrix getWeightJP() { + return weightJP; + } + + public void setWeightJP(Matrix weightJP) { + this.weightJP = weightJP; + } + + public Matrix getB2() { + return b2; + } + + public void setB2(Matrix b2) { + this.b2 = b2; + } + + public Matrix getInputMax() { + return inputMax; + } + + public void setInputMax(Matrix inputMax) { + this.inputMax = inputMax; + } + + public Matrix getInputMin() { + return inputMin; + } + + public void setInputMin(Matrix inputMin) { + this.inputMin = inputMin; + } + + public Matrix getOutputMax() { + return outputMax; + } + + public void setOutputMax(Matrix outputMax) { + this.outputMax = outputMax; + } + + public Matrix getOutputMin() { + return outputMin; + } + + public void setOutputMin(Matrix outputMin) { + this.outputMin = outputMin; + } + + public BPParameter getBpParameter() { + return bpParameter; + } + + public void setBpParameter(BPParameter bpParameter) { + this.bpParameter = bpParameter; + } + + public double getError() { + return error; + } + + public void setError(double error) { + this.error = error; + } + + public int getTimes() { + return times; + } + + public void setTimes(int times) { + this.times = times; + } +} diff --git a/algorithm/src/main/java/com/mh/algorithm/bpnn/BPNeuralNetworkFactory.java b/algorithm/src/main/java/com/mh/algorithm/bpnn/BPNeuralNetworkFactory.java new file mode 100644 index 0000000..3c6c8c8 --- /dev/null +++ b/algorithm/src/main/java/com/mh/algorithm/bpnn/BPNeuralNetworkFactory.java @@ -0,0 +1,257 @@ +package com.mh.algorithm.bpnn; + +import com.mh.algorithm.matrix.Matrix; +import com.mh.algorithm.utils.MatrixUtil; + +import java.util.*; + +public class BPNeuralNetworkFactory { + /** + * 训练BP神经网络模型 + * @param bpParameter + * @param inputAndOutput + * @return + */ + public BPModel trainBP(BPParameter bpParameter, Matrix inputAndOutput) throws Exception { + + ActivationFunction activationFunction = bpParameter.getActivationFunction(); + int inputCount = bpParameter.getInputLayerNeuronCount(); + int hiddenCount = bpParameter.getHiddenLayerNeuronCount(); + int outputCount = bpParameter.getOutputLayerNeuronCount(); + double normalizationMin = bpParameter.getNormalizationMin(); + double normalizationMax = bpParameter.getNormalizationMax(); + double step = bpParameter.getStep(); + double momentumFactor = bpParameter.getMomentumFactor(); + double precision = bpParameter.getPrecision(); + int maxTimes = bpParameter.getMaxTimes(); + + if(inputAndOutput.getMatrixColCount() != inputCount + outputCount){ + throw new Exception("神经元个数不符,请修改"); + } + // 初始化权值 + Matrix weightIJ = initWeight(inputCount, hiddenCount); + Matrix weightJP = initWeight(hiddenCount, outputCount); + + // 初始化阈值 + Matrix b1 = initThreshold(hiddenCount); + Matrix b2 = initThreshold(outputCount); + + // 动量项 + Matrix deltaWeightIJ0 = new Matrix(inputCount, hiddenCount); + Matrix deltaWeightJP0 = new Matrix(hiddenCount, outputCount); + Matrix deltaB10 = new Matrix(1, hiddenCount); + Matrix deltaB20 = new Matrix(1, outputCount); + + // 截取输入矩阵和输出矩阵 + Matrix input = inputAndOutput.subMatrix(0,inputAndOutput.getMatrixRowCount(),0,inputCount); + Matrix output = inputAndOutput.subMatrix(0,inputAndOutput.getMatrixRowCount(),inputCount,outputCount); + + // 归一化 + Map inputAfterNormalize = MatrixUtil.normalize(input, normalizationMin, normalizationMax); + input = (Matrix) inputAfterNormalize.get("res"); + + Map outputAfterNormalize = MatrixUtil.normalize(output, normalizationMin, normalizationMax); + output = (Matrix) outputAfterNormalize.get("res"); + + int times = 1; + double E = 0;//误差 + while (times < maxTimes) { + /*-----------------正向传播---------------------*/ + // 隐含层输入 + Matrix jIn = input.multiple(weightIJ); + // 扩充阈值 + Matrix b1Copy = b1.extend(2,jIn.getMatrixRowCount()); + // 加上阈值 + jIn = jIn.plus(b1Copy); + // 隐含层输出 + Matrix jOut = computeValue(jIn,activationFunction); + // 输出层输入 + Matrix pIn = jOut.multiple(weightJP); + // 扩充阈值 + Matrix b2Copy = b2.extend(2, pIn.getMatrixRowCount()); + // 加上阈值 + pIn = pIn.plus(b2Copy); + // 输出层输出 + Matrix pOut = computeValue(pIn,activationFunction); + // 计算误差 + Matrix e = output.subtract(pOut); + E = computeE(e);//误差 + // 判断是否符合精度 + if (Math.abs(E) <= precision) { + System.out.println("满足精度"); + break; + } + + /*-----------------反向传播---------------------*/ + // J与P之间权值修正量 + Matrix deltaWeightJP = e.multiple(step); + deltaWeightJP = deltaWeightJP.pointMultiple(computeDerivative(pIn,activationFunction)); + deltaWeightJP = deltaWeightJP.transpose().multiple(jOut); + deltaWeightJP = deltaWeightJP.transpose(); + // P层神经元阈值修正量 + Matrix deltaThresholdP = e.multiple(step); + deltaThresholdP = deltaThresholdP.transpose().multiple(computeDerivative(pIn, activationFunction)); + + // I与J之间的权值修正量 + Matrix deltaO = e.pointMultiple(computeDerivative(pIn,activationFunction)); + Matrix tmp = weightJP.multiple(deltaO.transpose()).transpose(); + Matrix deltaWeightIJ = tmp.pointMultiple(computeDerivative(jIn, activationFunction)); + deltaWeightIJ = input.transpose().multiple(deltaWeightIJ); + deltaWeightIJ = deltaWeightIJ.multiple(step); + + // J层神经元阈值修正量 + Matrix deltaThresholdJ = tmp.transpose().multiple(computeDerivative(jIn, activationFunction)); + deltaThresholdJ = deltaThresholdJ.multiple(-step); + + if (times == 1) { + // 更新权值与阈值 + weightIJ = weightIJ.plus(deltaWeightIJ); + weightJP = weightJP.plus(deltaWeightJP); + b1 = b1.plus(deltaThresholdJ); + b2 = b2.plus(deltaThresholdP); + }else{ + // 加动量项 + weightIJ = weightIJ.plus(deltaWeightIJ).plus(deltaWeightIJ0.multiple(momentumFactor)); + weightJP = weightJP.plus(deltaWeightJP).plus(deltaWeightJP0.multiple(momentumFactor)); + b1 = b1.plus(deltaThresholdJ).plus(deltaB10.multiple(momentumFactor)); + b2 = b2.plus(deltaThresholdP).plus(deltaB20.multiple(momentumFactor)); + } + + deltaWeightIJ0 = deltaWeightIJ; + deltaWeightJP0 = deltaWeightJP; + deltaB10 = deltaThresholdJ; + deltaB20 = deltaThresholdP; + + times++; + } + + // BP神经网络的输出 + BPModel result = new BPModel(); + result.setInputMax((Matrix) inputAfterNormalize.get("max")); + result.setInputMin((Matrix) inputAfterNormalize.get("min")); + result.setOutputMax((Matrix) outputAfterNormalize.get("max")); + result.setOutputMin((Matrix) outputAfterNormalize.get("min")); + result.setWeightIJ(weightIJ); + result.setWeightJP(weightJP); + result.setB1(b1); + result.setB2(b2); + result.setError(E); + result.setTimes(times); + result.setBpParameter(bpParameter); + System.out.println("循环次数:" + times + ",误差:" + E); + + return result; + } + + /** + * 计算BP神经网络的值 + * @param bpModel + * @param input + * @return + */ + public Matrix computeBP(BPModel bpModel,Matrix input) throws Exception { + if (input.getMatrixColCount() != bpModel.getBpParameter().getInputLayerNeuronCount()) { + throw new Exception("输入矩阵纬度有误"); + } + ActivationFunction activationFunction = bpModel.getBpParameter().getActivationFunction(); + Matrix weightIJ = bpModel.getWeightIJ(); + Matrix weightJP = bpModel.getWeightJP(); + Matrix b1 = bpModel.getB1(); + Matrix b2 = bpModel.getB2(); + double[][] normalizedInput = new double[input.getMatrixRowCount()][input.getMatrixColCount()]; + for (int i = 0; i < input.getMatrixRowCount(); i++) { + for (int j = 0; j < input.getMatrixColCount(); j++) { + normalizedInput[i][j] = bpModel.getBpParameter().getNormalizationMin() + + (input.getValOfIdx(i,j) - bpModel.getInputMin().getValOfIdx(0,j)) + / (bpModel.getInputMax().getValOfIdx(0,j) - bpModel.getInputMin().getValOfIdx(0,j)) + * (bpModel.getBpParameter().getNormalizationMax() - bpModel.getBpParameter().getNormalizationMin()); + } + } + Matrix normalizedInputMatrix = new Matrix(normalizedInput); + Matrix jIn = normalizedInputMatrix.multiple(weightIJ); + // 扩充阈值 + Matrix b1Copy = b1.extend(2,jIn.getMatrixRowCount()); + // 加上阈值 + jIn = jIn.plus(b1Copy); + // 隐含层输出 + Matrix jOut = computeValue(jIn,activationFunction); + // 输出层输入 + Matrix pIn = jOut.multiple(weightJP); + // 扩充阈值 + Matrix b2Copy = b2.extend(2,pIn.getMatrixRowCount()); + // 加上阈值 + pIn = pIn.plus(b2Copy); + // 输出层输出 + Matrix pOut = computeValue(pIn,activationFunction); + // 反归一化 + return MatrixUtil.inverseNormalize(pOut, bpModel.getBpParameter().getNormalizationMax(), bpModel.getBpParameter().getNormalizationMin(), bpModel.getOutputMax(), bpModel.getOutputMin()); + } + + // 初始化权值 + private Matrix initWeight(int x,int y){ + Random random=new Random(); + double[][] weight = new double[x][y]; + for (int i = 0; i < x; i++) { + for (int j = 0; j < y; j++) { + weight[i][j] = 2*random.nextDouble()-1; + } + } + return new Matrix(weight); + } + // 初始化阈值 + private Matrix initThreshold(int x){ + Random random = new Random(); + double[][] result = new double[1][x]; + for (int i = 0; i < x; i++) { + result[0][i] = 2*random.nextDouble()-1; + } + return new Matrix(result); + } + + /** + * 计算激活函数的值 + * @param a + * @return + */ + private Matrix computeValue(Matrix a, ActivationFunction activationFunction) throws Exception { + if (a.getMatrix() == null) { + throw new Exception("参数值为空"); + } + double[][] result = new double[a.getMatrixRowCount()][a.getMatrixColCount()]; + for (int i = 0; i < a.getMatrixRowCount(); i++) { + for (int j = 0; j < a.getMatrixColCount(); j++) { + result[i][j] = activationFunction.computeValue(a.getValOfIdx(i,j)); + } + } + return new Matrix(result); + } + + /** + * 激活函数导数的值 + * @param a + * @return + */ + private Matrix computeDerivative(Matrix a , ActivationFunction activationFunction) throws Exception { + if (a.getMatrix() == null) { + throw new Exception("参数值为空"); + } + double[][] result = new double[a.getMatrixRowCount()][a.getMatrixColCount()]; + for (int i = 0; i < a.getMatrixRowCount(); i++) { + for (int j = 0; j < a.getMatrixColCount(); j++) { + result[i][j] = activationFunction.computeDerivative(a.getValOfIdx(i,j)); + } + } + return new Matrix(result); + } + + + /** + * 计算误差 + * @param e + * @return + */ + private double computeE(Matrix e){ + e = e.square(); + return 0.5*e.sumAll(); + } +} diff --git a/algorithm/src/main/java/com/mh/algorithm/bpnn/BPParameter.java b/algorithm/src/main/java/com/mh/algorithm/bpnn/BPParameter.java new file mode 100644 index 0000000..7196136 --- /dev/null +++ b/algorithm/src/main/java/com/mh/algorithm/bpnn/BPParameter.java @@ -0,0 +1,106 @@ +package com.mh.algorithm.bpnn; + +import java.io.Serializable; + +public class BPParameter implements Serializable { + + //输入层神经元个数 + private int inputLayerNeuronCount = 3; + //隐含层神经元个数 + private int hiddenLayerNeuronCount = 3; + //输出层神经元个数 + private int outputLayerNeuronCount = 1; + //归一化区间 + private double normalizationMin = 0.2; + private double normalizationMax = 0.8; + //学习步长 + private double step = 0.05; + //动量因子 + private double momentumFactor = 0.2; + //激活函数 + private ActivationFunction activationFunction = new Sigmoid(); + //精度 + private double precision = 0.000001; + //最大循环次数 + private int maxTimes = 1000000; + + public double getMomentumFactor() { + return momentumFactor; + } + + public void setMomentumFactor(double momentumFactor) { + this.momentumFactor = momentumFactor; + } + + public double getStep() { + return step; + } + + public void setStep(double step) { + this.step = step; + } + + public double getNormalizationMin() { + return normalizationMin; + } + + public void setNormalizationMin(double normalizationMin) { + this.normalizationMin = normalizationMin; + } + + public double getNormalizationMax() { + return normalizationMax; + } + + public void setNormalizationMax(double normalizationMax) { + this.normalizationMax = normalizationMax; + } + + public int getInputLayerNeuronCount() { + return inputLayerNeuronCount; + } + + public void setInputLayerNeuronCount(int inputLayerNeuronCount) { + this.inputLayerNeuronCount = inputLayerNeuronCount; + } + + public int getHiddenLayerNeuronCount() { + return hiddenLayerNeuronCount; + } + + public void setHiddenLayerNeuronCount(int hiddenLayerNeuronCount) { + this.hiddenLayerNeuronCount = hiddenLayerNeuronCount; + } + + public int getOutputLayerNeuronCount() { + return outputLayerNeuronCount; + } + + public void setOutputLayerNeuronCount(int outputLayerNeuronCount) { + this.outputLayerNeuronCount = outputLayerNeuronCount; + } + + public ActivationFunction getActivationFunction() { + return activationFunction; + } + + public void setActivationFunction(ActivationFunction activationFunction) { + this.activationFunction = activationFunction; + } + + public double getPrecision() { + return precision; + } + + public void setPrecision(double precision) { + this.precision = precision; + } + + public int getMaxTimes() { + return maxTimes; + } + + public void setMaxTimes(int maxTimes) { + this.maxTimes = maxTimes; + } +} diff --git a/algorithm/src/main/java/com/mh/algorithm/bpnn/Sigmoid.java b/algorithm/src/main/java/com/mh/algorithm/bpnn/Sigmoid.java new file mode 100644 index 0000000..378060c --- /dev/null +++ b/algorithm/src/main/java/com/mh/algorithm/bpnn/Sigmoid.java @@ -0,0 +1,15 @@ +package com.mh.algorithm.bpnn; + +import java.io.Serializable; + +public class Sigmoid implements ActivationFunction, Serializable { + @Override + public double computeValue(double val) { + return 1 / (1 + Math.exp(-val)); + } + + @Override + public double computeDerivative(double val) { + return computeValue(val) * (1 - computeValue(val)); + } +} diff --git a/algorithm/src/main/java/com/mh/algorithm/constants/OrderEnum.java b/algorithm/src/main/java/com/mh/algorithm/constants/OrderEnum.java new file mode 100644 index 0000000..1e47167 --- /dev/null +++ b/algorithm/src/main/java/com/mh/algorithm/constants/OrderEnum.java @@ -0,0 +1,24 @@ +package com.mh.algorithm.constants; + +/** + * 排序枚举类 + */ +public enum OrderEnum { + + ASC(1,"升序"), + + DESC(2,"降序"); + + OrderEnum(int flag, String name) { + + this.flag = flag; + + this.name = name; + + } + + private int flag; + + private String name; + +} diff --git a/algorithm/src/main/java/com/mh/algorithm/knn/KNN.java b/algorithm/src/main/java/com/mh/algorithm/knn/KNN.java new file mode 100644 index 0000000..fd6fbdb --- /dev/null +++ b/algorithm/src/main/java/com/mh/algorithm/knn/KNN.java @@ -0,0 +1,88 @@ +package com.mh.algorithm.knn; + +import com.mh.algorithm.constants.OrderEnum; +import com.mh.algorithm.matrix.Matrix; +import com.mh.algorithm.utils.MatrixUtil; + +import java.util.*; + + +/** + * @program: top-algorithm-set + * @description: KNN k-临近算法进行分类 + * @author: Mr.Zhao + * @create: 2020-10-13 22:03 + **/ +public class KNN { + public static Matrix classify(Matrix input, Matrix dataSet, Matrix labels, int k) throws Exception { + if (dataSet.getMatrixRowCount() != labels.getMatrixRowCount()) { + throw new IllegalArgumentException("矩阵训练集与标签维度不一致"); + } + if (input.getMatrixColCount() != dataSet.getMatrixColCount()) { + throw new IllegalArgumentException("待分类矩阵列数与训练集列数不一致"); + } + if (dataSet.getMatrixRowCount() < k) { + throw new IllegalArgumentException("训练集样本数小于k"); + } + // 归一化 + int trainCount = dataSet.getMatrixRowCount(); + int testCount = input.getMatrixRowCount(); + Matrix trainAndTest = dataSet.splice(2, input); + Map normalize = MatrixUtil.normalize(trainAndTest, 0, 1); + trainAndTest = (Matrix) normalize.get("res"); + dataSet = trainAndTest.subMatrix(0, trainCount, 0, trainAndTest.getMatrixColCount()); + input = trainAndTest.subMatrix(0, testCount, 0, trainAndTest.getMatrixColCount()); + + // 获取标签信息 + List labelList = new ArrayList<>(); + for (int i = 0; i < labels.getMatrixRowCount(); i++) { + if (!labelList.contains(labels.getValOfIdx(i, 0))) { + labelList.add(labels.getValOfIdx(i, 0)); + } + } + + Matrix result = new Matrix(new double[input.getMatrixRowCount()][1]); + for (int i = 0; i < input.getMatrixRowCount(); i++) { + // 计算向量间的欧式距离 + // 将labels矩阵扩展 + Matrix labelMatrixCopied = input.getRowOfIdx(i).extend(2, dataSet.getMatrixRowCount()); + // 前面是计算欧氏距离,splice(1,labels)是将距离矩阵与labels矩阵合并 + Matrix distanceMatrix = dataSet.subtract(labelMatrixCopied).square().sumRow().pow(0.5).splice(1, labels); + // 将计算出的距离矩阵按照距离升序排序 + distanceMatrix.sort(0, OrderEnum.ASC); + // 遍历最近的k个变量 + Map map = new HashMap<>(); + for (int j = 0; j < k; j++) { + // 遍历标签种类数 + for (Double label : labelList) { + if (distanceMatrix.getValOfIdx(j, 1) == label) { + map.put(label, map.getOrDefault(label, 0) + 1); + } + } + } + result.setValue(i, 0, getKeyOfMaxValue(map)); + } + return result; + } + + /** + * 取map中值最大的key + * + * @param map + * @return + */ + private static Double getKeyOfMaxValue(Map map) { + if (map == null) + return null; + Double keyOfMaxValue = 0.0; + Integer maxValue = 0; + for (Double key : map.keySet()) { + if (map.get(key) > maxValue) { + keyOfMaxValue = key; + maxValue = map.get(key); + } + } + return keyOfMaxValue; + } + +} diff --git a/algorithm/src/main/java/com/mh/algorithm/matrix/Matrix.java b/algorithm/src/main/java/com/mh/algorithm/matrix/Matrix.java new file mode 100644 index 0000000..4820012 --- /dev/null +++ b/algorithm/src/main/java/com/mh/algorithm/matrix/Matrix.java @@ -0,0 +1,646 @@ +package com.mh.algorithm.matrix; + +import com.mh.algorithm.constants.OrderEnum; + +import java.io.Serializable; + +public class Matrix implements Serializable { + private double[][] matrix; + //矩阵列数 + private int matrixColCount; + //矩阵行数 + private int matrixRowCount; + + /** + * 构造一个空矩阵 + */ + public Matrix() { + this.matrix = null; + this.matrixColCount = 0; + this.matrixRowCount = 0; + } + + /** + * 构造一个matrix矩阵 + * @param matrix + */ + public Matrix(double[][] matrix) { + this.matrix = matrix; + this.matrixRowCount = matrix.length; + this.matrixColCount = matrix[0].length; + } + + /** + * 构造一个rowCount行colCount列值为0的矩阵 + * @param rowCount + * @param colCount + */ + public Matrix(int rowCount,int colCount) { + double[][] matrix = new double[rowCount][colCount]; + for (int i = 0; i < rowCount; i++) { + for (int j = 0; j < colCount; j++) { + matrix[i][j] = 0; + } + } + this.matrix = matrix; + this.matrixRowCount = rowCount; + this.matrixColCount = colCount; + } + + /** + * 构造一个rowCount行colCount列值为val的矩阵 + * @param val + * @param rowCount + * @param colCount + */ + public Matrix(double val,int rowCount,int colCount) { + double[][] matrix = new double[rowCount][colCount]; + for (int i = 0; i < rowCount; i++) { + for (int j = 0; j < colCount; j++) { + matrix[i][j] = val; + } + } + this.matrix = matrix; + this.matrixRowCount = rowCount; + this.matrixColCount = colCount; + } + + public double[][] getMatrix() { + return matrix; + } + + public void setMatrix(double[][] matrix) { + this.matrix = matrix; + this.matrixRowCount = matrix.length; + this.matrixColCount = matrix[0].length; + } + + public int getMatrixColCount() { + return matrixColCount; + } + + public int getMatrixRowCount() { + return matrixRowCount; + } + + /** + * 获取矩阵指定位置的值 + * + * @param x + * @param y + * @return + */ + public double getValOfIdx(int x, int y) throws IllegalArgumentException { + if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) { + throw new IllegalArgumentException("矩阵为空"); + } + if (x > matrixRowCount - 1) { + throw new IllegalArgumentException("索引x越界"); + } + if (y > matrixColCount - 1) { + throw new IllegalArgumentException("索引y越界"); + } + return matrix[x][y]; + } + + /** + * 获取矩阵指定行 + * + * @param x + * @return + */ + public Matrix getRowOfIdx(int x) throws IllegalArgumentException { + if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) { + throw new IllegalArgumentException("矩阵为空"); + } + if (x > matrixRowCount - 1) { + throw new IllegalArgumentException("索引x越界"); + } + double[][] result = new double[1][matrixColCount]; + result[0] = matrix[x]; + return new Matrix(result); + } + + /** + * 获取矩阵指定列 + * + * @param y + * @return + */ + public Matrix getColOfIdx(int y) throws IllegalArgumentException { + if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) { + throw new IllegalArgumentException("矩阵为空"); + } + if (y > matrixColCount - 1) { + throw new IllegalArgumentException("索引y越界"); + } + double[][] result = new double[matrixRowCount][1]; + for (int i = 0; i < matrixRowCount; i++) { + result[i][0] = matrix[i][y]; + } + return new Matrix(result); + } + + /** + * 设置矩阵中x,y位置元素的值 + * @param x + * @param y + * @param val + */ + public void setValue(int x, int y, double val) { + if (x > this.matrixRowCount - 1) { + throw new IllegalArgumentException("行索引越界"); + } + if (y > this.matrixColCount - 1) { + throw new IllegalArgumentException("列索引越界"); + } + this.matrix[x][y] = val; + } + + /** + * 矩阵乘矩阵 + * + * @param a + * @return + * @throws IllegalArgumentException + */ + public Matrix multiple(Matrix a) throws IllegalArgumentException { + if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) { + throw new IllegalArgumentException("矩阵为空"); + } + if (a.getMatrix() == null || a.getMatrixRowCount() == 0 || a.getMatrixColCount() == 0) { + throw new IllegalArgumentException("参数矩阵为空"); + } + if (matrixColCount != a.getMatrixRowCount()) { + throw new IllegalArgumentException("矩阵纬度不同,不可计算"); + } + double[][] result = new double[matrixRowCount][a.getMatrixColCount()]; + for (int i = 0; i < matrixRowCount; i++) { + for (int j = 0; j < a.getMatrixColCount(); j++) { + for (int k = 0; k < matrixColCount; k++) { + result[i][j] = result[i][j] + matrix[i][k] * a.getMatrix()[k][j]; + } + } + } + return new Matrix(result); + } + + /** + * 矩阵乘一个数字 + * + * @param a + * @return + */ + public Matrix multiple(double a) throws IllegalArgumentException { + if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) { + throw new IllegalArgumentException("矩阵为空"); + } + double[][] result = new double[matrixRowCount][matrixColCount]; + for (int i = 0; i < matrixRowCount; i++) { + for (int j = 0; j < matrixColCount; j++) { + result[i][j] = matrix[i][j] * a; + } + } + return new Matrix(result); + } + + /** + * 矩阵点乘 + * + * @param a + * @return + */ + public Matrix pointMultiple(Matrix a) throws IllegalArgumentException { + if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) { + throw new IllegalArgumentException("矩阵为空"); + } + if (a.getMatrix() == null || a.getMatrixRowCount() == 0 || a.getMatrixColCount() == 0) { + throw new IllegalArgumentException("参数矩阵为空"); + } + if (matrixRowCount != a.getMatrixRowCount() && matrixColCount != a.getMatrixColCount()) { + throw new IllegalArgumentException("矩阵纬度不同,不可计算"); + } + double[][] result = new double[matrixRowCount][matrixColCount]; + for (int i = 0; i < matrixRowCount; i++) { + for (int j = 0; j < matrixColCount; j++) { + result[i][j] = matrix[i][j] * a.getMatrix()[i][j]; + } + } + return new Matrix(result); + } + + /** + * 矩阵除一个数字 + * @param a + * @return + * @throws IllegalArgumentException + */ + public Matrix divide(double a) throws IllegalArgumentException { + if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) { + throw new IllegalArgumentException("矩阵为空"); + } + double[][] result = new double[matrixRowCount][matrixColCount]; + for (int i = 0; i < matrixRowCount; i++) { + for (int j = 0; j < matrixColCount; j++) { + result[i][j] = matrix[i][j] / a; + } + } + return new Matrix(result); + } + + /** + * 矩阵加法 + * + * @param a + * @return + */ + public Matrix plus(Matrix a) throws IllegalArgumentException { + if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) { + throw new IllegalArgumentException("矩阵为空"); + } + if (a.getMatrix() == null || a.getMatrixRowCount() == 0 || a.getMatrixColCount() == 0) { + throw new IllegalArgumentException("参数矩阵为空"); + } + if (matrixRowCount != a.getMatrixRowCount() && matrixColCount != a.getMatrixColCount()) { + throw new IllegalArgumentException("矩阵纬度不同,不可计算"); + } + double[][] result = new double[matrixRowCount][matrixColCount]; + for (int i = 0; i < matrixRowCount; i++) { + for (int j = 0; j < matrixColCount; j++) { + result[i][j] = matrix[i][j] + a.getMatrix()[i][j]; + } + } + return new Matrix(result); + } + + /** + * 矩阵加一个数字 + * @param a + * @return + * @throws IllegalArgumentException + */ + public Matrix plus(double a) throws IllegalArgumentException { + if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) { + throw new IllegalArgumentException("矩阵为空"); + } + double[][] result = new double[matrixRowCount][matrixColCount]; + for (int i = 0; i < matrixRowCount; i++) { + for (int j = 0; j < matrixColCount; j++) { + result[i][j] = matrix[i][j] + a; + } + } + return new Matrix(result); + } + + /** + * 矩阵减法 + * + * @param a + * @return + */ + public Matrix subtract(Matrix a) throws IllegalArgumentException { + if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) { + throw new IllegalArgumentException("矩阵为空"); + } + if (a.getMatrix() == null || a.getMatrixRowCount() == 0 || a.getMatrixColCount() == 0) { + throw new IllegalArgumentException("参数矩阵为空"); + } + if (matrixRowCount != a.getMatrixRowCount() && matrixColCount != a.getMatrixColCount()) { + throw new IllegalArgumentException("矩阵纬度不同,不可计算"); + } + double[][] result = new double[matrixRowCount][matrixColCount]; + for (int i = 0; i < matrixRowCount; i++) { + for (int j = 0; j < matrixColCount; j++) { + result[i][j] = matrix[i][j] - a.getMatrix()[i][j]; + } + } + return new Matrix(result); + } + + /** + * 矩阵减一个数字 + * @param a + * @return + * @throws IllegalArgumentException + */ + public Matrix subtract(double a) throws IllegalArgumentException { + if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) { + throw new IllegalArgumentException("矩阵为空"); + } + double[][] result = new double[matrixRowCount][matrixColCount]; + for (int i = 0; i < matrixRowCount; i++) { + for (int j = 0; j < matrixColCount; j++) { + result[i][j] = matrix[i][j] - a; + } + } + return new Matrix(result); + } + + /** + * 矩阵行求和 + * + * @return + */ + public Matrix sumRow() throws IllegalArgumentException { + if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) { + throw new IllegalArgumentException("矩阵为空"); + } + double[][] result = new double[matrixRowCount][1]; + for (int i = 0; i < matrixRowCount; i++) { + for (int j = 0; j < matrixColCount; j++) { + result[i][0] += matrix[i][j]; + } + } + return new Matrix(result); + } + + /** + * 矩阵列求和 + * + * @return + */ + public Matrix sumCol() throws IllegalArgumentException { + if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) { + throw new IllegalArgumentException("矩阵为空"); + } + double[][] result = new double[1][matrixColCount]; + for (int i = 0; i < matrixRowCount; i++) { + for (int j = 0; j < matrixColCount; j++) { + result[0][j] += matrix[i][j]; + } + } + return new Matrix(result); + } + + /** + * 矩阵所有元素求和 + * + * @return + */ + public double sumAll() throws IllegalArgumentException { + if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) { + throw new IllegalArgumentException("矩阵为空"); + } + double result = 0; + for (double[] doubles : matrix) { + for (int j = 0; j < matrixColCount; j++) { + result += doubles[j]; + } + } + return result; + } + + /** + * 矩阵所有元素求平方 + * + * @return + */ + public Matrix square() throws IllegalArgumentException { + if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) { + throw new IllegalArgumentException("矩阵为空"); + } + double[][] result = new double[matrixRowCount][matrixColCount]; + for (int i = 0; i < matrixRowCount; i++) { + for (int j = 0; j < matrixColCount; j++) { + result[i][j] = matrix[i][j] * matrix[i][j]; + } + } + return new Matrix(result); + } + + /** + * 矩阵所有元素求N次方 + * + * @return + */ + public Matrix pow(double n) throws IllegalArgumentException { + if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) { + throw new IllegalArgumentException("矩阵为空"); + } + double[][] result = new double[matrixRowCount][matrixColCount]; + for (int i = 0; i < matrixRowCount; i++) { + for (int j = 0; j < matrixColCount; j++) { + result[i][j] = Math.pow(matrix[i][j],n); + } + } + return new Matrix(result); + } + + /** + * 矩阵转置 + * + * @return + */ + public Matrix transpose() throws IllegalArgumentException { + if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) { + throw new IllegalArgumentException("矩阵为空"); + } + double[][] result = new double[matrixColCount][matrixRowCount]; + for (int i = 0; i < matrixRowCount; i++) { + for (int j = 0; j < matrixColCount; j++) { + result[j][i] = matrix[i][j]; + } + } + return new Matrix(result); + } + + /** + * 截取矩阵 + * @param startRowIndex 开始行索引 + * @param rowCount 截取行数 + * @param startColIndex 开始列索引 + * @param colCount 截取列数 + * @return + * @throws IllegalArgumentException + */ + public Matrix subMatrix(int startRowIndex,int rowCount,int startColIndex,int colCount) throws IllegalArgumentException { + if (startRowIndex + rowCount > matrixRowCount) { + throw new IllegalArgumentException("行索引越界"); + } + if (startColIndex + colCount> matrixColCount) { + throw new IllegalArgumentException("列索引越界"); + } + double[][] result = new double[rowCount][colCount]; + for (int i = startRowIndex; i < startRowIndex + rowCount; i++) { + if (startColIndex + colCount - startColIndex >= 0) + System.arraycopy(matrix[i], startColIndex, result[i - startRowIndex], 0, colCount); + } + return new Matrix(result); + } + + /** + * 矩阵合并 + * @param direction 合并方向,1为横向,2为竖向 + * @param a + * @return + * @throws IllegalArgumentException + */ + public Matrix splice(int direction, Matrix a) throws IllegalArgumentException { + if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) { + throw new IllegalArgumentException("矩阵为空"); + } + if (a.getMatrix() == null || a.getMatrixRowCount() == 0 || a.getMatrixColCount() == 0) { + throw new IllegalArgumentException("参数矩阵为空"); + } + if(direction == 1){ + //横向拼接 + if (matrixRowCount != a.getMatrixRowCount()) { + throw new IllegalArgumentException("矩阵行数不一致,无法拼接"); + } + double[][] result = new double[matrixRowCount][matrixColCount + a.getMatrixColCount()]; + for (int i = 0; i < matrixRowCount; i++) { + System.arraycopy(matrix[i],0,result[i],0,matrixColCount); + System.arraycopy(a.getMatrix()[i],0,result[i],matrixColCount,a.getMatrixColCount()); + } + return new Matrix(result); + }else if(direction == 2){ + //纵向拼接 + if (matrixColCount != a.getMatrixColCount()) { + throw new IllegalArgumentException("矩阵列数不一致,无法拼接"); + } + double[][] result = new double[matrixRowCount + a.getMatrixRowCount()][matrixColCount]; + for (int i = 0; i < matrixRowCount; i++) { + result[i] = matrix[i]; + } + for (int i = 0; i < a.getMatrixRowCount(); i++) { + result[matrixRowCount + i] = a.getMatrix()[i]; + } + return new Matrix(result); + }else{ + throw new IllegalArgumentException("方向参数有误"); + } + } + /** + * 扩展矩阵 + * @param direction 扩展方向,1为横向,2为竖向 + * @param a + * @return + * @throws IllegalArgumentException + */ + public Matrix extend(int direction , int a) throws IllegalArgumentException { + if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) { + throw new IllegalArgumentException("矩阵为空"); + } + if(direction == 1){ + //横向复制 + double[][] result = new double[matrixRowCount][matrixColCount*a]; + for (int i = 0; i < matrixRowCount; i++) { + for (int j = 0; j < a; j++) { + System.arraycopy(matrix[i],0,result[i],j*matrixColCount,matrixColCount); + } + } + return new Matrix(result); + }else if(direction == 2){ + //纵向复制 + double[][] result = new double[matrixRowCount*a][matrixColCount]; + for (int i = 0; i < matrixRowCount*a; i++) { + result[i] = matrix[i%matrixRowCount]; + } + return new Matrix(result); + }else{ + throw new IllegalArgumentException("方向参数有误"); + } + } + /** + * 获取每列的平均值 + * @return + * @throws IllegalArgumentException + */ + public Matrix getColAvg() throws IllegalArgumentException { + Matrix tmp = this.sumCol(); + return tmp.divide(matrixRowCount); + } + + /** + * 矩阵行排序 + * @param index 根据第几列的数进行行排序 + * @param order 排序顺序,升序或降序 + * @return + * @throws IllegalArgumentException + */ + public void sort(int index, OrderEnum order) throws IllegalArgumentException{ + if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) { + throw new IllegalArgumentException("矩阵为空"); + } + if(index >= matrixColCount){ + throw new IllegalArgumentException("排序索引index越界"); + } + sort(index,order,0,this.matrixRowCount - 1); + } + + /** + * 判断是否是方阵 + * 行列数相等,并且不等于0 + * @return + */ + public boolean isSquareMatrix(){ + return matrixColCount == matrixRowCount && matrixColCount != 0; + } + + @Override + public String toString() { + StringBuilder stringBuilder = new StringBuilder(); + stringBuilder.append("\r\n"); + for (int i = 0; i < matrixRowCount; i++) { + stringBuilder.append("# "); + for (int j = 0; j < matrixColCount; j++) { + stringBuilder.append(matrix[i][j]).append("\t "); + } + stringBuilder.append("#\r\n"); + } + stringBuilder.append("\r\n"); + return stringBuilder.toString(); + } + + private void sort(int index,OrderEnum order,int start,int end){ + if(start >= end){ + return; + } + int tmp = partition(index,order,start,end); + sort(index,order, start, tmp - 1); + sort(index,order, tmp + 1, end); + } + + private int partition(int index,OrderEnum order,int start,int end){ + int l = start + 1,r = end; + double v = matrix[start][index]; + switch (order){ + case ASC: + while(true){ + while(matrix[r][index] >= v && r > start){ + r--; + } + while(matrix[l][index] <= v && l < end){ + l++; + } + if(l >= r){ + break; + } + double[] tmp = matrix[r]; + matrix[r] = matrix[l]; + matrix[l] = tmp; + } + break; + case DESC: + while(true){ + while(matrix[r][index] <= v && r > start){ + r--; + } + while(matrix[l][index] >= v && l < end){ + l++; + } + if(l >= r){ + break; + } + double[] tmp = matrix[r]; + matrix[r] = matrix[l]; + matrix[l] = tmp; + } + break; + } + double[] tmp = matrix[r]; + matrix[r] = matrix[start]; + matrix[start] = tmp; + return r; + } +} diff --git a/algorithm/src/main/java/com/mh/algorithm/utils/CsvInfo.java b/algorithm/src/main/java/com/mh/algorithm/utils/CsvInfo.java new file mode 100644 index 0000000..7c12d44 --- /dev/null +++ b/algorithm/src/main/java/com/mh/algorithm/utils/CsvInfo.java @@ -0,0 +1,53 @@ +package com.mh.algorithm.utils; + +import com.mh.algorithm.matrix.Matrix; + +import java.util.ArrayList; + +public class CsvInfo { + private String[] header; + private int csvRowCount; + private int csvColCount; + private ArrayList csvFileList; + + public String[] getHeader() { + return header; + } + + public void setHeader(String[] header) { + this.header = header; + } + + public int getCsvRowCount() { + return csvRowCount; + } + + public int getCsvColCount() { + return csvColCount; + } + + public ArrayList getCsvFileList() { + return csvFileList; + } + + public void setCsvFileList(ArrayList csvFileList) { + this.csvFileList = csvFileList; + this.csvColCount = csvFileList.get(0) != null?csvFileList.get(0).length:0; + this.csvRowCount = csvFileList.size(); + } + + public Matrix toMatrix() throws Exception { + double[][] arr = new double[csvFileList.size()][csvFileList.get(0).length]; + for (int i = 0; i < csvFileList.size(); i++) { + for (int j = 0; j < csvFileList.get(0).length; j++) { + try { + arr[i][j] = Double.parseDouble(csvFileList.get(i)[j]); + }catch (NumberFormatException e){ + throw new Exception("Csv中含有非数字字符,无法转换成Matrix对象"); + } + } + } + return new Matrix(arr); + } + +} diff --git a/algorithm/src/main/java/com/mh/algorithm/utils/CsvUtil.java b/algorithm/src/main/java/com/mh/algorithm/utils/CsvUtil.java new file mode 100644 index 0000000..557e0f1 --- /dev/null +++ b/algorithm/src/main/java/com/mh/algorithm/utils/CsvUtil.java @@ -0,0 +1,66 @@ +package com.mh.algorithm.utils; + +import com.csvreader.CsvReader; +import com.csvreader.CsvWriter; +import com.mh.algorithm.matrix.Matrix; + +import java.io.IOException; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; + +public class CsvUtil { + /** + * 获取CSV中的信息 + * @param hasHeader 是否含有表头 + * @param path CSV文件的路径 + * @return + * @throws IOException + */ + public static CsvInfo getCsvInfo(boolean hasHeader , String path) throws IOException { + //创建csv对象,存储csv中的信息 + CsvInfo csvInfo = new CsvInfo(); + //获取CsvReader流 + CsvReader csvReader = new CsvReader(path, ',', StandardCharsets.UTF_8); + if(hasHeader){ + csvReader.readHeaders(); + } + //获取Csv中的所有记录 + ArrayList csvFileList = new ArrayList(); + while (csvReader.readRecord()) { + csvFileList.add(csvReader.getValues()); + } + //赋值 + csvInfo.setHeader(csvReader.getHeaders()); + csvInfo.setCsvFileList(csvFileList); + //关闭流 + csvReader.close(); + return csvInfo; + } + + /** + * 将矩阵写入到csv文件中 + * @param header 表头 + * @param data 以矩阵形式存放的数据 + * @param path 写入的文件地址 + * @throws Exception + */ + public static void createCsvFile(String[] header,Matrix data,String path) throws Exception { + + if (header!=null && header.length != data.getMatrixColCount()) { + throw new Exception("表头列数与数据列数不符"); + } + CsvWriter csvWriter = new CsvWriter(path, ',', StandardCharsets.UTF_8); + + if (header != null) { + csvWriter.writeRecord(header); + } + for (int i = 0; i < data.getMatrixRowCount(); i++) { + String[] record = new String[data.getMatrixColCount()]; + for (int j = 0; j < data.getMatrixColCount(); j++) { + record[j] = data.getValOfIdx(i, j)+""; + } + csvWriter.writeRecord(record); + } + csvWriter.close(); + } +} diff --git a/algorithm/src/main/java/com/mh/algorithm/utils/DoubleUtil.java b/algorithm/src/main/java/com/mh/algorithm/utils/DoubleUtil.java new file mode 100644 index 0000000..396097b --- /dev/null +++ b/algorithm/src/main/java/com/mh/algorithm/utils/DoubleUtil.java @@ -0,0 +1,20 @@ +package com.mh.algorithm.utils; + +/** + * @program: top-algorithm-set + * @description: DoubleTool + * @author: Mr.Zhao + * @create: 2020-11-12 21:54 + **/ +public class DoubleUtil { + + private static final Double MAX_ERROR = 0.0001; + + public static boolean equals(Double a, Double b) { + return Math.abs(a - b)< MAX_ERROR; + } + + public static boolean equals(Double a, Double b,Double maxError) { + return Math.abs(a - b)< maxError; + } +} diff --git a/algorithm/src/main/java/com/mh/algorithm/utils/MatrixUtil.java b/algorithm/src/main/java/com/mh/algorithm/utils/MatrixUtil.java new file mode 100644 index 0000000..b242a28 --- /dev/null +++ b/algorithm/src/main/java/com/mh/algorithm/utils/MatrixUtil.java @@ -0,0 +1,285 @@ +package com.mh.algorithm.utils; + +import Jama.EigenvalueDecomposition; +import com.mh.algorithm.matrix.Matrix; + +import java.util.*; + +public class MatrixUtil { + /** + * 创建一个单位矩阵 + * @param matrixRowCount 单位矩阵的纬度 + * @return + */ + public static Matrix eye(int matrixRowCount){ + double[][] result = new double[matrixRowCount][matrixRowCount]; + for (int i = 0; i < matrixRowCount; i++) { + for (int j = 0; j < matrixRowCount; j++) { + if(i == j){ + result[i][j] = 1; + }else{ + result[i][j] = 0; + } + } + } + return new Matrix(result); + } + + /** + * 求矩阵的逆 + * 原理:AE=EA^-1 + * @param a + * @return + * @throws Exception + */ + public static Matrix inv(Matrix a) throws Exception { + if (!invable(a)) { + throw new Exception("矩阵不可逆"); + } + // [a|E] + Matrix b = a.splice(1, eye(a.getMatrixRowCount())); + double[][] data = b.getMatrix(); + int rowCount = b.getMatrixRowCount(); + int colCount = b.getMatrixColCount(); + //此处应用a的列数,为简化,直接用b的行数 + for (int j = 0; j < rowCount; j++) { + //若遇到0则交换两行 + int notZeroRow = -2; + if(data[j][j] == 0){ + notZeroRow = -1; + for (int l = j; l < rowCount; l++) { + if (data[l][j] != 0) { + notZeroRow = l; + break; + } + } + } + if (notZeroRow == -1) { + throw new Exception("矩阵不可逆"); + }else if(notZeroRow != -2){ + //交换j与notZeroRow两行 + double[] tmp = data[j]; + data[j] = data[notZeroRow]; + data[notZeroRow] = tmp; + } + //将第data[j][j]化为1 + if (data[j][j] != 1) { + double multiple = data[j][j]; + for (int colIdx = j; colIdx < colCount; colIdx++) { + data[j][colIdx] /= multiple; + } + } + //行与行相减 + for (int i = 0; i < rowCount; i++) { + if (i != j) { + double multiple = data[i][j] / data[j][j]; + //遍历行中的列 + for (int k = j; k < colCount; k++) { + data[i][k] = data[i][k] - multiple * data[j][k]; + } + } + } + } + Matrix result = new Matrix(data); + return result.subMatrix(0, rowCount, rowCount, rowCount); + } + + /** + * 求矩阵的伴随矩阵 + * 原理:A*=|A|A^-1 + * @param a + * @return + * @throws Exception + */ + public static Matrix adj(Matrix a) throws Exception { + return inv(a).multiple(det(a)); + } + + /** + * 矩阵转成上三角矩阵 + * @param a + * @return + * @throws Exception + */ + public static Matrix getTopTriangle(Matrix a) throws Exception { + if (!a.isSquareMatrix()) { + throw new Exception("不是方阵无法进行计算"); + } + int matrixHeight = a.getMatrixRowCount(); + double[][] result = a.getMatrix(); + //遍历列 + for (int j = 0; j < matrixHeight; j++) { + //遍历行 + for (int i = j+1; i < matrixHeight; i++) { + //若遇到0则交换两行 + int notZeroRow = -2; + if(result[j][j] == 0){ + notZeroRow = -1; + for (int l = i; l < matrixHeight; l++) { + if (result[l][j] != 0) { + notZeroRow = l; + break; + } + } + } + if (notZeroRow == -1) { + throw new Exception("矩阵不可逆"); + }else if(notZeroRow != -2){ + //交换j与notZeroRow两行 + double[] tmp = result[j]; + result[j] = result[notZeroRow]; + result[notZeroRow] = tmp; + } + + double multiple = result[i][j]/result[j][j]; + //遍历行中的列 + for (int k = j; k < matrixHeight; k++) { + result[i][k] = result[i][k] - multiple * result[j][k]; + } + } + } + return new Matrix(result); + } + + /** + * 计算矩阵的行列式 + * @param a + * @return + * @throws Exception + */ + public static double det(Matrix a) throws Exception { + //将矩阵转成上三角矩阵 + Matrix b = MatrixUtil.getTopTriangle(a); + double result = 1; + //计算矩阵行列式 + for (int i = 0; i < b.getMatrixRowCount(); i++) { + result *= b.getValOfIdx(i, i); + } + return result; + } + /** + * 获取协方差矩阵 + * @param a + * @return + * @throws Exception + */ + public static Matrix cov(Matrix a) throws Exception { + if (a.getMatrix() == null) { + throw new Exception("矩阵为空"); + } + Matrix avg = a.getColAvg().extend(2, a.getMatrixRowCount()); + Matrix tmp = a.subtract(avg); + return tmp.transpose().multiple(tmp).multiple(1/((double) a.getMatrixRowCount() -1)); + } + + /** + * 判断矩阵是否可逆 + * 如果可转为上三角矩阵则可逆 + * @param a + * @return + */ + public static boolean invable(Matrix a) { + try { + getTopTriangle(a); + return true; + } catch (Exception e) { + return false; + } + } + + /** + * 获取矩阵的特征值矩阵,调用Jama中的getV方法 + * @param a + * @return + */ + public static Matrix getV(Matrix a) { + EigenvalueDecomposition eig = new EigenvalueDecomposition(new Jama.Matrix(a.getMatrix())); + return new Matrix(eig.getV().getArray()); + } + + /** + * 取特征值实部 + * @param a + * @return + */ + public double[] getRealEigenvalues(Matrix a){ + EigenvalueDecomposition eig = new EigenvalueDecomposition(new Jama.Matrix(a.getMatrix())); + return eig.getRealEigenvalues(); + } + + /** + * 取特征值虚部 + * @param a + * @return + */ + public double[] getImagEigenvalues(Matrix a){ + EigenvalueDecomposition eig = new EigenvalueDecomposition(new Jama.Matrix(a.getMatrix())); + return eig.getImagEigenvalues(); + } + + /** + * 取块对角特征值矩阵 + * @param a + * @return + */ + public static Matrix getD(Matrix a) { + EigenvalueDecomposition eig = new EigenvalueDecomposition(new Jama.Matrix(a.getMatrix())); + return new Matrix(eig.getD().getArray()); + } + + /** + * 数据归一化 + * @param a 要归一化的数据 + * @param normalizationMin 要归一化的区间下限 + * @param normalizationMax 要归一化的区间上限 + * @return + */ + public static Map normalize(Matrix a, double normalizationMin, double normalizationMax) throws Exception { + HashMap result = new HashMap<>(); + double[][] maxArr = new double[1][a.getMatrixColCount()]; + double[][] minArr = new double[1][a.getMatrixColCount()]; + double[][] res = new double[a.getMatrixRowCount()][a.getMatrixColCount()]; + for (int i = 0; i < a.getMatrixColCount(); i++) { + List tmp = new ArrayList(); + for (int j = 0; j < a.getMatrixRowCount(); j++) { + tmp.add(a.getValOfIdx(j,i)); + } + double max = (double) Collections.max(tmp); + double min = (double) Collections.min(tmp); + //数据归一化(注:若max与min均为0则不需要归一化) + if (max != 0 || min != 0) { + for (int j = 0; j < a.getMatrixRowCount(); j++) { + res[j][i] = normalizationMin + (a.getValOfIdx(j,i) - min) / (max - min) * (normalizationMax - normalizationMin); + } + } + maxArr[0][i] = max; + minArr[0][i] = min; + } + result.put("max", new Matrix(maxArr)); + result.put("min", new Matrix(minArr)); + result.put("res", new Matrix(res)); + return result; + } + + /** + * 反归一化 + * @param a 要反归一化的数据 + * @param normalizationMin 要反归一化的区间下限 + * @param normalizationMax 要反归一化的区间上限 + * @param dataMax 数据最大值 + * @param dataMin 数据最小值 + * @return + */ + public static Matrix inverseNormalize(Matrix a, double normalizationMax, double normalizationMin , Matrix dataMax,Matrix dataMin){ + double[][] res = new double[a.getMatrixRowCount()][a.getMatrixColCount()]; + for (int i = 0; i < a.getMatrixColCount(); i++) { + //数据反归一化 + if (dataMin.getValOfIdx(0,i) != 0 || dataMax.getValOfIdx(0,i) != 0) { + for (int j = 0; j < a.getMatrixRowCount(); j++) { + res[j][i] = dataMin.getValOfIdx(0,i) + (dataMax.getValOfIdx(0,i) - dataMin.getValOfIdx(0,i)) * (a.getValOfIdx(j,i) - normalizationMin) / (normalizationMax - normalizationMin); + } + } + } + return new Matrix(res); + } +} diff --git a/algorithm/src/main/java/com/mh/algorithm/utils/SerializationUtil.java b/algorithm/src/main/java/com/mh/algorithm/utils/SerializationUtil.java new file mode 100644 index 0000000..b159328 --- /dev/null +++ b/algorithm/src/main/java/com/mh/algorithm/utils/SerializationUtil.java @@ -0,0 +1,32 @@ +package com.mh.algorithm.utils; + +import java.io.*; + +public class SerializationUtil { + /** + * 对象序列化到本地 + * @param object + * @throws IOException + */ + public static void serialize(Object object, String path) throws IOException { + File file = new File(path); + System.out.println(file.getAbsolutePath()); + ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(file)); + out.writeObject(object); + out.close(); + } + + /** + * 对象反序列化 + * @return + * @throws IOException + * @throws ClassNotFoundException + */ + public static Object deSerialization(String path) throws IOException, ClassNotFoundException { + File file = new File(path); + ObjectInputStream oin = new ObjectInputStream(new FileInputStream(file)); + Object object = oin.readObject(); + oin.close(); + return object; + } +} diff --git a/algorithm/src/test/java/com/mh/algorithm/bpnn/bpnnTest.java b/algorithm/src/test/java/com/mh/algorithm/bpnn/bpnnTest.java new file mode 100644 index 0000000..a8ad316 --- /dev/null +++ b/algorithm/src/test/java/com/mh/algorithm/bpnn/bpnnTest.java @@ -0,0 +1,71 @@ +package com.mh.algorithm.bpnn; + +import com.mh.algorithm.matrix.Matrix; +import com.mh.algorithm.utils.CsvInfo; +import com.mh.algorithm.utils.CsvUtil; +import com.mh.algorithm.utils.SerializationUtil; +import org.junit.Test; + +import java.util.Date; + +public class bpnnTest { + @Test + public void test() throws Exception { + // 创建训练集矩阵 + CsvInfo csvInfo = CsvUtil.getCsvInfo(true, "D:\\ljf\\my_pro\\top-algorithm-set-dev\\src\\trainDataElec.csv"); + Matrix trainSet = csvInfo.toMatrix(); + // 创建BPNN工厂对象 + BPNeuralNetworkFactory factory = new BPNeuralNetworkFactory(); + // 创建BP参数对象 + BPParameter bpParameter = new BPParameter(); + bpParameter.setInputLayerNeuronCount(2); + bpParameter.setHiddenLayerNeuronCount(2); + bpParameter.setOutputLayerNeuronCount(2); + bpParameter.setPrecision(0.01); + bpParameter.setMaxTimes(100000); + + // 训练BP神经网络 + System.out.println(new Date()); + BPModel bpModel = factory.trainBP(bpParameter, trainSet); + System.out.println(new Date()); + + // 将BPModel序列化到本地 + SerializationUtil.serialize(bpModel, "elec"); + + CsvInfo csvInfo2 = CsvUtil.getCsvInfo(true, "D:\\ljf\\my_pro\\top-algorithm-set-dev\\src\\testDataElec.csv"); + Matrix testSet = csvInfo2.toMatrix(); + + Matrix testData1 = testSet.subMatrix(0, testSet.getMatrixRowCount(), 0, testSet.getMatrixColCount() - 2); + Matrix testLabel = testSet.subMatrix(0, testSet.getMatrixRowCount(), testSet.getMatrixColCount() - 2, 1); + // 将BPModel反序列化 + BPModel bpModel1 = (BPModel) SerializationUtil.deSerialization("elec"); + Matrix result = factory.computeBP(bpModel1, testData1); + + int total = result.getMatrixRowCount(); + int correct = 0; + for (int i = 0; i < result.getMatrixRowCount(); i++) { + if(Math.round(result.getValOfIdx(i,0)) == testLabel.getValOfIdx(i,0)){ + correct++; + } + } + double correctRate = Double.valueOf(correct) / Double.valueOf(total); + System.out.println(correctRate); + } + + /** + * 使用示例 + * @throws Exception + */ + @Test + public void bpnnUsing() throws Exception{ + CsvInfo csvInfo = CsvUtil.getCsvInfo(false, "D:\\ljf\\my_pro\\top-algorithm-set-dev\\src\\dataElec.csv"); + Matrix data = csvInfo.toMatrix(); + // 将BPModel反序列化 + BPModel bpModel1 = (BPModel) SerializationUtil.deSerialization("elec"); + // 创建工厂 + BPNeuralNetworkFactory factory = new BPNeuralNetworkFactory(); + Matrix result = factory.computeBP(bpModel1, data); + CsvUtil.createCsvFile(null,result,"D:\\ljf\\my_pro\\top-algorithm-set-dev\\src\\computeResult.csv"); + } + +} diff --git a/algorithm/src/test/java/com/mh/algorithm/knn/knnTest.java b/algorithm/src/test/java/com/mh/algorithm/knn/knnTest.java new file mode 100644 index 0000000..54713a3 --- /dev/null +++ b/algorithm/src/test/java/com/mh/algorithm/knn/knnTest.java @@ -0,0 +1,46 @@ +package com.mh.algorithm.knn; + +import com.mh.algorithm.matrix.Matrix; +import com.mh.algorithm.utils.CsvInfo; +import com.mh.algorithm.utils.CsvUtil; +import com.mh.algorithm.utils.DoubleUtil; +import org.junit.Test; + +/** + * @program: top-algorithm-set + * @description: + * @author: Mr.Zhao + * @create: 2020-10-26 22:04 + **/ +public class knnTest { + @Test + public void test() throws Exception { + // 训练集 + CsvInfo csvInfo = CsvUtil.getCsvInfo(false, "E:\\jarTest\\trainData.csv"); + Matrix trainSet = csvInfo.toMatrix(); + Matrix trainSetLabels = trainSet.getColOfIdx(trainSet.getMatrixColCount() - 1); + Matrix trainSetData = trainSet.subMatrix(0, trainSet.getMatrixRowCount(), 0, trainSet.getMatrixColCount() - 1); + + CsvInfo csvInfo1 = CsvUtil.getCsvInfo(false, "E:\\jarTest\\testData.csv"); + Matrix testSet = csvInfo1.toMatrix(); + Matrix testSetData = trainSet.subMatrix(0, testSet.getMatrixRowCount(), 0, testSet.getMatrixColCount() - 1); + Matrix testSetLabels = trainSet.getColOfIdx(testSet.getMatrixColCount() - 1); + + // 分类 + long startTime = System.currentTimeMillis(); + Matrix result = KNN.classify(testSetData, trainSetData, trainSetLabels, 5); + long endTime = System.currentTimeMillis(); + System.out.println("run time:" + (endTime - startTime)); + // 正确率 + Matrix error = result.subtract(testSetLabels); + int total = error.getMatrixRowCount(); + int correct = 0; + for (int i = 0; i < error.getMatrixRowCount(); i++) { + if (DoubleUtil.equals(error.getValOfIdx(i, 0), 0.0)) { + correct++; + } + } + double correctRate = Double.valueOf(correct) / Double.valueOf(total); + System.out.println("correctRate:"+ correctRate); + } +} diff --git a/pom.xml b/pom.xml index 7c03fe1..0a21d9b 100644 --- a/pom.xml +++ b/pom.xml @@ -11,6 +11,7 @@ common user-service + algorithm diff --git a/user-service/src/main/java/com/mh/user/controller/ControlSetController.java b/user-service/src/main/java/com/mh/user/controller/ControlSetController.java index fae8cdc..cb1af68 100644 --- a/user-service/src/main/java/com/mh/user/controller/ControlSetController.java +++ b/user-service/src/main/java/com/mh/user/controller/ControlSetController.java @@ -46,11 +46,11 @@ public class ControlSetController { } //查询设置表 - @SysLogger(title="控制设置",optDesc = "查询设置值") + @SysLogger(title="控制设置",optDesc = "查询时控设置值") @PostMapping(value="/query") - public HttpResult queryControlSet(@RequestParam("buildingId") String buildingId) { + public HttpResult queryControlSet(@RequestParam("buildingId") String buildingId, @RequestParam(value = "timeName",required = false) String timeName) { try{ - ControlSetEntity control=controlSetService.queryControlSet(buildingId); + ControlSetEntity control=controlSetService.queryControlSet(buildingId, timeName); return HttpResult.ok(control); }catch (Exception e){ // e.printStackTrace(); diff --git a/user-service/src/main/java/com/mh/user/controller/SerialPortController.java b/user-service/src/main/java/com/mh/user/controller/SerialPortController.java index 8e2b0f0..7065846 100644 --- a/user-service/src/main/java/com/mh/user/controller/SerialPortController.java +++ b/user-service/src/main/java/com/mh/user/controller/SerialPortController.java @@ -424,9 +424,9 @@ public class SerialPortController { } @PostMapping(value = "/control") - public HttpResult queryControlSet(@RequestParam(value = "buildingId") String buildingId) { + public HttpResult queryControlSet(@RequestParam(value = "buildingId") String buildingId, @RequestParam(value = "timeName", required = false) String timeName) { try { - ControlSetEntity list = controlSetService.queryControlSet(buildingId); + ControlSetEntity list = controlSetService.queryControlSet(buildingId, timeName); return HttpResult.ok(list); } catch (Exception e) { // e.printStackTrace(); diff --git a/user-service/src/main/java/com/mh/user/entity/MaintainInfoEntity.java b/user-service/src/main/java/com/mh/user/entity/MaintainInfoEntity.java index 7cee919..518843d 100644 --- a/user-service/src/main/java/com/mh/user/entity/MaintainInfoEntity.java +++ b/user-service/src/main/java/com/mh/user/entity/MaintainInfoEntity.java @@ -18,6 +18,7 @@ public class MaintainInfoEntity { private String maintainPeople; //维护人员 private Double cost; //费用 private String contents; //维保内容 + private String evaluate; // 评价分数 } diff --git a/user-service/src/main/java/com/mh/user/mapper/ControlSetMapper.java b/user-service/src/main/java/com/mh/user/mapper/ControlSetMapper.java index e5c4a37..a410a2a 100644 --- a/user-service/src/main/java/com/mh/user/mapper/ControlSetMapper.java +++ b/user-service/src/main/java/com/mh/user/mapper/ControlSetMapper.java @@ -60,6 +60,21 @@ public interface ControlSetMapper { @Result(column = "back_water_temp", property = "backWaterTemp"), @Result(column = "up_water_temp", property = "upWaterTemp"), }) - @Select("select * from control_Set where building_id=#{buildingId}") - ControlSetEntity queryControlSet(@Param("buildingId") String buildingId); + @Select("select " + + " top 1 " + + " * " + + "from " + + " control_Set " + + "where " + + " building_id = #{buildingId} " + + " and exists ( " + + " select " + + " 1 " + + " from " + + " device_install di " + + " where " + + " di.building_id = building_id " + + " and di.device_name like concat(#{timeName}, '时控') " + + ") ") + ControlSetEntity queryControlSet(@Param("buildingId") String buildingId, @Param("timeName") String timeName); } diff --git a/user-service/src/main/java/com/mh/user/mapper/MaintainInfoMapper.java b/user-service/src/main/java/com/mh/user/mapper/MaintainInfoMapper.java index c94e90f..6de88fe 100644 --- a/user-service/src/main/java/com/mh/user/mapper/MaintainInfoMapper.java +++ b/user-service/src/main/java/com/mh/user/mapper/MaintainInfoMapper.java @@ -18,8 +18,8 @@ public interface MaintainInfoMapper { * 维修保养信息 * @param maintainInfoEntity */ - @Insert("insert into maintain_info(cur_date,building_id,device_type,device_addr,maintain_type,maintain_people,cost,contents) values (" + - " getDate(),#{buildingId},#{deviceType},#{deviceAddr},#{maintainType},#{maintainPeople},#{cost},#{contents})") + @Insert("insert into maintain_info(cur_date,building_id,device_type,device_addr,maintain_type,maintain_people,cost,contents, evaluate) values (" + + " getDate(),#{buildingId},#{deviceType},#{deviceAddr},#{maintainType},#{maintainPeople},#{cost},#{contents}, #{evaluate})") int saveMaintainInfo(MaintainInfoEntity maintainInfoEntity); /** @@ -36,6 +36,7 @@ public interface MaintainInfoMapper { " , maintain_people = #{maintainPeople} " + " , cost = #{cost} " + " , contents = #{contents} " + + " , evaluate = #{evaluate} " + " where id = #{id} " + "") int updateMaintainInfo(MaintainInfoEntity maintainInfoEntity); @@ -61,7 +62,8 @@ public interface MaintainInfoMapper { @Result(property="maintainPeople",column="maintain_people"), @Result(property="id",column="id"), @Result(property="cost",column="cost"), - @Result(property="contents",column="contents") + @Result(property="contents",column="contents"), + @Result(property="evaluate",column="evaluate") }) List queryMaintainInfo(@Param("curDate") String curDate, @Param("buildingId") String buildingId, diff --git a/user-service/src/main/java/com/mh/user/service/ControlSetService.java b/user-service/src/main/java/com/mh/user/service/ControlSetService.java index b1e01ce..0862a03 100644 --- a/user-service/src/main/java/com/mh/user/service/ControlSetService.java +++ b/user-service/src/main/java/com/mh/user/service/ControlSetService.java @@ -6,5 +6,5 @@ public interface ControlSetService { void saveControlSet(ControlSetEntity controlSetEntity); - ControlSetEntity queryControlSet(String buildingId); + ControlSetEntity queryControlSet(String buildingId, String timeName); } diff --git a/user-service/src/main/java/com/mh/user/service/impl/ControlSetServiceImpl.java b/user-service/src/main/java/com/mh/user/service/impl/ControlSetServiceImpl.java index 5ee0f39..fdeac96 100644 --- a/user-service/src/main/java/com/mh/user/service/impl/ControlSetServiceImpl.java +++ b/user-service/src/main/java/com/mh/user/service/impl/ControlSetServiceImpl.java @@ -1,5 +1,6 @@ package com.mh.user.service.impl; +import com.mh.common.utils.StringUtils; import com.mh.user.entity.ControlSetEntity; import com.mh.user.mapper.ControlSetMapper; import com.mh.user.service.ControlSetService; @@ -26,8 +27,10 @@ public class ControlSetServiceImpl implements ControlSetService { } @Override - public ControlSetEntity queryControlSet(String buildingId) { - - return controlSetMapper.queryControlSet(buildingId); + public ControlSetEntity queryControlSet(String buildingId, String timeName) { + if (StringUtils.isBlank(timeName)) { + timeName = "%"; + } + return controlSetMapper.queryControlSet(buildingId,timeName); } } diff --git a/user-service/src/main/java/com/mh/user/service/impl/DeviceControlServiceImpl.java b/user-service/src/main/java/com/mh/user/service/impl/DeviceControlServiceImpl.java index 70ca9af..12dde02 100644 --- a/user-service/src/main/java/com/mh/user/service/impl/DeviceControlServiceImpl.java +++ b/user-service/src/main/java/com/mh/user/service/impl/DeviceControlServiceImpl.java @@ -11,6 +11,7 @@ import com.mh.user.model.DeviceModel; import com.mh.user.model.SerialPortModel; import com.mh.user.serialport.SerialPortSingle2; import com.mh.user.service.*; +import com.mh.user.utils.ExchangeStringUtil; import lombok.extern.slf4j.Slf4j; import org.springframework.stereotype.Service; @@ -385,6 +386,11 @@ public class DeviceControlServiceImpl implements DeviceControlService { } else { deviceCodeParam.setRegisterAddr("00240000"); //寄存器地址 } + } else if ("startOrStop".equals(deviceCodeParam.getParam())) { + deviceCodeParam.setFunCode("06"); //功能码读数据 + if ("瑞星".equals(deviceCodeParam.getBrand())) { + deviceCodeParam.setRegisterAddr("0001"+ ExchangeStringUtil.addZeroForNum(deviceCodeParam.getDataValue(), 4)); + } } return rtData; }