Commit d8dd3982 by lichenggang

模型预测

parent f716998b
Showing with 57 additions and 15 deletions
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
from collections import Counter
from predict.base_handler import BasePredictor from predict.base_handler import BasePredictor
# 可能的头部字段 # 可能的头部字段
...@@ -16,9 +18,11 @@ fields_map = {"序号": ["序号"], ...@@ -16,9 +18,11 @@ fields_map = {"序号": ["序号"],
"封装": ["封装", "封装规格"], "封装": ["封装", "封装规格"],
"品牌": ["品牌", "品牌/厂商", "参考品牌", "参考供应商", "厂商", "参考供应商", "参考厂商"]} "品牌": ["品牌", "品牌/厂商", "参考品牌", "参考供应商", "厂商", "参考供应商", "参考厂商"]}
# 必须返回也必须验证的标准字段
MUST_STD_FIELDS = ['参数', '数量']
# #
order_list = ['序号'] order_list = ['序号']
en_to_zh_map = {'brand_name': '品牌', 'param': '参数', 'goods_name': '型号'}
def fun(seri): def fun(seri):
li_seri = seri.tolist() li_seri = seri.tolist()
...@@ -31,7 +35,8 @@ def fun(seri): ...@@ -31,7 +35,8 @@ def fun(seri):
HEAD_ROW = 7 HEAD_ROW = 7
# 空置率阈值 # 空置率阈值
NAN_RATE = 0.8 NAN_RATE = 0.8
# 占位符
PLACEHOLDER = '?'
class DicPredict(BasePredictor): class DicPredict(BasePredictor):
...@@ -63,14 +68,6 @@ class DicPredict(BasePredictor): ...@@ -63,14 +68,6 @@ class DicPredict(BasePredictor):
columns = [] columns = []
li_data = [] li_data = []
for k, v in dict_data.items(): for k, v in dict_data.items():
# 去掉空置率大于等于0.8的列
counter = 0
for item in v:
if not str(item).strip():
counter += 1
if counter / len(v) >= NAN_RATE:
continue
columns.append(k) columns.append(k)
li_data.append(v) li_data.append(v)
...@@ -87,16 +84,44 @@ class DicPredict(BasePredictor): ...@@ -87,16 +84,44 @@ class DicPredict(BasePredictor):
else: else:
dic_ab_res = {i['column_name']: i['pronoun']} dic_ab_res = {i['column_name']: i['pronoun']}
ab_result.append(dic_ab_res) ab_result.append(dic_ab_res)
id_res = { pre_id_res = {
'std_result': std_result, 'std_result': std_result,
'ab_result': ab_result, 'ab_result': ab_result,
} }
return id_res return pre_id_res
def model_predict(self, dic_data):
"""
该方法目前只对[参数, 型号, 数量, 品牌]进行预测
:param dic_data:
:return:
"""
prob_columns = []
for k, v in dic_data.items():
if self.valid_chain(v):
prob_columns.append(k)
temp_dic_data = {k: list(filter(lambda x: x != PLACEHOLDER, dic_data[k]))for k in prob_columns}
temp_pre_model_res = {}
for k, v in temp_dic_data.items():
li_single_pred_res = []
for string in v:
single_pred_res, probdic = self.get_single_predict(string)
li_single_pred_res.append(single_pred_res)
result = Counter(li_single_pred_res)
# [('brand_name', 4), ('goods_name', 3), ('param', 2)]
li_sort = sorted(result.items(), key=lambda x: x[1], reverse=True)
temp_pre_model_res[k] = en_to_zh_map[li_sort[0][0]]
print(temp_pre_model_res)
# pre_model_res = {}
def predict(self, dic_data): def predict(self, dic_data):
res = self.pre_predict(dic_data) dic_data = self.pre_deal(dic_data)
if res: pre_id_res = self.pre_predict(dic_data)
return res if pre_id_res:
return pre_id_res
if len(dic_data) > 0: if len(dic_data) > 0:
self.order_predict(dic_data[0]) self.order_predict(dic_data[0])
pass pass
...@@ -115,3 +140,19 @@ class DicPredict(BasePredictor): ...@@ -115,3 +140,19 @@ class DicPredict(BasePredictor):
if size == 1: if size == 1:
return True return True
return (arr[size - 1] >= arr[size - 2]) and self.IsIncrease(arr, size - 1) return (arr[size - 1] >= arr[size - 2]) and self.IsIncrease(arr, size - 1)
def pre_deal(self, dic_data):
new_dic_data = {}
for k, v in dic_data.items():
# 去掉空置率大于等于0.8的列
counter = 0
for item in v:
if str(item).strip() == PLACEHOLDER:
counter += 1
if counter / len(v) <= NAN_RATE:
new_dic_data[k] = v
return new_dic_data
def valid_chain(self, li):
pass
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment