diff --git a/2024数据库脚本.sql b/2024数据库脚本.sql
new file mode 100644
index 0000000..4f0c9a3
--- /dev/null
+++ b/2024数据库脚本.sql
@@ -0,0 +1,38 @@
+-- 2024-05-07 维修表缺少字段
+ALTER TABLE maintain_info ADD cost numeric(2,0) NULL;
+EXEC sys.sp_addextendedproperty 'MS_Description', N'材料费用', 'schema', N'dbo', 'table', N'maintain_info', 'column', N'cost';
+ALTER TABLE maintain_info ADD contents varchar(100) NULL;
+EXEC sys.sp_addextendedproperty 'MS_Description', N'维保内容', 'schema', N'dbo', 'table', N'maintain_info', 'column', N'contents';
+ALTER TABLE maintain_info ADD evaluate varchar(10) NULL;
+EXEC sys.sp_addextendedproperty 'MS_Description', N'评价内容', 'schema', N'dbo', 'table', N'maintain_info', 'column', N'evaluate';
+
+-- 训练集合:
+select
+ eds.cur_date,
+ eds.building_id,
+ isnull(eds.water_value,
+ 0) as water_value,
+ isnull(eds.elect_value,
+ 0) as elect_value,
+ isnull(convert(numeric(24,2),t1.water_level),
+ 0) as water_level
+from
+ energy_day_sum eds
+ left join (
+ select
+ convert(date,
+ cur_date) as cur_date,
+ building_id,
+ avg(isnull(convert(numeric(24, 2), water_level), 0)) as water_level
+ from
+ history_data
+ group by
+ convert(date,
+ cur_date),
+ building_id
+ ) t1 on
+ eds.cur_date = t1.cur_date and eds.building_id = t1.building_id
+where eds.building_id != '所有'
+order by
+ eds.building_id,
+ eds.cur_date
diff --git a/algorithm/pom.xml b/algorithm/pom.xml
new file mode 100644
index 0000000..d7aab46
--- /dev/null
+++ b/algorithm/pom.xml
@@ -0,0 +1,135 @@
+
+
+
+ com.mh
+ chws
+ 1.0-SNAPSHOT
+
+ 4.0.0
+
+ com.mh
+ algorithm
+ 1.0.0
+ jar
+
+
+ UTF-8
+ UTF-8
+ 1.8
+ 1.8
+ 1.8
+
+
+
+
+ net.sourceforge.javacsv
+ javacsv
+ 2.0
+
+
+ gov.nist.math
+ jama
+ 1.0.3
+
+
+ junit
+ junit
+ RELEASE
+ test
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+ 3.1
+
+
+ 1.8
+
+
+
+
+
+
+
+ default
+
+ true
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-compiler-plugin
+ 3.8.0
+
+
+ 1.8
+ UTF-8
+
+
+
+
+ org.apache.maven.plugins
+ maven-javadoc-plugin
+ 2.9.1
+
+
+ attach-javadocs
+
+ jar
+
+
+
+ -Xdoclint:none
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-source-plugin
+ 2.3
+
+
+ attach-sources
+
+ jar
+
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-gpg-plugin
+ 1.4
+
+
+ sign-artifacts
+ verify
+
+ sign
+
+
+
+
+
+ maven-jar-plugin
+ 2.3.1
+
+ target/classes
+
+
+
+
+
+
+
diff --git a/algorithm/src/main/java/com/mh/algorithm/bpnn/ActivationFunction.java b/algorithm/src/main/java/com/mh/algorithm/bpnn/ActivationFunction.java
new file mode 100644
index 0000000..fff15fe
--- /dev/null
+++ b/algorithm/src/main/java/com/mh/algorithm/bpnn/ActivationFunction.java
@@ -0,0 +1,8 @@
+package com.mh.algorithm.bpnn;
+
+public interface ActivationFunction {
+ //计算值
+ double computeValue(double val);
+ //计算导数
+ double computeDerivative(double val);
+}
diff --git a/algorithm/src/main/java/com/mh/algorithm/bpnn/BPModel.java b/algorithm/src/main/java/com/mh/algorithm/bpnn/BPModel.java
new file mode 100644
index 0000000..719f8fb
--- /dev/null
+++ b/algorithm/src/main/java/com/mh/algorithm/bpnn/BPModel.java
@@ -0,0 +1,111 @@
+package com.mh.algorithm.bpnn;
+
+import com.mh.algorithm.matrix.Matrix;
+
+import java.io.Serializable;
+
+public class BPModel implements Serializable {
+ //BP神经网络权值与阈值
+ private Matrix weightIJ;
+ private Matrix b1;
+ private Matrix weightJP;
+ private Matrix b2;
+ /*用于反归一化*/
+ private Matrix inputMax;
+ private Matrix inputMin;
+ private Matrix outputMax;
+ private Matrix outputMin;
+ /*BP神经网络训练参数*/
+ private BPParameter bpParameter;
+ /*BP神经网络训练情况*/
+ private double error;
+ private int times;
+
+ public Matrix getWeightIJ() {
+ return weightIJ;
+ }
+
+ public void setWeightIJ(Matrix weightIJ) {
+ this.weightIJ = weightIJ;
+ }
+
+ public Matrix getB1() {
+ return b1;
+ }
+
+ public void setB1(Matrix b1) {
+ this.b1 = b1;
+ }
+
+ public Matrix getWeightJP() {
+ return weightJP;
+ }
+
+ public void setWeightJP(Matrix weightJP) {
+ this.weightJP = weightJP;
+ }
+
+ public Matrix getB2() {
+ return b2;
+ }
+
+ public void setB2(Matrix b2) {
+ this.b2 = b2;
+ }
+
+ public Matrix getInputMax() {
+ return inputMax;
+ }
+
+ public void setInputMax(Matrix inputMax) {
+ this.inputMax = inputMax;
+ }
+
+ public Matrix getInputMin() {
+ return inputMin;
+ }
+
+ public void setInputMin(Matrix inputMin) {
+ this.inputMin = inputMin;
+ }
+
+ public Matrix getOutputMax() {
+ return outputMax;
+ }
+
+ public void setOutputMax(Matrix outputMax) {
+ this.outputMax = outputMax;
+ }
+
+ public Matrix getOutputMin() {
+ return outputMin;
+ }
+
+ public void setOutputMin(Matrix outputMin) {
+ this.outputMin = outputMin;
+ }
+
+ public BPParameter getBpParameter() {
+ return bpParameter;
+ }
+
+ public void setBpParameter(BPParameter bpParameter) {
+ this.bpParameter = bpParameter;
+ }
+
+ public double getError() {
+ return error;
+ }
+
+ public void setError(double error) {
+ this.error = error;
+ }
+
+ public int getTimes() {
+ return times;
+ }
+
+ public void setTimes(int times) {
+ this.times = times;
+ }
+}
diff --git a/algorithm/src/main/java/com/mh/algorithm/bpnn/BPNeuralNetworkFactory.java b/algorithm/src/main/java/com/mh/algorithm/bpnn/BPNeuralNetworkFactory.java
new file mode 100644
index 0000000..3c6c8c8
--- /dev/null
+++ b/algorithm/src/main/java/com/mh/algorithm/bpnn/BPNeuralNetworkFactory.java
@@ -0,0 +1,257 @@
+package com.mh.algorithm.bpnn;
+
+import com.mh.algorithm.matrix.Matrix;
+import com.mh.algorithm.utils.MatrixUtil;
+
+import java.util.*;
+
+public class BPNeuralNetworkFactory {
+ /**
+ * 训练BP神经网络模型
+ * @param bpParameter
+ * @param inputAndOutput
+ * @return
+ */
+ public BPModel trainBP(BPParameter bpParameter, Matrix inputAndOutput) throws Exception {
+
+ ActivationFunction activationFunction = bpParameter.getActivationFunction();
+ int inputCount = bpParameter.getInputLayerNeuronCount();
+ int hiddenCount = bpParameter.getHiddenLayerNeuronCount();
+ int outputCount = bpParameter.getOutputLayerNeuronCount();
+ double normalizationMin = bpParameter.getNormalizationMin();
+ double normalizationMax = bpParameter.getNormalizationMax();
+ double step = bpParameter.getStep();
+ double momentumFactor = bpParameter.getMomentumFactor();
+ double precision = bpParameter.getPrecision();
+ int maxTimes = bpParameter.getMaxTimes();
+
+ if(inputAndOutput.getMatrixColCount() != inputCount + outputCount){
+ throw new Exception("神经元个数不符,请修改");
+ }
+ // 初始化权值
+ Matrix weightIJ = initWeight(inputCount, hiddenCount);
+ Matrix weightJP = initWeight(hiddenCount, outputCount);
+
+ // 初始化阈值
+ Matrix b1 = initThreshold(hiddenCount);
+ Matrix b2 = initThreshold(outputCount);
+
+ // 动量项
+ Matrix deltaWeightIJ0 = new Matrix(inputCount, hiddenCount);
+ Matrix deltaWeightJP0 = new Matrix(hiddenCount, outputCount);
+ Matrix deltaB10 = new Matrix(1, hiddenCount);
+ Matrix deltaB20 = new Matrix(1, outputCount);
+
+ // 截取输入矩阵和输出矩阵
+ Matrix input = inputAndOutput.subMatrix(0,inputAndOutput.getMatrixRowCount(),0,inputCount);
+ Matrix output = inputAndOutput.subMatrix(0,inputAndOutput.getMatrixRowCount(),inputCount,outputCount);
+
+ // 归一化
+ Map inputAfterNormalize = MatrixUtil.normalize(input, normalizationMin, normalizationMax);
+ input = (Matrix) inputAfterNormalize.get("res");
+
+ Map outputAfterNormalize = MatrixUtil.normalize(output, normalizationMin, normalizationMax);
+ output = (Matrix) outputAfterNormalize.get("res");
+
+ int times = 1;
+ double E = 0;//误差
+ while (times < maxTimes) {
+ /*-----------------正向传播---------------------*/
+ // 隐含层输入
+ Matrix jIn = input.multiple(weightIJ);
+ // 扩充阈值
+ Matrix b1Copy = b1.extend(2,jIn.getMatrixRowCount());
+ // 加上阈值
+ jIn = jIn.plus(b1Copy);
+ // 隐含层输出
+ Matrix jOut = computeValue(jIn,activationFunction);
+ // 输出层输入
+ Matrix pIn = jOut.multiple(weightJP);
+ // 扩充阈值
+ Matrix b2Copy = b2.extend(2, pIn.getMatrixRowCount());
+ // 加上阈值
+ pIn = pIn.plus(b2Copy);
+ // 输出层输出
+ Matrix pOut = computeValue(pIn,activationFunction);
+ // 计算误差
+ Matrix e = output.subtract(pOut);
+ E = computeE(e);//误差
+ // 判断是否符合精度
+ if (Math.abs(E) <= precision) {
+ System.out.println("满足精度");
+ break;
+ }
+
+ /*-----------------反向传播---------------------*/
+ // J与P之间权值修正量
+ Matrix deltaWeightJP = e.multiple(step);
+ deltaWeightJP = deltaWeightJP.pointMultiple(computeDerivative(pIn,activationFunction));
+ deltaWeightJP = deltaWeightJP.transpose().multiple(jOut);
+ deltaWeightJP = deltaWeightJP.transpose();
+ // P层神经元阈值修正量
+ Matrix deltaThresholdP = e.multiple(step);
+ deltaThresholdP = deltaThresholdP.transpose().multiple(computeDerivative(pIn, activationFunction));
+
+ // I与J之间的权值修正量
+ Matrix deltaO = e.pointMultiple(computeDerivative(pIn,activationFunction));
+ Matrix tmp = weightJP.multiple(deltaO.transpose()).transpose();
+ Matrix deltaWeightIJ = tmp.pointMultiple(computeDerivative(jIn, activationFunction));
+ deltaWeightIJ = input.transpose().multiple(deltaWeightIJ);
+ deltaWeightIJ = deltaWeightIJ.multiple(step);
+
+ // J层神经元阈值修正量
+ Matrix deltaThresholdJ = tmp.transpose().multiple(computeDerivative(jIn, activationFunction));
+ deltaThresholdJ = deltaThresholdJ.multiple(-step);
+
+ if (times == 1) {
+ // 更新权值与阈值
+ weightIJ = weightIJ.plus(deltaWeightIJ);
+ weightJP = weightJP.plus(deltaWeightJP);
+ b1 = b1.plus(deltaThresholdJ);
+ b2 = b2.plus(deltaThresholdP);
+ }else{
+ // 加动量项
+ weightIJ = weightIJ.plus(deltaWeightIJ).plus(deltaWeightIJ0.multiple(momentumFactor));
+ weightJP = weightJP.plus(deltaWeightJP).plus(deltaWeightJP0.multiple(momentumFactor));
+ b1 = b1.plus(deltaThresholdJ).plus(deltaB10.multiple(momentumFactor));
+ b2 = b2.plus(deltaThresholdP).plus(deltaB20.multiple(momentumFactor));
+ }
+
+ deltaWeightIJ0 = deltaWeightIJ;
+ deltaWeightJP0 = deltaWeightJP;
+ deltaB10 = deltaThresholdJ;
+ deltaB20 = deltaThresholdP;
+
+ times++;
+ }
+
+ // BP神经网络的输出
+ BPModel result = new BPModel();
+ result.setInputMax((Matrix) inputAfterNormalize.get("max"));
+ result.setInputMin((Matrix) inputAfterNormalize.get("min"));
+ result.setOutputMax((Matrix) outputAfterNormalize.get("max"));
+ result.setOutputMin((Matrix) outputAfterNormalize.get("min"));
+ result.setWeightIJ(weightIJ);
+ result.setWeightJP(weightJP);
+ result.setB1(b1);
+ result.setB2(b2);
+ result.setError(E);
+ result.setTimes(times);
+ result.setBpParameter(bpParameter);
+ System.out.println("循环次数:" + times + ",误差:" + E);
+
+ return result;
+ }
+
+ /**
+ * 计算BP神经网络的值
+ * @param bpModel
+ * @param input
+ * @return
+ */
+ public Matrix computeBP(BPModel bpModel,Matrix input) throws Exception {
+ if (input.getMatrixColCount() != bpModel.getBpParameter().getInputLayerNeuronCount()) {
+ throw new Exception("输入矩阵纬度有误");
+ }
+ ActivationFunction activationFunction = bpModel.getBpParameter().getActivationFunction();
+ Matrix weightIJ = bpModel.getWeightIJ();
+ Matrix weightJP = bpModel.getWeightJP();
+ Matrix b1 = bpModel.getB1();
+ Matrix b2 = bpModel.getB2();
+ double[][] normalizedInput = new double[input.getMatrixRowCount()][input.getMatrixColCount()];
+ for (int i = 0; i < input.getMatrixRowCount(); i++) {
+ for (int j = 0; j < input.getMatrixColCount(); j++) {
+ normalizedInput[i][j] = bpModel.getBpParameter().getNormalizationMin()
+ + (input.getValOfIdx(i,j) - bpModel.getInputMin().getValOfIdx(0,j))
+ / (bpModel.getInputMax().getValOfIdx(0,j) - bpModel.getInputMin().getValOfIdx(0,j))
+ * (bpModel.getBpParameter().getNormalizationMax() - bpModel.getBpParameter().getNormalizationMin());
+ }
+ }
+ Matrix normalizedInputMatrix = new Matrix(normalizedInput);
+ Matrix jIn = normalizedInputMatrix.multiple(weightIJ);
+ // 扩充阈值
+ Matrix b1Copy = b1.extend(2,jIn.getMatrixRowCount());
+ // 加上阈值
+ jIn = jIn.plus(b1Copy);
+ // 隐含层输出
+ Matrix jOut = computeValue(jIn,activationFunction);
+ // 输出层输入
+ Matrix pIn = jOut.multiple(weightJP);
+ // 扩充阈值
+ Matrix b2Copy = b2.extend(2,pIn.getMatrixRowCount());
+ // 加上阈值
+ pIn = pIn.plus(b2Copy);
+ // 输出层输出
+ Matrix pOut = computeValue(pIn,activationFunction);
+ // 反归一化
+ return MatrixUtil.inverseNormalize(pOut, bpModel.getBpParameter().getNormalizationMax(), bpModel.getBpParameter().getNormalizationMin(), bpModel.getOutputMax(), bpModel.getOutputMin());
+ }
+
+ // 初始化权值
+ private Matrix initWeight(int x,int y){
+ Random random=new Random();
+ double[][] weight = new double[x][y];
+ for (int i = 0; i < x; i++) {
+ for (int j = 0; j < y; j++) {
+ weight[i][j] = 2*random.nextDouble()-1;
+ }
+ }
+ return new Matrix(weight);
+ }
+ // 初始化阈值
+ private Matrix initThreshold(int x){
+ Random random = new Random();
+ double[][] result = new double[1][x];
+ for (int i = 0; i < x; i++) {
+ result[0][i] = 2*random.nextDouble()-1;
+ }
+ return new Matrix(result);
+ }
+
+ /**
+ * 计算激活函数的值
+ * @param a
+ * @return
+ */
+ private Matrix computeValue(Matrix a, ActivationFunction activationFunction) throws Exception {
+ if (a.getMatrix() == null) {
+ throw new Exception("参数值为空");
+ }
+ double[][] result = new double[a.getMatrixRowCount()][a.getMatrixColCount()];
+ for (int i = 0; i < a.getMatrixRowCount(); i++) {
+ for (int j = 0; j < a.getMatrixColCount(); j++) {
+ result[i][j] = activationFunction.computeValue(a.getValOfIdx(i,j));
+ }
+ }
+ return new Matrix(result);
+ }
+
+ /**
+ * 激活函数导数的值
+ * @param a
+ * @return
+ */
+ private Matrix computeDerivative(Matrix a , ActivationFunction activationFunction) throws Exception {
+ if (a.getMatrix() == null) {
+ throw new Exception("参数值为空");
+ }
+ double[][] result = new double[a.getMatrixRowCount()][a.getMatrixColCount()];
+ for (int i = 0; i < a.getMatrixRowCount(); i++) {
+ for (int j = 0; j < a.getMatrixColCount(); j++) {
+ result[i][j] = activationFunction.computeDerivative(a.getValOfIdx(i,j));
+ }
+ }
+ return new Matrix(result);
+ }
+
+
+ /**
+ * 计算误差
+ * @param e
+ * @return
+ */
+ private double computeE(Matrix e){
+ e = e.square();
+ return 0.5*e.sumAll();
+ }
+}
diff --git a/algorithm/src/main/java/com/mh/algorithm/bpnn/BPParameter.java b/algorithm/src/main/java/com/mh/algorithm/bpnn/BPParameter.java
new file mode 100644
index 0000000..7196136
--- /dev/null
+++ b/algorithm/src/main/java/com/mh/algorithm/bpnn/BPParameter.java
@@ -0,0 +1,106 @@
+package com.mh.algorithm.bpnn;
+
+import java.io.Serializable;
+
+public class BPParameter implements Serializable {
+
+ //输入层神经元个数
+ private int inputLayerNeuronCount = 3;
+ //隐含层神经元个数
+ private int hiddenLayerNeuronCount = 3;
+ //输出层神经元个数
+ private int outputLayerNeuronCount = 1;
+ //归一化区间
+ private double normalizationMin = 0.2;
+ private double normalizationMax = 0.8;
+ //学习步长
+ private double step = 0.05;
+ //动量因子
+ private double momentumFactor = 0.2;
+ //激活函数
+ private ActivationFunction activationFunction = new Sigmoid();
+ //精度
+ private double precision = 0.000001;
+ //最大循环次数
+ private int maxTimes = 1000000;
+
+ public double getMomentumFactor() {
+ return momentumFactor;
+ }
+
+ public void setMomentumFactor(double momentumFactor) {
+ this.momentumFactor = momentumFactor;
+ }
+
+ public double getStep() {
+ return step;
+ }
+
+ public void setStep(double step) {
+ this.step = step;
+ }
+
+ public double getNormalizationMin() {
+ return normalizationMin;
+ }
+
+ public void setNormalizationMin(double normalizationMin) {
+ this.normalizationMin = normalizationMin;
+ }
+
+ public double getNormalizationMax() {
+ return normalizationMax;
+ }
+
+ public void setNormalizationMax(double normalizationMax) {
+ this.normalizationMax = normalizationMax;
+ }
+
+ public int getInputLayerNeuronCount() {
+ return inputLayerNeuronCount;
+ }
+
+ public void setInputLayerNeuronCount(int inputLayerNeuronCount) {
+ this.inputLayerNeuronCount = inputLayerNeuronCount;
+ }
+
+ public int getHiddenLayerNeuronCount() {
+ return hiddenLayerNeuronCount;
+ }
+
+ public void setHiddenLayerNeuronCount(int hiddenLayerNeuronCount) {
+ this.hiddenLayerNeuronCount = hiddenLayerNeuronCount;
+ }
+
+ public int getOutputLayerNeuronCount() {
+ return outputLayerNeuronCount;
+ }
+
+ public void setOutputLayerNeuronCount(int outputLayerNeuronCount) {
+ this.outputLayerNeuronCount = outputLayerNeuronCount;
+ }
+
+ public ActivationFunction getActivationFunction() {
+ return activationFunction;
+ }
+
+ public void setActivationFunction(ActivationFunction activationFunction) {
+ this.activationFunction = activationFunction;
+ }
+
+ public double getPrecision() {
+ return precision;
+ }
+
+ public void setPrecision(double precision) {
+ this.precision = precision;
+ }
+
+ public int getMaxTimes() {
+ return maxTimes;
+ }
+
+ public void setMaxTimes(int maxTimes) {
+ this.maxTimes = maxTimes;
+ }
+}
diff --git a/algorithm/src/main/java/com/mh/algorithm/bpnn/Sigmoid.java b/algorithm/src/main/java/com/mh/algorithm/bpnn/Sigmoid.java
new file mode 100644
index 0000000..378060c
--- /dev/null
+++ b/algorithm/src/main/java/com/mh/algorithm/bpnn/Sigmoid.java
@@ -0,0 +1,15 @@
+package com.mh.algorithm.bpnn;
+
+import java.io.Serializable;
+
+public class Sigmoid implements ActivationFunction, Serializable {
+ @Override
+ public double computeValue(double val) {
+ return 1 / (1 + Math.exp(-val));
+ }
+
+ @Override
+ public double computeDerivative(double val) {
+ return computeValue(val) * (1 - computeValue(val));
+ }
+}
diff --git a/algorithm/src/main/java/com/mh/algorithm/constants/OrderEnum.java b/algorithm/src/main/java/com/mh/algorithm/constants/OrderEnum.java
new file mode 100644
index 0000000..1e47167
--- /dev/null
+++ b/algorithm/src/main/java/com/mh/algorithm/constants/OrderEnum.java
@@ -0,0 +1,24 @@
+package com.mh.algorithm.constants;
+
+/**
+ * 排序枚举类
+ */
+public enum OrderEnum {
+
+ ASC(1,"升序"),
+
+ DESC(2,"降序");
+
+ OrderEnum(int flag, String name) {
+
+ this.flag = flag;
+
+ this.name = name;
+
+ }
+
+ private int flag;
+
+ private String name;
+
+}
diff --git a/algorithm/src/main/java/com/mh/algorithm/knn/KNN.java b/algorithm/src/main/java/com/mh/algorithm/knn/KNN.java
new file mode 100644
index 0000000..fd6fbdb
--- /dev/null
+++ b/algorithm/src/main/java/com/mh/algorithm/knn/KNN.java
@@ -0,0 +1,88 @@
+package com.mh.algorithm.knn;
+
+import com.mh.algorithm.constants.OrderEnum;
+import com.mh.algorithm.matrix.Matrix;
+import com.mh.algorithm.utils.MatrixUtil;
+
+import java.util.*;
+
+
+/**
+ * @program: top-algorithm-set
+ * @description: KNN k-临近算法进行分类
+ * @author: Mr.Zhao
+ * @create: 2020-10-13 22:03
+ **/
+public class KNN {
+ public static Matrix classify(Matrix input, Matrix dataSet, Matrix labels, int k) throws Exception {
+ if (dataSet.getMatrixRowCount() != labels.getMatrixRowCount()) {
+ throw new IllegalArgumentException("矩阵训练集与标签维度不一致");
+ }
+ if (input.getMatrixColCount() != dataSet.getMatrixColCount()) {
+ throw new IllegalArgumentException("待分类矩阵列数与训练集列数不一致");
+ }
+ if (dataSet.getMatrixRowCount() < k) {
+ throw new IllegalArgumentException("训练集样本数小于k");
+ }
+ // 归一化
+ int trainCount = dataSet.getMatrixRowCount();
+ int testCount = input.getMatrixRowCount();
+ Matrix trainAndTest = dataSet.splice(2, input);
+ Map normalize = MatrixUtil.normalize(trainAndTest, 0, 1);
+ trainAndTest = (Matrix) normalize.get("res");
+ dataSet = trainAndTest.subMatrix(0, trainCount, 0, trainAndTest.getMatrixColCount());
+ input = trainAndTest.subMatrix(0, testCount, 0, trainAndTest.getMatrixColCount());
+
+ // 获取标签信息
+ List labelList = new ArrayList<>();
+ for (int i = 0; i < labels.getMatrixRowCount(); i++) {
+ if (!labelList.contains(labels.getValOfIdx(i, 0))) {
+ labelList.add(labels.getValOfIdx(i, 0));
+ }
+ }
+
+ Matrix result = new Matrix(new double[input.getMatrixRowCount()][1]);
+ for (int i = 0; i < input.getMatrixRowCount(); i++) {
+ // 计算向量间的欧式距离
+ // 将labels矩阵扩展
+ Matrix labelMatrixCopied = input.getRowOfIdx(i).extend(2, dataSet.getMatrixRowCount());
+ // 前面是计算欧氏距离,splice(1,labels)是将距离矩阵与labels矩阵合并
+ Matrix distanceMatrix = dataSet.subtract(labelMatrixCopied).square().sumRow().pow(0.5).splice(1, labels);
+ // 将计算出的距离矩阵按照距离升序排序
+ distanceMatrix.sort(0, OrderEnum.ASC);
+ // 遍历最近的k个变量
+ Map map = new HashMap<>();
+ for (int j = 0; j < k; j++) {
+ // 遍历标签种类数
+ for (Double label : labelList) {
+ if (distanceMatrix.getValOfIdx(j, 1) == label) {
+ map.put(label, map.getOrDefault(label, 0) + 1);
+ }
+ }
+ }
+ result.setValue(i, 0, getKeyOfMaxValue(map));
+ }
+ return result;
+ }
+
+ /**
+ * 取map中值最大的key
+ *
+ * @param map
+ * @return
+ */
+ private static Double getKeyOfMaxValue(Map map) {
+ if (map == null)
+ return null;
+ Double keyOfMaxValue = 0.0;
+ Integer maxValue = 0;
+ for (Double key : map.keySet()) {
+ if (map.get(key) > maxValue) {
+ keyOfMaxValue = key;
+ maxValue = map.get(key);
+ }
+ }
+ return keyOfMaxValue;
+ }
+
+}
diff --git a/algorithm/src/main/java/com/mh/algorithm/matrix/Matrix.java b/algorithm/src/main/java/com/mh/algorithm/matrix/Matrix.java
new file mode 100644
index 0000000..4820012
--- /dev/null
+++ b/algorithm/src/main/java/com/mh/algorithm/matrix/Matrix.java
@@ -0,0 +1,646 @@
+package com.mh.algorithm.matrix;
+
+import com.mh.algorithm.constants.OrderEnum;
+
+import java.io.Serializable;
+
+public class Matrix implements Serializable {
+ private double[][] matrix;
+ //矩阵列数
+ private int matrixColCount;
+ //矩阵行数
+ private int matrixRowCount;
+
+ /**
+ * 构造一个空矩阵
+ */
+ public Matrix() {
+ this.matrix = null;
+ this.matrixColCount = 0;
+ this.matrixRowCount = 0;
+ }
+
+ /**
+ * 构造一个matrix矩阵
+ * @param matrix
+ */
+ public Matrix(double[][] matrix) {
+ this.matrix = matrix;
+ this.matrixRowCount = matrix.length;
+ this.matrixColCount = matrix[0].length;
+ }
+
+ /**
+ * 构造一个rowCount行colCount列值为0的矩阵
+ * @param rowCount
+ * @param colCount
+ */
+ public Matrix(int rowCount,int colCount) {
+ double[][] matrix = new double[rowCount][colCount];
+ for (int i = 0; i < rowCount; i++) {
+ for (int j = 0; j < colCount; j++) {
+ matrix[i][j] = 0;
+ }
+ }
+ this.matrix = matrix;
+ this.matrixRowCount = rowCount;
+ this.matrixColCount = colCount;
+ }
+
+ /**
+ * 构造一个rowCount行colCount列值为val的矩阵
+ * @param val
+ * @param rowCount
+ * @param colCount
+ */
+ public Matrix(double val,int rowCount,int colCount) {
+ double[][] matrix = new double[rowCount][colCount];
+ for (int i = 0; i < rowCount; i++) {
+ for (int j = 0; j < colCount; j++) {
+ matrix[i][j] = val;
+ }
+ }
+ this.matrix = matrix;
+ this.matrixRowCount = rowCount;
+ this.matrixColCount = colCount;
+ }
+
+ public double[][] getMatrix() {
+ return matrix;
+ }
+
+ public void setMatrix(double[][] matrix) {
+ this.matrix = matrix;
+ this.matrixRowCount = matrix.length;
+ this.matrixColCount = matrix[0].length;
+ }
+
+ public int getMatrixColCount() {
+ return matrixColCount;
+ }
+
+ public int getMatrixRowCount() {
+ return matrixRowCount;
+ }
+
+ /**
+ * 获取矩阵指定位置的值
+ *
+ * @param x
+ * @param y
+ * @return
+ */
+ public double getValOfIdx(int x, int y) throws IllegalArgumentException {
+ if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
+ throw new IllegalArgumentException("矩阵为空");
+ }
+ if (x > matrixRowCount - 1) {
+ throw new IllegalArgumentException("索引x越界");
+ }
+ if (y > matrixColCount - 1) {
+ throw new IllegalArgumentException("索引y越界");
+ }
+ return matrix[x][y];
+ }
+
+ /**
+ * 获取矩阵指定行
+ *
+ * @param x
+ * @return
+ */
+ public Matrix getRowOfIdx(int x) throws IllegalArgumentException {
+ if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
+ throw new IllegalArgumentException("矩阵为空");
+ }
+ if (x > matrixRowCount - 1) {
+ throw new IllegalArgumentException("索引x越界");
+ }
+ double[][] result = new double[1][matrixColCount];
+ result[0] = matrix[x];
+ return new Matrix(result);
+ }
+
+ /**
+ * 获取矩阵指定列
+ *
+ * @param y
+ * @return
+ */
+ public Matrix getColOfIdx(int y) throws IllegalArgumentException {
+ if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
+ throw new IllegalArgumentException("矩阵为空");
+ }
+ if (y > matrixColCount - 1) {
+ throw new IllegalArgumentException("索引y越界");
+ }
+ double[][] result = new double[matrixRowCount][1];
+ for (int i = 0; i < matrixRowCount; i++) {
+ result[i][0] = matrix[i][y];
+ }
+ return new Matrix(result);
+ }
+
+ /**
+ * 设置矩阵中x,y位置元素的值
+ * @param x
+ * @param y
+ * @param val
+ */
+ public void setValue(int x, int y, double val) {
+ if (x > this.matrixRowCount - 1) {
+ throw new IllegalArgumentException("行索引越界");
+ }
+ if (y > this.matrixColCount - 1) {
+ throw new IllegalArgumentException("列索引越界");
+ }
+ this.matrix[x][y] = val;
+ }
+
+ /**
+ * 矩阵乘矩阵
+ *
+ * @param a
+ * @return
+ * @throws IllegalArgumentException
+ */
+ public Matrix multiple(Matrix a) throws IllegalArgumentException {
+ if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
+ throw new IllegalArgumentException("矩阵为空");
+ }
+ if (a.getMatrix() == null || a.getMatrixRowCount() == 0 || a.getMatrixColCount() == 0) {
+ throw new IllegalArgumentException("参数矩阵为空");
+ }
+ if (matrixColCount != a.getMatrixRowCount()) {
+ throw new IllegalArgumentException("矩阵纬度不同,不可计算");
+ }
+ double[][] result = new double[matrixRowCount][a.getMatrixColCount()];
+ for (int i = 0; i < matrixRowCount; i++) {
+ for (int j = 0; j < a.getMatrixColCount(); j++) {
+ for (int k = 0; k < matrixColCount; k++) {
+ result[i][j] = result[i][j] + matrix[i][k] * a.getMatrix()[k][j];
+ }
+ }
+ }
+ return new Matrix(result);
+ }
+
+ /**
+ * 矩阵乘一个数字
+ *
+ * @param a
+ * @return
+ */
+ public Matrix multiple(double a) throws IllegalArgumentException {
+ if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
+ throw new IllegalArgumentException("矩阵为空");
+ }
+ double[][] result = new double[matrixRowCount][matrixColCount];
+ for (int i = 0; i < matrixRowCount; i++) {
+ for (int j = 0; j < matrixColCount; j++) {
+ result[i][j] = matrix[i][j] * a;
+ }
+ }
+ return new Matrix(result);
+ }
+
+ /**
+ * 矩阵点乘
+ *
+ * @param a
+ * @return
+ */
+ public Matrix pointMultiple(Matrix a) throws IllegalArgumentException {
+ if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
+ throw new IllegalArgumentException("矩阵为空");
+ }
+ if (a.getMatrix() == null || a.getMatrixRowCount() == 0 || a.getMatrixColCount() == 0) {
+ throw new IllegalArgumentException("参数矩阵为空");
+ }
+ if (matrixRowCount != a.getMatrixRowCount() && matrixColCount != a.getMatrixColCount()) {
+ throw new IllegalArgumentException("矩阵纬度不同,不可计算");
+ }
+ double[][] result = new double[matrixRowCount][matrixColCount];
+ for (int i = 0; i < matrixRowCount; i++) {
+ for (int j = 0; j < matrixColCount; j++) {
+ result[i][j] = matrix[i][j] * a.getMatrix()[i][j];
+ }
+ }
+ return new Matrix(result);
+ }
+
+ /**
+ * 矩阵除一个数字
+ * @param a
+ * @return
+ * @throws IllegalArgumentException
+ */
+ public Matrix divide(double a) throws IllegalArgumentException {
+ if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
+ throw new IllegalArgumentException("矩阵为空");
+ }
+ double[][] result = new double[matrixRowCount][matrixColCount];
+ for (int i = 0; i < matrixRowCount; i++) {
+ for (int j = 0; j < matrixColCount; j++) {
+ result[i][j] = matrix[i][j] / a;
+ }
+ }
+ return new Matrix(result);
+ }
+
+ /**
+ * 矩阵加法
+ *
+ * @param a
+ * @return
+ */
+ public Matrix plus(Matrix a) throws IllegalArgumentException {
+ if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
+ throw new IllegalArgumentException("矩阵为空");
+ }
+ if (a.getMatrix() == null || a.getMatrixRowCount() == 0 || a.getMatrixColCount() == 0) {
+ throw new IllegalArgumentException("参数矩阵为空");
+ }
+ if (matrixRowCount != a.getMatrixRowCount() && matrixColCount != a.getMatrixColCount()) {
+ throw new IllegalArgumentException("矩阵纬度不同,不可计算");
+ }
+ double[][] result = new double[matrixRowCount][matrixColCount];
+ for (int i = 0; i < matrixRowCount; i++) {
+ for (int j = 0; j < matrixColCount; j++) {
+ result[i][j] = matrix[i][j] + a.getMatrix()[i][j];
+ }
+ }
+ return new Matrix(result);
+ }
+
+ /**
+ * 矩阵加一个数字
+ * @param a
+ * @return
+ * @throws IllegalArgumentException
+ */
+ public Matrix plus(double a) throws IllegalArgumentException {
+ if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
+ throw new IllegalArgumentException("矩阵为空");
+ }
+ double[][] result = new double[matrixRowCount][matrixColCount];
+ for (int i = 0; i < matrixRowCount; i++) {
+ for (int j = 0; j < matrixColCount; j++) {
+ result[i][j] = matrix[i][j] + a;
+ }
+ }
+ return new Matrix(result);
+ }
+
+ /**
+ * 矩阵减法
+ *
+ * @param a
+ * @return
+ */
+ public Matrix subtract(Matrix a) throws IllegalArgumentException {
+ if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
+ throw new IllegalArgumentException("矩阵为空");
+ }
+ if (a.getMatrix() == null || a.getMatrixRowCount() == 0 || a.getMatrixColCount() == 0) {
+ throw new IllegalArgumentException("参数矩阵为空");
+ }
+ if (matrixRowCount != a.getMatrixRowCount() && matrixColCount != a.getMatrixColCount()) {
+ throw new IllegalArgumentException("矩阵纬度不同,不可计算");
+ }
+ double[][] result = new double[matrixRowCount][matrixColCount];
+ for (int i = 0; i < matrixRowCount; i++) {
+ for (int j = 0; j < matrixColCount; j++) {
+ result[i][j] = matrix[i][j] - a.getMatrix()[i][j];
+ }
+ }
+ return new Matrix(result);
+ }
+
+ /**
+ * 矩阵减一个数字
+ * @param a
+ * @return
+ * @throws IllegalArgumentException
+ */
+ public Matrix subtract(double a) throws IllegalArgumentException {
+ if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
+ throw new IllegalArgumentException("矩阵为空");
+ }
+ double[][] result = new double[matrixRowCount][matrixColCount];
+ for (int i = 0; i < matrixRowCount; i++) {
+ for (int j = 0; j < matrixColCount; j++) {
+ result[i][j] = matrix[i][j] - a;
+ }
+ }
+ return new Matrix(result);
+ }
+
+ /**
+ * 矩阵行求和
+ *
+ * @return
+ */
+ public Matrix sumRow() throws IllegalArgumentException {
+ if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
+ throw new IllegalArgumentException("矩阵为空");
+ }
+ double[][] result = new double[matrixRowCount][1];
+ for (int i = 0; i < matrixRowCount; i++) {
+ for (int j = 0; j < matrixColCount; j++) {
+ result[i][0] += matrix[i][j];
+ }
+ }
+ return new Matrix(result);
+ }
+
+ /**
+ * 矩阵列求和
+ *
+ * @return
+ */
+ public Matrix sumCol() throws IllegalArgumentException {
+ if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
+ throw new IllegalArgumentException("矩阵为空");
+ }
+ double[][] result = new double[1][matrixColCount];
+ for (int i = 0; i < matrixRowCount; i++) {
+ for (int j = 0; j < matrixColCount; j++) {
+ result[0][j] += matrix[i][j];
+ }
+ }
+ return new Matrix(result);
+ }
+
+ /**
+ * 矩阵所有元素求和
+ *
+ * @return
+ */
+ public double sumAll() throws IllegalArgumentException {
+ if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
+ throw new IllegalArgumentException("矩阵为空");
+ }
+ double result = 0;
+ for (double[] doubles : matrix) {
+ for (int j = 0; j < matrixColCount; j++) {
+ result += doubles[j];
+ }
+ }
+ return result;
+ }
+
+ /**
+ * 矩阵所有元素求平方
+ *
+ * @return
+ */
+ public Matrix square() throws IllegalArgumentException {
+ if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
+ throw new IllegalArgumentException("矩阵为空");
+ }
+ double[][] result = new double[matrixRowCount][matrixColCount];
+ for (int i = 0; i < matrixRowCount; i++) {
+ for (int j = 0; j < matrixColCount; j++) {
+ result[i][j] = matrix[i][j] * matrix[i][j];
+ }
+ }
+ return new Matrix(result);
+ }
+
+ /**
+ * 矩阵所有元素求N次方
+ *
+ * @return
+ */
+ public Matrix pow(double n) throws IllegalArgumentException {
+ if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
+ throw new IllegalArgumentException("矩阵为空");
+ }
+ double[][] result = new double[matrixRowCount][matrixColCount];
+ for (int i = 0; i < matrixRowCount; i++) {
+ for (int j = 0; j < matrixColCount; j++) {
+ result[i][j] = Math.pow(matrix[i][j],n);
+ }
+ }
+ return new Matrix(result);
+ }
+
+ /**
+ * 矩阵转置
+ *
+ * @return
+ */
+ public Matrix transpose() throws IllegalArgumentException {
+ if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
+ throw new IllegalArgumentException("矩阵为空");
+ }
+ double[][] result = new double[matrixColCount][matrixRowCount];
+ for (int i = 0; i < matrixRowCount; i++) {
+ for (int j = 0; j < matrixColCount; j++) {
+ result[j][i] = matrix[i][j];
+ }
+ }
+ return new Matrix(result);
+ }
+
+ /**
+ * 截取矩阵
+ * @param startRowIndex 开始行索引
+ * @param rowCount 截取行数
+ * @param startColIndex 开始列索引
+ * @param colCount 截取列数
+ * @return
+ * @throws IllegalArgumentException
+ */
+ public Matrix subMatrix(int startRowIndex,int rowCount,int startColIndex,int colCount) throws IllegalArgumentException {
+ if (startRowIndex + rowCount > matrixRowCount) {
+ throw new IllegalArgumentException("行索引越界");
+ }
+ if (startColIndex + colCount> matrixColCount) {
+ throw new IllegalArgumentException("列索引越界");
+ }
+ double[][] result = new double[rowCount][colCount];
+ for (int i = startRowIndex; i < startRowIndex + rowCount; i++) {
+ if (startColIndex + colCount - startColIndex >= 0)
+ System.arraycopy(matrix[i], startColIndex, result[i - startRowIndex], 0, colCount);
+ }
+ return new Matrix(result);
+ }
+
+ /**
+ * 矩阵合并
+ * @param direction 合并方向,1为横向,2为竖向
+ * @param a
+ * @return
+ * @throws IllegalArgumentException
+ */
+ public Matrix splice(int direction, Matrix a) throws IllegalArgumentException {
+ if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
+ throw new IllegalArgumentException("矩阵为空");
+ }
+ if (a.getMatrix() == null || a.getMatrixRowCount() == 0 || a.getMatrixColCount() == 0) {
+ throw new IllegalArgumentException("参数矩阵为空");
+ }
+ if(direction == 1){
+ //横向拼接
+ if (matrixRowCount != a.getMatrixRowCount()) {
+ throw new IllegalArgumentException("矩阵行数不一致,无法拼接");
+ }
+ double[][] result = new double[matrixRowCount][matrixColCount + a.getMatrixColCount()];
+ for (int i = 0; i < matrixRowCount; i++) {
+ System.arraycopy(matrix[i],0,result[i],0,matrixColCount);
+ System.arraycopy(a.getMatrix()[i],0,result[i],matrixColCount,a.getMatrixColCount());
+ }
+ return new Matrix(result);
+ }else if(direction == 2){
+ //纵向拼接
+ if (matrixColCount != a.getMatrixColCount()) {
+ throw new IllegalArgumentException("矩阵列数不一致,无法拼接");
+ }
+ double[][] result = new double[matrixRowCount + a.getMatrixRowCount()][matrixColCount];
+ for (int i = 0; i < matrixRowCount; i++) {
+ result[i] = matrix[i];
+ }
+ for (int i = 0; i < a.getMatrixRowCount(); i++) {
+ result[matrixRowCount + i] = a.getMatrix()[i];
+ }
+ return new Matrix(result);
+ }else{
+ throw new IllegalArgumentException("方向参数有误");
+ }
+ }
+ /**
+ * 扩展矩阵
+ * @param direction 扩展方向,1为横向,2为竖向
+ * @param a
+ * @return
+ * @throws IllegalArgumentException
+ */
+ public Matrix extend(int direction , int a) throws IllegalArgumentException {
+ if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
+ throw new IllegalArgumentException("矩阵为空");
+ }
+ if(direction == 1){
+ //横向复制
+ double[][] result = new double[matrixRowCount][matrixColCount*a];
+ for (int i = 0; i < matrixRowCount; i++) {
+ for (int j = 0; j < a; j++) {
+ System.arraycopy(matrix[i],0,result[i],j*matrixColCount,matrixColCount);
+ }
+ }
+ return new Matrix(result);
+ }else if(direction == 2){
+ //纵向复制
+ double[][] result = new double[matrixRowCount*a][matrixColCount];
+ for (int i = 0; i < matrixRowCount*a; i++) {
+ result[i] = matrix[i%matrixRowCount];
+ }
+ return new Matrix(result);
+ }else{
+ throw new IllegalArgumentException("方向参数有误");
+ }
+ }
+ /**
+ * 获取每列的平均值
+ * @return
+ * @throws IllegalArgumentException
+ */
+ public Matrix getColAvg() throws IllegalArgumentException {
+ Matrix tmp = this.sumCol();
+ return tmp.divide(matrixRowCount);
+ }
+
+ /**
+ * 矩阵行排序
+ * @param index 根据第几列的数进行行排序
+ * @param order 排序顺序,升序或降序
+ * @return
+ * @throws IllegalArgumentException
+ */
+ public void sort(int index, OrderEnum order) throws IllegalArgumentException{
+ if (matrix == null || matrixRowCount == 0 || matrixColCount == 0) {
+ throw new IllegalArgumentException("矩阵为空");
+ }
+ if(index >= matrixColCount){
+ throw new IllegalArgumentException("排序索引index越界");
+ }
+ sort(index,order,0,this.matrixRowCount - 1);
+ }
+
+ /**
+ * 判断是否是方阵
+ * 行列数相等,并且不等于0
+ * @return
+ */
+ public boolean isSquareMatrix(){
+ return matrixColCount == matrixRowCount && matrixColCount != 0;
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder stringBuilder = new StringBuilder();
+ stringBuilder.append("\r\n");
+ for (int i = 0; i < matrixRowCount; i++) {
+ stringBuilder.append("# ");
+ for (int j = 0; j < matrixColCount; j++) {
+ stringBuilder.append(matrix[i][j]).append("\t ");
+ }
+ stringBuilder.append("#\r\n");
+ }
+ stringBuilder.append("\r\n");
+ return stringBuilder.toString();
+ }
+
+ private void sort(int index,OrderEnum order,int start,int end){
+ if(start >= end){
+ return;
+ }
+ int tmp = partition(index,order,start,end);
+ sort(index,order, start, tmp - 1);
+ sort(index,order, tmp + 1, end);
+ }
+
+ private int partition(int index,OrderEnum order,int start,int end){
+ int l = start + 1,r = end;
+ double v = matrix[start][index];
+ switch (order){
+ case ASC:
+ while(true){
+ while(matrix[r][index] >= v && r > start){
+ r--;
+ }
+ while(matrix[l][index] <= v && l < end){
+ l++;
+ }
+ if(l >= r){
+ break;
+ }
+ double[] tmp = matrix[r];
+ matrix[r] = matrix[l];
+ matrix[l] = tmp;
+ }
+ break;
+ case DESC:
+ while(true){
+ while(matrix[r][index] <= v && r > start){
+ r--;
+ }
+ while(matrix[l][index] >= v && l < end){
+ l++;
+ }
+ if(l >= r){
+ break;
+ }
+ double[] tmp = matrix[r];
+ matrix[r] = matrix[l];
+ matrix[l] = tmp;
+ }
+ break;
+ }
+ double[] tmp = matrix[r];
+ matrix[r] = matrix[start];
+ matrix[start] = tmp;
+ return r;
+ }
+}
diff --git a/algorithm/src/main/java/com/mh/algorithm/utils/CsvInfo.java b/algorithm/src/main/java/com/mh/algorithm/utils/CsvInfo.java
new file mode 100644
index 0000000..7c12d44
--- /dev/null
+++ b/algorithm/src/main/java/com/mh/algorithm/utils/CsvInfo.java
@@ -0,0 +1,53 @@
+package com.mh.algorithm.utils;
+
+import com.mh.algorithm.matrix.Matrix;
+
+import java.util.ArrayList;
+
+public class CsvInfo {
+ private String[] header;
+ private int csvRowCount;
+ private int csvColCount;
+ private ArrayList csvFileList;
+
+ public String[] getHeader() {
+ return header;
+ }
+
+ public void setHeader(String[] header) {
+ this.header = header;
+ }
+
+ public int getCsvRowCount() {
+ return csvRowCount;
+ }
+
+ public int getCsvColCount() {
+ return csvColCount;
+ }
+
+ public ArrayList getCsvFileList() {
+ return csvFileList;
+ }
+
+ public void setCsvFileList(ArrayList csvFileList) {
+ this.csvFileList = csvFileList;
+ this.csvColCount = csvFileList.get(0) != null?csvFileList.get(0).length:0;
+ this.csvRowCount = csvFileList.size();
+ }
+
+ public Matrix toMatrix() throws Exception {
+ double[][] arr = new double[csvFileList.size()][csvFileList.get(0).length];
+ for (int i = 0; i < csvFileList.size(); i++) {
+ for (int j = 0; j < csvFileList.get(0).length; j++) {
+ try {
+ arr[i][j] = Double.parseDouble(csvFileList.get(i)[j]);
+ }catch (NumberFormatException e){
+ throw new Exception("Csv中含有非数字字符,无法转换成Matrix对象");
+ }
+ }
+ }
+ return new Matrix(arr);
+ }
+
+}
diff --git a/algorithm/src/main/java/com/mh/algorithm/utils/CsvUtil.java b/algorithm/src/main/java/com/mh/algorithm/utils/CsvUtil.java
new file mode 100644
index 0000000..557e0f1
--- /dev/null
+++ b/algorithm/src/main/java/com/mh/algorithm/utils/CsvUtil.java
@@ -0,0 +1,66 @@
+package com.mh.algorithm.utils;
+
+import com.csvreader.CsvReader;
+import com.csvreader.CsvWriter;
+import com.mh.algorithm.matrix.Matrix;
+
+import java.io.IOException;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+
+public class CsvUtil {
+ /**
+ * 获取CSV中的信息
+ * @param hasHeader 是否含有表头
+ * @param path CSV文件的路径
+ * @return
+ * @throws IOException
+ */
+ public static CsvInfo getCsvInfo(boolean hasHeader , String path) throws IOException {
+ //创建csv对象,存储csv中的信息
+ CsvInfo csvInfo = new CsvInfo();
+ //获取CsvReader流
+ CsvReader csvReader = new CsvReader(path, ',', StandardCharsets.UTF_8);
+ if(hasHeader){
+ csvReader.readHeaders();
+ }
+ //获取Csv中的所有记录
+ ArrayList csvFileList = new ArrayList();
+ while (csvReader.readRecord()) {
+ csvFileList.add(csvReader.getValues());
+ }
+ //赋值
+ csvInfo.setHeader(csvReader.getHeaders());
+ csvInfo.setCsvFileList(csvFileList);
+ //关闭流
+ csvReader.close();
+ return csvInfo;
+ }
+
+ /**
+ * 将矩阵写入到csv文件中
+ * @param header 表头
+ * @param data 以矩阵形式存放的数据
+ * @param path 写入的文件地址
+ * @throws Exception
+ */
+ public static void createCsvFile(String[] header,Matrix data,String path) throws Exception {
+
+ if (header!=null && header.length != data.getMatrixColCount()) {
+ throw new Exception("表头列数与数据列数不符");
+ }
+ CsvWriter csvWriter = new CsvWriter(path, ',', StandardCharsets.UTF_8);
+
+ if (header != null) {
+ csvWriter.writeRecord(header);
+ }
+ for (int i = 0; i < data.getMatrixRowCount(); i++) {
+ String[] record = new String[data.getMatrixColCount()];
+ for (int j = 0; j < data.getMatrixColCount(); j++) {
+ record[j] = data.getValOfIdx(i, j)+"";
+ }
+ csvWriter.writeRecord(record);
+ }
+ csvWriter.close();
+ }
+}
diff --git a/algorithm/src/main/java/com/mh/algorithm/utils/DoubleUtil.java b/algorithm/src/main/java/com/mh/algorithm/utils/DoubleUtil.java
new file mode 100644
index 0000000..396097b
--- /dev/null
+++ b/algorithm/src/main/java/com/mh/algorithm/utils/DoubleUtil.java
@@ -0,0 +1,20 @@
+package com.mh.algorithm.utils;
+
+/**
+ * @program: top-algorithm-set
+ * @description: DoubleTool
+ * @author: Mr.Zhao
+ * @create: 2020-11-12 21:54
+ **/
+public class DoubleUtil {
+
+ private static final Double MAX_ERROR = 0.0001;
+
+ public static boolean equals(Double a, Double b) {
+ return Math.abs(a - b)< MAX_ERROR;
+ }
+
+ public static boolean equals(Double a, Double b,Double maxError) {
+ return Math.abs(a - b)< maxError;
+ }
+}
diff --git a/algorithm/src/main/java/com/mh/algorithm/utils/MatrixUtil.java b/algorithm/src/main/java/com/mh/algorithm/utils/MatrixUtil.java
new file mode 100644
index 0000000..b242a28
--- /dev/null
+++ b/algorithm/src/main/java/com/mh/algorithm/utils/MatrixUtil.java
@@ -0,0 +1,285 @@
+package com.mh.algorithm.utils;
+
+import Jama.EigenvalueDecomposition;
+import com.mh.algorithm.matrix.Matrix;
+
+import java.util.*;
+
+public class MatrixUtil {
+ /**
+ * 创建一个单位矩阵
+ * @param matrixRowCount 单位矩阵的纬度
+ * @return
+ */
+ public static Matrix eye(int matrixRowCount){
+ double[][] result = new double[matrixRowCount][matrixRowCount];
+ for (int i = 0; i < matrixRowCount; i++) {
+ for (int j = 0; j < matrixRowCount; j++) {
+ if(i == j){
+ result[i][j] = 1;
+ }else{
+ result[i][j] = 0;
+ }
+ }
+ }
+ return new Matrix(result);
+ }
+
+ /**
+ * 求矩阵的逆
+ * 原理:AE=EA^-1
+ * @param a
+ * @return
+ * @throws Exception
+ */
+ public static Matrix inv(Matrix a) throws Exception {
+ if (!invable(a)) {
+ throw new Exception("矩阵不可逆");
+ }
+ // [a|E]
+ Matrix b = a.splice(1, eye(a.getMatrixRowCount()));
+ double[][] data = b.getMatrix();
+ int rowCount = b.getMatrixRowCount();
+ int colCount = b.getMatrixColCount();
+ //此处应用a的列数,为简化,直接用b的行数
+ for (int j = 0; j < rowCount; j++) {
+ //若遇到0则交换两行
+ int notZeroRow = -2;
+ if(data[j][j] == 0){
+ notZeroRow = -1;
+ for (int l = j; l < rowCount; l++) {
+ if (data[l][j] != 0) {
+ notZeroRow = l;
+ break;
+ }
+ }
+ }
+ if (notZeroRow == -1) {
+ throw new Exception("矩阵不可逆");
+ }else if(notZeroRow != -2){
+ //交换j与notZeroRow两行
+ double[] tmp = data[j];
+ data[j] = data[notZeroRow];
+ data[notZeroRow] = tmp;
+ }
+ //将第data[j][j]化为1
+ if (data[j][j] != 1) {
+ double multiple = data[j][j];
+ for (int colIdx = j; colIdx < colCount; colIdx++) {
+ data[j][colIdx] /= multiple;
+ }
+ }
+ //行与行相减
+ for (int i = 0; i < rowCount; i++) {
+ if (i != j) {
+ double multiple = data[i][j] / data[j][j];
+ //遍历行中的列
+ for (int k = j; k < colCount; k++) {
+ data[i][k] = data[i][k] - multiple * data[j][k];
+ }
+ }
+ }
+ }
+ Matrix result = new Matrix(data);
+ return result.subMatrix(0, rowCount, rowCount, rowCount);
+ }
+
+ /**
+ * 求矩阵的伴随矩阵
+ * 原理:A*=|A|A^-1
+ * @param a
+ * @return
+ * @throws Exception
+ */
+ public static Matrix adj(Matrix a) throws Exception {
+ return inv(a).multiple(det(a));
+ }
+
+ /**
+ * 矩阵转成上三角矩阵
+ * @param a
+ * @return
+ * @throws Exception
+ */
+ public static Matrix getTopTriangle(Matrix a) throws Exception {
+ if (!a.isSquareMatrix()) {
+ throw new Exception("不是方阵无法进行计算");
+ }
+ int matrixHeight = a.getMatrixRowCount();
+ double[][] result = a.getMatrix();
+ //遍历列
+ for (int j = 0; j < matrixHeight; j++) {
+ //遍历行
+ for (int i = j+1; i < matrixHeight; i++) {
+ //若遇到0则交换两行
+ int notZeroRow = -2;
+ if(result[j][j] == 0){
+ notZeroRow = -1;
+ for (int l = i; l < matrixHeight; l++) {
+ if (result[l][j] != 0) {
+ notZeroRow = l;
+ break;
+ }
+ }
+ }
+ if (notZeroRow == -1) {
+ throw new Exception("矩阵不可逆");
+ }else if(notZeroRow != -2){
+ //交换j与notZeroRow两行
+ double[] tmp = result[j];
+ result[j] = result[notZeroRow];
+ result[notZeroRow] = tmp;
+ }
+
+ double multiple = result[i][j]/result[j][j];
+ //遍历行中的列
+ for (int k = j; k < matrixHeight; k++) {
+ result[i][k] = result[i][k] - multiple * result[j][k];
+ }
+ }
+ }
+ return new Matrix(result);
+ }
+
+ /**
+ * 计算矩阵的行列式
+ * @param a
+ * @return
+ * @throws Exception
+ */
+ public static double det(Matrix a) throws Exception {
+ //将矩阵转成上三角矩阵
+ Matrix b = MatrixUtil.getTopTriangle(a);
+ double result = 1;
+ //计算矩阵行列式
+ for (int i = 0; i < b.getMatrixRowCount(); i++) {
+ result *= b.getValOfIdx(i, i);
+ }
+ return result;
+ }
+ /**
+ * 获取协方差矩阵
+ * @param a
+ * @return
+ * @throws Exception
+ */
+ public static Matrix cov(Matrix a) throws Exception {
+ if (a.getMatrix() == null) {
+ throw new Exception("矩阵为空");
+ }
+ Matrix avg = a.getColAvg().extend(2, a.getMatrixRowCount());
+ Matrix tmp = a.subtract(avg);
+ return tmp.transpose().multiple(tmp).multiple(1/((double) a.getMatrixRowCount() -1));
+ }
+
+ /**
+ * 判断矩阵是否可逆
+ * 如果可转为上三角矩阵则可逆
+ * @param a
+ * @return
+ */
+ public static boolean invable(Matrix a) {
+ try {
+ getTopTriangle(a);
+ return true;
+ } catch (Exception e) {
+ return false;
+ }
+ }
+
+ /**
+ * 获取矩阵的特征值矩阵,调用Jama中的getV方法
+ * @param a
+ * @return
+ */
+ public static Matrix getV(Matrix a) {
+ EigenvalueDecomposition eig = new EigenvalueDecomposition(new Jama.Matrix(a.getMatrix()));
+ return new Matrix(eig.getV().getArray());
+ }
+
+ /**
+ * 取特征值实部
+ * @param a
+ * @return
+ */
+ public double[] getRealEigenvalues(Matrix a){
+ EigenvalueDecomposition eig = new EigenvalueDecomposition(new Jama.Matrix(a.getMatrix()));
+ return eig.getRealEigenvalues();
+ }
+
+ /**
+ * 取特征值虚部
+ * @param a
+ * @return
+ */
+ public double[] getImagEigenvalues(Matrix a){
+ EigenvalueDecomposition eig = new EigenvalueDecomposition(new Jama.Matrix(a.getMatrix()));
+ return eig.getImagEigenvalues();
+ }
+
+ /**
+ * 取块对角特征值矩阵
+ * @param a
+ * @return
+ */
+ public static Matrix getD(Matrix a) {
+ EigenvalueDecomposition eig = new EigenvalueDecomposition(new Jama.Matrix(a.getMatrix()));
+ return new Matrix(eig.getD().getArray());
+ }
+
+ /**
+ * 数据归一化
+ * @param a 要归一化的数据
+ * @param normalizationMin 要归一化的区间下限
+ * @param normalizationMax 要归一化的区间上限
+ * @return
+ */
+ public static Map normalize(Matrix a, double normalizationMin, double normalizationMax) throws Exception {
+ HashMap result = new HashMap<>();
+ double[][] maxArr = new double[1][a.getMatrixColCount()];
+ double[][] minArr = new double[1][a.getMatrixColCount()];
+ double[][] res = new double[a.getMatrixRowCount()][a.getMatrixColCount()];
+ for (int i = 0; i < a.getMatrixColCount(); i++) {
+ List tmp = new ArrayList();
+ for (int j = 0; j < a.getMatrixRowCount(); j++) {
+ tmp.add(a.getValOfIdx(j,i));
+ }
+ double max = (double) Collections.max(tmp);
+ double min = (double) Collections.min(tmp);
+ //数据归一化(注:若max与min均为0则不需要归一化)
+ if (max != 0 || min != 0) {
+ for (int j = 0; j < a.getMatrixRowCount(); j++) {
+ res[j][i] = normalizationMin + (a.getValOfIdx(j,i) - min) / (max - min) * (normalizationMax - normalizationMin);
+ }
+ }
+ maxArr[0][i] = max;
+ minArr[0][i] = min;
+ }
+ result.put("max", new Matrix(maxArr));
+ result.put("min", new Matrix(minArr));
+ result.put("res", new Matrix(res));
+ return result;
+ }
+
+ /**
+ * 反归一化
+ * @param a 要反归一化的数据
+ * @param normalizationMin 要反归一化的区间下限
+ * @param normalizationMax 要反归一化的区间上限
+ * @param dataMax 数据最大值
+ * @param dataMin 数据最小值
+ * @return
+ */
+ public static Matrix inverseNormalize(Matrix a, double normalizationMax, double normalizationMin , Matrix dataMax,Matrix dataMin){
+ double[][] res = new double[a.getMatrixRowCount()][a.getMatrixColCount()];
+ for (int i = 0; i < a.getMatrixColCount(); i++) {
+ //数据反归一化
+ if (dataMin.getValOfIdx(0,i) != 0 || dataMax.getValOfIdx(0,i) != 0) {
+ for (int j = 0; j < a.getMatrixRowCount(); j++) {
+ res[j][i] = dataMin.getValOfIdx(0,i) + (dataMax.getValOfIdx(0,i) - dataMin.getValOfIdx(0,i)) * (a.getValOfIdx(j,i) - normalizationMin) / (normalizationMax - normalizationMin);
+ }
+ }
+ }
+ return new Matrix(res);
+ }
+}
diff --git a/algorithm/src/main/java/com/mh/algorithm/utils/SerializationUtil.java b/algorithm/src/main/java/com/mh/algorithm/utils/SerializationUtil.java
new file mode 100644
index 0000000..b159328
--- /dev/null
+++ b/algorithm/src/main/java/com/mh/algorithm/utils/SerializationUtil.java
@@ -0,0 +1,32 @@
+package com.mh.algorithm.utils;
+
+import java.io.*;
+
+public class SerializationUtil {
+ /**
+ * 对象序列化到本地
+ * @param object
+ * @throws IOException
+ */
+ public static void serialize(Object object, String path) throws IOException {
+ File file = new File(path);
+ System.out.println(file.getAbsolutePath());
+ ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(file));
+ out.writeObject(object);
+ out.close();
+ }
+
+ /**
+ * 对象反序列化
+ * @return
+ * @throws IOException
+ * @throws ClassNotFoundException
+ */
+ public static Object deSerialization(String path) throws IOException, ClassNotFoundException {
+ File file = new File(path);
+ ObjectInputStream oin = new ObjectInputStream(new FileInputStream(file));
+ Object object = oin.readObject();
+ oin.close();
+ return object;
+ }
+}
diff --git a/algorithm/src/test/java/com/mh/algorithm/bpnn/bpnnTest.java b/algorithm/src/test/java/com/mh/algorithm/bpnn/bpnnTest.java
new file mode 100644
index 0000000..a8ad316
--- /dev/null
+++ b/algorithm/src/test/java/com/mh/algorithm/bpnn/bpnnTest.java
@@ -0,0 +1,71 @@
+package com.mh.algorithm.bpnn;
+
+import com.mh.algorithm.matrix.Matrix;
+import com.mh.algorithm.utils.CsvInfo;
+import com.mh.algorithm.utils.CsvUtil;
+import com.mh.algorithm.utils.SerializationUtil;
+import org.junit.Test;
+
+import java.util.Date;
+
+public class bpnnTest {
+ @Test
+ public void test() throws Exception {
+ // 创建训练集矩阵
+ CsvInfo csvInfo = CsvUtil.getCsvInfo(true, "D:\\ljf\\my_pro\\top-algorithm-set-dev\\src\\trainDataElec.csv");
+ Matrix trainSet = csvInfo.toMatrix();
+ // 创建BPNN工厂对象
+ BPNeuralNetworkFactory factory = new BPNeuralNetworkFactory();
+ // 创建BP参数对象
+ BPParameter bpParameter = new BPParameter();
+ bpParameter.setInputLayerNeuronCount(2);
+ bpParameter.setHiddenLayerNeuronCount(2);
+ bpParameter.setOutputLayerNeuronCount(2);
+ bpParameter.setPrecision(0.01);
+ bpParameter.setMaxTimes(100000);
+
+ // 训练BP神经网络
+ System.out.println(new Date());
+ BPModel bpModel = factory.trainBP(bpParameter, trainSet);
+ System.out.println(new Date());
+
+ // 将BPModel序列化到本地
+ SerializationUtil.serialize(bpModel, "elec");
+
+ CsvInfo csvInfo2 = CsvUtil.getCsvInfo(true, "D:\\ljf\\my_pro\\top-algorithm-set-dev\\src\\testDataElec.csv");
+ Matrix testSet = csvInfo2.toMatrix();
+
+ Matrix testData1 = testSet.subMatrix(0, testSet.getMatrixRowCount(), 0, testSet.getMatrixColCount() - 2);
+ Matrix testLabel = testSet.subMatrix(0, testSet.getMatrixRowCount(), testSet.getMatrixColCount() - 2, 1);
+ // 将BPModel反序列化
+ BPModel bpModel1 = (BPModel) SerializationUtil.deSerialization("elec");
+ Matrix result = factory.computeBP(bpModel1, testData1);
+
+ int total = result.getMatrixRowCount();
+ int correct = 0;
+ for (int i = 0; i < result.getMatrixRowCount(); i++) {
+ if(Math.round(result.getValOfIdx(i,0)) == testLabel.getValOfIdx(i,0)){
+ correct++;
+ }
+ }
+ double correctRate = Double.valueOf(correct) / Double.valueOf(total);
+ System.out.println(correctRate);
+ }
+
+ /**
+ * 使用示例
+ * @throws Exception
+ */
+ @Test
+ public void bpnnUsing() throws Exception{
+ CsvInfo csvInfo = CsvUtil.getCsvInfo(false, "D:\\ljf\\my_pro\\top-algorithm-set-dev\\src\\dataElec.csv");
+ Matrix data = csvInfo.toMatrix();
+ // 将BPModel反序列化
+ BPModel bpModel1 = (BPModel) SerializationUtil.deSerialization("elec");
+ // 创建工厂
+ BPNeuralNetworkFactory factory = new BPNeuralNetworkFactory();
+ Matrix result = factory.computeBP(bpModel1, data);
+ CsvUtil.createCsvFile(null,result,"D:\\ljf\\my_pro\\top-algorithm-set-dev\\src\\computeResult.csv");
+ }
+
+}
diff --git a/algorithm/src/test/java/com/mh/algorithm/knn/knnTest.java b/algorithm/src/test/java/com/mh/algorithm/knn/knnTest.java
new file mode 100644
index 0000000..54713a3
--- /dev/null
+++ b/algorithm/src/test/java/com/mh/algorithm/knn/knnTest.java
@@ -0,0 +1,46 @@
+package com.mh.algorithm.knn;
+
+import com.mh.algorithm.matrix.Matrix;
+import com.mh.algorithm.utils.CsvInfo;
+import com.mh.algorithm.utils.CsvUtil;
+import com.mh.algorithm.utils.DoubleUtil;
+import org.junit.Test;
+
+/**
+ * @program: top-algorithm-set
+ * @description:
+ * @author: Mr.Zhao
+ * @create: 2020-10-26 22:04
+ **/
+public class knnTest {
+ @Test
+ public void test() throws Exception {
+ // 训练集
+ CsvInfo csvInfo = CsvUtil.getCsvInfo(false, "E:\\jarTest\\trainData.csv");
+ Matrix trainSet = csvInfo.toMatrix();
+ Matrix trainSetLabels = trainSet.getColOfIdx(trainSet.getMatrixColCount() - 1);
+ Matrix trainSetData = trainSet.subMatrix(0, trainSet.getMatrixRowCount(), 0, trainSet.getMatrixColCount() - 1);
+
+ CsvInfo csvInfo1 = CsvUtil.getCsvInfo(false, "E:\\jarTest\\testData.csv");
+ Matrix testSet = csvInfo1.toMatrix();
+ Matrix testSetData = trainSet.subMatrix(0, testSet.getMatrixRowCount(), 0, testSet.getMatrixColCount() - 1);
+ Matrix testSetLabels = trainSet.getColOfIdx(testSet.getMatrixColCount() - 1);
+
+ // 分类
+ long startTime = System.currentTimeMillis();
+ Matrix result = KNN.classify(testSetData, trainSetData, trainSetLabels, 5);
+ long endTime = System.currentTimeMillis();
+ System.out.println("run time:" + (endTime - startTime));
+ // 正确率
+ Matrix error = result.subtract(testSetLabels);
+ int total = error.getMatrixRowCount();
+ int correct = 0;
+ for (int i = 0; i < error.getMatrixRowCount(); i++) {
+ if (DoubleUtil.equals(error.getValOfIdx(i, 0), 0.0)) {
+ correct++;
+ }
+ }
+ double correctRate = Double.valueOf(correct) / Double.valueOf(total);
+ System.out.println("correctRate:"+ correctRate);
+ }
+}
diff --git a/pom.xml b/pom.xml
index 7c03fe1..0a21d9b 100644
--- a/pom.xml
+++ b/pom.xml
@@ -11,6 +11,7 @@
common
user-service
+ algorithm
diff --git a/user-service/src/main/java/com/mh/user/controller/ControlSetController.java b/user-service/src/main/java/com/mh/user/controller/ControlSetController.java
index fae8cdc..cb1af68 100644
--- a/user-service/src/main/java/com/mh/user/controller/ControlSetController.java
+++ b/user-service/src/main/java/com/mh/user/controller/ControlSetController.java
@@ -46,11 +46,11 @@ public class ControlSetController {
}
//查询设置表
- @SysLogger(title="控制设置",optDesc = "查询设置值")
+ @SysLogger(title="控制设置",optDesc = "查询时控设置值")
@PostMapping(value="/query")
- public HttpResult queryControlSet(@RequestParam("buildingId") String buildingId) {
+ public HttpResult queryControlSet(@RequestParam("buildingId") String buildingId, @RequestParam(value = "timeName",required = false) String timeName) {
try{
- ControlSetEntity control=controlSetService.queryControlSet(buildingId);
+ ControlSetEntity control=controlSetService.queryControlSet(buildingId, timeName);
return HttpResult.ok(control);
}catch (Exception e){
// e.printStackTrace();
diff --git a/user-service/src/main/java/com/mh/user/controller/SerialPortController.java b/user-service/src/main/java/com/mh/user/controller/SerialPortController.java
index 8e2b0f0..7065846 100644
--- a/user-service/src/main/java/com/mh/user/controller/SerialPortController.java
+++ b/user-service/src/main/java/com/mh/user/controller/SerialPortController.java
@@ -424,9 +424,9 @@ public class SerialPortController {
}
@PostMapping(value = "/control")
- public HttpResult queryControlSet(@RequestParam(value = "buildingId") String buildingId) {
+ public HttpResult queryControlSet(@RequestParam(value = "buildingId") String buildingId, @RequestParam(value = "timeName", required = false) String timeName) {
try {
- ControlSetEntity list = controlSetService.queryControlSet(buildingId);
+ ControlSetEntity list = controlSetService.queryControlSet(buildingId, timeName);
return HttpResult.ok(list);
} catch (Exception e) {
// e.printStackTrace();
diff --git a/user-service/src/main/java/com/mh/user/entity/MaintainInfoEntity.java b/user-service/src/main/java/com/mh/user/entity/MaintainInfoEntity.java
index 7cee919..518843d 100644
--- a/user-service/src/main/java/com/mh/user/entity/MaintainInfoEntity.java
+++ b/user-service/src/main/java/com/mh/user/entity/MaintainInfoEntity.java
@@ -18,6 +18,7 @@ public class MaintainInfoEntity {
private String maintainPeople; //维护人员
private Double cost; //费用
private String contents; //维保内容
+ private String evaluate; // 评价分数
}
diff --git a/user-service/src/main/java/com/mh/user/mapper/ControlSetMapper.java b/user-service/src/main/java/com/mh/user/mapper/ControlSetMapper.java
index e5c4a37..a410a2a 100644
--- a/user-service/src/main/java/com/mh/user/mapper/ControlSetMapper.java
+++ b/user-service/src/main/java/com/mh/user/mapper/ControlSetMapper.java
@@ -60,6 +60,21 @@ public interface ControlSetMapper {
@Result(column = "back_water_temp", property = "backWaterTemp"),
@Result(column = "up_water_temp", property = "upWaterTemp"),
})
- @Select("select * from control_Set where building_id=#{buildingId}")
- ControlSetEntity queryControlSet(@Param("buildingId") String buildingId);
+ @Select("select " +
+ " top 1 " +
+ " * " +
+ "from " +
+ " control_Set " +
+ "where " +
+ " building_id = #{buildingId} " +
+ " and exists ( " +
+ " select " +
+ " 1 " +
+ " from " +
+ " device_install di " +
+ " where " +
+ " di.building_id = building_id " +
+ " and di.device_name like concat(#{timeName}, '时控') " +
+ ") ")
+ ControlSetEntity queryControlSet(@Param("buildingId") String buildingId, @Param("timeName") String timeName);
}
diff --git a/user-service/src/main/java/com/mh/user/mapper/MaintainInfoMapper.java b/user-service/src/main/java/com/mh/user/mapper/MaintainInfoMapper.java
index c94e90f..6de88fe 100644
--- a/user-service/src/main/java/com/mh/user/mapper/MaintainInfoMapper.java
+++ b/user-service/src/main/java/com/mh/user/mapper/MaintainInfoMapper.java
@@ -18,8 +18,8 @@ public interface MaintainInfoMapper {
* 维修保养信息
* @param maintainInfoEntity
*/
- @Insert("insert into maintain_info(cur_date,building_id,device_type,device_addr,maintain_type,maintain_people,cost,contents) values (" +
- " getDate(),#{buildingId},#{deviceType},#{deviceAddr},#{maintainType},#{maintainPeople},#{cost},#{contents})")
+ @Insert("insert into maintain_info(cur_date,building_id,device_type,device_addr,maintain_type,maintain_people,cost,contents, evaluate) values (" +
+ " getDate(),#{buildingId},#{deviceType},#{deviceAddr},#{maintainType},#{maintainPeople},#{cost},#{contents}, #{evaluate})")
int saveMaintainInfo(MaintainInfoEntity maintainInfoEntity);
/**
@@ -36,6 +36,7 @@ public interface MaintainInfoMapper {
" , maintain_people = #{maintainPeople} " +
" , cost = #{cost} " +
" , contents = #{contents} " +
+ " , evaluate = #{evaluate} " +
" where id = #{id} " +
"")
int updateMaintainInfo(MaintainInfoEntity maintainInfoEntity);
@@ -61,7 +62,8 @@ public interface MaintainInfoMapper {
@Result(property="maintainPeople",column="maintain_people"),
@Result(property="id",column="id"),
@Result(property="cost",column="cost"),
- @Result(property="contents",column="contents")
+ @Result(property="contents",column="contents"),
+ @Result(property="evaluate",column="evaluate")
})
List queryMaintainInfo(@Param("curDate") String curDate,
@Param("buildingId") String buildingId,
diff --git a/user-service/src/main/java/com/mh/user/service/ControlSetService.java b/user-service/src/main/java/com/mh/user/service/ControlSetService.java
index b1e01ce..0862a03 100644
--- a/user-service/src/main/java/com/mh/user/service/ControlSetService.java
+++ b/user-service/src/main/java/com/mh/user/service/ControlSetService.java
@@ -6,5 +6,5 @@ public interface ControlSetService {
void saveControlSet(ControlSetEntity controlSetEntity);
- ControlSetEntity queryControlSet(String buildingId);
+ ControlSetEntity queryControlSet(String buildingId, String timeName);
}
diff --git a/user-service/src/main/java/com/mh/user/service/impl/ControlSetServiceImpl.java b/user-service/src/main/java/com/mh/user/service/impl/ControlSetServiceImpl.java
index 5ee0f39..fdeac96 100644
--- a/user-service/src/main/java/com/mh/user/service/impl/ControlSetServiceImpl.java
+++ b/user-service/src/main/java/com/mh/user/service/impl/ControlSetServiceImpl.java
@@ -1,5 +1,6 @@
package com.mh.user.service.impl;
+import com.mh.common.utils.StringUtils;
import com.mh.user.entity.ControlSetEntity;
import com.mh.user.mapper.ControlSetMapper;
import com.mh.user.service.ControlSetService;
@@ -26,8 +27,10 @@ public class ControlSetServiceImpl implements ControlSetService {
}
@Override
- public ControlSetEntity queryControlSet(String buildingId) {
-
- return controlSetMapper.queryControlSet(buildingId);
+ public ControlSetEntity queryControlSet(String buildingId, String timeName) {
+ if (StringUtils.isBlank(timeName)) {
+ timeName = "%";
+ }
+ return controlSetMapper.queryControlSet(buildingId,timeName);
}
}
diff --git a/user-service/src/main/java/com/mh/user/service/impl/DeviceControlServiceImpl.java b/user-service/src/main/java/com/mh/user/service/impl/DeviceControlServiceImpl.java
index 70ca9af..12dde02 100644
--- a/user-service/src/main/java/com/mh/user/service/impl/DeviceControlServiceImpl.java
+++ b/user-service/src/main/java/com/mh/user/service/impl/DeviceControlServiceImpl.java
@@ -11,6 +11,7 @@ import com.mh.user.model.DeviceModel;
import com.mh.user.model.SerialPortModel;
import com.mh.user.serialport.SerialPortSingle2;
import com.mh.user.service.*;
+import com.mh.user.utils.ExchangeStringUtil;
import lombok.extern.slf4j.Slf4j;
import org.springframework.stereotype.Service;
@@ -385,6 +386,11 @@ public class DeviceControlServiceImpl implements DeviceControlService {
} else {
deviceCodeParam.setRegisterAddr("00240000"); //寄存器地址
}
+ } else if ("startOrStop".equals(deviceCodeParam.getParam())) {
+ deviceCodeParam.setFunCode("06"); //功能码读数据
+ if ("瑞星".equals(deviceCodeParam.getBrand())) {
+ deviceCodeParam.setRegisterAddr("0001"+ ExchangeStringUtil.addZeroForNum(deviceCodeParam.getDataValue(), 4));
+ }
}
return rtData;
}