291 lines
9.9 KiB
Diff
291 lines
9.9 KiB
Diff
From 65d2dc509e851d138154ee6a8a3ff3acb3780a30 Mon Sep 17 00:00:00 2001
|
|
From: gaoruoshu <gaoruoshu@huawei.com>
|
|
Date: Wed, 9 Aug 2023 19:15:07 +0800
|
|
Subject: [PATCH 2/3] bugfix: training model can only save file to specified
|
|
dir
|
|
|
|
---
|
|
analysis/default_config.py | 1 +
|
|
analysis/engine/parser.py | 4 +--
|
|
analysis/engine/train.py | 38 ++++++++++++++++++----
|
|
api/profile/profile.pb.go | 6 ++--
|
|
api/profile/profile.proto | 2 +-
|
|
common/models/training.go | 2 +-
|
|
modules/client/profile/profile_train.go | 42 +++++++++----------------
|
|
modules/server/profile/profile.go | 4 +--
|
|
8 files changed, 56 insertions(+), 43 deletions(-)
|
|
|
|
diff --git a/analysis/default_config.py b/analysis/default_config.py
|
|
index cf56ac2..7c921cc 100644
|
|
--- a/analysis/default_config.py
|
|
+++ b/analysis/default_config.py
|
|
@@ -23,6 +23,7 @@ GRPC_CERT_PATH = '/etc/atuned/grpc_certs'
|
|
ANALYSIS_DATA_PATH = '/var/atune_data/analysis/'
|
|
TUNING_DATA_PATH = '/var/atune_data/tuning/'
|
|
TUNING_DATA_DIRS = ['running', 'finished', 'error']
|
|
+TRAINING_MODEL_PATH = '/usr/libexec/atuned/analysis/models/'
|
|
|
|
|
|
def get_or_default(config, section, key, value):
|
|
diff --git a/analysis/engine/parser.py b/analysis/engine/parser.py
|
|
index c16089f..c36c74d 100644
|
|
--- a/analysis/engine/parser.py
|
|
+++ b/analysis/engine/parser.py
|
|
@@ -69,8 +69,8 @@ CLASSIFICATION_POST_PARSER.add_argument('model',
|
|
TRAIN_POST_PARSER = reqparse.RequestParser()
|
|
TRAIN_POST_PARSER.add_argument('datapath', required=True,
|
|
help="The datapath can not be null")
|
|
-TRAIN_POST_PARSER.add_argument('outputpath', required=True,
|
|
- help="The output path can not be null")
|
|
+TRAIN_POST_PARSER.add_argument('modelname', required=True,
|
|
+ help="The model name can not be null")
|
|
TRAIN_POST_PARSER.add_argument('modelpath', required=True,
|
|
help="The model path can not be null")
|
|
|
|
diff --git a/analysis/engine/train.py b/analysis/engine/train.py
|
|
index 9fdca46..7608660 100644
|
|
--- a/analysis/engine/train.py
|
|
+++ b/analysis/engine/train.py
|
|
@@ -22,6 +22,7 @@ from flask_restful import Resource
|
|
|
|
from analysis.optimizer.workload_characterization import WorkloadCharacterization
|
|
from analysis.engine.parser import TRAIN_POST_PARSER
|
|
+from analysis.default_config import TRAINING_MODEL_PATH
|
|
|
|
LOGGER = logging.getLogger(__name__)
|
|
|
|
@@ -29,7 +30,7 @@ LOGGER = logging.getLogger(__name__)
|
|
class Training(Resource):
|
|
"""provide the method of post for training"""
|
|
model_path = "modelpath"
|
|
- output_path = "outputpath"
|
|
+ model_name = "modelname"
|
|
data_path = "datapath"
|
|
|
|
def post(self):
|
|
@@ -40,18 +41,43 @@ class Training(Resource):
|
|
LOGGER.info(args)
|
|
|
|
model_path = args.get(self.model_path)
|
|
- output_path = args.get(self.output_path)
|
|
+ model_name = args.get(self.model_name)
|
|
data_path = args.get(self.data_path)
|
|
|
|
+ valid, err = valid_model_name(model_name)
|
|
+ if not valid:
|
|
+ return "Illegal model name provide: {}".format(err), 400
|
|
+
|
|
characterization = WorkloadCharacterization(model_path)
|
|
try:
|
|
+ output_path = TRAINING_MODEL_PATH + model_name
|
|
characterization.retrain(data_path, output_path)
|
|
except Exception as err:
|
|
LOGGER.error(err)
|
|
abort(500)
|
|
|
|
- if os.path.isdir(data_path):
|
|
- shutil.rmtree(data_path)
|
|
- else:
|
|
- os.remove(data_path)
|
|
return {}, 200
|
|
+
|
|
+
|
|
+def valid_model_name(name):
|
|
+ file_name, file_ext = os.path.splitext(name)
|
|
+
|
|
+ if file_ext != ".m":
|
|
+ return False, "the ext of model name should be .m"
|
|
+
|
|
+ if file_name in ['scaler', 'aencoder', 'tencoder', 'default_clf', 'total_clf', 'throughput_performance_clf']:
|
|
+ return False, "model name cannot be set as default_clf/scaler/aencoder/tencoder/throughput_performance_clf/total_clf"
|
|
+
|
|
+ for ind, char in enumerate(file_name):
|
|
+ if 'a' <= char <= 'z':
|
|
+ continue
|
|
+ if 'A' <= char <= 'Z':
|
|
+ continue
|
|
+ if '0' <= char <= '9':
|
|
+ continue
|
|
+ if ind != 0 and ind != len(file_name) - 1 and char == '_':
|
|
+ continue
|
|
+ return False, "model name cannot contains special character"
|
|
+
|
|
+ return True, None
|
|
+
|
|
diff --git a/api/profile/profile.pb.go b/api/profile/profile.pb.go
|
|
index dba166d..81a6757 100644
|
|
--- a/api/profile/profile.pb.go
|
|
+++ b/api/profile/profile.pb.go
|
|
@@ -493,7 +493,7 @@ func (m *CollectFlag) GetType() string {
|
|
|
|
type TrainMessage struct {
|
|
DataPath string `protobuf:"bytes,1,opt,name=DataPath,proto3" json:"DataPath,omitempty"`
|
|
- OutputPath string `protobuf:"bytes,2,opt,name=OutputPath,proto3" json:"OutputPath,omitempty"`
|
|
+ ModelName string `protobuf:"bytes,2,opt,name=ModelName,proto3" json:"ModelName,omitempty"`
|
|
XXX_NoUnkeyedLiteral struct{} `json:"-"`
|
|
XXX_unrecognized []byte `json:"-"`
|
|
XXX_sizecache int32 `json:"-"`
|
|
@@ -531,9 +531,9 @@ func (m *TrainMessage) GetDataPath() string {
|
|
return ""
|
|
}
|
|
|
|
-func (m *TrainMessage) GetOutputPath() string {
|
|
+func (m *TrainMessage) GetModelName() string {
|
|
if m != nil {
|
|
- return m.OutputPath
|
|
+ return m.ModelName
|
|
}
|
|
return ""
|
|
}
|
|
diff --git a/api/profile/profile.proto b/api/profile/profile.proto
|
|
index 2a6e751..29cf91e 100755
|
|
--- a/api/profile/profile.proto
|
|
+++ b/api/profile/profile.proto
|
|
@@ -84,7 +84,7 @@ message CollectFlag {
|
|
|
|
message TrainMessage {
|
|
string DataPath = 1;
|
|
- string OutputPath = 2;
|
|
+ string ModelName = 2;
|
|
}
|
|
|
|
message DetectMessage {
|
|
diff --git a/common/models/training.go b/common/models/training.go
|
|
index 3cc9d60..9497261 100644
|
|
--- a/common/models/training.go
|
|
+++ b/common/models/training.go
|
|
@@ -24,7 +24,7 @@ import (
|
|
type Training struct {
|
|
DataPath string `json:"datapath"`
|
|
ModelPath string `json:"modelpath"`
|
|
- OutputPath string `json:"outputpath"`
|
|
+ ModelName string `json:"modelname"`
|
|
}
|
|
|
|
// Post method call training service
|
|
diff --git a/modules/client/profile/profile_train.go b/modules/client/profile/profile_train.go
|
|
index f4e68cb..a645bc3 100644
|
|
--- a/modules/client/profile/profile_train.go
|
|
+++ b/modules/client/profile/profile_train.go
|
|
@@ -18,9 +18,9 @@ import (
|
|
"gitee.com/openeuler/A-Tune/common/client"
|
|
SVC "gitee.com/openeuler/A-Tune/common/service"
|
|
"gitee.com/openeuler/A-Tune/common/utils"
|
|
+ "gitee.com/openeuler/A-Tune/common/config"
|
|
"fmt"
|
|
"io"
|
|
- "os"
|
|
"path/filepath"
|
|
|
|
"github.com/urfave/cli"
|
|
@@ -39,8 +39,8 @@ var trainCommand = cli.Command{
|
|
Value: "",
|
|
},
|
|
cli.StringFlag{
|
|
- Name: "output_file,o",
|
|
- Usage: "the model to be generated",
|
|
+ Name: "model_name,m",
|
|
+ Usage: "the model name of generate model",
|
|
Value: "",
|
|
},
|
|
},
|
|
@@ -48,9 +48,9 @@ var trainCommand = cli.Command{
|
|
desc := `
|
|
training a new model with the self collected data, data_path option specified
|
|
the path that storage the collected data, the collected data must have more
|
|
- than two workload type. output_file specified the file path where to store
|
|
+ than two workload type. model_name specified the name of model to be generated
|
|
the trained model, which must be end with .m.
|
|
- example: atune-adm train --data_path=./data --output_file=./model/trained.m`
|
|
+ example: atune-adm train --data_path=/home/data --model_name=trained.m`
|
|
return desc
|
|
}(),
|
|
Action: train,
|
|
@@ -82,13 +82,13 @@ func checkTrainCtx(ctx *cli.Context) error {
|
|
return fmt.Errorf("input:%s is invalid", dataPath)
|
|
}
|
|
|
|
- outputPath := ctx.String("output_file")
|
|
- if outputPath == "" {
|
|
+ modelName := ctx.String("model_name")
|
|
+ if modelName == "" {
|
|
_ = cli.ShowCommandHelp(ctx, "train")
|
|
- return fmt.Errorf("error: output_file must be specified")
|
|
+ return fmt.Errorf("error: model_name must be specified")
|
|
}
|
|
- if !utils.IsInputStringValid(outputPath) {
|
|
- return fmt.Errorf("input:%s is invalid", outputPath)
|
|
+ if !utils.IsInputStringValid(modelName) {
|
|
+ return fmt.Errorf("input:%s is invalid", modelName)
|
|
}
|
|
|
|
return nil
|
|
@@ -103,23 +103,8 @@ func train(ctx *cli.Context) error {
|
|
if err != nil {
|
|
return err
|
|
}
|
|
- outputPath, err := filepath.Abs(ctx.String("output_file"))
|
|
- if err != nil {
|
|
- return err
|
|
- }
|
|
-
|
|
- dir := filepath.Dir(outputPath)
|
|
|
|
- exist, err := utils.PathExist(dir)
|
|
- if err != nil {
|
|
- return err
|
|
- }
|
|
- if !exist {
|
|
- err = os.MkdirAll(dir, utils.FilePerm)
|
|
- if err != nil {
|
|
- return err
|
|
- }
|
|
- }
|
|
+ modelName := ctx.String("model_name")
|
|
|
|
c, err := client.NewClientFromContext(ctx)
|
|
if err != nil {
|
|
@@ -128,7 +113,7 @@ func train(ctx *cli.Context) error {
|
|
defer c.Close()
|
|
|
|
svc := PB.NewProfileMgrClient(c.Connection())
|
|
- stream, err := svc.Training(CTX.Background(), &PB.TrainMessage{DataPath: dataPath, OutputPath: outputPath})
|
|
+ stream, err := svc.Training(CTX.Background(), &PB.TrainMessage{DataPath: dataPath, ModelName: modelName})
|
|
if err != nil {
|
|
return err
|
|
}
|
|
@@ -145,6 +130,7 @@ func train(ctx *cli.Context) error {
|
|
}
|
|
utils.Print(reply)
|
|
}
|
|
- fmt.Println("the model generate path:", outputPath)
|
|
+ modelPath := fmt.Sprintf("%s/models/%s", config.DefaultAnalysisPath, modelName)
|
|
+ fmt.Println("the model generate path:", modelPath)
|
|
return nil
|
|
}
|
|
diff --git a/modules/server/profile/profile.go b/modules/server/profile/profile.go
|
|
index 3264167..70047ca 100644
|
|
--- a/modules/server/profile/profile.go
|
|
+++ b/modules/server/profile/profile.go
|
|
@@ -1137,7 +1137,7 @@ func (s *ProfileServer) Training(message *PB.TrainMessage, stream PB.ProfileMgr_
|
|
}
|
|
|
|
DataPath := message.GetDataPath()
|
|
- OutputPath := message.GetOutputPath()
|
|
+ ModelName := message.GetModelName()
|
|
|
|
compressPath, err := utils.CreateCompressFile(DataPath)
|
|
if err != nil {
|
|
@@ -1156,7 +1156,7 @@ func (s *ProfileServer) Training(message *PB.TrainMessage, stream PB.ProfileMgr_
|
|
|
|
trainBody := new(models.Training)
|
|
trainBody.DataPath = trainPath
|
|
- trainBody.OutputPath = OutputPath
|
|
+ trainBody.ModelName = ModelName
|
|
trainBody.ModelPath = path.Join(config.DefaultAnalysisPath, "models")
|
|
|
|
success, err := trainBody.Post()
|
|
--
|
|
2.27.0
|
|
|