Commit 94f90604 by lichenggang

将封装添加进模型,且其值为参数

parent efc76efc
No preview for this file type
No preview for this file type
......@@ -7,6 +7,9 @@ from utils.robots import dd_send_msg
import pandas as pd
RIGHT_LEVEL = 0.7
SEQ_LEVEL = 0.5
class BasePredictor:
'''
预测类基类
......@@ -21,16 +24,16 @@ class BasePredictor:
self.robot_msg = dd_send_msg
self.pd = pd
def predict(self, key):
raise NotImplementedError
def get_single_predict(self, string: object) -> tuple:
def get_single_predict(self, obj: object) -> tuple:
'''
:param string: 接收单个要判断的字符串
:return tuple: 返回两个元素的元组, 第一个元素为判断结果, 第二个元素为可能性的字典
'''
string = str(obj)
series = self.pd.Series([string])
feature = self.extractor.transform(series)
predictions = self.classifier.predict(feature)
......@@ -44,9 +47,12 @@ class BasePredictor:
"""
序号列预测
"""
collect_order = [int(kw) for kw in data if isinstance(kw, float) or isinstance(kw, int)]
judge = self.isIncrease(collect_order, len(collect_order)) if len(collect_order) > 0 else False
return judge
collect_seq = [int(kw) for kw in data if isinstance(kw, float) or isinstance(kw, int)]
if len(collect_seq) < 1 or not self.isIncrease(collect_seq, len(collect_seq)):
return False
else:
rate = round(len(collect_seq) / len(data), 3)
return True if rate >= SEQ_LEVEL else False
def isnum(self, data):
"""
......@@ -64,7 +70,6 @@ class BasePredictor:
return True
return (arr[size - 1] >= arr[size - 2]) and self.isIncrease(arr, size - 1)
def isNumberCol(self, kw):
"""
是否是数量列
......
......@@ -24,6 +24,7 @@ MUST_STD_FIELDS = ['参数', '数量']
order_list = ['序号']
en_to_zh_map = {'brand_name': '品牌', 'param': '参数', 'goods_name': '型号'}
def fun(seri):
li_seri = seri.tolist()
for field in li_seri:
......@@ -92,11 +93,10 @@ class DicPredict(BasePredictor):
def model_predict(self, dic_data):
"""
该方法目前只对[参数, 型号, 数量, 品牌]进行预测
该方法目前只对[序号, 参数, 型号, 数量, 品牌]进行预测, 只有[参数, 数量]会强制有结果
:param dic_data:
:return:
"""
print(dic_data)
prob_columns = []
temp_pre_model_res = {}
......@@ -111,7 +111,7 @@ class DicPredict(BasePredictor):
continue
if self.isseq(v):
temp_pre_model_res[k] = '序号'
temp_dic_data = {k: list(filter(lambda x: x != PLACEHOLDER, dic_data[k]))for k in prob_columns}
temp_dic_data = {k: list(filter(lambda x: x != PLACEHOLDER, dic_data[k])) for k in prob_columns}
for k, v in temp_dic_data.items():
li_single_pred_res = []
for string in v:
......@@ -120,9 +120,8 @@ class DicPredict(BasePredictor):
result = Counter(li_single_pred_res)
# [('brand_name', 4), ('goods_name', 3), ('param', 2)]
li_sort = sorted(result.items(), key=lambda x: x[1], reverse=True)
print(k, li_sort)
temp_pre_model_res[k] = en_to_zh_map[li_sort[0][0]]
print(temp_pre_model_res)
sort_temp_pre_model_res = sorted(temp_pre_model_res.keys())
# pre_model_res = {}
......@@ -138,6 +137,7 @@ class DicPredict(BasePredictor):
model_id_res = self.model_predict(dic_data)
if model_id_res:
return model_id_res
elif predict_type == 'pre':
pre_id_res = self.pre_predict(dic_data)
if pre_id_res:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment