当前位置:首页 > > 充电吧
[导读]学习一种工具最简单和最有效的方法是download一个demo,根据教程模拟。Caffe作为深度学习框架,它也是一种工具,官方提供了一些demo,主要是在Caffe运行的网络架构文件。那么如何跑起一个

学习一种工具最简单和最有效的方法是download一个demo,根据教程模拟。Caffe作为深度学习框架,它也是一种工具,官方提供了一些demo,主要是在Caffe运行的网络架构文件。那么如何跑起一个demo呢?或者如何用demo直接做预测呢?

训练:caffe train --solver solver.txt 这样就可以了,如果有已训练好的参数或者进行学习迁移finetuning,那么训练的参数可以添加 "--weight init.caffemodel"或者"--snapshot snapshotfile.solvestate"

测试或预测:caffe test --weight test.caffemodel --model test.txt --iteration test_iteration

更加深入的理解,还要从源码出发,首先Caffe源码的tools目录下的caffe.cpp是生成可执行文件的源码,其中定义了train()和test()两个函数分别执行训练和测试,也可以自定义函数执行特定操作。

首先看train()函数:


  CHECK_GT(FLAGS_solver.size(), 0) << "Need a solver definition to train.";
  CHECK(!FLAGS_snapshot.size() || !FLAGS_weights.size())
      << "Give a snapshot to resume training or weights to finetune "
      "but not both.";

首先检查是否有solver.txt文件,文件名可自定义,但是内容必须符合solver结构,在src/caffe/proto/caffe.proto中有此定义。然后检查snapshot和weight,这两个参数分别用于中断后继续训练和学习迁移的。
参数检查完毕,caffe开始加载solver文件,之后是检查训练要工作在CPU还是GPU。


  caffe::SolverParameter solver_param;
  caffe::ReadSolverParamsFromTextFileOrDie(FLAGS_solver, &solver_param);

工作设备设定后,根据加载的solver_param参数创建solver对象,并根据snapshot和weight参数加载模型参数,如果两个参数都没有设置,则模型根据提供的初始化类型或者默认值进行初始化。


shared_ptr<caffe::Solver>  solver(caffe::SolverRegistry::CreateSolver(solver_param));

  solver->SetActionFunction(signal_handler.GetActionFunction());

  if (FLAGS_snapshot.size()) {
    LOG(INFO) << "Resuming from " << FLAGS_snapshot;
    solver->Restore(FLAGS_snapshot.c_str());
  } else if (FLAGS_weights.size()) {
    CopyLayers(solver.get(), FLAGS_weights);
  }

设置和初始化完成,就可以训练了。CPU版本的就是solver对象调用其Solver()函数。


  LOG(INFO) << "Starting Optimization";
  if (gpus.size() > 1) {
#ifdef USE_NCCL
    caffe::NCCLnccl(solver);
    nccl.Run(gpus, FLAGS_snapshot.size() > 0 ? FLAGS_snapshot.c_str() : NULL);
#else
    LOG(FATAL) << "Multi-GPU execution not available - rebuild with USE_NCCL";
#endif
  } else {
    solver->Solve();
  }


再看test()函数:

首先也是检查solver文件和权值文件weight,但是此时weight必须提供,否则无法预测


  CHECK_GT(FLAGS_model.size(), 0) << "Need a model definition to score.";
  CHECK_GT(FLAGS_weights.size(), 0) << "Need model weights to score.";

然后检测当前平台是否支持GPU,如果支持,则默认使用GPU进行预测


// Set device id and mode
  vectorgpus;
  get_gpus(&gpus);
  if (gpus.size() != 0) {
    LOG(INFO) << "Use GPU with device ID " << gpus[0];
#ifndef CPU_ONLY
    cudaDeviceProp device_prop;
    cudaGetDeviceProperties(&device_prop, gpus[0]);
    LOG(INFO) << "GPU device name: " << device_prop.name;
#endif
    Caffe::SetDevice(gpus[0]);
    Caffe::set_mode(Caffe::GPU);
  } else {
    LOG(INFO) << "Use CPU.";
    Caffe::set_mode(Caffe::CPU);
  }

平台设定后,创建Net对象初始化预测的神经网络,然后使用weight参数初始化网络权值


  Netcaffe_net(FLAGS_model, caffe::TEST, FLAGS_level, &stages);
  caffe_net.CopyTrainedLayersFrom(FLAGS_weights);

然后就可以开始预测了


  for (int i = 0; i < FLAGS_iterations; ++i) {
    float iter_loss;
    const vector<Blob*>& result =
        caffe_net.Forward(&iter_loss);
    loss += iter_loss;
    int idx = 0;
    for (int j = 0; j < result.size(); ++j) {
      const float* result_vec = result[j]->cpu_data();
      for (int k = 0; k < result[j]->count(); ++k, ++idx) {
        const float score = result_vec[k];
        if (i == 0) {
          test_score.push_back(score);
          test_score_output_id.push_back(j);
        } else {
          test_score[idx] += score;
        }
        const std::string& output_name = caffe_net.blob_names()[
            caffe_net.output_blob_indices()[j]];
        LOG(INFO) << "Batch " << i << ", " << output_name << " = " << score;
      }
    }
  }


主要是caffe_net.Forward(&iter_loss);这一句,其他都是为了可视化的参数。











本站声明: 本文章由作者或相关机构授权发布,目的在于传递更多信息,并不代表本站赞同其观点,本站亦不保证或承诺内容真实性等。需要转载请联系该专栏作者,如若文章内容侵犯您的权益,请及时联系本站删除。
换一批
延伸阅读

LED驱动电源的输入包括高压工频交流(即市电)、低压直流、高压直流、低压高频交流(如电子变压器的输出)等。

关键字: 驱动电源

在工业自动化蓬勃发展的当下,工业电机作为核心动力设备,其驱动电源的性能直接关系到整个系统的稳定性和可靠性。其中,反电动势抑制与过流保护是驱动电源设计中至关重要的两个环节,集成化方案的设计成为提升电机驱动性能的关键。

关键字: 工业电机 驱动电源

LED 驱动电源作为 LED 照明系统的 “心脏”,其稳定性直接决定了整个照明设备的使用寿命。然而,在实际应用中,LED 驱动电源易损坏的问题却十分常见,不仅增加了维护成本,还影响了用户体验。要解决这一问题,需从设计、生...

关键字: 驱动电源 照明系统 散热

根据LED驱动电源的公式,电感内电流波动大小和电感值成反比,输出纹波和输出电容值成反比。所以加大电感值和输出电容值可以减小纹波。

关键字: LED 设计 驱动电源

电动汽车(EV)作为新能源汽车的重要代表,正逐渐成为全球汽车产业的重要发展方向。电动汽车的核心技术之一是电机驱动控制系统,而绝缘栅双极型晶体管(IGBT)作为电机驱动系统中的关键元件,其性能直接影响到电动汽车的动力性能和...

关键字: 电动汽车 新能源 驱动电源

在现代城市建设中,街道及停车场照明作为基础设施的重要组成部分,其质量和效率直接关系到城市的公共安全、居民生活质量和能源利用效率。随着科技的进步,高亮度白光发光二极管(LED)因其独特的优势逐渐取代传统光源,成为大功率区域...

关键字: 发光二极管 驱动电源 LED

LED通用照明设计工程师会遇到许多挑战,如功率密度、功率因数校正(PFC)、空间受限和可靠性等。

关键字: LED 驱动电源 功率因数校正

在LED照明技术日益普及的今天,LED驱动电源的电磁干扰(EMI)问题成为了一个不可忽视的挑战。电磁干扰不仅会影响LED灯具的正常工作,还可能对周围电子设备造成不利影响,甚至引发系统故障。因此,采取有效的硬件措施来解决L...

关键字: LED照明技术 电磁干扰 驱动电源

开关电源具有效率高的特性,而且开关电源的变压器体积比串联稳压型电源的要小得多,电源电路比较整洁,整机重量也有所下降,所以,现在的LED驱动电源

关键字: LED 驱动电源 开关电源

LED驱动电源是把电源供应转换为特定的电压电流以驱动LED发光的电压转换器,通常情况下:LED驱动电源的输入包括高压工频交流(即市电)、低压直流、高压直流、低压高频交流(如电子变压器的输出)等。

关键字: LED 隧道灯 驱动电源
关闭