Commit ebf97191 by lichenggang

增加kw_predict

parent c82b2a08
...@@ -18,12 +18,14 @@ class Classify(classify_pb2_grpc.classifyServicer): ...@@ -18,12 +18,14 @@ class Classify(classify_pb2_grpc.classifyServicer):
def Classify(self, request, context): def Classify(self, request, context):
try: try:
dic_data = json.loads(request.keyword) dic_data = json.loads(request.keyword)
res = self.predictorfac.predict(dic_data) res = self.predictorfac.predict(dic_data)
res['status'] = 1 res['status'] = 1
except: except json.decoder.JSONDecodeError:
res = self.predictorfac.predict(request.keyword)
res['status'] = 1
except :
res = { res = {
'status': 0, 'status': 0,
'err_msg': '未知' 'err_msg': '未知'
......
...@@ -14,10 +14,11 @@ class PredictorFac(): ...@@ -14,10 +14,11 @@ class PredictorFac():
self.dict_predictor = dict_predict.DicPredict('dict', extractor, classifier) self.dict_predictor = dict_predict.DicPredict('dict', extractor, classifier)
def _get_predictor(self, data): def _get_predictor(self, data):
if isinstance(data, str):
return self.kw_predictor if isinstance(data, dict):
elif isinstance(data, dict):
return self.dict_predictor return self.dict_predictor
else:
return self.kw_predictor
def predict(self, dic_data, predict_type='all'): def predict(self, dic_data, predict_type='all'):
predictor = self._get_predictor(dic_data) predictor = self._get_predictor(dic_data)
...@@ -29,7 +30,7 @@ if __name__ == "__main__": ...@@ -29,7 +30,7 @@ if __name__ == "__main__":
def get_test_data(): def get_test_data():
import pandas as pd import pandas as pd
import json import json
df = pd.read_excel(r'C:\Users\ICHUNT\Desktop\bomlist\51AB0571_ CCTV ASST询价_SZIMS.xlsx', header=None, sheet_name='3') df = pd.read_excel(r'C:\Users\ICHUNT\Desktop\bomlist\立创.xlsx', header=None)
df.fillna('?', inplace=True) df.fillna('?', inplace=True)
dic_dft = df.to_dict(orient='list') dic_dft = df.to_dict(orient='list')
return json.dumps(dic_dft) return json.dumps(dic_dft)
...@@ -40,6 +41,7 @@ if __name__ == "__main__": ...@@ -40,6 +41,7 @@ if __name__ == "__main__":
data = get_test_data() data = get_test_data()
p = PredictorFac(model_config) p = PredictorFac(model_config)
data = json.loads(data) data = json.loads(data)
print(data) pretty_col = {'第%s列' % k: v for k, v in data.items()}
res = p.predict(data,predict_type='all') print(pretty_col)
res = p.predict(pretty_col, predict_type='all')
print(res) print(res)
...@@ -137,7 +137,7 @@ class DicPredict(BasePredictor): ...@@ -137,7 +137,7 @@ class DicPredict(BasePredictor):
return model_id_res return model_id_res
def predict(self, dic_data, predict_type='all'): def predict(self, dic_data, predict_type='all'):
self.info.info('预测类型为%s接收数据: ' % predict_type + str(dic_data)) self.info.info('预测类型为: %s, 接收数据: %s' % (predict_type, str(dic_data)))
dic_data = self.pre_deal(dic_data) dic_data = self.pre_deal(dic_data)
if predict_type == 'all': if predict_type == 'all':
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
from predict.base_handler import BasePredictor from predict.base_handler import BasePredictor
from static_config import EN_TO_ZH_MAP
class KwPredict(BasePredictor): class KwPredict(BasePredictor):
def predict(self, key): def predict(self, string, predict_type='all'):
res, prab = self.get_single_predict(key) self.info.info('预测类型为: %s, 接收数据: %s' % (predict_type, string))
result, prab = self.get_single_predict(string)
zh_prob = {EN_TO_ZH_MAP[k]: v for k, v in prab.items()}
res = {'result': EN_TO_ZH_MAP[result], 'probably_dict': zh_prob}
return res return res
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment