Commit 94f90604 by lichenggang

将封装添加进模型,且其值为参数

parent efc76efc
No preview for this file type
No preview for this file type
...@@ -7,6 +7,9 @@ from utils.robots import dd_send_msg ...@@ -7,6 +7,9 @@ from utils.robots import dd_send_msg
import pandas as pd import pandas as pd
RIGHT_LEVEL = 0.7 RIGHT_LEVEL = 0.7
SEQ_LEVEL = 0.5
class BasePredictor: class BasePredictor:
''' '''
预测类基类 预测类基类
...@@ -21,16 +24,16 @@ class BasePredictor: ...@@ -21,16 +24,16 @@ class BasePredictor:
self.robot_msg = dd_send_msg self.robot_msg = dd_send_msg
self.pd = pd self.pd = pd
def predict(self, key): def predict(self, key):
raise NotImplementedError raise NotImplementedError
def get_single_predict(self, string: object) -> tuple: def get_single_predict(self, obj: object) -> tuple:
''' '''
:param string: 接收单个要判断的字符串 :param string: 接收单个要判断的字符串
:return tuple: 返回两个元素的元组, 第一个元素为判断结果, 第二个元素为可能性的字典 :return tuple: 返回两个元素的元组, 第一个元素为判断结果, 第二个元素为可能性的字典
''' '''
string = str(obj)
series = self.pd.Series([string]) series = self.pd.Series([string])
feature = self.extractor.transform(series) feature = self.extractor.transform(series)
predictions = self.classifier.predict(feature) predictions = self.classifier.predict(feature)
...@@ -44,9 +47,12 @@ class BasePredictor: ...@@ -44,9 +47,12 @@ class BasePredictor:
""" """
序号列预测 序号列预测
""" """
collect_order = [int(kw) for kw in data if isinstance(kw, float) or isinstance(kw, int)] collect_seq = [int(kw) for kw in data if isinstance(kw, float) or isinstance(kw, int)]
judge = self.isIncrease(collect_order, len(collect_order)) if len(collect_order) > 0 else False if len(collect_seq) < 1 or not self.isIncrease(collect_seq, len(collect_seq)):
return judge return False
else:
rate = round(len(collect_seq) / len(data), 3)
return True if rate >= SEQ_LEVEL else False
def isnum(self, data): def isnum(self, data):
""" """
...@@ -64,7 +70,6 @@ class BasePredictor: ...@@ -64,7 +70,6 @@ class BasePredictor:
return True return True
return (arr[size - 1] >= arr[size - 2]) and self.isIncrease(arr, size - 1) return (arr[size - 1] >= arr[size - 2]) and self.isIncrease(arr, size - 1)
def isNumberCol(self, kw): def isNumberCol(self, kw):
""" """
是否是数量列 是否是数量列
......
...@@ -24,6 +24,7 @@ MUST_STD_FIELDS = ['参数', '数量'] ...@@ -24,6 +24,7 @@ MUST_STD_FIELDS = ['参数', '数量']
order_list = ['序号'] order_list = ['序号']
en_to_zh_map = {'brand_name': '品牌', 'param': '参数', 'goods_name': '型号'} en_to_zh_map = {'brand_name': '品牌', 'param': '参数', 'goods_name': '型号'}
def fun(seri): def fun(seri):
li_seri = seri.tolist() li_seri = seri.tolist()
for field in li_seri: for field in li_seri:
...@@ -92,11 +93,10 @@ class DicPredict(BasePredictor): ...@@ -92,11 +93,10 @@ class DicPredict(BasePredictor):
def model_predict(self, dic_data): def model_predict(self, dic_data):
""" """
该方法目前只对[参数, 型号, 数量, 品牌]进行预测 该方法目前只对[序号, 参数, 型号, 数量, 品牌]进行预测, 只有[参数, 数量]会强制有结果
:param dic_data: :param dic_data:
:return: :return:
""" """
print(dic_data)
prob_columns = [] prob_columns = []
temp_pre_model_res = {} temp_pre_model_res = {}
...@@ -111,7 +111,7 @@ class DicPredict(BasePredictor): ...@@ -111,7 +111,7 @@ class DicPredict(BasePredictor):
continue continue
if self.isseq(v): if self.isseq(v):
temp_pre_model_res[k] = '序号' temp_pre_model_res[k] = '序号'
temp_dic_data = {k: list(filter(lambda x: x != PLACEHOLDER, dic_data[k]))for k in prob_columns} temp_dic_data = {k: list(filter(lambda x: x != PLACEHOLDER, dic_data[k])) for k in prob_columns}
for k, v in temp_dic_data.items(): for k, v in temp_dic_data.items():
li_single_pred_res = [] li_single_pred_res = []
for string in v: for string in v:
...@@ -120,9 +120,8 @@ class DicPredict(BasePredictor): ...@@ -120,9 +120,8 @@ class DicPredict(BasePredictor):
result = Counter(li_single_pred_res) result = Counter(li_single_pred_res)
# [('brand_name', 4), ('goods_name', 3), ('param', 2)] # [('brand_name', 4), ('goods_name', 3), ('param', 2)]
li_sort = sorted(result.items(), key=lambda x: x[1], reverse=True) li_sort = sorted(result.items(), key=lambda x: x[1], reverse=True)
print(k, li_sort)
temp_pre_model_res[k] = en_to_zh_map[li_sort[0][0]] temp_pre_model_res[k] = en_to_zh_map[li_sort[0][0]]
print(temp_pre_model_res) sort_temp_pre_model_res = sorted(temp_pre_model_res.keys())
# pre_model_res = {} # pre_model_res = {}
...@@ -138,6 +137,7 @@ class DicPredict(BasePredictor): ...@@ -138,6 +137,7 @@ class DicPredict(BasePredictor):
model_id_res = self.model_predict(dic_data) model_id_res = self.model_predict(dic_data)
if model_id_res: if model_id_res:
return model_id_res return model_id_res
elif predict_type == 'pre': elif predict_type == 'pre':
pre_id_res = self.pre_predict(dic_data) pre_id_res = self.pre_predict(dic_data)
if pre_id_res: if pre_id_res:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment