/* QuartilePlotRederer.java
 * Copyright 2007  Casey Marshall <csm@soe.ucsc.edu>
 */


package org.jfree.chart.renderer.category;

import java.awt.BasicStroke;
import java.awt.Color;
import java.awt.Graphics2D;
import java.awt.Shape;
import java.awt.Stroke;
import java.awt.geom.Ellipse2D;
import java.awt.geom.Line2D;
import java.awt.geom.Rectangle2D;
import java.io.Serializable;

import org.jfree.chart.axis.CategoryAxis;
import org.jfree.chart.axis.ValueAxis;
import org.jfree.chart.plot.CategoryPlot;
import org.jfree.chart.plot.PlotOrientation;
import org.jfree.chart.plot.PlotRenderingInfo;
import org.jfree.data.category.CategoryDataset;
import org.jfree.data.statistics.BoxAndWhiskerCategoryDataset;
import org.jfree.ui.RectangleEdge;
import org.jfree.util.PublicCloneable;

/**
 * Draws a revision of the "range bar," "box plot," or "box-and-whisker plot"
 * with the method suggested by E. R. Tufte, in <em>The Visual Display of
 * Quatitative Information,</em> Second Edition, pages 123&ndash;125.
 * 
 * @author csm
 */
public class QuartilePlotRenderer extends AbstractCategoryItemRenderer
    implements Cloneable, PublicCloneable, Serializable
{

  /**
   * 
   */
  private static final long serialVersionUID = -6342465252163802843L;

  /**
   * Quartile plot styles supported by this renderer.
   */
  public static enum Style
  {
    /**
     * The "dot and whisker" style draws the median as a dot, the minimum to
     * lower quartile and maximum to upper quartile as a line, and the
     * interquartile range is the white space in between.
     * 
     * You can use the shape and paint attributes to control the style of the
     * dot, on a per-series basis, if you prefer.
     * 
     * This is the default style.
     */
    DOT_AND_WHISKER,
    
    /**
     * The line weight style. The minimum and maximum range is shown by a line,
     * the interquartile range is shown by thickening the line, and the median
     * is shown by a notch.
     */
    LINE_WEIGHT,
    
    /**
     * The line offset style. This is similar to the {@link #LINE_WEIGHT} style,
     * but the interquartile range is offset, not thickened.
     */
    LINE_OFFSET
  }

  /** The plot style. */
  private Style style;
  
  private boolean drawsMean;
  
  public QuartilePlotRenderer()
  {
    style = Style.DOT_AND_WHISKER;
    drawsMean = false;
    setBaseStroke(new BasicStroke(1.0f));
    setBasePaint(Color.black);
    setBaseShape(new Ellipse2D.Double(0.0, 0.0, 1.0, 1.0));
    setBaseFillPaint(Color.red);
  }
  
  /**
   * Tells if this plot will draw the mean value, as a horizontal tickmark.
   *
   * @return True if the mean is also drawn.
   */
  public boolean drawsMean()
  {
    return drawsMean;
  }
  
  /**
   * Returns the plot style.
   *
   * @return The plot style.
   */
  public Style getStyle()
  {
    return style;
  }

  /**
   * Set whether or not this renderer should include the mean, in addition to
   * the median, quartiles, and maximum and minimum inliers.
   *
   * @param drawsMean The flag value.
   */
  public void setDrawsMean(boolean drawsMean)
  {
    this.drawsMean = drawsMean;
  }
  
  /**
   * Sets the plot style.
   * 
   * @param style The plot style.
   */
  public void setStyle(Style style)
  {
    this.style = style;
  }
  
  /*
   */
  @Override
  public CategoryItemRendererState initialise(Graphics2D g2,
                                              Rectangle2D dataArea,
                                              CategoryPlot plot,
                                              int rendererIndex,
                                              PlotRenderingInfo info)
  {
    CategoryItemRendererState state = super.initialise(g2, dataArea, plot,
                                                       rendererIndex, info);
    switch (style)
    {
      case DOT_AND_WHISKER:
        state.setBarWidth(getSeriesShape(0).getBounds2D().getWidth());
        break;

      case LINE_OFFSET:
      case LINE_WEIGHT:
        // XXX LAME
        Stroke s = getSeriesStroke(0);
        if (s instanceof BasicStroke)
          state.setBarWidth(((BasicStroke) s).getLineWidth()
                            * (drawsMean ? 3 : 2));
    }
    return state;
  }

  /* (non-Javadoc)
   * @see org.jfree.chart.renderer.category.CategoryItemRenderer#drawItem(java.awt.Graphics2D, org.jfree.chart.renderer.category.CategoryItemRendererState, java.awt.geom.Rectangle2D, org.jfree.chart.plot.CategoryPlot, org.jfree.chart.axis.CategoryAxis, org.jfree.chart.axis.ValueAxis, org.jfree.data.category.CategoryDataset, int, int, int)
   */
  public void drawItem(Graphics2D g2, CategoryItemRendererState state,
                       Rectangle2D dataArea, CategoryPlot plot,
                       CategoryAxis domainAxis, ValueAxis rangeAxis,
                       CategoryDataset dataset, int row, int column, int pass)
  {
    if (!(dataset instanceof BoxAndWhiskerCategoryDataset))
      throw new IllegalArgumentException("expecting BoxAndWhiskerCategoryDataset");

    PlotOrientation orientation = plot.getOrientation();
    if (orientation == PlotOrientation.HORIZONTAL)
      drawHorizontalItem (g2, state, dataArea, plot, domainAxis, rangeAxis,
                          (BoxAndWhiskerCategoryDataset) dataset, row, column,
                          pass);
    else if (orientation == PlotOrientation.VERTICAL)
      drawVerticalItem(g2, state, dataArea, plot, domainAxis, rangeAxis,
                       (BoxAndWhiskerCategoryDataset) dataset, row, column, pass);
  }
  
  private void drawHorizontalItem(Graphics2D g2, CategoryItemRendererState state,
                                  Rectangle2D dataArea, CategoryPlot plot,
                                  CategoryAxis domainAxis, ValueAxis rangeAxis,
                                  BoxAndWhiskerCategoryDataset dataset, int row,
                                  int column, int pass)
  {
    Stroke stroke = getItemStroke (row, column);
    g2.setStroke(stroke);
    g2.setPaint(getItemPaint(row, column));

    // All styles draw the outer bars the same.
    RectangleEdge edge = plot.getRangeAxisEdge();
    double categoryEnd = domainAxis.getCategoryEnd(column,
                                                   getColumnCount(),
                                                   dataArea,
                                                   plot.getDomainAxisEdge());
    double categoryStart = domainAxis.getCategoryStart(column,
                                                       getColumnCount(),
                                                       dataArea,
                                                       plot.getDomainAxisEdge());
    double categoryWidth = categoryEnd - categoryStart;
    double spacing = (categoryWidth - (getRowCount() * state.getBarWidth()))
                      / (getRowCount() + 1);

    double y = categoryStart;
    
    if (getRowCount() > 1)
      {
        y += (row * state.getBarWidth()) + ((row + 1) * spacing);
      }
    else
      {
        y += categoryWidth / 2;
      }
    Number maxVal = dataset.getMaxRegularValue(row, column);
    Number minVal = dataset.getMinRegularValue(row, column);
    Number q1Val  = dataset.getQ1Value(row, column);
    Number q3Val  = dataset.getQ3Value(row, column);
    double maxX = rangeAxis.valueToJava2D(maxVal.doubleValue(), dataArea, edge);
    double minX = rangeAxis.valueToJava2D(minVal.doubleValue(), dataArea, edge);
    double q1   = rangeAxis.valueToJava2D(q1Val.doubleValue(), dataArea, edge);
    double q3   = rangeAxis.valueToJava2D(q3Val.doubleValue(), dataArea, edge);

    g2.draw(new Line2D.Double(maxX, y, q3, y));
    g2.draw(new Line2D.Double(q1, y, minX, y));

    // Also draw the mean.
    if (drawsMean)
      {
        double length = 1.5;
        if (style == Style.DOT_AND_WHISKER)
          {
            length = getItemShape(row, column).getBounds2D().getWidth();
          }
        else if (stroke instanceof BasicStroke)
          {
            length = 2.0 * ((BasicStroke) stroke).getLineWidth();
          }
        
        Number mean = dataset.getMeanValue(row, column);
        double x = rangeAxis.valueToJava2D(mean.doubleValue(), dataArea, edge);
        g2.draw(new Line2D.Double(x, y - length, x, y));
      }
    
    switch (style)
    {
      case DOT_AND_WHISKER:
      {
        Number median = dataset.getMedianValue(row, column);
        double medX = rangeAxis.valueToJava2D(median.doubleValue(), dataArea, edge);
        g2.setPaint(getItemFillPaint(row, column));
        Shape s = getItemShape(row, column);
        Rectangle2D bounds = s.getBounds2D();
        double txY = y - bounds.getHeight() / 2 - bounds.getY();
        double txX = medX - bounds.getWidth() / 2 - bounds.getX();
        g2.translate(txX, txY);
        g2.fill(s);
        g2.translate(-txX, -txY);
      }
      break;

      case LINE_WEIGHT:
      {
        double offset = 1.0;
        if (stroke instanceof BasicStroke)
          offset = ((BasicStroke) stroke).getLineWidth();
        Number median = dataset.getMedianValue(row, column);
        double x1 = rangeAxis.valueToJava2D(median.doubleValue(), dataArea,
                                            edge);
        double x2 = x1 + offset;
        x1 -= offset;
        g2.draw(new Line2D.Double(q1, y, x2, y));
        g2.draw(new Line2D.Double(x1, y, q3, y));
      }
      // Fall-through.
      
      case LINE_OFFSET:
      {
        double offset = 1.0;
        if (stroke instanceof BasicStroke)
          offset = ((BasicStroke) stroke).getLineWidth();
        Number median = dataset.getMedianValue(row, column);
        double x1 = rangeAxis.valueToJava2D(median.doubleValue(), dataArea,
                                            edge);
        double x2 = x1 + offset;
        x1 -= offset;
        g2.draw(new Line2D.Double(q1, y + offset, x2, y + offset));
        g2.draw(new Line2D.Double(x1, y + offset, q3, y + offset));
      }
      break;
    }
  }
  
  private void drawVerticalItem(Graphics2D g2, CategoryItemRendererState state,
                                Rectangle2D dataArea, CategoryPlot plot,
                                CategoryAxis domainAxis, ValueAxis rangeAxis,
                                BoxAndWhiskerCategoryDataset dataset, int row,
                                int column, int pass)
  {
    Stroke stroke = getItemStroke (row, column);
    g2.setStroke(stroke);
    g2.setPaint(getItemPaint(row, column));

    // All styles draw the outer bars the same.
    RectangleEdge edge = plot.getRangeAxisEdge();
    double categoryEnd = domainAxis.getCategoryEnd(column,
                                                   getColumnCount(),
                                                   dataArea,
                                                   plot.getDomainAxisEdge());
    double categoryStart = domainAxis.getCategoryStart(column,
                                                       getColumnCount(),
                                                       dataArea,
                                                       plot.getDomainAxisEdge());
    double categoryWidth = categoryEnd - categoryStart;
    double spacing = (categoryWidth - (getRowCount() * state.getBarWidth()))
                      / (getRowCount() + 1);

    double x = categoryStart;
    
    if (getRowCount() > 1)
      {
        x += (row * state.getBarWidth()) + ((row + 1) * spacing);
      }
    else
      {
        x += categoryWidth / 2;
      }
    Number maxVal = dataset.getMaxRegularValue(row, column);
    Number minVal = dataset.getMinRegularValue(row, column);
    Number q1Val  = dataset.getQ1Value(row, column);
    Number q3Val  = dataset.getQ3Value(row, column);
    double maxY = rangeAxis.valueToJava2D(maxVal.doubleValue(), dataArea, edge);
    double minY = rangeAxis.valueToJava2D(minVal.doubleValue(), dataArea, edge);
    double q1   = rangeAxis.valueToJava2D(q1Val.doubleValue(), dataArea, edge);
    double q3   = rangeAxis.valueToJava2D(q3Val.doubleValue(), dataArea, edge);

    g2.draw(new Line2D.Double(x, maxY, x, q3));
    g2.draw(new Line2D.Double(x, q1, x, minY));

    // Also draw the mean.
    if (drawsMean)
      {
        double length = 1.5;
        if (style == Style.DOT_AND_WHISKER)
          {
            length = getItemShape(row, column).getBounds2D().getWidth();
          }
        else if (stroke instanceof BasicStroke)
          {
            length = 2.0 * ((BasicStroke) stroke).getLineWidth();
          }
        
        Number mean = dataset.getMeanValue(row, column);
        double y = rangeAxis.valueToJava2D(mean.doubleValue(), dataArea, edge);
        g2.draw(new Line2D.Double(x - length, y, x, y));
      }
    
    switch (style)
    {
      case DOT_AND_WHISKER:
      {
        Number median = dataset.getMedianValue(row, column);
        double medY = rangeAxis.valueToJava2D(median.doubleValue(), dataArea, edge);
        g2.setPaint(getItemFillPaint(row, column));
        Shape s = getItemShape(row, column);
        Rectangle2D bounds = s.getBounds2D();
        double txX = x - bounds.getWidth() / 2 - bounds.getX();
        double txY = medY - bounds.getHeight() / 2 - bounds.getY();
        g2.translate(txX, txY);
        g2.fill(s);
        g2.translate(-txX, -txY);
      }
      break;

      case LINE_WEIGHT:
      {
        double offset = 1.0;
        if (stroke instanceof BasicStroke)
          offset = ((BasicStroke) stroke).getLineWidth();
        Number median = dataset.getMedianValue(row, column);
        double y1 = rangeAxis.valueToJava2D(median.doubleValue(), dataArea,
                                            edge);
        double y2 = y1 + offset;
        y1 -= offset;
        g2.draw(new Line2D.Double(x, q1, x, y2));
        g2.draw(new Line2D.Double(x, y1, x, q3));
      }
      // Fall-through.
      
      case LINE_OFFSET:
      {
        double offset = 1.0;
        if (stroke instanceof BasicStroke)
          offset = ((BasicStroke) stroke).getLineWidth();
        Number median = dataset.getMedianValue(row, column);
        double y1 = rangeAxis.valueToJava2D(median.doubleValue(), dataArea,
                                            edge);
        double y2 = y1 + offset;
        y1 -= offset;
        g2.draw(new Line2D.Double(x + offset, q1, x + offset, y2));
        g2.draw(new Line2D.Double(x + offset, y1, x + offset, q3));
      }
      break;
    }
  }
}
