Commit 94698314 by lichenggang

增加封装的识别

parent bb699679
...@@ -6,14 +6,13 @@ from utils.log_manager import bom_log ...@@ -6,14 +6,13 @@ from utils.log_manager import bom_log
from protobuf import classify_pb2 from protobuf import classify_pb2
from protobuf import classify_pb2_grpc from protobuf import classify_pb2_grpc
import traceback import traceback
from utils.config import model_config
from predic_fac import PredictorFac from predic_fac import PredictorFac
class Classify(classify_pb2_grpc.classifyServicer): class Classify(classify_pb2_grpc.classifyServicer):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.predictorfac = PredictorFac(model_config) self.predictorfac = PredictorFac()
def Classify(self, request, context): def Classify(self, request, context):
bom_log.info(f'grpc收到数据: {request.keyword}') bom_log.info(f'grpc收到数据: {request.keyword}')
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import sys import sys, os
ENVIRONMENT = 'test' if sys.platform in ('darwin', 'win32') else 'produce' ENVIRONMENT = 'test' if sys.platform in ('darwin', 'win32') else 'produce'
model_config = {} model_config = {}
project_path = os.path.abspath(os.path.dirname(__file__))
if ENVIRONMENT == 'test': if ENVIRONMENT == 'test':
model_config['model_path'] = r'model' model_config['modextr_path'] = project_path + '\\' + 'models_and_extractors\\'
model_config['extractor_path'] = r'extractor'
else: else:
model_config['model_path'] = r'model' model_config['modextr_path'] = project_path + '\\' + 'models_and_extractors\\'
model_config['extractor_path'] = r'extractor'
...@@ -8,21 +8,21 @@ import tornado.ioloop ...@@ -8,21 +8,21 @@ import tornado.ioloop
import traceback import traceback
from utils.log_manager import bom_log from utils.log_manager import bom_log
from utils.config import model_config
from predic_fac import PredictorFac from predic_fac import PredictorFac
class BaseHandler(tornado.web.RequestHandler): class BaseHandler(tornado.web.RequestHandler):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self.predictorfac = PredictorFac(model_config) self.predictorfac = PredictorFac()
class KwHandler(BaseHandler): class KwHandler(BaseHandler):
async def get(self): async def get(self):
try: try:
target = unquote(self.get_argument('keyword')) target = unquote(self.get_argument('keyword'))
bom_log.info(f'http收到识别关键词: {target}') predict_type = unquote(self.get_argument('type', 'all'))
res = self.predictorfac.predict(target) bom_log.info(f'http收到识别关键词: {target}, 预测类型为{predict_type}')
res = self.predictorfac.predict(target, predict_type)
res['status'] = 1 res['status'] = 1
except: except:
res = { res = {
......
No preview for this file type
No preview for this file type
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
from sklearn.externals import joblib import json
import pickle, json
from predict import dict_predict, kw_predict from predict import dict_predict, kw_predict
class PredictorFac(): class PredictorFac:
def __init__(self, config): def __init__(self):
with open(config['extractor_path'], 'rb') as f:
extractor = pickle.load(f) self.kw_predictor = kw_predict.KwPredict('single')
classifier = joblib.load(config['model_path']) self.dict_predictor = dict_predict.DicPredict('dict')
self.kw_predictor = kw_predict.KwPredict('single', extractor, classifier)
self.dict_predictor = dict_predict.DicPredict('dict', extractor, classifier)
def _get_predictor(self, data): def _get_predictor(self, data):
...@@ -30,16 +27,13 @@ if __name__ == "__main__": ...@@ -30,16 +27,13 @@ if __name__ == "__main__":
def get_test_data(): def get_test_data():
import pandas as pd import pandas as pd
import json import json
df = pd.read_excel(r'C:\Users\ICHUNT\Desktop\bomlist\立创.xlsx', header=None) df = pd.read_excel(r'C:\Users\ICHUNT\Desktop\bomlist\sky-新建 XLSX 工作表 (2).xlsx', header=None)
df.fillna('?', inplace=True) df.fillna('?', inplace=True)
dic_dft = df.to_dict(orient='list') dic_dft = df.to_dict(orient='list')
return json.dumps(dic_dft) return json.dumps(dic_dft)
from utils.config import model_config
data = get_test_data() data = get_test_data()
p = PredictorFac(model_config) p = PredictorFac()
data = json.loads(data) data = json.loads(data)
pretty_col = {'第%s列' % k: v for k, v in data.items()} pretty_col = {'第%s列' % k: v for k, v in data.items()}
print(pretty_col) print(pretty_col)
......
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
import re import re, os, pickle
from collections import Counter from collections import Counter
from sklearn.externals import joblib
from config import model_config
from utils.log_manager import get_logger from utils.log_manager import get_logger
from utils.robots import dd_send_msg from utils.robots import dd_send_msg
import pandas as pd import pandas as pd
...@@ -19,39 +21,56 @@ MULTI_SAME_LEVEL = 3 ...@@ -19,39 +21,56 @@ MULTI_SAME_LEVEL = 3
# 位号列阈值 # 位号列阈值
REF_LEVEL = 0.5 REF_LEVEL = 0.5
class ClassBasePredictorMeta(type):
def __init__(cls, name, bases, attrs):
super(ClassBasePredictorMeta, cls).__init__(name, bases, attrs)
modextr_path = model_config['modextr_path']
modextrs = os.listdir(modextr_path)
for i in modextrs:
if i.endswith('tractor'):
with open(modextr_path + i, 'rb') as f:
extractor = pickle.load(f)
setattr(cls, i, extractor)
else:
model = joblib.load(modextr_path + i)
setattr(cls, i, model)
def __new__(mcs, name, bases, attrs):
return type.__new__(mcs, name, bases, attrs)
class BasePredictor: class BasePredictor(metaclass=ClassBasePredictorMeta):
''' '''
预测类基类 预测类基类
''' '''
# 占位符 # 占位符
PLACEHOLDER = '?' PLACEHOLDER = '?'
pd = pd
def __init__(self, name, extractor, classifier): def __init__(self, name):
self.name = name self.name = name
self.classifier = classifier
self.extractor = extractor
self.error = get_logger(self.name + '_error') self.error = get_logger(self.name + '_error')
self.info = get_logger(self.name + '_info') self.info = get_logger(self.name + '_info')
self.robot_msg = dd_send_msg self.robot_msg = dd_send_msg
self.pd = pd
def predict(self, key): def predict(self, key):
raise NotImplementedError raise NotImplementedError
def get_single_predict(self, obj: object) -> tuple: @classmethod
def get_single_predict(cls, obj: object, model, extractor) -> tuple:
''' '''
:param obj: 接收单个要判断的元素
:param string: 接收单个要判断的字符串 :param model: 用到的模型
:param extractor: 用到的提取器
:return tuple: 返回两个元素的元组, 第一个元素为判断结果, 第二个元素为可能性的字典 :return tuple: 返回两个元素的元组, 第一个元素为判断结果, 第二个元素为可能性的字典
''' '''
string = str(obj) string = str(obj)
series = self.pd.Series([string]) series = cls.pd.Series([string])
feature = self.extractor.transform(series) feature = extractor.transform(series)
predictions = self.classifier.predict(feature) predictions = model.predict(feature)
proba = self.classifier.predict_proba(feature) proba = model.predict_proba(feature)
classes = self.classifier.classes_ classes = model.classes_
deal_list = [round(i, 3) for i in proba[0].tolist()] deal_list = [round(i, 3) for i in proba[0].tolist()]
dic_proba = {k: v for k, v in zip(classes, deal_list)} dic_proba = {k: v for k, v in zip(classes, deal_list)}
return predictions[0], dic_proba return predictions[0], dic_proba
...@@ -149,7 +168,7 @@ class BasePredictor: ...@@ -149,7 +168,7 @@ class BasePredictor:
@classmethod @classmethod
def get_comprehensive_res(cls, head_result, model_result, major='head', vote_type='standard'): def get_comprehensive_res(cls, head_result, model_result, major='head', vote_type='standard'):
if vote_type == 'standard': if vote_type == 'standard':
vote_count = {k: []for k in STD_FIELDS_MAP.keys()} vote_count = {k: [] for k in STD_FIELDS_MAP.keys()}
else: else:
set_k = set() set_k = set()
for v in head_result.values(): for v in head_result.values():
...@@ -211,5 +230,60 @@ class BasePredictor: ...@@ -211,5 +230,60 @@ class BasePredictor:
break break
return round(count / len(data), 3) return round(count / len(data), 3)
@classmethod
def get_encap_rate(cls, data):
"""
封装率
"""
count = 0
for item in data:
single_pred_res, probdic = cls.get_single_predict(item, cls.encap_model, cls.encap_extractor)
if single_pred_res == '1':
count += 1
return round(count / len(data), 3)
@classmethod
def get_nan_rate(cls, data):
'''
:param data:需要判断的数组
:return float: 空置率
'''
count = 0
for item in data:
if str(item).strip() == BasePredictor.PLACEHOLDER:
count += 1
return round(count / len(data), 3)
if __name__ == "__main__": if __name__ == "__main__":
print(BasePredictor.is_ref(['?', '位号', 'C75,C149', 'C59 C252,C5302', 'C49,C50', 'C5323,C5324,C5337', 'C5261,C5262', 'C5312', 'C419', 'C7,C26,C27,C133', 'C37,C108,C109,C110,C111,C127,C194,C248,C123,C420', 'C5189,C5291,C5332', 'C255,C5190,C5292,C5333 C5267,C5272', 'C107,C116', 'C5309,C5340,C5343', 'C423,C426', 'C66,C76,C5265,C5266,C5270,C5271,C5326,C5253,C5254,C5255,C5256,C5257,C5258,C5259,C5260,C5263,C5264,C5306', 'C2,C3,C4,C10,C14,C15,C16,C35,C52,C54,C55,C56,C57,C72,C73,C77,C78,C83,C85,C86,C87,C88,C89,C91,C92,C93,C94,C95,C96,C97,C98,C99,C100,C101,C102,C103,C104,C105,C128,C154,C156,C159,C162,C163,C165,C166,C169,C170,C180,C189,C193,C197,C205,C235,C243,C244,C246,C251,C258,C262,C263,C264,C266,C271,C334,C335,C615,C5188,C5229,C5235,C5290,C5294,C5296,C5299,C5300,C5301,C5316,C5331,C5335,C5345,C18,C117,C353,C362,C363,C377,C378,C380,C424,C425,C5206,C5223,C5226,C5228,C5230,C5231,C5233,C5234,C5240,C5242,C5243,C5245,C5247,C5248,C5249,C5251,C5252,C5268,C5269,C5342,C5344,C5346,C421,C5303,C5305', 'C39,C68,C82,C84,C90,C142,C247,C250,C280,C435', 'C5338,C5339', 'C151 C427', 'C38,C40,C155,C157,C215,C254,C257,C260,C261,C270,C5314 C150 C181 C5224', 'C17,C29,C62,C63,C64,C65,C67,C106,C121,C122,C161,C164,C303,C422,C5244,C5298,C5341,C183,C5304', 'C1,C8,C9,C12,C19,C51,C53,C136,C168,C5193,C5207,C5208,C5293,C5315,C428,C5227,C5232,C5241,C5246,C5250,C5276,C5313,C5327,C5328,C5329,C5330', 'C13,C21,C69,C70,C81', 'C396', 'C364', 'CE5', 'CE3,CE4', 'CE1,CE2', 'R8,R9,R10,R56,R66,R84,R164,R226,R238,R240,R311,R312,R348,R5812,R5889,R5890 R5887', 'R5731,R5732,R5742,R5743,R5744,R5745,R5922,R5923', 'R58,R78,R172,R173', 'R143,R398,R399,R5801,R5802,R5862,R5863,R5864,R5865,R5866,R5867', 'R4,R5,R11,R25,R28,R67,R5750,R5751,R5752,R5753,R5787,R5810,R5822', 'R59', 'R36,R37', 'R5828', 'R239,R241', 'R38', 'R70,R71,R123', 'R1,R280,R5785,R5789,R5797,R5824,R5825,R5830,R5831,R5897', 'R405,R5768 R5888', 'R115 R400,R401,R5920', 'R5826,R5827,R5829,R5832', 'R151', 'R407', 'R5807', 'R101,R117,R144,R150,R153,R156,R157,R159,R161,R162,R307,R5704,R5706,R5748,R5749,R5869,R5874', 'R96,R137', 'R5723,R5857,R5916', 'R136', 'R5771 R5858', 'R12,R31,R32,R33,R34,R35,R42,R44,R61,R77,R80,R125,R132,R140,R141,R193,R212,R267,R419,R5144,R5765,R5766,R5767,R5786,R5800,R5808,R5809,R5815,R5816,R5817,R5894,R5892,R5895,R5896,R5907,R5908,R5910,R5911,R5913,R5914,R5921', 'R139 R5861', 'R92,R138', 'R5147,R5891', 'R45,R5903,R5905,R5906', 'R64', 'R21', 'R65,R5770,R5772', 'R5724,R5769', 'R5782', 'R23,R26,R30,R39,R126,R127,R5722,R5779,R5781,R5783,R5784,R5798,R5799,R5805,R5813,R5814,R5823,R5856,R5915,R5924 R5917', 'R20,R24,R27 R5859', 'R29', 'R43,R49', 'R5725', 'R128,R129', 'R5898 R5918', 'L6,L16 R13', 'R2,R3', 'R130,R131', 'L5103', 'RN2,RN3', 'RP1,RP2', 'FB9,FB10,FB12 L53,L5107,L5108', 'L22', 'L7', 'L24 L5101,L5106,L5109', 'L3,L4', 'L2,L5 L9', 'X6', 'Y1', 'Y2', 'Y3', 'Y5', 'F1', 'D7 D33,D34 D45', 'D40,D41', 'D4,D47', 'D1', 'D5', 'D46', 'Q1,Q3 Q11', 'Q23 1Q1', 'U5116', 'U1', 'U5119', 'U2', 'U5115', 'U5110,U5128,U5133', 'U9,U12', 'U8,U11', 'U13', 'U5120 U5124 ', 'U19 U5118,U5121,U5123', 'U10', 'U20', 'U31', 'U5113', 'U5122', 'U5130', 'U5132', 'U5134', 'U3', 'U5,U6', 'U5126', 'P11', 'ESD1,D13,D14,D15,D16,D17,D18,D19,ESD39,ESD40,ESD41,ESD42,ESD43,ESD44,ESD45,ESD46,ESD47,ESD48,D48,ESD49,ESD55,ESD56,ESD57,ESD61,ESD62,ESD63,ESD64,ESD65,ESD66,ESD67,ESD68,ESD69,ESD70', 'TVS2', ',P1,P2,P5,P10,P4617,P4618,P4620,P4621,P4625', 'P7', 'K1', 'J5', 'J33', 'ANT1', 'J4604,MIC1,MIC2', 'J2 CN2', 'J4601,J4602', 'CTP', 'J4605', 'SW1', 'BAT1', '?'])) print(BasePredictor.is_ref(
['?', '位号', 'C75,C149', 'C59 C252,C5302', 'C49,C50', 'C5323,C5324,C5337', 'C5261,C5262', 'C5312', 'C419',
'C7,C26,C27,C133', 'C37,C108,C109,C110,C111,C127,C194,C248,C123,C420', 'C5189,C5291,C5332',
'C255,C5190,C5292,C5333 C5267,C5272', 'C107,C116', 'C5309,C5340,C5343', 'C423,C426',
'C66,C76,C5265,C5266,C5270,C5271,C5326,C5253,C5254,C5255,C5256,C5257,C5258,C5259,C5260,C5263,C5264,C5306',
'C2,C3,C4,C10,C14,C15,C16,C35,C52,C54,C55,C56,C57,C72,C73,C77,C78,C83,C85,C86,C87,C88,C89,C91,C92,C93,C94,C95,C96,C97,C98,C99,C100,C101,C102,C103,C104,C105,C128,C154,C156,C159,C162,C163,C165,C166,C169,C170,C180,C189,C193,C197,C205,C235,C243,C244,C246,C251,C258,C262,C263,C264,C266,C271,C334,C335,C615,C5188,C5229,C5235,C5290,C5294,C5296,C5299,C5300,C5301,C5316,C5331,C5335,C5345,C18,C117,C353,C362,C363,C377,C378,C380,C424,C425,C5206,C5223,C5226,C5228,C5230,C5231,C5233,C5234,C5240,C5242,C5243,C5245,C5247,C5248,C5249,C5251,C5252,C5268,C5269,C5342,C5344,C5346,C421,C5303,C5305',
'C39,C68,C82,C84,C90,C142,C247,C250,C280,C435', 'C5338,C5339', 'C151 C427',
'C38,C40,C155,C157,C215,C254,C257,C260,C261,C270,C5314 C150 C181 C5224',
'C17,C29,C62,C63,C64,C65,C67,C106,C121,C122,C161,C164,C303,C422,C5244,C5298,C5341,C183,C5304',
'C1,C8,C9,C12,C19,C51,C53,C136,C168,C5193,C5207,C5208,C5293,C5315,C428,C5227,C5232,C5241,C5246,C5250,C5276,C5313,C5327,C5328,C5329,C5330',
'C13,C21,C69,C70,C81', 'C396', 'C364', 'CE5', 'CE3,CE4', 'CE1,CE2',
'R8,R9,R10,R56,R66,R84,R164,R226,R238,R240,R311,R312,R348,R5812,R5889,R5890 R5887',
'R5731,R5732,R5742,R5743,R5744,R5745,R5922,R5923', 'R58,R78,R172,R173',
'R143,R398,R399,R5801,R5802,R5862,R5863,R5864,R5865,R5866,R5867',
'R4,R5,R11,R25,R28,R67,R5750,R5751,R5752,R5753,R5787,R5810,R5822', 'R59', 'R36,R37', 'R5828', 'R239,R241',
'R38', 'R70,R71,R123', 'R1,R280,R5785,R5789,R5797,R5824,R5825,R5830,R5831,R5897', 'R405,R5768 R5888',
'R115 R400,R401,R5920', 'R5826,R5827,R5829,R5832', 'R151', 'R407', 'R5807',
'R101,R117,R144,R150,R153,R156,R157,R159,R161,R162,R307,R5704,R5706,R5748,R5749,R5869,R5874', 'R96,R137',
'R5723,R5857,R5916', 'R136', 'R5771 R5858',
'R12,R31,R32,R33,R34,R35,R42,R44,R61,R77,R80,R125,R132,R140,R141,R193,R212,R267,R419,R5144,R5765,R5766,R5767,R5786,R5800,R5808,R5809,R5815,R5816,R5817,R5894,R5892,R5895,R5896,R5907,R5908,R5910,R5911,R5913,R5914,R5921',
'R139 R5861', 'R92,R138', 'R5147,R5891', 'R45,R5903,R5905,R5906', 'R64', 'R21', 'R65,R5770,R5772',
'R5724,R5769', 'R5782',
'R23,R26,R30,R39,R126,R127,R5722,R5779,R5781,R5783,R5784,R5798,R5799,R5805,R5813,R5814,R5823,R5856,R5915,R5924 R5917',
'R20,R24,R27 R5859', 'R29', 'R43,R49', 'R5725', 'R128,R129', 'R5898 R5918', 'L6,L16 R13', 'R2,R3', 'R130,R131',
'L5103', 'RN2,RN3', 'RP1,RP2', 'FB9,FB10,FB12 L53,L5107,L5108', 'L22', 'L7', 'L24 L5101,L5106,L5109', 'L3,L4',
'L2,L5 L9', 'X6', 'Y1', 'Y2', 'Y3', 'Y5', 'F1', 'D7 D33,D34 D45', 'D40,D41', 'D4,D47', 'D1', 'D5', 'D46',
'Q1,Q3 Q11', 'Q23 1Q1', 'U5116', 'U1', 'U5119', 'U2', 'U5115', 'U5110,U5128,U5133', 'U9,U12', 'U8,U11', 'U13',
'U5120 U5124 ', 'U19 U5118,U5121,U5123', 'U10', 'U20', 'U31', 'U5113', 'U5122', 'U5130', 'U5132', 'U5134',
'U3', 'U5,U6', 'U5126', 'P11',
'ESD1,D13,D14,D15,D16,D17,D18,D19,ESD39,ESD40,ESD41,ESD42,ESD43,ESD44,ESD45,ESD46,ESD47,ESD48,D48,ESD49,ESD55,ESD56,ESD57,ESD61,ESD62,ESD63,ESD64,ESD65,ESD66,ESD67,ESD68,ESD69,ESD70',
'TVS2', ',P1,P2,P5,P10,P4617,P4618,P4620,P4621,P4625', 'P7', 'K1', 'J5', 'J33', 'ANT1', 'J4604,MIC1,MIC2',
'J2 CN2', 'J4601,J4602', 'CTP', 'J4605', 'SW1', 'BAT1', '?']))
...@@ -7,7 +7,7 @@ from static_config import * ...@@ -7,7 +7,7 @@ from static_config import *
def fun(seri): def fun(seri):
li_seri = seri.tolist() li_seri = seri.tolist()
for index, field in enumerate(li_seri): for index, field in enumerate(li_seri):
if str(field).lower() in PROB_FIELDS: if str(field).lower() in ALL_FIELDS:
return field, seri.name return field, seri.name
...@@ -16,15 +16,16 @@ def get_head_row(li: list) -> int: ...@@ -16,15 +16,16 @@ def get_head_row(li: list) -> int:
返回第一行有效数据所在的行 返回第一行有效数据所在的行
""" """
for index, i in enumerate(li): for index, i in enumerate(li):
if str(i).lower() in PROB_FIELDS: if str(i).lower() in ALL_FIELDS:
return index + 1 return index + 1
# 取前多少行 # 取前多少行做表头预测
HEAD_ROW = 7 HEAD_ROW = 7
# 空置率阈值 # 空置率阈值
NAN_RATE = 0.8 NAN_RATE = 0.8
# 参数列空置率阈值
PARAM_NAN_RATE = 0.3
class DicPredict(BasePredictor): class DicPredict(BasePredictor):
...@@ -38,7 +39,6 @@ class DicPredict(BasePredictor): ...@@ -38,7 +39,6 @@ class DicPredict(BasePredictor):
li_fie = series.tolist() li_fie = series.tolist()
li_res = [] li_res = []
for field, column_name in li_fie: for field, column_name in li_fie:
print(field)
dic = { dic = {
'std_name': '', 'std_name': '',
'pronoun': field, 'pronoun': field,
...@@ -102,9 +102,9 @@ class DicPredict(BasePredictor): ...@@ -102,9 +102,9 @@ class DicPredict(BasePredictor):
def model_predict(self, dic_data): def model_predict(self, dic_data):
""" """
该方法目前只对[序号(非标准), 数量, 类别, 参数, 型号, 品牌, other]进行预测, 前三者是非模型预测 模型预测目前只对[数量, 类别, 参数, 型号, 品牌, 封装, 序号, 位号, 单位]进行预测
:param dic_data: :param dic_data:
:return :只有[参数, 数量]会强制有结果 :return :结果字典
""" """
prob_columns = [] prob_columns = []
temp_pre_model_res = {} temp_pre_model_res = {}
...@@ -139,8 +139,8 @@ class DicPredict(BasePredictor): ...@@ -139,8 +139,8 @@ class DicPredict(BasePredictor):
prob_columns} prob_columns}
for k, v in set_not_null_dic_data.items(): for k, v in set_not_null_dic_data.items():
li_single_pred_res = [] li_single_pred_res = []
for string in v: for item in v:
single_pred_res, probdic = self.get_single_predict(string) single_pred_res, probdic = self.get_single_predict(item, BasePredictor.model, BasePredictor.extractor)
li_single_pred_res.append(single_pred_res) li_single_pred_res.append(single_pred_res)
result = BasePredictor.repeat_max(li_single_pred_res) result = BasePredictor.repeat_max(li_single_pred_res)
# 如果该列被预测为其他, 则不做改动 # 如果该列被预测为其他, 则不做改动
...@@ -149,28 +149,41 @@ class DicPredict(BasePredictor): ...@@ -149,28 +149,41 @@ class DicPredict(BasePredictor):
# 如果类别列被预测为品牌, 则不做改动, 此处是基于目前的模型容易把中文多的列预测为品牌, 无奈之下所做的逻辑 # 如果类别列被预测为品牌, 则不做改动, 此处是基于目前的模型容易把中文多的列预测为品牌, 无奈之下所做的逻辑
if temp_pre_model_res.get(k) == '类别' and result == 'brand_name': if temp_pre_model_res.get(k) == '类别' and result == 'brand_name':
continue continue
# 参数列单独要求空置率不能高于PARAM_NAN_RATE
if result == 'param':
if BasePredictor.get_nan_rate(v) >= PARAM_NAN_RATE:
continue
temp_pre_model_res[k] = EN_TO_ZH_MAP[result] temp_pre_model_res[k] = EN_TO_ZH_MAP[result]
# 若有多个型号列,则进行不同率的比较, 不同率最高的选为目标列
not_null_dic_data = {k: list(filter(lambda x: x != BasePredictor.PLACEHOLDER, dic_data[k])) for k in not_null_dic_data = {k: list(filter(lambda x: x != BasePredictor.PLACEHOLDER, dic_data[k])) for k in
prob_columns} prob_columns}
# prob_param_cols = [i for i in temp_pre_model_res if temp_pre_model_res[i] == '参数']
# 若有多个型号列,则进行不同率的比较, 不同率最高的选为目标列
prob_gn_cols = [i for i in temp_pre_model_res if temp_pre_model_res[i] == '型号'] prob_gn_cols = [i for i in temp_pre_model_res if temp_pre_model_res[i] == '型号']
for param_or_gn_col_list in [prob_gn_cols]: if len(prob_gn_cols) >= 2:
if len(param_or_gn_col_list) >= 2:
li_diffrate = [(col, BasePredictor.get_diffrate(not_null_dic_data[col])) for col in li_diffrate = [(col, BasePredictor.get_diffrate(not_null_dic_data[col])) for col in
param_or_gn_col_list] prob_gn_cols]
sort_li_diffrate = sorted(li_diffrate, key=lambda x: x[1], reverse=True) sort_li_diffrate = sorted(li_diffrate, key=lambda x: x[1], reverse=True)
for col_diffrate in sort_li_diffrate[1:]: for col_diffrate in sort_li_diffrate[1:]:
temp_pre_model_res.pop(col_diffrate[0]) temp_pre_model_res.pop(col_diffrate[0])
# 若有多个参数列,进行参数特征的数量比较, 特征最多的选为目标列 # 若有多个参数列, 先进行封装列的提取(封装率需要大于0), 再进行参数特征的数量比较, 特征最多的选为目标列
prob_param_cols = [i for i in temp_pre_model_res if temp_pre_model_res[i] == '参数'] prob_param_cols = [i for i in temp_pre_model_res if temp_pre_model_res[i] == '参数']
self.info.info(f'可能的参数列有{str(prob_param_cols)}')
if len(prob_param_cols) >= 2: if len(prob_param_cols) >= 2:
li_encap_rate = [(col, BasePredictor.get_encap_rate(not_null_dic_data[col])) for col in
prob_param_cols]
sort_li_encaprate = sorted(li_encap_rate, key=lambda x: x[1], reverse=True)
if sort_li_encaprate[0][1] >= 0:
temp_pre_model_res[sort_li_encaprate[0][0]] = '封装'
prob_param_cols.remove(sort_li_encaprate[0][0])
li_feature_rate = [(col, BasePredictor.get_param_featurerate(not_null_dic_data[col])) for col in li_feature_rate = [(col, BasePredictor.get_param_featurerate(not_null_dic_data[col])) for col in
prob_param_cols] prob_param_cols]
sort_li_fearate = sorted(li_feature_rate, key=lambda x: x[1], reverse=True) sort_li_fearate = sorted(li_feature_rate, key=lambda x: x[1], reverse=True)
for col_fearate in sort_li_fearate[1:]: for col_fearate in sort_li_fearate[1:]:
self.info.info(f'参数列丢弃{col_fearate[0]}')
temp_pre_model_res.pop(col_fearate[0]) temp_pre_model_res.pop(col_fearate[0])
# 若有多个数量列,则进行空置率的比较, 空置率最低的选为目标列, #TODO 后续可能需要改成数量元素的占比率 # 若有多个数量列,则进行空置率的比较, 空置率最低的选为目标列, #TODO 后续可能需要改成数量元素的占比率
...@@ -178,11 +191,8 @@ class DicPredict(BasePredictor): ...@@ -178,11 +191,8 @@ class DicPredict(BasePredictor):
if len(prob_num_cols) >= 2: if len(prob_num_cols) >= 2:
li_nullrate = [] li_nullrate = []
for prob_num_col in prob_num_cols: for prob_num_col in prob_num_cols:
counter = 0 nanrate = BasePredictor.get_nan_rate(dic_data[prob_num_col])
for item in dic_data[prob_num_col]: li_nullrate.append((prob_num_col, nanrate))
if str(item).strip() == BasePredictor.PLACEHOLDER:
counter += 1
li_nullrate.append((prob_num_col, counter / len(dic_data[prob_num_col])))
sort_li_nullrate = sorted(li_nullrate, key=lambda x: x[1]) sort_li_nullrate = sorted(li_nullrate, key=lambda x: x[1])
for col_nullrate in sort_li_nullrate[1:]: for col_nullrate in sort_li_nullrate[1:]:
temp_pre_model_res.pop(col_nullrate[0]) temp_pre_model_res.pop(col_nullrate[0])
...@@ -192,11 +202,8 @@ class DicPredict(BasePredictor): ...@@ -192,11 +202,8 @@ class DicPredict(BasePredictor):
if len(prob_brand_cols) >= 2: if len(prob_brand_cols) >= 2:
li_nullrate = [] li_nullrate = []
for prob_brand_col in prob_brand_cols: for prob_brand_col in prob_brand_cols:
counter = 0 nanrate = BasePredictor.get_nan_rate(dic_data[prob_brand_col])
for item in dic_data[prob_brand_col]: li_nullrate.append((prob_brand_col, nanrate))
if str(item).strip() == BasePredictor.PLACEHOLDER:
counter += 1
li_nullrate.append((prob_brand_col, counter / len(dic_data[prob_brand_col])))
sort_li_nullrate = sorted(li_nullrate, key=lambda x: x[1]) sort_li_nullrate = sorted(li_nullrate, key=lambda x: x[1])
for col_nullrate in sort_li_nullrate[1:]: for col_nullrate in sort_li_nullrate[1:]:
temp_pre_model_res.pop(col_nullrate[0]) temp_pre_model_res.pop(col_nullrate[0])
......
...@@ -7,8 +7,14 @@ class KwPredict(BasePredictor): ...@@ -7,8 +7,14 @@ class KwPredict(BasePredictor):
def predict(self, string, predict_type='all'): def predict(self, string, predict_type='all'):
self.info.info('预测类型为: %s, 接收数据: %s' % (predict_type, string)) self.info.info('预测类型为: %s, 接收数据: %s' % (predict_type, string))
result, prab = self.get_single_predict(string) if predict_type == 'all':
result, prab = self.get_single_predict(string, BasePredictor.model, BasePredictor.extractor)
zh_prob = {EN_TO_ZH_MAP[k]: v for k, v in prab.items()} zh_prob = {EN_TO_ZH_MAP[k]: v for k, v in prab.items()}
res = {'result': EN_TO_ZH_MAP[result], 'probably_dict': zh_prob, 'predict_type': predict_type}
else:
model = getattr(BasePredictor, predict_type + '_model')
extractor = getattr(BasePredictor, predict_type + '_extractor')
result, prab = self.get_single_predict(string, model, extractor)
res = {'result': int(result), 'probably_dict': prab, 'predict_type': predict_type}
res = {'result': EN_TO_ZH_MAP[result], 'probably_dict': zh_prob}
return res return res
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
# 可能的头部字段
PROB_FIELDS = ["序号", "名称", "规格", "mpn", "用量(pcs)", "用量", "pcs", "位号", "描述", "值", "数量", "封装", "类别", "a面位置", "b面位置", "备注",
"需求数量", "售价", "封装", "封装规格", '型号', '参数', '数量', '品牌', '型号', '类型', "quantity", "qty", "buy qty",
"buy quantity", "需求用量","manufacturer制造商", "manufacturer", "description"
"参考品牌", "品牌", "item", "厂商编码", "品牌/厂商", "参考料号", "参考供应商", "top面", "bottom面", "designator", "remark", "元器件",
"标号", "需求型号", "Footprint封装", "Footprint", "产品分类", "单板数量", "规格型号", "packagereference", "footprint封装"]
PROB_FIELDS_1 = ['*' + i for i in PROB_FIELDS]
PROB_FIELDS.extend(PROB_FIELDS_1)
# 标准名和代名词的映射 # 标准名和代名词的映射
li_category = ["类别", "分类", "名称", "类别名称", "类型", "产品分类"] li_category = ["类别", "分类", "名称", "类别名称", "类型", "产品分类"]
li_param = ["参数", "规格", "描述", "值", "description"] li_param = ["参数", "规格", "描述", "值", "description"]
li_gn = ["型号", "参考料号", "料号", "mpn", "厂商编码", "元器件", "需求型号", "规格型号"] li_gn = ["型号", "参考料号", "料号", "mpn", "厂商编码", "元器件", "需求型号", "规格型号"]
li_num = ["数量", "用量(pcs)", "用量", "pcs", "quantity", "qty", "buy qty", "buy quantity", "需求用量", "单板数量"] li_num = ["数量", "用量(pcs)", "用量", "pcs", "quantity", "qty", "buy qty", "buy quantity", "需求用量", "单板数量", "采购数量"]
li_brand = ["品牌", "品牌/厂商", "参考品牌", "参考供应商", "厂商", "参考供应商", "参考厂商", "manufacturer制造商", "manufacturer"] li_brand = ["品牌", "品牌/厂商", "参考品牌", "参考供应商", "厂商", "参考供应商", "参考厂商", "manufacturer制造商", "manufacturer", "厂牌"]
li_encap = ["封装", "封装规格", "encapsulation", "footprint封装", 'packagereference']
li_category.extend(['*' + i for i in li_category]) li_category.extend(['*' + i for i in li_category])
li_param.extend(['*' + i for i in li_param]) li_param.extend(['*' + i for i in li_param])
li_gn.extend(['*' + i for i in li_gn]) li_gn.extend(['*' + i for i in li_gn])
li_num.extend(['*' + i for i in li_num]) li_num.extend(['*' + i for i in li_num])
li_brand.extend(['*' + i for i in li_brand]) li_brand.extend(['*' + i for i in li_brand])
li_encap.extend(['*' + i for i in li_encap])
PROB_FIELDS = ["序号", "位号", "a面位置", "位置", "b面位置", "备注", "售价", "item", "top面", "bottom面", "designator", "remark", "标号"]
AB_FIELDS = PROB_FIELDS + ['*' + i for i in PROB_FIELDS]
# 可能的头部字段
ALL_FIELDS = AB_FIELDS + li_category + li_param + li_gn + li_num + li_brand + li_encap
STD_FIELDS_MAP = { STD_FIELDS_MAP = {
"类别": li_category, "类别": li_category,
"参数": li_param, "参数": li_param,
"型号": li_gn, "型号": li_gn,
"数量": li_num, "数量": li_num,
"品牌": li_brand "品牌": li_brand,
"封装": li_encap
} }
# 必须返回也必须验证的标准字段 # 必须返回也必须验证的标准字段
MUST_STD_FIELDS = ['参数', '数量'] MUST_STD_FIELDS = ['参数', '数量']
# 参数名和中文的映射 # 参数名和中文的映射
EN_TO_ZH_MAP = {'brand_name': '品牌', 'param': '参数', 'goods_name': '型号', 'other': '其他'} EN_TO_ZH_MAP = {'brand_name': '品牌', 'param': '参数', 'goods_name': '型号', 'other': '其他', 'encap': '封装'}
# 类别合集 从learning_data.lie_category导入, 并添加了部分短英文Category # 类别合集 从learning_data.lie_category导入, 并添加了部分短英文Category
CATEGORY = ["半导体", "嵌入式", "光电子", "光源", "无源", "连接器", "断路器", "指示灯", "声源", "接触器", "铁氧芯", "冷热系统", "电源", "电线", "机械", CATEGORY = ["半导体", "嵌入式", "光电子", "光源", "无源", "连接器", "断路器", "指示灯", "声源", "接触器", "铁氧芯", "冷热系统", "电源", "电线", "机械",
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment